#include <string>
#include <sstream>
#include <iostream>
#include <fstream>
#include <algorithm>
#include "Tree.hpp"
#include "Node.hpp"
#include "Simulation.hpp"

namespace nbody {
	Tree::Tree(std::size_t _parallelId):
		nodes(new Node(this)), //insert dummy root node
		maxLeafBodies(16),
		parallelId(_parallelId)	{}

	Tree::~Tree() {
		clean();
		delete nodes;
	}

	void Tree::setSimulation(Simulation* simulation_) {
		simulation = simulation_;
	}

	void Tree::clean() {
		//remove all nodes; refresh dummy first node
		while (nodes->next != nodes) {
			Node* node = nodes->next;

			node->unlink();
			delete node;
		}
		delete nodes;
		nodes = new Node(this);
	}

	std::size_t Tree::numberOfNodes() const {
		std::size_t noNodes = 0;

		for (Node* node = nodes->next; node != nodes; node = node->next) {
			noNodes++;
		}
		return noNodes;
	}

	bool Tree::isCorrect() const {
		Node* current = nodes->next;

		while (current != nodes) {
			if (!current->isCorrect()) {
				return false;
			}
			current = current->next;
		}
		return true;
	}

	//accumulate force from the whole local tree on parameter body
	void Tree::accumulateForceOnto(Body& body) {
		Node* n = nodes->next;

		body.resetAcceleration();
		while (n != nodes) {
			if (n->sufficientForBody(body)) {
				body.accumulateForceOntoBody(n->representative);
				n = n->afterSubtree;
			} else if (n->leaf) {
				for (const auto& b : n->bodies) {
					body.accumulateForceOntoBody(b);
				}
				n = n->afterSubtree;
			} else {
				n = n->next;
			}
		}
	}

	//accumulate forces for whole tree (local particles)
	void Tree::computeForces() {
		for (Node* n = nodes->next; n != nodes; n = n->next) {
			if (n->leaf) {
				for (auto& b : n->bodies) {
					if (!b.refinement) {
						accumulateForceOnto(b);
					}
				}
			}
		}
	}

	//get local bodies for rebuildiong tree after moving particles
	std::vector<Body> Tree::extractLocalBodies() {
		std::vector<Body> result;

		while (nodes->next != nodes) {
			if (nodes->next->leaf) {
				std::copy_if(std::begin(nodes->next->bodies), std::end(nodes->next->bodies), std::back_inserter(result), [](const Body& b) {return !b.refinement; });
			}
			Node* h = nodes->next;
			nodes->next->unlink();
			delete(h);
		}
		clean();
		return result;
	}

	//get refinement particles required to compute forces within remote domains
	std::vector<Body> Tree::copyRefinements(const Box& domain) const {
		std::vector<Body> result;
		Node* current = nodes->next;

		if (!current->bb.isValid()) {
			//empty tree means no refinements
			return result;
		}
		while (current != nodes) {
			bool sufficient = current->sufficientForBox(domain);

			if (sufficient) {
				if (current->representative.mass > 0.0) {
					result.push_back(current->representative);
				}
				current = current->afterSubtree;
			} else if (current->leaf) {

				result.insert(std::end(result), std::begin(current->bodies), std::end(current->bodies));
				current = current->next;
			} else {
				current = current->next;
			}
		}
		return result;
	}

	//get bounding box of root node
	Box Tree::getRootBB() const {
		return nodes->next->bb;
	}

	//rebuild with predefined root node bounding box
	void Tree::rebuild(const Box& domain) {
		build(extractLocalBodies(), domain);
	}

	void Tree::rebuild() {
		build(extractLocalBodies());
	}

	//rebuild with predefined root node bounding box and bodies
	void Tree::rebuild(const Box& domain, const std::vector<Body>& bodies) {
		build(bodies, domain);
	}

	//move particles according to accumulated forces
	Box Tree::advance() {
		Box bb;

		for (Node* n = nodes->next; n != nodes; n = n->next) {
			if (n->leaf) {
				for (auto& b : n->bodies) {
					if (!b.refinement) {
						b.integrate();
						bb.extend(b);
					}
				}
			}
		}
		return bb;
	}

	//determine local tree bounding box
	Box Tree::getLocalBB() const {
		Box result;
		for (Node* n = nodes->next; n != nodes; n = n->next) {
			if (n->leaf) {
				for (auto& b : n->bodies) {
					if (!b.refinement) {
						result.extend(b);
					}
				}
			}
		}
		return result;
	}

	void Tree::print(std::size_t parallelId) const {
		for (Node* n = nodes->next; n != nodes; n = n->next) {
			n->bb.printBB(parallelId);
			if (n->leaf) {
				for (auto& b : n->bodies) {
					if (!b.refinement) {
						b.print(this->parallelId);
					}
				}
			}
		}
	}
} // namespace nbody
