MODULE grid
  ! Grid module for spatial discretization
  USE prec_const
  USE basic

  IMPLICIT NONE
  PRIVATE

  !   GRID Namelist
  INTEGER,  PUBLIC, PROTECTED :: pmaxe = 1      ! The maximal electron Hermite-moment computed
  INTEGER,  PUBLIC, PROTECTED :: jmaxe = 1      ! The maximal electron Laguerre-moment computed
  INTEGER,  PUBLIC, PROTECTED :: pmaxi = 1      ! The maximal ion Hermite-moment computed
  INTEGER,  PUBLIC, PROTECTED :: jmaxi = 1      ! The maximal ion Laguerre-moment computed
  INTEGER,  PUBLIC, PROTECTED :: maxj  = 1      ! The maximal Laguerre-moment
  INTEGER,  PUBLIC, PROTECTED :: dmaxe = 1      ! The maximal full GF set of e-moments v^dmax
  INTEGER,  PUBLIC, PROTECTED :: dmaxi = 1      ! The maximal full GF set of i-moments v^dmax
  INTEGER,  PUBLIC, PROTECTED :: Nx    = 16     ! Number of total internal grid points in x
  REAL(dp), PUBLIC, PROTECTED :: Lx    = 1._dp  ! horizontal length of the spatial box
  INTEGER,  PUBLIC, PROTECTED :: Ny    = 16     ! Number of total internal grid points in y
  REAL(dp), PUBLIC, PROTECTED :: Ly    = 1._dp  ! vertical length of the spatial box
  INTEGER,  PUBLIC, PROTECTED :: Nz    = 1      ! Number of total perpendicular planes
  REAL(dp), PUBLIC, PROTECTED :: Npol  = 1._dp  ! number of poloidal turns
  INTEGER,  PUBLIC, PROTECTED :: Odz   = 4      ! order of z interp and derivative schemes
  REAL(dp), PUBLIC, PROTECTED :: q0    = 1._dp  ! safety factor
  REAL(dp), PUBLIC, PROTECTED :: shear = 0._dp  ! magnetic field shear
  REAL(dp), PUBLIC, PROTECTED :: eps   = 0._dp ! inverse aspect ratio
  INTEGER,  PUBLIC, PROTECTED :: Nkx   = 8      ! Number of total internal grid points in kx
  REAL(dp), PUBLIC, PROTECTED :: Lkx   = 1._dp  ! horizontal length of the fourier box
  INTEGER,  PUBLIC, PROTECTED :: Nky   = 16     ! Number of total internal grid points in ky
  REAL(dp), PUBLIC, PROTECTED :: Lky   = 1._dp  ! vertical length of the fourier box
  REAL(dp), PUBLIC, PROTECTED :: kpar  = 0_dp   ! parallel wave vector component
  ! For Orszag filter
  REAL(dp), PUBLIC, PROTECTED :: two_third_kxmax
  REAL(dp), PUBLIC, PROTECTED :: two_third_kymax
  REAL(dp), PUBLIC :: two_third_kpmax

  ! 1D Antialiasing arrays (2/3 rule)
  REAL(dp), DIMENSION(:), ALLOCATABLE, PUBLIC :: AA_x
  REAL(dp), DIMENSION(:), ALLOCATABLE, PUBLIC :: AA_y

  ! Grids containing position in physical space
  REAL(dp), DIMENSION(:),   ALLOCATABLE, PUBLIC :: xarray
  REAL(dp), DIMENSION(:),   ALLOCATABLE, PUBLIC :: yarray
  ! Local and global z grids, 2D since it has to store odd and even grids
  REAL(dp), DIMENSION(:,:), ALLOCATABLE, PUBLIC :: zarray
  REAL(dp), DIMENSION(:),   ALLOCATABLE, PUBLIC :: zarray_full
  ! local z weights for computing simpson rule
  INTEGER,  DIMENSION(:),   ALLOCATABLE, PUBLIC :: zweights_SR
  REAL(dp), PUBLIC, PROTECTED  ::  deltax,  deltay, deltaz, inv_deltaz, diff_dz_coeff
  INTEGER,  PUBLIC, PROTECTED  ::  ixs,  ixe,  iys,  iye,  izs,  ize
  INTEGER,  PUBLIC, PROTECTED  ::  izgs, izge ! ghosts
  LOGICAL,  PUBLIC, PROTECTED  ::  SG = .true.! shifted grid flag
  INTEGER,  PUBLIC :: ir,iz ! counters
  ! Data about parallel distribution for kx
  integer(C_INTPTR_T), PUBLIC :: local_nkx, local_nky
  integer(C_INTPTR_T), PUBLIC :: local_nkx_offset, local_nky_offset
  INTEGER,             PUBLIC :: local_nkp
  ! "" for p
  INTEGER,             PUBLIC :: local_np_e, local_np_i
  INTEGER,             PUBLIC :: total_np_e, total_np_i
  integer(C_INTPTR_T), PUBLIC :: local_np_e_offset, local_np_i_offset
  INTEGER, DIMENSION(:), ALLOCATABLE, PUBLIC :: counts_np_e, counts_np_i
  INTEGER, DIMENSION(:), ALLOCATABLE, PUBLIC :: displs_np_e, displs_np_i
  ! "" for z
  INTEGER,             PUBLIC :: local_nz
  INTEGER,             PUBLIC :: total_nz
  integer(C_INTPTR_T), PUBLIC :: local_nz_offset
  INTEGER, DIMENSION(:), ALLOCATABLE, PUBLIC :: counts_nz
  INTEGER, DIMENSION(:), ALLOCATABLE, PUBLIC :: displs_nz
  ! "" for j (not parallelized)
  INTEGER,             PUBLIC :: local_nj_e, local_nj_i
  ! Grids containing position in fourier space
  REAL(dp), DIMENSION(:),     ALLOCATABLE, PUBLIC :: kxarray, kxarray_full
  REAL(dp), DIMENSION(:),     ALLOCATABLE, PUBLIC :: kyarray, kyarray_full
  ! Kperp array depends on kx, ky, z (geometry), eo (even or odd zgrid)
  REAL(dp), DIMENSION(:,:,:,:), ALLOCATABLE, PUBLIC :: kparray
  REAL(dp), PUBLIC, PROTECTED ::  deltakx, deltaky, kx_max, ky_max, kx_min, ky_min!, kp_max
  REAL(dp), PUBLIC, PROTECTED ::  local_kxmax, local_kymax
  INTEGER,  PUBLIC, PROTECTED ::  ikxs, ikxe, ikys, ikye!, ikps, ikpe
  INTEGER,  PUBLIC, PROTECTED :: ikx_0, iky_0, ikx_max, iky_max ! Indices of k-grid origin and max
  INTEGER,  PUBLIC            :: ikx, iky, ip, ij, ikp, pp2, eo ! counters
  LOGICAL,  PUBLIC, PROTECTED :: contains_kx0   = .false. ! flag if the proc contains kx=0 index
  LOGICAL,  PUBLIC, PROTECTED :: contains_ky0   = .false. ! flag if the proc contains ky=0 index
  LOGICAL,  PUBLIC, PROTECTED :: contains_kxmax = .false. ! flag if the proc contains kx=kxmax index

  ! Grid containing the polynomials degrees
  INTEGER,  DIMENSION(:), ALLOCATABLE, PUBLIC :: parray_e, parray_e_full
  INTEGER,  DIMENSION(:), ALLOCATABLE, PUBLIC :: parray_i, parray_i_full
  INTEGER,  DIMENSION(:), ALLOCATABLE, PUBLIC :: jarray_e, jarray_e_full
  INTEGER,  DIMENSION(:), ALLOCATABLE, PUBLIC :: jarray_i, jarray_i_full
  INTEGER,  PUBLIC, PROTECTED ::  ips_e,ipe_e, ijs_e,ije_e ! Start and end indices for pol. deg.
  INTEGER,  PUBLIC, PROTECTED ::  ips_i,ipe_i, ijs_i,ije_i
  INTEGER,  PUBLIC, PROTECTED ::  ipgs_e,ipge_e, ijgs_e,ijge_e ! Ghosts start and end indices
  INTEGER,  PUBLIC, PROTECTED ::  ipgs_i,ipge_i, ijgs_i,ijge_i
  INTEGER,  PUBLIC, PROTECTED ::  deltape, ip0_e, ip1_e, ip2_e ! Pgrid spacing and moment 0,1,2 index
  INTEGER,  PUBLIC, PROTECTED ::  deltapi, ip0_i, ip1_i, ip2_i
  LOGICAL,  PUBLIC, PROTECTED ::  CONTAINS_ip0_e, CONTAINS_ip0_i
  LOGICAL,  PUBLIC, PROTECTED ::  CONTAINS_ip1_e, CONTAINS_ip1_i
  LOGICAL,  PUBLIC, PROTECTED ::  CONTAINS_ip2_e, CONTAINS_ip2_i
  ! Usefull inverse numbers
  REAL(dp), PUBLIC, PROTECTED :: inv_Nx, inv_Ny, inv_Nz

  ! Public Functions
  PUBLIC :: init_1Dgrid_distr
  PUBLIC :: set_pgrid, set_jgrid
  PUBLIC :: set_kxgrid, set_kygrid, set_zgrid
  PUBLIC :: grid_readinputs, grid_outputinputs
  PUBLIC :: bare, bari

  ! Precomputations
  real(dp), PUBLIC, PROTECTED    :: pmaxe_dp, pmaxi_dp, jmaxe_dp,jmaxi_dp

CONTAINS


  SUBROUTINE grid_readinputs
    ! Read the input parameters
    USE prec_const
    IMPLICIT NONE
    INTEGER :: lu_in   = 90              ! File duplicated from STDIN

    NAMELIST /GRID/ pmaxe, jmaxe, pmaxi, jmaxi, &
                    Nx,  Lx,  Ny,  Ly, Nz, Npol, q0, shear, eps, SG
    READ(lu_in,grid)

    !! Compute the maximal degree of full GF moments set
    !   i.e. : all moments N_a^pj s.t. p+2j<=d are simulated (see GF closure)
    dmaxe = min(pmaxe,2*jmaxe+1)
    dmaxi = min(pmaxi,2*jmaxi+1)

    ! If no parallel dim (Nz=1), the moment hierarchy is separable between odds and even P
    !! and since the energy is injected in P=0 and P=2 for density/temperature gradients
    !! there is no need of simulating the odd p which will only be damped.
    !! We define in this case a grid Parray = 0,2,4,...,Pmax i.e. deltap = 2 instead of 1
    !! to spare computation
    IF(Nz .EQ. 1) THEN
      deltape = 2; deltapi = 2;
      pp2     = 1; ! index p+2 is ip+1
    ELSE
      deltape = 1; deltapi = 1;
      pp2     = 2; ! index p+2 is ip+1
    ENDIF

    ! Usefull precomputations
    inv_Nx = 1._dp/REAL(Nx,dp)
    inv_Ny = 1._dp/REAL(Ny,dp)

  END SUBROUTINE grid_readinputs

  SUBROUTINE init_1Dgrid_distr
    ! write(*,*) Nx
    local_nkx        = (Nx/2+1)/num_procs_kx
    ! write(*,*) local_nkx
    local_nkx_offset = rank_kx*local_nkx
    if (rank_kx .EQ. num_procs_kx-1) local_nkx = (Nx/2+1)-local_nkx_offset
  END SUBROUTINE init_1Dgrid_distr

  SUBROUTINE set_pgrid
    USE prec_const
    IMPLICIT NONE
    INTEGER :: ip, istart, iend, in

    ! Total number of Hermite polynomials we will evolve
    total_np_e = (Pmaxe/deltape) + 1
    total_np_i = (Pmaxi/deltapi) + 1
    ! Build the full grids on process 0 to diagnose it without comm
    ALLOCATE(parray_e_full(1:total_np_e))
    ALLOCATE(parray_i_full(1:total_np_i))
    ! P
    DO ip = 1,total_np_e; parray_e_full(ip) = (ip-1)*deltape; END DO
    DO ip = 1,total_np_i; parray_i_full(ip) = (ip-1)*deltapi; END DO
    !! Parallel data distribution
    ! Local data distribution
    CALL decomp1D(total_np_e, num_procs_p, rank_p, ips_e, ipe_e)
    CALL decomp1D(total_np_i, num_procs_p, rank_p, ips_i, ipe_i)
    local_np_e = ipe_e - ips_e + 1
    local_np_i = ipe_i - ips_i + 1
    ! Ghosts boundaries
    ipgs_e = ips_e - 2/deltape; ipge_e = ipe_e + 2/deltape;
    ipgs_i = ips_i - 2/deltapi; ipge_i = ipe_i + 2/deltapi;
    ! List of shift and local numbers between the different processes (used in scatterv and gatherv)
    ALLOCATE(counts_np_e (1:num_procs_p))
    ALLOCATE(counts_np_i (1:num_procs_p))
    ALLOCATE(displs_np_e (1:num_procs_p))
    ALLOCATE(displs_np_i (1:num_procs_p))
    DO in = 0,num_procs_p-1
      CALL decomp1D(total_np_e, num_procs_p, in, istart, iend)
      counts_np_e(in+1) = iend-istart+1
      displs_np_e(in+1) = istart-1
      CALL decomp1D(total_np_i, num_procs_p, in, istart, iend)
      counts_np_i(in+1) = iend-istart+1
      displs_np_i(in+1) = istart-1
    ENDDO

    ! local grid computation
    CONTAINS_ip0_e = .FALSE.
    CONTAINS_ip1_e = .FALSE.
    CONTAINS_ip2_e = .FALSE.
    CONTAINS_ip0_i = .FALSE.
    CONTAINS_ip1_i = .FALSE.
    CONTAINS_ip2_i = .FALSE.
    ALLOCATE(parray_e(ipgs_e:ipge_e))
    ALLOCATE(parray_i(ipgs_i:ipge_i))
    DO ip = ipgs_e,ipge_e
      parray_e(ip) = (ip-1)*deltape
      ! Storing indices of particular degrees for fluid moments computations
      IF(parray_e(ip) .EQ. 0) THEN
        ip0_e          = ip
        CONTAINS_ip0_e = .TRUE.
      ENDIF
      IF(parray_e(ip) .EQ. 1) THEN
        ip1_e          = ip
        CONTAINS_ip1_e = .TRUE.
      ENDIF
      IF(parray_e(ip) .EQ. 2) THEN
        ip2_e          = ip
        CONTAINS_ip2_e = .TRUE.
      ENDIF
    END DO
    DO ip = ipgs_i,ipge_i
      parray_i(ip) = (ip-1)*deltapi
      ! Storing indices of particular degrees for fluid moments computations
      IF(parray_i(ip) .EQ. 0) THEN
        ip0_i          = ip
        CONTAINS_ip0_i = .TRUE.
      ENDIF
      IF(parray_i(ip) .EQ. 1) THEN
        ip1_i          = ip
        CONTAINS_ip1_i = .TRUE.
      ENDIF
      IF(parray_i(ip) .EQ. 2) THEN
        ip2_i          = ip
        CONTAINS_ip2_i = .TRUE.
      ENDIF
    END DO
    !DGGK operator uses moments at index p=2 (ip=3) for the p=0 term so the
    ! process that contains ip=1 MUST contain ip=3 as well for both e and i.
    IF(((ips_e .EQ. ip0_e) .OR. (ips_i .EQ. ip0_e)) .AND. ((ipe_e .LT. ip2_e) .OR. (ipe_i .LT. ip2_i)))&
     WRITE(*,*) "Warning : distribution along p may not work with DGGK"
    ! Precomputations
    pmaxe_dp   = real(pmaxe,dp)
    pmaxi_dp   = real(pmaxi,dp)
  END SUBROUTINE set_pgrid

  SUBROUTINE set_jgrid
    USE prec_const
    IMPLICIT NONE
    INTEGER :: ij

    ! Build the full grids on process 0 to diagnose it without comm
    ALLOCATE(jarray_e_full(1:jmaxe+1))
    ALLOCATE(jarray_i_full(1:jmaxi+1))
    ! J
    DO ij = 1,jmaxe+1; jarray_e_full(ij) = (ij-1); END DO
    DO ij = 1,jmaxi+1; jarray_i_full(ij) = (ij-1); END DO
    ! Local data
    ijs_e = 1; ije_e = jmaxe + 1
    ijs_i = 1; ije_i = jmaxi + 1
    ! Ghosts boundaries
    ijgs_e = ijs_e - 1; ijge_e = ije_e + 1;
    ijgs_i = ijs_i - 1; ijge_i = ije_i + 1;
    ! Local number of J
    local_nj_e = ijge_e - ijgs_e + 1
    local_nj_i = ijge_i - ijgs_i + 1
    ALLOCATE(jarray_e(ijgs_e:ijge_e))
    ALLOCATE(jarray_i(ijgs_i:ijge_i))
    DO ij = ijgs_e,ijge_e; jarray_e(ij) = ij-1; END DO
    DO ij = ijgs_i,ijge_i; jarray_i(ij) = ij-1; END DO
    ! Precomputations
    maxj  = MAX(jmaxi, jmaxe)
    jmaxe_dp   = real(jmaxe,dp)
    jmaxi_dp   = real(jmaxi,dp)
  END SUBROUTINE set_jgrid


  SUBROUTINE set_kxgrid
    USE prec_const
    USE model, ONLY: LINEARITY
    IMPLICIT NONE
    INTEGER :: i_
    Nkx = Nx/2+1 ! Defined only on positive kx since fields are real
    ! Grid spacings
    IF (Nx .EQ. 1) THEN
      deltakx = 0._dp
      kx_max  = 0._dp
      kx_min  = 0._dp
    ELSE
      deltakx = 2._dp*PI/Lx
      kx_max  = Nkx*deltakx
      kx_min  = deltakx
    ENDIF
    ! Build the full grids on process 0 to diagnose it without comm
    ALLOCATE(kxarray_full(1:Nkx))
    DO ikx = 1,Nkx
     kxarray_full(ikx) = REAL(ikx-1,dp) * deltakx
    END DO
    !! Parallel distribution
    ikxs = local_nkx_offset + 1
    ikxe = ikxs + local_nkx - 1
    ALLOCATE(kxarray(ikxs:ikxe))
    local_kxmax = 0._dp
    ! Creating a grid ordered as dk*(0 1 2 3)
    DO ikx = ikxs,ikxe
      kxarray(ikx) = REAL(ikx-1,dp) * deltakx
      ! Finding kx=0
      IF (kxarray(ikx) .EQ. 0) THEN
        ikx_0 = ikx
        contains_kx0 = .true.
      ENDIF
      ! Finding local kxmax value
      IF (ABS(kxarray(ikx)) .GT. local_kxmax) THEN
        local_kxmax = ABS(kxarray(ikx))
      ENDIF
      ! Finding kxmax idx
      IF (kxarray(ikx) .EQ. kx_max) THEN
        ikx_max = ikx
        contains_kxmax = .true.
      ENDIF
    END DO
    ! Orszag 2/3 filter
    two_third_kxmax = 2._dp/3._dp*deltakx*(Nkx-1)
    ALLOCATE(AA_x(ikxs:ikxe))
    DO ikx = ikxs,ikxe
      IF ( (kxarray(ikx) .LT. two_third_kxmax) .OR. (LINEARITY .EQ. 'linear')) THEN
        AA_x(ikx) = 1._dp;
      ELSE
        AA_x(ikx) = 0._dp;
      ENDIF
    END DO
  END SUBROUTINE set_kxgrid

  SUBROUTINE set_kygrid
    USE prec_const
    USE model, ONLY: LINEARITY
    IMPLICIT NONE
    INTEGER :: i_, counter

    Nky = Ny;
    ALLOCATE(kyarray_full(1:Nky))
    ! Local data
    ! Start and END indices of grid
    ikys = 1
    ikye = Nky
    local_nky = ikye - ikys + 1
    ALLOCATE(kyarray(ikys:ikye))
    IF (Ny .EQ. 1) THEN ! "cancel" y dimension
      deltaky         = 1._dp
      kyarray(1)      = 0._dp
      iky_0           = 1
      contains_ky0    = .true.
      ky_max          = 0._dp
      iky_max         = 1
      ky_min          = 0._dp
      kyarray_full(1) = 0._dp
      local_kymax     = 0._dp
    ELSE ! Build apprpopriate grid
      deltaky     = 2._dp*PI/Ly
      ky_max      = (Ny/2)*deltakx
      ky_min      = deltaky
      ! Creating a grid ordered as dk*(0 1 2 3 -2 -1)
      local_kymax = 0._dp
      DO iky = ikys,ikye
        kyarray(iky) = deltaky*(MODULO(iky-1,Nky/2)-Nky/2*FLOOR(2.*real(iky-1)/real(Nky)))
        if (iky .EQ. Ny/2+1)     kyarray(iky) = -kyarray(iky)
        ! Finding ky=0
        IF (kyarray(iky) .EQ. 0) THEN
          iky_0 = iky
          contains_ky0 = .true.
        ENDIF
        ! Finding local kymax
        IF (ABS(kyarray(iky)) .GT. local_kymax) THEN
          local_kymax = ABS(kyarray(iky))
        ENDIF
        ! Finding kymax
        IF (kyarray(iky) .EQ. ky_max) ikx_max = ikx
      END DO
      ! Build the full grids on process 0 to diagnose it without comm
      ! ky
      DO iky = ikys,ikye
        kyarray_full(iky) = deltaky*(MODULO(iky-1,Nky/2)-Nky/2*FLOOR(2.*real(iky-1)/real(Nky)))
        IF (iky .EQ. Ny/2+1) kyarray_full(iky) = -kyarray_full(iky)
      END DO
    ENDIF
    ! Orszag 2/3 filter
    two_third_kymax = 2._dp/3._dp*deltaky*(Nky/2-1);
    ALLOCATE(AA_y(ikys:ikye))
    DO iky = ikys,ikye
      IF ( ((kyarray(iky) .GT. -two_third_kymax) .AND. &
           (kyarray(iky) .LT. two_third_kymax))   .OR. (LINEARITY .EQ. 'linear')) THEN
        AA_y(iky) = 1._dp;
      ELSE
        AA_y(iky) = 0._dp;
      ENDIF
    END DO
  END SUBROUTINE set_kygrid


  SUBROUTINE set_zgrid
    USE prec_const
    IMPLICIT NONE
    INTEGER :: i_, fid
    REAL    :: grid_shift, Lz
    INTEGER :: ip, istart, iend, in
    total_nz = Nz
    ! Length of the flux tube (in ballooning angle)
    Lz         = 2_dp*pi*Npol
    ! Z stepping (#interval = #points since periodic)
    deltaz        = Lz/REAL(Nz,dp)
    inv_deltaz    = 1._dp/deltaz
    diff_dz_coeff = (deltaz/2._dp)**2
    IF (SG) THEN
      grid_shift = deltaz/2._dp
    ELSE
      grid_shift = 0._dp
    ENDIF
    ! Build the full grids on process 0 to diagnose it without comm
    ALLOCATE(zarray_full(1:Nz))
    IF (Nz .EQ. 1) Npol = 0
    DO iz = 1,total_nz
      zarray_full(iz) = REAL(iz-1,dp)*deltaz - PI*REAL(Npol,dp)
    END DO
    !! Parallel data distribution
    ! Local data distribution
    CALL decomp1D(total_nz, num_procs_z, rank_z, izs, ize)
    local_nz = ize - izs + 1
    ! Ghosts boundaries (depend on the order of z operators)
    IF(Nz .EQ. 1) THEN
      izgs = izs;     izge = ize;
    ELSEIF(Nz .GE. 4) THEN
      izgs = izs - 2; izge = ize + 2;
    ELSE
      ERROR STOP 'Error stop: Nz is not appropriate!!'
    ENDIF
    ! List of shift and local numbers between the different processes (used in scatterv and gatherv)
    ALLOCATE(counts_nz (1:num_procs_z))
    ALLOCATE(displs_nz (1:num_procs_z))
    DO in = 0,num_procs_z-1
      CALL decomp1D(total_nz, num_procs_z, in, istart, iend)
      counts_nz(in+1) = iend-istart+1
      displs_nz(in+1) = istart-1
    ENDDO
    ! Local z array
    ALLOCATE(zarray(izgs:izge,0:1))
    DO iz = izgs,izge
      IF(iz .EQ. 0) THEN
        zarray(iz,0) = zarray_full(total_nz)
        zarray(iz,1) = zarray_full(total_nz) + grid_shift
      ELSEIF(iz .EQ. -1) THEN
        zarray(iz,0) = zarray_full(total_nz-1)
        zarray(iz,1) = zarray_full(total_nz-1) + grid_shift
      ELSEIF(iz .EQ. total_nz + 1) THEN
        zarray(iz,0) = zarray_full(1)
        zarray(iz,1) = zarray_full(1) + grid_shift
      ELSEIF(iz .EQ. total_nz + 2) THEN
        zarray(iz,0) = zarray_full(2)
        zarray(iz,1) = zarray_full(2) + grid_shift
      ELSE
        zarray(iz,0) = zarray_full(iz)
        zarray(iz,1) = zarray_full(iz) + grid_shift
      ENDIF
    ENDDO
    ! Weitghs for Simpson rule
    ALLOCATE(zweights_SR(izs:ize))
    DO iz = izs,ize
      IF((iz .EQ. 1) .OR. (iz .EQ. Nz)) THEN
        zweights_SR(iz) = 1._dp
      ELSEIF(MODULO(iz-1,2)) THEN
        zweights_SR(iz) = 4._dp
      ELSE
        zweights_SR(iz) = 2._dp
      ENDIF
    ENDDO
  END SUBROUTINE set_zgrid

  SUBROUTINE grid_outputinputs(fidres, str)
    ! Write the input parameters to the results_xx.h5 file

    USE futils, ONLY: attach

    USE prec_const
    IMPLICIT NONE

    INTEGER, INTENT(in) :: fidres
    CHARACTER(len=256), INTENT(in) :: str
    CALL attach(fidres, TRIM(str), "pmaxe", pmaxe)
    CALL attach(fidres, TRIM(str), "jmaxe", jmaxe)
    CALL attach(fidres, TRIM(str), "pmaxi", pmaxi)
    CALL attach(fidres, TRIM(str), "jmaxi", jmaxi)
    CALL attach(fidres, TRIM(str),   "Nx",   Nx)
    CALL attach(fidres, TRIM(str),   "Lx",   Lx)
    CALL attach(fidres, TRIM(str),   "Ny",   Ny)
    CALL attach(fidres, TRIM(str),   "Ly",   Ly)
    CALL attach(fidres, TRIM(str),   "Nz",   Nz)
    CALL attach(fidres, TRIM(str),   "q0",   q0)
    CALL attach(fidres, TRIM(str),"shear",shear)
    CALL attach(fidres, TRIM(str),  "eps",  eps)
    CALL attach(fidres, TRIM(str),  "Nkx",  Nkx)
    CALL attach(fidres, TRIM(str),  "Lkx",  Lkx)
    CALL attach(fidres, TRIM(str),  "Nky",  Nky)
    CALL attach(fidres, TRIM(str),  "Lky",  Lky)
    CALL attach(fidres, TRIM(str),   "SG",   SG)
  END SUBROUTINE grid_outputinputs

  FUNCTION bare(p_,j_)
    IMPLICIT NONE
    INTEGER :: bare, p_, j_
    bare = (jmaxe+1)*p_ + j_ + 1
  END FUNCTION

  FUNCTION bari(p_,j_)
    IMPLICIT NONE
    INTEGER :: bari, p_, j_
    bari = (jmaxi+1)*p_ + j_ + 1
  END FUNCTION

  SUBROUTINE decomp1D( n, numprocs, myid, s, e )
      INTEGER :: n, numprocs, myid, s, e
      INTEGER :: nlocal
      INTEGER :: deficit

      nlocal   = n / numprocs
      s        = myid * nlocal + 1
      deficit  = MOD(n,numprocs)
      s        = s + MIN(myid,deficit)
      IF (myid .LT. deficit) nlocal = nlocal + 1
      e = s + nlocal - 1
      IF (e .GT. n .OR. myid .EQ. numprocs-1) e = n
  END SUBROUTINE decomp1D

END MODULE grid