#This script will load in the previously trained emulators, predict for some testing points
#and check these predictions against the actual simulations
#Note that this script then shows how predictions can be done in gpflow 
#(assuming diriving data and calibration parameter values for the prediction
#have been chosen)

import numpy as np
import pandas as pd
from datetime import datetime
import matplotlib.pyplot as plt
import os

import gpflow
import tensorflow as tf

#build merged emulator
#import the training data again (wherever you put this)
Emul_Data = pd.read_csv("Data\\Training_Data_standardised.csv")


#Step 1 is to load in all the emulators again 
#These wont need to be optimsied again (we have saved the optimisation results)
#But the structures do need to be rebuilt


#drop the calibration parameters that arent relevant
BT_names = [s for s in Emul_Data.columns if "BT" in s]
NT_names = [s for s in Emul_Data.columns if "NT" in s]
C3g_names = [s for s in Emul_Data.columns if "C3g" in s]
SH_names = [s for s in Emul_Data.columns if "SH" in s]
Cr_names = [s for s in Emul_Data.columns if "Cr" in s]


#%%
#######Rebuild BT emulator

#drop unanted columns
BT_Data = Emul_Data.drop(NT_names+C3g_names+SH_names+Cr_names, axis = 1)
BT_Data = BT_Data.drop(["index", "time", "year", "lon", "lat", "X", "Y", "gpp_gb", "BT", "gpp_BT"], axis=1)

#grab input
X_BT = np.array(BT_Data)

#save names 
variable_names_BT = BT_Data.columns 


#grab output
y_BT = (np.array(Emul_Data["gpp_BT"])).reshape(-1,1).astype(np.float64)

#remove data (save memory)
del(BT_Data)

#establish some attributes
N =  X_BT.shape[0] #how many samples are there
dim = X_BT.shape[1] #how many continuous dims?

#then active_vars 
active_vars = list(range(dim)) #get all the variables

   
k_BT = gpflow.kernels.RBF(input_dim = dim, active_dims = list(range(0,dim)), ARD = True)

#SVGP:
M_BT = min(10*dim, len(X_BT)) # Number of inducing locations
Z_BT = X_BT[np.random.choice(range(0, len(X_BT)), M_BT, replace=False)].copy() # Initialise inducing locations to a random M inputs in the dataset
m_BT = gpflow.models.SVGP(X = X_BT, Y = y_BT, Z = Z_BT, kern=k_BT, mean_function = None, likelihood = gpflow.likelihoods.Gaussian(), minibatch_size=1000)


#And load the optimised parameter values
import pickle
with open("place_to_save_results\\SVGP_BT.pkl", 'rb') as fp:
    param_dict = pickle.load(fp)
m_BT.assign(param_dict)



print("re-fit BT emulator")



#%%
#######Rebuild NT emulator

#drop unanted columns
NT_Data = Emul_Data.drop(BT_names+C3g_names+SH_names+Cr_names, axis = 1)
NT_Data = NT_Data.drop(["index", "time", "year", "lon", "lat", "X", "Y", "gpp_gb", "NT", "gpp_NT"], axis=1)

#grab input
X_NT = np.array(NT_Data)

#save names 
variable_names_NT = NT_Data.columns 


#grab output
y_NT = (np.array(Emul_Data["gpp_NT"])).reshape(-1,1).astype(np.float64)

#remove data (save memory)
del(NT_Data)

#establish some attributes
N =  X_NT.shape[0] #how many samples are there
dim = X_NT.shape[1] #how many continuous dims?

#then active_vars 
active_vars = list(range(dim)) #get all the variables

   
k_NT = gpflow.kernels.RBF(input_dim = dim, active_dims = list(range(0,dim)), ARD = True)

#SVGP:
M_NT = min(10*dim, len(X_NT)) # Number of inducing locations
Z_NT = X_NT[np.random.choice(range(0, len(X_NT)), M_NT, replace=False)].copy() # Initialise inducing locations to a random M inputs in the dataset
m_NT = gpflow.models.SVGP(X = X_NT, Y = y_NT, Z = Z_NT, kern=k_NT, mean_function = None, likelihood = gpflow.likelihoods.Gaussian(), minibatch_size=1000)


#Now load the optimised parameter values
import pickle
with open("place_to_save_results\\SVGP_NT.pkl", 'rb') as fp:
    param_dict = pickle.load(fp)
m_NT.assign(param_dict)



print("re-fit NT emulator")



#%%
#######Rebuild C3g emulator

#drop unanted columns
C3g_Data = Emul_Data.drop(BT_names+NT_names+SH_names+Cr_names, axis = 1)
C3g_Data = C3g_Data.drop(["index", "time", "year", "lon", "lat", "X", "Y", "gpp_gb", "C3g", "gpp_C3g"], axis=1)

#grab input
X_C3g = np.array(C3g_Data)

#save names 
variable_names_C3g = C3g_Data.columns 


#grab output
y_C3g = (np.array(Emul_Data["gpp_C3g"])).reshape(-1,1).astype(np.float64)

#remove data (save memory)
del(C3g_Data)

#establish some attributes
N =  X_C3g.shape[0] #how many samples are there
dim = X_C3g.shape[1] #how many continuous dims?

#then active_vars 
active_vars = list(range(dim)) #get all the variables

   
k_C3g = gpflow.kernels.RBF(input_dim = dim, active_dims = list(range(0,dim)), ARD = True)

#SVGP:
M_C3g = min(10*dim, len(X_C3g)) # Number of inducing locations
Z_C3g = X_C3g[np.random.choice(range(0, len(X_C3g)), M_C3g, replace=False)].copy() # Initialise inducing locations to a random M inputs in the dataset
m_C3g = gpflow.models.SVGP(X = X_C3g, Y = y_C3g, Z = Z_C3g, kern=k_C3g, mean_function = None, likelihood = gpflow.likelihoods.Gaussian(), minibatch_size=1000)


#Now load the optimised parameter values
import pickle
with open("place_to_save_results\\SVGP_C3g.pkl", 'rb') as fp:
    param_dict = pickle.load(fp)
m_C3g.assign(param_dict)



print("re-fit C3g emulator")




#%%
#######Rebuild SH emulator

#drop unanted columns
SH_Data = Emul_Data.drop(BT_names+NT_names+C3g_names+Cr_names, axis = 1)
SH_Data = SH_Data.drop(["index", "time", "year", "lon", "lat", "X", "Y", "gpp_gb", "SH", "gpp_SH"], axis=1)

#grab input
X_SH = np.array(SH_Data)

#save names 
variable_names_SH = SH_Data.columns 


#grab output
y_SH = (np.array(Emul_Data["gpp_SH"])).reshape(-1,1).astype(np.float64)

#remove data (save memory)
del(SH_Data)

#establish some attributes
N =  X_SH.shape[0] #how many samples are there
dim = X_SH.shape[1] #how many continuous dims?

#then active_vars 
active_vars = list(range(dim)) #get all the variables

   
k_SH = gpflow.kernels.RBF(input_dim = dim, active_dims = list(range(0,dim)), ARD = True)

#SVGP:
M_SH = min(10*dim, len(X_SH)) # Number of inducing locations
Z_SH = X_SH[np.random.choice(range(0, len(X_SH)), M_SH, replace=False)].copy() # Initialise inducing locations to a random M inputs in the dataset
m_SH = gpflow.models.SVGP(X = X_SH, Y = y_SH, Z = Z_SH, kern=k_SH, mean_function = None, likelihood = gpflow.likelihoods.Gaussian(), minibatch_size=1000)


#Now load the optimised parameter values
import pickle
with open("place_to_save_results\\SVGP_SH.pkl", 'rb') as fp:
    param_dict = pickle.load(fp)
m_SH.assign(param_dict)



print("re-fit SH emulator")



#%%
#######Rebuild Cr emulator

#drop unanted columns
Cr_Data = Emul_Data.drop(BT_names+NT_names+C3g_names+SH_names, axis = 1)
Cr_Data = Cr_Data.drop(["index", "time", "year", "lon", "lat", "X", "Y", "gpp_gb", "Cr", "gpp_Cr"], axis=1)

#grab input
X_Cr = np.array(Cr_Data)

#save names 
variable_names_Cr = Cr_Data.columns 


#grab output
y_Cr = (np.array(Emul_Data["gpp_Cr"])).reshape(-1,1).astype(np.float64)

#remove data (save memory)
del(Cr_Data)

#establish some attributes
N =  X_Cr.shape[0] #how many samples are there
dim = X_Cr.shape[1] #how many continuous dims?

#then active_vars 
active_vars = list(range(dim)) #get all the variables

   
k_Cr = gpflow.kernels.RBF(input_dim = dim, active_dims = list(range(0,dim)), ARD = True)

#SVGP:
M_Cr = min(10*dim, len(X_Cr)) # Number of inducing locations
Z_Cr = X_Cr[np.random.choice(range(0, len(X_Cr)), M_Cr, replace=False)].copy() # Initialise inducing locations to a random M inputs in the dataset
m_Cr = gpflow.models.SVGP(X = X_Cr, Y = y_Cr, Z = Z_Cr, kern=k_Cr, mean_function = None, likelihood = gpflow.likelihoods.Gaussian(), minibatch_size=1000)


#Now load the optimised parameter values
import pickle
with open("place_to_save_results\\SVGP_Cr.pkl", 'rb') as fp:
    param_dict = pickle.load(fp)
m_Cr.assign(param_dict)



print("re-fit Cr emulator")


#remove data (save memory)
del(Emul_Data)

#%%
####################################################################################
#####################################################################################
#Step 2 is to predict the gpp for each of these emulators (for each PFT)
#%%

#import the validation data
Pred_Data = pd.read_csv("Data\\Testing_Data_standardised.csv")


#subsample (for speed, could just test against entire valdiation set, but 1000 is a lot anyway)
#Note that this is another way in which the results can differ (a different subsample will give slightly different results)
Pred_Data = Pred_Data.iloc[np.random.randint(0, len(Pred_Data), 1000),:]


#Predict BT

#drop unanted columns
BT_Pred_Data = Pred_Data.drop(NT_names+C3g_names+SH_names+Cr_names, axis = 1)
BT_Pred_Data = BT_Pred_Data.drop(["index", "time", "year", "lon", "lat", "X", "Y", "gpp_gb", "BT", "gpp_BT"], axis=1)

#grab input
X_Pred_BT = np.array(BT_Pred_Data)

#remove data (save memory)
del(BT_Pred_Data)

BT_pred = m_BT.predict_f_full_cov(X_Pred_BT)
BT_cov = BT_pred[1][0,:,:]+np.diag(np.repeat(m_BT.as_pandas_table().iloc[3][6], BT_pred[1].shape[1]))
BT_mean = BT_pred[0]

del(BT_pred)

print("done BT predictions")


#Now do NT predictions

#drop unanted columns
NT_Pred_Data = Pred_Data.drop(BT_names+C3g_names+SH_names+Cr_names, axis = 1)
NT_Pred_Data = NT_Pred_Data.drop(["index", "time", "year", "lon", "lat", "X", "Y", "gpp_gb", "NT", "gpp_NT"], axis=1)

#grab input
X_Pred_NT = np.array(NT_Pred_Data)

#remove data (save memory)
del(NT_Pred_Data)

NT_pred = m_NT.predict_f_full_cov(X_Pred_NT)
NT_cov = NT_pred[1][0,:,:]+np.diag(np.repeat(m_NT.as_pandas_table().iloc[3][6], NT_pred[1].shape[1]))
NT_mean = NT_pred[0]

del(NT_pred)


print("done NT predictions")


#Now do C3g predictions


#drop unanted columns
C3g_Pred_Data = Pred_Data.drop(BT_names+NT_names+SH_names+Cr_names, axis = 1)
C3g_Pred_Data = C3g_Pred_Data.drop(["index", "time", "year", "lon", "lat", "X", "Y", "gpp_gb", "C3g", "gpp_C3g"], axis=1)

#grab input
X_Pred_C3g = np.array(C3g_Pred_Data)

#remove data (save memory)
del(C3g_Pred_Data)


C3g_pred = m_C3g.predict_f_full_cov(X_Pred_C3g)
C3g_cov = C3g_pred[1][0,:,:]+np.diag(np.repeat(m_C3g.as_pandas_table().iloc[3][6], C3g_pred[1].shape[1]))
C3g_mean = C3g_pred[0]

del(C3g_pred)


print("done C3g predictions")



#Now do SH predictions


#drop unanted columns
SH_Pred_Data = Pred_Data.drop(BT_names+NT_names+C3g_names+Cr_names, axis = 1)
SH_Pred_Data = SH_Pred_Data.drop(["index", "time", "year", "lon", "lat", "X", "Y", "gpp_gb", "SH", "gpp_SH"], axis=1)

#grab input
X_Pred_SH = np.array(SH_Pred_Data)

#remove data (save memory)
del(SH_Pred_Data)


SH_pred = m_SH.predict_f_full_cov(X_Pred_SH)
SH_cov = SH_pred[1][0,:,:]+np.diag(np.repeat(m_SH.as_pandas_table().iloc[3][6], SH_pred[1].shape[1]))
SH_mean = SH_pred[0]

del(SH_pred)


print("done SH predictions")



#Now do Cr predictions


#drop unanted columns
Cr_Pred_Data = Pred_Data.drop(BT_names+NT_names+C3g_names+SH_names, axis = 1)
Cr_Pred_Data = Cr_Pred_Data.drop(["index", "time", "year", "lon", "lat", "X", "Y", "gpp_gb", "Cr", "gpp_Cr"], axis=1)

#grab input
X_Pred_Cr = np.array(Cr_Pred_Data)

#remove data (save memory)
del(Cr_Pred_Data)

Cr_pred = m_Cr.predict_f_full_cov(X_Pred_Cr)
Cr_cov = Cr_pred[1][0,:,:]+np.diag(np.repeat(m_Cr.as_pandas_table().iloc[3][6], Cr_pred[1].shape[1]))
Cr_mean = Cr_pred[0]

del(Cr_pred)


print("done Cr predictions")

#%%
####################################################################################################
####################################################################################################

#Step 3 is to actually do the validation
#(check the predictions match up with the simualtions)
#%%

#get standard errors


#BT
y_valid = (np.array(Pred_Data["gpp_BT"])).reshape(-1,1).astype(np.float64)

mean_valid = BT_mean
var_valid = np.diag(BT_cov).reshape(-1,1)

#get (2sd) upper bounds
upper_valid = mean_valid + 2*np.sqrt(var_valid)

#and (2sd) lower bounds
lower_valid = mean_valid - 2*np.sqrt(var_valid)


#now check which are in the bounds
within_bounds = (lower_valid < y_valid) & (y_valid < upper_valid)

#what percentage is it (should be 95%)
credible_value = 100*sum(within_bounds)/len(within_bounds)
print(credible_value)


#get standard errors

std_errs = (y_valid - mean_valid)/np.sqrt(var_valid)



plt.subplots(figsize=(15,15))
for i in range(0,dim):
    plt.subplot(np.ceil(np.sqrt(dim)),np.ceil(np.sqrt(dim)), i+1)
    plt.scatter(X_Pred_BT[:,i], std_errs)
    plt.hlines(2, xmin = 0, xmax = 1, color = "red")
    plt.hlines(-2, xmin = 0, xmax = 1, color = "red")
    plt.hlines(0, xmin = 0, xmax = 1, linestyles='dashed')
plt.show()

#%%
#NT
y_valid = (np.array(Pred_Data["gpp_NT"])).reshape(-1,1).astype(np.float64)

mean_valid = NT_mean
var_valid = np.diag(NT_cov).reshape(-1,1)

#get (2sd) upper bounds
upper_valid = mean_valid + 2*np.sqrt(var_valid)

#and (2sd) lower bounds
lower_valid = mean_valid - 2*np.sqrt(var_valid)


#now check which are in the bounds
within_bounds = (lower_valid < y_valid) & (y_valid < upper_valid)

#what percentage is it (should be 95%)
credible_value = 100*sum(within_bounds)/len(within_bounds)
print(credible_value)


#get standard errors

std_errs = (y_valid - mean_valid)/np.sqrt(var_valid)



plt.subplots(figsize=(15,15))
for i in range(0,dim):
    plt.subplot(np.ceil(np.sqrt(dim)),np.ceil(np.sqrt(dim)), i+1)
    plt.scatter(X_Pred_NT[:,i], std_errs)
    plt.hlines(2, xmin = 0, xmax = 1, color = "red")
    plt.hlines(-2, xmin = 0, xmax = 1, color = "red")
    plt.hlines(0, xmin = 0, xmax = 1, linestyles='dashed')
plt.show()



#%%
#C3g
y_valid = (np.array(Pred_Data["gpp_C3g"])).reshape(-1,1).astype(np.float64)

mean_valid = C3g_mean
var_valid = np.diag(C3g_cov).reshape(-1,1)

#get (2sd) upper bounds
upper_valid = mean_valid + 2*np.sqrt(var_valid)

#and (2sd) lower bounds
lower_valid = mean_valid - 2*np.sqrt(var_valid)


#now check which are in the bounds
within_bounds = (lower_valid < y_valid) & (y_valid < upper_valid)

#what percentage is it (should be 95%)
credible_value = 100*sum(within_bounds)/len(within_bounds)
print(credible_value)


#get standard errors

std_errs = (y_valid - mean_valid)/np.sqrt(var_valid)



plt.subplots(figsize=(15,15))
for i in range(0,dim):
    plt.subplot(np.ceil(np.sqrt(dim)),np.ceil(np.sqrt(dim)), i+1)
    plt.scatter(X_Pred_C3g[:,i], std_errs)
    plt.hlines(2, xmin = 0, xmax = 1, color = "red")
    plt.hlines(-2, xmin = 0, xmax = 1, color = "red")
    plt.hlines(0, xmin = 0, xmax = 1, linestyles='dashed')
plt.show()




#%%
#SH
y_valid = (np.array(Pred_Data["gpp_SH"])).reshape(-1,1).astype(np.float64)

mean_valid = SH_mean
var_valid = np.diag(SH_cov).reshape(-1,1)

#get (2sd) upper bounds
upper_valid = mean_valid + 2*np.sqrt(var_valid)

#and (2sd) lower bounds
lower_valid = mean_valid - 2*np.sqrt(var_valid)


#now check which are in the bounds
within_bounds = (lower_valid < y_valid) & (y_valid < upper_valid)

#what percentage is it (should be 95%)
credible_value = 100*sum(within_bounds)/len(within_bounds)
print(credible_value)


#get standard errors

std_errs = (y_valid - mean_valid)/np.sqrt(var_valid)



plt.subplots(figsize=(15,15))
for i in range(0,dim):
    plt.subplot(np.ceil(np.sqrt(dim)),np.ceil(np.sqrt(dim)), i+1)
    plt.scatter(X_Pred_SH[:,i], std_errs)
    plt.hlines(2, xmin = 0, xmax = 1, color = "red")
    plt.hlines(-2, xmin = 0, xmax = 1, color = "red")
    plt.hlines(0, xmin = 0, xmax = 1, linestyles='dashed')
plt.show()




#%%
#Cr
y_valid = (np.array(Pred_Data["gpp_Cr"])).reshape(-1,1).astype(np.float64)

mean_valid = Cr_mean
var_valid = np.diag(Cr_cov).reshape(-1,1)

#get (2sd) upper bounds
upper_valid = mean_valid + 2*np.sqrt(var_valid)

#and (2sd) lower bounds
lower_valid = mean_valid - 2*np.sqrt(var_valid)


#now check which are in the bounds
within_bounds = (lower_valid < y_valid) & (y_valid < upper_valid)

#what percentage is it (should be 95%)
credible_value = 100*sum(within_bounds)/len(within_bounds)
print(credible_value)


#get standard errors

std_errs = (y_valid - mean_valid)/np.sqrt(var_valid)



plt.subplots(figsize=(15,15))
for i in range(0,dim):
    plt.subplot(np.ceil(np.sqrt(dim)),np.ceil(np.sqrt(dim)), i+1)
    plt.scatter(X_Pred_Cr[:,i], std_errs)
    plt.hlines(2, xmin = 0, xmax = 1, color = "red")
    plt.hlines(-2, xmin = 0, xmax = 1, color = "red")
    plt.hlines(0, xmin = 0, xmax = 1, linestyles='dashed')
plt.show()
