/* This file implements functions for doing MCMC */

struct mcmcstate {
  double (*costfun)(double*, void*);
  double *x_high;
  double *x_low;
  double *x;
  uint npar;
  double *chain;
  double *chain_only_accepted;
  double *initial_propcov;
  double *propcov;
  double *propcovupdate; /* array for doing partial updates */
  size_t iter;
  size_t last_cov_update_iter;
  double prev_acc_cfval;
  double *prev_acc_x;
  double scalefactor;
  int recompute_acc_cf; /* for pseudo-marginal sampling with low probability */
  struct boxmullerstruct *BM;
};

void initialize_mcmcstate(struct mcmcstate *mcmcs, uint npar, double (*costfun)(double*, void*),
			  double *x_low, double *x_high, double *x, size_t niter) {

  int i;

  mcmcs->x_high = x_high;
  mcmcs->x_low = x_low;
  mcmcs->prev_acc_x = x;
  mcmcs->costfun = costfun;
  mcmcs->npar = npar;
  mcmcs->chain = calloc(niter*(mcmcs->npar + 2), sizeof(double));
  mcmcs->chain_only_accepted = calloc(niter*(mcmcs->npar), sizeof(double));
  mcmcs->propcov = calloc(mcmcs->npar*mcmcs->npar, sizeof(double));
  mcmcs->initial_propcov = calloc(mcmcs->npar*mcmcs->npar, sizeof(double));
  mcmcs->propcovupdate = calloc(mcmcs->npar*mcmcs->npar, sizeof(double));

  mcmcs->iter = 0;
  mcmcs->x = calloc(mcmcs->npar, sizeof(double));
  /* Copy x so that we get a reasonable proposal at first round */
  for (i=0; i<mcmcs->npar; i++) {
    mcmcs->x[i] = x[i];
  }

  mcmcs->prev_acc_cfval = 1e9;
  mcmcs->BM = initialize_boxmullerstruct();
  for (i=0; i<mcmcs->npar; i++) {
    mcmcs->initial_propcov[i*(mcmcs->npar + 1)] = pow(0.005*(mcmcs->x_high[i] - mcmcs->x_low[i]), 2);
    mcmcs->propcov[i*(mcmcs->npar + 1)] = mcmcs->initial_propcov[i*(mcmcs->npar + 1)];
  }

  /* Scaling factor to scale down the cost function. This is
     automatically determined in mcmc(). */
  mcmcs->scalefactor = 1;
}

void propose(struct mcmcstate *mcmcs) {

  int i, j;
  double *tmp = malloc(mcmcs->npar*sizeof(double));
  double *res = malloc(mcmcs->npar*sizeof(double));
  double *cov_copy = malloc(mcmcs->npar*mcmcs->npar*sizeof(double));
  size_t start, span;
  double ini_cov_weight = 1000; /* We only slowly give away the initial covariance */
  float burnin_propcovscaling = 0;
  size_t full_propcov_iter = 10000; /* At this point we have full proposal covariance as in standard AM */

  burnin_propcovscaling = fmin(1.*mcmcs->iter/full_propcov_iter, 1);



  if ((mcmcs->iter > 500) && (sqrt(mcmcs->iter) - (int) sqrt(mcmcs->iter) < 1e-9)) {
    /* Use only half of the chain for proposal generation */
    start = mcmcs->iter/2;
    span = mcmcs->iter - start;
    sample_covariance(mcmcs->propcovupdate, &mcmcs->chain_only_accepted[start*mcmcs->npar], mcmcs->npar, span, 2.4*2.4/mcmcs->npar);
    for (i=0; i<mcmcs->npar*mcmcs->npar; i++) {
      mcmcs->propcov[i] = 1./(span +  ini_cov_weight) *
	(ini_cov_weight*mcmcs->initial_propcov[i] + span*mcmcs->propcovupdate[i]*burnin_propcovscaling);
    }
  }

  // printf(" first par in chain_only_accepted: %g", mcmcs->chain_only_accepted[mcmcs->iter*mcmcs->npar]);

  for (i=0; i<mcmcs->npar; i++) {
    tmp[i] = (double) normal(0, 1, mcmcs->BM);
    res[i] = mcmcs->prev_acc_x[i];
  }

 /* With this probability we propose with a very small covariance and
    reset mcmcs->prev_acc_cf to that value, in order to get out if
    chain gets stuck. */
  float recompute_acc_cf_prob = 0.001; /* FIXME make configurable. */
  mcmcs->recompute_acc_cf = (random() < recompute_acc_cf_prob*RAND_MAX) ? 1 : 0;
  float shrinkage = (mcmcs->recompute_acc_cf) ? 0.0001 : 1.;

  /* FIXME cholesky needs to be calculated only when we update
     proposal covariance */
  for (j=0; j<mcmcs->npar*mcmcs->npar; j++) {
    cov_copy[j] = shrinkage*mcmcs->propcov[j];
  }
  /* Cholesky of the covariance matrix */
  LAPACKE_dpotrf(LAPACK_ROW_MAJOR, 'U', (lapack_int)mcmcs->npar,
		 cov_copy, (lapack_int)mcmcs->npar);
  cblas_dtrmv(CblasRowMajor, CblasUpper, CblasTrans, CblasNonUnit, (int) mcmcs->npar,
	      cov_copy, (int) mcmcs->npar, tmp, 1);

  for (i=0; i<mcmcs->npar; i++) {
    res[i] += tmp[i];
  }

  /* Free temporary arrays and set new x to point to the correct
     array */
  free(tmp);
  free(mcmcs->x);
  free(cov_copy);
  mcmcs->x = res;
}

int inside_bounds(struct mcmcstate *mcmcs) {
  int i;
  for (i=0; i<mcmcs->npar; i++) {
    if ((mcmcs->x[i] > mcmcs->x_high[i]) || (mcmcs->x[i] < mcmcs->x_low[i])) {
      return 0;
    }
  }
  return 1;
}

void mh(double cfval, struct mcmcstate *mcmcs) {
  /* Accepts / rejects, and adds an entry to the chain, update
     posterior mean and propcovupdate */
  int i;
  double accept = 0;

  /* Accept/reject and update chain and chain_only_accepted
     accordingly. The second part of the logical condition states that
     we randomly recompute previous cost functions, by proposing a
     value that is very very close to the previous value, since
     sometimes there are local minima that we do not want to fall
     into. This is not strictly correct but with a low probability and
     a very small difference to previous accepted point it should not
     affect statistics in a perceptible way. We make sure that any new
     trivially accepted points are still inside bounds. */
  if ((.5*cfval < .5*mcmcs->prev_acc_cfval - log(1.*random()/RAND_MAX)) || ((mcmcs->recompute_acc_cf) && (inside_bounds(mcmcs)))) {
    accept = 1;
    mcmcs->prev_acc_cfval = cfval;
    for (i=0; i<mcmcs->npar; i++) {
      mcmcs->prev_acc_x[i] = mcmcs->x[i];
    }
  }

  /* CHECKME Could there be an omp barrier thing missing making it so
     that not all contributions from cf are calculated? */

  mcmcs->chain[(mcmcs->iter + 1)*(mcmcs->npar + 2) - 2] = cfval;
  mcmcs->chain[(mcmcs->iter + 1)*(mcmcs->npar + 2) - 1] = accept;
  for (i=0; i<mcmcs->npar; i++) {
    mcmcs->chain[mcmcs->iter*(mcmcs->npar + 2) + i] = mcmcs->x[i];
    mcmcs->chain_only_accepted[mcmcs->iter*mcmcs->npar + i] = mcmcs->prev_acc_x[i];
  }
}

void mcmc(struct mcmcstate *mcmcs, size_t niter, void *aux) {

  double cfval;
  size_t write_interval = 100;
  float nrefpoints = 12.; // FIXME Get this from config
  float max_ksize = 128.;

  /* Arbitrary scaling. Change this at your will to get the chain to
     explore/mix properly: remember that we are just interested in the
     expected values. */
  mcmcs->scalefactor = 10./nrefpoints/sqrt(max_ksize);

  while (mcmcs->iter < niter) {
    printf("%zu/%zu\n", mcmcs->iter, niter);
    propose(mcmcs);
    cfval = (inside_bounds(mcmcs)) ? mcmcs->costfun(mcmcs->x, aux)*mcmcs->scalefactor : 1e10;
    mh(cfval, mcmcs);
    mcmcs->iter++;
    if ((!(mcmcs->iter%write_interval))) {
      write_1d_array_to_txt_double("mcmcresults.txt", &mcmcs->chain[(mcmcs->npar + 2)*(mcmcs->iter - write_interval)], write_interval, mcmcs->npar + 2, 1);
    }
  }
}
