#
# Author: Marlene Klockmann
# in collaboration with Udo v. Toussaint (IPP) and Sebastian Riedel (DLR)
# (c) Helmholtz-Zentrum Hereon, 2022
#
# perform the embedded GP regression with the python package GPflow
#
# ========================================================================
# modules ================================================================
# ========================================================================

import argparse
import gpflow
import gpflow.training.monitor as mon
import tensorflow as tf
import numpy as np
import pickle
import os

# ========================================================================
# options and settings ===================================================
# ========================================================================

parser = argparse.ArgumentParser(description='GP regression in embedding space')
parser.add_argument('-sy', '--startyear', help='Start year of reconstruction',type=int, required=True)
parser.add_argument('-ey', '--endyear', help='End year of reconstruction',type=int, required=True)
parser.add_argument('-smt', '--similarity-threshold', help='Similarity threshold.',type=float, default=0.45)
parser.add_argument('-nem', '--noise-ens-member', type=int)
parser.add_argument('-m2k', '--mask-data', help='If True, create a data set with the temporal coverage mask from the real PAGES2k data', action='store_true')
parser.add_argument('-stdr', '--use-stdratio', help='If True, use std ratio in scaling the distance matrix of the virtual coordinates', action='store_true')

parser.add_argument('-exp', '--exp-name', help='Name of the GP experiment (and folder where to store the output and model params)',type=str, required=True)
parser.add_argument('-subd', '--data-subdir', help='subdirectory where data is stored', type=str, default='mpiesm/2k')

parser.add_argument('-gpm', '--gp-model', help='GP Model (gpr or svgp)',type=str, required=True)
parser.add_argument('-mxi', '--maxiter', help='Maximum number of iterations',type=int, default=1000)
parser.add_argument('-k1l','--time-kernel-lscale', help='inital kernel lengthscale', type=float, default=0.1)
parser.add_argument('-k2l','--vdim-kernel-lscale', help='inital kernel lengthscale',type=float, default=0.5)

parser.add_argument('-lpp', '--load-prev-params', help='If set, load previously trained model params', action='store_true')
parser.add_argument('-mon','--monitor-session',help='Enables checkpointing and monitoring via tensorboard)', action='store_true')
parser.add_argument('-re','--restore-session',help='Restore session from checkpoints (works only if monitoring and checkpointing were active)', action='store_true')
parser.add_argument('-ntm', '--normtime-mean', help='If set, normalise timestep to mean of Dij', action='store_true')
parser.add_argument('-sn', '--sigma_noise', help='Specify a value for sigma_noise',type=float, required=False)
parser.add_argument('-mb', '--mini-batches', help='Specify a mini batch size',type=int, required=False)

arg = parser.parse_args()
print(arg)

gpdir='./Data/gpreg/experiments/'
expdir=gpdir+arg.exp_name+'/'
if os.path.exists(expdir):
   print('- Directory '+expdir+' already exists. No action necessary.')
else:
   print('- Creating '+expdir+'.')
   os.makedirs(expdir)

with open(expdir+arg.exp_name+'_arg_log.txt', 'w') as text_file:
    print(arg, file=text_file)


# ========================================================================
# data prep ==============================================================
# ========================================================================

print('-------------------------------------')
print('- Loading data...')
print('-------------------------------------')
# should include time and target data. 
# time should be the 1st column, target data the last.
# the target data is the entire data set, reshaped into one vector.
# all other columns contain the virtual coordinates. 
# if you have nts timeseries, you need nq=nts-1 virtual coordinates.
# the code assumes that all data is already in form of anomalies (except for the time).

# load data for training
syr=arg.startyear; eyr=arg.endyear; smt=arg.similarity_threshold

sfx='corr'+str(smt)+'_'+str(syr)+'-'+str(eyr); rsfx=''; msfx=''

if arg.use_stdratio: 
   rsfx='_stdr'
if arg.mask_data:
   msfx='_p2kcoverage'
if arg.noise_ens_member is not None:
 
   snoise=np.sqrt(3)

   if arg.noise_ens_member < 10:
      nn='0'+str(arg.noise_ens_member)
   else:
      nn=str(arg.noise_ens_member)
   datadir='./Data/gpreg/'+arg.data_subdir+'/npp_'+sfx
   infile='gpinput'+rsfx+'_s'+str(np.round(snoise,2))+'_n'+nn+msfx+'.dat'
else:
   datadir='./Data/gpreg/'+arg.data_subdir+'/ppp_'+sfx
   infile='gpinput'+rsfx+msfx+'.dat'
   

datain=np.loadtxt(datadir+'/'+infile)
nt=datain.shape[0]; 			# number of samples
nq=datain.shape[1]-2       # number of virtual dims
gpdata=np.zeros((nt,nq+1)) # input array for GP (time+virtual coordinates)

print('- Number of virtual dims: '+str(nq))
print('- Number of observations: '+str(nt))

## load the coordinates at which you want to evaluate the GP later
# the code assumes, that you want to evaluate the GP for all timesteps
# it will use the size of the 1st dimension of gpxstar to normalise
# the timestep, if arg.normtime=True   
xdir='./Data/gpreg/'+arg.data_subdir+'/ppp_'+sfx
gpxstar=np.loadtxt(xdir+'/'+'xstar'+rsfx+'.dat')
yrs=gpxstar[:,0].reshape(gpxstar.shape[0],1).copy()

## create target array
gptarget=datain[:,-1]
# create trailing dim, otherwise GPFlow gives an error
gptarget=gptarget.reshape((nt,1)) 

## create input array
#read virtual coordinates
gpdata[:,1:]=datain[:,1:-1]

# timestep scaling:
# either timesteps are left as they are (i.e. in years), 
# or the timestep is set to the mean of the 
# distance matrix Dij (-ntm, used throughout the paper)

if arg.normtime_mean:
   print('- Normalising time')
   dijdir='./Data/gpreg/'+arg.data_subdir+'/ppp_'+sfx
   corrmat=np.loadtxt(dijdir+'/P2k_NAtl_corrma_amv150yrst.dat')
   stdrmat=np.loadtxt(dijdir+'/P2k_NAtl_stdrmat_amv150yrs.dat')
   Dij=np.ones((nq+1,nq+1))-corrmat
   if arg.use_stdratio:
      Dij=Dij*stdrmat
   dt=np.mean(Dij)
   nys=gpxstar.shape[0]
   tmin=np.min(datain[:,0]); tmax=np.max(datain[:,0])
   gpdata[:,0]=dt*nys*((datain[:,0]-tmin)/(tmax-tmin))-(dt*0.5*nys)
   gpxstar[:,0]=dt*nys*(gpxstar[:,0]-tmin)/(tmax-tmin)-(dt*0.5*nys)
else:
   gpdata[:,0]=datain[:,0]


# ========================================================================
# build, train and evaluate the gp model =================================
# ========================================================================

print('-------------------------------------')
print('- GP model ...')
print('-------------------------------------')

print('- Building ...')
print('- GP Model: '+arg.gp_model)

with gpflow.defer_build():
    
    k1 = gpflow.kernels.RBF(input_dim=1,
                            lengthscales=arg.time_kernel_lscale)
    k2 = gpflow.kernels.RBF(input_dim=nq+1,
                           lengthscales=arg.vdim_kernel_lscale)
    k=k1+k2 # combine kernels

    if arg.gp_model=='gpr':
        gpm = gpflow.models.GPR(gpdata,
                                gptarget,
                                kern=k,name='GPR')
    elif arg.gp_model=='svgp':
        lik=gpflow.likelihoods.Gaussian()
        if arg.mini_batches is not None:
            gpm = gpflow.models.SVGP(gpdata,
                                 gptarget,
                                 Z=gpdata[::10,:],
                                 kern=k,
                                 likelihood=lik,
                                 minibatch_size=arg.mini_batches, name='SVGP')
        else:
            gpm = gpflow.models.SVGP(gpdata,
                                 gptarget,
                                 Z=gpdata[::10,:],
                                 kern=k,
                                 likelihood=lik, name='SVGP')


    if arg.sigma_noise is not None:
        #this has not been tested yet
        print('- Keeping likehood variance fixed as '+str(arg.sigma_noise))
        gpm.likelihood.variance.trainable = False
        gpm.likelihood.variance = arg.sigma_noise

gpm.compile()
print('LML before the optimisation: %f' % gpm.compute_log_likelihood())

## building the monitor with tensorboard
if arg.monitor_session:    
    print('- Initialise Monitor ...')
    session = gpm.enquire_session()
    global_step = mon.create_global_step(session)

    saver_task = mon.CheckpointTask(expdir+'monitor-saves').with_name('saver')\
        .with_condition(mon.PeriodicIterationCondition(10))\
        .with_exit_condition(True)

    file_writer = mon.LogdirWriter(expdir+'model-tensorboard')

    model_tboard_task = mon.ModelToTensorBoardTask(file_writer, gpm).with_name('model_tboard')\
        .with_condition(mon.PeriodicIterationCondition(10))\
        .with_exit_condition(True)

    lml_tboard_task = mon.LmlToTensorBoardTask(file_writer, gpm).with_name('lml_tboard')\
        .with_condition(mon.PeriodicIterationCondition(100))\
        .with_exit_condition(True)

    monitor_tasks = [model_tboard_task, lml_tboard_task, saver_task]
    monitor = mon.Monitor(monitor_tasks, session, global_step)

if arg.restore_session:
    print('- Restore previous session ...')
    mon.restore_session(session, expdir+'monitor-saves')

if arg.load_prev_params and not arg.restore_session:
	# only works for perfect proxies (not yet for noise ensembles)
    with open(expdir+arg.exp_name+'_gp_params','rb') as fp:
        param_dict = pickle.load(fp)
    gpm.assign(param_dict)

print('- Training...')

if arg.monitor_session:
    print('- Check progress on Tensorboard ...')
    print('- Start Tensorboard with tensorboard --logdir <name-of-monitoring-dir>')
    optimiser = gpflow.train.AdamOptimizer(0.01)
    with mon.Monitor(monitor_tasks, session, global_step, print_summary=True) as monitor:
        optimiser.minimize(gpm, step_callback=monitor, maxiter=arg.maxiter, global_step=global_step)
    file_writer.close()

else:
    gpflow.train.ScipyOptimizer().minimize(gpm, maxiter=arg.maxiter ,disp=True)

print('LML after the optimisation: %f' % gpm.compute_log_likelihood())

# evaluate gp at locations of xstar
print('- Evaluating ...')
gpmean,gpvar=gpm.predict_y(gpxstar)
print(gpm)

# save the model params, so they can be reused
# and save GP output with original time coordinate
print('- Saving ...')
if arg.noise_ens_member is not None:
    with open(expdir+arg.exp_name+'_gp_params_n'+str(nn),'wb') as fp:
        pickle.dump(gpm.read_trainables(), fp)

    np.savetxt(expdir+arg.exp_name+'_gpmean_n'+str(nn)+'.dat',np.hstack((yrs,gpmean)),fmt='%1.4e')
    np.savetxt(expdir+arg.exp_name+'_gpvar_n'+str(nn)+'.dat',np.hstack((yrs,gpvar)),fmt='%1.4e')
else:
    with open(expdir+arg.exp_name+'_gp_params','wb') as fp:
        pickle.dump(gpm.read_trainables(), fp) 

    np.savetxt(expdir+arg.exp_name+'_gpmean.dat',np.hstack((yrs,gpmean)),fmt='%1.4e')
    np.savetxt(expdir+arg.exp_name+'_gpvar.dat',np.hstack((yrs,gpvar)),fmt='%1.4e')


