MODULE processing
  USE prec_const,  ONLY: dp, imagu, SQRT2, SQRT3
  USE grid,        ONLY: &
    local_na, local_np, local_nj, local_nky, local_nkx, local_nz, Ngz,Ngj,Ngp, &
    parray,pmax,ip0,&
    CONTAINSp0,ip1,CONTAINSp1,ip2,CONTAINSp2,ip3,CONTAINSp3,&
    jarray,jmax,ij0, dmax,&
    kyarray, AA_y,&
    kxarray, AA_x,&
    zarray, deltaz, ieven, iodd, inv_deltaz
  USE fields,           ONLY: moments, phi, psi
  USE array,            ONLY : kernel, nadiab_moments, &
                               ddz_napj, ddzND_Napj, interp_napj,&
                               dens, upar, uper, Tpar, Tper, temp
  USE geometry,         ONLY: Jacobian, iInt_Jacobian
  USE time_integration, ONLY: updatetlevel
  USE calculus,         ONLY: simpson_rule_z, grad_z, grad_z2, grad_z4, interp_z
  USE model,            ONLY: EM, CLOS, beta, HDz_h
  USE species,          ONLY: tau,q_tau,q_o_sqrt_tau_sigma,sqrt_tau_o_sigma
  USE basic,            ONLY: t0_process, t1_process, tc_process
  USE parallel,         ONLY: num_procs_ky, rank_ky, comm_ky
  USE mpi
  implicit none

  REAL(dp), PUBLIC, ALLOCATABLE, DIMENSION(:), PROTECTED :: pflux_x, gflux_x
  REAL(dp), PUBLIC, ALLOCATABLE, DIMENSION(:), PROTECTED :: hflux_x
  INTEGER :: ierr
  PUBLIC :: init_process
  PUBLIC :: compute_nadiab_moments_z_gradients_and_interp
  PUBLIC :: compute_density, compute_upar, compute_uperp
  PUBLIC :: compute_Tpar, compute_Tperp, compute_fluid_moments
  PUBLIC :: compute_radial_transport
  PUBLIC :: compute_radial_heatflux
  PUBLIC :: compute_Napjz_spectrum
CONTAINS

SUBROUTINE init_process
  USE grid,       ONLY: local_na
  IMPLICIT NONE
  ALLOCATE( pflux_x(local_na))
  ALLOCATE( gflux_x(local_na))
  ALLOCATE( hflux_x(local_na))
END SUBROUTINE init_process

! 1D diagnostic to compute the average radial particle transport <n_a v_ExB_x>_xyz
SUBROUTINE compute_radial_transport
  IMPLICIT NONE
  COMPLEX(dp) :: pflux_local, gflux_local, integral
  REAL(dp)    :: buffer(2)
  INTEGER     :: i_, root, iky, ikx, ia, iz, in, iodd, ieven
  COMPLEX(dp), DIMENSION(local_nz+Ngz) :: integrant
  DO ia = 1,local_na
    pflux_local = 0._dp ! particle flux
    gflux_local = 0._dp ! gyrocenter flux
    integrant   = 0._dp ! auxiliary variable for z integration
    !!---------- Gyro center flux (drift kinetic) ------------
    ! Electrostatic part
    IF(CONTAINSp0) THEN
      DO iz = 1,local_nz+ngz ! we include ghost for integration
        DO ikx = 1,local_nkx
          DO iky = 1,local_nky
            integrant(iz) = integrant(iz) &
             +Jacobian(iz,ieven)*moments(ia,ip0,ij0,iky,ikx,iz,updatetlevel)&
              *imagu*kyarray(iky)*CONJG(phi(iky,ikx,iz))
          ENDDO
        ENDDO
      ENDDO
    ENDIF
    ! Electromagnetic part
    IF( EM .AND. CONTAINSp1 ) THEN
        DO iz = 1,local_nz+ngz
          DO ikx = 1,local_nkx
            DO iky = 1,local_nky
              integrant(iz) = integrant(iz)&
                -Jacobian(iz,iodd)*sqrt_tau_o_sigma(ia)*moments(ia,ip1,ij0,iky,ikx,iz,updatetlevel)&
                 *imagu*kyarray(iky)*CONJG(psi(iky,ikx,iz))
          ENDDO
        ENDDO
      ENDDO
    ENDIF
    ! Integrate over z
    call simpson_rule_z(local_nz,deltaz,integrant,integral)
    ! Get process local gyrocenter flux with a factor two to account for the negative ky modes
    gflux_local = 2._dp*integral*iInt_Jacobian

    !
    integrant   = 0._dp ! reset auxiliary variable
    !!---------- Particle flux (gyrokinetic) ------------
    ! Electrostatic part
    IF(CONTAINSp0) THEN
      DO iz = 1,local_nz+ngz
          DO ikx = 1,local_nkx
            DO iky = 1,local_nky
              DO in = 1+ngj/2, local_nj+ngj/2 ! only interior points
                integrant(iz) = integrant(iz)+ &
                  Jacobian(iz,ieven)*moments(ia,ip0,in,iky,ikx,iz,updatetlevel)&
                  *imagu*kyarray(iky)*kernel(ia,in,iky,ikx,iz,ieven)*CONJG(phi(iky,ikx,iz))
              ENDDO
            ENDDO
        ENDDO
      ENDDO
    ENDIF
    ! Electromagnetic part
    IF( EM .AND. CONTAINSp1 ) THEN
      DO iz = 1,local_nz+ngz
        DO ikx = 1,local_nkx
          DO iky = 1,local_nky
            DO in = 1+ngj/2, local_nj+ngj/2 ! only interior points
              integrant(iz) = integrant(iz) - &
                Jacobian(iz,iodd)*sqrt_tau_o_sigma(ia)*moments(ia,ip1,in,iky,ikx,iz,updatetlevel)&
                *imagu*kyarray(iky)*kernel(ia,in,iky,ikx,iz,iodd)*CONJG(psi(iky,ikx,iz))
            ENDDO
          ENDDO
        ENDDO
      ENDDO
    ENDIF
    ! Integrate over z
    call simpson_rule_z(local_nz,deltaz,integrant,integral)
    ! Get process local particle flux with a factor two to account for the negative ky modes
    pflux_local = 2._dp*integral*iInt_Jacobian

    !!!!---------- Sum over all processes ----------
    buffer(1) = REAL(gflux_local)
    buffer(2) = REAL(pflux_local)
    root = 0
    !Gather manually among the rank_p=0 processes and perform the sum
    gflux_x(ia) = 0
    pflux_x(ia) = 0
    IF (num_procs_ky .GT. 1) THEN
        !! Everyone sends its local_sum to root = 0
        IF (rank_ky .NE. root) THEN
            CALL MPI_SEND(buffer, 2 , MPI_DOUBLE_PRECISION, root, 1234, comm_ky, ierr)
        ELSE
            ! Recieve from all the other processes
            DO i_ = 0,num_procs_ky-1
                IF (i_ .NE. rank_ky) &
                    CALL MPI_RECV(buffer, 2 , MPI_DOUBLE_PRECISION, i_, 1234, comm_ky, MPI_STATUS_IGNORE, ierr)
                    gflux_x(ia) = gflux_x(ia) + buffer(1)
                    pflux_x(ia) = pflux_x(ia) + buffer(2)
            ENDDO
        ENDIF
    ELSE
      gflux_x(ia) = gflux_local
      pflux_x(ia) = pflux_local
    ENDIF
  ENDDO
    ! if(my_id .eq. 0) write(*,*) 'pflux_ri = ',pflux_ri
END SUBROUTINE compute_radial_transport

! 1D diagnostic to compute the average radial particle transport <T_i v_ExB_x>_xyz
SUBROUTINE compute_radial_heatflux
  IMPLICIT NONE
  COMPLEX(dp) :: hflux_local, integral
  REAL(dp)    :: buffer(2), n_dp
  INTEGER     :: i_, root, in, ia, iky, ikx, iz
  COMPLEX(dp), DIMENSION(local_nz+ngz) :: integrant        ! charge density q_a n_a
  DO ia = 1,local_na
  hflux_local = 0._dp ! heat flux
  integrant   = 0._dp ! z integration auxiliary variable
  !!----------------ELECTROSTATIC CONTRIBUTION---------------------------
  IF(CONTAINSp0 .AND. CONTAINSp2) THEN
    ! Loop to compute gamma_kx = sum_ky sum_j -i k_z Kernel_j Na00 * phi
    DO iz  = 1,local_nz+ngz
    DO ikx = 1,local_nkx
    DO iky = 1,local_nky
      DO in = 1+ngj/2, local_nj+ngj/2 ! only interior points
        n_dp = jarray(in)
        integrant(iz) = integrant(iz) &
        +Jacobian(iz,ieven)*tau(ia)*imagu*kyarray(iky)*CONJG(phi(iky,ikx,iz))&
         *kernel(ia,in,iky,ikx,iz,ieven)*(&
                       0.5_dp*SQRT2*moments(ia,ip2,in  ,iky,ikx,iz,updatetlevel)&
             +(2._dp*n_dp + 1.5_dp)*moments(ia,ip0,in  ,iky,ikx,iz,updatetlevel)&
                      -(n_dp+1._dp)*moments(ia,ip0,in+1,iky,ikx,iz,updatetlevel)&
                              -n_dp*moments(ia,ip0,in-1,iky,ikx,iz,updatetlevel))
      ENDDO
    ENDDO
    ENDDO
    ENDDO
   ENDIF
  IF(EM .AND. CONTAINSp1 .AND. CONTAINSp3) THEN
    !!----------------ELECTROMAGNETIC CONTRIBUTION--------------------
    DO iz  = 1,local_nz+ngz
    DO ikx = 1,local_nkx
    DO iky = 1,local_nky
      DO in = 1+ngj/2, local_nj+ngj/2 ! only interior points
        n_dp = jarray(in)
        integrant(iz) = integrant(iz) &
         +Jacobian(iz,iodd)*tau(ia)*sqrt_tau_o_sigma(ia)*imagu*kyarray(iky)*CONJG(psi(iky,ikx,iz))&
           *kernel(ia,in,iky,ikx,iz,iodd)*(&
                   0.5_dp*SQRT2*SQRT3*moments(ia,ip3,in  ,iky,ikx,iz,updatetlevel)&
                              +1.5_dp*moments(ia,ip1,in  ,iky,ikx,iz,updatetlevel)&
                  +(2._dp*n_dp+1._dp)*moments(ia,ip1,in  ,iky,ikx,iz,updatetlevel)&
                        -(n_dp+1._dp)*moments(ia,ip1,in+1,iky,ikx,iz,updatetlevel)&
                                -n_dp*moments(ia,ip1,in-1,iky,ikx,iz,updatetlevel))
      ENDDO
    ENDDO
    ENDDO
    ENDDO
  ENDIF
  ! Add polarisation contribution
  ! integrant(iz) = integrant(iz) + tau_i*imagu*ky_&
  ! *CONJG(phi(iky,ikx,iz))*phi(iky,ikx,iz) * HF_phi_correction_operator(iky,ikx,iz)
  ! Integrate over z
  call simpson_rule_z(local_nz,deltaz,integrant,integral)
  ! Double it for taking into account the other half plane
  hflux_local = 2._dp*integral*iInt_Jacobian
  buffer(2)   = REAL(hflux_local)
  root = 0
  !Gather manually among the rank_p=0 processes and perform the sum
  hflux_x(ia) = 0
  IF (num_procs_ky .GT. 1) THEN
      !! Everyone sends its local_sum to root = 0
      IF (rank_ky .NE. root) THEN
          CALL MPI_SEND(buffer, 2 , MPI_DOUBLE_PRECISION, root, 1234, comm_ky, ierr)
      ELSE
          ! Recieve from all the other processes
          DO i_ = 0,num_procs_ky-1
              IF (i_ .NE. rank_ky) &
                  CALL MPI_RECV(buffer, 2 , MPI_DOUBLE_PRECISION, i_, 1234, comm_ky, MPI_STATUS_IGNORE, ierr)
                  hflux_x(ia) = hflux_x(ia) + buffer(2)
          ENDDO
      ENDIF
  ELSE
    hflux_x(ia) = hflux_local
  ENDIF
  ENDDO
END SUBROUTINE compute_radial_heatflux

SUBROUTINE compute_nadiab_moments_z_gradients_and_interp
  ! evaluate the non-adiabatique ion moments
  !
  ! n_{pi} = N^{pj} + kernel(j) /tau_i phi delta_p0
  !
  IMPLICIT NONE
  INTEGER :: eo, p_int, j_int, ia,ip,ij,iky,ikx,iz
  CALL cpu_time(t0_process)

  !non adiab moments
  DO iz=1,local_nz+ngz
    DO ikx=1,local_nkx
      DO iky=1,local_nky
        DO ij=1,local_nj+ngj
          DO ip=1,local_np+ngp
            DO ia = 1,local_na
              IF(parray(ip) .EQ. 0) THEN
                nadiab_moments(ia,ip,ij,iky,ikx,iz) = moments(ia,ip,ij,iky,ikx,iz,updatetlevel) &
                                  + q_tau(ia)*kernel(ia,ij,iky,ikx,iz,ieven)*phi(iky,ikx,iz)
              ELSEIF( (parray(ip) .EQ. 1) .AND. (beta .GT. 0) ) THEN
                nadiab_moments(ia,ip,ij,iky,ikx,iz) = moments(ia,ip,ij,iky,ikx,iz,updatetlevel) &
                                  - q_o_sqrt_tau_sigma(ia)*kernel(ia,ij,iky,ikx,iz,ieven)*psi(iky,ikx,iz)
              ELSE
                nadiab_moments(ia,ip,ij,iky,ikx,iz) = moments(ia,ip,ij,iky,ikx,iz,updatetlevel)
              ENDIF
            ENDDO
          ENDDO
        ENDDO
      ENDDO
    ENDDO
  ENDDO

  !! Ensure to kill the moments too high if closue option is set to 1
  IF(CLOS .EQ. 1) THEN
    DO ij=1,local_nj+ngj
      j_int = jarray(ij)
      DO ip=1,local_np+ngp
        p_int = parray(ip)
        DO ia = 1,local_na
          IF(p_int+2*j_int .GT. dmax) &
          nadiab_moments(ia,ip,ij,:,:,:) = 0._dp
        ENDDO
      ENDDO
    ENDDO
  ENDIF

   !------------- INTERP AND GRADIENTS ALONG Z ----------------------------------
    DO ikx = 1,local_nkx
      DO iky = 1,local_nky
        DO ij = 1,local_nj+ngj
          DO ip = 1,local_np+ngp
            DO ia = 1,local_na
              p_int = parray(ip)
              eo    = MODULO(p_int,2)+1 ! Indicates if we are on even or odd z grid
              ! Compute z first derivative
              CALL   grad_z(eo,local_nz,ngz,inv_deltaz,nadiab_moments(ia,ip,ij,iky,ikx,:),ddz_napj(ia,ip,ij,iky,ikx,:))
              ! Parallel numerical diffusion
              IF (HDz_h) THEN
                CALL  grad_z4(local_nz,ngz,inv_deltaz,nadiab_moments(ia,ip,ij,iky,ikx,:),ddzND_Napj(ia,ip,ij,iky,ikx,:))
              ELSE
                CALL  grad_z4(local_nz,ngz,inv_deltaz,moments(ia,ip,ij,iky,ikx,:,updatetlevel),ddzND_Napj(ia,ip,ij,iky,ikx,:))
              ENDIF
              ! Compute even odd grids interpolation
              CALL interp_z(eo,local_nz,ngz,nadiab_moments(ia,ip,ij,iky,ikx,1:local_nz+ngz), interp_napj(ia,ip,ij,iky,ikx,1:local_nz))
            ENDDO
          ENDDO
        ENDDO
      ENDDO
    ENDDO

    ! Phi parallel gradient (not implemented fully, should be negligible)
    ! DO ikx = 1,local_nkx
    !   DO iky = 1,local_nky
    !     CALL grad_z(0,phi(iky,ikx,1:local_nz+ngz), ddz_phi(iky,ikx,1:local_nz))
    !   ENDDO
    ! ENDDO

  ! Execution time end
  CALL cpu_time(t1_process)
  tc_process = tc_process + (t1_process - t0_process)
END SUBROUTINE compute_nadiab_moments_z_gradients_and_interp

SUBROUTINE compute_Napjz_spectrum
  USE fields, ONLY : moments
  USE array,  ONLY : Napjz
  USE time_integration, ONLY : updatetlevel
  IMPLICIT NONE
  REAL(dp), DIMENSION(local_np,local_nj,local_nz) :: local_sum,global_sum, buffer
  INTEGER  :: i_, root, count, ia, ip, ij, iky, ikx, iz
  root = 0
  DO ia=1,local_na
    ! z-moment spectrum
    ! build local sum
    local_sum = 0._dp
    DO ikx = 1,local_nkx
      DO iky = 1,local_nky
        DO ij = 1,local_nj
          DO ip = 1,local_np
            local_sum(ip,ij,iz)  = local_sum(ip,ij,iz)  + &
            (moments(ia,ip+Ngp/2,ij+Ngj/2,iky,ikx,iz+Ngz/2,updatetlevel) &
            * CONJG(moments(ia,ip+Ngp/2,ij+Ngj/2,iky,ikx,iz+Ngz/2,updatetlevel)))
          ENDDO
        ENDDO
      ENDDO
    ENDDO
    ! sum reduction
    buffer     = local_sum
    global_sum = 0._dp
    count = local_np*local_nj*local_nz
    IF (num_procs_ky .GT. 1) THEN
        !! Everyone sends its local_sum to root = 0
        IF (rank_ky .NE. root) THEN
            CALL MPI_SEND(buffer, count , MPI_DOUBLE_PRECISION, root, 5678, comm_ky, ierr)
        ELSE
            ! Recieve from all the other processes
            DO i_ = 0,num_procs_ky-1
                IF (i_ .NE. rank_ky) &
                    CALL MPI_RECV(buffer, count , MPI_DOUBLE_PRECISION, i_, 5678, comm_ky, MPI_STATUS_IGNORE, ierr)
                    global_sum = global_sum + buffer
            ENDDO
        ENDIF
    ELSE
      global_sum = local_sum
    ENDIF
    Napjz(ia,:,:,:) = global_sum
  ENDDO
END SUBROUTINE compute_Napjz_spectrum

!_____________________________________________________________________________!
!!!!! FLUID MOMENTS COMPUTATIONS !!!!!
! Compute the 2D particle density for electron and ions (sum over Laguerre)
SUBROUTINE compute_density
  IMPLICIT NONE
  COMPLEX(dp) :: dens_
  INTEGER :: ia, iz, iky, ikx, ij
  DO ia=1,local_na
    IF ( CONTAINSp0 ) THEN
      ! Loop to compute dens_i = sum_j kernel_j Ni0j at each k
      DO iz = 1,local_nz
        DO iky = 1,local_nky
          DO ikx = 1,local_nkx
            dens_ = 0._dp
            DO ij = 1, local_nj
                dens_ = dens_ + kernel(ia,ij+ngj/2,iky,ikx,iz+ngz/2,0) * moments(ia,ip0,ij+ngj/2,iky,ikx,iz+ngz/2,updatetlevel)
            ENDDO
            dens(ia,iky,ikx,iz) = dens_
          ENDDO
        ENDDO
      ENDDO
    ENDIF
  ENDDO
END SUBROUTINE compute_density

! Compute the 2D particle fluid perp velocity for electron and ions (sum over Laguerre)
SUBROUTINE compute_uperp
  IMPLICIT NONE
  COMPLEX(dp) :: uperp_
  INTEGER :: ia, iz, iky, ikx, ij
  DO ia=1,local_na
    IF ( CONTAINSp0 ) THEN
        DO iz = 1,local_nz
          DO iky = 1,local_nky
            DO ikx = 1,local_nkx
              uperp_ = 0._dp
              DO ij = 1, local_nj
                uperp_ = uperp_ + kernel(ia,ij+ngj/2,iky,ikx,iz+ngz/2,0) *&
                 0.5_dp*(moments(ia,ip0,ij+ngj/2,iky,ikx,iz+ngz/2,updatetlevel) - moments(ia,ip0,ij-1+ngj/2,iky,ikx,iz+ngz/2,updatetlevel))
               ENDDO
              uper(ia,iky,ikx,iz) = uperp_
            ENDDO
          ENDDO
        ENDDO
    ENDIF
  ENDDO
END SUBROUTINE compute_uperp

! Compute the 2D particle fluid par velocity for electron and ions (sum over Laguerre)
SUBROUTINE compute_upar
  IMPLICIT NONE
  INTEGER :: ia, iz, iky, ikx, ij
  COMPLEX(dp) :: upar_
  DO ia=1,local_na
    IF ( CONTAINSp1 ) THEN
      DO iz = 1,local_nz
        DO iky = 1,local_nky
          DO ikx = 1,local_nkx
            upar_ = 0._dp
            DO ij = 1, local_nj
              upar_ = upar_ + kernel(ia,ij+ngj/2,iky,ikx,iz+ngz/2,1)*moments(ia,ip1,ij+ngj/2,iky,ikx,iz+ngz/2,updatetlevel)
             ENDDO
            upar(ia,iky,ikx,iz) = upar_
          ENDDO
        ENDDO
      ENDDO
    ENDIF
  ENDDO
END SUBROUTINE compute_upar

! Compute the 2D particle temperature for electron and ions (sum over Laguerre)
SUBROUTINE compute_tperp
  USE time_integration, ONLY : updatetlevel
  IMPLICIT NONE
  REAL(dp)    :: j_dp
  COMPLEX(dp) :: Tperp_
  INTEGER     :: ia, iz, iky, ikx, ij
  DO ia=1,local_na
    IF ( CONTAINSp0 .AND. CONTAINSp2 ) THEN
        ! Loop to compute T = 1/3*(Tpar + 2Tperp)
        DO iz = 1,local_nz
          DO iky = 1,local_nky
            DO ikx = 1,local_nkx
              Tperp_ = 0._dp
              DO ij = 1, local_nj
                j_dp = REAL(ij-1,dp)
                Tperp_ = Tperp_ + kernel(ia,ij,iky,ikx,iz,0)*&
                    ((2_dp*j_dp+1)*moments(ia,ip0,ij  +ngj/2,iky,ikx,iz+ngz/2,updatetlevel)&
                    -j_dp         *moments(ia,ip0,ij-1+ngj/2,iky,ikx,iz+ngz/2,updatetlevel)&
                    -j_dp+1       *moments(ia,ip0,ij+1+ngj/2,iky,ikx,iz+ngz/2,updatetlevel))
              ENDDO
              Tper(ia,iky,ikx,iz) = Tperp_
            ENDDO
          ENDDO
        ENDDO
    ENDIF
  ENDDO
END SUBROUTINE compute_Tperp

! Compute the 2D particle temperature for electron and ions (sum over Laguerre)
SUBROUTINE compute_Tpar
  USE time_integration, ONLY : updatetlevel
  IMPLICIT NONE
  REAL(dp)    :: j_dp
  COMPLEX(dp) :: Tpar_
  INTEGER     :: ia, iz, iky, ikx, ij

  DO ia=1,local_na
    IF ( CONTAINSp0 .AND. CONTAINSp0 ) THEN
      ! Loop to compute T = 1/3*(Tpar + 2Tperp)
      DO iz = 1,local_nz
        DO iky = 1,local_nky
          DO ikx = 1,local_nkx
            Tpar_ = 0._dp
            DO ij = 1, local_nj
              j_dp = REAL(ij-1,dp)
              Tpar_  = Tpar_ + kernel(ia,ij+ngj/2,iky,ikx,iz+ngz/2,0)*&
               (SQRT2 * moments(ia,ip2,ij+ngj/2,iky,ikx,iz+ngz/2,updatetlevel) &
                      + moments(ia,ip0,ij+ngj/2,iky,ikx,iz+ngz/2,updatetlevel))
            ENDDO
            Tpar(ia,iky,ikx,iz) = Tpar_
          ENDDO
        ENDDO
      ENDDO
    ENDIF
  ENDDO
END SUBROUTINE compute_Tpar

! Compute the 2D particle fluid moments for electron and ions (sum over Laguerre)
SUBROUTINE compute_fluid_moments
  IMPLICIT NONE
  CALL compute_density
  CALL compute_upar
  CALL compute_uperp
  CALL compute_Tpar
  CALL compute_Tperp
  ! Temperature
  temp = (Tpar - 2._dp * Tper)/3._dp - dens
END SUBROUTINE compute_fluid_moments

END MODULE processing