Skip to content
Communicator.hpp 7.29 KiB
Newer Older
#pragma once

#include <array>
#include <vector>

#include "gsl/multi_span"
#include <mpi.h>

#include "Configuration.hpp"
#include "MpiEnvironment.hpp"
#include "Util.hpp"
#include "state.hpp"

class Communicator {
	constexpr static std::size_t NoNeighbors{8};

	std::vector<int> _neighbors;
	std::vector<int> _sizes;
	std::vector<MPI_Datatype> _sendTypes;
	const std::vector<MPI_Datatype>& _recvTypes{_sendTypes};
	std::vector<MPI_Aint> _sendDisplacements;
	std::vector<MPI_Aint> _recvDisplacements;
	MPI_Comm _commDistGraph{MPI_COMM_NULL};

	MPI_Datatype _haloRowType;
	MPI_Datatype _haloColumnType;
	MPI_Datatype _haloCornerType{MPI_CHAR};

  public:
	Communicator(){};
	Communicator(const MpiEnvironment& env, const Size& gridSize,
	             const Size& tileSize) {
		// Begin definition of basic types
		MPI_Type_contiguous(tileSize.Cols, MPI_CHAR, &_haloRowType);
		MPI_Type_commit(&_haloRowType);

		MPI_Type_vector(tileSize.Rows, 1, tileSize.Cols + 2, MPI_CHAR,
		                &_haloColumnType);
		MPI_Type_commit(&_haloColumnType);

		// End definition of basic types

		// Begin definition of datastructures for a general cell
Thomas Steinreiter's avatar
Thomas Steinreiter committed
		const std::array<MPI_Datatype, NoNeighbors> generalSendTypes{
		    _haloCornerType, _haloRowType,    _haloCornerType, //
		    _haloColumnType, _haloColumnType,                  //
		    _haloCornerType, _haloRowType,    _haloCornerType  //
		};
Thomas Steinreiter's avatar
Thomas Steinreiter committed
		const auto tCols = tileSize.Cols;
		const auto tRows = tileSize.Rows;
		const auto dp =
		    [&](std::size_t x,
		        std::size_t y) { // character coordinates to displacement
			    return static_cast<MPI_Aint>(y * (tCols + 2) + x);
			};
		const std::array<MPI_Aint, NoNeighbors> generalSendDisplacements{
		    dp(1, 1),     dp(1, 1),     dp(tCols, 1),    //
		    dp(1, 1),     dp(tCols, 1),                  //
		    dp(1, tRows), dp(1, tRows), dp(tCols, tRows) //
		};

		const std::array<MPI_Aint, NoNeighbors> generalRecvDisplacements{
		    dp(0, 0),         dp(1, 0),         dp(tCols + 1, 0),        //
		    dp(0, 1),         dp(tCols + 1, 1),                          //
		    dp(0, tRows + 1), dp(1, tRows + 1), dp(tCols + 1, tRows + 1) //
		};

		const std::array<int, NoNeighbors> generalSizes{
		    1, 1, 1, //
		    1, 1,    //
		    1, 1, 1  //
		};
		// End definition of datastructures for a general cell

		// Begin definition of datastructures for this particular cell
		const auto rank2coord = [&](std::size_t rank) {
			return Coord{
			    rank % gridSize.Cols, //
			    rank / gridSize.Cols  //
			};
		};

		const auto coord2rank = [&](Coord c) {
			return gridSize.Cols * c.Y + c.X;
		};
		const auto isInsideGrid = [&](Coord c) {
			return c.X < gridSize.Cols && c.Y < gridSize.Rows;
		};

Thomas Steinreiter's avatar
Thomas Steinreiter committed
		const auto myCoord = rank2coord(env.worldRank());
		const std::array<Coord, NoNeighbors> virtualNeighborCoords{{
		    {myCoord.X - 1, myCoord.Y - 1}, // intentional signed underflow
		    {myCoord.X + 0, myCoord.Y - 1}, //
		    {myCoord.X + 1, myCoord.Y - 1}, //
		    {myCoord.X - 1, myCoord.Y + 0}, //
		    {myCoord.X + 1, myCoord.Y + 0}, //
		    {myCoord.X - 1, myCoord.Y + 1}, //
		    {myCoord.X + 0, myCoord.Y + 1}, //
		    {myCoord.X + 1, myCoord.Y + 1}  //
		}};

		for (std::size_t i{0}; i < NoNeighbors; ++i) {
Thomas Steinreiter's avatar
Thomas Steinreiter committed
			const auto nbrCoord = virtualNeighborCoords[i];
			if (isInsideGrid(nbrCoord)) {
				_neighbors.push_back(coord2rank(nbrCoord));
				_sendTypes.push_back(generalSendTypes[i]);
				_sendDisplacements.push_back(generalSendDisplacements[i]);
				_recvDisplacements.push_back(generalRecvDisplacements[i]);
				_sizes.push_back(generalSizes[i]);
			}
		}

		// if (env.worldRank() == 0) {
		//	std::cout << "neighbors:\n";
		//	for (const auto i : neighbors) std::cout << i << ",\n";
		//	std::cout << '\n';
		//}

		MPI_Dist_graph_create_adjacent(
		    MPI_COMM_WORLD,                         // comm_old
		    _neighbors.size(),                      // indegree
		    _neighbors.data(),                      // sources
		    reinterpret_cast<int*>(MPI_UNWEIGHTED), // sourceweights
		    _neighbors.size(),                      // outdegree
		    _neighbors.data(),                      // destinations
		    reinterpret_cast<int*>(MPI_UNWEIGHTED), // destweights
		    MPI_INFO_NULL,                          // info
		    0,                                      // reorder
		    &_commDistGraph                         // comm_dist_graph
		    );
		// End definition of datastructures for this particular cell
	}

	~Communicator() {
		if (_commDistGraph != MPI_COMM_NULL) {
			MPI_Comm_free(&_commDistGraph);
			MPI_Type_free(&_haloColumnType);
			MPI_Type_free(&_haloRowType);
		}
	}

	friend void swap(Communicator& first, Communicator& second) {
		using std::swap;
		swap(first._neighbors, second._neighbors);
		swap(first._sizes, second._sizes);
		swap(first._sendTypes, second._sendTypes);
		swap(first._sendDisplacements, second._sendDisplacements);
		swap(first._recvDisplacements, second._recvDisplacements);
		swap(first._commDistGraph, second._commDistGraph);
		swap(first._haloRowType, second._haloRowType);
		swap(first._haloColumnType, second._haloColumnType);
		swap(first._haloCornerType, second._haloCornerType);
	}

	Communicator(Communicator&) = delete;
	Communicator& operator=(Communicator&) = delete;
	Communicator(Communicator&& other) noexcept { swap(*this, other); }
	Communicator& operator=(Communicator&& other) noexcept {
		swap(*this, other);
		return *this;
	}

	void Communicate(gsl::multi_span<State, -1, -1>& model) {
		if (_commDistGraph == MPI_COMM_NULL)
			throw std::logic_error("Communicator not initialized");

		MPI_Neighbor_alltoallw(model.data(),              // sendbuf
		                       _sizes.data(),             // sendcounts
		                       _sendDisplacements.data(), // sdispl
		                       _sendTypes.data(),         // sendtypes
		                       model.data(),              // recvbuf
		                       _sizes.data(),             // recvcounts
		                       _recvDisplacements.data(), // rdispls
		                       _recvTypes.data(),         // recvtypes
		                       _commDistGraph             // comm
		                       );
	}

	class MpiRequest {
		MPI_Request _req{MPI_REQUEST_NULL};
		bool finished{};

	  public:
		MpiRequest(MPI_Request req) : _req(req) {}
		void Wait() {
			MPI_Wait(&_req, MPI_STATUS_IGNORE);
			finished = true;
		}
		~MpiRequest() {
			if (!finished) {
				std::cerr << "Forgot to Wait for MPI_Request\n";
				MPI_Abort(MPI_COMM_WORLD, EXIT_FAILURE);
			}
		}
	};

	auto AsyncCommunicate(gsl::multi_span<State, -1, -1>& model) {
		if (_commDistGraph == MPI_COMM_NULL)
			throw std::logic_error("Communicator not initialized");

		MPI_Request req;
		MPI_Ineighbor_alltoallw(model.data(),              // sendbuf
		                        _sizes.data(),             // sendcounts
		                        _sendDisplacements.data(), // sdispl
		                        _sendTypes.data(),         // sendtypes
		                        model.data(),              // recvbuf
		                        _sizes.data(),             // recvcounts
		                        _recvDisplacements.data(), // rdispls
		                        _recvTypes.data(),         // recvtypes
		                        _commDistGraph,            // comm
		                        &req                       // request
		                        );
		return MpiRequest{req};
	}