Advertisement

朴素贝叶斯公式(过滤垃圾邮件)

阅读量:

准备数据:切分文本

现提供邮件文件夹:spam;非垃圾邮件文件夹:ham,各有25封邮件。
将邮件中的内容文本,进行分割,转换成一系列词语组成的列表

复制代码
 def textParse(bigString){

    
     import re;
    
     listOfTokens=re.split('\W',bigString) ;
    
     return [tok.lower() for tok in listOfTokens if len(tok)>2];
    
 }

准备数据:从文本中构建词向量

(1)首先将所有文档中的单词组成词汇表

复制代码
 def creatVocabList(dataset){

    
     vocabSet=set([]);
    
     for document in dataset{
    
      vocabSet=vocabSet|set(document);  
    
 }
    
     return list(vocabSet);
    
 }

(2)将每一篇文档转换为词汇表上的向量,现有两种模型:词集模型与词袋模型
词集模型:文档转换成的向量中的每一元素为1或0,分别表示词汇表中的单词在输入文档中是否出现。
词袋模型:文档转换成的向量中的每一元素,表示词汇表中的单词在输入文档中出现的次数

复制代码
 def setOfWords2Vec(vocabList,inputSet){

    
     returnVec=[0]*len(vocabList);
    
     for word in inputSet{
    
     if word in vocabList{
    
      
    
         returnVec[vocabList.index(word)]=1;
    
 }
    
 }
    
     return returnVec;
    
 }
    
  
    
 def bagOfWords2Vec(vocabList,inputSet){
    
     returnVec=[0]*len(vocabList);
    
     for word in inputSet{
    
     if word in vocabList{
    
         returnVec[vocabList.index(word)]+=1;
    
 }
    
 }
    
     return returnVec;
    
 }

训练算法:从词向量计算概率

复制代码
 def trainNB0(trainMatrix,trainCategory){

    
     numTrainDocs=len(trainMatrix) ;
    
     numWords=len(trainMatrix[0]) ;
    
     pAbusive=sum(trainCategory)/float(numTrainDocs) ; 
    
     p0Num=np.ones(numWords); p1Num=np.ones(numWords);
    
     p0Deom=2.0; p1Deom=2.0;
    
     for i in range(numTrainDocs){
    
     if trainCategory[i]==1{
    
         p1Num+=trainMatrix[i] ; 
    
         p1Deom+=sum(trainMatrix[i]); 
    
 }
    
     else{
    
         p0Num+=trainMatrix[i];
    
         p0Deom+=sum(trainMatrix[i]);
    
 }
    
 }
    
     p1Vect=np.log(p1Num/p1Deom); 
    
     p0Vect=np.log(p0Num/p0Deom);
    
     return p0Vect,p1Vect,pAbusive;
    
 }

测试算法:使用朴素贝叶斯进行交叉验证

复制代码
 def classifyNB(vec2Classify,p0Vec,p1Vec,pClass){

    
     p1=sum(vec2Classify*p1Vec)+np.log(pClass);
    
     p0=sum(vec2Classify*p0Vec)+np.log(1-pClass);
    
     if p1>p0{
    
     return 1;  
    
 }
    
     else{
    
     return 0;
    
 }
    
 }
    
 def spamTest(){
    
     docList = [];  
    
     classList = [];  
    
     for i in range(1, 26){
    
     wordlist = textParse(open('data/spam/{}.txt'.format(str(i))).read());
    
     docList.append(wordlist);
    
     classList.append(1);
    
     wordlist = textParse(open('data/ham/{}.txt'.format(str(i))).read());
    
     docList.append(wordlist);
    
     classList.append(0);
    
 }
    
     vocabList = creatVocabList(docList);  
    
     import pickle;
    
     file=open('data/vocabList.txt',mode='wb');  
    
     pickle.dump(vocabList,file);
    
     file.close();
    
     trainingSet = list(range(50));
    
     testSet = [];
    
     for i in range(10){
    
     randIndex = int(np.random.uniform(0, len(trainingSet)));
    
     testSet.append(trainingSet[randIndex]);
    
     del (trainingSet[randIndex]);
    
 }
    
     trainMat = []; 
    
     trainClasses = []; 
    
     for docIndex in trainingSet{
    
   
    
     trainMat.append(bagOfWords2Vec(vocabList, docList[docIndex]));
    
     trainClasses.append(classList[docIndex]);
    
 }
    
     p0v,p1v,pAb=trainNB0(trainMat,trainClasses);
    
     file=open('data/threeRate.txt',mode='wb'); 
    
     pickle.dump([p0v,p1v,pAb],file);
    
     file.close();
    
     errorCount=0;
    
     for docIndex in testSet{
    
     wordVector=bagOfWords2Vec(vocabList,docList[docIndex]);
    
     if classifyNB(wordVector,p0v,p1v,pAb)!=classList[docIndex]{
    
         errorCount+=1;
    
 }
    
 }  
    
  return float(errorCount)/len(testSet);
    
 }

构造分类器

复制代码
 def fileClassify(filepath){

    
     import pickle;
    
     fileWordList=textParse(open(filepath,mode='r').read());
    
     file=open('data/vocabList.txt',mode='rb');
    
     vocabList=pickle.load(file);
    
     vocabList=vocabList;
    
     fileWordVec=bagOfWords2Vec(vocabList,fileWordList);
    
     file=open('data/threeRate.txt',mode='rb');
    
     rate=pickle.load(file);
    
     p0v=rate[0];p1v=rate[1];pAb=rate[2];
    
     return classifyNB(fileWordVec,p0v,p1v,pAb);
    
 }
    
  
    
 if __name__=='__main__'{
    
     print('朴素贝叶斯分类的错误率为:{}'.format(spamTest())) ;
    
     filepath=input('输入需判断的邮件路径');
    
  
    
     if fileClassify('data/spam/1.txt')=={
    
     print('垃圾邮件');
    
 }
    
     else{
    
     print('非垃圾邮件');
    
 }
    
 }

全部评论 (0)

还没有任何评论哟~