From 76f03816564d14736d8fd4e041efbecacd23287a Mon Sep 17 00:00:00 2001
From: Juan ESCOBAR <juan.escobar@aero.obs-mip.fr>
Date: Wed, 15 Mar 2023 16:32:26 +0100
Subject: [PATCH] Juan 15/03/2023:ZSOLVER/communication.f90.f90, Cray OPENACC
 Opt, pass ztab_halo* by args + dim in haloswap_mnh

---
 .../communication.f90                         | 179 ++++++++++++------
 1 file changed, 124 insertions(+), 55 deletions(-)

diff --git a/src/ZSOLVER/tensorproductmultigrid_Source/communication.f90 b/src/ZSOLVER/tensorproductmultigrid_Source/communication.f90
index f69262def..bb9003ede 100644
--- a/src/ZSOLVER/tensorproductmultigrid_Source/communication.f90
+++ b/src/ZSOLVER/tensorproductmultigrid_Source/communication.f90
@@ -1125,6 +1125,77 @@ contains
         requests_nsT(:) = MPI_REQUEST_NULL
         requests_ewT(:) = MPI_REQUEST_NULL
         !
+        !  Init Pointer
+        !
+#ifdef MNH_GPUDIRECT
+        if (LUseT) then
+           ! Send to south
+           if (Gneighbour_s) then
+              ztab_halo_st_haloTin => tab_halo_st(level,m)%haloTin
+           end if
+           ! Send to north
+           if (Gneighbour_n) then
+              ztab_halo_nt_haloTin => tab_halo_nt(level,m)%haloTin
+           end if
+           ! Send to east
+           if (Gneighbour_e) then
+              ztab_halo_et_haloTin => tab_halo_et(level,m)%haloTin
+           end if
+           ! Send to west
+           if (Gneighbour_w) then
+              ztab_halo_wt_haloTin => tab_halo_wt(level,m)%haloTin
+           end if
+           ! Receive from north
+           if (Gneighbour_n) then
+              ztab_halo_nt_haloTout => tab_halo_nt(level,m)%haloTout
+           end if
+           ! Receive from south
+           if (Gneighbour_s) then
+              ztab_halo_st_haloTout => tab_halo_st(level,m)%haloTout
+           end if
+           ! Receive from west
+           if (Gneighbour_w) then
+              ztab_halo_wt_haloTout => tab_halo_wt(level,m)%haloTout
+           end if
+           ! Receive from east
+           if (Gneighbour_e) then
+              ztab_halo_et_haloTout => tab_halo_et(level,m)%haloTout
+           end if
+        end if
+#endif
+        !
+        call haloswap_mnh_dim(ztab_halo_st_haloTin,ztab_halo_nt_haloTin,&
+                              ztab_halo_et_haloTin,ztab_halo_wt_haloTin,&
+                              ztab_halo_nt_haloTout,ztab_halo_st_haloTout,&
+                              ztab_halo_wt_haloTout,ztab_halo_et_haloTout,&
+                              zst)
+        !
+     end if!  (stepsize == 1) ...
+     if (comm_measuretime) then
+        call finish_timer(t_haloswap(level,m))
+     end if
+  end if !  (m > 0)
+
+contains
+  subroutine haloswap_mnh_dim(pztab_halo_st_haloTin,pztab_halo_nt_haloTin,&
+                              pztab_halo_et_haloTin,pztab_halo_wt_haloTin,&
+                              pztab_halo_nt_haloTout,pztab_halo_st_haloTout,&
+                              pztab_halo_wt_haloTout,pztab_halo_et_haloTout,&
+                              pzst)
+
+    implicit none
+    real :: pztab_halo_st_haloTin(1:a_n,1:halo_size,1:nz+2), &
+            pztab_halo_nt_haloTin(1:a_n,1:halo_size,1:nz+2), &
+            pztab_halo_et_haloTin(1:halo_size,1:a_n+2*halo_size,1:nz+2), &
+            pztab_halo_wt_haloTin(1:halo_size,1:a_n+2*halo_size,1:nz+2), &
+            pztab_halo_nt_haloTout(1:a_n,1:halo_size,1:nz+2), &
+            pztab_halo_st_haloTout(1:a_n,1:halo_size,1:nz+2), &
+            pztab_halo_wt_haloTout(1:halo_size,1:a_n+2*halo_size,1:nz+2), &
+            pztab_halo_et_haloTout(1:halo_size,1:a_n+2*halo_size,1:nz+2), &
+            pzst(1-halo_size:a_n+halo_size,1-halo_size:a_n+halo_size,0:nz+1)
+        !
+        ! Do Comm
+        !
 #ifdef MNH_GPUDIRECT
         if (LUseT) then
            !
@@ -1132,37 +1203,37 @@ contains
            !
            ! Send to south
            if (Gneighbour_s) then
-           ztab_halo_st_haloTin => tab_halo_st(level,m)%haloTin
-           !$acc kernels async(IS_SOUTH) present_cr(zst,ztab_halo_st_haloTin)
+!!$           pztab_halo_st_haloTin => tab_halo_st(level,m)%haloTin
+           !$acc kernels async(IS_SOUTH) present_cr(pzst,pztab_halo_st_haloTin)
            !$mnh_do_concurrent( ii=1:a_n,ij=1:halo_size,ik=1:nz+2 )
-              ztab_halo_st_haloTin(ii,ij,ik) = zst(ii,ij+a_n-halo_size,ik-1)
+              pztab_halo_st_haloTin(ii,ij,ik) = pzst(ii,ij+a_n-halo_size,ik-1)
            !$mnh_end_do()
            !$acc end kernels
            end if
            ! Send to north
            if (Gneighbour_n) then
-           ztab_halo_nt_haloTin => tab_halo_nt(level,m)%haloTin
-           !$acc kernels async(IS_NORTH) present_cr(zst,ztab_halo_nt_haloTin)
+!!$           pztab_halo_nt_haloTin => tab_halo_nt(level,m)%haloTin
+           !$acc kernels async(IS_NORTH) present_cr(pzst,pztab_halo_nt_haloTin)
            !$mnh_do_concurrent( ii=1:a_n,ij=1:halo_size,ik=1:nz+2 )           
-              ztab_halo_nt_haloTin(ii,ij,ik) = zst(ii,ij,ik-1)
+              pztab_halo_nt_haloTin(ii,ij,ik) = pzst(ii,ij,ik-1)
            !$mnh_end_do()
            !$acc end kernels
            end if
            ! Send to east
            if (Gneighbour_e) then
-           ztab_halo_et_haloTin => tab_halo_et(level,m)%haloTin
-           !$acc kernels async(IS_EAST) present_cr(zst,ztab_halo_et_haloTin)
+!!$           pztab_halo_et_haloTin => tab_halo_et(level,m)%haloTin
+           !$acc kernels async(IS_EAST) present_cr(pzst,pztab_halo_et_haloTin)
            !$mnh_do_concurrent( ii=1:halo_size,ij=1:a_n+2*halo_size,ik=1:nz+2 ) 
-              ztab_halo_et_haloTin(ii,ij,ik) = zst(ii+a_n-halo_size,ij-halo_size,ik-1)
+              pztab_halo_et_haloTin(ii,ij,ik) = pzst(ii+a_n-halo_size,ij-halo_size,ik-1)
            !$mnh_end_do()
            !$acc end kernels
            end if
            ! Send to west
            if (Gneighbour_w) then
-           ztab_halo_wt_haloTin => tab_halo_wt(level,m)%haloTin
-           !$acc kernels async(IS_WEST) present_cr(zst,ztab_halo_wt_haloTin)
+!!$           pztab_halo_wt_haloTin => tab_halo_wt(level,m)%haloTin
+           !$acc kernels async(IS_WEST) present_cr(pzst,pztab_halo_wt_haloTin)
            !$mnh_do_concurrent( ii=1:halo_size,ij=1:a_n+2*halo_size,ik=1:nz+2 ) 
-              ztab_halo_wt_haloTin(ii,ij,ik) = zst(ii,ij-halo_size,ik-1)
+              pztab_halo_wt_haloTin(ii,ij,ik) = pzst(ii,ij-halo_size,ik-1)
            !$mnh_end_do()
            !$acc end kernels
            end if
@@ -1177,14 +1248,14 @@ contains
         if (LUseT) then
 #ifdef MNH_GPUDIRECT
            if (Gneighbour_n) then
-           ztab_halo_nt_haloTout => tab_halo_nt(level,m)%haloTout
-           !$acc host_data use_device(ztab_halo_nt_haloTout)
-           call mpi_irecv(ztab_halo_nt_haloTout,size(ztab_halo_nt_haloTout),      &
+!!$           pztab_halo_nt_haloTout => tab_halo_nt(level,m)%haloTout
+           !$acc host_data use_device(pztab_halo_nt_haloTout)
+           call mpi_irecv(pztab_halo_nt_haloTout,size(pztab_halo_nt_haloTout),      &
                        MPI_DOUBLE_PRECISION,neighbour_n_rank,recvtag,  &
                        MPI_COMM_HORIZ, requests_nsT(1), ierr)
            !$acc end host_data
            end if
-           !print*,"mpi_irecv(ztab_halo_nt_haloTout,neighbour_n_rank=",neighbour_n_rank
+           !print*,"mpi_irecv(pztab_halo_nt_haloTout,neighbour_n_rank=",neighbour_n_rank
 #else
            call mpi_irecv(a%st(1,0-(halo_size-1),0),1,      &
                        halo_nst(level,m),neighbour_n_rank,recvtag,  &
@@ -1200,14 +1271,14 @@ contains
         if (LUseT) then
 #ifdef MNH_GPUDIRECT
            if (Gneighbour_s) then
-           ztab_halo_st_haloTout => tab_halo_st(level,m)%haloTout
-           !$acc host_data use_device (ztab_halo_st_haloTout)
-           call mpi_irecv(ztab_halo_st_haloTout,size(ztab_halo_st_haloTout),  &
+!!$           pztab_halo_st_haloTout => tab_halo_st(level,m)%haloTout
+           !$acc host_data use_device (pztab_halo_st_haloTout)
+           call mpi_irecv(pztab_halo_st_haloTout,size(pztab_halo_st_haloTout),  &
                        MPI_DOUBLE_PRECISION,neighbour_s_rank,recvtag,  &
                        MPI_COMM_HORIZ, requests_nsT(2), ierr)
            !$acc end host_data
            end if
-           !print*,"mpi_irecv(ztab_halo_st_haloTout,neighbour_s_rank=",neighbour_s_rank
+           !print*,"mpi_irecv(pztab_halo_st_haloTout,neighbour_s_rank=",neighbour_s_rank
 #else
            call mpi_irecv(a%st(1,a_n+1,0),1,                &
                        halo_nst(level,m),neighbour_s_rank,recvtag,  &
@@ -1229,13 +1300,13 @@ contains
         if (LUseT) then
 #ifdef MNH_GPUDIRECT
            if (Gneighbour_s) then
-           !$acc host_data use_device(ztab_halo_st_haloTin)
-           call mpi_isend(ztab_halo_st_haloTin,size(ztab_halo_st_haloTin),    &
+           !$acc host_data use_device(pztab_halo_st_haloTin)
+           call mpi_isend(pztab_halo_st_haloTin,size(pztab_halo_st_haloTin),    &
                        MPI_DOUBLE_PRECISION,neighbour_s_rank,sendtag,  &
                        MPI_COMM_HORIZ, requests_nsT(3), ierr)
            !$acc end host_data
            end if
-           !print*,"mpi_isend(ztab_halo_st_haloTin,neighbour_s_rank=",neighbour_s_rank
+           !print*,"mpi_isend(pztab_halo_st_haloTin,neighbour_s_rank=",neighbour_s_rank
 #else   
            call mpi_isend(a%st(1,a_n-(halo_size-1),0),1,    &
                        halo_nst(level,m),neighbour_s_rank,sendtag,  &
@@ -1251,13 +1322,13 @@ contains
         if (LUseT) then
 #ifdef MNH_GPUDIRECT
            if (Gneighbour_n) then
-           !$acc host_data use_device(ztab_halo_nt_haloTin)
-           call mpi_isend(ztab_halo_nt_haloTin,size(ztab_halo_nt_haloTin),   &
+           !$acc host_data use_device(pztab_halo_nt_haloTin)
+           call mpi_isend(pztab_halo_nt_haloTin,size(pztab_halo_nt_haloTin),   &
                        MPI_DOUBLE_PRECISION,neighbour_n_rank,sendtag,  &
                        MPI_COMM_HORIZ, requests_nsT(4), ierr)
            !$acc end host_data
            end if
-           !print*,"mpi_isend(ztab_halo_nt_haloTin,neighbour_n_rank=",neighbour_n_rank
+           !print*,"mpi_isend(pztab_halo_nt_haloTin,neighbour_n_rank=",neighbour_n_rank
 #else
            call mpi_isend(a%st(1,1,0),1,                    &
                        halo_nst(level,m),neighbour_n_rank,sendtag,  &
@@ -1278,14 +1349,14 @@ contains
         if (LUseT) then
 #ifdef MNH_GPUDIRECT
            if (Gneighbour_w) then
-           ztab_halo_wt_haloTout => tab_halo_wt(level,m)%haloTout
-           !$acc host_data use_device(ztab_halo_wt_haloTout)
-           call mpi_irecv(ztab_halo_wt_haloTout,size(ztab_halo_wt_haloTout),  &
+!!$           pztab_halo_wt_haloTout => tab_halo_wt(level,m)%haloTout
+           !$acc host_data use_device(pztab_halo_wt_haloTout)
+           call mpi_irecv(pztab_halo_wt_haloTout,size(pztab_halo_wt_haloTout),  &
                        MPI_DOUBLE_PRECISION,neighbour_w_rank,recvtag, &
                        MPI_COMM_HORIZ, requests_ewT(1), ierr)
            !$acc end host_data
            end if
-           !print*,"mpi_irecv(ztab_halo_wt_haloTout,neighbour_w_rank=",neighbour_w_rank
+           !print*,"mpi_irecv(pztab_halo_wt_haloTout,neighbour_w_rank=",neighbour_w_rank
 #else
            call mpi_irecv(a%st(0-(halo_size-1),0,0),1,  &
                        halo_wet(level,m),neighbour_w_rank,recvtag, &
@@ -1301,14 +1372,14 @@ contains
         if (LUseT) then
 #ifdef MNH_GPUDIRECT
            if (Gneighbour_e) then
-           ztab_halo_et_haloTout => tab_halo_et(level,m)%haloTout
-           !$acc host_data use_device(ztab_halo_et_haloTout)
-           call mpi_irecv(ztab_halo_et_haloTout,size(ztab_halo_et_haloTout),  &
+!!$           pztab_halo_et_haloTout => tab_halo_et(level,m)%haloTout
+           !$acc host_data use_device(pztab_halo_et_haloTout)
+           call mpi_irecv(pztab_halo_et_haloTout,size(pztab_halo_et_haloTout),  &
                        MPI_DOUBLE_PRECISION,neighbour_e_rank,recvtag, &
                        MPI_COMM_HORIZ, requests_ewT(2), ierr)
            !$acc end host_data
            end if
-           !print*,"mpi_irecv(ztab_halo_et_haloTout,neighbour_e_rank=",neighbour_e_rank
+           !print*,"mpi_irecv(pztab_halo_et_haloTout,neighbour_e_rank=",neighbour_e_rank
 #else
            call mpi_irecv(a%st(a_n+1,0,0),1,          &
                        halo_wet(level,m),neighbour_e_rank,recvtag, &
@@ -1325,13 +1396,13 @@ contains
         if (LUseT) then
 #ifdef MNH_GPUDIRECT
            if (Gneighbour_e) then
-           !$acc host_data use_device(ztab_halo_et_haloTin)
-           call mpi_isend(ztab_halo_et_haloTin,size(ztab_halo_et_haloTin),  &
+           !$acc host_data use_device(pztab_halo_et_haloTin)
+           call mpi_isend(pztab_halo_et_haloTin,size(pztab_halo_et_haloTin),  &
                        MPI_DOUBLE_PRECISION,neighbour_e_rank,sendtag, &
                        MPI_COMM_HORIZ, requests_ewT(3), ierr)
            !$acc end host_data
            end if
-           !print*,"mpi_isend(ztab_halo_et_haloTin,neighbour_e_rank=",neighbour_e_rank
+           !print*,"mpi_isend(pztab_halo_et_haloTin,neighbour_e_rank=",neighbour_e_rank
 #else
            call mpi_isend(a%st(a_n-(halo_size-1),0,0),1,  &
                        halo_wet(level,m),neighbour_e_rank,sendtag, &
@@ -1347,13 +1418,13 @@ contains
         if (LUseT) then
 #ifdef MNH_GPUDIRECT
            if (Gneighbour_w) then
-           !$acc host_data use_device(ztab_halo_wt_haloTin)
-           call mpi_isend(ztab_halo_wt_haloTin,size(ztab_halo_wt_haloTin),  &
+           !$acc host_data use_device(pztab_halo_wt_haloTin)
+           call mpi_isend(pztab_halo_wt_haloTin,size(pztab_halo_wt_haloTin),  &
                        MPI_DOUBLE_PRECISION,neighbour_w_rank,sendtag,   &
                        MPI_COMM_HORIZ, requests_ewT(4), ierr)
            !$acc end host_data
            end if
-           !print*,"mpi_isend(ztab_halo_wt_haloTin,neighbour_w_rank=",neighbour_w_rank
+           !print*,"mpi_isend(pztab_halo_wt_haloTin,neighbour_w_rank=",neighbour_w_rank
 #else
            call mpi_isend(a%st(1,0,0),1,                &
                        halo_wet(level,m),neighbour_w_rank,sendtag,   &
@@ -1372,33 +1443,33 @@ contains
         if (LUseT) then
            if (Gneighbour_n) then
            ! copy north halo for GPU managed
-           !$acc kernels async(IS_NORTH) present_cr(zst,ztab_halo_nt_haloTout)
+           !$acc kernels async(IS_NORTH) present_cr(pzst,pztab_halo_nt_haloTout)
            !$mnh_do_concurrent( ii=1:a_n,ij=1:halo_size,ik=1:nz+2 )
-              zst(ii,ij-halo_size,ik-1) = ztab_halo_nt_haloTout(ii,ij,ik)
+              pzst(ii,ij-halo_size,ik-1) = pztab_halo_nt_haloTout(ii,ij,ik)
            !$mnh_end_do()
            !$acc end kernels
            end if
            if (Gneighbour_s) then
            ! copy south halo for GPU managed
-           !$acc kernels async(IS_SOUTH) present_cr(zst,ztab_halo_st_haloTout)
+           !$acc kernels async(IS_SOUTH) present_cr(pzst,pztab_halo_st_haloTout)
            !$mnh_do_concurrent( ii=1:a_n,ij=1:halo_size,ik=1:nz+2 )
-              zst(ii,ij+a_n,ik-1) = ztab_halo_st_haloTout(ii,ij,ik)
+              pzst(ii,ij+a_n,ik-1) = pztab_halo_st_haloTout(ii,ij,ik)
            !$mnh_end_do()
            !$acc end kernels
            end if
            if (Gneighbour_w) then
            ! copy west halo for GPU managed
-           !$acc kernels async(IS_WEST) present_cr(zst,ztab_halo_wt_haloTout)
+           !$acc kernels async(IS_WEST) present_cr(pzst,pztab_halo_wt_haloTout)
            !$mnh_do_concurrent( ii=1:halo_size,ij=1:a_n+2*halo_size,ik=1:nz+2 )
-              zst(ii-halo_size,ij-halo_size,ik-1) = ztab_halo_wt_haloTout(ii,ij,ik)
+              pzst(ii-halo_size,ij-halo_size,ik-1) = pztab_halo_wt_haloTout(ii,ij,ik)
            !$mnh_end_do()
            !$acc end kernels
            end if
            if (Gneighbour_e) then
            ! copy east halo for GPU managed
-           !$acc kernels async(IS_EAST) present_cr(zst,ztab_halo_et_haloTout)
+           !$acc kernels async(IS_EAST) present_cr(pzst,pztab_halo_et_haloTout)
            !$mnh_do_concurrent( ii=1:halo_size,ij=1:a_n+2*halo_size,ik=1:nz+2 )
-              zst(ii+a_n,ij-halo_size,ik-1) = ztab_halo_et_haloTout(ii,ij,ik)
+              pzst(ii+a_n,ij-halo_size,ik-1) = pztab_halo_et_haloTout(ii,ij,ik)
            !$mnh_end_do()
            !$acc end kernels           
            end if 
@@ -1406,13 +1477,11 @@ contains
            call acc_wait_haloswap_mnh()           
         end if 
 #endif
-      end if!  (stepsize == 1) ...
-      if (comm_measuretime) then
-        call finish_timer(t_haloswap(level,m))
-      end if
-    end if
-
-  contains
+      !
+      ! End Comm
+      !
+    end subroutine haloswap_mnh_dim
+      
     subroutine acc_wait_haloswap_mnh()
       if (Gneighbour_s) then
          !$acc wait(IS_SOUTH)
-- 
GitLab