5 #include <unordered_set>
11 namespace DecisionTrees
18 template <
class Y>
struct Node
20 typedef Eigen::Ref<const Eigen::VectorXd>
arg_type;
36 throw std::domain_error(
"Node error cannot be negative");
68 virtual bool is_leaf()
const = 0;
80 template <
class Y>
struct SplitNode :
public Node<Y>
109 assert(
this ==
lower->parent);
110 assert(
this ==
higher->parent);
123 assert(
this ==
lower->parent);
124 assert(
this ==
higher->parent);
125 return 2 +
lower->count_lower_nodes() +
higher->count_lower_nodes();
133 assert(
this ==
lower->parent);
134 assert(
this ==
higher->parent);
135 return lower->count_leaf_nodes() +
higher->count_leaf_nodes();
143 assert(
this ==
lower->parent);
144 assert(
this ==
higher->parent);
145 return lower->total_leaf_error() +
higher->total_leaf_error();
153 assert(
this ==
lower->parent);
154 assert(
this ==
higher->parent);
156 copy->lower = std::unique_ptr<Node<Y>>(
lower->clone(copy.get()));
157 copy->higher = std::unique_ptr<Node<Y>>(
higher->clone(copy.get()));
158 return copy.release();
172 assert(
this ==
lower->parent);
173 assert(
this ==
higher->parent);
174 int number_leaves = 0;
175 if (!
lower->is_leaf()) {
176 lower->collect_lowest_split_nodes(s);
181 higher->collect_lowest_split_nodes(s);
185 if (number_leaves == 2) {
186 assert(
lower->is_leaf());
187 assert(
higher->is_leaf());
203 :
Node<Y>(n_error, n_value, n_parent)
double threshold
Definition: DecisionTreeNodes.hpp:84
Terminal node, which returns a constant prediction value for features which ended up on it.
Definition: DecisionTreeNodes.hpp:194
double error
Definition: DecisionTreeNodes.hpp:22
SplitNode< Y > * parent
Definition: DecisionTreeNodes.hpp:24
Non-terminal node, which splits data depending on a threshold value of some feature.
Definition: DecisionTreeNodes.hpp:13
double total_leaf_error() const override
Total error of the training samples seen by the leaf nodes reachable from this node (including its ow...
Definition: DecisionTreeNodes.hpp:139
Definition: BallTree.hpp:10
Y operator()(arg_type x) const override
Returns a prediction given a feature vector.
Definition: DecisionTreeNodes.hpp:105
Y operator()(arg_type) const override
Returns a prediction given a feature vector.
Definition: DecisionTreeNodes.hpp:211
virtual unsigned int count_leaf_nodes() const =0
Total number of leaf nodes reachable from this one, including itself.
unsigned int count_leaf_nodes() const override
Total number of leaf nodes reachable from this one, including itself.
Definition: DecisionTreeNodes.hpp:221
unsigned int feature_index
Definition: DecisionTreeNodes.hpp:85
void collect_lowest_split_nodes(std::unordered_set< SplitNode< Y > * > &s) override
Adds all lowest split nodes.
Definition: DecisionTreeNodes.hpp:168
virtual Node * clone(SplitNode< Y > *cloned_parent) const =0
Make a perfect copy of the node. Function works recursively from root to leafs.
void collect_lowest_split_nodes(std::unordered_set< SplitNode< Y > * > &) override
Adds all lowest split nodes.
Definition: DecisionTreeNodes.hpp:241
Y value
Definition: DecisionTreeNodes.hpp:23
unsigned int count_lower_nodes() const override
Total number of nodes reachable from this one.
Definition: DecisionTreeNodes.hpp:216
bool is_leaf() const override
Return true if node is a leaf.
Definition: DecisionTreeNodes.hpp:162
Eigen::Ref< const Eigen::VectorXd > arg_type
Definition: DecisionTreeNodes.hpp:20
bool is_leaf() const override
Return true if node is a leaf.
Definition: DecisionTreeNodes.hpp:236
unsigned int count_lower_nodes() const override
Total number of nodes reachable from this one.
Definition: DecisionTreeNodes.hpp:119
Tree node. Nodes are split (non-terminal) or leaf (terminal).
Definition: DecisionTreeNodes.hpp:18
Node(double n_error, Y n_value, SplitNode< Y > *n_parent)
Constructor.
Definition: DecisionTreeNodes.hpp:32
SplitNode(double n_error, Y n_value, SplitNode< Y > *n_parent, double n_threshold, unsigned int n_feature_index)
Constructor.
Definition: DecisionTreeNodes.hpp:100
std::unique_ptr< Node< Y > > lower
Definition: DecisionTreeNodes.hpp:82
double total_leaf_error() const override
Total error of the training samples seen by the leaf nodes reachable from this node (including its ow...
Definition: DecisionTreeNodes.hpp:226
unsigned int count_leaf_nodes() const override
Total number of leaf nodes reachable from this one, including itself.
Definition: DecisionTreeNodes.hpp:129
SplitNode< Y > * clone(SplitNode< Y > *cloned_parent) const override
Make a perfect copy of the node. Function works recursively from root to leafs.
Definition: DecisionTreeNodes.hpp:149
virtual bool is_leaf() const =0
Return true if node is a leaf.
virtual Y operator()(arg_type x) const =0
Returns a prediction given a feature vector.
LeafNode * clone(SplitNode< Y > *cloned_parent) const override
Make a perfect copy of the node. Function works recursively from root to leafs.
Definition: DecisionTreeNodes.hpp:231
LeafNode(double n_error, Y n_value, SplitNode< Y > *n_parent)
Constructor.
Definition: DecisionTreeNodes.hpp:202
std::unique_ptr< Node< Y > > higher
Definition: DecisionTreeNodes.hpp:83
virtual ~Node()
Virtual destructor.
Definition: DecisionTreeNodes.hpp:41
virtual void collect_lowest_split_nodes(std::unordered_set< SplitNode< Y > * > &s)=0
Adds all lowest split nodes.
typename Node< Y >::arg_type arg_type
Definition: DecisionTreeNodes.hpp:87
virtual unsigned int count_lower_nodes() const =0
Total number of nodes reachable from this one.
virtual double total_leaf_error() const =0
Total error of the training samples seen by the leaf nodes reachable from this node (including its ow...