#include <mpi.h>

#include "simulation.h"

#define ELECTRON_HEAD '@'
#define ELECTRON_TAIL '~'
#define WIRE '#'

void wireworld_step(world_t *world, size_t i_start, size_t bx, size_t by);

void do_simulation(world_t *world, size_t n_generations)
{
   const size_t nx = world->local_size[0];
   const size_t ny = world->local_size[1];

   const size_t DOWN = nx+2; // (+2 ... for halo cells)

   const size_t i_leftupper  = 1  +    DOWN;
   const size_t i_rightupper = nx +    DOWN;
   const size_t i_leftlower  = 1  + ny*DOWN;

   size_t g;
   char *tmp;

   const int counts[] = {1, 1, 1, 1, 1, 1, 1, 1};
   const MPI_Aint displs[] = {0, 0, 0, 0, 0, 0, 0, 0};
   const transfer_t *transfer = &world->transfer;
   
   MPI_Request request;

   for(g = 0; g < n_generations; g++) {
      tmp = world->cells_prev;
      world->cells_prev = world->cells_next;
      world->cells_next = tmp;

      // Start halo exchange
      MPI_Ineighbor_alltoallw(
         world->cells_prev, counts, displs, transfer->send_types,
         world->cells_prev, counts, displs, transfer->recv_types,
         transfer->graph_comm, &request
      );

      // Compute inner region
      wireworld_step(world, i_leftupper+1+DOWN, nx-2, ny-2);

      // Finish halo exchange
      MPI_Wait(&request, MPI_STATUS_IGNORE);

      // Compute boundary regions
      wireworld_step(world, i_leftupper, nx, 1); // upper
      wireworld_step(world, i_leftlower, nx, 1); // lower
      wireworld_step(world, i_leftupper+DOWN, 1, ny-2); // left
      wireworld_step(world, i_rightupper+DOWN, 1, ny-2); // right

/*    Blocking variant:

      MPI_Neighbor_alltoallw(
         world->cells_prev, counts, displs, transfer->send_types,
         world->cells_prev, counts, displs, transfer->recv_types,
         transfer->graph_comm
      );
      wireworld_step(world, i_leftupper, nx, ny);
*/
   }
}

void wireworld_step(world_t *world, size_t i_start, size_t bx, size_t by)
{
   const size_t L = -1, R = 1;
   const size_t D = world->local_size[0]+2; // (+2 ... for halo cells)
   const size_t U = -D;
   size_t x, y, i;
   int nheads;

   char *prev = world->cells_prev;
   char *next = world->cells_next;

   for(y = 0; y < by; y++) {
      i = i_start;
      for(x = 0; x < bx; x++) {
         switch(prev[i]) {
            // Electron heads become electron tails.
            case ELECTRON_HEAD: next[i] = ELECTRON_TAIL; break;

            // Electron tails become copper.
            case ELECTRON_TAIL: next[i] = WIRE; break;

            // New electron head replacing copper,
            // if 1 or 2 electron heads are in neighborhood.
            case WIRE:
                  nheads =
                     (prev[i+L+U] == ELECTRON_HEAD) +
                     (prev[i  +U] == ELECTRON_HEAD) +
                     (prev[i+R+U] == ELECTRON_HEAD) +
                     (prev[i+L  ] == ELECTRON_HEAD) +
                     (prev[i+R  ] == ELECTRON_HEAD) +
                     (prev[i+L+D] == ELECTRON_HEAD) +
                     (prev[i  +D] == ELECTRON_HEAD) +
                     (prev[i+R+D] == ELECTRON_HEAD);
                  if(nheads == 1 || nheads == 2) {
                     next[i] = ELECTRON_HEAD;
                  } else {
                     next[i] = WIRE;
                  }
                  break;

            default: break;
         }
         i++;
      }
      i_start += D;
   }
}

