使用Matlab进行SVM多分类
众所周知,普通的SVM只适合用于二分类,需要进行相应的改进才能使其适用于多分类任务。
本文将着重介绍如何让使用Matlab软件对数据进行SVM多分类。
原理就不详细介绍了,传送门
1. 二分类
1 2 3 4 5 6 7 8
| clear,clc
Train_Data =[-3 0;4 0;4 -2;3 -3;-3 -2;1 -4;-3 -4;0 1;-1 0;2 2;3 3;-2 -1;-4.5 -4;2 -1;5 -4;-2 2;-2 -3;0 2;1 -2;2 0];
Train_labels =[1 -1 -1 -1 1 -1 1 1 1 -1 -1 1 1 -1 -1 1 1 1 -1 -1]'; TestData = [3 -1;3 1;-2 1;-1 -2;2 -3;-3 -3]; classifier = fitcsvm(Train_Data,Train_labels); test_labels = predict(classifier ,TestData);
|
这里 test_labels 就是最后的分类结果
2. 多分类(不调用工具箱)
1 2 3 4 5 6 7
| TrainingSet=[ 1 10;2 20;3 30;4 40;5 50;6 66;3 30;4.1 42]; TestSet=[3 34; 1 14; 2.2 25; 6.2 63]; GroupTrain=[1;1;2;2;3;3;2;2]; results =my_MultiSvm(TrainingSet, GroupTrain, TestSet); disp('multi class problem'); disp(results);
|
results为最终的分类结果,上述中有用到 my_MultiSvm.m() 函数,以下是my_MultiSvm.m函数的全部内容
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38
| function [y_predict,models] = my_MultiSvm(X_train, y_train, X_test)
y_labels = unique(y_train); n_class = size(y_labels, 1); models = cell(n_class, 1); for i = 1:n_class class_i_place = find(y_train == y_labels(i)); svm_train_x = X_train(class_i_place,:); sample_num = numel(class_i_place); class_others = find(y_train ~= y_labels(i)); randp = randperm(numel(class_others)); svm_train_minus = randp(1:sample_num)'; svm_train_x = [svm_train_x; X_train(svm_train_minus,:)]; svm_train_y = [ones(sample_num, 1); -1*ones(sample_num, 1)]; disp(['生成模型:', num2str(i)]) models{i} = fitcsvm(svm_train_x, svm_train_y); end test_num = size(X_test, 1); y_predict = zeros(test_num, 1);
for i = 1:test_num if mod(i, 100) == 0 disp(['预测个数:', num2str(i)]) end bagging = zeros(n_class, 1); for j = 1:n_class model = models{j}; [label, rat] = predict(model, X_test(i,:)); bagging(j) = bagging(j) + rat(2); end [maxn, maxp] = max(bagging); y_predict(i) = y_labels(maxp); end end
|
3.多分类(调用libsvm工具箱)
【这个工具箱总是出现奇怪的问题,不是很推荐】
以下代码是调用matlab工具箱libsvm的一种方法
1 2 3 4 5 6
| TrainingSet=[ 1 10;2 20;3 30;4 40;5 50;6 66;3 30;4.1 42]; TestSet=[3 34; 1 14; 2.2 25; 6.2 63]; GroupTrain=[1;1;2;2;3;3;2;2]; GroupTest=[1;2;1;3]; model = svmtrain(GroupTrain,TrainingSet); [predict_label] = svmpredict(GroupTest,TestSet,model);
|
之所以放到最后,是因为需要在matlab安装libsvm的工具箱,具体方法可参看此链接在Matlab中安装LibSVM工具箱
以上内容转载改进自这里
2024-4-23
【最近在尝试用SVM做通信信号分选,仿真实验的时候碰到了题目问题,写篇Blog记录一下解决过程】
最后记录一下我的解决代码:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56
|
numTrain = DataSet; numTest = size(features, 1) - numTrain; svmTraining = features(1:numTrain, :); svmTrainingLabel = featureLabels(1:numTrain, :); svmTest = features(numTrain+1:end, :); svmTestLabel = featureLabels(numTrain+1:end, :);
randp = randperm(size(svmTraining,1)); svmTraining = svmTraining(randp,:); svmTrainingLabel = svmTrainingLabel(randp,:); randp = randperm(size(svmTest,1)); svmTest = svmTest(randp,:); svmTestLabel = svmTestLabel(randp,:);
LabelSet = unique(svmTrainingLabel); LabelSet = string(LabelSet); NumClass = size(LabelSet, 1); svmModels = cell(NumClass, 1);
for i = 1:NumClass class_i_place = find(svmTrainingLabel == LabelSet(i)); svm_train_x = svmTraining(class_i_place, :); sample_num1 = numel(class_i_place); sample_num2 = min(numel(class_i_place), numTrain-sample_num1); class_others = find(svmTrainingLabel ~= LabelSet(i)); randp = randperm(numel(class_others)); svm_train_minus = randp(1:sample_num2)'; svm_train_x = [svm_train_x; svmTraining(svm_train_minus, :)]; svm_train_y = [ones(sample_num1, 1); -1 * ones(sample_num2, 1)]; fprintf('正在生成%s类的模型,进度%d/%d\n', LabelSet(i), i, NumClass) svmModels{i} = fitcsvm(svm_train_x, svm_train_y); end svmTestPred = strings(numTest, 1);
for i = 1:numTest bagging = zeros(NumClass, 1); for j = 1:NumClass model = svmModels{j}; [~, rat] = predict(model, svmTest(i, :)); bagging(j) = rat(2); end [maxn, maxp] = max(bagging); svmTestPred(i) = LabelSet(maxp); end svmTestPred = categorical(svmTestPred); testAccuracy = mean(svmTestPred == svmTestLabel); disp("Test accuracy:" +testAccuracy * 100 + " %") figure cm = confusionchart(svmTestLabel, svmTestPred); cm.Title = 'Confusion Matrix for SVM Accuracy: ' + string(testAccuracy*100) + ' %'; cm.RowSummary = 'row-normalized'; cm.Parent.Position = [cm.Parent.Position(1:2), 740, 424]; saveas(gcf, 'confusion_matrix_restruct.png');
|