import numpy as np
from numpy import dtype, arange
import numpy.ma as ma
import matplotlib.pyplot as plt 
#from matplotlib.colors import LogNorm , SymLogNorm
#from mpl_toolkits.drivermap import Basemap, cm
import time
import scipy.ndimage
from subprocess import call
import sys 
#from mpi4py import MPI 
#import matplotlib.gridspec as gridspec
import datetime as dt
#from matplotlib.colors import LinearSegmentedColormap


def printVar(frame,globs):
#-------------------------------------------------------------------------------
# c (crude), l (low), i (intermediate), h (high), f (full) or None
# laea: Lambert Azimuthal Equal Area Projection (shape is not prserved, but area)

        lons_mean = np.mean(globs.lons)
        lats_mean = np.mean(globs.lats)
        fp=frame.plot
        date=fp.initDate + dt.timedelta(seconds=fp.secPerStep*(frame.ts+1))
        

        # clear existing data from figure
        for ax in frame.plot.fig.axes:
            ax.cla()

        for v in frame.plot.plotList:
            setts=frame.plot.settingList[v]
        #draw map
        if setts.withMap:
            if globs.domain=="eu":
                m = Basemap(width=5800000,height=5950000,resolution='l',projection='laea',lat_ts=lats_mean,lat_0=lats_mean+1.3,lon_0=lons_mean+0.5,ax=setts.ax)
            elif globs.domain=="nrw":
                m = Basemap(width=153000,height=153000,resolution='h',projection='laea',lat_ts=lats_mean,lat_0=lats_mean+0.051,lon_0=lons_mean+0.079,ax=setts.ax)
            x, y = m(globs.lons, globs.lats)

            #draw data
            d=frame.array[v].copy()
            if setts.m==True:
                d=ma.masked_array(d,mask=globs.mask)

            if setts.plotType=="contourf":
                ret = setts.ax.contourf(x, y, d[0],setts.cbar.contLev,  vmin=setts.cbar.minV, alpha=0.8,vmax=setts.cbar.maxV,  norm=setts.cbar.norm,  cmap=setts.cbar.cmap, zorder=setts.zo)
            elif setts.plotType=="contour" or setts.plotType=="contourLabeled":
              ret = setts.ax.contour(x, y, d[0],setts.cbar.contLev,  vmin=setts.cbar.minV, vmax=setts.cbar.maxV,norm=setts.cbar.norm,  cmap=setts.cbar.cmap, colors=setts.cbar.colors, linethick=0.02, zorder=setts.zo)
            if setts.plotType=="contourLabeled":
                cl=setts.ax.clabel(ret, fmt='%.0f', fontweight='bold', fontsize=3 )
                for txt in cl:
                    txt.set_bbox(dict(facecolor='white', edgecolor='none', pad=0))
                    txt.set_rotation("horizontal")      
                    txt.set_color("black")      

            elif setts.plotType=="pcolormesh":
                ret = setts.ax.pcolormesh(x, y, d[0], vmin=setts.cbar.minV, vmax=setts.cbar.maxV,norm=setts.cbar.norm,  cmap=setts.cbar.cmap, zorder=setts.zo)

            elif setts.plotType=="barbs":
                yy = np.arange(0, y.shape[0], 25)
                xx = np.arange(0, x.shape[1], 25)
                points = np.meshgrid(yy, xx)
                ret = setts.ax.barbs(x[points], y[points], d[0][points], d[1][points],barbcolor=setts.cbar.colors,length=4, zorder=setts.zo,linewidth=0.6) 
        

            elif setts.plotType=="bar2met":
                snow=setts.ax.vlines(d[0,0],np.finfo(float).eps,np.finfo(float).eps+d[0,1],color='r')
                rain=setts.ax.vlines(d[0,0],np.finfo(float).eps+d[0,1],d[0,2]+d[0,1],color='b')
                setts.ax.yaxis.grid(True, linestyle='-', which='major', color='lightgrey',alpha=0.5)
                setts.ax.xaxis.grid(True, linestyle='-', which='major', color='lightgrey', alpha=0.5)
                
                if setts.cbar.minV!=None and setts.cbar.maxV!=None:
                    setts.ax.set_ylim(bottom=setts.cbar.minV, top=setts.cbar.maxV)
                if setts.cbar.norm!=None:
                    setts.ax.set_yscale('log')
                yt=setts.ax.get_yticks()
                setts.ax.vlines(date,np.amin(yt),np.amax(yt),color="g")
                setts.ax.set_axisbelow(True)

                setts.ax.set_xticks(d[0,0])
                ticks=np.full(d[0].shape[1],"", dtype="S16")
                for i in range(0,d[0].shape[1]):
                    if i%12==0:
                        ticks[i]=d[0,0,i].strftime("%m-%d %Hh")
                setts.ax.set_xticklabels(ticks,ha="left",size=4)
                setts.ax.set_ylabel(setts.cbar.title,size=4.5)
                setts.ax.yaxis.set_tick_params(labelsize=4)

                setts.ax.legend([snow, rain], ['Snow', 'Rain'],fontsize=4)                              

            elif setts.plotType=="dotLines":
                line=setts.ax.plot(d[0,0],d[0,1],'b.-',markersize=3)
                setts.ax.yaxis.grid(True, linestyle='-', which='major', color='lightgrey',alpha=0.5)
                setts.ax.xaxis.grid(True, linestyle='-', which='major', color='lightgrey', alpha=0.5)
                yt=setts.ax.get_yticks()
                setts.ax.vlines(date,np.amin(yt),np.amax(yt),color="g")
                setts.ax.set_axisbelow(True)
                setts.ax.set_xticks(d[0,0])
                ticks=np.full(d[0].shape[1],"", dtype="S16")
                for i in range(0,d[0].shape[1]):
                    if i%12==0:
                        ticks[i]=d[0,0,i].strftime("%m-%d %Hh")
                setts.ax.set_xticklabels(ticks,ha="left",size=4)
                setts.ax.set_ylabel(setts.cbar.title,size=4.5)
                setts.ax.yaxis.set_tick_params(labelsize=4)             
        
            if setts.ax != fp.main:
                setts.ax.set_title(setts.title,size=6)


            #draw colorbar
            if setts.cbar.ax != None:   
                cbar = plt.colorbar(ret,
                                cax=setts.cbar.ax,
                                ticks=setts.cbar.ticks,
                                orientation=setts.cbar.orientation)
                cbar.ax.tick_params(labelsize=7)
                cbar.ax.set_yticklabels(setts.cbar.ticklabels)
                cbar.set_label(setts.cbar.title,size=8)

# draw countries,coasts and cities     
            if setts.withMap:
                if setts.withLines:     
                    if globs.domain=="eu": 
                        drawMapEU(m,x,y,setts.ax)
                    elif globs.domain=="nrw":   
                        drawMapNRW(m,x,y,setts.ax,topo=globs.topo)
                if setts.withMapTexture:        
                    m.drawmapboundary(fill_color='aqua',zorder=0)
                    m.fillcontinents(color='coral',lake_color='aqua',zorder=0)  
                    #m.shadedrelief()  #needs drivermap with PIL 

        #print main title
        if globs.domain=="eu":
            fp.main.text(25000,25000, u'TerrSysMP v1.1, rv001; EU @ 12.5x12.5km\u00B2 COSMO, CLM & ParFlow results', zorder=12,fontsize=4.75,fontweight="bold",va='bottom',bbox=dict(facecolor=(1,1,1,0.7), edgecolor=(1,1,1,0),pad=0.225 ))
        if globs.domain=="nrw":
            fp.main.text(1000,1000, u'TerrSysMP v1.1, rv001; NRW @ 0.5x0.5km\u00B2 CLM & ParFlow results', zorder=12,
                fontsize=4.75,fontweight="bold",va='bottom',bbox=dict(facecolor=(1,1,1,0.7), edgecolor=(1,1,1,0),pad=0.225 ))

        fp.main.set_title(fp.title, loc='left',size=6)
        fp.main.set_title('step: '+date.strftime("%Y-%m-%d %H:%M UTC")+'\ninit.: '+fp.initDate.strftime("%Y-%m-%d %H:%M UTC"), loc='right',size=6)

# Print Creative Commons
        fp.fig.text(0.98, 0.025, 'Simulation Results / Analyses / Visualization by HPSC TerrSys /CC BY-NC-ND 4.0',
                 fontsize=7, color='gray',
                 ha='right', va='bottom', alpha=0.5)

# write as png //pdf svg eps
        #fp.fig.savefig(globs.fileName+'_'+str(frame.nr).zfill(3)+'.png', dpi=200, facecolor='w', edgecolor='w',format='png')
        fp.fig.savefig(globs.fileName+'_'+frame.vid+'_'+str(frame.nr).zfill(3)+'.png', dpi=200, facecolor='w', edgecolor='w',format='png')

#-------------------------------------------------------------------------------



def printCity(name,lat,lon,domain,m,ax,hali='left'):
            x, y = m(lon, lat) 
            if domain=="nrw":
                marg=3000
            elif domain =="eu":
                marg=65000      
            m.scatter(x, y, s=2 , marker=',',color='k',zorder=12)
            if hali=='left' :
                off=marg
            elif hali=='right':
                off=-marg
            else:
                hali='left'
                off=marg    
            ax.text(x+off, y, name,fontsize=4,bbox=dict(facecolor=(1,1,1,0.6), edgecolor=(1,1,1,0),pad=1.5 ), fontweight='bold' , zorder=12,
            ha=hali,va='center',color='k')

def drawMapNRW(m,x,y,ax,topo=None):
        m.drawcoastlines(linewidth=1.2,color='#151515',zorder=10)
        m.drawcountries(linewidth=1.2,color='#282828',linestyle='-',zorder=10)
        m.drawrivers(linewidth=1.0,color='#0000FF',zorder=10)

        m.drawmeridians(np.r_[0:20:0.5],labels=[0,0,0,1],fontsize=6,linewidth=0.5,color='#606060',dashes=[2,2],zorder=11)
        m.drawparallels(np.r_[40:60:0.5],labels=[1,0,0,0],fontsize=6,linewidth=0.5,color='#606060',dashes=[2,2],zorder=11)

        ax.contour(x, y, topo, [250,500,750],hold='on',colors='#808080',linethick=0.05,zorder=10)
        printCity('Juelich',50.916966,6.371537,"nrw",m,ax)
        printCity('Cologne',50.936389,6.952778,"nrw",m,ax)
        printCity('Bonn',50.727647,7.109601,"nrw",m,ax)
        printCity('Duesseldorf',51.222089,6.777019,"nrw",m,ax)
        printCity('Koblenz',50.351882,7.594615,"nrw",m,ax,'right')
#        printCity('Bitburg',49.964940,6.521794,"nrw",m,ax)
        printCity('Maastricht',50.844460,5.691512,"nrw",m,ax)
        printCity('Aachen',50.771486,6.084142,"nrw",m,ax)
        printCity('Nuerburg',50.344968,6.978922,"nrw",m,ax)
        printCity('Malmedy',50.421748,6.029222,"nrw",m,ax)

def drawMapEU(m,x,y,ax,topo=None):
        m.plot(x[0,:],        y[0,:],        marker=None,color='k',linewidth=1.5)
        m.plot(x[len(x)-1,:], y[len(y)-1,:], marker=None,color='k',linewidth=1.5)
        m.plot(x[:,0],        y[:,0],        marker=None,color='k',linewidth=1.5)
        m.plot(x[:,len(x[0])-1], y[:,len(y[0])-1], marker=None,color='k',linewidth=1.5)

        m.drawcoastlines(linewidth=0.5,color='#151515',zorder=10)
        m.drawcountries(linewidth=0.5,color='#282828',linestyle='-',zorder=10)
###     m.drawrivers(linewidth=1.0,color='#0000FF',zorder=10)
###     ax.contour(x, y, topo, [500,1500,3000],hold='on',colors='#808080',linethick=0.05,zorder=10)


        m.drawmeridians(np.r_[-50:80:10],labels=[0,0,0,1],fontsize=6,linewidth=0.5,color='#606060',dashes=[2,2],zorder=11)
        m.drawparallels(np.r_[10:80:10],labels=[1,0,0,0],fontsize=6,linewidth=0.5,color='#606060',dashes=[2,2],zorder=11)

        printCity('Berlin',52.518611,13.408333,"eu",m,ax)
        printCity('Madrid',40.4125,-3.703889,"eu",m,ax)
        printCity('London',51.50539,-0.11832,"eu",m,ax)
        printCity('Roma',41.883333,12.483333,"eu",m,ax)
        printCity('Paris',48.85666,2.351667,"eu",m,ax)
        printCity('Ankara',39.916667,32.85,"eu",m,ax)
        printCity('Tunis',36.800833,10.18,"eu",m,ax)
        printCity('Kiev',50.45,30.5,"eu",m,ax)
        printCity('Moskow',55.75,37.616667,"eu",m,ax)
        printCity('Cairo',30.056111,31.239444,"eu",m,ax)
        printCity('Athens',37.97778,23.727778,"eu",m,ax)
        printCity('Belgrade',44.820556,20.462222,"eu",m,ax)
        printCity('Stockholm',59.302330,18.081357,"eu",m,ax)
        printCity('Helsinki',60.177497,24.954700,"eu",m,ax)
        printCity('Wien',48.154643,16.373658,"eu",m,ax)
        printCity('Rabat',33.890408,-6.796158,"eu",m,ax)


def getIndexOfCoord(lat,lon,lats,lons):
        londiff=np.absolute((lons+360)-(360+lon))
        latdiff=np.absolute((lats+360)-(360+lat))
        difflatlon=  londiff + latdiff

        return str(lat), str(lon) , np.unravel_index(np.argmin(difflatlon),difflatlon.shape)

#-----------------------------------------------------
# write Var out to file
#-----------------------------------------------------
def writeVarOutToNCFile(ncfile, variableArray, timesteps, nlayers, nlats, nlons):
    # Need to pass in an array of Vars see definition in class_definitions
    # Need to pass in a netCDF4 dataset to get dimensions
    # the netCDF variable will be nrecs x timesteps x nlayers x nlat x nlon.
    nrecs = len(variableArray)
    # create the lat and lon dimensions.
    ncfile.createDimension('lat',nlats)
    ncfile.createDimension('lon',nlons)
    # create layer dimension
    ncfile.createDimension('nlayer',nlayers)
    # create time dimension (unlimited dimension)
    ncfile.createDimension('time', timesteps)
    lat = ncfile.createVariable('lat',dtype('float32').char,('lat',))
    lon = ncfile.createVariable('lon',dtype('float32').char,('lon',))
    time = ncfile.createVariable('time',dtype('float32').char,('time',))
    lats_out = 21.145 + 0.109740566*arange(nlats,dtype='float32')
    lons_out = -10.3734 + 0.109747706*arange(nlons,dtype='float32')
    time_out = 1.0 + 1.0*arange(timesteps,dtype='float32')
    lat.units = 'degrees_north'
    lon.units = 'degrees_west'
    time.units = 'hours since midnight'
    # write data to coordinate vars.
    #lat[:] = lats_out
    #lon[:] = lons_out
    time[:] = time_out
    for nrec in range(nrecs):
        # create the diagnostic variable 
        print variableArray[nrec].varName
        diagnosticVar = ncfile.createVariable(variableArray[nrec].varName,dtype('float32').char,('time','nlayer','lat','lon'), zlib=True, least_significant_digit=6)
        # set the units attribute.
        diagnosticVar.units =  variableArray[nrec].units
        # write data to variables along record (unlimited) dimension.
        var_out = np.array(variableArray[nrec].data)
        diagnosticVar[:,:,::] = var_out
    # close the file.
    ncfile.close()
    print '*** SUCCESS writing file'


#-----------------------------------------------------
# postprocessing routines
#-----------------------------------------------------

#local (on 2d data)

def k2c(dat):
        dat-=273.16
        return dat

def sec2h(dat):
        dat*=3600
        return dat

def cutm1000(dat):
        dat=ma.masked_less(dat, -1000)
        return dat

def pa2hpa(dat):
        dat//=100
        return dat

def gauss2(dat):
        dat = scipy.ndimage.filters.gaussian_filter(dat,2)
        return dat

def masku001(dat):
        dat = ma.masked_less(dat,0.01)
        return dat
def masko0(dat):
        dat = ma.masked_greater_equal(dat,0.0)
        return dat

def masku0(dat):
        dat = ma.masked_less(dat,0.0)
        return dat

def rel2proc(dat):
        dat*=100
        return dat



#global (on Var)

def gCutNbound(v):
        v.data=v.data[:,:,4:-4,4:-4]

def gCut0(v):
        v.data[v.data[:,:,:,:] <= 0] =  np.finfo(float).eps     

def gMask0(v):
        v.data=ma.masked_less_equal(v.data,np.finfo(float).eps)         
            
def gAccum(v):
        for i in range(0,v.numSteps):
            if i>0:
                v.data[i,:,:,:]+=v.data[i-1,:,:,:]


#complex variable calculations

def unCumAdd(v1,v2):    
        out=np.empty_like(v1.data[1:])
        for i in range(0,v1.numSteps-1):
            out[i]= (v1.data[i+1,:,:,:]+v2.data[i+1,:,:,:]) - (v1.data[i,:,:,:]+v2.data[i,:,:,:])
        return out

def merge(v1,v2):
        out=[]
        for ts in range(0,v1.data.shape[0]):
                el=[]
                el.append(v1.data[ts,0])
                el.append(v2.data[ts,0])
                out.append(el)   
        return np.array(out)            

def magnitude(v1,v2):
        return np.sqrt(v1.data[:,:,:,:]*v1.data[:,:,:,:] + v2.data[:,:,:,:]*v2.data[:,:,:,:])

def colStorChange(vp,vs,vpo,dzmult,dz):         #vp=pressure[time:z:x:y], vs=saturation[time:z:x:y], vpo=porosity[1:z:x:y], dzmult[z],dz ;Note: arrays are numpy-arrays 
        z=len(dzmult)                           #z = num layers ; z-1= toplayer in Parflow
        stor=vp.data[:,z-1,:,:].copy()          #initialize with pressure of toplayer (ponding)
        stor[vs.data[:,z-1,:,:]!=1]=0           #reinitialize with 0 if points have no ponding (saturation!=1)
        for i in range(0,z):                    #loop over layers
          stor+=vs.data[:,i,:,:] * dzmult[i] * vpo.data[0,i,:,:] * dz  #accumulate storage of layers
        stor=(stor[1:,:,:]-stor[0,:,:]) *1000   #substract init value from each timestep to get relative change. Also convert to mm 
        return stor[:,np.newaxis,:,:]           #return storage[time,1,x,y]

def wtdc(v):
        press=v.data[:,0,:,:]
        dwtd=(press[1:,:,:]-press[0,:,:])*-1000
        dwtd[press[1:,:,:] < 0] = 0
        #Optional: clip the data to a range
        #dwtd[dwtd[:,:,:] > 100] = 100 #limit to 100
        #dwtd[dwtd[:,:,:] < -100] = -100 #limit to -100
        #dwtd[np.isnan(dwtd[:,:,:])] = 100 #limit to 100
        return dwtd[:,np.newaxis,:,:]

#Approximate Water Table depth:
def wtdapprox(vp,dzmult,dz): #Assumes hydrostatic pressure! Take pressure head (m) and subtract bottom layer z midpoint and add top layer z midpoint                                                                                                 
    size = len(dzmult)
    bot_z_midpt = 0.5*dz*dzmult[0]
    watercolumn = sum(dzmult)*dz
    # all pressure values at first level:
    press=vp.data[:,0,:,:]
    dwtda=press + watercolumn - bot_z_midpt
    #dwtda[press[1:,:,:] < 0] = 0
    return dwtda[:,np.newaxis,:,:]


#Absolute Water Table depth - This code needs to be checked for correctness and performance (very slow)
def wtd(vp,dzmult,dz,dzsum): # vp =  # vp=pressure[t,z,x,y]  ; dzmult=z multiplyer for each layer, dz=dz, the equal distance between layers; dzsum=sum over dzmult*dz (model height [m])
    z=0                                                                                                           # init z with offset
    dzmult = np.array(dzmult)                                                                                     # convert dzmult to numpy array
    dz=dzmult[:] * dz                                                                                            # actual dz = dz-mult * dz
    zz=np.zeros_like(dzmult)                                                                                      # zz=elevation[z]
    wtd=np.zeros_like(vp.data[:,0,:,:])                                                                           # water table depth
    #calculate elevation
    for d in range(0,vp.data.shape[1]):                                                                           #loop over z and 
        z +=  0.5 *  dz[d];                                                                                       #add lower 1/2 dz to iterator (middle of cell)
        zz[d]=z;                                                                                                  #this is the elevation from this cell
        z +=  0.5 *  dz[d];                                                                                       #add upper 1/2 dz to iterator

    for t in range(0,vp.data.shape[0]):                                                                           #loop over time
        print "Step: = "+str(t)
        for x in range(0,vp.data.shape[2]):                                                                         #loop over x
            for y in range(0,vp.data.shape[3]):                                                                       #loop over y
                gwt=vp.data.shape[1]-2                                                                                  #gwt is at surface if not found bottom up ???
                for d in range(0,vp.data.shape[1]-1):                                                                   #loop over z    
                    #Layer with GWT:
                    if(vp.data[t,d,x,y]>0 and vp.data[t,d+1,x,y]<0 ):                                                   #gwt is where press(z,x,y)>0 and press(z+1,x,y)<0
                        gwt=d
                        break
            wtd[t,x,y]=(dzsum-zz[gwt])
    return wtd[1:,np.newaxis,:,:]

#Total Water storage in column 
def colStor(vp,vs,vpo,dzmult,dz):            #vp=pressure[time:z:x:y], vs=saturation[time:z:x:y], vpo=porosity[1:z:x:y], dzmult[z],dz ;Note: arrays are numpy-arrays
        z=len(dzmult)                           #z = num layers ; z-1= toplayer in Parflow
        stor=vp.data[:,z-1,:,:].copy()          #initialize with pressure of toplayer (ponding)
        stor[vs.data[:,z-1,:,:]!=1]=0           #reinitialize with 0 if points have no ponding (saturation!=1)
        for i in range(0,z):                    #loop over layers
          stor+=vs.data[:,i,:,:] * dzmult[i] * vpo.data[0,i,:,:] * dz  #accumulate storage of layers
        stor*= 1000                             #convert to mm
        stor=stor[1:,:,:]                       #substract init value from each timestep to get relative change.
        return stor[:,np.newaxis,:,:]           #return storage[time,1,x,y]

#Absolute Ground Water Recharge or Discharge depending on the sign - This code needs to be checked for performance (very slow)
def recharge(vp,vks,vvg,dzmult,dz): # vp=pressure[t,z,x,y] ; vks=permz[z,x,y] ; vvg=vanGen[a/n,z,x,y] ; dzmul=multiplyer for each z layer, dz=dz equal distance of each layer, backz=user defined offset, making it zero for now; 
  #initialization                                                                                                    # init z with offset
  z=0                                                                                                           # init z with offset
  dzmult = np.array(dzmult)                                                                                     # convert dzmult to numpy array
  dz=dzmult[:] * dz                                                                                            # actual dz = dz-mult * dz
  zz=np.zeros_like(dzmult)                                                                                      # zz=elevation[z]
  k=np.zeros_like(vp.data[:,0,:,:])                                                                             # k=harmonic mean of permz[t,x,y]
  kr=np.zeros_like(vp.data[:,0,:,:])                                                                            # kr=upwinded van genuchten[t,x,y]
  diff=np.zeros_like(vp.data[:,0,:,:])                                                                          # diff=grad[t,x,y]      
  #calculate elevation
  for d in range(0,vp.data.shape[1]):                                                                           #loop over z and        
      z +=  0.5 *  dz[d];                                                                                       #add lower 1/2 dz to iterator (middle of cell)
      zz[d]=z;                                                                                                  #this is the elevation from this cell
      z +=  0.5 *  dz[d];                                                                                       #add upper 1/2 dz to iterator

  for t in range(0,vp.data.shape[0]):                                                                           #loop over time
    print "Step: = "+str(t)
    for x in range(0,vp.data.shape[2]):                                                                         #loop over x
      for y in range(0,vp.data.shape[3]):                                                                       #loop over y
        gwt=vp.data.shape[1]-2                                                                                  #gwt is at surface if not found bottom up ???
        for d in range(0,vp.data.shape[1]-1):                                                                   #loop over z    
            #Layer with GWT:
            if(vp.data[t,d,x,y]>0 and vp.data[t,d+1,x,y]<0 ):                                                   #gwt is where press(z,x,y)>0 and press(z+1,x,y)<0
                gwt=d
                break
        #Harmonic mean of Ksat:
        k[t,x,y]=(dz[gwt]+dz[gwt+1])/((dz[gwt]/vks.data[0,gwt,x,y])+(dz[gwt+1]/vks.data[0,gwt+1,x,y]))  #k=dz(gwt)+dz(gwt+1) / dz(gwt)/permz(gwt) + dz(gwt+1)/permz(gwt+1)

        #upwind VanGenuchten:
        diff[t,x,y]= (vp.data[t,gwt,x,y] + zz[gwt]) - (vp.data[t,gwt+1,x,y] + zz[gwt+1])                        #grad= press(gwt)-elev(gwt)   -   press(gwt+1)-elev(gwt+1)
        if diff[t,x,y] > 0:                                                                                     #if grad>0 take vanGen of gwt
          off=0
        else:                                                                                                   #else take vanGen of gwt+1
          off=1
        #VanGenuchten (vvg) alpha = 0 ; n = 1
        kr[t,x,y]= pow( 1 - pow(vvg.data[0,gwt+off,x,y]*abs(vp.data[t,gwt+off,x,y]),vvg.data[1,gwt+off,x,y]-1) /   pow(1+pow(vvg.data[0,gwt+off,x,y]*abs(vp.data[t,gwt+off,x,y]),vvg.data[1,gwt+off,x,y]),1-1/vvg.data[1,gwt+off,x,y]) , 2)  / pow( 1 + pow(vvg.data[0,gwt+off,x,y]*abs(vp.data[t,gwt+off,x,y]),vvg.data[1,gwt+off,x,y])    ,(1-1/vvg.data[1,gwt+off,x,y])/2)
                                                                                                                #van Genuchten= (  1-(ap)**n-1  /  (1+(ap)**n)**m   )**2  
                                                                                                                #cont'ed        (            1+(ap)**n              )**(m/n)                   
  #calculate term
  q=-1 *k*kr*diff * 1000                                                                                        #recharge term ; convert to mm/h ; switch sign
  return q[1:,np.newaxis,:,:] 


def runsatstor(vp,vs,vpo,dzmult,dz):
    km1=len(dzmult)-1
    full_stor = np.zeros_like(vp.data[1:,0,:,:])
    stor = np.zeros_like(vp.data[1:,0,:,:]) 
    for i in range(0,km1+1):
        full_stor+= 1                * dzmult[i] *  vpo.data[0,i,:,:] * dz
        stor     += vs.data[1:,i,:,:] * dzmult[i] *  vpo.data[0,i,:,:] * dz
        runsatstor   = full_stor - stor
    return runsatstor[:,np.newaxis,:,:]

def plantAvailStor(vp,vs,vpo,dzmult,dz):
        layers=len(dzmult)
        th=-10

        totStor=vp.data[:,layers-1,:,:].copy()                  #init with pressure (ponding)   
        totStor[vs.data[:,layers-1,:,:]!=1] = 0                 #back to 0 if not ponding
        
        for i in range(layers-1,layers-11,-1):                  #loop over 10 topmost layers
           mask = ma.masked_less(vp.data[:,i,:,:],th).mask      #mask where press is below  threshold
           iStor = ma.array((vs.data[:,i,:,:] * dzmult[i]       #calculate storage term
                                * vpo.data[0,i,:,:] * dz),mask=mask)
           iStor = ma.filled(iStor,0)                           #fill masked values with 0
           totStor += iStor                                     #add storage from i to total storage    
        dStor=(totStor[1:,:,:]-totStor[0,:,:]) * 1000           #calculate delta and convert to mm

        return dStor[:,np.newaxis,:,:]  


def surfaceRunOff(vp,vsl,dx,dy,M):
        #Not the same as ground water discharge
        discharge= ((dy*np.sqrt(np.absolute(vsl.data[0,0,:,:]))/M  * (vp.data[:,-1,:,:])**(5/3)) +  (dx*np.sqrt(np.absolute(vsl.data[1,0,:,:]))/M  * (vp.data[:,-1,:,:])**(5/3))) /3600
        discharge[vp.data[:,-1,:,:]<=0]=0
        ddischarge=(discharge[1:,:,:]-discharge[0,:,:])
        return ddischarge[:,np.newaxis,:,:]



def recharge(vp,vks,vvg,dzmult,kdz,backz): # vp=pressure[t,z,x,y] ; vks=permz[z,x,y] ; vvg=vanGen[a/n,z,x,y] ; dzmul=dz-multiplyer[z] kdz=konstant-dz; backz=z-offset 
  #initialization
  #z=backz                                                                                                      # init z with offset
  z=0                                                                                                           # init z with offset
  dzmult = np.array(dzmult)                                                                                     # convert dzmult to numpy array
  dz=dzmult[:] * kdz                                                                                            # actual dz = dz-mult * konst-dz
  zz=np.zeros_like(dzmult)                                                                                      # zz=elevation[z]
  k=np.zeros_like(vp.data[:,0,:,:])                                                                             # k=harmonic mean of permz[t,x,y]
  kr=np.zeros_like(vp.data[:,0,:,:])                                                                            # kr=upwinded van genuchten[t,x,y]
  diff=np.zeros_like(vp.data[:,0,:,:])                                                                          # diff=grad[t,x,y]      
  #calculate elevation
  for d in range(0,vp.data.shape[1]):                                                                           #loop over z and        
      z +=  0.5 *  dz[d];                                                                                       #add lower 1/2 dz to iterator (middle of cell)
      zz[d]=z;                                                                                                  #this is the elevation from this cell
      z +=  0.5 *  dz[d];                                                                                       #add upper 1/2 dz to iterator

  for t in range(0,vp.data.shape[0]):                                                                           #loop over time 
    print "Step: = "+str(t)
    for x in range(0,vp.data.shape[2]):                                                                         #loop over x
      for y in range(0,vp.data.shape[3]):                                                                       #loop over y
        gwt=vp.data.shape[1]-2                                                                                  #gwt is at surface if not found bottom up ???
        for d in range(0,vp.data.shape[1]-1):                                                                   #loop over z    
            #Layer with GWT:
            if(vp.data[t,d,x,y]>0 and vp.data[t,d+1,x,y]<0 ):                                                   #gwt is where press(z,x,y)>0 and press(z+1,x,y)<0
                gwt=d                                                                                           
                break
        #Harmonic mean of Ksat:
        k[t,x,y]=(dz[gwt]+dz[gwt+1])/((dz[gwt]/vks.data[0,gwt,x,y])+(dz[gwt+1]/vks.data[0,gwt+1,x,y]))  #k=dz(gwt)+dz(gwt+1) / dz(gwt)/permz(gwt) + dz(gwt+1)/permz(gwt+1)

        #upwind VanGenuchten:
        diff[t,x,y]= (vp.data[t,gwt,x,y] + zz[gwt]) - (vp.data[t,gwt+1,x,y] + zz[gwt+1])                        #grad= press(gwt)-elev(gwt)   -   press(gwt+1)-elev(gwt+1)
        if diff[t,x,y] > 0:                                                                                     #if grad>0 take vanGen of gwt
          off=0
        else:                                                                                                   #else take vanGen of gwt+1
          off=1
        #VanGenuchten (vvg) alpha = 0 ; n = 1
        kr[t,x,y]= pow( 1 - pow(vvg.data[0,gwt+off,x,y]*abs(vp.data[t,gwt+off,x,y]),vvg.data[1,gwt+off,x,y]-1) /   pow(1+pow(vvg.data[0,gwt+off,x,y]*abs(vp.data[t,gwt+off,x,y]),vvg.data[1,gwt+off,x,y]),1-1/vvg.data[1,gwt+off,x,y]) , 2)  / pow( 1 + pow(vvg.data[0,gwt+off,x,y]*abs(vp.data[t,gwt+off,x,y]),vvg.data[1,gwt+off,x,y])    ,(1-1/vvg.data[1,gwt+off,x,y])/2)
                                                                                                                #van Genuchten= (  1-(ap)**n-1  /  (1+(ap)**n)**m   )**2  
                                                                                                                #cont'ed        (            1+(ap)**n              )**(m/n)                    
  #calculate term
  q=-1 *k*kr*diff * 1000                                                                                        #recharge term ; convert to mm/h ; switch sign
  #return difference
#  qd=q[1:,:,:]-q[0,:,:]                                                                                                #difference to t0
  return q[1:,np.newaxis,:,:]                                                                                   #return as 4d array





#Meteograms


def precipMet(vr,vs,xj,yj,initDate,steps,sps):
        out=[[ [   [] ,[],[] ]  ]]
        date = initDate + dt.timedelta(seconds=sps)
        for i in range(0,steps):
                out[0][0][0].append(date)
                out[0][0][1].append(vs.data[i,0,xj,yj])
                out[0][0][2].append(vr.data[i,0,xj,yj])
                date+=dt.timedelta(seconds=sps)
        out = np.array(out)
        return out

def tempMet(v,xj,yj,initDate,steps,sps):
        out=[[ [   [] ,[] ]  ]]
        date = initDate + dt.timedelta(seconds=sps)
        for i in range(0,steps):
                out[0][0][0].append(date)
                out[0][0][1].append(v.data[i,0,xj,yj]-273.16)
                date+=dt.timedelta(seconds=sps)
        out = np.array(out)
        return out

def relhumMet(v,xj,yj,initDate,steps,sps):
        out=[[ [   [] ,[] ]  ]]  
        date = initDate + dt.timedelta(seconds=sps)
        for i in range(0,steps):
                out[0][0][0].append(date)
                out[0][0][1].append(v.data[i,0,xj,yj])
                date+=dt.timedelta(seconds=sps)
        out = np.array(out)
        return out

def slpMet(v,xj,yj,initDate,steps,sps):
        out=[[ [   [] ,[] ]  ]]  
        date = initDate + dt.timedelta(seconds=sps)
        for i in range(0,steps):
                out[0][0][0].append(date)
                out[0][0][1].append(v.data[i,0,xj,yj]/100)
                date+=dt.timedelta(seconds=sps)
        out = np.array(out)
        return out


def windMet(v1,v2,xj,yj,initDate,steps,sps):
        out=[[ [   [] ,[] ]  ]]  
        date = initDate + dt.timedelta(seconds=sps)
        for i in range(0,steps):
                out[0][0][0].append(date)
                out[0][0][1].append(    np.sqrt(v1.data[i,0,xj,yj]*v1.data[i,0,xj,yj] + v2.data[i,0,xj,yj]*v2.data[i,0,xj,yj])    )
                date+=dt.timedelta(seconds=sps)
        out = np.array(out)
        return out
