#include <mcmc.h>
#include <nlopt.h>

/* Find parameters for a Gaussian Process that maximizes the
   likelihood p(x|theta), where theta are the parameters defining the
   GP kernel. This method only uses training data and does not need
   testing data like a cross validation would.
*/

struct mleoptstruct {
  /* Helper struct to pass needed data to the optimizer function */
  struct config *E;
  struct state *S;
  uint *argidx;
  int npar;
  float *results; /* For keeping results for post-mortem */
  size_t iter; /* For knowing how many iterations we have done */
};

double loss_nlopt(uint n, const double *x, double *grad,  void *a){
  /* "a" stands for auxiliary data; we deal with setting the parameters
     in the loss function with that */

  double lx; /* loss function at x */

  struct mleoptstruct *aux = (struct mleoptstruct*) a;
  float *xf = malloc(aux->npar*sizeof(float));
  printf("loss(");
  for (int i=0; i<aux->npar; i++) {
    xf[i] = (float) x[i];
    printf("%f ", x[i]);
    if (aux->results) { /* NULL if doing MCMC */
      aux->results[aux->iter*(aux->npar + 1) + i] = xf[i];
    }
  }
  set_cfpars_based_on_argidx(xf, aux->argidx, aux->npar, aux->E);
  free(xf);
  lx = loss(aux->S, aux->E);
  printf(") = %f\n", lx);
  if (aux->results) {
    aux->results[(aux->iter + 1)*(aux->npar + 1) -1] = (float) lx;
    aux->iter++;
  }
  return lx;
}

double loss_mcmc(double *x, void *a) {
  struct mleoptstruct *aux = (struct mleoptstruct*) a;
  return loss_nlopt(aux->npar, x, NULL, a);
}

void find_GP_parameters(struct config *E, struct state *S, size_t opttype) {
  /* Finds the optimum parameters of the Gaussian process. The
     configuration is already set in E and S, and the list of
     parameters to be optimized over is set by setting low and high
     limits of parameters to differ from each other. Before calling,
     observations - real or fake - need to already have been added to
     S.

     With opttype == 0, an NLOpt calibration will be done. Otherwise
     an MCMC calibration will be performed instead.
*/

  /* args for get_cfc_idx_and_limits() */
  float *x_low = malloc(-6*E->cfc->covftype*sizeof(float));
  float *x_high = malloc(-6*E->cfc->covftype*sizeof(float));
  float *x0 = malloc(-6*E->cfc->covftype*sizeof(float));
  uint *argidx = malloc(-6*E->cfc->covftype*sizeof(uint));
  double *xld = NULL;
  double *xhd = NULL;
  double *xd = NULL;
  double xopt; /* final objective function value */
  uint npar;

  int i;
  nlopt_opt opt;

  /* Auxiliary parameter struct that needs to be passed to
     loss_nlopt() */
  struct mleoptstruct *auxpars = malloc(sizeof(struct mleoptstruct));
  auxpars->results = NULL;

  get_cfc_idx_and_limits(E->cfc, &npar, argidx, x_low, x_high, 1, x0);

  /* NLOpt requires doubles, so we convert */
  xld = calloc(npar, sizeof(double));
  xhd = calloc(npar, sizeof(double));
  xd = calloc(npar, sizeof(double));
  for (i=0; i<npar; i++) {
    xld[i] = (double) x_low[i];
    xhd[i] = (double) x_high[i];
    xd[i] = (double) x0[i];
  }

  auxpars->S = S;
  auxpars->E = E;
  auxpars->argidx = argidx;
  auxpars->npar = npar;
  auxpars->iter = 0;

  if (!opttype) { /* NLOpt optimization */
    printf("Calibrating GP with NLOpt.\n");
    auxpars->results = malloc(1000000*npar*sizeof(float)); /* Malloc enough... */
    opt = nlopt_create(NLOPT_GN_ISRES, npar); /* algorithm and dimensionality */
    // opt = nlopt_create(NLOPT_LN_COBYLA, npar);
    // opt = nlopt_create(NLOPT_LN_SBPLX, npar);
    // opt = nlopt_create(NLOPT_LN_NELDERMEAD, npar);
    // opt = nlopt_create(NLOPT_G_MLSL_LDS, npar);
    // nlopt_opt local_opt = nlopt_create(NLOPT_LN_BOBYQA, npar);
    // nlopt_set_local_optimizer(opt, local_opt);

    nlopt_set_lower_bounds(opt, xld);
    nlopt_set_upper_bounds(opt, xhd);
    nlopt_set_min_objective(opt, loss_nlopt, auxpars);
    nlopt_set_ftol_rel(opt, 1e-4);

    int status = nlopt_optimize(opt, xd, &xopt);
    printf("status: %d\n", status);
    if (status < 0) {
      printf("NLOpt failed! If status is -2, check that e.g. your high limits are above low etc. \n");
    }
    else {
      printf("found minimum at f(");
      for (i=0; i<npar-1; i++) {
	printf("%g, ", xd[i]);
      }
      printf("%g) = %g\n", xd[npar-1], xopt);
      write_1d_array_to_txt("optimization_results.txt", auxpars->results, auxpars->iter, auxpars->npar + 1, 1);
    }
  } else { /* MCMC */
    printf("Calibrating GP with MCMC.\n");
    struct mcmcstate mcmcs;
    initialize_mcmcstate(&mcmcs, npar, &loss_mcmc, xld, xhd, xd, opttype);
    mcmc(&mcmcs, opttype, auxpars);
  }
  free(argidx);
  free(x_low);
  free(x_high);
  free(x0);
  free(xld);
  free(xhd);
  free(xd);
  free(auxpars);
}
