From 800456e468c57c45b041a1a7af85902c2903b4be Mon Sep 17 00:00:00 2001
From: Antoine Hoffmann <antoine.hoffmann@epfl.ch>
Date: Fri, 31 Mar 2023 13:29:44 +0200
Subject: [PATCH] Closure is better organized

---
 Makefile                       |   2 +-
 matlab/load/load_params.m      |  12 ++
 src/advance_field_mod.F90      |   7 +-
 src/auxval.F90                 |  11 +-
 src/closure_mod.F90            | 242 +++++++++++++++++++++++----------
 src/cosolver_interface_mod.F90 |   5 +-
 src/diagnose.F90               |  38 +++---
 src/grid_mod.F90               |   1 -
 src/model_mod.F90              |   6 +-
 src/moments_eq_rhs_mod.F90     |   5 +-
 src/nonlinear_mod.F90          |  50 +++----
 src/processing_mod.F90         |  23 +---
 src/readinputs.F90             |   4 +
 wk/fast_analysis.m             |  23 +++-
 14 files changed, 272 insertions(+), 157 deletions(-)

diff --git a/Makefile b/Makefile
index 659e365b..bd218c82 100644
--- a/Makefile
+++ b/Makefile
@@ -103,7 +103,7 @@ $(OBJDIR)/time_integration_mod.o $(OBJDIR)/utility_mod.o
  $(OBJDIR)/advance_field_mod.o : src/advance_field_mod.F90 \
    $(OBJDIR)/grid_mod.o $(OBJDIR)/array_mod.o $(OBJDIR)/initial_par_mod.o \
 	 $(OBJDIR)/prec_const_mod.o $(OBJDIR)/time_integration_mod.o $(OBJDIR)/basic_mod.o \
-	 $(OBJDIR)/fields_mod.o $(OBJDIR)/model_mod.o
+	 $(OBJDIR)/fields_mod.o $(OBJDIR)/model_mod.o $(OBJDIR)/closure_mod.o
 	$(F90) -c $(F90FLAGS) $(FPPFLAGS) $(EXTMOD) $(EXTINC) src/advance_field_mod.F90 -o $@
 
  $(OBJDIR)/array_mod.o : src/array_mod.F90 \
diff --git a/matlab/load/load_params.m b/matlab/load/load_params.m
index 8fb8f973..56e12861 100644
--- a/matlab/load/load_params.m
+++ b/matlab/load/load_params.m
@@ -14,8 +14,20 @@ DATA.JMAX    = h5readatt(filename,'/data/input/grid','jmax');
 DATA.Nx      = h5readatt(filename,'/data/input/grid','Nx');
 DATA.Ny      = h5readatt(filename,'/data/input/grid','Ny');
 DATA.L       = h5readatt(filename,'/data/input/grid','Lx');
+try
 DATA.CLOS    = h5readatt(filename,'/data/input/model','CLOS');
 DATA.NL_CLOS = h5readatt(filename,'/data/input/model','NL_CLOS');
+catch
+    try
+        DATA.ha_cl   = h5readatt(filename,'/data/input/closure','hierarchy_closure');
+        DATA.CLOS    = h5readatt(filename,'/data/input/closure','dmax');   
+        DATA.nl_cl   = h5readatt(filename,'/data/input/closure','nonlinear_closure');   
+        DATA.NL_CLOS = h5readatt(filename,'/data/input/closure','nmax');   
+    catch
+        DATA.CLOS = 99;
+        DATA.NL_CLOS = 99;
+    end
+end
 DATA.Na      = h5readatt(filename,'/data/input/model','Na');
 DATA.NU      = h5readatt(filename,'/data/input/model','nu');
 DATA.MUp     = h5readatt(filename,'/data/input/model','mu_p');
diff --git a/src/advance_field_mod.F90 b/src/advance_field_mod.F90
index 173b1966..9ee8c4be 100644
--- a/src/advance_field_mod.F90
+++ b/src/advance_field_mod.F90
@@ -15,8 +15,8 @@ CONTAINS
 
     USE basic,  ONLY: dt
     USE grid,   ONLY:local_na,local_np,local_nj,local_nky,local_nkx,local_nz,&
-                     ngp, ngj, ngz, dmax, parray, jarray, dmax
-    USE model,  ONLY: CLOS
+                     ngp, ngj, ngz
+    USE closure,ONLY: evolve_mom
     use fields, ONLY: moments
     use array,  ONLY: moments_rhs
     USE time_integration, ONLY: updatetlevel, A_E, b_E, ntimelevel
@@ -34,6 +34,7 @@ CONTAINS
       DO ip    =1,local_np
         ipi = ip+ngp/2
       DO ia    =1,local_na
+        IF( evolve_mom(ipi,iji) )&
         moments(ia,ipi,iji,iky,ikx,izi,1) = moments(ia,ipi,iji,iky,ikx,izi,1) &
                + dt*b_E(istage)*moments_rhs(ia,ip,ij,iky,ikx,iz,istage)
       END DO
@@ -55,7 +56,7 @@ CONTAINS
       DO ip    =1,local_np
         ipi = ip+ngp/2
       DO ia    =1,local_na
-        IF((CLOS .NE. 1) .OR. (parray(ipi)+2*jarray(iji) .LE. dmax))&
+        IF( evolve_mom(ipi,iji) )&
         moments(ia,ipi,iji,iky,ikx,izi,updatetlevel) = moments(ia,ipi,iji,iky,ikx,izi,updatetlevel) + &
                           dt*A_E(updatetlevel,istage)*moments_rhs(ia,ip,ij,iky,ikx,iz,istage)
       END DO
diff --git a/src/auxval.F90 b/src/auxval.F90
index f9722ce4..fb8e13ca 100644
--- a/src/auxval.F90
+++ b/src/auxval.F90
@@ -9,6 +9,7 @@ subroutine auxval
   use prec_const
   USE numerics
   USE geometry
+  USE closure, ONLY: set_closure_model, hierarchy_closure
   USE parallel, ONLY: init_parallel_var, my_id, num_procs, num_procs_p, num_procs_z, num_procs_ky, rank_p, rank_ky, rank_z
   USE processing, ONLY: init_process
   IMPLICIT NONE
@@ -39,6 +40,8 @@ subroutine auxval
 
   CALL build_dv4Hp_table ! precompute the hermite fourth derivative table
 
+  CALL set_closure_model ! set the closure scheme in use
+
   !! Display parallel settings
   CALL mpi_barrier(MPI_COMM_WORLD, ierr)
   DO i_ = 0,num_procs-1
@@ -67,7 +70,11 @@ subroutine auxval
   ENDDO
   CALL mpi_barrier(MPI_COMM_WORLD, ierr)
 
-  IF((CLOS .EQ. 1)) &
-    CALL speak('Closure = 1 -> Maximal Napj degree is min(Pmax,2*Jmax+1): D = '// str(dmax))
+  SELECT CASE(hierarchy_closure)
+  CASE('truncation')
+    CALL speak('Truncation closure')
+  CASE('max_degree')
+    CALL speak('Max degree closure -> Maximal Napj degree is D = '// str(dmax))
+  END SELECT
 
 END SUBROUTINE auxval
diff --git a/src/closure_mod.F90 b/src/closure_mod.F90
index 9180feeb..9e18d704 100644
--- a/src/closure_mod.F90
+++ b/src/closure_mod.F90
@@ -1,88 +1,192 @@
 module closure
 ! Contains the routines to define closures
 IMPLICIT NONE
-
-PUBLIC :: apply_closure_model
+! Input
+CHARACTER(len=32),  PUBLIC, PROTECTED :: hierarchy_closure = 'truncation'! closure for the moment hierarchy
+INTEGER,            PUBLIC, PROTECTED :: dmax              = -1          ! max evolved degree moment
+CHARACTER(len=32),  PUBLIC, PROTECTED :: nonlinear_closure = 'truncation'! nonlinear truncation method
+INTEGER,            PUBLIC, PROTECTED :: nmax              = 0           ! upperbound of the nonlinear sum over n
+!  Attributes
+LOGICAL,DIMENSION(:,:), ALLOCATABLE, PUBLIC, PROTECTED :: evolve_mom     ! array that sets if a moment has to be evolved or not
+INTEGER,DIMENSION(:),   ALLOCATABLE, PUBLIC, PROTECTED :: nmaxarray      ! upperbound of the nonlinear sum over n (depend on j)
+ 
+PUBLIC :: closure_readinputs, set_closure_model, apply_closure_model
 
 CONTAINS
 
-! Positive Oob indices are approximated with a model
-SUBROUTINE apply_closure_model
-  USE prec_const, ONLY: xp
-  USE model,      ONLY: CLOS
-  USE grid,       ONLY: local_nj,ngj, jarray,&
-                        local_np,ngp, parray, dmax
-  USE fields,     ONLY: moments
-  USE time_integration, ONLY: updatetlevel
+SUBROUTINE closure_readinputs
+  USE basic, ONLY: lu_in
   IMPLICIT NONE
-  INTEGER ::ij,ip,ia
-  IF (CLOS .EQ. 0) THEN
-    ! zero truncation, An+1=0 for n+1>nmax only
-    CALL ghosts_upper_truncation
-  ELSEIF (CLOS .EQ. 1) THEN
-    ! truncation at highest fully represented kinetic moment
-    ! e.g. Dmax = 3 means
-    ! only Napj s.t. p+2j <= 3 are evolved
-    ! -> (p,j) allowed are (0,0),(1,0),(0,1),(2,0),(1,1),(3,0)
-    ! =>> Dmax = min(Pmax,2*Jmax+1)
-    j: DO ij = 1,local_nj+ngj
-    p: DO ip = 1,local_np+ngp
-      IF ( parray(ip)+2*jarray(ij) .GT. dmax) THEN
-        moments(ia,ip,ij,:,:,:,updatetlevel) = 0._xp
-      ENDIF
-    ENDDO p
-    ENDDO j
-  ELSE
-    ERROR STOP '>> ERROR << Closure scheme not found '
-  ENDIF
-  CALL ghosts_lower_truncation
-END SUBROUTINE apply_closure_model
+  NAMELIST /CLOSURE_PAR/ hierarchy_closure, dmax, nonlinear_closure, nmax
+  READ(lu_in,closure_par)
+END SUBROUTINE
 
-! Positive Oob indices are approximated with a model
-SUBROUTINE ghosts_upper_truncation
-  USE prec_const, ONLY: xp
-  USE grid,       ONLY: local_np,ngp,local_pmax, pmax,&
-                        local_nj,ngj,local_jmax, jmax
-  USE fields,           ONLY: moments
-  USE time_integration, ONLY: updatetlevel
+SUBROUTINE set_closure_model
+  USE grid, ONLY: local_np, ngp, local_nj, ngj, parray, jarray,&
+                  pmax, jmax
   IMPLICIT NONE
-  INTEGER ::ig
-  ! zero truncation, An+1=0 for n+1>nmax
-    ! applies only for the processes that evolve the highest moment degree
-    IF(local_jmax .GE. Jmax) THEN
-      DO ig = 1,ngj/2
-        moments(:,:,local_nj+ngj/2+ig,:,:,:,updatetlevel) = 0._xp
+  INTEGER :: ip,ij
+  ! adapt the dmax if it is set <0
+  IF(dmax .LT. 0) THEN
+    dmax = MIN(pmax,2*jmax+1)
+  ELSEIF(dmax .GT. (pmax+2*jmax)) THEN
+    ERROR STOP "dmax is higher than the maximal moments degree available"
+  ENDIF
+  ! set the evolve mom array
+  ALLOCATE(evolve_mom(local_np+ngp,local_nj+ngj))
+  SELECT CASE(hierarchy_closure)
+  CASE('truncation')
+    DO ip = 1,local_np+ngp
+      DO ij = 1, local_nj+ngj
+        evolve_mom(ip,ij) = (parray(ip).GE.0) .AND. (jarray(ij).GE.0) &
+                      .AND. (parray(ip).LE.pmax) .AND. (jarray(ij).LE.jmax)
       ENDDO
-    ENDIF
-    ! applies only for the process that has largest p index
-    IF(local_pmax .GE. Pmax) THEN
-      DO ig = 1,ngp/2
-        moments(:,local_np+ngp/2+ig,:,:,:,:,updatetlevel) = 0._xp
+    ENDDO
+  CASE('max_degree')
+    DO ip = 1,local_np+ngp
+      DO ij = 1, local_nj+ngj
+          evolve_mom(ip,ij) = (parray(ip).GE.0) .AND. (jarray(ij.GE.0)) &
+                        .AND. (2*parray(ip)+jarray(ij) .GT. dmax)
       ENDDO
+    ENDDO  
+  CASE DEFAULT
+    ERROR STOP "closure scheme not recognized (avail: truncation,max_degree)"
+  END SELECT
+
+  ! Set the nonlinear closure scheme (truncation of sum over n in Sapj)
+  ALLOCATE(nmaxarray(local_nj))
+  SELECT CASE(nonlinear_closure)
+  CASE('truncation')
+    IF(nmax .LT. 0) THEN
+      ERROR STOP "cannot truncate the sum with a number smaller than 0"
+    ELSE
+      nmaxarray(:) = nmax
     ENDIF
-END SUBROUTINE ghosts_upper_truncation
+  CASE('anti_laguerre_aliasing')
+    DO ij = 1,local_nj
+      nmaxarray(ij) = jmax - jarray(ij+ngj/2)
+    ENDDO
+  CASE('full_sum')
+    nmaxarray(:) = jmax
+  CASE DEFAULT
+    ERROR STOP "nonlinear closure scheme not recognized (avail: truncation,anti_laguerre_aliasing,full_sum)"
+  END SELECT
+
+END SUBROUTINE set_closure_model
 
-! Negative OoB indices are 0
-SUBROUTINE ghosts_lower_truncation
+! Positive Oob indices are approximated with a model
+SUBROUTINE apply_closure_model
   USE prec_const, ONLY: xp
-  USE grid,       ONLY: ngp,ngj,local_pmin,local_jmin
-  USE fields,           ONLY: moments
+  USE grid,       ONLY: local_nj,ngj,local_np,ngp,local_na
+  USE fields,     ONLY: moments
   USE time_integration, ONLY: updatetlevel
   IMPLICIT NONE
-  INTEGER :: ig
-! zero truncation, An=0 for n<0
-    IF(local_jmin .EQ. 0) THEN
-      DO ig  = 1,ngj/2
-        moments(:,:,ig,:,:,:,updatetlevel) = 0._xp
-      ENDDO
-    ENDIF
-    ! applies only for the process that has lowest p index
-    IF(local_pmin .EQ. 0) THEN
-      DO ig  = 1,ngp/2
-        moments(:,ig,:,:,:,:,updatetlevel) = 0._xp
-      ENDDO
-    ENDIF
+  INTEGER ::ij,ip,ia
+  SELECT CASE (hierarchy_closure)
+    CASE('truncation','max_degree')
+      DO ij = 1, local_nj+ngj
+        DO ip = 1,local_np+ngp
+          DO ia = 1,local_na
+            IF(.NOT. evolve_mom(ip,ij))&
+              moments(ia,ip,ij,:,:,:,updatetlevel) = 0._xp
+          ENDDO
+        ENDDO
+      ENDDO  
+    CASE DEFAULT
+      ERROR STOP "closure scheme not recognized"
+  END SELECT
+END SUBROUTINE apply_closure_model
+
+! ! Positive Oob indices are approximated with a model
+! SUBROUTINE apply_closure_model
+!   USE prec_const, ONLY: xp
+!   USE grid,       ONLY: local_nj,ngj, jarray,&
+!                         local_np,ngp, parray
+!   USE fields,     ONLY: moments
+!   USE time_integration, ONLY: updatetlevel
+!   IMPLICIT NONE
+!   INTEGER ::ij,ip,ia
+!   SELECT CASE (hierarchy_closure)
+!     CASE('truncation')
+!     ! zero truncation, An+1=0 for n+1>nmax only
+!     CALL ghosts_upper_truncation
+!     CASE('max_degree')
+!     ! truncation at highest fully represented kinetic moment
+!     ! e.g. Dmax = 3 means
+!     ! only Napj s.t. p+2j <= 3 are evolved
+!     ! -> (p,j) allowed are (0,0),(1,0),(0,1),(2,0),(1,1),(3,0)
+!     ! =>> Dmax = min(Pmax,2*Jmax+1)
+!     j: DO ij = 1,local_nj+ngj
+!     p: DO ip = 1,local_np+ngp
+!       IF ( parray(ip)+2*jarray(ij) .GT. dmax) THEN
+!         moments(ia,ip,ij,:,:,:,updatetlevel) = 0._xp
+!       ENDIF
+!     ENDDO p
+!     ENDDO j
+!   END SELECT
+!   CALL ghosts_lower_truncation
+! END SUBROUTINE apply_closure_model
+
+! ! Positive Oob indices are approximated with a model
+! SUBROUTINE ghosts_upper_truncation
+!   USE prec_const, ONLY: xp
+!   USE grid,       ONLY: local_np,ngp,local_pmax, pmax,&
+!                         local_nj,ngj,local_jmax, jmax
+!   USE fields,           ONLY: moments
+!   USE time_integration, ONLY: updatetlevel
+!   IMPLICIT NONE
+!   INTEGER ::ig
+!   ! zero truncation, An+1=0 for n+1>nmax
+!     ! applies only for the processes that evolve the highest moment degree
+!     IF(local_jmax .GE. Jmax) THEN
+!       DO ig = 1,ngj/2
+!         moments(:,:,local_nj+ngj/2+ig,:,:,:,updatetlevel) = 0._xp
+!       ENDDO
+!     ENDIF
+!     ! applies only for the process that has largest p index
+!     IF(local_pmax .GE. Pmax) THEN
+!       DO ig = 1,ngp/2
+!         moments(:,local_np+ngp/2+ig,:,:,:,:,updatetlevel) = 0._xp
+!       ENDDO
+!     ENDIF
+! END SUBROUTINE ghosts_upper_truncation
+
+! ! Negative OoB indices are 0
+! SUBROUTINE ghosts_lower_truncation
+!   USE prec_const, ONLY: xp
+!   USE grid,       ONLY: ngp,ngj,local_pmin,local_jmin
+!   USE fields,           ONLY: moments
+!   USE time_integration, ONLY: updatetlevel
+!   IMPLICIT NONE
+!   INTEGER :: ig
+! ! zero truncation, An=0 for n<0
+!     IF(local_jmin .EQ. 0) THEN
+!       DO ig  = 1,ngj/2
+!         moments(:,:,ig,:,:,:,updatetlevel) = 0._xp
+!       ENDDO
+!     ENDIF
+!     ! applies only for the process that has lowest p index
+!     IF(local_pmin .EQ. 0) THEN
+!       DO ig  = 1,ngp/2
+!         moments(:,ig,:,:,:,:,updatetlevel) = 0._xp
+!       ENDDO
+!     ENDIF
 
-END SUBROUTINE ghosts_lower_truncation
+! END SUBROUTINE ghosts_lower_truncation
+
+
+SUBROUTINE closure_outputinputs(fid)
+  ! Write the input parameters to the results_xx.h5 file
+  USE futils, ONLY: attach, creatd
+  IMPLICIT NONE
+  INTEGER, INTENT(in) :: fid
+  CHARACTER(len=256)  :: str
+  WRITE(str,'(a)') '/data/input/closure'
+  CALL creatd(fid, 0,(/0/),TRIM(str),'Closure Input')
+  CALL attach(fid, TRIM(str),"hierarchy_closure",hierarchy_closure)
+  CALL attach(fid, TRIM(str),             "dmax",dmax)
+  CALL attach(fid, TRIM(str),"nonlinear_closure",nonlinear_closure)
+  CALL attach(fid, TRIM(str),             "nmax",nmax)
+END SUBROUTINE closure_outputinputs
 
 END module closure
diff --git a/src/cosolver_interface_mod.F90 b/src/cosolver_interface_mod.F90
index a8aabcc5..988c3b01 100644
--- a/src/cosolver_interface_mod.F90
+++ b/src/cosolver_interface_mod.F90
@@ -13,10 +13,11 @@ CONTAINS
     USE parallel,    ONLY: num_procs_p, comm_p,dsp_p,rcv_p
     USE grid,        ONLY: &
       local_na, &
-      local_np, total_np, total_nj,&
+      local_np, ngp, total_np, total_nj, ngj,&
       local_nkx, local_nky, local_nz, bar
     USE array,       ONLY: Capj
     USE MPI
+    USE closure,     ONLY: evolve_mom
     IMPLICIT NONE
     LOGICAL, INTENT(IN) :: GK_CO
     COMPLEX(xp), DIMENSION(total_np)    :: local_coll, buffer
@@ -30,7 +31,7 @@ CONTAINS
           a:DO ia = 1,local_na
             j:DO ij = 1,total_nj
               p:DO ip = 1,total_np
-                IF((CLOS .NE. 1) .OR. (p_int+2*j_int .LE. dmax)) THEN !compute for every moments except for closure 1
+              IF(evolve_mom(ip+ngp/2,ij+ngj/2)) THEN !compute for every moments except for closure 1
                   !! Take GK or DK limit
                   IF (GK_CO) THEN ! GK operator (k-dependant)
                     ikx_C = ikx; iky_C = iky; iz_C = iz;
diff --git a/src/diagnose.F90 b/src/diagnose.F90
index 9527819e..9ca937dc 100644
--- a/src/diagnose.F90
+++ b/src/diagnose.F90
@@ -28,16 +28,17 @@ SUBROUTINE diagnose(kstep)
 END SUBROUTINE diagnose
 
 SUBROUTINE init_outfile(comm,file0,file,fid)
-  USE diagnostics_par, ONLY : write_doubleprecision, diag_par_outputinputs, input_fname
-  USE basic,           ONLY : speak, jobnum, basic_outputinputs
-  USE grid,            ONLY : grid_outputinputs
-  USE geometry,        ONLY : geometry_outputinputs
-  USE model,           ONLY : model_outputinputs
-  USE species,         ONLY : species_outputinputs
-  USE collision,       ONLY : coll_outputinputs
-  USE initial_par,     ONLY : initial_outputinputs
-  USE time_integration,ONLY : time_integration_outputinputs
-  USE futils,          ONLY : creatf, creatg, creatd, attach, putfile
+  USE diagnostics_par, ONLY: write_doubleprecision, diag_par_outputinputs, input_fname
+  USE basic,           ONLY: speak, jobnum, basic_outputinputs
+  USE grid,            ONLY: grid_outputinputs
+  USE geometry,        ONLY: geometry_outputinputs
+  USE model,           ONLY: model_outputinputs
+  USE closure,         ONLY: closure_outputinputs
+  USE species,         ONLY: species_outputinputs
+  USE collision,       ONLY: coll_outputinputs
+  USE initial_par,     ONLY: initial_outputinputs
+  USE time_integration,ONLY: time_integration_outputinputs
+  USE futils,          ONLY: creatf, creatg, creatd, attach, putfile
   IMPLICIT NONE
   !input
   INTEGER,            INTENT(IN)    :: comm
@@ -69,14 +70,15 @@ SUBROUTINE init_outfile(comm,file0,file,fid)
   CALL attach(fid, "/data/input/codeinfo",   "author",   AUTHOR) !defined in srcinfo.h
   CALL attach(fid, "/data/input/codeinfo", "execdate", EXECDATE) !defined in srcinfo.h
   CALL attach(fid, "/data/input/codeinfo",     "host",     HOST) !defined in srcinfo.h
-  CALL basic_outputinputs(fid)
-  CALL grid_outputinputs(fid)
-  CALL geometry_outputinputs(fid)
-  CALL diag_par_outputinputs(fid)
-  CALL model_outputinputs(fid)
-  CALL species_outputinputs(fid)
-  CALL coll_outputinputs(fid)
-  CALL initial_outputinputs(fid)
+  CALL            basic_outputinputs(fid)
+  CALL             grid_outputinputs(fid)
+  CALL         geometry_outputinputs(fid)
+  CALL         diag_par_outputinputs(fid)
+  CALL            model_outputinputs(fid)
+  CALL          closure_outputinputs(fid)
+  CALL          species_outputinputs(fid)
+  CALL             coll_outputinputs(fid)
+  CALL          initial_outputinputs(fid)
   CALL time_integration_outputinputs(fid)
   !  Save STDIN (input file) of this run
   IF(jobnum .LE. 99) THEN
diff --git a/src/grid_mod.F90 b/src/grid_mod.F90
index 671f015f..0c7e65b7 100644
--- a/src/grid_mod.F90
+++ b/src/grid_mod.F90
@@ -163,7 +163,6 @@ CONTAINS
     CALL set_kygrid(LINEARITY,N_HD)
     CALL set_kxgrid(shear,Npol,LINEARITY,N_HD)
     CALL set_zgrid (Npol)
-
     ! print*, 'p:',parray
     ! print*, 'j:',jarray
     ! print*, 'ky:',kyarray
diff --git a/src/model_mod.F90 b/src/model_mod.F90
index 6d93410a..7be127dd 100644
--- a/src/model_mod.F90
+++ b/src/model_mod.F90
@@ -4,8 +4,6 @@ MODULE model
   IMPLICIT NONE
   PRIVATE
   ! INPUTS
-  INTEGER,  PUBLIC, PROTECTED ::    CLOS =  0         ! linear truncation method
-  INTEGER,  PUBLIC, PROTECTED :: NL_CLOS =  0         ! nonlinear truncation method
   INTEGER,  PUBLIC, PROTECTED ::    KERN =  0         ! Kernel model
   CHARACTER(len=16), &
             PUBLIC, PROTECTED ::LINEARITY= 'linear'   ! To turn on non linear bracket term
@@ -39,7 +37,7 @@ CONTAINS
     USE prec_const
     IMPLICIT NONE
 
-    NAMELIST /MODEL_PAR/ CLOS, NL_CLOS, KERN, LINEARITY, &
+    NAMELIST /MODEL_PAR/ KERN, LINEARITY, &
                          mu_x, mu_y, N_HD, HDz_h, mu_z, mu_p, mu_j, HYP_V, Na,&
                          nu, k_gB, k_cB, lambdaD, beta, ADIAB_E, tau_e
 
@@ -70,8 +68,6 @@ CONTAINS
     CHARACTER(len=256)  :: str
     WRITE(str,'(a)') '/data/input/model'
     CALL creatd(fid, 0,(/0/),TRIM(str),'Model Input')
-    CALL attach(fid, TRIM(str),      "CLOS",    CLOS)
-    CALL attach(fid, TRIM(str),   "NL_CLOS", NL_CLOS)
     CALL attach(fid, TRIM(str),      "KERN",    KERN)
     CALL attach(fid, TRIM(str), "LINEARITY", LINEARITY)
     CALL attach(fid, TRIM(str),      "mu_x",    mu_x)
diff --git a/src/moments_eq_rhs_mod.F90 b/src/moments_eq_rhs_mod.F90
index 569e33d8..e2a25dd4 100644
--- a/src/moments_eq_rhs_mod.F90
+++ b/src/moments_eq_rhs_mod.F90
@@ -8,10 +8,11 @@ SUBROUTINE compute_moments_eq_rhs
   USE array
   USE fields
   USE grid,       ONLY: local_na, local_np, local_nj, local_nkx, local_nky, local_nz,&
-                        nzgrid,pp2,ngp,ngj,ngz,dmax,&
+                        nzgrid,pp2,ngp,ngj,ngz,&
                         diff_dz_coeff,diff_kx_coeff,diff_ky_coeff,diff_p_coeff,diff_j_coeff,&
                         parray,jarray,kxarray, kyarray, kparray
   USE basic
+  USE closure,    ONLY: evolve_mom
   USE prec_const
   USE collision
   USE time_integration
@@ -55,7 +56,7 @@ SUBROUTINE compute_moments_eq_rhs
             a:DO ia = 1,local_na
               Napj = moments(ia,ipi,iji,iky,ikx,izi,updatetlevel)
               RHS = 0._xp
-              IF((CLOS .NE. 1) .OR. (p_int +2*j_int .LE. dmax)) THEN ! for the closure scheme
+              IF(evolve_mom(ipi,iji)) THEN ! for the closure scheme
                 !! Compute moments_ mixing terms
                 ! Perpendicular dynamic
                 ! term propto n^{p,j}
diff --git a/src/nonlinear_mod.F90 b/src/nonlinear_mod.F90
index 3c783e24..2e2791e4 100644
--- a/src/nonlinear_mod.F90
+++ b/src/nonlinear_mod.F90
@@ -9,7 +9,8 @@ MODULE nonlinear
                          kyarray, AA_y, local_nky_ptr, local_nky_ptr_offset,inv_Ny,&
                          local_nkx_ptr,kxarray, AA_x, inv_Nx,&
                          local_nz,ngz,zarray,nzgrid
-  USE model,       ONLY : LINEARITY, CLOS, NL_CLOS, EM
+  USE model,       ONLY : LINEARITY, EM
+  USE closure,     ONLY : evolve_mom, nmaxarray
   USE prec_const,  ONLY : xp
   USE species,     ONLY : sqrt_tau_o_sigma
   USE time_integration, ONLY : updatetlevel
@@ -23,7 +24,7 @@ MODULE nonlinear
   COMPLEX(xp), DIMENSION(:,:), ALLOCATABLE :: Fx_cmpx, Gy_cmpx
   COMPLEX(xp), DIMENSION(:,:), ALLOCATABLE :: Fy_cmpx, Gx_cmpx, F_conv_G
   INTEGER :: in, is, p_int, j_int, n_int
-  INTEGER :: nmax, smax
+  INTEGER :: smax
   REAL(xp):: sqrt_p, sqrt_pp1
   PUBLIC  :: compute_Sapj, nonlinear_init
 
@@ -58,6 +59,7 @@ SUBROUTINE compute_Sapj
   END SELECT
 END SUBROUTINE compute_Sapj
 
+! Compute the poisson bracket {F,G}
 SUBROUTINE compute_nonlinear
   IMPLICIT NONE
   INTEGER :: iz,ij,ip,eo,ia,ikx,iky,izi,ipi,iji,ini,isi
@@ -68,29 +70,22 @@ SUBROUTINE compute_nonlinear
       j_int=jarray(iji)
       DO ip = 1,local_np ! Loop over Hermite moments
         ipi = ip + ngp/2
+        IF(evolve_mom(ipi,iji)) THEN !compute for every moments except for closure 1
         p_int    = parray(ipi)
         sqrt_p   = SQRT(REAL(p_int,xp))
         sqrt_pp1 = SQRT(REAL(p_int,xp) + 1._xp)
         eo       = min(nzgrid,MODULO(parray(ip),2)+1)
         DO ia = 1,local_na
-          IF((CLOS .NE. 1) .OR. (p_int+2*j_int .LE. dmax)) THEN !compute for every moments except for closure 1
             ! Set non linear sum truncation
-            IF (NL_CLOS .EQ. -2) THEN
-              nmax = Jmax
-            ELSEIF (NL_CLOS .EQ. -1) THEN
-              nmax = Jmax-j_int
-            ELSE
-              nmax = min(NL_CLOS,Jmax-j_int)
-            ENDIF
             bracket_sum_r = 0._xp ! initialize sum over real nonlinear term
-            DO in = 1,nmax+1 ! Loop over laguerre for the sum
+            DO in = 1,nmaxarray(ij)+1 ! Loop over laguerre for the sum
               ini = in+ngj/2
   !-----------!! ELECTROSTATIC CONTRIBUTION
               ! First convolution terms
               F_cmpx(:,:) = phi(:,:,izi) * kernel(ia,ini,:,:,izi,eo)
               ! Second convolution terms
               G_cmpx = 0._xp ! initialization of the sum
-              smax   = MIN( (in-1)+(ij-1), Jmax );
+              smax   = MIN( jarray(ini)+jarray(iji), jmax );
               DO is = 1, smax+1 ! sum truncation on number of moments
                 isi = is + ngj/2
                 G_cmpx(:,:) = G_cmpx(:,:) + &
@@ -100,22 +95,21 @@ SUBROUTINE compute_nonlinear
               CALL poisson_bracket_and_sum(kyarray,kxarray,inv_Ny,inv_Nx,AA_y,AA_x,local_nky_ptr,local_nkx_ptr,F_cmpx,G_cmpx,bracket_sum_r)
   !-----------!! ELECTROMAGNETIC CONTRIBUTION -sqrt(tau)/sigma*{Sum_s dnjs [sqrt(p+1)Nap+1s + sqrt(p)Nap-1s], Kernel psi}
               IF(EM) THEN
-              ! First convolution terms
-              F_cmpx(:,:) = -sqrt_tau_o_sigma(ia) * psi(:,:,izi) * kernel(ia,ini,:,:,izi,eo)
-              ! Second convolution terms
-              G_cmpx = 0._xp ! initialization of the sum
-              smax   = MIN( (in-1)+(ij-1), Jmax );
-              DO is = 1, smax+1 ! sum truncation on number of moments
-                isi = is + ngj/2
-                G_cmpx(:,:)  = G_cmpx(:,:) + &
-                  dnjs(in,ij,is) * (sqrt_pp1*moments(ia,ipi+1,isi,:,:,izi,updatetlevel)&
-                                   +sqrt_p  *moments(ia,ipi-1,isi,:,:,izi,updatetlevel))
-              ENDDO
-              ! this function add its result to bracket_sum_r
-              CALL poisson_bracket_and_sum(kyarray,kxarray,inv_Ny,inv_Nx,AA_y,AA_x,local_nky_ptr,local_nkx_ptr,F_cmpx,G_cmpx,bracket_sum_r)
+                ! First convolution terms
+                F_cmpx(:,:) = -sqrt_tau_o_sigma(ia) * psi(:,:,izi) * kernel(ia,ini,:,:,izi,eo)
+                ! Second convolution terms
+                G_cmpx = 0._xp ! initialization of the sum
+                DO is = 1, smax+1 ! sum truncation on number of moments
+                  isi = is + ngj/2
+                  G_cmpx(:,:)  = G_cmpx(:,:) + &
+                    dnjs(in,ij,is) * (sqrt_pp1*moments(ia,ipi+1,isi,:,:,izi,updatetlevel)&
+                                    +sqrt_p  *moments(ia,ipi-1,isi,:,:,izi,updatetlevel))
+                ENDDO
+                ! this function add its result to bracket_sum_r
+                CALL poisson_bracket_and_sum(kyarray,kxarray,inv_Ny,inv_Nx,AA_y,AA_x,local_nky_ptr,local_nkx_ptr,F_cmpx,G_cmpx,bracket_sum_r)
               ENDIF
             ENDDO
-            ! Put the real nonlinear product into k-space
+            ! Put the real nonlinear product back into k-space
 #ifdef SINGLE_PRECISION
             call fftwf_mpi_execute_dft_r2c(planf, bracket_sum_r, bracket_sum_c)
 #else
@@ -127,10 +121,10 @@ SUBROUTINE compute_nonlinear
                 Sapj(ia,ip,ij,iky,ikx,iz) = bracket_sum_c(ikx,iky)*AA_x(ikx)*AA_y(iky)
               ENDDO
             ENDDO
+            ENDDO
           ELSE
-            Sapj(ia,ip,ij,:,:,iz) = 0._xp
+            Sapj(:,ip,ij,:,:,iz) = 0._xp
           ENDIF
-        ENDDO
       ENDDO
     ENDDO
   ENDDO
diff --git a/src/processing_mod.F90 b/src/processing_mod.F90
index 3046ad30..be8ed50b 100644
--- a/src/processing_mod.F90
+++ b/src/processing_mod.F90
@@ -15,7 +15,7 @@ MODULE processing
    USE geometry,         ONLY: Jacobian, iInt_Jacobian
    USE time_integration, ONLY: updatetlevel
    USE calculus,         ONLY: simpson_rule_z, grad_z, grad_z_5D, grad_z2, grad_z4, grad_z4_5D, interp_z
-   USE model,            ONLY: EM, CLOS, beta, HDz_h
+   USE model,            ONLY: EM, beta, HDz_h
    USE species,          ONLY: tau,q_tau,q_o_sqrt_tau_sigma,sqrt_tau_o_sigma
    USE parallel,         ONLY: num_procs_ky, rank_ky, comm_ky
    USE mpi
@@ -49,7 +49,7 @@ CONTAINS
    !
    SUBROUTINE compute_nadiab_moments
       IMPLICIT NONE
-      INTEGER :: ia,ip,ij,iky,ikx,iz, j_int, p_int
+      INTEGER :: ia,ip,ij,iky,ikx,iz
       !non adiab moments
       DO iz=1,local_nz+ngz
       DO ikx=1,local_nkx
@@ -72,25 +72,6 @@ CONTAINS
       ENDDO
       ENDDO
       ENDDO
-      !! Ensure to kill the moments too high if closue option is set to 1
-      IF(CLOS .EQ. 1) THEN
-         DO iz=1,local_nz+ngz
-         DO ikx=1,local_nkx
-         DO iky=1,local_nky
-         DO ij=1,local_nj+ngj
-         j_int = jarray(ij)
-         DO ip=1,local_np+ngp
-         p_int = parray(ip)
-            DO ia = 1,local_na
-            IF(p_int+2*j_int .GT. dmax) &
-               nadiab_moments(ia,ip,ij,iky,ikx,iz) = 0._xp
-         ENDDO
-         ENDDO
-         ENDDO
-         ENDDO
-         ENDDO
-         ENDDO
-      ENDIF
    END SUBROUTINE compute_nadiab_moments
 
    ! z grid gradients
diff --git a/src/readinputs.F90 b/src/readinputs.F90
index dcc9978b..02f8b5d5 100644
--- a/src/readinputs.F90
+++ b/src/readinputs.F90
@@ -9,6 +9,7 @@ SUBROUTINE readinputs
   USE initial_par,      ONLY: initial_readinputs
   USE time_integration, ONLY: time_integration_readinputs
   USE geometry,         ONLY: geometry_readinputs
+  USE closure,          ONLY: closure_readinputs
 
   USE prec_const
   IMPLICIT NONE
@@ -29,6 +30,9 @@ SUBROUTINE readinputs
   ! Load model parameters from input file
   CALL model_readinputs
 
+  ! Load parameters for moment closure scheme
+  CALL closure_readinputs
+  
   ! Load model parameters from input file
   CALL species_readinputs
 
diff --git a/wk/fast_analysis.m b/wk/fast_analysis.m
index 546abd56..8306684e 100644
--- a/wk/fast_analysis.m
+++ b/wk/fast_analysis.m
@@ -9,11 +9,19 @@ PARTITION  = '/misc/gyacomo23_outputs/';
 % resdir = 'paper_2_GYAC23/CBC/7x4x192x96x32_nu_0.05_muxy_1.0_muz_2.0';
 % resdir = 'paper_2_GYAC23/CBC/Full_NL_7x4x192x96x32_nu_0.05_muxy_1.0_muz_2.0';
 
-%% tests
+%% tests single vs double precision
 % resdir = 'paper_2_GYAC23/precision_study/5x3x128x64x24';
-resdir = 'paper_2_GYAC23/precision_study/5x3x128x64x24_xp';
+% resdir = 'paper_2_GYAC23/precision_study/5x3x128x64x24_dp';
 % resdir = 'paper_2_GYAC23/precision_study/5x3x128x64x24_sp';
-% resdir = 'paper_2_GYAC23/precision_study/5x3x128x64x24_Lx_180';
+% resdir = 'paper_2_GYAC23/precision_study/5x3x128x64x24_sp_clos_1';
+% resdir = 'paper_2_GYAC23/precision_study/3x2x128x64x24_sp_muz_2.0';
+resdir = 'paper_2_GYAC23/precision_study/test_3x2x128x64x24_sp_muz_2.0';
+% resdir = 'paper_2_GYAC23/precision_study/3x2x128x64x24_sp_clos_1';
+
+%% 
+% resdir = 'paper_2_GYAC23/collisionless/kT_5.3/5x3x128x64x24_dp_muz_2.0';
+% resdir = 'paper_2_GYAC23/collisionless/kT_5.3/5x3x128x64x24_dp_muz_2.0_full_NL';
+% resdir = 'paper_2_GYAC23/collisionless/kT_5.3/5x3x128x64x24_dp_muz_2.0_muxy_0';
  %%
 J0 = 00; J1 = 10;
 
@@ -22,7 +30,7 @@ DATADIR = [PARTITION,resdir,'/'];
 data    = {};
 data    = compile_results_low_mem(data,DATADIR,J0,J1);
 
-if 0
+if 1
 %% Plot transport and phi radial profile
 [data.PHI, data.Ts3D] = compile_results_3D(DATADIR,J0,J1,'phi');
 
@@ -50,7 +58,7 @@ options.NAME      = '\phi';
 % options.NAME      = 'Q_x';
 % options.NAME      = 'n_i';
 % options.NAME      = 'n_i-n_e';
-options.PLAN      = 'xy';
+options.PLAN      = 'xz';
 % options.NAME      = 'f_i';
 % options.PLAN      = 'sx';
 options.COMP      = 'avg';
@@ -63,6 +71,11 @@ options.RESOLUTION = 256;
 create_film(data,options,'.gif')
 end
 
+if 0
+%% Performance profiler
+profiler(data)
+end
+
 if 0
 %% Hermite-Laguerre spectrum
 [data.Nipjz, data.Ts3D] = compile_results_3D(DATADIR,J0,J1,'Nipjz');
-- 
GitLab