Skip to content
Communicator.cpp 8.07 KiB
Newer Older
#include <algorithm>
#include <array>
#include <iostream>

#include "Communicator.hpp"

Communicator::MpiRequest::MpiRequest(DoubleVector<MPI_Request> reqs)
    : _reqs(reqs) {}

void Communicator::MpiRequest::Wait() {
	MPI_Waitall(_reqs.size(), _reqs.data(), MPI_STATUSES_IGNORE);
	finished = true;
}
Communicator::MpiRequest::~MpiRequest() {
	if (!finished) { MpiReportErrorAbort("Forgot to Wait for MPI_Request"); }
// defines types and graph topology
Communicator::Communicator(const MpiEnvironment& env,
                           CommunicationMode commMode, const Size& gridSize,
                           const Size& tileSize)
    : _commMode(commMode) {
	// Begin definition of basic types
	MPI_Type_contiguous(tileSize.Cols, MPI_CHAR, &_haloRowType);
	MPI_Type_commit(&_haloRowType);

	MPI_Type_vector(tileSize.Rows, 1, tileSize.Cols + 2, MPI_CHAR,
	                &_haloColumnType);
	MPI_Type_commit(&_haloColumnType);

	// End definition of basic types

	// Begin definition of types/displacements for a general cell somewhere in
	// the middle of the grid
	const std::array<MPI_Datatype, NoNeighbors> generalSendTypes{
	    _haloCornerType, _haloRowType,    _haloCornerType, //
	    _haloColumnType, _haloColumnType,                  //
	    _haloCornerType, _haloRowType,    _haloCornerType  //
	};
	const auto tCols = tileSize.Cols;
	const auto tRows = tileSize.Rows;

	// character coordinates to displacement
	const auto dp = [&](std::size_t x, std::size_t y) {
		return static_cast<MPI_Aint>(y * (tCols + 2) + x);
	};

	const std::array<MPI_Aint, NoNeighbors> generalSendDisplacements{
	    dp(1, 1),     dp(1, 1),     dp(tCols, 1),    //
	    dp(1, 1),     dp(tCols, 1),                  //
	    dp(1, tRows), dp(1, tRows), dp(tCols, tRows) //
	};

	const std::array<MPI_Aint, NoNeighbors> generalRecvDisplacements{
	    dp(0, 0),         dp(1, 0),         dp(tCols + 1, 0),        //
	    dp(0, 1),         dp(tCols + 1, 1),                          //
	    dp(0, tRows + 1), dp(1, tRows + 1), dp(tCols + 1, tRows + 1) //
	};

	const std::array<int, NoNeighbors> generalSizes{
	    1, 1, 1, //
	    1, 1,    //
	    1, 1, 1  //
	};
	// End definition of datastructures for a general cell

	// Begin definition of datastructures for this particular cell (handle the
	// border cases)
	const auto rank2coord = [&](std::size_t rank) {
		return Coord{
		    rank % gridSize.Cols, //
		    rank / gridSize.Cols  //
		};
	};

	const auto coord2rank = [&](Coord c) { return gridSize.Cols * c.Y + c.X; };

	const auto isInsideGrid = [&](Coord c) {
		return c.X < gridSize.Cols && c.Y < gridSize.Rows;
	};

	const auto myCoord = rank2coord(env.worldRank());
	const std::array<Coord, NoNeighbors> generalNeighborCoords{{
	    {myCoord.X - 1, myCoord.Y - 1}, // intentional signed underflow
	    {myCoord.X + 0, myCoord.Y - 1}, //
	    {myCoord.X + 1, myCoord.Y - 1}, //
	    {myCoord.X - 1, myCoord.Y + 0}, //
	    {myCoord.X + 1, myCoord.Y + 0}, //
	    {myCoord.X - 1, myCoord.Y + 1}, //
	    {myCoord.X + 0, myCoord.Y + 1}, //
	    {myCoord.X + 1, myCoord.Y + 1}  //
	}};

	for (std::size_t i{0}; i < NoNeighbors; ++i) {
		const auto nbrCoord = generalNeighborCoords[i];
		if (isInsideGrid(nbrCoord)) {
			_neighbors.push_back(coord2rank(nbrCoord));
			_sendTypes.push_back(generalSendTypes[i]);
			_sendDisplacements.push_back(generalSendDisplacements[i]);
			_recvDisplacements.push_back(generalRecvDisplacements[i]);
			_sizes.push_back(generalSizes[i]);
		}
	}

	MPI_Dist_graph_create_adjacent(
	    MPI_COMM_WORLD,                         // comm_old
	    _neighbors.size(),                      // indegree
	    _neighbors.data(),                      // sources
	    reinterpret_cast<int*>(MPI_UNWEIGHTED), // sourceweights
	    _neighbors.size(),                      // outdegree
	    _neighbors.data(),                      // destinations
	    reinterpret_cast<int*>(MPI_UNWEIGHTED), // destweights
	    MPI_INFO_NULL,                          // info
	    0,                                      // reorder
	    &_commDistGraph                         // comm_dist_graph
	    );
	// End definition of datastructures for this particular cell
}

Communicator::~Communicator() {
	if (_commDistGraph != MPI_COMM_NULL) {
		MPI_Comm_free(&_commDistGraph);
		MPI_Type_free(&_haloColumnType);
		MPI_Type_free(&_haloRowType);
	}
}

void Communicator::swap(Communicator& first, Communicator& second) {
	using std::swap;
	swap(first._neighbors, second._neighbors);
	swap(first._sizes, second._sizes);
	swap(first._sendTypes, second._sendTypes);
	swap(first._sendDisplacements, second._sendDisplacements);
	swap(first._recvDisplacements, second._recvDisplacements);
	swap(first._commDistGraph, second._commDistGraph);
	swap(first._haloRowType, second._haloRowType);
	swap(first._haloColumnType, second._haloColumnType);
	swap(first._haloCornerType, second._haloCornerType);
}

Communicator::Communicator(Communicator&& other) noexcept {
	swap(*this, other);
}
Communicator& Communicator::operator=(Communicator&& other) noexcept {
	swap(*this, other);
	return *this;
}

void Communicator::Communicate(gsl::multi_span<State, -1, -1>& model) {
	if (_commDistGraph == MPI_COMM_NULL)
		MpiReportErrorAbort("Communicator not initialized");
	switch (_commMode) {
	case CommunicationMode::Collective:
		MPI_Neighbor_alltoallw(model.data(),              // sendbuf
		                       _sizes.data(),             // sendcounts
		                       _sendDisplacements.data(), // sdispl
		                       _sendTypes.data(),         // sendtypes
		                       model.data(),              // recvbuf
		                       _sizes.data(),             // recvcounts
		                       _recvDisplacements.data(), // rdispls
		                       _recvTypes.data(),         // recvtypes
		                       _commDistGraph             // comm
		                       );

		break;
	case CommunicationMode::P2P: {
		AsyncCommunicate(model).Wait();
	} break;
	default:
		MpiReportErrorAbort("Invalid Communication mode");
}

Communicator::MpiRequest
Communicator::AsyncCommunicate(gsl::multi_span<State, -1, -1>& model) {
	if (_commDistGraph == MPI_COMM_NULL)
		MpiReportErrorAbort("Communicator not initialized");
	switch (_commMode) {
	case CommunicationMode::Collective: {
		MPI_Request req;
		MPI_Ineighbor_alltoallw(model.data(),              // sendbuf
		                        _sizes.data(),             // sendcounts
		                        _sendDisplacements.data(), // sdispl
		                        _sendTypes.data(),         // sendtypes
		                        model.data(),              // recvbuf
		                        _sizes.data(),             // recvcounts
		                        _recvDisplacements.data(), // rdispls
		                        _recvTypes.data(),         // recvtypes
		                        _commDistGraph,            // comm
Thomas Steinreiter's avatar
Thomas Steinreiter committed
		                        &req);                     // request
		return MpiRequest{{req}};
	} break;
	case CommunicationMode::P2P: {
		Communicator::MpiRequest::DoubleVector<MPI_Request> reqs;
		for (std::size_t i{0}; i < _neighbors.size(); ++i) {
Thomas Steinreiter's avatar
Thomas Steinreiter committed
			{
				MPI_Request req;
				MPI_Isend(model.data() + _sendDisplacements[i], // buf
				          1,                                    // count
				          _sendTypes[i],                        // datatype
				          _neighbors[i],                        // dest
				          0,                                    // tag
				          MPI_COMM_WORLD,                       // comm
				          &req);                                // request
				reqs.push_back(req);
			}

			{
				MPI_Request req;
				MPI_Irecv(model.data() + _recvDisplacements[i], // buf
				          1,                                    // count
				          _recvTypes[i],                        // datatype
				          _neighbors[i],                        // source
				          0,                                    // tag
				          MPI_COMM_WORLD,                       // comm
				          &req);                                // request
				reqs.push_back(req);
			}
		}
		return MpiRequest{reqs};
	} break;
	default:
		MpiReportErrorAbort("Invalid Communication mode");