Advertisement

数据挖掘--分类之决策树算法ID3

阅读量:

一、决策树:

一棵决策树由一个根节点,一组内部节点和一组叶节点组成。每个内部节点(包括根节点)表示在一个属性上的测试,每个分枝表示一个测试输出,每个叶节点表示一个类,有时不同的叶节点可以表示相同的类。

建立一棵决策树,需要解决的问题主要有:
1)如何选择测试属性?
测试属性的选择顺序影响决策树的结构甚至决策树的准确率,一般使用信息增益度量来选择测试属性。
2)如何停止划分样本?
从根节点测试属性开始,每个内部节点测试属性都把样本空间划分为若干个(子)区域,一般当某个(子)区域的样本同类时,就停止划分样本,有时也通过阈值提前停止划分样本。

二、算法

1. 算法思想及描述
首先,在整个训练数据集S、所有描述属性A1, A2, …, Am上递归地建立决策树。即将S作为根节点;如果S中的样本属于同一类别,则将S作为叶节点并用其中的类别标识,决策树建立完成(递归出口);
否则在S上计算当给定Ak(1≤k≤m)时类别属性C的信息增益G(C, Ak),选择信息增益最大的Ai作为根节点的测试属性;如果Ai的取值个数为v(取值记为a1, a2, …, av),则Ai将S划分为v个子集S1, S2, …, Sv(Sj(1≤j≤v)为S中Ai=aj的样本集合),同时根节点产生v个分枝与之对应。其次,分别在训练数据子集S1, S2, …, Sv、剩余描述属性A1, …, Ai-1, Ai+1, …, Am上采用相同方法递归地建立决策树子树(递归)。

可能出现如下情况,需要停止建立决策(子)树的递归过程。
1)某节点对应的训练数据(子)集为空。此时,该节点作为叶节点并用父节点中占多数的样本类别标识。
2)某节点没有对应的(剩余)描述属性。此时,该节点作为叶节点并用该节点中占多数的样本类别标识。

算法: 决策树分类算法Generate_decision_tree(S, A)
输入:训练数据集S,描述属性集合A
输出:决策树
步骤:
(1)创建对应S的节点Node;
(2)if S中的样本属于同一类别c then
以c标识Node并将Node作为叶节点返回;
(3)if A为空 then
以S中占多数的样本类别c标识Node并将Node作为叶节点返回;

(4)从A中选择对S而言信息增益最大的描述属性Ai作为Node的测试属性;
(5)for Ai的每个可能取值aj(1≤j≤v ) //设Ai的可能取值为a1, a2, …, av
(5.1)产生S的一个子集Sj //Sj(1≤j≤v)为S中Ai=aj的样本集合;
(5.2)if Sj为空 then
创建对应Sj的节点Nj,以S中占多数的样本类别c标识Nj,并将Nj作为叶节点形成Node的一个分枝
(5.3)else 由Generate_decision_tree(Sj, A-{Ai})创建子树形成Node的一个分枝;

三、 信息增益

在决策树分类算法中使用信息增益度量来选择测试属性。
从信息论角度看,通过描述属性可以减少类别属性的不确定性。

3.1 离散型随机变量X的无条件熵 定义为
式中,p(xi)为X= xi的概率;u为X的可能取值个数。

3.2 给定离散型随机变量Y,离散型随机变量X的条件熵 定义为
式中,p(xiyj)为X=xi, Y=yj的联合概率;p(xi/yj)为已知Y=yj时,X= xi的条件概率;u、v分别为X、Y的可能取值个数。
可以证明,H(X/Y)≤H(X)。所以,通过Y可以减少X的不确定性。

3.3 假设训练数据集是关系数据表r1, r2, …, rn,其中描述属性为A1, A2, …, Am、类别属性为C,类别属性C的无条件熵 定义为
式中,u为C的可能取值个数,即类别个数,类别记为c1, c2, …, cu;si为属于类别ci的记录集合,|si|即为属于类别ci的记录总数。

3.4 给定描述属性Ak(1≤k≤m),类别属性C的条件 熵定义为
式中,v为Ak的可能取值个数,取值记为a1, a2, …, av;sj为Ak=aj的记录集合,|sj|即为Ak=aj的记录数目;sij为Ak=aj且属于类别ci的记录集合,|sij|即为Ak=aj且属于类别ci的记录数目。

3.5 采用类别属性的无条件熵与条件熵的差(信息增益)来度量描述属性减少类别属性不确定性的程度。
给定描述属性Ak(1≤k≤m),类别属性C的信息增益定义为:
G(C, Ak)=E(C)-E(C, Ak)
可以看出,G(C, Ak)反映Ak减少C不确定性的程度,G(C, Ak)越大,Ak对减少C不确定性的贡献越大。

代码实现:

复制代码
 package data.decidetree;

    
  
    
 import java.io.BufferedReader;
    
 import java.io.FileInputStream;
    
 import java.io.IOException;
    
 import java.io.InputStreamReader;
    
 import java.util.ArrayList;
    
 import java.util.HashSet;
    
 import java.util.List;
    
 import java.util.Map;
    
 import java.util.Set;
    
 import java.util.Vector;
    
  
    
 public class DecideTreeID3 {
    
  
    
 	private static Vector<String> descripAttriList; // 描述属性
    
 	private static String[] typeAttriValues;   	//类别属性值域
    
 	private static String typeAttri;   			//类别属性
    
 	private static int typeAttriIndex; 			//类别属性下标
    
  
    
 	public DecideTreeID3() {
    
 		descripAttriList = new Vector<String>();
    
 	}
    
  
    
 	public TreeNode create_decidion_tree(Vector<Vector<String>> recordList,
    
 			Vector<String> descripAttris) {
    
 		// 创建recordList对应的节点Node
    
 		TreeNode root = new TreeNode();
    
 		if (this.isSampleDecideAttri(recordList)) {
    
 			// 样本recordList属于同一类别 typeAttri
    
 			root.setAttriName("<is Leaf>");
    
 			root.setDecideAttri(this.getDecideAttriValue(recordList));
    
 			//System.out.println("name="+root.getAttriName()+"   decide=" + root.getDecideAttri());
    
 			return root;
    
 		}
    
  
    
 		if (descripAttris == null || descripAttris.size() == 0) {
    
 			root.setDecideAttri(this.getMostDecideNumsValue(recordList));
    
 			return root;
    
 		}
    
 		root.setRecordList(recordList);
    
 		
    
 		//选择决策属性  并设置其分裂属性值、分裂属性的下标位置
    
 		this.chooseDecideAttri(root);
    
 		
    
 		if (root.getSplitAttri() != null && root.getSplitAttri().size() > 0) {
    
 			// 遍历每个分裂属性值
    
 			for (String attriValue : root.getSplitAttri()) {
    
 				Vector<String> descripAttri = new Vector<String>(descripAttris);
    
 				// 取得下标
    
 				int index = this.getAttriIndex(root.getAttriName());
    
 				
    
 				// 取得recordList的子集
    
 				Vector<Vector<String>> remainRecord = this.getRelatedRecord(recordList, attriValue, index);
    
  
    
 				if (remainRecord == null || remainRecord.size() == 0) {
    
 					TreeNode node = new TreeNode();
    
 					node.setDecideAttri(this.getMostDecideNumsValue(recordList));
    
 					root.getConnectNode().put(node, attriValue);
    
 				} else {
    
 					descripAttri.remove(root.getSplitAttriIndex());
    
 					TreeNode node = new TreeNode(create_decidion_tree(
    
 							remainRecord, descripAttri));
    
 					if (node.getAttriName() != null)
    
 						root.getConnectNode().put(node, attriValue);
    
 				}
    
 			}
    
 		}
    
 		return root;
    
 	}
    
  
    
 	/** * 选择分枝的测试属性 设置 treeNode的分裂属性位置,以及各个分枝
    
 	 * * @param typeAttri
    
 	 *            测试属性
    
 	 * @return 选择的属性 及其位置
    
 	 */
    
 	private void chooseDecideAttri(TreeNode treeNode) {
    
 		double max_entropy = 0.0;
    
 		TreeNode copyNode = new TreeNode(treeNode);
    
 		// 取得E(C)的值
    
 		double Ec = this.gainEValue(treeNode);
    
  
    
 		for (int i = 0; i < descripAttriList.size() - 1; i++) {
    
 			String descripAttri = descripAttriList.get(i);
    
 			if (!typeAttri.equals(descripAttri)) {
    
 				// 计算每个描述属性在测试属性下的熵 ,熵值最大则为下一个测试属性
    
 				double entropy = Ec- this.informationGain(copyNode, descripAttri);
    
 				// System.out.println("G(" + typeAttri + ","+ descripAttri +")=" + entropy);
    
 				if (max_entropy < entropy) {
    
 					max_entropy = entropy;
    
 					// 设置该结点的分裂属性
    
 					treeNode.setAttriName(descripAttriList.get(copyNode
    
 							.getSplitAttriIndex()));
    
 					treeNode.setSplitAttri(copyNode.getSplitAttri());
    
 					treeNode.setSplitAttriIndex(copyNode.getSplitAttriIndex());
    
 				}
    
 			}
    
 		}
    
 	}
    
  
    
 	/** * 根据属性名称 取得其在descripAttriList中的下标位置
    
 	 * * @param attriType
    
 	 * @return
    
 	 */
    
 	private int getAttriIndex(String attriType) {
    
 		int i = 0;
    
 		int index = 0;
    
 		for (String descripAttri : descripAttriList) {
    
 			if (attriType.equals(descripAttri)) {
    
 				index = i;
    
 				break;
    
 			}
    
 			i++;
    
 		}
    
 		return index;
    
 	}
    
  
    
 	/** * 类别属性C的条件熵
    
 	 * * @param typeAttri
    
 	 *            测试属性C
    
 	 * @param descripAttri
    
 	 *            描述属性 Ak E(C,Ak)
    
 	 * @return
    
 	 */
    
 	private double informationGain(TreeNode treeNode, String descripAttri) {
    
  
    
 		/** * 先按描述属性 descripAttri分类,再按测试属性typeAttri分类
    
 		 */
    
 		int n = treeNode.getRecordList().size(); // 该结点拥有的记录数
    
  
    
 		// 取得描述属性descripAttri的下标 ,并求得有多少种
    
 		List<String> type = new ArrayList<String>();
    
 		int descripIndex = this.getAttriIndex(descripAttri);
    
 		// 根据下标 获得该描述属性下的种类,及相应的子记录
    
 		for (List<String> eachRecord : treeNode.getRecordList()) {
    
 			// 取得第index的属性的值
    
 			String attri = eachRecord.get(descripIndex);
    
 			if (!type.contains(attri))
    
 				type.add(attri);
    
 		}
    
 		// 若分裂属性为该描述属性,其各个分裂属性为String[] splitAttris
    
 		String[] splitAttris = type.toArray(new String[0]);
    
  
    
 		double Ecak = 0.0;
    
  
    
 		// 用来存放treeNode的分裂属性
    
 		Vector<String> attr = new Vector<String>();
    
  
    
 		// 计算E(C,Ak)
    
 		// 用 newGroupAttri 存放各个分裂属性下的记录数,长度为分裂属性的种类数
    
 		int[] newGroupAttri = new int[splitAttris.length];
    
 		for (int i = 0; i < splitAttris.length; i++) {
    
 			newGroupAttri[i] = 0;
    
  
    
 			// 存放分裂属性的值
    
 			attr.add(splitAttris[i]);
    
 			for (List<String> eachRecord : treeNode.getRecordList()) {
    
 				// index为描述属性的下标,相应的该描述属性的各种值的下标也是index
    
 				// 取得每条记录在该描述属性下的值
    
 				String attri = eachRecord.get(descripIndex); // 找到该属性下的值
    
 				if (attri.contains(splitAttris[i])) {
    
 					// 是否已经有了该属性下的值
    
 					newGroupAttri[i]++;
    
 				}
    
 			}
    
 		}
    
 		treeNode.setSplitAttri(attr);
    
 		treeNode.setSplitAttriIndex(descripIndex);
    
 		// 再按测试属性treeNode.getAttriName()分类
    
 		for (int i = 0; i < splitAttris.length; i++) {
    
 			double sumIj = 0.0;
    
  
    
 			// 该描述属性的各个类别占记录数的比例
    
 			double multi = (double) newGroupAttri[i] / n;
    
 			int sum = 0;
    
  
    
 			// descripType 用来存放descripAttri
    
 			// 描述属性的各个值下,拥有最终决策属性的值的种类(即typeAttri的种类)
    
 			List<String> descripType = new ArrayList<String>();
    
 			for (List<String> eachRecord : treeNode.getRecordList()) {
    
 				String attri = eachRecord.get(typeAttriIndex);
    
 				String attriB = eachRecord.get(descripIndex); // 找到该属性下的值
    
 				if (!descripType.contains(attri)
    
 						&& attriB.contains(splitAttris[i]))
    
 					descripType.add(attri);
    
 			}
    
 			// descSplitAttris 为 typeAttri属性值的种类
    
 			String[] descSplitAttris = descripType.toArray(new String[0]);
    
  
    
 			for (int j = 0; j < descSplitAttris.length; j++) {
    
 				sum = 0;
    
 				// 在描述属性值为splitAttris[i]下,对拥有typeAttri属性值descSplitAttris[j]的记录计数
    
 				for (List<String> eachRecord : treeNode.getRecordList()) {
    
 					String attri = eachRecord.get(typeAttriIndex);
    
 					String attriB = eachRecord.get(descripIndex); // 找到该属性下的值
    
 					if (attriB.contains(splitAttris[i])
    
 							&& attri.contains(descSplitAttris[j])) {
    
 						sum++;
    
 					}
    
 				}
    
  
    
 				// 描述属性值为splitAttris[i],决策属性值为descSplitAttris[j],
    
 				// 在描述属性值为splitAttris[i]下占的比例
    
 				double Nij = (double) sum / newGroupAttri[i];
    
 				if (Nij > 0.0) {
    
 					sumIj -= Nij * (Math.log10(Nij) / Math.log10((double) 2));
    
 				}
    
 			}
    
 			Ecak = Ecak + multi * sumIj;
    
 		}
    
 		return Ecak;
    
 	}
    
  
    
 	/** * * @param treeNode
    
 	 * @return 类别属性的无条件熵E(C)
    
 	 */
    
 	private double gainEValue(TreeNode treeNode) {
    
 		int n = treeNode.getRecordList().size();
    
 		double Ec = 0.0;
    
 		List<String> rootType = new ArrayList<String>();
    
 		// 根据下标 获得该属性下的种类
    
 		for (List<String> eachRecord : treeNode.getRecordList()) {
    
 			String attri = eachRecord.get(typeAttriIndex);
    
 			if (!rootType.contains(attri))
    
 				rootType.add(attri);
    
 		}
    
 		String[] rootAttris = rootType.toArray(new String[0]);
    
 		int u = rootAttris.length; // 类别个数
    
 		int[] s = new int[u]; // 存放每个类别的记录数
    
 		// 计算E(C)
    
 		for (int i = 0; i < rootAttris.length; i++) {
    
 			s[i] = 0;
    
 			for (List<String> eachRecord : treeNode.getRecordList()) {
    
 				String attri = eachRecord.get(typeAttriIndex); // 找到该属性下的值
    
 				if (attri.contains(rootAttris[i])) {
    
 					// 是否已经有了该属性下的值
    
 					s[i]++;
    
 				}
    
 			}
    
 			double Ni = (double) s[i] / n;
    
 			Ec -= Ni * (Math.log10(Ni) / Math.log10((double) 2));
    
 		}
    
 		// System.out.println("E(" + typeAttri + ")=" + Ec);
    
 		return Ec;
    
 	}
    
  
    
 	/** * @param recordList
    
 	 *            所有记录
    
 	 * @return 在recordList中,类别属性值的个数
    
 	 */
    
 	private int getDecideAttriValueNums(Vector<Vector<String>> recordList) {
    
 		Set<String> set = new HashSet<String>();
    
 		for (Vector<String> vec : recordList) {
    
 			String attri = vec.get(typeAttriIndex);
    
 			if (!set.contains(attri)) {
    
 				set.add(attri);
    
 			}
    
 		}
    
 		return set.size();
    
 	}
    
  
    
 	// S样本属于同一类别C?
    
 	public boolean isSampleDecideAttri(Vector<Vector<String>> recordList) {
    
 		return this.getDecideAttriValueNums(recordList) == 1 ? true : false;
    
 	}
    
  
    
 	/** * 样本recordList中的类别属性值只有一种时,返回该类别的值
    
 	 * * @param recordList
    
 	 * @return
    
 	 */
    
 	public String getDecideAttriValue(Vector<Vector<String>> recordList) {
    
 		return recordList.get(0).get(typeAttriIndex);
    
 	}
    
  
    
 	/** * 根据测试属性的决定规则,取得该结点下的记录
    
 	 * * @param recordList
    
 	 * @param ruleAttri
    
 	 * @return
    
 	 */
    
 	public Vector<Vector<String>> getRelatedRecord(
    
 			Vector<Vector<String>> recordList, String ruleAttri, int index) {
    
 		Vector<Vector<String>> remainList = new Vector<Vector<String>>();
    
 		for (Vector<String> record : recordList) {
    
 			String attri = record.get(index);
    
 			if (ruleAttri.contains(attri)) {
    
 				remainList.add(record);
    
 			}
    
 		}
    
 		return remainList;
    
 	}
    
  
    
 	/** * * @param recordList
    
 	 * @return 返回样本recordList中占多数的类别typeAttri值
    
 	 */
    
 	public String getMostDecideNumsValue(Vector<Vector<String>> recordList) {
    
 		int max_num = 0;
    
 		int index = 0;
    
 		for (int i = 0; i < typeAttriValues.length; i++) {
    
 			int count = 0;
    
 			for (Vector<String> vec : recordList) {
    
 				if (vec.get(typeAttriIndex).equals(typeAttriValues[i])) {
    
 					count++;
    
 					System.out.println(vec.toString());
    
  
    
 				}
    
 			}
    
 			if (count > max_num) {
    
 				max_num = count;
    
 				index = i;
    
 			}
    
  
    
 		}
    
 		return typeAttriValues[index];
    
 	}
    
  
    
 	/** * 从文件中读取数据
    
 	 */
    
 	public Vector<Vector<String>> readDataFromFile(String filePath) {
    
 		descripAttriList = new Vector<String>(); // 描述属性
    
 		Set<String> typeValueSet = new HashSet<String>();
    
 		Vector<Vector<String>> dataList = new Vector<Vector<String>>(); // 所有记录
    
 		try {
    
 			// 一次读一行
    
 			// File file=new File(filePath);
    
 			BufferedReader reader = new BufferedReader(new InputStreamReader(
    
 					new FileInputStream(filePath), "GBK"));
    
 			String tempLine = null;
    
 			int i = 0;
    
 			while ((tempLine = reader.readLine()) != null) {
    
 				String[] items = tempLine.split(",");
    
 				if (i == 0) {
    
 					for (int j = 0; j < items.length; j++) {
    
 						descripAttriList.add(items[j]);
    
 					}
    
 					typeAttri = items[items.length - 1];
    
 					typeAttriIndex = items.length - 1;
    
 				} else {
    
 					Vector<String> eachRecord = new Vector<String>();
    
 					for (String item : items) {
    
 						eachRecord.add(item);
    
 					}
    
 					typeValueSet.add(items[items.length - 1]);
    
 					dataList.add(eachRecord);
    
 				}
    
 				i++;
    
 			}
    
  
    
 			typeAttriValues = new String[typeValueSet.size()];
    
 			typeAttriValues = typeValueSet.toArray(new String[0]);
    
 			reader.close();
    
 		} catch (IOException e) {
    
 			e.printStackTrace();
    
 			return null;
    
 		}
    
 		return dataList;
    
 	}
    
  
    
 	public static void main(String[] args) {
    
 		// String filePath="D:\ indexDir\ decide.txt";
    
 		String filePath = "D:\ indexDir\ buyInfo.txt";
    
 		DecideTreeID3 dd = new DecideTreeID3();
    
 		Vector<Vector<String>> list = dd.readDataFromFile(filePath);
    
 		Vector<String> decAttri = new Vector<String>(descripAttriList);
    
 		if(list!=null){
    
 			TreeNode root = new TreeNode(dd.create_decidion_tree(list, decAttri));
    
 			System.out.println(dd.printNodes(root));
    
 		}
    
  
    
 	}
    
  
    
 	public String printNodes(TreeNode root) {
    
 		StringBuilder sb = new StringBuilder("");
    
 		if (root == null)
    
 			return sb.toString();
    
  
    
 		// 是否为叶子结点
    
 		if (root.getDecideAttri() != null) {
    
 			sb.append("leaf...name=" + root.getAttriName()
    
 					+ ".....decide=" + root.getDecideAttri()+"\n");
    
 			return sb.toString();
    
 		}
    
  
    
 		System.out.println("parent" + root.getAttriName() + "...分裂规则:"
    
 				+ root.getSplitAttri().toString());
    
 		
    
 		for (Map.Entry<TreeNode, String> entry : root.getConnectNode()
    
 				.entrySet()) {
    
 			TreeNode node = new TreeNode(entry.getKey());
    
  
    
 			if (node.getDecideAttri() != null) {
    
 				sb.append("leaf...name=" + node.getAttriName()
    
 						+ "....val=" + entry.getValue() + "...decide="
    
 						+ node.getDecideAttri()+"\n");
    
 			} else {
    
 				sb.append("parent...name=" + node.getAttriName()
    
 						+ "...val=" + entry.getValue()+"\n");
    
 				sb.append(printNodes(node));
    
 			}
    
 		}
    
  
    
 		return sb.toString();
    
 	}
    
 }

TreeNode.java

复制代码
 package data.decidetree;

    
  
    
 import java.util.*;
    
  
    
 public class TreeNode {
    
  
    
 	private String attriName;  //属性名称
    
 	private String decideAttri;  //样本类别属性C
    
 	private int splitAttriIndex;  //分枝属性的下标位置
    
 	private Vector<Vector<String>> recordList; // 该结点下的记录
    
 	private Vector<String> splitAttri;  //分裂属性
    
 	private Map<TreeNode, String> connectNode;  //链接的结点
    
  
    
  
    
 	public TreeNode(){
    
 		recordList=new Vector<Vector<String>>();
    
 		connectNode=new HashMap<TreeNode, String>();
    
 		splitAttri=new Vector<String>();
    
 	}
    
 	
    
 	public TreeNode(TreeNode treeNode){
    
 		this.attriName=treeNode.getAttriName();
    
 		this.recordList=treeNode.getRecordList();
    
 		this.splitAttri=treeNode.getSplitAttri();
    
  
    
 		if(treeNode.getConnectNode()==null)
    
 			connectNode=new HashMap<TreeNode, String>();
    
 		else
    
 			connectNode=treeNode.getConnectNode();
    
 		this.decideAttri=treeNode.getDecideAttri();
    
 		this.splitAttriIndex=treeNode.getSplitAttriIndex();
    
 	}
    
  
    
 	public int getSplitAttriIndex() {
    
 		return splitAttriIndex;
    
 	}
    
  
    
 	public void setSplitAttriIndex(int splitAttriIndex) {
    
 		this.splitAttriIndex = splitAttriIndex;
    
 	}
    
  
    
  
    
 	public Vector<String> getSplitAttri() {
    
 		return splitAttri;
    
 	}
    
  
    
 	public void setSplitAttri(Vector<String> splitAttri) {
    
 		this.splitAttri = splitAttri;
    
 	}
    
  
    
 	public Map<TreeNode, String> getConnectNode() {
    
 		return connectNode;
    
 	}
    
  
    
 	public void setConnectNode(Map<TreeNode, String> connectNode) {
    
 		this.connectNode = connectNode;
    
 	}
    
  
    
 	public Vector<Vector<String>> getRecordList() {
    
 		return recordList;
    
 	}
    
  
    
 	public void setRecordList(Vector<Vector<String>> recordList) {
    
 		this.recordList = recordList;
    
 	}
    
  
    
 	public String getAttriName() {
    
 		return attriName;
    
 	}
    
  
    
 	public void setAttriName(String attriName) {
    
 		this.attriName = attriName;
    
 	}
    
  
    
 	public String getDecideAttri() {
    
 		return decideAttri;
    
 	}
    
  
    
 	public void setDecideAttri(String decideAttri) {
    
 		this.decideAttri = decideAttri;
    
 	}
    
 }

结果:

全部评论 (0)

还没有任何评论哟~