#include <string>
#include <sstream>
#include <iostream>
#include <fstream>
#include <algorithm>
#include "Tree.hpp"
#include "Node.hpp"
#include "Simulation.hpp"

namespace nbody {
	Tree::Tree(int parallelId) {
		//insert dummy root node
		this->nodes = new Node(this);
		this->maxLeafBodies = 16;
		this->parallelId = parallelId;
		this->simulation = nullptr;
	}

	Tree::~Tree() {
		this->clean();
		delete this->nodes;
	}

	void Tree::setSimulation(Simulation* simulation) {
		this->simulation = simulation;
	}

	void Tree::clean() {
		//remove all nodes; refresh dummy first node
		while (this->nodes->next != this->nodes) {
			Node* node = this->nodes->next;

			node->unlink();
			delete node;
		}
		delete this->nodes;
		this->nodes = new Node(this);
	}

	size_t Tree::numberOfNodes() const {
		unsigned long nodes = 0;

		for (Node* node = this->nodes->next; node != this->nodes; node = node->next) {
			nodes++;
		}
		return nodes;
	}

	bool Tree::isCorrect() const {
		Node* current = this->nodes->next;

		while (current != this->nodes) {
			if (!current->isCorrect()) {
				return false;
			}
			current = current->next;
		}
		return true;
	}

	//accumulate force from the whole local tree on parameter body
	void Tree::accumulateForceOnto(Body& body) {
		Node* n = this->nodes->next;

		resetAcceleration(body);
		while (n != this->nodes) {
			if (n->sufficientForBody(body)) {
				accumulateForceOntoBody(body, n->representative);
				n = n->afterSubtree;
			} else if (n->leaf) {
				for (auto b : n->bodies) {
					accumulateForceOntoBody(body, b);
				}
				n = n->afterSubtree;
			} else {
				n = n->next;
			}
		}
	}

	//accumulate forces for whole tree (local particles)
	void Tree::computeForces() {
		for (Node* n = this->nodes->next; n != this->nodes; n = n->next) {
			if (n->leaf) {
				for (auto b : n->bodies) {
					if (!b.refinement) {
						this->accumulateForceOnto(b);
					}
				}
			}
		}
	}

	//get local bodies for rebuildiong tree after moving particles
	std::vector<Body> Tree::extractLocalBodies() {
		std::vector<Body> result;

		while (this->nodes->next != this->nodes) {
			if (this->nodes->next->leaf) {
				std::copy_if(std::begin(nodes->next->bodies), std::end(nodes->next->bodies), std::back_inserter(result), [](const Body& b) {return !b.refinement; });
			}
			Node* h = this->nodes->next;
			this->nodes->next->unlink();
			delete(h);
		}
		this->clean();
		return result;
	}

	//get refinement particles required to compute forces within remote domains
	std::vector<Body> Tree::copyRefinements(const Box& domain) const {
		std::vector<Body> result;
		Node* current = this->nodes->next;

		if (!isValid(current->bb)) {
			//empty tree means no refinements
			return result;
		}
		while (current != this->nodes) {
			bool sufficient = current->sufficientForBox(domain);

			if (sufficient) {
				if (current->representative.mass > 0.0) {
					result.push_back(current->representative);
				}
				current = current->afterSubtree;
			} else if (current->leaf) {

				result.insert(std::end(result), std::begin(current->bodies), std::end(current->bodies));
				current = current->next;
			} else {
				current = current->next;
			}
		}
		return result;
	}

	//get bounding box of root node
	Box Tree::getRootBB() const {
		return this->nodes->next->bb;
	}

	//rebuild with predefined root node bounding box
	void Tree::rebuild(const Box& domain) {
		this->build(this->extractLocalBodies(), domain);
	}

	void Tree::rebuild() {
		this->build(this->extractLocalBodies());
	}

	//rebuild with predefined root node bounding box and bodies
	void Tree::rebuild(const Box& domain, const std::vector<Body>& bodies) {
		this->build(bodies, domain);
	}

	//move particles according to accumulated forces
	Box Tree::advance() {
		Box bb;

		for (Node* n = this->nodes->next; n != this->nodes; n = n->next) {
			if (n->leaf) {
				for (auto& b : n->bodies) {
					if (!b.refinement) {
						integrate(b);
						extend(bb, b);
					}
				}
			}
		}
		return bb;
	}

	//determine local tree bounding box
	Box Tree::getLocalBB() const {
		Box result;

		initBox(result);
		for (Node* n = this->nodes->next; n != this->nodes; n = n->next) {
			if (n->leaf) {
				for (auto& b : n->bodies) {
					if (!b.refinement) {
						extend(result, b);
					}
				}
			}
		}
		return result;
	}

	void Tree::print(int parallelId) const {
		for (Node* n = this->nodes->next; n != this->nodes; n = n->next) {
			printBB(parallelId, n->bb);
			if (n->leaf) {
				for (auto& b : n->bodies) {
					if (!b.refinement) {
						printBody(this->parallelId, b);
					}
				}
			}
		}
	}
}
