#include <algorithm>
#include <array>
#include <iostream>

#include "Communicator.hpp"

Communicator::MpiRequest::MpiRequest(DoubleVector<MPI_Request> reqs)
    : _reqs(reqs) {}

void Communicator::MpiRequest::Wait() {
	MPI_Waitall(static_cast<int>(_reqs.size()), //
	  ,                   //
	            MPI_STATUSES_IGNORE);           //
	finished = true;
Communicator::MpiRequest::~MpiRequest() {
	if (!finished) { MpiReportErrorAbort("Forgot to Wait for MPI_Request"); }
// defines types and graph topology
Communicator::Communicator(const MpiEnvironment& env,
                           CommunicationMode commMode, const Size& procsSize,
                           const Size& tileSize)
    : _commMode(commMode) {
	// Begin definition of basic types
	MPI_Type_contiguous(static_cast<int>(tileSize.Cols), MPI_CHAR,

	MPI_Type_vector(static_cast<int>(tileSize.Rows), 1, static_cast<int>(tileSize.Cols + 2), MPI_CHAR,

	// End definition of basic types

	// Begin definition of types/displacements for a general cell somewhere in
	// the middle of the procs grid
	const std::array<MPI_Datatype, NoNeighbors> generalSendTypes{{
	    _haloCornerType, _haloRowType, _haloCornerType, //
	    _haloColumnType, _haloColumnType,               //
	    _haloCornerType, _haloRowType, _haloCornerType  //
	const auto tCols = tileSize.Cols;
	const auto tRows = tileSize.Rows;

	// character coordinates to displacement
	const auto dp = [&](std::size_t x, std::size_t y) {
		return static_cast<MPI_Aint>(y * (tCols + 2) + x);
	const std::array<MPI_Aint, NoNeighbors> generalSendDisplacements{{
	    dp(1, 1), dp(1, 1), dp(tCols, 1),            //
	    dp(1, 1), dp(tCols, 1),                      //
	    dp(1, tRows), dp(1, tRows), dp(tCols, tRows) //
	const std::array<MPI_Aint, NoNeighbors> generalRecvDisplacements{{
	    dp(0, 0), dp(1, 0), dp(tCols + 1, 0),                        //
	    dp(0, 1), dp(tCols + 1, 1),                                  //
	    dp(0, tRows + 1), dp(1, tRows + 1), dp(tCols + 1, tRows + 1) //
	const std::array<int, NoNeighbors> generalSizes{{
	    1, 1, 1, //
	    1, 1,    //
	    1, 1, 1  //
	// End definition of datastructures for a general cell

	// Begin definition of datastructures for this particular cell (handle the
	// border cases)
	const auto rank2coord = [&](std::size_t rank) {
		return Coord{
		    rank % procsSize.Cols, //
		    rank / procsSize.Cols  //
	const auto coord2rank = [&](Coord c) { return static_cast<int>(procsSize.Cols * c.Y + c.X); };
	const auto isInsideProcsGrid = [&](Coord c) {
		return c.X < procsSize.Cols && c.Y < procsSize.Rows;

	const auto myCoord = rank2coord(env.worldRank());
	const std::array<Coord, NoNeighbors> generalNeighborCoords{{
	    {myCoord.X - 1, myCoord.Y - 1}, // intentional signed underflow
	    {myCoord.X + 0, myCoord.Y - 1}, //
	    {myCoord.X + 1, myCoord.Y - 1}, //
	    {myCoord.X - 1, myCoord.Y + 0}, //
	    {myCoord.X + 1, myCoord.Y + 0}, //
	    {myCoord.X - 1, myCoord.Y + 1}, //
	    {myCoord.X + 0, myCoord.Y + 1}, //
	    {myCoord.X + 1, myCoord.Y + 1}  //

	for (std::size_t i{0}; i < NoNeighbors; ++i) {
		const auto nbrCoord = generalNeighborCoords[i];
		if (isInsideProcsGrid(nbrCoord)) {

	    MPI_COMM_WORLD,                         // comm_old
	    static_cast<int>(_neighbors.size()),    // indegree,                      // sources
	    reinterpret_cast<int*>(MPI_UNWEIGHTED), // sourceweights
	    static_cast<int>(_neighbors.size()),    // outdegree,                      // destinations
	    reinterpret_cast<int*>(MPI_UNWEIGHTED), // destweights
	    MPI_INFO_NULL,                          // info
	    0,                                      // reorder
	    &_commDistGraph                         // comm_dist_graph
	// End definition of datastructures for this particular cell

Communicator::~Communicator() {
	if (_commDistGraph != MPI_COMM_NULL) {

void Communicator::swap(Communicator& first, Communicator& second) {
	using std::swap;
	swap(first._commMode, second._commMode);
	swap(first._neighbors, second._neighbors);
	swap(first._sizes, second._sizes);
	swap(first._sendTypes, second._sendTypes);
	swap(first._sendDisplacements, second._sendDisplacements);
	swap(first._recvDisplacements, second._recvDisplacements);
	swap(first._commDistGraph, second._commDistGraph);
	swap(first._haloRowType, second._haloRowType);
	swap(first._haloColumnType, second._haloColumnType);
	swap(first._haloCornerType, second._haloCornerType);

Communicator::Communicator(Communicator&& other) noexcept {
	swap(*this, other);
Communicator& Communicator::operator=(Communicator&& other) noexcept {
	swap(*this, other);
	return *this;

void Communicator::Communicate(State* model) {
	if (_commDistGraph == MPI_COMM_NULL)
		MpiReportErrorAbort("Communicator not initialized");
	switch (_commMode) {
	case CommunicationMode::Collective:
		MPI_Neighbor_alltoallw(model,                     // sendbuf
		             ,             // sendcounts
		             , // sdispl
		             ,         // sendtypes
		                       model,                     // recvbuf
		             ,             // recvcounts
		             , // rdispls
		             ,         // recvtypes
		                       _commDistGraph             // comm

	case CommunicationMode::P2P: {
	} break;
Communicator::MpiRequest Communicator::AsyncCommunicate(State* model) {
	if (_commDistGraph == MPI_COMM_NULL)
		MpiReportErrorAbort("Communicator not initialized");
	switch (_commMode) {
	case CommunicationMode::Collective: {
		MPI_Request req;
		MPI_Ineighbor_alltoallw(model,                     // sendbuf
		              ,             // sendcounts
		              , // sdispl
		              ,         // sendtypes
		                        model,                     // recvbuf
		              ,             // recvcounts
		              , // rdispls
		              ,         // recvtypes
		                        _commDistGraph,            // comm
		                        &req);                     // request
		return MpiRequest{{req}};
	case CommunicationMode::P2P: {
		Communicator::MpiRequest::DoubleVector<MPI_Request> reqs;
		for (std::size_t i{0}; i < _neighbors.size(); ++i) {
				MPI_Request req;
				MPI_Isend(model + _sendDisplacements[i], // buf
				          1,                             // count
				          _sendTypes[i],                 // datatype
				          _neighbors[i],                 // dest
				          0,                             // tag
				          MPI_COMM_WORLD,                // comm
				          &req);                         // request
				MPI_Request req;
				MPI_Irecv(model + _recvDisplacements[i], // buf
				          1,                             // count
				          _recvTypes[i],                 // datatype
				          _neighbors[i],                 // source
				          0,                             // tag
				          MPI_COMM_WORLD,                // comm
				          &req);                         // request
		return MpiRequest{reqs};