module modi_extrapolate
contains
subroutine extrapolate( p_fld, p_sdf, p_lons, p_lats, p_f, o_mask )
! Extrapolation in normal direction.
!
! AUTHOR
! Y. Batrak
!
! MODIFICATIONS
! Original  07/2014
#define HEAVISIDE_MEMO
  use modd_csts, only: x_r => xradius
  implicit none
    real,    intent( in out )                         :: &
        p_fld(:,:)
#define P_FLD_DIM size(p_fld,1), size(p_fld,2)
    real,    intent( in     ), dimension( P_FLD_DIM ) :: &
        p_sdf,      &
        p_lons,     &
        p_lats
    real,    intent( in     ), optional               :: &
        p_f( P_FLD_DIM )
    logical, intent( in     ), optional               :: &
        o_mask( P_FLD_DIM )

    real, dimension( P_FLD_DIM )                      :: &
        z_dx,                                            &
        z_dy,                                            &
        z_res,                                           &
        z_f,                                             &
        dlon,                                            &
        dlat
#ifdef HEAVISIDE_MEMO
    real, dimension( P_FLD_DIM ) :: HEAVISIDE
#endif
    logical :: g_mask( P_FLD_DIM )
    real    :: z_upd, z_hx, z_hy
    real    :: z_dlon, z_dlat
    integer :: i_m, i_n
    integer :: j_x, j_y, jj

    z_f    = 0.
    g_mask = p_sdf > 0
    if(present( p_f    )) z_f    = p_f
    if(present( o_mask )) g_mask = o_mask

    i_m = size( p_fld, 1 )
    i_n = size( p_fld, 2 )

    z_dlon = ( p_lons(2,1) - p_lons(1,1) )
    z_dlat = ( p_lats(1,2) - p_lats(1,1) )

    ! Calculation of normals
    !z_h =  i_m*i_n
    z_hy = 2.*x_r*z_dlat
    do j_y = 2, i_n - 1
        do j_x = 2, i_m - 1
            dlon(j_x,j_y) = fdlon(j_x,j_y)
            dlat(j_x,j_y) = fdlat(j_x,j_y)
        end do
    end do
    dlon([1,i_m],:      ) = dlon([2,i_m - 1],:          )
    dlon(:      ,[1,i_n]) = dlon(:          ,[2,i_n - 1])

    dlat([1,i_m],:      ) = dlat([2,i_m - 1],:          )
    dlat(:      ,[1,i_n]) = dlat(:          ,[2,i_n - 1])

#ifdef HEAVISIDE_MEMO
    HEAVISIDE = merge(1.,0.,g_mask)
#else
#define HEAVISIDE(x,y) merge(1.,0.,g_mask(x,y))
#endif
    do j_x = 2, i_m - 1
        z_dx(j_x,  1) = HEAVISIDE(j_x,1  )*(p_sdf(j_x+1,  1)-p_sdf(j_x-1,  1))/ dlon(j_x,1  )
        z_dx(j_x,i_n) = HEAVISIDE(j_x,i_n)*(p_sdf(j_x+1,i_n)-p_sdf(j_x-1,i_n))/ dlon(j_x,i_n)
        do j_y = 2, i_n - 1
            z_hx = dlon(j_x,j_y) !2.*x_r*cos( p_lats(j_x,j_y) )*z_dlon
            z_hy = dlat(j_x,j_y)

            z_dx(j_x,j_y) = HEAVISIDE( j_x, j_y )*( p_sdf(j_x+1,j_y  ) - p_sdf(j_x-1,  j_y) )/z_hx
            z_dy(j_x,j_y) = HEAVISIDE( j_x, j_y )*( p_sdf(j_x,  j_y+1) - p_sdf(j_x,  j_y-1) )/z_hy

            if( j_x == 2 ) then
                z_dy(1,  j_y) = HEAVISIDE( 1, j_y )*( p_sdf(1,    j_y+1) - p_sdf(1,    j_y-1) )/z_hy
                z_dy(i_m,j_y) = HEAVISIDE( 1, j_y )*( p_sdf(i_m,  j_y+1) - p_sdf(i_m,  j_y-1) )/z_hy
            end if
        end do
    end do

    do j_y = 1, i_n
        z_dx( 1,   j_y ) = HEAVISIDE(1,  j_y)*(p_sdf(2,  j_y)-p_sdf(1,    j_y))/dlon(2    ,j_y)*2.
        z_dx( i_m, j_y ) = HEAVISIDE(i_m,j_y)*(p_sdf(i_m,j_y)-p_sdf(i_m-1,j_y))/dlon(i_m-1,j_y)*2.
    end do
    do j_x = 1, i_m
        z_dy( j_x, 1   ) = HEAVISIDE( j_x,   1  )*( p_sdf(j_x,  2) - p_sdf(j_x,    1) )/dlat(j_x,2    )*2.
        z_dy( j_x, i_n ) = HEAVISIDE( j_x, i_n  )*( p_sdf(j_x,i_n) - p_sdf(j_x,i_n-1) )/dlat(j_x,i_n-1)*2.
    end do

    z_res = sqrt( z_dx**2. + z_dy**2. )
    where( z_res == 0 ) z_res = 1.

    z_dx = z_dx/z_res
    z_dy = z_dy/z_res

    ! Extrapolation
    z_res = p_fld
#define H(f) merge(1.,-1.,f>0)
    z_upd = huge(1.D0)
    jj = 0
    do while( z_upd > 5.E-4 )
        jj = jj + 1
        do j_x = 2, i_m - 1
            do j_y = 2, i_n - 1
                if( .NOT. g_mask(j_x,j_y) ) cycle

                if( (z_dx(j_x,j_y)*H(z_dx(j_x,j_y)) +      &
                     z_dy(j_x,j_y)*H(z_dy(j_x,j_y))) /= 0  ) then

                    z_hx = H(z_dx(j_x,j_y))
                    z_hy = H(z_dy(j_x,j_y))*sign(1.,z_dlat)

                    z_res(j_x,j_y) = (z_dx(j_x,j_y)/(.5*dlon(j_x,j_y))*( p_fld(j_x-1,j_y  )*(1.+z_hx) -   &
                                                                         p_fld(j_x+1,j_y  )*(1.-z_hx) ) + &
                                      z_dy(j_x,j_y)/(.5*dlat(j_x,j_y))*( p_fld(j_x,  j_y-1)*(1.+z_hy) -   &
                                                                         p_fld(j_x,  j_y+1)*(1.-z_hy) ) + &
                                      z_f (j_x,j_y)                                                     )/&
                                     (z_dx(j_x,j_y)*z_hx/dlon(j_x,j_y)*2. + z_dy(j_x,j_y)*z_hy/dlat(j_x,j_y)*2.)*.5
                else
                    z_res(j_x,j_y) = merge( (p_fld(j_x-1,j_y) + p_fld(j_x,j_y-1) + p_fld(j_x-1,j_y-1))/3.,  &
                                            p_fld(j_x,j_y),                                                 &
                                            HEAVISIDE(j_x,j_y) == 1 )
                end if
            end do
        end do

        z_dlon = dlon(2,2)*.5
        z_dlat = dlat(2,2)*.5

        do j_x = 1, i_m, i_m
            do j_y = 1, i_n, i_n
                if( HEAVISIDE(j_x,j_y) == 1 .AND. &
                    (z_dx(j_x,j_y)*H(z_dx(j_x,j_y)) +      &
                     z_dy(j_x,j_y)*H(z_dy(j_x,j_y))) /= 0  ) then

                    print*, j_x, j_y

                    if( z_dx(j_x,j_y) > 0 ) then
                        z_hx =  p_fld(j_x-1,j_y  )
                    else
                        z_hx = -p_fld(j_x+1,j_y  )
                    end if

                    if( z_dy(j_x,j_y) > 0 ) then
                        z_hy =  p_fld(j_x,  j_y-1)
                    else
                        z_hy = -p_fld(j_x,  j_y+1)
                    end if

                    z_res(j_x,j_y) = ( z_dx(j_x,j_y)/z_dlon*z_hx +             &
                                       z_dy(j_x,j_y)/z_dlat*z_hy +             &
                                       z_f (j_x,j_y)             ) /           &
                                     ( z_dx(j_x,j_y)*H(z_dx(j_x,j_y))/z_dlon + &
                                       z_dy(j_x,j_y)*H(z_dy(j_x,j_y))/z_dlat   )
                else
                    if( HEAVISIDE( j_x,j_y ) == 0 ) then
                        z_res(j_x,j_y) = p_fld(j_x,j_y)
                    else
                        z_res(j_x,j_y) = .5*( z_hx/H(z_dx(j_x,j_y)) + z_hy/H(z_dy(j_x,j_y)) )
                    end if
                end if
            end do
        end do

        z_upd = maxval(abs( z_res - p_fld ))
        print*, z_upd
        p_fld = z_res
    end do
  contains
    pure function fdlon(x,y) result(res)
      implicit none
        integer, intent( in ) :: x, y
        real                  :: res

        res = x_r*sqrt(cos(p_lats(x,y))**2*(p_lons(x+1,y)-p_lons(x-1,y))**2+(p_lats(x+1,y)-p_lats(x-1,y))**2)
    end function fdlon

    pure function fdlat(x,y) result(res)
      implicit none
        integer, intent( in ) :: x, y
        real                  :: res

        res = x_r*sqrt(cos(p_lats(x,y))**2*(p_lons(x,y+1)-p_lons(x,y-1))**2+(p_lats(x,y+1)-p_lats(x,y-1))**2)
    end function fdlat
end subroutine extrapolate
end module modi_extrapolate
#if 0
module modi_grad
contains
subroutine grad( p_fld, p_lons, p_lats, p_dx, p_dy, p_mask )
  use modd_csts, only: x_r => xradius
  implicit none
    real, intent( in  )                                            :: &
        p_fld(:,:)
    real, intent( in  ), dimension( size(p_fld,1), size(p_fld,2) ) :: &
        p_lons,     &
        p_lats
    real, intent( out ), dimension( size(p_fld,1), size(p_fld,2) ) :: &
        p_dx,       &
        p_dy
    real, intent( in  ), optional                                  :: &
        p_mask( size(p_fld,1), size(p_fld,2) )

    integer :: i_m, i_n
    integer :: j_x, j_y
    real    :: z_dlon, z_dlat
    real    :: z_hx, z_hy

    real    :: z_mask( size(p_fld,1), size(p_fld,2) )

    z_mask = 1.
    if(present( p_mask )) z_mask = p_mask

    i_m = size( p_fld, 1 )
    i_n = size( p_fld, 2 )

    z_dlon = ( p_lons(2,1) - p_lons(1,1) )
    z_dlat = -( p_lats(1,2) - p_lats(1,1) )

    z_hy = 2.*x_r*z_dlat
    do j_x = 2, i_m - 1
        p_dx(j_x,  1) = z_mask(j_x,  1)*(p_fld(j_x+1,  1)-p_fld(j_x-1,  1))/(2.*x_r*cos(p_lats(j_x,1  ))*z_dlon)
        p_dx(j_x,i_n) = z_mask(j_x,i_n)*(p_fld(j_x+1,i_n)-p_fld(j_x-1,i_n))/(2.*x_r*cos(p_lats(j_x,i_n))*z_dlon)
        do j_y = 2, i_n - 1
            z_hx = 2.*x_r*cos( p_lats(j_x,j_y) )*z_dlon

            p_dx(j_x,j_y) = z_mask(j_x,j_y)*( p_fld(j_x+1,j_y  ) - p_fld(j_x-1,  j_y) )/z_hx
            p_dy(j_x,j_y) = z_mask(j_x,j_y)*( p_fld(j_x,  j_y+1) - p_fld(j_x,  j_y-1) )/z_hy

            if( j_x == 1 ) then
                p_dy(1,  j_y) = z_mask(1,  j_y)*( p_fld(1,    j_y+1) - p_fld(1,    j_y-1) )/z_hy
                p_dy(i_m,j_y) = z_mask(i_m,j_y)*( p_fld(i_m,  j_y+1) - p_fld(i_m,  j_y-1) )/z_hy
            end if
        end do
    end do

    do j_y = 1, i_n
        p_dx( 1,   j_y ) = z_mask(1,   j_y)*(p_fld(2,  j_y)-p_fld(1,    j_y))/(x_r*cos(p_lats(1,  j_y))*z_dlon)
        p_dx( i_m, j_y ) = z_mask(i_m, j_y)*(p_fld(i_m,j_y)-p_fld(i_m-1,j_y))/(x_r*cos(p_lats(i_m,j_y))*z_dlon)
    end do
    do j_x = 1, i_m
        p_dy( j_x, 1   ) = z_mask(j_x, 1  )*( p_fld(j_x,  2) - p_fld(j_x,    1) )/z_hy*2.
        p_dy( j_x, i_n ) = z_mask(j_x, i_n)*( p_fld(j_x,i_n) - p_fld(j_x,i_n-1) )/z_hy*2.
    end do
end subroutine
end module
#endif
