module nn_extrapolate
! In this module implemented fast inplace nearest neighbour interpolation/extrapolation.
! For fast nearest neighbour search used kD-tree.
!
! AUTHOR
! Y. Batrak
!
! MODIFICATIONS
! Original  08/2014
  implicit none
  private

    public :: nearest_neighbour_extrapolate
    ! This type must have 16-byte size.
    type :: indexed
        real   (KIND=KIND(0.0D0)) :: value
        integer(KIND=KIND(0.0D0)) :: idx
    end type

    type :: kd_node
        logical                :: g_is_x_section
        real                   :: z_x_val,          &
                                  z_y_val,          &
                                  z_f_val

        type(kd_node), pointer :: left   => null(), &
                                  right  => null(), &
                                  parent => null()
    end type

  contains
    subroutine nearest_neighbour_extrapolate( p_fld, p_lons, p_lats, o_mask, o_ext_mask )
      use MODD_CSTS, only: XPI
      implicit none
        real,    intent( in out )                         :: &
            p_fld(:)
        real,    intent( in     ), dimension(size(p_fld)) :: &
            p_lons,     &
            p_lats
        logical, intent( in     ) :: o_mask(size(p_fld))
        logical, intent( in     ), optional :: o_ext_mask(size(p_fld))

        logical :: g_ext_mask(size(p_fld))

        type(kd_node), pointer :: tz_root, tz_leaf
        integer, allocatable   :: i_mask(:), i_ext_mask(:)

        integer :: ji, ji_mask, ji_ext_mask
        real    :: z_lat

        nullify(tz_root, tz_leaf)

        g_ext_mask = .NOT. o_mask
        if(present(o_ext_mask)) g_ext_mask = o_ext_mask
        allocate( i_mask(count(o_mask)), i_ext_mask(count(g_ext_mask)) )

        ji_mask     = 1
        ji_ext_mask = 1
        do ji = 1, size(p_fld)
            if(o_mask(ji)) then
                i_mask(ji_mask) = ji
                ji_mask = ji_mask + 1
            end if
            if(g_ext_mask(ji)) then
                i_ext_mask(ji_ext_mask) = ji
                ji_ext_mask = ji_ext_mask + 1
            end if
        end do

        if(ji_ext_mask-1 == size(p_fld)) then
            write(*,*) 'Nothing to extrapolate...'
            deallocate(i_mask,i_ext_mask)
            return
        else
            write(*,*) 'Extrapolation of ', ji_ext_mask-1, ' points...'
        end if

        call build_kd_tree( p_lons(i_mask), p_lats(i_mask), p_fld(i_mask), .TRUE., null(tz_root), tz_root )

        do ji = 1, ji_ext_mask - 1
            z_lat = p_lats(i_ext_mask(ji))
            call nearest_neighbour( tz_root, p_lons(i_ext_mask(ji)), z_lat, tz_leaf, cos(z_lat*XPI/180.)**2 )
            p_fld(i_ext_mask(ji)) = tz_leaf%z_f_val
        end do
        deallocate(i_mask,i_ext_mask)

        call release_kd_tree(tz_root)
    end subroutine nearest_neighbour_extrapolate

    recursive subroutine nearest_neighbour( tp_tree, p_lon, p_lat, tp_nn_leaf, z_mlat )
      implicit none
        type(kd_node), pointer, intent(in) :: tp_tree
        real,                   intent(in) :: p_lon, p_lat
        type(kd_node), pointer             :: tp_nn_leaf

        type(kd_node), pointer :: tz_n_leaf, tz_current, tz_parent
        real    :: z_dist, z_curr_best, z_mlat
        logical :: g_cross

        if(.NOT. associated(tp_tree)) then
            tp_nn_leaf => null()
        else
            call find_near_leaf( tp_tree, p_lon, p_lat, tz_n_leaf )

            z_curr_best = z_mlat*(tz_n_leaf%z_x_val - p_lon)**2 + (tz_n_leaf%z_y_val - p_lat)**2

            tz_current => tz_n_leaf !%parent
            tp_nn_leaf => tz_n_leaf

            do while( associated(tz_current) )
                z_dist = z_mlat*(tz_current%z_x_val - p_lon)**2 + (tz_current%z_y_val - p_lat)**2
                if( z_dist < z_curr_best ) then
                    z_curr_best = z_dist
                    tp_nn_leaf => tz_current
                end if

                if( tz_current%g_is_x_section ) then
                    g_cross = z_mlat*(tz_current%z_x_val - p_lon)**2 < z_curr_best
                else
                    g_cross =        (tz_current%z_y_val - p_lat)**2 < z_curr_best
                end if


                if(g_cross) then
                    if( associated(tz_current%left, TARGET=tz_n_leaf) ) then !From left branch
                        tz_n_leaf => null()
                        if(associated(tz_current%right)) then
                            tz_parent => tz_current%right%parent
                            nullify(tz_current%right%parent)

                            call nearest_neighbour( tz_current%right, p_lon, p_lat, tz_n_leaf, z_mlat )

                            tz_current%right%parent => tz_parent
                        end if
                    else !From right branch
                        tz_n_leaf => null()

                        if(associated(tz_current%left )) then
                            tz_parent => tz_current%left%parent
                            nullify(tz_current%left%parent)

                            call nearest_neighbour( tz_current%left,  p_lon, p_lat, tz_n_leaf, z_mlat )

                            tz_current%left%parent => tz_parent
                        end if
                    end if

                    if(associated(tz_n_leaf)) then
                        z_dist = z_mlat*(tz_n_leaf%z_x_val - p_lon)**2 + (tz_n_leaf%z_y_val - p_lat)**2
                        if( z_dist < z_curr_best ) then
                            z_curr_best = z_dist
                            tp_nn_leaf => tz_n_leaf
                        end if
                    end if
                end if
                tz_n_leaf  => tz_current
                tz_current => tz_current%parent
            end do
        end if
    end subroutine nearest_neighbour

    recursive subroutine find_near_leaf( tp_tree, p_lon, p_lat, tp_leaf )
      implicit none
        type(kd_node), pointer, intent( in ) :: tp_tree
        real,          intent( in ) :: p_lon, p_lat
        type(kd_node), pointer      :: tp_leaf


        if( tp_tree%g_is_x_section ) then
            if( p_lon < tp_tree%z_x_val ) then
                if ( .NOT. associated( tp_tree%left  ) ) then
                    tp_leaf => tp_tree
                else
                    call find_near_leaf( tp_tree%left, p_lon, p_lat, tp_leaf )
                end if
            else
                if ( .NOT. associated( tp_tree%right ) ) then
                    tp_leaf => tp_tree
                else
                    call find_near_leaf( tp_tree%right, p_lon, p_lat, tp_leaf )
                end if
            end if
        else
            if( p_lat < tp_tree%z_y_val ) then
                if ( .NOT. associated( tp_tree%left  ) ) then
                    tp_leaf => tp_tree
                else
                    call find_near_leaf( tp_tree%left, p_lon, p_lat, tp_leaf )
                end if
            else
                if ( .NOT. associated( tp_tree%right ) ) then
                    tp_leaf => tp_tree
                else
                    call find_near_leaf( tp_tree%right, p_lon, p_lat, tp_leaf )
                end if
            end if
        end if
    end subroutine find_near_leaf

    recursive subroutine build_kd_tree( p_lons, p_lats, p_fld, o_split_x, tp_parent, tp_tree )
      implicit none
        real, intent( in ) :: p_lons( :            ), &
                              p_lats( size(p_lons) ), &
                              p_fld ( size(p_lons) )
        logical, intent(in):: o_split_x
        type(kd_node), pointer, intent(in) :: tp_parent
        type(kd_node), pointer    :: tp_tree

        type(indexed)      :: tz_idxs  ( size(p_lons) )

        integer :: i_len, i_med, i_med_val
        integer :: ji

        i_len = size( p_lons )

        if( .NOT. associated(tp_tree) ) allocate(tp_tree)

        !print*, i_len
        tp_tree%g_is_x_section = o_split_x
        tp_tree%parent => tp_parent

        if( i_len == 1 ) then
            tp_tree%z_x_val = p_lons(1)
            tp_tree%z_y_val = p_lats(1)
            tp_tree%z_f_val = p_fld (1)
        else
            tz_idxs%idx = [(ji,ji=1,i_len)]
            if( o_split_x ) then
                tz_idxs%value = p_lons
            else
                tz_idxs%value = p_lats
            end if
            call qsort( tz_idxs )

            i_med     = (1 + i_len)/2
            i_med_val = tz_idxs( i_med  )%idx

            tp_tree%z_x_val        = p_lons( i_med_val )
            tp_tree%z_y_val        = p_lats( i_med_val )
            tp_tree%z_f_val        = p_fld ( i_med_val )

            !print*, i_med, i_len

            if( i_med > 1 ) &
                call build_kd_tree( p_lons( tz_idxs( 1:i_med - 1 )%idx ),   &
                                    p_lats( tz_idxs( 1:i_med - 1 )%idx ),   &
                                    p_fld ( tz_idxs( 1:i_med - 1 )%idx ),   &
                                    .NOT. o_split_x,                        &
                                    tp_tree,                                &
                                    tp_tree%left )
            call build_kd_tree( p_lons( tz_idxs( i_med + 1:i_len )%idx ),   &
                                p_lats( tz_idxs( i_med + 1:i_len )%idx ),   &
                                p_fld ( tz_idxs( i_med + 1:i_len )%idx ),   &
                                .NOT. o_split_x,                            &
                                tp_tree,                                    &
                                tp_tree%right )
        end if
    end subroutine build_kd_tree

    recursive subroutine release_kd_tree( tp_tree )
      implicit none
        type( kd_node ), pointer :: tp_tree

        if(associated(tp_tree)) then
            if( associated(tp_tree%left)  ) then
                call release_kd_tree(tp_tree%left)
                deallocate(tp_tree%left)
            end if
            if( associated(tp_tree%right) ) then
                call release_kd_tree(tp_tree%right)
                deallocate(tp_tree%right)
            end if
        end if
    end subroutine

    subroutine qsort( idxs )
      use, intrinsic :: iso_c_binding
      implicit none
        type(indexed), intent( in out ), target :: idxs(:)

        integer(c_size_t) :: elem_count
        integer(c_size_t) :: elem_size

        interface
          subroutine qsort_c(array,elem_count,elem_size,compare) bind(C,name="qsort")
            import
            type   (c_ptr   ), value :: array
            integer(c_size_t), value :: elem_count
            integer(c_size_t), value :: elem_size
            type   (c_funptr), value :: compare !int(*compare)(const void *, const void *)
          end subroutine qsort_c !standard C library qsort
        end interface

        elem_count = size( idxs )
        elem_size  = 16

        call qsort_c( c_loc(idxs(1)), elem_count, elem_size, c_funloc(compar) )
    end subroutine qsort

    function compar( a, b ) result(res) bind(C)
      use, intrinsic :: iso_c_binding
      implicit none
        type   (c_ptr),value :: a, b
        integer(c_int)       :: res

        type(indexed), pointer :: af, bf

        call c_f_pointer( a, af )
        call c_f_pointer( b, bf )

        if ( af%value <  bf%value ) res = -1
        if ( af%value == bf%value ) res =  0
        if ( af%value >  bf%value ) res =  1

    end function compar
end module
