From 99881a66a5bd3090e2439e89a50b8ba6addf8d18 Mon Sep 17 00:00:00 2001
From: Juan Escobar <escj@aero.obs-mip.fr>
Date: Tue, 18 Dec 2018 15:17:29 +0100
Subject: [PATCH] Juan 18/12/2018: addd original tensorproductmultigrid/Source

---
 .../communication.f90                         |  968 ++++++++++++++
 .../conjugategradient.f90                     |  275 ++++
 tensorproductmultigrid_Source/datatypes.f90   |  381 ++++++
 .../discretisation.f90                        |  879 +++++++++++++
 tensorproductmultigrid_Source/messages.f90    |  100 ++
 tensorproductmultigrid_Source/mg_main.f90     |  700 ++++++++++
 tensorproductmultigrid_Source/multigrid.f90   | 1141 +++++++++++++++++
 tensorproductmultigrid_Source/parameters.f90  |   58 +
 tensorproductmultigrid_Source/profiles.f90    |  174 +++
 tensorproductmultigrid_Source/timer.f90       |  184 +++
 10 files changed, 4860 insertions(+)
 create mode 100644 tensorproductmultigrid_Source/communication.f90
 create mode 100644 tensorproductmultigrid_Source/conjugategradient.f90
 create mode 100644 tensorproductmultigrid_Source/datatypes.f90
 create mode 100644 tensorproductmultigrid_Source/discretisation.f90
 create mode 100644 tensorproductmultigrid_Source/messages.f90
 create mode 100644 tensorproductmultigrid_Source/mg_main.f90
 create mode 100644 tensorproductmultigrid_Source/multigrid.f90
 create mode 100644 tensorproductmultigrid_Source/parameters.f90
 create mode 100644 tensorproductmultigrid_Source/profiles.f90
 create mode 100644 tensorproductmultigrid_Source/timer.f90

diff --git a/tensorproductmultigrid_Source/communication.f90 b/tensorproductmultigrid_Source/communication.f90
new file mode 100644
index 000000000..f0856420b
--- /dev/null
+++ b/tensorproductmultigrid_Source/communication.f90
@@ -0,0 +1,968 @@
+!=== COPYRIGHT AND LICENSE STATEMENT ===
+!
+!  This file is part of the TensorProductMultigrid code.
+!  
+!  (c) The copyright relating to this work is owned jointly by the
+!  Crown, Met Office and NERC [2014]. However, it has been created
+!  with the help of the GungHo Consortium, whose members are identified
+!  at https://puma.nerc.ac.uk/trac/GungHo/wiki .
+!  
+!  Main Developer: Eike Mueller
+!  
+!  TensorProductMultigrid is free software: you can redistribute it and/or
+!  modify it under the terms of the GNU Lesser General Public License as
+!  published by the Free Software Foundation, either version 3 of the
+!  License, or (at your option) any later version.
+!  
+!  TensorProductMultigrid is distributed in the hope that it will be useful,
+!  but WITHOUT ANY WARRANTY; without even the implied warranty of
+!  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+!  GNU Lesser General Public License for more details.
+!  
+!  You should have received a copy of the GNU Lesser General Public License
+!  along with TensorProductMultigrid (see files COPYING and COPYING.LESSER).
+!  If not, see <http://www.gnu.org/licenses/>.
+!
+!=== COPYRIGHT AND LICENSE STATEMENT ===
+
+
+!==================================================================
+!
+!  MPI communication routines for multigrid code
+!
+!    Eike Mueller, University of Bath, Feb 2012
+!
+!==================================================================
+
+module communication
+  use messages
+  use datatypes
+  use parameters
+  use mpi
+  use timer
+
+  implicit none
+
+public::comm_preinitialise
+public::comm_initialise
+public::comm_finalise
+public::scalarprod
+public::haloswap
+public::ihaloswap
+public::collect
+public::distribute
+public::i_am_master_mpi
+public::master_rank
+public::pproc
+public::MPI_COMM_HORIZ
+public::comm_parameters
+public::comm_measuretime
+
+
+  ! Number of processors
+  ! n_proc = 2^(2*pproc), with integer pproc
+  integer :: pproc
+
+! Rank of master process
+  integer, parameter :: master_rank = 0
+! Am I the master process?
+  logical :: i_am_master_mpi
+
+  integer, parameter :: dim = 3  ! Dimension
+  integer, parameter :: dim_horiz = 2  ! Horizontal dimension
+  integer :: MPI_COMM_HORIZ ! Communicator with horizontal partitioning
+
+private
+
+! Data types for halo exchange in both x- and y-direction
+  integer, dimension(:,:,:), allocatable :: halo_type
+
+! MPI vector data types
+  ! Halo for data exchange in north-south direction
+  integer, allocatable, dimension(:,:) :: halo_ns
+  ! Vector data type for interior of field a(level,m)
+  integer, allocatable, dimension(:,:) :: interior
+  ! Vector data type for one quarter of interior of field
+  ! at level a(level,m). This has the same size (and can be
+  ! used for communications with) the interior of a(level,m+1)
+  integer, allocatable, dimension(:,:) :: sub_interior
+  ! Timer for halo swaps
+  type(time), allocatable, dimension(:,:) :: t_haloswap
+  ! Timer for collect and distribute
+  type(time), allocatable, dimension(:) :: t_collect
+  type(time), allocatable, dimension(:) :: t_distribute
+  ! Parallelisation parameters
+  ! Measure communication times?
+  logical :: comm_measuretime
+
+  ! Parallel communication parameters
+  type comm_parameters
+    ! Size of halos
+    integer :: halo_size
+  end type comm_parameters
+
+  type(comm_parameters) :: comm_param
+
+! Data layout
+! ===========
+!
+!  The number of processes has to be of the form nproc = 2^(2*pproc) to
+!  ensure that data can be distributed between processes.
+!  The processes are arranged in a (2^pproc) x (2^pproc) cartesian grid
+!  in the horizontal plane (i.e. vertical columns are always local to one
+!  process), which is implemented via the communicator MPI_COMM_HORIZ.
+!  This MPI_cart_rank() and MPI_cart_shift() can then be used to
+!  easily identify neighbouring processes.
+
+!  The number of data grid cells in each direction has to be a multiply
+!  of 2**(L-1) where L is the number of levels (including the coarse
+!  and fine level), with the coarse level corresponding to level=1.
+!  Also define L_split as the level where we start to pull together
+!  data. For levels > L_split each position in the cartesian grid is
+!  included in the work, below this only a subset of processes is
+!  used.
+!
+!  Each grid a(level,m) is identified by two numbers:
+!  (1) The multigrid level it belongs to (level)
+!  (2) The number of active processes that operate on it (2^(2*m)).
+!
+!  For level > L_split, m=procp. For L_split we store a(L_split,pproc) and
+!  a(L_split,pproc-1), and only processes with even coordinates in both
+!  horizontal directions use this grid.
+!  Below that level, store a(L_split-1,pproc-1) and a(L_split-1,pproc-2),
+!  where only processes for which both horiontal coordinates are
+!  multiples of four use the latter. This is continued until only on
+!  process is left.
+!
+!
+!  level
+!    L          a(L,pproc)
+!    L-1        a(L-1,pproc)
+!    ...
+!    L_split    a(L_split,pproc)    a(L_split  ,pproc-1)
+!    L_split-1                      a(L_split-1,pproc-1)  a(L_split-1,pproc-2)
+!
+!                                                                  ... a(3,1)
+!                                                                      a(2,1)
+!                                                                      a(1,1)
+!
+!  When moving from left to right in the above graph the total number of
+!  grid cells does not change, but the number of data points per process
+!  increases by a factor of 4.
+!
+! Parallel operations
+! ===================
+!
+!   (*)  Halo exchange. Update halo with data from neighbouring
+!        processors in cartesian grid on current (level,m)
+!   (*)  Collect data on all processes at (level,m) on those
+!        processes that are still active on (level,m-1).
+!   (*)  Distribute data at (level,m-1) and duplicate on all processes
+!        that are active at (level,m).
+!
+!   Note that in the cartesian processor grid the first coordinate
+!   is the North-South (y-) direction, the second coordinate is the
+!   East-West (x-) direction, i.e. the layout is this:
+!
+!   p_0 (0,0)   p_1 (0,1)   p_2 (0,2)   p_3 (0,3)
+!
+!   p_4 (1,0)   p_5 (1,1)   p_6 (1,2)   p_7 (1,3)
+!
+!   p_8 (2,0)   p_9 (2,1)   p_10 (2,2)  p_11 (2,3)
+!
+!                       [...]
+!
+!
+!   Normal multigrid restriction and prolongation are used to
+!   move between levels with fixed m.
+!
+!
+
+contains
+
+!==================================================================
+! Pre-initialise communication routines
+!==================================================================
+  subroutine comm_preinitialise()
+    implicit none
+    integer :: nproc, ierr, rank
+    call mpi_comm_size(MPI_COMM_WORLD, nproc, ierr)
+    call mpi_comm_rank(MPI_COMM_WORLD, rank, ierr)
+    i_am_master_mpi = (rank == master_rank)
+    ! Check that nproc = 2^(2*p)
+    pproc = floor(log(1.0d0*nproc)/log(4.0d0))
+    if ( (nproc - 4**pproc) .ne. 0) then
+      call fatalerror("Number of processors has to be 2^(2*pproc) with integer pproc.")
+    end if
+    if (i_am_master_mpi) then
+      write(STDOUT,'("PARALLEL RUN")')
+      write(STDOUT,'("Number of processors : 2^(2*pproc) = ",I10," with pproc = ",I6)') &
+        nproc, pproc
+    end if
+    ! Create halo data types
+
+  end subroutine comm_preinitialise
+
+!==================================================================
+! Initialise communication routines
+!==================================================================
+  subroutine comm_initialise(n_lev,        & !} multigrid parameters
+                            lev_split,     & !}
+                            grid_param,    & ! Grid parameters
+                            comm_param_in)   ! Parallel communication
+                                             ! parameters
+    implicit none
+    integer, intent(in) :: n_lev
+    integer, intent(in) :: lev_split
+    type(grid_parameters), intent(inout) :: grid_param
+    type(comm_parameters), intent(in)    :: comm_param_in
+    integer :: n
+    integer :: nz
+    integer :: rank, nproc, ierr
+    integer :: count, blocklength, stride
+    integer, dimension(2) :: p_horiz
+    integer :: m, level, nlocal
+    logical :: reduced_m
+    integer :: halo_size
+    character(len=32) :: t_label
+
+    n = grid_param%n
+    nz = grid_param%nz
+
+    comm_param = comm_param_in
+    halo_size = comm_param%halo_size
+
+    call mpi_comm_size(MPI_COMM_WORLD, nproc, ierr)
+
+    ! Create cartesian topology
+    call mpi_cart_create(MPI_COMM_WORLD,        & ! Old communicator name
+                         dim_horiz,             & ! horizontal dimension
+                         (/2**pproc,2**pproc/), & ! extent in each horizontal direction
+                         (/.false.,.false./),   & ! periodic?
+                         .true.,                & ! reorder?
+                         MPI_COMM_HORIZ,        & ! Name of new communicator
+                         ierr)
+   ! calculate and display rank and corrdinates in cartesian grid
+    call mpi_comm_rank(MPI_COMM_HORIZ, rank, ierr)
+    call mpi_cart_coords(MPI_COMM_HORIZ,rank,dim_horiz,p_horiz,ierr)
+
+    ! Local size of (horizontal) grid
+    nlocal = n/2**pproc
+
+    ! === Set up data types ===
+    ! Halo for exchange in north-south direction
+    allocate(halo_ns(n_lev,0:pproc))
+    ! Interior data types
+    allocate(interior(n_lev,0:pproc))
+    allocate(sub_interior(n_lev,0:pproc))
+    ! Timer
+    allocate(t_haloswap(n_lev,0:pproc))
+    allocate(t_collect(0:pproc))
+    allocate(t_distribute(0:pproc))
+    do m=0,pproc
+      write(t_label,'("t_collect(",I3,")")') m
+      call initialise_timer(t_collect(m),t_label)
+      write(t_label,'("t_distribute(",I3,")")') m
+      call initialise_timer(t_distribute(m),t_label)
+    end do
+
+    m = pproc
+    level = n_lev
+    reduced_m = .false.
+    do while (level > 0)
+      ! --- Create halo data types ---
+      ! NS- (y-) direction
+      count = nlocal
+      blocklength = (nz+2)*halo_size
+      stride = (nlocal+2*halo_size)*(nz+2)
+      call mpi_type_vector(count,blocklength,stride,MPI_DOUBLE_PRECISION, &
+                           halo_ns(level,m),ierr)
+      call mpi_type_commit(halo_ns(level,m),ierr)
+#ifndef NDEBUG
+  if (ierr .ne. 0) &
+    call fatalerror("Commit halo_ns failed in mpi_type_commit().")
+#endif
+      ! --- Create interior data types ---
+      count = nlocal
+      blocklength = nlocal*(nz+2)
+      stride = (nz+2)*(nlocal+2*halo_size)
+      call mpi_type_vector(count,blocklength,stride,MPI_DOUBLE_PRECISION,interior(level,m),ierr)
+      call mpi_type_commit(interior(level,m),ierr)
+      count = nlocal/2
+      blocklength = nlocal/2*(nz+2)
+      stride = (nlocal+2*halo_size)*(nz+2)
+      call mpi_type_vector(count,blocklength,stride,MPI_DOUBLE_PRECISION,sub_interior(level,m),ierr)
+      call mpi_type_commit(sub_interior(level,m),ierr)
+
+      ! --- Create timers ---
+      write(t_label,'("t_haloswap(",I3,",",I3,")")') level,m
+      call initialise_timer(t_haloswap(level,m),t_label)
+
+      ! If we are below L_split, split data
+      if ( (level .le. lev_split) .and. (m > 0) .and. (.not. reduced_m)) then
+        reduced_m = .true.
+        m = m-1
+        nlocal = 2*nlocal
+        cycle
+      end if
+      reduced_m = .false.
+      level = level-1
+      nlocal = nlocal/2
+    end do
+
+  end subroutine comm_initialise
+
+!==================================================================
+! Finalise communication routines
+!==================================================================
+  subroutine comm_finalise(n_lev,   &  ! }
+                           lev_split)  ! } Multigrid parameters
+    implicit none
+    integer, intent(in) :: n_lev
+    integer, intent(in) :: lev_split
+    logical :: reduced_m
+    integer :: level, m
+    integer :: ierr
+    character(len=80) :: s
+    m = pproc
+    level = n_lev
+    reduced_m = .false.
+    if (i_am_master_mpi) then
+      write(STDOUT,'(" *** Finalising communications ***")')
+    end if
+    call print_timerinfo("--- Communication timing results ---")
+    do while (level > 0)
+      write(s,'("level = ",I3,", m = ",I3)') level, m
+      call print_timerinfo(s)
+      ! --- Print out timer information ---
+      call print_elapsed(t_haloswap(level,m),.True.,1.0_rl)
+      ! --- Free halo data types ---
+      call mpi_type_free(halo_ns(level,m),ierr)
+      ! --- Free interior data types ---
+      call mpi_type_free(interior(level,m),ierr)
+      call mpi_type_free(sub_interior(level,m),ierr)
+      ! If we are below L_split, split data
+      if ( (level .le. lev_split) .and. (m > 0) .and. (.not. reduced_m)) then
+        reduced_m = .true.
+        m = m-1
+        cycle
+      end if
+      reduced_m = .false.
+      level = level-1
+    end do
+    do m=pproc,0,-1
+      write(s,'("m = ",I3)') m
+      call print_timerinfo(s)
+      ! --- Print out timer information ---
+      call print_elapsed(t_collect(m),.True.,1.0_rl)
+      call print_elapsed(t_distribute(m),.True.,1.0_rl)
+    end do
+
+    ! Deallocate arrays
+    deallocate(halo_ns)
+    deallocate(interior)
+    deallocate(sub_interior)
+    deallocate(t_haloswap)
+    deallocate(t_collect)
+    deallocate(t_distribute)
+    if (i_am_master_mpi) then
+      write(STDOUT,'("")')
+    end if
+
+  end subroutine comm_finalise
+
+!==================================================================
+!  Scalar product of two fields
+!==================================================================
+  subroutine scalarprod(m, a, b, s)
+    implicit none
+    integer, intent(in) :: m
+    type(scalar3d), intent(in) :: a
+    type(scalar3d), intent(in) :: b
+    real(kind=rl), intent(out) :: s
+    integer :: nprocs, rank, ierr
+    integer :: p_horiz(2)
+    integer :: stepsize
+    integer, parameter :: dim_horiz = 2
+    real(kind=rl) :: local_sum, global_sum
+    integer :: nlocal, nz, i
+    real(kind=rl) :: ddot
+
+    nlocal = a%ix_max-a%ix_min+1
+    nz = a%grid_param%nz
+    ! Work out coordinates of processor
+    call mpi_comm_size(MPI_COMM_HORIZ,nprocs,ierr)
+    call mpi_comm_rank(MPI_COMM_HORIZ,rank,ierr)
+    stepsize = 2**(pproc-m)
+    if (nprocs > 1) then
+      ! Only inlcude local sum if the processor coordinates
+      ! are multiples of stepsize
+      call mpi_cart_coords(MPI_COMM_HORIZ,rank,dim_horiz,p_horiz,ierr)
+      if ( (stepsize == 1) .or.                   &
+           ( (stepsize > 1) .and.                 &
+             (mod(p_horiz(1),stepsize)==0) .and.  &
+             (mod(p_horiz(2),stepsize)==0) ) ) then
+        local_sum = 0.0_rl
+        do i = 1, nlocal
+          local_sum = local_sum &
+                    + ddot((nz+2)*nlocal,a%s(0,1,i),1,b%s(0,1,i),1)
+        end do
+      else
+        local_sum = 0.0_rl
+      end if
+      call mpi_allreduce(local_sum,global_sum,1,MPI_DOUBLE_PRECISION, &
+                         MPI_SUM,MPI_COMM_HORIZ,ierr)
+    else
+      global_sum = 0.0_rl
+      do i = 1, nlocal
+        global_sum = global_sum &
+                   + ddot((nz+2)*nlocal,a%s(0,1,i),1,b%s(0,1,i),1)
+      end do
+    end if
+    s = global_sum
+  end subroutine scalarprod
+
+!==================================================================
+!  Initiate asynchronous halo exchange
+!
+!  For all processes with horizontal indices that are multiples
+!  of 2^(pproc-m), update halos with information held by
+!  neighbouring processes, e.g. for pproc-m = 1, stepsize=2
+!
+!                      N  (0,2)
+!                      ^
+!                      !
+!                      v
+!
+!      W (2,0) <-->  (2,2)  <-->  E (2,4)
+!
+!                      ^
+!                      !
+!                      v
+!                      S (4,2)
+!
+!==================================================================
+  subroutine ihaloswap(level,m,       &  ! multigrid- and processor- level
+                       a,             &  ! data field
+                       send_requests, &  ! send requests (OUT)
+                       recv_requests  &  ! recv requests (OUT)
+                       )
+    implicit none
+    integer, intent(in) :: level
+    integer, intent(in) :: m
+    integer, intent(out), dimension(4) :: send_requests
+    integer, intent(out), dimension(4) :: recv_requests
+    type(scalar3d), intent(inout) :: a
+    integer :: a_n  ! horizontal grid size
+    integer :: nz   ! vertical grid size
+    integer, dimension(2) :: p_horiz
+    integer :: stepsize
+    integer :: ierr, rank, sendtag, recvtag
+    integer :: stat(MPI_STATUS_SIZE)
+    integer :: halo_size
+    integer :: neighbour_n_rank
+    integer :: neighbour_s_rank
+    integer :: neighbour_e_rank
+    integer :: neighbour_w_rank
+    integer :: yoffset, blocklength
+
+    halo_size = comm_param%halo_size
+
+    ! Do nothing if we are only using one processor
+    if (m > 0) then
+      a_n = a%ix_max-a%ix_min+1
+      nz = a%grid_param%nz
+      stepsize = 2**(pproc-m)
+
+      ! Work out rank, only execute on relevant processes
+      call mpi_comm_rank(MPI_COMM_HORIZ, rank, ierr)
+      call mpi_cart_coords(MPI_COMM_HORIZ,rank,dim_horiz,p_horiz,ierr)
+
+      ! Work out ranks of neighbours
+      ! W -> E
+      call mpi_cart_shift(MPI_COMM_HORIZ,1, stepsize, &
+                          neighbour_w_rank,neighbour_e_rank,ierr)
+      ! N -> S
+      call mpi_cart_shift(MPI_COMM_HORIZ,0, stepsize, &
+                          neighbour_n_rank,neighbour_s_rank,ierr)
+      if ( (stepsize == 1) .or.                   &
+         (  (mod(p_horiz(1),stepsize) == 0) .and. &
+            (mod(p_horiz(2),stepsize) == 0) ) ) then
+        if (halo_size == 1) then
+          ! Do not include corners in send/recv
+          yoffset = 1
+          blocklength = a_n*(nz+2)*halo_size
+        else
+          yoffset = 1-halo_size
+          blocklength = (a_n+2*halo_size)*(nz+2)*halo_size
+        end if
+        ! Receive from north
+        recvtag = 2
+        call mpi_irecv(a%s(0,0-(halo_size-1),1),1,                  &
+                       halo_ns(level,m),neighbour_n_rank,recvtag,   &
+                       MPI_COMM_HORIZ, recv_requests(1), ierr)
+        ! Receive from south
+        recvtag = 3
+        call mpi_irecv(a%s(0,a_n+1,1),1,                            &
+                       halo_ns(level,m),neighbour_s_rank,recvtag,   &
+                       MPI_COMM_HORIZ, recv_requests(2), ierr)
+        ! Send to south
+        sendtag = 2
+        call mpi_isend(a%s(0,a_n-(halo_size-1),1),1,                &
+                       halo_ns(level,m),neighbour_s_rank,sendtag,   &
+                       MPI_COMM_HORIZ, send_requests(1), ierr)
+        ! Send to north
+        sendtag = 3
+        call mpi_isend(a%s(0,1,1),1,                                &
+                       halo_ns(level,m),neighbour_n_rank,sendtag,   &
+                       MPI_COMM_HORIZ, send_requests(2), ierr)
+        ! Receive from west
+        recvtag = 0
+        call mpi_irecv(a%s(0,yoffset,0-(halo_size-1)),blocklength,    &
+                       MPI_DOUBLE_PRECISION,neighbour_w_rank,recvtag, &
+                       MPI_COMM_HORIZ, recv_requests(3), ierr)
+        ! Receive from east
+        sendtag = 1
+        call mpi_irecv(a%s(0,yoffset,a_n+1),blocklength,          &
+                       MPI_DOUBLE_PRECISION,neighbour_e_rank,recvtag, &
+                       MPI_COMM_HORIZ, recv_requests(4), ierr)
+        ! Send to east
+        sendtag = 0
+        call mpi_isend(a%s(0,yoffset,a_n-(halo_size-1)),blocklength,  &
+                       MPI_DOUBLE_PRECISION,neighbour_e_rank,sendtag, &
+                       MPI_COMM_HORIZ, send_requests(3), ierr)
+        ! Send to west
+        recvtag = 1
+        call mpi_isend(a%s(0,yoffset,1),blocklength,                &
+                       MPI_DOUBLE_PRECISION,neighbour_w_rank,sendtag,   &
+                       MPI_COMM_HORIZ, send_requests(4), ierr)
+      end if
+    end if
+  end subroutine ihaloswap
+
+!==================================================================
+!  Halo exchange
+!
+!  For all processes with horizontal indices that are multiples
+!  of 2^(pproc-m), update halos with information held by
+!  neighbouring processes, e.g. for pproc-m = 1, stepsize=2
+!
+!                      N  (0,2)
+!                      ^
+!                      !
+!                      v
+!
+!      W (2,0) <-->  (2,2)  <-->  E (2,4)
+!
+!                      ^
+!                      !
+!                      v
+!                      S (4,2)
+!
+!==================================================================
+  subroutine haloswap(level,m, &  ! multigrid- and processor- level
+                      a)          ! data field
+    implicit none
+    integer, intent(in) :: level
+    integer, intent(in) :: m
+    type(scalar3d), intent(inout) :: a
+    integer :: a_n  ! horizontal grid size
+    integer :: nz   ! vertical grid size
+    integer, dimension(2) :: p_horiz
+    integer :: stepsize
+    integer :: ierr, rank, sendtag, recvtag
+    integer :: stat(MPI_STATUS_SIZE)
+    integer :: halo_size
+    integer :: neighbour_n_rank
+    integer :: neighbour_s_rank
+    integer :: neighbour_e_rank
+    integer :: neighbour_w_rank
+    integer :: yoffset, blocklength
+    integer, dimension(4) :: requests_ns
+    integer, dimension(4) :: requests_ew
+
+    halo_size = comm_param%halo_size
+
+    ! Do nothing if we are only using one processor
+    if (m > 0) then
+      if (comm_measuretime) then
+        call start_timer(t_haloswap(level,m))
+      end if
+      a_n = a%ix_max-a%ix_min+1
+      nz = a%grid_param%nz
+      stepsize = 2**(pproc-m)
+
+      ! Work out rank, only execute on relevant processes
+      call mpi_comm_rank(MPI_COMM_HORIZ, rank, ierr)
+      call mpi_cart_coords(MPI_COMM_HORIZ,rank,dim_horiz,p_horiz,ierr)
+
+      ! Work out ranks of neighbours
+      ! W -> E
+      call mpi_cart_shift(MPI_COMM_HORIZ,1, stepsize, &
+                          neighbour_w_rank,neighbour_e_rank,ierr)
+      ! N -> S
+      call mpi_cart_shift(MPI_COMM_HORIZ,0, stepsize, &
+                          neighbour_n_rank,neighbour_s_rank,ierr)
+      if ( (stepsize == 1) .or.                   &
+         (  (mod(p_horiz(1),stepsize) == 0) .and. &
+            (mod(p_horiz(2),stepsize) == 0) ) ) then
+        if (halo_size == 1) then
+          ! Do not include corners in send/recv
+          yoffset = 1
+          blocklength = a_n*(nz+2)*halo_size
+        else
+          yoffset = 1-halo_size
+          blocklength = (a_n+2*halo_size)*(nz+2)*halo_size
+        end if
+        ! Receive from north
+        recvtag = 2
+        call mpi_irecv(a%s(0,0-(halo_size-1),1),1,                  &
+                       halo_ns(level,m),neighbour_n_rank,recvtag,   &
+                       MPI_COMM_HORIZ, requests_ns(1), ierr)
+        ! Receive from south
+        recvtag = 3
+        call mpi_irecv(a%s(0,a_n+1,1),1,                            &
+                       halo_ns(level,m),neighbour_s_rank,recvtag,   &
+                       MPI_COMM_HORIZ, requests_ns(2), ierr)
+        ! Send to south
+        sendtag = 2
+        call mpi_isend(a%s(0,a_n-(halo_size-1),1),1,                &
+                       halo_ns(level,m),neighbour_s_rank,sendtag,   &
+                       MPI_COMM_HORIZ, requests_ns(3), ierr)
+        ! Send to north
+        sendtag = 3
+        call mpi_isend(a%s(0,1,1),1,                                &
+                       halo_ns(level,m),neighbour_n_rank,sendtag,   &
+                       MPI_COMM_HORIZ, requests_ns(4), ierr)
+        if (halo_size > 1) then
+          ! Wait for North <-> South communication to complete
+          call mpi_waitall(4,requests_ns, MPI_STATUSES_IGNORE, ierr)
+        end if
+        ! Receive from west
+        recvtag = 0
+        call mpi_irecv(a%s(0,yoffset,0-(halo_size-1)),blocklength,    &
+                       MPI_DOUBLE_PRECISION,neighbour_w_rank,recvtag, &
+                       MPI_COMM_HORIZ, requests_ew(1), ierr)
+        ! Receive from east
+        sendtag = 1
+        call mpi_irecv(a%s(0,yoffset,a_n+1),blocklength,          &
+                       MPI_DOUBLE_PRECISION,neighbour_e_rank,recvtag, &
+                       MPI_COMM_HORIZ, requests_ew(2), ierr)
+        ! Send to east
+        sendtag = 0
+        call mpi_isend(a%s(0,yoffset,a_n-(halo_size-1)),blocklength,  &
+                       MPI_DOUBLE_PRECISION,neighbour_e_rank,sendtag, &
+                       MPI_COMM_HORIZ, requests_ew(3), ierr)
+        ! Send to west
+        recvtag = 1
+        call mpi_isend(a%s(0,yoffset,1),blocklength,                &
+                       MPI_DOUBLE_PRECISION,neighbour_w_rank,sendtag,   &
+                       MPI_COMM_HORIZ, requests_ew(4), ierr)
+        ! Wait for East <-> West communication to complete
+        if (halo_size == 1) then
+          ! Wait for North <-> South communication to complete
+          call mpi_waitall(4,requests_ns, MPI_STATUSES_IGNORE, ierr)
+        end if
+        call mpi_waitall(4,requests_ew, MPI_STATUSES_IGNORE, ierr)
+      end if
+      if (comm_measuretime) then
+        call finish_timer(t_haloswap(level,m))
+      end if
+    end if
+  end subroutine haloswap
+
+!==================================================================
+!  Collect from a(level,m) and store on less processors
+!  in b(level,m-1)
+!
+!  Example for pproc-m = 1, i.e. stepsize = 2:
+!
+!   NW (0,0)  <--  NE (0,2)
+!
+!     ^      .
+!     !        .
+!                .
+!   SW (2,0)       SE (2,2) [send to 0,0]
+!
+!==================================================================
+  subroutine collect(level,m, &    ! multigrid and processor level
+                     a, &          ! IN: data on level (level,m)
+                     b)            ! OUT: data on level (level,m-1)
+    implicit none
+    integer, intent(in) :: level
+    integer, intent(in) :: m
+    type(scalar3d), intent(in) :: a
+    type(scalar3d), intent(inout) :: b
+    integer :: a_n, b_n   ! horizontal grid sizes
+    integer :: nz ! vertical grid size
+    integer, dimension(2) :: p_horiz
+    integer :: stepsize
+    integer :: ierr, source_rank, dest_rank, rank, recv_tag, send_tag, iz
+    logical :: corner_nw, corner_ne, corner_sw, corner_se
+    integer :: recv_request(3)
+
+    call start_timer(t_collect(m))
+
+    stepsize = 2**(pproc-m)
+
+    a_n = a%ix_max-a%ix_min+1
+    b_n = b%ix_max-b%ix_min+1
+    nz = b%grid_param%nz
+
+    ! Work out rank, only execute on relevant processes
+    call mpi_comm_rank(MPI_COMM_HORIZ, rank, ierr)
+    ! Store position in process grid in in p_horiz
+    ! Note we can NOT use cart_shift as we need diagonal neighburs as well
+    call mpi_cart_coords(MPI_COMM_HORIZ,rank,dim_horiz,p_horiz,ierr)
+
+    ! Ignore all processes that do not participate at this level
+    if ( (stepsize .eq. 1) .or. ((mod(p_horiz(1),stepsize) == 0) .and. (mod(p_horiz(2),stepsize)) == 0)) then
+      ! Determine position in local 2x2 block
+      if (stepsize .eq. 1) then
+        corner_nw = ((mod(p_horiz(1),2) == 0) .and. (mod(p_horiz(2),2) == 0))
+        corner_ne = ((mod(p_horiz(1),2) == 0) .and. (mod(p_horiz(2),2) == 1))
+        corner_sw = ((mod(p_horiz(1),2) == 1) .and. (mod(p_horiz(2),2) == 0))
+        corner_se = ((mod(p_horiz(1),2) == 1) .and. (mod(p_horiz(2),2) == 1))
+      else
+        corner_nw = ((mod(p_horiz(1)/stepsize,2) == 0) .and. (mod(p_horiz(2)/stepsize,2) == 0))
+        corner_ne = ((mod(p_horiz(1)/stepsize,2) == 0) .and. (mod(p_horiz(2)/stepsize,2) == 1))
+        corner_sw = ((mod(p_horiz(1)/stepsize,2) == 1) .and. (mod(p_horiz(2)/stepsize,2) == 0))
+        corner_se = ((mod(p_horiz(1)/stepsize,2) == 1) .and. (mod(p_horiz(2)/stepsize,2) == 1))
+      end if
+      ! NW receives from the other three processes
+      if ( corner_nw ) then
+        ! Receive from NE
+        call mpi_cart_rank(MPI_COMM_HORIZ, &
+                           (/p_horiz(1),p_horiz(2)+stepsize/), &
+                           source_rank, &
+                           ierr)
+        recv_tag = 0
+        call mpi_irecv(b%s(0,1,b_n/2+1),1,sub_interior(level,m-1), source_rank, recv_tag, MPI_COMM_HORIZ, &
+                           recv_request(1),ierr)
+#ifndef NDEBUG
+  if (ierr .ne. 0) &
+    call fatalerror("Collect: receive from NE failed in mpi_irecv().")
+#endif
+         ! Receive from SW
+        call mpi_cart_rank(MPI_COMM_HORIZ, &
+                           (/p_horiz(1)+stepsize,p_horiz(2)/), &
+                           source_rank, &
+                           ierr)
+        recv_tag = 1
+        call mpi_irecv(b%s(0,b_n/2+1,1),1,sub_interior(level,m-1), source_rank, recv_tag, MPI_COMM_HORIZ, &
+                           recv_request(2),ierr)
+#ifndef NDEBUG
+  if (ierr .ne. 0) &
+    call fatalerror("Collect: receive from SW failed in mpi_irecv().")
+#endif
+        ! Receive from SE
+        call mpi_cart_rank(MPI_COMM_HORIZ, &
+                           (/p_horiz(1)+stepsize,p_horiz(2)+stepsize/), &
+                           source_rank, &
+                           ierr)
+        recv_tag = 2
+        call mpi_irecv(b%s(0,b_n/2+1,b_n/2+1),1,sub_interior(level,m-1), source_rank, recv_tag, MPI_COMM_HORIZ, &
+                           recv_request(3),ierr)
+#ifndef NDEBUG
+  if (ierr .ne. 0) &
+    call fatalerror("Collect: receive from SE failed in mpi_irecv().")
+#endif
+        ! Copy local data while waiting for data from other processes
+        do iz=0,nz+1
+          b%s(iz,1:a_n,1:a_n) = a%s(iz,1:a_n,1:a_n)
+        end do
+        ! Wait for receives to complete before proceeding
+        call mpi_waitall(3,recv_request,MPI_STATUSES_IGNORE,ierr)
+      end if
+      if ( corner_ne ) then
+        ! Send to NW
+        call mpi_cart_rank(MPI_COMM_HORIZ, &
+                           (/p_horiz(1),p_horiz(2)-stepsize/), &
+                           dest_rank, &
+                           ierr)
+        send_tag = 0
+        call mpi_send(a%s(0,1,1),1,interior(level,m),dest_rank,send_tag,MPI_COMM_HORIZ,ierr)
+#ifndef NDEBUG
+  if (ierr .ne. 0) &
+    call fatalerror("Collect: send from NE failed in mpi_send().")
+#endif
+      end if
+      if ( corner_sw ) then
+        ! Send to NW
+        call mpi_cart_rank(MPI_COMM_HORIZ, &
+                           (/p_horiz(1)-stepsize,p_horiz(2)/), &
+                           dest_rank, &
+                           ierr)
+        send_tag = 1
+        call mpi_send(a%s(0,1,1),1,interior(level,m),dest_rank,send_tag,MPI_COMM_HORIZ,ierr)
+#ifndef NDEBUG
+  if (ierr .ne. 0) &
+    call fatalerror("Collect: send from SW failed in mpi_send().")
+#endif
+      end if
+      if ( corner_se ) then
+        ! send to NW
+        call mpi_cart_rank(MPI_COMM_HORIZ, &
+                           (/p_horiz(1)-stepsize,p_horiz(2)-stepsize/), &
+                           dest_rank, &
+                           ierr)
+        send_tag = 2
+        call mpi_send(a%s(0,1,1),1,interior(level,m),dest_rank,send_tag,MPI_COMM_HORIZ,ierr)
+#ifndef NDEBUG
+  if (ierr .ne. 0) &
+    call fatalerror("Collect: send from SE failed in mpi_send().")
+#endif
+      end if
+
+    end if
+    call finish_timer(t_collect(m))
+
+  end subroutine collect
+
+!==================================================================
+!  Distribute data in a(level,m-1) and store in b(level,m)
+!
+!  Example for p-m = 1, i.e. stepsize = 2:
+!
+!   NW (0,0)  -->  NE (2,0)
+!
+!     !      .
+!     v        .
+!                .
+!   SW (0,2)       SE (2,2) [receive from to 0,0]
+!==================================================================
+  subroutine distribute(level,m, &  ! multigrid and processor level
+                        a,       &  ! IN: Data on level (level,m-1)
+                        b)          ! OUT: Data on level (level,m)
+    implicit none
+    integer, intent(in) :: level
+    integer, intent(in) :: m
+    type(scalar3d), intent(in) :: a
+    type(scalar3d), intent(inout) :: b
+    integer :: a_n, b_n   ! horizontal grid sizes
+    integer :: nz ! vertical grid size
+    integer, dimension(2) :: p_horiz
+    integer :: stepsize
+    integer :: ierr, source_rank, dest_rank, send_tag, recv_tag, rank, iz
+    integer :: stat(MPI_STATUS_SIZE)
+    integer :: send_request(3)
+    logical :: corner_nw, corner_ne, corner_sw, corner_se
+
+    call start_timer(t_distribute(m))
+
+    stepsize = 2**(pproc-m)
+
+    a_n = a%ix_max-a%ix_min+1
+    b_n = b%ix_max-b%ix_min+1
+    nz = a%grid_param%nz
+
+    ! Work out rank, only execute on relevant processes
+    call mpi_comm_rank(MPI_COMM_HORIZ, rank, ierr)
+    call mpi_cart_coords(MPI_COMM_HORIZ,rank,dim_horiz,p_horiz,ierr)
+
+    ! Ignore all processes that do not participate at this level
+    if ( (stepsize .eq. 1) .or. ((mod(p_horiz(1),stepsize) == 0) .and. (mod(p_horiz(2),stepsize)) == 0)) then
+      ! Work out coordinates in local 2 x 2 block
+      if (stepsize .eq. 1) then
+        corner_nw = ((mod(p_horiz(1),2) == 0) .and. (mod(p_horiz(2),2) == 0))
+        corner_ne = ((mod(p_horiz(1),2) == 0) .and. (mod(p_horiz(2),2) == 1))
+        corner_sw = ((mod(p_horiz(1),2) == 1) .and. (mod(p_horiz(2),2) == 0))
+        corner_se = ((mod(p_horiz(1),2) == 1) .and. (mod(p_horiz(2),2) == 1))
+      else
+        corner_nw = ((mod(p_horiz(1)/stepsize,2) == 0) .and. (mod(p_horiz(2)/stepsize,2) == 0))
+        corner_ne = ((mod(p_horiz(1)/stepsize,2) == 0) .and. (mod(p_horiz(2)/stepsize,2) == 1))
+        corner_sw = ((mod(p_horiz(1)/stepsize,2) == 1) .and. (mod(p_horiz(2)/stepsize,2) == 0))
+        corner_se = ((mod(p_horiz(1)/stepsize,2) == 1) .and. (mod(p_horiz(2)/stepsize,2) == 1))
+      end if
+      if ( corner_nw ) then
+        ! (Asynchronous) send to NE
+        call mpi_cart_rank(MPI_COMM_HORIZ, &
+                           (/p_horiz(1),p_horiz(2)+stepsize/), &
+                           dest_rank, &
+                           ierr)
+
+        send_tag = 0
+        call mpi_isend(a%s(0,1,a_n/2+1), 1,sub_interior(level,m-1),dest_rank, send_tag, &
+                       MPI_COMM_HORIZ,send_request(1),ierr)
+#ifndef NDEBUG
+  if (ierr .ne. 0) &
+    call fatalerror("Distribute: send to NE failed in mpi_isend().")
+#endif
+        ! (Asynchronous) send to SW
+        call mpi_cart_rank(MPI_COMM_HORIZ, &
+                           (/p_horiz(1)+stepsize,p_horiz(2)/), &
+                           dest_rank, &
+                           ierr)
+        send_tag = 1
+        call mpi_isend(a%s(0,a_n/2+1,1),1,sub_interior(level,m-1), dest_rank, send_tag, &
+                       MPI_COMM_HORIZ, send_request(2), ierr)
+#ifndef NDEBUG
+  if (ierr .ne. 0) &
+    call fatalerror("Distribute: send to SW failed in mpi_isend().")
+#endif
+        ! (Asynchronous) send to SE
+        call mpi_cart_rank(MPI_COMM_HORIZ, &
+                           (/p_horiz(1)+stepsize,p_horiz(2)+stepsize/), &
+                           dest_rank, &
+                           ierr)
+        send_tag = 2
+        call mpi_isend(a%s(0,a_n/2+1,a_n/2+1),1,sub_interior(level,m-1), dest_rank, send_tag, &
+                      MPI_COMM_HORIZ, send_request(3), ierr)
+#ifndef NDEBUG
+  if (ierr .ne. 0) &
+    call fatalerror("Distribute: send to SE failed in mpi_isend().")
+#endif
+        ! While sending, copy local data
+        do iz=0,nz+1
+          b%s(iz,1:b_n,1:b_n) = a%s(iz,1:b_n,1:b_n)
+        end do
+        ! Only proceed when async sends to complete
+        call mpi_waitall(3, send_request, MPI_STATUSES_IGNORE, ierr)
+      end if
+      if ( corner_ne ) then
+
+        ! Receive from NW
+        call mpi_cart_rank(MPI_COMM_HORIZ, &
+                           (/p_horiz(1),p_horiz(2)-stepsize/), &
+                           source_rank, &
+                           ierr)
+        recv_tag = 0
+        call mpi_recv(b%s(0,1,1),1,interior(level,m),source_rank,recv_tag,MPI_COMM_HORIZ,stat,ierr)
+#ifndef NDEBUG
+  if (ierr .ne. 0) &
+    call fatalerror("Distribute: receive on NE failed in mpi_recv().")
+#endif
+      end if
+      if ( corner_sw ) then
+        ! Receive from NW
+        call mpi_cart_rank(MPI_COMM_HORIZ, &
+                           (/p_horiz(1)-stepsize,p_horiz(2)/), &
+                           source_rank, &
+                           ierr)
+        recv_tag = 1
+        call mpi_recv(b%s(0,1,1),1,interior(level,m),source_rank,recv_tag,MPI_COMM_HORIZ,stat,ierr)
+#ifndef NDEBUG
+  if (ierr .ne. 0) &
+    call fatalerror("Distribute: receive on SW failed in mpi_recv().")
+#endif
+      end if
+      if ( corner_se ) then
+        ! Receive from NW
+        call mpi_cart_rank(MPI_COMM_HORIZ, &
+                           (/p_horiz(1)-stepsize,p_horiz(2)-stepsize/), &
+                           source_rank, &
+                           ierr)
+        recv_tag = 2
+        call mpi_recv(b%s(0,1,1),1,interior(level,m),source_rank,recv_tag,MPI_COMM_HORIZ,stat,ierr)
+#ifndef NDEBUG
+  if (ierr .ne. 0) &
+    call fatalerror("Distribute: receive on NW failed in mpi_recv().")
+#endif
+      end if
+
+    end if
+    call finish_timer(t_distribute(m))
+
+  end subroutine distribute
+
+end module communication
diff --git a/tensorproductmultigrid_Source/conjugategradient.f90 b/tensorproductmultigrid_Source/conjugategradient.f90
new file mode 100644
index 000000000..b58adae6b
--- /dev/null
+++ b/tensorproductmultigrid_Source/conjugategradient.f90
@@ -0,0 +1,275 @@
+!=== COPYRIGHT AND LICENSE STATEMENT ===
+!
+!  This file is part of the TensorProductMultigrid code.
+!  
+!  (c) The copyright relating to this work is owned jointly by the
+!  Crown, Met Office and NERC [2014]. However, it has been created
+!  with the help of the GungHo Consortium, whose members are identified
+!  at https://puma.nerc.ac.uk/trac/GungHo/wiki .
+!  
+!  Main Developer: Eike Mueller
+!  
+!  TensorProductMultigrid is free software: you can redistribute it and/or
+!  modify it under the terms of the GNU Lesser General Public License as
+!  published by the Free Software Foundation, either version 3 of the
+!  License, or (at your option) any later version.
+!  
+!  TensorProductMultigrid is distributed in the hope that it will be useful,
+!  but WITHOUT ANY WARRANTY; without even the implied warranty of
+!  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+!  GNU Lesser General Public License for more details.
+!  
+!  You should have received a copy of the GNU Lesser General Public License
+!  along with TensorProductMultigrid (see files COPYING and COPYING.LESSER).
+!  If not, see <http://www.gnu.org/licenses/>.
+!
+!=== COPYRIGHT AND LICENSE STATEMENT ===
+
+
+!==================================================================
+!
+! Conjugate gradient solver
+!
+!    Eike Mueller, University of Bath, Feb 2012
+!
+!==================================================================
+module conjugategradient
+
+  use parameters
+  use datatypes
+  use discretisation
+  use messages
+  use communication
+  use mpi
+
+  implicit none
+
+public::cg_parameters
+public::cg_initialise
+public::cg_finalise
+public::cg_solve
+
+private
+
+  ! --- Conjugate gradient parameters type ---
+  type cg_parameters
+    ! Verbosity level
+    integer :: verbose
+    ! Maximal number of iterations
+    integer :: maxiter
+    ! Required residual reduction
+    real(kind=rl) :: resreduction
+    ! Smoother iterations in preconditioner
+    integer :: n_prec
+  end type cg_parameters
+
+! --- Parameters ---
+  type(cg_parameters) :: cg_param
+  type(grid_parameters) :: grid_param
+
+contains
+
+!==================================================================
+! Initialise conjugate gradient module,
+!==================================================================
+  subroutine cg_initialise(cg_param_in)    &  ! Conjugate gradient
+                                           &  ! parameters
+    implicit none
+    type(cg_parameters), intent(in)    :: cg_param_in
+
+    if (i_am_master_mpi) then
+      write(STDOUT,*) '*** Initialising Conjugate gradient ***'
+      write(STDOUT,*) ''
+    end if
+    cg_param = cg_param_in
+  end subroutine cg_initialise
+
+!==================================================================
+! Finalise conjugate gradient module,
+!==================================================================
+  subroutine cg_finalise()
+    implicit none
+
+    if (i_am_master_mpi) then
+      write(STDOUT,*) '*** Finalising Conjugate gradient ***'
+      write(STDOUT,*) ''
+    end if
+  end subroutine cg_finalise
+
+!==================================================================
+! Solve A.u = b.
+!==================================================================
+  subroutine cg_solve(level,m,b,u)
+    implicit none
+    integer, intent(in)           :: level
+    integer, intent(in)           :: m
+    type(scalar3d), intent(in)    :: b    ! RHS vector
+    type(scalar3d), intent(inout) :: u    ! solution vector
+    type(scalar3d)                :: p    ! } Auxilliary vectors
+    type(scalar3d)                :: r    ! } Auxilliary vectors
+    type(scalar3d)                :: Ap   ! }
+    type(scalar3d)                :: z    ! }
+    integer                       :: n_lin
+    real(kind=rl)                 :: res0, rz, rz_old, res, alpha
+    integer :: i
+    logical :: solver_converged = .false.
+    integer :: n, nz, nlocal, halo_size
+    real(kind=rl) :: pAp
+
+    ! Initialise auxiliary fields
+    p%grid_param = u%grid_param
+    p%ix_min = u%ix_min
+    p%ix_max = u%ix_max
+    p%iy_min = u%iy_min
+    p%iy_max = u%iy_max
+    p%icompx_min = u%icompx_min
+    p%icompx_max = u%icompx_max
+    p%icompy_min = u%icompy_min
+    p%icompy_max = u%icompy_max
+    p%halo_size = u%halo_size
+
+    r%grid_param = u%grid_param
+    r%ix_min = u%ix_min
+    r%ix_max = u%ix_max
+    r%iy_min = u%iy_min
+    r%iy_max = u%iy_max
+    r%icompx_min = u%icompx_min
+    r%icompx_max = u%icompx_max
+    r%icompy_min = u%icompy_min
+    r%icompy_max = u%icompy_max
+    r%halo_size = u%halo_size
+
+    z%grid_param = u%grid_param
+    z%ix_min = u%ix_min
+    z%ix_max = u%ix_max
+    z%iy_min = u%iy_min
+    z%iy_max = u%iy_max
+    z%icompx_min = u%icompx_min
+    z%icompx_max = u%icompx_max
+    z%icompy_min = u%icompy_min
+    z%icompy_max = u%icompy_max
+    z%halo_size = u%halo_size
+
+    Ap%grid_param = u%grid_param
+    Ap%ix_min = u%ix_min
+    Ap%ix_max = u%ix_max
+    Ap%iy_min = u%iy_min
+    Ap%iy_max = u%iy_max
+    Ap%icompx_min = u%icompx_min
+    Ap%icompx_max = u%icompx_max
+    Ap%icompy_min = u%icompy_min
+    Ap%icompy_max = u%icompy_max
+    Ap%halo_size = u%halo_size
+
+    n = u%ix_max-u%ix_min+1
+    nz = u%grid_param%nz
+
+    nlocal = u%ix_max - u%ix_min + 1
+    halo_size = u%halo_size
+
+    n_lin = (nlocal+2*halo_size)**2 * (nz+2)
+
+    allocate(r%s(0:nz+1,                     &
+             1-halo_size:nlocal+halo_size,   &
+             1-halo_size:nlocal+halo_size) )
+    allocate(z%s(0:nz+1,                     &
+             1-halo_size:nlocal+halo_size,   &
+             1-halo_size:nlocal+halo_size) )
+    allocate(p%s(0:nz+1,                     &
+             1-halo_size:nlocal+halo_size,   &
+             1-halo_size:nlocal+halo_size) )
+    allocate(Ap%s(0:nz+1,                    &
+             1-halo_size:nlocal+halo_size,   &
+             1-halo_size:nlocal+halo_size) )
+    r%s = 0.0_rl
+    z%s = 0.0_rl
+    p%s = 0.0_rl
+    Ap%s = 0.0_rl
+
+    ! Initialise
+    ! r <- b - A.u
+    call calculate_residual(level,m,b,u,r)
+    ! z <- M^{-1} r
+    if (cg_param%n_prec > 0) then
+      call smooth(level,m,cg_param%n_prec,DIRECTION_FORWARD,r,z)
+      call smooth(level,m,cg_param%n_prec,DIRECTION_BACKWARD,r,z)
+    else
+      call dcopy(n_lin,r%s,1,z%s,1)
+    end if
+    ! p <- z
+    call dcopy(n_lin,z%s,1,p%s,1)
+    ! rz_old = <r,z>
+    call scalarprod(m,r,z,rz_old)
+    ! res0 <- ||r||
+    call scalarprod(m,r,r,res0)
+    res0 = dsqrt(res0)
+    if (cg_param%verbose > 0) then
+      if (i_am_master_mpi) then
+        write(STDOUT,'("    *** CG Solver ( ",I10," dof ) ***")') n_lin
+        write(STDOUT,'("    <CG> Initial residual ||r_0|| = ",E12.6)') res0
+      end if
+    endif
+    if (res0 > tolerance) then
+      do i=1,cg_param%maxiter
+        ! Ap <- A.p
+        call haloswap(level,m,p)
+        call apply(p,Ap)
+        ! alpha <- res_old / <p,A.p>
+        call scalarprod(m,p,Ap,pAp)
+        alpha = rz_old/pAp
+        ! x <- x + alpha*p
+        call daxpy(n_lin,alpha,p%s,1,u%s,1)
+        ! r <- r - alpha*A.p
+        call daxpy(n_lin,-alpha,Ap%s,1,r%s,1)
+        call scalarprod(m,r,r,res)
+        res = dsqrt(res)
+        if (cg_param%verbose > 1) then
+          if (i_am_master_mpi) then
+            write(STDOUT,'("    <CG> Iteration ",I6," ||r|| = ",E12.6)') &
+              i, res
+          end if
+        end if
+        if ( (res/res0 < cg_param%resreduction) .or. &
+             (res < tolerance ) ) then
+          solver_converged = .true.
+          exit
+        end if
+        z%s = 0.0_rl
+        ! z <- M^{-1} r
+        z%s = 0.0_rl
+        if (cg_param%n_prec > 0) then
+          call smooth(level,m,cg_param%n_prec,DIRECTION_FORWARD,r,z)
+          call smooth(level,m,cg_param%n_prec,DIRECTION_BACKWARD,r,z)
+        else
+          call dcopy(n_lin,r%s,1,z%s,1)
+        end if
+        call scalarprod(m,r,z,rz)
+        ! p <- res/res_old*p
+        call dscal(n_lin,rz/rz_old,p%s,1)
+        ! p <- p + z
+        call daxpy(n_lin,1.0_rl,z%s,1,p%s,1)
+        rz_old = rz
+      end do
+    else
+      res = res0
+      solver_converged = .true.
+    end if
+    if (cg_param%verbose>0) then
+      if (solver_converged) then
+        if (i_am_master_mpi) then
+          write(STDOUT,'("    <CG> Final residual    ||r|| = ",E12.6)') res
+          write(STDOUT,'("    <CG> CG solver converged after ",I6," iterations rho_avg = ",F10.6)') i, (res/res0)**(1.0_rl/i)
+        end if
+      else
+        call warning("    <CG> Solver did not converge")
+      endif
+    end if
+
+    deallocate(r%s)
+    deallocate(z%s)
+    deallocate(p%s)
+    deallocate(Ap%s)
+  end subroutine cg_solve
+
+end module conjugategradient
+
diff --git a/tensorproductmultigrid_Source/datatypes.f90 b/tensorproductmultigrid_Source/datatypes.f90
new file mode 100644
index 000000000..b0bfbecc3
--- /dev/null
+++ b/tensorproductmultigrid_Source/datatypes.f90
@@ -0,0 +1,381 @@
+!=== COPYRIGHT AND LICENSE STATEMENT ===
+!
+!  This file is part of the TensorProductMultigrid code.
+!  
+!  (c) The copyright relating to this work is owned jointly by the
+!  Crown, Met Office and NERC [2014]. However, it has been created
+!  with the help of the GungHo Consortium, whose members are identified
+!  at https://puma.nerc.ac.uk/trac/GungHo/wiki .
+!  
+!  Main Developer: Eike Mueller
+!  
+!  TensorProductMultigrid is free software: you can redistribute it and/or
+!  modify it under the terms of the GNU Lesser General Public License as
+!  published by the Free Software Foundation, either version 3 of the
+!  License, or (at your option) any later version.
+!  
+!  TensorProductMultigrid is distributed in the hope that it will be useful,
+!  but WITHOUT ANY WARRANTY; without even the implied warranty of
+!  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+!  GNU Lesser General Public License for more details.
+!  
+!  You should have received a copy of the GNU Lesser General Public License
+!  along with TensorProductMultigrid (see files COPYING and COPYING.LESSER).
+!  If not, see <http://www.gnu.org/licenses/>.
+!
+!=== COPYRIGHT AND LICENSE STATEMENT ===
+
+
+!==================================================================
+!
+!   Grid data types for three dimensional cell centred grids.
+!   We always assume that the number of gridcells and size in
+!   the x- and y- direction is identical.
+!
+!    Eike Mueller, University of Bath, Feb 2012
+!
+!==================================================================
+
+
+module datatypes
+
+  use mpi
+  use parameters
+  use messages
+
+  implicit none
+
+! Vertical boundary conditions
+  integer, parameter :: VERTBC_DIRICHLET = 1
+  integer, parameter :: VERTBC_NEUMANN = 2
+
+! Parameters of three dimensional grid
+  type grid_parameters
+    integer :: n         ! Total number of grid cells in horizontal direction
+    integer :: nz        ! Total number of grid cells in vertical direction
+    real(kind=rl) :: L   ! Global extent of grid in horizontal direction
+    real(kind=rl) :: H   ! Global extent of grid in vertical direction
+    integer :: vertbc    ! Vertical boundary condition (see VERTBC_DIRICHLET
+                         ! and VERTBC_NEUMANN)
+    logical :: graded    ! Is the vertical grid graded?
+  end type grid_parameters
+
+! Three dimensional scalar field s(z,y,x)
+  type scalar3d
+    integer :: ix_min     ! } (Inclusive) range of locally owned cells
+    integer :: ix_max     ! } in the x-direction
+    integer :: iy_min     ! } (these ranges DO NOT include halo cells)
+    integer :: iy_max     ! } in the y-direction
+    integer :: icompx_min ! } (Inclusive) ranges of computational cells,
+    integer :: icompx_max ! } in local coords. All cells in these ranges
+    integer :: icompy_min ! } are included in calculations, e.g. in the
+    integer :: icompy_max ! } smoother. This allows duplicating operations
+                          ! } on part of the halo for RB Gauss Seidel
+    integer :: halo_size  ! Size of halos
+    logical :: isactive   ! Is this field active, i.e. used on one of the
+                          ! active processes on coarser grids?
+    real(kind=rl),allocatable :: s(:,:,:)
+    type(grid_parameters) :: grid_param
+  end type scalar3d
+
+public::VERTBC_DIRICHLET
+public::VERTBC_NEUMANN
+public::scalar3d
+public::grid_parameters
+public::L2norm
+public::daxpy_scalar3d
+public::save_scalar3d
+public::create_scalar3d
+public::volscale_scalar3d
+public::destroy_scalar3d
+public::volume_of_element
+public::r_grid
+
+private
+
+  ! Vertical grid, this array of length n_z+1 stores the
+  ! vertices of the grid in the vertical direction
+  real(kind=rl), allocatable :: r_grid(:)
+
+  contains
+
+!==================================================================
+! volume of element on cubed sphere grid
+! NB: ix,iy are global indices
+!==================================================================
+  real(kind=rl) function volume_of_element(ix,iy,grid_param)
+    implicit none
+    integer, intent(in) :: ix
+    integer, intent(in) :: iy
+    type(grid_parameters), intent(in) :: grid_param
+    real(kind=rl) :: h
+    real(kind=rl) :: rho_i, sigma_j
+    h = 2.0_rl/grid_param%n
+    rho_i = 2.0_rl*(ix-0.5_rl)/grid_param%n-1.0_rl
+    sigma_j = 2.0_rl*(iy-0.5_rl)/grid_param%n-1.0_rl
+    volume_of_element = (1.0_rl+rho_i**2+sigma_j**2)**(-1.5_rl)*h**2
+  end function volume_of_element
+
+!==================================================================
+! Create scalar3d field on fine grid and set to zero
+!==================================================================
+  subroutine create_scalar3d(comm_horiz,grid_param, halo_size, phi)
+    implicit none
+
+    integer                             :: comm_horiz  ! Horizontal communicator
+    type(grid_parameters), intent(in)   :: grid_param  ! Grid parameters
+    integer, intent(in)                 :: halo_size   ! Halo size
+    type(scalar3d), intent(inout)       :: phi         ! Field to create
+    integer                             :: nproc       ! Number of processes
+    integer                             :: rank, ierr  ! rank and MPI error
+    integer, dimension(2)               :: p_horiz     ! position in 2d
+                                                       ! processor grid
+    integer                             :: nlocal      ! Local number of
+                                                       ! cells in horizontal
+                                                       ! direction
+    integer, parameter                  :: dim_horiz = 2 ! horiz. dimension
+
+    phi%grid_param = grid_param
+    call mpi_comm_size(comm_horiz, nproc, ierr)
+    nlocal = grid_param%n/sqrt(1.0*nproc)
+
+    ! Work out position in 2d processor grid
+    call mpi_comm_rank(comm_horiz, rank, ierr)
+    call mpi_cart_coords(comm_horiz,rank,dim_horiz,p_horiz,ierr)
+    ! Set local data ranges
+    ! NB: p_horiz stores (py,px) in that order (see comment in
+    ! communication module)
+    phi%iy_min = p_horiz(1)*nlocal + 1
+    phi%iy_max = (p_horiz(1)+1)*nlocal
+    phi%ix_min = p_horiz(2)*nlocal + 1
+    phi%ix_max = (p_horiz(2)+1)*nlocal
+    ! Set computational ranges. Note that these are different at
+    ! the edges of the domain!
+    if (p_horiz(1) == 0) then
+      phi%icompy_min = 1
+    else
+      phi%icompy_min = 1 - (halo_size - 1)
+    end if
+    if (p_horiz(1) == floor(sqrt(1.0_rl*nproc))-1) then
+      phi%icompy_max = nlocal
+    else
+      phi%icompy_max = nlocal + (halo_size - 1)
+    end if
+    if (p_horiz(2) == 0) then
+      phi%icompx_min = 1
+    else
+      phi%icompx_min = 1 - (halo_size - 1)
+    end if
+    if (p_horiz(2) == floor(sqrt(1.0_rl*nproc))-1) then
+      phi%icompx_max = nlocal
+    else
+      phi%icompx_max = nlocal + (halo_size - 1)
+    end if
+    ! Set halo size
+    phi%halo_size = halo_size
+    ! Set field to active
+    phi%isactive = .true.
+    ! Allocate memory
+    allocate(phi%s(0:grid_param%nz+1,            &
+                   1-halo_size:nlocal+halo_size, &
+                   1-halo_size:nlocal+halo_size))
+    phi%s(:,:,:) = 0.0_rl
+
+  end subroutine create_scalar3d
+
+!==================================================================
+! Destroy scalar3d field on fine grid
+!==================================================================
+  subroutine destroy_scalar3d(phi)
+    implicit none
+    type(scalar3d), intent(inout) :: phi
+
+    deallocate(phi%s)
+
+  end subroutine destroy_scalar3d
+
+!==================================================================
+! Scale fields with volume of element
+! Either multiply with volume factor |T| v_k (power = 1)
+! or divide by it (power = -1)
+!==================================================================
+  subroutine volscale_scalar3d(phi,power)
+    implicit none
+    type(scalar3d), intent(inout) :: phi
+    integer, intent(in) :: power
+    integer :: ix, iy, iz
+    integer :: ierr
+    integer :: nlocalx, nlocaly
+    real(kind=rl) :: vol_h, vol_r, h, tmp
+
+    if (.not. ( ( power .eq. 1) .or. (power .eq. -1) ) ) then
+      call fatalerror("power has to be -1 or 1 when volume-scaling fields")
+    end if
+
+    nlocalx = phi%ix_max-phi%ix_min+1
+    nlocaly = phi%iy_max-phi%iy_min+1
+
+    if (phi%isactive) then
+      do ix=1,nlocalx
+        do iy=1,nlocaly
+#ifdef CARTESIANGEOMETRY
+          h = phi%grid_param%L/phi%grid_param%n
+          vol_h = h**2
+#else
+          vol_h = volume_of_element(ix+(phi%ix_min-1), &
+                                    iy+(phi%iy_min-1), &
+                                    phi%grid_param)
+#endif
+          do iz=1,phi%grid_param%nz
+#ifdef CARTESIANGEOMETRY
+            vol_r = r_grid(iz+1)-r_grid(iz)
+#else
+            vol_r = (r_grid(iz+1)**3 - r_grid(iz)**3)/3.0_rl
+#endif
+            if (power == 1) then
+              tmp = vol_h*vol_r
+            else
+              tmp = 1.0_rl/(vol_h*vol_r)
+            end if
+            phi%s(iz,iy,ix) = tmp*phi%s(iz,iy,ix)
+          end do
+        end do
+      end do
+    end if
+
+  end subroutine volscale_scalar3d
+
+!==================================================================
+! Calculate L2 norm
+! If phi_is_volumeintegral is .true. then phi is interpreted
+! as the volume integral in a cell, otherwise it is interpreted as the
+! average value in a cell.
+!==================================================================
+  real(kind=rl) function l2norm(phi,phi_is_volumeintegral)
+    implicit none
+    type(scalar3d), intent(in) :: phi
+    logical, optional :: phi_is_volumeintegral
+    integer :: ix, iy, iz
+    real(kind=rl) :: tmp, global_tmp
+    integer :: ierr
+    integer :: nlocalx, nlocaly
+    real(kind=rl) :: vol_h, vol_r, h
+    logical :: divide_by_volume
+    real(kind=rl) :: volume_factor
+    if (present(phi_is_volumeintegral)) then
+      divide_by_volume = phi_is_volumeintegral
+    else
+      divide_by_volume = .false.
+    end if
+
+    nlocalx = phi%ix_max-phi%ix_min+1
+    nlocaly = phi%iy_max-phi%iy_min+1
+
+    tmp = 0.0_rl
+    if (phi%isactive) then
+      do ix=1,nlocalx
+        do iy=1,nlocaly
+#ifdef CARTESIANGEOMETRY
+          h = phi%grid_param%L/phi%grid_param%n
+          vol_h = h**2
+#else
+          vol_h = volume_of_element(ix+(phi%ix_min-1), &
+                                    iy+(phi%iy_min-1), &
+                                    phi%grid_param)
+#endif
+          do iz=1,phi%grid_param%nz
+#ifdef CARTESIANGEOMETRY
+            vol_r = r_grid(iz+1)-r_grid(iz)
+#else
+            vol_r = (r_grid(iz+1)**3 - r_grid(iz)**3)/3.0_rl
+#endif
+            if (divide_by_volume) then
+              volume_factor = 1.0_rl/(vol_h*vol_r)
+            else
+              volume_factor = vol_h*vol_r
+            end if
+            tmp = tmp + volume_factor*phi%s(iz,iy,ix)**2
+          end do
+        end do
+      end do
+    end if
+
+    call mpi_allreduce(tmp,global_tmp, 1, &
+                       MPI_DOUBLE_PRECISION,MPI_SUM,MPI_COMM_WORLD,ierr)
+    l2norm = dsqrt(global_tmp)
+  end function l2norm
+
+!==================================================================
+! calculate phi <- phi + alpha*dphi
+!==================================================================
+  subroutine daxpy_scalar3d(alpha,dphi,phi)
+    implicit none
+    real(kind=rl), intent(in) :: alpha
+    type(scalar3d), intent(in) :: dphi
+    type(scalar3d), intent(inout) :: phi
+    integer :: nlin
+    integer :: nlocalx, nlocaly
+
+    nlocalx = phi%ix_max-phi%ix_min+1
+    nlocaly = phi%iy_max-phi%iy_min+1
+    nlin = (nlocalx+2*phi%halo_size) &
+         * (nlocaly+2*phi%halo_size) &
+         * (phi%grid_param%nz+2)
+
+    call daxpy(nlin,alpha,dphi%s,1,phi%s,1)
+
+  end subroutine daxpy_scalar3d
+
+!==================================================================
+! Save scalar field to file
+!==================================================================
+  subroutine save_scalar3d(comm_horiz,phi,filename)
+    implicit none
+    integer, intent(in) :: comm_horiz
+    type(scalar3d), intent(in) :: phi
+    character(*), intent(in)   :: filename
+    integer :: file_id = 100
+    integer :: ix,iy,iz
+    integer :: nlocal
+    integer :: rank, nproc, ierr
+    character(len=21) :: s
+
+    nlocal = phi%ix_max-phi%ix_min+1
+
+    ! Get number of processes and rank
+    call mpi_comm_size(comm_horiz, nproc, ierr)
+    call mpi_comm_rank(comm_horiz, rank, ierr)
+
+    write(s,'(I10.10,"_",I10.10)') nproc, rank
+
+    open(unit=file_id,file=trim(filename)//"_"//trim(s)//".dat")
+    write(file_id,*) "# 3d scalar data file"
+    write(file_id,*) "# ==================="
+    write(file_id,*) "# Data is written as s(iz,iy,ix) "
+    write(file_id,*) "# with the leftmost index running fastest"
+    write(file_id,'(" n  = ",I8)') phi%grid_param%n
+    write(file_id,'(" nz = ",I8)') phi%grid_param%nz
+    write(file_id,'(" L  = ",F20.10)') phi%grid_param%L
+    write(file_id,'(" H  = ",F20.10)') phi%grid_param%H
+    write(file_id,'(" ix_min  = ",I10)') phi%ix_min
+    write(file_id,'(" ix_max  = ",I10)') phi%ix_max
+    write(file_id,'(" iy_min  = ",I10)') phi%iy_min
+    write(file_id,'(" iy_max  = ",I10)') phi%iy_max
+    write(file_id,'(" icompx_min  = ",I10)') phi%icompx_min
+    write(file_id,'(" icompx_max  = ",I10)') phi%icompx_max
+    write(file_id,'(" icompy_min  = ",I10)') phi%icompy_min
+    write(file_id,'(" icompy_max  = ",I10)') phi%icompy_max
+    write(file_id,'(" halosize    = ",I10)') phi%halo_size
+
+
+    do ix=1-phi%halo_size,nlocal+phi%halo_size
+      do iy=1-phi%halo_size,nlocal+phi%halo_size
+        do iz=0,phi%grid_param%nz+1
+          write(file_id,'(E24.15)') phi%s(iz,iy,ix)
+        end do
+      end do
+    end do
+    close(file_id)
+  end subroutine save_scalar3d
+
+end module datatypes
diff --git a/tensorproductmultigrid_Source/discretisation.f90 b/tensorproductmultigrid_Source/discretisation.f90
new file mode 100644
index 000000000..87227b862
--- /dev/null
+++ b/tensorproductmultigrid_Source/discretisation.f90
@@ -0,0 +1,879 @@
+!=== COPYRIGHT AND LICENSE STATEMENT ===
+!
+!  This file is part of the TensorProductMultigrid code.
+!  
+!  (c) The copyright relating to this work is owned jointly by the
+!  Crown, Met Office and NERC [2014]. However, it has been created
+!  with the help of the GungHo Consortium, whose members are identified
+!  at https://puma.nerc.ac.uk/trac/GungHo/wiki .
+!  
+!  Main Developer: Eike Mueller
+!  
+!  TensorProductMultigrid is free software: you can redistribute it and/or
+!  modify it under the terms of the GNU Lesser General Public License as
+!  published by the Free Software Foundation, either version 3 of the
+!  License, or (at your option) any later version.
+!  
+!  TensorProductMultigrid is distributed in the hope that it will be useful,
+!  but WITHOUT ANY WARRANTY; without even the implied warranty of
+!  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+!  GNU Lesser General Public License for more details.
+!  
+!  You should have received a copy of the GNU Lesser General Public License
+!  along with TensorProductMultigrid (see files COPYING and COPYING.LESSER).
+!  If not, see <http://www.gnu.org/licenses/>.
+!
+!=== COPYRIGHT AND LICENSE STATEMENT ===
+
+
+!==================================================================
+!
+!  Discretisation module of the model problem
+!
+!
+!  -omega2 * (d^2/dx^2 + d^2/dy^2 + lambda2 * d^2/dz^2 ) u
+!                                                + delta u = RHS
+!  [Cartesian]
+!
+!  or
+!
+!  -omega2 * (laplace_{2d} + lambda2/r^2 d/dr (r^2 d/dr)) u
+!                                                + delta u = RHS
+!  [Spherical]
+!
+!  We use a cell centered finite volume discretisation with
+!  The equation is discretised either in a unit cube or on 1/6th
+!  of a cubed sphere grid.
+!
+!  The vertical grid spacing is not necessarily uniform and can
+!  be chosen by specifying the vertical grid in a vector.
+!
+!  The following boundary conditions are used:
+!
+!  * Dirichlet in the horizontal
+!  * Neumann in the vertical
+!
+!  For delta = 0 the operator reduces to the Poisson operator.
+!
+!    Eike Mueller, University of Bath, Feb 2012
+!
+!==================================================================
+
+module discretisation
+
+  use parameters
+  use messages
+  use datatypes
+  use communication
+  use mpi
+
+  implicit none
+
+private
+
+  type model_parameters
+    real(kind=rl) :: omega2   ! omega^2
+    real(kind=rl) :: lambda2  ! lambda^2
+    real(kind=rl) :: delta    ! delta
+  end type model_parameters
+
+! --- Stencil ---
+!
+
+! Grid traversal direction in SOR
+  integer, parameter :: DIRECTION_FORWARD = 1
+  integer, parameter :: DIRECTION_BACKWARD = 2
+
+! Ordering in SOR
+  ! Lexicographic ordering
+  integer, parameter :: ORDERING_LEX = 1
+  ! Red-black ordering
+  integer, parameter :: ORDERING_RB = 2
+
+  type smoother_parameters
+    ! smoother
+    integer :: smoother
+    ! relaxation parameter
+    real(kind=rl) :: rho
+    ! ordering of degrees of freedom
+    integer :: ordering
+  end type smoother_parameters
+
+  ! Allowed smoothers
+  integer, parameter :: SMOOTHER_LINE_SOR = 3
+  integer, parameter :: SMOOTHER_LINE_SSOR = 4
+  integer, parameter :: SMOOTHER_LINE_JAC = 6
+
+  ! Number of levels
+  integer :: nlev
+
+  ! Grid parameters
+  type(grid_parameters) :: grid_param
+
+  ! Model parameters
+  type(model_parameters) :: model_param
+
+  ! Smoother parameters
+  type(smoother_parameters) :: smoother_param
+
+  ! Arrays for measuring the residual reduction
+  real(kind=rl), allocatable :: log_resreduction(:)
+  integer, allocatable :: nsmooth_total(:)
+
+  ! Data structure for storing the vertical discretisation
+  type vertical_coefficients
+    real(kind=rl), allocatable :: a(:)
+    real(kind=rl), allocatable :: b(:)
+    real(kind=rl), allocatable :: c(:)
+    real(kind=rl), allocatable :: d(:)
+  end type vertical_coefficients
+
+  ! Stoarge for vertical coefficients
+  type(vertical_coefficients) :: vert_coeff
+
+public::discretisation_initialise
+public::discretisation_finalise
+public::smooth
+public::line_SOR
+public::line_SSOR
+public::line_jacobi
+public::calculate_residual
+public::apply
+public::model_parameters
+public::smoother_parameters
+public::volume_of_element
+public::SMOOTHER_LINE_SOR
+public::SMOOTHER_LINE_SSOR
+public::SMOOTHER_LINE_JAC
+public::DIRECTION_FORWARD
+public::DIRECTION_BACKWARD
+public::ORDERING_LEX
+public::ORDERING_RB
+
+contains
+
+!==================================================================
+! Initialise module
+!==================================================================
+  subroutine discretisation_initialise(grid_param_in, &
+                                       model_param_in, &
+                                       smoother_param_in, &
+                                       nlev_in)
+    implicit none
+    type(grid_parameters), intent(in)   :: grid_param_in
+    type(model_parameters), intent(in)   :: model_param_in
+    type(smoother_parameters), intent(in)   :: smoother_param_in
+    integer, intent(in) :: nlev_in
+    integer :: k
+    grid_param = grid_param_in
+    model_param = model_param_in
+    smoother_param = smoother_param_in
+    nlev = nlev_in
+    allocate(log_resreduction(nlev))
+    allocate(nsmooth_total(nlev))
+    log_resreduction(:) = 0.0_rl
+    nsmooth_total(:) = 0
+    allocate(r_grid(grid_param%nz+1))
+    if (grid_param%graded) then
+      do k=1,grid_param%nz+1
+        r_grid(k) = grid_param%H*(1.0_rl*(k-1.0_rl)/grid_param%nz)**2
+      end do
+    else
+      do k=1,grid_param%nz+1
+        r_grid(k) = grid_param%H*(1.0_rl*(k-1.0_rl)/grid_param%nz)
+      end do
+    end if
+#ifdef CARTESIANGEOMETRY
+#else
+    r_grid(:) = 1.0_rl + r_grid(:)
+#endif
+    ! Allocate arrays for vertical discretisation matrices
+    ! and calculate matrix entries
+    allocate(vert_coeff%a(grid_param%nz))
+    allocate(vert_coeff%b(grid_param%nz))
+    allocate(vert_coeff%c(grid_param%nz))
+    allocate(vert_coeff%d(grid_param%nz))
+    call construct_vertical_coeff()
+  end subroutine discretisation_initialise
+
+!==================================================================
+! Finalise module
+!==================================================================
+  subroutine discretisation_finalise()
+    implicit none
+    integer :: level
+    real(kind=rl) :: rho_avg
+#ifdef MEASURESMOOTHINGRATE
+    if (i_am_master_mpi) then
+      write(STDOUT,'("Average smoothing rates:")')
+      do level=nlev,1,-1
+        if (nsmooth_total(level) > 0) then
+          rho_avg = exp(log_resreduction(level)/nsmooth_total(level))
+        else
+          rho_avg = 1.0_rl
+        end if
+        write(STDOUT,'("rho_{avg}(",I3,") = ",E10.4," ( ",I5," x )")') &
+          level, rho_avg, nsmooth_total(level)
+      end do
+    end if
+#endif
+    deallocate(log_resreduction)
+    deallocate(nsmooth_total)
+    deallocate(r_grid)
+    ! Deallocate storage for vertical discretisation matrices
+    deallocate(vert_coeff%a)
+    deallocate(vert_coeff%b)
+    deallocate(vert_coeff%c)
+    deallocate(vert_coeff%d)
+  end subroutine discretisation_finalise
+
+!==================================================================
+! Construct alpha_{i',j'} and |T_{ij}| needed for the
+! horizontal stencil
+! ( alpha_{i+1,j},
+!   alpha_{i-1,j},
+!   alpha_{i,j+1},
+!   alpha_{i,j-1},
+!   alpha_{ij})
+! (ix,iy) are LOCAL indices of the grid boxes, which are
+! converted to global indices
+!==================================================================
+  subroutine construct_alpha_T(grid_param,ix,iy,alpha_T,Tij)
+    implicit none
+    type(grid_parameters), intent(in) :: grid_param
+    integer, intent(in) :: ix
+    integer, intent(in) :: iy
+    real(kind=rl), intent(inout), dimension(5) :: alpha_T
+    real(kind=rl), intent(out) :: Tij
+    real(kind=rl)                  :: h, rho_i, sigma_j
+#ifdef CARTESIANGEOMETRY
+    h = grid_param%L/grid_param%n
+    ! Cartesian coefficients
+    Tij = h**2
+    if (ix == grid_param%n) then
+      alpha_T(1) = 2.0_rl
+    else
+      alpha_T(1) = 1.0_rl
+    end if
+    if (ix == 1) then
+      alpha_T(2) = 2.0_rl
+    else
+      alpha_T(2) = 1.0_rl
+    end if
+    if (iy == grid_param%n) then
+      alpha_T(3) = 2.0_rl
+    else
+      alpha_T(3) = 1.0_rl
+    end if
+    if (iy == 1) then
+      alpha_T(4) = 2.0_rl
+    else
+      alpha_T(4) = 1.0_rl
+    end if
+#else
+    ! Coefficients in cubed sphere geometry
+    ! (rho_i,sigma_j) \in [-1,1] x [-1,1] are the coordinates of the
+    ! cubed sphere segment
+    h = 2.0_rl/grid_param%n
+    Tij = volume_of_element(ix,iy,grid_param)
+    rho_i = 2.0_rl*(1.0_rl*ix-0.5_rl)/grid_param%n-1.0_rl
+    sigma_j = 2.0_rl*(1.0_rl*iy-0.5_rl)/grid_param%n-1.0_rl
+    ! alpha_{i+1,j}
+    if (ix == grid_param%n) then
+      alpha_T(1) = 2.0_rl*DSQRT((1.0_rl+(rho_i+0.25_rl*h)**2)/(1.0_rl+sigma_j**2))
+    else
+      alpha_T(1) = DSQRT((1.0_rl+(rho_i+0.5_rl*h)**2)/(1.0_rl+sigma_j**2))
+    end if
+    ! alpha_{i-1,j}
+    if (ix == 1) then
+      alpha_T(2) = 2.0_rl*DSQRT((1.0_rl+(rho_i-0.25_rl*h)**2)/(1.0_rl+sigma_j**2))
+    else
+      alpha_T(2) = DSQRT((1.0_rl+(rho_i-0.5_rl*h)**2)/(1.0_rl+sigma_j**2))
+    end if
+    ! alpha_{i,j+1}
+    if (iy == grid_param%n) then
+      alpha_T(3) = 2.0_rl*DSQRT((1.0_rl+(sigma_j+0.25_rl*h)**2)/(1.0_rl+rho_i**2))
+    else
+      alpha_T(3) = DSQRT((1.0_rl+(sigma_j+0.5_rl*h)**2)/(1.0_rl+rho_i**2))
+    end if
+    ! alpha_{i,j-1}
+    if (iy == 1) then
+      alpha_T(4) = 2.0_rl*DSQRT((1.0_rl+(sigma_j-0.25_rl*h)**2)/(1.0_rl+rho_i**2))
+    else
+      alpha_T(4) = DSQRT((1.0_rl+(sigma_j-0.5_rl*h)**2)/(1.0_rl+rho_i**2))
+    end if
+#endif
+    alpha_T(5) = alpha_T(1) + alpha_T(2) + alpha_T(3) + alpha_T(4)
+  end subroutine construct_alpha_T
+
+!==================================================================
+! Construct coefficients of tridiagonal matrix A_T
+! describing the coupling in the vertical direction and the
+! diagonal matrix diag(d)
+!==================================================================
+subroutine construct_vertical_coeff()
+  implicit none
+  real(kind=rl) :: a_k_tmp, b_k_tmp, c_k_tmp, d_k_tmp
+  real(kind=rl) :: omega2, lambda2, delta, vol_r, surface_k, surface_kp1
+  integer :: k
+  omega2 = model_param%omega2
+  lambda2 = model_param%lambda2
+  delta = model_param%delta
+  do k = 1, grid_param%nz
+#ifdef  CARTESIANGEOMETRY
+    vol_r = r_grid(k+1)-r_grid(k)
+    surface_k = 1.0_rl
+    surface_kp1 = 1.0_rl
+#else
+    vol_r = (r_grid(k+1)**3 - r_grid(k)**3)/3.0_rl
+    surface_k = r_grid(k)**2
+    surface_kp1 = r_grid(k+1)**2
+#endif
+    ! Diagonal element
+    a_k_tmp = delta*vol_r
+    ! off diagonal elements
+    ! Boundary conditions
+    ! Top
+    if (k == grid_param%nz) then
+      if (grid_param%vertbc == VERTBC_DIRICHLET) then
+        b_k_tmp = - 2.0_rl * omega2*lambda2 &
+                           * surface_kp1/(r_grid(k+1)-r_grid(k))
+      else
+        b_k_tmp = 0.0_rl
+      end if
+    else
+      b_k_tmp = - 2.0_rl*omega2*lambda2 &
+              * surface_kp1/(r_grid(k+2)-r_grid(k))
+    end if
+    ! Bottom
+    if (k == 1) then
+      if (grid_param%vertbc == VERTBC_DIRICHLET) then
+        c_k_tmp = - 2.0_rl * omega2*lambda2 &
+                           * surface_k/(r_grid(k+1)-r_grid(k))
+      else
+        c_k_tmp = 0.0_rl
+      end if
+    else
+      c_k_tmp = - 2.0_rl * omega2 * lambda2 &
+              * surface_k/(r_grid(k+1)-r_grid(k-1))
+    end if
+    ! Diagonal matrix d_k
+    d_k_tmp = - omega2 * (r_grid(k+1)-r_grid(k))
+    vert_coeff%a(k) = a_k_tmp/d_k_tmp
+    vert_coeff%b(k) = b_k_tmp/d_k_tmp
+    vert_coeff%c(k) = c_k_tmp/d_k_tmp
+    vert_coeff%d(k) = d_k_tmp
+  end do
+end subroutine construct_vertical_coeff
+
+!==================================================================
+! Calculate local residual r = b - A.u
+!==================================================================
+  subroutine calculate_residual(level,m,b,u,r)
+    implicit none
+    integer, intent(in)                :: level
+    integer, intent(in)                :: m
+    type(scalar3d), intent(in)         :: b
+    type(scalar3d), intent(inout)      :: u
+    type(scalar3d), intent(inout)      :: r
+    integer :: ix,iy,iz
+
+    ! r <- A.u
+    call apply(u,r)
+    ! r <- b - r = b - A.u
+    do ix=u%icompx_min,u%icompx_max
+      do iy=u%icompy_min,u%icompy_max
+        do iz=1,u%grid_param%nz
+          r%s(iz,iy,ix) = b%s(iz,iy,ix) - r%s(iz,iy,ix)
+        end do
+      end do
+    end do
+  end subroutine calculate_residual
+
+!==================================================================
+! Apply operator v = A.u
+!==================================================================
+  subroutine apply(u,v)
+    implicit none
+    type(scalar3d), intent(in)         :: u
+    type(scalar3d), intent(inout)      :: v
+    real(kind=rl), dimension(5) :: alpha_T
+    real(kind=rl) :: Tij
+    real(kind=rl) :: a_k, b_k, c_k, d_k
+    integer :: ix,iy,iz
+    real(kind=rl) :: tmp
+
+    do ix=u%icompx_min,u%icompx_max
+      do iy=u%icompy_min,u%icompy_max
+        ! Construct horizontal part of stencil
+        call construct_alpha_T(u%grid_param,  &
+                               ix+u%ix_min-1, &
+                               iy+u%iy_min-1, &
+                               alpha_T,Tij)
+        do iz=1,u%grid_param%nz
+          a_k = vert_coeff%a(iz)
+          b_k = vert_coeff%b(iz)
+          c_k = vert_coeff%c(iz)
+          d_k = vert_coeff%d(iz)
+          tmp = ((a_k-b_k-c_k)*Tij - alpha_T(5)) * u%s(iz,iy,ix)
+          if (iz < grid_param%nz) then
+            tmp = tmp + b_k*Tij * u%s(iz+1,iy,ix)
+          end if
+          if (iz > 1) then
+            tmp = tmp + c_k*Tij * u%s(iz-1,iy,ix)
+          end if
+          tmp = tmp + alpha_T(1) * u%s(iz,  iy  ,ix+1) &
+                    + alpha_T(2) * u%s(iz,  iy  ,ix-1) &
+                    + alpha_T(3) * u%s(iz,  iy+1,ix  ) &
+                    + alpha_T(4) * u%s(iz,  iy-1,ix  )
+          v%s(iz,iy,ix) = d_k*tmp
+        end do
+      end do
+    end do
+  end subroutine apply
+
+!==================================================================
+!==================================================================
+!
+!     S M O O T H E R S
+!
+!==================================================================
+!==================================================================
+
+!==================================================================
+! Perform nsmooth smoother iterations
+!==================================================================
+  subroutine smooth(level,m,nsmooth,direction,b,u)
+    implicit none
+    integer, intent(in) :: level
+    integer, intent(in) :: m
+    integer, intent(in) :: nsmooth       ! Number of smoothing steps
+    integer, intent(in) :: direction     ! Direction
+    type(scalar3d), intent(inout) :: b   ! RHS
+    type(scalar3d), intent(inout) :: u   ! solution vector
+    integer :: i
+    real(kind=rl) :: log_res_initial, log_res_final
+    type(scalar3d) :: r
+    integer :: halo_size
+    integer :: nlocal, nz
+
+#ifdef MEASURESMOOTHINGRATE
+    r%ix_min = u%ix_min
+    r%ix_max = u%ix_max
+    r%iy_min = u%iy_min
+    r%iy_max = u%iy_max
+    r%icompx_min = u%icompx_min
+    r%icompx_max = u%icompx_max
+    r%icompy_min = u%icompy_min
+    r%icompy_max = u%icompy_max
+    r%halo_size = u%halo_size
+    r%isactive = u%isactive
+    r%grid_param = u%grid_param
+    nlocal = r%ix_max-r%ix_min+1
+    halo_size = r%halo_size
+    nz = r%grid_param%nz
+    allocate(r%s(0:nz+1,                       &
+                 1-halo_size:nlocal+halo_size, &
+                 1-halo_size:nlocal+halo_size))
+    call calculate_residual(level,m,b,u,r)
+    log_res_initial = log(l2norm(r))
+#endif
+    ! Carry out nsmooth iterations of the smoother
+    if (smoother_param%smoother == SMOOTHER_LINE_SOR) then
+      do i=1,nsmooth
+        call line_SOR(level,m,direction,b,u)
+      end do
+    else if (smoother_param%smoother == SMOOTHER_LINE_SSOR) then
+      do i=1,nsmooth
+        call line_SSOR(level,m,direction,b,u)
+      end do
+    else if (smoother_param%smoother == SMOOTHER_LINE_JAC) then
+      do i=1,nsmooth
+        call line_jacobi(level,m,b,u)
+      end do
+    end if
+#ifdef MEASURESMOOTHINGRATE
+    call calculate_residual(level,m,b,u,r)
+    log_res_final = log(l2norm(r))
+    log_resreduction(level) = log_resreduction(level) &
+                            + (log_res_final - log_res_initial)
+    nsmooth_total(level) = nsmooth_total(level) + nsmooth
+    deallocate(r%s)
+#endif
+  end subroutine smooth
+
+!==================================================================
+! SOR line smoother
+!==================================================================
+  subroutine line_SOR(level,m,direction,b,u)
+
+    implicit none
+
+    integer, intent(in)                :: level
+    integer, intent(in)                :: m
+    integer, intent(in)                :: direction
+    type(scalar3d), intent(in)         :: b
+    type(scalar3d), intent(inout)      :: u
+    real(kind=rl), allocatable :: r(:)
+    integer :: nz, nlocal
+    real(kind=rl), allocatable :: c(:), utmp(:)
+    integer :: ixmin(5), ixmax(5), dix
+    integer :: iymin(5), iymax(5), diy
+    integer :: color
+    integer :: nsweeps, isweep
+    integer :: ordering
+    real(kind=rl) :: rho
+    integer, dimension(4) :: send_requests, recv_requests
+    integer :: tmp, ierr
+    integer :: iblock
+    logical :: overlap_comms
+
+    ordering = smoother_param%ordering
+    rho = smoother_param%rho
+
+    nz = u%grid_param%nz
+
+    ! Create residual vector
+    allocate(r(nz))
+    ! Allocate memory for auxiliary vectors for Thomas algorithm
+    allocate(c(nz))
+    allocate(utmp(nz))
+    nlocal = u%ix_max-u%ix_min+1
+#ifdef OVERLAPCOMMS
+    overlap_comms = (nlocal > 2)
+#else
+    overlap_comms = .false.
+#endif
+    ! Block 1 (N)
+    ixmin(1) = 1
+    ixmax(1) = nlocal
+    iymin(1) = 1
+    iymax(1) = 1
+    ! Block 2 (S)
+    ixmin(2) = 1
+    ixmax(2) = nlocal
+    iymin(2) = nlocal
+    iymax(2) = nlocal
+    ! Block 3 (W)
+    ixmin(3) = 1
+    ixmax(3) = 1
+    iymin(3) = 2
+    iymax(3) = nlocal-1
+    ! Block 4 (E)
+    ixmin(4) = nlocal
+    ixmax(4) = nlocal
+    iymin(4) = 2
+    iymax(4) = nlocal-1
+    ! Block 5 (INTERIOR)
+    if (overlap_comms) then
+      ixmin(5) = 2
+      ixmax(5) = nlocal-1
+      iymin(5) = 2
+      iymax(5) = nlocal-1
+    else
+      ! If there are no interior cells, do not overlap
+      ! communications and calculations, just loop over interior cells
+      ixmin(5) = 1
+      ixmax(5) = nlocal
+      iymin(5) = 1
+      iymax(5) = nlocal
+    end if
+    dix = +1
+    diy = +1
+    color = 1
+    ! When iteration backwards over the grid, reverse the direction
+    if (direction == DIRECTION_BACKWARD) then
+      do iblock = 1, 5
+        tmp = ixmax(iblock)
+        ixmax(iblock) = ixmin(iblock)
+        ixmin(iblock) = tmp
+        tmp = iymax(iblock)
+        iymax(iblock) = iymin(iblock)
+        iymin(iblock) = tmp
+      end do
+      dix = -1
+      diy = -1
+      color = 0
+    end if
+    nsweeps = 1
+    if (ordering == ORDERING_LEX) then
+      nsweeps = 1
+    else if (ordering == ORDERING_RB) then
+      nsweeps = 2
+    end if
+    do isweep = 1, nsweeps
+      if (overlap_comms) then
+        ! Loop over cells next to boundary (iblock = 1,...,4)
+        do iblock = 1, 4
+          call loop_over_grid(iblock)
+        end do
+        ! Initiate halo exchange
+        call ihaloswap(level,m,u,send_requests,recv_requests)
+      end if
+      ! Loop over INTERIOR cells
+      iblock = 5
+      call loop_over_grid(iblock)
+      if (overlap_comms) then
+        if (m > 0) then
+          call mpi_waitall(4,recv_requests, MPI_STATUSES_IGNORE, ierr)
+        end if
+      else
+        call haloswap(level,m,u)
+      end if
+      color = 1-color
+    end do
+
+    ! Free memory again
+    deallocate(r)
+    deallocate(c)
+    deallocate(utmp)
+
+    contains
+
+    !------------------------------------------------------------------
+    ! Loop over grid, for a given block
+    !------------------------------------------------------------------
+    subroutine loop_over_grid(iblock)
+      implicit none
+      integer, intent(in) :: iblock
+      integer :: ix,iy,iz
+      do ix=ixmin(iblock),ixmax(iblock),dix
+        do iy=iymin(iblock),iymax(iblock),diy
+          if (ordering == ORDERING_RB) then
+            if (mod((ix+u%ix_min)+(iy+u%iy_min),2) .ne. color) cycle
+          end if
+          call apply_tridiag_solve(ix,iy,r,c,b,         &
+                                   u%s(1:nz,iy  ,ix+1), &
+                                   u%s(1:nz,iy  ,ix-1), &
+                                   u%s(1:nz,iy+1,ix  ), &
+                                   u%s(1:nz,iy-1,ix  ), &
+                                   utmp)
+           ! Add to field with overrelaxation-factor
+          do iz=1,nz
+            u%s(iz,iy,ix) = (1.0_rl-rho)*u%s(iz,iy,ix) + rho*utmp(iz)
+          end do
+        end do
+      end do
+    end subroutine loop_over_grid
+
+  end subroutine line_SOR
+
+!==================================================================
+! SSOR line smoother
+!==================================================================
+  subroutine line_SSOR(level,m,direction,b,u)
+    implicit none
+    integer, intent(in)                :: level
+    integer, intent(in)                :: m
+    integer, intent(in)                :: direction
+    type(scalar3d), intent(in)         :: b
+    type(scalar3d), intent(inout)      :: u
+    if (direction == DIRECTION_FORWARD) then
+      call line_SOR(level,m,DIRECTION_FORWARD,b,u)
+      call line_SOR(level,m,DIRECTION_BACKWARD,b,u)
+    else
+      call line_SOR(level,m,DIRECTION_BACKWARD,b,u)
+      call line_SOR(level,m,DIRECTION_FORWARD,b,u)
+    end if
+  end subroutine line_SSOR
+
+!==================================================================
+! Jacobi line smoother
+!==================================================================
+  subroutine line_Jacobi(level,m,b,u)
+    implicit none
+    integer, intent(in)                :: level
+    integer, intent(in)                :: m
+    type(scalar3d), intent(in)         :: b
+    type(scalar3d), intent(inout)      :: u
+    real(kind=rl), allocatable :: r(:)
+    integer :: ix,iy,iz, nz
+    real(kind=rl), dimension(5) :: alpha_T
+    real(kind=rl), allocatable :: c(:), utmp(:)
+    real(kind=rl), allocatable :: u0(:,:,:)
+    integer :: nlocal, halo_size
+    real(kind=rl) :: rho
+    logical :: overlap_comms
+    integer, dimension(4) :: send_requests, recv_requests
+    integer :: ixmin(5), ixmax(5)
+    integer :: iymin(5), iymax(5)
+    integer :: iblock, ierr
+
+    ! Set optimal smoothing parameter on each level
+    rho = 2.0_rl/(2.0_rl+4.0_rl*model_param%omega2*u%grid_param%n**2/(1.0_rl+4.0_rl*model_param%omega2*u%grid_param%n**2))
+
+    nz = u%grid_param%nz
+    nlocal = u%ix_max -u%ix_min + 1
+    halo_size = u%halo_size
+
+#ifdef OVERLAPCOMMS
+    overlap_comms = (nlocal > 2)
+#else
+    overlap_comms = .false.
+#endif
+
+    ! Block 1 (N)
+    ixmin(1) = 1
+    ixmax(1) = nlocal
+    iymin(1) = 1
+    iymax(1) = 1
+    ! Block 2 (S)
+    ixmin(2) = 1
+    ixmax(2) = nlocal
+    iymin(2) = nlocal
+    iymax(2) = nlocal
+    ! Block 3 (W)
+    ixmin(3) = 1
+    ixmax(3) = 1
+    iymin(3) = 2
+    iymax(3) = nlocal-1
+    ! Block 4 (E)
+    ixmin(4) = nlocal
+    ixmax(4) = nlocal
+    iymin(4) = 2
+    iymax(4) = nlocal-1
+    ! Block 5 (INTERIOR)
+    if (overlap_comms) then
+      ixmin(5) = 2
+      ixmax(5) = nlocal-1
+      iymin(5) = 2
+      iymax(5) = nlocal-1
+    else
+      ! If there are no interior cells, do not overlap
+      ! communications and calculations, just loop over interior cells
+      ixmin(5) = 1
+      ixmax(5) = nlocal
+      iymin(5) = 1
+      iymax(5) = nlocal
+    end if
+
+    ! Temporary vector
+    allocate(u0(0:u%grid_param%nz+1,            &
+                1-halo_size:nlocal+halo_size,   &
+                1-halo_size:nlocal+halo_size) )
+    u0(:,:,:) = u%s(:,:,:)
+    ! Create residual vector
+    allocate(r(nz))
+    ! Allocate memory for auxiliary vectors for Thomas algorithm
+    allocate(c(nz))
+    allocate(utmp(nz))
+
+    ! Loop over grid
+    if (overlap_comms) then
+    ! Loop over cells next to boundary (iblock = 1,...,4)
+      do iblock = 1, 4
+        call loop_over_grid(iblock)
+      end do
+      ! Initiate halo exchange
+      call ihaloswap(level,m,u,send_requests,recv_requests)
+    end if
+    ! Loop over INTERIOR cells
+    iblock = 5
+    call loop_over_grid(iblock)
+    if (overlap_comms) then
+      if (m > 0) then
+        call mpi_waitall(4,recv_requests, MPI_STATUSES_IGNORE, ierr)
+      end if
+    else
+      call haloswap(level,m,u)
+    end if
+
+    ! Free memory again
+    deallocate(r)
+    deallocate(c)
+    deallocate(u0)
+    deallocate(utmp)
+
+  contains
+
+  subroutine loop_over_grid(iblock)
+    implicit none
+    integer, intent(in) :: iblock
+    integer :: ix,iy,iz
+    do ix=ixmin(iblock),ixmax(iblock)
+      do iy=iymin(iblock),iymax(iblock)
+        call apply_tridiag_solve(ix,iy,r,c,b,        &
+                                 u0(1:nz,iy  ,ix+1), &
+                                 u0(1:nz,iy  ,ix-1), &
+                                 u0(1:nz,iy+1,ix  ), &
+                                 u0(1:nz,iy-1,ix  ), &
+                                 utmp)
+        ! Add correction
+        do iz=1,nz
+          u%s(iz,iy,ix) = rho*utmp(iz) + (1.0_rl-rho)*u0(iz,iy,ix)
+        end do
+      end do
+    end do
+  end subroutine loop_over_grid
+
+  end subroutine line_Jacobi
+
+!==================================================================
+! At a given horizontal position (ix,iy) (local coordinates),
+! calculate
+!
+! u_out = T(ix,iy)^{-1} (b_(ix,iy)
+!       - sum_{ix',iy' != ix,iy} A_{(ix,iy),(ix',iy')}*u_in(ix',iy'))
+!
+!==================================================================
+  subroutine apply_tridiag_solve(ix,iy,r,c,b, &
+                                 u_in_1,      &
+                                 u_in_2,      &
+                                 u_in_3,      &
+                                 u_in_4,      &
+                                 u_out)
+
+    implicit none
+    integer, intent(in) :: ix
+    integer, intent(in) :: iy
+    real(kind=rl), intent(inout), dimension(:) :: r
+    real(kind=rl), intent(inout), dimension(:) :: c
+    type(scalar3d), intent(in) :: b
+    real(kind=rl), intent(in), dimension(:) :: u_in_1
+    real(kind=rl), intent(in), dimension(:) :: u_in_2
+    real(kind=rl), intent(in), dimension(:) :: u_in_3
+    real(kind=rl), intent(in), dimension(:) :: u_in_4
+    real(kind=rl), intent(inout), dimension(:) :: u_out
+    real(kind=rl), dimension(5) :: alpha_T
+    real(kind=rl) :: Tij
+    real(kind=rl) :: alpha_div_Tij, tmp, b_k_tmp, c_k_tmp
+    integer :: iz, nz
+
+    nz = b%grid_param%nz
+
+    call construct_alpha_T(b%grid_param,  &
+                           ix+b%ix_min-1, &
+                           iy+b%iy_min-1, &
+                           alpha_T,Tij)
+    ! Calculate r_i = b_i - A_{ij} u_i
+    do iz=1,nz
+      r(iz) = b%s(iz,iy,ix) - vert_coeff%d(iz) * ( &
+                alpha_T(1) * u_in_1(iz) + &
+                alpha_T(2) * u_in_2(iz) + &
+                alpha_T(3) * u_in_3(iz) + &
+                alpha_T(4) * u_in_4(iz) )
+    end do
+
+    ! Thomas algorithm
+    ! STEP 1: Create modified coefficients
+    iz = 1
+    alpha_div_Tij = alpha_T(5)/Tij
+    tmp = (vert_coeff%a(iz)-vert_coeff%b(iz)-vert_coeff%c(iz)) &
+             - alpha_div_Tij
+    c(iz) = vert_coeff%b(iz)/tmp
+    u_out(iz) = r(iz) / (tmp*Tij*vert_coeff%d(iz))
+    do iz=2,nz
+      b_k_tmp = vert_coeff%b(iz)
+      c_k_tmp = vert_coeff%c(iz)
+      tmp = ((vert_coeff%a(iz)-b_k_tmp-c_k_tmp)-alpha_div_Tij) &
+          - c(iz-1)*c_k_tmp
+      c(iz) = b_k_tmp / tmp
+      u_out(iz) = (r(iz) / (Tij*vert_coeff%d(iz)) - u_out(iz-1)*c_k_tmp) / tmp
+    end do
+    ! STEP 2: back substitution
+    do iz=nz-1,1,-1
+      u_out(iz) = u_out(iz) - c(iz) * u_out(iz+1)
+    end do
+  end subroutine apply_tridiag_solve
+
+end module discretisation
diff --git a/tensorproductmultigrid_Source/messages.f90 b/tensorproductmultigrid_Source/messages.f90
new file mode 100644
index 000000000..4ade93172
--- /dev/null
+++ b/tensorproductmultigrid_Source/messages.f90
@@ -0,0 +1,100 @@
+!=== COPYRIGHT AND LICENSE STATEMENT ===
+!
+!  This file is part of the TensorProductMultigrid code.
+!  
+!  (c) The copyright relating to this work is owned jointly by the
+!  Crown, Met Office and NERC [2014]. However, it has been created
+!  with the help of the GungHo Consortium, whose members are identified
+!  at https://puma.nerc.ac.uk/trac/GungHo/wiki .
+!  
+!  Main Developer: Eike Mueller
+!  
+!  TensorProductMultigrid is free software: you can redistribute it and/or
+!  modify it under the terms of the GNU Lesser General Public License as
+!  published by the Free Software Foundation, either version 3 of the
+!  License, or (at your option) any later version.
+!  
+!  TensorProductMultigrid is distributed in the hope that it will be useful,
+!  but WITHOUT ANY WARRANTY; without even the implied warranty of
+!  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+!  GNU Lesser General Public License for more details.
+!  
+!  You should have received a copy of the GNU Lesser General Public License
+!  along with TensorProductMultigrid (see files COPYING and COPYING.LESSER).
+!  If not, see <http://www.gnu.org/licenses/>.
+!
+!=== COPYRIGHT AND LICENSE STATEMENT ===
+
+
+!==================================================================
+!
+!  Module for error/warning/info messages
+!
+!    Eike Mueller, University of Bath, Feb 2012
+!
+!==================================================================
+module messages
+
+  use parameters
+  use mpi
+
+  implicit none
+
+contains
+
+!==================================================================
+! Print error message and exit
+!==================================================================
+  subroutine fatalerror(message)
+    implicit none
+    character(len=*), intent(in) :: message
+    integer :: ierr, rank
+    integer, parameter :: errorcode = -1
+    call mpi_comm_rank(MPI_COMM_WORLD,rank,ierr)
+    if (rank == 0) then
+      write(STDERR,'("FATAL ERROR: ",A)') message
+    end if
+    call mpi_finalize(ierr)
+    STOP
+  end subroutine fatalerror
+
+!==================================================================
+! Print error message
+!==================================================================
+  subroutine error(message)
+    implicit none
+    character(len=*), intent(in) :: message
+    integer :: ierr, rank
+    call mpi_comm_rank(MPI_COMM_WORLD,rank,ierr)
+    if (rank == 0) then
+      write(STDERR,'("ERROR: ",A)') message
+    end if
+  end subroutine error
+
+!==================================================================
+! Print warning message
+!==================================================================
+  subroutine warning(message)
+    implicit none
+    character(len=*), intent(in) :: message
+    integer :: ierr, rank
+    call mpi_comm_rank(MPI_COMM_WORLD,rank,ierr)
+    if (rank == 0) then
+      write(STDERR,'("WARNING: ",A)') message
+    end if
+  end subroutine warning
+
+!==================================================================
+! Print info message
+!==================================================================
+  subroutine information(message)
+    implicit none
+    character(len=*), intent(in) :: message
+    integer :: ierr, rank
+    call mpi_comm_rank(MPI_COMM_WORLD,rank,ierr)
+    if (rank == 0) then
+      write(STDERR,'("INFO: ",A)') message
+    end if
+  end subroutine information
+
+end module messages
diff --git a/tensorproductmultigrid_Source/mg_main.f90 b/tensorproductmultigrid_Source/mg_main.f90
new file mode 100644
index 000000000..0903a893a
--- /dev/null
+++ b/tensorproductmultigrid_Source/mg_main.f90
@@ -0,0 +1,700 @@
+!=== COPYRIGHT AND LICENSE STATEMENT ===
+!
+!  This file is part of the TensorProductMultigrid code.
+!  
+!  (c) The copyright relating to this work is owned jointly by the
+!  Crown, Met Office and NERC [2014]. However, it has been created
+!  with the help of the GungHo Consortium, whose members are identified
+!  at https://puma.nerc.ac.uk/trac/GungHo/wiki .
+!  
+!  Main Developer: Eike Mueller
+!  
+!  TensorProductMultigrid is free software: you can redistribute it and/or
+!  modify it under the terms of the GNU Lesser General Public License as
+!  published by the Free Software Foundation, either version 3 of the
+!  License, or (at your option) any later version.
+!  
+!  TensorProductMultigrid is distributed in the hope that it will be useful,
+!  but WITHOUT ANY WARRANTY; without even the implied warranty of
+!  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+!  GNU Lesser General Public License for more details.
+!  
+!  You should have received a copy of the GNU Lesser General Public License
+!  along with TensorProductMultigrid (see files COPYING and COPYING.LESSER).
+!  If not, see <http://www.gnu.org/licenses/>.
+!
+!=== COPYRIGHT AND LICENSE STATEMENT ===
+
+
+!==================================================================
+!
+!  Main program for multigrid solver code for Helmholtz/Poisson
+!  equation, discretised in the cell centred finite volume scheme
+!
+!    Eike Mueller, University of Bath, Feb 2012
+!
+!==================================================================
+
+!==================================================================
+! Main program
+!==================================================================
+
+program mg_main
+
+  use discretisation
+  use parameters
+  use datatypes
+  use multigrid
+  use conjugategradient
+  use solver
+  use profiles
+  use messages
+  use communication
+  use timer
+  use mpi
+
+  implicit none
+
+  type(grid_parameters)     :: grid_param
+  type(comm_parameters)     :: comm_param
+  type(model_parameters)    :: model_param
+  type(smoother_parameters) :: smoother_param
+  type(mg_parameters)       :: mg_param
+  type(cg_parameters)       :: cg_param
+  type(solver_parameters)   :: solver_param
+
+  type(scalar3d) :: u
+  type(scalar3d) :: b
+  type(scalar3d) :: r
+#ifdef TESTCONVERGENCE
+  type(scalar3d) :: uerror
+  real(kind=rl) :: l2error
+#endif
+
+  ! Timers
+  type(time) :: t_solve
+  type(time) :: t_readparam
+  type(time) :: t_initialise
+  type(time) :: t_finalise
+
+  ! --- Parameter file ---
+  character(len=256) :: parameterfile
+
+  ! --- Name of executable ---
+  character(len=256) :: executable
+
+  ! --- General parameters ---
+  logical :: savefields   ! Save fields to disk?
+
+  integer :: ierr
+
+  integer :: x
+
+  integer :: i, int_size
+
+  x = 1
+  do while(x == 0)
+  end do
+
+  ! Initialise MPI ...
+  call mpi_init(ierr)
+
+  ! ... and pre initialise communication module
+  call comm_preinitialise()
+
+  ! Parse command line arguments
+  if (iargc() .lt. 2) then
+    call getarg(0, executable)
+    if (i_am_master_mpi) then
+      write(STDOUT,*) "Usage: " // trim(executable) // " <parameterfile>"
+    end if
+    call mpi_finalize(ierr)
+    stop
+  end if
+
+  call getarg(1, parameterfile)
+  if (i_am_master_mpi) then
+    write(STDOUT,*) "+--------------------------------------+"
+    write(STDOUT,*) "+-- MULTIGRID SOLVER ------------------+"
+    write(STDOUT,*) "+--------------------------------------+"
+  end if
+
+  if (i_am_master_mpi) then
+    write(STDOUT,*) ''
+    write(STDOUT,*) 'Compile time parameters:'
+    write(STDOUT,*) ''
+#ifdef CARTESIANGEOMETRY
+    write(STDOUT,*) '  Geometry: Cartesian'
+#else
+    write(STDOUT,*) '  Geometry: Spherical'
+#endif
+#ifdef USELAPACK
+    write(STDOUT,*) '  Use Lapack: Yes'
+#else
+    write(STDOUT,*) '  Use Lapack: No'
+#endif
+#ifdef OVERLAPCOMMS
+    write(STDOUT,*) '  Overlap communications and calculation: Yes'
+#else
+    write(STDOUT,*) '  Overlap communications and calculation: No'
+#endif
+    write(STDOUT,*) ''
+    i = huge(i)
+    int_size = 1
+    do while (i > 0)
+      int_size = int_size + 1
+      i = i/2
+    end do
+    write(STDOUT,'("   size of integer: ",I3," bit")') int_size
+    write(STDOUT,*) ''
+  end if
+
+  ! Initialise timing module
+  call initialise_timing("timing.txt")
+
+  ! Read parameter files
+  call initialise_timer(t_readparam,"t_readparam")
+  call start_timer(t_readparam)
+  if (i_am_master_mpi) then
+    write(STDOUT,*) "Reading parameters from file '" // &
+                    trim(parameterfile) // "'"
+  end if
+  call read_general_parameters(parameterfile,savefields)
+  call read_solver_parameters(parameterfile,solver_param)
+  call read_grid_parameters(parameterfile,grid_param)
+  call read_comm_parameters(parameterfile,comm_param)
+  call read_model_parameters(parameterfile,model_param)
+  call read_smoother_parameters(parameterfile,smoother_param)
+  call read_multigrid_parameters(parameterfile,mg_param)
+  call read_conjugategradient_parameters(parameterfile,cg_param)
+  call finish_timer(t_readparam)
+
+  if (i_am_master_mpi) then
+    write(STDOUT,*) ''
+  end if
+
+  ! Initialise discretisation module
+  call discretisation_initialise(grid_param,     &
+                                 model_param,    &
+                                 smoother_param, &
+                                 mg_param%n_lev )
+
+  ! Initialise communication module
+  call initialise_timer(t_initialise,"t_initialise")
+  call start_timer(t_initialise)
+  call comm_initialise(mg_param%n_lev,     &
+                       mg_param%lev_split, &
+                       grid_param,         &
+                       comm_param)
+
+  ! Initialise multigrid
+  call mg_initialise(grid_param,       &
+                     comm_param,       &
+                     model_param,      &
+                     smoother_param,   &
+                     mg_param,         &
+                     cg_param          &
+                    )
+
+  call create_scalar3d(MPI_COMM_HORIZ,grid_param,comm_param%halo_size,u)
+  call create_scalar3d(MPI_COMM_HORIZ,grid_param,comm_param%halo_size,b)
+  call create_scalar3d(MPI_COMM_HORIZ,grid_param,comm_param%halo_size,r)
+  call initialise_rhs(grid_param,model_param,b)
+#ifdef TESTCONVERGENCE
+  call create_scalar3d(MPI_COMM_HORIZ,grid_param,comm_param%halo_size,uerror)
+  call analytical_solution(grid_param,uerror)
+#endif
+  call finish_timer(t_initialise)
+  if (i_am_master_mpi) then
+    write(STDOUT,*) ''
+  end if
+
+  ! Initialise ghosts in initial solution, as mg_solve assumes that they
+  ! are up-to-date
+  call haloswap(mg_param%n_lev,pproc,u)
+
+  ! Solve using multigrid
+  call initialise_timer(t_solve,"t_solve")
+  call start_timer(t_solve)
+  comm_measuretime = .True.
+#ifdef MEASUREHALOSWAP
+  call measurehaloswap()
+#else
+  call mg_solve(b,u,solver_param)
+#endif
+  comm_measuretime = .False.
+  call finish_timer(t_solve)
+
+#ifdef TESTCONVERGENCE
+  call daxpy_scalar3d(-1.0_rl,u,uerror)
+  call haloswap(mg_param%n_lev,pproc,uerror)
+  l2error = l2norm(uerror)
+  if (i_am_master_mpi) then
+    write(STDOUT,'("||error|| = ",E20.12," log_2(||error||) = ",E20.12)') &
+      l2error, log(l2error)/log(2.0_rl)
+  end if
+  if (savefields) then
+    call save_scalar3d(MPI_COMM_HORIZ,uerror,"error")
+  end if
+#endif
+
+  ! Save fields to disk
+  if (savefields) then
+    call haloswap(mg_param%n_lev,pproc,u)
+    call save_scalar3d(MPI_COMM_HORIZ,u,"solution")
+    call volscale_scalar3d(b,1)
+    call calculate_residual(mg_param%n_lev,pproc,b,u,r)
+    call volscale_scalar3d(b,-1)
+    call volscale_scalar3d(r,-1)
+    call haloswap(mg_param%n_lev,pproc,r)
+    call save_scalar3d(MPI_COMM_HORIZ,r,"residual")
+  end if
+
+  if (i_am_master_mpi) then
+    write(STDOUT,*) ''
+  end if
+
+  call discretisation_finalise()
+
+  ! Finalise
+  call initialise_timer(t_finalise,"t_finalise")
+  call start_timer(t_finalise)
+  call mg_finalise()
+  call cg_finalise()
+  ! Deallocate memory
+  call destroy_scalar3d(u)
+  call destroy_scalar3d(b)
+  call destroy_scalar3d(r)
+#ifdef TESTCONVERGENCE
+  call destroy_scalar3d(uerror)
+#endif
+
+
+  ! Finalise communications ...
+  call comm_finalise(mg_param%n_lev,mg_param%lev_split)
+  call finish_timer(t_finalise)
+  call print_timerinfo("# --- Main timing results ---")
+  call print_elapsed(t_readparam,.true.,1.0_rl)
+  call print_elapsed(t_initialise,.true.,1.0_rl)
+  call print_elapsed(t_solve,.true.,1.0_rl)
+  call print_elapsed(t_finalise,.true.,1.0_rl)
+  ! Finalise timing
+  call finalise_timing()
+  ! ... and MPI
+  call mpi_finalize(ierr)
+
+end program mg_main
+
+!==================================================================
+! Read general parameters from namelist file
+!==================================================================
+subroutine read_general_parameters(filename,savefields_out)
+  use parameters
+  use communication
+  use messages
+  use mpi
+  implicit none
+  character(*), intent(in) :: filename
+  logical, intent(out) :: savefields_out
+  integer, parameter :: file_id = 16
+  logical :: savefields
+  integer :: ierr
+  namelist /parameters_general/ savefields
+  if (i_am_master_mpi) then
+    open(file_id,file=filename)
+    read(file_id,NML=parameters_general)
+    close(file_id)
+    write(STDOUT,NML=parameters_general)
+    write(STDOUT,'("---------------------------------------------- ")')
+    write(STDOUT,'("General parameters")')
+    write(STDOUT,'("    Save fields = ",L6)') savefields
+    write(STDOUT,'("---------------------------------------------- ")')
+    write(STDOUT,'("")')
+  end if
+  call mpi_bcast(savefields,1,MPI_LOGICAL,master_rank,MPI_COMM_WORLD,ierr)
+  savefields_out = savefields
+end subroutine read_general_parameters
+
+!==================================================================
+! Read solver parameters from namelist file
+!==================================================================
+subroutine read_solver_parameters(filename,solver_param_out)
+  use solver
+  use parameters
+  use communication
+  use messages
+  use mpi
+  implicit none
+  character(*), intent(in) :: filename
+  type(solver_parameters), intent(out) :: solver_param_out
+  integer :: solvertype
+  real(kind=rl) :: resreduction
+  integer :: maxiter
+  integer, parameter :: file_id = 16
+  integer :: ierr
+  namelist /parameters_solver/ solvertype,resreduction, maxiter
+  if (i_am_master_mpi) then
+    open(file_id,file=filename)
+    read(file_id,NML=parameters_solver)
+    close(file_id)
+    write(STDOUT,NML=parameters_solver)
+    write(STDOUT,'("---------------------------------------------- ")')
+    write(STDOUT,'("Solver parameters ")')
+    if (solvertype == SOLVER_RICHARDSON) then
+      write(STDOUT,'("    solver       = Richardson")')
+    else if (solvertype == SOLVER_CG) then
+      write(STDOUT,'("    solver       = CG")')
+    else
+      call fatalerror("Unknown solver type")
+    end if
+    write(STDOUT,'("    maxiter      = ",I8)') maxiter
+    write(STDOUT,'("    resreduction = ",E15.6)') resreduction
+    write(STDOUT,'("---------------------------------------------- ")')
+    write(*,'("")')
+  end if
+  call mpi_bcast(solvertype,1,MPI_INTEGER,master_rank,MPI_COMM_WORLD,ierr)
+  call mpi_bcast(maxiter,1,MPI_INTEGER,master_rank,MPI_COMM_WORLD,ierr)
+  call mpi_bcast(resreduction,1,MPI_DOUBLE_PRECISION,master_rank,MPI_COMM_WORLD,ierr)
+  solver_param_out%solvertype = solvertype
+  solver_param_out%maxiter = maxiter
+  solver_param_out%resreduction = resreduction
+end subroutine read_solver_parameters
+
+!==================================================================
+! Read grid parameters from namelist file
+!==================================================================
+subroutine read_grid_parameters(filename,grid_param)
+  use parameters
+  use datatypes
+  use communication
+  use messages
+  use mpi
+  implicit none
+  character(*), intent(in) :: filename
+  type(grid_parameters), intent(out) :: grid_param
+  ! Grid parameters
+  integer :: n, nz
+  real(kind=rl) :: L, H
+  integer :: vertbc
+  logical :: graded
+  integer, parameter :: file_id = 16
+  integer :: ierr
+  namelist /parameters_grid/ n, nz, L, H, vertbc, graded
+  if (i_am_master_mpi) then
+    open(file_id,file=filename)
+    read(file_id,NML=parameters_grid)
+    close(file_id)
+    write(STDOUT,NML=parameters_grid)
+    write(STDOUT,'("---------------------------------------------- ")')
+    write(STDOUT,'("Grid parameters")')
+    write(STDOUT,'("    n      = ",I15)') n
+    write(STDOUT,'("    nz     = ",I15)') nz
+    write(STDOUT,'("    L      = ",F8.4)') L
+    write(STDOUT,'("    H      = ",F8.4)') H
+    if (vertbc == VERTBC_DIRICHLET) then
+      write(STDOUT,'("    vertbc = DIRICHLET")')
+    else if (vertbc == VERTBC_NEUMANN) then
+      write(STDOUT,'("    vertbc = NEUMANN")')
+    else
+      vertbc = -1
+    end if
+    write(STDOUT,'("    graded =",L3)') graded
+    write(STDOUT,'("---------------------------------------------- ")')
+    write(STDOUT,'("")')
+  end if
+  call mpi_bcast(n,1,MPI_INTEGER,master_rank,MPI_COMM_WORLD,ierr)
+  call mpi_bcast(nz,1,MPI_INTEGER,master_rank,MPI_COMM_WORLD,ierr)
+  call mpi_bcast(L,1,MPI_DOUBLE_PRECISION,master_rank,MPI_COMM_WORLD,ierr)
+  call mpi_bcast(H,1,MPI_DOUBLE_PRECISION,master_rank,MPI_COMM_WORLD,ierr)
+  call mpi_bcast(vertbc,1,MPI_INTEGER,master_rank,MPI_COMM_WORLD,ierr)
+  call mpi_bcast(graded,1,MPI_LOGICAL,master_rank,MPI_COMM_WORLD,ierr)
+  grid_param%n = n
+  grid_param%nz = nz
+  grid_param%L = L
+  grid_param%H = H
+  grid_param%vertbc = vertbc
+  grid_param%graded = graded
+  if (vertbc == -1) then
+    call fatalerror("vertbc has to be either 1 (Dirichlet) or 2 (Neumann)")
+  end if
+end subroutine read_grid_parameters
+
+!==================================================================
+! Read parallel communication parameters from namelist file
+!==================================================================
+subroutine read_comm_parameters(filename,comm_param)
+  use parameters
+  use communication
+  use messages
+  use mpi
+  implicit none
+  character(*), intent(in) :: filename
+  type(comm_parameters), intent(out) :: comm_param
+  ! Grid parameters
+  integer :: halo_size
+  integer, parameter :: file_id = 16
+  integer :: ierr
+  namelist /parameters_communication/ halo_size
+  if (i_am_master_mpi) then
+    open(file_id,file=filename)
+    read(file_id,NML=parameters_communication)
+    close(file_id)
+    write(STDOUT,NML=parameters_communication)
+    write(STDOUT,'("---------------------------------------------- ")')
+    write(STDOUT,'("Communication parameters")')
+    write(STDOUT,'("    halosize  = ",I3)') halo_size
+    write(STDOUT,'("---------------------------------------------- ")')
+    write(STDOUT,'("")')
+    if ( (halo_size .ne. 1) ) then
+      halo_size = -1
+    end if
+  end if
+  call mpi_bcast(halo_size,1,MPI_INTEGER,master_rank,MPI_COMM_WORLD,ierr)
+  comm_param%halo_size = halo_size
+  if (halo_size == -1) then
+    call fatalerror("Halo size has to be 1.")
+  end if
+end subroutine read_comm_parameters
+
+!==================================================================
+! Read model parameters from namelist file
+!==================================================================
+subroutine read_model_parameters(filename,model_param)
+  use parameters
+  use discretisation
+  use communication
+  use messages
+  use mpi
+  implicit none
+  character(*), intent(in) :: filename
+  type(model_parameters), intent(out) :: model_param
+  real(kind=rl) :: omega2, lambda2, delta
+  integer, parameter :: file_id = 16
+  integer :: ierr
+  namelist /parameters_model/ omega2, lambda2, delta
+  if (i_am_master_mpi) then
+    open(file_id,file=filename)
+    read(file_id,NML=parameters_model)
+    close(file_id)
+    write(STDOUT,NML=parameters_model)
+    write(STDOUT,'("---------------------------------------------- ")')
+    write(STDOUT,'("Model parameters")')
+    write(STDOUT,'("    omega2  = ",E15.6)') omega2
+    write(STDOUT,'("    lambda2 = ",E15.6)') lambda2
+    write(STDOUT,'("    delta   = ",E15.6)') delta
+    write(STDOUT,'("---------------------------------------------- ")')
+    write(STDOUT,'("")')
+  end if
+  call mpi_bcast(omega2,1,MPI_DOUBLE_PRECISION,master_rank,MPI_COMM_WORLD,ierr)
+  call mpi_bcast(lambda2,1,MPI_DOUBLE_PRECISION,master_rank,MPI_COMM_WORLD,ierr)
+  call mpi_bcast(delta,1,MPI_DOUBLE_PRECISION,master_rank,MPI_COMM_WORLD,ierr)
+  model_param%omega2 = omega2
+  model_param%lambda2 = lambda2
+  model_param%delta = delta
+end subroutine read_model_parameters
+
+!==================================================================
+! Read smoother parameters from namelist file
+!==================================================================
+subroutine read_smoother_parameters(filename,smoother_param)
+  use parameters
+  use discretisation
+  use communication
+  use messages
+  use mpi
+  implicit none
+  character(*), intent(in) :: filename
+  type(smoother_parameters), intent(out) :: smoother_param
+  integer, parameter :: file_id = 16
+  integer :: smoother, ordering
+  real(kind=rl) :: rho
+  integer :: ierr
+  namelist /parameters_smoother/ smoother,           &
+                                  ordering,           &
+                                  rho
+  if (i_am_master_mpi) then
+    open(file_id,file=filename)
+    read(file_id,NML=parameters_smoother)
+    close(file_id)
+    write(STDOUT,NML=parameters_smoother)
+
+    write(STDOUT,'("---------------------------------------------- ")')
+    write(STDOUT,'("Smoother parameters")')
+    ! Smoother
+    if (smoother == SMOOTHER_LINE_SOR) then
+      write(STDOUT,'("    smoother      = LINE_SOR")')
+    else if (smoother == SMOOTHER_LINE_SSOR) then
+      write(STDOUT,'("    smoother      = LINE_SSOR")')
+    else if (smoother == SMOOTHER_LINE_JAC) then
+      write(STDOUT,'("    smoother      = LINE_JACOBI")')
+    else
+      smoother = -1
+    end if
+
+    if (ordering == ORDERING_LEX) then
+      write(STDOUT,'("    ordering      = LEX")')
+    else if (ordering == ORDERING_RB) then
+      write(STDOUT,'("    ordering      = RB")')
+    else
+      ordering = -1
+    end if
+    write(STDOUT,'("    rho = ",E15.6)') rho
+    write(STDOUT,'("---------------------------------------------- ")')
+    write(STDOUT,'("")')
+  end if
+  call mpi_bcast(smoother,1,MPI_INTEGER,master_rank,MPI_COMM_WORLD,ierr)
+  call mpi_bcast(ordering,1,MPI_INTEGER,master_rank,MPI_COMM_WORLD,ierr)
+  call mpi_bcast(rho,1,MPI_DOUBLE_PRECISION,master_rank,MPI_COMM_WORLD,ierr)
+  smoother_param%smoother = smoother
+  smoother_param%ordering = ordering
+  smoother_param%rho = rho
+  if (smoother == -1) then
+    call fatalerror('Unknown smoother.')
+  end if
+  if (ordering == -1) then
+    call fatalerror('Unknown ordering.')
+  end if
+
+end subroutine read_smoother_parameters
+
+!==================================================================
+! Read multigrid parameters from namelist file
+!==================================================================
+subroutine read_multigrid_parameters(filename,mg_param)
+  use parameters
+  use multigrid
+  use communication
+  use messages
+  use mpi
+  implicit none
+  character(*), intent(in) :: filename
+  type(mg_parameters), intent(out) :: mg_param
+  integer, parameter :: file_id = 16
+  integer :: verbose, n_lev, lev_split, n_presmooth, n_postsmooth,    &
+             prolongation, restriction, n_coarsegridsmooth, &
+             coarsegridsolver
+  integer :: ierr
+  namelist /parameters_multigrid/ verbose,            &
+                                  n_lev,              &
+                                  lev_split,          &
+                                  n_presmooth,        &
+                                  n_postsmooth,       &
+                                  n_coarsegridsmooth, &
+                                  prolongation,       &
+                                  restriction,        &
+                                  coarsegridsolver
+  if (i_am_master_mpi) then
+    open(file_id,file=filename)
+    read(file_id,NML=parameters_multigrid)
+    close(file_id)
+    write(STDOUT,NML=parameters_multigrid)
+    write(STDOUT,'("---------------------------------------------- ")')
+    write(STDOUT,'("Multigrid parameters")')
+    write(STDOUT,'("    verbose      = ",L6)') verbose
+    write(STDOUT,'("    levels       = ",I3)') n_lev
+    write(STDOUT,'("    splitlevel   = ",I3)') lev_split
+    write(STDOUT,'("    n_presmooth  = ",I6)') n_presmooth
+    write(STDOUT,'("    n_postsmooth = ",I6)') n_postsmooth
+    if (restriction == REST_CELLAVERAGE) then
+      write(STDOUT,'("    restriction   = CELLAVERAGE")')
+    else
+      restriction = -1
+    endif
+    if (prolongation == PROL_CONSTANT) then
+      write(STDOUT,'("    prolongation  = CONSTANT")')
+    else if (prolongation == PROL_TRILINEAR) then
+#ifdef PIECEWISELINEAR
+      write(STDOUT,'("    prolongation  = TRILINEAR (piecewise linear)")')
+#else
+      write(STDOUT,'("    prolongation  = TRILINEAR (regression plane)")')
+#endif
+    else
+      prolongation = -1
+    endif
+    if (coarsegridsolver == COARSEGRIDSOLVER_CG) then
+      write(STDOUT,'("    coarse solver = CG")')
+    else if (coarsegridsolver == COARSEGRIDSOLVER_SMOOTHER) then
+      write(STDOUT,'("    coarse solver = SMOOTHER (",I6," iterations)")') &
+        n_coarsegridsmooth
+    else
+      coarsegridsolver = -1
+    end if
+    write(STDOUT,'("---------------------------------------------- ")')
+    write(*,'("")')
+
+  end if
+  call mpi_bcast(verbose,1,MPI_INTEGER,master_rank,MPI_COMM_WORLD,ierr)
+  call mpi_bcast(n_lev,1,MPI_INTEGER,master_rank,MPI_COMM_WORLD,ierr)
+  call mpi_bcast(lev_split,1,MPI_INTEGER,master_rank,MPI_COMM_WORLD,ierr)
+  call mpi_bcast(n_presmooth,1,MPI_INTEGER,master_rank,MPI_COMM_WORLD,ierr)
+  call mpi_bcast(n_postsmooth,1,MPI_INTEGER,master_rank,MPI_COMM_WORLD,ierr)
+  call mpi_bcast(n_coarsegridsmooth,1,MPI_INTEGER,master_rank,MPI_COMM_WORLD, &
+                 ierr)
+  call mpi_bcast(prolongation,1,MPI_INTEGER,master_rank,MPI_COMM_WORLD,ierr)
+  call mpi_bcast(restriction,1,MPI_INTEGER,master_rank,MPI_COMM_WORLD,ierr)
+  call mpi_bcast(coarsegridsolver,1,MPI_Integer,master_rank,MPI_COMM_WORLD,ierr)
+  mg_param%verbose = verbose
+  mg_param%n_lev = n_lev
+  mg_param%lev_split = lev_split
+  mg_param%n_presmooth = n_presmooth
+  mg_param%n_postsmooth = n_postsmooth
+  mg_param%n_coarsegridsmooth = n_coarsegridsmooth
+  mg_param%prolongation = prolongation
+  mg_param%restriction = restriction
+  mg_param%coarsegridsolver = coarsegridsolver
+  if (restriction == -1) then
+    call fatalerror('Unknown restriction.')
+  end if
+  if (prolongation == -1) then
+    call fatalerror('Unknown prolongation.')
+  end if
+  if (coarsegridsolver == -1) then
+    call fatalerror('Unknown coarse grid solver.')
+  end if
+end subroutine read_multigrid_parameters
+
+!==================================================================
+! Read conjugate gradient parameters from namelist file
+!==================================================================
+subroutine read_conjugategradient_parameters(filename,cg_param)
+  use parameters
+  use communication
+  use conjugategradient
+  use communication
+  use messages
+  use mpi
+  implicit none
+  character(*), intent(in) :: filename
+  type(cg_parameters), intent(out) :: cg_param
+  integer, parameter :: file_id = 16
+  integer :: verbose, maxiter, n_prec
+  real(kind=rl) :: resreduction
+  integer :: ierr
+  namelist /parameters_conjugategradient/ verbose,      &
+                                          maxiter,      &
+                                          resreduction, &
+                                          n_prec
+  if (i_am_master_mpi) then
+    open(file_id,file=filename)
+    read(file_id,NML=parameters_conjugategradient)
+    close(file_id)
+    write(STDOUT,NML=parameters_conjugategradient)
+    write(STDOUT,'("---------------------------------------------- ")')
+    write(STDOUT,'("Conjugate gradient parameters")')
+    write(STDOUT,'("    verbose       = ",I6)') verbose
+    write(STDOUT,'("    maxiter       = ",I6)') maxiter
+    write(STDOUT,'("    resreduction  = ",E15.6)') resreduction
+    write(STDOUT,'("    n_prec        = ",I6)') n_prec
+    write(STDOUT,'("---------------------------------------------- ")')
+    write(STDOUT,'("")')
+  end if
+  call mpi_bcast(verbose,1,MPI_INTEGER,master_rank,MPI_COMM_WORLD,ierr)
+  call mpi_bcast(maxiter,1,MPI_INTEGER,master_rank,MPI_COMM_WORLD,ierr)
+  call mpi_bcast(resreduction,1,MPI_DOUBLE_PRECISION,master_rank,MPI_COMM_WORLD,ierr)
+  call mpi_bcast(n_prec,1,MPI_INTEGER,master_rank,MPI_COMM_WORLD,ierr)
+  cg_param%verbose = verbose
+  cg_param%maxiter = maxiter
+  cg_param%resreduction = resreduction
+  cg_param%n_prec = n_prec
+end subroutine read_conjugategradient_parameters
+
diff --git a/tensorproductmultigrid_Source/multigrid.f90 b/tensorproductmultigrid_Source/multigrid.f90
new file mode 100644
index 000000000..3fb7f6c9e
--- /dev/null
+++ b/tensorproductmultigrid_Source/multigrid.f90
@@ -0,0 +1,1141 @@
+!=== COPYRIGHT AND LICENSE STATEMENT ===
+!
+!  This file is part of the TensorProductMultigrid code.
+!  
+!  (c) The copyright relating to this work is owned jointly by the
+!  Crown, Met Office and NERC [2014]. However, it has been created
+!  with the help of the GungHo Consortium, whose members are identified
+!  at https://puma.nerc.ac.uk/trac/GungHo/wiki .
+!  
+!  Main Developer: Eike Mueller
+!  
+!  TensorProductMultigrid is free software: you can redistribute it and/or
+!  modify it under the terms of the GNU Lesser General Public License as
+!  published by the Free Software Foundation, either version 3 of the
+!  License, or (at your option) any later version.
+!  
+!  TensorProductMultigrid is distributed in the hope that it will be useful,
+!  but WITHOUT ANY WARRANTY; without even the implied warranty of
+!  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+!  GNU Lesser General Public License for more details.
+!  
+!  You should have received a copy of the GNU Lesser General Public License
+!  along with TensorProductMultigrid (see files COPYING and COPYING.LESSER).
+!  If not, see <http://www.gnu.org/licenses/>.
+!
+!=== COPYRIGHT AND LICENSE STATEMENT ===
+
+
+!==================================================================
+!
+!  Geometric multigrid module for cell centred finite volume
+!  discretisation.
+!
+!    Eike Mueller, University of Bath, Feb 2012
+!
+!==================================================================
+module multigrid
+
+  use mpi
+  use parameters
+  use datatypes
+  use discretisation
+  use messages
+  use solver
+  use conjugategradient
+  use communication
+  use timer
+
+  implicit none
+
+public::mg_parameters
+public::mg_initialise
+public::mg_finalise
+public::mg_solve
+public::measurehaloswap
+public::REST_CELLAVERAGE
+public::PROL_CONSTANT
+public::PROL_TRILINEAR
+public::COARSEGRIDSOLVER_SMOOTHER
+public::COARSEGRIDSOLVER_CG
+
+private
+
+  ! --- multigrid parameter constants ---
+  ! restriction
+  integer, parameter :: REST_CELLAVERAGE = 1
+  ! prolongation method
+  integer, parameter :: PROL_CONSTANT = 1
+  integer, parameter :: PROL_TRILINEAR = 2
+  ! Coarse grid solver
+  integer, parameter :: COARSEGRIDSOLVER_SMOOTHER = 1
+  integer, parameter :: COARSEGRIDSOLVER_CG = 2
+
+  ! --- Multigrid parameters type ---
+  type mg_parameters
+    ! Verbosity level
+    integer :: verbose
+    ! Number of MG levels
+    integer :: n_lev
+    ! First level where data is pulled together
+    integer :: lev_split
+    ! Number of presmoothing steps
+    integer :: n_presmooth
+    ! Number of postsmoothing steps
+    integer :: n_postsmooth
+    ! Number of smoothing steps on coarsest level
+    integer :: n_coarsegridsmooth
+    ! Prolongation (see PROL_... for allowed values)
+    integer :: prolongation
+    ! Restriction (see RESTR_... for allowed values)
+    integer :: restriction
+    ! Smoother (see SMOOTHER_... for allowed values)
+    integer :: smoother
+    ! Relaxation factor
+    real(kind=rl) :: omega
+    ! Smoother on coarse grid
+    integer :: coarsegridsolver
+    ! ordering of grid points for smoother
+    integer :: ordering
+  end type mg_parameters
+
+! --- Parameters ---
+  type(mg_parameters) :: mg_param
+  type(model_parameters) :: model_param
+  type(smoother_parameters) :: smoother_param
+  type(grid_parameters) :: grid_param
+  type(comm_parameters) :: comm_param
+  type(cg_parameters) :: cg_param
+
+
+! --- Gridded and scalar data structures ---
+  ! Solution vector
+  type(scalar3d), allocatable :: u(:,:)
+  ! RHS vector
+  type(scalar3d), allocatable :: b(:,:)
+  ! residual
+  type(scalar3d), allocatable :: r(:,:)
+
+! --- Timer ---
+  type(time), allocatable, dimension(:,:) :: t_restrict
+  type(time), allocatable, dimension(:,:) :: t_prolongate
+  type(time), allocatable, dimension(:,:) :: t_residual
+  type(time), allocatable, dimension(:,:) :: t_addcorr
+  type(time), allocatable, dimension(:,:) :: t_smooth
+  type(time), allocatable, dimension(:,:) :: t_coarsesolve
+  type(time), allocatable, dimension(:,:) :: t_total
+
+contains
+
+!==================================================================
+! Initialise multigrid module, check and print out out parameters
+!==================================================================
+  subroutine mg_initialise(grid_param_in,     &  ! Grid parameters
+                           comm_param_in,     &  ! Comm parameters
+                           model_param_in,    &  ! Model parameters
+                           smoother_param_in, &  ! Smoother parameters
+                           mg_param_in,       &  ! Multigrid parameters
+                           cg_param_in        &  ! CG parameters
+                           )
+    implicit none
+    type(grid_parameters), intent(in)  :: grid_param_in
+    type(comm_parameters), intent(in)  :: comm_param_in
+    type(model_parameters), intent(in) :: model_param_in
+    type(smoother_parameters), intent(in) :: smoother_param_in
+    type(mg_parameters), intent(in)    :: mg_param_in
+    type(cg_parameters), intent(in)    :: cg_param_in
+    real(kind=rl)                      :: L, H
+    integer                            :: n, nz, m, nlocal
+    logical                            :: reduced_m
+    integer                            :: level
+    integer                            :: rank, ierr
+    integer, dimension(2)              :: p_horiz
+    integer, parameter                 :: dim_horiz = 2
+    logical                            :: grid_active
+    integer                            :: ix_min, ix_max, iy_min, iy_max
+    integer                            :: icompx_min, icompx_max, &
+                                          icompy_min, icompy_max
+    integer                            :: halo_size
+    integer                            :: vertbc
+    character(len=32)                  :: t_label
+
+
+
+    if (i_am_master_mpi) &
+      write(STDOUT,*) '*** Initialising multigrid ***'
+    ! Check that cell counts are valid
+    grid_param = grid_param_in
+    comm_param = comm_param_in
+    mg_param = mg_param_in
+    model_param = model_param_in
+    smoother_param = smoother_param_in
+    cg_param = cg_param_in
+    halo_size = comm_param%halo_size
+    vertbc = grid_param%vertbc
+
+    ! Check parameters
+    if (grid_param%n < 2**(mg_param%n_lev-1) ) then
+      call fatalerror('Number of cells in x-/y- direction has to be at least 2^{n_lev-1}.')
+    endif
+
+    if (mod(grid_param%n,2**(mg_param%n_lev-1)) .ne. 0 ) then
+      call fatalerror('Number of cells in x-/y- direction is not a multiple of 2^{n_lev-1}.')
+    end if
+    if (i_am_master_mpi) &
+      write(STDOUT,*) ''
+
+    ! Allocate memory for timers
+    allocate(t_smooth(mg_param%n_lev,0:pproc))
+    allocate(t_total(mg_param%n_lev,0:pproc))
+    allocate(t_restrict(mg_param%n_lev,0:pproc))
+    allocate(t_residual(mg_param%n_lev,0:pproc))
+    allocate(t_prolongate(mg_param%n_lev,0:pproc))
+    allocate(t_addcorr(mg_param%n_lev,0:pproc))
+    allocate(t_coarsesolve(mg_param%n_lev,0:pproc))
+
+    ! Allocate memory for all levels and initialise fields
+    allocate(u(mg_param%n_lev,0:pproc))
+    allocate(b(mg_param%n_lev,0:pproc))
+    allocate(r(mg_param%n_lev,0:pproc))
+    n = grid_param%n
+    nlocal = n/(2**pproc)
+    nz = grid_param%nz
+    L = grid_param%L
+    H = grid_param%H
+    level = mg_param%n_lev
+    m = pproc
+    reduced_m = .false.
+    ! Work out local processor coordinates (this is necessary to identify
+    ! global coordinates)
+    call mpi_comm_rank(MPI_COMM_HORIZ,rank,ierr)
+    call mpi_cart_coords(MPI_COMM_HORIZ,rank,dim_horiz,p_horiz,ierr)
+    if (i_am_master_mpi) then
+      write(STDOUT, &
+        '(" Global gridsize (x,y,z) (pproc = ",I4," )      : ",I8," x ",I8," x ",I8)') &
+        pproc, n, n, nz
+    end if
+    do while (level > 0)
+      if (i_am_master_mpi) &
+        write(STDOUT, &
+          '(" Local gridsize (x,y,z) on level ",I3," m = ",I4," : ",I8," x ",I8," x ",I8)') &
+          level, m, nlocal, nlocal, nz
+      if (nlocal < 1) then
+        call fatalerror('Number of grid points < 1')
+      end if
+
+      ! Set sizes of computational grid (take care at boundaries)
+      if (p_horiz(1) == 0) then
+        icompy_min = 1
+      else
+        icompy_min = 1 - (halo_size - 1)
+      end if
+
+      if (p_horiz(2) == 0) then
+        icompx_min = 1
+      else
+        icompx_min = 1 - (halo_size - 1)
+      end if
+
+      if (p_horiz(1) == 2**pproc-1) then
+        icompy_max = nlocal
+      else
+        icompy_max = nlocal + (halo_size - 1)
+      end if
+
+      if (p_horiz(2) == 2**pproc-1) then
+        icompx_max = nlocal
+      else
+        icompx_max = nlocal + (halo_size - 1)
+      end if
+
+      ! Allocate data
+      allocate(u(level,m)%s(0:nz+1,                       &
+                            1-halo_size:nlocal+halo_size, &
+                            1-halo_size:nlocal+halo_size))
+      allocate(b(level,m)%s(0:nz+1,                       &
+                            1-halo_size:nlocal+halo_size, &
+                            1-halo_size:nlocal+halo_size))
+      allocate(r(level,m)%s(0:nz+1,                       &
+                            1-halo_size:nlocal+halo_size, &
+                            1-halo_size:nlocal+halo_size))
+      u(level,m)%s(:,:,:) = 0.0_rl
+      b(level,m)%s(:,:,:) = 0.0_rl
+      r(level,m)%s(:,:,:) = 0.0_rl
+
+      ! NB: 1st coordinate is in the y-direction of the processor grid,
+      ! second coordinate is in the x-direction (see comments in
+      ! communication module)
+      iy_min = (p_horiz(1)/2**(pproc-m))*nlocal+1
+      iy_max = (p_horiz(1)/2**(pproc-m)+1)*nlocal
+      ix_min = p_horiz(2)/2**(pproc-m)*nlocal+1
+      ix_max = (p_horiz(2)/2**(pproc-m)+1)*nlocal
+      ! Set grid parameters and local data ranges
+      ! Note that only n (and possibly nz) change as we
+      ! move down the levels
+      u(level,m)%grid_param%L = L
+      u(level,m)%grid_param%H = H
+      u(level,m)%grid_param%n = n
+      u(level,m)%grid_param%nz = nz
+      u(level,m)%grid_param%vertbc = vertbc
+      u(level,m)%ix_min = ix_min
+      u(level,m)%ix_max = ix_max
+      u(level,m)%iy_min = iy_min
+      u(level,m)%iy_max = iy_max
+      u(level,m)%icompx_min = icompx_min
+      u(level,m)%icompx_max = icompx_max
+      u(level,m)%icompy_min = icompy_min
+      u(level,m)%icompy_max = icompy_max
+      u(level,m)%halo_size = halo_size
+
+      b(level,m)%grid_param%L = L
+      b(level,m)%grid_param%H = H
+      b(level,m)%grid_param%n = n
+      b(level,m)%grid_param%nz = nz
+      b(level,m)%grid_param%vertbc = vertbc
+      b(level,m)%ix_min = ix_min
+      b(level,m)%ix_max = ix_max
+      b(level,m)%iy_min = iy_min
+      b(level,m)%iy_max = iy_max
+      b(level,m)%icompx_min = icompx_min
+      b(level,m)%icompx_max = icompx_max
+      b(level,m)%icompy_min = icompy_min
+      b(level,m)%icompy_max = icompy_max
+      b(level,m)%halo_size = halo_size
+
+      r(level,m)%grid_param%L = L
+      r(level,m)%grid_param%H = H
+      r(level,m)%grid_param%n = n
+      r(level,m)%grid_param%nz = nz
+      r(level,m)%grid_param%vertbc = vertbc
+      r(level,m)%ix_min = ix_min
+      r(level,m)%ix_max = ix_max
+      r(level,m)%iy_min = iy_min
+      r(level,m)%iy_max = iy_max
+      r(level,m)%icompx_min = icompx_min
+      r(level,m)%icompx_max = icompx_max
+      r(level,m)%icompy_min = icompy_min
+      r(level,m)%icompy_max = icompy_max
+      r(level,m)%halo_size = halo_size
+
+      ! Are these grids active?
+      if ( (m == pproc) .or. &
+           ( (mod(p_horiz(1),2**(pproc-m)) == 0) .and. &
+             (mod(p_horiz(2),2**(pproc-m)) == 0) ) ) then
+        grid_active = .true.
+      else
+        grid_active = .false.
+      end if
+      u(level,m)%isactive = grid_active
+      b(level,m)%isactive = grid_active
+      r(level,m)%isactive = grid_active
+      write(t_label,'("t_total(",I3,",",I3,")")') level, m
+      call initialise_timer(t_total(level,m),t_label)
+      write(t_label,'("t_smooth(",I3,",",I3,")")') level, m
+      call initialise_timer(t_smooth(level,m),t_label)
+      write(t_label,'("t_restrict(",I3,",",I3,")")') level, m
+      call initialise_timer(t_restrict(level,m),t_label)
+      write(t_label,'("t_residual(",I3,",",I3,")")') level, m
+      call initialise_timer(t_residual(level,m),t_label)
+      write(t_label,'("t_prolongate(",I3,",",I3,")")') level, m
+      call initialise_timer(t_prolongate(level,m),t_label)
+      write(t_label,'("t_addcorrection(",I3,",",I3,")")') level, m
+      call initialise_timer(t_addcorr(level,m),t_label)
+      write(t_label,'("t_coarsegridsolver(",I3,",",I3,")")') level, m
+      call initialise_timer(t_coarsesolve(level,m),t_label)
+
+      ! If we are below L_split, split data
+      if ( (level .le. mg_param%lev_split) .and. &
+           (m > 0) .and. (.not. reduced_m) ) then
+        reduced_m = .true.
+        m = m-1
+        nlocal = 2*nlocal
+        cycle
+      end if
+      reduced_m = .false.
+      level = level-1
+      n = n/2
+      nlocal = nlocal/2
+    end do
+    if (i_am_master_mpi) &
+      write(STDOUT,*) ''
+    call cg_initialise(cg_param)
+  end subroutine mg_initialise
+
+!==================================================================
+! Finalise, free memory for all data structures
+!==================================================================
+  subroutine mg_finalise()
+    implicit none
+    integer :: level, m
+    logical :: reduced_m
+    character(len=80) :: s
+    integer :: ierr
+
+    if (i_am_master_mpi) &
+      write(STDOUT,*) '*** Finalising multigrid ***'
+    ! Deallocate memory
+    level = mg_param%n_lev
+    m = pproc
+    reduced_m = .false.
+    call print_timerinfo("--- V-cycle timing results ---")
+    do while (level > 0)
+      write(s,'("level = ",I3,", m = ",I3)') level,m
+      call print_timerinfo(s)
+      call print_elapsed(t_smooth(level,m),.True.,1.0_rl)
+      call print_elapsed(t_restrict(level,m),.True.,1.0_rl)
+      call print_elapsed(t_prolongate(level,m),.True.,1.0_rl)
+      call print_elapsed(t_residual(level,m),.True.,1.0_rl)
+      call print_elapsed(t_addcorr(level,m),.True.,1.0_rl)
+      call print_elapsed(t_coarsesolve(level,m),.True.,1.0_rl)
+      call print_elapsed(t_total(level,m),.True.,1.0_rl)
+      deallocate(u(level,m)%s)
+      deallocate(b(level,m)%s)
+      deallocate(r(level,m)%s)
+      ! If we are below L_split, split data
+      if ( (level .le. mg_param%lev_split) .and. &
+           (m > 0) .and. (.not. reduced_m) ) then
+        reduced_m = .true.
+        m = m-1
+        cycle
+      end if
+      reduced_m = .false.
+      level = level-1
+    end do
+    deallocate(u)
+    deallocate(b)
+    deallocate(r)
+    deallocate(t_total)
+    deallocate(t_smooth)
+    deallocate(t_restrict)
+    deallocate(t_prolongate)
+    deallocate(t_residual)
+    deallocate(t_addcorr)
+    deallocate(t_coarsesolve)
+      if (i_am_master_mpi) write(STDOUT,'("")')
+  end subroutine mg_finalise
+
+!==================================================================
+! Restrict from fine -> coarse
+!==================================================================
+  subroutine restrict(phifine,phicoarse)
+    implicit none
+    type(scalar3d), intent(in)    :: phifine
+    type(scalar3d), intent(inout) :: phicoarse
+    integer :: ix,iy,iz
+    integer :: ix_min, ix_max, iy_min, iy_max
+
+    ix_min = phicoarse%icompx_min
+    ix_max = phicoarse%icompx_max
+    iy_min = phicoarse%icompy_min
+    iy_max = phicoarse%icompy_max
+    ! three dimensional cell average
+    if (mg_param%restriction == REST_CELLAVERAGE) then
+      ! Do not coarsen in z-direction
+      do ix=ix_min,ix_max
+        do iy=iy_min,iy_max
+          do iz=1,phicoarse%grid_param%nz
+            phicoarse%s(iz,iy,ix) =  &
+              phifine%s(iz  ,2*iy  ,2*ix  ) + &
+              phifine%s(iz  ,2*iy-1,2*ix  ) + &
+              phifine%s(iz  ,2*iy  ,2*ix-1) + &
+              phifine%s(iz  ,2*iy-1,2*ix-1)
+          end do
+        end do
+      end do
+    end if
+  end subroutine restrict
+
+
+!==================================================================
+! Prolongate from coarse -> fine
+! level, m is the correspong to the fine grid level
+!==================================================================
+  subroutine prolongate(level,m,phicoarse,phifine)
+    implicit none
+    integer, intent(in) :: level
+    integer, intent(in) :: m
+    type(scalar3d), intent(in) :: phicoarse
+    type(scalar3d), intent(inout) :: phifine
+    real(kind=rl) :: tmp
+    integer :: nlocal
+    integer, dimension(5) :: ixmin, ixmax, iymin, iymax
+    integer :: n, nz
+    integer :: ix, iy, iz
+    integer :: dix, diy, diz
+    real(kind=rl) :: rhox, rhoy, rhoz
+    real(kind=rl) :: rho_i, sigma_j, h, c1, c2
+    logical :: overlap_comms
+    integer, dimension(4) :: send_requests, recv_requests
+    integer :: ierr
+    integer :: iblock
+
+    ! Needed for interpolation matrix
+#ifdef PIECEWISELINEAR
+#else
+    real(kind=rl) :: dx(4,3), A(3,3), dx_fine(4,2)
+    integer :: i,j,k
+    real(kind=rl) :: dxu(2), grad(2)
+    dx(1,3) = 1.0_rl
+    dx(2,3) = 1.0_rl
+    dx(3,3) = 1.0_rl
+    dx(4,3) = 1.0_rl
+#endif
+
+    nlocal = phicoarse%ix_max-phicoarse%ix_min+1
+    n = phicoarse%grid_param%n
+    nz = phicoarse%grid_param%nz
+
+#ifdef OVERLAPCOMMS
+    overlap_comms = (nlocal > 2)
+#else
+    overlap_comms = .false.
+#endif
+    ! Block 1 (N)
+    ixmin(1) = 1
+    ixmax(1) = nlocal
+    iymin(1) = 1
+    iymax(1) = 1
+    ! Block 2 (S)
+    ixmin(2) = 1
+    ixmax(2) = nlocal
+    iymin(2) = nlocal
+    iymax(2) = nlocal
+    ! Block 3 (W)
+    ixmin(3) = 1
+    ixmax(3) = 1
+    iymin(3) = 2
+    iymax(3) = nlocal-1
+    ! Block 4 (E)
+    ixmin(4) = nlocal
+    ixmax(4) = nlocal
+    iymin(4) = 2
+    iymax(4) = nlocal-1
+    ! Block 5 (INTERIOR)
+    if (overlap_comms) then
+      ixmin(5) = 2
+      ixmax(5) = nlocal-1
+      iymin(5) = 2
+      iymax(5) = nlocal-1
+    else
+      ! If there are no interior cells, do not overlap
+      ! communications and calculations, just loop over interior cells
+      ixmin(5) = 1
+      ixmax(5) = nlocal
+      iymin(5) = 1
+      iymax(5) = nlocal
+    end if
+
+    ! *** Constant prolongation or (tri-) linear prolongation ***
+    if ( (mg_param%prolongation == PROL_CONSTANT) .or. &
+         (mg_param%prolongation == PROL_TRILINEAR) ) then
+      if (overlap_comms) then
+        ! Loop over cells next to boundary (iblock = 1,...,4)
+        do iblock = 1, 4
+          if (mg_param%prolongation == PROL_CONSTANT) then
+            call loop_over_grid_constant(iblock)
+          end if
+          if (mg_param%prolongation == PROL_TRILINEAR) then
+            call loop_over_grid_linear(iblock)
+          end if
+        end do
+        ! Initiate halo exchange
+        call ihaloswap(level,m,phifine,send_requests,recv_requests)
+      end if
+      ! Loop over INTERIOR cells
+      iblock = 5
+      if (mg_param%prolongation == PROL_CONSTANT) then
+        call loop_over_grid_constant(iblock)
+      end if
+      if (mg_param%prolongation == PROL_TRILINEAR) then
+        call loop_over_grid_linear(iblock)
+      end if
+      if (overlap_comms) then
+        if (m > 0) then
+          call mpi_waitall(4,recv_requests, MPI_STATUSES_IGNORE, ierr)
+        end if
+      else
+        call haloswap(level,m,phifine)
+      end if
+    else
+      call fatalerror("Unsupported prolongation.")
+    end if
+
+    contains
+
+    !------------------------------------------------------------------
+    ! The actual loops over the grid for the individual blocks,
+    ! when overlapping calculation and communication
+    !------------------------------------------------------------------
+
+    !------------------------------------------------------------------
+    ! (1) Constant interpolation
+    !------------------------------------------------------------------
+    subroutine loop_over_grid_constant(iblock)
+      implicit none
+      integer, intent(in) :: iblock
+      integer :: ix,iy,iz
+      do ix=ixmin(iblock),ixmax(iblock)
+        do iy=iymin(iblock),iymax(iblock)
+          do dix = -1,0
+            do diy = -1,0
+              do iz=1,phicoarse%grid_param%nz
+                phifine%s(iz,2*iy+diy,2*ix+dix) = phicoarse%s(iz,iy,ix)
+              end do
+            end do
+          end do
+        end do
+      end do
+    end subroutine loop_over_grid_constant
+
+    !------------------------------------------------------------------
+    ! (2) Linear interpolation
+    !------------------------------------------------------------------
+    subroutine loop_over_grid_linear(iblock)
+      implicit none
+      integer, intent(in) :: iblock
+      integer :: ix,iy,iz
+      do ix=ixmin(iblock),ixmax(iblock)
+        do iy=iymin(iblock),iymax(iblock)
+#ifdef PIECEWISELINEAR
+          ! Piecewise linear interpolation
+          do iz=1,phicoarse%grid_param%nz
+            do dix = -1,0
+              do diy = -1,0
+                if ( (ix+(2*dix+1)+phicoarse%ix_min-1  < 1 ) .or. &
+                     (ix+(2*dix+1)+phicoarse%ix_min-1  > n ) ) then
+                  rhox = 0.5_rl
+                else
+                  rhox = 0.25_rl
+                end if
+                if ( (iy+(2*diy+1)+phicoarse%iy_min-1  < 1 ) .or. &
+                     (iy+(2*diy+1)+phicoarse%iy_min-1  > n ) ) then
+                  rhoy = 0.5_rl
+                else
+                  rhoy = 0.25_rl
+                end if
+                 phifine%s(iz,2*iy+diy,2*ix+dix) =      &
+                  phicoarse%s(iz,iy,ix) +                &
+                  rhox*(phicoarse%s(iz,iy,ix+(2*dix+1))  &
+                        - phicoarse%s(iz,iy,ix)) +       &
+                  rhoy*(phicoarse%s(iz,iy+(2*diy+1),ix)  &
+                        - phicoarse%s(iz,iy,ix))
+              end do
+            end do
+          end do
+#else
+          ! Fit a plane through the four neighbours of each
+          ! coarse grid point. Use the gradient of this plane and
+          ! the value of the field on the coarse grid point for
+          ! the linear interpolation
+          ! Calculate the displacement vectors
+#ifdef CARTESIANGEOMETRY
+          ! (ix-1, iy)
+          dx(1,1) = -1.0_rl
+          dx(1,2) =  0.0_rl
+          ! (ix+1, iy)
+          dx(2,1) = +1.0_rl
+          dx(2,2) =  0.0_rl
+          ! (ix, iy-1)
+          dx(3,1) =  0.0_rl
+          dx(3,2) = -1.0_rl
+          ! (ix, iy+1)
+          dx(4,1) =  0.0_rl
+          dx(4,2) = +1.0_rl
+#else
+          rho_i = 2.0_rl*(ix+(phicoarse%ix_min-1)-0.5_rl)/n-1.0_rl
+          sigma_j = 2.0_rl*(iy+(phicoarse%iy_min-1)-0.5_rl)/n-1.0_rl
+          if (abs(rho_i**2+sigma_j**2) > 1.0E-12) then
+            c1 = (1.0_rl+rho_i**2+sigma_j**2)/sqrt(rho_i**2+sigma_j**2)
+            c2 = sqrt(1.0_rl+rho_i**2+sigma_j**2)/sqrt(rho_i**2+sigma_j**2)
+          else
+            rho_i = 1.0_rl
+            sigma_j = 1.0_rl
+            c1 = sqrt(0.5_rl)
+            c2 = sqrt(0.5_rl)
+          end if
+          ! (ix-1, iy)
+          dx(1,1) = -c1*rho_i
+          dx(1,2) = +c2*sigma_j
+          ! (ix+1, iy)
+          dx(2,1) = +c1*rho_i
+          dx(2,2) = -c2*sigma_j
+          ! (ix, iy-1)
+          dx(3,1) = -c1*sigma_j
+          dx(3,2) = -c2*rho_i
+          ! (ix, iy+1)
+          dx(4,1) = +c1*sigma_j
+          dx(4,2) = +c2*rho_i
+          dx_fine(1,1) = 0.25_rl*(dx(1,1)+dx(3,1))
+          dx_fine(1,2) = 0.25_rl*(dx(1,2)+dx(3,2))
+          dx_fine(2,1) = 0.25_rl*(dx(2,1)+dx(3,1))
+          dx_fine(2,2) = 0.25_rl*(dx(2,2)+dx(3,2))
+          dx_fine(3,1) = 0.25_rl*(dx(1,1)+dx(4,1))
+          dx_fine(3,2) = 0.25_rl*(dx(1,2)+dx(4,2))
+          dx_fine(4,1) = 0.25_rl*(dx(2,1)+dx(4,1))
+          dx_fine(4,2) = 0.25_rl*(dx(2,2)+dx(4,2))
+#endif
+          ! Boundaries
+          if (ix-1+phicoarse%ix_min-1  < 1 ) then
+            dx(1,1) = 0.5_rl*dx(1,1)
+            dx(1,2) = 0.5_rl*dx(1,2)
+          end if
+          if (ix+1+phicoarse%ix_min-1  > n ) then
+            dx(2,1) = 0.5_rl*dx(2,1)
+            dx(2,2) = 0.5_rl*dx(2,2)
+          end if
+          if (iy-1+phicoarse%iy_min-1  < 1 ) then
+            dx(3,1) = 0.5_rl*dx(3,1)
+            dx(3,2) = 0.5_rl*dx(3,2)
+          end if
+          if (iy+1+phicoarse%iy_min-1  > n ) then
+            dx(4,1) = 0.5_rl*dx(4,1)
+            dx(4,2) = 0.5_rl*dx(4,2)
+          end if
+          ! Calculate matrix used for least squares linear fit
+          A(:,:) = 0.0_rl
+          do i = 1,4
+            do j=1,3
+              do k=1,3
+                A(j,k) = A(j,k) + dx(i,j)*dx(i,k)
+              end do
+            end do
+          end do
+          ! invert A
+          call invertA(A)
+          do iz=1,phicoarse%grid_param%nz
+            ! Calculate gradient on each level
+            dxu(1:2) = dx(1,1:2)*phicoarse%s(iz,iy  ,ix-1) &
+                     + dx(2,1:2)*phicoarse%s(iz,iy  ,ix+1) &
+                     + dx(3,1:2)*phicoarse%s(iz,iy-1,ix  ) &
+                     + dx(4,1:2)*phicoarse%s(iz,iy+1,ix  )
+            grad(:) = 0.0_rl
+            do j=1,2
+              do k=1,2
+                grad(j) = grad(j) + A(j,k)*dxu(k)
+              end do
+            end do
+            ! Use the gradient and the field value in the centre of
+            ! the coarse grid cell to interpolate to the fine grid
+            ! cells
+#ifdef CARTESIANGEOMETRY
+            do dix=-1,0
+              do diy=-1,0
+                phifine%s(iz,2*iy+diy,2*ix+dix) =       &
+                  phicoarse%s(iz,iy,ix)                 &
+                  + 0.25_rl*( grad(1)*(2.0*dix+1)    &
+                            + grad(2)*(2.0*diy+1))
+              end do
+            end do
+#else
+            phifine%s(iz,2*iy-1, 2*ix-1) = phicoarse%s(iz,iy,ix) + &
+                                         ( grad(1)*dx_fine(1,1)  + &
+                                           grad(2)*dx_fine(1,2) )
+            phifine%s(iz,2*iy-1, 2*ix  ) = phicoarse%s(iz,iy,ix) + &
+                                         ( grad(1)*dx_fine(2,1)  + &
+                                           grad(2)*dx_fine(2,2) )
+            phifine%s(iz,2*iy  , 2*ix-1) = phicoarse%s(iz,iy,ix) + &
+                                         ( grad(1)*dx_fine(3,1)  + &
+                                           grad(2)*dx_fine(3,2) )
+            phifine%s(iz,2*iy  , 2*ix  ) = phicoarse%s(iz,iy,ix) + &
+                                         ( grad(1)*dx_fine(4,1)  + &
+                                           grad(2)*dx_fine(4,2) )
+#endif
+          end do
+#endif
+        end do
+      end do
+    end subroutine loop_over_grid_linear
+
+  end subroutine prolongate
+
+  !------------------------------------------------------------------
+  ! Invert the 3x3 matrix A either using LaPack or the explicit
+  ! formula
+  !------------------------------------------------------------------
+  subroutine invertA(A)
+    implicit none
+    real(kind=rl), intent(inout), dimension(3,3) :: A
+    real(kind=rl), dimension(3,3) :: Anew
+    real(kind=rl) :: invdetA
+    integer :: ipiv(3), info
+    real(kind=rl) :: work(3)
+#ifdef USELAPACK
+    call DGETRF( 3, 3, A, 3, ipiv, info )
+    call DGETRI( 3, A, 3, ipiv, work, 3, info )
+#else
+    invdetA = 1.0_rl / ( A(1,1) * (A(3,3)*A(2,2) - A(3,2)*A(2,3)) &
+                       - A(2,1) * (A(3,3)*A(1,2) - A(3,2)*A(1,3)) &
+                       + A(3,1) * (A(2,3)*A(1,2) - A(2,2)*A(1,3)) )
+    Anew(1,1) =     A(3,3)*A(2,2) - A(3,2)*A(2,3)
+    Anew(1,2) = - ( A(3,3)*A(1,2) - A(3,2)*A(1,3) )
+    Anew(1,3) =     A(2,3)*A(1,2) - A(2,2)*A(1,3)
+    Anew(2,1) = - ( A(3,3)*A(2,1) - A(3,1)*A(2,3) )
+    Anew(2,2) =     A(3,3)*A(1,1) - A(3,1)*A(1,3)
+    Anew(2,3) = - ( A(2,3)*A(1,1) - A(2,1)*A(1,3) )
+    Anew(3,1) =     A(3,2)*A(2,1) - A(3,1)*A(2,2)
+    Anew(3,2) = - ( A(3,2)*A(1,1) - A(3,1)*A(1,2) )
+    Anew(3,3) =     A(2,2)*A(1,1) - A(2,1)*A(1,2)
+    A(:,:) = invdetA*Anew(:,:)
+#endif
+  end subroutine invertA
+
+!==================================================================
+! Multigrid V-cycle
+!==================================================================
+  recursive subroutine mg_vcycle(b,u,r,finelevel,splitlevel,level,m)
+    implicit none
+    integer, intent(in)                                     :: finelevel
+    type(scalar3d), intent(inout), dimension(finelevel,0:pproc) :: b
+    type(scalar3d), intent(inout), dimension(finelevel,0:pproc) :: u
+    type(scalar3d), intent(inout), dimension(finelevel,0:pproc) :: r
+    integer, intent(in)                                     :: splitlevel
+    integer, intent(in)                                     :: level
+    integer, intent(in)                                     :: m
+    integer                                                 :: n_gridpoints
+    integer                                                 :: nlocalx, nlocaly
+    integer                                                 :: halo_size
+
+    nlocalx = u(level,m)%ix_max-u(level,m)%ix_min+1
+    nlocaly = u(level,m)%iy_max-u(level,m)%iy_min+1
+    halo_size = u(level,m)%halo_size
+    n_gridpoints = (nlocalx+2*halo_size) &
+                 * (nlocaly+2*halo_size) &
+                 * (u(level,m)%grid_param%nz+2)
+
+    if (level > 1) then
+      ! Perform n_presmooth smoothing steps
+      call start_timer(t_smooth(level,m))
+      call start_timer(t_total(level,m))
+      call smooth(level,m,mg_param%n_presmooth, &
+                  DIRECTION_FORWARD, &
+                  b(level,m),u(level,m))
+      call finish_timer(t_total(level,m))
+      call finish_timer(t_smooth(level,m))
+      ! Calculate residual
+      call start_timer(t_residual(level,m))
+      call start_timer(t_total(level,m))
+      call calculate_residual(level,m,b(level,m),u(level,m),r(level,m))
+      call finish_timer(t_total(level,m))
+      call finish_timer(t_residual(level,m))
+      ! Restrict residual
+      call start_timer(t_restrict(level,m))
+      call start_timer(t_total(level,m))
+      call restrict(r(level,m),b(level-1,m))
+      call finish_timer(t_total(level,m))
+      call finish_timer(t_restrict(level,m))
+      if ( ((level-1) .le. splitlevel) .and. (m > 0) ) then
+        ! Collect data on coarser level
+        call start_timer(t_total(level,m))
+        call collect(level-1,m,b(level-1,m),b(level-1,m-1))
+        call finish_timer(t_total(level,m))
+        ! Set initial solution on coarser level to zero (incl. halos!)
+        u(level-1,m-1)%s(:,:,:) = 0.0_rl
+        ! solve on coarser grid
+        call mg_vcycle(b,u,r,finelevel,splitlevel,level-1,m-1)
+        ! Distribute data on coarser grid
+        call start_timer(t_total(level,m))
+        call distribute(level-1,m,u(level-1,m-1),u(level-1,m))
+        call haloswap(level-1,m,u(level-1,m))
+        call finish_timer(t_total(level,m))
+      else
+        ! Set initial solution on coarser level to zero (incl. halos!)
+        u(level-1,m)%s(:,:,:) = 0.0_rl
+        ! solve on coarser grid
+        call mg_vcycle(b,u,r,finelevel,splitlevel,level-1,m)
+      end if
+      ! Prolongate error
+      call start_timer(t_prolongate(level,m))
+      call start_timer(t_total(level,m))
+      call prolongate(level,m,u(level-1,m),r(level,m))
+      call finish_timer(t_total(level,m))
+      call finish_timer(t_prolongate(level,m))
+      ! Add error to fine grid solution
+      call start_timer(t_addcorr(level,m))
+      call start_timer(t_total(level,m))
+      call daxpy(n_gridpoints,1.0_rl,r(level,m)%s,1,u(level,m)%s,1)
+      call finish_timer(t_total(level,m))
+      call finish_timer(t_addcorr(level,m))
+      ! Perform n_postsmooth smoothing steps
+      call start_timer(t_smooth(level,m))
+      call start_timer(t_total(level,m))
+      call smooth(level,m, &
+                  mg_param%n_postsmooth, &
+                  DIRECTION_BACKWARD, &
+                  b(level,m),u(level,m))
+      call finish_timer(t_total(level,m))
+      call finish_timer(t_smooth(level,m))
+    else
+      call start_timer(t_coarsesolve(level,m))
+      call start_timer(t_total(level,m))
+      if (mg_param%coarsegridsolver == COARSEGRIDSOLVER_CG) then
+        call cg_solve(level,m,b(level,m),u(level,m))
+      else if (mg_param%coarsegridsolver == COARSEGRIDSOLVER_SMOOTHER) then
+        ! Smooth on coarsest level
+        call smooth(level,m, &
+                    mg_param%n_coarsegridsmooth,     &
+                    DIRECTION_FORWARD, &
+                    b(level,m),u(level,m))
+      end if
+      call finish_timer(t_total(level,m))
+      call finish_timer(t_coarsesolve(level,m))
+    end if
+
+  end subroutine mg_vcycle
+
+!==================================================================
+! Test halo exchanges
+!==================================================================
+  recursive subroutine mg_vcyclehaloswaponly(b,u,r,finelevel,splitlevel,level,m)
+    implicit none
+    integer, intent(in)                                     :: finelevel
+    type(scalar3d), intent(inout), dimension(finelevel,0:pproc) :: b
+    type(scalar3d), intent(inout), dimension(finelevel,0:pproc) :: u
+    type(scalar3d), intent(inout), dimension(finelevel,0:pproc) :: r
+    integer, intent(in)                                     :: splitlevel
+    integer, intent(in)                                     :: level
+    integer, intent(in)                                     :: m
+    integer, parameter :: nhaloswap = 100
+    integer :: i
+    integer :: ierr
+
+    if (level > 1) then
+      call mpi_barrier(MPI_COMM_HORIZ,ierr)
+      do i=1,nhaloswap
+        call haloswap(level,m,u(level,m))
+      end do
+      call mpi_barrier(MPI_COMM_HORIZ,ierr)
+      if ( ((level-1) .le. splitlevel) .and. (m > 0) ) then
+        call mpi_barrier(MPI_COMM_HORIZ,ierr)
+        do i=1,nhaloswap
+          call haloswap(level-1,m,u(level-1,m))
+        end do
+        call mpi_barrier(MPI_COMM_HORIZ,ierr)
+        call mg_vcyclehaloswaponly(b,u,r,finelevel,splitlevel,level-1,m-1)
+      else
+        call mg_vcyclehaloswaponly(b,u,r,finelevel,splitlevel,level-1,m)
+      end if
+    else
+      ! Haloswap on coarsest level
+      call mpi_barrier(MPI_COMM_HORIZ,ierr)
+      do i=1,nhaloswap
+        call haloswap(level,m,u(level,m))
+      end do
+      call mpi_barrier(MPI_COMM_HORIZ,ierr)
+    end if
+
+  end subroutine mg_vcyclehaloswaponly
+
+!==================================================================
+! Multigrid solve
+! Assumes that ghosts in initial solution are up-to-date
+!==================================================================
+  subroutine mg_solve(bRHS,usolution,solver_param)
+    implicit none
+    type(scalar3d), intent(in)      :: bRHS
+    type(scalar3d), intent(inout)   :: usolution
+    type(solver_parameters), intent(in) :: solver_param
+    integer :: solvertype
+    real(kind=rl)  :: resreduction
+    integer :: maxiter
+    integer :: n_gridpoints
+    integer :: iter, level, finelevel, splitlevel
+    real(kind=rl) :: res_old, res_new, res_initial
+    logical :: solverconverged = .false.
+    integer :: nlocalx, nlocaly
+    integer :: halo_size
+    type(time) :: t_prec, t_res, t_apply, t_l2norm, t_scalprod, t_mainloop
+    type(scalar3d) :: pp
+    type(scalar3d) :: q
+    real(kind=rl) :: alpha, beta, pq, rz, rz_old
+    integer :: ierr
+
+    solvertype = solver_param%solvertype
+    resreduction = solver_param%resreduction
+    maxiter = solver_param%maxiter
+    nlocalx = usolution%ix_max - usolution%ix_min+1
+    nlocaly = usolution%iy_max - usolution%iy_min+1
+    halo_size = usolution%halo_size
+
+    level = mg_param%n_lev
+    finelevel = level
+    splitlevel = mg_param%lev_split
+
+    ! Initialise timers
+    call initialise_timer(t_prec,"t_prec")
+    call initialise_timer(t_apply,"t_apply")
+    call initialise_timer(t_l2norm,"t_l2norm")
+    call initialise_timer(t_scalprod,"t_scalarprod")
+    call initialise_timer(t_res,"t_residual")
+    call initialise_timer(t_mainloop,"t_mainloop")
+
+    ! Copy b to b(1)
+    ! Copy usolution to u(1)
+    n_gridpoints = (nlocalx+2*halo_size) &
+                 * (nlocaly+2*halo_size) &
+                 * (usolution%grid_param%nz+2)
+    call dcopy(n_gridpoints,bRHS%s,1,b(level,pproc)%s,1)
+    call dcopy(n_gridpoints,usolution%s,1,u(level,pproc)%s,1)
+! Scale with volume of grid cells
+    call volscale_scalar3d(b(level,pproc),1)
+    call start_timer(t_res)
+    call calculate_residual(level, pproc, &
+                            b(level,pproc),u(level,pproc),r(level,pproc))
+    call finish_timer(t_res)
+    call start_timer(t_l2norm)
+    res_initial = l2norm(r(level,pproc),.true.)
+    call finish_timer(t_l2norm)
+    res_old = res_initial
+    if (mg_param%verbose > 0) then
+      if (i_am_master_mpi) then
+        write(STDOUT,'(" *** Multigrid solver ***")')
+        write(STDOUT,'(" <MG> Initial residual : ",E10.5)') res_initial
+      end if
+    end if
+    if (mg_param%verbose > 0) then
+      if (i_am_master_mpi) then
+        write(STDOUT,'(" <MG>   iter :   residual         rho")')
+        write(STDOUT,'(" <MG> --------------------------------")')
+      end if
+    end if
+
+    call mpi_barrier(MPI_COMM_WORLD,ierr)
+    call start_timer(t_mainloop)
+    if (solvertype == SOLVER_CG) then
+      ! NB: b(level,pproc) will be used as r in the following
+      call create_scalar3d(MPI_COMM_HORIZ,bRHS%grid_param,halo_size,pp)
+      call create_scalar3d(MPI_COMM_HORIZ,bRHS%grid_param,halo_size,q)
+      ! Apply preconditioner: Calculate p = M^{-1} r
+      ! (1) copy b <- r
+      call dcopy(n_gridpoints,r(level,pproc)%s,1,b(level,pproc)%s,1)
+      ! (2) set u <- 0
+      u(level,pproc)%s(:,:,:) = 0.0_rl
+      ! (3) Call MG Vcycle
+      call start_timer(t_prec)
+      call mg_vcycle(b,u,r,finelevel,splitlevel,level,pproc)
+      call finish_timer(t_prec)
+      ! (4) copy pp <- u (=solution from MG Vcycle)
+      call dcopy(n_gridpoints,u(level,pproc)%s,1,pp%s,1)
+      ! Calculate rz_old = <pp,b>
+      call start_timer(t_scalprod)
+      call scalarprod(pproc,pp,b(level,pproc),rz_old)
+      call finish_timer(t_scalprod)
+      do iter = 1, maxiter
+        ! Apply matrix q <- A.pp
+        call start_timer(t_apply)
+        call apply(pp,q)
+        call finish_timer(t_apply)
+        ! Calculate pq <- <pp,q>
+        call start_timer(t_scalprod)
+        call scalarprod(pproc,pp,q,pq)
+        call finish_timer(t_scalprod)
+        alpha = rz_old/pq
+        ! x <- x + alpha*p
+        call daxpy(n_gridpoints,alpha,pp%s,1,usolution%s,1)
+        ! b <- b - alpha*q
+        call daxpy(n_gridpoints,-alpha,q%s,1,b(level,pproc)%s,1)
+        ! Calculate norm of residual and exit if it has been
+        ! reduced sufficiently
+        call start_timer(t_l2norm)
+        res_new = l2norm(b(level,pproc),.true.)
+        call finish_timer(t_l2norm)
+        if (mg_param%verbose > 1) then
+          if (i_am_master_mpi) then
+            write(STDOUT,'(" <MG> ",I7," : ",E10.5,"  ",F10.5)') iter, res_new, res_new/res_old
+          end if
+        end if
+        if (res_new/res_initial < resreduction) then
+          solverconverged = .true.
+          exit
+        end if
+        res_old = res_new
+        ! Apply preconditioner q <- M^{-1} b
+        ! (1) Initialise solution u <- 0
+        u(level,pproc)%s(:,:,:) = 0.0_rl
+        ! (2) Call MG Vcycle
+        call start_timer(t_prec)
+        call mg_vcycle(b,u,r,finelevel,splitlevel,level,pproc)
+        call finish_timer(t_prec)
+        ! (3) copy q <- u (solution from MG Vcycle)
+        call dcopy(n_gridpoints,u(level,pproc)%s,1,q%s,1)
+        call start_timer(t_scalprod)
+        call scalarprod(pproc,q,b(level,pproc),rz)
+        call finish_timer(t_scalprod)
+        beta = rz/rz_old
+        ! p <- beta*p
+        call dscal(n_gridpoints,beta,pp%s,1)
+        ! p <- p+q
+        call daxpy(n_gridpoints,1.0_rl,q%s,1,pp%s,1)
+        rz_old = rz
+      end do
+      call destroy_scalar3d(pp)
+      call destroy_scalar3d(q)
+    else if (solvertype == SOLVER_RICHARDSON) then
+      ! Iterate until convergence
+      do iter=1,maxiter
+        call start_timer(t_prec)
+        call mg_vcycle(b,u,r,finelevel,splitlevel,level,pproc)
+        call finish_timer(t_prec)
+        call start_timer(t_res)
+        ! Ghosts are up-to-date here, so no need for halo exchange
+        call calculate_residual(level, pproc, &
+                                b(level,pproc),u(level,pproc),r(level,pproc))
+        call finish_timer(t_res)
+        call start_timer(t_l2norm)
+        res_new = l2norm(r(level,pproc),.true.)
+        call finish_timer(t_l2norm)
+        if (mg_param%verbose > 1) then
+          if (i_am_master_mpi) then
+            write(STDOUT,'(" <MG> ",I7," : ",E10.5,"  ",F10.5)') iter, res_new, res_new/res_old
+          end if
+        end if
+        if (res_new/res_initial < resreduction) then
+          solverconverged = .true.
+          exit
+        end if
+        res_old = res_new
+      end do
+      ! Copy u(1) to usolution
+      call dcopy(n_gridpoints,u(level,pproc)%s,1,usolution%s,1)
+    end if
+    call finish_timer(t_mainloop)
+
+    ! Print out solver information
+    if (mg_param%verbose > 0) then
+      if (solverconverged) then
+        if (i_am_master_mpi) then
+          write(STDOUT,'(" <MG> Final residual    ||r|| = ",E12.6)') res_new
+          write(STDOUT,'(" <MG> Solver converged in ",I7," iterations rho_{avg} = ",F10.5)') &
+          iter, (res_new/res_initial)**(1./(iter))
+        end if
+      else
+        if (i_am_master_mpi) then
+          write(STDOUT,'(" <MG> Solver failed to converge after ",I7," iterations rho_{avg} = ",F10.5)') &
+          maxiter, (res_new/res_initial)**(1./(iter))
+        end if
+      end if
+    end if
+    call print_timerinfo("--- Main iteration timing results ---")
+    call print_elapsed(t_apply,.True.,1.0_rl)
+    call print_elapsed(t_res,.True.,1.0_rl)
+    call print_elapsed(t_prec,.True.,1.0_rl)
+    call print_elapsed(t_l2norm,.True.,1.0_rl)
+    call print_elapsed(t_scalprod,.True.,1.0_rl)
+    call print_elapsed(t_mainloop,.True.,1.0_rl)
+    if (i_am_master_mpi) write(STDOUT,'("")')
+  end subroutine mg_solve
+
+!==================================================================
+! Test haloswap on all levels
+!==================================================================
+  subroutine measurehaloswap()
+    implicit none
+    integer :: iter, level, finelevel, splitlevel
+
+    level = mg_param%n_lev
+    finelevel = level
+    splitlevel = mg_param%lev_split
+    call mg_vcyclehaloswaponly(b,u,r,finelevel,splitlevel,level,pproc)
+  end subroutine measurehaloswap
+
+end module multigrid
+
diff --git a/tensorproductmultigrid_Source/parameters.f90 b/tensorproductmultigrid_Source/parameters.f90
new file mode 100644
index 000000000..3e7a97603
--- /dev/null
+++ b/tensorproductmultigrid_Source/parameters.f90
@@ -0,0 +1,58 @@
+!=== COPYRIGHT AND LICENSE STATEMENT ===
+!
+!  This file is part of the TensorProductMultigrid code.
+!  
+!  (c) The copyright relating to this work is owned jointly by the
+!  Crown, Met Office and NERC [2014]. However, it has been created
+!  with the help of the GungHo Consortium, whose members are identified
+!  at https://puma.nerc.ac.uk/trac/GungHo/wiki .
+!  
+!  Main Developer: Eike Mueller
+!  
+!  TensorProductMultigrid is free software: you can redistribute it and/or
+!  modify it under the terms of the GNU Lesser General Public License as
+!  published by the Free Software Foundation, either version 3 of the
+!  License, or (at your option) any later version.
+!  
+!  TensorProductMultigrid is distributed in the hope that it will be useful,
+!  but WITHOUT ANY WARRANTY; without even the implied warranty of
+!  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+!  GNU Lesser General Public License for more details.
+!  
+!  You should have received a copy of the GNU Lesser General Public License
+!  along with TensorProductMultigrid (see files COPYING and COPYING.LESSER).
+!  If not, see <http://www.gnu.org/licenses/>.
+!
+!=== COPYRIGHT AND LICENSE STATEMENT ===
+
+
+!==================================================================
+!
+!  General parameters
+!
+!    Eike Mueller, University of Bath, Feb 2012
+!
+!==================================================================
+
+module parameters
+
+  implicit none
+
+! floating point precision. Always use rl_kind in code
+  integer, parameter :: single_precision=4        ! single precision
+  integer, parameter :: double_precision=8        ! double precision
+  integer, parameter :: rl=double_precision       ! global switch between
+                                                  ! single/double precision
+! NOTE: As we still use BLAS subroutines, these need to be
+!       replaced as well when switching between double and
+!       single precision!
+  real(kind=rl), parameter :: Tolerance = 1.0e-15
+
+! Output units
+  integer, parameter :: STDOUT = 6
+  integer, parameter :: STDERR = 0
+
+! Numerical constants
+  real(kind=rl), parameter :: two_pi = 6.2831853071795862_rl
+
+end module parameters
diff --git a/tensorproductmultigrid_Source/profiles.f90 b/tensorproductmultigrid_Source/profiles.f90
new file mode 100644
index 000000000..7baf05ed1
--- /dev/null
+++ b/tensorproductmultigrid_Source/profiles.f90
@@ -0,0 +1,174 @@
+!=== COPYRIGHT AND LICENSE STATEMENT ===
+!
+!  This file is part of the TensorProductMultigrid code.
+!  
+!  (c) The copyright relating to this work is owned jointly by the
+!  Crown, Met Office and NERC [2014]. However, it has been created
+!  with the help of the GungHo Consortium, whose members are identified
+!  at https://puma.nerc.ac.uk/trac/GungHo/wiki .
+!  
+!  Main Developer: Eike Mueller
+!  
+!  TensorProductMultigrid is free software: you can redistribute it and/or
+!  modify it under the terms of the GNU Lesser General Public License as
+!  published by the Free Software Foundation, either version 3 of the
+!  License, or (at your option) any later version.
+!  
+!  TensorProductMultigrid is distributed in the hope that it will be useful,
+!  but WITHOUT ANY WARRANTY; without even the implied warranty of
+!  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+!  GNU Lesser General Public License for more details.
+!  
+!  You should have received a copy of the GNU Lesser General Public License
+!  along with TensorProductMultigrid (see files COPYING and COPYING.LESSER).
+!  If not, see <http://www.gnu.org/licenses/>.
+!
+!=== COPYRIGHT AND LICENSE STATEMENT ===
+
+
+!==================================================================
+!
+!  Analytical forms of RHS vectors
+!
+!    Eike Mueller, University of Bath, Feb 2012
+!
+!==================================================================
+module profiles
+
+  use communication
+  use parameters
+  use datatypes
+  use discretisation
+
+  implicit none
+
+  public::initialise_rhs
+  public::analytical_solution
+
+private
+  contains
+
+!==================================================================
+! Initialise RHS vector
+!==================================================================
+  subroutine initialise_rhs(grid_param,model_param,b)
+    implicit none
+    type(grid_parameters), intent(in) :: grid_param
+    type(model_parameters), intent(in) :: model_param
+    type(scalar3d), intent(inout) :: b
+    integer :: ix, iy, iz, ix_min, ix_max, iy_min, iy_max
+    real(kind=rl) :: x, y, z
+    real(kind=rl) :: rho, sigma, theta, phi, r, b_low, b_up, pi
+
+#ifdef TESTCONVERGENCE
+    real(kind=rl) :: px,py,pz
+#endif
+
+    ix_min = b%ix_min
+    ix_max = b%ix_max
+    iy_min = b%iy_min
+    iy_max = b%iy_max
+    b_low = 1.0_rl+0.25*b%grid_param%H
+    b_up = 1.0_rl+0.75*b%grid_param%H
+    pi = 4.0_rl*atan2(1.0_rl,1.0_rl)
+    ! Initialise RHS
+    do ix=ix_min, ix_max
+      do iy=iy_min, iy_max
+        do iz=1,b%grid_param%nz
+          x = 1.0_rl*((ix-0.5_rl)/(1.0_rl*b%grid_param%n))
+          y = 1.0_rl*((iy-0.5_rl)/(1.0_rl*b%grid_param%n))
+          z = 1.0_rl*((iz-0.5_rl)/(1.0_rl*b%grid_param%nz))
+#ifdef TESTCONVERGENCE
+          ! RHS for analytical solution x*(1-x)*y*(1-y)*z*(1-z)
+          if (grid_param%vertbc == VERTBC_DIRICHLET) then
+            px = x*(1.0_rl-x)
+            py = y*(1.0_rl-y)
+            pz = z*(1.0_rl-z)
+            b%s(iz,iy-iy_min+1,ix-ix_min+1) = &
+              ( 2.0_rl*model_param%omega2*((px+py)*pz &
+              + model_param%lambda2*px*py)+model_param%delta*px*py*pz)
+          else
+            px = x*(1.0_rl-x)
+            py = y*(1.0_rl-y)
+            pz = 1.0_rl
+            b%s(iz,iy-iy_min+1,ix-ix_min+1) = &
+              ( 2.0_rl*model_param%omega2*((px+py)*pz)+model_param%delta*px*py*pz)
+          end if
+#else
+          b%s(iz,iy-iy_min+1,ix-ix_min+1) = 0.0_rl
+#ifdef CARTESIANGEOMETRY
+          if ( ( x .ge. 0.1_rl ) .and. ( x .le. 0.4_rl ) &
+            .and. (y .ge. 0.3_rl ) .and. ( y .le. 0.6_rl ) &
+            .and. (z .ge. 0.2_rl ) .and. ( z .le. 0.7_rl ) ) &
+            then
+            b%s(iz,iy-iy_min+1,ix-ix_min+1) = 1.0_rl
+          end if
+#else
+          rho = 2.0_rl*(1.0_rl*ix-0.5_rl)/grid_param%n-1.0_rl
+          sigma = 2.0_rl*(1.0_rl*iy-0.5_rl)/grid_param%n-1.0_rl
+          phi = atan(sigma)
+          theta = atan(rho/sqrt(1.0_rl+sigma**2))
+          x = sin(theta)
+          y = cos(theta)*sin(phi)
+          z = cos(theta)*cos(phi)
+          phi = atan2(x,y)
+          theta = atan2(sqrt(x**2+y**2),z)
+          r = 0.5_rl*(r_grid(iz)+r_grid(iz+1))
+          if (( (r > b_low) .and. (r < b_up) ) .and. &
+              (((theta>pi/10.0_rl) .and. (theta<pi/5.0_rl )) .or. &
+               ((theta>3.0_rl*pi/8.0_rl) .and. (theta<5.0_rl*pi/8.0_rl )) .or. &
+               ((theta>4.0_rl*pi/5.0_rl) .and. (theta<9.0_rl*pi/10.0_rl)))) then
+            b%s(iz,iy-iy_min+1,ix-ix_min+1) = 1.0_rl
+          end if
+!    RHS used in GPU code:
+!          if ( (r > b_low) .and. (r < b_up) .and. &
+!               (rho > -0.5) .and. (rho < 0.5) .and. &
+!               (sigma > -0.5).and. (sigma < 0.5) ) then
+!            b%s(iz,iy-iy_min+1,ix-ix_min+1) = 1.0_rl
+!          end if
+#endif
+#endif
+        end do
+      end do
+    end do
+  end subroutine initialise_rhs
+
+!==================================================================
+! Exact solution for test problem
+!  u(x,y,z) = x*(1-x)*y*(1-y)*z*(1-z)
+!==================================================================
+  subroutine analytical_solution(grid_param,u)
+    implicit none
+    type(grid_parameters), intent(in) :: grid_param
+    type(scalar3d), intent(inout) :: u
+    integer :: ix, iy, iz, ix_min, ix_max, iy_min, iy_max
+    real(kind=rl) :: x, y, z
+
+    ix_min = u%ix_min
+    ix_max = u%ix_max
+    iy_min = u%iy_min
+    iy_max = u%iy_max
+
+    ! Initialise RHS
+    do ix=ix_min, ix_max
+      do iy=iy_min, iy_max
+        do iz=1,u%grid_param%nz
+          x = u%grid_param%L*((ix-0.5_rl)/(1.0_rl*u%grid_param%n))
+          y = u%grid_param%L*((iy-0.5_rl)/(1.0_rl*u%grid_param%n))
+          z = u%grid_param%H*((iz-0.5_rl)/(1.0_rl*u%grid_param%nz))
+          if (grid_param%vertbc == VERTBC_DIRICHLET) then
+            u%s(iz,iy-iy_min+1,ix-ix_min+1) &
+              = x*(1.0_rl-x) &
+              * y*(1.0_rl-y) &
+              * z*(1.0_rl-z)
+          else
+            u%s(iz,iy-iy_min+1,ix-ix_min+1) &
+              = x*(1.0_rl-x) &
+              * y*(1.0_rl-y)
+          end if
+        end do
+      end do
+    end do
+  end subroutine analytical_solution
+
+end module profiles
diff --git a/tensorproductmultigrid_Source/timer.f90 b/tensorproductmultigrid_Source/timer.f90
new file mode 100644
index 000000000..32c980f23
--- /dev/null
+++ b/tensorproductmultigrid_Source/timer.f90
@@ -0,0 +1,184 @@
+!=== COPYRIGHT AND LICENSE STATEMENT ===
+!
+!  This file is part of the TensorProductMultigrid code.
+!  
+!  (c) The copyright relating to this work is owned jointly by the
+!  Crown, Met Office and NERC [2014]. However, it has been created
+!  with the help of the GungHo Consortium, whose members are identified
+!  at https://puma.nerc.ac.uk/trac/GungHo/wiki .
+!  
+!  Main Developer: Eike Mueller
+!  
+!  TensorProductMultigrid is free software: you can redistribute it and/or
+!  modify it under the terms of the GNU Lesser General Public License as
+!  published by the Free Software Foundation, either version 3 of the
+!  License, or (at your option) any later version.
+!  
+!  TensorProductMultigrid is distributed in the hope that it will be useful,
+!  but WITHOUT ANY WARRANTY; without even the implied warranty of
+!  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+!  GNU Lesser General Public License for more details.
+!  
+!  You should have received a copy of the GNU Lesser General Public License
+!  along with TensorProductMultigrid (see files COPYING and COPYING.LESSER).
+!  If not, see <http://www.gnu.org/licenses/>.
+!
+!=== COPYRIGHT AND LICENSE STATEMENT ===
+
+
+!==================================================================
+!
+!  Timer module
+!
+!    Eike Mueller, University of Bath, Feb 2012
+!
+!==================================================================
+
+module timer
+
+  use mpi
+  use parameters
+
+  implicit none
+
+public::initialise_timing
+public::finalise_timing
+public::time
+public::initialise_timer
+public::start_timer
+public::finish_timer
+public::print_timerinfo
+public::print_elapsed
+
+private
+
+! Timer type
+  type time
+    character(len=32) :: label
+    real(kind=rl) :: start
+    real(kind=rl) :: finish
+    integer :: ncall
+    real(kind=rl) :: elapsed
+  end type time
+
+  ! id of timer output file
+  integer, parameter :: TIMEROUT = 9
+
+  ! used my MPI
+  integer :: rank, ierr
+
+contains
+
+!==================================================================
+! Initialise timer module
+!==================================================================
+  subroutine initialise_timing(filename)
+    implicit none
+    character(len=*), intent(in) :: filename
+    call mpi_comm_rank(MPI_COMM_WORLD,rank,ierr)
+    if (rank==0) then
+      open(UNIT=TIMEROUT,FILE=trim(filename))
+      write(STDOUT,'("Writing timer information to file ",A40)') filename
+      write(TIMEROUT,'("# ----------------------------------------------")')
+      write(TIMEROUT,'("# Timer information for geometric multigrid code")')
+      write(TIMEROUT,'("# ----------------------------------------------")')
+    end if
+  end subroutine initialise_timing
+
+!==================================================================
+! Finalise timer module
+!==================================================================
+  subroutine finalise_timing()
+    implicit none
+    if (rank==0) then
+      close(TIMEROUT)
+    end if
+  end subroutine finalise_timing
+
+!==================================================================
+! Initialise timer
+!==================================================================
+  subroutine initialise_timer(t,label)
+    implicit none
+    type(time), intent(inout) :: t
+    character(len=*), intent(in) :: label
+    t%label = label
+    t%start = 0.0_rl
+    t%ncall = 0
+    t%finish = 0.0_rl
+    t%elapsed = 0.0_rl
+  end subroutine initialise_timer
+
+!==================================================================
+! Start timer
+!==================================================================
+  subroutine start_timer(t)
+    implicit none
+    type(time), intent(inout) :: t
+    t%start = mpi_wtime()
+  end subroutine start_timer
+
+!==================================================================
+! Finish timer
+!==================================================================
+  subroutine finish_timer(t)
+    implicit none
+    type(time), intent(inout) :: t
+    t%finish = mpi_wtime()
+    t%elapsed = t%elapsed + (t%finish-t%start)
+    t%ncall = t%ncall + 1
+  end subroutine finish_timer
+
+!==================================================================
+! Print to timer file
+!==================================================================
+  subroutine print_timerinfo(msg)
+    implicit none
+    character(len=*), intent(in) :: msg
+    if (rank == 0) then
+      write(TIMEROUT,*) "# " // trim(msg)
+    end if
+  end subroutine print_timerinfo
+
+!==================================================================
+! Print timer information
+!==================================================================
+  subroutine print_elapsed(t,summaryonly,factor)
+    implicit none
+    type(time), intent(in)        :: t
+    logical, intent(in)           :: summaryonly
+    real(kind=rl), intent(in)     :: factor
+    real(kind=rl) :: elapsedtime
+    real(kind=rl) :: t_min
+    real(kind=rl) :: t_max
+    real(kind=rl) :: t_avg
+    integer :: rank, nprocs, ierr
+    integer :: nc
+
+    elapsedtime = (t%elapsed) * factor
+    call mpi_reduce(elapsedtime,t_min,1,MPI_DOUBLE_PRECISION, &
+                    MPI_MIN, 0, MPI_COMM_WORLD,ierr)
+    call mpi_reduce(elapsedtime,t_avg,1,MPI_DOUBLE_PRECISION, &
+                    MPI_SUM, 0, MPI_COMM_WORLD,ierr)
+    call mpi_reduce(elapsedtime,t_max,1,MPI_DOUBLE_PRECISION, &
+                    MPI_MAX, 0, MPI_COMM_WORLD,ierr)
+    call mpi_comm_size(MPI_COMM_WORLD,nprocs,ierr)
+    call mpi_comm_rank(MPI_COMM_WORLD,rank,ierr)
+    t_avg = t_avg/nprocs
+    nc = t%ncall
+    if (nc == 0) nc = 1
+    if (summaryonly) then
+      if (rank == 0) then
+        write(TIMEROUT,'(A32," [",I7,"]: ",E10.4," / ",E10.4," / ",E10.4," (min/avg/max)")') &
+          t%label,t%ncall,t_min,t_avg,t_max
+        write(TIMEROUT,'(A32,"    t/call: ",E10.4," / ",E10.4," / ",E10.4," (min/avg/max)")') &
+          t%label,t_min/nc,t_avg/nc,t_max/nc
+      end if
+    else
+      write(TIMEROUT,'(A32," : ",I8," calls ",E10.4," (rank ",I8,")")') &
+        t%label,elapsedtime, rank
+    end if
+    write(TIMEROUT,'("")')
+  end subroutine print_elapsed
+
+end module timer
-- 
GitLab