#!/usr/bin/python3

import numpy as np
from matplotlib import pyplot as plt
import os
import zlib
import lzma
import bz2

panels="a b c d e f g j i j".split()
cmap = plt.get_cmap("tab10")
cols=[cmap(0), cmap(3)]
point="x +".split()
lw= [2, 1]

filetype="png"
# filetype="pdf" !! Uncomment this to get .pdf figures

def compsize(inp, alg):
    if alg == 'zlib':
        return len(zlib.compress(inp))
    elif alg == 'bzip2':
        return len(bz2.compress(inp))
    elif alg == 'lzma':
        return len(lzma.compress(inp))
    else:
        raise ValueError("Can't process %s. Only  zlib, bzip2 , and lzma so far."%(alg,))



def genGaussian(N, background, sigma):
    np.random.seed(3) ## Make plots reproducible
    return  (background + sigma*np.random.normal(0,1,N)).astype(np.float32)

def gen53(N, background, sigma):
    # Generate a random signal with -5/3 power spectrum
    npts=N//2+1
    np.random.seed(3) ## Make plots reproducible

    # Gen spectrum
    sigf = np.pi * npts/((np.arange(npts)+1.)**(2.))*np.exp(2*np.random.random((npts,))*np.pi*1j)
    #sigf = np.pi * npts/(np.arange(npts)**(1.666)+1e-10)*np.exp(2*np.random.random((npts,))*np.pi*1j)
    sigf[0] = 0. # zero constant component

    # Trim lowest harmonics to get the desired norm
    norm = sigma*sigma
    sigE = np.cumsum(np.flip(np.abs(sigf)**2)) [::-1]
    i = np.argmax(sigE < norm)  # Index of the first frequency that has to be fully preserved
    sigf[0:i-1] = 0. #Zero lower freqs

    Ei = np.abs(sigf[i-1])**2 # Energy in frequency to be partly-cut
    sigf[i-1] *= np.sqrt(1. - (sigE[i-1] - norm)/Ei) # Scale it to remove excess energy (sigE[i-1] - norm) 
    a = N / np.sqrt(2) * np.fft.irfft(sigf)

    ### print( np.sqrt(np.mean(a*a))) Check variance 

    return np.array(a + background,  dtype=np.float32)




# Structure function
def strucf(sig, n):
    strucX = np.arange(1,n)
    strf = np.zeros((n-1,),dtype=np.float32)
    for i in range(1,n):
        strf[i-1] = np.sqrt(np.mean((sig[0:-n] - sig[i:i-n])**2))
    return strucX, strf
#for i in 
#print(x.shape)
#print(a)

methods="shave groom halfshave round groomhalf groomav".split()
def TrimPrecision(a, keepbits, method):
  # a is numpy array
  assert (a.dtype == np.float32)
  out = a.copy()
  b = out.view(dtype=np.int32)
  maskbits = 23 - keepbits
  mask=(0xFFFFFFFF >> maskbits)<<maskbits
  
  if method == 'shave':
      b &= mask
  elif  method == 'set':
      b |= ~mask 
  elif method == 'groom':
      b[0::2] &= mask  
      b[1::2] |= ~mask
  elif method == 'halfshave':#Shave, but set next bit
      b &= mask
      b |= (1<<(maskbits-1)) 
  elif method == 'round':
      b &= mask
      out[:] = 2*a - out
      b &= mask
  elif method == 'groomhalf': ## Groom and then apply half-cut
      out[:] =  TrimPrecision(a, keepbits, 'groom')
      b &= mask
      b |= (1<<(maskbits-1)) 
  elif method == 'groomav': ## Recover with running-mean
      out_groom =  TrimPrecision(a, keepbits, 'groom')
      out[0] = out_groom[0]
#      out[-1] = out_groom[-1]
#      out[1:-1] = 0.5*(out_groom[1:-1] + 0.5*(out_groom[0:-2] + out_groom[2:]))
      out[1:] = 0.5 * (out_groom[1:] + out_groom[0:-1])
      #out[:] = out_groom[:]
  else:
      raise KeyError("Wrong trim method")

  return out

def pack8(a):
    ma=np.amax(a)
    mi=np.amin(a)
    ndrv = 254
    off = (0.5*(ma+mi)).astype(np.float32)
    sc = ((ma-mi)/ndrv).astype(np.float32)
    pck = ((a-off)/sc).astype(np.int8)
    return pck, sc, off

def pack16(a):
    ma=np.amax(a)
    mi=np.amin(a)
    ndrv = 65534
    off = (0.5*(ma+mi)).astype(np.float32)
    sc = ((ma-mi)/ndrv).astype(np.float32)
    pck = ((a-off)/sc).astype(np.int16)
    return pck, sc, off

def unpack(pck,sc, off):

    return pck*sc + off


#
# Original signals
#
def plotSignals(signals, titles, nplot):
    fighalf=plt.figure(figsize=(8, 2.5))

    for iv, y in enumerate(signals):
        background = np.mean(y)
        sigma = np.std(y)
        ax = fighalf.add_subplot(1,2,iv+1)
        ax.set_ylim([background-3.99*sigma,background+3.99*sigma])
        ax.plot(y[0:nplot], color=cols[iv], linewidth=lw[iv])
        #ax.text(0,352,"%s)"%(panels[iv]))
        ax.set_xlabel("Count")
        if (iv < 1):
            ax.set_ylabel("Value")
        ax.set_title("%s) %s"%(panels[iv],titles[iv]))
    fighalf.subplots_adjust(top=0.92, bottom=0.17, wspace=0.15, left=0.07, right=0.99,)
    figname="Signals."+filetype
    fighalf.savefig(figname)

    #os.system('xli %s'%(figname,))
    fighalf.clf()

def shuffle2(tst):
        tst1 = bytearray(tst).copy()
        N = len(tst1) // 2
        tst1[0*N:1*N] = tst[0:2:2*N]
        tst1[1*N:2*N] = tst[1:2:2*N]
        return bytes(tst1)

def shuffle4(tst):
        tst1 = bytearray(tst).copy()
        N = len(tst1) // 4 
        tst1[0*N:1*N] = tst[0:4:4*N]
        tst1[1*N:2*N] = tst[1:4:4*N]
        tst1[2*N:3*N] = tst[2:4:4*N]
        tst1[3*N:4*N] = tst[3:4:4*N]
        return bytes(tst1)

### Compressibility plot
def CompressionPlot(signals, titles):
    #comprnames = ['zlib']; compfig=fig
    comprnames = ['zlib', 'bzip2', 'lzma' ]; compfig=plt.figure(figsize=(8, 8))
    kbarray = [ i for i in range(24)]
    nrows=len(comprnames)

    for iCompr, comprname in enumerate(comprnames):
        compsizes = [[],[]]
        compsizesshuffle = [[],[]]
        ax1 = compfig.add_subplot(nrows,2,2+2*iCompr)
        ax2 = compfig.add_subplot(nrows,2,1+2*iCompr)
        ax1.set_ylim([0.002,1])
        ax2.set_ylim([0,1])

        for iv, y in enumerate(signals):
            ##Packing
            Nbytes =  len(y) * 4
            pck8,sc,off = pack8(y)
            err8 = np.amax(np.abs((y-unpack(pck8, sc, off))))
            comp8 = compsize(pck8.tobytes(), comprname) / Nbytes

            pck16,sc,off = pack16(y)
            err16 = np.amax(np.abs(y-unpack(pck16, sc, off)))
            comp16 = compsize(pck16.tobytes(), comprname) / Nbytes
            comp16s = compsize(shuffle2(pck16.tobytes()), comprname) / Nbytes

            for keepbits in kbarray:
                trimmed = TrimPrecision(y, keepbits, 'round')
                maxerr = np.amax(np.abs(y-trimmed))
                if maxerr > err8:
                    efbits8 = keepbits
                if maxerr > err16:
                    efbits16 = keepbits
                tst = trimmed.tobytes()
                compsizes[iv].append( compsize(tst, comprname) / Nbytes )
                tst1 = shuffle4(tst)
                compsizesshuffle[iv].append( compsize(tst1, comprname) / Nbytes )


            ax1.scatter([efbits8, efbits16, efbits16], [comp8, comp16, comp16s], color=cols[iv], marker=point[iv], s=30)
            ax2.scatter([efbits8, efbits16, efbits16], [comp8, comp16, comp16s], color=cols[iv], marker=point[iv], s=30)

        for iv, (sizes, sizess) in enumerate(zip(compsizes,compsizesshuffle)):
            lab = titles[iv].split()[0]
            ax1.semilogy(kbarray, sizes, label=lab, color=cols[iv], linewidth=lw[iv])
            ax1.semilogy(kbarray, sizess, label=lab + ", shuffle", color=cols[iv], linestyle='dashed', linewidth=lw[iv])
            ax2.plot(kbarray, sizes, label=lab, color=cols[iv], linewidth=lw[iv])
            ax2.plot(kbarray, sizess, label=lab + ", shuffle", color=cols[iv], linestyle='dashed', linewidth=lw[iv])
        ax1.set_title("%s) %s, log scale"% (panels[1+2*iCompr], comprname))
        ax2.set_title("%s) %s, linear scale"% (panels[2*iCompr], comprname))
        ax1.grid()
        ax2.grid()

        if iCompr == nrows - 1:
            ax1.legend()
            for ax in [ax1, ax2]:
                ax.set_xlabel("Keep-bits")

        ax2.set_ylabel("Compressed size")
    compfig.subplots_adjust(top=0.97, bottom=0.06, hspace=0.3, wspace=0.3, left=0.07, right=0.98) # For four-panel
        
    figname="Zlib."+filetype
    compfig.savefig(figname)

    os.system('xli %s'%(figname,))
    compfig.clf()


#### Linerar function 
def plotLinear(methods):
    fig=plt.figure(figsize=(8, 5))
    discr=256
    x = np.arange(int(discr*0.9), int(discr*1.1), dtype=np.float32)

    for im, keepbits in enumerate([8,6,4,2]):
        ax = fig.add_subplot(2,2,im+1)
        ax.plot(x,x,label="ref")
#        ax.set_xlim([300,360])
        ax.set_ylim([int(discr*0.85), int(discr*1.3)])

        for method in methods[:-2]:
            y = TrimPrecision(x, keepbits, method)
            ax.plot(x,y,label=method)

        ax.title.set_text("%s) %d keepbits"%(panels[im],keepbits,))
        if im == 0:
            ax.legend()
        if im % 2 == 0:
            ax.set_ylabel("Trimmed value")
        if im > 1:
            ax.set_xlabel("Original value")
    fig.subplots_adjust(top=0.95, bottom=0.1, hspace=0.3, wspace=0.3, left=0.07, right=0.98)
    figname="Linear-%d.%s"%(int(discr),filetype)
    fig.savefig(figname)
    fig.clf()


##Structure function plots
def PlotStrucFun(sig, sigsuff, methods, ylim):
    ## Structure function of white noise
    fig=plt.figure(figsize=(8, 5))
    for im, method in enumerate(methods):
        ax = fig.add_subplot(2,3,im+1)
        
        #ax.set_ylim([35,85])
        ax.set_ylim(ylim)
        
        bg = np.round(np.mean(sig))
        strucX, strf = strucf(sig, 32)
        
        ax.plot(strucX,strf,label="ref")

        for keepbits in [8,6,4,2]:
            acut = TrimPrecision(sig, keepbits, method)
            strucX, strf = strucf(acut, 32)
            ax.plot(strucX,strf,label="%d bit"%(keepbits))

        ax.title.set_text("%s) %s"%(panels[im], method))
        if im == 0:
            ax.legend()
        if im % 3 == 0:
            ax.set_ylabel("Structure function")
        if im > 2:
            ax.set_xlabel("Offset")
        ax.grid()

    plt.subplots_adjust(top=0.95, bottom=0.1, hspace=0.3, wspace=0.15, left=0.07, right=0.98)
    figname = "Keepbits-%s-bg-%d.%s"%(sigsuff,bg,filetype)
    fig.savefig(figname)
    os.system('xli %s'%(figname,))
    fig.clf()

## Tables

def PrintTables(sigs, sigsuffs, methods):
    for  sig,signame in zip(sigs, sigsuffs):
     with open("Stats-%s-table.tex"%(signame), 'wt') as outf:
        bits = [2,4,6]
        
        outf.write('\\begin{tabular}{r r r r r}\n\\hline\n')

        #header
        outf.write("%10s"%(' ',))
        for b in bits:
            outf.write("%10s"%('& %d keep-bit'%(b,),))
        outf.write("\\\\\n\\hline\n")

        for im, method in enumerate(methods):
            outf.write("%10s"%(method))
            for bit in bits:
                acut = TrimPrecision(sig, bit, method)
                NRMSE = np.sqrt(np.mean((1.-acut/sig)**2))
                outf.write("%10s"%('& %7.5f'%(NRMSE,),))
            outf.write("\\\\\n")

        outf.write('\\end{tabular}')

def main():
    N=4096*8
    nplot=4096
    background = 380
    sigma = 50


    titles=["Correlated signal", "Uncorrelated signal"]

    sig53 = gen53(N, background, sigma)  
    rndsig = genGaussian(N, background, sigma)
    signals = [sig53, rndsig]

    plotSignals(signals, titles, nplot)
    plotLinear(methods)
    CompressionPlot(signals, titles)
    PlotStrucFun(rndsig, 'rnd', methods, [30,85])
    PlotStrucFun(sig53, 'sig53', methods, [0,32])

    PrintTables([rndsig,sig53],"rnd sig53".split(), methods)


if __name__ == "__main__":
    main()

##plt.show()
#plt.pause()


