#include <iostream>
#include <cmath>
#include <cstring>
#include <Body.hpp>
#include <Box.hpp>
#include <Tree.hpp>
#include <Node.hpp>
#include <iostream>
#include <BarnesHutTree.hpp>
#include "MpiSimulation.hpp"

namespace nbody {
	MpiSimulation::MpiSimulation(const std::string& inputFile) {
		this->tree = nullptr;
		this->bodyType = MPI_DATATYPE_NULL;
		this->boxType = MPI_DATATYPE_NULL;

		//create MPI datatypes for bodies and domain boxes
		int bodyBlocklengths[6] = { 1, 3, 3, 3, 1, 1 };
		MPI_Datatype bodyDatatypes[6] = { MPI_UINT64_T, MPI_DOUBLE, MPI_DOUBLE, MPI_DOUBLE, MPI_DOUBLE, MPI_CXX_BOOL };
		MPI_Aint bodyOffsets[6];
		bodyOffsets[0] = offsetof(Body, id);
		bodyOffsets[1] = offsetof(Body, position);
		bodyOffsets[2] = offsetof(Body, velocity);
		bodyOffsets[3] = offsetof(Body, acceleration);
		bodyOffsets[4] = offsetof(Body, mass);
		bodyOffsets[5] = offsetof(Body, refinement);
		MPI_Type_create_struct(6, bodyBlocklengths, bodyOffsets, bodyDatatypes, &this->bodyType);
		MPI_Type_commit(&this->bodyType);

		int boxBlocklengths[2] = { 3, 3 };
		MPI_Datatype boxDatatypes[2] = { MPI_DOUBLE, MPI_DOUBLE };
		MPI_Aint boxOffsets[2];
		boxOffsets[0] = offsetof(Box, min);
		boxOffsets[1] = offsetof(Box, max);
		MPI_Type_create_struct(2, boxBlocklengths, boxOffsets, boxDatatypes, &this->boxType);
		MPI_Type_commit(&this->boxType);

		//get number of processes and own process id
		MPI_Comm_size(MPI_COMM_WORLD, &this->parallelSize);
		MPI_Comm_rank(MPI_COMM_WORLD, &this->parallelRank);
		this->domains.reserve(this->parallelSize);
		this->correctState = true;

		//parse input data
		if (!inputFile.empty()) {
			this->correctState = this->readInputData(inputFile);
		}
		else {
			this->correctState = false;
		}
		//broadcast current state and terminate if input file cannot be read
		int mpiCorrectState = static_cast<int>(this->correctState);
		MPI_Bcast(&mpiCorrectState, 1, MPI_INT, 0, MPI_COMM_WORLD);
		if (!this->correctState) {
			std::cerr << "Error occurred: terminating ...\n";
			MPI_Type_free(&this->bodyType);
			MPI_Type_free(&this->boxType);
			MPI_Abort(MPI_COMM_WORLD, -1);
			abort();
		}
	}

	bool MpiSimulation::stateCorrect() {
		return this->correctState;
	}

	MpiSimulation::~MpiSimulation() {
		//cleanup MPI types
		MPI_Type_free(&this->bodyType);
		MPI_Type_free(&this->boxType);
		delete this->tree;
		this->tree = nullptr;
		while (!this->sendStores.empty()) {
			delete[] this->sendStores.back().bodies;
			this->sendStores.pop_back();
		}
	}

	std::size_t MpiSimulation::getNumberOfProcesses() const {
		return this->parallelSize;
	}

	std::size_t MpiSimulation::getProcessId() const {
		return this->parallelRank;
	}

	MPI_Datatype* MpiSimulation::getDatatype() {
		return &this->bodyType;
	}

	//mpi send wrapper
	void MpiSimulation::send(std::vector<Body> bodies, int target) { //TODO(steinret): MPI_BSend, remove SendStore
		std::size_t bodySize = bodies.size();
		SendStore* store = this->availableSendStore(bodySize);

		//do unblocking send
		memcpy(store->bodies, &(bodies[0]), bodySize * sizeof(Body));
		MPI_Isend(store->bodies, bodySize, this->bodyType, target, 0, MPI_COMM_WORLD, &store->request);
	}

	//mpi recv wrapper
	int MpiSimulation::recv(std::vector<Body>& bodies, int source) {
		MPI_Status status;
		int count;

		//do blocking recv; any source receive can be done with source == MPI_ANY_SOURCE
		MPI_Probe(source, 0, MPI_COMM_WORLD, &status);
		MPI_Get_count(&status, this->bodyType, &count);
		bodies.resize(count);
		MPI_Recv(bodies.data(), count, this->bodyType, status.MPI_SOURCE, 0, MPI_COMM_WORLD, &status);
		//return source to determine message source for any source receives
		return status.MPI_SOURCE;
	}

	//initial body distribution
	void MpiSimulation::distributeBodies() {
		//process 0 distributes bodies, others receive
		if (this->parallelRank == 0) {
			std::vector<Node> nodes;
			Box bb;
			nodes.push_back(Node(nullptr));
			nodes.front().setBodies(this->bodies);
			bb.extendForBodies(this->bodies);
			nodes.front().setBB(bb);
			//determine how to distribute bodies to processes
			//split box with most particles by halfing its longest side
			//until number of boxes equals number of processes
			while (nodes.size() < static_cast<std::size_t>(this->parallelSize)) {
				std::size_t mostBodiesIndex = 0;

				for (std::size_t i = 1; i < nodes.size(); i++) {
					if (nodes[i].getBodies().size() > nodes[mostBodiesIndex].getBodies().size()) {
						mostBodiesIndex = i;
					}
				}
				std::vector<Box> subdomains = nodes[mostBodiesIndex].getBB().splitLongestSide();
				std::vector<Body> buf = nodes[mostBodiesIndex].getBodies();
				Node n(nullptr);

				n.setBodies(subdomains[0].extractBodies(buf));
				n.setBB(subdomains[0]);
				nodes.insert(std::begin(nodes) + mostBodiesIndex, n);
				n = Node(nullptr);
				n.setBodies(subdomains[1].extractBodies(buf));
				n.setBB(subdomains[1]);
				nodes.insert(std::begin(nodes) + mostBodiesIndex, n);
				nodes.erase(std::begin(nodes) + mostBodiesIndex + 2);
			}
			this->bodies = nodes[0].getBodies();
			for (std::size_t i = 1; i < nodes.size(); i++) {
				this->send(nodes[i].getBodies(), i);
			}
		} else {
			this->recv(this->bodies, 0);
		}
	}

	void MpiSimulation::distributeDomains(const std::vector<Body>& localBodies) {
		Box localDomain;

		//determine local domain size
		localDomain.extendForBodies(localBodies);

		this->distributeDomains(localDomain);
	}

	void MpiSimulation::distributeDomains() {
		this->distributeDomains(this->bodies);
	}

	//domain distribution, all processes need to know the spatial domains of the others
	void MpiSimulation::distributeDomains(const Box& localDomain) {
		//distribute local domain sizes to all processes through collective MPI operation
		this->domains[this->parallelRank] = localDomain;
		MPI_Allgather(&this->domains[this->parallelRank], 1, this->boxType, &this->domains[0], 1, this->boxType, MPI_COMM_WORLD);
		this->overallDomain = localDomain;
		//determine overall domain size
		for (std::size_t i = 0; i < static_cast<std::size_t>(this->parallelSize); i++) {
			this->overallDomain.extend(this->domains[i]);
		}
	}

	//send stores are needed for unblocking sends, get available one and cleanup unused ones
	SendStore* MpiSimulation::availableSendStore(std::size_t numElems) {
		//determine if theere is a available store for non-blocking particle send
		//cleanup of unused send stores is also done
		auto it = std::begin(sendStores);

		while (it != std::end(sendStores)) {
			bool completed;

			int mpiCompleted;
			MPI_Test(&it->request, &mpiCompleted, MPI_STATUS_IGNORE);
			completed = mpiCompleted != 0;
			if (it->size >= numElems && completed) {
				return &(*it);
			} else if (completed) {
				delete[] it->bodies;
				it = this->sendStores.erase(it);
			} else {
				it++;
			}
		}
		SendStore store;
		store.bodies = new Body[numElems];
		store.size = numElems;
		this->sendStores.push_back(store);
		return &(this->sendStores.back());
	}

	//distribute bodies needed by other processes for their local simlation
	void MpiSimulation::distributeLETs() {
		//send out locally essential trees (local bodies needed by remote simulations, determined by remote domain size)
		for (std::size_t i = 0; i < static_cast<std::size_t>(this->parallelSize); i++) {
			if (i != static_cast<std::size_t>(this->parallelRank)) {
				std::vector<Body> refinements = this->tree->copyRefinements(this->domains[i]);

				this->send(refinements, i);
			}
		}

		//receive bodies and integrate them into local tree for simulation
		for (std::size_t i = 0; i < static_cast<std::size_t>(this->parallelSize - 1); i++) {
			std::vector<Body> refinements;

			//any source receive can be blocking, because we need to wait for data anyhow
			//order is not important, and receiving and merging arriving particles can be overlapped
			this->recv(refinements, MPI_ANY_SOURCE);
			this->tree->mergeLET(refinements);
		}
		if (!this->tree->isCorrect()) {
			std::cerr << "wrong tree\n";
		}
	}

	void MpiSimulation::buildTree() {
		this->tree = new BarnesHutTree(this->parallelRank);
		this->tree->build(this->bodies, this->overallDomain);
		if (!this->tree->isCorrect()) {
			std::cerr << "wrong tree\n";
		}
	}

	void MpiSimulation::rebuildTree() {
		//rebuild tree with moved local particles
		this->tree->rebuild(this->overallDomain);
	}

	//run a simulation step
	void MpiSimulation::runStep() {
		//tree is already built here

		//distribute local bodies needed by remote processes
		this->distributeLETs();
		//force computation
		this->tree->computeForces();
		//advance/move particles and distribute updated domain to other processes
		this->distributeDomains(this->tree->advance());
		//rebuild tree with new particle positions
		this->rebuildTree();
		if (!this->tree->isCorrect()) {
			std::cerr << "wrong tree\n";
		}
	}
} // namespace nbody
