Skip to content
Communicator.hpp 5.16 KiB
Newer Older
#pragma once

#include <array>
#include <vector>

#include "gsl/multi_span"
#include <mpi.h>

#include "Configuration.hpp"
#include "MpiEnvironment.hpp"
#include "Util.hpp"
#include "state.hpp"

class Communicator {
	constexpr static std::size_t NoNeighbors{8};

	std::vector<int> _neighbors;
	std::vector<int> _sizes;
	std::vector<MPI_Datatype> _sendTypes;
	const std::vector<MPI_Datatype>& _recvTypes{_sendTypes};
	std::vector<MPI_Aint> _sendDisplacements;
	std::vector<MPI_Aint> _recvDisplacements;
	MPI_Comm _commDistGraph{MPI_COMM_NULL};

	MPI_Datatype _haloRowType;
	MPI_Datatype _haloColumnType;
	MPI_Datatype _haloCornerType{MPI_CHAR};

  public:
	Communicator(const MpiEnvironment& env, const Size& gridSize,
	             const Size& tileSize) {
		// Begin definition of basic types
		MPI_Type_contiguous(tileSize.Cols, MPI_CHAR, &_haloRowType);
		MPI_Type_commit(&_haloRowType);

		MPI_Type_vector(tileSize.Rows, 1, tileSize.Cols + 2, MPI_CHAR,
		                &_haloColumnType);
		MPI_Type_commit(&_haloColumnType);

		// End definition of basic types

		// Begin definition of datastructures for a general cell
		const auto& generalSendTypes = std::array<MPI_Datatype, NoNeighbors>{
		    _haloCornerType, _haloRowType,    _haloCornerType, //
		    _haloColumnType, _haloColumnType,                  //
		    _haloCornerType, _haloRowType,    _haloCornerType  //
		};
		const auto& generalRecvTypes = generalSendTypes; // same

		const auto& tCols = tileSize.Cols;
		const auto& tRows = tileSize.Rows;

		auto dp = [&](std::size_t x,
		              std::size_t y) { // character coordinates to displacement
			return static_cast<MPI_Aint>(y * (tCols + 2) + x);
		};

		std::array<MPI_Aint, NoNeighbors> generalSendDisplacements{
		    dp(1, 1),     dp(1, 1),     dp(tCols, 1),    //
		    dp(1, 1),     dp(tCols, 1),                  //
		    dp(1, tRows), dp(1, tRows), dp(tCols, tRows) //
		};

		std::array<MPI_Aint, NoNeighbors> generalRecvDisplacements{
		    dp(0, 0),         dp(1, 0),         dp(tCols + 1, 0),        //
		    dp(0, 1),         dp(tCols + 1, 1),                          //
		    dp(0, tRows + 1), dp(1, tRows + 1), dp(tCols + 1, tRows + 1) //
		};

		std::array<int, NoNeighbors> generalSizes{
		    1, 1, 1, //
		    1, 1,    //
		    1, 1, 1  //
		};
		// End definition of datastructures for a general cell

		// Begin definition of datastructures for this particular cell
		auto rank2coord = [&](std::size_t rank) {
			return Coord{
			    rank % gridSize.Cols, //
			    rank / gridSize.Cols  //
			};
		};

		auto coord2rank = [&](Coord c) { return gridSize.Cols * c.Y + c.X; };

		auto isInsideGrid = [&](Coord c) {
			return c.X < gridSize.Cols && c.Y < gridSize.Rows;
		};

		const auto& myCoord = rank2coord(env.worldRank());
		std::array<Coord, NoNeighbors> virtualNeighborCoords{{
		    {myCoord.X - 1, myCoord.Y - 1}, // intentional signed underflow
		    {myCoord.X + 0, myCoord.Y - 1}, //
		    {myCoord.X + 1, myCoord.Y - 1}, //
		    {myCoord.X - 1, myCoord.Y + 0}, //
		    {myCoord.X + 1, myCoord.Y + 0}, //
		    {myCoord.X - 1, myCoord.Y + 1}, //
		    {myCoord.X + 0, myCoord.Y + 1}, //
		    {myCoord.X + 1, myCoord.Y + 1}  //
		}};

		// interate over all neighbors
		for (std::size_t i{0}; i < NoNeighbors; ++i) {
			const auto& nbrCoord = virtualNeighborCoords[i];
			if (isInsideGrid(nbrCoord)) {
				_neighbors.push_back(coord2rank(nbrCoord));
				_sendTypes.push_back(generalSendTypes[i]);
				_sendDisplacements.push_back(generalSendDisplacements[i]);
				_recvDisplacements.push_back(generalRecvDisplacements[i]);
				_sizes.push_back(generalSizes[i]);
			}
		}

		// if (env.worldRank() == 0) {
		//	std::cout << "neighbors:\n";
		//	for (const auto i : neighbors) std::cout << i << ",\n";
		//	std::cout << '\n';
		//}

		MPI_Dist_graph_create_adjacent(
		    MPI_COMM_WORLD,                         // comm_old
		    _neighbors.size(),                      // indegree
		    _neighbors.data(),                      // sources
		    reinterpret_cast<int*>(MPI_UNWEIGHTED), // sourceweights
		    _neighbors.size(),                      // outdegree
		    _neighbors.data(),                      // destinations
		    reinterpret_cast<int*>(MPI_UNWEIGHTED), // destweights
		    MPI_INFO_NULL,                          // info
		    0,                                      // reorder
		    &_commDistGraph                         // comm_dist_graph
		    );
		// End definition of datastructures for this particular cell
	}

	~Communicator() {
		MPI_Comm_free(&_commDistGraph);
		MPI_Type_free(&_haloColumnType);
		MPI_Type_free(&_haloRowType);
	}

	void Communicate(gsl::multi_span<State, -1, -1>& model) {
		MPI_Neighbor_alltoallw(model.data(),              // sendbuf
		                       _sizes.data(),             // sendcounts
		                       _sendDisplacements.data(), // sdispl
		                       _sendTypes.data(),         // sendtypes
		                       model.data(),              // recvbuf
		                       _sizes.data(),             // recvcounts
		                       _recvDisplacements.data(), // rdispls
		                       _recvTypes.data(),         // recvtypes
		                       _commDistGraph             // comm
		                       );
	}
};