#!/usr/bin/env python
import sys
import numpy as np
#######################################
#Distribute erosion per PFT
#######################################
OrcPath=' ' #Enter here the path to the vegegation fractions at 5arcminute resolution
EPath=' '#Enter here the path to the erosion files
pfts=range(0,13)
years=range(1851,2006)
kols=4322
rows=1680
simulation="equilibrium" # choose from equlibrium or transient

if simulation=="equilibrium":
  ##############################
  #for equilibrium only
  ##############################
  vegetnc=Dataset('%s/PFTmap_LUHv2_BM3_HoughtonCountryForestarea_withoutNoBio_1851_5m.nc' %(OrcPath),'r')
  E=[]
  for pft in pfts:
    if pft>0 and pft<9:
      veget_tree=vegetnc.variables['maxvegetfrac'][0,1:9,:,:] #shape(time,pft,lat,lon)
      veget_tree=veget_tree.reshape(8,rows,kols)
      veget_tree=1.*veget_tree
      veget_tree[veget_tree>1000.]=0.
      Enc=Dataset('%sE_tree_eq_grid.nc' % (EPath),'r') # soil erosion rate in t/ha/y, shape(rows,kols)
      erosion=Enc.variables['E'][:]*0.0001 #t/ha to t/m2
      erosion[erosion>0.01]=0.
      erosion_pft=erosion/np.nansum(veget_tree,axis=0)
      erosion_pft[erosion_pft<0.]=np.nan
      E.append(erosion_pft)
      Enc.close()
    elif pft>8 and pft<11:
      veget_gras=vegetnc.variables['maxvegetfrac'][0,9:11,:,:] #shape(time,pft,lat,lon)
      veget_gras=veget_gras.reshape(2,rows,kols)
      veget_gras=1.*veget_gras
      veget_gras[veget_gras>1000.]=0.
      Enc=Dataset('%sE_grass_eq_grid.nc' % (EPath),'r') # soil erosion rate in t/ha/y, shape(rows,kols)
      erosion=Enc.variables['E'][:]*0.0001 
      erosion[erosion>0.01]=0.
      erosion_pft=erosion/np.nansum(veget_gras,axis=0)
      erosion_pft[erosion_pft<0.]=np.nan
      E.append(erosion_pft)
      Enc.close()
    elif pft>10 and pft<13:
      veget_crop=vegetnc.variables['maxvegetfrac'][0,11:13,:,:] #shape(time,pft,lat,lon)
      veget_crop=veget_crop.reshape(2,rows,kols)
      veget_crop=1.*veget_crop
      veget_crop[veget_crop>1000.]=0.
      Enc=Dataset('%sE_crop_eq_grid.nc' % (EPath),'r') # soil erosion rate in t/ha/y, shape(rows,kols)
      erosion=Enc.variables['E'][:]*0.0001 
      erosion[erosion>0.01]=0.
      erosion_pft=erosion/np.nansum(veget_crop,axis=0)
      erosion_pft[erosion_pft<0.]=np.nan
      E.append(erosion_pft)
      Enc.close()
    else:
      veget_bare=vegetnc.variables['maxvegetfrac'][0,0,:,:] #shape(time,pft,lat,lon)
      veget_bare=veget_bare.reshape(rows,kols)
      veget_bare=1.*veget_bare
      veget_bare[veget_bare>1000.]=0.
      Enc=Dataset('%sE_bare_eq_grid.nc' % (EPath),'r') # soil erosion rate in t/ha/y, shape(rows,kols)
      erosion=Enc.variables['E'][:]*0.0001
      erosion[erosion>0.01]=0.
      erosion_pft=erosion/veget_bare
      erosion_pft[erosion_pft<0.]=np.nan
      E.append(erosion_pft)
      Enc.close()
  Enc.close()
  E=np.asarray(E).reshape(len(pfts),rows,kols)
  E[np.isnan(E)==True]=0.
  E[np.isfinite(E)==False]=0.

  output = Dataset('%s/E_PFT/E_pft_eq.nc' % (EPath),'w')
  output.createDimension('lat',rows)
  output.createDimension('lon',kols)
  output.createDimension('pft',len(pfts))
  output.createVariable('E','d',('pft','lat','lon',))
  output.variables['E'][:] = E
  output.close() 

if simulation =='transient':
  ###############################################
  #erosion rates per year and per pft (LUC+
  for y in range(len(years)):
    vegetnc=Dataset('%s/PFTmap_LUHv2_BM3_HoughtonCountryForestarea_withoutNoBio_%i_5m.nc' % (OrcPath,years[y]),'r')
    E=[]
    for pft in pfts:
      if pft>0 and pft<9:
        veget_tree=vegetnc.variables['maxvegetfrac'][0,1:9,:,:] #shape(time,pft,lat,lon)
        veget_tree=veget_tree.reshape(8,rows,kols)
        veget_tree=1.*veget_tree
        veget_tree[veget_tree>1000.]=0.
        Enc=Dataset('%sE_tree_%04i_grid.nc' % (EPath,years[y]),'r') # soil erosion rate in t/ha/y, shape(rows,kols)
        erosion=Enc.variables['E'][:]*0.0001 #t/ha to t/m2
        erosion[erosion>0.01]=0.
        erosion_pft=erosion/np.nansum(veget_tree,axis=0)
        erosion_pft[erosion_pft<0.]=np.nan
        E.append(erosion_pft)
        Enc.close()
      elif pft>8 and pft<11:
        veget_gras=vegetnc.variables['maxvegetfrac'][0,9:11,:,:] #shape(time,pft,lat,lon)
        veget_gras=veget_gras.reshape(2,rows,kols)
        veget_gras=1.*veget_gras
        veget_gras[veget_gras>1000.]=0.
        Enc=Dataset('%sE_grass_%04i_grid.nc' % (EPath,years[y]),'r') # soil erosion rate in t/ha/y, shape(rows,kols)
        erosion=Enc.variables['E'][:]*0.0001 
        erosion[erosion>0.01]=0.
        erosion_pft=erosion/np.nansum(veget_gras,axis=0)
        erosion_pft[erosion_pft<0.]=np.nan
        E.append(erosion_pft)
        Enc.close()
      elif pft>10 and pft<13:
        veget_crop=vegetnc.variables['maxvegetfrac'][0,11:13,:,:] #shape(time,pft,lat,lon)
        veget_crop=veget_crop.reshape(2,rows,kols)
        veget_crop=1.*veget_crop
        veget_crop[veget_crop>1000.]=0.
        Enc=Dataset('%sE_crop_%04i_grid.nc' % (EPath,years[y]),'r') # soil erosion rate in t/ha/y, shape(rows,kols)
        erosion=Enc.variables['E'][:]*0.0001 
        erosion[erosion>0.01]=0.
        erosion_pft=erosion/np.nansum(veget_crop,axis=0)
        erosion_pft[erosion_pft<0.]=np.nan
        E.append(erosion_pft)
        Enc.close()
      else:
        veget_bare=vegetnc.variables['maxvegetfrac'][0,0,:,:] #shape(time,pft,lat,lon)
        veget_bare=veget_bare.reshape(rows,kols)
        veget_bare=1.*veget_bare
        veget_bare[veget_bare>1000.]=0.
        Enc=Dataset('%sE_nare_%04i_grid.nc' % (EPath,years[y]),'r') # soil erosion rate in t/ha/y, shape(rows,kols)
        erosion=Enc.variables['E'][:]*0.0001
        erosion[erosion>0.01]=0.
        erosion_pft=erosion/veget_bare
        erosion_pft[erosion_pft<0.]=np.nan
        E.append(erosion_pft)
        Enc.close()
    Enc.close()
    E=np.asarray(E).reshape(len(pfts),rows,kols)
    E[np.isnan(E)==True]=0.
    E[np.isfinite(E)==False]=0.
    output = Dataset('%s/E_PFT/E_pft_%04i.nc' % (EPath,years[y]),'w')
    output.createDimension('lat',rows)
    output.createDimension('lon',kols)
    output.createDimension('pft',len(pfts))
    output.createVariable('E','d',('pft','lat','lon',))
    output.variables['E'][:] = E
    output.close() 

    vegetnc.close()

