使用Matlab进行SVM多分类

众所周知,普通的SVM只适合用于二分类,需要进行相应的改进才能使其适用于多分类任务。
本文将着重介绍如何让使用Matlab软件对数据进行SVM多分类。
原理就不详细介绍了,传送门

1. 二分类

1
2
3
4
5
6
7
8
clear,clc 
%% 二分类
%训练数据20×2,20行代表20个训练样本点,第一列代表横坐标,第二列纵坐标
Train_Data =[-3 0;4 0;4 -2;3 -3;-3 -2;1 -4;-3 -4;0 1;-1 0;2 2;3 3;-2 -1;-4.5 -4;2 -1;5 -4;-2 2;-2 -3;0 2;1 -2;2 0];
%共20行,代表训练数据对应点属于哪一类(1类,-1类)
Train_labels =[1 -1 -1 -1 1 -1 1 1 1 -1 -1 1 1 -1 -1 1 1 1 -1 -1]'; TestData = [3 -1;3 1;-2 1;-1 -2;2 -3;-3 -3]; %测试数据
classifier = fitcsvm(Train_Data,Train_labels); %train
test_labels = predict(classifier ,TestData); % test

这里 test_labels 就是最后的分类结果

2. 多分类(不调用工具箱)

1
2
3
4
5
6
7
%% 多分类 
TrainingSet=[ 1 10;2 20;3 30;4 40;5 50;6 66;3 30;4.1 42];%训练数据
TestSet=[3 34; 1 14; 2.2 25; 6.2 63];%测试数据
GroupTrain=[1;1;2;2;3;3;2;2];%训练标签
results =my_MultiSvm(TrainingSet, GroupTrain, TestSet);
disp('multi class problem');
disp(results);

results为最终的分类结果,上述中有用到 my_MultiSvm.m() 函数,以下是my_MultiSvm.m函数的全部内容

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
function [y_predict,models] = my_MultiSvm(X_train, y_train, X_test) 
% multi svm % one vs all 模型
% Input: % X_train: n*m矩阵 n为训练集样本数 m为特征数
% y_train: n*1向量 为训练集label,支持任意多种类
% X_test: n*m矩阵 n为测试集样本数 m为特征数
% Output: % y_predict: n*1向量 测试集的预测结果
y_labels = unique(y_train);
n_class = size(y_labels, 1);
models = cell(n_class, 1); % 训练n个模型
for i = 1:n_class
class_i_place = find(y_train == y_labels(i));
svm_train_x = X_train(class_i_place,:);
sample_num = numel(class_i_place);
class_others = find(y_train ~= y_labels(i));
randp = randperm(numel(class_others));
svm_train_minus = randp(1:sample_num)';
svm_train_x = [svm_train_x; X_train(svm_train_minus,:)];
svm_train_y = [ones(sample_num, 1); -1*ones(sample_num, 1)];
disp(['生成模型:', num2str(i)])
models{i} = fitcsvm(svm_train_x, svm_train_y);
end
test_num = size(X_test, 1);
y_predict = zeros(test_num, 1);
% 对每条数据,n个模型分别进行预测,选择label为1且概率最大的一个作为预测类别
for i = 1:test_num
if mod(i, 100) == 0
disp(['预测个数:', num2str(i)])
end
bagging = zeros(n_class, 1);
for j = 1:n_class
model = models{j};
[label, rat] = predict(model, X_test(i,:));
bagging(j) = bagging(j) + rat(2);
end
[maxn, maxp] = max(bagging);
y_predict(i) = y_labels(maxp);
end
end

3.多分类(调用libsvm工具箱)

【这个工具箱总是出现奇怪的问题,不是很推荐】

以下代码是调用matlab工具箱libsvm的一种方法

1
2
3
4
5
6
TrainingSet=[ 1 10;2 20;3 30;4 40;5 50;6 66;3 30;4.1 42];%训练数据 
TestSet=[3 34; 1 14; 2.2 25; 6.2 63];%测试数据
GroupTrain=[1;1;2;2;3;3;2;2];%训练标签
GroupTest=[1;2;1;3];%测试标签 %svm分类
model = svmtrain(GroupTrain,TrainingSet); % SVM网络预测
[predict_label] = svmpredict(GroupTest,TestSet,model);

之所以放到最后,是因为需要在matlab安装libsvm的工具箱,具体方法可参看此链接在Matlab中安装LibSVM工具箱

以上内容转载改进自这里

2024-4-23
【最近在尝试用SVM做通信信号分选,仿真实验的时候碰到了题目问题,写篇Blog记录一下解决过程】
最后记录一下我的解决代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
%% 使用支持向量机SVM进行分类
% 分开训练数据和测试数据
numTrain = DataSet;
numTest = size(features, 1) - numTrain;
svmTraining = features(1:numTrain, :);
svmTrainingLabel = featureLabels(1:numTrain, :);
svmTest = features(numTrain+1:end, :);
svmTestLabel = featureLabels(numTrain+1:end, :);
% histogram(svmTestLabel,'BarWidth',0.5);
% 随机重排
randp = randperm(size(svmTraining,1));
svmTraining = svmTraining(randp,:);
svmTrainingLabel = svmTrainingLabel(randp,:);
randp = randperm(size(svmTest,1));
svmTest = svmTest(randp,:);
svmTestLabel = svmTestLabel(randp,:);

LabelSet = unique(svmTrainingLabel);
LabelSet = string(LabelSet);
NumClass = size(LabelSet, 1);
svmModels = cell(NumClass, 1);
% 训练NumClass个模型,每个模型针对一种调制类型,将是与不是作为判断数据导入svm进行二分类
for i = 1:NumClass
class_i_place = find(svmTrainingLabel == LabelSet(i));
svm_train_x = svmTraining(class_i_place, :);
sample_num1 = numel(class_i_place);
sample_num2 = min(numel(class_i_place), numTrain-sample_num1);
class_others = find(svmTrainingLabel ~= LabelSet(i));
randp = randperm(numel(class_others));
svm_train_minus = randp(1:sample_num2)';
svm_train_x = [svm_train_x; svmTraining(svm_train_minus, :)];
svm_train_y = [ones(sample_num1, 1); -1 * ones(sample_num2, 1)];
fprintf('正在生成%s类的模型,进度%d/%d\n', LabelSet(i), i, NumClass)
svmModels{i} = fitcsvm(svm_train_x, svm_train_y);
end
svmTestPred = strings(numTest, 1);
% 对每条数据,n个模型分别进行预测,选择label为1且概率最大的一个作为预测类别
for i = 1:numTest
bagging = zeros(NumClass, 1);
for j = 1:NumClass
model = svmModels{j};
[~, rat] = predict(model, svmTest(i, :));
bagging(j) = rat(2);
end
[maxn, maxp] = max(bagging);
svmTestPred(i) = LabelSet(maxp);
end
svmTestPred = categorical(svmTestPred);
testAccuracy = mean(svmTestPred == svmTestLabel);
disp("Test accuracy:" +testAccuracy * 100 + " %")
figure
cm = confusionchart(svmTestLabel, svmTestPred);
cm.Title = 'Confusion Matrix for SVM Accuracy: ' + string(testAccuracy*100) + ' %';
cm.RowSummary = 'row-normalized';
cm.Parent.Position = [cm.Parent.Position(1:2), 740, 424];
saveas(gcf, 'confusion_matrix_restruct.png');