朴素贝叶斯分类算法(Naive Bayes algorithm)
发布时间
阅读量:
阅读量
多变量朴素贝叶斯分类器(MNB)模型

在文本分类任务中, 我们的目的是识别文档归类于最可能的类别. 在NB分类方法中, 判断依据是MAP(最大后验概率)估计值的结果cmap:




下图为多项式NB 的训练和分类算法伪代码

下面是一个例子简单解析这个算法

java代码实现
package test5;
import java.util.ArrayList;
import java.util.Map.Entry;
import java.util.TreeMap;
class Node{
ArrayList<String> V;
double[] prior;
double[][] condprob;
Node(){
V=new ArrayList<String>();
}
}
public class MultinomialNB {
public static Node trainMultinomialNB(ArrayList<String> classSet, TreeMap<String,String> trainingSet){
// 单词出现多次,只算一个,ExtractVocabulary(D),计算词汇表大小
ArrayList<String> V=new ArrayList<String>();
for(int i=0;i<trainingSet.size();i++){
for(Entry<String, String> entry:trainingSet.entrySet()){
String str=entry.getKey();
String[] strs=str.split(" "); //根据空格切割字符串
for(String word:strs){
if(!V.contains(word)){
V.add(word);
}
}
}
}
//countDoc,计算训练集合中的文档总数
double N=trainingSet.size();
double[] prior=new double[classSet.size()];
double[][] condprob=new double[V.size()][classSet.size()];
int i=0;
for(String c:classSet){
//CountDocsInClass,计算Nc,Nc为训练集合中c 类所包含的文档数目
double Nc=0;
String text="";
for(Entry<String, String> entry:trainingSet.entrySet()){
if(entry.getValue().equals(c)){
Nc++;
text+=entry.getKey(); //将类别c下的文档连接成一个大字符串,concatenatetextofalldocinclass
text+=" ";
}
}
prior[i]=Nc/N; //计算先验概率
String[] texts=text.split(" "); //根据空格切割字符串
double[] Tct=new double[V.size()];
int j=0;
// 计算类c下单词t的出现次数Tct,CountTokenOfTerms
for(String t:V){
double count=0;
for(String word:texts){
if(t.equals(word)){ //注意,不能使用==
count++;
}
}
Tct[j]=count;
j++;
}
//计算条件概率 的估计值为t在c 类文档中出现的相对频率
double Sigma=0;
for(int x=0;x<Tct.length;x++){
Sigma+=Tct[x];
}
for(int t=0;t<V.size();t++){
condprob[t][i]=(Tct[t]+1)/(Sigma+V.size());
}
i++;
}
Node node=new Node(); //用结构体返回多个结果
node.V=V;
node.prior=prior;
node.condprob=condprob;
return node; //返回类node
}
public static double[] applyMultinomialNB(ArrayList<String> classSet,ArrayList<String> V,double[] prior,
double[][]condprob,ArrayList<String> testingSet){
double[] scores=new double[classSet.size()];
for(int i=0;i<testingSet.size();i++){ //遍历测试集
ArrayList<String> W=new ArrayList<String>();
// EXTRACTOKENSFROMDoc,将文档d中的单词抽取出来,允许重复,如果单词是全新的,在全局单词表V中都没出现过,则忽略掉
String str=testingSet.get(i);
String[] strs=str.split(" "); //根据空格切割字符串
for(String word:strs){
if(V.contains(word)){
W.add(word);
}
}
//计算后验概率
int index=0;
scores=new double[classSet.size()];
for(int j=0;j<classSet.size();j++){
scores[j]=Math.log(prior[j]);
for(String t:W){
for(String word:V){
if(t.equals(word)){
index=V.indexOf(word);
}
}
scores[j]+=Math.log(condprob[index][j]);
}
}
if(scores[0]>scores[1]){ //比较两个最大后验概率,
System.out.println("测试集属于类yes");
}else{
System.out.println("测试集属于类no");
}
}
return scores;
}
public static void main(String[] args){
ArrayList<String> doc=new ArrayList<String>();
ArrayList<String> testingSet=new ArrayList<String>();
ArrayList<String> classSet=new ArrayList<String>();
//初始化训练集
doc.add("Chinese Beijing Chinese"); //属于类别China
doc.add("Chinese Chinese Shanghai"); //属于类别China
doc.add("Chinese Macao"); //属于类别China
doc.add("Tokyo Japan Chinese"); //不属于类别China
TreeMap<String,String> trainingSet=new TreeMap<String,String>();
for(int i=0;i<doc.size()-1;i++){
trainingSet.put(doc.get(i),"yes");
}
trainingSet.put(doc.get(doc.size()-1),"no");
//初始化测试集
testingSet.add("Chinese Chinese Chinese Tokyo Japan");
//初始化类别集合
classSet.add("yes");
classSet.add("no");
Node node=new Node();
node=trainMultinomialNB(classSet,trainingSet);
double[] scores;
scores=applyMultinomialNB(classSet,node.V,node.prior,node.condprob,testingSet);
//输出结果
for(double score:scores){
System.out.println(score);
}
}
}
结果为:
测试集属于类yes
-8.10769031284391
-8.906681345001262
由于采用了对数运算处理后的数据特征值均为负值这一性质,在计算得到每个样本点与参考点之间的相似度指标时会呈现负值特征;经比较发现yes类样本的C-Map值高于no类样本,则可推断该测试集应归于yes类别
全部评论 (0)
还没有任何评论哟~
