import os

import numpy as np
import pandas as pd
import xarray as xr

from sklearn import cluster


def main(features, n_clusters):

    data, varkeys = load_data()

    df = data.loc[:, features]

    sname = './output/01_clus41_mbkmeans/mbkmeans_c{:02d}_{}.nc'.format(c, var_letters)
    print(data.shape)
    print(varkeys)


def load_data(fname='./data/train_predict.hdf'):
    os.chdir('/home/lgregor/projects/global_co2_ml/')

    data = []
    for y in range(1982, 2017):
        print(y, end=', ')
        data += pd.read_hdf(fname, 'y%d' % y),
    data = pd.concat(data, ignore_index=False)

    clim = xr.open_dataset('./data/climatologies/climatology_clustering.nc')

    data['EKE_conv'] = clim.EKE_conv.values[None].repeat(35, axis=0).reshape(420, 180, 360).reshape(-1)

    return data, varkeys


def subset_data(data, x):
    x = ['SST', 'ChlLog', 'MLDLog', 'LDEO_pCO2_interp', 'EKE_conv', ]
    df = data.loc[:, x]


def minibatch_kmeans(df_subset, varkeys, sname):

    varkeys = pd.read_csv('./output/clustering/var_keys.csv', index_col=0, header=None).iloc[:, 0]

    df_noNaN = df_subset.dropna()

    var_letters = ''.join(np.sort([varkeys[k] for k in x]))

    for c in [21]:
        print(c, end=': clustering, ')
        kmn = cluster.MiniBatchKMeans(
            n_clusters=c,
            max_iter=150,
            max_no_improvement=25,
            batch_size=32,
            verbose=0,
            n_init=200)

        kmn.fit(df_noNaN)

        print('reindexing labels onto grid', end=', ')

        ##############################
        #####  SAVING TO NETCDF  #####
        ##############################
        xds = xr.Dataset()

        cc = pd.DataFrame(kmn.cluster_centers_, columns=x)
        cc.columns.name = 'feature'
        cc.index.name = 'cluster'
        cc = cc.stack().to_xarray()

        cm = pd.Series(kmn.labels_, index=df_noNaN.index)
        cm = cm.reindex(index=df.index)
        cm = cm.to_xarray()
        cm = cm.transpose('time', 'lat', 'lon')

        xds['cluster_centers'] = cc
        xds['hard_clusters'] = cm
        xds.attrs['inertia'] = kmn.inertia_
        xds.attrs['method'] = str(kmn)

        print('saving file')

        xds.to_netcdf(sname)


if __name__ == "__main__":
    main()
