# -*- coding: utf-8 -*-
"""
plot grounded area against time for all 3 MISMIP+ experiments (7 branches)
plot the grounding lines at 0,  100, 200 years

Created on Tue Sep  8 09:29:09 2015

@author: s.l.cornford@bris.ac.uk
"""

from netCDF4 import Dataset
import numpy as np
import matplotlib.pyplot as plt

def tscale(time):
    """
    scale time to sqrt(time) to emphasize earlier times
    """
    return np.sqrt(time)
def intscale(time):
    """
    inverse of tscale
    """
    return time**2

def garplot(ncfile, label, color, marker):
    """
    add a plot of grounded area aggainst time to current axes
    """
    ncid = Dataset(ncfile, 'r')
    gar = ncid.variables["groundedArea"][:]*1e-6*1e-3
    time = ncid.variables["time"][:]
    plt.plot(tscale(time), gar, 'o-', mfc=color,
             color='black', label=label, marker=marker)
    ncid.close()
    return np.max(gar)

def glplot(ncfile, times, colora, label):
    """
    add a plot of grounding line points to current axes.
    makes use of the numpy.ma.MaskedArray when reading xGL,yGL
    """
    ncid = Dataset(ncfile, 'r')
    time = ncid.variables["time"][:]
    lxmax = 0.0
    lxmin = 800.0
    for i in range(0, len(times)):
        seq = (time == times[i])
        xGL = ncid.variables["xGL"][:, seq]*1e-3
        lxmax = max(np.max(xGL), lxmax)
        lxmin = min(np.min(xGL), lxmin)
        yGL = ncid.variables["yGL"][:, seq]*1e-3
        plt.plot(xGL, yGL, 's', ms=3, mfc=colora[i],
                 mec=colora[i], label=label + ', t = ' + format(times[i]))
    return lxmin, lxmax

plt.figure(figsize=(7, 10))

plt.subplot(211)

xmin, xmax = glplot('Ice1r-example.nc', [0, 100], ['black', 'red'], 'Ice1r')
plt.xlim([xmin-50.0, xmax+50.0])
xmin, xmax = glplot('Ice1ra-example.nc', [200], ['orange'], 'Ice1ra')
xmin, xmax = glplot('Ice1rr-example.nc', [200], ['yellow'], 'Ice1rr')

xmin, xmax = glplot('Ice2r-example.nc', [100], ['blue'], 'Ice2r')
xmin, xmax = glplot('Ice2ra-example.nc', [200], ['purple'], 'Ice2ra')
xmin, xmax = glplot('Ice2rr-example.nc', [200], ['pink'], 'Ice2rr')

plt.legend(frameon=True, borderaxespad=0, fontsize='small', loc='right')
plt.xlabel(r'$x$ (km)')
plt.ylabel(r'$y$ (km)')
#plt.savefig("example-gl.pdf")


#ax = plt.figure(figsize=(7, 5))
plt.subplot(212)
plt.plot(tscale([100, 100]), [0, 100], color="grey")
plt.plot(tscale([200, 200]), [0, 100], color="grey")
plt.xlim(tscale([0, 1000]))
plt.ylim([25, 40])

xtlocs = tscale([0, 10, 50, 100, 200, 400, 800])
plt.xticks(xtlocs, intscale(xtlocs))
plt.xlabel(r'Time,  $t$ (yr)')
plt.ylabel(r'Grounded Area (1000 km$^3$)')

#Ice0

maxa = garplot('Ice0-example.nc', 'Ice0', 'grey', 'd')

#Ice1
maxa = garplot('Ice1r-example.nc', 'Ice1r', 'red', 'o')
maxa = garplot('Ice1ra-example.nc', 'Ice1ra', 'orange', 'o')
maxa = garplot('Ice1rr-example.nc', 'Ice1rr', 'yellow', 'o')

#Ice2
maxa = garplot('Ice2r-example.nc', 'Ice2r', 'blue', 's')
maxa = garplot('Ice2ra-example.nc', 'Ice2ra', 'purple', 's')
maxa = garplot('Ice2rr-example.nc', 'Ice2rr', 'pink', 's')



plt.legend(loc='lower left', ncol=3, frameon=True,
           borderaxespad=0, fontsize='small')

plt.savefig("plot_example.pdf")
