/* sundials.hpp
 * 
 * Copyright (C) 2010 Sylwester Arabas
 * 
 * This program is free software; you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation; either version 3 of the License, or (at
 * your option) any later version.
 * 
 * This program is distributed in the hope that it will be useful, but
 * WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
 * General Public License for more details.
 * 
 * You should have received a copy of the GNU General Public License
 * along with this program; if not, write to the Free Software
 * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
 */

#ifndef SUNDIALS_HPP
#  define SUNDIALS_HPP

#  include "drops.hpp"
#  include <nvector/nvector_serial.h>

#  define self SolverCVODES
class self : public Solver
{

  private: N_Vector y0, y1, p;
  protected: void *cvode_mem;
  private: Model* model;
  private: Output* output;
  private: SpectraMemLayout* ml;
  private: ModelParams* params;
  private: quantity<si::time> tmax;
  private: Tolerances *tol;
  private: InitBinLayout *ibl;

  public: virtual int getLinearMultistepMethod() = 0;
  public: virtual int getNonlinearSolverIteration() = 0;

  // this cannot be done in a constructor, as we need to call implementations of the 
  // above pure virtual functions, which is not allowed inside a constructor in C++
  public: void construct(Model *model_, Output *output_, ModelParams *params_, 
    InitpTq *iptq, vector<InitSpectrum*>* ispecs, InitBinLayout *ibl_, SpectraMemLayout *ml_,
    SatVapPresMltplr *svpmltp, Tolerances* tol_, quantity<si::time> tmax_)
  {
    model = model_;
    output = output_;
    ml = ml_;
    ibl = ibl_;
    params = params_;
    tol = tol_;
    tmax = tmax_;

    y0 = N_VNew_Serial(model->isBulk() ? 3 : 3 + ml->max_length());
    y1 = N_VClone_Serial(y0);
    NV_LENGTH_S(y0) = model->isBulk() ? 3 : 3 + ml->length();
    NV_LENGTH_S(y1) = model->isBulk() ? 3 : 3 + ml->length();

    if (!model->isBulk())
    {
      p = N_VNew_Serial(ml->max_length());
      NV_LENGTH_S(p) = ml->length();
    }

    NV_Ith_S(y0, model->getStateVectorIndexPressure()) = 
      iptq->getPressure() / quantity<si::pressure>(1. * si::pascals);
    NV_Ith_S(y0, model->getStateVectorIndexTemperature()) = 
      iptq->getTemperature() / si::kelvins;
    NV_Ith_S(y0, model->getStateVectorIndexSpecificHumidity()) = 
      iptq->getSpecificHumidity(); 

    if (!model->isBulk())
    {

      for (int s = 0; s < ml->n_specs(); ++s)
      {
        int n = ml->n_bins(s);

        for (int b = 0; b < n; ++b)
        {
          // concentration
          quantity<si::length> r_d_l = ibl->getLeftEdge(b), r_d_r = ibl->getRightEdge(b);
          NV_Ith_S(p, ml->N_ix(s, b)) = (
            ispecs->at(s)->n_n(r_d_l + .5 * (r_d_r - r_d_l)) * (r_d_r - r_d_l) / ( // TODO: .5 -> (InBinLayout->charactericRadii())?
              iptq->getPressure() / iptq->getTemperature() / params->R(iptq->getSpecificHumidity())
            )
          ) * si::kilogramme; 
        }

        for (int b = 0; b < n + 1; ++b)
        {
          // drop temperatures (equal to air temperature)
          NV_Ith_S(y0, model->getStateVectorIndexRadii() + ml->To_l_ix(s, b)) = 
            iptq->getTemperature() / si::kelvins;

          quantity<si::dimensionless> svpm_eq = iptq->getPressure() / (
            constants::p_v_s(iptq->getTemperature()) * (
              constants::epsilon * (pow<-1>(iptq->getSpecificHumidity()) - 1.) + 1.
            )
          );

          if (ispecs->at(s)->isDryNotWet())
          {
            // dry radii
            NV_Ith_S(p, ml->r_d_l_ix(s, b)) = ibl->getLeftEdge(b) / si::metres;

            // wet radii (at equilibrium with initial thermod. conditions)
            NV_Ith_S(y0, model->getStateVectorIndexRadii() + ml->r_l_ix(s, b)) = svpmltp->r_eq(
              svpm_eq, iptq->getTemperature(), ibl->getLeftEdge(b), params->S->at(s)) / si::metres;
          }
          else 
          {
            // wet radii (at equilibrium with initial thermod. conditions)
            NV_Ith_S(y0, model->getStateVectorIndexRadii() + ml->r_l_ix(s, b)) = ibl->getLeftEdge(b) / si::metres;

            // dry radii
            NV_Ith_S(p, ml->r_d_l_ix(s, b)) = svpmltp->rd_eq(
              svpm_eq, iptq->getTemperature(), ibl->getLeftEdge(b), params->S->at(s)) / si::metres;
          }
        }
      }
    }
  }

  private: void init(realtype t0)
  {
    cvode_mem = CVodeCreate(
      this->getLinearMultistepMethod(), 
      this->getNonlinearSolverIteration()
    );
    if (cvode_mem == NULL)
      throw exception();

    params->pp = p;
    params->cvode_mem = cvode_mem;

    if (CV_SUCCESS != CVodeSetUserData(cvode_mem, params))
      throw exception();

    if (CV_SUCCESS != CVodeInit(cvode_mem, model->getODERhsFnPtr(), t0, y0))
      throw exception();

    tol->setTolerances(cvode_mem);

    postinit();
  }

  public: void run() 
  {
    output->head();
    output->record(0. * si::seconds, y0, p);
    realtype t0 = 0, t1;
    N_Vector tmp;

    init(t0);
    while (t0 * si::seconds < tmax)
    {
      // move just one step ...
      if (CV_SUCCESS != CVode(cvode_mem, tmax / si::seconds, y1, &t1, CV_ONE_STEP)) 
        throw exception();

      // ... and trigger the refinement logic after each step
      if (!model->isBulk() && refineSpectrum(y0, y1, p)) 
      {
        // reinitialize CVODE in case new bins were added
        CVodeFree(&cvode_mem);
        init(t0);
        NV_LENGTH_S(y1) = NV_LENGTH_S(y0);
        continue;
      }

      // if CV_ONE_STEP went too far ...
      if (t1 * si::seconds > tmax)
      {
        // ... ask CVODE for interpolatiion @tmax 
        if (CV_SUCCESS != CVodeGetDky(cvode_mem, t1 = (tmax / si::seconds), 0, y1))
          throw exception();
      }

      // report some house-keeping info
      long int nfevals;
      if (CV_SUCCESS != CVodeGetNumRhsEvals(cvode_mem, &nfevals))
        throw exception();
      int qlast;
      if (CV_SUCCESS != CVodeGetLastOrder(cvode_mem, &qlast))
        throw exception();
      realtype hlast;
      if (CV_SUCCESS != CVodeGetLastStep(cvode_mem, &hlast))
        throw exception();
      long int njevals = 0;
      if (this->getNonlinearSolverIteration() == CV_NEWTON)
      {
        if (CV_SUCCESS != CVDlsGetNumJacEvals(cvode_mem, &njevals))
          throw exception();
      }
      long int nniters;
      if (CV_SUCCESS != CVodeGetNumNonlinSolvIters(cvode_mem, &nniters))
        throw exception();
      
      cerr << msgprefix << "solver reached t = " << t1 * si::seconds << " ("
        << "nfevals=" << nfevals
        << " qlast=" << qlast
        << " hlast=" << hlast
        << " njevals=" << njevals
        << " nniters=" << nniters
        << " njevals/nniters=" << realtype(njevals)/nniters
        << ")" << endl;

      // trigger recording
      output->record(t1 * si::seconds, y1, p); // TODO: ..., CVodeGetNumSteps, CVodeGetNumRhsEvals);
 
      // swap t0 with t1 so the solver takes current output as intput in the next loop pass
      tmp = y0; y0 = y1; y1 = tmp;
      t0 = t1;
    } 

    output->foot(t0 * si::seconds, y0, p);
    CVodeFree(&cvode_mem);
  }

  protected: virtual bool refineSpectrum(N_Vector y0, N_Vector y1, N_Vector p) 
  { 
    bool modified = false;
    for (int s = 0; s < ml->n_specs(); ++s)    
    {
      for (int b = 0; ml->bin_r(s, b) != -1; b = ml->bin_r(s, b))
      {
#  define o model->getStateVectorIndexRadii()
        // TODO: units?
        realtype r_l_0 = NV_Ith_S(y0, o + ml->r_l_ix(s, b));
        realtype r_r_0 = NV_Ith_S(y0, o + ml->r_r_ix(s, b));
        realtype r_l_1 = NV_Ith_S(y1, o + ml->r_l_ix(s, b));
        realtype r_r_1 = NV_Ith_S(y1, o + ml->r_r_ix(s, b));
        realtype N = NV_Ith_S(p, ml->N_ix(s, b));
        if (ibl->binNeedsSplitting(r_l_1 * si::metres, r_r_1 * si::metres, 
          r_l_0 * si::metres, r_r_0 * si::metres, N / si::kilogrammes, b)) 
        {
          realtype T_l = NV_Ith_S(y0, o + ml->To_l_ix(s, b));
          realtype T_r = NV_Ith_S(y0, o + ml->To_r_ix(s, b));
          realtype r_d_l = NV_Ith_S(p, ml->r_d_l_ix(s, b));
          realtype r_d_r = NV_Ith_S(p, ml->r_d_r_ix(s, b));
          realtype r_l = NV_Ith_S(y0, o + ml->r_l_ix(s, b));
          realtype r_r = NV_Ith_S(y0, o + ml->r_r_ix(s, b));
          realtype dr = r_r - r_l;
 
          size_t m = ibl->splitInto(N / si::kilogrammes);
          cerr << msgprefix << "splitting bin " << b << " into " << m << " bins" << endl;
          ml->split(s, b, m);
          NV_LENGTH_S(y0) = 3 + ml->length();
          NV_LENGTH_S(p) = ml->length();

          for (int i = 0; i < m; ++i)
          {
            realtype r_left = r_l + i / realtype(m) * (r_r - r_l);
            realtype r_rght = r_l + (i + 1) / realtype(m) * (r_r - r_l);

            // radii
            NV_Ith_S(y0, o + ml->r_l_ix(s, b)) = r_left;
            NV_Ith_S(y0, o + ml->r_r_ix(s, b)) = r_rght;

            // conc
            NV_Ith_S(p, ml->N_ix(s, b)) = N * (r_rght - r_left) / dr;

            // temp.
            NV_Ith_S(y0, o + ml->To_l_ix(s, b)) = T_l + (T_r - T_l) * (r_left - r_l) / dr;
            NV_Ith_S(y0, o + ml->To_r_ix(s, b)) = T_l + (T_r - T_l) * (r_rght - r_l) / dr;

            // dry radii 
            NV_Ith_S(p, ml->r_d_l_ix(s, b)) = r_d_l + (r_d_r - r_d_l) * (r_left - r_l) / dr;
            NV_Ith_S(p, ml->r_d_r_ix(s, b)) = r_d_l + (r_d_r - r_d_l) * (r_rght - r_l) / dr;

            if (i != m - 1) b = ml->bin_r(s, b);
          }
          modified = true;
        }
#  undef o
      }
    }
    return modified;
  };

  public: ~self()
  {
    N_VDestroy_Serial(y0);
    N_VDestroy_Serial(y1);
    if (!model->isBulk()) N_VDestroy_Serial(p);
    //CVodeFree(&cvode_mem); // TODO...
  }

};
#  undef self

#endif  
