Skip to content
Snippets Groups Projects
Forked from Méso-NH / Méso-NH code
2287 commits behind, 1061 commits ahead of the upstream repository.
communication.f90 86.70 KiB
!=== COPYRIGHT AND LICENSE STATEMENT ===
!
!  This file is part of the TensorProductMultigrid code.
!  
!  (c) The copyright relating to this work is owned jointly by the
!  Crown, Met Office and NERC [2014]. However, it has been created
!  with the help of the GungHo Consortium, whose members are identified
!  at https://puma.nerc.ac.uk/trac/GungHo/wiki .
!  
!  Main Developer: Eike Mueller
!  
!  TensorProductMultigrid is free software: you can redistribute it and/or
!  modify it under the terms of the GNU Lesser General Public License as
!  published by the Free Software Foundation, either version 3 of the
!  License, or (at your option) any later version.
!  
!  TensorProductMultigrid is distributed in the hope that it will be useful,
!  but WITHOUT ANY WARRANTY; without even the implied warranty of
!  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
!  GNU Lesser General Public License for more details.
!  
!  You should have received a copy of the GNU Lesser General Public License
!  along with TensorProductMultigrid (see files COPYING and COPYING.LESSER).
!  If not, see <http://www.gnu.org/licenses/>.
!
!=== COPYRIGHT AND LICENSE STATEMENT ===


!==================================================================
!
!  MPI communication routines for multigrid code
!
!    Eike Mueller, University of Bath, Feb 2012
!
!==================================================================

module communication
  use messages
  use datatypes

#ifndef MNH
  use mpi
#else
  use modd_mpif
#endif   
  use timer

  use mode_mnh_allocate_mg_halo

  implicit none

public::comm_preinitialise
public::comm_initialise
public::comm_finalise
public::scalarprod_mnh
public::scalarprod
public::print_scalaprod2
public::boundary_mnh
public::haloswap_mnh
public::haloswap
public::ihaloswap_mnh
public::ihaloswap
public::collect
public::distribute
public::i_am_master_mpi
public::master_rank
public::pproc
public::MPI_COMM_HORIZ
public::comm_parameters
public::comm_measuretime


  ! Number of processors
  ! n_proc = 2^(2*pproc), with integer pproc
  integer :: pproc

! Rank of master process
  integer, parameter :: master_rank = 0
! Am I the master process?
  logical :: i_am_master_mpi

  integer, parameter :: dim = 3  ! Dimension
  integer, parameter :: dim_horiz = 2  ! Horizontal dimension
  integer :: MPI_COMM_HORIZ ! Communicator with horizontal partitioning

private

! Data types for halo exchange in both x- and y-direction
  integer, dimension(:,:,:), allocatable :: halo_type

! Array for Halo Exchange <-> GPU copy/managed
  type type_halo_T
     real, dimension(:,:,:), pointer, contiguous :: haloTin,haloTout
  end type type_halo_T
! MPI vector data types
  ! Halo for data exchange in north-south direction
  integer, allocatable, dimension(:,:) :: halo_ns
  integer, allocatable, dimension(:,:) :: halo_nst
  type(type_halo_T), allocatable, dimension(:,:) :: tab_halo_nt
  type(type_halo_T), allocatable, dimension(:,:) :: tab_halo_st
  integer, allocatable, dimension(:,:) :: halo_wet
  type(type_halo_T), allocatable, dimension(:,:) :: tab_halo_wt
  type(type_halo_T), allocatable, dimension(:,:) :: tab_halo_et
  ! Vector data type for interior of field a(level,m)
  integer, allocatable, dimension(:,:) :: interior
  integer, allocatable, dimension(:,:) :: interiorT
  type(type_halo_T), allocatable, dimension(:,:) :: tab_interiorT_ne,tab_interiorT_sw,tab_interiorT_se
  ! Vector data type for one quarter of interior of field
  ! at level a(level,m). This has the same size (and can be
  ! used for communications with) the interior of a(level,m+1)
  integer, allocatable, dimension(:,:) :: sub_interior
  integer, allocatable, dimension(:,:) :: sub_interiorT
  type(type_halo_T), allocatable, dimension(:,:) :: tab_sub_interiorT_ne,tab_sub_interiorT_sw,tab_sub_interiorT_se
  ! Timer for halo swaps
  type(time), allocatable, dimension(:,:) :: t_haloswap
  ! Timer for collect and distribute
  type(time), allocatable, dimension(:) :: t_collect
  type(time), allocatable, dimension(:) :: t_distribute
  ! Parallelisation parameters
  ! Measure communication times?
  logical :: comm_measuretime

  ! Parallel communication parameters
  type comm_parameters
    ! Size of halos
    integer :: halo_size
  end type comm_parameters

  type(comm_parameters) :: comm_param

! Data layout
! ===========
!
!  The number of processes has to be of the form nproc = 2^(2*pproc) to
!  ensure that data can be distributed between processes.
!  The processes are arranged in a (2^pproc) x (2^pproc) cartesian grid
!  in the horizontal plane (i.e. vertical columns are always local to one
!  process), which is implemented via the communicator MPI_COMM_HORIZ.
!  This MPI_cart_rank() and MPI_cart_shift() can then be used to
!  easily identify neighbouring processes.

!  The number of data grid cells in each direction has to be a multiply
!  of 2**(L-1) where L is the number of levels (including the coarse
!  and fine level), with the coarse level corresponding to level=1.
!  Also define L_split as the level where we start to pull together
!  data. For levels > L_split each position in the cartesian grid is
!  included in the work, below this only a subset of processes is
!  used.
!
!  Each grid a(level,m) is identified by two numbers:
!  (1) The multigrid level it belongs to (level)
!  (2) The number of active processes that operate on it (2^(2*m)).
!
!  For level > L_split, m=procp. For L_split we store a(L_split,pproc) and
!  a(L_split,pproc-1), and only processes with even coordinates in both
!  horizontal directions use this grid.
!  Below that level, store a(L_split-1,pproc-1) and a(L_split-1,pproc-2),
!  where only processes for which both horiontal coordinates are
!  multiples of four use the latter. This is continued until only on
!  process is left.
!
!
!  level
!    L          a(L,pproc)
!    L-1        a(L-1,pproc)
!    ...
!    L_split    a(L_split,pproc)    a(L_split  ,pproc-1)
!    L_split-1                      a(L_split-1,pproc-1)  a(L_split-1,pproc-2)
!
!                                                                  ... a(3,1)
!                                                                      a(2,1)
!                                                                      a(1,1)
!
!  When moving from left to right in the above graph the total number of
!  grid cells does not change, but the number of data points per process
!  increases by a factor of 4.
!
! Parallel operations
! ===================
!
!   (*)  Halo exchange. Update halo with data from neighbouring
!        processors in cartesian grid on current (level,m)
!   (*)  Collect data on all processes at (level,m) on those
!        processes that are still active on (level,m-1).
!   (*)  Distribute data at (level,m-1) and duplicate on all processes
!        that are active at (level,m).
!
!   Note that in the cartesian processor grid the first coordinate
!   is the North-South (y-) direction, the second coordinate is the
!   East-West (x-) direction, i.e. the layout is this:
!
!   p_0 (0,0)   p_1 (0,1)   p_2 (0,2)   p_3 (0,3)
!
!   p_4 (1,0)   p_5 (1,1)   p_6 (1,2)   p_7 (1,3)
!
!   p_8 (2,0)   p_9 (2,1)   p_10 (2,2)  p_11 (2,3)
!
!                       [...]
!
!
!   Normal multigrid restriction and prolongation are used to
!   move between levels with fixed m.
!
!

contains

!==================================================================
! Pre-initialise communication routines
!==================================================================
  subroutine comm_preinitialise()
    implicit none
    integer :: nproc, ierr, rank
    call mpi_comm_size(MPI_COMM_WORLD, nproc, ierr)
    call mpi_comm_rank(MPI_COMM_WORLD, rank, ierr)
    i_am_master_mpi = (rank == master_rank)
    ! Check that nproc = 2^(2*p)
    !pproc = floor(log(1.0d0*nproc)/log(4.0d0))
    pproc = nint(log(1.0d0*nproc)/log(4.0d0))
    if ( (nproc - 4**pproc) .ne. 0) then
      print*,"Number of processors has to be 2^(2*pproc) with integer nproc,pproc=",nproc,pproc
      call fatalerror("Number of processors has to be 2^(2*pproc) with integer pproc.")
    end if
    if (i_am_master_mpi) then
      write(STDOUT,'("PARALLEL RUN")')
      write(STDOUT,'("Number of processors : 2^(2*pproc) = ",I10," with pproc = ",I6)') &
        nproc, pproc
    end if
    ! Create halo data types

  end subroutine comm_preinitialise

!==================================================================
! Initialise communication routines
!==================================================================
  subroutine comm_initialise(n_lev,        & !} multigrid parameters
                            lev_split,     & !}
                            grid_param,    & ! Grid parameters
                            comm_param_in)   ! Parallel communication
                                             ! parameters
    implicit none
    integer, intent(in) :: n_lev
    integer, intent(in) :: lev_split
    type(grid_parameters), intent(inout) :: grid_param
    type(comm_parameters), intent(in)    :: comm_param_in
    integer :: n
    integer :: nz
    integer :: rank, nproc, ierr
    integer :: count, blocklength, stride
    integer, dimension(2) :: p_horiz
    integer :: m, level, nlocal
    logical :: reduced_m
    integer :: halo_size
    character(len=32) :: t_label

    integer,parameter    :: nb_dims=3
    integer,dimension(nb_dims) :: profil_tab,profil_sous_tab,coord_debut

    n = grid_param%n
    nz = grid_param%nz

    comm_param = comm_param_in
    halo_size = comm_param%halo_size

    call mpi_comm_size(MPI_COMM_WORLD, nproc, ierr)

    ! Create cartesian topology
    call mpi_cart_create(MPI_COMM_WORLD,        & ! Old communicator name
                         dim_horiz,             & ! horizontal dimension
                         (/2**pproc,2**pproc/), & ! extent in each horizontal direction
                         (/.false.,.false./),   & ! periodic?
                         .false.,               & ! reorder?
                         MPI_COMM_HORIZ,        & ! Name of new communicator
                         ierr)
   ! calculate and display rank and corrdinates in cartesian grid
    call mpi_comm_rank(MPI_COMM_HORIZ, rank, ierr)
    call mpi_cart_coords(MPI_COMM_HORIZ,rank,dim_horiz,p_horiz,ierr)

    ! Local size of (horizontal) grid
    nlocal = n/2**pproc

    ! === Set up data types ===
    ! Halo for exchange in north-south direction
    if (LUseO) allocate(halo_ns(n_lev,0:pproc))
    if (LUseT) allocate(halo_nst(n_lev,0:pproc))
    if (LUseT) allocate(tab_halo_nt(n_lev,0:pproc))
    if (LUseT) allocate(tab_halo_st(n_lev,0:pproc))
    if (LUseT) allocate(halo_wet(n_lev,0:pproc))
    if (LUseT) allocate(tab_halo_wt(n_lev,0:pproc))
    if (LUseT) allocate(tab_halo_et(n_lev,0:pproc))
    ! Interior data types
    if (LUseO) allocate(interior(n_lev,0:pproc))
    if (LUseO) allocate(sub_interior(n_lev,0:pproc))
    if (LUseT) allocate(interiorT(n_lev,0:pproc))
    if (LUseT) allocate(sub_interiorT(n_lev,0:pproc))
    if (LUseT) allocate(tab_interiorT_ne(n_lev,0:pproc))
    if (LUseT) allocate(tab_interiorT_sw(n_lev,0:pproc))
    if (LUseT) allocate(tab_interiorT_se(n_lev,0:pproc))
    if (LUseT) allocate(tab_sub_interiorT_ne(n_lev,0:pproc))
    if (LUseT) allocate(tab_sub_interiorT_sw(n_lev,0:pproc))
    if (LUseT) allocate(tab_sub_interiorT_se(n_lev,0:pproc))
    ! Timer
    allocate(t_haloswap(n_lev,0:pproc))
    allocate(t_collect(0:pproc))
    allocate(t_distribute(0:pproc))
    do m=0,pproc
      write(t_label,'("t_collect(",I3,")")') m
      call initialise_timer(t_collect(m),t_label)
      write(t_label,'("t_distribute(",I3,")")') m
      call initialise_timer(t_distribute(m),t_label)
    end do

    m = pproc
    level = n_lev
    reduced_m = .false.
    do while (level > 0)
      ! --- Create halo data types ---
      if (LUseO) then
      ! NS- (y-) direction
      count = nlocal
      blocklength = (nz+2)*halo_size
      stride = (nlocal+2*halo_size)*(nz+2)
      call mpi_type_vector(count,blocklength,stride,MPI_DOUBLE_PRECISION, &
                           halo_ns(level,m),ierr)
      call mpi_type_commit(halo_ns(level,m),ierr)
      endif
      ! tranpose
      if (LUseT) then
      ! NS- (y-) transpose direction
      count =        nz+2                                        ! nlocal
      blocklength =  nlocal*halo_size                            ! (nz+2)*halo_size
      stride =       (nlocal+2*halo_size) * (nlocal+2*halo_size) ! (nlocal+2*halo_size)*(nz+2)
      call mpi_type_vector(count,blocklength,stride,MPI_DOUBLE_PRECISION, &
                           halo_nst(level,m),ierr)
      call mpi_type_commit(halo_nst(level,m),ierr)
      ! allocate send/recv buffer for GPU copy/managed
      call mnh_allocate_mg_halo(tab_halo_nt(level,m)%haloTin,nlocal,halo_size,nz+2)
      call mnh_allocate_mg_halo(tab_halo_nt(level,m)%haloTout,nlocal,halo_size,nz+2)
      call mnh_allocate_mg_halo(tab_halo_st(level,m)%haloTin,nlocal,halo_size,nz+2)
      call mnh_allocate_mg_halo(tab_halo_st(level,m)%haloTout,nlocal,halo_size,nz+2)      
      
     ! WE- (x-) transpose direction
      count =        (nz+2)*(nlocal+2*halo_size)*halo_size       ! nlocal
      blocklength =  1*halo_size                                 ! (nz+2)*halo_size
      stride =       nlocal+2*halo_size                          ! (nlocal+2*halo_size)*(nz+2)
      call mpi_type_vector(count,blocklength,stride,MPI_DOUBLE_PRECISION, &
                           halo_wet(level,m),ierr)
      call mpi_type_commit(halo_wet(level,m),ierr)
      ! allocate send/recv buffer for GPU copy/managed      !
      call mnh_allocate_mg_halo(tab_halo_wt(level,m)%haloTin,halo_size,nlocal+2*halo_size,nz+2)
      call mnh_allocate_mg_halo(tab_halo_wt(level,m)%haloTout,halo_size,nlocal+2*halo_size,nz+2)
      call mnh_allocate_mg_halo(tab_halo_et(level,m)%haloTin,halo_size,nlocal+2*halo_size,nz+2)
      call mnh_allocate_mg_halo(tab_halo_et(level,m)%haloTout,halo_size,nlocal+2*halo_size,nz+2)      
      endif
#ifndef NDEBUG
  if (ierr .ne. 0) &
    call fatalerror("Commit halo_ns failed in mpi_type_commit().")
#endif
      ! --- Create interior data types ---
      if (LUseO) then
         count = nlocal
         blocklength = nlocal*(nz+2)
         stride = (nz+2)*(nlocal+2*halo_size)
         call mpi_type_vector(count,blocklength,stride,MPI_DOUBLE_PRECISION,interior(level,m),ierr)
         call mpi_type_commit(interior(level,m),ierr)
         count = nlocal/2
         blocklength = nlocal/2*(nz+2)
         stride = (nlocal+2*halo_size)*(nz+2)
         call mpi_type_vector(count,blocklength,stride,MPI_DOUBLE_PRECISION,sub_interior(level,m),ierr)
         call mpi_type_commit(sub_interior(level,m),ierr)
      end if
      if (LUseT) then
         ! interiorT
         if ( nlocal /= 0 ) then
            profil_tab      = (/ nlocal+2*halo_size , nlocal+2*halo_size , nz+2 /)
            profil_sous_tab = (/ nlocal             , nlocal             , nz+2 /)
            coord_debut     = (/ 0                  , 0                  , 0    /)
            call MPI_TYPE_CREATE_SUBARRAY(nb_dims,profil_tab,profil_sous_tab,coord_debut,&
                 MPI_ORDER_FORTRAN,MPI_DOUBLE_PRECISION,interiorT(level,m),ierr)
            call mpi_type_commit(interiorT(level,m),ierr)
            call mnh_allocate_mg_halo(tab_interiorT_ne(level,m)%haloTin,nlocal,nlocal,nz+2)
            call mnh_allocate_mg_halo(tab_interiorT_ne(level,m)%haloTout,nlocal,nlocal,nz+2)
            call mnh_allocate_mg_halo(tab_interiorT_sw(level,m)%haloTin,nlocal,nlocal,nz+2)
            call mnh_allocate_mg_halo(tab_interiorT_sw(level,m)%haloTout,nlocal,nlocal,nz+2)
            call mnh_allocate_mg_halo(tab_interiorT_se(level,m)%haloTin,nlocal,nlocal,nz+2)
            call mnh_allocate_mg_halo(tab_interiorT_se(level,m)%haloTout,nlocal,nlocal,nz+2)
         end if
         ! sub_interiorT
         if ( (nlocal/2) /= 0 ) then
            profil_tab      = (/ nlocal+2*halo_size , nlocal+2*halo_size , nz+2 /)
            profil_sous_tab = (/ nlocal/2           , nlocal/2           , nz+2 /)
            coord_debut     = (/ 0                  , 0                  , 0    /)
            call MPI_TYPE_CREATE_SUBARRAY(nb_dims,profil_tab,profil_sous_tab,coord_debut,&
                 MPI_ORDER_FORTRAN,MPI_DOUBLE_PRECISION,sub_interiorT(level,m),ierr)
            call mpi_type_commit(sub_interiorT(level,m),ierr)
            call mnh_allocate_mg_halo(tab_sub_interiorT_ne(level,m)%haloTin,nlocal/2,nlocal/2,nz+2)
            call mnh_allocate_mg_halo(tab_sub_interiorT_ne(level,m)%haloTout,nlocal/2,nlocal/2,nz+2)
            call mnh_allocate_mg_halo(tab_sub_interiorT_sw(level,m)%haloTin,nlocal/2,nlocal/2,nz+2)
            call mnh_allocate_mg_halo(tab_sub_interiorT_sw(level,m)%haloTout,nlocal/2,nlocal/2,nz+2)
            call mnh_allocate_mg_halo(tab_sub_interiorT_se(level,m)%haloTin,nlocal/2,nlocal/2,nz+2)
            call mnh_allocate_mg_halo(tab_sub_interiorT_se(level,m)%haloTout,nlocal/2,nlocal/2,nz+2)            
         end if
      end if
      ! --- Create timers ---
      write(t_label,'("t_haloswap(",I3,",",I3,")")') level,m
      call initialise_timer(t_haloswap(level,m),t_label)

      ! If we are below L_split, split data
      if ( (level .le. lev_split) .and. (m > 0) .and. (.not. reduced_m)) then
        reduced_m = .true.
        m = m-1
        nlocal = 2*nlocal
        cycle
      end if
      reduced_m = .false.
      level = level-1
      nlocal = nlocal/2
    end do

  end subroutine comm_initialise

!==================================================================
! Finalise communication routines
!==================================================================
  subroutine comm_finalise(n_lev,   &  ! } Multigrid parameters
                           lev_split,  & !}
                           grid_param )  ! } Grid parameters
    implicit none
    integer, intent(in) :: n_lev
    integer, intent(in) :: lev_split
    type(grid_parameters), intent(in) :: grid_param
    ! local var
    logical :: reduced_m
    integer :: level, m
    integer :: ierr
    integer :: nlocal,n
    character(len=80) :: s

   ! Local size of (horizontal) grid
    n = grid_param%n
    nlocal = n/2**pproc

    m = pproc
    level = n_lev
    reduced_m = .false.
    if (i_am_master_mpi) then
      write(STDOUT,'(" *** Finalising communications ***")')
    end if
    call print_timerinfo("--- Communication timing results ---")
    do while (level > 0)
      write(s,'("level = ",I3,", m = ",I3)') level, m
      call print_timerinfo(s)
      ! --- Print out timer information ---
      call print_elapsed(t_haloswap(level,m),.True.,1.0_rl)
      ! --- Free halo data types ---
      if (LUseO) call mpi_type_free(halo_ns(level,m),ierr)
      if (LUseT) call mpi_type_free(halo_nst(level,m),ierr)
      if (LUseT) call mpi_type_free(halo_wet(level,m),ierr)
      ! --- Free interior data types ---
      if (LUseO) call mpi_type_free(interior(level,m),ierr)
      if (LUseO) call mpi_type_free(sub_interior(level,m),ierr)
      if (LUseT .and. (nlocal /= 0 ) ) call mpi_type_free(interiorT(level,m),ierr)
      if (LUseT .and. ( (nlocal/2) /= 0 ) ) call mpi_type_free(sub_interiorT(level,m),ierr)
      ! If we are below L_split, split data
      if ( (level .le. lev_split) .and. (m > 0) .and. (.not. reduced_m)) then
        reduced_m = .true.
        m = m-1
        nlocal = 2*nlocal
        cycle
      end if
      reduced_m = .false.
      level = level-1
      nlocal = nlocal/2
    end do
    do m=pproc,0,-1
      write(s,'("m = ",I3)') m
      call print_timerinfo(s)
      ! --- Print out timer information ---
      call print_elapsed(t_collect(m),.True.,1.0_rl)
      call print_elapsed(t_distribute(m),.True.,1.0_rl)
    end do

    ! Deallocate arrays
    if (LUseO) deallocate(halo_ns)
    if (LUseT) deallocate(halo_nst,halo_wet)
    if (LUseO) deallocate(interior)
    if (LUseO) deallocate(sub_interior)
    if (LUseT) deallocate(interiorT)
    if (LUseT) deallocate(sub_interiorT)
    deallocate(t_haloswap)
    deallocate(t_collect)
    deallocate(t_distribute)
    if (i_am_master_mpi) then
      write(STDOUT,'("")')
    end if

  end subroutine comm_finalise

!==================================================================
!  Scalar product of two fields
!==================================================================
  subroutine scalarprod_mnh(m, a, b, s)
#ifdef MNH_MPPDB
    USE MODE_MPPDB
#endif    
    implicit none
    integer, intent(in) :: m
    type(scalar3d), intent(in) :: a
    type(scalar3d), intent(in) :: b
    real(kind=rl), intent(out) :: s
    !local var
    integer :: nprocs, rank, ierr
    integer :: p_horiz(2)
    integer :: stepsize
    integer, parameter :: dim_horiz = 2
    real(kind=rl) :: local_sum, global_sum 
    real(kind=rl) :: local_sumt,global_sumt
    integer :: nlocal, nz, i
    integer :: ix,iy,iz
    real(kind=rl) :: ddot

    integer :: iy_min,iy_max, ix_min,ix_max
    integer :: icompy_min,icompy_max, icompx_min,icompx_max
    real , dimension(:,:,:) , pointer , contiguous :: za_st,zb_st

    nlocal = a%ix_max-a%ix_min+1
    nz = a%grid_param%nz

    iy_min = a%iy_min
    iy_max = a%iy_max
    ix_min = a%ix_min
    ix_max = a%ix_max

    icompy_min = a%icompy_min
    icompy_max = a%icompy_max
    icompx_min = a%icompx_min
    icompx_max = a%icompx_max

    ! Work out coordinates of processor
    call mpi_comm_size(MPI_COMM_HORIZ,nprocs,ierr)
    call mpi_comm_rank(MPI_COMM_HORIZ,rank,ierr)
    stepsize = 2**(pproc-m)
    if (nprocs > 1) then
      ! Only inlcude local sum if the processor coordinates
      ! are multiples of stepsize
      call mpi_cart_coords(MPI_COMM_HORIZ,rank,dim_horiz,p_horiz,ierr)
      if ( (stepsize == 1) .or.                   &
           ( (stepsize > 1) .and.                 &
             (mod(p_horiz(1),stepsize)==0) .and.  &
             (mod(p_horiz(2),stepsize)==0) ) ) then
        if (LUseO) then
           local_sum = 0.0_rl
           do i = 1, nlocal
              local_sum = local_sum &
                   + ddot((nz+2)*nlocal,a%s(0,1,i),1,b%s(0,1,i),1)
           end do
        end if
        if (LUseT) then
           local_sumt = 0.0_rl
           za_st => a%st
           zb_st => b%st
           !$acc kernels
           !$acc loop collapse(3) reduction(+:local_sumt)
           do iz=0,nz+1
              do iy=icompy_min,icompy_max
                 do ix=icompx_min,icompx_max
                    local_sumt = local_sumt &
                         + za_st(ix,iy,iz)*zb_st(ix,iy,iz)
                 end do
              end do
           end do
           !$acc end kernels
        end if
      else
        if (LUseO) local_sum = 0.0_rl
        if (LUseT) local_sumt = 0.0_rl
      end if
      if (LUseO) call mpi_allreduce(local_sum,global_sum,1,MPI_DOUBLE_PRECISION, &
                         MPI_SUM,MPI_COMM_HORIZ,ierr)
      if (LUseT) call mpi_allreduce(local_sumt,global_sumt,1,MPI_DOUBLE_PRECISION, &
                         MPI_SUM,MPI_COMM_HORIZ,ierr)
    else
      if (LUseO) then   
      global_sum = 0.0_rl
      do i = 1, nlocal
        global_sum = global_sum &
                   + ddot((nz+2)*nlocal,a%s(0,1,i),1,b%s(0,1,i),1)
      end do
      endif
      if (LUseT) then
      za_st => a%st
      zb_st => b%st
      global_sumt = 0.0_rl
         !$acc kernels
         do iz=0,nz+1
            do iy=iy_min,iy_max
               do ix=ix_min,ix_max
                  global_sumt = global_sumt &
                       + za_st(ix,iy,iz)*zb_st(ix,iy,iz)
               end do
            end do
         end do
         !$acc end kernels
      endif
    end if
    if (LUseO) then
       s = global_sum
    else
       s = global_sumt
    end if

#ifdef MNH
!!$    CALL MPPDB_CHECK0D_REAL_MG(s,"scalarprod_mnh")
#endif
    
  end subroutine scalarprod_mnh
!-------------------------------------------------------------------------------
  subroutine scalarprod(m, a, b, s)
    implicit none
    integer, intent(in) :: m
    type(scalar3d), intent(in) :: a
    type(scalar3d), intent(in) :: b
    real(kind=rl), intent(out) :: s
    integer :: nprocs, rank, ierr
    integer :: p_horiz(2)
    integer :: stepsize
    integer, parameter :: dim_horiz = 2
    real(kind=rl) :: local_sum, global_sum
    integer :: nlocal, nz, i
    real(kind=rl) :: ddot

    nlocal = a%ix_max-a%ix_min+1
    nz = a%grid_param%nz
    ! Work out coordinates of processor
    call mpi_comm_size(MPI_COMM_HORIZ,nprocs,ierr)
    call mpi_comm_rank(MPI_COMM_HORIZ,rank,ierr)
    stepsize = 2**(pproc-m)
    if (nprocs > 1) then
      ! Only inlcude local sum if the processor coordinates
      ! are multiples of stepsize
      call mpi_cart_coords(MPI_COMM_HORIZ,rank,dim_horiz,p_horiz,ierr)
      if ( (stepsize == 1) .or.                   &
           ( (stepsize > 1) .and.                 &
             (mod(p_horiz(1),stepsize)==0) .and.  &
             (mod(p_horiz(2),stepsize)==0) ) ) then
        local_sum = 0.0_rl
        do i = 1, nlocal
          local_sum = local_sum &
                    + ddot((nz+2)*nlocal,a%s(0,1,i),1,b%s(0,1,i),1)
        end do
      else
        local_sum = 0.0_rl
      end if
      call mpi_allreduce(local_sum,global_sum,1,MPI_DOUBLE_PRECISION, &
                         MPI_SUM,MPI_COMM_HORIZ,ierr)
    else
      global_sum = 0.0_rl
      do i = 1, nlocal
        global_sum = global_sum &
                   + ddot((nz+2)*nlocal,a%s(0,1,i),1,b%s(0,1,i),1)
      end do
    end if
    s = global_sum
  end subroutine scalarprod
!==================================================================
!  Pritn Scalar product^2  of 1 fields
!==================================================================
  subroutine print_scalaprod2(l,m, a, message )
    implicit none
    integer, intent(in) :: l,m
    type(scalar3d), intent(in) :: a
    character(len=*) , intent(in) :: message

    !local 
    real(kind=rl) :: s

    call scalarprod_mnh(m, a, a, s)
    s = sqrt(s)
    if (i_am_master_mpi) then
       write(STDOUT,'("Print_norm::",A,2I3,E23.15)') message, l,m,s
       call flush(STDOUT)
    end if

  end subroutine print_scalaprod2
!==================================================================
! Boundary Neumann
!==================================================================
  subroutine boundary_mnh(a)              ! data field

    implicit none
 
    type(scalar3d), intent(inout) :: a

    !local var
    integer :: n, ix_min,ix_max,iy_min,iy_max
    integer :: icompx_max,icompy_max

    real , dimension(:,:,:) , pointer , contiguous :: za_st
    
    ! Update Real Boundary for Newman case   u(0) = u(1) , etc ...

    !return

    n      = a%grid_param%n
    ix_min = a%ix_min
    ix_max = a%ix_max
    iy_min = a%iy_min
    iy_max = a%iy_max
    if (LUseO) then 
    if ( ix_min == 1 ) then
       a%s(:,:,0) = a%s(:,:,1)
    endif
    if ( ix_max == n ) then
       a%s(:,:,a%icompx_max+1) = a%s(:,:,a%icompx_max)
    endif
    if ( iy_min == 1 ) then
       a%s(:,0,:) = a%s(:,1,:)
    endif
    if ( iy_max == n ) then
       a%s(:,a%icompy_max+1,:) = a%s(:,a%icompy_max,:)
    endif
    endif
    if (LUseT) then
    !  transpose

    za_st => a%st
    icompx_max = a%icompx_max
    icompy_max = a%icompy_max

    !$acc kernels
    if ( ix_min == 1 ) then
       !acc kernels
       za_st(0,:,:) = za_st(1,:,:)
       !acc end kernels
    endif
    if ( ix_max == n ) then
       !acc kernels
       za_st(icompx_max+1,:,:) = za_st(icompx_max,:,:)
       !acc end kernels
    endif
    if ( iy_min == 1 ) then
       !acc kernels
       za_st(:,0,:) = za_st(:,1,:)
       !acc end kernels
    endif
    if ( iy_max == n ) then
       !acc kernels
       za_st(:,icompy_max+1,:) = za_st(:,icompy_max,:)
       !acc end kernels
    endif
    !$acc end kernels
   
    endif

  end subroutine boundary_mnh
!==================================================================
!  Initiate asynchronous halo exchange
!
!  For all processes with horizontal indices that are multiples
!  of 2^(pproc-m), update halos with information held by
!  neighbouring processes, e.g. for pproc-m = 1, stepsize=2
!
!                      N  (0,2)
!                      ^
!                      !
!                      v
!
!      W (2,0) <-->  (2,2)  <-->  E (2,4)
!
!                      ^
!                      !
!                      v
!                      S (4,2)
!
!==================================================================
  subroutine ihaloswap_mnh(level,m,       &  ! multigrid- and processor- level
                       a,             &  ! data field
                       send_requests, &  ! send requests (OUT)
                       recv_requests,  &  ! recv requests (OUT)
                       send_requestsT, &  ! send requests T (OUT)
                       recv_requestsT  &  ! recv requests T (OUT)
                       )
    implicit none
    integer, intent(in) :: level
    integer, intent(in) :: m
    integer, intent(out), dimension(4) :: send_requests
    integer, intent(out), dimension(4) :: recv_requests
    integer, intent(out), dimension(4) :: send_requestsT
    integer, intent(out), dimension(4) :: recv_requestsT
    type(scalar3d), intent(inout) :: a
    integer :: a_n  ! horizontal grid size
    integer :: nz   ! vertical grid size
    integer, dimension(2) :: p_horiz
    integer :: stepsize
    integer :: ierr, rank, sendtag, recvtag
    integer :: stat(MPI_STATUS_SIZE)
    integer :: halo_size
    integer :: neighbour_n_rank
    integer :: neighbour_s_rank
    integer :: neighbour_e_rank
    integer :: neighbour_w_rank
    integer :: yoffset, blocklength

    halo_size = comm_param%halo_size

    ! Do nothing if we are only using one processor
    if (m > 0) then
      a_n = a%ix_max-a%ix_min+1
      nz = a%grid_param%nz
      stepsize = 2**(pproc-m)

      ! Work out rank, only execute on relevant processes
      call mpi_comm_rank(MPI_COMM_HORIZ, rank, ierr)
      call mpi_cart_coords(MPI_COMM_HORIZ,rank,dim_horiz,p_horiz,ierr)

      ! Work out ranks of neighbours
      ! W -> E
      call mpi_cart_shift(MPI_COMM_HORIZ,1, stepsize, &
                          neighbour_w_rank,neighbour_e_rank,ierr)
      ! N -> S
      call mpi_cart_shift(MPI_COMM_HORIZ,0, stepsize, &
                          neighbour_n_rank,neighbour_s_rank,ierr)
      if ( (stepsize == 1) .or.                   &
         (  (mod(p_horiz(1),stepsize) == 0) .and. &
            (mod(p_horiz(2),stepsize) == 0) ) ) then
        if (halo_size == 1) then
          ! Do not include corners in send/recv
          yoffset = 1
          blocklength = a_n*(nz+2)*halo_size
        else
          yoffset = 1-halo_size
          blocklength = (a_n+2*halo_size)*(nz+2)*halo_size
        end if
        ! Receive from north
        recvtag = 1002
        if (LUseO) call mpi_irecv(a%s(0,0-(halo_size-1),1),1,       &
                       halo_ns(level,m),neighbour_n_rank,recvtag,   &
                       MPI_COMM_HORIZ, recv_requests(1), ierr)
        recvtag = 1012
        if (LUseT) then
           call mpi_irecv(a%st(1,0-(halo_size-1),0),1,      &
                halo_nst(level,m),neighbour_n_rank,recvtag,  &
                MPI_COMM_HORIZ, recv_requestsT(1), ierr)
        end if
        ! Receive from south
        recvtag = 1003
        if (LUseO) call mpi_irecv(a%s(0,a_n+1,1),1,                 &
                       halo_ns(level,m),neighbour_s_rank,recvtag,   &
                       MPI_COMM_HORIZ, recv_requests(2), ierr)
        recvtag = 1013
        if (LUseT) then
           call mpi_irecv(a%st(1,a_n+1,0),1,                &
                halo_nst(level,m),neighbour_s_rank,recvtag,  &
                MPI_COMM_HORIZ, recv_requestsT(2), ierr)
        end if
        ! Send to south
        sendtag = 1002
        if (LUseO) call mpi_isend(a%s(0,a_n-(halo_size-1),1),1,     &
                       halo_ns(level,m),neighbour_s_rank,sendtag,   &
                       MPI_COMM_HORIZ, send_requests(1), ierr)
        sendtag = 1012
        if (LUseT) then
           call mpi_isend(a%st(1,a_n-(halo_size-1),0),1,    &
                halo_nst(level,m),neighbour_s_rank,sendtag,  &
                MPI_COMM_HORIZ, send_requestsT(1), ierr)
        end if
        ! Send to north
        sendtag = 1003
        if (LUseO) call mpi_isend(a%s(0,1,1),1,                     &
                       halo_ns(level,m),neighbour_n_rank,sendtag,   &
                       MPI_COMM_HORIZ, send_requests(2), ierr)
        sendtag = 1013
        if (LUseT) then
           call mpi_isend(a%st(1,1,0),1,                    &
                       halo_nst(level,m),neighbour_n_rank,sendtag,  &
                       MPI_COMM_HORIZ, send_requestsT(2), ierr)
        end if
        ! Receive from west
        recvtag = 1000
        if (LUseO) call mpi_irecv(a%s(0,yoffset,0-(halo_size-1)),blocklength,  &
                       MPI_DOUBLE_PRECISION,neighbour_w_rank,recvtag, &
                       MPI_COMM_HORIZ, recv_requests(3), ierr)
        recvtag = 1010
        if (LUseT) then
           call mpi_irecv(a%st(0-(halo_size-1),0,0),1,  &
                halo_wet(level,m),neighbour_w_rank,recvtag, &
                MPI_COMM_HORIZ, recv_requestsT(3), ierr)
        end if
        ! Receive from east
        sendtag = 1001
        if (LUseO) call mpi_irecv(a%s(0,yoffset,a_n+1),blocklength,          &
                       MPI_DOUBLE_PRECISION,neighbour_e_rank,recvtag, &
                       MPI_COMM_HORIZ, recv_requests(4), ierr)
        sendtag = 1011
        if (LUseT) then
           call mpi_irecv(a%st(a_n+1,0,0),1,          &
                halo_wet(level,m),neighbour_e_rank,recvtag, &
                MPI_COMM_HORIZ, recv_requestsT(4), ierr)
        end if
        ! Send to east
        sendtag = 1000
        if (LUseO) call mpi_isend(a%s(0,yoffset,a_n-(halo_size-1)),blocklength,  &
                       MPI_DOUBLE_PRECISION,neighbour_e_rank,sendtag, &
                       MPI_COMM_HORIZ, send_requests(3), ierr)
        sendtag = 1010
        if (LUseT) then
           call mpi_isend(a%st(a_n-(halo_size-1),0,0),1,  &
                halo_wet(level,m),neighbour_e_rank,sendtag, &
                MPI_COMM_HORIZ, send_requestsT(3), ierr)
        end if
        ! Send to west
        recvtag = 1001
        if (LUseO) call mpi_isend(a%s(0,yoffset,1),blocklength,                &
                       MPI_DOUBLE_PRECISION,neighbour_w_rank,sendtag,   &
                       MPI_COMM_HORIZ, send_requests(4), ierr)
        recvtag = 1011
        if (LUseT) then
           call mpi_isend(a%st(1,0,0),1,                &
                       halo_wet(level,m),neighbour_w_rank,sendtag,   &
                       MPI_COMM_HORIZ, send_requestsT(4), ierr)
        end if
      end if
    end if
  end subroutine ihaloswap_mnh
!==================================================================
  subroutine ihaloswap(level,m,       &  ! multigrid- and processor- level
                       a,             &  ! data field
                       send_requests, &  ! send requests (OUT)
                       recv_requests  &  ! recv requests (OUT)
                       )
    implicit none
    integer, intent(in) :: level
    integer, intent(in) :: m
    integer, intent(out), dimension(4) :: send_requests
    integer, intent(out), dimension(4) :: recv_requests
    type(scalar3d), intent(inout) :: a
    integer :: a_n  ! horizontal grid size
    integer :: nz   ! vertical grid size
    integer, dimension(2) :: p_horiz
    integer :: stepsize
    integer :: ierr, rank, sendtag, recvtag
    integer :: stat(MPI_STATUS_SIZE)
    integer :: halo_size
    integer :: neighbour_n_rank
    integer :: neighbour_s_rank
    integer :: neighbour_e_rank
    integer :: neighbour_w_rank
    integer :: yoffset, blocklength

    halo_size = comm_param%halo_size

    ! Do nothing if we are only using one processor
    if (m > 0) then
      a_n = a%ix_max-a%ix_min+1
      nz = a%grid_param%nz
      stepsize = 2**(pproc-m)

      ! Work out rank, only execute on relevant processes
      call mpi_comm_rank(MPI_COMM_HORIZ, rank, ierr)
      call mpi_cart_coords(MPI_COMM_HORIZ,rank,dim_horiz,p_horiz,ierr)

      ! Work out ranks of neighbours
      ! W -> E
      call mpi_cart_shift(MPI_COMM_HORIZ,1, stepsize, &
                          neighbour_w_rank,neighbour_e_rank,ierr)
      ! N -> S
      call mpi_cart_shift(MPI_COMM_HORIZ,0, stepsize, &
                          neighbour_n_rank,neighbour_s_rank,ierr)
      if ( (stepsize == 1) .or.                   &
         (  (mod(p_horiz(1),stepsize) == 0) .and. &
            (mod(p_horiz(2),stepsize) == 0) ) ) then
        if (halo_size == 1) then
          ! Do not include corners in send/recv
          yoffset = 1
          blocklength = a_n*(nz+2)*halo_size
        else
          yoffset = 1-halo_size
          blocklength = (a_n+2*halo_size)*(nz+2)*halo_size
        end if
        ! Receive from north
        recvtag = 2
        call mpi_irecv(a%s(0,0-(halo_size-1),1),1,                  &
                       halo_ns(level,m),neighbour_n_rank,recvtag,   &
                       MPI_COMM_HORIZ, recv_requests(1), ierr)
        ! Receive from south
        recvtag = 3
        call mpi_irecv(a%s(0,a_n+1,1),1,                            &
                       halo_ns(level,m),neighbour_s_rank,recvtag,   &
                       MPI_COMM_HORIZ, recv_requests(2), ierr)
        ! Send to south
        sendtag = 2
        call mpi_isend(a%s(0,a_n-(halo_size-1),1),1,                &
                       halo_ns(level,m),neighbour_s_rank,sendtag,   &
                       MPI_COMM_HORIZ, send_requests(1), ierr)
        ! Send to north
        sendtag = 3
        call mpi_isend(a%s(0,1,1),1,                                &
                       halo_ns(level,m),neighbour_n_rank,sendtag,   &
                       MPI_COMM_HORIZ, send_requests(2), ierr)
        ! Receive from west
        recvtag = 0
        call mpi_irecv(a%s(0,yoffset,0-(halo_size-1)),blocklength,    &
                       MPI_DOUBLE_PRECISION,neighbour_w_rank,recvtag, &
                       MPI_COMM_HORIZ, recv_requests(3), ierr)
        ! Receive from east
        sendtag = 1
        call mpi_irecv(a%s(0,yoffset,a_n+1),blocklength,          &
                       MPI_DOUBLE_PRECISION,neighbour_e_rank,recvtag, &
                       MPI_COMM_HORIZ, recv_requests(4), ierr)
        ! Send to east
        sendtag = 0
        call mpi_isend(a%s(0,yoffset,a_n-(halo_size-1)),blocklength,  &
                       MPI_DOUBLE_PRECISION,neighbour_e_rank,sendtag, &
                       MPI_COMM_HORIZ, send_requests(3), ierr)
        ! Send to west
        recvtag = 1
        call mpi_isend(a%s(0,yoffset,1),blocklength,                &
                       MPI_DOUBLE_PRECISION,neighbour_w_rank,sendtag,   &
                       MPI_COMM_HORIZ, send_requests(4), ierr)
      end if
    end if
  end subroutine ihaloswap

!==================================================================
!  Halo exchange
!
!  For all processes with horizontal indices that are multiples
!  of 2^(pproc-m), update halos with information held by
!  neighbouring processes, e.g. for pproc-m = 1, stepsize=2
!
!                      N  (0,2)
!                      ^
!                      !
!                      v
!
!      W (2,0) <-->  (2,2)  <-->  E (2,4)
!
!                      ^
!                      !
!                      v
!                      S (4,2)
!
!==================================================================
  subroutine haloswap_mnh(level,m, &  ! multigrid- and processor- level
                      a)          ! data field
    implicit none
    integer, intent(in) :: level
    integer, intent(in) :: m
    type(scalar3d), intent(inout) :: a
    !local var
    integer :: a_n  ! horizontal grid size
    integer :: nz   ! vertical grid size
    integer, dimension(2) :: p_horiz
    integer :: stepsize
    integer :: ierr, rank, sendtag, recvtag
    integer :: stat(MPI_STATUS_SIZE)
    integer :: halo_size
    integer :: neighbour_n_rank
    integer :: neighbour_s_rank
    integer :: neighbour_e_rank
    integer :: neighbour_w_rank
    integer :: yoffset, blocklength
    integer, dimension(4) :: requests_ns
    integer, dimension(4) :: requests_ew
    integer, dimension(4) :: requests_nsT
    integer, dimension(4) :: requests_ewT

    integer :: ii,ij,ik
    real , pointer , contiguous , dimension(:,:,:) :: zst
    !
    real , pointer , contiguous , dimension(:,:,:) :: ztab_halo_st_haloTin,ztab_halo_nt_haloTin
    real , pointer , contiguous , dimension(:,:,:) :: ztab_halo_et_haloTin,ztab_halo_wt_haloTin
    !
    real , pointer , contiguous , dimension(:,:,:) :: ztab_halo_nt_haloTout,ztab_halo_st_haloTout
    real , pointer , contiguous , dimension(:,:,:) :: ztab_halo_wt_haloTout,ztab_halo_et_haloTout

    INTEGER,PARAMETER :: IS_WEST=1 , IS_EAST=2, IS_SOUTH=3, IS_NORTH=4

    LOGICAL :: Gneighbour_s,Gneighbour_n,Gneighbour_e,Gneighbour_w

    halo_size = comm_param%halo_size

    ! Do nothing if we are only using one processor
    if (m > 0) then
      if (comm_measuretime) then
        call start_timer(t_haloswap(level,m))
      end if
      a_n = a%ix_max-a%ix_min+1
      nz = a%grid_param%nz
      stepsize = 2**(pproc-m)

      ! Work out rank, only execute on relevant processes
      call mpi_comm_rank(MPI_COMM_HORIZ, rank, ierr)
      call mpi_cart_coords(MPI_COMM_HORIZ,rank,dim_horiz,p_horiz,ierr)

      ! Work out ranks of neighbours
      ! W -> E
      call mpi_cart_shift(MPI_COMM_HORIZ,1, stepsize, &
                          neighbour_w_rank,neighbour_e_rank,ierr)
      ! N -> S
      call mpi_cart_shift(MPI_COMM_HORIZ,0, stepsize, &
                          neighbour_n_rank,neighbour_s_rank,ierr)
      if ( (stepsize == 1) .or.                   &
         (  (mod(p_horiz(1),stepsize) == 0) .and. &
            (mod(p_horiz(2),stepsize) == 0) ) ) then
        if (halo_size == 1) then
          ! Do not include corners in send/recv
          yoffset = 1
          blocklength = a_n*(nz+2)*halo_size
        else
          yoffset = 1-halo_size
          blocklength = (a_n+2*halo_size)*(nz+2)*halo_size
        end if
        !
        zst => a%st
        !
        Gneighbour_s = (neighbour_s_rank >= 0)
        Gneighbour_n = (neighbour_n_rank >= 0)
        Gneighbour_e = (neighbour_e_rank >= 0)
        Gneighbour_w = (neighbour_w_rank >= 0)        
        !
        requests_ns(:) = MPI_REQUEST_NULL
        requests_ew(:) = MPI_REQUEST_NULL
        requests_nsT(:) = MPI_REQUEST_NULL
        requests_ewT(:) = MPI_REQUEST_NULL
        !
        !  Init Pointer
        !
#ifdef MNH_GPUDIRECT
        if (LUseT) then
           ! Send to south
           if (Gneighbour_s) then
              ztab_halo_st_haloTin => tab_halo_st(level,m)%haloTin
           end if
           ! Send to north
           if (Gneighbour_n) then
              ztab_halo_nt_haloTin => tab_halo_nt(level,m)%haloTin
           end if
           ! Send to east
           if (Gneighbour_e) then
              ztab_halo_et_haloTin => tab_halo_et(level,m)%haloTin
           end if
           ! Send to west
           if (Gneighbour_w) then
              ztab_halo_wt_haloTin => tab_halo_wt(level,m)%haloTin
           end if
           ! Receive from north
           if (Gneighbour_n) then
              ztab_halo_nt_haloTout => tab_halo_nt(level,m)%haloTout
           end if
           ! Receive from south
           if (Gneighbour_s) then
              ztab_halo_st_haloTout => tab_halo_st(level,m)%haloTout
           end if
           ! Receive from west
           if (Gneighbour_w) then
              ztab_halo_wt_haloTout => tab_halo_wt(level,m)%haloTout
           end if
           ! Receive from east
           if (Gneighbour_e) then
              ztab_halo_et_haloTout => tab_halo_et(level,m)%haloTout
           end if
        end if
#endif
        !
        call haloswap_mnh_dim(ztab_halo_st_haloTin,ztab_halo_nt_haloTin,&
                              ztab_halo_et_haloTin,ztab_halo_wt_haloTin,&
                              ztab_halo_nt_haloTout,ztab_halo_st_haloTout,&
                              ztab_halo_wt_haloTout,ztab_halo_et_haloTout,&
                              zst)
        !
     end if!  (stepsize == 1) ...
     if (comm_measuretime) then
        call finish_timer(t_haloswap(level,m))
     end if
  end if !  (m > 0)

contains
  subroutine haloswap_mnh_dim(pztab_halo_st_haloTin,pztab_halo_nt_haloTin,&
                              pztab_halo_et_haloTin,pztab_halo_wt_haloTin,&
                              pztab_halo_nt_haloTout,pztab_halo_st_haloTout,&
                              pztab_halo_wt_haloTout,pztab_halo_et_haloTout,&
                              pzst)

    implicit none
    real :: pztab_halo_st_haloTin(1:a_n,1:halo_size,1:nz+2), &
            pztab_halo_nt_haloTin(1:a_n,1:halo_size,1:nz+2), &
            pztab_halo_et_haloTin(1:halo_size,1:a_n+2*halo_size,1:nz+2), &
            pztab_halo_wt_haloTin(1:halo_size,1:a_n+2*halo_size,1:nz+2), &
            pztab_halo_nt_haloTout(1:a_n,1:halo_size,1:nz+2), &
            pztab_halo_st_haloTout(1:a_n,1:halo_size,1:nz+2), &
            pztab_halo_wt_haloTout(1:halo_size,1:a_n+2*halo_size,1:nz+2), &
            pztab_halo_et_haloTout(1:halo_size,1:a_n+2*halo_size,1:nz+2), &
            pzst(1-halo_size:a_n+halo_size,1-halo_size:a_n+halo_size,0:nz+1)
        !
        ! Do Comm
        !
#ifdef MNH_GPUDIRECT
        if (LUseT) then
           !
           ! Copy send buffer async to GPU
           !
           ! Send to south
           if (Gneighbour_s) then
!!$           pztab_halo_st_haloTin => tab_halo_st(level,m)%haloTin
           !$acc kernels async(IS_SOUTH) present_cr(pzst,pztab_halo_st_haloTin)
           !$mnh_do_concurrent( ii=1:a_n,ij=1:halo_size,ik=1:nz+2 )
              pztab_halo_st_haloTin(ii,ij,ik) = pzst(ii,ij+a_n-halo_size,ik-1)
           !$mnh_end_do()
           !$acc end kernels
           end if
           ! Send to north
           if (Gneighbour_n) then
!!$           pztab_halo_nt_haloTin => tab_halo_nt(level,m)%haloTin
           !$acc kernels async(IS_NORTH) present_cr(pzst,pztab_halo_nt_haloTin)
           !$mnh_do_concurrent( ii=1:a_n,ij=1:halo_size,ik=1:nz+2 )           
              pztab_halo_nt_haloTin(ii,ij,ik) = pzst(ii,ij,ik-1)
           !$mnh_end_do()
           !$acc end kernels
           end if
           ! Send to east
           if (Gneighbour_e) then
!!$           pztab_halo_et_haloTin => tab_halo_et(level,m)%haloTin
           !$acc kernels async(IS_EAST) present_cr(pzst,pztab_halo_et_haloTin)
           !$mnh_do_concurrent( ii=1:halo_size,ij=1:a_n+2*halo_size,ik=1:nz+2 ) 
              pztab_halo_et_haloTin(ii,ij,ik) = pzst(ii+a_n-halo_size,ij-halo_size,ik-1)
           !$mnh_end_do()
           !$acc end kernels
           end if
           ! Send to west
           if (Gneighbour_w) then
!!$           pztab_halo_wt_haloTin => tab_halo_wt(level,m)%haloTin
           !$acc kernels async(IS_WEST) present_cr(pzst,pztab_halo_wt_haloTin)
           !$mnh_do_concurrent( ii=1:halo_size,ij=1:a_n+2*halo_size,ik=1:nz+2 ) 
              pztab_halo_wt_haloTin(ii,ij,ik) = pzst(ii,ij-halo_size,ik-1)
           !$mnh_end_do()
           !$acc end kernels
           end if
        end if
#endif
        ! Receive from north
        recvtag = 1002
        if (LUseO) call mpi_irecv(a%s(0,0-(halo_size-1),1),1,       &
                       halo_ns(level,m),neighbour_n_rank,recvtag,   &
                       MPI_COMM_HORIZ, requests_ns(1), ierr)
        recvtag = 1012
        if (LUseT) then
#ifdef MNH_GPUDIRECT
           if (Gneighbour_n) then
!!$           pztab_halo_nt_haloTout => tab_halo_nt(level,m)%haloTout
           !$acc host_data use_device(pztab_halo_nt_haloTout)
           call mpi_irecv(pztab_halo_nt_haloTout,size(pztab_halo_nt_haloTout),      &
                       MPI_DOUBLE_PRECISION,neighbour_n_rank,recvtag,  &
                       MPI_COMM_HORIZ, requests_nsT(1), ierr)
           !$acc end host_data
           end if
           !print*,"mpi_irecv(pztab_halo_nt_haloTout,neighbour_n_rank=",neighbour_n_rank
#else
           call mpi_irecv(a%st(1,0-(halo_size-1),0),1,      &
                       halo_nst(level,m),neighbour_n_rank,recvtag,  &
                       MPI_COMM_HORIZ, requests_nsT(1), ierr)
#endif
        end if
        ! Receive from south
        recvtag = 1003
        if (LUseO) call mpi_irecv(a%s(0,a_n+1,1),1,                 &
                       halo_ns(level,m),neighbour_s_rank,recvtag,   &
                       MPI_COMM_HORIZ, requests_ns(2), ierr)
        recvtag = 1013
        if (LUseT) then
#ifdef MNH_GPUDIRECT
           if (Gneighbour_s) then
!!$           pztab_halo_st_haloTout => tab_halo_st(level,m)%haloTout
           !$acc host_data use_device (pztab_halo_st_haloTout)
           call mpi_irecv(pztab_halo_st_haloTout,size(pztab_halo_st_haloTout),  &
                       MPI_DOUBLE_PRECISION,neighbour_s_rank,recvtag,  &
                       MPI_COMM_HORIZ, requests_nsT(2), ierr)
           !$acc end host_data
           end if
           !print*,"mpi_irecv(pztab_halo_st_haloTout,neighbour_s_rank=",neighbour_s_rank
#else
           call mpi_irecv(a%st(1,a_n+1,0),1,                &
                       halo_nst(level,m),neighbour_s_rank,recvtag,  &
                       MPI_COMM_HORIZ, requests_nsT(2), ierr)           
#endif
        end if
#ifdef MNH_GPUDIRECT
        if (LUseT) then
           ! wait for async copy of send buffer to GPU
           call acc_wait_haloswap_mnh()           
        end if
#endif        
        ! Send to south
        sendtag = 1002
        if (LUseO) call mpi_isend(a%s(0,a_n-(halo_size-1),1),1,     &
                       halo_ns(level,m),neighbour_s_rank,sendtag,   &
                       MPI_COMM_HORIZ, requests_ns(3), ierr)
        sendtag = 1012
        if (LUseT) then
#ifdef MNH_GPUDIRECT
           if (Gneighbour_s) then
           !$acc host_data use_device(pztab_halo_st_haloTin)
           call mpi_isend(pztab_halo_st_haloTin,size(pztab_halo_st_haloTin),    &
                       MPI_DOUBLE_PRECISION,neighbour_s_rank,sendtag,  &
                       MPI_COMM_HORIZ, requests_nsT(3), ierr)
           !$acc end host_data
           end if
           !print*,"mpi_isend(pztab_halo_st_haloTin,neighbour_s_rank=",neighbour_s_rank
#else   
           call mpi_isend(a%st(1,a_n-(halo_size-1),0),1,    &
                       halo_nst(level,m),neighbour_s_rank,sendtag,  &
                       MPI_COMM_HORIZ, requests_nsT(3), ierr)
#endif
        end if
        ! Send to north
        sendtag = 1003
        if (LUseO) call mpi_isend(a%s(0,1,1),1,                     &
                       halo_ns(level,m),neighbour_n_rank,sendtag,   &
                       MPI_COMM_HORIZ, requests_ns(4), ierr)
        sendtag = 1013
        if (LUseT) then
#ifdef MNH_GPUDIRECT
           if (Gneighbour_n) then
           !$acc host_data use_device(pztab_halo_nt_haloTin)
           call mpi_isend(pztab_halo_nt_haloTin,size(pztab_halo_nt_haloTin),   &
                       MPI_DOUBLE_PRECISION,neighbour_n_rank,sendtag,  &
                       MPI_COMM_HORIZ, requests_nsT(4), ierr)
           !$acc end host_data
           end if
           !print*,"mpi_isend(pztab_halo_nt_haloTin,neighbour_n_rank=",neighbour_n_rank
#else
           call mpi_isend(a%st(1,1,0),1,                    &
                       halo_nst(level,m),neighbour_n_rank,sendtag,  &
                       MPI_COMM_HORIZ, requests_nsT(4), ierr)           
#endif
        end if
        if (halo_size > 1) then
          ! Wait for North <-> South communication to complete
          if (LUseO) call mpi_waitall(4,requests_ns, MPI_STATUSES_IGNORE, ierr)
          if (LUseT) call mpi_waitall(4,requests_nsT, MPI_STATUSES_IGNORE, ierr)
        end if
        ! Receive from west
        recvtag = 1000
        if (LUseO) call mpi_irecv(a%s(0,yoffset,0-(halo_size-1)),blocklength,    &
                       MPI_DOUBLE_PRECISION,neighbour_w_rank,recvtag, &
                       MPI_COMM_HORIZ, requests_ew(1), ierr)
        recvtag = 1010
        if (LUseT) then
#ifdef MNH_GPUDIRECT
           if (Gneighbour_w) then
!!$           pztab_halo_wt_haloTout => tab_halo_wt(level,m)%haloTout
           !$acc host_data use_device(pztab_halo_wt_haloTout)
           call mpi_irecv(pztab_halo_wt_haloTout,size(pztab_halo_wt_haloTout),  &
                       MPI_DOUBLE_PRECISION,neighbour_w_rank,recvtag, &
                       MPI_COMM_HORIZ, requests_ewT(1), ierr)
           !$acc end host_data
           end if
           !print*,"mpi_irecv(pztab_halo_wt_haloTout,neighbour_w_rank=",neighbour_w_rank
#else
           call mpi_irecv(a%st(0-(halo_size-1),0,0),1,  &
                       halo_wet(level,m),neighbour_w_rank,recvtag, &
                       MPI_COMM_HORIZ, requests_ewT(1), ierr)          
#endif
        end if
        ! Receive from east
        sendtag = 1001
        if (LUseO) call mpi_irecv(a%s(0,yoffset,a_n+1),blocklength,          &
                       MPI_DOUBLE_PRECISION,neighbour_e_rank,recvtag, &
                       MPI_COMM_HORIZ, requests_ew(2), ierr)
        sendtag = 1011
        if (LUseT) then
#ifdef MNH_GPUDIRECT
           if (Gneighbour_e) then
!!$           pztab_halo_et_haloTout => tab_halo_et(level,m)%haloTout
           !$acc host_data use_device(pztab_halo_et_haloTout)
           call mpi_irecv(pztab_halo_et_haloTout,size(pztab_halo_et_haloTout),  &
                       MPI_DOUBLE_PRECISION,neighbour_e_rank,recvtag, &
                       MPI_COMM_HORIZ, requests_ewT(2), ierr)
           !$acc end host_data
           end if
           !print*,"mpi_irecv(pztab_halo_et_haloTout,neighbour_e_rank=",neighbour_e_rank
#else
           call mpi_irecv(a%st(a_n+1,0,0),1,          &
                       halo_wet(level,m),neighbour_e_rank,recvtag, &
                       MPI_COMM_HORIZ, requests_ewT(2), ierr)
           
#endif
        end if
        ! Send to east
        sendtag = 1000
        if (LUseO) call mpi_isend(a%s(0,yoffset,a_n-(halo_size-1)),blocklength,  &
                       MPI_DOUBLE_PRECISION,neighbour_e_rank,sendtag, &
                       MPI_COMM_HORIZ, requests_ew(3), ierr)
        sendtag = 1010
        if (LUseT) then
#ifdef MNH_GPUDIRECT
           if (Gneighbour_e) then
           !$acc host_data use_device(pztab_halo_et_haloTin)
           call mpi_isend(pztab_halo_et_haloTin,size(pztab_halo_et_haloTin),  &
                       MPI_DOUBLE_PRECISION,neighbour_e_rank,sendtag, &
                       MPI_COMM_HORIZ, requests_ewT(3), ierr)
           !$acc end host_data
           end if
           !print*,"mpi_isend(pztab_halo_et_haloTin,neighbour_e_rank=",neighbour_e_rank
#else
           call mpi_isend(a%st(a_n-(halo_size-1),0,0),1,  &
                       halo_wet(level,m),neighbour_e_rank,sendtag, &
                       MPI_COMM_HORIZ, requests_ewT(3), ierr)           
#endif
        end if
        ! Send to west
        recvtag = 1001
        if (LUseO) call mpi_isend(a%s(0,yoffset,1),blocklength,                &
                       MPI_DOUBLE_PRECISION,neighbour_w_rank,sendtag,   &
                       MPI_COMM_HORIZ, requests_ew(4), ierr)
        recvtag = 1011
        if (LUseT) then
#ifdef MNH_GPUDIRECT
           if (Gneighbour_w) then
           !$acc host_data use_device(pztab_halo_wt_haloTin)
           call mpi_isend(pztab_halo_wt_haloTin,size(pztab_halo_wt_haloTin),  &
                       MPI_DOUBLE_PRECISION,neighbour_w_rank,sendtag,   &
                       MPI_COMM_HORIZ, requests_ewT(4), ierr)
           !$acc end host_data
           end if
           !print*,"mpi_isend(pztab_halo_wt_haloTin,neighbour_w_rank=",neighbour_w_rank
#else
           call mpi_isend(a%st(1,0,0),1,                &
                       halo_wet(level,m),neighbour_w_rank,sendtag,   &
                       MPI_COMM_HORIZ, requests_ewT(4), ierr)           
#endif
        end if
        ! Wait for East <-> West communication to complete
        if (halo_size == 1) then
          ! Wait for North <-> South communication to complete
          if (LUseO) call mpi_waitall(4,requests_ns, MPI_STATUSES_IGNORE, ierr)
          if (LUseT) call mpi_waitall(4,requests_nsT, MPI_STATUSES_IGNORE, ierr)
        end if
        if (LUseO) call mpi_waitall(4,requests_ew, MPI_STATUSES_IGNORE, ierr)
        if (LUseT) call mpi_waitall(4,requests_ewT, MPI_STATUSES_IGNORE, ierr)
#ifdef MNH_GPUDIRECT
        if (LUseT) then
           if (Gneighbour_n) then
           ! copy north halo for GPU managed
           !$acc kernels async(IS_NORTH) present_cr(pzst,pztab_halo_nt_haloTout)
           !$mnh_do_concurrent( ii=1:a_n,ij=1:halo_size,ik=1:nz+2 )
              pzst(ii,ij-halo_size,ik-1) = pztab_halo_nt_haloTout(ii,ij,ik)
           !$mnh_end_do()
           !$acc end kernels
           end if
           if (Gneighbour_s) then
           ! copy south halo for GPU managed
           !$acc kernels async(IS_SOUTH) present_cr(pzst,pztab_halo_st_haloTout)
           !$mnh_do_concurrent( ii=1:a_n,ij=1:halo_size,ik=1:nz+2 )
              pzst(ii,ij+a_n,ik-1) = pztab_halo_st_haloTout(ii,ij,ik)
           !$mnh_end_do()
           !$acc end kernels
           end if
           if (Gneighbour_w) then
           ! copy west halo for GPU managed
           !$acc kernels async(IS_WEST) present_cr(pzst,pztab_halo_wt_haloTout)
           !$mnh_do_concurrent( ii=1:halo_size,ij=1:a_n+2*halo_size,ik=1:nz+2 )
              pzst(ii-halo_size,ij-halo_size,ik-1) = pztab_halo_wt_haloTout(ii,ij,ik)
           !$mnh_end_do()
           !$acc end kernels
           end if
           if (Gneighbour_e) then
           ! copy east halo for GPU managed
           !$acc kernels async(IS_EAST) present_cr(pzst,pztab_halo_et_haloTout)
           !$mnh_do_concurrent( ii=1:halo_size,ij=1:a_n+2*halo_size,ik=1:nz+2 )
              pzst(ii+a_n,ij-halo_size,ik-1) = pztab_halo_et_haloTout(ii,ij,ik)
           !$mnh_end_do()
           !$acc end kernels           
           end if 
           ! wait for async copy of send buffer to GPU
           call acc_wait_haloswap_mnh()           
        end if 
#endif
      !
      ! End Comm
      !
    end subroutine haloswap_mnh_dim
      
    subroutine acc_wait_haloswap_mnh()
      if (Gneighbour_s) then
         !$acc wait(IS_SOUTH)
      endif
      if (Gneighbour_n) then
         !$acc wait(IS_NORTH)
      endif
      if (Gneighbour_e) then
         !$acc wait(IS_EAST)
      endif
      if (Gneighbour_w) then
         !$acc wait(IS_WEST)
      endif
    end subroutine acc_wait_haloswap_mnh
    
  end subroutine haloswap_mnh
!==================================================================
  subroutine haloswap(level,m, &  ! multigrid- and processor- level
                      a)          ! data field
    implicit none
    integer, intent(in) :: level
    integer, intent(in) :: m
    type(scalar3d), intent(inout) :: a
    integer :: a_n  ! horizontal grid size
    integer :: nz   ! vertical grid size
    integer, dimension(2) :: p_horiz
    integer :: stepsize
    integer :: ierr, rank, sendtag, recvtag
    integer :: stat(MPI_STATUS_SIZE)
    integer :: halo_size
    integer :: neighbour_n_rank
    integer :: neighbour_s_rank
    integer :: neighbour_e_rank
    integer :: neighbour_w_rank
    integer :: yoffset, blocklength
    integer, dimension(4) :: requests_ns
    integer, dimension(4) :: requests_ew

    halo_size = comm_param%halo_size

    ! Do nothing if we are only using one processor
    if (m > 0) then
      if (comm_measuretime) then
        call start_timer(t_haloswap(level,m))
      end if
      a_n = a%ix_max-a%ix_min+1
      nz = a%grid_param%nz
      stepsize = 2**(pproc-m)

      ! Work out rank, only execute on relevant processes
      call mpi_comm_rank(MPI_COMM_HORIZ, rank, ierr)
      call mpi_cart_coords(MPI_COMM_HORIZ,rank,dim_horiz,p_horiz,ierr)

      ! Work out ranks of neighbours
      ! W -> E
      call mpi_cart_shift(MPI_COMM_HORIZ,1, stepsize, &
                          neighbour_w_rank,neighbour_e_rank,ierr)
      ! N -> S
      call mpi_cart_shift(MPI_COMM_HORIZ,0, stepsize, &
                          neighbour_n_rank,neighbour_s_rank,ierr)
      if ( (stepsize == 1) .or.                   &
         (  (mod(p_horiz(1),stepsize) == 0) .and. &
            (mod(p_horiz(2),stepsize) == 0) ) ) then
        if (halo_size == 1) then
          ! Do not include corners in send/recv
          yoffset = 1
          blocklength = a_n*(nz+2)*halo_size
        else
          yoffset = 1-halo_size
          blocklength = (a_n+2*halo_size)*(nz+2)*halo_size
        end if
        ! Receive from north
        recvtag = 2
        call mpi_irecv(a%s(0,0-(halo_size-1),1),1,                  &
                       halo_ns(level,m),neighbour_n_rank,recvtag,   &
                       MPI_COMM_HORIZ, requests_ns(1), ierr)
        ! Receive from south
        recvtag = 3
        call mpi_irecv(a%s(0,a_n+1,1),1,                            &
                       halo_ns(level,m),neighbour_s_rank,recvtag,   &
                       MPI_COMM_HORIZ, requests_ns(2), ierr)
        ! Send to south
        sendtag = 2
        call mpi_isend(a%s(0,a_n-(halo_size-1),1),1,                &
                       halo_ns(level,m),neighbour_s_rank,sendtag,   &
                       MPI_COMM_HORIZ, requests_ns(3), ierr)
        ! Send to north
        sendtag = 3
        call mpi_isend(a%s(0,1,1),1,                                &
                       halo_ns(level,m),neighbour_n_rank,sendtag,   &
                       MPI_COMM_HORIZ, requests_ns(4), ierr)
        if (halo_size > 1) then
          ! Wait for North <-> South communication to complete
          call mpi_waitall(4,requests_ns, MPI_STATUSES_IGNORE, ierr)
        end if
        ! Receive from west
        recvtag = 0
        call mpi_irecv(a%s(0,yoffset,0-(halo_size-1)),blocklength,    &
                       MPI_DOUBLE_PRECISION,neighbour_w_rank,recvtag, &
                       MPI_COMM_HORIZ, requests_ew(1), ierr)
        ! Receive from east
        sendtag = 1
        call mpi_irecv(a%s(0,yoffset,a_n+1),blocklength,          &
                       MPI_DOUBLE_PRECISION,neighbour_e_rank,recvtag, &
                       MPI_COMM_HORIZ, requests_ew(2), ierr)
        ! Send to east
        sendtag = 0
        call mpi_isend(a%s(0,yoffset,a_n-(halo_size-1)),blocklength,  &
                       MPI_DOUBLE_PRECISION,neighbour_e_rank,sendtag, &
                       MPI_COMM_HORIZ, requests_ew(3), ierr)
        ! Send to west
        recvtag = 1
        call mpi_isend(a%s(0,yoffset,1),blocklength,                &
                       MPI_DOUBLE_PRECISION,neighbour_w_rank,sendtag,   &
                       MPI_COMM_HORIZ, requests_ew(4), ierr)
        ! Wait for East <-> West communication to complete
        if (halo_size == 1) then
          ! Wait for North <-> South communication to complete
          call mpi_waitall(4,requests_ns, MPI_STATUSES_IGNORE, ierr)
        end if
        call mpi_waitall(4,requests_ew, MPI_STATUSES_IGNORE, ierr)
      end if
      if (comm_measuretime) then
        call finish_timer(t_haloswap(level,m))
      end if
    end if
  end subroutine haloswap
!==================================================================
!  Collect from a(level,m) and store on less processors
!  in b(level,m-1)
!
!  Example for pproc-m = 1, i.e. stepsize = 2:
!
!   NW (0,0)  <--  NE (0,2)
!
!     ^      .
!     !        .
!                .
!   SW (2,0)       SE (2,2) [send to 0,0]
!
!==================================================================
  subroutine collect(level,m, &    ! multigrid and processor level
                     a, &          ! IN: data on level (level,m)
                     b)            ! OUT: data on level (level,m-1)
    implicit none
    integer, intent(in) :: level
    integer, intent(in) :: m
    type(scalar3d), intent(in) :: a
    type(scalar3d), intent(inout) :: b
    integer :: a_n, b_n   ! horizontal grid sizes
    integer :: nz ! vertical grid size
    integer, dimension(2) :: p_horiz
    integer :: stepsize
    integer :: ierr, source_rank, dest_rank, rank, recv_tag, send_tag, iz
    logical :: corner_nw, corner_ne, corner_sw, corner_se
    integer :: recv_request(3)
    integer :: recv_requestT(3)

    integer :: ii,ij,ik

    real , pointer , contiguous , dimension(:,:,:) :: za_st,zb_st

    real , pointer , contiguous , dimension(:,:,:) :: ztab_interiorT_ne_m_haloTin
    real , pointer , contiguous , dimension(:,:,:) :: ztab_interiorT_sw_m_haloTin
    real , pointer , contiguous , dimension(:,:,:) :: ztab_interiorT_se_m_haloTin

    real , pointer , contiguous , dimension(:,:,:) :: ztab_sub_interiorT_ne_m_1_haloTout
    real , pointer , contiguous , dimension(:,:,:) :: ztab_sub_interiorT_sw_m_1_haloTout
    real , pointer , contiguous , dimension(:,:,:) :: ztab_sub_interiorT_se_m_1_haloTout

    call start_timer(t_collect(m))

    stepsize = 2**(pproc-m)

    a_n = a%ix_max-a%ix_min+1
    b_n = b%ix_max-b%ix_min+1
    nz = b%grid_param%nz

    ! Work out rank, only execute on relevant processes
    call mpi_comm_rank(MPI_COMM_HORIZ, rank, ierr)
    ! Store position in process grid in in p_horiz
    ! Note we can NOT use cart_shift as we need diagonal neighburs as well
    call mpi_cart_coords(MPI_COMM_HORIZ,rank,dim_horiz,p_horiz,ierr)

    ! Ignore all processes that do not participate at this level
    if ( (stepsize .eq. 1) .or. ((mod(p_horiz(1),stepsize) == 0) .and. (mod(p_horiz(2),stepsize)) == 0)) then
      ! Determine position in local 2x2 block
      if (stepsize .eq. 1) then
        corner_nw = ((mod(p_horiz(1),2) == 0) .and. (mod(p_horiz(2),2) == 0))
        corner_ne = ((mod(p_horiz(1),2) == 0) .and. (mod(p_horiz(2),2) == 1))
        corner_sw = ((mod(p_horiz(1),2) == 1) .and. (mod(p_horiz(2),2) == 0))
        corner_se = ((mod(p_horiz(1),2) == 1) .and. (mod(p_horiz(2),2) == 1))
      else
        corner_nw = ((mod(p_horiz(1)/stepsize,2) == 0) .and. (mod(p_horiz(2)/stepsize,2) == 0))
        corner_ne = ((mod(p_horiz(1)/stepsize,2) == 0) .and. (mod(p_horiz(2)/stepsize,2) == 1))
        corner_sw = ((mod(p_horiz(1)/stepsize,2) == 1) .and. (mod(p_horiz(2)/stepsize,2) == 0))
        corner_se = ((mod(p_horiz(1)/stepsize,2) == 1) .and. (mod(p_horiz(2)/stepsize,2) == 1))
      end if
      ! NW receives from the other three processes
      if ( corner_nw ) then
        ! Receive from NE
        call mpi_cart_rank(MPI_COMM_HORIZ, &
                           (/p_horiz(1),p_horiz(2)+stepsize/), &
                           source_rank, &
                           ierr)
        recv_tag = 1000
        if (LUseO) call mpi_irecv(b%s(0,1,b_n/2+1),1,sub_interior(level,m-1),source_rank, recv_tag, MPI_COMM_HORIZ, &
                           recv_request(1),ierr)
        recv_tag = 1010
        if (LUseT) then
#ifdef MNH_GPUDIRECT
           ztab_sub_interiorT_ne_m_1_haloTout => tab_sub_interiorT_ne(level,m-1)%haloTout
           !$acc host_data use_device(ztab_sub_interiorT_ne_m_1_haloTout)
           call mpi_irecv(ztab_sub_interiorT_ne_m_1_haloTout,size(ztab_sub_interiorT_ne_m_1_haloTout), &
                MPI_DOUBLE_PRECISION,source_rank, recv_tag, MPI_COMM_HORIZ, &
                recv_requestT(1),ierr)
           !$acc end host_data
#else
           call mpi_irecv(b%st(b_n/2+1,1,0),1,sub_interiorT(level,m-1), source_rank, recv_tag, MPI_COMM_HORIZ, &
                recv_requestT(1),ierr)
#endif
        end if
#ifndef NDEBUG
  if (ierr .ne. 0) &
    call fatalerror("Collect: receive from NE failed in mpi_irecv().")
#endif
         ! Receive from SW
        call mpi_cart_rank(MPI_COMM_HORIZ, &
                           (/p_horiz(1)+stepsize,p_horiz(2)/), &
                           source_rank, &
                           ierr)
        recv_tag = 1001
        if (LUseO) call mpi_irecv(b%s(0,b_n/2+1,1),1,sub_interior(level,m-1), source_rank, recv_tag, MPI_COMM_HORIZ, &
                           recv_request(2),ierr)
        recv_tag = 1011
        if (LUseT) then
#ifdef MNH_GPUDIRECT
           ztab_sub_interiorT_sw_m_1_haloTout => tab_sub_interiorT_sw(level,m-1)%haloTout
           !$acc host_data use_device(ztab_sub_interiorT_sw_m_1_haloTout)
           call mpi_irecv(ztab_sub_interiorT_sw_m_1_haloTout,size(ztab_sub_interiorT_sw_m_1_haloTout), &
                MPI_DOUBLE_PRECISION,source_rank, recv_tag, MPI_COMM_HORIZ, &
                recv_requestT(2),ierr)
           !$acc end host_data
#else
           call mpi_irecv(b%st(1,b_n/2+1,0),1,sub_interiorT(level,m-1), source_rank, recv_tag, MPI_COMM_HORIZ, &
                recv_requestT(2),ierr)         
#endif           
        endif

#ifndef NDEBUG
  if (ierr .ne. 0) &
    call fatalerror("Collect: receive from SW failed in mpi_irecv().")
#endif
        ! Receive from SE
        call mpi_cart_rank(MPI_COMM_HORIZ, &
                           (/p_horiz(1)+stepsize,p_horiz(2)+stepsize/), &
                           source_rank, &
                           ierr)
        recv_tag = 1002
        if (LUseO) call mpi_irecv(b%s(0,b_n/2+1,b_n/2+1),1,sub_interior(level,m-1), source_rank, recv_tag, MPI_COMM_HORIZ, &
                           recv_request(3),ierr)
        recv_tag = 1012
        if (LUseT) then
#ifdef MNH_GPUDIRECT
           ztab_sub_interiorT_se_m_1_haloTout => tab_sub_interiorT_se(level,m-1)%haloTout
           !$acc host_data use_device(ztab_sub_interiorT_se_m_1_haloTout)
           call mpi_irecv(ztab_sub_interiorT_se_m_1_haloTout,size(ztab_sub_interiorT_se_m_1_haloTout), &
                MPI_DOUBLE_PRECISION,source_rank, recv_tag, MPI_COMM_HORIZ, &
                recv_requestT(3),ierr)
           !$acc end host_data
#else
           call mpi_irecv(b%st(b_n/2+1,b_n/2+1,0),1,sub_interiorT(level,m-1), source_rank, recv_tag, MPI_COMM_HORIZ, &
                recv_requestT(3),ierr)           
#endif
        end if
#ifndef NDEBUG
  if (ierr .ne. 0) &
    call fatalerror("Collect: receive from SE failed in mpi_irecv().")
#endif
        ! Copy local data while waiting for data from other processes
        if (LUseO) b%s(0:nz+1,1:a_n,1:a_n) = a%s(0:nz+1,1:a_n,1:a_n)
        if (LUseT) then
#ifdef MNH_GPUDIRECT           
           zb_st => b%st
           za_st => a%st
           !$acc kernels present_cr(zb_st,za_st)
           !$mnh_do_concurrent(ii=1:a_n,ij=1:a_n,ik=1:nz+2)
              zb_st(ii,ij,ik-1) = za_st(ii,ij,ik-1)
           !$mnh_end_do()
           !$acc end kernels
#else           
           b%st(1:a_n,1:a_n,0:nz+1) = a%st(1:a_n,1:a_n,0:nz+1)
#endif
        end if
        ! Wait for receives to complete before proceeding
        if (LUseO) call mpi_waitall(3,recv_request,MPI_STATUSES_IGNORE,ierr)
        if (LUseT) call mpi_waitall(3,recv_requestT,MPI_STATUSES_IGNORE,ierr)
#ifdef MNH_GPUDIRECT
        if (LUseT) then
           zb_st => b%st
           ! copy from buffer for GPU DIRECT
           ! Receive from NE
           !$acc kernels present_cr(zb_st,ztab_sub_interiorT_ne_m_1_haloTout)
           !$mnh_do_concurrent(ii=1:b_n/2,ij=1:b_n/2,ik=1:nz+2)
              zb_st(ii+b_n/2,ij,ik-1) = ztab_sub_interiorT_ne_m_1_haloTout(ii,ij,ik)
           !$mnh_end_do()
           !$acc end kernels           
           ! Receive from SW
           !$acc kernels present_cr(zb_st,ztab_sub_interiorT_sw_m_1_haloTout)
           !$mnh_do_concurrent(ii=1:b_n/2,ij=1:b_n/2,ik=1:nz+2)
              zb_st(ii,ij+b_n/2,ik-1) = ztab_sub_interiorT_sw_m_1_haloTout(ii,ij,ik)
           !$mnh_end_do()
           !$acc end kernels
           ! Receive from SE
           !$acc kernels present_cr(zb_st,ztab_sub_interiorT_se_m_1_haloTout)
           !$mnh_do_concurrent(ii=1:b_n/2,ij=1:b_n/2,ik=1:nz+2)
               zb_st(ii+b_n/2,ij+b_n/2,ik-1) = ztab_sub_interiorT_se_m_1_haloTout(ii,ij,ik)
           !$mnh_end_do()
           !$acc end kernels
        end if
#endif
      end if
      if ( corner_ne ) then
        ! Send to NW
        call mpi_cart_rank(MPI_COMM_HORIZ, &
                           (/p_horiz(1),p_horiz(2)-stepsize/), &
                           dest_rank, &
                           ierr)

        za_st => a%st
        
        send_tag = 1000
        if (LUseO) call mpi_send(a%s(0,1,1),1,interior(level,m),dest_rank,send_tag,MPI_COMM_HORIZ,ierr)
        send_tag = 1010
        if (LUseT) then
#ifdef MNH_GPUDIRECT
           ztab_interiorT_ne_m_haloTin => tab_interiorT_ne(level,m)%haloTin
           !$acc kernels present_cr(ztab_interiorT_ne_m_haloTin,za_st)
           !$mnh_do_concurrent(ii=1:a_n,ij=1:a_n,ik=1:nz+2)
              ztab_interiorT_ne_m_haloTin(ii,ij,ik) = za_st(ii,ij,ik-1)
           !$mnh_end_do()
           !$acc end kernels           
           !$acc host_data use_device(ztab_interiorT_ne_m_haloTin)
           call mpi_send(ztab_interiorT_ne_m_haloTin,size(ztab_interiorT_ne_m_haloTin), &
                MPI_DOUBLE_PRECISION,dest_rank,send_tag,MPI_COMM_HORIZ,ierr)
           !$acc end host_data
#else
           call mpi_send(a%st(1,1,0),1,interiorT(level,m),dest_rank,send_tag,MPI_COMM_HORIZ,ierr)
#endif
        end if
#ifndef NDEBUG
  if (ierr .ne. 0) &
    call fatalerror("Collect: send from NE failed in mpi_send().")
#endif
      end if
      if ( corner_sw ) then
        ! Send to NW
        call mpi_cart_rank(MPI_COMM_HORIZ, &
                           (/p_horiz(1)-stepsize,p_horiz(2)/), &
                           dest_rank, &
                           ierr)
        
        za_st => a%st
        
        send_tag = 1001
        if (LUseO) call mpi_send(a%s(0,1,1),1,interior(level,m),dest_rank,send_tag,MPI_COMM_HORIZ,ierr)
        send_tag = 1011
        if (LUseT) then
#ifdef MNH_GPUDIRECT
           ztab_interiorT_sw_m_haloTin => tab_interiorT_sw(level,m)%haloTin
           !$acc kernels present_cr(ztab_interiorT_sw_m_haloTin,za_st)
           !$mnh_do_concurrent(ii=1:a_n,ij=1:a_n,ik=1:nz+2)
              ztab_interiorT_sw_m_haloTin(ii,ij,ik) = za_st(ii,ij,ik-1)
           !$mnh_end_do()
           !$acc end kernels           
           !$acc host_data use_device(ztab_interiorT_sw_m_haloTin)
           call mpi_send(ztab_interiorT_sw_m_haloTin,size(ztab_interiorT_sw_m_haloTin), &
                MPI_DOUBLE_PRECISION,dest_rank,send_tag,MPI_COMM_HORIZ,ierr)
           !$acc end host_data
#else
           call mpi_send(a%st(1,1,0),1,interiorT(level,m),dest_rank,send_tag,MPI_COMM_HORIZ,ierr)        
#endif
        end if
#ifndef NDEBUG
  if (ierr .ne. 0) &
    call fatalerror("Collect: send from SW failed in mpi_send().")
#endif
      end if
      if ( corner_se ) then
        ! send to NW
        call mpi_cart_rank(MPI_COMM_HORIZ, &
                           (/p_horiz(1)-stepsize,p_horiz(2)-stepsize/), &
                           dest_rank, &
                           ierr)

        za_st => a%st
        
        send_tag = 1002
        if (LUseO) call mpi_send(a%s(0,1,1),1,interior(level,m),dest_rank,send_tag,MPI_COMM_HORIZ,ierr)
        send_tag = 1012
        if (LUseT) then
#ifdef MNH_GPUDIRECT
           ztab_interiorT_se_m_haloTin => tab_interiorT_se(level,m)%haloTin
           !$acc kernels present_cr(ztab_interiorT_se_m_haloTin,za_st)
           !$mnh_do_concurrent(ii=1:a_n,ij=1:a_n,ik=1:nz+2)
              ztab_interiorT_se_m_haloTin(ii,ij,ik) = za_st(ii,ij,ik-1)
           !$mnh_end_do()
           !$acc end kernels
           !$acc host_data use_device(ztab_interiorT_se_m_haloTin)
           call mpi_send(ztab_interiorT_se_m_haloTin,size(ztab_interiorT_se_m_haloTin), &
                MPI_DOUBLE_PRECISION,dest_rank,send_tag,MPI_COMM_HORIZ,ierr)
           !$acc end host_data
#else
           call mpi_send(a%st(1,1,0),1,interiorT(level,m),dest_rank,send_tag,MPI_COMM_HORIZ,ierr)
#endif
        end if
#ifndef NDEBUG
  if (ierr .ne. 0) &
    call fatalerror("Collect: send from SE failed in mpi_send().")
#endif
      end if

    end if
    call finish_timer(t_collect(m))

  end subroutine collect

!==================================================================
!  Distribute data in a(level,m-1) and store in b(level,m)
!
!  Example for p-m = 1, i.e. stepsize = 2:
!
!   NW (0,0)  -->  NE (2,0)
!
!     !      .
!     v        .
!                .
!   SW (0,2)       SE (2,2) [receive from to 0,0]
!==================================================================
  subroutine distribute(level,m, &  ! multigrid and processor level
                        a,       &  ! IN: Data on level (level,m-1)
                        b)          ! OUT: Data on level (level,m)
    implicit none
    integer, intent(in) :: level
    integer, intent(in) :: m
    type(scalar3d), intent(in) :: a
    type(scalar3d), intent(inout) :: b
    integer :: a_n, b_n   ! horizontal grid sizes
    integer :: nz ! vertical grid size
    integer, dimension(2) :: p_horiz
    integer :: stepsize
    integer :: ierr, source_rank, dest_rank, send_tag, recv_tag, rank, iz
    integer :: stat(MPI_STATUS_SIZE)
    integer :: send_request(3)
    integer :: send_requestT(3)
    logical :: corner_nw, corner_ne, corner_sw, corner_se

    integer :: ii,ij,ik

    real , pointer , contiguous , dimension(:,:,:) :: za_st,zb_st

    real , pointer , contiguous , dimension(:,:,:) :: ztab_sub_interiorT_ne_m_1_haloTin
    real , pointer , contiguous , dimension(:,:,:) :: ztab_sub_interiorT_sw_m_1_haloTin
    real , pointer , contiguous , dimension(:,:,:) :: ztab_sub_interiorT_se_m_1_haloTin

    real , pointer , contiguous , dimension(:,:,:) :: ztab_interiorT_ne_m_haloTout
    real , pointer , contiguous , dimension(:,:,:) :: ztab_interiorT_sw_m_haloTout
    real , pointer , contiguous , dimension(:,:,:) :: ztab_interiorT_se_m_haloTout

    call start_timer(t_distribute(m))

    stepsize = 2**(pproc-m)

    a_n = a%ix_max-a%ix_min+1
    b_n = b%ix_max-b%ix_min+1
    nz = a%grid_param%nz

    ! Work out rank, only execute on relevant processes
    call mpi_comm_rank(MPI_COMM_HORIZ, rank, ierr)
    call mpi_cart_coords(MPI_COMM_HORIZ,rank,dim_horiz,p_horiz,ierr)

    ! Ignore all processes that do not participate at this level
    if ( (stepsize .eq. 1) .or. ((mod(p_horiz(1),stepsize) == 0) .and. (mod(p_horiz(2),stepsize)) == 0)) then
      ! Work out coordinates in local 2 x 2 block
      if (stepsize .eq. 1) then
        corner_nw = ((mod(p_horiz(1),2) == 0) .and. (mod(p_horiz(2),2) == 0))
        corner_ne = ((mod(p_horiz(1),2) == 0) .and. (mod(p_horiz(2),2) == 1))
        corner_sw = ((mod(p_horiz(1),2) == 1) .and. (mod(p_horiz(2),2) == 0))
        corner_se = ((mod(p_horiz(1),2) == 1) .and. (mod(p_horiz(2),2) == 1))
      else
        corner_nw = ((mod(p_horiz(1)/stepsize,2) == 0) .and. (mod(p_horiz(2)/stepsize,2) == 0))
        corner_ne = ((mod(p_horiz(1)/stepsize,2) == 0) .and. (mod(p_horiz(2)/stepsize,2) == 1))
        corner_sw = ((mod(p_horiz(1)/stepsize,2) == 1) .and. (mod(p_horiz(2)/stepsize,2) == 0))
        corner_se = ((mod(p_horiz(1)/stepsize,2) == 1) .and. (mod(p_horiz(2)/stepsize,2) == 1))
      end if
      if ( corner_nw ) then
        ! (Asynchronous) send to NE
        call mpi_cart_rank(MPI_COMM_HORIZ, &
                           (/p_horiz(1),p_horiz(2)+stepsize/), &
                           dest_rank, &
                           ierr)
        
        za_st => a%st
        
        send_tag = 1000
        if (LUseO) call mpi_isend(a%s(0,1,a_n/2+1), 1,sub_interior(level,m-1),dest_rank, send_tag, &
                       MPI_COMM_HORIZ,send_request(1),ierr)
        send_tag = 1010
        if (LUseT) then
#ifdef MNH_GPUDIRECT
           ztab_sub_interiorT_ne_m_1_haloTin => tab_sub_interiorT_ne(level,m-1)%haloTin
           !$acc kernels present_cr(ztab_sub_interiorT_ne_m_1_haloTin,za_st)
           !$mnh_do_concurrent(ii=1:a_n/2,ij=1:a_n/2,ik=1:nz+2)
              ztab_sub_interiorT_ne_m_1_haloTin(ii,ij,ik) = za_st(ii+a_n/2,ij,ik-1)
           !$mnh_end_do()
           !$acc end kernels           
           !$acc host_data use_device(ztab_sub_interiorT_ne_m_1_haloTin)
           call mpi_isend(ztab_sub_interiorT_ne_m_1_haloTin,size(ztab_sub_interiorT_ne_m_1_haloTin), &
                MPI_DOUBLE_PRECISION,dest_rank, send_tag, &
                MPI_COMM_HORIZ,send_requestT(1),ierr)
           !$acc end host_data
#else
           call mpi_isend(a%st(a_n/2+1,1,0), 1,sub_interiorT(level,m-1),dest_rank, send_tag, &
                MPI_COMM_HORIZ,send_requestT(1),ierr)          
#endif
        end if
#ifndef NDEBUG
  if (ierr .ne. 0) &
    call fatalerror("Distribute: send to NE failed in mpi_isend().")
#endif
        ! (Asynchronous) send to SW
        call mpi_cart_rank(MPI_COMM_HORIZ, &
                           (/p_horiz(1)+stepsize,p_horiz(2)/), &
                           dest_rank, &
                           ierr)
        send_tag = 1001
        if (LUseO) call mpi_isend(a%s(0,a_n/2+1,1),1,sub_interior(level,m-1), dest_rank, send_tag, &
                       MPI_COMM_HORIZ, send_request(2), ierr)
        send_tag = 1011
        if (LUseT) then
#ifdef MNH_GPUDIRECT
           ztab_sub_interiorT_sw_m_1_haloTin => tab_sub_interiorT_sw(level,m-1)%haloTin
           !$acc kernels present_cr(ztab_sub_interiorT_sw_m_1_haloTin,za_st)
           !$mnh_do_concurrent(ii=1:a_n/2,ij=1:a_n/2,ik=1:nz+2)
              ztab_sub_interiorT_sw_m_1_haloTin(ii,ij,ik) = za_st(ii,ij+a_n/2,ik-1)
           !$mnh_end_do()
           !$acc end kernels           
           !$acc host_data use_device(ztab_sub_interiorT_sw_m_1_haloTin)
           call mpi_isend(ztab_sub_interiorT_sw_m_1_haloTin,size(ztab_sub_interiorT_sw_m_1_haloTin), &
                MPI_DOUBLE_PRECISION, dest_rank, send_tag, &
                MPI_COMM_HORIZ, send_requestT(2), ierr)
           !$acc end host_data
#else
           call mpi_isend(a%st(1,a_n/2+1,0),1,sub_interiorT(level,m-1), dest_rank, send_tag, &
                MPI_COMM_HORIZ, send_requestT(2), ierr)           
#endif
        end if

#ifndef NDEBUG
  if (ierr .ne. 0) &
    call fatalerror("Distribute: send to SW failed in mpi_isend().")
#endif
        ! (Asynchronous) send to SE
        call mpi_cart_rank(MPI_COMM_HORIZ, &
                           (/p_horiz(1)+stepsize,p_horiz(2)+stepsize/), &
                           dest_rank, &
                           ierr)
        send_tag = 1002
        if (LUseO) call mpi_isend(a%s(0,a_n/2+1,a_n/2+1),1,sub_interior(level,m-1), dest_rank, send_tag, &
                      MPI_COMM_HORIZ, send_request(3), ierr)
        send_tag = 1012
        if (LUseT) then
#ifdef MNH_GPUDIRECT
           ztab_sub_interiorT_se_m_1_haloTin => tab_sub_interiorT_se(level,m-1)%haloTin
           !$acc kernels present_cr(ztab_sub_interiorT_se_m_1_haloTin,za_st)
           !$mnh_do_concurrent(ii=1:a_n/2,ij=1:a_n/2,ik=1:nz+2)
              ztab_sub_interiorT_se_m_1_haloTin(ii,ij,ik) = za_st(ii+a_n/2,ij+a_n/2,ik-1)
           !$mnh_end_do()
           !$acc end kernels           
           !$acc host_data use_device(ztab_sub_interiorT_se_m_1_haloTin)
           call mpi_isend(ztab_sub_interiorT_se_m_1_haloTin,size(ztab_sub_interiorT_se_m_1_haloTin), &
                MPI_DOUBLE_PRECISION, dest_rank, send_tag, &
                MPI_COMM_HORIZ, send_requestT(3), ierr)
           !$acc end host_data
#else
           call mpi_isend(a%st(a_n/2+1,a_n/2+1,0),1,sub_interiorT(level,m-1), dest_rank, send_tag, &
                MPI_COMM_HORIZ, send_requestT(3), ierr)          
#endif
        end if
#ifndef NDEBUG
  if (ierr .ne. 0) &
    call fatalerror("Distribute: send to SE failed in mpi_isend().")
#endif
        ! While sending, copy local data
        if (LUseO) b%s(0:nz+1,1:b_n,1:b_n) = a%s(0:nz+1,1:b_n,1:b_n)
        if (LUseT) then
#ifdef MNH_GPUDIRECT                     
           zb_st => b%st
           za_st => a%st
           !$acc kernels present_cr(zb_st,za_st)
           !$mnh_do_concurrent(ii=1:b_n,ij=1:b_n,ik=1:nz+2)
              zb_st(ii,ij,ik-1) = za_st(ii,ij,ik-1)
           !$mnh_end_do()
           !$acc end kernels
#else
           b%st(1:b_n,1:b_n,0:nz+1) = a%st(1:b_n,1:b_n,0:nz+1)
#endif
        end if
        ! Only proceed when async sends to complete
        if (LUseO) call mpi_waitall(3, send_request, MPI_STATUSES_IGNORE, ierr)
        if (LUseT) call mpi_waitall(3, send_requestT, MPI_STATUSES_IGNORE, ierr)
      end if
      if ( corner_ne ) then

        ! Receive from NW
        call mpi_cart_rank(MPI_COMM_HORIZ, &
                           (/p_horiz(1),p_horiz(2)-stepsize/), &
                           source_rank, &
                           ierr)
        
        zb_st => b%st
        
        recv_tag = 1000
        if (LUseO) call mpi_recv(b%s(0,1,1),1,interior(level,m),source_rank,recv_tag,MPI_COMM_HORIZ,stat,ierr)
        recv_tag = 1010
        if (LUseT) then
#ifdef MNH_GPUDIRECT
           ztab_interiorT_ne_m_haloTout => tab_interiorT_ne(level,m)%haloTout
           !$acc host_data use_device(ztab_interiorT_ne_m_haloTout)
           call mpi_recv(ztab_interiorT_ne_m_haloTout,size(ztab_interiorT_ne_m_haloTout), &
                MPI_DOUBLE_PRECISION,source_rank,recv_tag,MPI_COMM_HORIZ,stat,ierr)
           !$acc end host_data
           !$acc kernels present_cr(zb_st,ztab_interiorT_ne_m_haloTout)
           !$mnh_do_concurrent(ii=1:b_n,ij=1:b_n,ik=1:nz+2)
              zb_st(ii,ij,ik-1) = ztab_interiorT_ne_m_haloTout(ii,ij,ik)
           !$mnh_end_do()
           !$acc end kernels           
#else
           call mpi_recv(b%st(1,1,0),1,interiorT(level,m),source_rank,recv_tag,MPI_COMM_HORIZ,stat,ierr)           
#endif
        end if
#ifndef NDEBUG
  if (ierr .ne. 0) &
    call fatalerror("Distribute: receive on NE failed in mpi_recv().")
#endif
      end if
      if ( corner_sw ) then
        ! Receive from NW
        call mpi_cart_rank(MPI_COMM_HORIZ, &
                           (/p_horiz(1)-stepsize,p_horiz(2)/), &
                           source_rank, &
                           ierr)

        zb_st => b%st
        
        recv_tag = 1001
        if (LUseO) call mpi_recv(b%s(0,1,1),1,interior(level,m),source_rank,recv_tag,MPI_COMM_HORIZ,stat,ierr)
        recv_tag = 1011
        if (LUseT) then
#ifdef MNH_GPUDIRECT
           ztab_interiorT_sw_m_haloTout => tab_interiorT_sw(level,m)%haloTout
           !$acc host_data use_device(ztab_interiorT_sw_m_haloTout)
           call mpi_recv(ztab_interiorT_sw_m_haloTout,size(ztab_interiorT_sw_m_haloTout), &           
                MPI_DOUBLE_PRECISION,source_rank,recv_tag,MPI_COMM_HORIZ,stat,ierr)
           !$acc end host_data
           !$acc kernels present_cr(zb_st,ztab_interiorT_sw_m_haloTout)
           !$mnh_do_concurrent(ii=1:b_n,ij=1:b_n,ik=1:nz+2)
              zb_st(ii,ij,ik-1) = ztab_interiorT_sw_m_haloTout(ii,ij,ik)
           !$mnh_end_do()
           !$acc end kernels
#else
           call mpi_recv(b%st(1,1,0),1,interiorT(level,m),source_rank,recv_tag,MPI_COMM_HORIZ,stat,ierr)           
#endif
        end if
#ifndef NDEBUG
  if (ierr .ne. 0) &
    call fatalerror("Distribute: receive on SW failed in mpi_recv().")
#endif
      end if
      if ( corner_se ) then
        ! Receive from NW
        call mpi_cart_rank(MPI_COMM_HORIZ, &
                           (/p_horiz(1)-stepsize,p_horiz(2)-stepsize/), &
                           source_rank, &
                           ierr)
        zb_st => b%st
        
        recv_tag = 1002
        if (LUseO) call mpi_recv(b%s(0,1,1),1,interior(level,m),source_rank,recv_tag,MPI_COMM_HORIZ,stat,ierr)
        recv_tag = 1012
        if (LUseT) then
#ifdef MNH_GPUDIRECT
           ztab_interiorT_se_m_haloTout => tab_interiorT_se(level,m)%haloTout
           !$acc host_data use_device(ztab_interiorT_se_m_haloTout)
           call mpi_recv(ztab_interiorT_se_m_haloTout,size(ztab_interiorT_se_m_haloTout), &
                MPI_DOUBLE_PRECISION,source_rank,recv_tag,MPI_COMM_HORIZ,stat,ierr)
           !$acc end host_data
           !$acc kernels present_cr(zb_st,ztab_interiorT_se_m_haloTout)
           !$mnh_do_concurrent(ii=1:b_n,ij=1:b_n,ik=1:nz+2)
              zb_st(ii,ij,ik-1) = ztab_interiorT_se_m_haloTout(ii,ij,ik)
           !$mnh_end_do()
           !$acc end kernels
#else
           call mpi_recv(b%st(1,1,0),1,interiorT(level,m),source_rank,recv_tag,MPI_COMM_HORIZ,stat,ierr)           
#endif
        end if
#ifndef NDEBUG
  if (ierr .ne. 0) &
    call fatalerror("Distribute: receive on NW failed in mpi_recv().")
#endif
      end if

    end if
    call finish_timer(t_distribute(m))

  end subroutine distribute

end module communication