MLpp
Kernels.hpp
1 #pragma once
2 /* (C) 2021 Roman Werpachowski. */
3 #include <memory>
4 #include <type_traits>
5 #include <Eigen/Core>
6 #include "dll.hpp"
7 
8 namespace ml
9 {
13  namespace Kernels
14  {
20  class Kernel
21  {
22  public:
24  DLL_DECLSPEC virtual ~Kernel();
25 
34  DLL_DECLSPEC virtual double value(const Eigen::Ref<const Eigen::VectorXd> x1, const Eigen::Ref<const Eigen::VectorXd> x2) const = 0;
35 
39  virtual unsigned int dim() const = 0;
40  protected:
41  DLL_DECLSPEC void validate_arguments(const Eigen::Ref<const Eigen::VectorXd> x1, const Eigen::Ref<const Eigen::VectorXd> x2) const;
42  };
43 
47  class DifferentiableKernel: public virtual Kernel
48  {
56  DLL_DECLSPEC virtual void gradient(const Eigen::Ref<const Eigen::VectorXd> x1, const Eigen::Ref<const Eigen::VectorXd> x2, Eigen::Ref<Eigen::VectorXd> dydx1) const = 0;
57  };
58 
59 
64  {
72  DLL_DECLSPEC virtual void hessian(const Eigen::Ref<const Eigen::VectorXd> x1, const Eigen::Ref<const Eigen::VectorXd> x2, Eigen::Ref<Eigen::MatrixXd> H) const = 0;
73  };
74 
79  {
80  public:
82  DLL_DECLSPEC virtual ~RadialBasisFunction();
83 
90  DLL_DECLSPEC virtual double value(double r2) const = 0;
91  };
92 
97  {
98  public:
105  DLL_DECLSPEC virtual double gradient(double r2) const = 0;
106  };
107 
112  {
113  public:
120  DLL_DECLSPEC virtual double second_derivative(double r2) const = 0;
121  };
122 
129  {
130  public:
131  DLL_DECLSPEC double value(double r2) const override;
132  DLL_DECLSPEC double gradient(double r2) const override;
133  DLL_DECLSPEC double second_derivative(double r2) const override;
134  };
135 
145  template <class RBF = RadialBasisFunction> class RBFKernel : public virtual Kernel
146  {
147  static_assert(std::is_base_of_v<RadialBasisFunction, RBF>);
148  public:
156  RBFKernel(std::unique_ptr<const RBF>&& rbf, unsigned int dim)
157  : rbf_(std::move(rbf)), dim_(dim)
158  {
159  if (rbf_ == nullptr) {
160  throw std::invalid_argument("Null RBF object");
161  }
162  if (!dim) {
163  throw std::domain_error("Kernel dimension must be positive");
164  }
165  }
166 
167  RBFKernel(const RBFKernel& other) = delete;
168  RBFKernel& operator=(const RBFKernel& other) = delete;
169 
170  DLL_DECLSPEC RBFKernel(RBFKernel&& other) = default;
171  DLL_DECLSPEC RBFKernel& operator=(RBFKernel&& other) = default;
172 
173  double value(const Eigen::Ref<const Eigen::VectorXd> x1, const Eigen::Ref<const Eigen::VectorXd> x2) const override
174  {
175  validate_arguments(x1, x2);
176  return rbf_->value((x1 - x2).squaredNorm());
177  }
178 
179  unsigned int dim() const override
180  {
181  return dim_;
182  }
183  protected:
184  std::unique_ptr<const RBF> rbf_;
185  private:
186  unsigned int dim_;
187  };
188 
193  template <class DiffRBF = DifferentiableRadialBasisFunction> class DifferentiableRBFKernel : public virtual DifferentiableKernel, public RBFKernel<DiffRBF>
194  {
195  static_assert(std::is_base_of_v<DifferentiableRadialBasisFunction, DiffRBF>);
196  public:
198 
199  void gradient(const Eigen::Ref<const Eigen::VectorXd> x1, const Eigen::Ref<const Eigen::VectorXd> x2, Eigen::Ref<Eigen::VectorXd> dydx1) const override
200  {
201  this->validate_arguments(x1, x2);
202  if (dydx1.size() != this->dim()) {
203  throw std::invalid_argument("Wrong dimension of dydx1");
204  }
205  const double rbf1der = this->rbf_->gradient((x1 - x2).squaredNorm());
206  for (Eigen::Index i = 0; i < x1.size(); ++i) {
207  dydx1[i] = 2 * (x1[i] - x2[i]) * rbf1der;
208  }
209  }
210  };
211  }
212 }
ml::Kernels::Kernel::~Kernel
virtual ~Kernel()
Virtual destructor.
ml::Kernels::RBFKernel::dim
unsigned int dim() const override
Dimension of the feature space.
Definition: Kernels.hpp:179
ml
Definition: BallTree.hpp:10
ml::Kernels::DifferentiableRadialBasisFunction::gradient
virtual double gradient(double r2) const =0
Gradient of the radial basis function of the RBF kernel.
ml::Kernels::GaussianRBF::gradient
double gradient(double r2) const override
Gradient of the radial basis function of the RBF kernel.
ml::Kernels::Kernel::value
virtual double value(const Eigen::Ref< const Eigen::VectorXd > x1, const Eigen::Ref< const Eigen::VectorXd > x2) const =0
Value of the kernel .
dll.hpp
ml::Kernels::RadialBasisFunction::~RadialBasisFunction
virtual ~RadialBasisFunction()
Virtual destructor.
ml::Kernels::RBFKernel
Radial basis function kernel.
Definition: Kernels.hpp:145
ml::Kernels::Kernel::dim
virtual unsigned int dim() const =0
Dimension of the feature space.
ml::Kernels::DifferentiableKernel
Abstract differentiable R^D kernel interface.
Definition: Kernels.hpp:47
ml::Kernels::DoubleDifferentiableRadialBasisFunction
Double differentiable radial basis function kernel.
Definition: Kernels.hpp:111
ml::Kernels::RBFKernel::value
double value(const Eigen::Ref< const Eigen::VectorXd > x1, const Eigen::Ref< const Eigen::VectorXd > x2) const override
Value of the kernel .
Definition: Kernels.hpp:173
ml::Kernels::GaussianRBF::value
double value(double r2) const override
Radial basis function of the RBF kernel.
ml::Kernels::RBFKernel::RBFKernel
RBFKernel(std::unique_ptr< const RBF > &&rbf, unsigned int dim)
Constructor.
Definition: Kernels.hpp:156
ml::Kernels::RadialBasisFunction
Radial basis function.
Definition: Kernels.hpp:78
ml::Kernels::Kernel
Abstract R^D kernel interface.
Definition: Kernels.hpp:20
ml::Kernels::GaussianRBF
Gaussian radial basis function.
Definition: Kernels.hpp:128
ml::Kernels::DifferentiableRadialBasisFunction
Differentiable radial basis function kernel.
Definition: Kernels.hpp:96
ml::Kernels::RadialBasisFunction::value
virtual double value(double r2) const =0
Radial basis function of the RBF kernel.
ml::Kernels::DoubleDifferentiableRadialBasisFunction::second_derivative
virtual double second_derivative(double r2) const =0
Second derivative of the radial basis function of the RBF kernel.
ml::Kernels::DifferentiableRBFKernel
Differentiable radial basis function kernel.
Definition: Kernels.hpp:193
ml::Kernels::DoubleDifferentiableKernel
Abstract double differentiable R^D kernel interface.
Definition: Kernels.hpp:63
ml::Kernels::GaussianRBF::second_derivative
double second_derivative(double r2) const override
Second derivative of the radial basis function of the RBF kernel.
ml::Kernels::DifferentiableRBFKernel::gradient
void gradient(const Eigen::Ref< const Eigen::VectorXd > x1, const Eigen::Ref< const Eigen::VectorXd > x2, Eigen::Ref< Eigen::VectorXd > dydx1) const override
Gradient of the kernel over the first feature vector. The gradient over the second vector can be cal...
Definition: Kernels.hpp:199