MLPACK  1.0.11
ra_search_rules.hpp
Go to the documentation of this file.
1 
24 #ifndef __MLPACK_METHODS_RANN_RA_SEARCH_RULES_HPP
25 #define __MLPACK_METHODS_RANN_RA_SEARCH_RULES_HPP
26 
27 #include "../neighbor_search/ns_traversal_info.hpp"
28 #include "ra_search.hpp" // For friend declaration.
29 
30 namespace mlpack {
31 namespace neighbor {
32 
33 template<typename SortPolicy, typename MetricType, typename TreeType>
35 {
36  public:
37  RASearchRules(const arma::mat& referenceSet,
38  const arma::mat& querySet,
39  arma::Mat<size_t>& neighbors,
40  arma::mat& distances,
41  MetricType& metric,
42  const double tau = 5,
43  const double alpha = 0.95,
44  const bool naive = false,
45  const bool sampleAtLeaves = false,
46  const bool firstLeafExact = false,
47  const size_t singleSampleLimit = 20);
48 
49 
50 
51  double BaseCase(const size_t queryIndex, const size_t referenceIndex);
52 
75  double Score(const size_t queryIndex, TreeType& referenceNode);
76 
100  double Score(const size_t queryIndex,
101  TreeType& referenceNode,
102  const double baseCaseResult);
103 
121  double Rescore(const size_t queryIndex,
122  TreeType& referenceNode,
123  const double oldScore);
124 
143  double Score(TreeType& queryNode, TreeType& referenceNode);
144 
165  double Score(TreeType& queryNode,
166  TreeType& referenceNode,
167  const double baseCaseResult);
168 
191  double Rescore(TreeType& queryNode,
192  TreeType& referenceNode,
193  const double oldScore);
194 
195 
198  {
199  if (numSamplesMade.n_elem == 0)
200  return 0;
201  else
202  return arma::sum(numSamplesMade);
203  }
204 
206 
207  const TraversalInfoType& TraversalInfo() const { return traversalInfo; }
208  TraversalInfoType& TraversalInfo() { return traversalInfo; }
209 
210  private:
212  const arma::mat& referenceSet;
213 
215  const arma::mat& querySet;
216 
218  arma::Mat<size_t>& neighbors;
219 
221  arma::mat& distances;
222 
224  MetricType& metric;
225 
228 
231 
234 
237 
239  arma::Col<size_t> numSamplesMade;
240 
243 
244  // TO REMOVE: just for testing
246 
247  TraversalInfoType traversalInfo;
248 
258  void InsertNeighbor(const size_t queryIndex,
259  const size_t pos,
260  const size_t neighbor,
261  const double distance);
262 
272  size_t MinimumSamplesReqd(const size_t n,
273  const size_t k,
274  const double tau,
275  const double alpha) const;
276 
286  double SuccessProbability(const size_t n,
287  const size_t k,
288  const size_t m,
289  const size_t t) const;
290 
300  void ObtainDistinctSamples(const size_t numSamples,
301  const size_t rangeUpperBound,
302  arma::uvec& distinctSamples) const;
303 
307  double Score(const size_t queryIndex,
308  TreeType& referenceNode,
309  const double distance,
310  const double bestDistance);
311 
315  double Score(TreeType& queryNode,
316  TreeType& referenceNode,
317  const double distance,
318  const double bestDistance);
319 
320  // So that RASearch can access ObtainDistinctSamples() and
321  // MinimumSamplesReqd(). Maybe refactoring is a better solution but this is
322  // okay for now.
323  friend class RASearch<SortPolicy, MetricType, TreeType>;
324 }; // class RASearchRules
325 
326 }; // namespace neighbor
327 }; // namespace mlpack
328 
329 // Include implementation.
330 #include "ra_search_rules_impl.hpp"
331 
332 #endif // __MLPACK_METHODS_RANN_RA_SEARCH_RULES_HPP
bool firstLeafExact
Whether to do exact computation on the first leaf before any sampling.
double samplingRatio
The sampling ratio.
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: load.hpp:31
size_t numSamplesReqd
The minimum number of samples required per query.
MetricType & metric
The instantiated metric.
double Rescore(const size_t queryIndex, TreeType &referenceNode, const double oldScore)
Re-evaluate the score for recursion order.
arma::mat & distances
The matrix the resultant neighbor distances should be stored in.
double Score(const size_t queryIndex, TreeType &referenceNode)
Get the score for recursion order.
neighbor::NeighborSearchTraversalInfo< TreeType > TraversalInfoType
arma::Col< size_t > numSamplesMade
The number of samples made for every query.
The RASearch class: This class provides a generic manner to perform rank-approximate search via rando...
Definition: ra_search.hpp:71
const TraversalInfoType & TraversalInfo() const
Traversal information for NeighborSearch.
const arma::mat & querySet
The query set.
bool sampleAtLeaves
Whether to sample at leaves or just use all of it.
RASearchRules(const arma::mat &referenceSet, const arma::mat &querySet, arma::Mat< size_t > &neighbors, arma::mat &distances, MetricType &metric, const double tau=5, const double alpha=0.95, const bool naive=false, const bool sampleAtLeaves=false, const bool firstLeafExact=false, const size_t singleSampleLimit=20)
size_t MinimumSamplesReqd(const size_t n, const size_t k, const double tau, const double alpha) const
Compute the minimum number of samples required to guarantee the given rank-approximation and success ...
const arma::mat & referenceSet
The reference set.
TraversalInfoType & TraversalInfo()
void InsertNeighbor(const size_t queryIndex, const size_t pos, const size_t neighbor, const double distance)
Insert a point into the neighbors and distances matrices; this is a helper function.
see subsection cli_alt_reg_tut Alternate DET regularization The usual regularized error f $R_ alpha(t)\f $of a node\f $t\f $is given by
Definition: det.txt:367
void ObtainDistinctSamples(const size_t numSamples, const size_t rangeUpperBound, arma::uvec &distinctSamples) const
Pick up desired number of samples (with replacement) from a given range of integers so that only the ...
double SuccessProbability(const size_t n, const size_t k, const size_t m, const size_t t) const
Compute the success probability of obtaining 'k'-neighbors from a set of size 'n' within the top 't' ...
size_t singleSampleLimit
The limit on the largest node that can be approximated by sampling.
double BaseCase(const size_t queryIndex, const size_t referenceIndex)
arma::Mat< size_t > & neighbors
The matrix the resultant neighbor indices should be stored in.