Advertisement

朴素贝叶斯——垃圾邮件概率问题

阅读量:

现代社会飞速发展,越来越多的垃圾邮件充斥着我们的邮箱,所以我们通过多个词来判断是否为垃圾邮件,但这个概率难以估计,通过贝叶斯公式,可以转化为求垃圾邮件中这些词出现的概率。

为什么使用朴素贝叶斯:使用“贝叶斯”的方法才使得垃圾邮件的分类达到一个较好的效果,而且随着邮件数目越来越多,贝叶斯分类的效果会更加好。

任务主要思路:

分类标准:当 P(垃圾邮件|文字内容)> P(正常邮件|文字内容)时,我们认为该邮件为垃圾邮件,但是单凭单个词而做出判断误差肯定相当大,因此我们可以将所有的词一起进行联合判断。所有词语彼此之间是不相关的(严格说这个假设不成立;实际上各词语之间不可能完全没有相关性,但可以忽略)。假如我们进行判断的词有“中奖”、“免费”、“无套路”,则需要判断P(垃圾邮件|中奖,免费,无套路)与P(正常|中奖,免费,无套路)。

拉普拉斯平滑:

主要的思想是对词的个数+1,对训练数据进行平滑处理。当训练样本很大时,每个词的个数+1造成的概率变化并不大,在误差允许的范围之内。

数据集需要在百度网盘上进行下载,下载完成并解压之后放到.py同目录中去。

代码实现的主要思路:

首先我们需要导入包,然后进行读取数据,然后将读取到的数据进行预处理,预处理之后进行正式的处理,最后就可以进行预测。

代码如下:

复制代码
 import os

    
 import re
    
 import string
    
 import math
    
  
    
 DATA_DIR = 'enron'
    
 target_names = ['ham', 'spam']
    
  
    
  
    
 def get_data(DATA_DIR):
    
     subfolders = ['enron%d' % i for i in range(1, 7)]
    
     data = []
    
     target = []
    
     for subfolder in subfolders:
    
     # spam
    
     spam_files = os.listdir(os.path.join(DATA_DIR, subfolder, 'spam'))
    
     for spam_file in spam_files:
    
         with open(os.path.join(DATA_DIR, subfolder, 'spam', spam_file), encoding="latin-1") as f:
    
             data.append(f.read())
    
             target.append(1)
    
     ham_files = os.listdir(os.path.join(DATA_DIR, subfolder, 'ham'))
    
     for ham_file in ham_files:
    
         with open(os.path.join(DATA_DIR, subfolder, 'ham', ham_file), encoding="latin-1") as f:
    
             data.append(f.read())
    
             target.append(0)
    
     return data, target
    
  
    
  
    
 X, y = get_data(DATA_DIR)
    
  
    
  
    
 class SpamDetector_1(object):
    
  
    
     def clean(self, s):
    
     translator = str.maketrans("", "", string.punctuation)
    
     return s.translate(translator)
    
  
    
     def tokenize(self, text):
    
     text = self.clean(text).lower()
    
     return re.split("\W+", text)
    
  
    
     def get_word_counts(self, words):
    
     word_counts = {}
    
     for word in words:
    
         word_counts[word] = word_counts.get(word, 0.0) + 1.0
    
     return word_counts
    
  
    
  
    
 class SpamDetector_2(SpamDetector_1):
    
     def fit(self, X, Y):
    
     self.num_messages = {}
    
     self.log_class_priors = {}
    
     self.word_counts = {}
    
     self.vocab = set()
    
     # 统计spam和ham邮件的个数
    
     self.num_messages['spam'] = sum(1 for label in Y if label == 1)
    
     self.num_messages['ham'] = sum(1 for label in Y if label == 0)
    
  
    
  
    
     self.log_class_priors['spam'] = math.log(
    
         self.num_messages['spam'] / (self.num_messages['spam'] + self.num_messages['ham']))
    
     self.log_class_priors['ham'] = math.log(
    
         self.num_messages['ham'] / (self.num_messages['spam'] + self.num_messages['ham']))
    
  
    
     self.word_counts['spam'] = {}
    
     self.word_counts['ham'] = {}
    
  
    
     for x, y in zip(X, Y):
    
         c = 'spam' if y == 1 else 'ham'
    
         counts = self.get_word_counts(self.tokenize(x))
    
         for word, count in counts.items():
    
             if word not in self.vocab:
    
                 self.vocab.add(word) 
    
             if word not in self.word_counts[c]:
    
                 self.word_counts[c][word] = 0.0
    
             self.word_counts[c][word] += count
    
  
    
  
    
 MNB = SpamDetector_2()
    
 MNB.fit(X[100:], y[100:])
    
  
    
  
    
 class SpamDetector(SpamDetector_2):
    
     def predict(self, X):
    
     result = []
    
     flag_1 = 0
    
     for x in X:
    
         counts = self.get_word_counts(self.tokenize(x))  
    
         spam_score = 0
    
         ham_score = 0
    
         flag_2 = 0
    
         for word, _ in counts.items():
    
             if word not in self.vocab:
    
                 continue
    
  
    
             
    
             else:
    
                 if word in self.word_counts['spam'].keys() and word in self.word_counts['ham'].keys():
    
                     log_w_given_spam = math.log(
    
                         (self.word_counts['spam'][word] + 1) / (
    
                                     sum(self.word_counts['spam'].values()) + len(self.vocab)))
    
                     log_w_given_ham = math.log(
    
                         (self.word_counts['ham'][word] + 1) / (sum(self.word_counts['ham'].values()) + len(
    
                             self.vocab)))
    
                 if word in self.word_counts['spam'].keys() and word not in self.word_counts['ham'].keys():
    
                     log_w_given_spam = math.log(
    
                         (self.word_counts['spam'][word] + 1) / (
    
                                     sum(self.word_counts['spam'].values()) + len(self.vocab)))
    
                     log_w_given_ham = math.log(1 / (sum(self.word_counts['ham'].values()) + len(
    
                         self.vocab)))
    
                 if word not in self.word_counts['spam'].keys() and word in self.word_counts['ham'].keys():
    
                     log_w_given_spam = math.log(1 / (sum(self.word_counts['spam'].values()) + len(self.vocab)))
    
                     log_w_given_ham = math.log(
    
                         (self.word_counts['ham'][word] + 1) / (sum(self.word_counts['ham'].values()) + len(
    
                             self.vocab)))
    
  
    
             spam_score += log_w_given_spam
    
             ham_score += log_w_given_ham
    
  
    
             flag_2 += 1
    
  
    
             spam_score += self.log_class_priors['spam']
    
             ham_score += self.log_class_priors['ham']
    
  
    
         if spam_score > ham_score:
    
             result.append(1)
    
         else:
    
             result.append(0)
    
  
    
         flag_1 += 1
    
  
    
     return result
    
  
    
  
    
 MNB = SpamDetector()
    
 MNB.fit(X[100:], y[100:])
    
 pred = MNB.predict(X[:100])
    
 true = y[:100]
    
  
    
 accuracy = 0
    
 for i in range(100):
    
     if pred[i] == true[i]:
    
     accuracy += 1
    
 print(accuracy)

结果如下:

全部评论 (0)

还没有任何评论哟~