#!/usr/bin/env python
import sys
sys.path.append('../../')
import os
import numpy as np
import string
import datetime as dt
import logging
import da.tools.io4 as io
from da.tools.general import CreateDirs

fontsize=10

def nice_lat(cls):
  #
  # Convert latitude from decimal to cardinal
  #
  if cls > 0:
     h = 'N'
  else:
     h = 'S'
  
  dec, deg = np.math.modf(cls)

  return string.strip('%2d %2d\'%s' % (abs(deg), round(abs(60 * dec), 0), h))

def nice_lon(cls):
  #
  # Convert longitude from decimal to cardinal
  #
  if cls > 0:
     h = 'E'
  else:
     h = 'W'
  
  dec, deg = np.math.modf(cls)

  return string.strip('%3d %2d\'%s' % (abs(deg), round(abs(60 * dec), 0), h))

def nice_alt(cls):
  #
  # Reformat elevation or altitude
  #
  return string.strip('%10.1f masl' % round(cls, -1))


def SummarizeObs(DaCycle,printfmt='html'):
    """***************************************************************************************
    Call example:

    python summarize_obs.py 

    Option printfmt    : [tex,scr,html] print summary table in latex, terminal, or html format 

    Other options are all those needed to create a DaCycle object

    OR:

    call directly from a python script as:

    q=SummarizeObs(DaCycle,printfmt='html')

    ***************************************************************************************"""

    sumdir=os.path.join(DaCycle['dir.analysis'],'summary')
    if not os.path.exists(sumdir):
        logging.info( "Creating new directory "+sumdir )
        os.makedirs(sumdir)

    mrdir=os.path.join(DaCycle['dir.analysis'],'data_molefractions')
    if not os.path.exists(mrdir):
        logging.error( "Input directory does not exist (%s), exiting... "%mrdir )
        return None

    mrfiles = os.listdir(mrdir)
    infiles = [os.path.join(mrdir,f) for f in mrfiles if f.endswith('.nc')]

    if printfmt == 'tex': 
        print '\\begin{tabular*}{\\textheight}{l l l l r r r r}'
        print 'Code &  Name & Lat, Lon, Elev & Lab &  N (flagged) & $\\sqrt{R}$  &Inn \\XS &Bias\\\\'
        print '\hline\\\\ \n\multicolumn{8}{ c }{Semi-Continuous Surface Samples}\\\\[3pt] '
        fmt= '%8s  & '+' %55s  & '+'%20s &'+'%6s &'+' %4d (%d)  & '+' %5.2f  & '+' %5.2f & '+'%+5.2f  \\\\'
    elif printfmt == 'html':
        tablehead = \
              "<TR>\n <TH> Site code </TH> \
                   <TH> Sampling Type </TH> \
                   <TH> Lab. </TH> \
                   <TH> Country </TH> \
                   <TH> Lat, Lon, Elev. (m ASL) </TH> \
                   <TH> No. Obs. Avail. </TH>  \
                   <TH> &#8730;R (&mu;mol mol<sup>-1</sup>) </TH>  \
                   <TH> &#8730;HPH  (&mu;mol mol<sup>-1</sup>) </TH> \
                   <TH> H(x)-y (&mu;mol mol<sup>-1</sup>) </TH> \
                   <TH> H(x)-y (JJAS) (&mu;mol mol<sup>-1</sup>) </TH>  \
                   <TH> H(x)-y (NDJFMA) (&mu;mol mol<sup>-1</sup>) </TH> \n \
               </TR>\n"

        fmt=  """<TR> \n \
                <TD><a href='javascript:LoadCO2Tseries("%s")'>%s </a></TD>\
                <TD>%s</TD>\
                <TD>%s</TD>\
                <TD>%40s</TD>\
                <TD>%s</TD>\
                <TD>%d</TD>\
                <TD>%+5.2f</TD>\
                <TD>%+5.2f</TD>\
                <TD>%+5.2f&plusmn;%5.2f</TD>\
                <TD>%+5.2f&plusmn;%5.2f</TD>\
                <TD>%+5.2f&plusmn;%5.2f</TD>\n \
               </TR>\n"""
    elif printfmt == 'scr':
        print 'Code   Site     NObs flagged  R  Inn X2'
        fmt= '%8s '+' %55s  %s %s'+' %4d '+' %4d '+' %5.2f '+' %5.2f'

    table=[]
    lons=[]
    lats=[]
    names=[]
    nobs=[]
    infiles.sort()
    for infile in infiles:
            print infile
            #logging.debug( infile )
            f         = io.CT_CDF(infile,'read')
            date      = f.GetVariable('time')
            obs       = f.GetVariable('value')
            mdm       = f.GetVariable('modeldatamismatch')
            simulated = f.GetVariable('modelsamplesmean_forecast')
            simulated_std = f.GetVariable('modelsamplesstandarddeviation_forecast')

            obsco2=obs.compress(obs>0)*1e6
            obsc13=obs.compress(obs<0)
            mdmco2=mdm.compress(mdm>0)*1e6
            mdmc13=mdm.compress(mdm<0)
            simulatedco2=simulated.compress(simulated>0)*1e6
            simulatedc13=simulated.compress(simulated<0)
            simulatedco2_std=simulated_std.compress(simulated_std>0)*1e6
            simulatedc13_std=simulated_std.compress(simulated_std<0)
            if obs.mean()>0:
                print infile,obs,obsco2,simulatedco2
            else:
                print infile,obs,obsc13,simulatedc13,simulatedc13.shape
            lenc13=simulatedc13.shape[0]
            pydates = [dt.datetime(1970,1,1)+dt.timedelta(seconds=int(d)) for d in date]

            summer = [i for i,d in enumerate(pydates) if d.month in [6,7,8,9] ]
            winter = [i for i,d in enumerate(pydates) if d.month in [11,12,1,2,3,4] ]

            if obs.mean()>0:
                diffco2=((simulatedco2-obsco2).mean())
                diffsummerco2=((simulatedco2-obsco2).take(summer).mean())
                diffwinterco2=((simulatedco2-obsco2).take(winter).mean())
                diffstdco2=((simulatedco2-obsco2).std())
                diffsummerstdco2=((simulatedco2-obsco2).take(summer).std())
                diffwinterstdco2=((simulatedco2-obsco2).take(winter).std())

            else:
                print simulatedc13.shape,obsc13[:lenc13].shape,np.array(summer).shape,np.array(winter).shape
                diffc13=((simulatedc13-obsc13[:lenc13]).mean())
                diffsummerc13=((simulatedc13-obsc13[:lenc13]).take(summer).mean())
                diffwinterc13=((simulatedc13-obsc13[:lenc13]).take(winter).mean())
                diffstdc13=((simulatedc13-obsc13[:lenc13]).std())
                diffsummerstdc13=((simulatedc13-obsc13[:lenc13]).take(summer).std())
                diffwinterstdc13=((simulatedc13-obsc13[:lenc13]).take(winter).std())
            longsitestring=f.site_name+', '+f.site_country
            location=nice_lat(f.site_latitude)+', '+ nice_lon(f.site_longitude)+', '+nice_alt(f.site_elevation)

            if obs.mean()>0:
                if printfmt == 'html':
                    ss=(f.site_code.upper(),
                        f.site_code.upper(),
                        f.dataset_project,
                        f.lab_abbr,
                        f.site_country,
                        location,
                        len(np.ma.compressed(mdmco2)),
                        mdmco2.mean(),
                        np.sqrt((simulatedco2_std**2).mean()),
                        diffco2,diffstdco2,
                        diffsummerco2,diffsummerstdco2,
                        diffwinterco2,diffwinterstdco2)
            else:
                if printfmt == 'html':
                    ss=(f.site_code.upper(),
                        f.site_code.upper(),
                        f.dataset_project,
                        f.lab_abbr,
                        f.site_country,
                        location,
                        len(np.ma.compressed(mdmc13)),
                        mdmc13.mean(),
                        np.sqrt((simulatedc13_std**2).mean()),
                        diffc13,diffstdc13,
                        diffsummerc13,diffsummerstdc13,
                        diffwinterc13,diffwinterstdc13)

            table.append(ss)
            f.close()

    if printfmt == 'tex':
        saveas=os.path.join(sumdir,'site_table.tex')
        f=open(saveas,'w')
    elif printfmt == 'html':
        saveas=os.path.join(sumdir,'site_table.html')
        f=open(saveas,'w')
        txt = "<meta http-equiv='content-type' content='text/html;charset=utf-8' />\n"
        f.write(txt)
        txt="<table border=1 cellpadding=2 cellspacing=2 width='100%' bgcolor='#EEEEEE'>\n"
        f.write(txt)

    f.write(tablehead)

    for i,ss in enumerate(table):

        f.write(fmt%ss)
        if (i+1)%15 == 0:
            f.write(tablehead)

    if printfmt == 'tex': 
        f.write( '\cline{2-8}\\\\' )
        f.write( '\hline \\\\')
        f.write( '\end{tabular*}')
    else:
        txt="\n</table>"
        f.write(txt)
    f.close()

    logging.info("File written with summary: %s" % saveas)

def SummarizeStats(DaCycle):
    """
    Summarize the statistics of the observations for this cycle
    This includes X2 statistics, RMSD, and others for both forecast and
    final fluxes
    """
    import string

    sumdir=os.path.join(DaCycle['dir.analysis'],'summary')
    if not os.path.exists(sumdir):
        logging.info( "Creating new directory "+sumdir )
        os.makedirs(sumdir)

    # get forecast data from optimizer.ddddd.nc

    startdate                       = DaCycle['time.start'] 
    DaCycle['time.sample.stamp']    = "%s"%(startdate.strftime("%Y%m%d"),)
    infile                          = os.path.join(DaCycle['dir.output'],'optimizer.%s.nc'%DaCycle['time.sample.stamp'])
    print infile
    if not os.path.exists(infile):
        logging.error("File not found: %s"%infile)
        raise IOError

    f = io.CT_CDF(infile,'read')
    sites = f.GetVariable('sitecode')
    y0  = f.GetVariable('observed')
    hx  = f.GetVariable('modelsamplesmean_prior')
    dF  = f.GetVariable('modelsamplesdeviations_prior')
    HPHTR = f.GetVariable('totalmolefractionvariance')
    R     = f.GetVariable('modeldatamismatchvariance')
    flags = f.GetVariable('flag')
    f.close()

    y0co2=y0.compress(y0>0)*1e6
    y0c13=y0.compress(y0<0)

    hxco2=hx.compress(hx>0)*1e6
    hxc13=hx.compress(hx<0)

    dFc13=[]
    dFco2=[]
    sitesc13=[]
    sitesco2=[]

    for i in range(dF[:,0].shape[0]):
        if hx[i]<0:
            dFc13.append(dF[i,:])
            sitesc13.append(sites[i,:])
        if hx[i]>0:
            dFco2.append(dF[i,:])
            sitesco2.append(sites[i,:])

    dFco2=np.array(dFco2)*1e6
    dFc13=np.array(dFc13)

    HPHTRco2= HPHTR.compress(hx>0)*1e6*1e6
    HPHTRc13= HPHTR.compress(hx<0)

    Rco2= R.compress(hx>0)*1e6*1e6
    Rc13= R.compress(hx<0)

    flagsco2=flags.compress(hx>0)
    flagsc13=flags.compress(hx<0)

    HPHTco2  = dFco2.dot(np.transpose(dFco2)).diagonal()/(dFco2.shape[1]-1.0)
    HPHTc13  = dFc13.dot(np.transpose(dFc13)).diagonal()/(dFc13.shape[1]-1.0)

    rejectedco2 = (flagsco2 == 2.0)
    rejectedc13 = (flagsc13 == 2.0)

    sitecodesco2 = [string.join(s.compressed(),'').strip() for s in sitesco2]
    sitecodesc13 = [string.join(s.compressed(),'').strip() for s in sitesc13]


    # calculate X2 per observation for this time step

    x2co2=[]
    x2c13=[]
    for i,site in enumerate(sitecodesco2):
        x2co2.append((y0co2[i]-hxco2[i])**2/HPHTRco2[i] )
    for i,site in enumerate(sitecodesc13):
        x2c13.append((y0c13[i]-hxc13[i])**2/(HPHTc13[i] +Rc13[i]))

    x2co2=np.ma.masked_where(HPHTRco2 == 0.0,x2co2)
    x2c13=np.ma.masked_where((HPHTRc13) == 0.0,x2c13)

    # calculate X2 per site
    saveas=os.path.join(sumdir,'x2_table_%s.html'%DaCycle['time.sample.stamp'] )
    saveasc13=os.path.join(sumdir,'x2c13_table_%s.html'%DaCycle['time.sample.stamp'] )
    logging.info("Writing HTML tables for this cycle (%s)"%saveas)
    logging.info("Writing HTML tables for this cycle (%s)"%saveasc13)
    f=open(saveas,'w')
    fc13=open(saveasc13,'w')
    txt = "<meta http-equiv='content-type' content='text/html;charset=utf-8' />\n"
    f.write(txt)
    fc13.write(txt)
    txt="<table border=1 cellpadding=2 cellspacing=2 width='100%' bgcolor='#EEEEEE'>\n"
    f.write(txt)
    fc13.write(txt)
    tablehead = \
          "<TR>\n <TH> Site code </TH> \
               <TH> N<sub>obs</sub> </TH>  \
               <TH> N<sub>rejected</sub> </TH>  \
               <TH> &#8730;R (&mu;mol mol<sup>-1</sup>) </TH>  \
               <TH> &#8730;HPH<sup>T</sup> (&mu;mol mol<sup>-1</sup>) </TH>  \
               <TH> H(x)-y (&mu;mol mol<sup>-1</sup>) </TH> \n \
               <TH> X2 </TH> \n \
           </TR>\n"

    fmt=  """<TR> \n \
            <TD>%s</TD>\
            <TD>%d</TD>\
            <TD>%d</TD>\
            <TD>%+5.2f</TD>\
            <TD>%+5.2f</TD>\
            <TD>%+5.2f&plusmn;%5.2f</TD>\
            <TD>%5.2f</TD>\n \
           </TR>\n"""

    f.write(tablehead)
    fc13.write(tablehead)

    set_sitesco2 = set(sitecodesco2)
    set_sitesco2 = np.sort(list(set_sitesco2))

    for i,site in enumerate(set_sitesco2):
        sel = [i for i,s in enumerate(sitecodesco2) if s == site]
        ss=(site, len(sel),rejectedco2.take(sel).sum(), np.sqrt(Rco2.take(sel)[0]), np.sqrt(HPHTco2.take(sel).mean()), (hxco2-y0co2).take(sel).mean(),(hxco2-y0co2).take(sel).std(), x2co2.take(sel).mean(),)
        #print site,sel,x2.take(sel)

        f.write(fmt%ss)
        if (i+1)%15 == 0:
            f.write(tablehead)

    txt="\n</table>"
    f.write(txt)
    f.close()


    set_sitesc13 = set(sitecodesc13)
    set_sitesc13 = np.sort(list(set_sitesc13))

    for i,site in enumerate(set_sitesc13):
        sel = [i for i,s in enumerate(sitecodesc13) if s == site]
        ss=(site, len(sel),rejectedc13.take(sel).sum(), np.sqrt(Rc13.take(sel)[0]), np.sqrt(HPHTc13.take(sel).mean()), (hxc13-y0c13).take(sel).mean(),(hxc13-y0c13).take(sel).std(), x2c13.take(sel).mean(),)
        #print site,sel,x2.take(sel)

        fc13.write(fmt%ss)
        if (i+1)%15 == 0:
            fc13.write(tablehead)

    txt="\n</table>"
    fc13.write(txt)
    fc13.close()

    # Now summarize for each site across time steps

    if not DaCycle['time.start'] >= dt.datetime(2008,12,29):
        return

    logging.info("Writing HTML tables for each site")
    for site in set_sitesco2:
        saveas=os.path.join(sumdir,'%s_x2co2.html'%site)
        f=open(saveas,'w')
        logging.debug(saveas)
        txt = "<meta http-equiv='content-type' content='text/html;charset=utf-8' />\n"
        f.write(txt)
        txt="<table border=1 cellpadding=2 cellspacing=2 width='100%' bgcolor='#EEEEEE'>\n"
        f.write(txt)
        tablehead = \
              "<TR>\n <TH> From File </TH> \
                   <TH> Site </TH>  \
                   <TH> N<sub>obs</sub> </TH>  \
                   <TH> N<sub>rejected</sub> </TH>  \
                   <TH> &#8730;R (&mu;mol mol<sup>-1</sup>) </TH>  \
                   <TH> &#8730;HPH<sup>T</sup> (&mu;mol mol<sup>-1</sup>) </TH>  \
                   <TH> H(x)-y (&mu;mol mol<sup>-1</sup>) </TH> \n \
                   <TH> X2 </TH> \n \
               </TR>\n"
        f.write(tablehead)

        files = os.listdir(sumdir)
        x2co2_files = [file for file in files if file.startswith('x2co2')]
        for htmlfile in x2co2_files:
            lines = grep(site,os.path.join(sumdir,htmlfile))
            for line in lines:
                f.write('<TR>\n')
                f.write('<TD>'+htmlfile+'</TD>')
                f.write(line+'\n')
                f.write('</TR>\n')

        txt="\n</table>"
        f.write(txt)
        f.close()

    for site in set_sitesc13:
        saveas=os.path.join(sumdir,'%s_x2c13.html'%site)
        fc13=open(saveas,'w')
        logging.debug(saveas)
        txt = "<meta http-equiv='content-type' content='text/html;charset=utf-8' />\n"
        fc13.write(txt)
        txt="<table border=1 cellpadding=2 cellspacing=2 width='100%' bgcolor='#EEEEEE'>\n"
        fc13.write(txt)
        tablehead = \
              "<TR>\n <TH> From File </TH> \
                   <TH> Site </TH>  \
                   <TH> N<sub>obs</sub> </TH>  \
                   <TH> N<sub>rejected</sub> </TH>  \
                   <TH> &#8730;R (&mu;mol mol<sup>-1</sup>) </TH>  \
                   <TH> &#8730;HPH<sup>T</sup> (&mu;mol mol<sup>-1</sup>) </TH>  \
                   <TH> H(x)-y (&mu;mol mol<sup>-1</sup>) </TH> \n \
                   <TH> X2 </TH> \n \
               </TR>\n"
        fc13.write(tablehead)

        files = os.listdir(sumdir)
        x2c13_files = [file for file in files if file.startswith('x2c13')]
        for htmlfile in x2c13_files:
            lines = grep(site,os.path.join(sumdir,htmlfile))
            for line in lines:
                fc13.write('<TR>\n')
                fc13.write('<TD>'+htmlfile+'</TD>')
                fc13.write(line+'\n')
                fc13.write('</TR>\n')

        txt="\n</table>"
        fc13.write(txt)
        fc13.close()



import re
def grep(pattern,file):
    fileObj = open(file,'r')
    r=[]
    linenumber=0
    for line in fileObj:
        if re.search(pattern,line):
            r.append(line)
    return r

# main body if called as script

if __name__ == '__main__':    # started as script

    from da.ctc13.initexit_offline import CycleControl
    from da.ctc13.dasystem import CtDaSystem 
    from da.ct.statevector import CtStateVector 

    sys.path.append('../../')

    logging.root.setLevel(logging.DEBUG)

    DaCycle = CycleControl(args={'rc':'../../ctdas_co2c13_terdisbcbmanscale_neweps449_test.rc'})
    DaCycle.Initialize()
    DaCycle.ParseTimes()

    DaSystem    = CtDaSystem('../rc/carbontracker_co2c13_eps449_zeus.rc')
    #DaSystem.Initialize()

    DaCycle.DaSystem    = DaSystem

    #q=SummarizeObs(DaCycle)

    while DaCycle['time.start'] < DaCycle['time.finish']:

        q=SummarizeStats(DaCycle)
        DaCycle.AdvanceCycleTimes()

    sys.exit(0)


