MLpp
BallTree.hpp
1 #pragma once
2 
3 #include <queue>
4 #include <memory>
5 #include <utility>
6 #include <Eigen/Core>
7 #include "Features.hpp"
8 #include "dll.hpp"
9 
10 namespace ml
11 {
18  class BallTree
19  {
20  public:
27  DLL_DECLSPEC BallTree(Eigen::Ref<const Eigen::MatrixXd> X, unsigned int min_split_size);
28 
36  DLL_DECLSPEC BallTree(Eigen::Ref<const Eigen::MatrixXd> X, Eigen::Ref<const Eigen::VectorXd> y, unsigned int min_split_size);
37 
44  DLL_DECLSPEC BallTree(Eigen::MatrixXd&& X, unsigned int min_split_size);
45 
53  DLL_DECLSPEC BallTree(Eigen::MatrixXd&& X, Eigen::VectorXd&& y, unsigned int min_split_size);
54 
58  const Eigen::MatrixXd& data() const
59  {
60  return data_;
61  }
62 
66  const Eigen::VectorXd& labels() const
67  {
68  return labels_;
69  }
70 
79  DLL_DECLSPEC void find_k_nearest_neighbours(Eigen::Ref<const Eigen::VectorXd> x, unsigned int k, std::vector<unsigned int>& nn) const;
80 
89  DLL_DECLSPEC unsigned int find_nearest_neighbour(Eigen::Ref<const Eigen::VectorXd> x) const;
90 
94  auto size() const
95  {
96  return static_cast<unsigned int>(data_.cols());
97  }
98 
102  auto dim() const
103  {
104  return static_cast<unsigned int>(data_.rows());
105  }
106  private:
107  struct Node
108  {
109  double radius;
110  unsigned int pivot_index;
111  unsigned int start_index;
112  unsigned int end_index;
113  std::unique_ptr<Node> left_child;
114  std::unique_ptr<Node> right_child;
115  };
116 
117  Eigen::MatrixXd data_;
118  Eigen::VectorXd labels_;
119  std::unique_ptr<Node> root_;
120  unsigned int min_split_size_;
121 
130  void construct(Eigen::Ref<Eigen::MatrixXd> work, Eigen::Ref<Eigen::VectorXd> labels, unsigned int offset, std::unique_ptr<Node>& node, Features::VectorRange<Features::IndexedFeatureValue> features);
131 
135  typedef std::pair<unsigned int, double> IndexedDistanceFromTarget;
136 
140  struct IndexedDistanceFromTargetComparator
141  {
142  bool operator()(const IndexedDistanceFromTarget& a, const IndexedDistanceFromTarget& b) const
143  {
144  return a.second < b.second;
145  }
146  };
147 
151  typedef std::priority_queue<IndexedDistanceFromTarget, std::vector<IndexedDistanceFromTarget>, IndexedDistanceFromTargetComparator> MaxDistancePriorityQueue;
152 
160  void knn_search(Eigen::Ref<const Eigen::VectorXd> x, unsigned int k, const Node* node, MaxDistancePriorityQueue& q) const;
161 
169  double distance_from_queue(Eigen::Ref<const Eigen::VectorXd> x, unsigned int k, const MaxDistancePriorityQueue& q) const;
170  };
171 }
ml::BallTree::dim
auto dim() const
Dimension of the feature vectors.
Definition: BallTree.hpp:102
ml::BallTree::BallTree
BallTree(Eigen::Ref< const Eigen::MatrixXd > X, unsigned int min_split_size)
Constructor taking only features.
ml
Definition: BallTree.hpp:10
ml::BallTree::find_k_nearest_neighbours
void find_k_nearest_neighbours(Eigen::Ref< const Eigen::VectorXd > x, unsigned int k, std::vector< unsigned int > &nn) const
Finds up to k nearest neighbours for given target vector. Uses the KNS1 algorithm from http://people....
dll.hpp
ml::BallTree::find_nearest_neighbour
unsigned int find_nearest_neighbour(Eigen::Ref< const Eigen::VectorXd > x) const
Finds nearest neighbour for given target vector. Uses the KNS1 algorithm from http://people....
ml::BallTree
Ball tree: an efficient tree structure for nearest-neighbour search in R^D space.
Definition: BallTree.hpp:18
ml::BallTree::size
auto size() const
Size of the tree (number of vectors).
Definition: BallTree.hpp:94
ml::BallTree::data
const Eigen::MatrixXd & data() const
Returns const reference to feature vectors (reordered).
Definition: BallTree.hpp:58
ml::BallTree::labels
const Eigen::VectorXd & labels() const
Returns const reference to labels (reordered).
Definition: BallTree.hpp:66