#########################################
# Stephane SAUX PICART: stux@pml.ac.uk
#  created 26th October 2011
#########################################


import numpy as np



def wavelet(data1, data2, redfacs, wavelet_ranges_def=[20.,20.,20.,20.,20.]):
   """ see Casati et al (2004), A new intensity-scale approach for the verification of spatial precipitation forecasts, Meteorol, Appl.141-154.
      Stages for each cutoffs:
         0. convert data into binary array
         1. copy data into larger array (that is a power of 2)
         2. difference between model and observations (creates binary error image)
         3. 2D haar wavelet on each binary error image
         4. calculate MSE for each binary image
      
      Change with respect Casati et al. 2004:
         - binary map made according to ranges rather than user-defined cutoffs. Ranges are defined using quantile.
         - ocluded data taken into account
      
      Inputs: 2 masked arrays (estimation and observation). Must have the same size and same mask
         redfacs            = reduction factors (depend on the original size) [2,4,8,16,32,...]
         wavelet_ranges_def = list of percentage to be included in each ranges, default is [20.,20.,20.,20.,20.]
      
      Outputs:
         ss_array = result of the analysis: skill score (dimension = number of reduction factors x number of ranges)
         wavelet_cutoffs1, wavelet_cutoffs2     = lower value of each range
         wavelet_cutoffs_b1, wavelet_cutoffs_b2 = lower value of each range + max (for plotting purposes)
   """
   
   L = redfacs.size
   
   # Range definition
   # for satellite (or data1)
   cutoffs1, cutoffs_b1 = dynamic_range(data1, wavelet_ranges_def)
   
   # for model (or data2)
   cutoffs2, cutoffs_b2 = dynamic_range(data2, wavelet_ranges_def)
   
   wavelet_cutoffs1   = cutoffs1
   wavelet_cutoffs2   = cutoffs2
   wavelet_cutoffs_b1 = cutoffs_b1
   wavelet_cutoffs_b2 = cutoffs_b2
      
   # Compute reduction factors for wavelet + initialise MSE and SS array
   redfacs   = np.asarray(redfacs)
   nr        = np.size(redfacs)
   nc        = np.size(cutoffs1)
   mse_array = np.zeros((nr,nc))
   ss_array  = np.zeros((nr,nc))
   
   # Loop through all cutoffs
   count_co = 0
   i = 0
   for co in cutoffs1:
      # 0. convert data into binary array
      #----------------------------------
      # set 0 inside the range
      data2_temp = np.ma.masked_where((data2<cutoffs_b2[i+1])*(data2>=cutoffs_b2[i]), data2)
      ind = np.ma.where((data2<cutoffs_b2[i+1])*(data2>=cutoffs_b2[i]))
      data1_temp = np.ma.masked_where((data1<cutoffs_b1[i+1])*(data1>=cutoffs_b1[i]), data1)
      data2_temp = np.ma.filled(data2_temp,0.)
      data1_temp = np.ma.filled(data1_temp,0.)
      
      # set 1 outside the range
      data2_temp = np.ma.masked_where((data2>=cutoffs_b2[i+1])+(data2<cutoffs_b2[i]), data2_temp)
      ind = np.ma.where((data2>=cutoffs_b2[i+1])+(data2<cutoffs_b2[i]))
      data1_temp = np.ma.masked_where((data1>=cutoffs_b1[i+1])+(data1<cutoffs_b1[i]), data1_temp)
      data2_temp = np.ma.filled(data2_temp,1.)
      data1_temp = np.ma.filled(data1_temp,1.)
      
      # Mask again using the original mask
      data1_temp = np.ma.masked_where(data1.mask==True, data1_temp)
      data2_temp = np.ma.masked_where(data2.mask==True, data2_temp)
      
      # 1. copy data into larger array (that is a power of 2)
      #------------------------------------------------------
      width = data2.shape[0]
      height = data2.shape[1]
      
      new_width = redfacs[-1]*2
      new_height = redfacs[-1]*2
      new_data2_temp = np.ma.zeros((new_width,new_height))
      new_data1_temp = np.ma.zeros((new_width,new_height))
      new_data2_temp = np.ma.masked_equal(new_data2_temp, 0.)
      new_data1_temp = np.ma.masked_equal(new_data1_temp, 0.)
      
      new_data2_temp[0:width-1,0:height-1] = data2_temp[0:width-1,0:height-1]
      new_data1_temp[0:width-1,0:height-1] = data1_temp[0:width-1,0:height-1]
      
      # 2. difference between model and observations (creates binary error image)
      #--------------------------------------------------------------------------
      bin_diff = new_data2_temp - new_data1_temp
      
      # 3. 2D haar wavelet on each binary error image
      #----------------------------------------------
      # Initialise variables
      previous_father = bin_diff
      weight = np.ma.zeros(bin_diff.shape) + 1.
      
      # Compute the fraction of pixels > cutoff in the binary analysis of the satellite composite
      data1_temp_ma0 = np.ma.masked_where(data1_temp==0., data1_temp)
      data1_temp_ma1 = np.ma.masked_where(data1_temp==1., data1_temp)
      nb_above_co_sat  = len(np.ma.compressed(data1_temp_ma0))
      nb_below_co_sat  = len(np.ma.compressed(data1_temp_ma1))
      
      data2_temp_ma0 = np.ma.masked_where(data2_temp==0., data2_temp)
      data2_temp_ma1 = np.ma.masked_where(data2_temp==1., data2_temp)
      nb_above_co_mod  = len(np.ma.compressed(data2_temp_ma0))
      nb_below_co_mod  = len(np.ma.compressed(data2_temp_ma1))
      if (nb_above_co_sat + nb_below_co_sat) > 0 and (nb_above_co_mod + nb_below_co_mod) > 0.:
         epsilon_sat = float(nb_above_co_sat)/float(nb_above_co_sat + nb_below_co_sat)
         epsilon_mod = float(nb_above_co_mod)/float(nb_above_co_mod + nb_below_co_mod)
         epsilon = (epsilon_sat + epsilon_mod)/2.
         epsilon = epsilon_sat
         
         count_fac = 0
         
         print "  Cutoff = ", co, "  epsilon = ", epsilon
         if epsilon > 0.:
            # Loop through all reduction factors
            for fac in redfacs:
               # Compute father
               father, next_weight = haarFather(bin_diff, fac)
               father.mask = bin_diff.mask.copy()
               
               # Compute mother
               mother, weight = haarMother(previous_father, father, fac/2, weight)
               
               # Re initialise the previous father
               previous_father = father
               
               # 4. Compute MSE and skill score for each binary image
               #---------------------------------------
               mother2 = mother*mother
               mse_array[count_fac,count_co] = mother2.mean()
               
               weight = next_weight
               if epsilon < 1.:
                  ss_array[count_fac,count_co] = 1. - (mse_array[count_fac,count_co]/(2.*epsilon*(1.-epsilon)/float(L)))
               
               count_fac = count_fac+1
            #--- endfor fac
         #--- endif epsilon
      
      count_co = count_co+1
      i = i+1
   #--- endfor co
   
   return ss_array, wavelet_cutoffs1, wavelet_cutoffs2, wavelet_cutoffs_b1, wavelet_cutoffs_b2
#--- end of wavelet function



def haarFather(data, redfac):
   """ Compute the father wavelet component """
   
   # Input array's size
   size_in = data.shape
   
   # Output father array
   father = np.ma.zeros(size_in)
   father = np.ma.masked_where(data.mask==True, father)
   weight = np.ma.zeros(size_in)
   weight = np.ma.masked_where(data.mask==True, weight)
   
   # Reduced size
   size_out = (int(size_in[0]/redfac), int(size_in[1]/redfac))
   
   
   n_tot = 0
   # Loop through all pixels of the reduced array
   for j in range(size_out[1]):
      for i in range(size_out[0]):
         if not data[i*redfac:i*redfac+redfac,j*redfac:j*redfac+redfac].mask.all():
            father[i*redfac:i*redfac+redfac,j*redfac:j*redfac+redfac] = data[i*redfac:i*redfac+redfac,j*redfac:j*redfac+redfac].mean()
            
            # Compute a weight for each box = number of valid pixels (non-cloudy)
            n_valid_box = float(data[i*redfac:i*redfac+redfac,j*redfac:j*redfac+redfac].count())/float(np.size(data[i*redfac:i*redfac+redfac,j*redfac:j*redfac+redfac]))
            weight[i*redfac:i*redfac+redfac,j*redfac:j*redfac+redfac] = n_valid_box
            
      # endfor j
   # endfor i

   return father, weight
# end of haarFather function



def haarMother(data, father, fac, weight):
   """ Compute the mother wavelet component knowing the father """
   
   mother = data - father
   
   return mother, weight
# end of haarMother function



#######################################################################################################################
def dynamic_range(data, list_percentage):
   ''' Find dynamic range based on predefine percentages
       list_percentage[i] is the fraction of pixel we want in the range i for exemple: [25,25,25,25] sum should equal 100
   '''
   
   list_data = sorted(data[:,:].compressed().tolist())
   n = len(list_data)
   n_range = len(list_percentage)
   
   i = 0
   range_min = []
   range_bnd = []
   list_per_tot = [0]
   for per in list_percentage:
      list_per_tot.append(list_per_tot[i]+per)
      i = i+1
   
   for i in range(len(list_percentage)):
      range_ind_min = int(n*list_per_tot[i]/100.)
      range_ind_max = int(n*list_per_tot[i+1]/100.)
      range_min.append(list_data[range_ind_min])
      range_bnd.append(list_data[range_ind_min])
      
   range_bnd.append(max(list_data))
   
   return np.asarray(range_min), np.asarray(range_bnd)
#--- end of dynamic_range




#if __name__ == "__main__":
   # This provides an example comparing two masked arrays provided by the user (a and b).
   
   # a = 2D masked array (dimension=dim1 x dim2)
   # b = 2D masked array (dimension=dim1 x dim2)
   
   # redfacs = [2,4,8,16,32, ... , closest power of 2 <= max dimension of a and b]
            
   # mse_array, ss_array, wavelet_cutoffs1, wavelet_cutoffs2, wavelet_cutoffs_b1, wavelet_cutoffs_b2 = wavelet(a,b,redfacs)
   
