#### ATAT Automated Timing Accordance Tool
#### Copyright Jeremy C. Ely, The University of Sheffield, 2017.
#### For details see Ely et al., in prep, GMD.

#This program is free software: you can redistribute it and/or modify
#it under the terms of the GNU General Public License as published by
#the Free Software Foundation, either version 3 of the License, or
#any later version.

#This program is distributed in the hope that it will be useful,
#but WITHOUT ANY WARRANTY; without even the implied warranty of
#MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
#GNU General Public License for more details.

#You should have received a copy of the GNU General Public License
#along with this program.  If not, see <http://www.gnu.org/licenses/>.

#### Import Modules (use pip/conda/package manager to install these prerequisits)

import sys
from netCDF4 import Dataset as ncdf 
import numpy as np
import numpy.ma as ma
import scipy.ndimage as ndi
import matplotlib.pyplot as plt
import matplotlib.colors as colors
from matplotlib.colors import Normalize
from mpl_toolkits.basemap import Basemap

##############################################################################################################################################################
######################################### Section 1. to Load Geochron data and define later used variables ###################################################

###### User definition of ice free / advance ############# 

data_type = raw_input("Type DEGLACIAL or ADVANCE for different geochronological settings: ")

#### Load geochronological data ####
if data_type == "DEGLACIAL":
    print "Running icefree dates"    
    data_cdf = ncdf(raw_input("Path to deglacial ages netcdf file: "))
elif data_type == "ADVANCE":
    print "Running ice advance dates"
    data_cdf = ncdf(raw_input("Path to ice advance ages netcdf file: "))
else:
    print "ERROR: No or incorrect geochronological context defined, options are DEGLACIAL or ADVANCE"
    sys.exit()

#### Define data dimensions and assign to an array

data_age = data_cdf.variables['deglacial_age'][0,:,:]
data_age[np.isnan(data_age[:,:])] = 0.0
data_mean_age = np.mean(data_age)
if data_mean_age > 0.0:
    data_age = 0.0 - data_age
data_age_msk = np.array(data_age == 0)
data_age_ma = ma.masked_array(data_age,data_age_msk)
data_age_ma_0 = np.nan_to_num(data_age_ma)

data_error = data_cdf.variables['error'][0,:,:]
data_error_ma = ma.masked_invalid(data_error)
#### Treat error differently, depending on ice free/ advance scenario
if data_type == "DEGLACIAL":
    data_age_error_ma = data_age_ma + data_error_ma
elif data_type == "ADVANCE":
    data_age_error_ma = data_age_ma - data_error_ma
data_age_error_ma_0 = np.nan_to_num(data_age_error_ma)


#### Define distance between age observations, for later weighting
#### note for the distance calculation to work, 0 = date, 1 = no date
data_location = np.zeros((len(data_cdf.dimensions["y1"]),len(data_cdf.dimensions["x1"])))
data_location[:,:] = data_age[:,:]
data_location[data_location == 0.0]= 1
data_location[data_location < 0]=0
data_location = data_location.astype(int)

data_location_index = zip(*np.where(data_location==0))

data_distance = np.zeros((len(data_cdf.dimensions["y1"]),len(data_cdf.dimensions["x1"])))
for x in xrange (0, len(data_location_index)):
    data_location[data_location_index[x]] = 1
    data_distance_current = ndi.distance_transform_edt(data_location)
    data_distance[data_location_index[x]] = data_distance_current[data_location_index[x]]
    data_location[data_location_index[x]] = 0

data_distance_msk = np.array(data_location > 0)
data_distance_ma = ma.masked_array(data_distance,data_distance_msk)

#### Create a distance weight, based on deviation from the mean
data_distance_weight_ma = ma.sqrt(data_distance_ma/(ma.mean(data_distance_ma)))

#### Load mapping variables from geochron data file 

lons = data_cdf.variables['lon'][0,:]
lats = data_cdf.variables['lat'][0,:]
m = Basemap(width=max(data_cdf.variables["y1"]), height= max(data_cdf.variables["x1"]),resolution='l',projection='aea',lat_0=lats.mean(),lon_0=lons.mean())
xs,ys = m(*(lons,lats))

##############################################################################################################################################################
##########################Section 3. Load ice sheet model output and define later used variables ############################################################# 

#### Load ice sheet model ####
model_cdf = ncdf(raw_input("Path to model output netcdf file: "))

#### User speficifcation of if model has thickness (THK) based (e.g. SIA model), or mask (MSK) based (e.g. SIA+SSA model) extent masks ####
extent_type = raw_input("Input THK or MSK for different model extent definitions: ")

#### Calculate ice free age ####
if data_type == "DEGLACIAL":
    model_identify_deglacial_slices = np.zeros(((len(model_cdf.dimensions["time"])-1),len(data_cdf.dimensions["y1"]),len(data_cdf.dimensions["x1"])))
    model_identify_deglacial_slices[:,:,:] = np.NAN

    if extent_type == "THK":
        #### Loop through all thicknesses to get timing
        for i in range (0, len(model_cdf.dimensions["time"])-1):
            model_thk_1 = model_cdf.variables['thk'][i,:,:]
            model_thk_2 = model_cdf.variables['thk'][i+1,:,:]
            model_deglaciated_region = ((model_thk_1 > 0) & (model_thk_2 <= 0))
            model_deglaciated_region = model_deglaciated_region * 1
            model_deglaciated_region = model_deglaciated_region * (i+1)
            model_identify_deglacial_slices[i,:,:] = model_deglaciated_region[:,:]

    if extent_type == "MSK":
        ground_mask = raw_input("Grounded ice mask value: ")
    
        for i in range (0, len(model_cdf.dimensions["time"])-1):
            model_msk_1 = model_cdf.variables['mask'][i,:,:]
            model_msk_2 = model_cdf.variables['mask'][i+1,:,:]
            model_deglaciated_region = ((model_msk_1 == ground_mask) & (model_msk_2 != groundmask))
            model_deglaciated_region = model_deglaciated_region * 1
            model_deglaciated_region = model_deglaciated_region * (i+1)
            model_identify_deglacial_slices[i,:,:] = model_deglaciated_region[:,:]

    model_deglaciated_slice = np.zeros((len(data_cdf.dimensions["y1"]),len(data_cdf.dimensions["x1"])))
    model_deglaciated_slice[:,:] = np.amax(model_identify_deglacial_slices[:,:,:], axis = 0)

    #### Conversion to time BP #####
    model_time_start = min(model_cdf.variables["time"])
    model_time_interval = abs((model_time_start - max(model_cdf.variables["time"]))/(len(model_cdf.variables["time"])-1))
    model_deglacial_age = np.zeros((len(data_cdf.dimensions["y1"]),len(data_cdf.dimensions["x1"])))
    model_deglacial_age = model_deglacial_age + model_time_start
    model_deglacial_age = model_deglacial_age + (model_deglaciated_slice * model_time_interval)
    model_deglacial_age[model_deglacial_age==model_time_start] = None
    model_deglacial_age_ma = ma.masked_invalid(model_deglacial_age)
    model_deglacial_age_ma_0 = np.nan_to_num(model_deglacial_age_ma)

#### Or calculate ice advance age #####
if data_type == "ADVANCE":
    model_identify_advance_slices = np.zeros(((len(model_cdf.dimensions["time"])-1),len(data_cdf.dimensions["y1"]),len(data_cdf.dimensions["x1"])))
    model_identify_advance_slices[:,:,:] = np.NAN

    if extent_type == "THK":
        for i in range (0, len(model_cdf.dimensions["time"])-1):
            model_thk_1 = model_cdf.variables['thk'][i,:,:]
            model_thk_2 = model_cdf.variables['thk'][i+1,:,:]
            model_advance_region = ((model_thk_1 <= 0) & (model_thk_2 > 0))
            model_advance_region = model_advance_region * 1
            model_advance_region = model_advance_region * (i+1)
            model_identify_advance_slices[i,:,:] = model_advance_region[:,:]

    if extent_type == "MSK":
        ground_mask = raw_input("Grounded ice mask value: ")
    
        for i in range (0, len(model_cdf.dimensions["time"])-1):
            model_msk_1 = model_cdf.variables['mask'][i,:,:]
            model_msk_2 = model_cdf.variables['mask'][i+1,:,:]
            model_advance_region = ((model_msk_1 != ground_mask) & (model_msk_2 == groundmask))
            model_advance_region = model_advance_region * 1
            model_advance_region = model_advance_region * (i+1)
            model_identify_advance_slices[i,:,:] = model_advance_region[:,:]

    model_advance_slice = np.zeros((len(data_cdf.dimensions["y1"]),len(data_cdf.dimensions["x1"])))
    model_advance_slice[:,:] = np.amax(model_identify_advance_slices[:,:,:], axis = 0)

    #### Conversion to time BP #####
    model_time_start = min(model_cdf.variables["time"])
    model_time_interval = abs((model_time_start - max(model_cdf.variables["time"]))/(len(model_cdf.variables["time"])-1))
    model_advance_age = np.zeros((len(data_cdf.dimensions["y1"]),len(data_cdf.dimensions["x1"])))
    model_advance_age = model_advance_age + model_time_start
    model_advance_age = model_advance_age + (model_advance_slice * model_time_interval)
    model_advance_age[model_advance_age==model_time_start] = None
    model_advance_age_ma = ma.masked_invalid(model_advance_age)
    model_advance_age_ma_0 = np.nan_to_num(model_advance_age_ma)

##############################################################################################################################################################
################## Section 4.1. For ice-free ages/retreat ages, work out if model deglaciation occurs before geochron data (account for error) ############### 

if data_type == "DEGLACIAL":

    #### Determine if have been ice covered

    comp_covered = np.zeros((len(data_cdf.dimensions["y1"]),len(data_cdf.dimensions["x1"])))
    comp_covered[:,:] = ma.where((model_deglacial_age_ma_0[:,:] < 0) & (data_age_ma_0[:,:] < 0), data_age_ma_0, 0)
    comp_covered_msk = np.array(comp_covered == 0)
    comp_covered_ma = ma.masked_array(comp_covered,comp_covered_msk)
    comp_covered_numb = ma.count(comp_covered_ma) + 0.0
    comp_covered_percent = (100/(ma.count(data_age_ma)+0.0)) * comp_covered_numb

    #### Identify covered dates that model agrees were ice free

    comp_icefree_agree = np.zeros((len(data_cdf.dimensions["y1"]),len(data_cdf.dimensions["x1"])))
    comp_icefree_agree[:,:] = ma.where((data_age_ma_0[:,:] >= model_deglacial_age_ma_0[:,:]) & (data_age_ma_0[:,:] < 0), data_age_ma_0, 0)
    comp_icefree_agree_msk = np.array(comp_icefree_agree == 0)
    comp_icefree_agree_ma = ma.masked_array(comp_icefree_agree, comp_icefree_agree_msk)
    comp_icefree_agree_numb = ma.count(comp_icefree_agree_ma) + 0.0
    comp_icefree_agree_percent = (100/comp_covered_numb) * comp_icefree_agree_numb

    #### Account for error

    comp_icefree_inerror = np.zeros((len(data_cdf.dimensions["y1"]),len(data_cdf.dimensions["x1"])))
    comp_icefree_inerror[:,:] = ma.where((data_age_error_ma_0[:,:] >= model_deglacial_age_ma_0[:,:]) & (data_age_ma_0[:,:] < 0), data_age_ma_0, 0)
    comp_icefree_inerror_msk = np.array(comp_icefree_inerror == 0)
    comp_icefree_inerror_ma = ma.masked_array(comp_icefree_inerror, comp_icefree_inerror_msk)
    comp_icefree_inerror_numb = ma.count(comp_icefree_agree_ma) + 0.0
    comp_icefree_inerror_percent = (100/comp_covered_numb) * comp_icefree_inerror_numb

    #### Identify three categories of date (1 = not covered, 2 = covered but disagree,3 = model-data agreement, 4 = model-data agreement within error)

    data_location_formap = np.zeros((len(data_cdf.dimensions["y1"]),len(data_cdf.dimensions["x1"])))
    data_location_formap[:,:] = data_age[:,:]
    data_location_formap[np.isnan(data_location_formap)] = 0
    data_location_formap[data_location_formap < 0] = 1

    comp_covered_loc = ma.where((comp_covered < 0), 1, 0) 
    comp_agree_loc = ma.where((comp_icefree_agree < 0), 1, 0)
    comp_inerror_loc = ma.where((comp_agree_loc == 0) & (comp_icefree_inerror < 0), 1, 0)

############################ Section 4.2. For advance ages, work out if model advance agrees with geochron advance ###########################################

if data_type == "ADVANCE":

    #### Determine if have been ice covered

    comp_covered = np.zeros((len(data_cdf.dimensions["y1"]),len(data_cdf.dimensions["x1"])))
    comp_covered[:,:] = ma.where((model_advance_age_ma_0[:,:] < 0) & (data_age_ma_0[:,:] < 0), data_age_ma_0, 0)
    comp_covered_msk = np.array(comp_covered == 0)
    comp_covered_ma = ma.masked_array(comp_covered,comp_covered_msk)
    comp_covered_numb = ma.count(comp_covered_ma) + 0.0
    comp_covered_percent = (100/(ma.count(data_age_ma)+0.0)) * comp_covered_numb

    #### Identify covered dates that model agrees were ice free

    comp_advance_agree = np.zeros((len(data_cdf.dimensions["y1"]),len(data_cdf.dimensions["x1"])))
    comp_advance_agree[:,:] = ma.where((data_age_ma_0[:,:] <= model_advance_age_ma_0[:,:]) & (data_age_ma_0[:,:] < 0), data_age_ma_0, 0)
    comp_advance_agree_msk = np.array(comp_advance_agree == 0)
    comp_advance_agree_ma = ma.masked_array(comp_advance_agree, comp_advance_agree_msk)
    comp_advance_agree_numb = ma.count(comp_advance_agree_ma) + 0.0
    comp_advance_agree_percent = (100/comp_covered_numb) * comp_advance_agree_numb

    #### Account for error

    comp_advance_inerror = np.zeros((len(data_cdf.dimensions["y1"]),len(data_cdf.dimensions["x1"])))
    comp_advance_inerror[:,:] = ma.where((data_age_error_ma_0[:,:] <= model_advance_age_ma_0[:,:]) & (data_age_ma_0[:,:] < 0), data_age_ma_0, 0)
    comp_advance_inerror_msk = np.array(comp_advance_inerror == 0)
    comp_advance_inerror_ma = ma.masked_array(comp_advance_inerror, comp_advance_inerror_msk)
    comp_advance_inerror_numb = ma.count(comp_advance_agree_ma) + 0.0
    comp_advance_inerror_percent = (100/comp_covered_numb) * comp_advance_inerror_numb

    #### Identify three categories of date (1 = not covered, 2 = covered but disagree,3 = model-data agreement, 4 = model-data agreement within error)

    data_location_formap = np.zeros((len(data_cdf.dimensions["y1"]),len(data_cdf.dimensions["x1"])))
    data_location_formap[:,:] = data_age[:,:]
    data_location_formap[np.isnan(data_location_formap)] = 0
    data_location_formap[data_location_formap < 0] = 1

    comp_covered_loc = ma.where((comp_covered < 0), 1, 0) 
    comp_agree_loc = ma.where((comp_advance_agree < 0), 1, 0)
    comp_inerror_loc = ma.where((comp_agree_loc == 0) & (comp_advance_inerror < 0), 1, 0)

#### Create variable for mapping
comp_categories = np.zeros((len(data_cdf.dimensions["y1"]),len(data_cdf.dimensions["x1"])))
comp_categories = data_location_formap + comp_covered_loc + comp_agree_loc + comp_inerror_loc
comp_categories = ma.masked_array(comp_categories, data_age_msk)


##############################################################################################################################################################
######################################## Section 5. Produce Map of date categories for dates: ################################################################

#### Note same categories used for both ice free and advance dates

map_date_cmap = colors.ListedColormap(['white', 'red', 'green', 'blue'],'indexed')

map_date_categories = m.pcolor(xs,ys,comp_categories, cmap=map_date_cmap)
plt.clim(1,4)

map_date_categories_colorbar = plt.colorbar(map_date_categories, orientation='vertical')
map_date_categories_colorbar.ax.get_yaxis().set_ticks([])
for j, lab in enumerate(['$1$','$2$','$3$','$4$']):
    map_date_categories_colorbar.ax.text(.5, (2 * j +1)/8.0, lab, ha='center', va='center')
map_date_categories_colorbar.ax.get_yaxis().labelpad = 15
map_date_categories_colorbar.ax.set_ylabel('Data to model agreement category', rotation=270)

m.shadedrelief()
m.drawcoastlines()
m.drawparallels(np.arange(0.,81.,5.),labels=[1,0,0,0])
m.drawmeridians(np.arange(10.,351.,5.),labels=[0,0,0,1])

plt.show()

##############################################################################################################################################################
############################################ Section 6. Calculate RMSE and weighted error between dates ######################################################

#### Calculate RMSE for all dates, regardless of category

comp_all_diff = np.zeros((len(data_cdf.dimensions["y1"]),len(data_cdf.dimensions["x1"])))
if data_type == "DEGLACIAL":
    comp_all_diff[:,:] = data_age_ma[:,:] - model_deglacial_age_ma[:,:]
elif data_type == "ADVANCE":
    comp_all_diff[:,:] = data_age_ma[:,:] - model_advance_age_ma[:,:]

comp_all_diff_ma = ma.masked_array(comp_all_diff, data_age_msk)
comp_all_rmse = ma.sqrt(ma.mean((comp_all_diff_ma*comp_all_diff_ma)))

#### Calculate weighted RMSE for all dates

comp_all_weighted_diff = np.zeros((len(data_cdf.dimensions["y1"]),len(data_cdf.dimensions["x1"])))
comp_all_weighted_diff[:,:] = comp_all_diff_ma[:,:] * data_distance_weight_ma[:,:]
comp_all_weighted_diff_ma = ma.masked_array(comp_all_weighted_diff, data_age_msk)

comp_all_weighted_rmse = ma.sqrt(ma.mean((comp_all_weighted_diff_ma*comp_all_weighted_diff_ma)))

#### Calculate RMSE for all dates that covered by ice

comp_covered_diff = np.zeros((len(data_cdf.dimensions["y1"]),len(data_cdf.dimensions["x1"])))
comp_covered_diff_msk = np.array(comp_categories >= 2)
comp_covered_diff_ma = ma.masked_array(comp_all_diff, comp_covered_msk)

comp_covered_rmse = ma.sqrt(ma.mean((comp_covered_diff_ma*comp_covered_diff_ma)))

#### Calculate weighted RMSE for all dates covered by ice

comp_covered_weighted_diff = np.zeros((len(data_cdf.dimensions["y1"]),len(data_cdf.dimensions["x1"])))
comp_covered_weighted_diff_ma = ma.masked_array(comp_all_weighted_diff, comp_covered_msk)

comp_covered_weighted_rmse = ma.sqrt(ma.mean((comp_covered_weighted_diff_ma*comp_covered_weighted_diff_ma)))

#### Calculate RMSE for all dates that are ice free within error

comp_agree_diff = np.zeros((len(data_cdf.dimensions["y1"]),len(data_cdf.dimensions["x1"])))
comp_agree_diff_msk = np.array(comp_categories >= 3)
comp_agree_diff_ma = ma.masked_array(comp_all_diff, comp_agree_diff_msk)

comp_agree_rmse = ma.sqrt(ma.mean((comp_agree_diff_ma*comp_agree_diff_ma)))

#### Calculate RMSE for all dates that are ice free within error

comp_agree_weighted_diff = np.zeros((len(data_cdf.dimensions["y1"]),len(data_cdf.dimensions["x1"])))
comp_agree_weighted_diff_ma = ma.masked_array(comp_all_weighted_diff, comp_agree_diff_msk)

comp_agree_weighted_rmse = ma.sqrt(ma.mean((comp_agree_weighted_diff_ma*comp_agree_weighted_diff_ma)))

##############################################################################################################################################################
################################## Section 7. Display map of differences  ####################################################################################

#### User input to define mapping variables
usr_norm = raw_input("Define NONE or WEIGHTED for map display of differences: ")
usr_cover = raw_input("Define ALL, COVERED or INERROR, for dates to be displayed on map: ")
data = np.zeros((len(data_cdf.dimensions["y1"]),len(data_cdf.dimensions["x1"])))

if usr_norm == "NONE":
    if usr_cover == "ALL":
        data = comp_all_diff_ma
    elif usr_cover == "COVERED":
        data = comp_covered_diff_ma
    elif usr_cover == "INERROR":
        data = comp_agree_diff_ma
elif usr_norm == "WEIGHTED":
    if usr_cover == "ALL":
        data = comp_all_weighted_diff_ma
    elif usr_cover == "COVERED":
        data = comp_covered_weighted_diff_ma
    elif usr_cover == "INERROR":
        data = comp_agree_weighted_diff_ma

#### make the map
class MidpointNormalize(Normalize):
    def __init__(self, vmin=None, vmax=None, midpoint=None, clip=False):
        self.midpoint = midpoint
        Normalize.__init__(self, vmin, vmax, clip)
        
    def __call__(self, value, clip=None):
        x, y = [self.vmin, self.midpoint, self.vmax], [0, 0.5, 1]
        return np.ma.masked_array(np.interp(value, x, y))

norm = MidpointNormalize(midpoint=0)
abs_dif_map = m.pcolor(xs, ys, data, norm=norm, cmap='seismic')

plt.colorbar(abs_dif_map, orientation='vertical')
map_date_categories_colorbar.ax.get_yaxis().set_ticks([])
map_date_categories_colorbar.ax.get_yaxis().labelpad = 15
map_date_categories_colorbar.ax.set_ylabel('Data-model difference(years)', rotation=270)

m.shadedrelief()
m.drawcoastlines()
m.drawparallels(np.arange(0.,81.,5.),labels=[1,0,0,0])
m.drawmeridians(np.arange(10.,351.,5.),labels=[0,0,0,1])

plt.show()

##############################################################################################################################################################
################################## Section 8. Save relevant metrics to .csv file. ############################################################################

header = np.array(["Number of Dates", "Percentage of dates covered", "RMSE all dates", "Weighted RMSE all dates" ,"RMSE dates covered by model", "Weighted RMSE dates covered by model", "RMSE of dates within error", "Weighted RMSE dates within error"])  
stats = np.array([ma.count(data_age_ma), comp_covered_percent, comp_all_rmse, comp_all_weighted_rmse, comp_covered_rmse, comp_covered_weighted_rmse, comp_agree_rmse, comp_agree_weighted_rmse])

data = pd.DataFrame(stats, index=header)
data.to_csv('ATAT_output.csv', header=None)
