#!/usr/bin/env python3

# This file is for plotting calibration results from mcmc or nlopt.
# FIXME NEEDS CLEANUP AND DOCUMENTATION

from pylab import *

import os
# plt.rcdefaults()
# plt.style.use('ggplot')
from sys import argv
#from glob import glob

#from itertools import product
from scipy.stats import gaussian_kde
#from scipy.stats import norm as N


def print_prob_that_some_was_accepted_after_latest(data):
    # a file for debugging...
    acc = data[:,-1]
    acceptinds = where(acc == 1.)[0]
    print('accepted:', len(acceptinds))
    print('samples drawn:', len(data))
    acceptind = acceptinds[-1]
    print('latest acceptinds:', acceptinds[-10:])
    cfs = data[acceptind+1:,-2]
    R = exp(-(cfs-data[acceptind,-2]))
    prob = 1-prod(1-R)
    print('prob that some was accepted after last:', prob)

def chainify(c):
    # create mcmc chain from data file where line format is
    # par1 ... parn cfvalue accept
    for j,x in enumerate(c):
        if x[-1] == 0.0:
            c[j] = c[j-1]
        elif x[-1] == 1.0:
            continue
        else:
            print(('unable to chainify %s', x))
            sys.exit(1)
    return c

def get_mcmc_data():
    datafile = 'mcmcresults.txt'
    plot_parnames = True
    data = loadtxt(datafile, skiprows=1, ndmin=2)
    npar = shape(data)[1] - 2

    # FIXME this section should read the names somewhere once they are available
    if plot_parnames:
        with open("parameters.txt", 'r') as f:
            parameters = f.readline().replace('#', '').split()
        print(('parameters:', parameters))
    else:
        parameters = ["param." + str(i) for i in range(npar)]

    m = data[:,-2] < np.inf # 5e9
    data = data[m]
    origdata = 1*data
    print_prob_that_some_was_accepted_after_latest(data)
    data = chainify(data)
    return data, origdata, parameters, npar, m


def get_MAP_bin(chain, nbins):
    nker = chain.shape[1]//4
    a = []
    for k in range(nker):
        h = histogramdd(chain[:,k*4:(k+1)*4], bins=(nbins,nbins,nbins,nbins))
        ind = np.unravel_index(np.argmax(h[0], axis=None), h[0].shape)
        a.append([(h[1][i][d] + h[1][i][d+1])/2 for i,d in enumerate(ind)])
    return np.reshape(np.array(a), (-1,))



def plot_file(burnin=0.):
    data, origdata, parameters, npar, m = get_mcmc_data()

    # mm = data[:,3] > 5e4 # REMEMBER TO REMOVE THIS
    # data = data[mm]
    # origdata = origdata[mm]



    MAPpars = data[argmax(-data[:,-2])]
    MAPbinpars = get_MAP_bin(data, 20)
    dataskip = int(burnin*len(data))
    data = data[dataskip:]


    medianpars = median(data, axis=0)
    origdata = origdata[dataskip:]
    fig = figure(figsize=(12,9))
    try:
        x_low = loadtxt("x_low.txt")
        x_high = loadtxt("x_high.txt")
        true_values = loadtxt("x_true.txt")
    except:
        pass
    ncols = max(2, npar//4)
    for k in range(npar):
        ax = fig.add_subplot((npar-1)/ncols+1,ncols,k+1)
        ax.yaxis.get_major_formatter().set_powerlimits((-1, 2))
        stride = max(len(data)//5000, 1)
        print(stride)
        # if type(origdata) != bool: # This is BAD: origdata  can bee bool or ndarray...
        plot(origdata[:,k][::stride],'.', alpha=0.8, markersize=.2, color='r', label='Proposed points')
        # plot(data[:,k][::stride],'-', alpha=0.7, linewidth=.5, color='g', label='MCMC chain')
        plot(data[:,k][::stride],'.', alpha=0.8, markersize=.4, color='g', label='MCMC chain')

        # Find mode by looking at 1d histograms
        h1d = histogram(data[:,k], bins=20)
        #ff = figure()
        #plot((h[1][1:] + h[1][:-1])/2, h[0], label=parameters[k])
        #show()
        print(h1d)
        a0 = argmax(h1d[0])
        ind80 = where(h1d[0] > .80*np.max(h1d[0]))
        a80 = ind80[0][0]
        print("chosen:", a0, a80)
        mm0 = (h1d[1][a0] + h1d[1][a0+1])/2
        mm = (h1d[1][a80] + h1d[1][a80+1])/2

        ax.axhline(mm0, color='cyan', linestyle='-', lw=1, label='1d MAP bin')
        ax.axhline(medianpars[k], color='magenta', linestyle='-', lw=1, label='median')

        ax.axhline(mm, color='black', linestyle='-', lw=1, label='Bin of 80% of max bin\'s samples')
        ax.axhline(MAPpars[k], color='blue', linestyle='-', lw=1, label='MAP parameter values')
        ax.axhline(MAPbinpars[k], color='red', linestyle='-', lw=1, label='Best bin of each kernel')
        try:
            ax.axhline(true_values[k], color='b', linestyle='-', lw=1, label='True value')
        except:
            pass
        plt.setp(ax.get_xticklabels(), visible=False) # no xticks
        grid(True)

        inserts = "=///////"
        t = parameters[k]
        for ii,tt in enumerate([mm0, medianpars[k], mm, MAPpars[k], MAPbinpars[k]]):
            if "_t" in parameters[k]:
                tt2 = str(around(tt/3600/24, 2)) + 'd'
            else:
                tt2 = str(around(tt, 3))
            t = t + ' ' + inserts[ii] + ' ' + tt2
        title(t)
        try:
            ax.set_ylim(x_low[k], x_high[k])
        except:
            pass
        ax.set_xlim(0, len(data)//stride)
    ax.legend()

    # suptitle(argv[1])
    return data, origdata, parameters, npar, m


def plot_single_2dmarginal(arr, ax=None):
    # arr is (2,ndata) - shaped numpy array
    xv, yv = arr.T
    xmin = xv.min(); xmax = xv.max()
    ymin = yv.min(); ymax = yv.max()
    print((xmin, xmax, ymin, ymax))
    stride = max(1, len(arr)//1200) # plot bunch of points what ever the len of arr
    # print(('arr shape:', shape(arr)))
    ax.plot(xv[::stride],yv[::stride],'.', color='black', alpha=.5, markersize=1.)

    try: # If there is not enough data, this will fail(?!?)
        KR = gaussian_kde([xv[::stride],yv[::stride]], bw_method=0.45)
    except:
        return

    xs = linspace(xmin, xmax, 100)
    ys = linspace(ymin, ymax, 100)

    xs = linspace(ax.get_xlim()[0], ax.get_xlim()[1], 100)
    ys = linspace(ax.get_ylim()[0], ax.get_ylim()[1], 100)

    from itertools import product
    X,Y = np.array(list(product(xs,ys))).T
    Z = KR.evaluate([X,Y])
    print(X[argmax(Z)], Y[argmax(Z)])
    # exit(1)
    #try: # If there is not enough data, this will fail(?!?)
    N = 1./np.sum(Z)
    #except:
    #    return

    from scipy.optimize import minimize

    def calculate_mass_above_threshold(zval):
        return np.sum(Z[Z > zval])*N

    def contourfinder(prob):
        def contour_cf(zval):
            mass = calculate_mass_above_threshold(zval)
            return (mass - prob)**2
        X = minimize(contour_cf, (np.max(Z) + np.min(Z))*.5 , method='Nelder-Mead')
        return X['x'][0]

    contourlevs = [0.85, 0.5, 0.15] # Must be in decreasing order
    # contourlevs = [0.7, 0.4, 0.1] # Must be in decreasing order
    contourz = [contourfinder(t) for t in contourlevs]
    contourlabels = [str(int(t*100)) + '%' for t in contourlevs[:2]] + ['']
    try:
        CS = ax.contour(xs, ys, Z.reshape((100,100)).T, contourz, linewidths=[3,3,3],
                        colors=['black', 'r', 'b'], zorder=2, alpha=1)
        # print("CS:", CS)
    except:
        return
    # linestyles=['-','-','-'], colors=['b', 'b', 'b']

    # Contour labels, copied from matplotlib demo
    fmt = {}
    for l, s in zip(CS.levels, contourlabels):
        fmt[l] = s

    # Labels etc. Skip for now, as the image will look cluttered otherwise
    # plt.clabel(CS, CS.levels[:-1], inline=True, fmt=fmt) #, inline_spacing=-15) #, fontsize=20,


def plot_2dmarginals(arr, parnames, burnin=0.5, corr_matrices=[], mainchainindex=None, truepars=None, bestpars=None):
    # We take npoints entries of arr from the last
    # 50%. manual_corr_matrices will be used to plot many chains'
    # parameter covariances in the same upper-triangle plot.
    # npoints = 700
    npoints = len(arr)
    # print 'FIXME CHANGE NPOINTS BACK FROM 700 TO 20000!!!'
    J = int(len(arr)*burnin)
    try:
        arr = arr[J:][::int(len(arr)*(1-burnin))/npoints] #.T
    except:
        arr = arr[J:]
    print(('arr length:', len(arr)))
    print("mean: ", np.mean(arr, axis=0))
    print("median: ", np.median(arr, axis=0))
    posteriormean = np.mean(arr, axis=0)
    # fig = figure(figsize=(24,13.5))
    scaleforfig = 2.0
    fig = figure(figsize=(10*scaleforfig,11*scaleforfig))
    # fig = figure(figsize=(16,9))
    l = shape(arr)[1]
    fsize=18
    cmap = plt.get_cmap('bwr')
    if corr_matrices == []:
        corr_matrices = [corrcoef(arr.T)[1:,:]]
    elif mainchainindex == None:
        print('''with manually added corr matrices, please give the mainchainindex
        parameter corresponding to the index of the lower triangle data in
        array "arr".''')
        sys.exit()
    else:
        print((shape(corr_matrices)))
        corr_matrices = corr_matrices[:,1:,:] # [cm[1:,:] for cm in corr_matrices]
        print((shape(corr_matrices)))
    # imshow(corrmatrix, interpolation='none', cmap=cmap, vmin=-1, vmax=1)
    # show()
    # sys.exit()
    from itertools import combinations, permutations, product
    for i,j in product(list(range(l)), list(range(l-1))): #list(permutations(xrange(l),2)):
        ax = fig.add_subplot(l+1,l,l*(j)+i+1, adjustable='box')
        # title('i,j:' + ' ' + str(i) + ', ' + str(j)) # DEBUG
        if j<i:
            #ax.set_aspect(2./3) #, adjustable='box-force')
            if j == 0:
                title(parnames[(i+1)%l], fontsize=fsize)
            try:
                # print((shape(corr_matrices)))
                # print(('type corr_matrices:', type(corr_matrices)))
                # print(('corr_matrices:', corr_matrices))
                # print(('corr_matrices[:]', corr_matrices[:]))
                if len(corr_matrices) == 1:
                    C = [corr_matrices[0][j,(i+1)%(l)]]
                else:
                    C = corr_matrices[:,j,(i+1)%(l)]
            except:
                print(('i, j, l:', i, j, l))
                print(('(i+1)%l:', (i+1)%l))
                print('FAIL!!')
                sys.exit()
            if len(C) == 1:
                cc = sign(C[0])*C[0]**2 #Let's make the color gradient a little steeper towards the end
                imshow([[cc]], vmin=-1, vmax=1, cmap=cmap)
                plt.text(0, 0, str(around(cc,2)), ha="center", va="center", fontsize=20, color=['black', 'white'][int(cc > 0.7 or cc < -0.4)])
            else: # plotting in two columns
                Cfilled = zeros(((len(C)+1)/2*2))
                Cfilled[:len(C)] = C

                # This is a stupid workaround to get non-square pixels. Assumes nx == 2
                nx, ny = (2, len(Cfilled)/2)
                print(('shape Cfilled:', shape(Cfilled)))
                Cfilled_grown = np.array([[Cfilled[k],]*nx*ny for k in range(len(Cfilled))]).reshape(nx*ny,nx,-1)
                for i in range(nx*ny-1):
                    Cfilled_grown[i][1] = Cfilled_grown[i+1][0]
                Cfilled_grown[1:][::2] = Cfilled_grown[::2]

                print(('Cfilled_grown:', Cfilled_grown.reshape((-1,ny))))
                imshow(Cfilled_grown.reshape(nx*ny,-1), vmin=-1, vmax=1, cmap=cmap)
                x1 = 1.5; x2 = 4.5; y1 = 1; y2 = 3; y3 = 5
                locations = [(x1,y1), (x2,y1), (x1,y2), (x2,y2), (x1,y3), (x2,y3)]
                for ii, xx in enumerate(C):
                    textcolor = 'black'
                    if xx > 0.7 or xx < -0.4:
                        textcolor = 'white'
                    corrstr = str(around(C[ii],2))
                    if ii == mainchainindex:
                        corrstr = r"$\mathbf{" + corrstr + r"}$"
                    else:
                        corrstr = r"$" + corrstr + r"$"
                    plt.text(locations[ii][0]-.5, locations[ii][1]-.5, corrstr,
                             ha="center", va="center",fontsize=fsize-4, color=textcolor)

            xticks([])
            yticks([])
            continue
        # ax = fig.add_subplot(l,l,l*(j-1)+i+1)
        xticks(rotation=45)
        yticks(rotation=45)
        plot_single_2dmarginal(arr[:,[i,(j+1)%l]], ax=ax)
        xlim(np.min(arr[:,i]), np.max(arr[:,i]))
        ylim(np.min(arr[:,(j+1)%l]), np.max(arr[:,(j+1)%l]))
        if truepars is not None:
            plot(truepars[i], truepars[j+1], 'X', markersize=10, color='darkorange')
        if bestpars is not None:
            plot(bestpars[i], bestpars[j+1], 'P', markersize=10, color='black')
        # plot(posteriormean[i], posteriormean[j+1], 'o', markersize=7, color='cyan')

        print((around(corrcoef(arr[:,[i,(j+1)%l]].T)[0][1],3)))
        # sys.exit()
        # plt.title(str(around(corrcoef(arr[:,[i,(j+1)%l]].T)[0][1],3)))

        # The rest is about getting text labels right etc.
        ax.yaxis.offsetText.set_fontsize(fsize-4) # exponent font sizes
        ax.xaxis.offsetText.set_fontsize(fsize-4)
        if j == l-2: # bottom
            xl = xlabel(parnames[i], fontsize=fsize, labelpad=10)
            ax.xaxis.get_major_formatter().set_powerlimits((-2, 2))
            ax.xaxis.set_label_coords(.45, -.55)
            ax.xaxis.get_offset_text().set_x(0.9)
            if i != 0:
                #ax.xaxis.get_major_formatter().set_powerlimits((-2, 2))
                plt.setp(ax.get_yticklabels(), visible=False);
        if i == 0:
            # ylabel(parnames[j][:10], fontsize=fsize)
            ylabel(parnames[j+1], fontsize=fsize, labelpad=10) # Don't truncate when using latex labels
            ax.yaxis.set_label_coords(-.55, .5)
            ax.yaxis.get_major_formatter().set_powerlimits((-2, 2))
            if j != l-2:
                plt.setp(ax.get_xticklabels(), visible=False);
        if j != l-2 and i != 0:
            plt.setp(ax.get_xticklabels(), visible=False);
            plt.setp(ax.get_yticklabels(), visible=False);
        for xlabel_i in ax.get_xticklabels():
            xlabel_i.set_fontsize(fsize-1) # number sizes on axes
        for ylabel_i in ax.get_yticklabels():
            ylabel_i.set_fontsize(fsize-1) # number sizes on axes


if __name__ == '__main__':
    os.chdir(sys.argv[1])
    data, origdata, parameters, npar, m = plot_file(burnin=.5)
    bestpars =  origdata[argsort(origdata[:,-2])[0]][:-2]
    savefig('chains.pdf', dpi=300, bbox_inches='tight', pad_inches=0.0)
    # data = loadtxt(sys.argv[1], skiprows=1)
    # npar = shape(data)[1] - 2

    # with open("parnames.txt") as f:
    #     parnames = f.readline().split()

    data = chainify(data)
    # burnin may already be applied above, so skip it here with burnin=0.0
    plot_2dmarginals(data[:,list(range(0, npar))], parameters[0:npar], burnin=0.0, truepars=loadtxt("x_true.txt"))
    savefig('2dmarginals.pdf', dpi=300, bbox_inches='tight', pad_inches=0.0)
