#!/usr/bin/env python3

import sys
import os
import pst
import numpy as np
from matplotlib import pyplot as plt
from scipy import optimize
from dist import *

COL = [
	'#0084C8',
	'#DC0000',
	'#009100',
	'#FFC022',
]

def plot(type_, reff, sigmaeff, col):
	sigma2eff = sigmaeff**2.

	if type_ == b'lognorm':
		#title = 'Particle size distribution (log-normal), r$_\\mathrm{eff}$ = %d $\mu$m' % mean
		label = 'lognorm r$_\\mathrm{eff}$ = %.0f $\\mu$m, $\\sigma_\\mathrm{eff}$ = %.0f $\\mu$m' % (reff, sigmaeff)
		mu, sigma2 = lognorm_find_mu_sigma2(reff, sigma2eff)
		n = lognorm_pdf(r, mu, np.sqrt(sigma2))
		print('reff: %f sigmaeff: %f mu: %f sigma: %f reff_sol: %f sigmaeff_sol: %f reff_num: %f sigmaeff_num: %f mean: %f sd: %f' % (
			reff,
			sigmaeff,
			mu,
			np.sqrt(sigma2),
			lognorm_reff(mu, sigma2),
			np.sqrt(lognorm_sigma2eff(mu, sigma2)),
			reff_num(r, n),
			np.sqrt(sigma2eff_num(r, n, reff)),
			calc_mean(r, n),
			calc_sd(r, n),
		))
	elif type_ == b'gamma':
		label = 'gamma r$_\\mathrm{eff}$ = %.0f $\\mu$m, $\\sigma_\\mathrm{eff}$ = %.0f $\\mu$m' % (reff, sigmaeff)
		n = gamma_pdf(r, reff, sigmaeff)
		print('reff: %f sigmaeff: %f reff_num: %f sigmaeff_num: %f mean: %f sd: %f' % (
			reff,
			sigmaeff,
			reff_num(r, n),
			np.sqrt(sigma2eff_num(r, n, reff)),
			calc_mean(r, n),
			calc_sd(r, n),
		))
	else:
		raise ValueError('Invalid type: %s' % type_)

	plt.plot(r, n/np.nanmax(n), lw=1, color=col, label=label)
	# plt.axvline(reff, linestyle='dashed', color=col, lw=0.7)
	# plt.axvline(reff - sigmaeff, linestyle='dotted', color=col, lw=0.7)
	# plt.axvline(reff + sigmaeff, linestyle='dotted', color=col, lw=0.7)
	plt.xlabel('r ($\mu$m)')
	plt.ylabel('n(r)')
	#plt.title(title)

if __name__ == '__main__':
	stderr = os.fdopen(sys.stderr.fileno(), 'wb')
	args, opts = pst.decode_argv(sys.argv)

	if len(args) < 3:
		stderr.write(b'Usage: %s { <type> <reff> <sigmaeff> }... <output> num: <num>\n' % args[0])
		sys.exit(1)

	output = args[-1]
	num = opts.get(b'num', 100000)
	r = np.linspace(1e-3, 50., num)
	plt.figure(figsize=(6, 6))
	for i, arg in enumerate(args[1:-1]):
		type_ = arg[0]
		reff = arg[1]
		sigmaeff = arg[2]
		plot(type_, reff, sigmaeff, col=COL[i])
	legend = plt.legend(fontsize=9)
	plt.savefig(os.fsdecode(output), bbox_inches='tight')
