Advertisement

朴素贝叶斯算法过滤垃圾邮件

阅读量:

基础知识

  1. 条件概率
    在A发生的条件下B发生的概率记为P(B|A),

  2. 全概率公式
    B_{1},B_{2}...B_{n}为样本空间E的一个划分,则

算法概述

1. 贝叶斯公式
若为样本空间E的一个划分,则

D_{i}为训练集D中第i类样本构成的集合,在假设这些样本相互独立并遵循相同分布的前提下,则参数\theta_{c}在数据集D_{i}上的似然度为:

应用极大似然估计的过程即为确定其参数估计值 \hat{\theta_{c}} ,使得观测数据的概率密度函数 P(D_{i}|\theta_{c}) 达到最大;由于连乘运算在实际计算中容易导致数值下溢问题 ,因此通常会对计算结果取对数以避免数值下溢问题 。

3. 朴素贝叶斯
基于其核心假设为条件独立性,则有
P(A=a \mid B=B_i) = P(A_1=a_1, \dots, A_n=a_n \mid B=B_i) = \prod_{j=1}^{n} P(A_j=a_j \mid B=B_i)

4. 拉普拉斯平滑

其中N表示训练集B中可能的类别数

其中Ni表示第i个属性可能出现的取值数

应用

文本分类–使用朴素贝叶斯过滤垃圾邮件

程序初始化

复制代码
    def loadDataSet():  #数据格式
    postingList=[['my', 'dog', 'has', 'flea', 'problems', 'help', 'please'],
                 ['maybe', 'not', 'take', 'him', 'to', 'dog', 'park', 'stupid'],
                 ['my', 'dalmation', 'is', 'so', 'cute', 'I', 'love', 'him'],
                 ['stop', 'posting', 'stupid', 'worthless', 'garbage'],
                 ['mr', 'licks', 'ate', 'my', 'steak', 'how', 'to', 'stop', 'him'],
                 ['quit', 'buying', 'worthless', 'dog', 'food', 'stupid']]
    classVec = [0,1,0,1,0,1] #1代表侮辱性,0代表正常言论
    return postingList,classVec
    
    def createVocabList(dataSet):#创建词汇表
    vocabSet = set([]) #不重复
    for document in dataSet:
        vocabSet = vocabSet | set(document) #创建并集
    return list(vocabSet)
    
    def bagOfWord2VecMN(vocabList,inputSet):#对照词汇表,将输入句子转化为0,1组成的向量
    returnVec = [0]*len(vocabList)
    for word in inputSet:
        if word in vocabList:
            returnVec[vocabList.index(word)] += 1
    return returnVec

(1) 获取原始数据文件以完成数据收集任务,并呈现给系统作为输入材料。
邮件数据集
(2) 进行初步处理, 将原始数据文件通过自然语言处理技术转换为词条向量形式.

复制代码
    def textParse(bigString):  #接受大写字符串并将其解析为字符串列表
    import re
    listOfTokens = re.split(r'\W*', bigString) #正则表达式匹配,无论何符号都切分
    return [tok.lower() for tok in listOfTokens if len(tok) > 2]

(3)分析数据:检查词条确保解析的正确性

(4)训练数据:使用建立的trainNB()函数

复制代码
    def trainNB(trainMatrix, trainCategory):
    numTrainDocs = len(trainMatrix) #文档个数
    numWords = len(trainMatrix[0]) #每个文档大小
    pAbusive = sum(trainCategory)/float(numTrainDocs) #求每个类别在样本中的概率
    p0Num = np.ones(numWords)  #拉普拉斯平滑,防止等于0
    p1Num = np.ones(numWords)
    p0Denom = 2.0,  p1Denom = 2.0
    for i in range(numTrainDocs):  #遍历文档集
        if trainCategory[i] == 1:
            p1Num += trainMatrix[i]  #侮辱性文档中每个词汇相加
            p1Denom += sum(trainMatrix[i]) #侮辱性文档总词数相加
        else:
            p0Num += trainMatrix[i]
            p0Denom += sum(trainMatrix[i])
    p1Vect = log(p1Num / p1Denom)  # P(Ai|Bi)组成的向量,并取对数
    p0Vect = log(p0Num / p0Denom)
    return p0Vect, p1Vect, pAbusive
    
    def classifyNB(vec2Classify,p0Vec,p1Vec,pClass1):  #判断是否为侮辱性邮件
    p1 = sum(vec2Classify * p1Vec) + log(pClass1) #朴素贝叶斯
    p0 = sum(vec2Classify * p0Vec) + log(1-pClass1) 
    if p1 > p0:
        return 1
    else:
        return 0

(5)使用算法:构建完整程序对一组文档进行分类,输出错分文档

复制代码
    def spamTest():
    docList = []; classList = []; fullText = []
    for i in range(1, 26): #导入并解析文件
        wordList = textParse(open('#文件地址(spam)' % i).read())
        docList.append(wordList)
        fullText.extend(wordList)
        classList.append(1)
         wordList = textParse(open('#文件地址(ham)' % i).read())
        docList.append(wordList)
        fullText.extend(wordList)
        classList.append(0)
    vocabList = createVocablist(docList)
    trainingSet = list(range(50))
    testSet = []
    for i in range(10):  #随机构建训练集
        randIndex = int(np.random.uniform(0, len(trainingSet)))
        testSet.append(trainingSet[randIndex])
        del(trainingSet[randIndex])
    trainMat = []
    trainClasses = []
    for docIndex in trainingSet:
        trainMat.append(bagOfWords2Vec(vocabList, docList[docIndex]))
        trainClasses.append(classList[docIndex])
    p0V, p1V, pSpam = trainNB(np.array(trainMat), np.array(trainClasses))
    errorCount = 0
    for docIndex in testSet:  #对测试集分类
        wordVector = bagOfWords2Vec(vocabList, docList[docIndex])
        if classifyNB(np.array(wordVector), p0V, p1V, pSpam) != classList[docIndex]:
            errorCount += 1
    print('the error rate is :', float(errorCount)/len(testSet))

Harrington, P. (2013). Practical machine learning. China Machine Press, Beijing.
周志华. Machine learning: A comprehensive textbook[M]. Tsinghua University Press, 2016.

全部评论 (0)

还没有任何评论哟~