#
# calculate cross correlations between the pseudoproxies,
# plot all locations, remove double locations to reduce data
# size and save time,lon,lat,temperature and amv in one file
#
#
# call from command line (not ipython shell), e.g. like this:
# python amv0_select_pseudoprox.py -sy 1000 -ey 2000 -sm corr -pcm -pni -sdt
# alternatively, the script can also be called from a bash script
#
# (c) Marlene Klockmann, Helmholtz-Zentrum Hereon, 2022

import argparse
import os

import numpy as np
import matplotlib
from matplotlib import cm
from netCDF4 import Dataset
from scipy.stats.stats import pearsonr
from scipy import signal
import cartopy.crs as ccrs
import pandas as pd

## =======================================================================
# options and settings ===================================================
## =======================================================================

#command line options
parser = argparse.ArgumentParser(description='Prepare pseudoproxies based on PAGES2k network')
parser.add_argument('-sy', '--startyear', help='Start year of reconstruction', type=int, required=True)
parser.add_argument('-ey', '--endyear', help='End year of reconstruction', type=int, required=True)

parser.add_argument('-smt', '--similarity-threshold', help='Value for selecting only records with similarity value above the threshold. Default threshold is 0.45', type=float, default=0.45)
parser.add_argument('-pcm', '--plot-corr-metrics', help='Plot different metrics with regard to similarity measure: 150yr running correlation, map with ratio of variance', action='store_true')
parser.add_argument('-pni', '--plot-network-info', help='Plot different maps with information about the network', action='store_true')
parser.add_argument('-sdt', '--save-data', help='Save the selected records and period as .dat file', action='store_true')
parser.add_argument('-bj', '--batch-job', help='Script is run within a batch job (changes the graphical backend, i.e.,if set no plot windows will open)', action='store_true')
parser.add_argument('-subd', '--data-subdir', help='subdirectory where data is stored', type=str, default='mpiesm/2k')
parser.add_argument('-rmean', '--remove-mean', action='store_true')

parser.add_argument('-wne', '--white-noise', help='Create a white noise ensemble', action='store_true')
parser.add_argument('-ens', '--ensemble_size', help='Number of noise-ensemble members', type=int, default=30)

arg = parser.parse_args()
print(arg)

if arg.batch_job:
    matplotlib.use('Agg')
import matplotlib.pyplot as plt
plt.rcParams.update({'font.size': 18})

#other settings
datadir='./Data/'+arg.data_subdir 
savedir='./Data/gpreg/'+arg.data_subdir

syr=arg.startyear
eyr=arg.endyear
smt=arg.similarity_threshold

sfx='corr'+str(smt)+'_'+str(syr)+'-'+str(eyr)
   
pdir=savedir+'/ppp_'+sfx
if not os.path.exists(pdir):
      os.makedirs(pdir)

## =================================================================
# functions
## =================================================================

# relative amplitude (wrt to the smaller amplitude)
def relamp(yhat,yref):
    shat=np.std(yhat); sref=np.std(yref)
    if shat>=sref:
        relamp=shat/sref-1
    elif shat<sref:
        relamp=sref/shat-1
    return relamp


def stdratio(yhat,yref):
    shat=np.std(yhat); sref=np.std(yref)
    if shat>=sref:
        stdr=shat/sref
    elif shat<sref:
        stdr=sref/shat
    return stdr


## =================================================================
# data
## =================================================================

print('- Read data')

#start with reading complete timeseries
if arg.data_subdir=='mpiesm/2k':
   yrs=np.arange(-100,2000)
elif arg.data_subdir=='ccsm4':
   yrs=np.arange(850,2006)

#AMV
data=Dataset(datadir+'/AMV_ym.nc','r')
amv=np.asarray(data.variables['var169'][:,0,:],dtype=np.float64)-273.15
data.close()

#proxies
prnet='PAGES2k_NAtl_ym'
data1=Dataset(datadir+'/pprox_t2m_sst_'+prnet+'.nc','r')
prox=np.asarray(data1.variables['tsurf'][:,:],dtype=np.float64)-273.15
plat=np.asarray(data1.variables['lat'][:],dtype=np.float64)
plon=np.asarray(data1.variables['lon'][:],dtype=np.float64)
data1.close()

# additional metadata for the proxies in the NAtl region 
# from the PAGES2k data base 
meta=pd.read_csv("./Data/prox/nat_p2k.csv")

# sort approx. by geographical location
pz=(plat+plon)
sdx=np.argsort(pz)
plon=plon[sdx];plat=plat[sdx]; 
meta=meta.iloc[sdx]; meta=meta.set_index(pd.Index(np.arange(0,len(plon))))

temp=np.hstack((prox[:,sdx],amv)); 

# remove some points in advance to reduce network size
# points selected manually (based on start date and location) 
rap=[2,4,5,6,8,11,16,19,23,28,64]
plon=np.delete(plon,rap); plat=np.delete(plat,rap)
temp=np.delete(temp,rap,axis=1)
meta=meta.drop(rap); meta=meta.set_index(pd.Index(np.arange(0,len(plon))))

## =================================================================
# correlation metrics
## =================================================================
print('- Calculate correlations')

## calculate and plot cross-correlation matrix
corrmat=np.zeros((temp.shape[1],temp.shape[1]))*np.nan
jend=1
for ii in range(0,temp.shape[1]-1):
    for jj in range(0,jend):
           corrmat[ii,jj],p=pearsonr(temp[:,ii],temp[:,jj])
    jend+=1

# treat AMV correlations differently: use only the last 150 years
for ii in range(0,temp.shape[1]):
    corrmat[-1,ii],p=pearsonr(signal.detrend(temp[-150:,ii]),signal.detrend(temp[-150:,-1]))

if arg.plot_corr_metrics:
   plt.figure()
   cls=plt.cm.get_cmap('RdBu_r',lut=20)
   plt.pcolormesh(corrmat,cmap=cls); plt.colorbar(); plt.clim(-1,1)
   plt.title('Cross-correlation')
 
## calculate and plot relative amplitude and stdratio
rampmat=np.zeros((temp.shape[1],temp.shape[1]))*np.nan
stdrmat=np.zeros((temp.shape[1],temp.shape[1]))*np.nan
jend=1
for ii in range(0,temp.shape[1]):
    for jj in range(0,jend):
           rampmat[ii,jj]=relamp(temp[:,jj],temp[:,ii])
           stdrmat[ii,jj]=stdratio(temp[:,jj],temp[:,ii])
    jend+=1

# treat AMV relamp differently: use only the last 150 years
for ii in range(0,temp.shape[1]):
       rampmat[-1,ii]=relamp(signal.detrend(temp[-150:,ii]),signal.detrend(temp[-150:,-1]))
       stdrmat[-1,ii]=stdratio(signal.detrend(temp[-150:,ii]),signal.detrend(temp[-150:,-1]))
if arg.plot_corr_metrics:
   plt.figure()
   plt.pcolormesh(rampmat,cmap=cls); plt.colorbar(); plt.clim(-4,4)
   plt.title('Relative Amplitude')
## plot locations with relative amplitude colour coded
if arg.plot_corr_metrics:
   proj=ccrs.NearsidePerspective(central_longitude=-35, central_latitude=35)
   fig=plt.figure(figsize=(8,8))
   ax=plt.axes(projection=proj)
   ax.coastlines()
   z=rampmat[-1,:-1]
   normalize = matplotlib.colors.Normalize(vmin=-4, vmax=4)
   colors = [cls(normalize(value)) for value in z]
   ax.scatter(plon,plat,s=100,marker='v',color=colors,transform=ccrs.Geodetic())
   ax.set_global()
   ax.set_title('Relative Amplitude wrt AMV')
   cax, _ = matplotlib.colorbar.make_axes(ax)
   cbar = matplotlib.colorbar.ColorbarBase(cax, cmap=cls, norm=normalize)
  
## plot locations with similarity measure colour coded
if arg.plot_corr_metrics:
   fig=plt.figure(figsize=(8,8))
   ax=plt.axes(projection=proj)
   ax.coastlines()
   z=corrmat[-1,:-1]
   normalize = matplotlib.colors.Normalize(vmin=-1, vmax=1)
   colors = [cls(normalize(value)) for value in z]
   ax.scatter(plon,plat,s=100,marker='v',color=colors,transform=ccrs.Geodetic())
   ax.set_global()
   ax.set_title('Corr. with AMV')

## mark locations with correlations higher than the given threshold
locs=np.argwhere(corrmat[-1,:]<smt); locs=locs.reshape(len(locs))
plon=np.delete(plon,locs); plat=np.delete(plat,locs)
if arg.plot_corr_metrics:
   ax.scatter(plon,plat,s=50,marker='.',color='Black',transform=ccrs.Geodetic())
   ax.set_global()
   cax, _ = matplotlib.colorbar.make_axes(ax)
   cbar = matplotlib.colorbar.ColorbarBase(cax, cmap=cls, norm=normalize)

## remove locations with correlation below threshold
temp=np.delete(temp,locs,axis=1)
meta=meta.drop(locs); meta=meta.set_index(pd.Index(np.arange(0,len(plon))))
if arg.plot_corr_metrics:
   zz=np.delete(z,locs,axis=0).copy()
# create mask for temporal coverage
print('- Determine temporal network coverage')
tmask=np.zeros((temp[:,1:].shape))
for jj in range(0,tmask.shape[1]):
    y1=int(meta['year_start'].values[jj]); y2=int(meta['year_end'].values[jj])
    if y1>eyr: continue
    if y1<syr: y1=syr
    if y2>=eyr: y2=eyr-1
    tmask[int(np.argwhere(yrs==y1)):int(np.argwhere(yrs==y2))+1,jj]=1
  
tcov=np.sum(tmask,axis=1)

## =====================================================================
# final network
## =====================================================================

## redo cross-correlation matrix for selected network and entire period
corrmat=np.zeros((temp.shape[1],temp.shape[1]))*np.nan
rampmat=np.zeros((temp.shape[1],temp.shape[1]))*np.nan
stdrmat=np.zeros((temp.shape[1],temp.shape[1]))*np.nan

for ii in range(0,temp.shape[1]-1):
    for jj in range(0,temp.shape[1]-1):
           corrmat[ii,jj],p=pearsonr(temp[:,ii],temp[:,jj])
           rampmat[ii,jj]=relamp(temp[:,jj],temp[:,ii])
           stdrmat[ii,jj]=stdratio(temp[:,jj],temp[:,ii])

for ii in range(0,temp.shape[1]):
       corrmat[ii,-1],p=pearsonr(signal.detrend(temp[-150:,ii]),signal.detrend(temp[-150:,-1]))
       rampmat[ii,-1]=relamp(signal.detrend(temp[-150:,-1]),signal.detrend(temp[-150:,ii]))
       stdrmat[ii,-1]=stdratio(signal.detrend(temp[-150:,-1]),signal.detrend(temp[-150:,ii]))
for jj in range(0,temp.shape[1]):
       corrmat[-1,jj],p=pearsonr(signal.detrend(temp[-150:,-1]),signal.detrend(temp[-150:,jj]))
       rampmat[-1,jj]=relamp(signal.detrend(temp[-150:,jj]),signal.detrend(temp[-150:,-1]))
       stdrmat[-1,jj]=stdratio(signal.detrend(temp[-150:,jj]),signal.detrend(temp[-150:,-1]))

if arg.plot_corr_metrics:
   plt.figure()
   cls=plt.cm.get_cmap('RdBu_r',lut=20)
   plt.pcolormesh(corrmat,cmap=cls); plt.colorbar();plt.clim(-1,1)
   plt.title('Cross-Correlation')

   plt.figure()
   cls=plt.cm.get_cmap('RdBu_r',lut=20)
   plt.pcolormesh(rampmat,cmap=cls); plt.colorbar();plt.clim(-4,4)
   plt.title('Relative Amplitude')

#plot locations with similarity measure colour coded
if arg.plot_network_info:
   proj=ccrs.NearsidePerspective(central_longitude=-35, central_latitude=35)
   fig=plt.figure(figsize=(8,8))
   ax=plt.axes(projection=proj)
   ax.coastlines()
   cls=plt.cm.get_cmap('RdBu_r',lut=20)
   z=corrmat[-1,:-1]
   normalize = matplotlib.colors.Normalize(vmin=-1, vmax=1)
   colors = [cls(normalize(value)) for value in z]
   ax.scatter(plon,plat,s=100,marker='v',color=colors,transform=ccrs.Geodetic())
   ax.set_global()
   cax, _ = matplotlib.colorbar.make_axes(ax)
   cbar = matplotlib.colorbar.ColorbarBase(cax, cmap=cls, norm=normalize)

   # plot locations with start_year colour coded
   fig=plt.figure(figsize=(8,8))
   ax=plt.axes(projection=proj)
   ax.coastlines()
   cls=plt.cm.get_cmap('jet',lut=21)
   z=meta['year_start'].values
   normalize = matplotlib.colors.Normalize(vmin=0, vmax=2000)
   colors = [cls(normalize(value)) for value in z]
   ax.scatter(plon,plat,s=100,marker='v',color=colors,transform=ccrs.Geodetic())
   ax.set_global()
   ax.set_title('Start year')
   cax, _ = matplotlib.colorbar.make_axes(ax)
   cbar = matplotlib.colorbar.ColorbarBase(cax, cmap=cls, norm=normalize)

   # plot locations with end_year colour coded
   fig=plt.figure(figsize=(8,8))
   ax=plt.axes(projection=proj)
   ax.coastlines()
   cls=plt.cm.get_cmap('jet',lut=21)
   z=meta['year_end'].values
   normalize = matplotlib.colors.Normalize(vmin=1950, vmax=2000)
   colors = [cls(normalize(value)) for value in z]
   ax.scatter(plon,plat,s=100,marker='v',color=colors,transform=ccrs.Geodetic())
   ax.set_global()
   ax.set_title('End year')
   cax, _ = matplotlib.colorbar.make_axes(ax)
   cbar = matplotlib.colorbar.ColorbarBase(cax, cmap=cls, norm=normalize)
   
# plot network size over time
plt.figure(figsize=(8,3)); plt.plot(yrs,tcov)
plt.xlabel('Year (CE)'); plt.ylabel('# of records')
plt.ylim(0,30); 
plt.xlim(syr,eyr)
plt.tight_layout()

print('-------------------------------')
print('Number of selected sites: '+ str(len(plon)))
print('Startyear is '+str(syr)+', endyear is '+str(eyr))
print('Total number of points within P2k is '+ str(np.sum(tcov[np.argwhere((yrs>=syr) & (yrs<=eyr)).squeeze()])))
print('Total number of points with complete coverage is '+ str(tmask[np.argwhere((yrs>=syr) & (yrs<=eyr)).squeeze(),:].size))
print('-------------------------------')

## write out datasets (perfect PPs)
if arg.remove_mean:
   temp-=np.mean(temp,axis=0)

if arg.save_data:
   print('Saving selected records...')
   data=np.hstack((yrs.reshape(len(yrs),1),temp))
   np.savetxt(pdir+'/P2k_NAtl_data_amv150yrs.dat',data[np.argwhere((yrs>=syr)&(yrs<=eyr)).squeeze(),:],fmt='%1.4e')
   np.savetxt(pdir+'/P2k_NAtl_mask_amv150yrs.dat',tmask[np.argwhere((yrs>=syr)&(yrs<=eyr)).squeeze(),:],fmt='%1.4e')
   np.savetxt(pdir+'/P2k_NAtl_corrma_amv150yrst.dat',corrmat,fmt='%1.4e')
   np.savetxt(pdir+'/P2k_NAtl_rampmat_amv150yrs.dat',rampmat,fmt='%1.4e')
   np.savetxt(pdir+'/P2k_NAtl_stdrmat_amv150yrs.dat',stdrmat,fmt='%1.4e')
   np.savetxt(pdir+'/P2k_NAtl_latlon_amv150yrst.dat',np.vstack((plat,plon)),fmt='%1.4e')


## if desired, create an ensemble of noisy PPs
if arg.white_noise:
   print('Creating a '+str(arg.ensemble_size)+'-member white-noise ensemble')

   snoise=np.sqrt(3)
   
   def pnoise(prox,amp):
       noise=np.random.randn(prox.shape[0],prox.shape[1])*amp*np.std(prox,axis=0)
       prox=prox+noise
       return prox

   ndir=savedir+'/npp_'+sfx
   if not os.path.exists(ndir):
      os.makedirs(ndir)

   prox=data[np.argwhere((yrs>=syr)&(yrs<=eyr)).squeeze(),1:-1]
   yy=data[np.argwhere((yrs>=syr)&(yrs<=eyr)).squeeze(),0]; yy=yy.reshape(len(yy),1)
   aa=data[np.argwhere((yrs>=syr)&(yrs<=eyr)).squeeze(),-1]; aa=aa.reshape(len(aa),1)
     
   for it in range(0,arg.ensemble_size):
       if it+1 < 10:
          nn='0'+str(it+1)
       else:
          nn=str(it+1)
       proxwn=pnoise(prox,snoise)
       datawn=np.hstack((yy,proxwn,aa))
       np.savetxt(ndir+'/P2k_NAtl_data_amv150yrs_s'+str(np.round(snoise,2))+'_n'+nn+'.dat',datawn,fmt='%1.4e')

if not arg.batch_job:
   plt.show()

