#!/usr/bin/env python2
# -*- coding: utf-8 -*-
"""
Created on Wed May 3 16:56:55 2017

@author: davevanwees

Supportive script for Model. Used for calling the model and actual execution of timestep-wise calculation.
This script calls the Model for execution.

"""

import numpy as np
import matplotlib.pyplot as plt
from scipy import io
import netCDF4 as netcdf
import sys
import os
import glob
import time as timer
import h5py
import gdal, gdalconst
import copy
import gc
import pandas as pd
import multiprocessing
import traceback
from functools import partial
from pyhdf.SD import SD, SDC
from cdo import *; cdo = Cdo()
plt.rc('image', interpolation='none')       # change default image interpolation
#del CDF_MOD_NETCDF4, CDF_MOD_SCIPY, CDO_PY_VERSION

wdir =     '/Volumes/Mac_HD/Work/Vici_project/koolstof_model/'            # Model directory
wdir_geo = '/Volumes/Mac_HD/Work/Data/MODIS_geolocation_500m/'       # geolocation file directory
ddir =     '/Volumes/Mac_HD/Work/Data/'                              # data directory

sys.path.append(wdir+'Scripts/')
try:
    print 'Reloading class'
    reload(sys.modules['carbonmodel_class3'])
    from carbonmodel_class3 import Model
except:
    print 'Importing class'
    from carbonmodel_class3 import Model

try:
    map(os.remove, glob.glob(tempfile.tempdir + '/cdoPy*'))  # remove CDO tempfiles, bug in CDO module.
    print 'CDO temp files cleaned.'
except:
    pass


#d = np.load(ddir+'FC_Leeuwen2014/'+'Leeuwen_2014_processed_v2.npy')[()]
d = np.load(ddir+'FC_Leeuwen2014/'+'Leeuwen_2014_Africa'+'.npy')[()]
results = d['results']
measlist = d['measlist']
fyears = results['ftimes']['fyears']
fmonths = results['ftimes']['fmonths']

df0 = pd.DataFrame([results['year'], results['month']]).transpose()
df1 = pd.DataFrame([results['lat'], results['lon']]).transpose()
df2 = pd.concat([pd.DataFrame([measlist[i][0], measlist[i][1], measlist[i][2]]).transpose() for i in range(len(measlist))], ignore_index=True)
df3 = pd.concat([df1, df2, df0], axis=1, ignore_index=True)
samplecoor = df3.drop_duplicates().as_matrix()
print 'Note %s duplicate(s) is/are removed at location(s): %s' % (len(df3) - len(samplecoor), np.where(df3.duplicated()==True)[0])

    
print 'debug'



''' Syntax to use:
- worker_function(i)   # run single process, without logging.
- logger(worker_function, i)   # run single process, with logging.
- Model.multiprocess(partial(logger, worker_Leeuwen), range(10), 3)    # run multiprocess for range of i, using 3 processors.
Alternative:
def wrapper(i):     return logger(worker_Leeuwen, i)
- Model.multiprocess(wrapper, range(3), 3)
'''


def logger(worker_function, i):
    
    '''
    Used for correct logging of overhead prints into terminal.
    '''
    
    time0 = timer.time()
    
    try:
        printtxt = open(wdir+'logs/'+'log_'+str(foldername[:-1])+'_'+str(i)+'.txt', 'w')      # redirect stdout to text file
        sys.stdout = printtxt
        print 'Process: '+str(multiprocessing.current_process())
        
        
        status = worker_function(i)
        
        
        print
        print 'Done, duration = %s (h,m,s)' % (timer.strftime("%H:%M:%S", timer.gmtime(timer.time() - time0)))
        
    except (KeyboardInterrupt, SystemExit):
        print '|| From worker: Caught KeyboardInterrupt or sys.exit(), terminating'
        raise KeyboardInterruptError()  # raise 'error' (fake) that is caught in except.
        # the whole interrupt structure works only with this fake error.
    except Exception, e:
        print '|| From worker: Received error: %r' % (e,)
        exc_type, exc_obj, exc_tb = sys.exc_info()
        fname = os.path.split(exc_tb.tb_frame.f_code.co_filename)[1]
        print(exc_type, fname, exc_tb.tb_lineno)
        traceback.print_exc()  # prints to console/terminal by default
        traceback.print_exc(file=sys.stdout)

    finally:
        print 'Worker Finished %s' % timer.strftime("%a %d %b %Y %X")

        printtxt.close()

        printcon = sys.__stdout__
        timer.sleep(1)
        
        if status == 'worker_done':
            print >> printcon, str(multiprocessing.current_process()) + ' FINISHED %s' % i
        else:
            print >> printcon, str(multiprocessing.current_process()) + 'Instance not made for %s' % i
        

        sys.stdout.close()
        sys.stdout = sys.__stdout__  # restore stdout back to console/terminal


def worker_Leeuwen(res, mode, i):
    
    '''
    Runs model only for model pixels that contain field data from Leeuwen et al. (2014).
    '''
    
    sdir = wdir + 'results/' + foldername
    printcon = sys.__stdout__
    
    border = 0
    if res == 500:
        n = Model(sdir, mode=mode, resolution=res, tile=samplecoor[i][2], firemode='on', act='active', sampleloc=[[float(samplecoor[i, 0]), float(samplecoor[i, 1])], border], s_pixelmode=True)
    elif res != 500:
        n = Model(sdir, mode=mode, resolution=res, tile='global', firemode='on', act='active', sampleloc=[[float(samplecoor[i, 0]), float(samplecoor[i, 1])], 0], s_pixelmode=True)
    
    n.fyear = str(int(fyears[i]))
    n.fmonth = int(fmonths[i])
    
    print 'START: mode = %s - tile %s %s' % (n.mode, measlist[i][0], timer.strftime("%a %d %b %Y %X"))
    print >> printcon, str(multiprocessing.current_process()) + ' START: i = %s, mode = %s - tile %s' % (i, n.mode, measlist[i][0])  # duplicate print in terminal
    
    for step in range(n.n_years):
        for month in np.arange(0, 12):
            n.algorithm2(step, month)
            n.diagnose(step, month)
            
        print ('\r{0} / ' + str(n.n_years)).format(step+1),
        print >> printcon, ('\r' + str(multiprocessing.current_process()) + ' {0} / ' + str(n.n_years)).format(step+1),
        sys.stdout.flush()      # overwrites previous print
        printcon.flush()        # overwrites previous print
    
    if mode == 1:
        n.save()
    elif mode == 2:
        n.save(addtxt='FC')
    
    n = None
    gc.collect()
    return 'worker_done'


def worker_tile(mode, tile, firemode='on', s_years=None, sampleloc=None, addtxt=None):
    
    ''' 
    Execute model on a per-tile basis. Defaults to 500-meter resolution.
    '''
    
    # NEW version where setup_vars is decoupled from algorithm.
    
    sdir = wdir + 'results/' + foldername
    printcon = sys.__stdout__
    
    if addtxt is None: addtxt = ''
    
    res = 500
    
    s_pixelmode = False
    if sampleloc != None: s_pixelmode=True
    
    n = Model(sdir, mode=mode, resolution=res, tile=tile, firemode=firemode, act='active', sampleloc=sampleloc, s_pixelmode=s_pixelmode)
    
    if s_years is not None:
        n.load_spinup_pools(s_years=s_years, addtxt=addtxt)
    elif mode == 2:
        n.load_spinup_pools(addtxt=addtxt)
    
    print 'START: mode = %s - tile %s %s' % (n.mode, tile, timer.strftime("%a %d %b %Y %X"))
    print >> printcon, str(multiprocessing.current_process()) + ' START: mode = %s - tile %s' % (n.mode, tile)  # duplicate print in terminal

    for step in range(n.n_years):
        
        if mode == 2:
            n.setup_vars(step)
        
        for month in np.arange(0, 12):
            
            n.algorithm(step, month)
            if mode == 1:
                n.diagnose(step, month)
            elif mode == 2:
                n.diagnose(step, month, diagf='light')
            
        print ('\r{0} / ' + str(n.n_years)).format(step+1),  # overwrites previous print
        print >> printcon, ('\r' + str(multiprocessing.current_process()) + ' {0} / ' + str(n.n_years)).format(step+1),
        sys.stdout.flush()
        printcon.flush()
    
    if (s_years is not None) and (mode == 2):   addtxt += '_spin%syears' % s_years
    
    n.save(ext='.nc', addtxt=addtxt)
    
    n = None
    gc.collect()
    return 'worker_done'


def worker_global(res, mode, firemode='on', s_years=None, sampleloc=None, addtxt=None):
    
    ''' 
    Execute model on a global level.
    '''
    
    sdir = wdir + 'results/' + foldername
    
    if addtxt is None: addtxt = ''
    
    if res == 500:     rescheck = '500m'
    else:
        rescheck = '%03.0fdeg' % round(res * 100)  # converts e.g. 0.25 -> '025deg'
    if rescheck not in sdir:
        print 'WARNING, folder name resolution not the same as given!'; sys.exit()
    
    time0 = timer.time()
    
    s_pixelmode = False
    if sampleloc != None: s_pixelmode=True
    
    n = Model(sdir, mode=mode, resolution=res, tile='Africa', firemode=firemode, act='active', sampleloc=sampleloc, s_pixelmode=s_pixelmode)
    
    n = C(filename)
    n.regrid(res)
    
    
    if s_years is not None:
        n.load_spinup_pools(s_years=s_years, addtxt=addtxt)
    elif mode == 2:
        n.load_spinup_pools(addtxt=addtxt)
    
    for step in range(n.n_years):
        time_loop = timer.time()
        for month in np.arange(0, 12):
            n.algorithm2(step, month)
            n.diagnose(step, month)
            
        print ('\r{0} / ' + str(n.n_years) + ', duration (incl. read) = {1} s' + '\n').format(step+1, timer.time()-time_loop) ,  # overwrites previous print
        sys.stdout.flush()
    print
    
    print 'Done, duration = %s (h,m,s)' % (timer.strftime("%H:%M:%S", timer.gmtime(timer.time() - time0)))
    
    if (s_years is not None) and (mode == 2):   addtxt += '_spin%syears' % s_years
    
    n.save(ext='.nc', addtxt=addtxt)
    print
    
    n = None
    gc.collect()
    return 'worker_done'


def worker_global_perbiome(res, mode, biome, firemode='on', s_years=None, sampleloc=None, addtxt=None):
    
    '''
    Run model globally, seperately per biome.
    '''
    
    sdir = wdir + 'results/' + foldername
    
    if addtxt is None: addtxt = ''
    
    if res == 500:
        rescheck = '500m'
    else:
        rescheck = '%03.0fdeg' % round(res * 100)  # converts e.g. 0.25 -> '025deg'
    if rescheck not in sdir:
        print 'WARNING, folder name resolution not the same as given!'; sys.exit()
    
    time0 = timer.time()
    
    s_pixelmode = False
    if sampleloc != None: s_pixelmode=True
    
    n = Model(sdir, mode=mode, resolution=res, tile='Africa', firemode=firemode, act='active', sampleloc=sampleloc, s_pixelmode=s_pixelmode)
    n.modsettings['biome_method'] = 'perbiome'
    
    n.biome = biome
    
    addtxt += 'biome%02.f' % int(biome)
    
    if s_years is not None:
        n.load_spinup_pools(s_years=s_years, addtxt=addtxt)
    elif mode == 2:
        n.load_spinup_pools(addtxt=addtxt)
    
    for step in range(n.n_years):
        time_loop = timer.time()
        for month in np.arange(0, 12):
            n.algorithm2(step, month)
            n.diagnose(step, month)
            
        print ('\r{0} / ' + str(n.n_years) + ', duration (incl. read) = {1} s' + '\n').format(step+1, timer.time() - time_loop),  # overwrites previous print
        sys.stdout.flush()
    print
    
    print 'Done, duration = %s (h,m,s)' % (timer.strftime("%H:%M:%S", timer.gmtime(timer.time() - time0)))
    
    if (s_years is not None) and (mode == 2):   addtxt += '_spin%syears' % s_years
    
    n.save(ext='.nc', addtxt=addtxt)
    print
    
    n = None
    gc.collect()
    return 'worker_done'


def worker_jobscript():
    
    tile_bash = sys.argv[1]
    print 'Processing tile: %s' % tile_bash
    
    foldername = 'run7_500m_Africa_fire-on-off_speedspin-off_changes-combined_new-abiotic/'
    sdir = wdir + 'results/' + foldername
    
    resolution = 500
    if resolution == 500:     res = '500m'
    else:
        res = '%03.0fdeg' % round(resolution * 100)  # converts e.g. 0.25 -> '025deg'
    time0 = timer.time()

    n = Model(sdir, mode=1, resolution=resolution, tile=tile_bash, firemode='off', act='active')
    f = copy.deepcopy(n)
    f.add_fire()
    
    for step in range(n.n_years):
        time_loop = timer.time()
        for month in np.arange(0, 12):
            n.algorithm2(step, month)
            n.diagnose(step, month)

            f.algorithm2(step, month)
            f.diagnose(step, month)

        print ('\r{0} / ' + str(n.n_years) + ', duration (incl. read) = {1} s' + '\n').format(step, timer.time() - time_loop),  # overwrites previous print
        sys.stdout.flush()
    print

    print 'Done, duration = %s (h,m,s)' % (timer.strftime("%H:%M:%S", timer.gmtime(timer.time() - time0)))

    n.save()
    f.save()


def calc_effective_turn():
    foldername = 'debug_testrun/'
    sdir = Model.wdir + 'results/' + foldername
    n = Model(sdir, mode=2, resolution=0.25, tile='Africa', act='passive')
    moist, abiot, litt_eturn, cwd_eturn, leaf_eturn, stem_eturn = n.construct_scalars()
    
    litt_eturn = np.mean(1 / (litt_eturn * 12), axis=0)
    cwd_eturn = np.mean(1 / (cwd_eturn * 12), axis=0)
    leaf_eturn = 1 / (leaf_eturn * 12)
    stem_eturn = 1 / (stem_eturn * 12)
    
    #litt_eturn[litt_eturn > 2] = np.nan
    #cwd_eturn[cwd_eturn > 10] = np.nan
    
    litt_eturn_biome = Model.biome_calc(litt_eturn, ['total'] + Model.biomes_UMD.keys()[1:13+1], 0.25, '2009', 'Africa')[1]
    cwd_eturn_biome = Model.biome_calc(cwd_eturn, ['total'] + Model.biomes_UMD.keys()[1:13+1], 0.25, '2009', 'Africa')[1]
    leaf_eturn_biome = Model.biome_calc(leaf_eturn, ['total'] + Model.biomes_UMD.keys()[1:13+1], 0.25, '2009', 'Africa')[1]
    stem_eturn_biome = Model.biome_calc(stem_eturn, ['total'] + Model.biomes_UMD.keys()[1:13+1], 0.25, '2009', 'Africa')[1]
    
    print 'debug'
#calc_effective_turn()






''' Execution jobs (make one to your wishes). Examples below: '''


foldername = 'FL_500m_Africa_example/'
def temp_func(i):   return worker_Leeuwen(res=500, mode=1, i=i)
#Model.multiprocess(partial(logger, temp_func), range(len(samplecoor)), 4)
def temp_func(i):   return worker_Leeuwen(res=500, mode=2, i=i)
#Model.multiprocess(partial(logger, temp_func), range(len(samplecoor)), 4)



tiles = sorted(set(Model.regions['NHAF'][1] + Model.regions['SHAF'][1])) #+ Model.regions['MIDE'][1]))
tiles = ['h12v09', 'h13v09']

def temp_func(tile):    return worker_tile(mode=1, tile=tile)
#Model.multiprocess(partial(logger, temp_func), tiles, 2)
def temp_func(tile):    return worker_tile(mode=2, tile=tile)
#Model.multiprocess(partial(logger, temp_func), tiles, 2)



foldername = 'FL_025deg_Africa_17-06_2/'
#worker_global(res=0.25, mode=1)
#worker_global(res=0.25, mode=2)
#worker_global(res=0.25, mode=1, firemode='off')
#worker_global(res=0.25, mode=2, firemode='off')

#for biome in Model.biomes_UMD.keys()[:13+1]:
    #worker_global_perbiome(res=0.25, mode=1, biome=biome, addtxt='percNOMISC')
    #worker_global_perbiome(res=0.25, mode=2, biome=biome, addtxt='percNOMISC')

# for biome in Model.biomes_UMD.keys()[:13+1]:
#     worker_global_perbiome(res=0.25, mode=1, biome=biome)
#     worker_global_perbiome(res=0.25, mode=2, biome=biome)
#     worker_global_perbiome(res=0.25, mode=1, biome=biome, firemode='off')
#     worker_global_perbiome(res=0.25, mode=2, biome=biome, firemode='off')


foldername = 'FL_013deg_Africa_17-06_2/'
# worker_global(res=0.125, mode=1)
# worker_global(res=0.125, mode=2)
# worker_global(res=0.125, mode=1, firemode='off')
# worker_global(res=0.125, mode=2, firemode='off')


foldername = 'FL_005deg_Africa_17-06_2/'
# worker_global(res=0.05, mode=1)
# worker_global(res=0.05, mode=2)
# worker_global(res=0.05, mode=1, firemode='off')
# worker_global(res=0.05, mode=2, firemode='off')



