Loading [MathJax]/extensions/TeX/AMSsymbols.js
8 #include <unordered_set>
10 #include "DecisionTreeNodes.hpp"
23 typedef Eigen::Ref<const Eigen::VectorXd>
arg_type ;
35 : root_(std::move(root))
38 throw std::invalid_argument(
"Null root");
41 throw std::invalid_argument(
"Root has no parent");
43 root_->collect_lowest_split_nodes(lowest_split_nodes_);
51 : root_(std::move(other.root_)), lowest_split_nodes_(std::move(other.lowest_split_nodes_))
59 : root_(other.root_->clone(nullptr))
61 root_->collect_lowest_split_nodes(lowest_split_nodes_);
71 root_ = std::move(other.root_);
72 lowest_split_nodes_ = std::move(other.lowest_split_nodes_);
84 root_.reset(other.root_->clone(
nullptr));
85 lowest_split_nodes_.clear();
86 root_->collect_lowest_split_nodes(lowest_split_nodes_);
105 return 1 + root_->count_lower_nodes();
111 return root_->count_leaf_nodes();
123 return root_->total_leaf_error();
148 if (max_allowed_error_increase < 0) {
149 throw std::domain_error(
"Maximum allowed error increase cannot be negative");
151 if (lowest_split_nodes_.empty()) {
154 double lowest_error_increase = std::numeric_limits<double>::infinity();
156 for (
auto split_node_ptr : lowest_split_nodes_) {
157 assert(split_node_ptr);
158 const SplitNode& split_node = *split_node_ptr;
159 assert(split_node.
lower);
160 assert(split_node.
higher);
161 assert(split_node.
lower->is_leaf());
162 assert(split_node.
higher->is_leaf());
163 const double error_increase = split_node.error - (split_node.
lower->error + split_node.
higher->error);
164 if (error_increase <= lowest_error_increase) {
165 removed = split_node_ptr;
166 lowest_error_increase = error_increase;
170 assert(lowest_error_increase >= 0);
171 if (lowest_error_increase > max_allowed_error_increase) {
174 SplitNode*
const parent_of_removed = removed->parent;
175 auto new_leaf = std::make_unique<LeafNode>(removed->error, removed->value, parent_of_removed);
176 if (parent_of_removed) {
179 if (removed == parent_of_removed->
lower.get()) {
180 parent_of_removed->
lower = std::move(new_leaf);
181 other_is_leaf = parent_of_removed->
higher->is_leaf();
183 assert(removed == parent_of_removed->
higher.get());
184 parent_of_removed->
higher = std::move(new_leaf);
185 other_is_leaf = parent_of_removed->
lower->is_leaf();
188 lowest_split_nodes_.erase(removed);
189 assert(!lowest_split_nodes_.count(removed));
191 lowest_split_nodes_.insert(parent_of_removed);
192 assert(lowest_split_nodes_.count(parent_of_removed));
196 root_ = std::move(new_leaf);
197 lowest_split_nodes_.clear();
207 return static_cast<unsigned int>(lowest_split_nodes_.size());
210 std::unique_ptr<Node> root_;
211 std::unordered_set<SplitNode*> lowest_split_nodes_;
unsigned int count_leaf_nodes() const
Counts leaf nodes in the tree.
Definition: DecisionTree.hpp:109
Terminal node, which returns a constant prediction value for features which ended up on it.
Definition: DecisionTreeNodes.hpp:194
Decision tree.
Definition: DecisionTree.hpp:20
Eigen::Ref< const Eigen::VectorXd > arg_type
Definition: DecisionTree.hpp:23
double original_error() const
Returns the prediction error for training data before any splits are made.
Definition: DecisionTree.hpp:115
DecisionTrees::SplitNode< Y > SplitNode
Definition: DecisionTree.hpp:26
DecisionTree(DecisionTree< Y > &&other) noexcept
Move constructor.
Definition: DecisionTree.hpp:50
Non-terminal node, which splits data depending on a threshold value of some feature.
Definition: DecisionTreeNodes.hpp:13
Y value_type
Definition: DecisionTree.hpp:24
Definition: BallTree.hpp:10
DecisionTrees::LeafNode< Y > LeafNode
Definition: DecisionTree.hpp:27
DecisionTree(std::unique_ptr< Node > &&root)
Constructs a decision tree by taking ownership of a root node.
Definition: DecisionTree.hpp:34
bool remove_weakest_link(const double max_allowed_error_increase)
Finds the weakest link and removes it, if the error does not increase too much.
Definition: DecisionTree.hpp:146
unsigned int count_nodes() const
Counts nodes in the tree.
Definition: DecisionTree.hpp:103
Y operator()(arg_type x) const
Returns a prediction given a feature vector.
Definition: DecisionTree.hpp:97
double total_leaf_error() const
Returns the total prediction error for training data after all splits.
Definition: DecisionTree.hpp:121
DecisionTree< Y > & operator=(const DecisionTree< Y > &other)
Copy assignment operator.
Definition: DecisionTree.hpp:81
Tree node. Nodes are split (non-terminal) or leaf (terminal).
Definition: DecisionTreeNodes.hpp:18
std::unique_ptr< Node< Y > > lower
Definition: DecisionTreeNodes.hpp:82
DecisionTree< Y > & operator=(DecisionTree< Y > &&other) noexcept
Move assignment operator.
Definition: DecisionTree.hpp:68
std::unique_ptr< Node< Y > > higher
Definition: DecisionTreeNodes.hpp:83
DecisionTrees::Node< Y > Node
Definition: DecisionTree.hpp:25
unsigned int number_lowest_split_nodes() const
Counts lowest split nodes.
Definition: DecisionTree.hpp:205
DecisionTree(const DecisionTree< Y > &other)
Copy constructor.
Definition: DecisionTree.hpp:58
double cost_complexity(double alpha) const
Calculates cost-complexity measure.
Definition: DecisionTree.hpp:132