MLpp
DecisionTreeNodes.hpp
1 #pragma once
2 /* (C) 2020 Roman Werpachowski. */
3 #include <memory>
4 #include <stdexcept>
5 #include <unordered_set>
6 #include <Eigen/Core>
7 
8 namespace ml
9 {
11  namespace DecisionTrees
12  {
13  template <class Y> struct SplitNode;
14 
18  template <class Y> struct Node
19  {
20  typedef Eigen::Ref<const Eigen::VectorXd> arg_type;
22  double error;
23  Y value;
32  Node(double n_error, Y n_value, SplitNode<Y>* n_parent)
33  : error(n_error), value(n_value), parent(n_parent)
34  {
35  if (error < 0) {
36  throw std::domain_error("Node error cannot be negative");
37  }
38  }
39 
41  virtual ~Node() {}
42 
47  virtual Y operator()(arg_type x) const = 0;
48 
50  virtual unsigned int count_lower_nodes() const = 0;
51 
53  virtual unsigned int count_leaf_nodes() const = 0;
54 
59  virtual double total_leaf_error() const = 0;
60 
65  virtual Node* clone(SplitNode<Y>* cloned_parent) const = 0;
66 
68  virtual bool is_leaf() const = 0;
69 
76  virtual void collect_lowest_split_nodes(std::unordered_set<SplitNode<Y>*>& s) = 0;
77  };
78 
80  template <class Y> struct SplitNode : public Node<Y>
81  {
82  std::unique_ptr<Node<Y>> lower;
83  std::unique_ptr<Node<Y>> higher;
84  double threshold;
85  unsigned int feature_index;
87  using arg_type = typename Node<Y>::arg_type;
88  using Node<Y>::error;
89  using Node<Y>::value;
90  using Node<Y>::parent;
91 
100  SplitNode(double n_error, Y n_value, SplitNode<Y>* n_parent, double n_threshold, unsigned int n_feature_index)
101  : Node<Y>(n_error, n_value, n_parent), threshold(n_threshold), feature_index(n_feature_index)
102  {}
103 
105  Y operator()(arg_type x) const override
106  {
107  assert(lower);
108  assert(higher);
109  assert(this == lower->parent);
110  assert(this == higher->parent);
111  if (x[feature_index] < threshold) {
112  return (*lower)(x);
113  } else {
114  return (*higher)(x);
115  }
116  }
117 
119  unsigned int count_lower_nodes() const override
120  {
121  assert(lower);
122  assert(higher);
123  assert(this == lower->parent);
124  assert(this == higher->parent);
125  return 2 + lower->count_lower_nodes() + higher->count_lower_nodes();
126  }
127 
129  unsigned int count_leaf_nodes() const override
130  {
131  assert(lower);
132  assert(higher);
133  assert(this == lower->parent);
134  assert(this == higher->parent);
135  return lower->count_leaf_nodes() + higher->count_leaf_nodes();
136  }
137 
139  double total_leaf_error() const override
140  {
141  assert(lower);
142  assert(higher);
143  assert(this == lower->parent);
144  assert(this == higher->parent);
145  return lower->total_leaf_error() + higher->total_leaf_error();
146  }
147 
149  SplitNode<Y>* clone(SplitNode<Y>* cloned_parent) const override
150  {
151  assert(lower);
152  assert(higher);
153  assert(this == lower->parent);
154  assert(this == higher->parent);
155  auto copy = std::make_unique<SplitNode<Y>>(error, value, cloned_parent, threshold, feature_index);
156  copy->lower = std::unique_ptr<Node<Y>>(lower->clone(copy.get()));
157  copy->higher = std::unique_ptr<Node<Y>>(higher->clone(copy.get()));
158  return copy.release();
159  }
160 
162  bool is_leaf() const override
163  {
164  return false;
165  }
166 
168  void collect_lowest_split_nodes(std::unordered_set<SplitNode<Y>*>& s) override
169  {
170  assert(lower);
171  assert(higher);
172  assert(this == lower->parent);
173  assert(this == higher->parent);
174  int number_leaves = 0;
175  if (!lower->is_leaf()) {
176  lower->collect_lowest_split_nodes(s);
177  } else {
178  ++number_leaves;
179  }
180  if (!higher->is_leaf()) {
181  higher->collect_lowest_split_nodes(s);
182  } else {
183  ++number_leaves;
184  }
185  if (number_leaves == 2) {
186  assert(lower->is_leaf());
187  assert(higher->is_leaf());
188  s.insert(this);
189  }
190  }
191  };
192 
194  template <class Y> struct LeafNode : public Node<Y>
195  {
202  LeafNode(double n_error, Y n_value, SplitNode<Y>* n_parent)
203  : Node<Y>(n_error, n_value, n_parent)
204  {}
205 
206  using arg_type = typename Node<Y>::arg_type;
207  using Node<Y>::error;
208  using Node<Y>::value;
209  using Node<Y>::parent;
210 
211  Y operator()(arg_type) const override
212  {
213  return value;
214  }
215 
216  unsigned int count_lower_nodes() const override
217  {
218  return 0;
219  }
220 
221  unsigned int count_leaf_nodes() const override
222  {
223  return 1;
224  }
225 
226  double total_leaf_error() const override
227  {
228  return error;
229  }
230 
231  LeafNode* clone(SplitNode<Y>* cloned_parent) const override
232  {
233  return new LeafNode<Y>(error, value, cloned_parent);
234  }
235 
236  bool is_leaf() const override
237  {
238  return true;
239  }
240 
241  void collect_lowest_split_nodes(std::unordered_set<SplitNode<Y>*>&) override
242  {}
243  };
244  }
245 }
ml::DecisionTrees::SplitNode::threshold
double threshold
Definition: DecisionTreeNodes.hpp:84
ml::DecisionTrees::LeafNode
Terminal node, which returns a constant prediction value for features which ended up on it.
Definition: DecisionTreeNodes.hpp:194
ml::DecisionTrees::Node::error
double error
Definition: DecisionTreeNodes.hpp:22
ml::DecisionTrees::Node::parent
SplitNode< Y > * parent
Definition: DecisionTreeNodes.hpp:24
ml::DecisionTrees::SplitNode
Non-terminal node, which splits data depending on a threshold value of some feature.
Definition: DecisionTreeNodes.hpp:13
ml::DecisionTrees::SplitNode::total_leaf_error
double total_leaf_error() const override
Total error of the training samples seen by the leaf nodes reachable from this node (including its ow...
Definition: DecisionTreeNodes.hpp:139
ml
Definition: BallTree.hpp:10
ml::DecisionTrees::SplitNode::operator()
Y operator()(arg_type x) const override
Returns a prediction given a feature vector.
Definition: DecisionTreeNodes.hpp:105
ml::DecisionTrees::LeafNode::operator()
Y operator()(arg_type) const override
Returns a prediction given a feature vector.
Definition: DecisionTreeNodes.hpp:211
ml::DecisionTrees::Node::count_leaf_nodes
virtual unsigned int count_leaf_nodes() const =0
Total number of leaf nodes reachable from this one, including itself.
ml::DecisionTrees::LeafNode::count_leaf_nodes
unsigned int count_leaf_nodes() const override
Total number of leaf nodes reachable from this one, including itself.
Definition: DecisionTreeNodes.hpp:221
ml::DecisionTrees::SplitNode::feature_index
unsigned int feature_index
Definition: DecisionTreeNodes.hpp:85
ml::DecisionTrees::SplitNode::collect_lowest_split_nodes
void collect_lowest_split_nodes(std::unordered_set< SplitNode< Y > * > &s) override
Adds all lowest split nodes.
Definition: DecisionTreeNodes.hpp:168
ml::DecisionTrees::Node::clone
virtual Node * clone(SplitNode< Y > *cloned_parent) const =0
Make a perfect copy of the node. Function works recursively from root to leafs.
ml::DecisionTrees::LeafNode::collect_lowest_split_nodes
void collect_lowest_split_nodes(std::unordered_set< SplitNode< Y > * > &) override
Adds all lowest split nodes.
Definition: DecisionTreeNodes.hpp:241
ml::DecisionTrees::Node::value
Y value
Definition: DecisionTreeNodes.hpp:23
ml::DecisionTrees::LeafNode::count_lower_nodes
unsigned int count_lower_nodes() const override
Total number of nodes reachable from this one.
Definition: DecisionTreeNodes.hpp:216
ml::DecisionTrees::SplitNode::is_leaf
bool is_leaf() const override
Return true if node is a leaf.
Definition: DecisionTreeNodes.hpp:162
ml::DecisionTrees::Node::arg_type
Eigen::Ref< const Eigen::VectorXd > arg_type
Definition: DecisionTreeNodes.hpp:20
ml::DecisionTrees::LeafNode::is_leaf
bool is_leaf() const override
Return true if node is a leaf.
Definition: DecisionTreeNodes.hpp:236
ml::DecisionTrees::SplitNode::count_lower_nodes
unsigned int count_lower_nodes() const override
Total number of nodes reachable from this one.
Definition: DecisionTreeNodes.hpp:119
ml::DecisionTrees::Node
Tree node. Nodes are split (non-terminal) or leaf (terminal).
Definition: DecisionTreeNodes.hpp:18
ml::DecisionTrees::Node::Node
Node(double n_error, Y n_value, SplitNode< Y > *n_parent)
Constructor.
Definition: DecisionTreeNodes.hpp:32
ml::DecisionTrees::SplitNode::SplitNode
SplitNode(double n_error, Y n_value, SplitNode< Y > *n_parent, double n_threshold, unsigned int n_feature_index)
Constructor.
Definition: DecisionTreeNodes.hpp:100
ml::DecisionTrees::SplitNode::lower
std::unique_ptr< Node< Y > > lower
Definition: DecisionTreeNodes.hpp:82
ml::DecisionTrees::LeafNode::total_leaf_error
double total_leaf_error() const override
Total error of the training samples seen by the leaf nodes reachable from this node (including its ow...
Definition: DecisionTreeNodes.hpp:226
ml::DecisionTrees::SplitNode::count_leaf_nodes
unsigned int count_leaf_nodes() const override
Total number of leaf nodes reachable from this one, including itself.
Definition: DecisionTreeNodes.hpp:129
ml::DecisionTrees::SplitNode::clone
SplitNode< Y > * clone(SplitNode< Y > *cloned_parent) const override
Make a perfect copy of the node. Function works recursively from root to leafs.
Definition: DecisionTreeNodes.hpp:149
ml::DecisionTrees::Node::is_leaf
virtual bool is_leaf() const =0
Return true if node is a leaf.
ml::DecisionTrees::Node::operator()
virtual Y operator()(arg_type x) const =0
Returns a prediction given a feature vector.
ml::DecisionTrees::LeafNode::clone
LeafNode * clone(SplitNode< Y > *cloned_parent) const override
Make a perfect copy of the node. Function works recursively from root to leafs.
Definition: DecisionTreeNodes.hpp:231
ml::DecisionTrees::LeafNode::LeafNode
LeafNode(double n_error, Y n_value, SplitNode< Y > *n_parent)
Constructor.
Definition: DecisionTreeNodes.hpp:202
ml::DecisionTrees::SplitNode::higher
std::unique_ptr< Node< Y > > higher
Definition: DecisionTreeNodes.hpp:83
ml::DecisionTrees::Node::~Node
virtual ~Node()
Virtual destructor.
Definition: DecisionTreeNodes.hpp:41
ml::DecisionTrees::Node::collect_lowest_split_nodes
virtual void collect_lowest_split_nodes(std::unordered_set< SplitNode< Y > * > &s)=0
Adds all lowest split nodes.
ml::DecisionTrees::SplitNode::arg_type
typename Node< Y >::arg_type arg_type
Definition: DecisionTreeNodes.hpp:87
ml::DecisionTrees::Node::count_lower_nodes
virtual unsigned int count_lower_nodes() const =0
Total number of nodes reachable from this one.
ml::DecisionTrees::Node::total_leaf_error
virtual double total_leaf_error() const =0
Total error of the training samples seen by the leaf nodes reachable from this node (including its ow...