#include <cstdlib>
#include <limits>
#include <cmath>
#include <iostream>
#include <algorithm>
#include "Node.hpp"

namespace nbody {
	Node::Node(Tree* tree) {
		initBox(this->bb);
		this->afterSubtree = nullptr;
		this->prev = this;
		this->next = this;
		this->leaf = true;
		this->tree = tree;
		this->prevSibling = nullptr;
		this->nextSibling = nullptr;
		this->parent = nullptr;
	}

	Node::~Node() {
	}


	Box Node::getBB() {
		return this->bb;
	}

	void Node::setBB(Box bb) {
		this->bb = bb;
	}

	//check if node needs to be splitted during tree build
	bool Node::isSplitable() {
		bool result = true;

		if (this->bodies.size() <= this->tree->maxLeafBodies) {
			result = false;
		}
		//this is to prevent errors with collocated particles
		if (volume(this->bb) <= std::numeric_limits<float>::epsilon()) {
			result = false;
		}
		return result;
	}

	void Node::extendBBforBodies() {
		extendForBodies(this->bb, this->bodies);
	}

	void Node::extendBBtoCube() {
		extendToCube(this->bb);
	}

	std::vector<Body> Node::getBodies() {
		return this->bodies;
	}

	void Node::insertBefore(Node* node) {
		node->next = this;
		node->prev = this->prev;
		this->prev->next = node;
		this->prev = node;
	}

	void Node::insertAfter(Node* node) {
		this->next->insertBefore(node);
	}

	void Node::unlink() {
		this->next->prev = this->prev;
		this->prev->next = this->next;
		this->next = this;
		this->prev = this;
	}

	bool Node::isCorrect() {
		if (this->afterSubtree == NULL) {
			std::cerr << "after subtree null\n";
			return false;
		}
		if (!isCorrectBox(this->bb)) {
			std::cerr << "bb wrong\n";
			return false;
		}
		for (int i = 0; i < 3; i++) {
			if (this->bb.min[i] > this->bb.max[i]) {
				std::cerr << "bb " << i << " min " << this->bb.min[i] << " max " << this->bb.max[i] << '\n';
				return false;
			}
		}
		if (std::any_of(std::begin(bodies), std::end(bodies), [&](const Body& b) {return !isContained(b, bb); })) {
			std::cerr << "bb out of bounds\n";
			return false;
		}
		if (!this->leaf) {
			Node* current = this->next;
			int children = 0;

			while (current != NULL && current != this->afterSubtree) {
				current = current->afterSubtree;
				children++;
			}
			if (current == NULL) {
				std::cerr << "afterSubtree null\n";
				return false;
			}
			if (children != this->tree->numberOfChildren()) {
				std::cerr << "wrong number of children " << children << '\n';
				return false;
			}
			current = this->next;
			for (int i = 0; i < this->tree->numberOfChildren(); i++) {
				current = current->afterSubtree;
			}
			if (current != this->afterSubtree) {
				std::cerr << "last sibling afterSubtree inconsistent\n";
				return false;
			}
		}
		if (!this->leaf && this->bodies.size() > 0) {
			std::cerr << "non-empty inner node\n";
			return false;
		}
		if (this->leaf && this->nextSibling != NULL && this->next != this->nextSibling) {
			std::cerr << "wrong next sibling\n";
			return false;
		}
		return true;
	}

	//update overall node information
	void Node::update() {
		double position[3] = {0.0, 0.0, 0.0};
		double mass = 0.0;

		if (this->leaf) {
			for (auto& body : bodies) {
				mass += body.mass;
				for (int i = 0; i < 3; i++) {
					position[i] += body.position[i] * body.mass;
				}
			}
		} else {
			for (Node* node = this->next; node != NULL; node = node->nextSibling) {
				mass += node->representative.mass;
			}
		}
		for (int i = 0; i < 3; i++) {
			this->representative.position[i] = position[i] / mass;
		}
		this->representative.mass = mass;
	}

	//get criterion to check if node is sufficient for force evaluation
	double Node::getL() {
		return maxSidelength(this->bb);
	}

	void Node::print(int parallelId) {
		printBB(parallelId, this->bb);
		for (auto& body : bodies) {
			std::cout << "  ";
			printBody(parallelId, body);
		}
	}

	//check if node is sufficient for force evaluation
	bool Node::sufficientForBody(Body body) {
		double distance = 0.0;

		for (int i = 0; i < 3; i++) {
			distance += (this->representative.position[i] - body.position[i]) * (this->representative.position[i] - body.position[i]);
		}
		return sqrt(distance) > this->getL();
	}

	//check if node is sufficient for force evaluation for all bodies in box
	bool Node::sufficientForBox(Box box) {
		return distanceToBox(this->bb, box) > this->getL();
	}

	void Node::setBodies(std::vector<Body> bodies) {
		this->bodies = bodies;
	}

	//get local bodies
	void Node::extractLocalBodiesTo(std::vector<Body>& result) {
		std::copy_if(std::begin(this->bodies), std::end(this->bodies), std::back_inserter(result), [](const Body& b) {return !b.refinement; });
		this->bodies.clear();
	}
}