Skip to content
Snippets Groups Projects
communication.f90 59.7 KiB
Newer Older
  • Learn to ignore specific revisions
  • !=== COPYRIGHT AND LICENSE STATEMENT ===
    !
    !  This file is part of the TensorProductMultigrid code.
    !  
    !  (c) The copyright relating to this work is owned jointly by the
    !  Crown, Met Office and NERC [2014]. However, it has been created
    !  with the help of the GungHo Consortium, whose members are identified
    !  at https://puma.nerc.ac.uk/trac/GungHo/wiki .
    !  
    !  Main Developer: Eike Mueller
    !  
    !  TensorProductMultigrid is free software: you can redistribute it and/or
    !  modify it under the terms of the GNU Lesser General Public License as
    !  published by the Free Software Foundation, either version 3 of the
    !  License, or (at your option) any later version.
    !  
    !  TensorProductMultigrid is distributed in the hope that it will be useful,
    !  but WITHOUT ANY WARRANTY; without even the implied warranty of
    !  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    !  GNU Lesser General Public License for more details.
    !  
    !  You should have received a copy of the GNU Lesser General Public License
    !  along with TensorProductMultigrid (see files COPYING and COPYING.LESSER).
    !  If not, see <http://www.gnu.org/licenses/>.
    !
    !=== COPYRIGHT AND LICENSE STATEMENT ===
    
    
    !==================================================================
    !
    !  MPI communication routines for multigrid code
    !
    !    Eike Mueller, University of Bath, Feb 2012
    !
    !==================================================================
    
    module communication
      use messages
      use datatypes
      use parameters
    
      use timer
    
      implicit none
    
    public::comm_preinitialise
    public::comm_initialise
    public::comm_finalise
    
    public::scalarprod_mnh
    
    public::haloswap_mnh
    
    public::ihaloswap_mnh
    
    public::ihaloswap
    public::collect
    public::distribute
    public::i_am_master_mpi
    public::master_rank
    public::pproc
    public::MPI_COMM_HORIZ
    public::comm_parameters
    public::comm_measuretime
    
    
      ! Number of processors
      ! n_proc = 2^(2*pproc), with integer pproc
      integer :: pproc
    
    ! Rank of master process
      integer, parameter :: master_rank = 0
    ! Am I the master process?
      logical :: i_am_master_mpi
    
      integer, parameter :: dim = 3  ! Dimension
      integer, parameter :: dim_horiz = 2  ! Horizontal dimension
      integer :: MPI_COMM_HORIZ ! Communicator with horizontal partitioning
    
    private
    
    ! Data types for halo exchange in both x- and y-direction
      integer, dimension(:,:,:), allocatable :: halo_type
    
    ! MPI vector data types
      ! Halo for data exchange in north-south direction
      integer, allocatable, dimension(:,:) :: halo_ns
    
      integer, allocatable, dimension(:,:) :: halo_nst
      integer, allocatable, dimension(:,:) :: halo_wet
    
      ! Vector data type for interior of field a(level,m)
      integer, allocatable, dimension(:,:) :: interior
    
      integer, allocatable, dimension(:,:) :: interiorT 
    
      ! Vector data type for one quarter of interior of field
      ! at level a(level,m). This has the same size (and can be
      ! used for communications with) the interior of a(level,m+1)
      integer, allocatable, dimension(:,:) :: sub_interior
    
      integer, allocatable, dimension(:,:) :: sub_interiorT 
    
      ! Timer for halo swaps
      type(time), allocatable, dimension(:,:) :: t_haloswap
      ! Timer for collect and distribute
      type(time), allocatable, dimension(:) :: t_collect
      type(time), allocatable, dimension(:) :: t_distribute
      ! Parallelisation parameters
      ! Measure communication times?
      logical :: comm_measuretime
    
      ! Parallel communication parameters
      type comm_parameters
        ! Size of halos
        integer :: halo_size
      end type comm_parameters
    
      type(comm_parameters) :: comm_param
    
    ! Data layout
    ! ===========
    !
    !  The number of processes has to be of the form nproc = 2^(2*pproc) to
    !  ensure that data can be distributed between processes.
    !  The processes are arranged in a (2^pproc) x (2^pproc) cartesian grid
    !  in the horizontal plane (i.e. vertical columns are always local to one
    !  process), which is implemented via the communicator MPI_COMM_HORIZ.
    !  This MPI_cart_rank() and MPI_cart_shift() can then be used to
    !  easily identify neighbouring processes.
    
    !  The number of data grid cells in each direction has to be a multiply
    !  of 2**(L-1) where L is the number of levels (including the coarse
    !  and fine level), with the coarse level corresponding to level=1.
    !  Also define L_split as the level where we start to pull together
    !  data. For levels > L_split each position in the cartesian grid is
    !  included in the work, below this only a subset of processes is
    !  used.
    !
    !  Each grid a(level,m) is identified by two numbers:
    !  (1) The multigrid level it belongs to (level)
    !  (2) The number of active processes that operate on it (2^(2*m)).
    !
    !  For level > L_split, m=procp. For L_split we store a(L_split,pproc) and
    !  a(L_split,pproc-1), and only processes with even coordinates in both
    !  horizontal directions use this grid.
    !  Below that level, store a(L_split-1,pproc-1) and a(L_split-1,pproc-2),
    !  where only processes for which both horiontal coordinates are
    !  multiples of four use the latter. This is continued until only on
    !  process is left.
    !
    !
    !  level
    !    L          a(L,pproc)
    !    L-1        a(L-1,pproc)
    !    ...
    !    L_split    a(L_split,pproc)    a(L_split  ,pproc-1)
    !    L_split-1                      a(L_split-1,pproc-1)  a(L_split-1,pproc-2)
    !
    !                                                                  ... a(3,1)
    !                                                                      a(2,1)
    !                                                                      a(1,1)
    !
    !  When moving from left to right in the above graph the total number of
    !  grid cells does not change, but the number of data points per process
    !  increases by a factor of 4.
    !
    ! Parallel operations
    ! ===================
    !
    !   (*)  Halo exchange. Update halo with data from neighbouring
    !        processors in cartesian grid on current (level,m)
    !   (*)  Collect data on all processes at (level,m) on those
    !        processes that are still active on (level,m-1).
    !   (*)  Distribute data at (level,m-1) and duplicate on all processes
    !        that are active at (level,m).
    !
    !   Note that in the cartesian processor grid the first coordinate
    !   is the North-South (y-) direction, the second coordinate is the
    !   East-West (x-) direction, i.e. the layout is this:
    !
    !   p_0 (0,0)   p_1 (0,1)   p_2 (0,2)   p_3 (0,3)
    !
    !   p_4 (1,0)   p_5 (1,1)   p_6 (1,2)   p_7 (1,3)
    !
    !   p_8 (2,0)   p_9 (2,1)   p_10 (2,2)  p_11 (2,3)
    !
    !                       [...]
    !
    !
    !   Normal multigrid restriction and prolongation are used to
    !   move between levels with fixed m.
    !
    !
    
    contains
    
    !==================================================================
    ! Pre-initialise communication routines
    !==================================================================
      subroutine comm_preinitialise()
        implicit none
        integer :: nproc, ierr, rank
        call mpi_comm_size(MPI_COMM_WORLD, nproc, ierr)
        call mpi_comm_rank(MPI_COMM_WORLD, rank, ierr)
        i_am_master_mpi = (rank == master_rank)
        ! Check that nproc = 2^(2*p)
        pproc = floor(log(1.0d0*nproc)/log(4.0d0))
        if ( (nproc - 4**pproc) .ne. 0) then
          call fatalerror("Number of processors has to be 2^(2*pproc) with integer pproc.")
        end if
        if (i_am_master_mpi) then
          write(STDOUT,'("PARALLEL RUN")')
          write(STDOUT,'("Number of processors : 2^(2*pproc) = ",I10," with pproc = ",I6)') &
            nproc, pproc
        end if
        ! Create halo data types
    
      end subroutine comm_preinitialise
    
    !==================================================================
    ! Initialise communication routines
    !==================================================================
      subroutine comm_initialise(n_lev,        & !} multigrid parameters
                                lev_split,     & !}
                                grid_param,    & ! Grid parameters
                                comm_param_in)   ! Parallel communication
                                                 ! parameters
        implicit none
        integer, intent(in) :: n_lev
        integer, intent(in) :: lev_split
        type(grid_parameters), intent(inout) :: grid_param
        type(comm_parameters), intent(in)    :: comm_param_in
        integer :: n
        integer :: nz
        integer :: rank, nproc, ierr
        integer :: count, blocklength, stride
        integer, dimension(2) :: p_horiz
        integer :: m, level, nlocal
        logical :: reduced_m
        integer :: halo_size
        character(len=32) :: t_label
    
    
        integer,parameter    :: nb_dims=3
        integer,dimension(nb_dims) :: profil_tab,profil_sous_tab,coord_debut
    
    
        n = grid_param%n
        nz = grid_param%nz
    
        comm_param = comm_param_in
        halo_size = comm_param%halo_size
    
        call mpi_comm_size(MPI_COMM_WORLD, nproc, ierr)
    
        ! Create cartesian topology
        call mpi_cart_create(MPI_COMM_WORLD,        & ! Old communicator name
                             dim_horiz,             & ! horizontal dimension
                             (/2**pproc,2**pproc/), & ! extent in each horizontal direction
                             (/.false.,.false./),   & ! periodic?
                             .true.,                & ! reorder?
                             MPI_COMM_HORIZ,        & ! Name of new communicator
                             ierr)
       ! calculate and display rank and corrdinates in cartesian grid
        call mpi_comm_rank(MPI_COMM_HORIZ, rank, ierr)
        call mpi_cart_coords(MPI_COMM_HORIZ,rank,dim_horiz,p_horiz,ierr)
    
        ! Local size of (horizontal) grid
        nlocal = n/2**pproc
    
        ! === Set up data types ===
        ! Halo for exchange in north-south direction
    
        if (LUseO) allocate(halo_ns(n_lev,0:pproc))
        if (LUseT) allocate(halo_nst(n_lev,0:pproc))
        if (LUseT) allocate(halo_wet(n_lev,0:pproc))
    
        if (LUseO) allocate(interior(n_lev,0:pproc))
        if (LUseO) allocate(sub_interior(n_lev,0:pproc))
        if (LUseT) allocate(interiorT(n_lev,0:pproc))
        if (LUseT) allocate(sub_interiorT(n_lev,0:pproc))
    
        ! Timer
        allocate(t_haloswap(n_lev,0:pproc))
        allocate(t_collect(0:pproc))
        allocate(t_distribute(0:pproc))
        do m=0,pproc
          write(t_label,'("t_collect(",I3,")")') m
          call initialise_timer(t_collect(m),t_label)
          write(t_label,'("t_distribute(",I3,")")') m
          call initialise_timer(t_distribute(m),t_label)
        end do
    
        m = pproc
        level = n_lev
        reduced_m = .false.
        do while (level > 0)
          ! --- Create halo data types ---
    
          ! NS- (y-) direction
          count = nlocal
          blocklength = (nz+2)*halo_size
          stride = (nlocal+2*halo_size)*(nz+2)
          call mpi_type_vector(count,blocklength,stride,MPI_DOUBLE_PRECISION, &
                               halo_ns(level,m),ierr)
          call mpi_type_commit(halo_ns(level,m),ierr)
    
          endif
          ! tranpose
          if (LUseT) then
          ! NS- (y-) transpose direction
          count =        nz+2                                        ! nlocal
          blocklength =  nlocal*halo_size                            ! (nz+2)*halo_size
          stride =       (nlocal+2*halo_size) * (nlocal+2*halo_size) ! (nlocal+2*halo_size)*(nz+2)
          call mpi_type_vector(count,blocklength,stride,MPI_DOUBLE_PRECISION, &
                               halo_nst(level,m),ierr)
          call mpi_type_commit(halo_nst(level,m),ierr)
         ! WE- (x-) transpose direction
          count =        (nz+2)*(nlocal+2*halo_size)*halo_size       ! nlocal
          blocklength =  1*halo_size                                 ! (nz+2)*halo_size
          stride =       nlocal+2*halo_size                          ! (nlocal+2*halo_size)*(nz+2)
          call mpi_type_vector(count,blocklength,stride,MPI_DOUBLE_PRECISION, &
                               halo_wet(level,m),ierr)
          call mpi_type_commit(halo_wet(level,m),ierr)
          endif
    
    #ifndef NDEBUG
      if (ierr .ne. 0) &
        call fatalerror("Commit halo_ns failed in mpi_type_commit().")
    #endif
          ! --- Create interior data types ---
    
          if (LUseO) then
             count = nlocal
             blocklength = nlocal*(nz+2)
             stride = (nz+2)*(nlocal+2*halo_size)
             call mpi_type_vector(count,blocklength,stride,MPI_DOUBLE_PRECISION,interior(level,m),ierr)
             call mpi_type_commit(interior(level,m),ierr)
             count = nlocal/2
             blocklength = nlocal/2*(nz+2)
             stride = (nlocal+2*halo_size)*(nz+2)
             call mpi_type_vector(count,blocklength,stride,MPI_DOUBLE_PRECISION,sub_interior(level,m),ierr)
             call mpi_type_commit(sub_interior(level,m),ierr)
          end if
          if (LUseT) then
             ! interiorT
             if ( nlocal /= 0 ) then
                profil_tab      = (/ nlocal+2*halo_size , nlocal+2*halo_size , nz+2 /)
                profil_sous_tab = (/ nlocal             , nlocal             , nz+2 /)
                coord_debut     = (/ 0                  , 0                  , 0    /)
                call MPI_TYPE_CREATE_SUBARRAY(nb_dims,profil_tab,profil_sous_tab,coord_debut,&
                     MPI_ORDER_FORTRAN,MPI_DOUBLE_PRECISION,interiorT(level,m),ierr)
                call mpi_type_commit(interiorT(level,m),ierr)
             end if
             ! sub_interiorT
             if ( (nlocal/2) /= 0 ) then
                profil_tab      = (/ nlocal+2*halo_size , nlocal+2*halo_size , nz+2 /)
                profil_sous_tab = (/ nlocal/2           , nlocal/2           , nz+2 /)
                coord_debut     = (/ 0                  , 0                  , 0    /)
                call MPI_TYPE_CREATE_SUBARRAY(nb_dims,profil_tab,profil_sous_tab,coord_debut,&
                     MPI_ORDER_FORTRAN,MPI_DOUBLE_PRECISION,sub_interiorT(level,m),ierr)
                call mpi_type_commit(sub_interiorT(level,m),ierr)
             end if
          end if
    
          ! --- Create timers ---
          write(t_label,'("t_haloswap(",I3,",",I3,")")') level,m
          call initialise_timer(t_haloswap(level,m),t_label)
    
          ! If we are below L_split, split data
          if ( (level .le. lev_split) .and. (m > 0) .and. (.not. reduced_m)) then
            reduced_m = .true.
            m = m-1
            nlocal = 2*nlocal
            cycle
          end if
          reduced_m = .false.
          level = level-1
          nlocal = nlocal/2
        end do
    
      end subroutine comm_initialise
    
    !==================================================================
    ! Finalise communication routines
    !==================================================================
    
      subroutine comm_finalise(n_lev,   &  ! } Multigrid parameters
                               lev_split,  & !}
                               grid_param )  ! } Grid parameters
    
        implicit none
        integer, intent(in) :: n_lev
        integer, intent(in) :: lev_split
    
        type(grid_parameters), intent(in) :: grid_param
        ! local var
    
        logical :: reduced_m
        integer :: level, m
        integer :: ierr
    
    
       ! Local size of (horizontal) grid
        n = grid_param%n
        nlocal = n/2**pproc
    
    
        m = pproc
        level = n_lev
        reduced_m = .false.
        if (i_am_master_mpi) then
          write(STDOUT,'(" *** Finalising communications ***")')
        end if
        call print_timerinfo("--- Communication timing results ---")
        do while (level > 0)
          write(s,'("level = ",I3,", m = ",I3)') level, m
          call print_timerinfo(s)
          ! --- Print out timer information ---
          call print_elapsed(t_haloswap(level,m),.True.,1.0_rl)
          ! --- Free halo data types ---
    
          if (LUseO) call mpi_type_free(halo_ns(level,m),ierr)
          if (LUseT) call mpi_type_free(halo_nst(level,m),ierr)
          if (LUseT) call mpi_type_free(halo_wet(level,m),ierr)
    
          ! --- Free interior data types ---
    
          if (LUseO) call mpi_type_free(interior(level,m),ierr)
          if (LUseO) call mpi_type_free(sub_interior(level,m),ierr)
    
          if (LUseT .and. (nlocal /= 0 ) ) call mpi_type_free(interiorT(level,m),ierr)
          if (LUseT .and. ( (nlocal/2) /= 0 ) ) call mpi_type_free(sub_interiorT(level,m),ierr)
    
          ! If we are below L_split, split data
          if ( (level .le. lev_split) .and. (m > 0) .and. (.not. reduced_m)) then
            reduced_m = .true.
            m = m-1
    
            cycle
          end if
          reduced_m = .false.
          level = level-1
    
        end do
        do m=pproc,0,-1
          write(s,'("m = ",I3)') m
          call print_timerinfo(s)
          ! --- Print out timer information ---
          call print_elapsed(t_collect(m),.True.,1.0_rl)
          call print_elapsed(t_distribute(m),.True.,1.0_rl)
        end do
    
        ! Deallocate arrays
    
        if (LUseO) deallocate(halo_ns)
        if (LUseT) deallocate(halo_nst,halo_wet)
    
        if (LUseO) deallocate(interior)
        if (LUseO) deallocate(sub_interior)
        if (LUseT) deallocate(interiorT)
        if (LUseT) deallocate(sub_interiorT)
    
    
        deallocate(t_haloswap)
        deallocate(t_collect)
        deallocate(t_distribute)
        if (i_am_master_mpi) then
          write(STDOUT,'("")')
        end if
    
      end subroutine comm_finalise
    
    !==================================================================
    !  Scalar product of two fields
    !==================================================================
    
      subroutine scalarprod_mnh(m, a, b, s)
        implicit none
        integer, intent(in) :: m
        type(scalar3d), intent(in) :: a
        type(scalar3d), intent(in) :: b
        real(kind=rl), intent(out) :: s
    
        integer :: nprocs, rank, ierr
        integer :: p_horiz(2)
        integer :: stepsize
        integer, parameter :: dim_horiz = 2
    
        real(kind=rl) :: local_sum, global_sum 
        real(kind=rl) :: local_sumt,global_sumt
    
        integer :: nlocal, nz, i
    
        real(kind=rl) :: ddot
    
    
        integer :: iy_min,iy_max, ix_min,ix_max
        real , dimension(:,:,:) , pointer :: za_st,zb_st
    
    
        nlocal = a%ix_max-a%ix_min+1
        nz = a%grid_param%nz
    
    
        iy_min = a%iy_min
        iy_max = a%iy_max
        ix_min = a%ix_min
        ix_max = a%ix_max
    
    
        ! Work out coordinates of processor
        call mpi_comm_size(MPI_COMM_HORIZ,nprocs,ierr)
        call mpi_comm_rank(MPI_COMM_HORIZ,rank,ierr)
        stepsize = 2**(pproc-m)
        if (nprocs > 1) then
          ! Only inlcude local sum if the processor coordinates
          ! are multiples of stepsize
          call mpi_cart_coords(MPI_COMM_HORIZ,rank,dim_horiz,p_horiz,ierr)
          if ( (stepsize == 1) .or.                   &
               ( (stepsize > 1) .and.                 &
                 (mod(p_horiz(1),stepsize)==0) .and.  &
                 (mod(p_horiz(2),stepsize)==0) ) ) then
    
            if (LUseO) then
               local_sum = 0.0_rl
               do i = 1, nlocal
                  local_sum = local_sum &
                       + ddot((nz+2)*nlocal,a%s(0,1,i),1,b%s(0,1,i),1)
               end do
            end if
            if (LUseT) then
               local_sumt = 0.0_rl
               do iz=0,nz+1
    
                  do iy=a%icompy_min,a%icompy_max
                     do ix=a%icompx_min,a%icompx_max
    
                        local_sumt = local_sumt &
                             + a%st(ix,iy,iz)*b%st(ix,iy,iz)
                     end do
                  end do
               end do
            end if
    
            if (LUseO) local_sum = 0.0_rl
            if (LUseT) local_sumt = 0.0_rl
    
          if (LUseO) call mpi_allreduce(local_sum,global_sum,1,MPI_DOUBLE_PRECISION, &
                             MPI_SUM,MPI_COMM_HORIZ,ierr)
          if (LUseT) call mpi_allreduce(local_sumt,global_sumt,1,MPI_DOUBLE_PRECISION, &
    
                             MPI_SUM,MPI_COMM_HORIZ,ierr)
        else
    
          global_sum = 0.0_rl
          do i = 1, nlocal
            global_sum = global_sum &
                       + ddot((nz+2)*nlocal,a%s(0,1,i),1,b%s(0,1,i),1)
          end do
    
                do iy=iy_min,iy_max
                   do ix=ix_min,ix_max
    
                           + za_st(ix,iy,iz)*zb_st(ix,iy,iz)
    
      end subroutine scalarprod_mnh
    !-------------------------------------------------------------------------------
    
      subroutine scalarprod(m, a, b, s)
        implicit none
        integer, intent(in) :: m
        type(scalar3d), intent(in) :: a
        type(scalar3d), intent(in) :: b
        real(kind=rl), intent(out) :: s
        integer :: nprocs, rank, ierr
        integer :: p_horiz(2)
        integer :: stepsize
        integer, parameter :: dim_horiz = 2
        real(kind=rl) :: local_sum, global_sum
        integer :: nlocal, nz, i
        real(kind=rl) :: ddot
    
        nlocal = a%ix_max-a%ix_min+1
        nz = a%grid_param%nz
        ! Work out coordinates of processor
        call mpi_comm_size(MPI_COMM_HORIZ,nprocs,ierr)
        call mpi_comm_rank(MPI_COMM_HORIZ,rank,ierr)
        stepsize = 2**(pproc-m)
        if (nprocs > 1) then
          ! Only inlcude local sum if the processor coordinates
          ! are multiples of stepsize
          call mpi_cart_coords(MPI_COMM_HORIZ,rank,dim_horiz,p_horiz,ierr)
          if ( (stepsize == 1) .or.                   &
               ( (stepsize > 1) .and.                 &
                 (mod(p_horiz(1),stepsize)==0) .and.  &
                 (mod(p_horiz(2),stepsize)==0) ) ) then
            local_sum = 0.0_rl
            do i = 1, nlocal
              local_sum = local_sum &
                        + ddot((nz+2)*nlocal,a%s(0,1,i),1,b%s(0,1,i),1)
            end do
          else
            local_sum = 0.0_rl
          end if
          call mpi_allreduce(local_sum,global_sum,1,MPI_DOUBLE_PRECISION, &
                             MPI_SUM,MPI_COMM_HORIZ,ierr)
        else
          global_sum = 0.0_rl
          do i = 1, nlocal
            global_sum = global_sum &
                       + ddot((nz+2)*nlocal,a%s(0,1,i),1,b%s(0,1,i),1)
          end do
        end if
        s = global_sum
      end subroutine scalarprod
    !==================================================================
    
    !  Pritn Scalar product^2  of 1 fields
    !==================================================================
      subroutine print_scalaprod2(l,m, a, message )
        implicit none
        integer, intent(in) :: l,m
        type(scalar3d), intent(in) :: a
        character(len=*) , intent(in) :: message
    
        !local 
        real(kind=rl) :: s
    
        call scalarprod_mnh(m, a, a, s)
        s = sqrt(s)
        if (i_am_master_mpi) then
           write(STDOUT,'("Print_norm::",A,2I3,E23.15)') message, l,m,s
           call flush(STDOUT)
        end if
    
      end subroutine print_scalaprod2
    !==================================================================
    
    ! Boundary Neumann
    !==================================================================
      subroutine boundary_mnh(a)              ! data field
    
        implicit none
     
        type(scalar3d), intent(inout) :: a
    
    
        integer :: n, ix_min,ix_max,iy_min,iy_max
    
        integer :: icompx_max,icompy_max
    
        real , dimension(:,:,:) , pointer :: za_st
    
        ! Update Real Boundary for Newman case   u(0) = u(1) , etc ...
    
        n      = a%grid_param%n
        ix_min = a%ix_min
        ix_max = a%ix_max
        iy_min = a%iy_min
        iy_max = a%iy_max
    
        if ( ix_min == 1 ) then
           a%s(:,:,0) = a%s(:,:,1)
        endif
        if ( ix_max == n ) then
           a%s(:,:,a%icompx_max+1) = a%s(:,:,a%icompx_max)
        endif
        if ( iy_min == 1 ) then
           a%s(:,0,:) = a%s(:,1,:)
        endif
        if ( iy_max == n ) then
           a%s(:,a%icompy_max+1,:) = a%s(:,a%icompy_max,:)
        endif
    
    
        za_st => a%st
        icompx_max = a%icompx_max
        icompy_max = a%icompy_max
    
        !$acc kernels
    
           !acc kernels
           za_st(0,:,:) = za_st(1,:,:)
           !acc end kernels
    
           !acc kernels
           za_st(icompx_max+1,:,:) = za_st(icompx_max,:,:)
           !acc end kernels
    
           !acc kernels
           za_st(:,0,:) = za_st(:,1,:)
           !acc end kernels
    
           !acc kernels
           za_st(:,icompy_max+1,:) = za_st(:,icompy_max,:)
           !acc end kernels
    
      end subroutine boundary_mnh
    !==================================================================
    
    !  Initiate asynchronous halo exchange
    !
    !  For all processes with horizontal indices that are multiples
    !  of 2^(pproc-m), update halos with information held by
    !  neighbouring processes, e.g. for pproc-m = 1, stepsize=2
    !
    !                      N  (0,2)
    !                      ^
    !                      !
    !                      v
    !
    !      W (2,0) <-->  (2,2)  <-->  E (2,4)
    !
    !                      ^
    !                      !
    !                      v
    !                      S (4,2)
    !
    
    !==================================================================
      subroutine ihaloswap_mnh(level,m,       &  ! multigrid- and processor- level
                           a,             &  ! data field
                           send_requests, &  ! send requests (OUT)
    
                           recv_requests,  &  ! recv requests (OUT)
                           send_requestsT, &  ! send requests T (OUT)
                           recv_requestsT  &  ! recv requests T (OUT)
    
                           )
        implicit none
        integer, intent(in) :: level
        integer, intent(in) :: m
        integer, intent(out), dimension(4) :: send_requests
        integer, intent(out), dimension(4) :: recv_requests
    
        integer, intent(out), dimension(4) :: send_requestsT
        integer, intent(out), dimension(4) :: recv_requestsT
    
        type(scalar3d), intent(inout) :: a
        integer :: a_n  ! horizontal grid size
        integer :: nz   ! vertical grid size
        integer, dimension(2) :: p_horiz
        integer :: stepsize
        integer :: ierr, rank, sendtag, recvtag
        integer :: stat(MPI_STATUS_SIZE)
        integer :: halo_size
        integer :: neighbour_n_rank
        integer :: neighbour_s_rank
        integer :: neighbour_e_rank
        integer :: neighbour_w_rank
        integer :: yoffset, blocklength
    
        halo_size = comm_param%halo_size
    
        ! Do nothing if we are only using one processor
        if (m > 0) then
          a_n = a%ix_max-a%ix_min+1
          nz = a%grid_param%nz
          stepsize = 2**(pproc-m)
    
          ! Work out rank, only execute on relevant processes
          call mpi_comm_rank(MPI_COMM_HORIZ, rank, ierr)
          call mpi_cart_coords(MPI_COMM_HORIZ,rank,dim_horiz,p_horiz,ierr)
    
          ! Work out ranks of neighbours
          ! W -> E
          call mpi_cart_shift(MPI_COMM_HORIZ,1, stepsize, &
                              neighbour_w_rank,neighbour_e_rank,ierr)
          ! N -> S
          call mpi_cart_shift(MPI_COMM_HORIZ,0, stepsize, &
                              neighbour_n_rank,neighbour_s_rank,ierr)
          if ( (stepsize == 1) .or.                   &
             (  (mod(p_horiz(1),stepsize) == 0) .and. &
                (mod(p_horiz(2),stepsize) == 0) ) ) then
            if (halo_size == 1) then
              ! Do not include corners in send/recv
              yoffset = 1
              blocklength = a_n*(nz+2)*halo_size
            else
              yoffset = 1-halo_size
              blocklength = (a_n+2*halo_size)*(nz+2)*halo_size
            end if
            ! Receive from north
    
            recvtag = 1002
            if (LUseO) call mpi_irecv(a%s(0,0-(halo_size-1),1),1,       &
    
                           halo_ns(level,m),neighbour_n_rank,recvtag,   &
                           MPI_COMM_HORIZ, recv_requests(1), ierr)
    
            recvtag = 1012
            if (LUseT) call mpi_irecv(a%st(1,0-(halo_size-1),0),1,      &
                           halo_nst(level,m),neighbour_n_rank,recvtag,  &
                           MPI_COMM_HORIZ, recv_requestsT(1), ierr)
    
            ! Receive from south
    
            recvtag = 1003
            if (LUseO) call mpi_irecv(a%s(0,a_n+1,1),1,                 &
    
                           halo_ns(level,m),neighbour_s_rank,recvtag,   &
                           MPI_COMM_HORIZ, recv_requests(2), ierr)
    
            recvtag = 1013
            if (LUseT) call mpi_irecv(a%st(1,a_n+1,0),1,                &
                           halo_nst(level,m),neighbour_s_rank,recvtag,  &
                           MPI_COMM_HORIZ, recv_requestsT(2), ierr)
    
            sendtag = 1002
            if (LUseO) call mpi_isend(a%s(0,a_n-(halo_size-1),1),1,     &
    
                           halo_ns(level,m),neighbour_s_rank,sendtag,   &
                           MPI_COMM_HORIZ, send_requests(1), ierr)
    
            sendtag = 1012
            if (LUseT) call mpi_isend(a%st(1,a_n-(halo_size-1),0),1,    &
                           halo_nst(level,m),neighbour_s_rank,sendtag,  &
                           MPI_COMM_HORIZ, send_requestsT(1), ierr)
    
            sendtag = 1003
            if (LUseO) call mpi_isend(a%s(0,1,1),1,                     &
    
                           halo_ns(level,m),neighbour_n_rank,sendtag,   &
                           MPI_COMM_HORIZ, send_requests(2), ierr)
    
            sendtag = 1013
            if (LUseT) call mpi_isend(a%st(1,1,0),1,                    &
                           halo_nst(level,m),neighbour_n_rank,sendtag,  &
                           MPI_COMM_HORIZ, send_requestsT(2), ierr)
    
            ! Receive from west
    
            recvtag = 1000
            if (LUseO) call mpi_irecv(a%s(0,yoffset,0-(halo_size-1)),blocklength,  &
    
                           MPI_DOUBLE_PRECISION,neighbour_w_rank,recvtag, &
                           MPI_COMM_HORIZ, recv_requests(3), ierr)
    
            recvtag = 1010
            if (LUseT) call mpi_irecv(a%st(0-(halo_size-1),0,0),1,  &
                           halo_wet(level,m),neighbour_w_rank,recvtag, &
                           MPI_COMM_HORIZ, recv_requestsT(3), ierr)
    
            ! Receive from east
    
            sendtag = 1001
            if (LUseO) call mpi_irecv(a%s(0,yoffset,a_n+1),blocklength,          &
    
                           MPI_DOUBLE_PRECISION,neighbour_e_rank,recvtag, &
                           MPI_COMM_HORIZ, recv_requests(4), ierr)
    
            sendtag = 1011
            if (LUseT) call mpi_irecv(a%st(a_n+1,0,0),1,          &
                           halo_wet(level,m),neighbour_e_rank,recvtag, &
                           MPI_COMM_HORIZ, recv_requestsT(4), ierr)
    
            sendtag = 1000
            if (LUseO) call mpi_isend(a%s(0,yoffset,a_n-(halo_size-1)),blocklength,  &
    
                           MPI_DOUBLE_PRECISION,neighbour_e_rank,sendtag, &
                           MPI_COMM_HORIZ, send_requests(3), ierr)
    
            sendtag = 1010
            if (LUseT) call mpi_isend(a%st(a_n-(halo_size-1),0,0),1,  &
                           halo_wet(level,m),neighbour_e_rank,sendtag, &
                           MPI_COMM_HORIZ, send_requestsT(3), ierr)
    
            recvtag = 1001
            if (LUseO) call mpi_isend(a%s(0,yoffset,1),blocklength,                &
    
                           MPI_DOUBLE_PRECISION,neighbour_w_rank,sendtag,   &
                           MPI_COMM_HORIZ, send_requests(4), ierr)
    
            recvtag = 1011
            if (LUseT) call mpi_isend(a%st(1,0,0),1,                &
                           halo_wet(level,m),neighbour_w_rank,sendtag,   &
                           MPI_COMM_HORIZ, send_requestsT(4), ierr)
    
          end if
        end if
      end subroutine ihaloswap_mnh
    
    !==================================================================
      subroutine ihaloswap(level,m,       &  ! multigrid- and processor- level
                           a,             &  ! data field
                           send_requests, &  ! send requests (OUT)
                           recv_requests  &  ! recv requests (OUT)
                           )
        implicit none
        integer, intent(in) :: level
        integer, intent(in) :: m
        integer, intent(out), dimension(4) :: send_requests
        integer, intent(out), dimension(4) :: recv_requests
        type(scalar3d), intent(inout) :: a
        integer :: a_n  ! horizontal grid size
        integer :: nz   ! vertical grid size
        integer, dimension(2) :: p_horiz
        integer :: stepsize
        integer :: ierr, rank, sendtag, recvtag
        integer :: stat(MPI_STATUS_SIZE)
        integer :: halo_size
        integer :: neighbour_n_rank
        integer :: neighbour_s_rank
        integer :: neighbour_e_rank
        integer :: neighbour_w_rank
        integer :: yoffset, blocklength
    
        halo_size = comm_param%halo_size
    
        ! Do nothing if we are only using one processor
        if (m > 0) then
          a_n = a%ix_max-a%ix_min+1
          nz = a%grid_param%nz
          stepsize = 2**(pproc-m)
    
          ! Work out rank, only execute on relevant processes
          call mpi_comm_rank(MPI_COMM_HORIZ, rank, ierr)
          call mpi_cart_coords(MPI_COMM_HORIZ,rank,dim_horiz,p_horiz,ierr)
    
          ! Work out ranks of neighbours
          ! W -> E
          call mpi_cart_shift(MPI_COMM_HORIZ,1, stepsize, &
                              neighbour_w_rank,neighbour_e_rank,ierr)
          ! N -> S
          call mpi_cart_shift(MPI_COMM_HORIZ,0, stepsize, &
                              neighbour_n_rank,neighbour_s_rank,ierr)
          if ( (stepsize == 1) .or.                   &
             (  (mod(p_horiz(1),stepsize) == 0) .and. &
                (mod(p_horiz(2),stepsize) == 0) ) ) then
            if (halo_size == 1) then
              ! Do not include corners in send/recv
              yoffset = 1
              blocklength = a_n*(nz+2)*halo_size
            else
              yoffset = 1-halo_size
              blocklength = (a_n+2*halo_size)*(nz+2)*halo_size
            end if
            ! Receive from north
            recvtag = 2
            call mpi_irecv(a%s(0,0-(halo_size-1),1),1,                  &
                           halo_ns(level,m),neighbour_n_rank,recvtag,   &
                           MPI_COMM_HORIZ, recv_requests(1), ierr)
            ! Receive from south
            recvtag = 3
            call mpi_irecv(a%s(0,a_n+1,1),1,                            &
                           halo_ns(level,m),neighbour_s_rank,recvtag,   &
                           MPI_COMM_HORIZ, recv_requests(2), ierr)
            ! Send to south
            sendtag = 2
            call mpi_isend(a%s(0,a_n-(halo_size-1),1),1,                &
                           halo_ns(level,m),neighbour_s_rank,sendtag,   &
                           MPI_COMM_HORIZ, send_requests(1), ierr)
            ! Send to north
            sendtag = 3
            call mpi_isend(a%s(0,1,1),1,                                &
                           halo_ns(level,m),neighbour_n_rank,sendtag,   &
                           MPI_COMM_HORIZ, send_requests(2), ierr)
            ! Receive from west
            recvtag = 0
            call mpi_irecv(a%s(0,yoffset,0-(halo_size-1)),blocklength,    &
                           MPI_DOUBLE_PRECISION,neighbour_w_rank,recvtag, &
                           MPI_COMM_HORIZ, recv_requests(3), ierr)
            ! Receive from east
            sendtag = 1
            call mpi_irecv(a%s(0,yoffset,a_n+1),blocklength,          &
                           MPI_DOUBLE_PRECISION,neighbour_e_rank,recvtag, &
                           MPI_COMM_HORIZ, recv_requests(4), ierr)
            ! Send to east
            sendtag = 0
            call mpi_isend(a%s(0,yoffset,a_n-(halo_size-1)),blocklength,  &
                           MPI_DOUBLE_PRECISION,neighbour_e_rank,sendtag, &
                           MPI_COMM_HORIZ, send_requests(3), ierr)
            ! Send to west
            recvtag = 1
            call mpi_isend(a%s(0,yoffset,1),blocklength,                &
                           MPI_DOUBLE_PRECISION,neighbour_w_rank,sendtag,   &
                           MPI_COMM_HORIZ, send_requests(4), ierr)
          end if
        end if
      end subroutine ihaloswap
    
    !==================================================================
    !  Halo exchange
    !
    !  For all processes with horizontal indices that are multiples
    !  of 2^(pproc-m), update halos with information held by
    !  neighbouring processes, e.g. for pproc-m = 1, stepsize=2
    !
    !                      N  (0,2)
    !                      ^
    !                      !
    !                      v
    !
    !      W (2,0) <-->  (2,2)  <-->  E (2,4)
    !
    !                      ^
    !                      !
    !                      v
    !                      S (4,2)
    !
    
    !==================================================================
      subroutine haloswap_mnh(level,m, &  ! multigrid- and processor- level
                          a)          ! data field
        implicit none
        integer, intent(in) :: level
        integer, intent(in) :: m
        type(scalar3d), intent(inout) :: a
    
        integer :: a_n  ! horizontal grid size
        integer :: nz   ! vertical grid size
        integer, dimension(2) :: p_horiz
        integer :: stepsize
        integer :: ierr, rank, sendtag, recvtag
        integer :: stat(MPI_STATUS_SIZE)
        integer :: halo_size
        integer :: neighbour_n_rank
        integer :: neighbour_s_rank
        integer :: neighbour_e_rank
        integer :: neighbour_w_rank
        integer :: yoffset, blocklength
        integer, dimension(4) :: requests_ns
        integer, dimension(4) :: requests_ew
    
        integer, dimension(4) :: requests_nsT
        integer, dimension(4) :: requests_ewT
    
    
        halo_size = comm_param%halo_size
    
        ! Do nothing if we are only using one processor
        if (m > 0) then
          if (comm_measuretime) then
            call start_timer(t_haloswap(level,m))
          end if
          a_n = a%ix_max-a%ix_min+1
          nz = a%grid_param%nz
          stepsize = 2**(pproc-m)
    
          ! Work out rank, only execute on relevant processes
          call mpi_comm_rank(MPI_COMM_HORIZ, rank, ierr)
          call mpi_cart_coords(MPI_COMM_HORIZ,rank,dim_horiz,p_horiz,ierr)
    
          ! Work out ranks of neighbours