import Nio
import numpy as np
import function.MVIETool_fun as MVIE
import os
import xarray as xr
#############################################################
# Calculate Mean of standard deviations (M_SD) derived from 
# two reanalysis datasets (JRA55 & NNRP) used in examples of Sec.5
# with and without area weighting, respectively.
#############################################################
# Input datasets are not provided due to too big size of data.
# REA datasets are 1961-2000 monthly datasets with resolution
#  2.5x2.5 of Northern Hemisphere.
# We provide the results in Mul-REA.Mean-SD.areaweight.nc and
# Mul-REA.Mean-SD.noareaweight.nc
#############################################################
# Length of year (1961-2000)
nyear = 40

# Lat and area-weighting
latS = 0.
latN = 90.
Resolution_lat =  2.5 
nlat = np.int((latN-latS)/Resolution_lat+1)

lat = np.linspace(latS,latN,num=nlat)
lat_wgt = MVIE.latWgt(lat)

# Var names
varnames = [["SST"],["SLP"],["Q600"],["T850"],["u850","v850"],["u200","v200"]]
var_output = ["SST","SLP","Q600","T850","uv850","uv200"]

# Season names
season_names = ["DJF","MAM","JJA","SON"]

# Open Reanalysis datasets 
f_REA1 = xr.open_dataset("./REA_data/example.JRA55.1961-2000.nc")
f_REA2 = xr.open_dataset("./REA_data/example.NNRP.1961-2000.nc")

# Read datatime in Reanalysis datasets and create time slice for seasons 
time1_ind = f_REA1["time"].indexes["time"]
time1_slice = []
time1_slice.append(np.logical_or(time1_ind.month<3,time1_ind.month>11)) #DJF:Month-1,2,12
time1_slice.append(np.logical_and(time1_ind.month>2,time1_ind.month<6)) #MAM:Month-3,4,5
time1_slice.append(np.logical_and(time1_ind.month>5,time1_ind.month<9)) #JJA:Month-6,7,8
time1_slice.append(np.logical_and(time1_ind.month>8,time1_ind.month<12)) #SON:Month-9,10,11

time2_ind = f_REA1["time"].indexes["time"]
time2_slice = []
time2_slice.append(np.logical_or(time2_ind.month<3,time2_ind.month>11)) #DJF:Month-1,2,12
time2_slice.append(np.logical_and(time2_ind.month>2,time2_ind.month<6)) #MAM:Month-3,4,5
time2_slice.append(np.logical_and(time2_ind.month>5,time2_ind.month<9)) #JJA:Month-6,7,8
time2_slice.append(np.logical_and(time2_ind.month>8,time2_ind.month<12)) #SON:Month-9,10,11

# Create lat slice
lat1 = f_REA1["lat"]
lat1_slice = np.logical_and(lat1>=latS,lat1<=latN)
lat2 = f_REA2["lat"]
lat2_slice = np.logical_and(lat2>=latS,lat2<=latN)

# Create new file to save M_SD
if os.path.isfile("Mul-REA.Mean-SD.areaweight.nc"):
   os.system("rm Mul-REA.Mean-SD.areaweight.nc")
if os.path.isfile("Mul-REA.Mean-SD.noareaweight.nc"):
   os.system("rm Mul-REA.Mean-SD.noareaweight.nc")

f_out1 = Nio.open_file("Mul-REA.Mean-SD.areaweight.nc","c")
f_out2 = Nio.open_file("Mul-REA.Mean-SD.noareaweight.nc","c")

f_out1.create_dimension('var',1)
f_out2.create_dimension('var',1)

MeanSD_Mul1 = np.zeros(len(season_names),float)
MeanSD_Mul2 = np.zeros(len(season_names),float)

# Calculate M_SD for individual variables
print('%-20s'%"Individual M_SD:",'%-13s'%"With areaWgt",'%-13s'%"No areaWgt")

for i in range(len(var_output)):
   var_REA1 = f_REA1[varnames[i][0]]
   var_REA2 = f_REA2[varnames[i][0]]
   if len(varnames[i]) >1:
      var2_REA1 = f_REA1[varnames[i][1]]
      var2_REA2 = f_REA2[varnames[i][1]]

   varShp1 = var_REA1.shape
   varShp2 = var_REA2.shape
   
   for j in range(len(season_names)):
      var_season_REA1 = np.mean(np.reshape(np.expand_dims(var_REA1.sel(time=time1_slice[j]),axis=1),\
         (nyear,3,varShp1[1],varShp1[2])),axis=1)
      var_season_REA2 = np.mean(np.reshape(np.expand_dims(var_REA2.sel(time=time2_slice[j]),axis=1),\
         (nyear,3,varShp2[1],varShp2[2])),axis=1)

      Ref_season = (var_season_REA1+var_season_REA2)/2.

      var_seasonM = np.zeros((2,Ref_season.shape[1],Ref_season.shape[2]),float)
      var_seasonM[0,:,:] = np.mean(var_season_REA1,axis=0)
      var_seasonM[1,:,:] = np.mean(var_season_REA2,axis=0)

      var_Std_grid = np.std(var_seasonM,axis=0)/np.std(Ref_season,axis=0)

      if (i==0) and (j==0):
         latWgt_Matrix = np.transpose(np.tile(lat_wgt,(Ref_season.shape[2],1)))

      MeanSD_var1 = np.nansum(latWgt_Matrix*var_Std_grid)/np.nansum(latWgt_Matrix)
      MeanSD_var2 = np.nanmean(var_Std_grid)
      
      if len(varnames[i]) >1:
         var_season_REA1 = np.mean(np.reshape(np.expand_dims(var2_REA1.sel(time=time1_slice[j]),axis=1),\
            (nyear,3,varShp1[1],varShp1[2])),axis=1)
         var_season_REA2 = np.mean(np.reshape(np.expand_dims(var2_REA2.sel(time=time2_slice[j]),axis=1),\
            (nyear,3,varShp2[1],varShp2[2])),axis=1)

         Ref_season = (var_season_REA1+var_season_REA2)/2.

         var_seasonM = np.zeros((2,Ref_season.shape[1],Ref_season.shape[2]),float)
         var_seasonM[0,:,:] = np.mean(var_season_REA1,axis=0)
         var_seasonM[1,:,:] = np.mean(var_season_REA2,axis=0)

         var_Std_grid = np.std(var_seasonM,axis=0)/np.std(Ref_season,axis=0)

         MeanSD_var1 += np.nansum(latWgt_Matrix*var_Std_grid)/np.nansum(latWgt_Matrix)
         MeanSD_var2 += np.nanmean(var_Std_grid)

         MeanSD_var1 /= 2.
         MeanSD_var2 /= 2.

         if var_output[i]=="uv850":
            saveVarname = "uv850_"+season_names[j]
         else:
            saveVarname = "uv200_"+season_names[j]
      else:
         saveVarname = varnames[i][0]+"_"+season_names[j]    
        
      MeanSD_Mul1[j] +=  MeanSD_var1
      MeanSD_Mul2[j] +=  MeanSD_var2

      f_out1.create_variable(saveVarname,'d',('var',))
      f_out1.variables[saveVarname][:] = np.float64(MeanSD_var1)
      f_out2.create_variable(saveVarname,'d',('var',))
      f_out2.variables[saveVarname][:] = np.float64(MeanSD_var2)

      print('%-16s'%saveVarname,'%13.3f'%MeanSD_var1,'%13.3f'%MeanSD_var2)

# Calculate M_SD of multivariable field
MeanSD_Mul1 /= len(var_output)
MeanSD_Mul2 /= len(var_output)

print('%-20s'%"Multivariable field:",'%-13s'%"With areaWgt",'%-13s'%"No areaWgt")

for j in range(len(season_names)):
   saveVarname = "MulFie_"+season_names[j]

   f_out1.create_variable(saveVarname,'d',('var',))
   f_out1.variables[saveVarname][:] = np.float64(MeanSD_Mul1[j])
   f_out2.create_variable(saveVarname,'d',('var',))
   f_out2.variables[saveVarname][:] = np.float64(MeanSD_Mul2[j])

   print('%-16s'%saveVarname,'%13.3f'%MeanSD_Mul1[j],'%13.3f'%MeanSD_Mul2[j])

#############################################################




     

     
