! Copyright (c) 2013,  Los Alamos National Security, LLC (LANS)
! and the University Corporation for Atmospheric Research (UCAR).
!
! Unless noted otherwise source code is licensed under the BSD license.
! Additional copyright and license information can be found in the LICENSE file
! distributed with this code, or at http://mpas-dev.github.com/license.html
!
module sw_time_integration

   use mpas_vector_reconstruction
   use mpas_grid_types
   use mpas_configure
   use mpas_constants
   use mpas_dmpar

   use sw_constants

   contains


   subroutine sw_timestep(domain, dt, timeStamp)
   !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! 
   ! Advance model state forward in time by the specified time step
   !
   ! Input: domain - current model state in time level 1 (e.g., time_levs(1)state%h(:,:)) 
   !                 plus grid meta-data
   ! Output: domain - upon exit, time level 2 (e.g., time_levs(2)%state%h(:,:)) contains 
   !                  model state advanced forward in time by dt seconds
   !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! 

      implicit none

      type (domain_type), intent(inout) :: domain
      real (kind=RKIND), intent(in) :: dt
      character(len=*), intent(in) :: timeStamp

      type (block_type), pointer :: block
      type (mpas_pool_type), pointer :: statePool

      character (len=StrKIND), pointer :: xtime
      character (len=StrKIND), pointer :: config_time_integration

      call mpas_pool_get_config(domain % configs, 'config_time_integration', config_time_integration)

      if (trim(config_time_integration) == 'RK4') then
         call sw_rk4(domain, dt)
      else
         write(0,*) 'Unknown time integration option '//trim(config_time_integration)
         write(0,*) 'Currently, only ''RK4'' is supported.'
         stop
      end if

      block => domain % blocklist
      do while (associated(block))
         call mpas_pool_get_subpool(block % structs, 'state', statePool)

         call mpas_pool_get_array(statePool, 'xtime', xtime, 2)
         xtime = timeStamp 
         block => block % next
      end do

   end subroutine sw_timestep


   subroutine sw_rk4(domain, dt)
   !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! 
   ! Advance model state forward in time by the specified time step using 
   !   4th order Runge-Kutta
   !
   ! Input: domain - current model state in time level 1 (e.g., time_levs(1)state%h(:,:)) 
   !                 plus grid meta-data
   ! Output: domain - upon exit, time level 2 (e.g., time_levs(2)%state%h(:,:)) contains 
   !                  model state advanced forward in time by dt seconds
   !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! 

      implicit none

      type (domain_type), intent(inout) :: domain
      real (kind=RKIND), intent(in) :: dt

      integer :: iCell, k
      type (block_type), pointer :: block
      type (mpas_pool_type), pointer :: statePool
      type (mpas_pool_type), pointer :: meshPool
      type (mpas_pool_type), pointer :: tendPool
      type (mpas_pool_type), pointer :: provisStatePool
      type (mpas_pool_type), pointer :: prevProvisPool, nextProvisPool

      integer :: rk_step

      real (kind=RKIND), dimension(4) :: rk_weights, rk_substep_weights

      integer, pointer :: nCells, nEdges, nVertices, nVertLevels

      real (kind=RKIND), dimension(:,:), pointer :: uOld, uNew, uProvis, uTend
      real (kind=RKIND), dimension(:,:), pointer :: hOld, hNew, hProvis, hTend
      real (kind=RKIND), dimension(:,:), pointer :: uReconstructX, uReconstructY, uReconstructZ, uReconstructZonal, uReconstructMeridional
      real (kind=RKIND), dimension(:,:,:), pointer :: tracersOld, tracersNew, tracersProvis, tracersTend

      type (field2DReal), pointer :: pvEdgeField, divergenceField, vorticityField, uField, hField
      type (field3DReal), pointer :: tracersField

      integer, pointer :: config_test_case
      real (kind=RKIND), pointer :: config_h_mom_eddy_visc4


      call mpas_pool_get_config(domain % configs, 'config_test_case', config_test_case)
      call mpas_pool_get_config(domain % configs, 'config_h_mom_eddy_visc4', config_h_mom_eddy_visc4)

     !
     ! Initialize time_levs(2) with state at current time
     ! Initialize first RK state
     ! Couple tracers time_levs(2) with h in time-levels
     ! Initialize RK weights
     !
     block => domain % blocklist
     do while (associated(block))
        call mpas_pool_get_subpool(block % structs, 'mesh', meshPool)
        call mpas_pool_get_subpool(block % structs, 'state', statePool)

        allocate(provisStatePool)
        call mpas_pool_create_pool(provisStatePool)
        call mpas_pool_clone_pool(statePool, provisStatePool, 1)

        call mpas_pool_add_subpool(block % structs, 'provis_state', provisStatePool)

        call mpas_pool_get_dimension(meshPool, 'nCells', nCells)
        call mpas_pool_get_dimension(meshPool, 'nVertLevels', nVertLevels)

        call mpas_pool_get_array(statePool, 'u', uOld, 1)
        call mpas_pool_get_array(statePool, 'u', uNew, 2)
        call mpas_pool_get_array(statePool, 'h', hOld, 1)
        call mpas_pool_get_array(statePool, 'h', hNew, 2)
        call mpas_pool_get_array(statePool, 'tracers', tracersOld, 1)
        call mpas_pool_get_array(statePool, 'tracers', tracersNew, 2)

        uNew(:,:) = uOld(:,:)
        hNew(:,:) = hOld(:,:)
        do iCell = 1, nCells  ! couple tracers to h
          do k = 1, nVertLevels
            tracersNew(:,k,iCell) = tracersOld(:,k,iCell) * hOld(k,iCell)
           end do
        end do

        call mpas_pool_initialize_time_levels(statePool)

        block => block % next
     end do

     block => domain % blocklist
     do while(associated(block))
        if (associated(block % prev)) then
           call mpas_pool_get_subpool(block % prev % structs, 'provis_state', prevProvisPool)
        else
           nullify(prevProvisPool)
        end if

        if (associated(block % next)) then
           call mpas_pool_get_subpool(block % next % structs, 'provis_state', nextProvisPool)
        else
           nullify(nextProvisPool)
        end if

        call mpas_pool_get_subpool(block % structs, 'provis_state', provisStatePool)

        if (associated(prevProvisPool) .and. associated(nextProvisPool)) then
           call mpas_pool_link_pools(provisStatePool, prevProvisPool, nextProvisPool)
        else if (associated(prevProvisPool)) then
           call mpas_pool_link_pools(provisStatePool, prevProvisPool)
        else if (associated(nextProvisPool)) then
           call mpas_pool_link_pools(provisStatePool, nextPool=nextProvisPool)
        else
           call mpas_pool_link_pools(provisStatePool)
        end if

        call mpas_pool_link_parinfo(block, provisStatePool)

        block => block % next
     end do


     rk_weights(1) = dt/6.
     rk_weights(2) = dt/3.
     rk_weights(3) = dt/3.
     rk_weights(4) = dt/6.

     rk_substep_weights(1) = dt/2.
     rk_substep_weights(2) = dt/2.
     rk_substep_weights(3) = dt
     rk_substep_weights(4) = 0.


     !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! 
     ! BEGIN RK loop 
     !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! 
     do rk_step = 1, 4

! --- update halos for diagnostic variables
        call mpas_pool_get_subpool(domain % blocklist % structs, 'provis_state', provisStatePool)

        call mpas_pool_get_field(provisStatePool, 'pv_edge', pvEdgeField, 1)

        call mpas_dmpar_exch_halo_field(pvEdgeField)

        if (config_h_mom_eddy_visc4 > 0.0) then
            call mpas_pool_get_field(statePool, 'divergence', divergenceField, 2)
            call mpas_pool_get_field(statePool, 'vorticity', vorticityField, 2)
            call mpas_dmpar_exch_halo_field(divergenceField)
            call mpas_dmpar_exch_halo_field(vorticityField)
        end if

! --- compute tendencies

       block => domain % blocklist
       do while (associated(block))
          call mpas_pool_get_subpool(block % structs, 'tend', tendPool)
          call mpas_pool_get_subpool(block % structs, 'provis_state', provisStatePool)
          call mpas_pool_get_subpool(block % structs, 'mesh', meshPool)

          call sw_compute_tend(tendPool, provisStatePool, meshPool, 1)
          call sw_compute_scalar_tend(tendPool, provisStatePool, meshPool, 1)
          call sw_enforce_boundary_edge(tendPool, meshPool)
          block => block % next
       end do

! --- update halos for prognostic variables

       call mpas_pool_get_subpool(domain % blocklist % structs, 'tend', tendPool)

       call mpas_pool_get_field(tendPool, 'u', uField)
       call mpas_pool_get_field(tendPool, 'h', hField)
       call mpas_pool_get_field(tendPool, 'tracers', tracersField)

       call mpas_dmpar_exch_halo_field(uField)
       call mpas_dmpar_exch_halo_field(hField)
       call mpas_dmpar_exch_halo_field(tracersField)

! --- compute next substep state

       if (rk_step < 4) then
          block => domain % blocklist
          do while (associated(block))
             call mpas_pool_get_subpool(block % structs, 'state', statePool)
             call mpas_pool_get_subpool(block % structs, 'mesh', meshPool)
             call mpas_pool_get_subpool(block % structs, 'tend', tendPool)
             call mpas_pool_get_subpool(block % structs, 'provis_state', provisStatePool)

             call mpas_pool_get_dimension(meshPool, 'nCells', nCells)
             call mpas_pool_get_dimension(meshPool, 'nVertLevels', nVertLevels)

             call mpas_pool_get_array(provisStatePool, 'u', uProvis)
             call mpas_pool_get_array(provisStatePool, 'h', hProvis)
             call mpas_pool_get_array(provisStatePool, 'tracers', tracersProvis)

             call mpas_pool_get_array(statePool, 'u', uOld, 1)
             call mpas_pool_get_array(statePool, 'h', hOld, 1)
             call mpas_pool_get_array(statePool, 'tracers', tracersOld, 1)

             call mpas_pool_get_array(tendPool, 'u', uTend)
             call mpas_pool_get_array(tendPool, 'h', hTend)
             call mpas_pool_get_array(tendPool, 'tracers', tracersTend)

             uProvis(:,:) = uOld(:,:) + rk_substep_weights(rk_step) * uTend(:,:)
             hProvis(:,:) = hOld(:,:) + rk_substep_weights(rk_step) * hTend(:,:)
             do iCell = 1, nCells
                do k = 1, nVertLevels
                   tracersProvis(:,k,iCell) = ( hOld(k,iCell) * tracersOld(:,k,iCell)  &
                                   + rk_substep_weights(rk_step) * tracersTend(:,k,iCell) &
                                              ) / hProvis(k,iCell)
                end do
             end do
             if (config_test_case == 1) then    ! For case 1, wind field should be fixed
                uProvis(:,:) = uOld(:,:)
             end if
             call sw_compute_solve_diagnostics(dt, provisStatePool, meshPool)
             block => block % next
          end do
       end if

!--- accumulate update (for RK4)

       block => domain % blocklist
       do while (associated(block))
          call mpas_pool_get_subpool(block % structs, 'state', statePool)
          call mpas_pool_get_subpool(block % structs, 'tend', tendPool)
          call mpas_pool_get_subpool(block % structs, 'mesh', meshPool)

          call mpas_pool_get_dimension(meshPool, 'nCells', nCells)
          call mpas_pool_get_dimension(meshPool, 'nVertLevels', nVertLevels)

          call mpas_pool_get_array(statePool, 'u', uOld, 1)
          call mpas_pool_get_array(statePool, 'h', hOld, 1)
          call mpas_pool_get_array(statePool, 'tracers', tracersOld, 1)

          call mpas_pool_get_array(statePool, 'u', uNew, 2)
          call mpas_pool_get_array(statePool, 'h', hNew, 2)
          call mpas_pool_get_array(statePool, 'tracers', tracersNew, 2)

          call mpas_pool_get_array(tendPool, 'u', uTend)
          call mpas_pool_get_array(tendPool, 'h', hTend)
          call mpas_pool_get_array(tendPool, 'tracers', tracersTend)

          uNew(:,:) = uNew(:,:) + rk_weights(rk_step) * uTend(:,:) 
          hNew(:,:) = hNew(:,:) + rk_weights(rk_step) * hTend(:,:) 
          do iCell = 1, nCells
             do k = 1, nVertLevels
                tracersNew(:,k,iCell) = tracersNew(:,k,iCell) + rk_weights(rk_step) * tracersTend(:,k,iCell)
             end do
          end do
          block => block % next
       end do

      end do
      !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! 
      ! END RK loop 
      !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! 


      !
      !  A little clean up at the end: decouple new scalar fields and compute diagnostics for new state
      !
      block => domain % blocklist
      do while (associated(block))
         call mpas_pool_get_subpool(block % structs, 'mesh', meshPool)
         call mpas_pool_get_subpool(block % structs, 'state', statePool)

         call mpas_pool_get_dimension(meshPool, 'nCells', nCells)
         call mpas_pool_get_dimension(meshPool, 'nVertLevels', nVertLevels)

         call mpas_pool_get_array(statePool, 'u', uOld, 1)
         call mpas_pool_get_array(statePool, 'u', uNew, 2)
         call mpas_pool_get_array(statePool, 'h', hNew, 2)
         call mpas_pool_get_array(statePool, 'tracers', tracersNew, 2)

         call mpas_pool_get_array(statePool, 'uReconstructX', uReconstructX, 2)
         call mpas_pool_get_array(statePool, 'uReconstructY', uReconstructY, 2)
         call mpas_pool_get_array(statePool, 'uReconstructZ', uReconstructZ, 2)
         call mpas_pool_get_array(statePool, 'uReconstructZonal', uReconstructZonal, 2)
         call mpas_pool_get_array(statePool, 'uReconstructMeridional', uReconstructMeridional, 2)

         do iCell = 1, nCells
            do k = 1, nVertLevels
               tracersNew(:,k,iCell) = tracersNew(:,k,iCell) / hNew(k,iCell)
            end do
         end do

         if (config_test_case == 1) then    ! For case 1, wind field should be fixed
            uNew(:,:) = uOld(:,:)
         end if

         call sw_compute_solve_diagnostics(dt, statePool, meshPool, 2)

         call mpas_reconstruct(meshPool, uNew,          &
                          uReconstructX, uReconstructY, uReconstructZ, &
                          uReconstructZonal, uReconstructMeridional )

         block => block % next
      end do

      block => domain % blocklist
      do while(associated(block))
         call mpas_pool_get_subpool(block % structs, 'provis_state', provisStatePool)

         call mpas_pool_destroy_pool(provisStatePool)

         call mpas_pool_remove_subpool(block % structs, 'provis_state')
         block => block % next
      end do



   end subroutine sw_rk4


   subroutine sw_compute_tend(tendPool, statePool, meshPool, timeLevelIn)
   !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! 
   ! Compute height and normal wind tendencies, as well as diagnostic variables
   !
   ! Input: s - current model state
   !        grid - grid metadata
   !
   ! Output: tend - computed tendencies for prognostic variables
   !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! 

      implicit none

      type (mpas_pool_type), intent(inout) :: tendPool
      type (mpas_pool_type), intent(in) :: statePool
      type (mpas_pool_type), intent(in) :: meshPool
      integer, intent(in), optional :: timeLevelIn

      integer :: iEdge, iCell, iVertex, k, cell1, cell2, vertex1, vertex2, eoe, i, j, timeLevel
      real (kind=RKIND) :: flux, vorticity_abs, workpv, q, upstream_bias

      integer, pointer :: nCells, nEdges, nVertices, nVertLevels, nCellsSolve, nEdgesSolve, nVerticesSolve
      real (kind=RKIND), dimension(:), pointer :: h_s, fVertex, fEdge, dvEdge, dcEdge, areaCell, areaTriangle, &
                                                  meshScalingDel2, meshScalingDel4
      real (kind=RKIND), dimension(:,:), pointer :: vh, weightsOnEdge, kiteAreasOnVertex, h_edge, h, u, v, tend_h, tend_u, &
                                                    circulation, vorticity, ke, pv_edge, divergence, h_vertex
      integer, dimension(:,:), pointer :: cellsOnEdge, cellsOnVertex, verticesOnEdge, edgesOnCell, edgesOnEdge, edgesOnVertex
      integer, dimension(:), pointer :: nEdgesOnCell, nEdgesOnEdge
      real (kind=RKIND) :: r, u_diffusion

      real (kind=RKIND), allocatable, dimension(:,:) :: delsq_divergence
      real (kind=RKIND), allocatable, dimension(:,:) :: delsq_u
      real (kind=RKIND), allocatable, dimension(:,:) :: delsq_circulation, delsq_vorticity

      real (kind=RKIND), dimension(:,:), pointer :: u_src
      real (kind=RKIND), parameter :: rho_ref = 1000.0
      real (kind=RKIND) :: ke_edge

      logical, pointer :: config_wind_stress, config_bottom_drag
      real (kind=RKIND), pointer :: config_h_mom_eddy_visc2, config_h_mom_eddy_visc4

      if (present(timeLevelIn)) then
         timeLevel = timeLevelIn
      else
         timeLevel = 1
      end if

      call mpas_pool_get_config(swConfigs, 'config_bottom_drag', config_bottom_drag)
      call mpas_pool_get_config(swConfigs, 'config_wind_stress', config_wind_stress)
      call mpas_pool_get_config(swConfigs, 'config_h_mom_eddy_visc2', config_h_mom_eddy_visc2)
      call mpas_pool_get_config(swConfigs, 'config_h_mom_eddy_visc4', config_h_mom_eddy_visc4)

      call mpas_pool_get_array(statePool, 'h', h, timeLevel)
      call mpas_pool_get_array(statePool, 'u', u, timeLevel)
      call mpas_pool_get_array(statePool, 'v', v, timeLevel)
      call mpas_pool_get_array(statePool, 'h_edge', h_edge, timeLevel)
      call mpas_pool_get_array(statePool, 'circulation', circulation, timeLevel)
      call mpas_pool_get_array(statePool, 'vorticity', vorticity, timeLevel)
      call mpas_pool_get_array(statePool, 'divergence', divergence, timeLevel)
      call mpas_pool_get_array(statePool, 'ke', ke, timeLevel)
      call mpas_pool_get_array(statePool, 'pv_edge', pv_edge, timeLevel)
      call mpas_pool_get_array(statePool, 'vh', vh, timeLevel)

      call mpas_pool_get_array(meshPool, 'weightsOnEdge', weightsOnEdge)
      call mpas_pool_get_array(meshPool, 'kiteAreasOnVertex', kiteAreasOnVertex)
      call mpas_pool_get_array(meshPool, 'cellsOnEdge', cellsOnEdge)
      call mpas_pool_get_array(meshPool, 'cellsOnVertex', cellsOnVertex)
      call mpas_pool_get_array(meshPool, 'verticesOnEdge', verticesOnEdge)
      call mpas_pool_get_array(meshPool, 'nEdgesOnCell', nEdgesOnCell)
      call mpas_pool_get_array(meshPool, 'edgesOnCell', edgesOnCell)
      call mpas_pool_get_array(meshPool, 'nEdgesOnEdge', nEdgesOnEdge)
      call mpas_pool_get_array(meshPool, 'edgesOnEdge', edgesOnEdge)
      call mpas_pool_get_array(meshPool, 'edgesOnVertex', edgesOnVertex)
      call mpas_pool_get_array(meshPool, 'dcEdge', dcEdge)
      call mpas_pool_get_array(meshPool, 'dvEdge', dvEdge)
      call mpas_pool_get_array(meshPool, 'areaCell', areaCell)
      call mpas_pool_get_array(meshPool, 'areaTriangle', areaTriangle)
      call mpas_pool_get_array(meshPool, 'h_s', h_s)
      call mpas_pool_get_array(meshPool, 'fVertex', fVertex)
      call mpas_pool_get_array(meshPool, 'fEdge', fEdge)
      call mpas_pool_get_array(meshPool, 'u_src', u_src)
      call mpas_pool_get_array(meshPool, 'meshScalingDel2', meshScalingDel2)
      call mpas_pool_get_array(meshPool, 'meshScalingDel4', meshScalingDel4)

      call mpas_pool_get_array(tendPool, 'h', tend_h)
      call mpas_pool_get_array(tendPool, 'u', tend_u)
                  
      call mpas_pool_get_dimension(meshPool, 'nCells', nCells)
      call mpas_pool_get_dimension(meshPool, 'nCellsSolve', nCellsSolve)
      call mpas_pool_get_dimension(meshPool, 'nEdges', nEdges)
      call mpas_pool_get_dimension(meshPool, 'nEdgesSolve', nEdgesSolve)
      call mpas_pool_get_dimension(meshPool, 'nVertices', nVertices)
      call mpas_pool_get_dimension(meshPool, 'nVerticesSolve', nVerticesSolve)
      call mpas_pool_get_dimension(meshPool, 'nVertLevels', nVertLevels)

      !
      ! Compute height tendency for each cell
      !
      tend_h(:,:) = 0.0
      do iEdge = 1, nEdges
         cell1 = cellsOnEdge(1,iEdge)
         cell2 = cellsOnEdge(2,iEdge)
         do k = 1, nVertLevels
            flux = u(k,iEdge) * dvEdge(iEdge) * h_edge(k,iEdge)
            tend_h(k,cell1) = tend_h(k,cell1) - flux
            tend_h(k,cell2) = tend_h(k,cell2) + flux
         end do
      end do 
      do iCell = 1, nCellsSolve
         do k = 1, nVertLevels
            tend_h(k,iCell) = tend_h(k,iCell) / areaCell(iCell)
         end do
      end do


      !
      ! Compute u (normal) velocity tendency for each edge (cell face)
      !
      tend_u(:,:) = 0.0
      do iEdge = 1, nEdgesSolve
         cell1 = cellsOnEdge(1,iEdge)
         cell2 = cellsOnEdge(2,iEdge)
         vertex1 = verticesOnEdge(1,iEdge)
         vertex2 = verticesOnEdge(2,iEdge)
         
         do k = 1, nVertLevels
            q = 0.0
            do j = 1, nEdgesOnEdge(iEdge)
               eoe = edgesOnEdge(j,iEdge)
               workpv = 0.5 * (pv_edge(k,iEdge) + pv_edge(k,eoe))
               q = q + weightsOnEdge(j,iEdge) * u(k,eoe) * workpv * h_edge(k,eoe) 
            end do

            tend_u(k,iEdge) =       &
                              q     &
                              - (   ke(k,cell2) - ke(k,cell1) + &
                                    gravity * (h(k,cell2) + h_s(cell2) - h(k,cell1) - h_s(cell1)) &
                                  ) / dcEdge(iEdge)
         end do
      end do


     ! Compute diffusion, computed as \nabla divergence - k \times \nabla vorticity
     !                    only valid for visc == constant
     if (config_h_mom_eddy_visc2 > 0.0) then
        do iEdge = 1, nEdgesSolve
           cell1 = cellsOnEdge(1,iEdge)
           cell2 = cellsOnEdge(2,iEdge)
           vertex1 = verticesOnEdge(1,iEdge)
           vertex2 = verticesOnEdge(2,iEdge)

           do k = 1, nVertLevels
              u_diffusion =   ( divergence(k,cell2)  -  divergence(k,cell1) ) / dcEdge(iEdge) &
                   -(vorticity(k,vertex2)  - vorticity(k,vertex1) ) / dvEdge(iEdge)
              u_diffusion = meshScalingDel2(iEdge) * config_h_mom_eddy_visc2 * u_diffusion
              tend_u(k,iEdge) = tend_u(k,iEdge) + u_diffusion
           end do
        end do
     end if

     !
     ! velocity tendency: del4 dissipation, -\nu_4 \nabla^4 u
     !   computed as \nabla^2 u = \nabla divergence + k \times \nabla vorticity
     !   applied recursively.
     !   strictly only valid for h_mom_eddy_visc4 == constant
     !
     if (config_h_mom_eddy_visc4 > 0.0) then
        allocate(delsq_divergence(nVertLevels, nCells+1))
        allocate(delsq_u(nVertLevels, nEdges+1))
        allocate(delsq_circulation(nVertLevels, nVertices+1))
        allocate(delsq_vorticity(nVertLevels, nVertices+1))

        delsq_u(:,:) = 0.0

        ! Compute \nabla^2 u = \nabla divergence + k \times \nabla vorticity
        do iEdge = 1, nEdges
           cell1 = cellsOnEdge(1,iEdge)
           cell2 = cellsOnEdge(2,iEdge)
           vertex1 = verticesOnEdge(1,iEdge)
           vertex2 = verticesOnEdge(2,iEdge)

           do k = 1, nVertLevels

              delsq_u(k,iEdge) = ( divergence(k,cell2)  - divergence(k,cell1) ) / dcEdge(iEdge)  &
                   -( vorticity(k,vertex2) - vorticity(k,vertex1)) / dvEdge(iEdge)

           end do
        end do

        ! vorticity using \nabla^2 u
        delsq_circulation(:,:) = 0.0
        do iEdge = 1, nEdges
           vertex1 = verticesOnEdge(1,iEdge)
           vertex2 = verticesOnEdge(2,iEdge)
           do k=1,nVertLevels
              delsq_circulation(k,vertex1) = delsq_circulation(k,vertex1) &
                   - dcEdge(iEdge) * delsq_u(k,iEdge)
              delsq_circulation(k,vertex2) = delsq_circulation(k,vertex2) &
                   + dcEdge(iEdge) * delsq_u(k,iEdge)
           end do
        end do
        do iVertex = 1, nVertices
           r = 1.0 / areaTriangle(iVertex)
           do k = 1, nVertLevels
              delsq_vorticity(k,iVertex) = delsq_circulation(k,iVertex) * r
           end do
        end do

        ! Divergence using \nabla^2 u
        delsq_divergence(:,:) = 0.0
        do iEdge = 1, nEdges
           cell1 = cellsOnEdge(1,iEdge)
           cell2 = cellsOnEdge(2,iEdge)
           do k=1,nVertLevels
              delsq_divergence(k,cell1) = delsq_divergence(k,cell1) &
                   + delsq_u(k,iEdge)*dvEdge(iEdge)
              delsq_divergence(k,cell2) = delsq_divergence(k,cell2) &
                   - delsq_u(k,iEdge)*dvEdge(iEdge)
           end do
        end do
        do iCell = 1,nCells
           r = 1.0 / areaCell(iCell)
           do k = 1, nVertLevels
              delsq_divergence(k,iCell) = delsq_divergence(k,iCell) * r
           end do
        end do

        ! Compute - \kappa \nabla^4 u 
        ! as  \nabla div(\nabla^2 u) + k \times \nabla ( k \cross curl(\nabla^2 u) )
        do iEdge = 1, nEdgesSolve
           cell1 = cellsOnEdge(1,iEdge)
           cell2 = cellsOnEdge(2,iEdge)
           vertex1 = verticesOnEdge(1,iEdge)
           vertex2 = verticesOnEdge(2,iEdge)

           do k = 1, nVertLevels

              u_diffusion = (  delsq_divergence(k,cell2) &
                   - delsq_divergence(k,cell1) ) / dcEdge(iEdge)  &
                   -(  delsq_vorticity(k,vertex2) &
                   - delsq_vorticity(k,vertex1) ) / dvEdge(iEdge)

              u_diffusion = meshScalingDel4(iEdge) * config_h_mom_eddy_visc4 * u_diffusion
              tend_u(k,iEdge) = tend_u(k,iEdge) - u_diffusion

           end do
        end do

        deallocate(delsq_divergence)
        deallocate(delsq_u)
        deallocate(delsq_circulation)
        deallocate(delsq_vorticity)

     end if

     ! Compute u (velocity) tendency from wind stress (u_src)
     if(config_wind_stress) then
         do iEdge = 1, nEdges
            tend_u(1,iEdge) =  tend_u(1,iEdge) &
                  + u_src(1,iEdge) / rho_ref / h_edge(1,iEdge)
         end do
     endif

     if (config_bottom_drag) then
         do iEdge = 1,  nEdges
             ! bottom drag is the same as POP:
             ! -c |u| u  where c is unitless and 1.0e-3.
             ! see POP Reference guide, section 3.4.4.
             ke_edge = 0.5 * ( ke(1,cellsOnEdge(1,iEdge)) &
                   + ke(1,cellsOnEdge(2,iEdge)))

             tend_u(1,iEdge) = tend_u(1,iEdge)  &
                  - 1.0e-3*u(1,iEdge) &
                  *sqrt(2.0*ke_edge)/h_edge(1,iEdge)
         end do
     endif
 
   end subroutine sw_compute_tend


   subroutine sw_compute_scalar_tend(tendPool, statePool, meshPool, timeLevelIn)
   !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! 
   !
   ! Input: s - current model state
   !        grid - grid metadata
   !
   ! Output: tend - computed scalar tendencies
   !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! 

      implicit none

      type (mpas_pool_type), intent(inout) :: tendPool
      type (mpas_pool_type), intent(in) :: statePool
      type (mpas_pool_type), intent(in) :: meshPool
      integer, intent(in), optional :: timeLevelIn

      integer :: iCell, iEdge, k, iTracer, cell1, cell2, i
      real (kind=RKIND) :: flux, tracer_edge, r
      real (kind=RKIND) :: invAreaCell1, invAreaCell2, tracer_turb_flux
      integer, dimension(:,:), pointer :: boundaryEdge
      real (kind=RKIND), dimension(:,:), allocatable :: boundaryMask
      real (kind=RKIND), dimension(:,:,:), allocatable:: delsq_tracer
      
      real (kind=RKIND) :: d2fdx2_cell1, d2fdx2_cell2
      real (kind=RKIND), dimension(:), pointer :: dvEdge, dcEdge, areaCell
      real (kind=RKIND), dimension(:,:,:), pointer :: tracers, tracer_tend
      integer, dimension(:,:), pointer :: cellsOnEdge, boundaryCell
      real (kind=RKIND), dimension(:,:,:), pointer :: deriv_two
      real (kind=RKIND) :: coef_3rd_order
      real (kind=RKIND), dimension(:,:), pointer :: u, h_edge

      integer, pointer :: config_tracer_adv_order
      logical, pointer :: config_monotonic

      integer, dimension(:), pointer :: nEdgesOnCell
      integer, dimension(:,:), pointer :: cellsOnCell

      integer, pointer :: nCells, nEdges, nVertices, nVertLevels, nCellsSolve, nEdgesSolve, nVerticesSolve, nTracers
      integer :: timeLevel

      real (kind=RKIND), pointer :: config_h_tracer_eddy_diff2, config_h_tracer_eddy_diff4

      if (present(timeLevelIn)) then
         timeLevel = timeLevelIn
      else
         timeLevel = 1
      endif

      call mpas_pool_get_config(swConfigs, 'config_tracer_adv_order', config_tracer_adv_order)
      call mpas_pool_get_config(swConfigs, 'config_monotonic', config_monotonic)
      call mpas_pool_get_config(swConfigs, 'config_h_tracer_eddy_diff2', config_h_tracer_eddy_diff2)
      call mpas_pool_get_config(swConfigs, 'config_h_tracer_eddy_diff4', config_h_tracer_eddy_diff4)

      call mpas_pool_get_array(statePool, 'u', u, timeLevel)
      call mpas_pool_get_array(statePool, 'h_edge', h_edge, timeLevel)
      call mpas_pool_get_array(statePool, 'tracers', tracers, timeLevel)

      call mpas_pool_get_array(meshPool, 'dcEdge', dcEdge)
      call mpas_pool_get_array(meshPool, 'deriv_two', deriv_two)
      call mpas_pool_get_array(meshPool, 'dvEdge', dvEdge)
      call mpas_pool_get_array(meshPool, 'cellsOnEdge', cellsOnEdge)
      call mpas_pool_get_array(meshPool, 'boundaryCell', boundaryCell)
      call mpas_pool_get_array(meshPool, 'boundaryEdge', boundaryEdge)
      call mpas_pool_get_array(meshPool, 'areaCell', areaCell)
      call mpas_pool_get_array(meshPool, 'nEdgesOnCell', nEdgesOnCell)
      call mpas_pool_get_array(meshPool, 'cellsOnCell', cellsOnCell)

      call mpas_pool_get_array(tendPool, 'tracers', tracer_tend)

      call mpas_pool_get_dimension(meshPool, 'nCells', nCells)
      call mpas_pool_get_dimension(meshPool, 'nEdges', nEdges)
      call mpas_pool_get_dimension(meshPool, 'nVertices', nVertices)
      call mpas_pool_get_dimension(meshPool, 'nVertLevels', nVertLevels)
      call mpas_pool_get_dimension(meshPool, 'nCellsSolve', nCellsSolve)
      call mpas_pool_get_dimension(meshPool, 'nEdgesSolve', nEdgesSolve)
      call mpas_pool_get_dimension(meshPool, 'nVerticesSolve', nVerticesSolve)
      call mpas_pool_get_dimension(meshPool, 'nTracers', nTracers)

      coef_3rd_order = 0.
      if (config_tracer_adv_order == 3) coef_3rd_order = 1.0
      if (config_tracer_adv_order == 3 .and. config_monotonic) coef_3rd_order = 0.25

      tracer_tend(:,:,:) = 0.0

      if (config_tracer_adv_order == 2) then

      do iEdge = 1, nEdges
            cell1 = cellsOnEdge(1,iEdge)
            cell2 = cellsOnEdge(2,iEdge)
            if (cell1 <= nCells .and. cell2 <= nCells) then
               do k = 1, nVertLevels
                  do iTracer = 1, nTracers
                     tracer_edge = 0.5 * (tracers(iTracer,k,cell1) + tracers(iTracer,k,cell2))
                     flux = u(k,iEdge) * dvEdge(iEdge) * h_edge(k,iEdge) * tracer_edge
                     tracer_tend(iTracer,k,cell1) = tracer_tend(iTracer,k,cell1) - flux / areaCell(cell1)
                     tracer_tend(iTracer,k,cell2) = tracer_tend(iTracer,k,cell2) + flux / areaCell(cell2)
                  end do 
               end do 
            end if
      end do 

      else if (config_tracer_adv_order == 3) then

         do iEdge = 1, nEdges
            cell1 = cellsOnEdge(1,iEdge)
            cell2 = cellsOnEdge(2,iEdge)

            !-- if a cell not on the most outside ring of the halo
            if (cell1 <= nCells .and. cell2 <= nCells) then

               do k = 1, nVertLevels

                  d2fdx2_cell1 = 0.0
                  d2fdx2_cell2 = 0.0

                  do iTracer = 1, nTracers
 
                     !-- if not a boundary cell
                     if(boundaryCell(k,cell1) .eq. 0 .and. boundaryCell(k,cell2) .eq. 0) then

                        d2fdx2_cell1 = deriv_two(1,1,iEdge) * tracers(iTracer,k,cell1)
                        d2fdx2_cell2 = deriv_two(1,2,iEdge) * tracers(iTracer,k,cell2)

                        !-- all edges of cell 1
                        do i = 1, nEdgesOnCell(cell1)
                                d2fdx2_cell1 = d2fdx2_cell1 + &
                                deriv_two(i+1,1,iEdge) * tracers(iTracer,k, CellsOnCell(i,cell1))
                        end do

                        !-- all edges of cell 2
                        do i = 1, nEdgesOnCell(cell2)
                                d2fdx2_cell2 = d2fdx2_cell2 + &
                                deriv_two(i+1,2,iEdge) * tracers(iTracer,k,CellsOnCell(i,cell2))
                        end do

                     endif

                     !-- if u > 0:
                     if (u(k,iEdge) > 0) then
                        flux = dvEdge(iEdge) * u(k,iEdge) * h_edge(k,iEdge) * (          &
                             0.5*(tracers(iTracer,k,cell1) + tracers(iTracer,k,cell2))      &
                             -(dcEdge(iEdge) **2) * (d2fdx2_cell1 + d2fdx2_cell2) / 12.          &
                             -(dcEdge(iEdge) **2) * coef_3rd_order*(d2fdx2_cell1 - d2fdx2_cell2) / 12. )
                     !-- else u <= 0:
                     else
                        flux = dvEdge(iEdge) *  u(k,iEdge) * h_edge(k,iEdge) * (          &
                             0.5*(tracers(iTracer,k,cell1) + tracers(iTracer,k,cell2))      &
                             -(dcEdge(iEdge) **2) * (d2fdx2_cell1 + d2fdx2_cell2) / 12.          &
                             +(dcEdge(iEdge) **2) * coef_3rd_order*(d2fdx2_cell1 - d2fdx2_cell2) / 12. )
                     end if

                     !-- update tendency
                     tracer_tend(iTracer,k,cell1) = tracer_tend(iTracer,k,cell1) - flux / areaCell(cell1)
                     tracer_tend(iTracer,k,cell2) = tracer_tend(iTracer,k,cell2) + flux / areaCell(cell2)
                  enddo
               end do
            end if
         end do

      else  if (config_tracer_adv_order == 4) then

         do iEdge = 1, nEdges
            cell1 = cellsOnEdge(1,iEdge)
            cell2 = cellsOnEdge(2,iEdge)

            !-- if an edge is not on the outer-most ring of the halo
            if (cell1 <= nCells .and. cell2 <= nCells) then

               do k = 1, nVertLevels

                  d2fdx2_cell1 = 0.0
                  d2fdx2_cell2 = 0.0

                  do iTracer = 1, nTracers

                     !-- if not a boundary cell
                     if(boundaryCell(k,cell1) .eq. 0 .and. boundaryCell(k,cell2) .eq. 0) then

                        d2fdx2_cell1 = deriv_two(1,1,iEdge) * tracers(iTracer,k,cell1)
                        d2fdx2_cell2 = deriv_two(1,2,iEdge) * tracers(iTracer,k,cell2)

                        !-- all edges of cell 1
                        do i = 1, nEdgesOnCell(cell1)
                                d2fdx2_cell1 = d2fdx2_cell1 + &
                                deriv_two(i+1,1,iEdge) * tracers(iTracer,k, cellsOnCell(i,cell1))
                        end do

                        !-- all edges of cell 2
                        do i = 1, nEdgesOnCell(cell2)
                                d2fdx2_cell2 = d2fdx2_cell2 + &
                                deriv_two(i+1,2,iEdge) * tracers(iTracer,k, cellsOnCell(i,cell2))
                        end do

                     endif

                     flux = dvEdge(iEdge) *  u(k,iEdge) * h_edge(k,iEdge) * (          &
                          0.5*(tracers(iTracer,k,cell1) + tracers(iTracer,k,cell2))      &
                             -(dcEdge(iEdge) **2) * (d2fdx2_cell1 + d2fdx2_cell2) / 12. )

                     !-- update tendency
                     tracer_tend(iTracer,k,cell1) = tracer_tend(iTracer,k,cell1) - flux / areaCell(cell1)
                     tracer_tend(iTracer,k,cell2) = tracer_tend(iTracer,k,cell2) + flux / areaCell(cell2)
                  enddo
               end do
            end if
         end do

      endif   ! if (config_tracer_adv_order == 2 )

      !
      ! tracer tendency: del2 horizontal tracer diffusion, div(h \kappa_2 \nabla \phi)
      !
      if ( config_h_tracer_eddy_diff2 > 0.0 ) then

         !
         ! compute a boundary mask to enforce insulating boundary conditions in the horizontal
         !
         allocate(boundaryMask(nVertLevels, nEdges+1))
         boundaryMask = 1.0
         where(boundaryEdge.eq.1) boundaryMask=0.0

         do iEdge = 1, nEdges
            cell1 = cellsOnEdge(1,iEdge)
            cell2 = cellsOnEdge(2,iEdge)
            invAreaCell1 = 1.0/areaCell(cell1)
            invAreaCell2 = 1.0/areaCell(cell2)

            do k = 1, nVertLevels
              do iTracer = 1,  nTracers
                 ! \kappa_2 \nabla \phi on edge
                 tracer_turb_flux = config_h_tracer_eddy_diff2 &
                    *( tracers(iTracer,k,cell2) - tracers(iTracer,k,cell1)) / dcEdge(iEdge)

                 ! div(h \kappa_2 \nabla \phi) at cell center
                 flux = dvEdge(iEdge) * h_edge(k,iEdge) * tracer_turb_flux * boundaryMask(k, iEdge)
                 tracer_tend(iTracer,k,cell1) = tracer_tend(iTracer,k,cell1) + flux * invAreaCell1
                 tracer_tend(iTracer,k,cell2) = tracer_tend(iTracer,k,cell2) - flux * invAreaCell2
              end do
            end do

         end do

        deallocate(boundaryMask)

      end if

      !
      ! tracer tendency: del4 horizontal tracer diffusion, &
      !    div(h \kappa_4 \nabla [div(h \nabla \phi)])
      !
      if ( config_h_tracer_eddy_diff4 > 0.0 ) then

         !
         ! compute a boundary mask to enforce insulating boundary conditions in the horizontal
         !
         allocate(boundaryMask(nVertLevels, nEdges+1))
         boundaryMask = 1.0
         where(boundaryEdge.eq.1) boundaryMask=0.0

         allocate(delsq_tracer(nTracers, nVertLevels, nCells+1))

         delsq_tracer(:,:,:) = 0.

         ! first del2: div(h \nabla \phi) at cell center
         do iEdge = 1, nEdges
            cell1 = cellsOnEdge(1,iEdge)
            cell2 = cellsOnEdge(2,iEdge)

            do k = 1, nVertLevels
              do iTracer = 1, nTracers
                 delsq_tracer(iTracer,k,cell1) = delsq_tracer(iTracer,k,cell1) &
                    + dvEdge(iEdge) * h_edge(k,iEdge) * (tracers(iTracer,k,cell2) - tracers(iTracer,k,cell1)) / dcEdge(iEdge) * boundaryMask(k,iEdge)
                 delsq_tracer(iTracer,k,cell2) = delsq_tracer(iTracer,k,cell2) &
                    - dvEdge(iEdge) * h_edge(k,iEdge) * (tracers(iTracer,k,cell2) - tracers(iTracer,k,cell1)) / dcEdge(iEdge) * boundaryMask(k,iEdge)
              end do
            end do

         end do

         do iCell = 1, nCells
            r = 1.0 / areaCell(iCell)
            do k = 1, nVertLevels
            do iTracer = 1, nTracers
               delsq_tracer(iTracer,k,iCell) = delsq_tracer(iTracer,k,iCell) * r
            end do
            end do
         end do

         ! second del2: div(h \nabla [delsq_tracer]) at cell center
         do iEdge = 1, nEdges
            cell1 = cellsOnEdge(1,iEdge)
            cell2 = cellsOnEdge(2,iEdge)
            invAreaCell1 = 1.0 / areaCell(cell1)
            invAreaCell2 = 1.0 / areaCell(cell2)

            do k = 1, nVertLevels
            do iTracer = 1, nTracers
               tracer_turb_flux = config_h_tracer_eddy_diff4 * (delsq_tracer(iTracer,k,cell2) - delsq_tracer(iTracer,k,cell1)) / dcEdge(iEdge)
               flux = dvEdge(iEdge) * tracer_turb_flux
               tracer_tend(iTracer,k,cell1) = tracer_tend(iTracer,k,cell1) - flux * invAreaCell1 * boundaryMask(k,iEdge)
               tracer_tend(iTracer,k,cell2) = tracer_tend(iTracer,k,cell2) + flux * invAreaCell2 * boundaryMask(k,iEdge)
            end do
            enddo

         end do

         deallocate(delsq_tracer)
         deallocate(boundaryMask)

      end if

   end subroutine sw_compute_scalar_tend


   subroutine sw_compute_solve_diagnostics(dt, statePool, meshPool, timeLevelIn)
   !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! 
   ! Compute diagnostic fields used in the tendency computations
   !
   ! Input: grid - grid metadata
   !
   ! Output: s - computed diagnostics
   !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! 

      implicit none

      real (kind=RKIND), intent(in) :: dt
      type (mpas_pool_type), intent(inout) :: statePool
      type (mpas_pool_type), intent(in) :: meshPool
      integer, intent(in), optional :: timeLevelIn

      integer :: iEdge, iCell, iVertex, k, cell1, cell2, vertex1, vertex2, eoe, i, j, cov
      real (kind=RKIND) :: flux, vorticity_abs, workpv

      integer, pointer :: nCells, nEdges, nVertices, nVertLevels, vertexDegree
      real (kind=RKIND), dimension(:), pointer :: h_s, fVertex, fEdge, dvEdge, dcEdge, areaCell, areaTriangle
      real (kind=RKIND), dimension(:,:), pointer :: vh, weightsOnEdge, kiteAreasOnVertex, h_edge, h, u, v, tend_h, tend_u, &
                                                    circulation, vorticity, ke, pv_edge, pv_vertex, pv_cell, gradPVn, gradPVt, divergence, &
                                                    h_vertex, vorticity_cell
      integer, dimension(:,:), pointer :: cellsOnEdge, cellsOnVertex, verticesOnEdge, edgesOnCell, edgesOnEdge, edgesOnVertex, boundaryEdge, boundaryCell, cellsOnCell
      integer, dimension(:), pointer :: nEdgesOnCell, nEdgesOnEdge
      real (kind=RKIND) :: r, h1, h2
      real (kind=RKIND) :: d2fdx2_cell1, d2fdx2_cell2
      real (kind=RKIND), dimension(:,:,:), pointer :: deriv_two
      real (kind=RKIND) :: coef_3rd_order

      integer :: timeLevel

      logical, pointer :: config_monotonic
      integer, pointer :: config_thickness_adv_order
      real (kind=RKIND), pointer :: config_apvm_upwinding

      if (present(timeLevelIn)) then
         timeLevel = timeLevelIn
      else
         timeLevel = 1
      end if

      call mpas_pool_get_config(swConfigs, 'config_monotonic', config_monotonic)
      call mpas_pool_get_config(swConfigs, 'config_thickness_adv_order', config_thickness_adv_order)
      call mpas_pool_get_config(swConfigs, 'config_apvm_upwinding', config_apvm_upwinding)

      call mpas_pool_get_array(statePool, 'h', h, timeLevel)
      call mpas_pool_get_array(statePool, 'u', u, timeLevel)
      call mpas_pool_get_array(statePool, 'v', v, timeLevel)
      call mpas_pool_get_array(statePool, 'vh', vh, timeLevel)
      call mpas_pool_get_array(statePool, 'h_edge', h_edge, timeLevel)
      call mpas_pool_get_array(statePool, 'h_vertex', h_vertex, timeLevel)
      call mpas_pool_get_array(statePool, 'h', tend_h, timeLevel)
      call mpas_pool_get_array(statePool, 'u', tend_u, timeLevel)
      call mpas_pool_get_array(statePool, 'circulation', circulation, timeLevel)
      call mpas_pool_get_array(statePool, 'vorticity', vorticity, timeLevel)
      call mpas_pool_get_array(statePool, 'divergence', divergence, timeLevel)
      call mpas_pool_get_array(statePool, 'ke', ke, timeLevel)
      call mpas_pool_get_array(statePool, 'pv_edge', pv_edge, timeLevel)
      call mpas_pool_get_array(statePool, 'pv_vertex', pv_vertex, timeLevel)
      call mpas_pool_get_array(statePool, 'pv_cell', pv_cell, timeLevel)
      call mpas_pool_get_array(statePool, 'vorticity_cell', vorticity_cell, timeLevel)
      call mpas_pool_get_array(statePool, 'gradPVn', gradPVn, timeLevel)
      call mpas_pool_get_array(statePool, 'gradPVt', gradPVt, timeLevel)

      call mpas_pool_get_array(meshPool, 'weightsOnEdge', weightsOnEdge)
      call mpas_pool_get_array(meshPool, 'kiteAreasOnVertex', kiteAreasOnVertex)
      call mpas_pool_get_array(meshPool, 'cellsOnEdge', cellsOnEdge)
      call mpas_pool_get_array(meshPool, 'cellsOnVertex', cellsOnVertex)
      call mpas_pool_get_array(meshPool, 'cellsOnCell', cellsOnCell)
      call mpas_pool_get_array(meshPool, 'verticesOnEdge', verticesOnEdge)
      call mpas_pool_get_array(meshPool, 'nEdgesOnCell', nEdgesOnCell)
      call mpas_pool_get_array(meshPool, 'edgesOnCell', edgesOnCell)
      call mpas_pool_get_array(meshPool, 'nEdgesOnEdge', nEdgesOnEdge)
      call mpas_pool_get_array(meshPool, 'edgesOnEdge', edgesOnEdge)
      call mpas_pool_get_array(meshPool, 'edgesOnVertex', edgesOnVertex)
      call mpas_pool_get_array(meshPool, 'dcEdge', dcEdge)
      call mpas_pool_get_array(meshPool, 'dvEdge', dvEdge)
      call mpas_pool_get_array(meshPool, 'areaCell', areaCell)
      call mpas_pool_get_array(meshPool, 'areaTriangle', areaTriangle)
      call mpas_pool_get_array(meshPool, 'h_s', h_s)
      call mpas_pool_get_array(meshPool, 'fVertex', fVertex)
      call mpas_pool_get_array(meshPool, 'fEdge', fEdge)
      call mpas_pool_get_array(meshPool, 'deriv_two', deriv_two)
                  
      call mpas_pool_get_dimension(meshPool, 'nCells', nCells)
      call mpas_pool_get_dimension(meshPool, 'nEdges', nEdges)
      call mpas_pool_get_dimension(meshPool, 'nVertices', nVertices)
      call mpas_pool_get_dimension(meshPool, 'nVertLevels', nVertLevels)
      call mpas_pool_get_dimension(meshPool, 'vertexDegree', vertexDegree)

      call mpas_pool_get_array(meshPool, 'boundaryEdge', boundaryEdge)
      call mpas_pool_get_array(meshPool, 'boundaryCell', boundaryCell)

      !
      ! Find those cells that have an edge on the boundary
      !
      boundaryCell(:,:) = 0
      do iEdge = 1, nEdges
       do k = 1, nVertLevels
         if(boundaryEdge(k,iEdge).eq.1) then
           cell1 = cellsOnEdge(1,iEdge)
           cell2 = cellsOnEdge(2,iEdge)
           boundaryCell(k,cell1) = 1
           boundaryCell(k,cell2) = 1
         endif
       enddo
      enddo

      !
      ! Compute height on cell edges at velocity locations
      !   Namelist options control the order of accuracy of the reconstructed h_edge value
      !

      coef_3rd_order = 0.
      if (config_thickness_adv_order == 3) coef_3rd_order = 1.0
      if (config_thickness_adv_order == 3 .and. config_monotonic) coef_3rd_order = 0.25

      if (config_thickness_adv_order == 2) then

         do iEdge = 1, nEdges
            cell1 = cellsOnEdge(1,iEdge)
            cell2 = cellsOnEdge(2,iEdge)
            if (cell1 <= nCells .and. cell2 <= nCells) then
               do k = 1, nVertLevels
                  h_edge(k,iEdge) = 0.5 * (h(k,cell1) + h(k,cell2))
               end do 
            end if
         end do 

      else if (config_thickness_adv_order == 3) then

         do iEdge = 1, nEdges
            cell1 = cellsOnEdge(1,iEdge)
            cell2 = cellsOnEdge(2,iEdge)

            !-- if a cell not on the most outside ring of the halo
            if (cell1 <= nCells .and. cell2 <= nCells) then

               do k = 1, nVertLevels

                  d2fdx2_cell1 = 0.0
                  d2fdx2_cell2 = 0.0

                  !-- if not a boundary cell
                  if(boundaryCell(k,cell1) .eq. 0 .and. boundaryCell(k,cell2) .eq. 0) then

                     d2fdx2_cell1 = deriv_two(1,1,iEdge) * h(k,cell1)
                     d2fdx2_cell2 = deriv_two(1,2,iEdge) * h(k,cell2)

                     !-- all edges of cell 1
                     do i = 1, nEdgesOnCell(cell1)
                             d2fdx2_cell1 = d2fdx2_cell1 + &
                             deriv_two(i+1,1,iEdge) * h(k, cellsOnCell(i,cell1))
                     end do

                     !-- all edges of cell 2
                     do i = 1, nEdgesOnCell(cell2)
                             d2fdx2_cell2 = d2fdx2_cell2 + &
                             deriv_two(i+1,2,iEdge) * h(k, cellsOnCell(i,cell2))
                     end do

                  endif

                  !-- if u > 0:
                  if (u(k,iEdge) > 0) then
                     h_edge(k,iEdge) =     &
                          0.5*(h(k,cell1) + h(k,cell2))      &
                          -(dcEdge(iEdge) **2) * (d2fdx2_cell1 + d2fdx2_cell2) / 12.          &
                          -(dcEdge(iEdge) **2) * coef_3rd_order*(d2fdx2_cell1 - d2fdx2_cell2) / 12.
                  !-- else u <= 0:
                  else
                     h_edge(k,iEdge) =     &
                          0.5*(h(k,cell1) + h(k,cell2))      &
                          -(dcEdge(iEdge) **2) * (d2fdx2_cell1 + d2fdx2_cell2) / 12.          &
                          +(dcEdge(iEdge) **2) * coef_3rd_order*(d2fdx2_cell1 - d2fdx2_cell2) / 12.
                  end if

               end do   ! do k
            end if      ! if (cell1 <=
         end do         ! do iEdge

      else  if (config_thickness_adv_order == 4) then

         do iEdge = 1, nEdges
            cell1 = cellsOnEdge(1,iEdge)
            cell2 = cellsOnEdge(2,iEdge)

            !-- if a cell not on the most outside ring of the halo
            if (cell1 <= nCells .and. cell2 <= nCells) then

               do k = 1, nVertLevels

                  d2fdx2_cell1 = 0.0
                  d2fdx2_cell2 = 0.0

                  !-- if not a boundary cell
                  if(boundaryCell(k,cell1) .eq. 0 .and. boundaryCell(k,cell2) .eq. 0) then

                     d2fdx2_cell1 = deriv_two(1,1,iEdge) * h(k,cell1)
                     d2fdx2_cell2 = deriv_two(1,2,iEdge) * h(k,cell2)

                     !-- all edges of cell 1
                     do i = 1, nEdgesOnCell(cell1)
                             d2fdx2_cell1 = d2fdx2_cell1 + &
                             deriv_two(i+1,1,iEdge) * h(k, cellsOnCell(i,cell1))
                     end do

                     !-- all edges of cell 2
                     do i = 1, nEdgesOnCell(cell2)
                             d2fdx2_cell2 = d2fdx2_cell2 + &
                             deriv_two(i+1,2,iEdge) * h(k, cellsOnCell(i,cell2))
                     end do

                  endif

                  h_edge(k,iEdge) =   &
                       0.5*(h(k,cell1) + h(k,cell2))      &
                          -(dcEdge(iEdge) **2) * (d2fdx2_cell1 + d2fdx2_cell2) / 12.

               end do   ! do k
            end if      ! if (cell1 <=
         end do         ! do iEdge

      endif   ! if(config_thickness_adv_order == 2)

      !
      ! set the velocity in the nEdges+1 slot to zero, this is a dummy address
      !    used to when reading for edges that do not exist
      !
      u(:,nEdges+1) = 0.0

      !
      ! Compute circulation and relative vorticity at each vertex
      !
      circulation(:,:) = 0.0
      do iEdge = 1, nEdges
         do k = 1, nVertLevels
            circulation(k,verticesOnEdge(1,iEdge)) = circulation(k,verticesOnEdge(1,iEdge)) - dcEdge(iEdge) * u(k,iEdge)
            circulation(k,verticesOnEdge(2,iEdge)) = circulation(k,verticesOnEdge(2,iEdge)) + dcEdge(iEdge) * u(k,iEdge)
         end do
      end do
      do iVertex = 1, nVertices
         do k = 1, nVertLevels
            vorticity(k,iVertex) = circulation(k,iVertex) / areaTriangle(iVertex)
         end do
      end do


      !
      ! Compute the divergence at each cell center
      !
      divergence(:,:) = 0.0
      do iEdge = 1, nEdges
         cell1 = cellsOnEdge(1,iEdge)
         cell2 = cellsOnEdge(2,iEdge)
         if (cell1 <= nCells) then
            do k = 1, nVertLevels
              divergence(k,cell1) = divergence(k,cell1) + u(k,iEdge)*dvEdge(iEdge)
            enddo
         endif
         if(cell2 <= nCells) then
            do k = 1, nVertLevels
              divergence(k,cell2) = divergence(k,cell2) - u(k,iEdge)*dvEdge(iEdge)
            enddo
         end if
      end do
      do iCell = 1, nCells
        r = 1.0 / areaCell(iCell)
        do k = 1, nVertLevels
           divergence(k,iCell) = divergence(k,iCell) * r
        enddo
      enddo

      !
      ! Compute kinetic energy in each cell
      !
      ke(:,:) = 0.0
      do iCell = 1, nCells
         do i = 1, nEdgesOnCell(iCell)
            iEdge = edgesOnCell(i,iCell)
            do k = 1, nVertLevels
               ke(k,iCell) = ke(k,iCell) + 0.25 * dcEdge(iEdge) * dvEdge(iEdge) * u(k,iEdge)**2.0
            end do
         end do
         do k = 1, nVertLevels
            ke(k,iCell) = ke(k,iCell) / areaCell(iCell)
         end do
      end do

      !
      ! Compute v (tangential) velocities
      !
      v(:,:) = 0.0
      do iEdge = 1,nEdges
         do i = 1, nEdgesOnEdge(iEdge)
            eoe = edgesOnEdge(i,iEdge)
            do k = 1,nVertLevels
               v(k,iEdge) = v(k,iEdge) + weightsOnEdge(i,iEdge) * u(k, eoe)
            end do
         end do
      end do

#ifdef NCAR_FORMULATION
      !
      ! Compute mass fluxes tangential to each edge (i.e., through the faces of dual grid cells)
      !
      vh(:,:) = 0.0
      do iEdge = 1, nEdgesSolve
         do j = 1, nEdgesOnEdge(iEdge)
            eoe = edgesOnEdge(j,iEdge)
            do k = 1, nVertLevels
               vh(k,iEdge) = vh(k,iEdge) + weightsOnEdge(j,iEdge) * u(k,eoe) * h_edge(k,eoe)
            end do
         end do
      end do
#endif


      !
      ! Compute height at vertices, pv at vertices, and average pv to edge locations
      !  ( this computes pv_vertex at all vertices bounding real cells and distance-1 ghost cells )
      !
      do iVertex = 1,nVertices
         do k = 1, nVertLevels
            h_vertex(k,iVertex) = 0.0
            do i = 1, vertexDegree
               h_vertex(k,iVertex) = h_vertex(k,iVertex) + h(k,cellsOnVertex(i,iVertex)) * kiteAreasOnVertex(i,iVertex)
            end do
            h_vertex(k,iVertex) = h_vertex(k,iVertex) / areaTriangle(iVertex)

            pv_vertex(k,iVertex) = (fVertex(iVertex) + vorticity(k,iVertex)) / h_vertex(k,iVertex)
         end do
      end do


      !
      ! Compute gradient of PV in the tangent direction
      !   ( this computes gradPVt at all edges bounding real cells and distance-1 ghost cells )
      !
      do iEdge = 1, nEdges
         do k = 1, nVertLevels
           gradPVt(k,iEdge) = (pv_vertex(k,verticesOnEdge(2,iEdge)) - pv_vertex(k,verticesOnEdge(1,iEdge))) / &
                              dvEdge(iEdge)
         enddo
      enddo

      !
      ! Compute pv at the edges
      !   ( this computes pv_edge at all edges bounding real cells )
      !
      pv_edge(:,:) = 0.0
      do iVertex = 1,nVertices
        do i = 1, vertexDegree
           iEdge = edgesOnVertex(i,iVertex)
           do k = 1, nVertLevels
              pv_edge(k,iEdge) =  pv_edge(k,iEdge)  + 0.5 * pv_vertex(k,iVertex)
           end do
        end do
      end do


      !
      ! Modify PV edge with upstream bias. 
      !
      do iEdge = 1, nEdges
         do k = 1, nVertLevels
           pv_edge(k,iEdge) = pv_edge(k,iEdge) - config_apvm_upwinding * v(k,iEdge) * dt * gradPVt(k,iEdge)
         enddo
      enddo


      !
      ! Compute pv at cell centers
      !    ( this computes pv_cell for all real cells and distance-1 ghost cells )
      !
      pv_cell(:,:) = 0.0
      vorticity_cell(:,:) = 0.0
      do iVertex = 1, nVertices
       do i = 1, vertexDegree
         iCell = cellsOnVertex(i,iVertex)
         if (iCell <= nCells) then
           do k = 1, nVertLevels
             pv_cell(k,iCell) = pv_cell(k,iCell) + kiteAreasOnVertex(i, iVertex) * pv_vertex(k, iVertex) / areaCell(iCell)
             vorticity_cell(k,iCell) = vorticity_cell(k,iCell) + kiteAreasOnVertex(i, iVertex) * vorticity(k, iVertex) / areaCell(iCell)
           enddo
         endif
       enddo
      enddo


      !
      ! Compute gradient of PV in normal direction
      !   ( this computes gradPVn for all edges bounding real cells )
      !
      gradPVn(:,:) = 0.0
      do iEdge = 1, nEdges
        if( cellsOnEdge(1,iEdge) <= nCells .and. cellsOnEdge(2,iEdge) <= nCells) then
          do k = 1, nVertLevels
            gradPVn(k,iEdge) = (pv_cell(k,cellsOnEdge(2,iEdge)) - pv_cell(k,cellsOnEdge(1,iEdge))) / &
                                 dcEdge(iEdge)
          enddo
        endif
      enddo

      ! Modify PV edge with upstream bias.
      !
      do iEdge = 1, nEdges
         do k = 1, nVertLevels
           pv_edge(k,iEdge) = pv_edge(k,iEdge) - config_apvm_upwinding * u(k,iEdge) * dt * gradPVn(k,iEdge)
         enddo
      enddo

      !
      ! set pv_edge = fEdge / h_edge at boundary points
      !
   !  if (maxval(boundaryEdge).ge.0) then
   !  do iEdge = 1,nEdges
   !     cell1 = cellsOnEdge(1,iEdge)
   !     cell2 = cellsOnEdge(2,iEdge)
   !     do k = 1,nVertLevels
   !       if(boundaryEdge(k,iEdge).eq.1) then
   !         v(k,iEdge) = 0.0
   !         if(cell1.gt.0) then
   !            h1 = h(k,cell1)
   !            pv_edge(k,iEdge) = fEdge(iEdge) / h1
   !            h_edge(k,iEdge) = h1
   !         else
   !            h2 = h(k,cell2)
   !            pv_edge(k,iEdge) = fEdge(iEdge) / h2
   !            h_edge(k,iEdge) = h2
   !         endif
   !       endif
   !     enddo
   !  enddo
   !  endif


   end subroutine sw_compute_solve_diagnostics


   subroutine sw_enforce_boundary_edge(tendPool, meshPool)
   !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
   ! Enforce any boundary conditions on the normal velocity at each edge
   !
   ! Input: grid - grid metadata
   !
   ! Output: tend_u set to zero at boundaryEdge == 1 locations
   !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!


      implicit none

      type (mpas_pool_type), intent(inout) :: tendPool
      type (mpas_pool_type), intent(in) :: meshPool

      integer, dimension(:,:), pointer :: boundaryEdge
      real (kind=RKIND), dimension(:,:), pointer :: tend_u
      integer, pointer :: nCells, nEdges, nVertices, nVertLevels
      integer :: iEdge, k

      call mpas_pool_get_dimension(meshPool, 'nCells', nCells)
      call mpas_pool_get_dimension(meshPool, 'nEdges', nEdges)
      call mpas_pool_get_dimension(meshPool, 'nVertices', nVertices)
      call mpas_pool_get_dimension(meshPool, 'nVertLevels', nVertLevels)

      call mpas_pool_get_array(meshPool, 'boundaryEdge', boundaryEdge)
      call mpas_pool_get_array(tendPool, 'u', tend_u)

      if(maxval(boundaryEdge).le.0) return

      do iEdge = 1, nEdges
        do k = 1, nVertLevels

          if(boundaryEdge(k,iEdge).eq.1) then
             tend_u(k,iEdge) = 0.0
          endif

        enddo
       enddo

   end subroutine sw_enforce_boundary_edge


end module sw_time_integration
