Skip to content
Snippets Groups Projects
multigrid.f90 72.9 KiB
Newer Older
  • Learn to ignore specific revisions
  • !=== COPYRIGHT AND LICENSE STATEMENT ===
    !
    !  This file is part of the TensorProductMultigrid code.
    !  
    !  (c) The copyright relating to this work is owned jointly by the
    !  Crown, Met Office and NERC [2014]. However, it has been created
    !  with the help of the GungHo Consortium, whose members are identified
    !  at https://puma.nerc.ac.uk/trac/GungHo/wiki .
    !  
    !  Main Developer: Eike Mueller
    !  
    !  TensorProductMultigrid is free software: you can redistribute it and/or
    !  modify it under the terms of the GNU Lesser General Public License as
    !  published by the Free Software Foundation, either version 3 of the
    !  License, or (at your option) any later version.
    !  
    !  TensorProductMultigrid is distributed in the hope that it will be useful,
    !  but WITHOUT ANY WARRANTY; without even the implied warranty of
    !  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    !  GNU Lesser General Public License for more details.
    !  
    !  You should have received a copy of the GNU Lesser General Public License
    !  along with TensorProductMultigrid (see files COPYING and COPYING.LESSER).
    !  If not, see <http://www.gnu.org/licenses/>.
    !
    !=== COPYRIGHT AND LICENSE STATEMENT ===
    
    
    !==================================================================
    !
    !  Geometric multigrid module for cell centred finite volume
    !  discretisation.
    !
    !    Eike Mueller, University of Bath, Feb 2012
    !
    !==================================================================
    module multigrid
    
    
      use parameters
      use datatypes
      use discretisation
      use messages
      use solver
      use conjugategradient
      use communication
      use timer
    
      implicit none
    
    public::mg_parameters
    public::mg_initialise
    public::mg_finalise
    
    public::mg_solve
    public::measurehaloswap
    public::REST_CELLAVERAGE
    
    public::PROL_CONSTANT
    public::PROL_TRILINEAR
    public::COARSEGRIDSOLVER_SMOOTHER
    public::COARSEGRIDSOLVER_CG
    
    private
    
      ! --- multigrid parameter constants ---
      ! restriction
      integer, parameter :: REST_CELLAVERAGE = 1
    
      integer, parameter :: REST_KHALIL      = 2
    
      ! prolongation method
      integer, parameter :: PROL_CONSTANT = 1
      integer, parameter :: PROL_TRILINEAR = 2
      ! Coarse grid solver
      integer, parameter :: COARSEGRIDSOLVER_SMOOTHER = 1
      integer, parameter :: COARSEGRIDSOLVER_CG = 2
    
      ! --- Multigrid parameters type ---
      type mg_parameters
        ! Verbosity level
        integer :: verbose
        ! Number of MG levels
        integer :: n_lev
        ! First level where data is pulled together
        integer :: lev_split
        ! Number of presmoothing steps
        integer :: n_presmooth
        ! Number of postsmoothing steps
        integer :: n_postsmooth
        ! Number of smoothing steps on coarsest level
        integer :: n_coarsegridsmooth
        ! Prolongation (see PROL_... for allowed values)
        integer :: prolongation
        ! Restriction (see RESTR_... for allowed values)
        integer :: restriction
        ! Smoother (see SMOOTHER_... for allowed values)
        integer :: smoother
        ! Relaxation factor
        real(kind=rl) :: omega
        ! Smoother on coarse grid
        integer :: coarsegridsolver
        ! ordering of grid points for smoother
        integer :: ordering
      end type mg_parameters
    
    ! --- Parameters ---
      type(mg_parameters) :: mg_param
      type(model_parameters) :: model_param
      type(smoother_parameters) :: smoother_param
      type(grid_parameters) :: grid_param
      type(comm_parameters) :: comm_param
      type(cg_parameters) :: cg_param
    
    
    ! --- Gridded and scalar data structures ---
      ! Solution vector
    
      type(scalar3d), allocatable :: xu_mg(:,:)
    
      type(scalar3d), allocatable :: xb_mg(:,:)
    
      type(scalar3d), allocatable :: xr_mg(:,:)
    
    
    ! --- Timer ---
      type(time), allocatable, dimension(:,:) :: t_restrict
      type(time), allocatable, dimension(:,:) :: t_prolongate
      type(time), allocatable, dimension(:,:) :: t_residual
      type(time), allocatable, dimension(:,:) :: t_addcorr
      type(time), allocatable, dimension(:,:) :: t_smooth
      type(time), allocatable, dimension(:,:) :: t_coarsesolve
      type(time), allocatable, dimension(:,:) :: t_total
    
    contains
    
    !==================================================================
    ! Initialise multigrid module, check and print out out parameters
    !==================================================================
      subroutine mg_initialise(grid_param_in,     &  ! Grid parameters
                               comm_param_in,     &  ! Comm parameters
                               model_param_in,    &  ! Model parameters
                               smoother_param_in, &  ! Smoother parameters
                               mg_param_in,       &  ! Multigrid parameters
                               cg_param_in        &  ! CG parameters
                               )
    
        
        use discretisation, only : zt1d_discretisation_allocate3d
        
    
        type(grid_parameters), intent(in)  :: grid_param_in
        type(comm_parameters), intent(in)  :: comm_param_in
        type(model_parameters), intent(in) :: model_param_in
        type(smoother_parameters), intent(in) :: smoother_param_in
        type(mg_parameters), intent(in)    :: mg_param_in
        type(cg_parameters), intent(in)    :: cg_param_in
        real(kind=rl)                      :: L, H
        integer                            :: n, nz, m, nlocal
        logical                            :: reduced_m
        integer                            :: level
        integer                            :: rank, ierr
        integer, dimension(2)              :: p_horiz
        integer, parameter                 :: dim_horiz = 2
        logical                            :: grid_active
        integer                            :: ix_min, ix_max, iy_min, iy_max
        integer                            :: icompx_min, icompx_max, &
                                              icompy_min, icompy_max
        integer                            :: halo_size
        integer                            :: vertbc
        character(len=32)                  :: t_label
    
    
        real , dimension(:,:,:) , pointer , contiguous :: zxu_mg_st,zxb_mg_st,zxr_mg_st
    
    
        if (i_am_master_mpi) &
          write(STDOUT,*) '*** Initialising multigrid ***'
        ! Check that cell counts are valid
        grid_param = grid_param_in
        comm_param = comm_param_in
        mg_param = mg_param_in
        model_param = model_param_in
        smoother_param = smoother_param_in
        cg_param = cg_param_in
        halo_size = comm_param%halo_size
        vertbc = grid_param%vertbc
    
        ! Check parameters
        if (grid_param%n < 2**(mg_param%n_lev-1) ) then
          call fatalerror('Number of cells in x-/y- direction has to be at least 2^{n_lev-1}.')
        endif
    
        if (mod(grid_param%n,2**(mg_param%n_lev-1)) .ne. 0 ) then
          call fatalerror('Number of cells in x-/y- direction is not a multiple of 2^{n_lev-1}.')
        end if
        if (i_am_master_mpi) &
          write(STDOUT,*) ''
    
        ! Allocate memory for timers
        allocate(t_smooth(mg_param%n_lev,0:pproc))
        allocate(t_total(mg_param%n_lev,0:pproc))
        allocate(t_restrict(mg_param%n_lev,0:pproc))
        allocate(t_residual(mg_param%n_lev,0:pproc))
        allocate(t_prolongate(mg_param%n_lev,0:pproc))
        allocate(t_addcorr(mg_param%n_lev,0:pproc))
        allocate(t_coarsesolve(mg_param%n_lev,0:pproc))
    
        ! Allocate memory for all levels and initialise fields
    
        allocate(xu_mg(mg_param%n_lev,0:pproc))
        allocate(xb_mg(mg_param%n_lev,0:pproc))
        allocate(xr_mg(mg_param%n_lev,0:pproc))
    
        n = grid_param%n
        nlocal = n/(2**pproc)
        nz = grid_param%nz
        L = grid_param%L
        H = grid_param%H
        level = mg_param%n_lev
        m = pproc
        reduced_m = .false.
    
        ! Work out local processor coordinates (this is necessary to identify
        ! global coordinates)
        call mpi_comm_rank(MPI_COMM_HORIZ,rank,ierr)
        call mpi_cart_coords(MPI_COMM_HORIZ,rank,dim_horiz,p_horiz,ierr)
        if (i_am_master_mpi) then
          write(STDOUT, &
            '(" Global gridsize (x,y,z) (pproc = ",I4," )      : ",I8," x ",I8," x ",I8)') &
            pproc, n, n, nz
        end if
        do while (level > 0)
          if (i_am_master_mpi) &
            write(STDOUT, &
              '(" Local gridsize (x,y,z) on level ",I3," m = ",I4," : ",I8," x ",I8," x ",I8)') &
              level, m, nlocal, nlocal, nz
          if (nlocal < 1) then
            call fatalerror('Number of grid points < 1')
          end if
    
          ! Set sizes of computational grid (take care at boundaries)
          if (p_horiz(1) == 0) then
            icompy_min = 1
          else
            icompy_min = 1 - (halo_size - 1)
          end if
    
          if (p_horiz(2) == 0) then
            icompx_min = 1
          else
            icompx_min = 1 - (halo_size - 1)
          end if
    
          if (p_horiz(1) == 2**pproc-1) then
            icompy_max = nlocal
          else
            icompy_max = nlocal + (halo_size - 1)
          end if
    
          if (p_horiz(2) == 2**pproc-1) then
            icompx_max = nlocal
          else
            icompx_max = nlocal + (halo_size - 1)
          end if
    
          ! Allocate data
    
          allocate(xu_mg(level,m)%s(0:nz+1,                       &
    
                                1-halo_size:nlocal+halo_size, &
                                1-halo_size:nlocal+halo_size))
    
          allocate(xb_mg(level,m)%s(0:nz+1,                       &
    
                                1-halo_size:nlocal+halo_size, &
                                1-halo_size:nlocal+halo_size))
    
          allocate(xr_mg(level,m)%s(0:nz+1,                       &
    
                                1-halo_size:nlocal+halo_size, &
                                1-halo_size:nlocal+halo_size))
    
          xu_mg(level,m)%s(:,:,:) = 0.0_rl
          xb_mg(level,m)%s(:,:,:) = 0.0_rl
          xr_mg(level,m)%s(:,:,:) = 0.0_rl
    
             allocate(zxu_mg_st(1-halo_size:nlocal+halo_size, &
                  1-halo_size:nlocal+halo_size, &
                  0:nz+1))
             !$acc enter data create (zxu_mg_st)
    
             xu_mg(level,m)%st => zxu_mg_st
             
    
             allocate(zxb_mg_st(1-halo_size:nlocal+halo_size, &
                  1-halo_size:nlocal+halo_size, &
                  0:nz+1))
             !$acc enter data create (zxb_mg_st)
    
             xb_mg(level,m)%st => zxb_mg_st
             
    
             allocate(zxr_mg_st(1-halo_size:nlocal+halo_size, &
                  1-halo_size:nlocal+halo_size, &
                  0:nz+1))
             !$acc enter data create (zxr_mg_st)
    
          
          !$acc kernels
          zxu_mg_st(:,:,:) = 0.0_rl
          zxb_mg_st(:,:,:) = 0.0_rl
          zxr_mg_st(:,:,:) = 0.0_rl
          !$acc end kernels
    
    
          ! NB: 1st coordinate is in the y-direction of the processor grid,
          ! second coordinate is in the x-direction (see comments in
          ! communication module)
          iy_min = (p_horiz(1)/2**(pproc-m))*nlocal+1
          iy_max = (p_horiz(1)/2**(pproc-m)+1)*nlocal
          ix_min = p_horiz(2)/2**(pproc-m)*nlocal+1
          ix_max = (p_horiz(2)/2**(pproc-m)+1)*nlocal
          ! Set grid parameters and local data ranges
          ! Note that only n (and possibly nz) change as we
          ! move down the levels
    
          xu_mg(level,m)%grid_param%L = L
          xu_mg(level,m)%grid_param%H = H
          xu_mg(level,m)%grid_param%n = n
          xu_mg(level,m)%grid_param%nz = nz
          xu_mg(level,m)%grid_param%vertbc = vertbc
          xu_mg(level,m)%ix_min = ix_min
          xu_mg(level,m)%ix_max = ix_max
          xu_mg(level,m)%iy_min = iy_min
          xu_mg(level,m)%iy_max = iy_max
          xu_mg(level,m)%icompx_min = icompx_min
          xu_mg(level,m)%icompx_max = icompx_max
          xu_mg(level,m)%icompy_min = icompy_min
          xu_mg(level,m)%icompy_max = icompy_max
          xu_mg(level,m)%halo_size = halo_size
    
          xb_mg(level,m)%grid_param%L = L
          xb_mg(level,m)%grid_param%H = H
          xb_mg(level,m)%grid_param%n = n
          xb_mg(level,m)%grid_param%nz = nz
          xb_mg(level,m)%grid_param%vertbc = vertbc
          xb_mg(level,m)%ix_min = ix_min
          xb_mg(level,m)%ix_max = ix_max
          xb_mg(level,m)%iy_min = iy_min
          xb_mg(level,m)%iy_max = iy_max
          xb_mg(level,m)%icompx_min = icompx_min
          xb_mg(level,m)%icompx_max = icompx_max
          xb_mg(level,m)%icompy_min = icompy_min
          xb_mg(level,m)%icompy_max = icompy_max
          xb_mg(level,m)%halo_size = halo_size
    
          xr_mg(level,m)%grid_param%L = L
          xr_mg(level,m)%grid_param%H = H
          xr_mg(level,m)%grid_param%n = n
          xr_mg(level,m)%grid_param%nz = nz
          xr_mg(level,m)%grid_param%vertbc = vertbc
          xr_mg(level,m)%ix_min = ix_min
          xr_mg(level,m)%ix_max = ix_max
          xr_mg(level,m)%iy_min = iy_min
          xr_mg(level,m)%iy_max = iy_max
          xr_mg(level,m)%icompx_min = icompx_min
          xr_mg(level,m)%icompx_max = icompx_max
          xr_mg(level,m)%icompy_min = icompy_min
          xr_mg(level,m)%icompy_max = icompy_max
          xr_mg(level,m)%halo_size = halo_size
    
    
          ! Are these grids active?
          if ( (m == pproc) .or. &
               ( (mod(p_horiz(1),2**(pproc-m)) == 0) .and. &
                 (mod(p_horiz(2),2**(pproc-m)) == 0) ) ) then
            grid_active = .true.
          else
            grid_active = .false.
          end if
    
          xu_mg(level,m)%isactive = grid_active
          xb_mg(level,m)%isactive = grid_active
          xr_mg(level,m)%isactive = grid_active
    
          write(t_label,'("t_total(",I3,",",I3,")")') level, m
          call initialise_timer(t_total(level,m),t_label)
          write(t_label,'("t_smooth(",I3,",",I3,")")') level, m
          call initialise_timer(t_smooth(level,m),t_label)
          write(t_label,'("t_restrict(",I3,",",I3,")")') level, m
          call initialise_timer(t_restrict(level,m),t_label)
          write(t_label,'("t_residual(",I3,",",I3,")")') level, m
          call initialise_timer(t_residual(level,m),t_label)
          write(t_label,'("t_prolongate(",I3,",",I3,")")') level, m
          call initialise_timer(t_prolongate(level,m),t_label)
          write(t_label,'("t_addcorrection(",I3,",",I3,")")') level, m
          call initialise_timer(t_addcorr(level,m),t_label)
          write(t_label,'("t_coarsegridsolver(",I3,",",I3,")")') level, m
          call initialise_timer(t_coarsesolve(level,m),t_label)
    
          ! If we are below L_split, split data
          if ( (level .le. mg_param%lev_split) .and. &
               (m > 0) .and. (.not. reduced_m) ) then
            reduced_m = .true.
            m = m-1
            nlocal = 2*nlocal
            cycle
          end if
          reduced_m = .false.
          level = level-1
          n = n/2
          nlocal = nlocal/2
        end do
        if (i_am_master_mpi) &
          write(STDOUT,*) ''
        call cg_initialise(cg_param)
      end subroutine mg_initialise
    
    !==================================================================
    ! Finalise, free memory for all data structures
    !==================================================================
      subroutine mg_finalise()
        implicit none
        integer :: level, m
        logical :: reduced_m
        character(len=80) :: s
        integer :: ierr
    
        if (i_am_master_mpi) &
          write(STDOUT,*) '*** Finalising multigrid ***'
        ! Deallocate memory
        level = mg_param%n_lev
        m = pproc
        reduced_m = .false.
        call print_timerinfo("--- V-cycle timing results ---")
        do while (level > 0)
          write(s,'("level = ",I3,", m = ",I3)') level,m
          call print_timerinfo(s)
          call print_elapsed(t_smooth(level,m),.True.,1.0_rl)
          call print_elapsed(t_restrict(level,m),.True.,1.0_rl)
          call print_elapsed(t_prolongate(level,m),.True.,1.0_rl)
          call print_elapsed(t_residual(level,m),.True.,1.0_rl)
          call print_elapsed(t_addcorr(level,m),.True.,1.0_rl)
          call print_elapsed(t_coarsesolve(level,m),.True.,1.0_rl)
          call print_elapsed(t_total(level,m),.True.,1.0_rl)
    
          deallocate(xu_mg(level,m)%s)
          deallocate(xb_mg(level,m)%s)
          deallocate(xr_mg(level,m)%s)
    
             !$acc exit data delete(xu_mg(level,m)%st)  
             deallocate(xu_mg(level,m)%st)
             !
             !$acc exit data delete(xb_mg(level,m)%st)
             deallocate(xb_mg(level,m)%st)
             !
             !$acc exit data delete(xr_mg(level,m)%st)
             deallocate(xr_mg(level,m)%st)
    
          ! If we are below L_split, split data
          if ( (level .le. mg_param%lev_split) .and. &
               (m > 0) .and. (.not. reduced_m) ) then
            reduced_m = .true.
            m = m-1
            cycle
          end if
          reduced_m = .false.
          level = level-1
        end do
    
        deallocate(xu_mg)
        deallocate(xb_mg)
        deallocate(xr_mg)
    
        deallocate(t_total)
        deallocate(t_smooth)
        deallocate(t_restrict)
        deallocate(t_prolongate)
        deallocate(t_residual)
        deallocate(t_addcorr)
        deallocate(t_coarsesolve)
          if (i_am_master_mpi) write(STDOUT,'("")')
      end subroutine mg_finalise
    
    !==================================================================
    ! Restrict from fine -> coarse
    
    !==================================================================
      subroutine restrict_mnh(phifine,phicoarse)
        implicit none
        type(scalar3d), intent(in)    :: phifine
        type(scalar3d), intent(inout) :: phicoarse
    
        integer :: ix_min, ix_max, iy_min, iy_max, n ,nz
    
        real , dimension(:,:,:) , pointer , contiguous :: zphifine_st , zphicoarse_st
    
        n      = phicoarse%grid_param%n
    
        ix_min = phicoarse%icompx_min
        ix_max = phicoarse%icompx_max
        iy_min = phicoarse%icompy_min
        iy_max = phicoarse%icompy_max
    
        ! three dimensional cell average
        if (mg_param%restriction == REST_CELLAVERAGE) then
          ! Do not coarsen in z-direction
    
          do ix=ix_min,ix_max
            do iy=iy_min,iy_max
    
                phicoarse%s(iz,iy,ix) =  &
                  phifine%s(iz  ,2*iy  ,2*ix  ) + &
                  phifine%s(iz  ,2*iy-1,2*ix  ) + &
                  phifine%s(iz  ,2*iy  ,2*ix-1) + &
                  phifine%s(iz  ,2*iy-1,2*ix-1)
              end do
            end do
          end do
    
             zphifine_st => phifine%st
             zphicoarse_st => phicoarse%st
    
    #endif         
             do concurrent (ix=ix_min:ix_max, iy=iy_min:iy_max, iz=1:nz)
    
                      zphicoarse_st(ix,iy,iz) =  &
                           zphifine_st(2*ix  ,2*iy  ,iz) + &
                           zphifine_st(2*ix  ,2*iy-1,iz) + &
                           zphifine_st(2*ix-1,2*iy  ,iz) + &
                           zphifine_st(2*ix-1,2*iy-1,iz)
    
       elseif(mg_param%restriction == REST_KHALIL) then
    
          if (LUseO) then 
             do ix=ix_min,ix_max
                xw=1.0
                xe=1.0
                if (ix==1) xw=0.0
                if (ix==n) xe=0.0
                do iy=iy_min,iy_max
                   xs=1.0
                   xn=1.0
                   if (iy==1) xs=0.0
                   if (iy==n) xn=0.0
    
                      phicoarse%s(iz,iy,ix) = 0.25_rl *         ( &
                           phifine%s(iz,2*iy+1,2*ix-1) * xn        + &
                           phifine%s(iz,2*iy+1,2*ix  ) * xn        + &
                           phifine%s(iz,2*iy  ,2*ix-2) * xw        + &
                           phifine%s(iz,2*iy  ,2*ix-1) * (4-xw-xn) + &
                           phifine%s(iz,2*iy  ,2*ix  ) * (4-xe-xn) + &
                           phifine%s(iz,2*iy  ,2*ix+1) * xe        + &
                           phifine%s(iz,2*iy-1,2*ix-2) * xw        + &
                           phifine%s(iz,2*iy-1,2*ix-1) * (4-xw-xs) + &
                           phifine%s(iz,2*iy-1,2*ix  ) * (4-xe-xs) + &
                           phifine%s(iz,2*iy-2,2*ix-1) * xs        + &
                           phifine%s(iz,2*iy-2,2*ix  ) * xs          &     
                           & )   
                   end do
                end do
             end do
          end if
          if (LUseT) then 
    
                do iy=iy_min,iy_max
                   xs=1.0
                   xn=1.0
                   if (iy==1) xs=0.0
                   if (iy==n) xn=0.0
                   do ix=ix_min,ix_max
                      xw=1.0
                      xe=1.0
                      if (ix==1) xw=0.0
                      if (ix==n) xe=0.0
                      phicoarse%st(ix,iy,iz) = 0.25_rl *         ( &
                           phifine%s(2*ix-1,2*iy+1,iz) * xn        + &
                           phifine%s(2*ix  ,2*iy+1,iz) * xn        + &
                           phifine%s(2*ix-2,2*iy  ,iz) * xw        + &
                           phifine%s(2*ix-1,2*iy  ,iz) * (4-xw-xn) + &
                           phifine%s(2*ix  ,2*iy  ,iz) * (4-xe-xn) + &
                           phifine%s(2*ix+1,2*iy  ,iz) * xe        + &
                           phifine%s(2*ix-2,2*iy-1,iz) * xw        + &
                           phifine%s(2*ix-1,2*iy-1,iz) * (4-xw-xs) + &
                           phifine%s(2*ix  ,2*iy-1,iz) * (4-xe-xs) + &
                           phifine%s(2*ix-1,2*iy-2,iz) * xs        + &
                           phifine%s(2*ix  ,2*iy-2,iz) * xs          &     
                           & )                
                   end do
                end do
             end do
          end if
          
    
        end if
      end subroutine restrict_mnh
    !==================================================================
    ! Restrict from fine -> coarse
    
    !==================================================================
      subroutine restrict(phifine,phicoarse)
        implicit none
        type(scalar3d), intent(in)    :: phifine
        type(scalar3d), intent(inout) :: phicoarse
        integer :: ix,iy,iz
    
        integer :: ix_min, ix_max, iy_min, iy_max , nz
    
    
        ix_min = phicoarse%icompx_min
        ix_max = phicoarse%icompx_max
        iy_min = phicoarse%icompy_min
        iy_max = phicoarse%icompy_max
    
        ! three dimensional cell average
        if (mg_param%restriction == REST_CELLAVERAGE) then
          ! Do not coarsen in z-direction
    
          if (LUseO) then
             do ix=ix_min,ix_max
                do iy=iy_min,iy_max
    
                      phicoarse%s(iz,iy,ix) =  &
                           phifine%s(iz  ,2*iy  ,2*ix  ) + &
                           phifine%s(iz  ,2*iy-1,2*ix  ) + &
                           phifine%s(iz  ,2*iy  ,2*ix-1) + &
                           phifine%s(iz  ,2*iy-1,2*ix-1)
                   end do
                end do
             end do
          end if
          if (LUseT) then 
    
                do iy=iy_min,iy_max
                   do ix=ix_min,ix_max
                      phicoarse%st(ix,iy,iz) =  &
                           phifine%st(2*ix  ,2*iy  ,iz) + &
                           phifine%st(2*ix  ,2*iy-1,iz) + &
                           phifine%st(2*ix-1,2*iy  ,iz) + &
                           phifine%st(2*ix-1,2*iy-1,iz)
                   end do
                end do
             end do
          endif
       end if
     end subroutine restrict
    
    !==================================================================
    ! Prolongate from coarse -> fine
    ! level, m is the correspong to the fine grid level
    !==================================================================
      subroutine prolongate_mnh(level,m,phicoarse,phifine)
        implicit none
        integer, intent(in) :: level
        integer, intent(in) :: m
        type(scalar3d), intent(in) :: phicoarse
        type(scalar3d), intent(inout) :: phifine
        real(kind=rl) :: tmp
        integer :: nlocal
        integer, dimension(5) :: ixmin, ixmax, iymin, iymax
        integer :: n, nz
        integer :: ix, iy, iz
        integer :: dix, diy, diz
        real(kind=rl) :: rhox, rhoy, rhoz
        real(kind=rl) :: rho_i, sigma_j, h, c1, c2
        logical :: overlap_comms
        integer, dimension(4) :: send_requests, recv_requests
    
        integer, dimension(4) :: send_requestsT, recv_requestsT
    
        integer :: ierr
        integer :: iblock
    
        ! Needed for interpolation matrix
    #ifdef PIECEWISELINEAR
    #else
        real(kind=rl) :: dx(4,3), A(3,3), dx_fine(4,2)
        integer :: i,j,k
        real(kind=rl) :: dxu(2), grad(2)
        dx(1,3) = 1.0_rl
        dx(2,3) = 1.0_rl
        dx(3,3) = 1.0_rl
        dx(4,3) = 1.0_rl
    #endif
    
        nlocal = phicoarse%ix_max-phicoarse%ix_min+1
        n = phicoarse%grid_param%n
        nz = phicoarse%grid_param%nz
    
    #ifdef OVERLAPCOMMS
        overlap_comms = (nlocal > 2)
    #else
        overlap_comms = .false.
    #endif
        ! Block 1 (N)
        ixmin(1) = 1
        ixmax(1) = nlocal
        iymin(1) = 1
        iymax(1) = 1
        ! Block 2 (S)
        ixmin(2) = 1
        ixmax(2) = nlocal
        iymin(2) = nlocal
        iymax(2) = nlocal
        ! Block 3 (W)
        ixmin(3) = 1
        ixmax(3) = 1
        iymin(3) = 2
        iymax(3) = nlocal-1
        ! Block 4 (E)
        ixmin(4) = nlocal
        ixmax(4) = nlocal
        iymin(4) = 2
        iymax(4) = nlocal-1
        ! Block 5 (INTERIOR)
        if (overlap_comms) then
          ixmin(5) = 2
          ixmax(5) = nlocal-1
          iymin(5) = 2
          iymax(5) = nlocal-1
        else
          ! If there are no interior cells, do not overlap
          ! communications and calculations, just loop over interior cells
          ixmin(5) = 1
          ixmax(5) = nlocal
          iymin(5) = 1
          iymax(5) = nlocal
        end if
    
        ! *** Constant prolongation or (tri-) linear prolongation ***
        if ( (mg_param%prolongation == PROL_CONSTANT) .or. &
             (mg_param%prolongation == PROL_TRILINEAR) ) then
          if (overlap_comms) then
            ! Loop over cells next to boundary (iblock = 1,...,4)
            do iblock = 1, 4
              if (mg_param%prolongation == PROL_CONSTANT) then
                call loop_over_grid_constant_mnh(iblock)
              end if
              if (mg_param%prolongation == PROL_TRILINEAR) then
                call loop_over_grid_linear_mnh(iblock)
              end if
            end do
            ! Initiate halo exchange
    
            call ihaloswap_mnh(level,m,phifine,send_requests,recv_requests,send_requestsT,recv_requestsT)
    
          end if
          ! Loop over INTERIOR cells
          iblock = 5
          if (mg_param%prolongation == PROL_CONSTANT) then
            call loop_over_grid_constant_mnh(iblock)
          end if
          if (mg_param%prolongation == PROL_TRILINEAR) then
            call loop_over_grid_linear_mnh(iblock)
          end if
          if (overlap_comms) then
            if (m > 0) then
    
              if (LUseO) call mpi_waitall(4,recv_requests, MPI_STATUSES_IGNORE, ierr)
    
              if (LUseO) call mpi_waitall(4,send_requests, MPI_STATUSES_IGNORE, ierr)
    
              if (LUseT) call mpi_waitall(4,recv_requestsT, MPI_STATUSES_IGNORE, ierr)
    
              if (LUseT) call mpi_waitall(4,send_requestsT, MPI_STATUSES_IGNORE, ierr)
    
            call haloswap_mnh(level,m,phifine)
    
          end if
        else
          call fatalerror("Unsupported prolongation.")
        end if
    
        contains
    
        !------------------------------------------------------------------
        ! The actual loops over the grid for the individual blocks,
        ! when overlapping calculation and communication
        !------------------------------------------------------------------
    
        !------------------------------------------------------------------
        ! (1) Constant interpolation
        !------------------------------------------------------------------
        subroutine loop_over_grid_constant_mnh(iblock)
          implicit none
          integer, intent(in) :: iblock
    
          integer :: ix,iy,iz, nz
    
          nz = phicoarse%grid_param%nz
    
          
          if (LUseO) then
             do ix=ixmin(iblock),ixmax(iblock)
                do iy=iymin(iblock),iymax(iblock)
                   do dix = -1,0
                      do diy = -1,0
    
                            phifine%s(iz,2*iy+diy,2*ix+dix) = phicoarse%s(iz,iy,ix)
                         end do
                      end do
                   end do
    
             end do
          end if
         if (LUseT) then
             do ix=ixmin(iblock),ixmax(iblock)
                do iy=iymin(iblock),iymax(iblock)
                   do dix = -1,0
                      do diy = -1,0
    
                            phifine%st(2*ix+dix,2*iy+diy,iz) = phicoarse%st(ix,iy,iz)
                         end do
                      end do
                   end do
                end do
             end do
          end if
    
    
        end subroutine loop_over_grid_constant_mnh
    
        !------------------------------------------------------------------
        ! (2) Linear interpolation
        !------------------------------------------------------------------
        subroutine loop_over_grid_linear_mnh(iblock)
          implicit none
          integer, intent(in) :: iblock
    
          real , dimension(:,:,:) , pointer , contiguous :: zphifine_st , zphicoarse_st
    
          ix_min = ixmin(iblock)
          ix_max = ixmax(iblock)
          iy_min = iymin(iblock)
          iy_max = iymax(iblock)
    
          ! optimisation for newman MNH case : all coef constant
    
    
          if (LUseO) then     
             do ix=ixmin(iblock),ixmax(iblock)
                do iy=iymin(iblock),iymax(iblock)
                   ! Piecewise linear interpolation
    
                      do dix = -1,0
                         do diy = -1,0
                            phifine%s(iz,2*iy+diy,2*ix+dix) =      &
                                 phicoarse%s(iz,iy,ix) +                &
                                 rhox*(phicoarse%s(iz,iy,ix+(2*dix+1))  &
                                 - phicoarse%s(iz,iy,ix)) +       &
                                 rhoy*(phicoarse%s(iz,iy+(2*diy+1),ix)  &
                                 - phicoarse%s(iz,iy,ix))
                         end do
                      end do
                   end do
    
             end do
          end if
          if (LUseT) then  
             ! Piecewise linear interpolation
    
             
             zphifine_st => phifine%st
             zphicoarse_st => phicoarse%st
    
    
    #endif         
             do diy = -1,0
                do dix = -1,0
                   DO CONCURRENT (ix=ix_min:ix_max, iy=iy_min:iy_max, iz=1:nz)
                      zphifine_st(2*ix+dix,2*iy+diy,iz) =      &
                           zphicoarse_st(ix,iy,iz) +                &
                           rhox*(zphicoarse_st(ix+(2*dix+1),iy,iz)  &
                           - zphicoarse_st(ix,iy,iz)) +       &
                           rhoy*(zphicoarse_st(ix,iy+(2*diy+1),iz)  &
                           - zphicoarse_st(ix,iy,iz))                  
    
        end subroutine loop_over_grid_linear_mnh
    
      end subroutine prolongate_mnh
    
    !==================================================================
    ! Prolongate from coarse -> fine
    ! level, m is the correspong to the fine grid level
    !==================================================================
      subroutine prolongate(level,m,phicoarse,phifine)
        implicit none
        integer, intent(in) :: level
        integer, intent(in) :: m
        type(scalar3d), intent(in) :: phicoarse
        type(scalar3d), intent(inout) :: phifine
        real(kind=rl) :: tmp
        integer :: nlocal
        integer, dimension(5) :: ixmin, ixmax, iymin, iymax
        integer :: n, nz
        integer :: ix, iy, iz
        integer :: dix, diy, diz
        real(kind=rl) :: rhox, rhoy, rhoz
        real(kind=rl) :: rho_i, sigma_j, h, c1, c2
        logical :: overlap_comms
        integer, dimension(4) :: send_requests, recv_requests
        integer :: ierr
        integer :: iblock
    
        ! Needed for interpolation matrix
    #ifdef PIECEWISELINEAR
    #else
        real(kind=rl) :: dx(4,3), A(3,3), dx_fine(4,2)
        integer :: i,j,k
        real(kind=rl) :: dxu(2), grad(2)
        dx(1,3) = 1.0_rl
        dx(2,3) = 1.0_rl
        dx(3,3) = 1.0_rl
        dx(4,3) = 1.0_rl
    #endif
    
        nlocal = phicoarse%ix_max-phicoarse%ix_min+1
        n = phicoarse%grid_param%n
        nz = phicoarse%grid_param%nz
    
    #ifdef OVERLAPCOMMS
        overlap_comms = (nlocal > 2)
    #else
        overlap_comms = .false.
    #endif
        ! Block 1 (N)
        ixmin(1) = 1
        ixmax(1) = nlocal
        iymin(1) = 1
        iymax(1) = 1
        ! Block 2 (S)
        ixmin(2) = 1
        ixmax(2) = nlocal
        iymin(2) = nlocal
        iymax(2) = nlocal
        ! Block 3 (W)
        ixmin(3) = 1
        ixmax(3) = 1
        iymin(3) = 2
        iymax(3) = nlocal-1
        ! Block 4 (E)
        ixmin(4) = nlocal
        ixmax(4) = nlocal
        iymin(4) = 2
        iymax(4) = nlocal-1
        ! Block 5 (INTERIOR)
        if (overlap_comms) then
          ixmin(5) = 2
          ixmax(5) = nlocal-1
          iymin(5) = 2
          iymax(5) = nlocal-1
        else
          ! If there are no interior cells, do not overlap
          ! communications and calculations, just loop over interior cells
          ixmin(5) = 1
          ixmax(5) = nlocal
          iymin(5) = 1
          iymax(5) = nlocal
        end if
    
        ! *** Constant prolongation or (tri-) linear prolongation ***
        if ( (mg_param%prolongation == PROL_CONSTANT) .or. &
             (mg_param%prolongation == PROL_TRILINEAR) ) then
          if (overlap_comms) then
            ! Loop over cells next to boundary (iblock = 1,...,4)
            do iblock = 1, 4
              if (mg_param%prolongation == PROL_CONSTANT) then
                call loop_over_grid_constant(iblock)
              end if
              if (mg_param%prolongation == PROL_TRILINEAR) then
                call loop_over_grid_linear(iblock)
              end if
            end do
            ! Initiate halo exchange
            call ihaloswap(level,m,phifine,send_requests,recv_requests)
          end if
          ! Loop over INTERIOR cells
          iblock = 5
          if (mg_param%prolongation == PROL_CONSTANT) then
            call loop_over_grid_constant(iblock)
          end if
          if (mg_param%prolongation == PROL_TRILINEAR) then
            call loop_over_grid_linear(iblock)
          end if
          if (overlap_comms) then
            if (m > 0) then
              call mpi_waitall(4,recv_requests, MPI_STATUSES_IGNORE, ierr)
            end if
          else
            call haloswap(level,m,phifine)
          end if
        else
          call fatalerror("Unsupported prolongation.")
        end if
    
        contains
    
        !------------------------------------------------------------------
        ! The actual loops over the grid for the individual blocks,
        ! when overlapping calculation and communication
        !------------------------------------------------------------------
    
        !------------------------------------------------------------------
        ! (1) Constant interpolation
        !------------------------------------------------------------------
        subroutine loop_over_grid_constant(iblock)
          implicit none
          integer, intent(in) :: iblock
          integer :: ix,iy,iz
          do ix=ixmin(iblock),ixmax(iblock)