>  기사  >  데이터 베이스  >  利用SVM解决2维空间向量的3级分类问题

利用SVM解决2维空间向量的3级分类问题

WBOY
WBOY원래의
2016-06-07 15:43:031558검색

【原文:http://blog.csdn.net/firefight/article/details/6400060】 为了学习OPENCV SVM分类器, 参考网上的 利用SVM解决2维空间向量的分类问题 实现并改为C代码,仅供参考 环境:OPENCV2.2 VS2008 步骤: 1,生成随机的点,并按一定的空间分布将其归类 2,

【原文:http://blog.csdn.net/firefight/article/details/6400060】

为了学习OPENCV SVM分类器, 参考网上的"利用SVM解决2维空间向量的分类问题"实现并改为C++代码,仅供参考

 

环境:OPENCV2.2 + VS2008

步骤:
1,生成随机的点,并按一定的空间分布将其归类
2,创建SVM并利用随机点样本进行训练
3,将整个空间按SVM分类结果进行划分,并显示支持向量

 

[cpp] view plaincopy

  1. #include "stdafx.h"  
  2. #include   
  3.   
  4. void drawCross(Mat &img, Point center, Scalar color)  
  5. {  
  6.     int col = center.x > 2 ? center.x : 2;  
  7.     int row = center.y> 2 ? center.y : 2;  
  8.   
  9.     line(img, Point(col -2, row - 2), Point(col + 2, row + 2), color);    
  10.     line(img, Point(col + 2, row - 2), Point(col - 2, row + 2), color);    
  11. }  
  12.   
  13. int newSvmTest(int rows, int cols, int testCount)  
  14. {  
  15.     if(testCount > rows * cols)  
  16.         return 0;  
  17.   
  18.     Mat img = Mat::zeros(rows, cols, CV_8UC3);  
  19.     Mat testPoint = Mat::zeros(rows, cols, CV_8UC1);  
  20.     Mat data = Mat::zeros(testCount, 2, CV_32FC1);  
  21.     Mat res = Mat::zeros(testCount, 1, CV_32SC1);  
  22.   
  23.     //Create random test points  
  24.     for (int i= 0; i
  25.     {   
  26.         int row = rand() % rows;  
  27.         int col = rand() % cols;  
  28.   
  29.         if(testPoint.atchar>(row, col) == 0)  
  30.         {  
  31.             testPoint.atchar>(row, col) = 1;  
  32.             data.atfloat>(i, 0) = float (col) / cols;   
  33.             data.atfloat>(i, 1) = float (row) / rows;   
  34.         }  
  35.         else  
  36.         {  
  37.             i--;  
  38.             continue;  
  39.         }  
  40.   
  41.         if (row > ( 50 * cos(col * CV_PI/ 100) + 200) )  
  42.         {   
  43.             drawCross(img, Point(col, row), CV_RGB(255, 0, 0));  
  44.             res.atint>(i, 0) = 1;   
  45.         }   
  46.         else   
  47.         {   
  48.             if (col > 200)   
  49.             {   
  50.                 drawCross(img, Point(col, row), CV_RGB(0, 255, 0));  
  51.                 res.atint>(i, 0) = 2;   
  52.             }   
  53.             else   
  54.             {   
  55.                 drawCross(img, Point(col, row), CV_RGB(0, 0, 255));  
  56.                 res.atint>(i, 0) = 3;   
  57.             }   
  58.         }   
  59.   
  60.     }  
  61.   
  62.     //Show test points  
  63.     imshow("dst", img);  
  64.     waitKey(0);  
  65.   
  66.     /////////////START SVM TRAINNING//////////////////  
  67.     CvSVM svm = CvSVM();   
  68.     CvSVMParams param;   
  69.     CvTermCriteria criteria;  
  70.   
  71.     criteria= cvTermCriteria(CV_TERMCRIT_EPS, 1000, FLT_EPSILON);  
  72. /* SVM种类:CvSVM::C_SVC 
    Kernel的种类:CvSVM::RBF
    degree:10.0(此次不使用) 
    gamma:8.0 
    coef0:1.0(此次不使用)
    C:10.0 
    nu:0.5(此次不使用) 
    p:0.1(此次不使用) 
    然后对训练数据正规化处理,并放在CvMat型的数组里。*/

  73.     param= CvSVMParams (CvSVM::C_SVC, CvSVM::RBF, 10.0, 8.0, 1.0, 10.0, 0.5, 0.1, NULL, criteria);   
  74.     svm.train(data, res, Mat(), Mat(), param);  
  75.   
  76.     for (int i= 0; i
  77.     {   
  78.         for (int j= 0; j
  79.         {   
  80.             Mat m = Mat::zeros(1, 2, CV_32FC1);  
  81.             m.atfloat>(0,0) = float (j) / cols;  
  82.             m.atfloat>(0,1) = float (i) / rows;  
  83.   
  84.             float ret = 0.0;   
  85.             ret = svm.predict(m);   
  86.             Scalar rcolor;   
  87.   
  88.             switch ((int) ret)   
  89.             {   
  90.                 case 1: rcolor= CV_RGB(100, 0, 0); break;   
  91.                 case 2: rcolor= CV_RGB(0, 100, 0); break;   
  92.                 case 3: rcolor= CV_RGB(0, 0, 100); break;   
  93.             }   
  94.   
  95.             line(img, Point(j,i), Point(j,i), rcolor);  
  96.         }   
  97.     }  
  98.   
  99.     imshow("dst", img);  
  100.     waitKey(0);  
  101.   
  102.     //Show support vectors  
  103.     int sv_num= svm.get_support_vector_count();   
  104.     for (int i= 0; i
  105.     {   
  106.         const float* support = svm.get_support_vector(i);   
  107.         circle(img, Point((int) (support[0] * cols), (int) (support[1] * rows)), 5, CV_RGB(200, 200, 200));   
  108.     }  
  109.   
  110.     imshow("dst", img);  
  111.     waitKey(0);  
  112.   
  113.     return 0;  
  114. }  
  115.   
  116. int main(int argc, char** argv)  
  117. {  
  118.     return newSvmTest(400, 600, 100);  
  119. }  

 

学习样本:

利用SVM解决2维空间向量的3级分类问题

 

分类:

利用SVM解决2维空间向量的3级分类问题

 

支持向量:

利用SVM解决2维空间向量的3级分类问题

성명:
본 글의 내용은 네티즌들의 자발적인 기여로 작성되었으며, 저작권은 원저작자에게 있습니다. 본 사이트는 이에 상응하는 법적 책임을 지지 않습니다. 표절이나 침해가 의심되는 콘텐츠를 발견한 경우 admin@php.cn으로 문의하세요.