From 30efd6a2070c260013c9ad75fc15e86a385e725f Mon Sep 17 00:00:00 2001
From: Juan Escobar <juan.escobar@aero.obs-mip.fr>
Date: Thu, 12 Aug 2021 17:12:33 +0200
Subject: [PATCH] Juan 12/08/2021:communication.f90, in distribute add
 host_data use_device + acc parallel loop collapse

---
 .../communication.f90                         | 120 +++++++++++++++---
 1 file changed, 100 insertions(+), 20 deletions(-)

diff --git a/src/ZSOLVER/tensorproductmultigrid_Source/communication.f90 b/src/ZSOLVER/tensorproductmultigrid_Source/communication.f90
index 2bf1bc797..35e357ed9 100644
--- a/src/ZSOLVER/tensorproductmultigrid_Source/communication.f90
+++ b/src/ZSOLVER/tensorproductmultigrid_Source/communication.f90
@@ -1542,8 +1542,7 @@ contains
 #else
            call mpi_irecv(b%st(1,b_n/2+1,0),1,sub_interiorT(level,m-1), source_rank, recv_tag, MPI_COMM_HORIZ, &
                 recv_requestT(2),ierr)         
-#endif
-           
+#endif           
         endif
 
 #ifndef NDEBUG
@@ -1578,7 +1577,18 @@ contains
 #endif
         ! Copy local data while waiting for data from other processes
         if (LUseO) b%s(0:nz+1,1:a_n,1:a_n) = a%s(0:nz+1,1:a_n,1:a_n)
-        if (LUseT) b%st(1:a_n,1:a_n,0:nz+1) = a%st(1:a_n,1:a_n,0:nz+1)
+        if (LUseT) then
+#ifdef MNH_GPUDIREC           
+           zb_st => b%st
+           za_st => a%st
+           !$acc parallel loop collapse(3)
+           do concurrent (ii=1:a_n,ij=1:a_n,ik=1:nz+2)
+              zb_st(ii,ij,ik-1) = za_st(ii,ij,ik-1)
+           end do
+#else           
+           b%st(1:a_n,1:a_n,0:nz+1) = a%st(1:a_n,1:a_n,0:nz+1)
+#endif
+        end if
         ! Wait for receives to complete before proceeding
         if (LUseO) call mpi_waitall(3,recv_request,MPI_STATUSES_IGNORE,ierr)
         if (LUseT) call mpi_waitall(3,recv_requestT,MPI_STATUSES_IGNORE,ierr)
@@ -1736,6 +1746,18 @@ contains
     integer :: send_requestT(3)
     logical :: corner_nw, corner_ne, corner_sw, corner_se
 
+    integer :: ii,ij,ik
+
+    real , pointer , contiguous , dimension(:,:,:) :: za_st,zb_st
+
+    real , pointer , contiguous , dimension(:,:,:) :: ztab_sub_interiorT_ne_m_1_haloTin
+    real , pointer , contiguous , dimension(:,:,:) :: ztab_sub_interiorT_sw_m_1_haloTin
+    real , pointer , contiguous , dimension(:,:,:) :: ztab_sub_interiorT_se_m_1_haloTin
+
+    real , pointer , contiguous , dimension(:,:,:) :: ztab_interiorT_ne_m_haloTout
+    real , pointer , contiguous , dimension(:,:,:) :: ztab_interiorT_sw_m_haloTout
+    real , pointer , contiguous , dimension(:,:,:) :: ztab_interiorT_se_m_haloTout
+
     call start_timer(t_distribute(m))
 
     stepsize = 2**(pproc-m)
@@ -1768,17 +1790,25 @@ contains
                            (/p_horiz(1),p_horiz(2)+stepsize/), &
                            dest_rank, &
                            ierr)
-
+        
+        za_st => a%st
+        
         send_tag = 1000
         if (LUseO) call mpi_isend(a%s(0,1,a_n/2+1), 1,sub_interior(level,m-1),dest_rank, send_tag, &
                        MPI_COMM_HORIZ,send_request(1),ierr)
         send_tag = 1010
         if (LUseT) then
 #ifdef MNH_GPUDIRECT
-           tab_sub_interiorT_ne(level,m-1)%haloTin(1:a_n/2,1:a_n/2,1:nz+2) = a%st(a_n/2+1:a_n,1:a_n/2,0:nz+1)
-           call mpi_isend(tab_sub_interiorT_ne(level,m-1)%haloTin,size(tab_sub_interiorT_ne(level,m-1)%haloTin), &
+           ztab_sub_interiorT_ne_m_1_haloTin => tab_sub_interiorT_ne(level,m-1)%haloTin
+           !$acc parallel loop collapse(3)
+           do concurrent (ii=1:a_n/2,ij=1:a_n/2,ik=1:nz+2)
+              ztab_sub_interiorT_ne_m_1_haloTin(ii,ij,ik) = za_st(ii+a_n/2,ij,ik-1)
+           end do
+           !$acc host_data use_device(ztab_sub_interiorT_ne_m_1_haloTin)
+           call mpi_isend(ztab_sub_interiorT_ne_m_1_haloTin,size(ztab_sub_interiorT_ne_m_1_haloTin), &
                 MPI_DOUBLE_PRECISION,dest_rank, send_tag, &
                 MPI_COMM_HORIZ,send_requestT(1),ierr)
+           !$acc end host_data
 #else
            call mpi_isend(a%st(a_n/2+1,1,0), 1,sub_interiorT(level,m-1),dest_rank, send_tag, &
                 MPI_COMM_HORIZ,send_requestT(1),ierr)          
@@ -1799,10 +1829,16 @@ contains
         send_tag = 1011
         if (LUseT) then
 #ifdef MNH_GPUDIRECT
-           tab_sub_interiorT_sw(level,m-1)%haloTin(1:a_n/2,1:a_n/2,1:nz+2) = a%st(1:a_n/2,a_n/2+1:a_n,0:nz+1)
-           call mpi_isend(tab_sub_interiorT_sw(level,m-1)%haloTin,size(tab_sub_interiorT_sw(level,m-1)%haloTin), &
+           ztab_sub_interiorT_sw_m_1_haloTin => tab_sub_interiorT_sw(level,m-1)%haloTin
+           !$acc parallel loop collapse(3)
+           do concurrent (ii=1:a_n/2,ij=1:a_n/2,ik=1:nz+2)
+              ztab_sub_interiorT_sw_m_1_haloTin(ii,ij,ik) = za_st(ii,ij+a_n/2,ik-1)
+           end do
+           !$acc host_data use_device(ztab_sub_interiorT_sw_m_1_haloTin)
+           call mpi_isend(ztab_sub_interiorT_sw_m_1_haloTin,size(ztab_sub_interiorT_sw_m_1_haloTin), &
                 MPI_DOUBLE_PRECISION, dest_rank, send_tag, &
                 MPI_COMM_HORIZ, send_requestT(2), ierr)
+           !$acc end host_data
 #else
            call mpi_isend(a%st(1,a_n/2+1,0),1,sub_interiorT(level,m-1), dest_rank, send_tag, &
                 MPI_COMM_HORIZ, send_requestT(2), ierr)           
@@ -1824,10 +1860,16 @@ contains
         send_tag = 1012
         if (LUseT) then
 #ifdef MNH_GPUDIRECT
-           tab_sub_interiorT_se(level,m-1)%haloTin(1:a_n/2,1:a_n/2,1:nz+2) = a%st(a_n/2+1:a_n,a_n/2+1:a_n,0:nz+1)
-           call mpi_isend(tab_sub_interiorT_se(level,m-1)%haloTin,size(tab_sub_interiorT_se(level,m-1)%haloTin), &
+           ztab_sub_interiorT_se_m_1_haloTin => tab_sub_interiorT_se(level,m-1)%haloTin
+           !$acc parallel loop collapse(3)
+           do concurrent (ii=1:a_n/2,ij=1:a_n/2,ik=1:nz+2)
+              ztab_sub_interiorT_se_m_1_haloTin(ii,ij,ik) = za_st(ii+a_n/2,ij+a_n/2,ik-1)
+           end do
+           !$acc host_data use_device(ztab_sub_interiorT_se_m_1_haloTin)
+           call mpi_isend(ztab_sub_interiorT_se_m_1_haloTin,size(ztab_sub_interiorT_se_m_1_haloTin), &
                 MPI_DOUBLE_PRECISION, dest_rank, send_tag, &
                 MPI_COMM_HORIZ, send_requestT(3), ierr)
+           !$acc end host_data
 #else
            call mpi_isend(a%st(a_n/2+1,a_n/2+1,0),1,sub_interiorT(level,m-1), dest_rank, send_tag, &
                 MPI_COMM_HORIZ, send_requestT(3), ierr)          
@@ -1839,7 +1881,18 @@ contains
 #endif
         ! While sending, copy local data
         if (LUseO) b%s(0:nz+1,1:b_n,1:b_n) = a%s(0:nz+1,1:b_n,1:b_n)
-        if (LUseT) b%st(1:b_n,1:b_n,0:nz+1) = a%st(1:b_n,1:b_n,0:nz+1)
+        if (LUseT) then
+#ifdef MNH_GPUDIREC                     
+           zb_st => b%st
+           za_st => a%st
+           !$acc parallel loop collapse(3)
+           do concurrent (ii=1:b_n,ij=1:b_n,ik=1:nz+2)
+              zb_st(ii,ij,ik-1) = za_st(ii,ij,ik-1)
+           end do
+#else
+           b%st(1:b_n,1:b_n,0:nz+1) = a%st(1:b_n,1:b_n,0:nz+1)
+#endif
+        end if
         ! Only proceed when async sends to complete
         if (LUseO) call mpi_waitall(3, send_request, MPI_STATUSES_IGNORE, ierr)
         if (LUseT) call mpi_waitall(3, send_requestT, MPI_STATUSES_IGNORE, ierr)
@@ -1851,14 +1904,23 @@ contains
                            (/p_horiz(1),p_horiz(2)-stepsize/), &
                            source_rank, &
                            ierr)
+        
+        zb_st => b%st
+        
         recv_tag = 1000
         if (LUseO) call mpi_recv(b%s(0,1,1),1,interior(level,m),source_rank,recv_tag,MPI_COMM_HORIZ,stat,ierr)
         recv_tag = 1010
         if (LUseT) then
-#ifdef MNH_GPUDIRECT
-           call mpi_recv(tab_interiorT_ne(level,m)%haloTout,size(tab_interiorT_ne(level,m)%haloTout), &
+#ifdef MNH_GPUDIREC
+           ztab_interiorT_ne_m_haloTout => tab_interiorT_ne(level,m)%haloTout
+           !$acc host_data use_device(ztab_interiorT_ne_m_haloTout)
+           call mpi_recv(ztab_interiorT_ne_m_haloTout,size(ztab_interiorT_ne_m_haloTout), &
                 MPI_DOUBLE_PRECISION,source_rank,recv_tag,MPI_COMM_HORIZ,stat,ierr)
-           b%st(1:a_n,1:a_n,0:nz+1) = tab_interiorT_ne(level,m)%haloTout(1:a_n,1:a_n,1:nz+2)
+           !$acc end host_data
+           !$acc parallel loop collapse(3)
+           do concurrent (ii=1:a_n,ij=1:a_n,ik=1:nz+2)
+              zb_st(ii,ij,ik-1) = ztab_interiorT_ne_m_haloTout(ii,ij,ik)
+           end do
 #else
            call mpi_recv(b%st(1,1,0),1,interiorT(level,m),source_rank,recv_tag,MPI_COMM_HORIZ,stat,ierr)           
 #endif
@@ -1874,14 +1936,23 @@ contains
                            (/p_horiz(1)-stepsize,p_horiz(2)/), &
                            source_rank, &
                            ierr)
+
+        zb_st => b%st
+        
         recv_tag = 1001
         if (LUseO) call mpi_recv(b%s(0,1,1),1,interior(level,m),source_rank,recv_tag,MPI_COMM_HORIZ,stat,ierr)
         recv_tag = 1011
         if (LUseT) then
-#ifdef MNH_GPUDIRECT           
-           call mpi_recv(tab_interiorT_sw(level,m)%haloTout,size(tab_interiorT_sw(level,m)%haloTout), &
+#ifdef MNH_GPUDIRECT
+           ztab_interiorT_sw_m_haloTout => tab_interiorT_sw(level,m)%haloTout
+           !$acc host_data use_device(ztab_interiorT_sw_m_haloTout)
+           call mpi_recv(ztab_interiorT_sw_m_haloTout,size(ztab_interiorT_sw_m_haloTout), &           
                 MPI_DOUBLE_PRECISION,source_rank,recv_tag,MPI_COMM_HORIZ,stat,ierr)
-           b%st(1:a_n,1:a_n,0:nz+1) = tab_interiorT_sw(level,m)%haloTout(1:a_n,1:a_n,1:nz+2)
+           !$acc end host_data
+           !$acc parallel loop collapse(3)
+           do concurrent (ii=1:a_n,ij=1:a_n,ik=1:nz+2)
+              zb_st(ii,ij,ik-1) = ztab_interiorT_sw_m_haloTout(ii,ij,ik)
+           end do
 #else
            call mpi_recv(b%st(1,1,0),1,interiorT(level,m),source_rank,recv_tag,MPI_COMM_HORIZ,stat,ierr)           
 #endif
@@ -1897,14 +1968,23 @@ contains
                            (/p_horiz(1)-stepsize,p_horiz(2)-stepsize/), &
                            source_rank, &
                            ierr)
+
+        zb_st => b%st
+        
         recv_tag = 1002
         if (LUseO) call mpi_recv(b%s(0,1,1),1,interior(level,m),source_rank,recv_tag,MPI_COMM_HORIZ,stat,ierr)
         recv_tag = 1012
         if (LUseT) then
-#ifdef MNH_GPUDIRECT           
-           call mpi_recv(tab_interiorT_se(level,m)%haloTout,size(tab_interiorT_se(level,m)%haloTout), &
+#ifdef MNH_GPUDIRECT
+           ztab_interiorT_se_m_haloTout => tab_interiorT_se(level,m)%haloTout
+           !$acc host_data use_device(ztab_interiorT_ne_m_haloTout)
+           call mpi_recv(ztab_interiorT_se_m_haloTout,size(ztab_interiorT_se_m_haloTout), &
                 MPI_DOUBLE_PRECISION,source_rank,recv_tag,MPI_COMM_HORIZ,stat,ierr)
-           b%st(1:a_n,1:a_n,0:nz+1) = tab_interiorT_se(level,m)%haloTout(1:a_n,1:a_n,1:nz+2)
+           !$acc end host_data
+           !$acc parallel loop collapse(3)
+           do concurrent (ii=1:a_n,ij=1:a_n,ik=1:nz+2)
+              zb_st(ii,ij,ik-1) = ztab_interiorT_se_m_haloTout(ii,ij,ik)
+           end do           
 #else
            call mpi_recv(b%st(1,1,0),1,interiorT(level,m),source_rank,recv_tag,MPI_COMM_HORIZ,stat,ierr)           
 #endif
-- 
GitLab