diff --git a/src/ZSOLVER/tensorproductmultigrid_Source/mode_mg.f90 b/src/ZSOLVER/tensorproductmultigrid_Source/mode_mg.f90
index 806825f1b47d9e14c0d528fe57666e2931faeb0e..11ade04e65232027b59ee864585f82ad81bc9a1b 100644
--- a/src/ZSOLVER/tensorproductmultigrid_Source/mode_mg.f90
+++ b/src/ZSOLVER/tensorproductmultigrid_Source/mode_mg.f90
@@ -58,6 +58,8 @@ contains
 
 subroutine mg_init_mnh(KN,KNZ,PL,PH,PA_K,PB_K,PC_K,PD_K)
 
+use MODE_OPENACC_SET_DEVICE
+  
 implicit none
 
 integer       , optional , intent (in) :: KN,KNZ
@@ -73,7 +75,9 @@ logical :: gisinit
   if  (.not. gisinit ) then
      call mpi_init(ierr)
   end if
-
+  !
+  ! get default device type     
+  call MNH_OPENACC_GET_DEVICE_AT_INIT()
   ! ... and pre initialise communication module
   call comm_preinitialise()
 
diff --git a/src/ZSOLVER/tensorproductmultigrid_Source/multigrid.f90 b/src/ZSOLVER/tensorproductmultigrid_Source/multigrid.f90
index d00404df25f62bc23ed6844d84f92a913cd9ad39..0d7eedec5fbb922c57b19dd5b17737219ba65586 100644
--- a/src/ZSOLVER/tensorproductmultigrid_Source/multigrid.f90
+++ b/src/ZSOLVER/tensorproductmultigrid_Source/multigrid.f90
@@ -1169,6 +1169,7 @@ contains
 ! Multigrid V-cycle
 !==================================================================
   recursive subroutine mg_vcycle_mnh(b,u,r,finelevel,splitlevel,level,m)
+    use MODE_OPENACC_SET_DEVICE
     implicit none
     integer, intent(in)                                     :: finelevel
     type(scalar3d), intent(inout), dimension(finelevel,0:pproc) :: b
@@ -1182,7 +1183,9 @@ contains
     integer                                                 :: nlocalx, nlocaly
     integer                                                 :: halo_size
 
-     real , dimension(:,:,:) , pointer ::  zu_level_1_m_st
+    real , dimension(:,:,:) , pointer ::  zu_level_1_m_st
+
+    integer :: iswitch_cpu_gpu = 5
 
     nlocalx = u(level,m)%ix_max-u(level,m)%ix_min+1
     nlocaly = u(level,m)%iy_max-u(level,m)%iy_min+1
@@ -1240,7 +1243,20 @@ contains
            !$acc end kernels
         end if
         ! solve on coarser grid
-        call mg_vcycle_mnh(b,u,r,finelevel,splitlevel,level-1,m)
+        ! switch from GPU to CPU if level == iswitch_cpu_gpu 
+        if (level .EQ. iswitch_cpu_gpu ) then
+           !print*,' enter mg_vcycle_mnh level=', iswitch_cpu_gpu
+           !call MNH_OPENACC_GET_DEVICE()
+           call MNH_OPENACC_SET_DEVICE_HOST()
+           !call MNH_OPENACC_GET_DEVICE()
+        end if
+        call mg_vcycle_mnh(b,u,r,finelevel,splitlevel,level-1,m)       
+        if (level .EQ. iswitch_cpu_gpu) then
+           !print*,' exit mg_vcycle_mnh level=', iswitch_cpu_gpu
+           !call MNH_OPENACC_GET_DEVICE()
+           call MNH_OPENACC_SET_DEVICE_DEFAULT()
+           !call MNH_OPENACC_GET_DEVICE()
+        end if
       end if
       ! Prolongate error
       call start_timer(t_prolongate(level,m))
@@ -1462,6 +1478,7 @@ contains
     type(scalar3d) :: z_one 
     real(kind=rl) :: alpha, beta, pq, rz, rz_old
     integer :: ierr
+    real(kind=rl) , pointer , dimension(:,:,:) :: z_one_st
 
     solvertype = solver_param%solvertype
     resreduction = solver_param%resreduction
@@ -1492,9 +1509,12 @@ contains
        z_one%s(:,:,:) = 0.0_rl
        z_one%s(1:z_one%grid_param%nz,1:z_one%icompy_max,1:z_one%icompx_max) = 1.0_rl
     end if
-    if (LUseT) then
-       z_one%st(:,:,:) = 0.0_rl
-       z_one%st(1:z_one%icompx_max,1:z_one%icompy_max,1:z_one%grid_param%nz) = 1.0_rl
+    if (LUseT) then       
+       z_one_st => z_one%st
+       !$acc kernels
+       z_one_st(:,:,:) = 0.0_rl
+       z_one_st(1:z_one%icompx_max,1:z_one%icompy_max,1:z_one%grid_param%nz) = 1.0_rl
+       !$acc end kernels
     end if
     !   Mean / Norm of B
     call scalarprod_mnh(pproc,z_one,z_one, mean_initial )