#!/usr/bin/env python

import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
import numpy as np
import numpy.ma as ma
import netCDF4
import cartopy
import cartopy.crs as ccrs
import string

def make_colormap(seq):
    """Return a LinearSegmentedColormap
    seq: a sequence of floats and RGB-tuples. The floats should be increasing
    and in the interval (0,1).
    """
    seq = [(None,) * 3, 0.0] + list(seq) + [1.0, (None,) * 3]
    cdict = {'red': [], 'green': [], 'blue': []}
    for i, item in enumerate(seq):
        if isinstance(item, float):
            r1, g1, b1 = seq[i - 1]
            r2, g2, b2 = seq[i + 1]
            cdict['red'].append([item, r1, r2])
            cdict['green'].append([item, g1, g2])
            cdict['blue'].append([item, b1, b2])
    return mcolors.LinearSegmentedColormap('CustomMap', cdict)


#==================
# Load datasets
f1 = netCDF4.Dataset('ocn_diags.mpicontrol_d30_GMDplot.oc.nc')

# Grid info
dp      = f1.variables['DEPTH'][:]/100
yr      = f1.variables['DAY'][:]/360

# Figure 1
hflux   = f1.variables['HFLUX'][:]
wflux   = f1.variables['WFLUX'][:]*1000
sst     = f1.variables['SST'][:]-273.15
sss     = f1.variables['SSS'][:]
tav     = f1.variables['TAV'][:]-273.15
sav     = f1.variables['SAV'][:]

# Figure 2
ttdsg   = f1.variables['TTDSG'][:,:]-273.15
stdsg   = f1.variables['STDSG'][:,:]
zsst    = f1.variables['ZSST'][:,0,:]-273.15
zsss    = f1.variables['ZSSS'][:,0,:]

# Figure 3
omax30  = f1.variables['OMAX30'][:]
dptrans = f1.variables['DRAKE_PASSAGE'][:]

# Figure 9
amoc    = f1.variables['AMOC'][:,:,:]

f1.close()
#==================
f1 = netCDF4.Dataset('mpicontrol1b_r6_d30_GMDplot.oc.nc')

# Grid info
lon     = f1.variables['LONGITUDE_T'][:]
lat     = f1.variables['LATITUDE_T'][:]
lonu    = np.insert(f1.variables['LONGITUDE_U'][:],0,0)
latu    = np.insert(f1.variables['LATITUDE_U'][:],0,-88)
dpe     = f1.variables['DEPTH_EDGES'][:]/100

# Figure 7
sst2d   = np.mean(f1.variables['theta0'][:300,0,:,:],axis=0)-273.15
sss2d   = np.mean(f1.variables['so'][:300,0,:,:],axis=0)

# Figure 8
ssh     = np.mean(f1.variables['zos'][:300,:,:],axis=0)

f1.close()
#==================
# Figure 7 - Obs
e1      = netCDF4.Dataset('EN3_v2a_ObjectiveAnalysis_1960_1989_mean_on_FORTE_grid_GMDplot.nc')

elon    = e1.variables['lon'][:]
elat    = e1.variables['lat'][:]
esst2d  = np.mean(e1.variables['temperature'][:,0,:,:],axis=0)-273.15
esss2d  = np.mean(e1.variables['salinity'][:,0,:,:],axis=0)

e1.close()
#==================
# Figure 8 - Obs
f1      = netCDF4.Dataset('OCCAv2_SSH_on_FORTE_grid_GMDplot.nc')

lonssh  = f1.variables['lon'][:]
latssh  = f1.variables['lat'][:]
occassh = f1.variables['etan'][0,:,:]

f1.close()
#==================
# Figures 4 and 5
f1      = netCDF4.Dataset('Ts_Precip_MSLP_720001_729000_GMDplot.nc')

# Grid info
lona    = np.append(f1.variables['longitude'][:],359.99)
lata   = np.insert(f1.variables['latitude'][:],0,90)

STanmn  = f1.variables['STanmn'][:,:]
STmonclim  = f1.variables['STmonclim'][:,:,:]
PTanmn  = f1.variables['PTanmn'][:,:]
PTmonclim  = f1.variables['PTmonclim'][:,:,:]
PSanmn  = f1.variables['PSanmn'][:,:]
PSmonclim  = f1.variables['PSmonclim'][:,:,:]

PTjjaclim  = np.mean(PTmonclim[6:8,:,:],axis=0)
PTdjfclim  = np.mean(np.roll(PTmonclim,1)[:3,:,:],axis=0)
PTdjfclim2 = np.roll(PTmonclim,1,axis=0)[:3,:,:]
PTjjaclim2 = PTmonclim[5:8,:,:]

f1.close()
#==================
# Figure 6
f1      = netCDF4.Dataset('u_zm.nc')

pres    = f1.variables['Pressure'][:]
lata2   = f1.variables['Latitude'][:]
UJJA    = f1.variables['UJJA'][:,:]
UDJF    = f1.variables['UDJF'][:,:]

f1.close()
#==================
# Figure 10
f1      = netCDF4.Dataset('FORTE_MHT_GMDplot.nc')
MHT     = f1.variables['MHT'][:]
MHTA    = f1.variables['MHTA'][:]
MHTIP   = f1.variables['MHTIP'][:]
f1.close()


# Figure 1
fig, (axarr) = plt.subplots(3, 2, figsize=(25,25))

axarr[0,0].plot(yr,hflux, yr,hflux*0)
axarr[0,0].set_ylim([-1.6,1.6])
axarr[0,0].tick_params(labelsize=18)
axarr[0,0].set_xlabel("Time (years)", fontsize=20)
axarr[0,0].set_ylabel("W m$^{-2}$", fontsize=20)

axarr[0,1].plot(yr,wflux, yr,wflux*0)
axarr[0,1].set_ylim([-3,3])
axarr[0,1].tick_params(labelsize=18)
axarr[0,1].set_xlabel("Time (years)", fontsize=20)
axarr[0,1].set_ylabel("mm yr$^{-1}$", fontsize=20)

axarr[1,0].plot(yr,sst)
axarr[1,0].set_ylim([18,19.5])
axarr[1,0].tick_params(labelsize=18)
axarr[1,0].set_xlabel("Time (years)", fontsize=20)
axarr[1,0].set_ylabel("$^{o}$C", fontsize=20)

axarr[1,1].plot(yr,sss)
axarr[1,1].set_ylim([34.8,35.3])
axarr[1,1].tick_params(labelsize=18)
axarr[1,1].set_xlabel("Time (years)", fontsize=20)
axarr[1,1].set_ylabel("psu", fontsize=20)

axarr[2,0].plot(yr,tav)
axarr[2,0].set_ylim([3.4,3.8])
axarr[2,0].tick_params(labelsize=18)
axarr[2,0].set_xlabel("Time (years)", fontsize=20)
axarr[2,0].set_ylabel("$^{o}$C", fontsize=20)

axarr[2,1].plot(yr,sav)
axarr[2,1].set_ylim([34.728,34.732])
axarr[2,1].tick_params(labelsize=18)
axarr[2,1].set_xlabel("Time (years)", fontsize=20)
axarr[2,1].set_ylabel("psu", fontsize=20)

# Label subfigures
axs = axarr.flat
for n, ax in enumerate(axs):
    ax.text(-0.1, 1.05, string.ascii_lowercase[n], transform=ax.transAxes, 
            size=20, weight='bold')

fig.savefig('Figures/FORTE2_Figure01.eps', dpi=300, transparent=True, bbox_inches='tight', pad_inches=0.1)


#Figure 2
#### Choose colormap... 
mycmap = plt.cm.get_cmap("seismic")# jet, seismic, spectral
# extract colors from the .jet map and set over/under to be same as max/min
cmaplist = [mycmap(i) for i in range(mycmap.N)]
mycmap.set_under(cmaplist[0])
mycmap.set_over(cmaplist[-1])

c = mcolors.ColorConverter().to_rgb
mybr = make_colormap(
    [c('blue'), c('white'), 0.45, c('white'), .55, c('white'), c('red')])

fig, (axarr) = plt.subplots(2, 2, figsize=(20,14))

### 
ttdsga = np.transpose(ttdsg-np.tile(ttdsg[0,:],[len(yr),1]))
stdsga = np.transpose(stdsg-np.tile(stdsg[0,:],[len(yr),1]))
zssta  = np.transpose(zsst-np.tile(zsst[0,:],[len(yr),1]))
zsssa  = np.transpose(zsss-np.tile(zsss[0,:],[len(yr),1]))

levels = np.arange(-3,3.1,.2)
cs = axarr[0,0].contourf(yr,lat,zssta,levels,cmap=mybr, extend='both')
axarr[0,0].tick_params(labelsize=18)
axarr[0,0].set_xlabel("Time (years)", fontsize=20)
axarr[0,0].set_ylabel("Latitude", fontsize=20)
#axarr[0,0].invert_yaxis()
cbar = fig.colorbar(mappable=cs,orientation='vertical',ax=axarr[0,0],shrink=.8,pad=0.02)
#plt.subplots_adjust(wspace=0.1, hspace=0.1)
cbar.ax.tick_params(labelsize=20) 

levels = np.arange(-2,2.1,.1)
cs = axarr[0,1].contourf(yr,lat,zsssa,levels,cmap=mybr, extend='both')
axarr[0,1].tick_params(labelsize=18)
axarr[0,1].set_xlabel("Time (years)", fontsize=20)
axarr[0,1].set_ylabel("Latitude", fontsize=20)
#axarr[0,0].invert_yaxis()
cbar = fig.colorbar(mappable=cs,orientation='vertical',ax=axarr[0,1],shrink=.8,pad=0.02)
#plt.subplots_adjust(wspace=0.1, hspace=0.1)
cbar.ax.tick_params(labelsize=20) 

levels = np.arange(-2.2,2.3,.2)
cs = axarr[1,0].contourf(yr,dp,ttdsga,levels,cmap=mybr, extend='both')
axarr[1,0].tick_params(labelsize=18)
axarr[1,0].set_xlabel("Time (years)", fontsize=20)
axarr[1,0].set_ylabel("Depth (m)", fontsize=20)
axarr[1,0].invert_yaxis()
cbar = fig.colorbar(mappable=cs,orientation='vertical',ax=axarr[1,0],shrink=.8,pad=0.02)
#plt.subplots_adjust(wspace=0.1, hspace=0.1)
cbar.ax.tick_params(labelsize=20) 

levels = np.arange(-.35,.38,.05)
cs2 = axarr[1,1].contourf(yr,dp,stdsga,levels,cmap=mybr, extend='both')
axarr[1,1].tick_params(labelsize=18)
axarr[1,1].set_xlabel("Time (years)", fontsize=20)
axarr[1,1].set_ylabel("Depth (m)", fontsize=20)
axarr[1,1].invert_yaxis()
cbar = fig.colorbar(mappable=cs2,orientation='vertical',ax=axarr[1,1],shrink=.8,pad=0.02)
#plt.subplots_adjust(wspace=0.1, hspace=0.1)
cbar.ax.tick_params(labelsize=20) 

# Label subfigures
axs = axarr.flat
for n, ax in enumerate(axs):
    ax.text(-0.1, 1.05, string.ascii_lowercase[n], transform=ax.transAxes, 
            size=20, weight='bold')

fig.savefig('Figures/FORTE2_Figure02.eps', dpi=300, transparent=True, bbox_inches='tight', pad_inches=0.1)


#Figure 3
fig, (axarr) = plt.subplots(1, 2, figsize=(20,7))

axarr[0].plot(yr,omax30)
axarr[0].set_ylim([10,20])
axarr[0].tick_params(labelsize=18)
axarr[0].set_xlabel("Time (years)", fontsize=20)
axarr[0].set_ylabel("Sv", fontsize=20)

axarr[1].plot(yr,dptrans)
axarr[1].set_ylim([70,160])
axarr[1].tick_params(labelsize=18)
axarr[1].set_xlabel("Time (years)", fontsize=20)
axarr[1].set_ylabel("Sv", fontsize=20)

# Label subfigures
axs = axarr.flat
for n, ax in enumerate(axs):
    ax.text(-0.1, 1.05, string.ascii_lowercase[n], transform=ax.transAxes, 
            size=20, weight='bold')

fig.savefig('Figures/FORTE2_Figure03.eps', dpi=300, transparent=True, bbox_inches='tight', pad_inches=0.1)


# Figure 4
#==================
def Fglobarra( ax, inarray, cmin=None, cmax=None, pal='seismic', clev=6 , clab=6):

    ax.set_global()
    ax.coastlines()
    ax.gridlines()
    ax.add_feature(cartopy.feature.LAND, zorder=0)
    
    cs = ax.pcolormesh(lona ,lata[1:], inarray[1:,:], cmap=pal, vmin=cmin, vmax=cmax, transform=ccrs.PlateCarree())
    cbar = plt.colorbar(mappable=cs,orientation='horizontal',ax=ax,shrink=.6,pad=0.01, extend='both')
    cbar.ax.tick_params(labelsize=20) 

    inarray2 = np.insert(inarray,0,inarray[:,-1],axis=1) # Duplicate last column of array at start to 'join' contours at prime meridian
    cs2 = ax.contour(lona[:], lata[1:-1], inarray2[1:,:], clev, colors='k', transform=ccrs.PlateCarree())
    plt.clabel(cs2, clab,  inline=True, fontsize=16, fmt = '%1.0f')
#==================

c = mcolors.ColorConverter().to_rgb
mybr = make_colormap(
    [c('blue'), c('white'), 0.45, c('white'), .55, c('white'), c('red')])

c = mcolors.ColorConverter().to_rgb
rain = make_colormap(
    [c('white'), c('blue'), 0.5, c('blue'), c('orange'), 0.9, c('orange'), c('red')])

fig, (axarr) = plt.subplots(4, 2, figsize=(25,28))

clev = [-15,-10,-5,0,5,10,15,20,25,30]
clab = [-15,-10,-5,0,5,10,15,20,25,30]
axarr[0,0] = plt.subplot(4, 2, 1, projection=ccrs.Robinson())
Fglobarra( axarr[0,0], STanmn, cmin=-40, cmax=40, pal="RdBu_r", clev=clev, clab=clab)

axarr[1,0] = plt.subplot(4, 2, 3, projection=ccrs.Robinson())
Fglobarra( axarr[1,0], STmonclim[0,:,:], cmin=-40, cmax=40, pal="RdBu_r", clev=clev, clab=clab)

axarr[2,0] = plt.subplot(4, 2, 5, projection=ccrs.Robinson())
Fglobarra( axarr[2,0], STmonclim[6,:,:], cmin=-40, cmax=40, pal="RdBu_r", clev=clev, clab=clab)

clev = [5,10,15]#,20,25,30]
clab = [5,10,15]#,20,25,30]
axarr[0,1] = plt.subplot(4, 2, 2, projection=ccrs.Robinson())
Fglobarra( axarr[0,1], PTanmn, cmin=0, cmax=15, pal=rain, clev=clev, clab=clab)

axarr[1,1] = plt.subplot(4, 2, 4, projection=ccrs.Robinson())
Fglobarra( axarr[1,1], PTdjfclim, cmin=0, cmax=15, pal=rain, clev=clev, clab=clab)

axarr[2,1] = plt.subplot(4, 2, 6, projection=ccrs.Robinson())
Fglobarra( axarr[2,1], PTjjaclim, cmin=0, cmax=15, pal=rain, clev=clev, clab=clab)

clev = [5,10,20,30,40,50]
clab = [5,10,20,30,40,50]
axarr[3,0] = plt.subplot(4, 2, 7, projection=ccrs.Robinson())
Fglobarra( axarr[3,0], np.abs(np.max(STmonclim,axis=0)-np.min(STmonclim,axis=0)), cmin=0, cmax=45, pal=mybr, clev=clev, clab=clab)

clev = [5,10,15,20,25,30]
clab = [5,10,15,20,25,30]
axarr[3,1] = plt.subplot(4, 2, 8, projection=ccrs.Robinson())
Fglobarra( axarr[3,1], np.abs(np.max(PTmonclim,axis=0)-np.min(PTmonclim,axis=0)), cmin=0, cmax=15, pal=mybr, clev=clev, clab=clab)

# Label subfigures
axs = axarr.flat
for n, ax in enumerate(axs):
    ax.text(-0.05, 1.05, string.ascii_lowercase[n], transform=ax.transAxes, 
            size=20, weight='bold')
    
plt.tight_layout()
plt.savefig("Figures/FORTE2_Figure04.png", dpi=300, transparent=True, bbox_inches='tight', pad_inches=0.1)


# Figure 5
#==================
def Fglobarra( ax, inarray, cmin=None, cmax=None, pal='seismic', clev=6 , clab=6):

    ax.set_global()
    ax.coastlines()
    ax.gridlines()
    ax.add_feature(cartopy.feature.LAND, zorder=0)
    
    cs = ax.pcolormesh(lona ,lata, inarray, cmap=pal, vmin=cmin, vmax=cmax, transform=ccrs.PlateCarree())
    cbar = plt.colorbar(mappable=cs,orientation='horizontal',ax=ax,shrink=.6,pad=0.01, extend='both')
    cbar.ax.tick_params(labelsize=20) 

    inarray2 = np.insert(inarray,0,inarray[:,-1],axis=1) # Duplicate last column of array at start to 'join' contours at prime meridian
    cs2 = ax.contour(lona[:], lata[1:-1], inarray2[1:,:], clev, colors='k', transform=ccrs.PlateCarree())
    plt.clabel(cs2, clab,  inline=True, fontsize=16, fmt = '%1.0f')
#==================

c = mcolors.ColorConverter().to_rgb
mybr = make_colormap(
    [c('blue'), c('white'), 0.45, c('white'), .55, c('white'), c('red')])

fig, (axarr) = plt.subplots(2, 2, figsize=(25,14))

clev = [-5,-3,-1,0,1,3,5,10,15,20]
clab = [-5,-3,-1,0,1,3,5,10,15,20]
axarr[0,0] = plt.subplot(2, 2, 1, projection=ccrs.Robinson())
Fglobarra( axarr[0,0], PSanmn, cmin=-15, cmax=15, pal="RdBu_r", clev=clev, clab=clab)

axarr[0,1] = plt.subplot(2, 2, 2, projection=ccrs.Robinson())
Fglobarra( axarr[0,1], PSmonclim[0,:,:], cmin=-15, cmax=15, pal="RdBu_r", clev=clev, clab=clab)

axarr[1,1] = plt.subplot(2, 2, 4, projection=ccrs.Robinson())
Fglobarra( axarr[1,1], PSmonclim[6,:,:], cmin=-15, cmax=15, pal="RdBu_r", clev=clev, clab=clab)

clev = [5,10,15,20,25,30,40]
clab = [5,10,15,20,25,30,40]
axarr[1,0] = plt.subplot(2, 2, 3, projection=ccrs.Robinson())
Fglobarra(axarr[1,0], np.abs(np.max(PSmonclim,axis=0)-np.min(PSmonclim,axis=0)), cmin=0, cmax=30, pal="RdBu_r", clev=clev, clab=clab)

# Label subfigures
axs = axarr.flat
for n, ax in enumerate(axs):
    ax.text(-0.05, 1.05, string.ascii_lowercase[n], transform=ax.transAxes, 
            size=20, weight='bold')
    
plt.tight_layout()
plt.savefig("Figures/FORTE2_Figure05.png", dpi=300, transparent=True, bbox_inches='tight', pad_inches=0.1)


# Figure 6
myuv = make_colormap(
    [c('blue'), c('white'), 0.25, c('white'), c('orange'), 0.7, c('orange'), c('red')])

fig, (axarr) = plt.subplots(2, 1, figsize=(20,14))

levels = np.arange(-10,45,5)
clev = np.arange(-10,45,5)
cs = axarr[0].contourf(lata2,pres,UDJF,levels,cmap=myuv, extend='both')
cs2 = axarr[0].contour(lata2,pres,UDJF, levels, colors='k')
plt.clabel(cs2, clev.astype(int),  inline=True, fontsize=14, fmt = '%1.0f')
axarr[0].tick_params(labelsize=18)
axarr[0].set_xlabel("Latitude", fontsize=20)
axarr[0].set_ylabel("Pressure (hPa)", fontsize=20)
axarr[0].invert_yaxis()
axarr[0].set_xlim(-90,90)
cbar = fig.colorbar(mappable=cs,orientation='vertical',ax=axarr[0],shrink=.8,pad=0.02)
cbar.ax.tick_params(labelsize=20) 

cs = axarr[1].contourf(lata2,pres,UJJA,levels,cmap=myuv, extend='both')
cs2 = axarr[1].contour(lata2,pres,UJJA, levels, colors='k')
plt.clabel(cs2, clev.astype(int),  inline=True, fontsize=14, fmt = '%1.0f')
axarr[1].tick_params(labelsize=18)
axarr[1].set_xlabel("Latitude", fontsize=20)
axarr[1].set_ylabel("Pressure (hPa)", fontsize=20)
axarr[1].invert_yaxis()
axarr[1].set_xlim(-90,90)
cbar = fig.colorbar(mappable=cs,orientation='vertical',ax=axarr[1],shrink=.8,pad=0.02)
cbar.ax.tick_params(labelsize=20) 

# Label subfigures
axs = axarr.flat
for n, ax in enumerate(axs):
    ax.text(-0.05, 1.05, string.ascii_lowercase[n], transform=ax.transAxes, 
            size=20, weight='bold')

fig.savefig('Figures/FORTE2_Figure06.png', dpi=300, transparent=True, bbox_inches='tight', pad_inches=0.1)


# Figure 7
#==================
def Fglobarr( ax, inarray, cmin=None, cmax=None, pal='seismic', clev=6 , clab=6):

    from cartopy.util import add_cyclic_point
    ax.set_global()
    ax.coastlines()
    ax.gridlines()
    ax.add_feature(cartopy.feature.LAND, zorder=0)
    inarrayc = add_cyclic_point(inarray)
    lonsc = np.append(lon,360.99)
		
    cs = ax.pcolormesh(lonu, latu, inarray, cmap=pal, vmin=cmin, vmax=cmax, transform=ccrs.PlateCarree())
    cbar = plt.colorbar(mappable=cs,orientation='horizontal',ax=ax,shrink=.6,pad=0.01, extend='both')
    cbar.ax.tick_params(labelsize=20) 

    cs2 = ax.contour(lonsc ,lat, inarrayc, clev, colors='k', transform=ccrs.PlateCarree())
    plt.clabel(cs2, clab,  inline=True, fontsize=16, fmt = '%1.0f')
#==================

c = mcolors.ColorConverter().to_rgb
mybr = make_colormap(
    [c('blue'), c('white'), 0.45, c('white'), .55, c('white'), c('red')])

fig, (axarr) = plt.subplots(3, 2, figsize=(25,20))

clev = [-2,2,6,10,14,18,22,26]
clab = [2,10,18,26]
axarr[0,0] = plt.subplot(3, 2, 1, projection=ccrs.Robinson())
Fglobarr( axarr[0,0], sst2d, cmin=-2, cmax=28, pal="RdBu_r", clev=clev, clab=clab)

axarr[1,0] = plt.subplot(3, 2, 3, projection=ccrs.Robinson())
Fglobarr( axarr[1,0], esst2d, cmin=-2, cmax=28, pal="RdBu_r", clev=clev, clab=clab)

clev = [30,32,33,34,35,36]
clab = [30,32,33,34,35,36]
axarr[0,1] = plt.subplot(3, 2, 2, projection=ccrs.Robinson())
Fglobarr( axarr[0,1], sss2d, cmin=30, cmax=37, pal="BrBG_r", clev=clev, clab=clab)

axarr[1,1] = plt.subplot(3, 2, 4, projection=ccrs.Robinson())
Fglobarr( axarr[1,1], esss2d, cmin=30, cmax=37, pal="BrBG_r", clev=clev, clab=clab)

clev = list(np.arange(-7,7.1,2))
clab = list(np.arange(-5,5.1,2))
axarr[2,0] = plt.subplot(3, 2, 5, projection=ccrs.Robinson())
Fglobarr( axarr[2,0], sst2d-esst2d, cmin=-7, cmax=7, pal=mybr, clev=clev, clab=clab)

clev = list(np.arange(-3.5,3.6,1))
clab = list(np.arange(-2.5,2.6,1))
axarr[2,1] = plt.subplot(3, 2, 6, projection=ccrs.Robinson())
Fglobarr( axarr[2,1], sss2d-esss2d, cmin=-3, cmax=3, pal=mybr, clev=clev, clab=clab)

# Label subfigures
axs = axarr.flat
for n, ax in enumerate(axs):
    ax.text(-0.05, 1.05, string.ascii_lowercase[n], transform=ax.transAxes, 
            size=20, weight='bold')

plt.tight_layout()
plt.savefig("Figures/FORTE2_Figure07.png", dpi=300, transparent=True, bbox_inches='tight', pad_inches=0.1)


# Figure 8
#==================
def Fglobarr( ax, inarray, cmin=None, cmax=None, pal='seismic', clev=6 , clab=6):

    from cartopy.util import add_cyclic_point
    ax.set_global()
    ax.coastlines()
    ax.gridlines()
    ax.add_feature(cartopy.feature.LAND, zorder=0)
    inarrayc = add_cyclic_point(inarray)
    lonsc = np.append(lon,360.99)
		
    cs = ax.pcolormesh(lonu ,latu, inarray, cmap=pal, vmin=cmin, vmax=cmax, transform=ccrs.PlateCarree())
    cbar = plt.colorbar(mappable=cs,orientation='horizontal',ax=ax,shrink=.6,pad=0.01, extend='both')
    cbar.ax.tick_params(labelsize=20) 

    cs2 = ax.contour(lonsc ,lat, inarrayc, clev, colors='k', transform=ccrs.PlateCarree())
    plt.clabel(cs2, clab,  inline=True, fontsize=16, fmt = '%1.1f')
#==================

c = mcolors.ColorConverter().to_rgb
mybr = make_colormap(
    [c('blue'), c('white'), 0.64, c('white'), .65, c('white'), c('red')])

fig, (axarr) = plt.subplots(2, 1, figsize=(25,20))

clev = np.arange(-1.4,1.2,.2)
clab = np.arange(-1.4,1.2,.2)
axarr[0] = plt.subplot(2, 1, 1, projection=ccrs.Robinson())
Fglobarr( axarr[0], ssh, cmin=-1.8, cmax=1, pal=mybr, clev=clev, clab=clab)
axarr[1] = plt.subplot(2, 1, 2, projection=ccrs.Robinson())
Fglobarr( axarr[1], occassh, cmin=-1.8, cmax=1, pal=mybr, clev=clev, clab=clab)

# Label subfigures
axs = axarr.flat
for n, ax in enumerate(axs):
    ax.text(-0.05, 1.05, string.ascii_lowercase[n], transform=ax.transAxes, 
            size=20, weight='bold')

plt.tight_layout()
plt.savefig("Figures/FORTE2_Figure08.png", dpi=300, transparent=True, bbox_inches='tight', pad_inches=0.1)


# Figure 9
#### Choose colormap... 
mycmap = plt.cm.get_cmap("jet")# jet, seismic, spectral
# extract colors from the .jet map and set over/under to be same as max/min
cmaplist = [mycmap(i) for i in range(mycmap.N)]
mycmap.set_under(cmaplist[0])
mycmap.set_over(cmaplist[-1])

c = mcolors.ColorConverter().to_rgb
mybr = make_colormap(
    [c('blue'), c('white'), 0.45, c('white'), .55, c('white'), c('red')])

fig,ax = plt.subplots(1, 1, figsize=(20,7))

amocm  = np.mean(amoc[390:,:,:],axis=0)

levels = np.arange(-4,22,1)
cs = ax.contourf(lat[28:],dp,amocm,levels,cmap=mycmap, extend='both')
ax.tick_params(labelsize=18)
ax.set_xlabel("Latitude", fontsize=20)
ax.set_ylabel("Depth (m)", fontsize=20)
ax.invert_yaxis()
ax.set_xlim(-30,70)
cbar = fig.colorbar(mappable=cs,orientation='vertical',ax=ax,shrink=.8,pad=0.02)
#plt.subplots_adjust(wspace=0.1, hspace=0.1)
cbar.ax.tick_params(labelsize=20) 

fig.savefig('Figures/FORTE2_Figure09.png', dpi=300, transparent=True, bbox_inches='tight', pad_inches=0.1)

#Figure 10

def smooth(y, box_pts):
    box = np.ones(box_pts)/box_pts
    y_smooth = np.convolve(y, box, mode='same')
    return y_smooth

fig,ax = plt.subplots(1, 1, figsize=(10,5))
plt.plot(latu[1:], smooth(MHTA,3)/1e15,label="Atlantic")
plt.plot(latu[1:], smooth(MHTIP,3)/1e15,label="Indo-Pacific")
plt.plot(latu[1:], smooth(MHT,3)/1e15,label="Total")
plt.plot(latu[1:], MHT*0,color='k')
plt.xlabel("Latitude")
plt.ylabel("Heat transport (PW)")
plt.xlim([-80,80])
plt.legend()

fig.savefig('Figures/FORTE2_Figure10.png', dpi=300, transparent=True, bbox_inches='tight', pad_inches=0.1)

