Skip to content
Snippets Groups Projects
Commit 85d65ca8 authored by Thomas Steinreiter's avatar Thomas Steinreiter
Browse files

use iterators for traversing node data structures

parent 5ab53430
Branches
No related merge requests found
#include <limits>
#include <cmath>
#include <iostream>
#include <functional>
#include <algorithm>
#include "Node.hpp"
#include "Tree.hpp"
namespace nbody {
Node::Node(Tree* _tree):tree(_tree) {}
Node::Node(Tree* _tree):tree(_tree),
siblingNodesView(_tree) {}
Box Node::getBB() const {
return bb;
......@@ -118,14 +121,10 @@ namespace nbody {
double mass = 0.0;
if (leaf) {
for (const auto& body : bodies) {
mass += body.mass;
position += body.position * body.mass;
}
mass = std::accumulate(std::begin(bodies), std::end(bodies), 0.0, [](double m, const Body& b) { return m + b.mass; });
position = std::accumulate(std::begin(bodies), std::end(bodies), Vec3{}, [](Vec3 p, const Body& b) { return p + (b.position * b.mass); });
} else {
for (Node* node = next; node != nullptr; node = node->nextSibling) {
mass += node->representative.mass;
}
mass = std::accumulate(std::begin(siblingNodesView), std::end(siblingNodesView), 0.0, [](double m, const Node& n) { return m + n.representative.mass; });
}
representative.position = position / mass;
representative.mass = mass;
......@@ -167,4 +166,19 @@ namespace nbody {
std::copy_if(std::begin(bodies), std::end(bodies), std::back_inserter(result), [](const Body& b) {return !b.refinement; });
bodies.clear();
}
Node::SiblingNodesView::SiblingNodesView(Tree* tree) : tree_(tree) {};
Node::SiblingNodeIterator Node::SiblingNodesView::begin() {
return Node::SiblingNodeIterator{ tree_->nodes->next };
}
Node::SiblingNodeIterator Node::SiblingNodesView::end() {
return Node::SiblingNodeIterator{ nullptr };
}
Node::SiblingNodeIterator Node::SiblingNodesView::begin() const {
return Node::SiblingNodeIterator{ tree_->nodes->next };
}
Node::SiblingNodeIterator Node::SiblingNodesView::end() const {
return Node::SiblingNodeIterator{ nullptr };
}
} // namespace nbody
......@@ -2,16 +2,17 @@
#define TREE_NODE_HPP
#include "Body.hpp"
#include <iterator>
#include "Box.hpp"
#include "Tree.hpp"
#include "BarnesHutTree.hpp"
#include <cstdlib>
#include <vector>
namespace nbody {
class Tree;
//class for storing node information
class Node {
friend class Tree;
friend class NodeIterator;
friend class BarnesHutTree;
protected:
Box bb;
......@@ -45,6 +46,60 @@ namespace nbody {
virtual void setBodies(const std::vector<Body>& bodies_);
virtual void setBodies(std::vector<Body>&& bodies_);
virtual void extractLocalBodiesTo(std::vector<Body>& result);
// iterator classes for traversing of pointered node data structures
template<typename Derived>
class BaseNodeIterator {
protected:
Node* node_{ nullptr };
BaseNodeIterator(Node * node) :node_(node) {};
public:
typedef std::ptrdiff_t difference_type;
typedef Node value_type;
typedef Node& reference;
typedef Node* pointer;
typedef std::bidirectional_iterator_tag iterator_category;
const BaseNodeIterator& operator++() {
return static_cast<Derived*>(this)->operator++();
}
BaseNodeIterator operator++(int) {
auto result = *this; ++(*this); return result;
}
BaseNodeIterator& operator--() {
return static_cast<Derived*>(this)->operator--();
}
BaseNodeIterator operator--(int) {
auto result = *this; --(*this); return result;
}
reference operator*() const { return *node_; }
pointer operator->() const { return node_; }
friend bool operator==(const BaseNodeIterator& lhs, const BaseNodeIterator& rhs) { return lhs.node_ == rhs.node_; }
friend bool operator!=(const BaseNodeIterator& lhs, const BaseNodeIterator& rhs) { return !(lhs == rhs); }
};
struct NodeIterator : BaseNodeIterator<NodeIterator>{
NodeIterator(Node * node) : BaseNodeIterator(node) {};
const BaseNodeIterator& operator++() { node_ = node_->next; return *this; }
BaseNodeIterator& operator--() { node_ = node_->prev; return *this; }
};
struct SiblingNodeIterator : BaseNodeIterator<SiblingNodeIterator>{
SiblingNodeIterator(Node * node) : BaseNodeIterator(node) {};
const BaseNodeIterator& operator++() { node_ = node_->nextSibling; return *this; }
BaseNodeIterator& operator--() { node_ = node_->prevSibling; return *this; }
};
private:
class SiblingNodesView {
Tree* tree_;
public:
SiblingNodesView(Tree* tree);
SiblingNodeIterator begin();
SiblingNodeIterator end();
SiblingNodeIterator begin() const;
SiblingNodeIterator end() const;
};
SiblingNodesView siblingNodesView;
};
} // namespace nbody
......
......@@ -11,7 +11,8 @@ namespace nbody {
Tree::Tree(std::size_t _parallelId):
nodes(new Node(this)), //insert dummy root node
maxLeafBodies(16),
parallelId(_parallelId) {}
parallelId(_parallelId),
nodesView(this){}
Tree::~Tree() {
clean();
......@@ -35,12 +36,7 @@ namespace nbody {
}
std::size_t Tree::numberOfNodes() const {
std::size_t noNodes = 0;
for (Node* node = nodes->next; node != nodes; node = node->next) {
noNodes++;
}
return noNodes;
return std::abs(std::distance(std::begin(nodesView), std::end(nodesView)));
}
bool Tree::isCorrect() const {
......@@ -77,9 +73,9 @@ namespace nbody {
//accumulate forces for whole tree (local particles)
void Tree::computeForces() {
for (Node* n = nodes->next; n != nodes; n = n->next) {
if (n->leaf) {
for (auto& b : n->bodies) {
for (auto& n : nodesView) {
if (n.leaf) {
for (auto& b : n.bodies) {
if (!b.refinement) {
accumulateForceOnto(b);
}
......@@ -155,9 +151,9 @@ namespace nbody {
Box Tree::advance() {
Box bb;
for (Node* n = nodes->next; n != nodes; n = n->next) {
if (n->leaf) {
for (auto& b : n->bodies) {
for (auto& n : nodesView) {
if (n.leaf) {
for (auto& b : n.bodies) {
if (!b.refinement) {
b.integrate();
bb.extend(b);
......@@ -171,9 +167,9 @@ namespace nbody {
//determine local tree bounding box
Box Tree::getLocalBB() const {
Box result;
for (Node* n = nodes->next; n != nodes; n = n->next) {
if (n->leaf) {
for (auto& b : n->bodies) {
for (const auto& n : nodesView) {
if (n.leaf) {
for (const auto& b : n.bodies) {
if (!b.refinement) {
result.extend(b);
}
......@@ -184,10 +180,10 @@ namespace nbody {
}
void Tree::print(std::size_t parallelId) const {
for (Node* n = nodes->next; n != nodes; n = n->next) {
n->bb.printBB(parallelId);
if (n->leaf) {
for (auto& b : n->bodies) {
for (const auto& n : nodesView) {
n.bb.printBB(parallelId);
if (n.leaf) {
for (const auto& b : n.bodies) {
if (!b.refinement) {
b.print(this->parallelId);
}
......@@ -195,4 +191,19 @@ namespace nbody {
}
}
}
Tree::NodesView::NodesView(Tree* tree): tree_(tree) {};
Node::NodeIterator Tree::NodesView::begin() {
return Node::NodeIterator{ tree_->nodes->next };
}
Node::NodeIterator Tree::NodesView::end() {
return Node::NodeIterator{ tree_->nodes };
}
Node::NodeIterator Tree::NodesView::begin() const {
return Node::NodeIterator{ tree_->nodes->next };
}
Node::NodeIterator Tree::NodesView::end() const {
return Node::NodeIterator{ tree_->nodes };
}
} // namespace nbody
......@@ -6,16 +6,14 @@
#include <string>
#include "Body.hpp"
#include "Box.hpp"
#include "Node.hpp"
namespace nbody {
class Node;
class Simulation;
//superclass for Barnes-Hut tree
class Tree {
friend class Node;
friend class PthreadSimulation;
friend class ContinuousPthreadSimulation;
protected:
Node* nodes;
std::size_t maxLeafBodies{ 0 };
......@@ -42,9 +40,18 @@ namespace nbody {
virtual void print(std::size_t parallelId) const;
virtual Box advance();
virtual Box getLocalBB() const;
private:
class NodesView {
Tree* tree_;
public:
NodesView(Tree* tree);
Node::NodeIterator begin();
Node::NodeIterator end();
Node::NodeIterator begin() const;
Node::NodeIterator end() const;
};
NodesView nodesView;
};
} // namespace nbody
#endif
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment