Advertisement

朴素贝叶斯算法Java 实现

阅读量:

对于朴素贝叶斯算法来说,各位熟悉Java开发的朋友们应该都再熟悉不过了;关于其工作原理的具体细节我也就不赘述了。我的主要目标就是通过编写Java代码来实现这一算法的思想

1. 用javabean +Arraylist 对于训练数据存储

2. 对于样本数据训练

具体的代码如下:

复制代码
 package NB;

    
 /** * 训练样本的属性 javaBean
    
  * */
    
 public class JavaBean {
    
   int age;
    
   String income;
    
   String student;
    
   String credit_rating;
    
   String buys_computer;
    
  public JavaBean(){
    
 	 
    
  }
    
 public JavaBean(int age,String income,String student,String credit_rating,String buys_computer){
    
 	this.age=age;
    
 	this.income=income;
    
 	this.student=student;
    
 	this.credit_rating=credit_rating;
    
 	this.buys_computer=buys_computer;
    
 }
    
   
    
   
    
 public int getAge() {
    
 	return age;
    
 }
    
 public void setAge(int age) {
    
 	this.age = age;
    
 }
    
 public String getIncome() {
    
 	return income;
    
 }
    
 public void setIncome(String income) {
    
 	this.income = income;
    
 }
    
 public String getStudent() {
    
 	return student;
    
 }
    
 public void setStudent(String student) {
    
 	this.student = student;
    
 }
    
 public String getCredit_rating() {
    
 	return credit_rating;
    
 }
    
 public void setCredit_rating(String credit_rating) {
    
 	this.credit_rating = credit_rating;
    
 }
    
 public String getBuys_computer() {
    
 	return buys_computer;
    
 }
    
 public void setBuys_computer(String buys_computer) {
    
 	this.buys_computer = buys_computer;
    
 }
    
  
    
  
    
  
    
 @Override
    
 public String toString() {
    
 	return "JavaBean [age=" + age + ", income=" + income + ", student="
    
 			+ student + ", credit_rating=" + credit_rating + ", buys_computer="
    
 			+ buys_computer + "]";
    
 }
    
  
    
   
    
   
    
   
    
   
    
 }

算法实现的部分:

复制代码
 package NB;

    
  
    
 import java.io.BufferedReader;
    
 import java.io.File;
    
 import java.io.FileReader;
    
 import java.util.ArrayList;
    
  
    
 public class TestNB {
    
  
    
 	/**data_length
    
 	 * 算法的思想
    
 	 */
    
 	public static  ArrayList<JavaBean> list = new ArrayList<JavaBean>();;
    
 	static int data_length=0;
    
 	public static void main(String[] args) {
    
 		// 1.读取数据,放入list容器中
    
 		File file = new File("E://test.txt");
    
 		txt2String(file);
    
 		//数据测试样本
    
 		testData(25,"Medium","Yes","Fair");
    
 	}
    
     // 读取样本数据
    
 	public static void txt2String(File file) {
    
 		
    
 		try {
    
 			BufferedReader br = new BufferedReader(new FileReader(file));// 构造一个BufferedReader类来读取文件
    
 			String s = null;
    
 			while ((s = br.readLine()) != null) {// 使用readLine方法,一次读一行
    
 				data_length++; 
    
 				splitt(s);
    
 			}
    
 			br.close();
    
 		} catch (Exception e) {
    
 			e.printStackTrace();
    
 		}
    
 		
    
 	}
    
 	// 存入ArrayList中
    
 	  public static void splitt(String str){
    
 		   
    
 	        String strr = str.trim();
    
 	        String[] abc = strr.split("[\ p{Space}]+");
    
 	        int age=Integer.parseInt(abc[0]);
    
 	        JavaBean bean=new JavaBean(age, abc[1], abc[2], abc[3], abc[4]);
    
 	        list.add(bean);		
    
 	       
    
 	       
    
 	    }
    
 	  // 训练样本,测试
    
 	  public static void testData(int age,String a,String b,String c){
    
 		  //训练样本  
    
 		  int number_yes=0;
    
 		  int bumber_no=0;
    
 		  
    
 		 // age情况 个数
    
 		  int num_age_yes=0;
    
 		  int num_age_no=0;
    
 		  // income 
    
 		  int num_income_yes=0;
    
 		  int num_income_no=0;
    
 		  // student 
    
 		  int num_student_yes=0;
    
 		  int num_stdent_no=0;
    
 		  //credit
    
 		  int num_credit_yes=0;
    
 		  int num_credit_no=0;
    
 		  
    
 		  //遍历List 获得数据
    
 		  for(int i=0;i<list.size();i++){
    
 		    JavaBean bb=list.get(i);
    
 		    if(bb.getBuys_computer().equals("Yes")){ //Yes
    
 		    	number_yes++;
    
 	            if(bb.getIncome().equals(a)){//income
    
 	            	num_income_yes++;
    
 	            }
    
 		    	if(bb.getStudent().equals(b)){//student
    
 		    		num_student_yes++;
    
 		    	}
    
 		    	if(bb.getCredit_rating().equals(c)){//credit
    
 		    		num_credit_yes++;
    
 		    	}
    
 		    	if(bb.getAge()==age){//age
    
 		    		num_age_yes++;
    
 		    	}
    
 		    	
    
 		    	
    
 		    }else {//No
    
 		    	bumber_no++;
    
 		    	if(bb.getIncome().equals(a)){//income
    
 	            	num_income_no++;
    
 	            }
    
 		    	if(bb.getStudent().equals(b)){//student
    
 		    		num_stdent_no++;
    
 		    	}
    
 		    	if(bb.getCredit_rating().equals(c)){//credit
    
 		    		num_credit_no++;
    
 		    	}
    
 		    	if(bb.getAge()==age){//age
    
 		    		num_age_no++;
    
 		    	}
    
 		    	
    
 			}  
    
 		  }
    
 		  
    
 		    System.out.println("购买的历史个数:"+number_yes);
    
 		    System.out.println("不买的历史个数:"+bumber_no);
    
 		    
    
 		    System.out.println("购买+age:"+num_age_yes);
    
 		    System.out.println("不买+age:"+num_age_no);
    
 		    
    
 		    System.out.println("购买+income:"+num_income_yes);
    
 		    System.out.println("不买+income:"+num_income_no);
    
 		    
    
 		    System.out.println("购买+stundent:"+num_student_yes);
    
 		    System.out.println("不买+student:"+num_stdent_no);
    
 		    
    
 		    System.out.println("购买+credit:"+num_credit_yes);
    
 		    System.out.println("不买+credit:"+num_credit_no);
    
 		    
    
 		     概率判断
    
 		    double buy_yes=number_yes*1.0/data_length; // 买的概率
    
 			double buy_no=bumber_no*1.0/data_length; //  不买的概率
    
 		    System.out.println("训练数据中买的概率:"+buy_yes);
    
 		    System.out.println("训练数据中不买的概率:"+buy_no);
    
 			/// 未知用户的判断
    
 		    double nb_buy_yes=(1.0*num_age_yes/number_yes)*(1.0*num_income_yes/number_yes)*(1.0*num_student_yes/number_yes)*(1.0*num_credit_yes/number_yes)*buy_yes;       
    
 		    double nb_buy_no=(1.0*num_age_no/bumber_no)*(1.0*num_income_no/bumber_no)*(1.0*num_stdent_no/bumber_no)*(1.0*num_credit_no/bumber_no)*buy_no;       
    
 		    System.out.println("新用户买的概率:"+nb_buy_yes);
    
 		    System.out.println("新用户不买的概率:"+nb_buy_no);
    
 		    if(nb_buy_yes>nb_buy_no){
    
 		    	System.out.println("新用户买的概率大");
    
 		    }else {
    
 		    	System.out.println("新用户不买的概率大");
    
 			}    
    
 	  }	  
    
 }

对于样本数据:

复制代码
 25  High    No  Fair       No

    
 25  High    No  Excellent  No
    
 33  High    No  Fair       Yes
    
 41  Medium  No  Fair       Yes     
    
 41  Low     Yes Fair       Yes
    
 41  Low     Yes Excellent  No
    
 33  Low     Yes Excellent  Yes
    
 25  Medium  No  Fair       No
    
 25  Low     Yes Fair       Yes
    
 41  Medium  Yes Fair       Yes
    
 25  Medium  Yes Excellent  Yes
    
 33  Medium  No  Excellent  Yes
    
 33  High    Yes Fair       Yes
    
 41  Medium  No  Excellent  No

对于未知用户的数据得出的结果:

复制代码
 购买的历史个数:9

    
 不买的历史个数:5
    
 购买+age:2
    
 不买+age:3
    
 购买+income:4
    
 不买+income:2
    
 购买+stundent:6
    
 不买+student:1
    
 购买+credit:6
    
 不买+credit:2
    
 训练数据中买的概率:0.6428571428571429
    
 训练数据中不买的概率:0.35714285714285715
    
 新用户买的概率:0.028218694885361547
    
 新用户不买的概率:0.006857142857142858
    
 新用户买的概率大

全部评论 (0)

还没有任何评论哟~