! Main solver routines for heat equation solver
module core
  use heat

contains

  ! Exchange the boundary data between MPI tasks
  subroutine exchange_init(field0, parallel)
    use mpi

    implicit none

    type(field), intent(inout) :: field0
    type(parallel_data), intent(in) :: parallel

    integer :: ierr

    ! Send to left, receive from right
    call mpi_isend(field0%data(0, 1), 1, parallel%columntype, &
         & parallel%nleft, 11, parallel%comm, parallel%requests(1), ierr)
    call mpi_irecv(field0%data(0, field0%ny + 1), 1, parallel%columntype, &
         & parallel%nright, 11, &
         & parallel%comm, parallel%requests(2), ierr)
    
    ! Send to right, receive from left
    call mpi_isend(field0%data(0, field0%ny), 1, parallel%columntype, &
         & parallel%nright, 12, parallel%comm, parallel%requests(3), ierr)
    call mpi_irecv(field0%data(0, 0), 1, parallel%columntype, &
         & parallel%nleft, 12, &
         & parallel%comm, parallel%requests(4), ierr)

    ! Send to up receive from down
    call mpi_isend(field0%data(1, 0), 1, parallel%rowtype, &
         & parallel%nup, 13, parallel%comm, parallel%requests(5), ierr)
    call mpi_irecv(field0%data(field0%nx+1, 0), 1, parallel%rowtype, &
         & parallel%ndown, 13, parallel%comm, parallel%requests(6), ierr)

    ! Send to the down, receive from up
    call mpi_isend(field0%data(field0%nx, 0), 1, parallel%rowtype, &
         & parallel%ndown, 14, parallel%comm, parallel%requests(7), ierr)
    call mpi_irecv(field0%data(0, 0), 1, parallel%rowtype, &
         & parallel%nup, 14, parallel%comm, parallel%requests(8), ierr)

  end subroutine exchange_init

  ! Finalize the non-blocking communication
  subroutine exchange_finalize(parallel)
    use mpi
    implicit none
    type(parallel_data), intent(inout) :: parallel
    integer :: ierr

    call mpi_waitall(8, parallel%requests, mpi_statuses_ignore, ierr)
  end subroutine exchange_finalize

  ! Compute one time step of temperature evolution
  ! Arguments:
  !   curr (type(field)): current temperature values
  !   prev (type(field)): values from previous time step
  !   a (real(dp)): update equation constant
  !   dt (real(dp)): time step value
  subroutine evolve_interior(curr, prev, a, dt)

    implicit none

    type(field), intent(inout) :: curr, prev
    real(dp) :: a, dt
    integer :: i, j, nx, ny

    nx = curr%nx
    ny = curr%ny

    do j = 2, ny - 1
       do i = 2, nx - 1
          curr%data(i, j) = prev%data(i, j) + a * dt * &
               & ((prev%data(i-1, j) - 2.0 * prev%data(i, j) + &
               &   prev%data(i+1, j)) / curr%dx**2 + &
               &  (prev%data(i, j-1) - 2.0 * prev%data(i, j) + &
               &   prev%data(i, j+1)) / curr%dy**2)
       end do
    end do
  end subroutine evolve_interior

  ! Compute one time step of temperature evolution
  ! Arguments:
  !   curr (type(field)): current temperature values
  !   prev (type(field)): values from previous time step
  !   a (real(dp)): update equation constant
  !   dt (real(dp)): time step value
  ! Update only the border-dependent part
  subroutine evolve_edges(curr, prev, a, dt)

    implicit none

    type(field), intent(inout) :: curr, prev
    real(dp) :: a, dt
    integer :: i, j, nx, ny

    nx = curr%nx
    ny = curr%ny

    j = 1
    do i = 1, nx
       curr%data(i, j) = prev%data(i, j) + a * dt * &
            & ((prev%data(i-1, j) - 2.0 * prev%data(i, j) + &
            &   prev%data(i+1, j)) / curr%dx**2 + &
            &  (prev%data(i, j-1) - 2.0 * prev%data(i, j) + &
            &   prev%data(i, j+1)) / curr%dy**2)
    end do
    j = ny
    do i = 1, nx
       curr%data(i, j) = prev%data(i, j) + a * dt * &
            & ((prev%data(i-1, j) - 2.0 * prev%data(i, j) + &
            &   prev%data(i+1, j)) / curr%dx**2 + &
            &  (prev%data(i, j-1) - 2.0 * prev%data(i, j) + &
            &   prev%data(i, j+1)) / curr%dy**2)
    end do
    i = 1
    do j = 1, ny
       curr%data(i, j) = prev%data(i, j) + a * dt * &
            & ((prev%data(i-1, j) - 2.0 * prev%data(i, j) + &
            &   prev%data(i+1, j)) / curr%dx**2 + &
            &  (prev%data(i, j-1) - 2.0 * prev%data(i, j) + &
            &   prev%data(i, j+1)) / curr%dy**2)
    end do
    i = nx
    do j = 1, ny
       curr%data(i, j) = prev%data(i, j) + a * dt * &
            & ((prev%data(i-1, j) - 2.0 * prev%data(i, j) + &
            &   prev%data(i+1, j)) / curr%dx**2 + &
            &  (prev%data(i, j-1) - 2.0 * prev%data(i, j) + &
            &   prev%data(i, j+1)) / curr%dy**2)
    end do

  end subroutine evolve_edges


end module core
