From 1dccb6f8cb18b577530a87f896c8041f93f6612f Mon Sep 17 00:00:00 2001
From: Antoine <antoine.hoffmann@epfl.ch>
Date: Mon, 11 Sep 2023 09:34:22 +0200
Subject: [PATCH] Advances in implementation of ExB shear -1D routines
 development -NL factor to correct moving grid method -still an issue with the
 1D r2c FFT -ExB not working yet.

---
 Makefile                   |   9 +-
 src/ExB_shear_flow_mod.F90 |  45 +++-
 src/fourier_mod.F90        | 429 +++++++++++++++++++++----------------
 src/grid_mod.F90           |  22 +-
 src/model_mod.F90          |   5 +
 src/nonlinear_mod.F90      |  29 +--
 6 files changed, 317 insertions(+), 222 deletions(-)

diff --git a/Makefile b/Makefile
index d2e104d2..3cbd289b 100644
--- a/Makefile
+++ b/Makefile
@@ -217,7 +217,7 @@ $(OBJDIR)/time_integration_mod.o $(OBJDIR)/utility_mod.o $(OBJDIR)/CLA_mod.o
 	$(F90) -c $(F90FLAGS) $(FPPFLAGS) $(EXTMOD) $(EXTINC) src/fields_mod.F90 -o $@
 
  $(OBJDIR)/fourier_mod.o : src/fourier_mod.F90 \
- 	 $(OBJDIR)/basic_mod.o $(OBJDIR)/prec_const_mod.o
+ 	 $(OBJDIR)/basic_mod.o $(OBJDIR)/prec_const_mod.o $(OBJDIR)/utility_mod.o
 	$(F90) -c $(F90FLAGS) $(FPPFLAGS) $(EXTMOD) $(EXTINC) src/fourier_mod.F90 -o $@
 
  $(OBJDIR)/geometry_mod.o : src/geometry_mod.F90 \
@@ -273,8 +273,8 @@ $(OBJDIR)/time_integration_mod.o $(OBJDIR)/utility_mod.o $(OBJDIR)/CLA_mod.o
 	$(F90) -c $(F90FLAGS) $(FPPFLAGS) $(EXTMOD) $(EXTINC) src/moments_eq_rhs_mod.F90 -o $@
 
  $(OBJDIR)/nonlinear_mod.o : src/nonlinear_mod.F90 \
- 	 $(OBJDIR)/array_mod.o $(OBJDIR)/basic_mod.o $(OBJDIR)/fourier_mod.o \
-	 $(OBJDIR)/fields_mod.o $(OBJDIR)/grid_mod.o $(OBJDIR)/model_mod.o\
+ 	 $(OBJDIR)/array_mod.o $(OBJDIR)/basic_mod.o $(OBJDIR)/ExB_shear_flow_mod.o \
+	 $(OBJDIR)/fourier_mod.o $(OBJDIR)/fields_mod.o $(OBJDIR)/grid_mod.o $(OBJDIR)/model_mod.o\
 	 $(OBJDIR)/prec_const_mod.o $(OBJDIR)/time_integration_mod.o
 	$(F90) -c $(F90FLAGS) $(FPPFLAGS) $(EXTMOD) $(EXTINC) src/nonlinear_mod.F90 -o $@
 
@@ -337,8 +337,7 @@ $(OBJDIR)/time_integration_mod.o $(OBJDIR)/utility_mod.o $(OBJDIR)/CLA_mod.o
 	$(F90) -c $(F90FLAGS) $(FPPFLAGS) $(EXTMOD) $(EXTINC) src/time_integration_mod.F90 -o $@
 
  $(OBJDIR)/utility_mod.o : src/utility_mod.F90  \
-   $(OBJDIR)/grid_mod.o $(OBJDIR)/basic_mod.o $(OBJDIR)/prec_const_mod.o \
-   $(OBJDIR)/time_integration_mod.o
+   $(OBJDIR)/basic_mod.o $(OBJDIR)/prec_const_mod.o $(OBJDIR)/time_integration_mod.o
 	$(F90) -c $(F90FLAGS) $(FPPFLAGS) $(EXTMOD) $(EXTINC) src/utility_mod.F90 -o $@
 
  $(OBJDIR)/CLA_mod.o : src/CLA_mod.F90  \
diff --git a/src/ExB_shear_flow_mod.F90 b/src/ExB_shear_flow_mod.F90
index 0bb11019..a204f37f 100644
--- a/src/ExB_shear_flow_mod.F90
+++ b/src/ExB_shear_flow_mod.F90
@@ -2,15 +2,17 @@ MODULE ExB_shear_flow
     ! This module contains the necessary tools to implement ExB shearing flow effects.
     ! The algorithm is taken from the presentation of Hammett et al. 2006 (APS) and
     ! it the one used in GS2.
-    USE prec_const, ONLY: xp
+    USE prec_const, ONLY: xp, imagu
 
     IMPLICIT NONE
     ! Variables
-    REAL(xp), PUBLIC, PROTECTED :: gamma_E = 0._xp ! ExB background shearing rate \gamma_E
-    REAL(xp), DIMENSION(:), ALLOCATABLE, PUBLIC, PROTECTED :: sky_ExB      ! shift of the kx modes, kx* = kx + s(ky)
-    INTEGER,  DIMENSION(:), ALLOCATABLE, PUBLIC, PROTECTED :: jump_ExB     ! jump to do to shift the kx grids
-    LOGICAL,  DIMENSION(:), ALLOCATABLE, PUBLIC, PROTECTED :: shiftnow_ExB ! Indicates if there is a line to shift
-
+    REAL(xp),   PUBLIC, PROTECTED :: gamma_E = 0._xp     ! ExB background shearing rate \gamma_E
+    REAL(xp),   PUBLIC, PROTECTED :: t0, inv_t0 = 0._xp  ! charact. shear time
+    REAL(xp),   DIMENSION(:),   ALLOCATABLE, PUBLIC, PROTECTED :: sky_ExB      ! shift of the kx modes, kx* = kx + s(ky)
+    INTEGER,    DIMENSION(:),   ALLOCATABLE, PUBLIC, PROTECTED :: jump_ExB     ! jump to do to shift the kx grids
+    LOGICAL,    DIMENSION(:),   ALLOCATABLE, PUBLIC, PROTECTED :: shiftnow_ExB ! Indicates if there is a line to shift
+    COMPLEX(xp),DIMENSION(:,:), ALLOCATABLE, PUBLIC, PROTECTED :: ExB_NL_factor! factor for nonlinear term
+    COMPLEX(xp),DIMENSION(:,:), ALLOCATABLE, PUBLIC, PROTECTED :: inv_ExB_NL_factor
     ! Routines
     PUBLIC :: Setup_ExB_shear_flow, Apply_ExB_shear_flow, Update_ExB_shear_flow
 
@@ -18,7 +20,8 @@ CONTAINS
 
     ! Setup the variables for the ExB shear
     SUBROUTINE Setup_ExB_shear_flow
-        USE grid,       ONLY : local_nky
+        USE grid,  ONLY : total_nkx, local_nky, deltakx, deltaky
+        USE model, ONLY : ExBrate  
         IMPLICIT NONE
 
         ! Setup the ExB shift
@@ -30,6 +33,20 @@ CONTAINS
         jump_ExB     = 0
         ALLOCATE(shiftnow_ExB(local_nky))
         shiftnow_ExB = .FALSE.
+
+        ! Setup nonlinear factor
+        ALLOCATE(    ExB_NL_factor(total_nkx,local_nky))
+        ALLOCATE(inv_ExB_NL_factor(total_nkx,local_nky))
+            ExB_NL_factor = 1._xp
+        inv_ExB_NL_factor = 1._xp
+        IF(ExBrate .NE. 0) THEN
+            t0     = deltakx/deltaky/ExBrate
+            inv_t0 = 1._xp/t0
+        ELSE ! avoid 1/0 division (t0 is killed anyway in this case)
+            t0     = 0._xp
+            inv_t0 = 0._xp
+        ENDIF
+
     END SUBROUTINE Setup_ExB_shear_flow
 
     ! Update according to the current ExB shear value
@@ -109,12 +126,14 @@ CONTAINS
 
     ! update the ExB shear value for the next time step
     SUBROUTINE Update_ExB_shear_flow
-        USE basic,      ONLY: dt, chrono_ExBs, start_chrono, stop_chrono
-        USE grid,       ONLY: local_nky, kyarray, inv_dkx
+        USE basic,      ONLY: dt, time, chrono_ExBs, start_chrono, stop_chrono
+        USE grid,       ONLY: local_nky, kyarray, inv_dkx, xarray,&
+                              local_nkx, ikyarray, inv_ikyarray, deltakx, deltaky, deltax
         USE model,      ONLY: ExBrate
         IMPLICIT NONE
         ! local var
-        INTEGER :: iky
+        INTEGER :: iky, ix
+        REAL(xp):: dtExBshear
         CALL start_chrono(chrono_ExBs)
         ! update the ExB shift, jumps and flags
         shiftnow_ExB = .FALSE.
@@ -125,6 +144,12 @@ CONTAINS
             ! in shiftnow_ExB and will use it in Shift_fields to avoid
             ! zero-shiftings that may be majoritary.
             shiftnow_ExB(iky) = (abs(jump_ExB(iky)) .GT. 0)
+            ! Update the ExB nonlinear factor
+            dtExBshear = time - t0*inv_ikyarray(iky)*ANINT(ikyarray(iky)*time*inv_t0,xp)
+            DO ix = 1,local_nkx
+                ExB_NL_factor(ix,iky) = EXP(-imagu*xarray(ix)*ExBrate*ikyarray(iky)*dtExBshear)
+            inv_ExB_NL_factor(ix,iky) = 1._xp/ExB_NL_factor(ix,iky)
+            ENDDO
         ENDDO
         CALL stop_chrono(chrono_ExBs)
     END SUBROUTINE Update_ExB_shear_flow
diff --git a/src/fourier_mod.F90 b/src/fourier_mod.F90
index 743e2d17..d78ce276 100644
--- a/src/fourier_mod.F90
+++ b/src/fourier_mod.F90
@@ -1,5 +1,5 @@
 MODULE fourier
-    USE prec_const, ONLY: xp, c_xp_c, c_xp_r, imagu
+    USE prec_const, ONLY: xp, c_xp_c, c_xp_r, imagu, mpi_xp_c
     use, intrinsic :: iso_c_binding
     implicit none
 
@@ -16,24 +16,29 @@ MODULE fourier
     LOGICAL, PUBLIC, PROTECTED :: FFT2D = .TRUE.
 
     !! Module accessible routines
-    PUBLIC :: init_grid_distr_and_plans, poisson_bracket_and_sum, finalize_plans
+    PUBLIC :: init_grid_distr_and_plans, poisson_bracket_and_sum, finalize_plans, apply_inv_ExB_NL_factor
 
     !! Module variables
     CHARACTER(2)                :: FFT_ALGO ! use of 2D or 1D routines
+    !! 2D fft specific variables (C interface)
     type(C_PTR)                 :: cdatar_f, cdatar_g, cdatar_c
     type(C_PTR)                 :: cdatac_f, cdatac_g, cdatac_c
     type(C_PTR) ,        PUBLIC :: planf, planb
     integer(C_INTPTR_T)         :: i, ix, iy
     integer(C_INTPTR_T), PUBLIC :: alloc_local_1, alloc_local_2
-    integer(C_INTPTR_T)         :: NX_, NY_, NY_halved 
+    integer(C_INTPTR_T)         :: NX_, NY_, NY_halved, local_nky_ 
     real   (c_xp_r), pointer, PUBLIC :: real_data_f(:,:), real_data_g(:,:), bracket_sum_r(:,:)
     complex(c_xp_c), pointer, PUBLIC :: cmpx_data_f(:,:), cmpx_data_g(:,:), bracket_sum_c(:,:)
-    !! 1D fft specific variables
-    type(C_PTR), PUBLIC :: plan_kx2x_c2c ! transform from (kx,ky) to ( x,ky) (complex to complex)
+    REAL(xp),                 PUBLIC :: inv_Nx_, inv_Ny_
+    !! 1D fft specific variables (full fortran interface)
     type(C_PTR), PUBLIC :: plan_ky2y_c2r ! transform from ( x,ky) to ( x, y) (complex to real)
     type(C_PTR), PUBLIC :: plan_y2ky_r2c ! transform from ( x, y) to ( x,ky) (real to complex)
+    type(C_PTR), PUBLIC :: plan_kx2x_c2c ! transform from (kx,ky) to ( x,ky) (complex to complex)
     type(C_PTR), PUBLIC :: plan_x2kx_c2c ! transform from ( x,ky) to (kx,ky) (complex to complex)
-    complex(c_xp_c), pointer, PUBLIC :: ky_x_data(:,:)
+    COMPLEX(xp), DIMENSION(:,:), ALLOCATABLE :: f_kxky_l ! working arrays
+    COMPLEX(xp), DIMENSION(:,:), ALLOCATABLE :: f_xky_l  ! temp. 1D ifft algo stage (local mpi)
+    REAL(xp),    DIMENSION(:,:), ALLOCATABLE :: bracket_sum_xy_g   ! poisson bracket sum in real space
+    COMPLEX(xp), DIMENSION(:,:), ALLOCATABLE :: bracket_sum_xky_g  ! final poisson bracket in complex space
 
     CONTAINS
     !******************************************************************************!
@@ -47,23 +52,17 @@ MODULE fourier
         INTEGER, INTENT(IN)  :: Nx,Ny, communicator
         INTEGER(C_INTPTR_T), INTENT(OUT) :: local_nkx_ptr,local_nkx_ptr_offset
         INTEGER(C_INTPTR_T), INTENT(OUT) :: local_nky_ptr,local_nky_ptr_offset
-        NX_ = Nx; NY_ = Ny
-        NY_halved = NY_/2 + 1
-        IF(FFT2D) THEN
-            FFT_ALGO = '2D'
-        ELSE
-            FFT_ALGO = '1D'
-        ENDIF
-        CALL speak('FFT algorithm :' // FFT_ALGO)
-
-        SELECT CASE (FFT_ALGO)
-        CASE ('2D')
-            CALL fft2D_distr_and_plans(Nx,Ny,communicator,&
-                    local_nkx_ptr,local_nkx_ptr_offset,local_nky_ptr,local_nky_ptr_offset)
-        CASE ('1D')
-            CALL fft1D_distr_and_plans(Nx,Ny,communicator,&
-                    local_nkx_ptr,local_nkx_ptr_offset,local_nky_ptr,local_nky_ptr_offset)
-        END SELECT
+        NX_        = Nx; NY_ = Ny
+        inv_Nx_    = 1._xp/NX_
+        inv_Ny_    = 1._xp/NY_        
+        NY_halved  = NY_/2 + 1
+        ! Call FFTW 2D mpi routines to distribute the data and init 2D MPI FFTW plans
+        CALL fft2D_distr_and_plans(Nx,Ny,communicator,&
+                local_nkx_ptr,local_nkx_ptr_offset,local_nky_ptr,local_nky_ptr_offset)
+        local_nky_ = local_nky_ptr ! store number of local ky in the module
+        ! Init 1D MPI FFTW plans for ExB rate correction factor
+        CALL fft1D_plans
+        ! store data distr. in the module for the poisson_bracket function
     END SUBROUTINE init_grid_distr_and_plans
 
     !------------- 2D fft initialization and mpi distribution
@@ -134,235 +133,287 @@ MODULE fourier
 
     !******************************************************************************!
     !------------- 1D initialization with balanced data distribution
-    SUBROUTINE fft1D_distr_and_plans(Nx,Ny,communicator,&
-                local_nkx_ptr,local_nkx_ptr_offset,local_nky_ptr,local_nky_ptr_offset)
+    SUBROUTINE fft1D_plans
         USE utility,  ONLY: decomp1D
         USE parallel, ONLY: num_procs_ky, rank_ky
         IMPLICIT NONE
-        INTEGER, INTENT(IN)  :: Nx,Ny, communicator
-        INTEGER(C_INTPTR_T), INTENT(OUT) :: local_nkx_ptr,local_nkx_ptr_offset
-        INTEGER(C_INTPTR_T), INTENT(OUT) :: local_nky_ptr,local_nky_ptr_offset
         ! local var
-        INTEGER :: is,ie    !start and end indices
-        INTEGER :: rank     ! rank of each 1D fourier transforms
-        INTEGER :: n        ! size of the data to fft
-        INTEGER :: howmany  ! howmany 1D fourier transforms
+        integer(C_INTPTR_T) :: rank     ! rank of each 1D fourier transforms
+        integer(C_INTPTR_T) :: n        ! size of the data to fft
+        integer(C_INTPTR_T) :: howmany  ! howmany 1D fourier transforms
         COMPLEX, DIMENSION(:,:), ALLOCATABLE:: in, out
-        INTEGER :: inembed, onembed
-        INTEGER :: istride, ostride
-        INTEGER :: idist, odist
-        INTEGER :: sign
-        INTEGER :: flags
-        COMPLEX(xp), DIMENSION(:,:), ALLOCATABLE:: fkxky
-        COMPLEX(xp), DIMENSION(:,:), ALLOCATABLE:: fxky_l, fxky_g
-        REAL(xp),    DIMENSION(:,:), ALLOCATABLE:: fxy, fkxy
-
-        ! number of kx (no data distr.)
-        local_nkx_ptr        = Nx
-        local_nkx_ptr_offset = 0
-
-        !! Distributinon of data and definition of the size of the arrays that will be worked on
-        ! balanced distribution among the processes for ky
-        CALL  decomp1D( (Ny/2+1), num_procs_ky, rank_ky, is, ie )
-        local_nky_ptr        = ie - is + 1
-        local_nky_ptr_offset = is - 1
-        ! give the rest of the points to the last process
-        if (rank_ky .EQ. num_procs_ky-1) local_nky_ptr = (Ny/2+1)-local_nky_ptr_offset
+        integer(C_INTPTR_T) :: inembed, onembed
+        integer(C_INTPTR_T) :: istride, ostride
+        integer(C_INTPTR_T) :: idist, odist
+        integer(C_INTPTR_T) :: sign
+        integer(C_INTPTR_T) :: flags
 
-        !! Allocate temporary array for plan making
-        ALLOCATE(  fkxky(Nx,local_nky_ptr))
-        ALLOCATE( fxky_l(Nx,local_nky_ptr))
-        ALLOCATE( fxky_g(Nx,Ny/2+1))
-        ALLOCATE(    fxy(Nx,Ny))
-        !! Plan of the 4 many transforms required
-        ! 1. (kx,ky) -> (x,ky), C -> C, transforms
+        !! Plan of the 4 1D many transforms required
+        !----------- 1: FFTx and inv through local ky data
+        !1.1 (kx,ky) -> (x,ky), C -> C, transforms
+        ! in:
+        ALLOCATE( f_kxky_l(NX_,local_nky_))
+        ! out:
+        ALLOCATE(  f_xky_l(NX_,local_nky_))
+        ! transform parameters
         rank    = 1              ! 1D transform
-        n       = Nx             ! all kx modes
-        howmany = local_nky_ptr  ! all local ky
-        inembed = Nx             ! all data must be transformed
-        onembed = Nx
-        idist   = Nx             ! distance between data to transforms (x columns)
-        odist   = Nx
+        n       = NX_            ! all kx modes
+        howmany = local_nky_     ! all local ky
+        inembed = NX_            ! all data must be transformed
+        onembed = NX_
+        idist   = NX_            ! distance between data to transforms (x columns)
+        odist   = NX_
         istride = 1              ! contiguous data
         ostride = 1
 #ifdef SINGLE_PRECISION
         CALL sfftw_plan_many_dft(plan_kx2x_c2c, rank, n, howmany,&
-                                 fkxky, inembed, istride, idist,&
-                                fxky_l, onembed, ostride, odist,& 
+                                 f_kxky_l, inembed, istride, idist,&
+                                  f_xky_l, onembed, ostride, odist,& 
                                  FFTW_BACKWARD, FFTW_PATIENT)                
 #else
         CALL dfftw_plan_many_dft(plan_kx2x_c2c, rank, n, howmany,&
-                                 fkxky, inembed, istride, idist,&
-                                fxky_l, onembed, ostride, odist,& 
+                                 f_kxky_l, inembed, istride, idist,&
+                                  f_xky_l, onembed, ostride, odist,& 
                                  FFTW_BACKWARD, FFTW_PATIENT)    
 #endif
-        ! 1.5 MPI communication along ky (from fxky_l to fxky_g)
-        ! 2. (x,ky) -> (x,y), C -> R, transforms
+        ! 1.2 (x,ky) -> (kx,ky), C -> C, transforms
+        ! in:  f_xky_l
+        ! out: f_kxky_l
+        ! transform parameters
         rank    = 1              ! 1D transform
-        n       = Ny             ! all ky modes
-        howmany = Nx             ! all kx
-        inembed = Ny/2+1         ! all ky must be transformed
-        onembed = Ny             ! to all y
-        idist   = 1              ! distance between two slice to transforms (y row)
-        odist   = 1
-        istride = Nx             ! non contiguous data
-        ostride = Nx
+        n       = NX_            ! all kx modes
+        howmany = local_nky_     ! all local ky
+        inembed = NX_            ! all data must be transformed
+        onembed = NX_
+        idist   = NX_            ! distance between data to transforms (x columns)
+        odist   = NX_
+        istride = 1              ! contiguous data
+        ostride = 1
 #ifdef SINGLE_PRECISION
-        CALL sfftw_plan_many_dft_c2r(plan_ky2y_c2r, rank, n, howmany,&
-                                   fxky_g, inembed, istride, idist,&
-                                      fxy, onembed, ostride, odist,& 
-                                     FFTW_BACKWARD, FFTW_PATIENT)                
+        CALL sfftw_plan_many_dft(plan_x2kx_c2c, rank, n, howmany,&
+                                 f_xky_l, inembed, istride, idist,&
+                                f_kxky_l, onembed, ostride, odist,& 
+                                FFTW_FORWARD, FFTW_PATIENT)                
 #else
-        CALL dfftw_plan_many_dft_c2r(plan_ky2y_c2r, rank, n, howmany,&
-                                   fxky_g, inembed, istride, idist,&
-                                      fxy, onembed, ostride, odist,& 
-                                     FFTW_BACKWARD, FFTW_PATIENT)    
+        CALL dfftw_plan_many_dft(plan_x2kx_c2c, rank, n, howmany,&
+                                 f_xky_l, inembed, istride, idist,&
+                                f_kxky_l, onembed, ostride, odist,& 
+                                FFTW_FORWARD, FFTW_PATIENT)    
 #endif
-        ! 3. (x,y) -> (x,ky), R -> C, transforms
+
+        !----------- 2: FFTy and inv through global ky data
+        ! 2.1 (x,y) -> (x,ky), R -> C, transforms (bplan_y in GENE)
+        ! in:
+        ALLOCATE(bracket_sum_xy_g(NX_,NY_))
+        ! out:
+        ALLOCATE(bracket_sum_xky_g(NX_,NY_/2+1))
+        ! transform parameters
         rank    = 1              ! 1D transform
-        n       = Ny             ! all y
-        howmany = Nx             ! all x
-        inembed = Ny             ! all y must be used
-        onembed = Ny/2+1         ! to all ky
+        n       = NY_             ! all y
+        howmany = NX_             ! all x
+        inembed = NY_             ! all y must be used
+        onembed = NY_/2+1         ! to all ky
         idist   = 1              ! distance between two slice to transforms (y row)
         odist   = 1
-        istride = Nx             ! non contiguous data
-        ostride = Nx
+        istride = 1              ! contiguous data
+        ostride = 1
 #ifdef SINGLE_PRECISION
         CALL sfftw_plan_many_dft_r2c(plan_y2ky_r2c, rank, n, howmany,&
-                                     fxy, inembed, istride, idist,&
-                                  fxky_g, onembed, ostride, odist,& 
-                                    FFTW_FORWARD, FFTW_PATIENT)                
+                             bracket_sum_xy_g, 0, 1, NY_,&
+                            bracket_sum_xky_g, 0, 1, NY_/2+1,& 
+                                    FFTW_PATIENT)                
 #else
         CALL dfftw_plan_many_dft_r2c(plan_y2ky_r2c, rank, n, howmany,&
-                                     fxy, inembed, istride, idist,&
-                                  fxky_g, onembed, ostride, odist,& 
-                                    FFTW_FORWARD, FFTW_PATIENT)    
+                             bracket_sum_xy_g, 0, 1, NY_,&
+                            bracket_sum_xky_g, 0, 1, NY_/2+1,& 
+                                    FFTW_PATIENT)    
 #endif
-        ! 3.5 MPI splitting along ky (from fxky_g to fxky_l)
-        ! 4. (x,ky) -> (kx,ky), C -> C, transforms
-        rank    = 1              ! 1D transform
-        n       = Nx             ! all x
-        howmany = local_nky_ptr  ! only local ky
-        inembed = Nx             ! all x must be used
-        onembed = local_nky_ptr  ! to the local ky
-        idist   = 1              ! distance between two slice to transforms (x row)
+        !-----------
+        ! 2.2 (x,ky) -> (x,y), C -> R, transforms (fplan_y in GENE)
+        ! in:  bracket_sum_xky_g
+        ! out: bracket_sum_xy_g
+        ! transform parameters
+        rank    = 1               ! 1D transform
+        n       = NY_             ! all y
+        howmany = NX_             ! all x
+        inembed = NY_/2+1         ! all y must be used
+        onembed = NY_             ! to all ky
+        idist   = 1               ! distance between two slice to transforms (y row)
         odist   = 1
-        istride = Nx             ! non contiguous data
-        ostride = Nx
+        istride = NX_             ! non contiguous data
+        ostride = NX_
 #ifdef SINGLE_PRECISION
-        CALL sfftw_plan_many_dft(plan_y2ky_r2c, rank, n, howmany,&
-                                fxky_l, inembed, istride, idist,&
-                                 fkxky, onembed, ostride, odist,& 
-                                FFTW_FORWARD, FFTW_PATIENT)                
+        CALL sfftw_plan_many_dft_c2r(plan_ky2y_c2r, rank, n, howmany,&
+                            bracket_sum_xky_g, 0, 1, NY_/2+1,&
+                             bracket_sum_xy_g, 0, 1, odist,& 
+                                    FFTW_PATIENT)                
 #else
-        CALL dfftw_plan_many_dft(plan_y2ky_r2c, rank, n, howmany,&
-                                fxky_l, inembed, istride, idist,&
-                                 fkxky, onembed, ostride, odist,& 
-                                FFTW_FORWARD, FFTW_PATIENT)    
+        CALL dfftw_plan_many_dft_c2r(plan_ky2y_c2r, rank, n, howmany,&
+                            bracket_sum_xky_g, 0, 1, NY_/2+1,&
+                             bracket_sum_xy_g, 0, 1, NY_,& 
+                                    FFTW_PATIENT)    
 #endif
-    END SUBROUTINE fft1D_distr_and_plans
-    !******************************************************************************!
 
-    !******************************************************************************!
-    ! High level routine to ifft a 2D comple array into a real one
-    ! It uses variables from the module as the plans
-    SUBROUTINE iFFT_2D_c2r
-        IMPLICIT NONE
-        SELECT CASE (FFT_ALGO)
-        CASE ('2D')
-        CASE ('1D')
-        END SELECT
-    END SUBROUTINE iFFT_2D_c2r
-    !******************************************************************************!
-
-    !******************************************************************************!
-    SUBROUTINE FFT_2D_r2c
-        IMPLICIT NONE
-        SELECT CASE (FFT_ALGO)
-        CASE ('2D')
-        CASE ('1D')
-        END SELECT
-    END SUBROUTINE FFT_2D_r2c  
+    ! Free mem (optional)
+    DEALLOCATE(f_xky_l,f_kxky_l)
+    DEALLOCATE(bracket_sum_xky_g,bracket_sum_xy_g)
+END SUBROUTINE fft1D_plans
     !******************************************************************************!
 
     !******************************************************************************!
     !!! Compute the poisson bracket to real space and sum it to the bracket_sum_r
     !   module variable (convolution theorem)
-    SUBROUTINE poisson_bracket_and_sum(ky_, kx_, inv_Ny, inv_Nx, AA_y, AA_x,&
-                                        local_nky_ptr, local_nkx_ptr, F_, G_, sum_real_)
+    SUBROUTINE poisson_bracket_and_sum( ky_, kx_, inv_Ny, inv_Nx, AA_y, AA_x,&
+                                        local_nky_ptr, local_nkx_ptr, F_, G_,&
+                                        ExB, ExB_NL_factor, sum_real_)
+        USE parallel, ONLY: my_id, num_procs_ky, comm_ky, rank_ky
         IMPLICIT NONE
-        INTEGER(C_INTPTR_T),                  INTENT(IN) :: local_nkx_ptr,local_nky_ptr
-        REAL(xp),                             INTENT(IN) :: inv_Nx, inv_Ny
-        REAL(xp), DIMENSION(local_nky_ptr),   INTENT(IN) :: ky_, AA_y, AA_x
+        INTEGER(C_INTPTR_T),                  INTENT(IN)    :: local_nkx_ptr,local_nky_ptr
+        REAL(xp),                             INTENT(IN)    :: inv_Nx, inv_Ny
+        REAL(xp), DIMENSION(local_nky_ptr),   INTENT(IN)    :: ky_, AA_y, AA_x
         REAL(xp), DIMENSION(local_nky_ptr,local_nkx_ptr), INTENT(IN) :: kx_
-        COMPLEX(c_xp_c), DIMENSION(local_nky_ptr,local_nkx_ptr) &
-                                                         :: F_(:,:), G_(:,:)
-        real(c_xp_r), pointer,             INTENT(INOUT) :: sum_real_(:,:)
+        COMPLEX(c_xp_c), DIMENSION(local_nky_ptr,local_nkx_ptr), &
+                                            INTENT(IN)      :: F_, G_
+        COMPLEX(xp), DIMENSION(local_nkx_ptr,local_nky_ptr), &
+                                            INTENT(IN)      :: ExB_NL_factor
+        LOGICAL, INTENT(IN) :: ExB
+        real(c_xp_r), pointer,              INTENT(INOUT)   :: sum_real_(:,:)
+        ! local variables
         INTEGER :: ikx,iky
-        !! Anti aliasing
-        DO ikx = 1,local_nkx_ptr
-            F_(:,ikx) = F_(:,ikx)*AA_y(:)*AA_x(ikx)
-            G_(:,ikx) = G_(:,ikx)*AA_y(:)*AA_x(ikx)
-        ENDDO
-        !------------------------------------------------------------------
-
-        !-------------------- First term df/dx x dg/dy --------------------
+        COMPLEX(xp), DIMENSION(local_nkx_ptr,local_nky_ptr) :: ikxF, ikyG, ikyF, ikxG
+        REAL(xp),    DIMENSION(NX_,2*(NY_/2 + 1)) :: ddxf, ddyg, ddyf, ddxg
+        
+        ! Build the fields to convolve
+        ! Store df/dx, dg/dy and df/dy, dg/dx
         DO ikx = 1,local_nkx_ptr
         DO iky = 1,local_nky_ptr
-            cmpx_data_f(ikx,iky) = imagu*kx_(iky,ikx)*F_(iky,ikx)
-            cmpx_data_g(ikx,iky) = imagu*ky_(iky)    *G_(iky,ikx)
+            ikxF(ikx,iky) = imagu*kx_(iky,ikx)*F_(iky,ikx)*AA_y(iky)*AA_x(ikx)
+            ikyG(ikx,iky) = imagu*ky_(iky)    *G_(iky,ikx)*AA_y(iky)*AA_x(ikx)
+            ikyF(ikx,iky) = imagu*ky_(iky)    *F_(iky,ikx)*AA_y(iky)*AA_x(ikx)
+            ikxG(ikx,iky) = imagu*kx_(iky,ikx)*G_(iky,ikx)*AA_y(iky)*AA_x(ikx)
         ENDDO
         ENDDO
-
-        !CALL iFFT_2D_c2r(cmpx_data_f,real_data_f)
-
+        IF(ExB) THEN 
+            ! Apply the ExB shear correction factor exp(ixkySJdT)
+            CALL apply_ExB_NL_factor(ikxF,ExB_NL_factor)
+            CALL apply_ExB_NL_factor(ikyG,ExB_NL_factor)
+            CALL apply_ExB_NL_factor(ikyF,ExB_NL_factor)
+            CALL apply_ExB_NL_factor(ikxG,ExB_NL_factor)
+        ENDIF
+        !-------------------- First term df/dx x dg/dy --------------------
 #ifdef SINGLE_PRECISION
-        call fftwf_mpi_execute_dft_c2r(planb, cmpx_data_f, real_data_f)
-        call fftwf_mpi_execute_dft_c2r(planb, cmpx_data_g, real_data_g)
+        call fftwf_mpi_execute_dft_c2r(planb, ikxF, real_data_f)
+        call fftwf_mpi_execute_dft_c2r(planb, ikyG, real_data_g)
 #else
-        call fftw_mpi_execute_dft_c2r(planb, cmpx_data_f, real_data_f)
-        call fftw_mpi_execute_dft_c2r(planb, cmpx_data_g, real_data_g)
+        call fftw_mpi_execute_dft_c2r(planb, ikxF, real_data_f)
+        call fftw_mpi_execute_dft_c2r(planb, ikyG, real_data_g)
 #endif
         sum_real_ = sum_real_ + real_data_f*real_data_g*inv_Ny*inv_Nx
         !--------------------------------------------------------------------
 
         !-------------------- Second term -df/dy x dg/dx --------------------
-        DO ikx = 1,local_nkx_ptr
-        DO iky = 1,local_nky_ptr
-            cmpx_data_f(ikx,iky) = imagu*ky_(iky)    *F_(iky,ikx)
-            cmpx_data_g(ikx,iky) = imagu*kx_(iky,ikx)*G_(iky,ikx)
-        ENDDO
-        ENDDO
 #ifdef SINGLE_PRECISION
-        call fftwf_mpi_execute_dft_c2r(planb, cmpx_data_f, real_data_f)
-        call fftwf_mpi_execute_dft_c2r(planb, cmpx_data_g, real_data_g)
+        call fftwf_mpi_execute_dft_c2r(planb, ikyF, real_data_f)
+        call fftwf_mpi_execute_dft_c2r(planb, ikxG, real_data_g)
 #else
-        call fftw_mpi_execute_dft_c2r(planb, cmpx_data_f, real_data_f)
-        call fftw_mpi_execute_dft_c2r(planb, cmpx_data_g, real_data_g)
+        call fftw_mpi_execute_dft_c2r(planb, ikyF, real_data_f)
+        call fftw_mpi_execute_dft_c2r(planb, ikxG, real_data_g)
 #endif
         sum_real_ = sum_real_ - real_data_f*real_data_g*inv_Ny*inv_Nx
     END SUBROUTINE poisson_bracket_and_sum
+    !******************************************************************************!
+
+    !******************************************************************************!
+    ! Apply the exp(xkySJdt) factor to the Poisson bracket fields 
+    ! (see Mcmillan et al. 2019)
+    SUBROUTINE apply_ExB_NL_factor(fkxky,ExB_NL_factor)
+        IMPLICIT NONE
+        COMPLEX(xp), DIMENSION(NX_,local_nky_), INTENT(INOUT)  :: fkxky
+        COMPLEX(xp), DIMENSION(NX_,local_nky_), INTENT(IN)     :: ExB_NL_factor
+        ! local variables
+        COMPLEX(xp), DIMENSION(NX_,local_nky_) :: fxky
+        CALL iFFT_kxky_to_xky(fkxky,fxky)
+        fxky = fxky*ExB_NL_factor*inv_Nx_
+        CALL FFT_xky_to_kxky(fxky,fkxky)
+    END SUBROUTINE apply_ExB_NL_factor
+
+    SUBROUTINE apply_inv_ExB_NL_factor(fxy,inv_ExB_NL_factor)
+        IMPLICIT NONE
+        !REAL(xp),    DIMENSION(NX_,2*(NY_/2+1)), INTENT(INOUT) :: fxy
+        real(c_xp_r), pointer,                   INTENT(INOUT) :: fxy(:,:)
+        COMPLEX(xp), DIMENSION(NX_,local_nky_),  INTENT(IN)    :: inv_ExB_NL_factor
+        ! local variables
+        COMPLEX(xp), DIMENSION(NX_,local_nky_) :: fxky
+        bracket_sum_xy_g = fxy
+        CALL FFT_xy_to_xky(bracket_sum_xy_g,fxky)
+        fxky = fxky*inv_ExB_NL_factor
+        CALL iFFT_xky_to_xy(fxky,bracket_sum_xy_g)
+        fxy = bracket_sum_xy_g
+    END SUBROUTINE apply_inv_ExB_NL_factor
+
+    !******************************************************************************!
+    ! High level FFT routines
+    SUBROUTINE iFFT_kxky_to_xky(in_kxky,out_xky)
+        IMPLICIT NONE
+        COMPLEX(xp), DIMENSION(NX_,local_nky_), INTENT(IN)  :: in_kxky
+        COMPLEX(xp), DIMENSION(NX_,local_nky_), INTENT(OUT) :: out_xky
+#ifdef SINGLE_PRECISION
+        CALL sfftw_execute_dft(plan_kx2x_c2c, in_kxky, out_xky)
+#else 
+        CALL dfftw_execute_dft(plan_kx2x_c2c, in_kxky, out_xky)
+#endif
+    END SUBROUTINE iFFT_kxky_to_xky
+
+    SUBROUTINE FFT_xky_to_kxky(in_xky,out_kxky)
+        IMPLICIT NONE
+        COMPLEX(xp), DIMENSION(NX_,local_nky_), INTENT(IN)  :: in_xky
+        COMPLEX(xp), DIMENSION(NX_,local_nky_), INTENT(OUT) :: out_kxky
+#ifdef SINGLE_PRECISION
+        CALL sfftw_execute_dft(plan_x2kx_c2c, in_xky, out_kxky)
+#else 
+        CALL dfftw_execute_dft(plan_x2kx_c2c, in_xky, out_kxky)
+#endif
+    END SUBROUTINE FFT_xky_to_kxky
+
+    SUBROUTINE FFT_xy_to_xky(in_xy,out_xky)
+        IMPLICIT NONE
+        REAL(xp),    DIMENSION(NX_,NY_),     INTENT(IN)  :: in_xy
+        !real(c_xp_r), pointer,               INTENT(IN)  :: in_xy(:,:)
+        COMPLEX(xp), DIMENSION(NX_,NY_/2+1), INTENT(OUT) :: out_xky
+#ifdef SINGLE_PRECISION
+        CALL sfftw_execute_dft_r2c(plan_y2ky_r2c, in_xy, out_xky)
+#else 
+        CALL dfftw_execute_dft_r2c(plan_y2ky_r2c, in_xy, out_xky)
+#endif
+    END SUBROUTINE FFT_xy_to_xky
 
+    SUBROUTINE iFFT_xky_to_xy(in_xky,out_xy)
+        IMPLICIT NONE
+        COMPLEX(xp), DIMENSION(NX_,NY_/2+1), INTENT(IN)  :: in_xky
+        REAL(xp),    DIMENSION(NX_,NY_),     INTENT(OUT) :: out_xy
+        !real(c_xp_r), pointer,               INTENT(OUT) :: out_xy(:,:)
+#ifdef SINGLE_PRECISION
+        CALL sfftw_execute_dft_c2r(plan_ky2y_c2r, in_xky, out_xy)
+#else 
+        CALL dfftw_execute_dft_c2r(plan_ky2y_c2r, in_xky, out_xy)
+#endif
+    END SUBROUTINE iFFT_xky_to_xy
+
+    !******************************************************************************!
 
     SUBROUTINE finalize_plans
         USE basic, ONLY: speak
         IMPLICIT NONE
         CALL speak('..plan Destruction.')
-
-        SELECT CASE (FFT_ALGO)
-        CASE ('2D')
-            call fftw_destroy_plan(planb)
-            call fftw_destroy_plan(planf)
-            call fftw_mpi_cleanup()
-            call fftw_free(cdatar_f)
-            call fftw_free(cdatar_g)
-            call fftw_free(cdatar_c)
-            call fftw_free(cdatac_f)
-            call fftw_free(cdatac_g)
-            call fftw_free(cdatac_c)
-        CASE ('1D')
-        END SELECT
+        call fftw_destroy_plan(planb)
+        call fftw_destroy_plan(planf)
+        call fftw_mpi_cleanup()
+        call fftw_free(cdatar_f)
+        call fftw_free(cdatar_g)
+        call fftw_free(cdatar_c)
+        call fftw_free(cdatac_f)
+        call fftw_free(cdatac_g)
+        call fftw_free(cdatac_c)
     END SUBROUTINE finalize_plans
 
 END MODULE fourier
diff --git a/src/grid_mod.F90 b/src/grid_mod.F90
index ae65d843..70fd7e09 100644
--- a/src/grid_mod.F90
+++ b/src/grid_mod.F90
@@ -25,7 +25,9 @@ MODULE grid
   INTEGER,  DIMENSION(:),   ALLOCATABLE, PUBLIC,PROTECTED :: jarray,  jarray_full
   REAL(xp), DIMENSION(:,:), ALLOCATABLE, PUBLIC,PROTECTED :: kxarray ! ExB shear makes it ky dependant
   REAL(xp), DIMENSION(:),   ALLOCATABLE, PUBLIC,PROTECTED :: kxarray_full
+  REAL(xp), DIMENSION(:),   ALLOCATABLE, PUBLIC,PROTECTED :: xarray
   REAL(xp), DIMENSION(:),   ALLOCATABLE, PUBLIC,PROTECTED :: kyarray, kyarray_full
+  REAL(xp), DIMENSION(:),   ALLOCATABLE, PUBLIC,PROTECTED :: ikyarray, inv_ikyarray !mode indices arrays
   REAL(xp), DIMENSION(:,:), ALLOCATABLE, PUBLIC,PROTECTED :: zarray
   REAL(xp), DIMENSION(:),   ALLOCATABLE, PUBLIC,PROTECTED :: zarray_full
   REAL(xp), DIMENSION(:,:,:,:), ALLOCATABLE, PUBLIC,PROTECTED :: kparray !kperp
@@ -72,7 +74,7 @@ MODULE grid
   integer(C_INTPTR_T), PUBLIC,PROTECTED :: local_nkx_ptr_offset, local_nky_ptr_offset
   ! Grid spacing and limits
   REAL(xp), PUBLIC, PROTECTED ::  deltap, deltaz, inv_deltaz, inv_dkx
-  REAL(xp), PUBLIC, PROTECTED ::  deltakx, deltaky, kx_max, ky_max, kx_min, ky_min!, kp_max
+  REAL(xp), PUBLIC, PROTECTED ::  deltakx, deltaky, deltax, kx_max, ky_max, kx_min, ky_min!, kp_max
   INTEGER , PUBLIC, PROTECTED ::  local_pmin,  local_pmax
   INTEGER , PUBLIC, PROTECTED ::  local_jmin,  local_jmax
   REAL(xp), PUBLIC, PROTECTED ::  local_kymin, local_kymax
@@ -214,6 +216,8 @@ CONTAINS
     local_nky_offset = local_nky_ptr_offset
     ALLOCATE(kyarray_full(Nky))
     ALLOCATE(kyarray(local_nky))
+    ALLOCATE(ikyarray(local_nky))
+    ALLOCATE(inv_ikyarray(local_nky))
     ALLOCATE(AA_y(local_nky))
     !!---------------- RADIAL KX INDICES (not parallelized)
     Nkx       = Nx
@@ -223,8 +227,9 @@ CONTAINS
     local_nkx_ptr = ikxe - ikxs + 1
     local_nkx     = ikxe - ikxs + 1
     local_nkx_offset = ikxs - 1
-    ALLOCATE(kxarray(local_nky,local_Nkx))
     ALLOCATE(kxarray_full(total_nkx))
+    ALLOCATE(kxarray(local_nky,local_Nkx))
+    ALLOCATE(xarray(total_nkx))
     ALLOCATE(AA_x(local_nkx))
     !!---------------- PARALLEL Z GRID (parallelized)
     total_nz = Nz
@@ -432,10 +437,17 @@ CONTAINS
       ! indexation (|1 2 3||1 2 3|... local_nky|)
       IF(Ny .EQ. 1) THEN
         kyarray(iky)      = deltaky
+        kyarray(iky)      = iky-1        
         kyarray_full(iky) = deltaky
         SINGLE_KY         = .TRUE.
       ELSE
-        kyarray(iky) = kyarray_full(iky+local_nky_offset)
+        kyarray(iky)      = kyarray_full(iky+local_nky_offset)
+        ikyarray(iky)     = REAL((iky+local_nky_offset)-1,xp)
+        IF(ikyarray(iky) .GT. 0) THEN
+          inv_ikyarray(iky) = 1._xp/ikyarray(iky)
+        ELSE
+          inv_ikyarray(iky) = 0._xp
+        ENDIF
       ENDIF
       ! Finding kx=0
       IF (kyarray(iky) .EQ. 0) THEN
@@ -490,12 +502,14 @@ CONTAINS
       Lx = Lx_adapted*Nexc
     ENDIF
     deltakx = 2._xp*PI/Lx
-    inv_dkx = 1._xp/deltakx   
+    inv_dkx = 1._xp/deltakx 
+    deltax  = Lx/REAL(Nx,xp) ! periodic donc pas Lx/(Nx-1)  
     IF(MODULO(total_nkx,2) .EQ. 0) THEN ! Even number of kx (-2 -1 0 1 2 3)
       ! Creating a grid ordered as dk*(0 1 2 3 -2 -1)
       DO ikx = 1,total_nkx
         kxarray_full(ikx) = deltakx*REAL(MODULO(ikx-1,total_nkx/2)-(total_nkx/2)*FLOOR(2.*real(ikx-1)/real(total_nkx)),xp)
         IF (ikx .EQ. total_nkx/2+1) kxarray_full(ikx) = -kxarray_full(ikx)
+        xarray(ikx) = REAL(ikx-1,xp)*deltax
       END DO
       kx_max = MAXVAL(kxarray_full)!(total_nkx/2)*deltakx
       kx_min = MINVAL(kxarray_full)!-kx_max+deltakx
diff --git a/src/model_mod.F90 b/src/model_mod.F90
index f468a0cd..4072421a 100644
--- a/src/model_mod.F90
+++ b/src/model_mod.F90
@@ -31,6 +31,7 @@ MODULE model
   ! Auxiliary variable
   LOGICAL,  PUBLIC, PROTECTED ::      EM =  .true.    ! Electromagnetic effects flag
   LOGICAL,  PUBLIC, PROTECTED ::  MHD_PD =  .true.    ! MHD pressure drift
+  LOGICAL,  PUBLIC, PROTECTED ::     ExB =  .false.   ! presence of ExB background shearing rate
   ! Removes Landau damping in temperature and higher equation (Ivanov 2022)
   LOGICAL,  PUBLIC, PROTECTED :: RM_LD_T_EQ = .false.
   ! Flag to force the reality condition symmetry for the kx at ky=0
@@ -78,6 +79,10 @@ CONTAINS
       EM = .FALSE.
     ENDIF
 
+    IF(ExBrate .GT. 0) THEN
+      ExB = .TRUE.
+    ENDIF
+
   END SUBROUTINE model_readinputs
 
   SUBROUTINE model_outputinputs(fid)
diff --git a/src/nonlinear_mod.F90 b/src/nonlinear_mod.F90
index be57b22f..33fed8f1 100644
--- a/src/nonlinear_mod.F90
+++ b/src/nonlinear_mod.F90
@@ -1,18 +1,20 @@
 MODULE nonlinear
   USE array,       ONLY : dnjs, Sapj, kernel
-  USE fourier,     ONLY : bracket_sum_r, bracket_sum_c, planf, planb, poisson_bracket_and_sum
+  USE fourier,     ONLY : bracket_sum_r, bracket_sum_c, planf, planb, poisson_bracket_and_sum,&
+                          apply_inv_ExB_NL_factor
   USE fields,      ONLY : phi, psi, moments
-  USE grid,        ONLY: local_na, &
+  USE grid,        ONLY : local_na, &
                          local_np,ngp,parray,pmax,&
                          local_nj,ngj,jarray,jmax, local_nj_offset, dmax,&
                          kyarray, AA_y, local_nky_ptr, local_nky_ptr_offset,inv_Ny,&
                          local_nkx_ptr,kxarray, AA_x, inv_Nx,&
                          local_nz,ngz,zarray,nzgrid, deltakx, iky0, contains_kx0, contains_ky0
-  USE model,       ONLY : LINEARITY, EM, ikxZF, ZFamp
+  USE model,       ONLY : LINEARITY, EM, ikxZF, ZFamp, ExB
   USE closure,     ONLY : evolve_mom, nmaxarray
   USE prec_const,  ONLY : xp
   USE species,     ONLY : sqrt_tau_o_sigma
   USE time_integration, ONLY : updatetlevel
+  USE ExB_shear_flow,   ONLY : ExB_NL_factor, inv_ExB_NL_factor
   use, intrinsic :: iso_c_binding
 
   IMPLICIT NONE
@@ -20,8 +22,6 @@ MODULE nonlinear
   INCLUDE 'fftw3-mpi.f03'
 
   COMPLEX(xp), DIMENSION(:,:), ALLOCATABLE :: F_cmpx, G_cmpx
-  COMPLEX(xp), DIMENSION(:,:), ALLOCATABLE :: Fx_cmpx, Gy_cmpx
-  COMPLEX(xp), DIMENSION(:,:), ALLOCATABLE :: Fy_cmpx, Gx_cmpx, F_conv_G
   INTEGER :: in, is, p_int, j_int, n_int
   INTEGER :: smax
   REAL(xp):: sqrt_p, sqrt_pp1
@@ -33,13 +33,6 @@ SUBROUTINE nonlinear_init
   IMPLICIT NONE
   ALLOCATE( F_cmpx(local_nky_ptr,local_nkx_ptr))
   ALLOCATE( G_cmpx(local_nky_ptr,local_nkx_ptr))
-
-  ALLOCATE(Fx_cmpx(local_nky_ptr,local_nkx_ptr))
-  ALLOCATE(Gy_cmpx(local_nky_ptr,local_nkx_ptr))
-  ALLOCATE(Fy_cmpx(local_nky_ptr,local_nkx_ptr))
-  ALLOCATE(Gx_cmpx(local_nky_ptr,local_nkx_ptr))
-
-  ALLOCATE(F_conv_G(local_nky_ptr,local_nkx_ptr))
 END SUBROUTINE nonlinear_init
 
 SUBROUTINE compute_Sapj
@@ -105,7 +98,9 @@ SUBROUTINE compute_nonlinear
                   dnjs(in,ij,is) * moments(ia,ipi,isi,:,:,izi,updatetlevel)
               ENDDO s1
               ! this function adds its result to bracket_sum_r
-                CALL poisson_bracket_and_sum(kyarray,kxarray,inv_Ny,inv_Nx,AA_y,AA_x,local_nky_ptr,local_nkx_ptr,F_cmpx,G_cmpx,bracket_sum_r)
+                CALL poisson_bracket_and_sum( kyarray,kxarray,inv_Ny,inv_Nx,AA_y,AA_x,&
+                                              local_nky_ptr,local_nkx_ptr,F_cmpx,G_cmpx,&
+                                              ExB, ExB_NL_factor, bracket_sum_r)
   !-----------!! ELECTROMAGNETIC CONTRIBUTION -sqrt(tau)/sigma*{Sum_s dnjs [sqrt(p+1)Nap+1s + sqrt(p)Nap-1s], Kernel psi}
               IF(EM) THEN
                 ! First convolution terms
@@ -119,9 +114,15 @@ SUBROUTINE compute_nonlinear
                                     +sqrt_p  *moments(ia,ipi-1,isi,:,:,izi,updatetlevel))
                 ENDDO s2
                 ! this function adds its result to bracket_sum_r
-                CALL poisson_bracket_and_sum(kyarray,kxarray,inv_Ny,inv_Nx,AA_y,AA_x,local_nky_ptr,local_nkx_ptr,F_cmpx,G_cmpx,bracket_sum_r)
+                CALL poisson_bracket_and_sum( kyarray,kxarray,inv_Ny,inv_Nx,AA_y,AA_x,&
+                                              local_nky_ptr,local_nkx_ptr,F_cmpx,G_cmpx,&
+                                              ExB, ExB_NL_factor,bracket_sum_r)
               ENDIF
             ENDDO n
+            ! Apply the ExB shearing rate factor before going back to k-space
+            IF (ExB) THEN
+              CALL apply_inv_ExB_NL_factor(bracket_sum_r,inv_ExB_NL_factor)
+            ENDIF
             ! Put the real nonlinear product back into k-space
 #ifdef SINGLE_PRECISION
             call fftwf_mpi_execute_dft_r2c(planf, bracket_sum_r, bracket_sum_c)
-- 
GitLab