#!/usr/bin/env python3

import os
import sys
import numpy as np
import pst
from dist import *
from matplotlib import pyplot as plt
import ds_format as ds

def read(filename):
	d = {
		'r': [],
		'qext': [],
		'qsca': [],
		'p180': [],
	}
	with open(input_, 'rb') as f:
		meta = pst.decode(f.readline())
		d['wavelength_nm'] = meta[b'wavelength_nm']
		d['crefin'] = meta[b'crefin']
		f.readline()
		for line in f.readlines():
			r, qext, qsca, p180 = [float(x) for x in line.split()]
			d['r'].append(r)
			d['qext'].append(qext)
			d['qsca'].append(qsca)
			d['p180'].append(p180)
	d['r'] = np.array(d['r'])
	d['qext'] = np.array(d['qext'], dtype=np.float64)
	d['qsca'] = np.array(d['qsca'], dtype=np.float64)
	d['p180'] = np.array(d['p180'], dtype=np.float64)
	return d

def calc_k(r, qext, qsca, p180, n):
	beta = np.sum(qsca*r**2.*p180/(4.*np.pi)*n)
	alpha_e = np.sum(qext*r**2.*n)
	return beta/alpha_e

if __name__ == '__main__':
	stderr = os.fdopen(sys.stderr.fileno(), 'wb')
	args, opts = pst.decode_argv(sys.argv)
	if len(args) != 4:
		stderr.write(b'Usage: %s <input> <type> <output> [sigmaeff_ratio: <sigmaeff_ratio>]\n' % args[0])
		sys.exit(1)
	input_ = args[1]
	type_ = args[2]
	output = os.fsdecode(args[3])
	sigmaeff_ratio = opts.get(b'sigmaeff_ratio', 0.25)
	d = read(input_)

	reff_min = 5 # um
	reff_max = 50 # um
	reff_n = 500

	reff = np.linspace(reff_min, reff_max, reff_n)
	m = len(reff)
	sigmaeff = reff*sigmaeff_ratio
	k = np.empty(m, dtype=np.float64)
	lr = np.empty(m, dtype=np.float64)

	if type_ == b'lognorm':
		for i in range(m):
			reff0 = reff[i]
			sigmaeff0 = sigmaeff[i]
			sigma2eff0 = sigmaeff0**2.
			mu, sigma2 = lognorm_find_mu_sigma2(reff0, sigma2eff0)
			n = lognorm_pdf(d['r'], mu, np.sqrt(sigma2))
			n = n/np.sum(n)
			k[i] = calc_k(d['r'], d['qext'], d['qsca'], d['p180'], n)
	elif type_ == b'gamma':
		for i in range(m):
			reff0 = reff[i]
			sigmaeff0 = sigmaeff[i]
			n = gamma_pdf(d['r'], reff0, sigmaeff0)
			n = n/np.sum(n)
			k[i] = calc_k(d['r'], d['qext'], d['qsca'], d['p180'], n)
	else:
		raise ValueError('Invalid type "%s"' % type_)

	ds.to_netcdf(output, {
		'wavelength': d['wavelength_nm']*1e-9,
		'reff': reff*1e-6,
		'sigmaeff': sigmaeff*1e-6,
		'k': k,
		'lr': 1./k,
		'.': {
			'.': {
				'distribution': type_,
			},
			'wavelength': {
				'.dims': [],
				'long_name': 'wavelength',
				'units': 'm',
			},
			'reff': {
				'.dims': ['reff'],
				'long_name': 'effective_radius',
				'units': 'm',
			},
			'sigmaeff': {
				'.dims': ['reff'],
				'long_name': 'effective_standard_deviation',
				'units': 'm',
			},
			'k': {
				'.dims': ['reff'],
				'long_name': 'backscatter_to_extinction_ratio',
				'units': 'sr-1',
			},
			'lr': {
				'.dims': ['reff'],
				'long_name': 'lidar_ratio',
				'units': 'sr',
			}
		}

	})
