#! /usr/bin/env python3

# Plots time series of selected observations vs fitted mean function
# for various places. Not a particularly beautiful script but
# works. Creates plots like Fig. 1 in "Efficient multi-scale Gaussian
# process regression for massive remote ...", geosci model dev 2020.

# Usage: e.g.
# ./plot_mf_timeseries_with_data.py ../experiments/exp_where_local_mf_coeffs_where_learned

from glob import glob
from netCDF4 import Dataset
from pylab import *
import scipy.optimize
import sys


def datestring_from_fname(f):
    if '/' in f: f = f.split('/')[-1]
    # print(f)
    s = f.split('_')[2]
    y = int(s[:2])
    m = int(s[2:4])
    d = int(s[4:6])
    return y, m, d

datadir = "../data/oco2_v9"
files = list(sorted(glob(datadir + "/*")))
print(files)
beta_dir = sys.argv[1]
gridres = 2.

ndays = (datetime.date(*datestring_from_fname(files[-1])) - \
         datetime.date(*datestring_from_fname(files[0]))).days


def find_closest_index(lat, lon, latmin, latmax0, lonmin, lonmax0, gridres):
    lats = arange(latmin, latmax0, gridres) + gridres*.5
    lons = arange(lonmin, lonmax0, gridres) + gridres*.5
    nlon = len(lons)
    lonidx = argmin(np.abs(lons - lon))
    latidx = argmin(np.abs(lats - lat))
    print("Closest points: lat, gridlon:", lat, lats[latidx], lon, lons[lonidx])
    print('latidx, lonidx:', latidx, lonidx)
    return nlon*latidx + lonidx


places = "St. Petersburg, Mauna Loa, Perth, Washington DC, New Delhi, Ulan Bator".split(", ")
coords = [(60.0,30.3), (19.5, -155.6), (-32.0, 115.9), (38.9, -77.3), (28.6, 77.2), (46.9, 106.9)]

places += "Azores, Bogota, Lagos".split(", ")
coords += [(37.7, 25.7), (4.7, -74.1), (6.5, -3.3)]

# Finish getting the legend strings
for i,p in enumerate(places):
    xlat, xlon = coords[i]
    D0 = ' W' if xlon < 0 else ' E'
    D1 = ' N' if xlat > 0 else ' S'
    p = p + " (" + str(abs(xlat)) + "$\degree$" + D1 + ', ' + str(abs(xlon)) + "$\degree$" + D0 + ')'
    places[i] = p

colors = ['red', 'green', 'magenta','black', 'cyan', 'blue', 'brown', 'yellow', 'orange']

# Pick only these locations:
indexes = [0,1,2,3,5,6,8]

alphaindexes = [places.index(s) for s in sorted([places[i] for i in indexes])]

places = [places[i] for i in alphaindexes]
coords = [coords[i] for i in alphaindexes]

inds = [find_closest_index(*c, -85, 85, -180, 180, gridres) for c in coords]

filestride = 1 # debug if > 1

# Set to True to plot a difference field for evaluating how good the
# fit was.
plot_diff = False

if plot_diff:
    fig, axes = subplots(2,1, figsize=(24,13))
    ax0 = axes[0]
    ax1 = axes[1]
else:
    fig, ax0 = subplots(1, figsize=(16,7))


# 1. Plot the mean function curve
beta_fnames = ["beta0.txt", "beta1.txt", "beta2.txt", "beta3.txt", "delta.txt"]
beta_fnames = [beta_dir + "/" + bf for bf in beta_fnames]
beta_data = np.array([np.reshape(loadtxt(b), (-1,)) for b in beta_fnames])
print("beta data shape:", shape(beta_data))

def f(x, b0, b1, b2, b3, d):
    # This function is able to pretty well describe the XCO2 variation
    # place to place. This is the meanfunction used, change according
    # to your needs.
    x = x/365.25*2*pi
    return b0*sin(x + d) + b1*cos(2*x + d) + b3*x + b2

mfdata = zeros((ndays, len(places)))

# the beta coefficients
bvals = []
for idx in inds:
    bvals.append([beta_data[k][idx] for k in list(range(5))])

for i in range(ndays):
    for j in range(len(places)):
        mfdata[i,j] = f(i, *bvals[j])

for i,p in enumerate(places):
    # Draw mean function solid line
    ax0.plot(mfdata[:,i], label=p, color=colors[i], lw=3)


# 2. Plot the points through which the fit was made
for i,c in enumerate(places):
    localdata = loadtxt(beta_dir + '/' + str(inds[i]))
    ax0.plot((localdata[:,2] - 1e7)/86400, localdata[:,3], '.', color=colors[i])

    if plot_diff:
        # Get local data that was used for the fit
        localdata = loadtxt(beta_dir + '/' + str(inds[i]))
        t = (localdata[:,2] - 1e7)/86400
        so = argsort(t)
        co2 = localdata[:,3][so]
        tt = t[so]
        mfvals = zeros_like(tt)
        for j,t in enumerate(tt):
            mfvals[j] = f(t, *bvals[i])

        localdiffdata = co2 - mfvals
        ax1.plot(tt, localdiffdata, '.', label="difference at" + places[i], color=colors[i], markersize=10)

if plot_diff:
    ax1.set_ylim([-7,7])
    ax1.set_xlim([0, ndays])
    ax1.grid(True)
    ax1.set_xticks([114, 114+365, 114+2*365, 114+3*365])
    ax1.set_xticklabels(['2015', '2016', '2017', '2018'], rotation=70)
    ax1.set_xlabel('Year')
    ax1.set_ylabel('Deviation')

ax0.set_xticks([114, 114+365, 114+2*365, 114+3*365])
ax0.set_xticklabels(['2015', '2016', '2017', '2018'], rotation=70)
ax0.set_xlabel('Year')
ax0.set_ylabel('XCO2 (ppm)')
#ax0.set_ylim([392, 412])
ax0.set_xlim([0, ndays])

ax0.legend()
# ax0.set_title(beta_dir)
savefig(beta_dir + "/mean_function_fits.png", bbox_inches='tight', pad_inches=0.05, dpi=300)
