
import os
import sys
import json
import warnings
from glob import glob

import numpy as np
import xarray as xr
import pandas as pd

warnings.filterwarnings("ignore")

from sklearn import metrics
from sklearn import model_selection
from sklearn import ensemble
from sklearn import svm
from sklearn import preprocessing
from sklearn import pipeline
from sklearn import neural_network

sys.path.append('/mnt/lustre/users/lgregor/projects/global_co2_ml/scripts')
import misc_tools as mt


class regression:
    @staticmethod
    def xgb(x, y, s):
        import itertools as itt
        import xgboost as xgb

        cv = []
        lr = 0.05
        depths = [3, 5, 7, 9, 11]
        childs = [3, 5, 8, 11]
        params = itt.product(depths, childs)

        r = int(1 + (y.size // 8000))
        n_split = np.min([6, y.size, s.unique().size])
        skfolds = [i for i in model_selection.GroupKFold(n_splits=n_split).split(x[::r], y[::r], s[::r])]

        results = []
        for d, c in params:
            cv += xgb.cv(
                params={'max_depth': d,
                        'min_child_weight': c,
                        'learning_rate': lr,
                        'verbose': False,
                        'silent': True},
                dtrain=xgb.DMatrix(x, label=y),
                folds=skfolds,
                num_boost_round=1000,
                early_stopping_rounds=10),

            results += [
                cv[-1]['test-rmse-mean'].values[-1],
                cv[-1].index.values[-1],
                d, c]
        results = pd.DataFrame(
            data=np.array(results).reshape(-1, 4),
            columns=['rmse', 'n_estimators', 'depth', 'weight'])
        best = results.sort_values(by='rmse').iloc[0]
        d = best['depth']
        n = best['n_estimators']
        w = best['weight']

        model = xgb.XGBRegressor(
            learning_rate=lr,
            max_depth=int(d),
            n_estimators=int(n),
            min_child_weight=int(w),
            n_jobs=24,
            silent=True)
        model.fit(x, y)

        model.parameters = (
            'max_depth: {d}, '
            'min_child_weight: {w}, '
            'n_estimators: {n}, '
            'learning_rate: {l}'
        ).format(d=d, n=n, w=w, l=lr)

        return model

    @staticmethod
    def etr(x, y, s):

        n_split = np.min([6, y.size, s.unique().size])
        skfolds = [i for i in model_selection.GroupKFold(n_splits=n_split).split(x, y, s)]

        cv = model_selection.GridSearchCV(
            estimator=ensemble.ExtraTreesRegressor(n_estimators=300),
            param_grid={
                'min_samples_leaf': [3, 5, 9, 13],
                'max_features': [3, 4, 5]},
            cv=skfolds,
            iid=False,
            n_jobs=22)
        cv.fit(x, y)
        model = cv.best_estimator_
        model.n_jobs = 22
        model.fit(x, y)

        model.parameters = {
            'min_samples_leaf': model.min_samples_leaf,
            'max_features': model.max_features}

        return model

    @staticmethod
    def rfr(x, y, s):

        n_split = np.min([6, y.size, s.unique().size])
        skfolds = [i for i in model_selection.GroupKFold(n_splits=n_split).split(x, y, s)]

        cv = model_selection.GridSearchCV(
            estimator=ensemble.RandomForestRegressor(n_estimators=300),
            param_grid={
                'min_samples_leaf': [3, 5, 9, 13],
                'max_features': [3, 4, 5, 6]},
            cv=skfolds,
            iid=False,
            n_jobs=24)

        cv.fit(x, y)

        model = cv.best_estimator_
        model.n_jobs = 24
        model.fit(x, y)

        model.parameters = {
            'min_samples_leaf': model.min_samples_leaf,
            'max_features': model.max_features}

        return model

    @staticmethod
    def svr(x, y, s):

        r = int(1 + (y.size // 4000))

        n_split = np.min([6, y.size, s.unique().size])
        skfolds = [i for i in model_selection.GroupKFold(n_splits=n_split).split(x[::r], y[::r], s[::r])]

        scaler = preprocessing.StandardScaler()
        trnx = scaler.fit_transform(x)

        cv = model_selection.GridSearchCV(
            estimator=svm.NuSVR(kernel='rbf'),
            param_grid={
                'gamma': [0.01, 0.05, 0.1, 0.5, 1],
                'C': [1000, 500, 100, 50],
                'nu': [0.5]},
            cv=skfolds,
            iid=False,
            n_jobs=24)

        cv.fit(trnx[::r], y[::r])

        estimator = cv.best_estimator_
        r = int(1 + (y.size // 10000))
        estimator.fit(trnx[::r], y[::r])

        model = pipeline.Pipeline([
            ('standard_scaler', scaler),
            ('svr', estimator)])

        model.parameters = str({
            'C': estimator.C,
            'gamma': estimator.gamma,
            'nu': estimator.nu,
            'SVs': estimator.support_.size
        })

        return model

    @staticmethod
    def ffn(x, y, s):

        r = int(1 + (y.size // 5000))
        f = y.size // 30
        hidden_layer1 = np.unique(np.linspace(f//2, f, 5).astype(int))

        n_split = np.min([6, y.size, s.unique().size])
        skfolds = [i for i in model_selection.GroupKFold(n_splits=n_split).split(x[::r], y[::r], s[::r])]

        model = neural_network.MLPRegressor(
            batch_size='auto', learning_rate='adaptive',
            early_stopping=True, n_iter_no_change=10, max_iter=150)

        cv = model_selection.GridSearchCV(
            refit=False,
            estimator=model,
            param_grid={
            "alpha": [0.001, 0.1, 10],
            "hidden_layer_sizes": np.r_[np.meshgrid(hidden_layer1, [20])].reshape(2, -1).T.tolist()},
            cv=skfolds,
            iid=False,
            n_jobs=24)

        cv.fit(x[::r], y[::r])

        params = dict(batch_size=25, max_iter=1000, early_stopping=True, shuffle=False, verbose=False)
        params.update(cv.best_params_)
        model.set_params(**params)

        r = int(1 + (y.size // 10000))
        model.fit(x[::r], y[::r])

        model.parameters = str({
            'hidden_layer_size': model.hidden_layer_sizes,
            'batch_size': model.batch_size,
            'alphaL2': model.alpha,
        })

        return model


def main(func=regression.xgb, clus_flist=[]):
    os.chdir('/mnt/lustre/users/lgregor/projects/global_co2_ml/')

    clus_flist = np.sort(clus_flist)
    isfile = np.array([True if os.path.isfile(f) else False for f in clus_flist])
    assert (all(isfile)), "Following are not files: {}".format(str(clus_flist[~isfile]))

    for clus_fname in clus_flist:
        print(clus_fname)
        data, clus = load_data(data_fname="./data/train_predict.hdf", clus_fname=clus_fname)

        y = 'pCO2'
        x = [
            'SST', 'SSTanom',
            'SALT', 'ACO2', 'MLDLog',
            'ChlLog', 'CHLanom',
            'ERA_WIND', 'ERA_VWND', 'ERA_UWND',
            'TIME_COS', 'TIME_SIN',
            'ICE',
        ]

        obs, prd, tst, trn = split_train_test(data, x, y)

        train_clusters(data, clus, func, x, y, trn, tst, prd, sdir="./output/02_regress/")


def load_data(data_fname='./data/train_predict.hdf', clus_fname=''):
    # LOAD PREDICTION DATASET
    data = []
    for y in range(1982, 2017):
        print(y, end=', ')
        data += pd.read_hdf(data_fname, 'y%d' % y),
    data = pd.concat(data, ignore_index=False)
    print('')

    # LOAD CLUSTERS
    clus = xr.open_dataset(clus_fname)
    n_years = data.YEAR.unique().size
    hc = clus.hard_clusters.values.reshape(-1)
    if hc.size != data.shape[0]:
        hc = hc[None].repeat(n_years, axis=0).reshape(-1)
    data['clusters'] = hc

    clus.attrs['fname'] = clus_fname

    return data, clus


def split_train_test(data, x, y):
    print('Getting training data')
    obs = data.loc[:, y].notnull().values
    prd = data.loc[:, x].notnull().all(1).values

    i = ((data.YEAR == 1984) |
        (data.YEAR == 1990) |
        (data.YEAR == 1995) |
        (data.YEAR == 2000) |
        (data.YEAR == 2005) |
        (data.YEAR == 2010) |
        (data.YEAR == 2014)
        ).values
    tst = obs & prd & i
    trn = obs & prd & ~i

    return obs, prd, tst, trn


def train_clusters(data, clus, func, x, y, trn, tst, prd, sdir="./output/"):

    stdout = 'C :    TOTAL  TRAIN   TEST    R2   RMSE   PARAMS\n'
    stdout += '------------------------------------------------\n'
    print(stdout, end='')

    feature_importances = []
    # n_clusters = int(data['clusters'].max() + 1)
    # for c in range(n_clusters):
    clusters = np.sort(data['clusters'].unique())
    clusters = clusters[~np.isnan(clusters)]
    for c in clusters:
        i = data['clusters'] == c
        j = i & prd
        itrn = i & trn
        itst = i & tst

        trnx, trny = data.loc[itrn, x].astype(np.float64), data.loc[itrn, y].astype(np.float64)
        tstx, tsty = data.loc[itst, x].astype(np.float64), data.loc[itst, y].astype(np.float64)
        trng, tstg = data.loc[itrn, 'YEAR'], data.loc[itst, 'YEAR']

        model = func(trnx, trny, trng)
        if hasattr(model, 'feature_importances_'):
            feature_importances += model.feature_importances_,

        trnh = model.predict(trnx)
        tsth = model.predict(tstx)

        trn_r2 = metrics.r2_score(trny, trnh)
        tst_r2 = metrics.r2_score(tsty, tsth)
        trn_rmse = metrics.mean_squared_error(trny, trnh)**0.5
        tst_rmse = metrics.mean_squared_error(tsty, tsth)**0.5

        results = "{:02d}: {: >8}{: >7}{: >7}".format(int(c), i.sum(), itrn.sum(), itst.sum())
        results += '{: >6.2f}{: >7.2f}   {}\n'.format(tst_r2, tst_rmse, model.parameters)
        stdout += results
        print(results, end='')

        pred = data.loc[j, x].astype(np.float64)
        data.loc[j, 'pco2_raw'] = model.predict(pred)

    ###########################################
    ##  TRANSFORMING DATA TO XARRAY DATASET  ##
    ###########################################
    xds = xr.Dataset()

    month = np.arange(1, 13)
    time = data.index.levels[0]
    lat = data.index.levels[1]
    lon = data.index.levels[2]
    shape = time.size, lat.size, lon.size

    cluster_meth = clus.method.split('(')[0]
    cluster_nums = clus.cluster.size
    cluster_feat = clus.feature.values.tolist()
    cluster_name = clus.fname[:-3].split('_')[-1]
    regres_meth = get_model_name(model)

    sname = '{}/clusRegr_{}{:02d}{}_{}.nc'.format(sdir, cluster_meth, cluster_nums, cluster_name, regres_meth)
    print(sname, end='')


    xds['pco2_raw'] = xr.DataArray(
        data=data.pco2_raw.values.reshape(*shape),
        dims=('time', 'lat', 'lon'),
        coords={'time': time, 'lat': lat, 'lon': lon},
        attrs={'long_name': 'sea surface partial pressure of CO2',
            'units': 'uatm',
            'features': str(x),
            'training_log': stdout,
            'method': regres_meth,
            'description': (
                'Estimated from features for clusters '
                'estimated from climatological data. '
                'See `clusters` for more information '
                'about the clusters and method.')})

    xds['pco2_smooth'] = xr.DataArray(
        data=mt.convolve_timespace_3d(xds.pco2_raw, kern_size=[15, 15]).values,
        dims=('time', 'lat', 'lon'),
        coords={'time': time, 'lat': lat, 'lon': lon},
        attrs={'long_name': 'sea surface partial pressure of CO2 (smoothed)',
            'units': 'uatm',
            'description': (
                'This is the same as pco2_raw, with the exception '
                'that the data has been smoothed with a 2d convolution. '
                'A 2d gaussian filter is used to smooth the data. ')})
    try:
        xds['clusters'] = xr.DataArray(
            data=clus.hard_clusters.values,
            dims=('time', 'lat', 'lon'),
            coords={'time': time, 'lat': lat, 'lon': lon},
            attrs={'long_name': 'climatological CO2 clusters for regression',
                'features': str(cluster_feat),
                'n_clusters': cluster_nums,
                'method': cluster_meth,
                'description': (
                    'Data clustered based on `features` using '
                    'method shown in `method` attribute. ')})
    except:
        pass

    try:
        xds['train_locs'] = xr.DataArray(
            data=trn.reshape(*shape),
            dims=('time', 'lat', 'lon'),
            attrs={'standard_name': 'train_locations',
                'n_obs': trn.sum(),
                'description': (
                    "Training sample locations for machine learning "
                    "based grouped splitting using years achieving "
                    "roughly a 80:20 split between train and test data.")})

        xds['test_locs'] = xr.DataArray(
            data=tst.reshape(*shape),
            dims=('time', 'lat', 'lon'),
            attrs={'standard_name': 'test_locations_for_machine_learning',
                'n_obs': tst.sum(),
                'description': (
                    "Test sample locations for machine learning "
                    "based grouped splitting using years achieving "
                    "roughly a 80:20 split between train and test data.")})
    except:
        pass

    try:
        xds['biomes'] = xr.DataArray(
            data=xr.open_dataset('./output/01_cluster/fay2014_b23_abci.nc')['hard_clusters'][0].values.squeeze(),
            dims=('lat', 'lon'),
            attrs={'description': "Mean biomes from Fay and McKinley (2014)",
                'biome_names': str(clus.cluster.values)})
    except:
        pass

    try:
        xds['socat'] = xr.DataArray(
            data=data.pCO2.values.reshape(*shape),
            dims=('time', 'lat', 'lon'),
            coords={'time': time, 'lat': lat, 'lon': lon},
            attrs={'long_name': 'surface pCO2 observations',
                'units': 'uatm',
                'version': 'v5',
                'description': (
                    'Sea surface pCO2 ship measurements from the '
                    '5th version of the SOCAT database. ')})
    except:
        pass

    if feature_importances:
        xds['feature_importances'] = xr.DataArray(
            data=np.array(feature_importances),
            dims=('cluster', 'feature'),
            coords={'cluster': clus.cluster.values, 'feature': x},
            attrs={'description': 'feature importances as derived from CART based methods.'})

    for key in xds.data_vars:
        xds[key].encoding = {'zlib': True, 'complevel': 4}

    try:
        xds.to_netcdf(sname)
        print('')
    except:
        sname = sname.split('.')[0] + "_copy.nc"
        print('Saved to new name: {}'.format(sname))
        xds.to_netcdf(sname)


def get_model_name(model):
    if hasattr(model, 'steps'):
        txt = str(model.steps[-1][-1])
    else:
        txt = str(model)

    txt = txt.split('(')[0]
    for word in ['Regressor', 'Regression']:
        txt = txt.replace(word, '')

    return txt


if __name__ == "__main__":
    import argparse
    import warnings

    warnings.filterwarnings("ignore")

    parser = argparse.ArgumentParser()

    func_names = [a for a in dir(regression) if not a.startswith('_')]
    func_pretty = str([a for a in func_names]).replace("'", "")
    parser.add_argument("regression_function", metavar='regress_algo', type=str, help="choose regression {}".format(func_pretty))
    parser.add_argument('cluster_filenames', metavar='clus_fnames', type=str, nargs='+', help='a list of cluster files (from output/clus40)')

    args = parser.parse_args()


    assert args.regression_function in func_names, "regression_function must be one of the following: " + str(func_names)
    func = getattr(regression, args.regression_function)
    flist = args.cluster_filenames

    main(func, flist)
