Loading [MathJax]/extensions/TeX/AMSsymbols.js
MLpp
All Classes Namespaces Files Functions Variables Typedefs Pages
DecisionTree.hpp
1 #pragma once
2 /* (C) 2020 Roman Werpachowski. */
3 #include <cassert>
4 #include <iostream>
5 #include <memory>
6 #include <stdexcept>
7 #include <tuple>
8 #include <unordered_set>
9 #include <Eigen/Core>
10 #include "DecisionTreeNodes.hpp"
11 
12 namespace ml
13 {
20  template <class Y> class DecisionTree
21  {
22  public:
23  typedef Eigen::Ref<const Eigen::VectorXd> arg_type ;
24  typedef Y value_type;
34  DecisionTree(std::unique_ptr<Node>&& root)
35  : root_(std::move(root))
36  {
37  if (!root_) {
38  throw std::invalid_argument("Null root");
39  }
40  if (root_->parent) {
41  throw std::invalid_argument("Root has no parent");
42  }
43  root_->collect_lowest_split_nodes(lowest_split_nodes_);
44  }
45 
50  DecisionTree(DecisionTree<Y>&& other) noexcept
51  : root_(std::move(other.root_)), lowest_split_nodes_(std::move(other.lowest_split_nodes_))
52  {}
53 
59  : root_(other.root_->clone(nullptr))
60  {
61  root_->collect_lowest_split_nodes(lowest_split_nodes_);
62  }
63 
69  {
70  if (this != &other) {
71  root_ = std::move(other.root_);
72  lowest_split_nodes_ = std::move(other.lowest_split_nodes_);
73  }
74  return *this;
75  }
76 
82  {
83  if (this != &other) {
84  root_.reset(other.root_->clone(nullptr));
85  lowest_split_nodes_.clear();
86  root_->collect_lowest_split_nodes(lowest_split_nodes_);
87  }
88  return *this;
89  }
90 
97  Y operator()(arg_type x) const
98  {
99  return (*root_)(x);
100  }
101 
103  unsigned int count_nodes() const
104  {
105  return 1 + root_->count_lower_nodes();
106  }
107 
109  unsigned int count_leaf_nodes() const
110  {
111  return root_->count_leaf_nodes();
112  }
113 
115  double original_error() const
116  {
117  return root_->error;
118  }
119 
121  double total_leaf_error() const
122  {
123  return root_->total_leaf_error();
124  }
125 
132  double cost_complexity(double alpha) const
133  {
134  return total_leaf_error() + alpha * static_cast<double>(count_leaf_nodes());
135  }
136 
146  bool remove_weakest_link(const double max_allowed_error_increase)
147  {
148  if (max_allowed_error_increase < 0) {
149  throw std::domain_error("Maximum allowed error increase cannot be negative");
150  }
151  if (lowest_split_nodes_.empty()) {
152  return false;
153  }
154  double lowest_error_increase = std::numeric_limits<double>::infinity();
155  SplitNode* removed = nullptr;
156  for (auto split_node_ptr : lowest_split_nodes_) {
157  assert(split_node_ptr);
158  const SplitNode& split_node = *split_node_ptr;
159  assert(split_node.lower);
160  assert(split_node.higher);
161  assert(split_node.lower->is_leaf());
162  assert(split_node.higher->is_leaf());
163  const double error_increase = split_node.error - (split_node.lower->error + split_node.higher->error);
164  if (error_increase <= lowest_error_increase) {
165  removed = split_node_ptr;
166  lowest_error_increase = error_increase;
167  }
168  }
169  assert(removed);
170  assert(lowest_error_increase >= 0);
171  if (lowest_error_increase > max_allowed_error_increase) {
172  return false;
173  }
174  SplitNode* const parent_of_removed = removed->parent;
175  auto new_leaf = std::make_unique<LeafNode>(removed->error, removed->value, parent_of_removed);
176  if (parent_of_removed) {
177  // Removing a non-root node from pruned tree.
178  bool other_is_leaf;
179  if (removed == parent_of_removed->lower.get()) {
180  parent_of_removed->lower = std::move(new_leaf);
181  other_is_leaf = parent_of_removed->higher->is_leaf();
182  } else {
183  assert(removed == parent_of_removed->higher.get());
184  parent_of_removed->higher = std::move(new_leaf);
185  other_is_leaf = parent_of_removed->lower->is_leaf();
186  }
187  // Update the set of lowest split nodes.
188  lowest_split_nodes_.erase(removed);
189  assert(!lowest_split_nodes_.count(removed));
190  if (other_is_leaf) {
191  lowest_split_nodes_.insert(parent_of_removed);
192  assert(lowest_split_nodes_.count(parent_of_removed));
193  }
194  } else {
195  // We removed the last split. Replace the pruned tree with the new leaf.
196  root_ = std::move(new_leaf);
197  lowest_split_nodes_.clear();
198  }
199  return true;
200  }
201 
205  unsigned int number_lowest_split_nodes() const
206  {
207  return static_cast<unsigned int>(lowest_split_nodes_.size());
208  }
209  private:
210  std::unique_ptr<Node> root_;
211  std::unordered_set<SplitNode*> lowest_split_nodes_;
212  };
213 }
ml::DecisionTree::count_leaf_nodes
unsigned int count_leaf_nodes() const
Counts leaf nodes in the tree.
Definition: DecisionTree.hpp:109
ml::DecisionTrees::LeafNode
Terminal node, which returns a constant prediction value for features which ended up on it.
Definition: DecisionTreeNodes.hpp:194
ml::DecisionTree
Decision tree.
Definition: DecisionTree.hpp:20
ml::DecisionTree::arg_type
Eigen::Ref< const Eigen::VectorXd > arg_type
Definition: DecisionTree.hpp:23
ml::DecisionTree::original_error
double original_error() const
Returns the prediction error for training data before any splits are made.
Definition: DecisionTree.hpp:115
ml::DecisionTree::SplitNode
DecisionTrees::SplitNode< Y > SplitNode
Definition: DecisionTree.hpp:26
ml::DecisionTree::DecisionTree
DecisionTree(DecisionTree< Y > &&other) noexcept
Move constructor.
Definition: DecisionTree.hpp:50
ml::DecisionTrees::SplitNode
Non-terminal node, which splits data depending on a threshold value of some feature.
Definition: DecisionTreeNodes.hpp:13
ml::DecisionTree::value_type
Y value_type
Definition: DecisionTree.hpp:24
ml
Definition: BallTree.hpp:10
ml::DecisionTree::LeafNode
DecisionTrees::LeafNode< Y > LeafNode
Definition: DecisionTree.hpp:27
ml::DecisionTree::DecisionTree
DecisionTree(std::unique_ptr< Node > &&root)
Constructs a decision tree by taking ownership of a root node.
Definition: DecisionTree.hpp:34
ml::DecisionTree::remove_weakest_link
bool remove_weakest_link(const double max_allowed_error_increase)
Finds the weakest link and removes it, if the error does not increase too much.
Definition: DecisionTree.hpp:146
ml::DecisionTree::count_nodes
unsigned int count_nodes() const
Counts nodes in the tree.
Definition: DecisionTree.hpp:103
ml::DecisionTree::operator()
Y operator()(arg_type x) const
Returns a prediction given a feature vector.
Definition: DecisionTree.hpp:97
ml::DecisionTree::total_leaf_error
double total_leaf_error() const
Returns the total prediction error for training data after all splits.
Definition: DecisionTree.hpp:121
ml::DecisionTree::operator=
DecisionTree< Y > & operator=(const DecisionTree< Y > &other)
Copy assignment operator.
Definition: DecisionTree.hpp:81
ml::DecisionTrees::Node
Tree node. Nodes are split (non-terminal) or leaf (terminal).
Definition: DecisionTreeNodes.hpp:18
ml::DecisionTrees::SplitNode::lower
std::unique_ptr< Node< Y > > lower
Definition: DecisionTreeNodes.hpp:82
ml::DecisionTree::operator=
DecisionTree< Y > & operator=(DecisionTree< Y > &&other) noexcept
Move assignment operator.
Definition: DecisionTree.hpp:68
ml::DecisionTrees::SplitNode::higher
std::unique_ptr< Node< Y > > higher
Definition: DecisionTreeNodes.hpp:83
ml::DecisionTree::Node
DecisionTrees::Node< Y > Node
Definition: DecisionTree.hpp:25
ml::DecisionTree::number_lowest_split_nodes
unsigned int number_lowest_split_nodes() const
Counts lowest split nodes.
Definition: DecisionTree.hpp:205
ml::DecisionTree::DecisionTree
DecisionTree(const DecisionTree< Y > &other)
Copy constructor.
Definition: DecisionTree.hpp:58
ml::DecisionTree::cost_complexity
double cost_complexity(double alpha) const
Calculates cost-complexity measure.
Definition: DecisionTree.hpp:132