opencv  2.2.0
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Pages
ml.hpp
Go to the documentation of this file.
1 /*M///////////////////////////////////////////////////////////////////////////////////////
2 //
3 // IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING.
4 //
5 // By downloading, copying, installing or using the software you agree to this license.
6 // If you do not agree to this license, do not download, install,
7 // copy or use the software.
8 //
9 //
10 // Intel License Agreement
11 //
12 // Copyright (C) 2000, Intel Corporation, all rights reserved.
13 // Third party copyrights are property of their respective owners.
14 //
15 // Redistribution and use in source and binary forms, with or without modification,
16 // are permitted provided that the following conditions are met:
17 //
18 // * Redistribution's of source code must retain the above copyright notice,
19 // this list of conditions and the following disclaimer.
20 //
21 // * Redistribution's in binary form must reproduce the above copyright notice,
22 // this list of conditions and the following disclaimer in the documentation
23 // and/or other materials provided with the distribution.
24 //
25 // * The name of Intel Corporation may not be used to endorse or promote products
26 // derived from this software without specific prior written permission.
27 //
28 // This software is provided by the copyright holders and contributors "as is" and
29 // any express or implied warranties, including, but not limited to, the implied
30 // warranties of merchantability and fitness for a particular purpose are disclaimed.
31 // In no event shall the Intel Corporation or contributors be liable for any direct,
32 // indirect, incidental, special, exemplary, or consequential damages
33 // (including, but not limited to, procurement of substitute goods or services;
34 // loss of use, data, or profits; or business interruption) however caused
35 // and on any theory of liability, whether in contract, strict liability,
36 // or tort (including negligence or otherwise) arising in any way out of
37 // the use of this software, even if advised of the possibility of such damage.
38 //
39 //M*/
40 
41 #ifndef __OPENCV_ML_HPP__
42 #define __OPENCV_ML_HPP__
43 
44 // disable deprecation warning which appears in VisualStudio 8.0
45 #if _MSC_VER >= 1400
46 #pragma warning( disable : 4996 )
47 #endif
48 
49 #ifndef SKIP_INCLUDES
50 
51  #include "opencv2/core/core.hpp"
52  #include <limits.h>
53 
54  #if defined WIN32 || defined _WIN32
55  #include <windows.h>
56  #endif
57 
58 #else // SKIP_INCLUDES
59 
60  #if defined WIN32 || defined _WIN32
61  #define CV_CDECL __cdecl
62  #define CV_STDCALL __stdcall
63  #else
64  #define CV_CDECL
65  #define CV_STDCALL
66  #endif
67 
68  #ifndef CV_EXTERN_C
69  #ifdef __cplusplus
70  #define CV_EXTERN_C extern "C"
71  #define CV_DEFAULT(val) = val
72  #else
73  #define CV_EXTERN_C
74  #define CV_DEFAULT(val)
75  #endif
76  #endif
77 
78  #ifndef CV_EXTERN_C_FUNCPTR
79  #ifdef __cplusplus
80  #define CV_EXTERN_C_FUNCPTR(x) extern "C" { typedef x; }
81  #else
82  #define CV_EXTERN_C_FUNCPTR(x) typedef x
83  #endif
84  #endif
85 
86  #ifndef CV_INLINE
87  #if defined __cplusplus
88  #define CV_INLINE inline
89  #elif (defined WIN32 || defined _WIN32) && !defined __GNUC__
90  #define CV_INLINE __inline
91  #else
92  #define CV_INLINE static
93  #endif
94  #endif /* CV_INLINE */
95 
96  #if (defined WIN32 || defined _WIN32) && defined CVAPI_EXPORTS
97  #define CV_EXPORTS __declspec(dllexport)
98  #else
99  #define CV_EXPORTS
100  #endif
101 
102  #ifndef CVAPI
103  #define CVAPI(rettype) CV_EXTERN_C CV_EXPORTS rettype CV_CDECL
104  #endif
105 
106 #endif // SKIP_INCLUDES
107 
108 
109 #ifdef __cplusplus
110 
111 // Apple defines a check() macro somewhere in the debug headers
112 // that interferes with a method definiton in this header
113 #undef check
114 
115 /****************************************************************************************\
116 * Main struct definitions *
117 \****************************************************************************************/
118 
119 /* log(2*PI) */
120 #define CV_LOG2PI (1.8378770664093454835606594728112)
121 
122 /* columns of <trainData> matrix are training samples */
123 #define CV_COL_SAMPLE 0
124 
125 /* rows of <trainData> matrix are training samples */
126 #define CV_ROW_SAMPLE 1
127 
128 #define CV_IS_ROW_SAMPLE(flags) ((flags) & CV_ROW_SAMPLE)
129 
130 struct CvVectors
131 {
132  int type;
133  int dims, count;
135  union
136  {
138  float** fl;
139  double** db;
140  } data;
141 };
142 
143 #if 0
144 /* A structure, representing the lattice range of statmodel parameters.
145  It is used for optimizing statmodel parameters by cross-validation method.
146  The lattice is logarithmic, so <step> must be greater then 1. */
147 typedef struct CvParamLattice
148 {
149  double min_val;
150  double max_val;
151  double step;
152 }
153 CvParamLattice;
154 
155 CV_INLINE CvParamLattice cvParamLattice( double min_val, double max_val,
156  double log_step )
157 {
158  CvParamLattice pl;
159  pl.min_val = MIN( min_val, max_val );
160  pl.max_val = MAX( min_val, max_val );
161  pl.step = MAX( log_step, 1. );
162  return pl;
163 }
164 
165 CV_INLINE CvParamLattice cvDefaultParamLattice( void )
166 {
167  CvParamLattice pl = {0,0,0};
168  return pl;
169 }
170 #endif
171 
172 /* Variable type */
173 #define CV_VAR_NUMERICAL 0
174 #define CV_VAR_ORDERED 0
175 #define CV_VAR_CATEGORICAL 1
176 
177 #define CV_TYPE_NAME_ML_SVM "opencv-ml-svm"
178 #define CV_TYPE_NAME_ML_KNN "opencv-ml-knn"
179 #define CV_TYPE_NAME_ML_NBAYES "opencv-ml-bayesian"
180 #define CV_TYPE_NAME_ML_EM "opencv-ml-em"
181 #define CV_TYPE_NAME_ML_BOOSTING "opencv-ml-boost-tree"
182 #define CV_TYPE_NAME_ML_TREE "opencv-ml-tree"
183 #define CV_TYPE_NAME_ML_ANN_MLP "opencv-ml-ann-mlp"
184 #define CV_TYPE_NAME_ML_CNN "opencv-ml-cnn"
185 #define CV_TYPE_NAME_ML_RTREES "opencv-ml-random-trees"
186 #define CV_TYPE_NAME_ML_GBT "opencv-ml-gradient-boosting-trees"
187 
188 #define CV_TRAIN_ERROR 0
189 #define CV_TEST_ERROR 1
190 
192 {
193 public:
194  CvStatModel();
195  virtual ~CvStatModel();
196 
197  virtual void clear();
198 
199  CV_WRAP virtual void save( const char* filename, const char* name=0 ) const;
200  CV_WRAP virtual void load( const char* filename, const char* name=0 );
201 
202  virtual void write( CvFileStorage* storage, const char* name ) const;
203  virtual void read( CvFileStorage* storage, CvFileNode* node );
204 
205 protected:
206  const char* default_model_name;
207 };
208 
209 /****************************************************************************************\
210 * Normal Bayes Classifier *
211 \****************************************************************************************/
212 
213 /* The structure, representing the grid range of statmodel parameters.
214  It is used for optimizing statmodel accuracy by varying model parameters,
215  the accuracy estimate being computed by cross-validation.
216  The grid is logarithmic, so <step> must be greater then 1. */
217 
218 class CvMLData;
219 
221 {
222  // SVM params type
223  enum { SVM_C=0, SVM_GAMMA=1, SVM_P=2, SVM_NU=3, SVM_COEF=4, SVM_DEGREE=5 };
224 
226  {
227  min_val = max_val = step = 0;
228  }
229 
230  CvParamGrid( double _min_val, double _max_val, double log_step )
231  {
232  min_val = _min_val;
233  max_val = _max_val;
234  step = log_step;
235  }
236  //CvParamGrid( int param_id );
237  bool check() const;
238 
241  CV_PROP_RW double step;
242 };
243 
245 {
246 public:
248  virtual ~CvNormalBayesClassifier();
249 
250  CvNormalBayesClassifier( const CvMat* trainData, const CvMat* responses,
251  const CvMat* varIdx=0, const CvMat* sampleIdx=0 );
252 
253  virtual bool train( const CvMat* trainData, const CvMat* responses,
254  const CvMat* varIdx = 0, const CvMat* sampleIdx=0, bool update=false );
255 
256  virtual float predict( const CvMat* samples, CV_OUT CvMat* results=0 ) const;
257  CV_WRAP virtual void clear();
258 
259 #ifndef SWIG
260  CV_WRAP CvNormalBayesClassifier( const cv::Mat& trainData, const cv::Mat& responses,
261  const cv::Mat& varIdx=cv::Mat(), const cv::Mat& sampleIdx=cv::Mat() );
262  CV_WRAP virtual bool train( const cv::Mat& trainData, const cv::Mat& responses,
263  const cv::Mat& varIdx = cv::Mat(), const cv::Mat& sampleIdx=cv::Mat(),
264  bool update=false );
265  CV_WRAP virtual float predict( const cv::Mat& samples, CV_OUT cv::Mat* results=0 ) const;
266 #endif
267 
268  virtual void write( CvFileStorage* storage, const char* name ) const;
269  virtual void read( CvFileStorage* storage, CvFileNode* node );
270 
271 protected:
272  int var_count, var_all;
282 };
283 
284 
285 /****************************************************************************************\
286 * K-Nearest Neighbour Classifier *
287 \****************************************************************************************/
288 
289 // k Nearest Neighbors
291 {
292 public:
293 
295  virtual ~CvKNearest();
296 
297  CvKNearest( const CvMat* trainData, const CvMat* responses,
298  const CvMat* sampleIdx=0, bool isRegression=false, int max_k=32 );
299 
300  virtual bool train( const CvMat* trainData, const CvMat* responses,
301  const CvMat* sampleIdx=0, bool is_regression=false,
302  int maxK=32, bool updateBase=false );
303 
304  virtual float find_nearest( const CvMat* samples, int k, CV_OUT CvMat* results=0,
305  const float** neighbors=0, CV_OUT CvMat* neighborResponses=0, CV_OUT CvMat* dist=0 ) const;
306 
307 #ifndef SWIG
308  CV_WRAP CvKNearest( const cv::Mat& trainData, const cv::Mat& responses,
309  const cv::Mat& sampleIdx=cv::Mat(), bool isRegression=false, int max_k=32 );
310 
311  CV_WRAP virtual bool train( const cv::Mat& trainData, const cv::Mat& responses,
312  const cv::Mat& sampleIdx=cv::Mat(), bool isRegression=false,
313  int maxK=32, bool updateBase=false );
314 
315  virtual float find_nearest( const cv::Mat& samples, int k, cv::Mat* results=0,
316  const float** neighbors=0, cv::Mat* neighborResponses=0,
317  cv::Mat* dist=0 ) const;
318  CV_WRAP virtual float find_nearest( const cv::Mat& samples, int k, CV_OUT cv::Mat& results,
319  CV_OUT cv::Mat& neighborResponses, CV_OUT cv::Mat& dists) const;
320 #endif
321 
322  virtual void clear();
323  int get_max_k() const;
324  int get_var_count() const;
325  int get_sample_count() const;
326  bool is_regression() const;
327 
328 protected:
329 
330  virtual float write_results( int k, int k1, int start, int end,
331  const float* neighbor_responses, const float* dist, CvMat* _results,
332  CvMat* _neighbor_responses, CvMat* _dist, Cv32suf* sort_buf ) const;
333 
334  virtual void find_neighbors_direct( const CvMat* _samples, int k, int start, int end,
335  float* neighbor_responses, const float** neighbors, float* dist ) const;
336 
337 
338  int max_k, var_count;
339  int total;
342 };
343 
344 /****************************************************************************************\
345 * Support Vector Machines *
346 \****************************************************************************************/
347 
348 // SVM training parameters
350 {
351  CvSVMParams();
352  CvSVMParams( int _svm_type, int _kernel_type,
353  double _degree, double _gamma, double _coef0,
354  double Cvalue, double _nu, double _p,
355  CvMat* _class_weights, CvTermCriteria _term_crit );
356 
359  CV_PROP_RW double degree; // for poly
360  CV_PROP_RW double gamma; // for poly/rbf/sigmoid
361  CV_PROP_RW double coef0; // for poly/sigmoid
362 
363  CV_PROP_RW double C; // for CV_SVM_C_SVC, CV_SVM_EPS_SVR and CV_SVM_NU_SVR
364  CV_PROP_RW double nu; // for CV_SVM_NU_SVC, CV_SVM_ONE_CLASS, and CV_SVM_NU_SVR
365  CV_PROP_RW double p; // for CV_SVM_EPS_SVR
366  CvMat* class_weights; // for CV_SVM_C_SVC
367  CV_PROP_RW CvTermCriteria term_crit; // termination criteria
368 };
369 
370 
372 {
373  typedef void (CvSVMKernel::*Calc)( int vec_count, int vec_size, const float** vecs,
374  const float* another, float* results );
375  CvSVMKernel();
376  CvSVMKernel( const CvSVMParams* params, Calc _calc_func );
377  virtual bool create( const CvSVMParams* params, Calc _calc_func );
378  virtual ~CvSVMKernel();
379 
380  virtual void clear();
381  virtual void calc( int vcount, int n, const float** vecs, const float* another, float* results );
382 
384  Calc calc_func;
385 
386  virtual void calc_non_rbf_base( int vec_count, int vec_size, const float** vecs,
387  const float* another, float* results,
388  double alpha, double beta );
389 
390  virtual void calc_linear( int vec_count, int vec_size, const float** vecs,
391  const float* another, float* results );
392  virtual void calc_rbf( int vec_count, int vec_size, const float** vecs,
393  const float* another, float* results );
394  virtual void calc_poly( int vec_count, int vec_size, const float** vecs,
395  const float* another, float* results );
396  virtual void calc_sigmoid( int vec_count, int vec_size, const float** vecs,
397  const float* another, float* results );
398 };
399 
400 
402 {
405  float* data;
406 };
407 
408 
410 {
411  double obj;
412  double rho;
415  double r; // for Solver_NU
416 };
417 
419 {
420 public:
421  typedef bool (CvSVMSolver::*SelectWorkingSet)( int& i, int& j );
422  typedef float* (CvSVMSolver::*GetRow)( int i, float* row, float* dst, bool existed );
423  typedef void (CvSVMSolver::*CalcRho)( double& rho, double& r );
424 
425  CvSVMSolver();
426 
427  CvSVMSolver( int count, int var_count, const float** samples, schar* y,
428  int alpha_count, double* alpha, double Cp, double Cn,
429  CvMemStorage* storage, CvSVMKernel* kernel, GetRow get_row,
430  SelectWorkingSet select_working_set, CalcRho calc_rho );
431  virtual bool create( int count, int var_count, const float** samples, schar* y,
432  int alpha_count, double* alpha, double Cp, double Cn,
433  CvMemStorage* storage, CvSVMKernel* kernel, GetRow get_row,
434  SelectWorkingSet select_working_set, CalcRho calc_rho );
435  virtual ~CvSVMSolver();
436 
437  virtual void clear();
438  virtual bool solve_generic( CvSVMSolutionInfo& si );
439 
440  virtual bool solve_c_svc( int count, int var_count, const float** samples, schar* y,
441  double Cp, double Cn, CvMemStorage* storage,
442  CvSVMKernel* kernel, double* alpha, CvSVMSolutionInfo& si );
443  virtual bool solve_nu_svc( int count, int var_count, const float** samples, schar* y,
444  CvMemStorage* storage, CvSVMKernel* kernel,
445  double* alpha, CvSVMSolutionInfo& si );
446  virtual bool solve_one_class( int count, int var_count, const float** samples,
447  CvMemStorage* storage, CvSVMKernel* kernel,
448  double* alpha, CvSVMSolutionInfo& si );
449 
450  virtual bool solve_eps_svr( int count, int var_count, const float** samples, const float* y,
451  CvMemStorage* storage, CvSVMKernel* kernel,
452  double* alpha, CvSVMSolutionInfo& si );
453 
454  virtual bool solve_nu_svr( int count, int var_count, const float** samples, const float* y,
455  CvMemStorage* storage, CvSVMKernel* kernel,
456  double* alpha, CvSVMSolutionInfo& si );
457 
458  virtual float* get_row_base( int i, bool* _existed );
459  virtual float* get_row( int i, float* dst );
460 
465  const float** samples;
470 
472 
473  double* G;
474  double* alpha;
475 
476  // -1 - lower bound, 0 - free, 1 - upper bound
478 
480  double* b;
481  float* buf[2];
482  double eps;
483  int max_iter;
484  double C[2]; // C[0] == Cn, C[1] == Cp
486 
487  SelectWorkingSet select_working_set_func;
488  CalcRho calc_rho_func;
489  GetRow get_row_func;
490 
491  virtual bool select_working_set( int& i, int& j );
492  virtual bool select_working_set_nu_svm( int& i, int& j );
493  virtual void calc_rho( double& rho, double& r );
494  virtual void calc_rho_nu_svm( double& rho, double& r );
495 
496  virtual float* get_row_svc( int i, float* row, float* dst, bool existed );
497  virtual float* get_row_one_class( int i, float* row, float* dst, bool existed );
498  virtual float* get_row_svr( int i, float* row, float* dst, bool existed );
499 };
500 
501 
503 {
504  double rho;
505  int sv_count;
506  double* alpha;
507  int* sv_index;
508 };
509 
510 
511 // SVM model
513 {
514 public:
515  // SVM type
516  enum { C_SVC=100, NU_SVC=101, ONE_CLASS=102, EPS_SVR=103, NU_SVR=104 };
517 
518  // SVM kernel type
519  enum { LINEAR=0, POLY=1, RBF=2, SIGMOID=3 };
520 
521  // SVM params type
522  enum { C=0, GAMMA=1, P=2, NU=3, COEF=4, DEGREE=5 };
523 
524  CV_WRAP CvSVM();
525  virtual ~CvSVM();
526 
527  CvSVM( const CvMat* trainData, const CvMat* responses,
528  const CvMat* varIdx=0, const CvMat* sampleIdx=0,
529  CvSVMParams params=CvSVMParams() );
530 
531  virtual bool train( const CvMat* trainData, const CvMat* responses,
532  const CvMat* varIdx=0, const CvMat* sampleIdx=0,
533  CvSVMParams params=CvSVMParams() );
534 
535  virtual bool train_auto( const CvMat* trainData, const CvMat* responses,
536  const CvMat* varIdx, const CvMat* sampleIdx, CvSVMParams params,
537  int kfold = 10,
538  CvParamGrid Cgrid = get_default_grid(CvSVM::C),
539  CvParamGrid gammaGrid = get_default_grid(CvSVM::GAMMA),
540  CvParamGrid pGrid = get_default_grid(CvSVM::P),
541  CvParamGrid nuGrid = get_default_grid(CvSVM::NU),
542  CvParamGrid coeffGrid = get_default_grid(CvSVM::COEF),
543  CvParamGrid degreeGrid = get_default_grid(CvSVM::DEGREE),
544  bool balanced=false );
545 
546  virtual float predict( const CvMat* sample, bool returnDFVal=false ) const;
547 
548 #ifndef SWIG
549  CV_WRAP CvSVM( const cv::Mat& trainData, const cv::Mat& responses,
550  const cv::Mat& varIdx=cv::Mat(), const cv::Mat& sampleIdx=cv::Mat(),
551  CvSVMParams params=CvSVMParams() );
552 
553  CV_WRAP virtual bool train( const cv::Mat& trainData, const cv::Mat& responses,
554  const cv::Mat& varIdx=cv::Mat(), const cv::Mat& sampleIdx=cv::Mat(),
555  CvSVMParams params=CvSVMParams() );
556 
557  CV_WRAP virtual bool train_auto( const cv::Mat& trainData, const cv::Mat& responses,
558  const cv::Mat& varIdx, const cv::Mat& sampleIdx, CvSVMParams params,
559  int k_fold = 10,
566  bool balanced=false);
567  CV_WRAP virtual float predict( const cv::Mat& sample, bool returnDFVal=false ) const;
568 #endif
569 
570  CV_WRAP virtual int get_support_vector_count() const;
571  virtual const float* get_support_vector(int i) const;
572  virtual CvSVMParams get_params() const { return params; };
573  CV_WRAP virtual void clear();
574 
575  static CvParamGrid get_default_grid( int param_id );
576 
577  virtual void write( CvFileStorage* storage, const char* name ) const;
578  virtual void read( CvFileStorage* storage, CvFileNode* node );
579  CV_WRAP int get_var_count() const { return var_idx ? var_idx->cols : var_all; }
580 
581 protected:
582 
583  virtual bool set_params( const CvSVMParams& params );
584  virtual bool train1( int sample_count, int var_count, const float** samples,
585  const void* responses, double Cp, double Cn,
586  CvMemStorage* _storage, double* alpha, double& rho );
587  virtual bool do_train( int svm_type, int sample_count, int var_count, const float** samples,
588  const CvMat* responses, CvMemStorage* _storage, double* alpha );
589  virtual void create_kernel();
590  virtual void create_solver();
591 
592  virtual float predict( const float* row_sample, int row_len, bool returnDFVal=false ) const;
593 
594  virtual void write_params( CvFileStorage* fs ) const;
595  virtual void read_params( CvFileStorage* fs, CvFileNode* node );
596 
599  int var_all;
600  float** sv;
601  int sv_total;
606 
609 };
610 
611 /****************************************************************************************\
612 * Expectation - Maximization *
613 \****************************************************************************************/
614 
616 {
617  CvEMParams() : nclusters(10), cov_mat_type(1/*CvEM::COV_MAT_DIAGONAL*/),
618  start_step(0/*CvEM::START_AUTO_STEP*/), probs(0), weights(0), means(0), covs(0)
619  {
620  term_crit=cvTermCriteria( CV_TERMCRIT_ITER+CV_TERMCRIT_EPS, 100, FLT_EPSILON );
621  }
622 
623  CvEMParams( int _nclusters, int _cov_mat_type=1/*CvEM::COV_MAT_DIAGONAL*/,
624  int _start_step=0/*CvEM::START_AUTO_STEP*/,
625  CvTermCriteria _term_crit=cvTermCriteria(CV_TERMCRIT_ITER+CV_TERMCRIT_EPS, 100, FLT_EPSILON),
626  const CvMat* _probs=0, const CvMat* _weights=0, const CvMat* _means=0, const CvMat** _covs=0 ) :
627  nclusters(_nclusters), cov_mat_type(_cov_mat_type), start_step(_start_step),
628  probs(_probs), weights(_weights), means(_means), covs(_covs), term_crit(_term_crit)
629  {}
630 
634  const CvMat* probs;
635  const CvMat* weights;
636  const CvMat* means;
637  const CvMat** covs;
639 };
640 
641 
643 {
644 public:
645  // Type of covariation matrices
646  enum { COV_MAT_SPHERICAL=0, COV_MAT_DIAGONAL=1, COV_MAT_GENERIC=2 };
647 
648  // The initial step
649  enum { START_E_STEP=1, START_M_STEP=2, START_AUTO_STEP=0 };
650 
651  CV_WRAP CvEM();
652  CvEM( const CvMat* samples, const CvMat* sampleIdx=0,
653  CvEMParams params=CvEMParams(), CvMat* labels=0 );
654  //CvEM (CvEMParams params, CvMat * means, CvMat ** covs, CvMat * weights,
655  // CvMat * probs, CvMat * log_weight_div_det, CvMat * inv_eigen_values, CvMat** cov_rotate_mats);
656 
657  virtual ~CvEM();
658 
659  virtual bool train( const CvMat* samples, const CvMat* sampleIdx=0,
660  CvEMParams params=CvEMParams(), CvMat* labels=0 );
661 
662  virtual float predict( const CvMat* sample, CV_OUT CvMat* probs ) const;
663 
664 #ifndef SWIG
665  CV_WRAP CvEM( const cv::Mat& samples, const cv::Mat& sampleIdx=cv::Mat(),
666  CvEMParams params=CvEMParams() );
667 
668  CV_WRAP virtual bool train( const cv::Mat& samples,
669  const cv::Mat& sampleIdx=cv::Mat(),
670  CvEMParams params=CvEMParams(),
671  CV_OUT cv::Mat* labels=0 );
672 
673  CV_WRAP virtual float predict( const cv::Mat& sample, CV_OUT cv::Mat* probs=0 ) const;
674 
675  CV_WRAP int getNClusters() const;
676  CV_WRAP cv::Mat getMeans() const;
677  CV_WRAP void getCovs(CV_OUT std::vector<cv::Mat>& covs) const;
678  CV_WRAP cv::Mat getWeights() const;
679  CV_WRAP cv::Mat getProbs() const;
680 
681  CV_WRAP inline double getLikelihood() const { return log_likelihood; };
682 #endif
683 
684  CV_WRAP virtual void clear();
685 
686  int get_nclusters() const;
687  const CvMat* get_means() const;
688  const CvMat** get_covs() const;
689  const CvMat* get_weights() const;
690  const CvMat* get_probs() const;
691 
692  inline double get_log_likelihood () const { return log_likelihood; };
693 
694 // inline const CvMat * get_log_weight_div_det () const { return log_weight_div_det; };
695 // inline const CvMat * get_inv_eigen_values () const { return inv_eigen_values; };
696 // inline const CvMat ** get_cov_rotate_mats () const { return cov_rotate_mats; };
697 
698 protected:
699 
700  virtual void set_params( const CvEMParams& params,
701  const CvVectors& train_data );
702  virtual void init_em( const CvVectors& train_data );
703  virtual double run_em( const CvVectors& train_data );
704  virtual void init_auto( const CvVectors& samples );
705  virtual void kmeans( const CvVectors& train_data, int nclusters,
706  CvMat* labels, CvTermCriteria criteria,
707  const CvMat* means );
710 
715 
719 };
720 
721 /****************************************************************************************\
722 * Decision Tree *
723 \****************************************************************************************/\
725 {
726  unsigned short* u;
727  int* i;
728 };
729 
730 
731 #define CV_DTREE_CAT_DIR(idx,subset) \
732  (2*((subset[(idx)>>5]&(1 << ((idx) & 31)))==0)-1)
733 
735 {
736  int var_idx;
738  int inversed;
739  float quality;
741  union
742  {
743  int subset[2];
744  struct
745  {
746  float c;
748  }
749  ord;
750  };
751 };
752 
754 {
756  int Tn;
757  double value;
758 
762 
764 
766  int depth;
767  int* num_valid;
768  int offset;
769  int buf_idx;
770  double maxlr;
771 
772  // global pruning data
774  double alpha;
776 
777  // cross-validation pruning data
778  int* cv_Tn;
779  double* cv_node_risk;
780  double* cv_node_error;
781 
782  int get_num_valid(int vi) { return num_valid ? num_valid[vi] : sample_count; }
783  void set_num_valid(int vi, int n) { if( num_valid ) num_valid[vi] = n; }
784 };
785 
786 
788 {
797  const float* priors;
798 
799  CvDTreeParams() : max_categories(10), max_depth(INT_MAX), min_sample_count(10),
800  cv_folds(10), use_surrogates(true), use_1se_rule(true),
801  truncate_pruned_tree(true), regression_accuracy(0.01f), priors(0)
802  {}
803 
804  CvDTreeParams( int _max_depth, int _min_sample_count,
805  float _regression_accuracy, bool _use_surrogates,
806  int _max_categories, int _cv_folds,
807  bool _use_1se_rule, bool _truncate_pruned_tree,
808  const float* _priors ) :
809  max_categories(_max_categories), max_depth(_max_depth),
810  min_sample_count(_min_sample_count), cv_folds (_cv_folds),
811  use_surrogates(_use_surrogates), use_1se_rule(_use_1se_rule),
812  truncate_pruned_tree(_truncate_pruned_tree),
813  regression_accuracy(_regression_accuracy),
814  priors(_priors)
815  {}
816 };
817 
818 
820 {
822  CvDTreeTrainData( const CvMat* trainData, int tflag,
823  const CvMat* responses, const CvMat* varIdx=0,
824  const CvMat* sampleIdx=0, const CvMat* varType=0,
825  const CvMat* missingDataMask=0,
826  const CvDTreeParams& params=CvDTreeParams(),
827  bool _shared=false, bool _add_labels=false );
828  virtual ~CvDTreeTrainData();
829 
830  virtual void set_data( const CvMat* trainData, int tflag,
831  const CvMat* responses, const CvMat* varIdx=0,
832  const CvMat* sampleIdx=0, const CvMat* varType=0,
833  const CvMat* missingDataMask=0,
834  const CvDTreeParams& params=CvDTreeParams(),
835  bool _shared=false, bool _add_labels=false,
836  bool _update_data=false );
837  virtual void do_responses_copy();
838 
839  virtual void get_vectors( const CvMat* _subsample_idx,
840  float* values, uchar* missing, float* responses, bool get_class_idx=false );
841 
842  virtual CvDTreeNode* subsample_data( const CvMat* _subsample_idx );
843 
844  virtual void write_params( CvFileStorage* fs ) const;
845  virtual void read_params( CvFileStorage* fs, CvFileNode* node );
846 
847  // release all the data
848  virtual void clear();
849 
850  int get_num_classes() const;
851  int get_var_type(int vi) const;
852  int get_work_var_count() const {return work_var_count;}
853 
854  virtual const float* get_ord_responses( CvDTreeNode* n, float* values_buf, int* sample_indices_buf );
855  virtual const int* get_class_labels( CvDTreeNode* n, int* labels_buf );
856  virtual const int* get_cv_labels( CvDTreeNode* n, int* labels_buf );
857  virtual const int* get_sample_indices( CvDTreeNode* n, int* indices_buf );
858  virtual const int* get_cat_var_data( CvDTreeNode* n, int vi, int* cat_values_buf );
859  virtual void get_ord_var_data( CvDTreeNode* n, int vi, float* ord_values_buf, int* sorted_indices_buf,
860  const float** ord_values, const int** sorted_indices, int* sample_indices_buf );
861  virtual int get_child_buf_idx( CvDTreeNode* n );
862 
864 
865  virtual bool set_params( const CvDTreeParams& params );
866  virtual CvDTreeNode* new_node( CvDTreeNode* parent, int count,
867  int storage_idx, int offset );
868 
869  virtual CvDTreeSplit* new_split_ord( int vi, float cmp_val,
870  int split_point, int inversed, float quality );
871  virtual CvDTreeSplit* new_split_cat( int vi, float quality );
872  virtual void free_node_data( CvDTreeNode* node );
873  virtual void free_train_data();
874  virtual void free_node( CvDTreeNode* node );
875 
876  int sample_count, var_all, var_count, max_c_count;
877  int ord_var_count, cat_var_count, work_var_count;
878  bool have_labels, have_priors;
880  int tflag;
881 
883  const CvMat* responses;
884  CvMat* responses_copy; // used in Boosting
885 
886  int buf_count, buf_size;
887  bool shared;
889 
893 
898 
900  CvMat* var_type; // i-th element =
901  // k<0 - ordered
902  // k>=0 - categorical, see k-th element of cat_* arrays
905 
907 
910 
912 
917 
919 };
920 
921 class CvDTree;
922 class CvForestTree;
923 
924 namespace cv
925 {
926  struct DTreeBestSplitFinder;
927  struct ForestTreeBestSplitFinder;
928 }
929 
931 {
932 public:
933  CV_WRAP CvDTree();
934  virtual ~CvDTree();
935 
936  virtual bool train( const CvMat* trainData, int tflag,
937  const CvMat* responses, const CvMat* varIdx=0,
938  const CvMat* sampleIdx=0, const CvMat* varType=0,
939  const CvMat* missingDataMask=0,
940  CvDTreeParams params=CvDTreeParams() );
941 
942  virtual bool train( CvMLData* trainData, CvDTreeParams params=CvDTreeParams() );
943 
944  // type in {CV_TRAIN_ERROR, CV_TEST_ERROR}
945  virtual float calc_error( CvMLData* trainData, int type, std::vector<float> *resp = 0 );
946 
947  virtual bool train( CvDTreeTrainData* trainData, const CvMat* subsampleIdx );
948 
949  virtual CvDTreeNode* predict( const CvMat* sample, const CvMat* missingDataMask=0,
950  bool preprocessedInput=false ) const;
951 
952 #ifndef SWIG
953  CV_WRAP virtual bool train( const cv::Mat& trainData, int tflag,
954  const cv::Mat& responses, const cv::Mat& varIdx=cv::Mat(),
955  const cv::Mat& sampleIdx=cv::Mat(), const cv::Mat& varType=cv::Mat(),
956  const cv::Mat& missingDataMask=cv::Mat(),
957  CvDTreeParams params=CvDTreeParams() );
958 
959  CV_WRAP virtual CvDTreeNode* predict( const cv::Mat& sample, const cv::Mat& missingDataMask=cv::Mat(),
960  bool preprocessedInput=false ) const;
961  CV_WRAP virtual cv::Mat getVarImportance();
962 #endif
963 
964  virtual const CvMat* get_var_importance();
965  CV_WRAP virtual void clear();
966 
967  virtual void read( CvFileStorage* fs, CvFileNode* node );
968  virtual void write( CvFileStorage* fs, const char* name ) const;
969 
970  // special read & write methods for trees in the tree ensembles
971  virtual void read( CvFileStorage* fs, CvFileNode* node,
972  CvDTreeTrainData* data );
973  virtual void write( CvFileStorage* fs ) const;
974 
975  const CvDTreeNode* get_root() const;
976  int get_pruned_tree_idx() const;
977  CvDTreeTrainData* get_data();
978 
979 protected:
980  friend struct cv::DTreeBestSplitFinder;
981 
982  virtual bool do_train( const CvMat* _subsample_idx );
983 
984  virtual void try_split_node( CvDTreeNode* n );
985  virtual void split_node_data( CvDTreeNode* n );
986  virtual CvDTreeSplit* find_best_split( CvDTreeNode* n );
987  virtual CvDTreeSplit* find_split_ord_class( CvDTreeNode* n, int vi,
988  float init_quality = 0, CvDTreeSplit* _split = 0, uchar* ext_buf = 0 );
989  virtual CvDTreeSplit* find_split_cat_class( CvDTreeNode* n, int vi,
990  float init_quality = 0, CvDTreeSplit* _split = 0, uchar* ext_buf = 0 );
991  virtual CvDTreeSplit* find_split_ord_reg( CvDTreeNode* n, int vi,
992  float init_quality = 0, CvDTreeSplit* _split = 0, uchar* ext_buf = 0 );
993  virtual CvDTreeSplit* find_split_cat_reg( CvDTreeNode* n, int vi,
994  float init_quality = 0, CvDTreeSplit* _split = 0, uchar* ext_buf = 0 );
995  virtual CvDTreeSplit* find_surrogate_split_ord( CvDTreeNode* n, int vi, uchar* ext_buf = 0 );
996  virtual CvDTreeSplit* find_surrogate_split_cat( CvDTreeNode* n, int vi, uchar* ext_buf = 0 );
997  virtual double calc_node_dir( CvDTreeNode* node );
998  virtual void complete_node_dir( CvDTreeNode* node );
999  virtual void cluster_categories( const int* vectors, int vector_count,
1000  int var_count, int* sums, int k, int* cluster_labels );
1001 
1002  virtual void calc_node_value( CvDTreeNode* node );
1003 
1004  virtual void prune_cv();
1005  virtual double update_tree_rnc( int T, int fold );
1006  virtual int cut_tree( int T, int fold, double min_alpha );
1007  virtual void free_prune_data(bool cut_tree);
1008  virtual void free_tree();
1009 
1010  virtual void write_node( CvFileStorage* fs, CvDTreeNode* node ) const;
1011  virtual void write_split( CvFileStorage* fs, CvDTreeSplit* split ) const;
1012  virtual CvDTreeNode* read_node( CvFileStorage* fs, CvFileNode* node, CvDTreeNode* parent );
1013  virtual CvDTreeSplit* read_split( CvFileStorage* fs, CvFileNode* node );
1014  virtual void write_tree_nodes( CvFileStorage* fs ) const;
1015  virtual void read_tree_nodes( CvFileStorage* fs, CvFileNode* node );
1016 
1020 
1021 public:
1023 };
1024 
1025 
1026 /****************************************************************************************\
1027 * Random Trees Classifier *
1028 \****************************************************************************************/
1029 
1030 class CvRTrees;
1031 
1033 {
1034 public:
1035  CvForestTree();
1036  virtual ~CvForestTree();
1037 
1038  virtual bool train( CvDTreeTrainData* trainData, const CvMat* _subsample_idx, CvRTrees* forest );
1039 
1040  virtual int get_var_count() const {return data ? data->var_count : 0;}
1041  virtual void read( CvFileStorage* fs, CvFileNode* node, CvRTrees* forest, CvDTreeTrainData* _data );
1042 
1043  /* dummy methods to avoid warnings: BEGIN */
1044  virtual bool train( const CvMat* trainData, int tflag,
1045  const CvMat* responses, const CvMat* varIdx=0,
1046  const CvMat* sampleIdx=0, const CvMat* varType=0,
1047  const CvMat* missingDataMask=0,
1048  CvDTreeParams params=CvDTreeParams() );
1049 
1050  virtual bool train( CvDTreeTrainData* trainData, const CvMat* _subsample_idx );
1051  virtual void read( CvFileStorage* fs, CvFileNode* node );
1052  virtual void read( CvFileStorage* fs, CvFileNode* node,
1053  CvDTreeTrainData* data );
1054  /* dummy methods to avoid warnings: END */
1055 
1056 protected:
1057  friend struct cv::ForestTreeBestSplitFinder;
1058 
1059  virtual CvDTreeSplit* find_best_split( CvDTreeNode* n );
1061 };
1062 
1063 
1065 {
1066  //Parameters for the forest
1067  CV_PROP_RW bool calc_var_importance; // true <=> RF processes variable importance
1070 
1071  CvRTParams() : CvDTreeParams( 5, 10, 0, false, 10, 0, false, false, 0 ),
1072  calc_var_importance(false), nactive_vars(0)
1073  {
1074  term_crit = cvTermCriteria( CV_TERMCRIT_ITER+CV_TERMCRIT_EPS, 50, 0.1 );
1075  }
1076 
1077  CvRTParams( int _max_depth, int _min_sample_count,
1078  float _regression_accuracy, bool _use_surrogates,
1079  int _max_categories, const float* _priors, bool _calc_var_importance,
1080  int _nactive_vars, int max_num_of_trees_in_the_forest,
1081  float forest_accuracy, int termcrit_type ) :
1082  CvDTreeParams( _max_depth, _min_sample_count, _regression_accuracy,
1083  _use_surrogates, _max_categories, 0,
1084  false, false, _priors ),
1085  calc_var_importance(_calc_var_importance),
1086  nactive_vars(_nactive_vars)
1087  {
1088  term_crit = cvTermCriteria(termcrit_type,
1089  max_num_of_trees_in_the_forest, forest_accuracy);
1090  }
1091 };
1092 
1093 
1095 {
1096 public:
1097  CV_WRAP CvRTrees();
1098  virtual ~CvRTrees();
1099  virtual bool train( const CvMat* trainData, int tflag,
1100  const CvMat* responses, const CvMat* varIdx=0,
1101  const CvMat* sampleIdx=0, const CvMat* varType=0,
1102  const CvMat* missingDataMask=0,
1103  CvRTParams params=CvRTParams() );
1104 
1105  virtual bool train( CvMLData* data, CvRTParams params=CvRTParams() );
1106  virtual float predict( const CvMat* sample, const CvMat* missing = 0 ) const;
1107  virtual float predict_prob( const CvMat* sample, const CvMat* missing = 0 ) const;
1108 
1109 #ifndef SWIG
1110  CV_WRAP virtual bool train( const cv::Mat& trainData, int tflag,
1111  const cv::Mat& responses, const cv::Mat& varIdx=cv::Mat(),
1112  const cv::Mat& sampleIdx=cv::Mat(), const cv::Mat& varType=cv::Mat(),
1113  const cv::Mat& missingDataMask=cv::Mat(),
1114  CvRTParams params=CvRTParams() );
1115  CV_WRAP virtual float predict( const cv::Mat& sample, const cv::Mat& missing = cv::Mat() ) const;
1116  CV_WRAP virtual float predict_prob( const cv::Mat& sample, const cv::Mat& missing = cv::Mat() ) const;
1117  CV_WRAP virtual cv::Mat getVarImportance();
1118 #endif
1119 
1120  CV_WRAP virtual void clear();
1121 
1122  virtual const CvMat* get_var_importance();
1123  virtual float get_proximity( const CvMat* sample1, const CvMat* sample2,
1124  const CvMat* missing1 = 0, const CvMat* missing2 = 0 ) const;
1125 
1126  virtual float calc_error( CvMLData* _data, int type , std::vector<float> *resp = 0 ); // type in {CV_TRAIN_ERROR, CV_TEST_ERROR}
1127 
1128  virtual float get_train_error();
1129 
1130  virtual void read( CvFileStorage* fs, CvFileNode* node );
1131  virtual void write( CvFileStorage* fs, const char* name ) const;
1132 
1133  CvMat* get_active_var_mask();
1134  CvRNG* get_rng();
1135 
1136  int get_tree_count() const;
1137  CvForestTree* get_tree(int i) const;
1138 
1139 protected:
1140 
1141  virtual bool grow_forest( const CvTermCriteria term_crit );
1142 
1143  // array of the trees of the forest
1146  int ntrees;
1148  double oob_error;
1151 
1154 };
1155 
1156 /****************************************************************************************\
1157 * Extremely randomized trees Classifier *
1158 \****************************************************************************************/
1160 {
1161  virtual void set_data( const CvMat* trainData, int tflag,
1162  const CvMat* responses, const CvMat* varIdx=0,
1163  const CvMat* sampleIdx=0, const CvMat* varType=0,
1164  const CvMat* missingDataMask=0,
1165  const CvDTreeParams& params=CvDTreeParams(),
1166  bool _shared=false, bool _add_labels=false,
1167  bool _update_data=false );
1168  virtual void get_ord_var_data( CvDTreeNode* n, int vi, float* ord_values_buf, int* missing_buf,
1169  const float** ord_values, const int** missing, int* sample_buf = 0 );
1170  virtual const int* get_sample_indices( CvDTreeNode* n, int* indices_buf );
1171  virtual const int* get_cv_labels( CvDTreeNode* n, int* labels_buf );
1172  virtual const int* get_cat_var_data( CvDTreeNode* n, int vi, int* cat_values_buf );
1173  virtual void get_vectors( const CvMat* _subsample_idx, float* values, uchar* missing,
1174  float* responses, bool get_class_idx=false );
1175  virtual CvDTreeNode* subsample_data( const CvMat* _subsample_idx );
1177 };
1178 
1180 {
1181 protected:
1182  virtual double calc_node_dir( CvDTreeNode* node );
1183  virtual CvDTreeSplit* find_split_ord_class( CvDTreeNode* n, int vi,
1184  float init_quality = 0, CvDTreeSplit* _split = 0, uchar* ext_buf = 0 );
1185  virtual CvDTreeSplit* find_split_cat_class( CvDTreeNode* n, int vi,
1186  float init_quality = 0, CvDTreeSplit* _split = 0, uchar* ext_buf = 0 );
1187  virtual CvDTreeSplit* find_split_ord_reg( CvDTreeNode* n, int vi,
1188  float init_quality = 0, CvDTreeSplit* _split = 0, uchar* ext_buf = 0 );
1189  virtual CvDTreeSplit* find_split_cat_reg( CvDTreeNode* n, int vi,
1190  float init_quality = 0, CvDTreeSplit* _split = 0, uchar* ext_buf = 0 );
1191  virtual void split_node_data( CvDTreeNode* n );
1192 };
1193 
1195 {
1196 public:
1197  CV_WRAP CvERTrees();
1198  virtual ~CvERTrees();
1199  virtual bool train( const CvMat* trainData, int tflag,
1200  const CvMat* responses, const CvMat* varIdx=0,
1201  const CvMat* sampleIdx=0, const CvMat* varType=0,
1202  const CvMat* missingDataMask=0,
1203  CvRTParams params=CvRTParams());
1204 #ifndef SWIG
1205  CV_WRAP virtual bool train( const cv::Mat& trainData, int tflag,
1206  const cv::Mat& responses, const cv::Mat& varIdx=cv::Mat(),
1207  const cv::Mat& sampleIdx=cv::Mat(), const cv::Mat& varType=cv::Mat(),
1208  const cv::Mat& missingDataMask=cv::Mat(),
1209  CvRTParams params=CvRTParams());
1210 #endif
1211  virtual bool train( CvMLData* data, CvRTParams params=CvRTParams() );
1212 protected:
1213  virtual bool grow_forest( const CvTermCriteria term_crit );
1214 };
1215 
1216 
1217 /****************************************************************************************\
1218 * Boosted tree classifier *
1219 \****************************************************************************************/
1220 
1222 {
1227 
1228  CvBoostParams();
1229  CvBoostParams( int boost_type, int weak_count, double weight_trim_rate,
1230  int max_depth, bool use_surrogates, const float* priors );
1231 };
1232 
1233 
1234 class CvBoost;
1235 
1237 {
1238 public:
1239  CvBoostTree();
1240  virtual ~CvBoostTree();
1241 
1242  virtual bool train( CvDTreeTrainData* trainData,
1243  const CvMat* subsample_idx, CvBoost* ensemble );
1244 
1245  virtual void scale( double s );
1246  virtual void read( CvFileStorage* fs, CvFileNode* node,
1247  CvBoost* ensemble, CvDTreeTrainData* _data );
1248  virtual void clear();
1249 
1250  /* dummy methods to avoid warnings: BEGIN */
1251  virtual bool train( const CvMat* trainData, int tflag,
1252  const CvMat* responses, const CvMat* varIdx=0,
1253  const CvMat* sampleIdx=0, const CvMat* varType=0,
1254  const CvMat* missingDataMask=0,
1255  CvDTreeParams params=CvDTreeParams() );
1256  virtual bool train( CvDTreeTrainData* trainData, const CvMat* _subsample_idx );
1257 
1258  virtual void read( CvFileStorage* fs, CvFileNode* node );
1259  virtual void read( CvFileStorage* fs, CvFileNode* node,
1260  CvDTreeTrainData* data );
1261  /* dummy methods to avoid warnings: END */
1262 
1263 protected:
1264 
1265  virtual void try_split_node( CvDTreeNode* n );
1266  virtual CvDTreeSplit* find_surrogate_split_ord( CvDTreeNode* n, int vi, uchar* ext_buf = 0 );
1267  virtual CvDTreeSplit* find_surrogate_split_cat( CvDTreeNode* n, int vi, uchar* ext_buf = 0 );
1268  virtual CvDTreeSplit* find_split_ord_class( CvDTreeNode* n, int vi,
1269  float init_quality = 0, CvDTreeSplit* _split = 0, uchar* ext_buf = 0 );
1270  virtual CvDTreeSplit* find_split_cat_class( CvDTreeNode* n, int vi,
1271  float init_quality = 0, CvDTreeSplit* _split = 0, uchar* ext_buf = 0 );
1272  virtual CvDTreeSplit* find_split_ord_reg( CvDTreeNode* n, int vi,
1273  float init_quality = 0, CvDTreeSplit* _split = 0, uchar* ext_buf = 0 );
1274  virtual CvDTreeSplit* find_split_cat_reg( CvDTreeNode* n, int vi,
1275  float init_quality = 0, CvDTreeSplit* _split = 0, uchar* ext_buf = 0 );
1276  virtual void calc_node_value( CvDTreeNode* n );
1277  virtual double calc_node_dir( CvDTreeNode* n );
1278 
1280 };
1281 
1282 
1284 {
1285 public:
1286  // Boosting type
1287  enum { DISCRETE=0, REAL=1, LOGIT=2, GENTLE=3 };
1288 
1289  // Splitting criteria
1290  enum { DEFAULT=0, GINI=1, MISCLASS=3, SQERR=4 };
1291 
1292  CV_WRAP CvBoost();
1293  virtual ~CvBoost();
1294 
1295  CvBoost( const CvMat* trainData, int tflag,
1296  const CvMat* responses, const CvMat* varIdx=0,
1297  const CvMat* sampleIdx=0, const CvMat* varType=0,
1298  const CvMat* missingDataMask=0,
1299  CvBoostParams params=CvBoostParams() );
1300 
1301  virtual bool train( const CvMat* trainData, int tflag,
1302  const CvMat* responses, const CvMat* varIdx=0,
1303  const CvMat* sampleIdx=0, const CvMat* varType=0,
1304  const CvMat* missingDataMask=0,
1305  CvBoostParams params=CvBoostParams(),
1306  bool update=false );
1307 
1308  virtual bool train( CvMLData* data,
1309  CvBoostParams params=CvBoostParams(),
1310  bool update=false );
1311 
1312  virtual float predict( const CvMat* sample, const CvMat* missing=0,
1313  CvMat* weak_responses=0, CvSlice slice=CV_WHOLE_SEQ,
1314  bool raw_mode=false, bool return_sum=false ) const;
1315 
1316 #ifndef SWIG
1317  CV_WRAP CvBoost( const cv::Mat& trainData, int tflag,
1318  const cv::Mat& responses, const cv::Mat& varIdx=cv::Mat(),
1319  const cv::Mat& sampleIdx=cv::Mat(), const cv::Mat& varType=cv::Mat(),
1320  const cv::Mat& missingDataMask=cv::Mat(),
1321  CvBoostParams params=CvBoostParams() );
1322 
1323  CV_WRAP virtual bool train( const cv::Mat& trainData, int tflag,
1324  const cv::Mat& responses, const cv::Mat& varIdx=cv::Mat(),
1325  const cv::Mat& sampleIdx=cv::Mat(), const cv::Mat& varType=cv::Mat(),
1326  const cv::Mat& missingDataMask=cv::Mat(),
1327  CvBoostParams params=CvBoostParams(),
1328  bool update=false );
1329 
1330  CV_WRAP virtual float predict( const cv::Mat& sample, const cv::Mat& missing=cv::Mat(),
1331  const cv::Range& slice=cv::Range::all(), bool rawMode=false,
1332  bool returnSum=false ) const;
1333 #endif
1334 
1335  virtual float calc_error( CvMLData* _data, int type , std::vector<float> *resp = 0 ); // type in {CV_TRAIN_ERROR, CV_TEST_ERROR}
1336 
1337  CV_WRAP virtual void prune( CvSlice slice );
1338 
1339  CV_WRAP virtual void clear();
1340 
1341  virtual void write( CvFileStorage* storage, const char* name ) const;
1342  virtual void read( CvFileStorage* storage, CvFileNode* node );
1343  virtual const CvMat* get_active_vars(bool absolute_idx=true);
1344 
1345  CvSeq* get_weak_predictors();
1346 
1347  CvMat* get_weights();
1348  CvMat* get_subtree_weights();
1349  CvMat* get_weak_response();
1350  const CvBoostParams& get_params() const;
1351  const CvDTreeTrainData* get_data() const;
1352 
1353 protected:
1354 
1355  virtual bool set_params( const CvBoostParams& params );
1356  virtual void update_weights( CvBoostTree* tree );
1357  virtual void trim_weights();
1358  virtual void write_params( CvFileStorage* fs ) const;
1359  virtual void read_params( CvFileStorage* fs, CvFileNode* node );
1360 
1364 
1368 
1376 };
1377 
1378 
1379 /****************************************************************************************\
1380 * Gradient Boosted Trees *
1381 \****************************************************************************************/
1382 
1383 // DataType: STRUCT CvGBTreesParams
1384 // Parameters of GBT (Gradient Boosted trees model), including single
1385 // tree settings and ensemble parameters.
1386 //
1387 // weak_count - count of trees in the ensemble
1388 // loss_function_type - loss function used for ensemble training
1389 // subsample_portion - portion of whole training set used for
1390 // every single tree training.
1391 // subsample_portion value is in (0.0, 1.0].
1392 // subsample_portion == 1.0 when whole dataset is
1393 // used on each step. Count of sample used on each
1394 // step is computed as
1395 // int(total_samples_count * subsample_portion).
1396 // shrinkage - regularization parameter.
1397 // Each tree prediction is multiplied on shrinkage value.
1398 
1399 
1401 {
1406 
1407  CvGBTreesParams();
1408  CvGBTreesParams( int loss_function_type, int weak_count, float shrinkage,
1409  float subsample_portion, int max_depth, bool use_surrogates );
1410 };
1411 
1412 // DataType: CLASS CvGBTrees
1413 // Gradient Boosting Trees (GBT) algorithm implementation.
1414 //
1415 // data - training dataset
1416 // params - parameters of the CvGBTrees
1417 // weak - array[0..(class_count-1)] of CvSeq
1418 // for storing tree ensembles
1419 // orig_response - original responses of the training set samples
1420 // sum_response - predicitons of the current model on the training dataset.
1421 // this matrix is updated on every iteration.
1422 // sum_response_tmp - predicitons of the model on the training set on the next
1423 // step. On every iteration values of sum_responses_tmp are
1424 // computed via sum_responses values. When the current
1425 // step is complete sum_response values become equal to
1426 // sum_responses_tmp.
1427 // sampleIdx - indices of samples used for training the ensemble.
1428 // CvGBTrees training procedure takes a set of samples
1429 // (train_data) and a set of responses (responses).
1430 // Only pairs (train_data[i], responses[i]), where i is
1431 // in sample_idx are used for training the ensemble.
1432 // subsample_train - indices of samples used for training a single decision
1433 // tree on the current step. This indices are countered
1434 // relatively to the sample_idx, so that pairs
1435 // (train_data[sample_idx[i]], responses[sample_idx[i]])
1436 // are used for training a decision tree.
1437 // Training set is randomly splited
1438 // in two parts (subsample_train and subsample_test)
1439 // on every iteration accordingly to the portion parameter.
1440 // subsample_test - relative indices of samples from the training set,
1441 // which are not used for training a tree on the current
1442 // step.
1443 // missing - mask of the missing values in the training set. This
1444 // matrix has the same size as train_data. 1 - missing
1445 // value, 0 - not a missing value.
1446 // class_labels - output class labels map.
1447 // rng - random number generator. Used for spliting the
1448 // training set.
1449 // class_count - count of output classes.
1450 // class_count == 1 in the case of regression,
1451 // and > 1 in the case of classification.
1452 // delta - Huber loss function parameter.
1453 // base_value - start point of the gradient descent procedure.
1454 // model prediction is
1455 // f(x) = f_0 + sum_{i=1..weak_count-1}(f_i(x)), where
1456 // f_0 is the base value.
1457 
1458 
1459 
1461 {
1462 public:
1463 
1464  /*
1465  // DataType: ENUM
1466  // Loss functions implemented in CvGBTrees.
1467  //
1468  // SQUARED_LOSS
1469  // problem: regression
1470  // loss = (x - x')^2
1471  //
1472  // ABSOLUTE_LOSS
1473  // problem: regression
1474  // loss = abs(x - x')
1475  //
1476  // HUBER_LOSS
1477  // problem: regression
1478  // loss = delta*( abs(x - x') - delta/2), if abs(x - x') > delta
1479  // 1/2*(x - x')^2, if abs(x - x') <= delta,
1480  // where delta is the alpha-quantile of pseudo responses from
1481  // the training set.
1482  //
1483  // DEVIANCE_LOSS
1484  // problem: classification
1485  //
1486  */
1487  enum {SQUARED_LOSS=0, ABSOLUTE_LOSS, HUBER_LOSS=3, DEVIANCE_LOSS};
1488 
1489 
1490  /*
1491  // Default constructor. Creates a model only (without training).
1492  // Should be followed by one form of the train(...) function.
1493  //
1494  // API
1495  // CvGBTrees();
1496 
1497  // INPUT
1498  // OUTPUT
1499  // RESULT
1500  */
1501  CV_WRAP CvGBTrees();
1502 
1503 
1504  /*
1505  // Full form constructor. Creates a gradient boosting model and does the
1506  // train.
1507  //
1508  // API
1509  // CvGBTrees( const CvMat* trainData, int tflag,
1510  const CvMat* responses, const CvMat* varIdx=0,
1511  const CvMat* sampleIdx=0, const CvMat* varType=0,
1512  const CvMat* missingDataMask=0,
1513  CvGBTreesParams params=CvGBTreesParams() );
1514 
1515  // INPUT
1516  // trainData - a set of input feature vectors.
1517  // size of matrix is
1518  // <count of samples> x <variables count>
1519  // or <variables count> x <count of samples>
1520  // depending on the tflag parameter.
1521  // matrix values are float.
1522  // tflag - a flag showing how do samples stored in the
1523  // trainData matrix row by row (tflag=CV_ROW_SAMPLE)
1524  // or column by column (tflag=CV_COL_SAMPLE).
1525  // responses - a vector of responses corresponding to the samples
1526  // in trainData.
1527  // varIdx - indices of used variables. zero value means that all
1528  // variables are active.
1529  // sampleIdx - indices of used samples. zero value means that all
1530  // samples from trainData are in the training set.
1531  // varType - vector of <variables count> length. gives every
1532  // variable type CV_VAR_CATEGORICAL or CV_VAR_ORDERED.
1533  // varType = 0 means all variables are numerical.
1534  // missingDataMask - a mask of misiing values in trainData.
1535  // missingDataMask = 0 means that there are no missing
1536  // values.
1537  // params - parameters of GTB algorithm.
1538  // OUTPUT
1539  // RESULT
1540  */
1541  CvGBTrees( const CvMat* trainData, int tflag,
1542  const CvMat* responses, const CvMat* varIdx=0,
1543  const CvMat* sampleIdx=0, const CvMat* varType=0,
1544  const CvMat* missingDataMask=0,
1545  CvGBTreesParams params=CvGBTreesParams() );
1546 
1547 
1548  /*
1549  // Destructor.
1550  */
1551  virtual ~CvGBTrees();
1552 
1553 
1554  /*
1555  // Gradient tree boosting model training
1556  //
1557  // API
1558  // virtual bool train( const CvMat* trainData, int tflag,
1559  const CvMat* responses, const CvMat* varIdx=0,
1560  const CvMat* sampleIdx=0, const CvMat* varType=0,
1561  const CvMat* missingDataMask=0,
1562  CvGBTreesParams params=CvGBTreesParams(),
1563  bool update=false );
1564 
1565  // INPUT
1566  // trainData - a set of input feature vectors.
1567  // size of matrix is
1568  // <count of samples> x <variables count>
1569  // or <variables count> x <count of samples>
1570  // depending on the tflag parameter.
1571  // matrix values are float.
1572  // tflag - a flag showing how do samples stored in the
1573  // trainData matrix row by row (tflag=CV_ROW_SAMPLE)
1574  // or column by column (tflag=CV_COL_SAMPLE).
1575  // responses - a vector of responses corresponding to the samples
1576  // in trainData.
1577  // varIdx - indices of used variables. zero value means that all
1578  // variables are active.
1579  // sampleIdx - indices of used samples. zero value means that all
1580  // samples from trainData are in the training set.
1581  // varType - vector of <variables count> length. gives every
1582  // variable type CV_VAR_CATEGORICAL or CV_VAR_ORDERED.
1583  // varType = 0 means all variables are numerical.
1584  // missingDataMask - a mask of misiing values in trainData.
1585  // missingDataMask = 0 means that there are no missing
1586  // values.
1587  // params - parameters of GTB algorithm.
1588  // update - is not supported now. (!)
1589  // OUTPUT
1590  // RESULT
1591  // Error state.
1592  */
1593  virtual bool train( const CvMat* trainData, int tflag,
1594  const CvMat* responses, const CvMat* varIdx=0,
1595  const CvMat* sampleIdx=0, const CvMat* varType=0,
1596  const CvMat* missingDataMask=0,
1598  bool update=false );
1599 
1600 
1601  /*
1602  // Gradient tree boosting model training
1603  //
1604  // API
1605  // virtual bool train( CvMLData* data,
1606  CvGBTreesParams params=CvGBTreesParams(),
1607  bool update=false ) {return false;};
1608 
1609  // INPUT
1610  // data - training set.
1611  // params - parameters of GTB algorithm.
1612  // update - is not supported now. (!)
1613  // OUTPUT
1614  // RESULT
1615  // Error state.
1616  */
1617  virtual bool train( CvMLData* data,
1619  bool update=false );
1620 
1621 
1622  /*
1623  // Response value prediction
1624  //
1625  // API
1626  // virtual float predict( const CvMat* sample, const CvMat* missing=0,
1627  CvMat* weak_responses=0, CvSlice slice = CV_WHOLE_SEQ,
1628  int k=-1 ) const;
1629 
1630  // INPUT
1631  // sample - input sample of the same type as in the training set.
1632  // missing - missing values mask. missing=0 if there are no
1633  // missing values in sample vector.
1634  // weak_responses - predictions of all of the trees.
1635  // not implemented (!)
1636  // slice - part of the ensemble used for prediction.
1637  // slice = CV_WHOLE_SEQ when all trees are used.
1638  // k - number of ensemble used.
1639  // k is in {-1,0,1,..,<count of output classes-1>}.
1640  // in the case of classification problem
1641  // <count of output classes-1> ensembles are built.
1642  // If k = -1 ordinary prediction is the result,
1643  // otherwise function gives the prediction of the
1644  // k-th ensemble only.
1645  // OUTPUT
1646  // RESULT
1647  // Predicted value.
1648  */
1649  virtual float predict( const CvMat* sample, const CvMat* missing=0,
1650  CvMat* weakResponses=0, CvSlice slice = CV_WHOLE_SEQ,
1651  int k=-1 ) const;
1652 
1653  /*
1654  // Delete all temporary data.
1655  //
1656  // API
1657  // virtual void clear();
1658 
1659  // INPUT
1660  // OUTPUT
1661  // delete data, weak, orig_response, sum_response,
1662  // weak_eval, ubsample_train, subsample_test,
1663  // sample_idx, missing, lass_labels
1664  // delta = 0.0
1665  // RESULT
1666  */
1667  CV_WRAP virtual void clear();
1668 
1669  /*
1670  // Compute error on the train/test set.
1671  //
1672  // API
1673  // virtual float calc_error( CvMLData* _data, int type,
1674  // std::vector<float> *resp = 0 );
1675  //
1676  // INPUT
1677  // data - dataset
1678  // type - defines which error is to compute^ train (CV_TRAIN_ERROR) or
1679  // test (CV_TEST_ERROR).
1680  // OUTPUT
1681  // resp - vector of predicitons
1682  // RESULT
1683  // Error value.
1684  */
1685  virtual float calc_error( CvMLData* _data, int type,
1686  std::vector<float> *resp = 0 );
1687 
1688 
1689  /*
1690  //
1691  // Write parameters of the gtb model and data. Write learned model.
1692  //
1693  // API
1694  // virtual void write( CvFileStorage* fs, const char* name ) const;
1695  //
1696  // INPUT
1697  // fs - file storage to read parameters from.
1698  // name - model name.
1699  // OUTPUT
1700  // RESULT
1701  */
1702  virtual void write( CvFileStorage* fs, const char* name ) const;
1703 
1704 
1705  /*
1706  //
1707  // Read parameters of the gtb model and data. Read learned model.
1708  //
1709  // API
1710  // virtual void read( CvFileStorage* fs, CvFileNode* node );
1711  //
1712  // INPUT
1713  // fs - file storage to read parameters from.
1714  // node - file node.
1715  // OUTPUT
1716  // RESULT
1717  */
1718  virtual void read( CvFileStorage* fs, CvFileNode* node );
1719 
1720 
1721  // new-style C++ interface
1722  CV_WRAP CvGBTrees( const cv::Mat& trainData, int tflag,
1723  const cv::Mat& responses, const cv::Mat& varIdx=cv::Mat(),
1724  const cv::Mat& sampleIdx=cv::Mat(), const cv::Mat& varType=cv::Mat(),
1725  const cv::Mat& missingDataMask=cv::Mat(),
1726  CvGBTreesParams params=CvGBTreesParams() );
1727 
1728  CV_WRAP virtual bool train( const cv::Mat& trainData, int tflag,
1729  const cv::Mat& responses, const cv::Mat& varIdx=cv::Mat(),
1730  const cv::Mat& sampleIdx=cv::Mat(), const cv::Mat& varType=cv::Mat(),
1731  const cv::Mat& missingDataMask=cv::Mat(),
1733  bool update=false );
1734 
1735  CV_WRAP virtual float predict( const cv::Mat& sample, const cv::Mat& missing=cv::Mat(),
1736  const cv::Range& slice = cv::Range::all(),
1737  int k=-1 ) const;
1738 
1739 protected:
1740 
1741  /*
1742  // Compute the gradient vector components.
1743  //
1744  // API
1745  // virtual void find_gradient( const int k = 0);
1746 
1747  // INPUT
1748  // k - used for classification problem, determining current
1749  // tree ensemble.
1750  // OUTPUT
1751  // changes components of data->responses
1752  // which correspond to samples used for training
1753  // on the current step.
1754  // RESULT
1755  */
1756  virtual void find_gradient( const int k = 0);
1757 
1758 
1759  /*
1760  //
1761  // Change values in tree leaves according to the used loss function.
1762  //
1763  // API
1764  // virtual void change_values(CvDTree* tree, const int k = 0);
1765  //
1766  // INPUT
1767  // tree - decision tree to change.
1768  // k - used for classification problem, determining current
1769  // tree ensemble.
1770  // OUTPUT
1771  // changes 'value' fields of the trees' leaves.
1772  // changes sum_response_tmp.
1773  // RESULT
1774  */
1775  virtual void change_values(CvDTree* tree, const int k = 0);
1776 
1777 
1778  /*
1779  //
1780  // Find optimal constant prediction value according to the used loss
1781  // function.
1782  // The goal is to find a constant which gives the minimal summary loss
1783  // on the _Idx samples.
1784  //
1785  // API
1786  // virtual float find_optimal_value( const CvMat* _Idx );
1787  //
1788  // INPUT
1789  // _Idx - indices of the samples from the training set.
1790  // OUTPUT
1791  // RESULT
1792  // optimal constant value.
1793  */
1794  virtual float find_optimal_value( const CvMat* _Idx );
1795 
1796 
1797  /*
1798  //
1799  // Randomly split the whole training set in two parts according
1800  // to params.portion.
1801  //
1802  // API
1803  // virtual void do_subsample();
1804  //
1805  // INPUT
1806  // OUTPUT
1807  // subsample_train - indices of samples used for training
1808  // subsample_test - indices of samples used for test
1809  // RESULT
1810  */
1811  virtual void do_subsample();
1812 
1813 
1814  /*
1815  //
1816  // Internal recursive function giving an array of subtree tree leaves.
1817  //
1818  // API
1819  // void leaves_get( CvDTreeNode** leaves, int& count, CvDTreeNode* node );
1820  //
1821  // INPUT
1822  // node - current leaf.
1823  // OUTPUT
1824  // count - count of leaves in the subtree.
1825  // leaves - array of pointers to leaves.
1826  // RESULT
1827  */
1828  void leaves_get( CvDTreeNode** leaves, int& count, CvDTreeNode* node );
1829 
1830 
1831  /*
1832  //
1833  // Get leaves of the tree.
1834  //
1835  // API
1836  // CvDTreeNode** GetLeaves( const CvDTree* dtree, int& len );
1837  //
1838  // INPUT
1839  // dtree - decision tree.
1840  // OUTPUT
1841  // len - count of the leaves.
1842  // RESULT
1843  // CvDTreeNode** - array of pointers to leaves.
1844  */
1845  CvDTreeNode** GetLeaves( const CvDTree* dtree, int& len );
1846 
1847 
1848  /*
1849  //
1850  // Is it a regression or a classification.
1851  //
1852  // API
1853  // bool problem_type();
1854  //
1855  // INPUT
1856  // OUTPUT
1857  // RESULT
1858  // false if it is a classification problem,
1859  // true - if regression.
1860  */
1861  virtual bool problem_type() const;
1862 
1863 
1864  /*
1865  //
1866  // Write parameters of the gtb model.
1867  //
1868  // API
1869  // virtual void write_params( CvFileStorage* fs ) const;
1870  //
1871  // INPUT
1872  // fs - file storage to write parameters to.
1873  // OUTPUT
1874  // RESULT
1875  */
1876  virtual void write_params( CvFileStorage* fs ) const;
1877 
1878 
1879  /*
1880  //
1881  // Read parameters of the gtb model and data.
1882  //
1883  // API
1884  // virtual void read_params( CvFileStorage* fs );
1885  //
1886  // INPUT
1887  // fs - file storage to read parameters from.
1888  // OUTPUT
1889  // params - parameters of the gtb model.
1890  // data - contains information about the structure
1891  // of the data set (count of variables,
1892  // their types, etc.).
1893  // class_labels - output class labels map.
1894  // RESULT
1895  */
1896  virtual void read_params( CvFileStorage* fs, CvFileNode* fnode );
1897 
1898 
1901 
1912 
1914 
1916  float delta;
1917  float base_value;
1918 
1919 };
1920 
1921 
1922 
1923 /****************************************************************************************\
1924 * Artificial Neural Networks (ANN) *
1925 \****************************************************************************************/
1926 
1928 
1930 {
1932  CvANN_MLP_TrainParams( CvTermCriteria term_crit, int train_method,
1933  double param1, double param2=0 );
1934  ~CvANN_MLP_TrainParams();
1935 
1936  enum { BACKPROP=0, RPROP=1 };
1937 
1940 
1941  // backpropagation parameters
1942  CV_PROP_RW double bp_dw_scale, bp_moment_scale;
1943 
1944  // rprop parameters
1945  CV_PROP_RW double rp_dw0, rp_dw_plus, rp_dw_minus, rp_dw_min, rp_dw_max;
1946 };
1947 
1948 
1950 {
1951 public:
1952  CV_WRAP CvANN_MLP();
1953  CvANN_MLP( const CvMat* layerSizes,
1954  int activateFunc=CvANN_MLP::SIGMOID_SYM,
1955  double fparam1=0, double fparam2=0 );
1956 
1957  virtual ~CvANN_MLP();
1958 
1959  virtual void create( const CvMat* layerSizes,
1960  int activateFunc=CvANN_MLP::SIGMOID_SYM,
1961  double fparam1=0, double fparam2=0 );
1962 
1963  virtual int train( const CvMat* inputs, const CvMat* outputs,
1964  const CvMat* sampleWeights, const CvMat* sampleIdx=0,
1966  int flags=0 );
1967  virtual float predict( const CvMat* inputs, CV_OUT CvMat* outputs ) const;
1968 
1969 #ifndef SWIG
1970  CV_WRAP CvANN_MLP( const cv::Mat& layerSizes,
1971  int activateFunc=CvANN_MLP::SIGMOID_SYM,
1972  double fparam1=0, double fparam2=0 );
1973 
1974  CV_WRAP virtual void create( const cv::Mat& layerSizes,
1975  int activateFunc=CvANN_MLP::SIGMOID_SYM,
1976  double fparam1=0, double fparam2=0 );
1977 
1978  CV_WRAP virtual int train( const cv::Mat& inputs, const cv::Mat& outputs,
1979  const cv::Mat& sampleWeights, const cv::Mat& sampleIdx=cv::Mat(),
1981  int flags=0 );
1982 
1983  CV_WRAP virtual float predict( const cv::Mat& inputs, cv::Mat& outputs ) const;
1984 #endif
1985 
1986  CV_WRAP virtual void clear();
1987 
1988  // possible activation functions
1989  enum { IDENTITY = 0, SIGMOID_SYM = 1, GAUSSIAN = 2 };
1990 
1991  // available training flags
1992  enum { UPDATE_WEIGHTS = 1, NO_INPUT_SCALE = 2, NO_OUTPUT_SCALE = 4 };
1993 
1994  virtual void read( CvFileStorage* fs, CvFileNode* node );
1995  virtual void write( CvFileStorage* storage, const char* name ) const;
1996 
1997  int get_layer_count() { return layer_sizes ? layer_sizes->cols : 0; }
1998  const CvMat* get_layer_sizes() { return layer_sizes; }
1999  double* get_weights(int layer)
2000  {
2001  return layer_sizes && weights &&
2002  (unsigned)layer <= (unsigned)layer_sizes->cols ? weights[layer] : 0;
2003  }
2004 
2005 protected:
2006 
2007  virtual bool prepare_to_train( const CvMat* _inputs, const CvMat* _outputs,
2008  const CvMat* _sample_weights, const CvMat* sampleIdx,
2009  CvVectors* _ivecs, CvVectors* _ovecs, double** _sw, int _flags );
2010 
2011  // sequential random backpropagation
2012  virtual int train_backprop( CvVectors _ivecs, CvVectors _ovecs, const double* _sw );
2013 
2014  // RPROP algorithm
2015  virtual int train_rprop( CvVectors _ivecs, CvVectors _ovecs, const double* _sw );
2016 
2017  virtual void calc_activ_func( CvMat* xf, const double* bias ) const;
2018  virtual void calc_activ_func_deriv( CvMat* xf, CvMat* deriv, const double* bias ) const;
2019  virtual void set_activ_func( int _activ_func=SIGMOID_SYM,
2020  double _f_param1=0, double _f_param2=0 );
2021  virtual void init_weights();
2022  virtual void scale_input( const CvMat* _src, CvMat* _dst ) const;
2023  virtual void scale_output( const CvMat* _src, CvMat* _dst ) const;
2024  virtual void calc_input_scale( const CvVectors* vecs, int flags );
2025  virtual void calc_output_scale( const CvVectors* vecs, int flags );
2026 
2027  virtual void write_params( CvFileStorage* fs ) const;
2028  virtual void read_params( CvFileStorage* fs, CvFileNode* node );
2029 
2033  double** weights;
2034  double f_param1, f_param2;
2035  double min_val, max_val, min_val1, max_val1;
2037  int max_count, max_buf_sz;
2040 };
2041 
2042 /****************************************************************************************\
2043 * Auxilary functions declarations *
2044 \****************************************************************************************/
2045 
2046 /* Generates <sample> from multivariate normal distribution, where <mean> - is an
2047  average row vector, <cov> - symmetric covariation matrix */
2048 CVAPI(void) cvRandMVNormal( CvMat* mean, CvMat* cov, CvMat* sample,
2049  CvRNG* rng CV_DEFAULT(0) );
2050 
2051 /* Generates sample from gaussian mixture distribution */
2052 CVAPI(void) cvRandGaussMixture( CvMat* means[],
2053  CvMat* covs[],
2054  float weights[],
2055  int clsnum,
2056  CvMat* sample,
2057  CvMat* sampClasses CV_DEFAULT(0) );
2058 
2059 #define CV_TS_CONCENTRIC_SPHERES 0
2060 
2061 /* creates test set */
2062 CVAPI(void) cvCreateTestSet( int type, CvMat** samples,
2063  int num_samples,
2064  int num_features,
2065  CvMat** responses,
2066  int num_classes, ... );
2067 
2068 
2069 #endif
2070 
2071 /****************************************************************************************\
2072 * Data *
2073 \****************************************************************************************/
2074 
2075 #include <map>
2076 #include <string>
2077 #include <iostream>
2078 
2079 #define CV_COUNT 0
2080 #define CV_PORTION 1
2081 
2083 {
2084 public:
2085  CvTrainTestSplit();
2086  CvTrainTestSplit( int _train_sample_count, bool _mix = true);
2087  CvTrainTestSplit( float _train_sample_portion, bool _mix = true);
2088 
2089  union
2090  {
2091  int count;
2092  float portion;
2093  } train_sample_part;
2095 
2096  union
2097  {
2098  int *count;
2099  float *portion;
2100  } *class_part;
2102 
2103  bool mix;
2104 };
2105 
2107 {
2108 public:
2109  CvMLData();
2110  virtual ~CvMLData();
2111 
2112  // returns:
2113  // 0 - OK
2114  // 1 - file can not be opened or is not correct
2115  int read_csv(const char* filename);
2116 
2117  const CvMat* get_values(){ return values; };
2118 
2119  const CvMat* get_responses();
2120 
2121  const CvMat* get_missing(){ return missing; };
2122 
2123  void set_response_idx( int idx ); // old response become predictors, new response_idx = idx
2124  // if idx < 0 there will be no response
2125  int get_response_idx() { return response_idx; }
2126 
2127  const CvMat* get_train_sample_idx() { return train_sample_idx; };
2128  const CvMat* get_test_sample_idx() { return test_sample_idx; };
2129  void mix_train_and_test_idx();
2130  void set_train_test_split( const CvTrainTestSplit * spl);
2131 
2132  const CvMat* get_var_idx();
2133  void chahge_var_idx( int vi, bool state ); // state == true to set vi-variable as predictor
2134 
2135  const CvMat* get_var_types();
2136  int get_var_type( int var_idx ) { return var_types->data.ptr[var_idx]; };
2137  // following 2 methods enable to change vars type
2138  // use these methods to assign CV_VAR_CATEGORICAL type for categorical variable
2139  // with numerical labels; in the other cases var types are correctly determined automatically
2140  void set_var_types( const char* str ); // str examples:
2141  // "ord[0-17],cat[18]", "ord[0,2,4,10-12], cat[1,3,5-9,13,14]",
2142  // "cat", "ord" (all vars are categorical/ordered)
2143  void change_var_type( int var_idx, int type); // type in { CV_VAR_ORDERED, CV_VAR_CATEGORICAL }
2144 
2145  void set_delimiter( char ch );
2146  char get_delimiter() { return delimiter; };
2147 
2148  void set_miss_ch( char ch );
2149  char get_miss_ch() { return miss_ch; };
2150 
2151 protected:
2152  virtual void clear();
2153 
2154  void str_to_flt_elem( const char* token, float& flt_elem, int& type);
2155  void free_train_test_idx();
2156 
2158  char miss_ch;
2159  //char flt_separator;
2160 
2165 
2166  CvMat* response_out; // header
2169 
2171 
2173  bool mix;
2174 
2176  std::map<std::string, int> *class_map;
2177 
2180  int* sample_idx; // data of train_sample_idx and test_sample_idx
2181 
2183 };
2184 
2185 
2186 namespace cv
2187 {
2188 
2196 typedef CvSVM SVM;
2210 typedef CvBoost Boost;
2215 
2216 template<> CV_EXPORTS void Ptr<CvDTreeSplit>::delete_obj();
2217 
2218 }
2219 
2220 #endif
2221 /* End of file. */