#! /usr/bin/env python3

# Draws a satGP-computed GP field and uncertainty, and makes a
# scatterplot of the data

# usage:
# ./plot_field_and_unc.py path/to/experiment/directory Area

# where Area is one of the areadict keys in areas.py (and in
# gaussian_proc.h)

import matplotlib.pyplot as plt

from pylab import *
import numpy as np
import time
import sys
import datetime
import os

from mpl_toolkits.basemap import Basemap

from areas import areas, get_latres

areadict = areas[sys.argv[2]]

def get_fielddata():
    max_rows = 100
    data = np.loadtxt(sys.argv[1] + "/gp_mean.txt", ndmin=2, max_rows=max_rows)
    data[data==0] = np.nan
    unc = np.loadtxt(sys.argv[1] + "/gp_unc.txt", ndmin=2, max_rows=max_rows)
    unc[unc==0] = np.nan
    latres = get_latres(data, areadict)
    gridres = 1.*(areadict['latmax'] - areadict['latmin'])/latres
    # Convolve IN TIME NOW ONLY to get smooth video, if such a thing
    # would need to be done (e.g. if using random observation
    # selection)
    if False:
        G = lambda x: np.exp(-.5*x**2/5**2)/(np.sqrt(np.pi*2*25))
        ker = G(np.linspace(-10,10,21))
        ker = ker/np.sum(ker)
        for i in range(np.shape(data)[1]):
            data[:,i] = np.convolve(data[:,i], ker, mode='same')
            unc[:,i] = np.convolve(unc[:,i], ker, mode='same')
    return data, np.sqrt(unc), latres, gridres

def get_daylist():
    # Use a daylist file if daylist.txt exists and its length matches data
    try:
        dl = loadtxt(sys.argv[1] + "/daylist.txt", ndmin=1)
        # print(len(dl), len(data))
        if len(dl) == len(data):
            daylist = dl
        else:
            daylist = range(len(data))
    except Exception:
        daylist = range(len(data))
    return daylist

def get_winds():
    """Plot the wind arrows if winds are used"""
    plot_winds = False
    uw = None; vw = None
    if os.path.exists(sys.argv[1] + "/u_winds.txt"):
        plot_winds = True
    print("plot_winds:", plot_winds)
    if plot_winds:
        print("loading")
        uw = np.loadtxt(sys.argv[1] + "/u_winds.txt")
        # print("u_winds:", u_winds)
        vw = np.loadtxt(sys.argv[1] + "/v_winds.txt")
    return uw, vw, plot_winds

[data, unc, latres, gridres] = get_fielddata()
daylist = get_daylist()
[u_winds, v_winds, plot_winds] = get_winds()

# Limits, use whatever you like for the plots
v0 = np.nanmin(data)
v1 = np.nanmax(data)

v0unc = np.nanmin(unc)
v1unc = np.nanmax(unc)
print('Grid resolution:', gridres)

max_lon = areadict['lonmax']; max_lat = areadict['latmax']
min_lon = areadict['lonmin']; min_lat = areadict['latmin']
aspect = (max_lon - min_lon)/(max_lat - min_lat)

# As to the figure size, (25.1, 21.9) is good figsize World
fig, axes = subplots(2, 1, figsize=(25.1, 21.9))

# Set titles for the subplots.
axes[0].set_title("(a) Concentration (ppm)", fontsize=30)
axes[1].set_title("(b) Uncertainty (std)", fontsize=30)

b_data = Basemap(resolution='i', llcrnrlon=min_lon-180, llcrnrlat=min_lat,
                 urcrnrlon=max_lon-180, urcrnrlat=max_lat, ax=axes[0])
b_unc = Basemap(resolution='i', llcrnrlon=min_lon-180, llcrnrlat=min_lat,
                urcrnrlon=max_lon-180, urcrnrlat=max_lat, ax=axes[1])

b_data.drawcoastlines()
b_unc.drawcoastlines()

# Initial figure: these are then updated.
im_data = b_data.imshow(data[0].reshape((latres, -1)), vmin=v0,
                        vmax=v1, cmap='jet', interpolation='bilinear')
im_unc = b_unc.imshow(unc[0].reshape((latres, -1)), vmin=v0unc, vmax=v1unc,
                      cmap='jet', interpolation='bilinear')

nlon = int(len(data[0])/latres)
nlat = int(latres)

# With dense grids, winds should be not plotted everywhere. This gives
# the stride for that. FIXME winds not necessarily end up in the
# centers this way...
wind_stride = 1
if plot_winds:
    u_winds = u_winds.reshape((-1, nlat, nlon))
    v_winds = v_winds.reshape((-1, nlat, nlon))
    u_winds = u_winds[:,::wind_stride,::wind_stride]
    v_winds = v_winds[:,::wind_stride,::wind_stride]
    l1 = linspace(min_lon + gridres/2, max_lon - gridres/2, nlon) + 180
    l2 = linspace(min_lat + gridres/2, max_lat - gridres/2, nlat)
    l1 = l1[::wind_stride]
    l2 = l2[::wind_stride]
    X,Y = meshgrid(l1, l2)
    X = reshape(X, (-1,))
    Y = reshape(Y, (-1,))
    im_winds = b_data.quiver(X, Y, u_winds[0], v_winds[0], latlon=True, color='black')


from mpl_toolkits.axes_grid1 import make_axes_locatable
divider0 = make_axes_locatable(axes[0])
divider1 = make_axes_locatable(axes[1])
cax0 = divider0.append_axes("right", size="2%", pad=0.05)
cax1 = divider1.append_axes("right", size="2%", pad=0.05)

# cbar = plt.colorbar(PC, cax = cax1)

cb0 = plt.colorbar(im_data, cax=cax0)
cb1 = plt.colorbar(im_unc, cax=cax1)

cb0.ax.tick_params(labelsize=15)
cb1.ax.tick_params(labelsize=15)


def nan_helper(y):
    return np.isnan(y), lambda z: z.nonzero()[0]


def interpolate_nans_out(data):
    nans, x = nan_helper(data)
    data[nans] = np.interp(x(nans), x(~nans), data[~nans])


t0 = axes[0].text(.9,0.01, "", color='black', fontsize=20,
                  transform=axes[0].transAxes, horizontalalignment="center")
t1 = axes[1].text(.9,0.01, "", color='black', fontsize=20,
                  transform=axes[1].transAxes, horizontalalignment="center")


def draw_scatter(b_data, i):
    # fig.suptitle(str(datetime.date(2014,9,6) + datetime.timedelta(days=i)))
    t0.set_text(str(datetime.date(2014,9,6) + datetime.timedelta(days=daylist[i])))
    t1.set_text(str(datetime.date(2014,9,6) + datetime.timedelta(days=daylist[i])))
    interpolate_nans_out(data[i])
    im_data.set_array(data[i].reshape((latres, -1)))
    im_unc.set_array(unc[i].reshape((latres, -1)))
    if plot_winds:
        im_winds.set_UVC(u_winds[i], v_winds[i]) #, np.sqrt(u_winds[i]**2 + v_winds[i]**2))

    # Scatter plot the observations
    scats = []
    edgecolors = ['black', 'black', 'white', 'black', 'black']
    markers = ['o', 'o', 'o', 'o', 'o']
    linewidths = [1,1,2,1,1]
    sizes = [100, 175, 250, 175, 100]
    neighboring_days = 2
    ii = int(daylist[i])
    scattermin = nan; scattermax = nan
    for j in range(ii-neighboring_days, ii+neighboring_days+1):
        ll = j - (ii-neighboring_days)
        scatterdatafile = sys.argv[1] + "/daydatas/day_" + str(j) + ".txt"
        if os.path.exists(scatterdatafile):
            print("day", j, "we have obs in", scatterdatafile)
            scatterdata = loadtxt(scatterdatafile, ndmin=2)
            scattermin = nanmin(np.array([scattermin, nanmin(scatterdata[:,2])]))
            scattermax = nanmax(np.array([scattermax, nanmax(scatterdata[:,2])]))
            scatterdata = reshape(scatterdata, (-1,5))
            scats.append(b_data.scatter(scatterdata[:,1], scatterdata[:,0],
                                        c=scatterdata[:,2],
                                        s=sizes[ll], edgecolors=edgecolors[ll], linewidths=linewidths[ll],
                                        cmap='jet', latlon=True, vmin=v0, vmax=v1, marker=markers[ll]))

    minval = nanmin(np.array([scattermin, nanmin(data[i])]))
    maxval = nanmax(np.array([scattermax, nanmax(data[i])]))
    print("Scatter min/max:", minval, maxval)
    # cb0.set_clim(vmin=minval, vmax=maxval)
    # cb0.draw_all()
    fig.canvas.draw()
    plt.tight_layout()
    savefig(sys.argv[1] + "/img_" + str(i).zfill(4) + ".jpg",
            bbox_inches='tight', pad_inches=0.1, dpi=110)
    print("fig:", i, end="\r")
    for s in scats:
        s.remove()


for i,d in enumerate(data[::1]):
    # if i not in [55,634]: continue
    # if i not in [55,329,513,634]: continue
    draw_scatter(b_data, i)


# If you don't want to struggle with matplotlib animations, his can be
# encoded into a video e.g. with something like:

# ffmpeg -f image2 -threads 0 -r 15 -i pngtest/img_%04d.jpg -y -vcodec libx264 -crf 22  -vf scale=1080:-1 ChinaSea.mp4
# ffmpeg -r 1/15 -pattern_type glob -i img_'*'.jpg -c:v libx264 -vf "fps=25,format=yuv420p" -vf scale=1280:-2  ChinaSea_newer_test.mp4
