Advertisement

[035]Java实现SVM对乳腺癌检测数据分类分析

阅读量:

背景简介:

最近正在学习SVM分类技术,在浏览网络上大多数相关资料时发现主要是关于其原理介绍以及如何通过终端命令行工具使用svm-train和svm-predict等方法的具体指导。然而,在实际的数据分析实现方面却相对较少。进一步查阅资料后发现了一个非常优秀的开发库——LIBSVM(Library for Support Vector Machines)。该软件由台湾国家科学委员会发布并维护支持的版本很好地封装了支持向量机的相关代码结构,并使数据分析更加便捷。此外该软件还提供了大量适用于分类、回归以及标签化任务的数据集集合官方地址:

准备训练和测试数据:

在LibSVM官网上可获取所需的数据集,在此例中我们下载了UCI的乳腺癌数据集,在具体格式方面将依据以下说明进行设置

复制代码
    <label> <index1>:<value1> <index2>:<value2>

例如:


  1. 4.000000 1:1099510.000000 2:10.000000 3:4.000000 4:3.000000 5:1.000000 6:3.000000 7:3.000000 8:6.000000 9:5.000000 10:2.000000
  2. 4.000000 1:1100524.000000 2:6.000000 3:10.000000 4:10.000000 5:2.000000 6:8.000000 7:10.000000 8:7.000000 9:3.000000 10:3.000000
  3. 4.000000 1:1102573.000000 2:5.000000 3:6.000000 4:5.000000 5:6.000000 6:10.000000 7:1.000000 8:3.000000 9:1.000000 10:1.000000

链接:https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/binary.html#breast-cancer

字段含义:
0.Class: (2 for benign, 4 for malignant)
1. Sample code number: id number
2. Clump Thickness: 1 - 10
3. Uniformity of Cell Size: 1 - 10
4. Uniformity of Cell Shape: 1 - 10
5. Marginal Adhesion: 1 - 10
6. Single Epithelial Cell Size: 1 - 10
7. Bare Nuclei: 1 - 10
8. Bland Chromatin: 1 - 10
9. Normal Nucleoli: 1 - 10
10. Mitoses: 1 - 10

项目部署:

建立一个JAVA工程后,在项目根目录中添加LibSVM的JAR包镜像文件。需要注意的是,在项目中还需引入以下三个特定的Java源代码文件:svm_train.javasvm_scale.java以及 svm_predict.java这三份代码文件。这些类实际上是对原始LibSVM库进行了封装处理,并将原本以命令行方式使用的功能转化为通过String[]数组接口进行操作的方式以提高编程开发效率。另外一份名为srm_tony.java的Java源代码则提供了一个图形用户界面(GUI)功能模块,并非必须直接引入到项目中。

为了实现训练集与测试集数据文件的有效管理与引用,在项目工程目录中合理配置数据存储位置以确保后续开发过程中的便捷访问

在Java编程环境中,请按照以下步骤实现对LibSVM API的调用以完成分类任务的具体代码实现:

复制代码
    import java.io.IOException;
    
    import libsvm.*;
    
    /**JAVA test code for LibSVM
     * @author yangliu
     * @blog 
     * @mail yangliuyx@gmail.com
     */
    
    public class LibSVMTest {
    
    public static void main(String[] args) throws IOException {
        // TODO Auto-generated method stub
        //Test for svm_train and svm_predict
        //svm_train: 
        //    param: String[], parse result of command line parameter of svm-train
        //    return: String, the directory of modelFile
        //svm_predect:
        //    param: String[], parse result of command line parameter of svm-predict, including the modelfile
        //    return: Double, the accuracy of SVM classification
        String[] trainArgs = {"UCI-breast-cancer-tra"};//directory of training file
        String modelFile = svm_train.main(trainArgs);
        String[] testArgs = {"UCI-breast-cancer-test", modelFile, "UCI-breast-cancer-result"};//directory of test file, model file, result file
        Double accuracy = svm_predict.main(testArgs);
        System.out.println("SVM Classification is done! The accuracy is " + accuracy);
    
        //Test for cross validation
        //String[] crossValidationTrainArgs = {"-v", "10", "UCI-breast-cancer-tra"};// 10 fold cross validation
        //modelFile = svm_train.main(crossValidationTrainArgs);
        //System.out.print("Cross validation is done! The modelFile is " + modelFile);
    }
    
    }

执行结果:

复制代码
    .*
    optimization finished, #iter = 1223
    nu = 0.6996186233933985
    obj = -271.992875483972, rho = 0.4257786283326366
    nSV = 639, nBSV = 222
    Total nSV = 639
    Accuracy = 69.23076923076923% (27/39) (classification)
    SVM Classification is done! The accuracy is 0.6923076923076923

可以看到准确率只有0.69

程序改进:

通过Java程序svm_scale.java对数据进行归一化处理,并将处理后的标准化数据分别存入UCI-breast-cancer-tra-scale和UCI-breast-cancer-test-scale两个文件中,并完成后续的处理流程。

svm_scale.java需要修改几个地方代码:
output_target函数修改为:

复制代码
    private String output_target(double value)
    {
        if(y_scaling)
        {
            if(value == y_min)
                value = y_lower;
            else if(value == y_max)
                value = y_upper;
            else
                value = y_lower + (y_upper-y_lower) *
                (value-y_min) / (y_max-y_min);
        }
    
        System.out.print(value + " ");
        return value + " ";
    }

output函数改为:

复制代码
    private String output(int index, double value)
    {
        /* skip single-valued attribute */
        if(feature_max[index] == feature_min[index])
            return " ";
    
        if(value == feature_min[index])
            value = lower;
        else if(value == feature_max[index])
            value = upper;
        else
            value = lower + (upper-lower) * 
                (value-feature_min[index])/
                (feature_max[index]-feature_min[index]);
    
        if(value != 0)
        {
            System.out.print(index + ":" + value + " ");
            new_num_nonzeros++;
            return index + ":" + value + " ";
        }
        return " ";
    }

run需要修改两部分代码:

复制代码
    switch(argv[i-1].charAt(1))
            {
                case 'l': lower = Double.parseDouble(argv[i]);  break;
                case 'u': upper = Double.parseDouble(argv[i]);  break;
                case 'y':
                      y_lower = Double.parseDouble(argv[i]);
                      ++i;
                      y_upper = Double.parseDouble(argv[i]);
                      y_scaling = true;
                      break;
                case 's': save_filename = argv[i];  break;
                case 'r': restore_filename = argv[i];   break;
                case 'p': save_filePath = argv[i];  break;
                default:
                      System.err.println("unknown option");
                      exit_with_help();
            }
复制代码
        BufferedWriter bw = FileStream.fileWriterStream(save_filePath,  true);
    
        /* pass 3: scale */
        while(readline(fp) != null)
        {
            int next_index = 1;
            double target;
            double value;
            String dataLine = "";
    
            StringTokenizer st = new StringTokenizer(line," \t\n\r\f:");
            target = Double.parseDouble(st.nextToken());
            dataLine = output_target(target);
            while(st.hasMoreElements())
            {
                index = Integer.parseInt(st.nextToken());
                value = Double.parseDouble(st.nextToken());
                for (i = next_index; i<index; i++)
                    dataLine += output(i, 0);
                dataLine += output(index, value);
                next_index = index + 1;
            }
    
            for(i=next_index;i<= max_index;i++)
                output(i, 0);
            System.out.print("\n");
            dataLine += "\n";
            FileStream.writerData(bw, dataLine);
        }
        if (new_num_nonzeros > num_nonzeros)
            System.err.print(
             "WARNING: original #nonzeros " + num_nonzeros+"\n"
            +"         new      #nonzeros " + new_num_nonzeros+"\n"
            +"Use -l 0 if many original feature values are zeros\n");
    
        fp.close();
        bw.close();

新建FileStream 类,用于数据存储

复制代码
    package com.yuan.util;
    
    import java.io.BufferedWriter;
    import java.io.FileWriter;
    import java.io.IOException;
    
    public class FileStream {
    
    
    public static BufferedWriter fileWriterStream(String fileName, boolean append){
        BufferedWriter fp_save = null;
        try {
            fp_save = new BufferedWriter(new FileWriter(fileName, append));
        } catch(IOException e) {
            System.err.println("can't open file " + fileName);
            System.exit(1);
        }
        return fp_save;
    }
    
    public static void writerData(BufferedWriter bw, String data) throws IOException{
        bw.write(data);
    }
    }

修改SVMClassifierTest类

复制代码
    // TODO Auto-generated method stub  
            //Test for svm_train and svm_predict  
            //svm_train:   
            //    param: String[], parse result of command line parameter of svm-train  
            //    return: String, the directory of modelFile  
            //svm_predect:  
            //    param: String[], parse result of command line parameter of svm-predict, including the modelfile  
            //    return: Double, the accuracy of SVM classification  
            String[] trainArgs = {"UCI-breast-cancer-tra"};//directory of training file  
            svm_scale.main(new String[]{"-p", "UCI-breast-cancer-tra-scale", "UCI-breast-cancer-tra"});//训练数据归一化存储
            svm_scale.main(new String[]{"-p", "UCI-breast-cancer-test-scale", "UCI-breast-cancer-test"});//测试数据归一化存储
    
            String[] scaleTrainArgs = {"UCI-breast-cancer-tra-scale"};//directory of training file  
            String modelFile = svm_train.main(scaleTrainArgs); 
    
            String[] testArgs = {"UCI-breast-cancer-test-scale", modelFile, "UCI-breast-cancer-result"};//directory of test file, model file, result file  
            Double accuracy = svm_predict.main(testArgs);  
            System.out.println("SVM Classification is done! The accuracy is " + accuracy);  
    
            //Test for cross validation  
            //String[] crossValidationTrainArgs = {"-v", "10", "UCI-breast-cancer-tra"};// 10 fold cross validation  
            //modelFile = svm_train.main(crossValidationTrainArgs);  
            //System.out.print("Cross validation is done! The modelFile is " + modelFile);

结果:

复制代码
    *
    optimization finished, #iter = 97
    nu = 0.0711047842614367
    obj = -78.46733678185721, rho = -0.9253740588830286
    nSV = 99, nBSV = 83
    Total nSV = 99
    Accuracy = 89.74358974358975% (70/78) (classification)
    SVM Classification is done! The accuracy is 0.8974358974358975

可以看到准确率大幅度提高。
至此LIBSVM的简单调用及改进就完成了。

引用:

该研究团队开发了一种支持向量机的集合,《智能系统与技术》期刊上的一篇论文发表在《ACM Transactions on Intelligent Systems and Technology》期刊的第4期第5卷(不确定具体卷号),详细信息可参考http://www.csie.ntu.edu.tw/~cjlin/libsvm

全部评论 (0)

还没有任何评论哟~