Advertisement

[Data Mining]APriori算法C++实现

阅读量:
复制代码
 // stdafx.h

    
  
    
 #ifndef __STDAFX_H__
    
 #define __STDAFX_H__
    
  
    
 #include <algorithm>
    
 #include <iostream>
    
 #include <string>
    
 #include <vector>
    
 #include <map>
    
  
    
 #include <cstring>
    
 #include <cassert>
    
 #include <cstdlib>
    
 #include <cctype>
    
 #include <cstdio>
    
  
    
 using namespace std;
    
  
    
 #define PAUSE	pause();
    
 #define STOP	while (true) { pause(); }
    
  
    
 #endif
复制代码
 // stdafx.cpp

    
  
    
 #include "stdafx.h"
复制代码
 // assist.h

    
  
    
 #ifndef __ASSIST_H__
    
 #define __ASSIST_H__
    
  
    
 #include "stdafx.h"
    
  
    
 void pause(void); // 中断暂停
    
 void sort_unique(string& str); // 将一个字符串排序后去重
    
  
    
 #endif
复制代码
 // assist.cpp

    
  
    
 #include "stdafx.h"
    
 #include "assist.h"
    
  
    
 void pause(void)
    
 {
    
 	system("pause>nul");
    
 }
    
  
    
 void sort_unique(string& str)
    
 {
    
 	sort(str.begin(), str.end()); // 排序
    
 	string::iterator old_end = str.end(); // 去重
    
 	string::iterator new_end = unique(str.begin(), str.end());
    
 	str.erase(new_end, old_end);
    
 }
复制代码
 // ItemSet.h

    
  
    
 #ifndef __ITEM_SET_H__
    
 #define __ITEM_SET_H__
    
  
    
 struct CItemSet // 项集
    
 {
    
 	string m_seq; // 用string保存项的序列,将string当成整型数组用
    
 	// 项集中的元素统统去重并按字典序升序排列
    
  
    
 	// 构造函数
    
 	CItemSet(void) {} // 默认空的构造函数
    
 	CItemSet(string seq): m_seq(seq) {} // 直接用string构造
    
 	CItemSet(const CItemSet& obj): m_seq(obj.m_seq) {} // 复制构造函数
    
  
    
 	void split_dot_blank(char* buff); // 对输入的m_seq进行解析去掉分隔符得到数字序列
    
 	friend istream& operator>>(istream& in, CItemSet& obj); // 输入
    
 	friend ostream& operator<<(ostream& out, const CItemSet& obj); // 输出
    
 	
    
 	// 比较运算符,满足STL的协议
    
 	bool operator==(const CItemSet& oth) const;
    
 	bool operator<(const CItemSet& oth) const;
    
  
    
 	int size() const; // 返回项集中项的个数
    
  
    
 	bool isEmpty(void) const; // 项集是否为空
    
 	bool isInside(const CItemSet& oth) const; // 项集this是否包含于项集&oth中	
    
  
    
 	// 两个剪枝
    
 	bool operator%(const CItemSet& oth) const; // 判断this和oth是否拥有相同的前缀
    
 	bool split_appear(vector<CItemSet>& set_list) const; // k + 1项集的所有k项子集是否都出现在项集集合set_list中
    
 	// 实际中set_list将传入计算出的k频繁项集集合
    
  
    
 	CItemSet sub_set_at(int i) const; // 将项集中的第i个项拆出来单独形成一个{ ai }项并返回
    
  
    
 	CItemSet operator+(const CItemSet& oth) const; // 集合并
    
 	CItemSet operator-(const CItemSet& oth) const; // 集合减
    
 };
    
  
    
 #endif
复制代码
 // ItemSet.cpp

    
  
    
 #include "stdafx.h"
    
 #include "assist.h"
    
 #include "ItemSet.h"
    
  
    
 void CItemSet::split_dot_blank(char* buff)
    
 {		
    
 #define	IN		0	// 在数字序列中
    
 #define	OUT		1	// 不在数字序列中
    
 	
    
 	char str_num[1024]; // 存放输入序列中的数字部分
    
 	int flag = OUT; // 刚开始初始化为在序列外
    
 	
    
 	char seq_buff[1024]; // 用于暂存映射好的seq
    
 	int k = 0;
    
 	
    
 	// i指示seq,j指示str_num
    
 	for (int i = 0, j = 0; i <= strlen(buff); i++) {
    
 		if (!ispunct(buff[i]) && !isspace(buff[i]) && buff[i] != '\0') { // 如果为数字
    
 			if (flag == OUT) {
    
 				flag = IN; // 从序列外到序列内表示新的数字的开始
    
 				j = 0; // 因此j要清0,表示开始一个新的数字序列
    
 			}
    
 			str_num[j++] = buff[i];				
    
 		}
    
 		else { // 如果为分隔符
    
 			if (flag == IN) {
    
 				flag = OUT; // 从序列内到序列外表示一个数字序列的结束
    
 				str_num[j] = '\0'; // 因此要添加字符串结束符号
    
 				
    
 				int num;
    
 				sscanf(str_num, "%d", &num);
    
 				// 					if (!mp[num]) mp[num] = ++cnt;
    
 				// 					seq_buff[k++] = mp[num] + 'A' - 1;
    
 				seq_buff[k++] = num;
    
 			}
    
 		}
    
 	}
    
 	seq_buff[k] = '\0';
    
 	
    
 	m_seq = string(seq_buff);
    
 	sort_unique(m_seq);		
    
 	
    
 #undef	IN
    
 #undef	OUT
    
 }
    
  
    
 istream& operator>>(istream& in, CItemSet& obj) 
    
 {
    
 	char buff[1024];
    
 	if (!gets(buff)) {
    
 		in.clear();
    
 		return in >> buff;
    
 	}
    
 	obj.split_dot_blank(buff);
    
 	return in;
    
 }
    
  
    
 ostream& operator<<(ostream& out, const CItemSet& obj) 
    
 {
    
 	cout << "{ ";
    
 	for (int i = 0; i < obj.m_seq.length(); i++) {
    
 		out << (int)obj.m_seq[i];
    
 		if (i + 1 < obj.m_seq.length())
    
 			cout << ", ";
    
 	}
    
 	cout << " }";
    
 	return out;
    
 }
    
  
    
 bool CItemSet::operator==(const CItemSet& oth) const
    
 { // 相等就是字符串相等
    
 	return m_seq == oth.m_seq;
    
 }
    
  
    
 bool CItemSet::operator<(const CItemSet& oth) const
    
 { // 小于就是字符串小于
    
 	return m_seq < oth.m_seq;
    
 }
    
  
    
 int CItemSet::size() const
    
 {
    
 	return m_seq.length();
    
 }
    
  
    
 bool CItemSet::isEmpty(void) const
    
 {
    
 	return m_seq.size() == 0;
    
 }
    
  
    
 bool CItemSet::isInside(const CItemSet& oth) const
    
 { // 将this中的每个字符逐个放到oth中查找,必须都存在才返回true
    
 	for (int i = 0; i < m_seq.length(); i++)
    
 		if (oth.m_seq.find(m_seq[i]) == -1)
    
 			return false;
    
 	return true;
    
 }
    
  
    
 bool CItemSet::operator%(const CItemSet& oth) const
    
 { // 判断前缀是否相等
    
 	// 首先序列长度至少要大于1
    
 	if (m_seq.size() != 1 && m_seq.substr(0, m_seq.length() - 1) != oth.m_seq.substr(0, oth.m_seq.length() - 1))
    
 		return false; // 前缀相同才能合并
    
 	return true;
    
 }
    
  
    
 bool CItemSet::split_appear(vector<CItemSet>& set_list) const
    
 {
    
 	for (int i = 0; i < m_seq.length(); i++)
    
 	{
    
 		string tmp_seq = m_seq;
    
 		CItemSet itmset = CItemSet(tmp_seq.erase(i, 1)); // 逐个擦除字符
    
 		// 剩下的序列和set_list中的每个序列都二分查找一下,必须出现一次才能返回true
    
 		vector<CItemSet>::iterator it_find = find(set_list.begin(), set_list.end(), itmset);
    
 		if (it_find == set_list.end())
    
 			return false;
    
 	}
    
 	return true;
    
 }
    
  
    
 CItemSet CItemSet::sub_set_at(int i) const
    
 {
    
 	assert(0 <= i && i <= m_seq.length());
    
 	return CItemSet(m_seq.substr(i, 1));
    
 }
    
  
    
 CItemSet CItemSet::operator+(const CItemSet& oth) const
    
 {
    
 	string seq_cat;
    
 	seq_cat = m_seq + oth.m_seq; // 相加后一定要排序去重
    
 	sort_unique(seq_cat);
    
 	return CItemSet(seq_cat);
    
 }
    
  
    
 CItemSet CItemSet::operator-(const CItemSet& oth) const
    
 {
    
 	int pos = m_seq.find(oth.m_seq);
    
 	string tmp_seq = m_seq;
    
 	return CItemSet(tmp_seq.erase(pos, oth.m_seq.length()));
    
 }
复制代码
 // Assoc.h

    
  
    
 #ifndef __ASSOC_H__
    
 #define __ASSOC_H__
    
  
    
 struct CAssoc // 关联规则
    
 {
    
 	CItemSet m_fst, m_sec; // 前件和后件
    
 	
    
 	CAssoc(CItemSet fst, CItemSet sec):
    
 		m_fst(fst), m_sec(sec) {}
    
 	
    
 	friend ostream& operator<<(ostream& os, const CAssoc& obj);
    
 };
    
  
    
 #endif
复制代码
 // Assoc.cpp

    
  
    
 #include "stdafx.h"
    
 #include "assist.h"
    
 #include "ItemSet.h"
    
 #include "Assoc.h"
    
  
    
 ostream& operator<<(ostream& os, const CAssoc& obj)
    
 {
    
 	os << obj.m_fst << " -> " << obj.m_sec;
    
 	return os;
    
 }
复制代码
 // Tasks.h

    
  
    
 #ifndef __TASKS_H__
    
 #define __TASKS_H__
    
  
    
 struct CTasks 
    
 {
    
 	vector<CItemSet> task_list; // 输入的任务序列
    
  
    
 	// 频繁项集序列
    
 	// fre_set_list[0]就是频繁1项集序列
    
 	// fre_set_list[1]就是频繁2项集序列,以此类推
    
 	vector<vector<CItemSet> > fre_set_list;
    
  
    
 	vector<CAssoc> str_assoc_list; // 强关联规则的列表
    
  
    
 	// 提供"项集 -> 项集出现的次数的映射"
    
 	// 通过get_count将STL map的负责处理过程包装起来
    
 	map<CItemSet, int> mp;
    
  
    
 	int max; // 输入序列中最大的序号
    
  
    
 	// 支持度和置信度的阈值
    
 	double sup;
    
 	double conf;
    
 	
    
 	CTasks(void): max(0), sup(42.85), conf(71.4) {}
    
  
    
 	// 输入输出任务表,都正规化表示
    
 	friend istream& operator>>(istream& is, CTasks& obj);
    
 	friend ostream& operator<<(ostream& os, CTasks& obj);
    
  
    
 	int get_count(const CItemSet& itmset); // 得到mp[itmset]的值(记忆化存储)
    
 	double get_sup(const CItemSet& itmset); // 得到项集itmset的支持度
    
 	double get_conf(const CAssoc& assoc); // 得到关联规则assoc对应的置信度
    
  
    
 	void make_fre_set_list(void); // 构造任务表的频繁项集
    
 	void rec(const CItemSet& fst, const CItemSet& sec, int start); // 递归构造候选关联规则
    
 	void make_str_assoc_list(void); // 构造任务表的强关联规则
    
 };
    
  
    
 #endif
复制代码
 // Tasks.cpp

    
  
    
 #include "stdafx.h"
    
 #include "assist.h"
    
 #include "ItemSet.h"
    
 #include "Assoc.h"
    
 #include "Tasks.h"
    
  
    
 istream& operator>>(istream& is, CTasks& obj)
    
 {
    
 	CItemSet it;
    
 	is >> obj.sup >> obj.conf;
    
 	char ch;
    
 	getchar(ch);
    
 	while (is >> it) obj.task_list.push_back(it);
    
 	for (int i = 0; i < obj.task_list.size(); i++)
    
 		for (int j = 0; j < obj.task_list[i].m_seq.length(); j++)
    
 			if (obj.task_list[i].m_seq[j] > obj.max)
    
 				obj.max = obj.task_list[i].m_seq[j];
    
 	is.clear();	
    
 	return is;
    
 }
    
  
    
 ostream& operator<<(ostream& os, CTasks& obj)
    
 {
    
 	for (int i = 0; i < obj.task_list.size(); i++) {
    
 		os << "task " << i + 1 << " : " << obj.task_list[i] << endl;
    
 	}
    
 	cout << endl;
    
 	cout << "Minimum Support: " << obj.sup << endl;
    
 	cout << "Minimum Confidence: " << obj.conf << endl;
    
 	return os;
    
 }
    
  
    
 int CTasks::get_count(const CItemSet& itmset)
    
 {
    
 	if (mp.find(itmset) == mp.end()) {
    
 		int cnt = 0;
    
 		for (int i = 0; i < task_list.size(); i++)
    
 			if (itmset.isInside(task_list[i]))
    
 				cnt++;
    
 			mp.insert(make_pair<CItemSet, int>(itmset, cnt));
    
 	}
    
 	return mp.find(itmset)->second;
    
 }
    
  
    
 double CTasks::get_sup(const CItemSet& itmset)
    
 {
    
 	return get_count(itmset) * 100.0 / task_list.size();
    
 }
    
  
    
 double CTasks::get_conf(const CAssoc& assoc)
    
 {
    
 	return get_count(assoc.m_fst + assoc.m_sec) * 100.0 / get_count(assoc.m_fst);
    
 }
    
  
    
 void CTasks::make_fre_set_list(void)
    
 {
    
 	vector<CItemSet> cd_list; // 候选集列表
    
 	vector<CItemSet> fi_list; // 最终集列表
    
 	for (int i = 1; i <= max; i++) { // 构造候选1项集
    
 		string seq("0");
    
 		seq[0] = i;
    
 		cd_list.push_back(CItemSet(seq));
    
 	}
    
 	
    
 	PAUSE
    
 	for (int lvl = 1; cd_list.size() > 0; lvl++) // 从lvl项集开始构造
    
 	{
    
 		int i, j;
    
 		
    
 		cout << "-----" << lvl << " Item Sets-----" << endl;
    
 		cout << endl;
    
 		PAUSE
    
 				
    
 		cout << "-Candidate sets:" << endl; // 先打印候选项集列表
    
 		if (cd_list.size() == 0) {
    
 			PAUSE
    
 			cout << "There is no candidate item set!" << endl;
    
 			PAUSE
    
 			return ;
    
 		}
    
 		for (i = 0; i < cd_list.size(); i++) {
    
 			PAUSE
    
 			double tmp_sup = get_sup(cd_list[i]);
    
 			cout << cd_list[i] << " : " << tmp_sup << "  ";
    
 			if (tmp_sup >= sup) { // 打印同时判断是否满足最小支持度要求
    
 				cout << "√" << endl;
    
 				fi_list.push_back(cd_list[i]); // 满足的顺便压入最终列表
    
 			}
    
 			else cout << "×" << endl;
    
 		}
    
 		cd_list.clear(); // 结束后将候选列表清空,为生成k + 1项集做准备
    
 		cout << endl;
    
 		PAUSE
    
 			
    
 		cout << "+Final sets:" << endl; // 打印最终列表
    
 		PAUSE
    
 		if (fi_list.size() == 0) {
    
 			cout << "There is no final sets!" << endl;
    
 			PAUSE
    
 			return ;
    
 		}
    
 		fre_set_list.push_back(fi_list);
    
 		for (i = 0; i < fi_list.size(); i++)
    
 			cout << fi_list[i] << endl;
    
 		cout << endl;
    
 		for (i = 0; i < fi_list.size(); i++) // 由k项集最终列表生成候选k + 1项集列表
    
 			for (j = i + 1; j < fi_list.size(); j++) {
    
 				if (!(fi_list[i] % fi_list[j])) continue; // 先剪枝:前缀相同才能合并
    
 					CItemSet tmp_itmset = fi_list[i] + fi_list[j];
    
 					if (!tmp_itmset.isEmpty() && tmp_itmset.split_appear(fi_list))
    
 					// 再剪枝:合并的k + 1项集的所有k项子集必须都出现在k频繁项集列表中
    
 						cd_list.push_back(tmp_itmset);
    
 			}
    
 		fi_list.clear(); // 完毕后k频繁项集清空,为下一轮做准备
    
 		PAUSE
    
 	}
    
 }
    
  
    
 // 由关联规则fst -> sec来构造所有候选关联规则
    
 void CTasks::rec(const CItemSet& fst, const CItemSet& sec, int start)
    
 { // start指向前件中需要挪到后件去的那个项的开始试探值
    
 	if (fst.size() == 1) return ; // 递归终点:前件至少要有一个项
    
 	
    
 	for (int i = start; i < fst.size(); i++) {
    
 		CItemSet single_set = fst.sub_set_at(i);
    
 		CItemSet next_fst = fst - single_set; // 从前件中挪走第i项
    
 		CItemSet next_sec = sec + single_set; // 放到后件中
    
 		CAssoc tmp_assoc(next_fst, next_sec); // 形成新的关联规则
    
 		double tmp_conf = get_conf(tmp_assoc); // 并计算新关联规则的置信度
    
 		cout << tmp_assoc << " : " << tmp_conf << "  ";
    
 		if (tmp_conf > conf) { // 剪枝1:对于置信度小于最低置信度的就不用往下递归了
    
 			cout << "√" << endl;
    
 			str_assoc_list.push_back(tmp_assoc);
    
 			rec(next_fst, next_sec, i); // 剪枝2:往下递归时不用从头(0)开始试探,而是
    
 			// 从去掉的第i项的后面一项开始试探
    
 			// 原因有俩,其一是如果上一层是成功的,则从i之前试探会重复
    
 			// 其二是如果上一层是失败的,则意味着i之前必有失败的案例,如果继续把那个失败的项
    
 			// 加入后件仍然会是失败的
    
 			// 因此后件即使不排序也是自动按照字典序升序排列的
    
 		}
    
 		else cout << "×" << endl;
    
 		PAUSE
    
 	}
    
 }
    
  
    
 void CTasks::make_str_assoc_list(void)
    
 {
    
 	cout << "-Candidate:" << endl; // 先打印候选关联规则
    
 	PAUSE
    
 	for (int i = 1; i < fre_set_list.size(); i++)
    
 		for (int j = 0; j < fre_set_list[i].size(); j++)
    
 			rec(fre_set_list[i][j], CItemSet(string("")), 0);
    
 	cout << endl;
    
 		
    
 	cout << "+Final:" << endl; // 再打印最终得到的强关联规则
    
 	PAUSE
    
 	if (!str_assoc_list.size()) {
    
 		cout << "no strong association rules" << endl;
    
 		PAUSE
    
 		return ;
    
 	}
    
 	PAUSE
    
 	for (int k = 0; k < str_assoc_list.size(); k++) cout << str_assoc_list[k] << endl;
    
 	PAUSE
    
 }
复制代码
 // main.cpp

    
  
    
 #include "stdafx.h"
    
 #include "assist.h"
    
 #include "ItemSet.h"
    
 #include "Assoc.h"
    
 #include "Tasks.h"
    
  
    
 // 输入:第一行是最小支持度和置信度,接着每行一条任务
    
 /*
    
   12. 41.85 71.4
    
 1,2,3
    
 1,2,4
    
 1,3,4
    
 1,2,3,5
    
 1,3,5
    
 2,4,5
    
 1,2,3,4
    
   21. */
    
 /* 最后来三个EOF才行,因为输入流使用C++和C的混编 */
    
 /* 之后程序遇到每一步都会暂停,按任意键继续即可 */
    
  
    
 int main() {
    
 	CTasks tasks;
    
 	cin >> tasks;
    
 	cout << endl;
    
 	
    
 	cout << "==================================" << endl;
    
 	cout << "==============Tasks===============" << endl;
    
 	cout << "==================================" << endl;
    
 	cout << endl;
    
 	cout << tasks;
    
 	cout << endl;
    
 	
    
 	PAUSE
    
 	cout << "==================================" << endl;
    
 	cout << "========Frequent Item Sets========" << endl;
    
 	cout << "==================================" << endl;
    
 	cout << endl;
    
 	tasks.make_fre_set_list();
    
 	
    
 	cout << "==================================" << endl;
    
 	cout << "=====Strong Association Rules=====" << endl;
    
 	cout << "==================================" << endl;
    
 	cout << endl;
    
 	tasks.make_str_assoc_list();
    
 	
    
 	STOP
    
 		
    
 	return 0;
    
 }

全部评论 (0)

还没有任何评论哟~