!========================================================================
! get_fields(u,v,h,Phi,lat,lon,time,k,n,amp,wave_type) - evaluates the  
!    analytic solutions for the proposed test case on arbitrary lat x lon 
!    grids.
! 
! Inputs:
!     lat        - 1D array of desired latitudes (radians) 
!     lon        - 1D array of desired longitudes (radians)
!     time       - 1D array of desired times (sec)(array even for a single time)
!     k          - spherical wave-number (dimensionless)
!     n          - wave-mode (dimensionless)
!     amp        - wave-amplitude (m/sec)
!     wave_type  - -1=WIG, 0=Rossby, 1=EIG
! 
! Outouts:
!     u          - 3D array (time, lon, lat): zonal velocity (m/sec)
!     v          - 3D array (time, lon, lat): meridional velocity (m/sec)
!     h          - 3D array (time, lon, lat): free-surface height anomaly (m)
!     Phi        - 3D array (time, lon, lat): geopotential height (m^2/sec^2)
!        
! get_omega(freq,k,n,wave_type) - evaluates the wave-frequency.
! 
! Inputs:
!     k          - spherical wave-number (dimensionless)
!     n          - wave-mode (dimensionless)
!     wave_type  - -1=WIG, 0=Rossby, 1=EIG
! 
! Outouts:
!     omega      - wave-frequency (rad/sec)
!
! *This code is only valid for wave-numbers k>=1 and wave-modes n>=1.
! Special treatments are required for k=0 and n=-1,0/-.*
!========================================================================

module matsuno

  implicit none

  integer, parameter :: dp = selected_real_kind(15,307)

!========================================================================
! PARAMETERS
!========================================================================
  real(kind=dp), parameter :: OMEGA = 7.29212e-5_dp            ! Earth's angular frequency (rad/sec)
  real(kind=dp), parameter :: G     = 9.80616_dp               ! Earth's gravitational acceleration (m/sec^2)
  real(kind=dp), parameter :: A     = 6371220.0_dp             ! Earth's mean radius (m)
  real(kind=dp), parameter :: H0    = 5.0e-1_dp                ! Layer's mean depth (m)
  real(kind=dp), parameter :: PI    = 3.14159265358979323_dp   ! pi

  real(kind=dp), parameter :: &
        EPSILON  = (2.0_dp*OMEGA*A)**2/(G*H0)                  ! Lamb's parameter

  private :: dp, OMEGA, G, A, H0, PI

  private :: get_hermite_polynomial,  &
             get_psi,                 &
             get_amplitudes

  public  :: get_omega,  &
             get_fields

contains


!========================================================================
! FREQUENCY - Eevaluates the wave-frequency
!========================================================================
  subroutine get_omega(freq,k,n,wave_type)

    implicit none

    integer,          intent(in)    :: k       
    integer,          intent(in)    :: n
    integer,          intent(in)    :: wave_type   
    real(kind=dp),    intent(inout) :: freq        

    real(kind=dp)                   :: omegaj(1:3)  
    real(kind=dp)                   :: Delta0       
    complex(kind=dp)                :: Deltaj       
    real(kind=dp)                   :: Delta4       

    complex(kind=dp), parameter     :: i1 = cmplx(0.0_dp,1.0_dp,dp)
    complex(kind=dp), parameter     :: r1 = cmplx(1.0_dp,0.0_dp,dp)
    real(kind=dp),    parameter     :: r2 = 0.5_dp
    real(kind=dp),    parameter     :: r3 = 1.0_dp/3.0_dp

    integer                         :: j

    Delta0 =  03.0_dp * (G*H0*(k/A)**2 + 2.0_dp*OMEGA*(G*H0)**0.5_dp/A*(2*n+1))
    Delta4 = -54.0_dp * OMEGA * G*H0 * k / A**2

    do j=1,3
       Deltaj = ( r1 * ( Delta4**2 - 4.0_dp * Delta0**3 ) )**r2
       Deltaj = ( r2 * ( Delta4    +          Deltaj    ) )**r3
       Deltaj = Deltaj * exp(2.0_dp*PI * i1 * j * r3)

       omegaj(j)  = real( -r3 * ( Deltaj + Delta0 / Deltaj ) )
    end do

    if (wave_type == 0) then
       freq = -minval(abs(omegaj))
    elseif (wave_type == 1) then
       freq = maxval(omegaj)
    elseif (wave_type == -1) then
       freq = minval(omegaj)
    end if

  end subroutine get_omega


!========================================================================
! HERMITE POLYNOMIAL
! Evaluates the normalized Hermite polynomial of degree n using the 
! three-term recurrence relation.
!========================================================================
  recursive function get_hermite_polynomial(x,n) result(H_n)
    implicit none

    integer,       intent(in)    :: n       
    real(kind=dp), intent(in)    :: x(:)    
    real(kind=dp)                :: H_n(size(x))

    if (n < 0) then
      H_n(:) = 0.0_dp
    elseif (n == 0) then
      H_n(:) = 1.0_dp/PI**0.25
    elseif (n == 1) then
      H_n = (4.0_dp/PI)**0.25 * x
    elseif (n >= 2) then
      H_n = (2.0_dp/n)**0.5 * x * get_hermite_polynomial(x,n-1) - &
            ((n-1.0_dp)/n)**0.5 * get_hermite_polynomial(x,n-2)
    end if

  end function get_hermite_polynomial 


!========================================================================
! EIGENFUNCTIONS - Evaluates the eigenfunction psi.
!========================================================================
  subroutine get_psi(psi_n,lat,n,amp)

    implicit none

    integer,       intent(in)    :: n       
    real(kind=dp), intent(in)    :: lat(:) 
    real(kind=dp), intent(in)    :: amp 
    real(kind=dp), intent(inout) :: psi_n(:)

    real(kind=dp) :: y(size(lat)) 
    real(kind=dp) :: ex(size(lat))

    ! re-scale latitude
    y = EPSILON**0.25 * lat

    ! Guassian envelope
    ex = exp(-0.5_dp * y**2)

    psi_n = amp * ex * get_hermite_polynomial(y,n)

  end subroutine get_psi


!========================================================================
! AMPLITUDES - Evaluates the latitude dependent amplitudes.
!========================================================================
  subroutine get_amplitudes(u_hat,v_hat,p_hat,lat,k,n,amp,wave_type)

    implicit none

    integer,       intent(in)       :: k       
    integer,       intent(in)       :: n
    integer,       intent(in)       :: wave_type
    real(kind=dp), intent(in)       :: lat(:) 
    real(kind=dp), intent(in)       :: amp 
    complex(kind=dp), intent(inout) :: u_hat(:)
    complex(kind=dp), intent(inout) :: v_hat(:)
    complex(kind=dp), intent(inout) :: p_hat(:)

    real(kind=dp) :: freq 
    real(kind=dp) :: psi_n(size(lat))
    real(kind=dp) :: psi_plus(size(lat))
    real(kind=dp) :: psi_minus(size(lat))

    complex(kind=dp), parameter  :: i1 = cmplx(0.0_dp,1.0_dp,dp)

    call get_omega(freq,k,n,wave_type)
    call get_psi(psi_n,lat,n,amp)
    call get_psi(psi_plus,lat,n+1,amp)
    call get_psi(psi_minus,lat,n-1,amp)

    v_hat = psi_n

    u_hat = - ((n+1)/2.0_dp)**0.5 * (freq/(G*H0)**0.5+k/A) * psi_plus  &
            - ((n)/2.0_dp)**0.5 * (freq/(G*H0)**0.5-k/A) * psi_minus

    p_hat = - ((n+1)/2.0_dp)**0.5 * (freq+(G*H0)**0.5*k/A) * psi_plus  &
            + ((n)/2.0_dp)**0.5 * (freq-(G*H0)**0.5*k/A) * psi_minus

    ! pre-factors
    u_hat = G*H0*EPSILON**0.25 / (i1*A*(freq**2-G*H0*(k/A)**2)) * u_hat
    p_hat = G*H0*EPSILON**0.25 / (i1*A*(freq**2-G*H0*(k/A)**2)) * p_hat



  end subroutine get_amplitudes


!========================================================================
! FIELDS - Evaluates the fields.
!========================================================================
  subroutine get_fields(u,v,h,Phi,lat,lon,time,k,n,amp,wave_type)

    implicit none

    integer,       intent(in)    :: k       
    integer,       intent(in)    :: n
    integer,       intent(in)    :: wave_type
    real(kind=dp), intent(in)    :: lat(:) 
    real(kind=dp), intent(in)    :: lon(:) 
    real(kind=dp), intent(in)    :: time(:) 
    real(kind=dp), intent(in)    :: amp 
    real(kind=dp), intent(inout) :: u(:,:,:)
    real(kind=dp), intent(inout) :: v(:,:,:)
    real(kind=dp), intent(inout) :: h(:,:,:)
    real(kind=dp), intent(inout) :: Phi(:,:,:)

    real(kind=dp) :: freq 
    complex(kind=dp) :: u_hat(size(lat))
    complex(kind=dp) :: v_hat(size(lat))
    complex(kind=dp) :: p_hat(size(lat))

    integer       :: ni,nj,nt

    complex(kind=dp), parameter  :: i1 = cmplx(0.0_dp,1.0_dp,dp)

    call get_omega(freq,k,n,wave_type)
    call get_amplitudes(u_hat,v_hat,p_hat,lat,k,n,amp,wave_type)

    do nj=1,size(lat)
       do ni=1,size(lon)
          do nt=1,size(time)
             u(nt,ni,nj) = realpart(u_hat(nj) * &
                 exp( i1 * (k * lon(ni) - freq * time(nt) ) ) )
             v(nt,ni,nj) = realpart(v_hat(nj) * &
                 exp( i1 * (k * lon(ni) - freq * time(nt) ) ) )
             Phi(nt,ni,nj) = realpart(p_hat(nj) * &
                 exp( i1 * (k * lon(ni) - freq * time(nt) ) ) )
          end do
       end do
    end do

   ! transform to free-surface height anomaly
   h = Phi/G

  end subroutine get_fields

end module matsuno