function out = spartacus_sw_vegetation(in)
%SPARTACUS_SW_VEGETATION  Compute flux profile in vegetation canopy
%  out=spartacus_sw_vegetation(in) performs the computation, where 
%  "in" is a structure with the following elements:
%    sza              solar zenith angle (degrees)
%    m                number of regions: 2 (clear/veg) or 3 (clear/veg1/veg2)
%    n                number of layers
%    z(n+1)           height (m) of layer interfaces starting at the surface
%    od(m,n)          optical depth
%    leaf_r(m,n)      leaf reflectance
%    leaf_t(m,n)      leaf transmittance
%    frac             vegetation fraction
%    surf_albedo      surface_albedo
%    top_flux         top of canopy direct flux (W m-2)
%    eff_diam         effective vegetation-element diameter (m)
%    fast             do we use fast sub-canopy-layer calc? (0 or 1)
%  The output structure "out" contains the following elements:
%    flux_up          upwelling flux at layer interfaces (W m-2)
%    flux_dn          downwelling flux at layer interfaces (W m-2)
%    flux_dn_diffuse  diffuse downwelling flux at layer interfaces (W m-2)
%    flux_dn_direct   direct downwelling flux at layer interfaces (W m-2)
%
%  Author: Robin Hogan <r.j.hogan@ecmwf.int>
%  Copyright: 2017 European Centre for Medium Range Weather Forecasts
%
%  Copying and distribution of this file, with or without modification,
%  are permitted in any medium without royalty provided the copyright
%  notice and this notice are preserved.

if in.m > 3 || in.m < 2
  error('Can only deal with 2 or 3 regions')
end

% Cosine of effective diffuse zenith angle for vertical transport
mu1 = 0.5;
% Tangent of effective diffuse zenith angle for horizontal transport
% (see Schaefer et al. 2016 for why this is not quite consistent with
% mu1)
tan_theta_lat = pi/2;

% Cosine of solar zenith angle
mu0 = cosd(in.sza);
% Tangent of solar zenith angle
tan_theta_lat0 = tand(in.sza); % Direct

% Convert leaf reflectance/transmittance into single scattering albedo
% (ssa) and upscatter fractions for diffuse and direct radiation
ssa = in.leaf_r+in.leaf_t;
ssa(find(ssa<=1e-10)) = 1e-10;
upscatter_diff = 0.5 + mu1.*(in.leaf_r-in.leaf_t)./(3.*ssa);
upscatter_dir  = 0.5 + mu0.*(in.leaf_r-in.leaf_t)./(3.*ssa);

% Layer thickness (m)
dz = abs(diff(in.z));

% Extinction coefficient (m-1)
ext = in.od ./ (ones(in.m,1)*dz(:)');

% Two-stream coefficients
gamma1 = (1-ssa.*(1-upscatter_diff))./mu1;
gamma2 = ssa.*upscatter_diff./mu1;
gamma3 = upscatter_dir;
gamma4 = 1.0 - upscatter_dir;

% Length of interface between regions a and b
Lab = [0 0] + (4.0/pi) .* in.frac ./ in.eff_diam;
%Lab = [0 0] + (4.0/pi) .* in.frac .* (1.0 - in.frac)./ in.eff_diam;

% Region fractions in each layer
frac = zeros(in.m,in.n);

if in.m == 2
  % Two-region description: region a (index 1) is clear, region b
  % (index 2) is vegetated
  frac(1,:) = 1.0 - in.frac';
  frac(2,:) = in.frac';

  % Rates of radiation exchange between regions for diffuse radiation
  fab = Lab .* tan_theta_lat ./ (1.0 - in.frac);
  fba = Lab .* tan_theta_lat ./ in.frac;
  % Rates of radiation exchange between regions for direct radiation
  fab0 = Lab .* tan_theta_lat0 ./ (1.0 - in.frac);
  fba0 = Lab .* tan_theta_lat0 ./ in.frac;

else % in.m == 3
  % Three-region description: region a (index 1) is clear, region b
  % (index 2) is lower density vegetation and region c (index 3) is
  % higher density vegetation
  frac(1,:) = 1.0 - in.frac';
  frac(2,:) = 0.5.*in.frac';
  frac(3,:) = 0.5.*in.frac';

  % Treat effective diameter of optically thick part of 
  % cloud as sqrt(2) smaller, and since fraction of optically 
  % thick part is half the cloud fraction, we have
  Lbc = Lab ./ sqrt(2);

  % Rates of radiation exchange between regions for diffuse radiation
  fab = Lab .* tan_theta_lat ./ (1.0 - in.frac);
  fba = Lab .* tan_theta_lat ./ (0.5 .* in.frac);
  fbc = Lbc .* tan_theta_lat ./ (0.5 .* in.frac);
  fcb = fbc; % because the areal fractions are equal
  % Rates of radiation exchange between regions for direct radiation
  fab0 = Lab .* tan_theta_lat0 ./ (1.0 - in.frac);
  fba0 = Lab .* tan_theta_lat0 ./ (0.5 .* in.frac);
  fbc0 = Lbc .* tan_theta_lat0 ./ (0.5 .* in.frac);
  fcb0 = fbc0; % because the areal fractions are equal

end

% Initialize the matrices representing transmission (T), reflection
% (R), scattering of the direct beam up (Sup), scattering of the
% direct beam down (Sdn) and direct unscattered transmission (Ess)
T   = zeros(in.m,in.m,in.n);
R   = T;
Sup = T;
Sdn = T;
Ess = T;

% Indices for the upwelling (iu), downwelling diffuse (iv) and
% downwelling direct (is) parts of the larger Gamma matrix
iu = 1:in.m;
iv = in.m+1:2*in.m;
is = 2*in.m+1:3*in.m;

% Loop over layers and compute the T, R, Sup, Sdn and Ess matrices
for ii = 1:in.n
  % First the two-stream parts of the Gamma sub-matrices
  Gamma0 = diag(-ext(:,ii)./mu0);
  Gamma1 = diag(-ext(:,ii).*gamma1(:,ii));
  Gamma2 = diag(ext(:,ii).*gamma2(:,ii));
  Gamma3 = diag(ext(:,ii).*ssa(:,ii).*gamma3(:,ii));
  Gamma4 = diag(ext(:,ii).*ssa(:,ii).*gamma4(:,ii));
  % Then the 3D exchange components of the Gamma sub-matrices
  Gamma0(1,1) = Gamma0(1,1) - fab0(ii);
  Gamma0(2,2) = Gamma0(2,2) - fba0(ii);
  Gamma0(1,2) = Gamma0(1,2) + fba0(ii);
  Gamma0(2,1) = Gamma0(2,1) + fab0(ii);
  Gamma1(1,1) = Gamma1(1,1) - fab(ii);
  Gamma1(2,2) = Gamma1(2,2) - fba(ii);
  Gamma1(1,2) = Gamma1(1,2) + fba(ii);
  Gamma1(2,1) = Gamma1(2,1) + fab(ii);
  if in.m == 3
    % Additional terms in the case of three regions for exchange
    % between regions b and c
    Gamma0(2,2) = Gamma0(2,2) - fbc0(ii);
    Gamma0(3,3) = Gamma0(3,3) - fcb0(ii);
    Gamma0(2,3) = Gamma0(2,3) + fcb0(ii);
    Gamma0(3,2) = Gamma0(3,2) + fbc0(ii);
    Gamma1(2,2) = Gamma1(2,2) - fbc(ii);
    Gamma1(3,3) = Gamma1(3,3) - fcb(ii);
    Gamma1(2,3) = Gamma1(2,3) + fcb(ii);
    Gamma1(3,2) = Gamma1(3,2) + fbc(ii);
  end

  % Construct full Gamma matrix
  Gamma = [-Gamma1 -Gamma2 -Gamma3; ...
	    Gamma2  Gamma1  Gamma4; ...
	    zeros(in.m,2*in.m) Gamma0];

  % Perform matrix exponential of Gamma*dz
  Gamma_z1 = Gamma.*dz(ii);
  expGamma = expm(Gamma.*dz(ii));

  % Compute the matrices we need from the result
  R(:,:,ii) = -expGamma(iu,iu)\expGamma(iu,iv);
  T(:,:,ii) = expGamma(iv,iu)*squeeze(R(:,:,ii)) + expGamma(iv,iv);
  Sup(:,:,ii) = -expGamma(iu,iu)\expGamma(iu,is);
  Sdn(:,:,ii) = expGamma(iv,iu)*Sup(:,:,ii) + expGamma(iv,is);
  Ess(:,:,ii) = expGamma(is,is); 
end

% Initialize diffuse and direct albedo at half levels
Adiff = zeros(in.m,in.m,in.n+1); 
for jj = 1:in.m
  Adiff(jj,jj,1) = in.surf_albedo;
end
Adir  = mu0.*Adiff;

% Normal (full) calculation
ii = 1;
beta = inv(eye(in.m)-Adiff(:,:,ii)*R(:,:,ii));
Adiff(:,:,ii+1) = R(:,:,ii)+T(:,:,ii)*beta*Adiff(:,:,ii)*T(:,:,ii);
Adir(:,:,ii+1)  = Sup(:,:,ii) + T(:,:,ii)*beta*Adir(:,:,ii)*Ess(:,:,ii); 

% Optionally overwrite lowest layer with a simpler scheme, taking
% advantage of the simplification of the matrices when there is no
% vegetation present
if isfield(in,'fast')
  if in.fast ~= 0
    % In the case that there is no vegetation in the lowest layer, the
    % mathematics becomes simpler
    if in.m == 2
      expGamma  = fast_expm([fab(1) fba(1)].*dz(1));
      expGamma0 = fast_expm([fab0(1) fba0(1)].*dz(1));
    else
      expGamma = fast_expm([fab(1) fba(1) fbc(1)].*dz(1));
      expGamma0 = fast_expm([fab0(1) fba0(1) fbc0(1)].*dz(1));
    end
    Tfast = expGamma;
    Ess_fast = expGamma0;
    Adiff(:,:,2) = Tfast*Adiff(:,:,1)*Tfast;
    Adir (:,:,2) = Tfast*Adir(:,:,1) *Ess_fast;
  end
end

% Upper layers
for ii = 2:in.n
  beta = inv(eye(in.m)-Adiff(:,:,ii)*R(:,:,ii));
  Adiff(:,:,ii+1) = R(:,:,ii)+T(:,:,ii)*beta*Adiff(:,:,ii)*T(:,:,ii);
  Adir(:,:,ii+1)  = Sup(:,:,ii) + T(:,:,ii)*beta*(Adir(:,:,ii)*Ess(:,:,ii) + Adiff(:,:,ii)*Sdn(:,:,ii));
end

% Direct downwelling (s), upwelling (u) and diffuse downwelling (u)
% fluxes in each region at each half level
s = zeros(in.m,in.n+1);
u = zeros(in.m,in.n+1);
v = zeros(in.m,in.n+1);

% At top of canopy we have the incoming direct flux, and assume zero
% downwelling diffuse flux
s(:,end) = in.top_flux*frac(:,end);

% Upwelling at top is reflected from the incoming direct
u(:,end) = Adir(:,:,end)*s(:,end);

% Work down through the canopy computing fluxes at each half level
for ii = in.n:-1:1
  s(:,ii) = Ess(:,:,ii)*s(:,ii+1);
  v(:,ii) = (eye(in.m) - R(:,:,ii)*Adiff(:,:,ii)) ...
    \ (T(:,:,ii)*v(:,ii+1) + R(:,:,ii)*Adir(:,:,ii)*s(:,ii) + Sdn(:,:,ii)*s(:,ii+1));
  u(:,ii) = Adiff(:,:,ii)*v(:,ii) + Adir(:,:,ii)*s(:,ii);
end

% Sum the region fluxes to get the domain-averaged fluxes at each half level
out.flux_up = sum(u,1);
out.flux_dn_diffuse = sum(v,1);
out.flux_dn_direct = mu0.*sum(s,1);
% Total downwelling is sum of diffuse and direct
out.flux_dn = out.flux_dn_diffuse + out.flux_dn_direct;

