#include <string.h>

#include "world.h"


void world_init_io_type(world_t *world);
void world_free_io_type(world_t *world);

void world_init_neighborhood(world_t *world, MPI_Comm cart_comm, int nprocs[], int proc_coord[]);
void world_free_neighborhood(world_t *world);

void world_init(world_t *world, MPI_Comm cart_comm, size_t *global_size)
{
   int dim, lo, hi;
   int nprocs[2], periods[2], proc_coord[2];
   char *buffer;
   size_t storage_size;

   MPI_Cart_get(cart_comm, 2, nprocs, periods, proc_coord);

   for(dim = 0; dim < 2; dim++) {
      lo = (proc_coord[dim]+0) * global_size[dim] / nprocs[dim];
      hi = (proc_coord[dim]+1) * global_size[dim] / nprocs[dim];

      world->global_size[dim] = global_size[dim];
      world->local_size[dim] = hi - lo;
      world->local_start[dim] = lo;
   }

   storage_size = world_get_storage_size(world);
   world->cells_prev = malloc(storage_size);
   world->cells_next = malloc(storage_size);
   memset(world->cells_prev, ' ', storage_size);
   memset(world->cells_next, ' ', storage_size);

   world_init_io_type(world);
   world_init_neighborhood(world, cart_comm, nprocs, proc_coord);
}

void world_free(world_t *world)
{
   world_free_io_type(world);
   world_free_neighborhood(world);

   free(world->cells_prev);
   free(world->cells_next);
   world->cells_prev = NULL;
   world->cells_next = NULL;
}

void world_init_io_type(world_t *world) {
   const int nx = world->local_size[0], ny = world->local_size[1];
   const int sizes[] = {nx+2, ny+2};
   const int subsizes[] = {nx, ny};
   const int starts[] = {1, 1};

   MPI_Type_create_subarray(2,
      sizes, subsizes, starts,
      MPI_ORDER_FORTRAN, MPI_CHAR, &world->transfer.io_type
   );
   MPI_Type_commit(&world->transfer.io_type);
}

void world_free_io_type(world_t *world)
{
   MPI_Type_free(&world->transfer.io_type);
}

void world_init_neighborhood(world_t *world, MPI_Comm cart_comm, int nprocs[], int proc_coord[])
{
   const int px = proc_coord[0],
             py = proc_coord[1];

   const int npx = nprocs[0],
             npy = nprocs[1];

   const int nx = world->local_size[0],
             ny = world->local_size[1];

   struct halo_info_s {
      int proc_coord[2];
      int subsizes[2];
      int send_starts[2];
      int recv_starts[2];
   };

   const struct halo_info_s halo[] = {
      // Target Proc | Subsize | Send start | Recv start
      { {px-1, py-1},  { 1,  1}, { 1,  1},    {   0,    0} }, // left upper
      { {px,   py-1},  {nx,  1}, { 1,  1},    {   1,    0} }, // upper
      { {px+1, py-1},  { 1,  1}, {nx,  1},    {nx+1,    0} }, // right upper
      { {px-1, py  },  { 1, ny}, { 1,  1},    {   0,    1} }, // left
      { {px+1, py  },  { 1, ny}, {nx,  1},    {nx+1,    1} }, // right
      { {px-1, py+1},  { 1,  1}, { 1, ny},    {   0, ny+1} }, // left lower
      { {px,   py+1},  {nx,  1}, { 1, ny},    {   1, ny+1} }, // lower
      { {px+1, py+1},  { 1,  1}, {nx, ny},    {nx+1, ny+1} }, // right lower
   };

   size_t i, n;
   const int sizes[] = {nx+2, ny+2};
   int neighbor_ranks[8];
   int weights[8];

   MPI_Datatype *send_types = world->transfer.send_types;
   MPI_Datatype *recv_types = world->transfer.recv_types;

   n = 0;
   for(i = 0; i < 8; i++) {
      int x = halo[i].proc_coord[0];
      int y = halo[i].proc_coord[1];

      // Bounds check (Valid neighbor?)
      if(x >= 0 && x < npx && y >= 0 && y < npy) {
         int neighbor_rank;

         // Create send and recevie type
         MPI_Type_create_subarray(2,
            sizes, halo[i].subsizes, halo[i].send_starts,
            MPI_ORDER_FORTRAN, MPI_CHAR, &send_types[n]
         );
         MPI_Type_commit(&send_types[n]);
         MPI_Type_create_subarray(2,
            sizes, halo[i].subsizes, halo[i].recv_starts,
            MPI_ORDER_FORTRAN, MPI_CHAR, &recv_types[n]
         );
         MPI_Type_commit(&recv_types[n]);

         // Get rank of neighbor
         MPI_Cart_rank(cart_comm, halo[i].proc_coord, &neighbor_rank);
         neighbor_ranks[n] = neighbor_rank;
         weights[n] = halo[i].subsizes[0] * halo[i].subsizes[1];
         n++;
      }
   }
   world->transfer.n_neighbors = n;

   // Create graph communicator
   {
      const int allow_reorder = 0;
      MPI_Dist_graph_create_adjacent(cart_comm,
         n, neighbor_ranks, weights,
         n, neighbor_ranks, weights,
         MPI_INFO_NULL, allow_reorder, &world->transfer.graph_comm
      );
   }
}

void world_free_neighborhood(world_t *world)
{
   int i;
   const int n = world->transfer.n_neighbors;
   MPI_Datatype *send_types = world->transfer.send_types;
   MPI_Datatype *recv_types = world->transfer.recv_types;

   for(i = 0; i < n; i++) {
      MPI_Type_free(&send_types[i]);
      MPI_Type_free(&recv_types[i]);
   }
   MPI_Comm_free(&world->transfer.graph_comm);
}

size_t world_get_storage_size(const world_t *world)
{
   int nx = world->local_size[0], ny = world->local_size[1];
   return (nx+2)*(ny+2)*sizeof(char);
}

