#!/usr/bin/env python2
# -*- coding: utf-8 -*-
"""
Created on Thu Nov 10 11:10:55 2016

Supportive script for Model. Contains all functions geospatial operations required for Model.
Functions inside this script are imported into the main Model script.

@author: davevanwees
"""

import numpy as np
import netCDF4 as netcdf
from pyhdf.SD import SD, SDC
import os
import osr
import gdal, gdalconst
import pyproj
import matplotlib.pyplot as plt
import openpyxl as excl
import time as timer
plt.rc('image', interpolation='none')
gdal.UseExceptions()


@classmethod
def load_degbox(cls, resdeg, tile):
    
    '''
    indexdeg: degree index vectors of box bounding MODIS tile.
    sizedeg: sizes of indexdeg lat and lon vectors.
    lat, lon: degree latitude and longitude vectors, global.
    '''
    wdir_geo = cls.wdir_geo
    region_box = cls.region_box
    
    ress = '%03.0fdeg' % float(('%0.2f' % round(resdeg,2)).lstrip('0.'))  # converts e.g. 0.25 -> '025deg'
    
    indexdeg = {}
    
    if tile[::3] == 'hv':       # if tile is a MODIS tile.
        overlap_dict = np.load(wdir_geo + '00_Pixel_location_'+ress+'_extremes_top.npy')[()]  # 0.25deg cell extremes that bound box around MODIS tile.
        overlap = overlap_dict[tile]
        #overlap = cls.MODIS_pixloc(resdeg, tile, mode='corners')
        
        indexdeg['lat'] = np.arange(overlap['lat'][0], overlap['lat'][1] + 1)  # indices of 0.25deg box bounding the MODIS tile.
        indexdeg['lon'] = np.arange(overlap['lon'][0], overlap['lon'][1] + 1)
        
    else:       # if tile is 'global' or a region_box (e.g. Africa)
        indexdeg['lat'] = np.arange(0, int(180/resdeg))
        indexdeg['lon'] = np.arange(0, int(360/resdeg))
        
        if tile in region_box.keys():
            box_lat = [int(i * (0.25/resdeg)) for i in region_box[tile][0]]
            box_lon = [int(i * (0.25/resdeg)) for i in region_box[tile][1]]
            indexdeg['lat'] = indexdeg['lat'][box_lat[0] : box_lat[1]]      # restrict degree area to a region, instead of global.
            indexdeg['lon'] = indexdeg['lon'][box_lon[0] : box_lon[1]]
    
    sizedeg = {'lat': len(indexdeg['lat']), 'lon': len(indexdeg['lon'])}
    
    lat = np.arange(90 - resdeg / 2.0, -90, -resdeg)  # cell midpoint indices
    lon = np.arange(-180 + resdeg / 2.0, 180, resdeg)
    
    lat = lat[indexdeg['lat'][0]:indexdeg['lat'][-1] + 1]
    lon = lon[indexdeg['lon'][0]:indexdeg['lon'][-1] + 1]
    
    return indexdeg, sizedeg, lat, lon


@classmethod
def load_area(cls, res, tile):
    
    wdir_geo = cls.wdir_geo
    
    if res in [250, 500, 1000]:
        
        realres = 463.312716528 * (res / 500.0)
        area_pix = realres ** 2
        
        cols = (500 * 2400) / res
        rows = (500 * 2400) / res
        area_tile = cols * rows * area_pix
        
    else:                                  # pixel area [m2]
        ress = '%03.0fdeg' % float(('%0.2f' % round(res,2)).lstrip('0.'))  # converts e.g. 0.25 -> '025deg'
        total_pix = np.load(wdir_geo + '00_Total_pixels_'+ress+'.npy')  # number of MODIS pixels in every degree grid cell.
        indexdeg, _, _, _ = cls.load_degbox(res, tile)
        area = total_pix * 463.312716528 ** 2
        #area = cls.loadGFED('grid_cell_area', '2010')
        area_pix = area[indexdeg['lat'][0]:indexdeg['lat'][-1] + 1, indexdeg['lon'][0]:indexdeg['lon'][-1] + 1].copy()  # confine tile area in 0.25deg grid.
        #Alternative (makes copy): area_pix = area[np.ix_(index025deg['lat'], index025deg['lon'])]
    
        if tile[::3] == 'hv':       # if tile is a MODIS tile.     
            tilemask, _ = cls.tile_mask(res, tile)  # summed tile area [m2]
            area_tile = np.sum(area_pix * tilemask)
        else:       # if tile is 'global' or a region_box (e.g. Africa)
            area_tile = []
    
    return area_pix, area_tile


@classmethod
def tilemap(cls, sampleloc, mapping, mask_earth=True, m250=False, m1000=False):
    
    '''
    Works like the MODIS tile calculator: https://landweb.modaps.eosdis.nasa.gov/cgi-bin/developer/tilemap.cgi
    Geolocation array can be made using 'reverse' mapping and:  index_x, index_y = np.meshgrid(range(2400), range(2400))
    Also works for arrays!
    FORWARD mapping = [lat, lon] -> [tile, index_y, index_x]
    INVERSE mapping = [tile, index_y, index_x] -> [lat, lon]    # Also works for noniteger and negative indices! e.g. ul tile corner is ['h19v09', -0.5, -0.5]
    :param cls: 
    :param sampleloc: location [lat, lon] in degrees ('forward') or [tile, index_y, index_x] ('inverse')
    :param mapping: 'forward' or 'reverse' (output in degrees) or 'reverse_m' (output in meters).
    :param mask_earth: mask outside earth as NaN [bool]
    :param m250: set True if 250 meter resolution MODIS grid.
    :param m1000: set True if 1 km resolution MODIS grid.
    :return: forward: [tile, index_y, index_x], reverse: [lat, lon] in degrees or meters (_m).
    '''
    
    sphere_radius = 6371007.181
    proj4str = ("+proj=sinu +a=%f +b=%f +units=m" % (sphere_radius, sphere_radius))
    p_modis_grid = pyproj.Proj(proj4str)
    
    R0 = 6371007.181000  # Earth radius in [m]
    limit_left = -20015109.354  # left limit of MODIS grid in [m]
    limit_top = 10007554.677  # top limit of MODIS grid in [m]
    realres = ((abs(limit_left) * 2) / 36) / 2400  # actual size of each MODIS tile  (alternative: cell_size = ((limit_top*2)/18) / 2400)
    realres = 463.312716528
    ndim = 2400
    
    if m250 == True:
        realres = realres / 2
        ndim = 4800
        
    if m1000 == True:
        realres = realres * 2
        ndim = 1200
    
    T = ndim * realres
    
    if mapping == 'forward':     # [lat, lon]
        
        lat = sampleloc[0]
        lon = sampleloc[1]
        
        x, y = p_modis_grid(lon, lat)
        
        lon_frac = (x / T) + 36 / 2
        lat_frac = - (y / T) + 18 / 2
        
        hn = np.floor(lon_frac).astype(int)
        vn = np.floor(lat_frac).astype(int)
        
        index_x = np.floor((lon_frac - hn) * ndim).astype(int)      # floor to get integer index. Checked -> floor is correct!
        index_y = np.floor((lat_frac - vn) * ndim).astype(int)
        
        # ### METHOD 2 (Same result):
        # x = R0 * np.deg2rad(lon) * np.cos(np.deg2rad(lat))
        # y = R0 * np.deg2rad(lat)
        # 
        # hn = np.floor((x - limit_left) / T).astype(int)
        # vn = np.floor((limit_top - y) / T).astype(int)
        # 
        # index_x = np.floor(( (x - limit_left) % T) / realres).astype(int)
        # index_y = np.floor(( (limit_top - y) % T) / realres).astype(int)
        
        if index_x.ndim == 0:                       # if it concerns one value.
            tilen = 'h%02dv%02d' % (hn, vn)
            output = [tilen, index_y, index_x]
            
        elif index_x.ndim == 1:                    # if it concerns a 1d list/array of values.
            tilen = ['h%02dv%02d' % (hni, vni) for hni, vni in zip(hn, vn)]
            output = [tilen, index_y, index_x]
            
        else:                                       # if it concers a multi-dimensional array.
            output = [hn, vn, index_y, index_x]
    
    elif 'reverse' in mapping:
        
        tilen = sampleloc[0]
        index_x = sampleloc[2]
        index_y = sampleloc[1]
        
        hn = int(tilen[1:3])
        vn = int(tilen[4:])
        
        lon_frac = (index_x + 0.5) / ndim + hn      # +0.5 to get cell midcenter.
        lat_frac = (index_y + 0.5) / ndim + vn
        
        x = (lon_frac - 36/2) * T
        y = - (lat_frac - 18/2) * T
        
        lon, lat = p_modis_grid(x, y, inverse=True)
        
        # ### METHOD 2 (Same result):
        # x = (index_x + 0.5) * realres + hn * T + limit_left
        # y = limit_top - (index_y + 0.5) * realres - vn * T
        # 
        # lat2 = np.rad2deg(y / R0)
        # lon2 = np.rad2deg(x / (R0 * np.cos(np.deg2rad(lat))))
        
        if mask_earth == True:
            
            phi = np.deg2rad(lat)   # phi
            # lam = np.deg2rad(lon)   # lambda
            # y2 = R0 * phi                   # Recalculate x and y. https://en.wikipedia.org/wiki/Sinusoidal_projection
            # x2 = R0 * lam * np.cos(phi)
            x_border = np.deg2rad(180.0) * R0 * np.cos(phi)
            
            outside_earth = np.abs(x) > x_border
            
            if (type(outside_earth) is np.bool_):
                if outside_earth == True:
                    lat = np.nan
                    lon = np.nan
                    y = np.nan
                    x = np.nan
                else: pass
            else:
                lat[outside_earth] = np.nan
                lon[outside_earth] = np.nan
                y[outside_earth] = np.nan
                x[outside_earth] = np.nan
        
        if mapping == 'reverse_m':
            output = [y, x]
        else:
            output = [lat, lon]
        
    return output


@classmethod
def MODIS_tile_corners(cls, tile, proj_m=False):
    
    if proj_m == False:        # tile lat lon extremes in [degrees].
        
        ul = cls.tilemap([tile, -0.5, -0.5], mapping='reverse', mask_earth=True)
        lr = cls.tilemap([tile, 2399.5, 2399.5], mapping='reverse', mask_earth=True)
        ur = cls.tilemap([tile, -0.5, 2399.5], mapping='reverse', mask_earth=True)
        ll = cls.tilemap([tile, 2399.5, -0.5], mapping='reverse', mask_earth=True)
        
        if np.isnan(np.array([ul, lr, ur, ll])).any():
            # If one or more corners outside earth, calculate new earth boundary tile corners.
            index_x, index_y = np.meshgrid(range(2400), range(2400))
            lats_geom, lons_geom = cls.tilemap([tile, index_y, index_x], mapping='reverse_m', mask_earth=True)
            
            ul_ind = np.unravel_index(np.nanargmin(lons_geom - lats_geom), (2400,2400))
            lr_ind = np.unravel_index(np.nanargmin(lats_geom - lons_geom), (2400,2400))
            ur_ind = np.unravel_index(np.nanargmax(lats_geom + lons_geom), (2400,2400))
            ll_ind = np.unravel_index(np.nanargmin(lats_geom + lons_geom), (2400,2400))
            
            ul = cls.tilemap([tile, ul_ind[0]-0.5, ul_ind[1]-0.5], mapping='reverse', mask_earth=True)
            lr = cls.tilemap([tile, lr_ind[0]+0.5, lr_ind[1]+0.5], mapping='reverse', mask_earth=True)
            ur = cls.tilemap([tile, ur_ind[0]-0.5, ur_ind[1]+0.5], mapping='reverse', mask_earth=True)
            ll = cls.tilemap([tile, ll_ind[0]+0.5, ll_ind[1]-0.5], mapping='reverse', mask_earth=True)
            
            # IF still both nan, at least retrieve the latitude.
            if np.isnan(np.array(ul)).all():    ul[0] = cls.tilemap([tile, ul_ind[0]-0.5, ul_ind[1]-0.5], mapping='reverse', mask_earth=False)[0]
            if np.isnan(np.array(lr)).all():    lr[0] = cls.tilemap([tile, lr_ind[0]+0.5, lr_ind[1]+0.5], mapping='reverse', mask_earth=False)[0]
            if np.isnan(np.array(ur)).all():    ur[0] = cls.tilemap([tile, ur_ind[0]-0.5, ur_ind[1]+0.5], mapping='reverse', mask_earth=False)[0]
            if np.isnan(np.array(ll)).all():    ll[0] = cls.tilemap([tile, ll_ind[0]+0.5, ll_ind[1]-0.5], mapping='reverse', mask_earth=False)[0]
        
        lon_min = np.nanmin([ul[1], lr[1], ur[1], ll[1]])   # Determine min and max based on tile corners.
        lon_max = np.nanmax([ul[1], lr[1], ur[1], ll[1]])
        lat_min = np.nanmin([ul[0], lr[0], ur[0], ll[0]])
        lat_max = np.nanmax([ul[0], lr[0], ur[0], ll[0]])
        
        if np.isnan(np.array([ul, lr, ur, ll])).any():
            hn = int(tile[1:3])
            if hn <= 17:    lon_min = -180.0
            elif hn > 17:   lon_max = 180.0
        
        lon_min = np.around(lon_min, 6)     # round if more than 6 decimals are non-significant. 1e-6 deg is 0.1 meter precision, more than enough.
        lon_max = np.around(lon_max, 6)
        lat_min = np.around(lat_min, 6)
        lat_max = np.around(lat_max, 6)
        
        return lon_min, lon_max, lat_min, lat_max
    
    
    elif proj_m == True:      # tile projection corners in [m].
        
        ul = cls.tilemap([tile, -0.5, -0.5], mapping='reverse_m', mask_earth=False)
        lr = cls.tilemap([tile, 2399.5, 2399.5], mapping='reverse_m', mask_earth=False)
        ur = cls.tilemap([tile, -0.5, 2399.5], mapping='reverse_m', mask_earth=False)
        ll = cls.tilemap([tile, 2399.5, -0.5], mapping='reverse_m', mask_earth=False)
        
        ulx = ul[1]
        lrx = lr[1]
        uly = ul[0]
        lry = lr[0]
        
        # # Manual method (old, but same result)
        # # Manual calculation of tile corners in meters. Outmost corners of tile, not center of outermost pixel!!!
        # 
        # R0 = 6371007.181000  # Earth radius in [m]
        # limit_left = 20015109.354  # left limit of MODIS grid in [m]
        # limit_top = 10007554.677  # top limit of MODIS grid in [m]
        # realres = ((limit_left * 2) / 36) / 2400  # actual size of each MODIS tile  (alternative: cell_size = ((limit_top*2)/18) / 2400)
        # realres = 463.312716528
        # 
        # limit_left = 20015109.354  # left limit of MODIS grid in [m]
        # limit_top = 10007554.677  # top limit of MODIS grid in [m]
        # 
        # h = int(tile[1:3])
        # v = int(tile[4:6])
        # 
        # x_coor = -limit_left + h * realres * 2400  # location of lower left tile corner in meters (see https://code.env.duke.edu/projects/mget/wiki/SinusoidalMODIS)
        # y_coor = -limit_top + (17 - v) * realres * 2400
        # 
        # ulx = realres * 2400 * (h - 36 / 2) + 0.000010  # location of upper left and lower right corners.
        # lrx = realres * 2400 * (h - 36 / 2 + 1) + 0.000010
        # uly = -realres * 2400 * (v - 18 / 2) - 0.000005
        # lry = -realres * 2400 * (v - 18 / 2 + 1) - 0.000005
        
        return ulx, lrx, uly, lry


@classmethod
def load_proj(cls, res, tile, projtype):
    ''' Get Projection (in Wkt or Proj4) and GeoTransfrom matrix
    For 500-meter tile, or 0.25 degree Global or Africa.
    Geotransform array is [ulx, W-E pixel res, rot, uly, rot, N-S pixel resolution]
    '''
    region_box = cls.region_box
    
    if res in [250, 500, 1000]:
        
        dsSRS = osr.SpatialReference()
        sphere_radius = 6371007.181
        proj4str = ("+proj=sinu +a=%f +b=%f +units=m" % (sphere_radius, sphere_radius))  # sphere_radius = 6371007.181
        # Source: http://spatialreference.org/ref/sr-org/6842/
        # And see alos: https://lpdaac.usgs.gov/sites/default/files/public/product_documentation/mcd12_user_guide_v6.pdf
        dsSRS.ImportFromProj4(proj4str)
        
        realres = 463.312716528
        if res == 250:      realres = realres / 2
        elif res == 1000:   realres = realres * 2
        
        ulx, _, uly, _ = cls.MODIS_tile_corners(tile, proj_m=True)
        dsGEO = (ulx, realres, 0, uly, 0, -realres)
    
    else:    # if degrees.
        
        dsSRS = osr.SpatialReference()
        dsSRS.ImportFromEPSG(4326)
        
        _, _, lat, lon = cls.load_degbox(res, tile)
        ulx = lon[0] - res/2.0  # gdal interperts top left corner of pixel, not center! Therefore, add half a pixel to get top left.
        uly = lat[0] + res/2.0
        dsGEO = (ulx, res, 0, uly, 0, -res)
    
    if projtype == 'wkt':
        dsSRS = dsSRS.ExportToWkt()
    elif projtype == 'proj4':
        dsSRS = dsSRS.ExportToProj4()
    
    return dsSRS, dsGEO


@classmethod
def construct_geolocation(cls, tilen, mask_earth=True, mres=None):
    
    if mres is None: mres = 500
    npixels = (500 * 2400) / mres    # number of row and column pixels in MODIS tile.
    
    m250 = False
    m1000 = False
    if mres == 250:  m250=True
    if mres == 1000: m1000=True
    
    index_x, index_y = np.meshgrid(range(npixels), range(npixels))
    lats_geo, lons_geo = cls.tilemap([tilen, index_y, index_x], mapping='reverse', mask_earth=mask_earth, m250=m250, m1000=m1000)
    
    # No need to save: loading Float64 saved .tif geoloc opens in 0.5 s, just as fast as calculating. Float64 precision is necessary.
    
    return lats_geo, lons_geo


@classmethod
def loci(cls, res, sampleloc):
    
    ''' Finds for a sample location (in degrees) in which MODIS tile it is located, and which index it has in that tile.
    Parameters
    ----------
    :res: resolution. Insert degree resolutions as float (0.05, 0.10, 0.25), and meter resolutions as integer (500)
    :sampleloc: sample location in coordinate degrees. Valid input: [lat, lon] (2-item list, array or tuple)
    '''

    if res in [250, 500, 1000]:
        
        tile, y, x = cls.tilemap([sampleloc[0], sampleloc[1]], mapping='forward')
        loc = y, x
    
    else:   # if res is degrees.
        lat = np.arange(90 - res / 2.0, -90, -res)  # cell midpoint indices
        lon = np.arange(-180 + res / 2.0, 180, res)
        
        y = np.abs(lat - sampleloc[0]).argmin()  # test for closest degree cell midpoint
        x = np.abs(lon - sampleloc[1]).argmin()
        loc = y, x
        tile = 'global'
        
        # ### METHOD 2 (Same result):
        # y = np.floor((90.0 - sampleloc[0]) / res)
        # x = np.floor((sampleloc[1] + 180.0) / res)
        # 
        # # inverse:
        # lat = (90.0 - res / 2.0) - res * y
        # lon = (-180.0 + res / 2.0) + res * x
        
        
    return loc, tile


@classmethod
def tile_index(cls, res, sampleloc, tile, indexdeg=None):
    
    '''
    Returns indices in model grid of samplelocation. Also checks if that location is inside or outside tile/region.
    If inside, the indices of the sample location in the tile/region are returned. If outise, default indices are returned.
    :param cls: 
    :param res: resolution of model grid (500 [m], 0.25, 0.125, 0.05 [degrees])
    :param sampleloc: sample location [lat, lon] in degrees
    :param tile: MODIS tile or degree region ('global', 'Africa', ..).
    :param indexdeg: 
    :return: index_y, index_x (grid indices), in_out (inside or outside model grid boundaries)
    '''
    
    loci = cls.loci
    
    loc, _ = loci(res, sampleloc)
    index_y, index_x = loc
    
    if tile[::3] == 'hv':       # if tile is a MODIS tile, check if sampleloc is inside tile.
        
        _, tile_check = loci(500, sampleloc)
        if tile_check == tile:
            in_out = 'inside'
        else: in_out = 'outside'
        
    if type(res) == float:
        
        if indexdeg is None:
            indexdeg, _, _, _ = cls.load_degbox(res, tile)
            
        index_y = index_y - indexdeg['lat'][0]
        index_x = index_x - indexdeg['lon'][0]
        
        if tile[::3] != 'hv':
            if (index_y in indexdeg['lat'] - indexdeg['lat'][0]) and (index_x in indexdeg['lon'] - indexdeg['lon'][0]):  # test if samplepoint inside tile/region.
                in_out = 'inside'
            else:
                in_out = 'outside'
    
    if in_out == 'outside':
        index_y = np.nan
        index_x = np.nan
        
        # From old function:
        # if res == 500:
        #     index_y = 1200  # outside tile, using default sampleloc
        #     index_x = 1200
        # else:
        #     if tile == 'global':
        #         index_y = 391      # outside tile, using default sampleloc
        #         index_x = 796
        #     elif tile != 'global':
        #         index_y = indexdeg['lat'][len(indexdeg['lat']) / 2] - indexdeg['lat'][0]  # outside tile, using default sampleloc (=tile midpoint)
        #         index_x = indexdeg['lon'][len(indexdeg['lon']) / 2] - indexdeg['lon'][0]
    
    return index_y, index_x, in_out


@classmethod
def border_check(cls, index_y, index_x, border):
    
    '''
    Check if sample location including its border are inside the simulation tile/region boundaries.
    :param cls: 
    :param index_y: lat index in model grid [int]
    :param index_x: lon index in model grid [int]
    :param border: number of pixels surrouding the sample location [int]
    :return: border [int] (new, corrected border), border_changed [bool] (boolean if border had to be corrected or not)
    '''
    
    b = border
    
    if index_y - b >= 0 and index_x - b >= 0:
        border_changed = False
    else:
        b_correction = min([n for n in [index_y - b, index_x - b] if n < 0])  # find largest negative number
        border = b + b_correction
        border_changed = True
    
    return border, border_changed


@classmethod
def grid_conv(cls, tile, resdeg):
    
    '''
    Construct grid for conversion of degree data to 500-meter tiles, using nearest neighbour interpolation.
    :param cls: 
    :param tile: MODIS tile [str]
    :param resdeg: resolution of degree grid [float]
    :return: ilatc_ar, ilonc_ar. Conversion meshgrid arrays for latitude and longitude.
    '''
    
    # Old code, much slower:
    # ilatc_ar = np.zeros((2400, 2400)).astype('int')
    # ilonc_ar = np.zeros((2400, 2400)).astype('int')
    # 
    # for y in range(0, 2400):
    #     for x in range(0, 2400):
    #         ilatc_ar[y, x] = np.abs(lat - lats_geo[y, x]).argmin()
    #         ilonc_ar[y, x] = np.abs(lon - lons_geo[y, x]).argmin()
    
    lats_geo, lons_geo = cls.construct_geolocation(tile)
    outside_earth = np.isnan(lons_geo)
    
    lat_index = np.floor(np.abs(lats_geo - 90.0) * int(1 / resdeg)).astype(int)  # 0 : 720, max(lat_index) = 719 
    lon_index = np.floor((lons_geo + 180.0) * int(1 / resdeg)).astype(int)  # 0 : 1440, max(lon_index) = 1439

    lat_index = lat_index.astype(float)
    lon_index = lon_index.astype(float)
    lat_index[outside_earth] = np.nan
    lon_index[outside_earth] = np.nan
    ilatc_ar = (lat_index - np.nanmin(lat_index))
    ilonc_ar = (lon_index - np.nanmin(lon_index))
    
    ilatc_ar[outside_earth] = 99999     # set an index that will never occur in the degree grid range.
    ilonc_ar[outside_earth] = 99999
    ilatc_ar = ilatc_ar.astype(int)
    ilonc_ar = ilonc_ar.astype(int)
    
    return ilatc_ar, ilonc_ar


@classmethod
def haversine(cls, lat1, lat2, lon1, lon2, R=None):
    '''
    Calculate distance in meters between two coordinates on sphere.
    Input in degrees. Also works for array inputs.
    Default earth radius is WGS84.
    # Source: http://www.movable-type.co.uk/scripts/latlong.html
    # Or: https://en.wikipedia.org/wiki/Haversine_formula
    '''
    
    if R is None:
        R = 6378137     # WGS84 earth radius in m.
    lat1 = np.deg2rad(lat1)     # convert to radians.
    lat2 = np.deg2rad(lat2)
    lon1 = np.deg2rad(lon1)
    lon2 = np.deg2rad(lon2)
    
    dlat = lat2-lat1
    dlon = lon2-lon1
    
    a = np.sin(dlat/2)**2 + np.cos(lat1) * np.cos(lat2) * np.sin(dlon/2)**2
    c = 2 * np.arctan2(np.sqrt(a), np.sqrt(1-a))
    d = R * c   # distance in m.
    
    # Note: Can be used to calculate exact pixel area, if needed.
    
    return d


@classmethod
def pixel_area(cls, lat, resdeg):
    # Calculate m2 of degree pixel. Approximate, only for small pixels. For exact area use haversine.
    
    R = 6378137     # WGS84 earth radius in m.
    
    dlon = R * np.tan(np.deg2rad(resdeg)) * np.cos(np.deg2rad(lat))
    dlat = R * np.tan(np.deg2rad(resdeg))
    
    return dlon * dlat      # pixel area in m2.


@classmethod
def MODIS_pixloc(cls, resdeg, tile, mode='corners'):
    
    '''
    :param resdeg: resolution in degrees [float]
    :param tile: MODIS tile [str]
    :param mode: 'corners', 'counts'
    :return: ind_out, corner points of bounding square [dict], or gridded pixel counts [arr].
    '''
    
    wdir_geo = cls.wdir_geo
    
    ress = '%03.0fdeg' % float(('%0.2f' % round(resdeg,2)).lstrip('0.'))  # converts e.g. 0.25 -> '025deg'
    
    if mode == 'corners':
        
        lats_geo, lons_geo = cls.construct_geolocation(tile)
        outside_earth = np.isnan(lons_geo)
        lat_index = np.floor(np.abs(lats_geo - 90.0) * int(1 / resdeg)).astype(int)  # 0 : 720, max(lat_index) = 719 
        lon_index = np.floor((lons_geo + 180.0) * int(1 / resdeg)).astype(int)  # 0 : 1440, max(lon_index) = 1439
        
        lat_index = lat_index.astype(float)
        lon_index = lon_index.astype(float)
        lat_index[outside_earth] = np.nan
        lon_index[outside_earth] = np.nan
        
        ind_lat_ex = (np.nanmin(lat_index).astype(int), np.nanmax(lat_index).astype(int))
        ind_lon_ex = (np.nanmin(lon_index).astype(int), np.nanmax(lon_index).astype(int))
        tile_pix = {'lat': ind_lat_ex, 'lon': ind_lon_ex}
    
    elif mode == 'counts':
        
        tilemask_name = 'tilemask_' + ress + '_' + tile
        if os.path.isfile(wdir_geo + 'tilemask/' + tilemask_name + '.tif'):
            ds = gdal.Open(wdir_geo + 'tilemask/' + tilemask_name + '.tif', gdalconst.GA_ReadOnly)
            tile_pix = ds.ReadAsArray().astype('float64')
            ds = None
            
        else:
            print 'no tile mask %s found: creating one ... ' % tile,
            
            lats_geo, lons_geo = cls.construct_geolocation(tile)
            outside_earth = np.isnan(lons_geo)
            lat_index = np.floor(np.abs(lats_geo - 90.0) * int(1 / resdeg)).astype(int)  # 0 : 720, max(lat_index) = 719 
            lon_index = np.floor((lons_geo + 180.0) * int(1 / resdeg)).astype(int)  # 0 : 1440, max(lon_index) = 1439
            
            lat_index[outside_earth] = 99999
            lon_index[outside_earth] = 99999
            #lat_vector = np.unique(lat_index[~np.isnan(lat_index)])
            #lon_vector = np.unique(lon_index[~np.isnan(lon_index)])
            
            arr = np.empty((2400 * 2400), dtype='O')
            arr[:] = zip(lat_index.flatten(), lon_index.flatten())
            arr = arr.reshape((2400, 2400))
            
            uni, cou = np.unique(arr, return_counts=True)
            
            # Old algorithm, more than twice as slow.
            # time0=timer.time()
            # for x in range(0, 2400):
            #     for y in range(0, 2400):
            #         if lat_index[x, y] < (180 * int(1 / resdeg)) and lon_index[x, y] < (360 * int(1 / resdeg)):  # excluded pixels have value
            #             total_pix[lat_index[x, y], lon_index[x, y]] = total_pix[lat_index[x, y], lon_index[x, y]] + 1  # max value = 3600; 60*60 (number of MODIS pixels in 0.25 deg)
            # dur0 = timer.time() - time0
            
            uni = uni.tolist()
            cou = cou.tolist()
            if (99999, 99999) in uni:
                delindex = uni.index((99999, 99999))
                del uni[delindex], cou[delindex]
            
            tile_pix = np.zeros((180 * int(1 / resdeg), 360 * int(1 / resdeg))).astype(int)
            tile_pix[tuple(np.array(uni).T)] = cou
            
            driver = gdal.GetDriverByName('GTiff')
            ds = driver.Create(wdir_geo + 'tilemask/' + tilemask_name + '.tif', 360 * int(1 / resdeg), 180 * int(1 / resdeg), 1, gdal.GDT_UInt16, options=['COMPRESS=LZW', 'INTERLEAVE=BAND', 'TILED=YES'])
            ds.GetRasterBand(1).WriteArray(tile_pix)
            ds = None
            
            print 'Done.'
        
        tile_pix = tile_pix.astype(int)
    
    return tile_pix


@classmethod
def tile_mask_new(cls, resdeg, tiles, frac=False, extent=None):
    
    wdir_geo = cls.wdir_geo
    
    ress = '%03.0fdeg' % float(('%0.2f' % round(resdeg, 2)).lstrip('0.'))  # converts e.g. 0.25 -> '025deg'
    
    if (type(tiles) is str) or (type(tiles) is np.string_): tiles = [tiles]
    
    fullmask = np.zeros((int(180 / resdeg), int(360 / resdeg)))
    bound_box = [int(180 / resdeg), 0, int(360 / resdeg), 0]  # top, bottom, left, right
    
    for tilen in tiles:
        
        tile_pix = cls.MODIS_pixloc(resdeg, tilen, mode='counts')
        
        indexdeg,_,_,_ = cls.load_degbox(resdeg, tilen)
        
        tile_pix2 = np.zeros((int(180 / resdeg), int(360 / resdeg)))
        tile_pix2[indexdeg['lat'][0]:indexdeg['lat'][-1]+1, indexdeg['lon'][0]:indexdeg['lon'][-1]+1] = tile_pix
        
        fullmask += tile_pix2
    
    total_pix = np.load(wdir_geo + '00_Total_pixels_' + ress + '_new.npy')  # number of MODIS pixels in every degree grid cell.
    fullmask = fullmask / total_pix 
    
    if frac == False:
        fullmask[fullmask > 0] = 1      # if frac is 'yes', keep fractions, otherwise make boolean.
    
    index = np.where(fullmask > 0)
    bound_box = [min(index[0]), max(index[0]), min(index[1]), max(index[1])]
    
    if extent is None:
        fullmask = fullmask[bound_box[0]: bound_box[1] + 1, bound_box[0]: bound_box[1] + 1]
        
    else:
        indexdeg, _, _, _ = cls.load_degbox(resdeg, extent)  # extent is e.g. global, Africa, or other items in the region_box dictionary.
        fullmask = fullmask[indexdeg['lat'][0]:indexdeg['lat'][-1] + 1, indexdeg['lon'][0]:indexdeg['lon'][-1] + 1]
        bound_box[0] = bound_box[0] - indexdeg['lat'][0]
        bound_box[1] = bound_box[1] - indexdeg['lat'][0]
        bound_box[2] = bound_box[2] - indexdeg['lon'][0]
        bound_box[3] = bound_box[3] - indexdeg['lon'][0]
    
    fullmask[fullmask == 0] = np.nan
    
    return fullmask, bound_box


def load_dates(years):
    
    ## load dates
    ddir = '/Volumes/Mac_HD/Work/Data/'
    fdir = 'Era-Interim/AIRT/ECMWF_Era-interim-monthlyofdaily_2000-2017_AIRT_grid-025x025deg.nc'
    ds = netcdf.Dataset(ddir + fdir, 'r')
    ds_vars = ds.variables.keys()  # list variable names
    time = ds.variables['time'][:]
    time_units = ds.variables['time'].units
    time_calendar = ds.variables['time'].calendar
    datum = netcdf.num2date(time, units=time_units, calendar=time_calendar)

    if (type(years) is str) or (type(years) is np.string_): years = [years]
    
    yearn = np.asarray([date.year for date in datum])  # years expanded times 12
    dates = np.zeros((0))
    for folder in years:
        year_index = np.where(yearn == int(folder))[0]
        date = datum[year_index]
        dates = np.concatenate((dates, date), axis=0)
    ds.close()
    return dates


def init_vars(self):
    
    ''' 
    Initialize variables used in model.
    '''
    
    self.fyear = None
    self.fmonth = None
    self.biome = None
    
    self.s_years = None
    
    self.leaf = 0
    self.gras = 0
    self.stem = 0
    #self.root = 0
    self.cwd = 0
    self.litt = 0
    self.soil = 0
    self.slow = 0
    
    self.rootf = 0
    self.rootc = 0
    
    self.leaf_out = 0
    self.stem_out = 0
    self.cwd_out = 0
    self.atmos = 0
    self.fire_count = 0
    
    # self.diagn = np.zeros([self.n_years*12,8])
    # self.pools = np.zeros([self.n_years*12,6])
    # self.params = np.zeros([self.n_years*12,8])
    #self.diagn = {}
    #self.diagn_annual = {}
    #self.pools = {}
    #self.pools_out = {}
    #self.params = {}
    #self.diagn_fire = {}
    #self.pools_fire = {}

    self.diagt_pool = {}
    self.diagm_pool = {}
    self.diagt_out = {}
    self.diagm_out = {}
    self.diagt_param = {}
    self.diagm_param = {}
    self.diagf = {}
        
    if self.firemode == 'on':
        
        self.diagt_fire = {}
        self.diagm_fire = {}
    
    elif self.firemode == 'off':
        
        self.leaf_fire = np.array([0])
        self.gras_fire = np.array([0])
        self.stem_fire = np.array([0])
        self.cwd_fire = np.array([0])
        self.litt_fire = np.array([0])
        
        self.leaf_mort = np.array([0])
        self.gras_mort = np.array([0])
        self.stem_mort = np.array([0])
        self.rootf_mort = np.array([0])
        self.rootc_mort = np.array([0])
        
        self.AGB_fire = np.array([0])
        self.LIT_fire = np.array([0])
        self.cum_fire = np.array([0])
        self.tot_fire = np.array([0])

    self.looptest1 = None
    
    '''
    Note that for e.g. self.leaf a new instance variable is created in the first iteration of self.leaf = self.leaf + ...
    Whereas, for self.diagnostics, the existing class variable gets changed, that is for every instance of the class!
    Therefore, a diagnostics instance has to be initialized first.
    '''


def main_tile(self):
    
    print
    if type(self.resolution) == float:
        self.indexdeg, self.sizedeg, _, _ = self.load_degbox(self.resolution, self.tile)
    elif self.resolution == 500:
        self.indexdeg, self.sizedeg, _, _ = self.load_degbox(0.25, self.tile)
    
    if self.sampleloc != None:
        self.sampley, self.samplex, in_out = self.tile_index(self.resolution, self.sampleloc[0], self.tile, self.indexdeg)
        
        if in_out == 'inside':      print 'sample location %s inside tile, index %s.' % (self.sampleloc[0], [self.sampley, self.samplex])
        elif in_out == 'outside':   
            if self.tile != 'global':   print 'sample location %s outside tile ... using default sampleloc %s (=tile midpoint).' % (self.sampleloc[0], [self.sampley, self.samplex])
        
        print 'border %s' % self.sampleloc[1],
        self.sampleloc[1], border_changed = self.border_check(self.sampley, self.samplex, self.sampleloc[1])
        if border_changed == False:
            print 'inside bounds.'
        if border_changed == True:
            print 'outside bounds, using largest possible bound: %s.' % self.sampleloc[1]

    self.index025deg, self.size025deg, _, _ = self.load_degbox(0.25, self.tile)
    if self.resolution == 500:
        self.ilatc_ar, self.ilonc_ar = self.grid_conv(self.tile, 0.25)
    
    dates = load_dates(self.years)
        
    self.area, _ = self.load_area(self.resolution, self.tile)
    
    init_vars(self)


@classmethod
def loci_old(cls, res, sampleloc):
    ''' Finds for a sample location (in degrees) in which MODIS tile it is located, and which index it has in that tile.
    
    Parameters
    ----------
    :res: resolution. Insert degree resolutions as float (0.05, 0.10, 0.25), and meter resolutions as integer (500)
    :sampleloc: sample location in coordinate degrees. Valid input: [lat, lon] (2-item list, array or tuple)
    '''
    
    wdir_geo = cls.wdir_geo
    
    if type(res) == float:
        lat = np.arange(90 - res / 2.0, -90, -res)  # cell midpoint indices
        lon = np.arange(-180 + res/2.0, 180, res)
        
        y = np.abs(lat - sampleloc[0]).argmin()  # test for closest degree cell midpoint
        x = np.abs(lon - sampleloc[1]).argmin()
        loc = y, x
        tile = 'global'
        
    elif res == 500:
        wb = excl.load_workbook(filename='/Volumes/Mac_HD/Work/Vici_project/Preprocessing/sn_bound_10deg.xlsx', read_only=True, data_only=True);
        ws = wb.get_sheet_by_name('Sheet1')             # MODIS tile corner bounds coordinates, excel file.
        bounds = np.array([[i.value for i in j] for j in ws['A2':'F649']])
        # find all possible corresponding MODIS tiles
        bounds_tile = bounds[(sampleloc[0] >= bounds[:, 4]) \
                             & (sampleloc[0] < bounds[:, 5]) \
                             & (sampleloc[1] >= bounds[:, 2]) \
                             & (sampleloc[1] < bounds[:, 3])]
    
        if len(bounds_tile) == 0: raise ValueError('No tile selected, no MODIS tile available for location')
        if len(bounds_tile) > 1:  # overlap occuring, multiple tiles possible
            tiles = ['h%02dv%02d' % (int(hv[1]), int(hv[0])) for hv in bounds_tile]  # lists possible tile options
            a = []
            minimum = []
            for tilen in tiles[:]:  # iterating over a copy of tiles, because tiles can be altered during looping (in case of except)
                try:
                    hdf = SD(wdir_geo + 'MODIS_geoloc_500m_' + tilen + '.hdf', SDC.READ)
                    lats_geo = hdf.select('Latitude').get()     # 500m pixel midpoint coordinates.
                    lons_geo = hdf.select('Longitude').get()
                    hdf = []
                except:
                    tiles.remove(tilen)  # tiles is altered here. tile not avaialable, skipping.
                    continue
                a.append(np.abs(lats_geo - sampleloc[0]) + np.abs(lons_geo - sampleloc[1]))  # selects correct tile by finding minimum difference
                minimum.append(np.min(a[-1]))
            tile = tiles[minimum.index(min(minimum))]  # final tile choice
        else:
            tile = 'h%02dv%02d' % (int(bounds_tile[0, 1]), int(bounds_tile[0, 0]))  # in case of bounds_tile == 1
    
        hdf = SD(wdir_geo + 'MODIS_geoloc_500m_' + tile + '.hdf', SDC.READ)  # find index of location in tile
        lats_geo = hdf.select('Latitude').get()     # 500m pixel midpoint coordinates.
        lons_geo = hdf.select('Longitude').get()
        hdf = []
        a = np.abs(lats_geo - sampleloc[0]) + np.abs(lons_geo - sampleloc[1])  # find index of sample location in MODIS tile
        loc = np.unravel_index(a.argmin(), a.shape)  # tuple of location indices, [lat, lon]

    return loc, tile


@classmethod
def grid_conv_old(cls, tile, resdeg):
    ### Apply grid conversion for degree grid to 500m.
    
    wdir_geo = cls.wdir_geo
    load_degbox = cls.load_degbox
    
    _, _, lat, lon = load_degbox(resdeg, tile)
    
    ress = '%03.0fdeg' % float(('%0.2f' % round(resdeg,2)).lstrip('0.'))  # converts e.g. 0.25 -> '025deg'
    gridconv_name = 'gridconv_' + ress + '-500m_' + tile
    
    if os.path.isfile(wdir_geo + 'gridconv/' + gridconv_name + '.tif'):
        #print 'conversion grid found: loading ... ',
        ds = gdal.Open(wdir_geo + 'gridconv/' + gridconv_name + '.tif', gdalconst.GA_ReadOnly)
        ilatc_ar = ds.GetRasterBand(1).ReadAsArray()
        ilonc_ar = ds.GetRasterBand(2).ReadAsArray()
        ds = None
    else:
        print 'no conversion grid found: creating one ... ',
        hdf = SD(wdir_geo + 'MODIS_geoloc_500m_' + tile + '.hdf', SDC.READ)
        lats_geo = hdf.select('Latitude').get()     # 500m pixel midpoint coordinates.
        lons_geo = hdf.select('Longitude').get()
        hdf = []
        
        ilatc_ar = np.zeros((2400, 2400)).astype('int')
        ilonc_ar = np.zeros((2400, 2400)).astype('int')
        
        for y in range(0, 2400):
            for x in range(0, 2400):
                ilatc_ar[y, x] = np.abs(lat - lats_geo[y, x]).argmin()
                ilonc_ar[y, x] = np.abs(lon - lons_geo[y, x]).argmin()
        
        driver = gdal.GetDriverByName('GTiff')
        ds = driver.Create(wdir_geo + 'gridconv/' + gridconv_name + '.tif', 2400, 2400, 2, gdal.GDT_Byte, options=['COMPRESS=LZW', 'INTERLEAVE=BAND', 'TILED=YES'])
        ds.GetRasterBand(1).WriteArray(ilatc_ar)
        ds.GetRasterBand(2).WriteArray(ilonc_ar)
        ds = None
    
    #if np.min(ilatc_ar) != 0 or np.min(ilonc_ar) != 0:
    #    print 'CAUTION, conversion grid not minimum 0 ... ',
        
    return ilatc_ar, ilonc_ar


@classmethod
def tile_mask(cls, resdeg, tiles, frac='no', extent=None):
    ''' Returns degree mask of 500m tile borders.
    tiles = list of tiles.
    extent = global or not. If global, the mask will be placed in a global grid.
    '''
    wdir_geo = cls.wdir_geo
    
    ress = '%03.0fdeg' % float(('%0.2f' % round(resdeg,2)).lstrip('0.'))  # converts e.g. 0.25 -> '025deg'

    if (type(tiles) is str) or (type(tiles) is np.string_): tiles = [tiles]
    
    fullmask = np.zeros((int(180 / resdeg), int(360 / resdeg)))
    bound_box = [int(180 / resdeg), 0, int(360 / resdeg), 0]  # top, bottom, left, right
    
    for tile in tiles:
        
        tilemask_name = 'tilemask_' + ress + '_' + tile
        if os.path.isfile(wdir_geo + 'tilemask/' + tilemask_name + '.tif'):
            # print 'tile mask %s found: loading ... ' % tile,
            ds = gdal.Open(wdir_geo + 'tilemask/' + tilemask_name + '.tif', gdalconst.GA_ReadOnly)
            tile_pix = ds.ReadAsArray().astype('float64')
            ds = None
        else:
            time0 = timer.time()
            print 'no tile mask %s found: creating one ... ' % tile,
            hdf = SD(wdir_geo + 'MODIS_geoloc_500m_' + tile + '.hdf', SDC.READ)
            lats_geo = hdf.select('Latitude').get()     # 500m pixel midpoint coordinates.
            lons_geo = hdf.select('Longitude').get()
            hdf = []
            
            indexdeg, sizedeg, _, _ = cls.load_degbox(resdeg, tile)
            lat_index = np.floor((lats_geo + 90.0) * int(1 / resdeg)).astype(int)  # 0 : 720, max(lat_index) = 719 
            lon_index = np.floor((lons_geo + 180.0) * int(1 / resdeg)).astype(int)  # 0 : 1440, max(lon_index) = 1439
            
            tile_pix = np.zeros((int(180.0/resdeg), int(360.0/resdeg)))
            for x in range(0, 2400):
                for y in range(0, 2400):
                    if lat_index[x, y] < (180 * int(1 / resdeg)) and lon_index[x, y] < (360 * int(1 / resdeg)):  # excluded pixels have value
                        tile_pix[lat_index[x, y], lon_index[x, y]] = tile_pix[lat_index[x, y], lon_index[x, y]] + 1  # max value = 3600; 60*60 (number of MODIS pixels in 0.25 deg)
            
            tile_pix = np.flipud(tile_pix)
            tile_pix = tile_pix[indexdeg['lat'][0]:indexdeg['lat'][-1] + 1, indexdeg['lon'][0]:indexdeg['lon'][-1] + 1]
            
            driver = gdal.GetDriverByName('GTiff')
            ds = driver.Create(wdir_geo + 'tilemask/' + tilemask_name + '.tif', sizedeg['lon'], sizedeg['lat'], 1, gdal.GDT_UInt16, options=['COMPRESS=LZW', 'INTERLEAVE=BAND', 'TILED=YES'])
            #ds = driver.Create(wdir_geo + 'tilemask/' + tilemask_name + '.tif', int(360.0/resdeg), int(180.0/resdeg), 1, gdal.GDT_Int16, options=['COMPRESS=LZW', 'INTERLEAVE=BAND', 'TILED=YES'])
            ds.GetRasterBand(1).WriteArray(tile_pix)
            ds = None
            
            print 'Done, duration = %s s' % (timer.time() - time0)
        
        indexdeg, _, lat, lon = cls.load_degbox(resdeg, tile)
        fullmask[indexdeg['lat'][0]:indexdeg['lat'][-1] + 1, indexdeg['lon'][0]:indexdeg['lon'][-1] + 1] += tile_pix
        #fullmask = fullmask + tile_pix
        
        bound_box[0] = min(indexdeg['lat'][0], bound_box[0])  # top
        bound_box[1] = max(indexdeg['lat'][-1], bound_box[1])  # bottom
        bound_box[2] = min(indexdeg['lon'][0], bound_box[2])  # left
        bound_box[3] = max(indexdeg['lon'][-1], bound_box[3])  # right
    
    
    total_pix = np.load(wdir_geo + '00_Total_pixels_' + ress + '.npy')  # number of MODIS pixels in every degree grid cell.
    fullmask = fullmask / total_pix     # Not all borders are fractions because total_pix is the same fraction at those points, giving 1.
    # But this doesnt matter because those edges are at the edge of available data anyway, so they only contain ocean.
    
    if frac == 'no':
        fullmask[fullmask > 0] = 1      # if frac is 'yes', keep fractions, otherwise make boolean.
    
    if extent is None:
        fullmask = fullmask[bound_box[0]:bound_box[1] + 1, bound_box[2]:bound_box[3] + 1]
    else:
        indexdeg, _, _, _ = cls.load_degbox(resdeg, extent)     # extent is e.g. global, Africa, or other items in the region_box dictionary.
        fullmask = fullmask[indexdeg['lat'][0]:indexdeg['lat'][-1] + 1, indexdeg['lon'][0]:indexdeg['lon'][-1] + 1]
        bound_box[0] = bound_box[0] - indexdeg['lat'][0]
        bound_box[1] = bound_box[1] - indexdeg['lat'][0]
        bound_box[2] = bound_box[2] - indexdeg['lon'][0]
        bound_box[3] = bound_box[3] - indexdeg['lon'][0]
    
    fullmask[fullmask == 0] = np.nan
    
    return fullmask, bound_box



if __name__ == '__main__':
    pass