#
# Author: Marlene Klockmann
# in collaboration with Udo v. Toussaint (IPP) and Sebastian Riedel (DLR)
# (c) Helholtz-Zentrum Hereon, 2022
#
# Put data in the right format for the GP program
#
# Reads in temperature and climate index timeseries and prepares the 
# input files for the GP regression. Time is the 1st column, target data 
# (here: temperature and AMV) the last. The timeseries are stacked on top 
# of each other and can have different length. All other columns contain 
# the virtual coordinates. if you have nts timeseries, you need nq=nts-1 
# virtual coordinates. Currently, the virtual coordinates from 
# Multidimensional scaling with a given distance matrix Dij 
# based on cross-correlation and standard-deviation ratios
#
# ==================================================================
# modules ==========================================================
# ==================================================================

import numpy as np
from sklearn.manifold import MDS
import argparse
import sys

## =================================================================
# settings =========================================================
## =================================================================
parser = argparse.ArgumentParser(description='Prepare virtual coordinates for the GP regression')
parser.add_argument('-sy', '--startyear', help='Start year of reconstruction', type=int, required=True)
parser.add_argument('-ey', '--endyear', help='End year of reconstruction', type=int, required=True)
parser.add_argument('-ty', '--testyear', help='Year that devides training and testing parts of the AMV timeseries', type=int, required=True)
parser.add_argument('-smt', '--similarity-threshold', help='Value for selecting only records with similarity value above the threshold. Default threshold is 0.45', type=float, default=0.45)
parser.add_argument('-m2k', '--mask-data', help='If True, create a data set with the temporal coverage mask from the real PAGES2k data', action='store_true')
parser.add_argument('-stdr', '--use-stdratio', help='If True, use std ratio in scaling the distance matrix of the virtual coordinates', action='store_true')
parser.add_argument('-wne', '--white-noise', help='Create a white noise ensemble', action='store_true')
parser.add_argument('-ens', '--ensemble_size', help='Number of noise-ensemble members', type=int, default=20)
parser.add_argument('-subd', '--data-subdir', help='subdirectory where data is stored', type=str, default='mpiesm/2k')
parser.add_argument('-rmean', '--mean-removed', action='store_true')

arg = parser.parse_args()
print(arg)

syr=arg.startyear
eyr=arg.endyear
yrec=arg.testyear
smt=arg.similarity_threshold

sfx='corr'+str(smt)+'_'+str(syr)+'-'+str(eyr)

## =================================================================
# read data ========================================================
## =================================================================

datadir='./Data/gpreg/'+arg.data_subdir+'/ppp_'+sfx
corrmat=np.loadtxt(datadir+'/P2k_NAtl_corrma_amv150yrst.dat')
stdrmat=np.loadtxt(datadir+'/P2k_NAtl_stdrmat_amv150yrs.dat')
data=np.loadtxt(datadir+'/P2k_NAtl_data_amv150yrs.dat')
if arg.mask_data:
   tmask=np.loadtxt(datadir+'/P2k_NAtl_mask_amv150yrs.dat')

nts=data.shape[1]-1   # number of time series (iproxy records + AMV)
nq=nts-1              # number of necessary virtual dims
nt=data.shape[0]      # length of one individual complete timeseries
nd=nts*nt             # number of all samples in time-virtual space

# remove means 
if not arg.mean_removed:
   data[:,1:]-=np.mean(data[:,1:],axis=0)  

#reshape mask array
if arg.mask_data:
   gpmask=tmask.reshape(tmask.size,order='F')


## =================================================================
# prepare the virtual space ========================================
## =================================================================

print('-Creating the virtual coordinates.')
## coordinates from distance matrix 
## distances are set by cross-correlations and STD ratio (optional)

# scaling
Dij=np.ones((nts,nts))-corrmat
rsfx=''

if arg.use_stdratio:
   Dij=Dij*stdrmat
   rsfx='_stdr'

# set diagonal to zero
for d in range(0,nts):
    Dij[d,d]=0
# calculate coordinates from distance matrix Dij via Multi-dimensional Scaling
embedding=MDS(n_components=nq,dissimilarity='precomputed')
qq=embedding.fit_transform(Dij)

# create input array
# (all proxy records+AMV stacked on top of each other)
gpinput=np.zeros((nd,nts+1))
gpinput[:,0]=np.tile(data[:,0],data.shape[1]-1)
gpinput[:,-1]=data[:,1:].reshape(nd,order='F')

# put coordinates into gpinput array
for tt in range(0,nts):
    gpinput[tt*nt:(tt+1)*nt,1:-1]=qq[tt,:]

## write out xstar, i.e. locations for which the GP should be evaluated later 
# i.e. the coordinates of the AMV timeseries in nq-space at all points in time
gpxstar=gpinput[(nts-1)*nt:,:-1]
np.savetxt(datadir+'/xstar'+rsfx+'.dat',gpxstar,fmt='%1.5e')

# create, mask and save noisy gpinput (NPPs)
if arg.white_noise:   

   print('-Creating the '+str(arg.ensemble_size)+'noisy inputs first.')
   ndir='./gpreg/'+arg.data_subdir+'/npp_'+sfx
   
   snoise=np.sqrt(3)
 
   for it in range(0,arg.ensemble_size):
       gpcopy=gpinput.copy()
       if it+1 < 10:
          nn='0'+str(it+1)
       else:
          nn=str(it+1)
       npp=np.loadtxt(ndir+'/P2k_NAtl_data_amv150yrs_s'+str(np.round(snoise,2))+'_n'+nn+'.dat')
       npp[:,1:]-=np.mean(npp[:,1:],axis=0)  
       gpcopy[:,-1]=npp[:,1:].reshape(nd,order='F')
       
       #mask AMV and save noisy data
       gpcopy=np.delete(gpcopy,range((nts-1)*nt,(nts-1)*nt+int(np.argwhere(data[:,0]==yrec))),axis=0)
       np.savetxt(ndir+'/gpinput'+rsfx+'_s'+str(np.round(snoise,2))+'_n'+nn+'.dat',gpcopy,fmt='%1.5e')

       #mask with tmask
       if arg.mask_data:
          gpcopy=np.delete(gpcopy,np.argwhere(gpmask==0),axis=0)
          np.savetxt(ndir+'/gpinput'+rsfx+'_s'+str(np.round(snoise,2))+'_n'+nn+'_p2kcoverage.dat',gpcopy,fmt='%1.5e')
          
## mask and save perfect gpinput (PPPs)
# 1. mask AMV for training and testing data
print('-Saving the perfect inputs.')
gpinput=np.delete(gpinput,range((nts-1)*nt,(nts-1)*nt+int(np.argwhere(data[:,0]==yrec))),axis=0)
# save with only AMV masked
np.savetxt(datadir+'/gpinput'+rsfx+'.dat',gpinput,fmt='%1.5e')

# 2. mask also according to tmask
if arg.mask_data:
   gpinput=np.delete(gpinput,np.argwhere(gpmask==0),axis=0)
   # save with actual P2k coverage
   np.savetxt(datadir+'/gpinput'+rsfx+'_p2kcoverage.dat',gpinput,fmt='%1.5e')





