#### ATAT Automated Timing Accordance Test
#### (C) Jeremy C. Ely, University of Sheffield.
#### For details see Ely et al., in prep, GMD.

#This program is free software: you can redistribute it and/or modify
#it under the terms of the GNU General Public License as published by
#the Free Software Foundation, either version 3 of the License, or
#any later version.

#This program is distributed in the hope that it will be useful,
#but WITHOUT ANY WARRANTY; without even the implied warranty of
#MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
#GNU General Public License for more details.

#You should have received a copy of the GNU General Public License
#along with this program.  If not, see <http://www.gnu.org/licenses/>.

#### Eight arguments specified at the command-line are required to run ATAT using the following command:
## python ATATv1.0.1.py a1 a2 a3 a4 a5 a6 a7 a8
## Whereby:
##  a1 = The type of data to be tested [DEGLACIAL or ADVANCE]
##  a2 = [Path to data file]
##  a3 = Margin definition criteria [THK or MSK]
##  a4 = [Path to model run file]
##  a5 = Value of ice extent mask, where present [An integer, or 0 if not applicable]
##  a6 = Mapping option - consider use of margin border or not in mapping [BORDER or NONE]
##  a7 = Mapping option - Plot RMSE or wRMSE [NONE or WEIGHTED]
##  a8 = Mapping option - Which category of dates to plot in difference map. [ALL, COVERED or INERROR] 

#### An example bash script is included with this software. For more details see Ely et al., in prep GMD.

#### Import Modules (use pip/conda/package manager to install these prerequisits)

import sys
from netCDF4 import Dataset as ncdf 
import numpy as np
import numpy.ma as ma
import scipy.ndimage as ndi
from skimage.morphology import square
import matplotlib.pyplot as plt
import matplotlib.colors as colors
from matplotlib.colors import Normalize
from mpl_toolkits.basemap import Basemap
import pandas as pd

#### Some expected warnings
np.warnings.filterwarnings('ignore')

##############################################################################################################################################################
######################################### Section 1. to Load Geochron data and define later used variables ###################################################

###### User definition of ice free / advance #############

sys.argv=["ATATv1.0.1.py","DEGLACIAL","deglacial_ages_subset.nc","MSK","pism_min1.nc","2","BORDER","WEIGHTED","INERROR"]

data_type = sys.argv[1]
data_cdf = ncdf(sys.argv[2])

#### Load geochronological data ####
if data_type == "DEGLACIAL":
    print "Running icefree dates"    

elif data_type == "ADVANCE":
    print "Running ice advance dates"

else:
    print "ERROR: No or incorrect geochronological context defined, options are DEGLACIAL or ADVANCE"
    sys.exit()

#### Define data dimensions and assign to an array

data_age = data_cdf.variables['deglacial_age'][:,:]
data_age[np.isnan(data_age[:,:])] = 0.0
data_mean_age = np.mean(data_age)
if data_mean_age > 0.0:
    data_age = 0.0 - data_age
data_elevation = data_cdf.variables['elevation'][:,:]
data_elevation[data_elevation==0]=np.nan
data_topography = data_cdf.variables['topg'][:,:]
data_age_msk = np.array(data_age == 0)
data_age_ma = ma.masked_array(data_age,data_age_msk)
data_elevation_ma = ma.masked_array(data_elevation, data_age_msk)
data_topography_ma = ma.masked_array(data_topography, data_age_msk)
data_age_ma_0 = np.nan_to_num(data_age_ma)
data_elevation_uncertainty = np.zeros((len(data_cdf.dimensions["y1"]),len(data_cdf.dimensions["x1"])))
data_elevation_uncertainty = abs(data_topography - data_elevation)
data_elevation_uncertainty_ma = ma.masked_array(data_elevation_uncertainty,data_age_msk)

data_error = data_cdf.variables['error'][:,:]
data_error_ma = ma.masked_invalid(data_error)
#### Treat error differently, depending on ice free/ advance scenario
if data_type == "DEGLACIAL":
    data_age_error_ma = data_age_ma + data_error_ma
elif data_type == "ADVANCE":
    data_age_error_ma = data_age_ma - data_error_ma
data_age_error_ma_0 = np.nan_to_num(data_age_error_ma)

#### Define distance between age observations, for later weighting
data_location = np.zeros((len(data_cdf.dimensions["y1"]),len(data_cdf.dimensions["x1"])))
data_location[:,:] = data_age[:,:]
data_location[data_location == 0.0]= 0
data_location[data_location < 0]=1
data_location = data_location.astype(int)
search = square(10)
data_distance_weight_ma = ndi.convolve(data_location, search, mode='wrap')

#### Load mapping variables from geochron data file 

lons = data_cdf.variables['lon'][:,:]
lats = data_cdf.variables['lat'][:,:]

#### change projection parameters here, or stick to default commented out
#m = Basemap(width=abs(max(data_cdf.variables["x1"])), height= abs(max(data_cdf.variables["y1"])),resolution='l',projection='laea',lat_0=lats.mean(),lon_0=lons.mean())
m = Basemap(width=abs(max(data_cdf.variables["x1"])), height= abs(max(data_cdf.variables["y1"])),resolution='l',projection='laea',lat_0=56.0,lon_0=-5.0)
xs,ys = m(*(lons,lats))

##############################################################################################################################################################
##########################Section 3. Load ice sheet model output and define later used variables ############################################################# 

#### Load ice sheet model ####

model_cdf = ncdf(sys.argv[4])

extent_type = sys.argv[3]
margin_uncert_size = square(3)

#### Calculate ice free age ####
if data_type == "DEGLACIAL":
    model_identify_deglacial_slices = np.zeros(((len(model_cdf.dimensions["time"])-1),len(data_cdf.dimensions["y1"]),len(data_cdf.dimensions["x1"])))
    model_identify_deglacial_slices_border = np.zeros(((len(model_cdf.dimensions["time"])-1),len(data_cdf.dimensions["y1"]),len(data_cdf.dimensions["x1"])))
    model_identify_deglacial_slices[:,:,:] = np.NAN
    model_identify_deglacial_slices_border[:,:,:] = np.NAN

    if extent_type == "THK":
        #### Loop through all thicknesses to get timing then identify neighbouring pixels to account for margin uncertainty
        for i in range (0, len(model_cdf.dimensions["time"])-1):
            model_thk_1 = model_cdf.variables['thk'][i,:,:]
            model_thk_2 = model_cdf.variables['thk'][i+1,:,:]
            model_deglaciated_region = ((model_thk_1 > 0) & (model_thk_2 <= 0))
            model_deglaciated_region = model_deglaciated_region.astype(int)
            model_deglaciated_region_border = ndi.convolve(model_deglaciated_region,margin_uncert_size,mode='wrap')
            model_surface_elevation = data_topography + model_thk_2
            model_deglaciated_thinning = ((model_surface_elevation < data_elevation)& (model_thk2 > 0))
            model_deglaciated_thinning = model_deglaciated_thinning.astype(int)
            model_deglaciated_region = model_deglaciated_region + model_deglaciated_thinning
            model_deglaciated_thinning_uncert = ((model_surface_elevation < (data_elevation+data_elevation_uncertainty))& (model_thk2 > 0))
            model_deglaciated_thinning_uncert = model_deglaciated_thinning_uncert.astype(int)
            model_deglaciated_region_border = model_deglaciated_region_border + model_deglaciated_thinning_uncert
            model_deglaciated_time = np.where((model_deglaciated_region [:,:] > 0), (model_cdf.variables["time"][i]), np.nan)
            model_deglaciated_time_border = np.where((model_deglaciated_region_border [:,:] > 0), (model_cdf.variables["time"][i+1]), np.nan)
            model_identify_deglacial_slices[i,:,:] = model_deglaciated_time[:,:]
            model_identify_deglacial_slices_border[i,:,:] = model_deglaciated_time_border[:,:]


    if extent_type == "MSK":
        ground_mask = sys.argv[5]
        ground_mask = int(ground_mask)
        for i in range (0, len(model_cdf.dimensions["time"])-1):
            model_msk_1 = model_cdf.variables['mask'][i,:,:]
            model_msk_2 = model_cdf.variables['mask'][i+1,:,:]
            model_deglaciated_region = ((model_msk_1 == ground_mask) & (model_msk_2 != ground_mask))
            model_deglaciated_region = model_deglaciated_region.astype(int)
            model_deglaciated_region_border = ndi.convolve(model_deglaciated_region,margin_uncert_size,mode='wrap')
            model_surface_elevation = data_topography + model_cdf.variables['thk'][i+1,:,:]
            model_deglaciated_thinning = ((model_surface_elevation < data_elevation)& (model_msk_2 == ground_mask))
            model_deglaciated_thinning = model_deglaciated_thinning.astype(int)
            model_deglaciated_region = model_deglaciated_region + model_deglaciated_thinning
            model_deglaciated_thinning_uncert = ((model_surface_elevation < (data_elevation+data_elevation_uncertainty))&(model_msk_2 == ground_mask))
            model_deglaciated_thinning_uncert = model_deglaciated_thinning_uncert.astype(int)
            model_deglaciated_region_border = model_deglaciated_region_border + model_deglaciated_thinning_uncert
            model_deglaciated_time = np.where((model_deglaciated_region [:,:] > 0), (model_cdf.variables["time"][i]), np.nan)
            model_deglaciated_time_border = np.where((model_deglaciated_region_border [:,:] > 0), (model_cdf.variables["time"][i+1]), np.nan)
            model_identify_deglacial_slices[i,:,:] = model_deglaciated_time[:,:]
            model_identify_deglacial_slices_border[i,:,:] = model_deglaciated_time_border[:,:]

    model_deglacial_age = np.zeros((len(data_cdf.dimensions["y1"]),len(data_cdf.dimensions["x1"])))
    model_deglacial_age[:,:] = np.nanmax(model_identify_deglacial_slices[:,:,:], axis = 0)
    model_deglacial_age_border = np.zeros((len(data_cdf.dimensions["y1"]),len(data_cdf.dimensions["x1"])))
    model_deglacial_age_border[:,:] = np.nanmax(model_identify_deglacial_slices_border[:,:,:], axis = 0)
    model_deglacial_age_ma = ma.masked_invalid(model_deglacial_age)
    model_deglacial_age_ma_0 = np.nan_to_num(model_deglacial_age_ma)
    model_deglacial_age_border_ma = ma.masked_invalid(model_deglacial_age_border)
    model_deglacial_age_border_ma_0 = np.nan_to_num(model_deglacial_age_border_ma)

#### Or calculate ice advance age #####
if data_type == "ADVANCE":
    model_identify_advance_slices = np.zeros(((len(model_cdf.dimensions["time"])-1),len(data_cdf.dimensions["y1"]),len(data_cdf.dimensions["x1"])))
    model_identify_advance_slices[:,:,:] = np.NAN
    model_idenify_advance_slices_border[:,:,:] = np.NAN

    if extent_type == "THK":
        for i in range (0, len(model_cdf.dimensions["time"])-1):
            model_thk_1 = model_cdf.variables['thk'][i,:,:]
            model_thk_2 = model_cdf.variables['thk'][i+1,:,:]
            model_advance_region = ((model_thk_1 <= 0) & (model_thk_2 > 0))
            model_advance_region = model_advance_region.astype(int)
            model_advance_time = np.where((model_advance_region [:,:] > 0), (model_cdf.variables["time"][i]), np.nan)
            model_advance_region_border = ndi.convolve(model_advance_region,margin_uncert_size,mode='wrap')
            model_advance_time_border = np.where((model_advance_region [:,:] > 0), (model_cdf.variables["time"][i]), np.nan)
            model_identify_advance_slices[i,:,:] = model_advance_region[:,:]
            model_identify_advance_slices_border[i,:,:] = model_advance_time_border[:,:]

    if extent_type == "MSK":
        ground_mask = sys.argv[5]    
        for i in range (0, len(model_cdf.dimensions["time"])-1):
            model_msk_1 = model_cdf.variables['mask'][i,:,:]
            model_msk_2 = model_cdf.variables['mask'][i+1,:,:]
            model_advance_region = ((model_msk_1 != ground_mask) & (model_msk_2 == groundmask))
            model_advance_region = model_advance_region.astype(int)
            model_advance_time = np.where((model_advance_region [:,:] > 0), (model_cdf.variables["time"][i]), np.nan)
            model_advance_region_border = ndi.convolve(model_advance_region,margin_uncert_size,mode='wrap')
            model_advance_region_border = np.where((model_advance_region_border[:,:] > 0),1,0)
            model_advance_time_border = np.where((model_advance_region [:,:] > 0), (model_cdf.variables["time"][i]), np.nan)
            model_identify_advance_slices[i,:,:] = model_advance_region[:,:]
            model_identify_advance_slices_border[i,:,:] = model_advance_time_border[:,:]

    model_advance_age = np.zeros((len(data_cdf.dimensions["y1"]),len(data_cdf.dimensions["x1"])))
    model_advance_age[:,:] = np.nanmin(model_identify_advance_slices[:,:,:], axis = 0)
    model_advance_age_border[:,:] = np.zeros((len(data_cdf.dimensions["y1"]),len(data_cdf.dimensions["x1"])))
    model_advance_age_border[:,:] = np.nanmin(model_identify_advance_slices_border[:,:,:], axis = 0)
    model_advance_age_ma = ma.masked_invalid(model_advance_age)
    model_advance_age_ma_0 = np.nan_to_num(model_advance_age_ma)
    model_advance_age_border_ma = ma.masked_invalid(model_advance_age_border)
    model_advance_age_border_ma_0 = np.nan_to_num(model_advance_age_border_ma)


##############################################################################################################################################################
################## Section 4.1. For ice-free ages/retreat ages, work out if model deglaciation occurs before geochron data (account for error) ############### 

if data_type == "DEGLACIAL":

    #### Determine if have been ice covered

    comp_covered = np.zeros((len(data_cdf.dimensions["y1"]),len(data_cdf.dimensions["x1"])))
    comp_covered[:,:] = ma.where((model_deglacial_age_ma_0[:,:] < 0) & (data_age_ma_0[:,:] < 0), data_age_ma_0, 0)
    comp_covered_msk = np.array(comp_covered == 0)
    comp_covered_ma = ma.masked_array(comp_covered,comp_covered_msk)
    comp_covered_numb = ma.count(comp_covered_ma) + 0.0
    comp_covered_percent = (100/(ma.count(data_age_ma)+0.0)) * comp_covered_numb

    #### Determine if have been ice covered within border
    comp_covered_border = np.zeros((len(data_cdf.dimensions["y1"]),len(data_cdf.dimensions["x1"])))
    comp_covered_border[:,:] = ma.where((model_deglacial_age_border_ma_0[:,:] < 0) & (data_age_ma_0[:,:] < 0), data_age_ma_0, 0)
    comp_covered_border_msk = np.array(comp_covered_border == 0)
    comp_covered_border_ma = ma.masked_array(comp_covered_border,comp_covered_border_msk)
    comp_covered_border_numb = ma.count(comp_covered_border_ma) + 0.0
    comp_covered_border_percent = (100/(ma.count(data_age_ma)+0.0)) * comp_covered_border_numb

    #### Identify covered dates that model agrees were ice free

    comp_icefree_agree = np.zeros((len(data_cdf.dimensions["y1"]),len(data_cdf.dimensions["x1"])))
    comp_icefree_agree[:,:] = ma.where((data_age_ma_0[:,:] >= model_deglacial_age_ma_0[:,:]) & (data_age_ma_0[:,:] < 0), data_age_ma_0, 0)
    comp_icefree_agree_msk = np.array(comp_icefree_agree == 0)
    comp_icefree_agree_ma = ma.masked_array(comp_icefree_agree, comp_icefree_agree_msk)
    comp_icefree_agree_numb = ma.count(comp_icefree_agree_ma) + 0.0
    comp_icefree_agree_percent = (100/comp_covered_numb) * comp_icefree_agree_numb

    #### Identify covered dates that model agrees were ice free accounting for border

    comp_icefree_agree_border = np.zeros((len(data_cdf.dimensions["y1"]),len(data_cdf.dimensions["x1"])))
    comp_icefree_agree_border[:,:] = ma.where((data_age_ma_0[:,:] >= model_deglacial_age_border_ma_0[:,:]) & (data_age_ma_0[:,:] < 0), data_age_ma_0, 0)
    comp_icefree_agree_border_msk = np.array(comp_icefree_agree_border == 0)
    comp_icefree_agree_border_ma = ma.masked_array(comp_icefree_agree_border, comp_icefree_agree_border_msk)
    comp_icefree_agree_border_numb = ma.count(comp_icefree_agree_border_ma) + 0.0
    comp_icefree_agree_border_percent = (100/comp_covered_border_numb) * comp_icefree_agree_border_numb

    #### Account for age error

    comp_icefree_inerror = np.zeros((len(data_cdf.dimensions["y1"]),len(data_cdf.dimensions["x1"])))
    comp_icefree_inerror[:,:] = ma.where((data_age_error_ma_0[:,:] >= model_deglacial_age_ma_0[:,:]) & (data_age_ma_0[:,:] < 0), data_age_ma_0, 0)
    comp_icefree_inerror_msk = np.array(comp_icefree_inerror == 0)
    comp_icefree_inerror_ma = ma.masked_array(comp_icefree_inerror, comp_icefree_inerror_msk)
    comp_icefree_inerror_numb = ma.count(comp_icefree_agree_ma) + 0.0
    comp_icefree_inerror_percent = (100/comp_covered_numb) * comp_icefree_inerror_numb

    #### Account for age error and border

    comp_icefree_inerror_border = np.zeros((len(data_cdf.dimensions["y1"]),len(data_cdf.dimensions["x1"])))
    comp_icefree_inerror_border[:,:] = ma.where((data_age_error_ma_0[:,:] >= model_deglacial_age_border_ma_0[:,:]) & (data_age_ma_0[:,:] < 0), data_age_ma_0, 0)
    comp_icefree_inerror_border_msk = np.array(comp_icefree_inerror_border == 0)
    comp_icefree_inerror_border_ma = ma.masked_array(comp_icefree_inerror_border, comp_icefree_inerror_border_msk)
    comp_icefree_inerror_border_numb = ma.count(comp_icefree_agree_border_ma) + 0.0
    comp_icefree_inerror_border_percent = (100/comp_covered_border_numb) * comp_icefree_inerror_numb

    #### Identify three categories of date (1 = not covered, 2 = covered but disagree,3 = model-data agreement, 4 = model-data agreement within error)

    data_location_formap = np.zeros((len(data_cdf.dimensions["y1"]),len(data_cdf.dimensions["x1"])))
    data_location_formap[:,:] = data_age[:,:]
    data_location_formap[np.isnan(data_location_formap)] = 0
    data_location_formap[data_location_formap < 0] = 1

    comp_covered_loc = ma.where((comp_covered < 0), 1, 0) 
    comp_agree_loc = ma.where((comp_icefree_agree < 0), 1, 0)
    comp_inerror_loc = ma.where((comp_agree_loc == 0) & (comp_icefree_inerror < 0), 1, 0)

    #### Identify three categories of date considering border (1 = not covered, 2 = covered but disagree,3 = model-data agreement, 4 = model-data agreement within error)
    
    comp_covered_loc_border = ma.where((comp_covered_border < 0), 1, 0) 
    comp_agree_loc_border = ma.where((comp_icefree_agree_border < 0), 1, 0)
    comp_inerror_loc_border = ma.where((comp_agree_loc_border == 0) & (comp_icefree_inerror_border < 0), 1, 0)

############################ Section 4.2. For advance ages, work out if model advance agrees with geochron advance ###########################################

if data_type == "ADVANCE":

    #### Determine if have been ice covered

    comp_covered = np.zeros((len(data_cdf.dimensions["y1"]),len(data_cdf.dimensions["x1"])))
    comp_covered[:,:] = ma.where((model_advance_age_ma_0[:,:] < 0) & (data_age_ma_0[:,:] < 0), data_age_ma_0, 0)
    comp_covered_msk = np.array(comp_covered == 0)
    comp_covered_ma = ma.masked_array(comp_covered,comp_covered_msk)
    comp_covered_numb = ma.count(comp_covered_ma) + 0.0
    comp_covered_percent = (100/(ma.count(data_age_ma)+0.0)) * comp_covered_numb

    #### Determine if have been ice covered within border

    comp_covered_border = np.zeros((len(data_cdf.dimensions["y1"]),len(data_cdf.dimensions["x1"])))
    comp_covered_border[:,:] = ma.where((model_advance_age_border_ma_0[:,:] < 0) & (data_age_ma_0[:,:] < 0), data_age_ma_0, 0)
    comp_covered_border_msk = np.array(comp_covered == 0)
    comp_covered_border_ma = ma.masked_array(comp_covered_border,comp_covered_border_msk)
    comp_covered_numb = ma.count(comp_covered_ma) + 0.0
    comp_covered_percent = (100/(ma.count(data_age_ma)+0.0)) * comp_covered_numb    

    #### Identify covered dates that model agrees were ice free

    comp_advance_agree = np.zeros((len(data_cdf.dimensions["y1"]),len(data_cdf.dimensions["x1"])))
    comp_advance_agree[:,:] = ma.where((data_age_ma_0[:,:] <= model_advance_age_ma_0[:,:]) & (data_age_ma_0[:,:] < 0), data_age_ma_0, 0)
    comp_advance_agree_msk = np.array(comp_advance_agree == 0)
    comp_advance_agree_ma = ma.masked_array(comp_advance_agree, comp_advance_agree_msk)
    comp_advance_agree_numb = ma.count(comp_advance_agree_ma) + 0.0
    comp_advance_agree_percent = (100/comp_covered_numb) * comp_advance_agree_numb

    #### Identify covered dates that model agrees were ice free considering border

    comp_advance_agree_border = np.zeros((len(data_cdf.dimensions["y1"]),len(data_cdf.dimensions["x1"])))
    comp_advance_agree_border[:,:] = ma.where((data_age_ma_0[:,:] <= model_advance_age_border_ma_0[:,:]) & (data_age_ma_0[:,:] < 0), data_age_ma_0, 0)
    comp_advance_agree_border_msk = np.array(comp_advance_agree_border == 0)
    comp_advance_agree_border_ma = ma.masked_array(comp_advance_agree_border, comp_advance_agree_border_msk)
    comp_advance_agree_border_numb = ma.count(comp_advance_agree_border_ma) + 0.0
    comp_advance_agree_border_percent = (100/comp_covered_border_numb) * comp_advance_agree_border_numb

    #### Account for error

    comp_advance_inerror = np.zeros((len(data_cdf.dimensions["y1"]),len(data_cdf.dimensions["x1"])))
    comp_advance_inerror[:,:] = ma.where((data_age_error_ma_0[:,:] <= model_advance_age_ma_0[:,:]) & (data_age_ma_0[:,:] < 0), data_age_ma_0, 0)
    comp_advance_inerror_msk = np.array(comp_advance_inerror == 0)
    comp_advance_inerror_ma = ma.masked_array(comp_advance_inerror, comp_advance_inerror_msk)
    comp_advance_inerror_numb = ma.count(comp_advance_agree_ma) + 0.0
    comp_advance_inerror_percent = (100/comp_covered_numb) * comp_advance_inerror_numb

    #### Account for error and border

    comp_advance_inerror_border = np.zeros((len(data_cdf.dimensions["y1"]),len(data_cdf.dimensions["x1"])))
    comp_advance_inerror_border[:,:] = ma.where((data_age_error_ma_0[:,:] <= model_advance_age_border_ma_0[:,:]) & (data_age_ma_0[:,:] < 0), data_age_ma_0, 0)
    comp_advance_inerror_border_msk = np.array(comp_advance_inerror == 0)
    comp_advance_inerror_border_ma = ma.masked_array(comp_advance_inerror, comp_advance_inerror_msk)
    comp_advance_inerror_border_numb = ma.count(comp_advance_agree_ma) + 0.0
    comp_advance_inerror_border_percent = (100/comp_covered_border_numb) * comp_advance_inerror_border_numb

    #### Identify three categories of date (1 = not covered, 2 = covered but disagree,3 = model-data agreement, 4 = model-data agreement within error)

    data_location_formap = np.zeros((len(data_cdf.dimensions["y1"]),len(data_cdf.dimensions["x1"])))
    data_location_formap[:,:] = data_age[:,:]
    data_location_formap[np.isnan(data_location_formap)] = 0
    data_location_formap[data_location_formap < 0] = 1

    comp_covered_loc = ma.where((comp_covered < 0), 1, 0) 
    comp_agree_loc = ma.where((comp_advance_agree < 0), 1, 0)
    comp_inerror_loc = ma.where((comp_agree_loc == 0) & (comp_advance_inerror < 0), 1, 0)

    #### Identify three categories of date considering border (1 = not covered, 2 = covered but disagree,3 = model-data agreement, 4 = model-data agreement within error)
    
    comp_covered_loc_border = ma.where((comp_covered_border < 0), 1, 0) 
    comp_agree_loc_border = ma.where((comp_advance_agree_border < 0), 1, 0)
    comp_inerror_loc_border = ma.where((comp_agree_loc_border == 0) & (comp_advance_inerror < 0), 1, 0)
    
#### Create variable for mapping choosing if border or not

comp_categories = np.zeros((len(data_cdf.dimensions["y1"]),len(data_cdf.dimensions["x1"])))
comp_categories = data_location_formap + comp_covered_loc + comp_agree_loc + comp_inerror_loc
comp_categories = ma.masked_array(comp_categories, data_age_msk)

comp_categories_border = np.zeros((len(data_cdf.dimensions["y1"]),len(data_cdf.dimensions["x1"])))
comp_categories_border = data_location_formap + comp_covered_loc_border + comp_agree_loc_border + comp_inerror_loc_border
comp_categories_border = ma.masked_array(comp_categories_border, data_age_msk)
comp_categories_border = ma.masked_array(comp_categories_border, data_age_msk)

##############################################################################################################################################################
######################################## Section 5. Produce Map of date categories for dates: ################################################################

#### Note same categories used for both ice free and advance dates

map_date_cmap = colors.ListedColormap(['white', 'red', 'green', 'blue'],'indexed')

border = sys.argv[6]
if border == "NONE":
#    map_date_categories = m.pcolormesh(data_cdf.variables['lon'][0,:],data_cdf.variables['lat'][:,0],comp_categories, cmap=map_date_cmap)
    map_date_categories = m.pcolormesh(xs,ys,comp_categories, cmap=map_date_cmap)
elif border == "BORDER":
#    map_date_categories = m.pcolormesh(data_cdf.variables['lon'][0,:],data_cdf.variables['lat'][:,0],comp_categories_border, cmap=map_date_cmap)
    map_date_categories = m.pcolormesh(xs,ys,comp_categories_border, cmap=map_date_cmap)
plt.clim(1,4)

map_date_categories_colorbar = plt.colorbar(map_date_categories, orientation='vertical')
map_date_categories_colorbar.ax.get_yaxis().set_ticks([])
for j, lab in enumerate(['$1$','$2$','$3$','$4$']):
    map_date_categories_colorbar.ax.text(.5, (2 * j +1)/8.0, lab, ha='center', va='center')
map_date_categories_colorbar.ax.get_yaxis().labelpad = 15
map_date_categories_colorbar.ax.set_ylabel('Data to model agreement category', rotation=270)

m.shadedrelief()
m.drawcoastlines()
m.drawparallels(np.arange(0.,81.,5.),labels=[1,0,0,0])
m.drawmeridians(np.arange(10.,351.,5.),labels=[0,0,0,1])

model_name = str(sys.argv[4])
model_name.replace ('/','_')

plt.savefig('CategoryMap' + str(data_type) + str(model_name) + '.png')

##############################################################################################################################################################
############################################ Section 6. Calculate RMSE and weighted error between dates ######################################################

#### Calculate RMSEs for all dates, regardless of category

comp_all_diff = np.zeros((len(data_cdf.dimensions["y1"]),len(data_cdf.dimensions["x1"])))
comp_all_diff_border = np.zeros((len(data_cdf.dimensions["y1"]),len(data_cdf.dimensions["x1"])))

if data_type == "DEGLACIAL":
    comp_all_diff[:,:] = data_age_ma[:,:] - model_deglacial_age_ma[:,:]
    comp_all_diff_border = data_age_ma[:,:] - model_deglacial_age_border_ma
elif data_type == "ADVANCE":
    comp_all_diff[:,:] = data_age_ma[:,:] - model_advance_age_ma[:,:]
    comp_all_diff_border[:,:] = data_age_ma[:,:] - model_advance_age_border_ma[:,:]
    
comp_all_diff_ma = ma.masked_array(comp_all_diff, data_age_msk)
comp_all_diff_border_ma = ma.masked_array(comp_all_diff_border, data_age_msk)

comp_all_rmse = ma.sqrt(ma.mean(comp_all_diff_ma*comp_all_diff_ma))
comp_all_border_rmse = ma.sqrt(ma.mean(comp_all_diff_border_ma*comp_all_diff_border_ma))

#### Calculate weighted RMSEs for all dates

comp_all_weighted_diff = np.zeros((len(data_cdf.dimensions["y1"]),len(data_cdf.dimensions["x1"])))
comp_all_weighted_diff_border = np.zeros((len(data_cdf.dimensions["y1"]),len(data_cdf.dimensions["x1"])))

comp_all_weighted_diff[:,:] = comp_all_diff_ma[:,:] * data_distance_weight_ma[:,:]
comp_all_weighted_diff_ma = ma.masked_array(comp_all_weighted_diff, data_age_msk)

comp_all_weighted_diff_border[:,:] = comp_all_diff_border_ma[:,:]*data_distance_weight_ma[:,:]
comp_all_weighted_diff_border_ma = ma.masked_array(comp_all_weighted_diff_border, data_age_msk)

comp_all_weighted_rmse = ma.sqrt(ma.mean(comp_all_weighted_diff_ma*comp_all_weighted_diff_ma))
comp_all_weighted_border_rmse = ma.sqrt(ma.mean(comp_all_weighted_diff_border_ma*comp_all_weighted_diff_border_ma))

#### Calculate RMSE for all dates that covered by ice

comp_covered_diff_msk = np.array(comp_categories >= 2)
comp_covered_diff_ma = ma.masked_array(comp_all_diff, comp_covered_msk)

comp_covered_diff_border_msk = np.array(comp_categories_border >= 2)
comp_covered_diff_border_ma = ma.masked_array(comp_all_diff, comp_covered_border_msk)

comp_covered_rmse = ma.sqrt(ma.mean(comp_covered_diff_ma*comp_covered_diff_ma))
comp_covered_border_rmse = ma.sqrt(ma.mean(comp_covered_diff_border_ma*comp_covered_diff_border_ma))

#### Calculate weighted RMSE for all dates covered by ice

comp_covered_weighted_diff_ma = ma.masked_array(comp_all_weighted_diff, comp_covered_msk)
comp_covered_weighted_diff_border_ma = ma.masked_array(comp_all_weighted_diff_border, comp_covered_msk)

comp_covered_weighted_rmse = ma.sqrt(ma.mean(comp_covered_weighted_diff_ma*comp_covered_weighted_diff_ma))
comp_covered_weighted_border_rmse = ma.sqrt(ma.mean(comp_covered_weighted_diff_border_ma * comp_covered_weighted_diff_border_ma))

#### Calculate RMSE for all dates that are ice free within error

comp_agree_diff_msk = np.array(comp_categories >= 3)
comp_agree_diff_ma = ma.masked_array(comp_all_diff, comp_agree_diff_msk)

comp_agree_diff_border_msk = np.array(comp_categories_border >=3)
comp_agree_diff_border_ma = ma.masked_array(comp_all_diff, comp_agree_diff_border_msk)

comp_agree_rmse = ma.sqrt(ma.mean(comp_agree_diff_ma*comp_agree_diff_ma))
comp_agree_border_rmse = ma.sqrt(ma.mean(comp_agree_diff_border_ma * comp_agree_diff_border_ma))

#### Calculate weighted RMSE for all dates that are ice free within error

comp_agree_weighted_diff_ma = ma.masked_array(comp_all_weighted_diff, comp_agree_diff_msk)
comp_agree_weighted_diff_border_ma = ma.masked_array(comp_all_weighted_diff_border, comp_agree_diff_border_msk)

comp_agree_weighted_rmse = ma.sqrt(ma.mean(comp_agree_weighted_diff_ma*comp_agree_weighted_diff_ma))
comp_agree_weighted_border_rmse = ma.sqrt(ma.mean(comp_agree_weighted_diff_border_ma*comp_agree_weighted_diff_border_ma))

##############################################################################################################################################################
################################## Section 7. Display map of differences  ####################################################################################
plt.close("all")
usr_norm = sys.argv[7]
usr_cover = sys.argv[8]
data = np.zeros((len(data_cdf.dimensions["y1"]),len(data_cdf.dimensions["x1"])))

if border == "NONE":
    if usr_norm == "NONE":
        if usr_cover == "ALL":
            data = comp_all_diff_ma
        elif usr_cover == "COVERED":
            data = comp_covered_diff_ma
        elif usr_cover == "INERROR":
            data = comp_agree_diff_ma
    elif usr_norm == "WEIGHTED":
        if usr_cover == "ALL":
            data = comp_all_weighted_diff_ma
        elif usr_cover == "COVERED":
            data = comp_covered_weighted_diff_ma
        elif usr_cover == "INERROR":
            data = comp_agree_weighted_diff_ma
            
elif border == "BORDER":
    if usr_norm == "NONE":
        if usr_cover == "ALL":
            data = comp_all_diff_border_ma
        elif usr_cover == "COVERED":
            data = comp_covered_diff_border_ma
        elif usr_cover == "INERROR":
            data = comp_agree_diff_border_ma
    elif usr_norm == "WEIGHTED":
        if usr_cover == "ALL":
            data = comp_all_weighted_diff_border_ma
        elif usr_cover == "COVERED":
            data = comp_covered_weighted_diff_border_ma
        elif usr_cover == "INERROR":
            data = comp_agree_weighted_diff_border_ma
            
data = ma.masked_array(data, data_age_ma.mask)

#### make the map
class MidpointNormalize(Normalize):
    def __init__(self, vmin=None, vmax=None, midpoint=None, clip=False):
        self.midpoint = midpoint
        Normalize.__init__(self, vmin, vmax, clip)
        
    def __call__(self, value, clip=None):
        x, y = [self.vmin, self.midpoint, self.vmax], [0, 0.5, 1]
        return np.ma.masked_array(np.interp(value, x, y))

norm = MidpointNormalize(midpoint=0)
abs_dif_map = m.pcolor(xs,ys, data, norm=norm, cmap='seismic')
plt.colorbar(abs_dif_map, orientation='vertical')
map_date_categories_colorbar.ax.get_yaxis().set_ticks([])
map_date_categories_colorbar.ax.get_yaxis().labelpad = 15
map_date_categories_colorbar.ax.set_ylabel('Data-model difference(years)', rotation=270)

m.shadedrelief()
m.drawcoastlines()
m.drawparallels(np.arange(0.,81.,5.),labels=[1,0,0,0])
m.drawmeridians(np.arange(10.,351.,5.),labels=[0,0,0,1])

plt.savefig('DifferenceMap' + str(data_type) + str(model_name) + '.png')
plt.close()

##############################################################################################################################################################
################################## Section 8. Save relevant metrics to .csv file. ############################################################################

#### make advance / ice free same for output
if data_type == "DEGLACIAL":
    comp_agree_percent = comp_icefree_agree_percent
    comp_agree_border_percent = comp_icefree_agree_border_percent
elif data_type == "ADVANCE":
    comp_agree_percent = comp_advance_agree_percent
    comp_agree_border_percent = comp_advance_agree_border_percent

header = np.array(["Number of Dates", "RMSE all dates", "Weighted RMSE all dates" ,"RMSE all dates model margin uncertainty", "Weighted RMSE all dates model margin uncertainty",
                   "Percentage of dates covered", "RMSE dates covered by model", "Weighted RMSE dates covered by model",
                   "Percentage of dates covered with model margin uncertainty", "RMSE dates covered with model margin uncertainty", "Weighted RMSE dates covered with model margin uncertainty",
                   "Percentage of dates within error", "RMSE of dates within error", "Weighted RMSE dates within error",
                   "Percentage of dates within error with model margin uncertainty", "RMSE of dates within error with model margin uncertainty", "Weighted RMSE dates within error with model margin uncertainty"])  
stats = np.array([ma.count(data_age_ma), comp_all_rmse, comp_all_weighted_rmse, comp_all_border_rmse, comp_all_weighted_border_rmse,
                  comp_covered_percent, comp_covered_rmse, comp_covered_weighted_rmse,
                  comp_covered_border_percent, comp_covered_border_rmse, comp_covered_weighted_border_rmse,
                  comp_agree_percent, comp_agree_rmse, comp_agree_weighted_rmse,
                  comp_agree_border_percent, comp_agree_border_rmse, comp_agree_weighted_border_rmse])

#### save data - file name refers to type of data tested and model directory (removed all instances of "/").

data = pd.DataFrame(stats, index=header)
data.to_csv("ATAT_" + str(data_type) + model_name + ".csv", header=None)
