Advertisement

代码 rbf分类_基于机器学习的心律失常分类(七)——支持向量机SVM[MATLAB]

阅读量:
8fdb50ea90dc03e23a2121104364e370.png

一、支持向量机

支持向量机是一种监督学习的二分类模型,其主要目的是找到一种超平面对样本数据进行分割,从而转化为求解凸二次规划的问题。然而在大部分的分类任务中,一般无法使用一种线性关系将两种分类实例分割,支持向量机模型通过非线性核函数将分类任务从线性不可分的低维空间转换到线性可分的高维空间,然后通过算法分析找出最佳超平面,最后返回到原始空间以获得初始空间的解。
1a9405499f2e0016df70d5be800727e5.png

支持向量机变换图
072778b087fe2ef1bec2c30e5a2b7cbc.png
bd8ab0bfe415583930093b722cf635e0.png
778906c51b79d2b5b5e08daf6a116ef2.png

二、代码

说明:

  1. DATA矩阵为arma系数矩阵,1-6列为ARMA模型系数,第七列为所数类别(1-正常心电,2-左束支阻滞,3-右束支阻滞,4-室性早搏)
  2. 四个类别分别选取各1700个数据,共计6800个样本,取70%作为训练集。
复制代码
 %% I. 清空环境变量

    
 %clear all
    
 %clc
    
  
    
 %% II. 导入数据
    
 %load BreastTissue_data.mat
    
 matrix=DATA(:,1:6);
    
 label=DATA(:,7);
    
 %%
    
 % 1. 随机产生训练集和测试集
    
 n1 = randperm(1700);
    
 n2 = randperm(1700);
    
 n3 = randperm(1700);
    
 n4 = randperm(1700);
    
 %%
    
 % 2. 训练集——4760(6800)个样本
    
 P_train1 = matrix(n1(1:1190),:);
    
 P_train2 = matrix(1700+n2(1:1190),:);
    
 P_train3 = matrix(1700*2+n3(1:1190),:);
    
 P_train4 = matrix(1700*3+n4(1:1190),:);
    
 train_matrix = [P_train1;P_train2;P_train3;P_train4];
    
 T_train1 = label(n1(1:1190),:);
    
 T_train2 = label(1700+n2(1:1190),:);
    
 T_train3 = label(1700*2+n3(1:1190),:);
    
 T_train4 = label(1700*3+n4(1:1190),:);
    
 train_label = [T_train1;T_train2;T_train3;T_train4];
    
  
    
 %%
    
 % 3. 测试集——2040个样本
    
 P_test1 = matrix(n1(1191:end),:);
    
 P_test2 = matrix(1700+n2(1191:end),:);
    
 P_test3 = matrix(1700*2+n3(1191:end),:);
    
 P_test4 = matrix(1700*3+n4(1191:end),:);
    
 test_matrix = [P_test1;P_test2;P_test3;P_test4];
    
 T_test1 = label(n1(1191:end),:);
    
 T_test2 = label(1700+n2(1191:end),:);
    
 T_test3 = label(1700*2+n3(1191:end),:);
    
 T_test4 = label(1700*3+n4(1191:end),:);
    
 test_label = [T_test1;T_test2;T_test3;T_test4];
    
  
    
 %% III. 数据归一化
    
 [Train_matrix,PS] = mapminmax(train_matrix');
    
 Train_matrix = Train_matrix';
    
 Test_matrix = mapminmax('apply',test_matrix',PS);
    
 Test_matrix = Test_matrix';
    
  
    
 %% IV. SVM创建/训练(RBF核函数)
    
 %%
    
 % 1. 寻找最佳c/g参数——交叉验证方法
    
 [c,g] = meshgrid(-10:0.2:10,-10:0.2:10);
    
 [m,n] = size(c);
    
 cg = zeros(m,n);
    
 eps = 10^(-4);
    
 v = 5;
    
 bestc = 1;
    
 bestg = 0.1;
    
 bestacc = 0;
    
 for i = 1:m
    
     for j = 1:n
    
     cmd = ['-v ',num2str(v),' -h 0',' -c ',num2str(2^c(i,j)),' -g ',num2str(2^g(i,j))];
    
     cg(i,j) = svmtrain(train_label,Train_matrix,cmd);     
    
     if cg(i,j) > bestacc
    
         bestacc = cg(i,j);
    
         bestc = 2^c(i,j);
    
         bestg = 2^g(i,j);
    
     end        
    
     if abs( cg(i,j)-bestacc )<=eps && bestc > 2^c(i,j) 
    
         bestacc = cg(i,j);
    
         bestc = 2^c(i,j);
    
         bestg = 2^g(i,j);
    
     end               
    
     end
    
 end
    
 cmd = [' -h 0',' -c ',num2str(bestc),' -g ',num2str(bestg)];
    
  
    
 %%
    
 % 2. 创建/训练SVM模型
    
 model = svmtrain(train_label,Train_matrix,cmd);
    
  
    
 %% V. SVM仿真测试
    
 [predict_label_1,accuracy_1,m] = svmpredict(train_label,Train_matrix,model);
    
 [predict_label_2,accuracy_2,n] = svmpredict(test_label,Test_matrix,model);
    
 result_1 = [train_label predict_label_1];
    
 result_2 = [test_label predict_label_2];
    
  
    
 %% VI. 绘图
    
 figure
    
 plot(1:length(test_label),test_label,'r-*')
    
 hold on
    
 plot(1:length(test_label),predict_label_2,'b:o')
    
 grid on
    
 legend('真实类别','预测类别')
    
 xlabel('测试集样本编号')
    
 ylabel('测试集样本类别')
    
 string = {'测试集SVM预测结果对比(RBF核函数)';
    
       ['accuracy = ' num2str(accuracy_2(1)) '%']};
    
 title(string)
    
  
    
  
    
 %% 混淆矩阵
    
 T_sim=predict_label_2;
    
 T_test=test_label;
    
 H=[];
    
 number_1_1= length(find(T_sim == 1 & T_test == 1));H(1,1)=number_1_1;
    
 number_1_2= length(find(T_sim == 1 & T_test == 2));H(1,2)=number_1_2;
    
 number_1_3= length(find(T_sim == 1 & T_test == 3));H(1,3)=number_1_3;
    
 number_1_4= length(find(T_sim == 1 & T_test == 4));H(1,4)=number_1_4;
    
  
    
 number_2_1= length(find(T_sim == 2 & T_test == 1));H(2,1)=number_2_1;
    
 number_2_2= length(find(T_sim == 2 & T_test == 2));H(2,2)=number_2_2;
    
 number_2_3= length(find(T_sim == 2 & T_test == 3));H(2,3)=number_2_3;
    
 number_2_4= length(find(T_sim == 2 & T_test == 4));H(2,4)=number_2_4;
    
  
    
 number_3_1= length(find(T_sim == 3 & T_test == 1));H(3,1)=number_3_1;
    
 number_3_2= length(find(T_sim == 3 & T_test == 2));H(3,2)=number_3_2;
    
 number_3_3= length(find(T_sim == 3 & T_test == 3));H(3,3)=number_3_3;
    
 number_3_4= length(find(T_sim == 3 & T_test == 4));H(3,4)=number_3_4;
    
  
    
 number_4_1= length(find(T_sim == 4 & T_test == 1));H(4,1)=number_4_1;
    
 number_4_2= length(find(T_sim == 4 & T_test == 2));H(4,2)=number_4_2;
    
 number_4_3= length(find(T_sim == 4 & T_test == 3));H(4,3)=number_4_3;
    
 number_4_4= length(find(T_sim == 4 & T_test == 4));H(4,4)=number_4_4;
    
 H
    
    
    
    
    
![](https://ad.itadn.com/c/weblog/blog-img/images/2025-08-17/9AIq6jy08MfStTXNHYKzludDZ2GQ.png)

三、结果分析

关于评估指标的含义在《基于机器学习的心律失常分类(五)——决策树分类》里已经说过了,这里就直接给出分类结果。

本文数据包括四类不同的心律波形,共计6800个样本心拍数据,随机提取其中的70%(4760个)为训练样本,30%(2040个)为测试样本。在学习训练样本,通过测试样本进行分类验证,分别得到分类模型的混淆矩阵,竖直方向表示预测样本的标签类别,水平方向表示真实样本的标签类别。
bc5ad5192ea3f26ea7460f56d96aed2d.png

支持向量机混淆矩阵
315756ff8939615b692dec7b621eb509.png

支持向量机评估指标

全部评论 (0)

还没有任何评论哟~