#include "BarnesHutTree.hpp"
#include "Node.hpp"
#include "Box.hpp"
#include <iostream>
#include <limits>

namespace nbody {
	BarnesHutTree::BarnesHutTree(std::size_t parallelId) : Tree(parallelId) {}

	//determine octree subboxes
	std::vector<Box> BarnesHutTree::splitBB(const Node* node) {
		return node->getBB().octreeSplit();
	}

	std::size_t BarnesHutTree::numberOfChildren() const {
		return 8;
	}

	//update upper tree nodes according to moved particles
	void BarnesHutTree::update() {
		//iterate for updating representatives
		Node* current = nodes->prev;
		while (current != nodes) {
			current->representative.id = std::numeric_limits<std::size_t>::max();
			current->representative.mass = 0.0;
			current->representative.position[0] = 0.0;
			current->representative.position[1] = 0.0;
			current->representative.position[2] = 0.0;
			if (current->leaf) {
				for (std::size_t i = 0; i < current->bodies.size(); i++) {
					current->representative.mass += current->bodies[i].mass;
					for (std::size_t j = 0; j < 3; j++) {
						current->representative.position[j] += current->bodies[i].position[j] * current->bodies[i].mass;
					}
				}
			} else {
				Node* child = current->next;

				do {
					current->representative.mass += child->representative.mass;
					for (std::size_t j = 0; j < 3; j++) {
						current->representative.position[j] += child->representative.position[j] * child->representative.mass;
					}
					child = child->nextSibling;
				} while (child != nullptr);
			}
			for (std::size_t j = 0; j < 3; j++) {
				if (current->representative.mass > std::numeric_limits<float>::epsilon()) {
					current->representative.position[j] /= current->representative.mass;
				} else {
					current->representative.position[j] = 0.0;
				}
			}
			current = current->prev;
		}
	}

	//split tree node into sub-boxes during tree build
	void BarnesHutTree::split(Node* current) {
		std::vector<Box> subboxes = BarnesHutTree::splitBB(current);
		current->leaf = false;
		Node* after = current->next;

		for (auto it = std::begin(subboxes); it != std::end(subboxes); it++) {
			Node* child = new Node(current->tree);

			child->bb = *it;
			child->bodies = it->copyBodies(current->bodies);
			child->nextSibling = nullptr;
			child->prevSibling = nullptr;
			after->insertBefore(child);
			if (it != std::begin(subboxes)) {
				child->prev->nextSibling = child;
				child->prevSibling = child->prev;
				child->prev->afterSubtree = child;
			}
		}
		after->prev->afterSubtree = current->afterSubtree;
		current->bodies.clear();
	}

	//initialize tree for build process
	void BarnesHutTree::init(const std::vector<Body>& bodies, const Box& domain) {
		Node* current;

		clean();
		if (bodies.empty()) return;
		//insert root node
		nodes->insertAfter(new Node(this));
		current = nodes->next;
		//assign bodies to root node
		current->bodies = bodies;
		//setup proper bounding box
		current->bb = domain;
		current->extendBBforBodies();
		current->extendBBtoCube();
		current->afterSubtree = current->next;
	}

	//check if split is required and perform it
	bool BarnesHutTree::splitNode(Node* current) {
		bool result = current->isSplitable();

		if (result) {
			split(current);
		}
		return result;
	}

	//build tree with given domain
	void BarnesHutTree::build(const std::vector<Body>& bodies, const Box& domain) {
		init(bodies, domain);
		//iterate over existing boxes and split if it contains too much bodies
		BarnesHutTree::splitSubtree(nodes->next);
		update();
	}

	//build tree
	void BarnesHutTree::build(const std::vector<Body>& bodies) {
		Box bb;
		bb.extendForBodies(bodies);
		build(bodies, bb);
	}

	//merge remote refinement particles into local tree
	//(this are remote particles from other processes needed for force computation on local particles)
	void BarnesHutTree::mergeLET(const std::vector<Body>& bodies) {
		//put all new bodies into fitting leaves, walk through tree and split
		Node* current;

		for (auto it = std::begin(bodies); it != std::end(bodies); it++) {
			current = nodes->next;
			while (!current->leaf) {
				Node* child = current->next;

				while (child != nullptr && !child->getBB().contained(it->position)) {
					child = child->nextSibling;
				}
				//TODO(pheinzlr): check for child == nullptr?
				current = child;
			}
			current->bodies.push_back(*it);
			current->bodies.back().refinement = true;
		}
		current = nodes->next;
		while (current != nodes) {
			splitNode(current);
			current = current->next;
		}
		update();
	}

	//node splitting if required
	void BarnesHutTree::splitSubtree(Node* root) {
		bool toSplitLeft;
		Node* current = root;

		do {
			toSplitLeft = false;
			while (current != root->afterSubtree) {
				if (current->isSplitable()) {
					split(current);
					toSplitLeft = true;
				}
				current = current->next;
			}
		} while (toSplitLeft);
	}
} // namespace nbody
