数据挖掘--分类之决策树算法ID3
一、决策树:
一棵决策树由一个根节点,一组内部节点和一组叶节点组成。每个内部节点(包括根节点)表示在一个属性上的测试,每个分枝表示一个测试输出,每个叶节点表示一个类,有时不同的叶节点可以表示相同的类。

建立一棵决策树,需要解决的问题主要有:
1)如何选择测试属性?
测试属性的选择顺序影响决策树的结构甚至决策树的准确率,一般使用信息增益度量来选择测试属性。
2)如何停止划分样本?
从根节点测试属性开始,每个内部节点测试属性都把样本空间划分为若干个(子)区域,一般当某个(子)区域的样本同类时,就停止划分样本,有时也通过阈值提前停止划分样本。
二、算法:
1. 算法思想及描述
首先,在整个训练数据集S、所有描述属性A1, A2, …, Am上递归地建立决策树。即将S作为根节点;如果S中的样本属于同一类别,则将S作为叶节点并用其中的类别标识,决策树建立完成(递归出口);
否则在S上计算当给定Ak(1≤k≤m)时类别属性C的信息增益G(C, Ak),选择信息增益最大的Ai作为根节点的测试属性;如果Ai的取值个数为v(取值记为a1, a2, …, av),则Ai将S划分为v个子集S1, S2, …, Sv(Sj(1≤j≤v)为S中Ai=aj的样本集合),同时根节点产生v个分枝与之对应。其次,分别在训练数据子集S1, S2, …, Sv、剩余描述属性A1, …, Ai-1, Ai+1, …, Am上采用相同方法递归地建立决策树子树(递归)。
可能出现如下情况,需要停止建立决策(子)树的递归过程。
1)某节点对应的训练数据(子)集为空。此时,该节点作为叶节点并用父节点中占多数的样本类别标识。
2)某节点没有对应的(剩余)描述属性。此时,该节点作为叶节点并用该节点中占多数的样本类别标识。
算法: 决策树分类算法Generate_decision_tree(S, A)
输入:训练数据集S,描述属性集合A
输出:决策树
步骤:
(1)创建对应S的节点Node;
(2)if S中的样本属于同一类别c then
以c标识Node并将Node作为叶节点返回;
(3)if A为空 then
以S中占多数的样本类别c标识Node并将Node作为叶节点返回;
(4)从A中选择对S而言信息增益最大的描述属性Ai作为Node的测试属性;
(5)for Ai的每个可能取值aj(1≤j≤v ) //设Ai的可能取值为a1, a2, …, av
(5.1)产生S的一个子集Sj //Sj(1≤j≤v)为S中Ai=aj的样本集合;
(5.2)if Sj为空 then
创建对应Sj的节点Nj,以S中占多数的样本类别c标识Nj,并将Nj作为叶节点形成Node的一个分枝
(5.3)else 由Generate_decision_tree(Sj, A-{Ai})创建子树形成Node的一个分枝;
三、 信息增益
在决策树分类算法中使用信息增益度量来选择测试属性。
从信息论角度看,通过描述属性可以减少类别属性的不确定性。
3.1 离散型随机变量X的无条件熵 定义为
式中,p(xi)为X= xi的概率;u为X的可能取值个数。
3.2 给定离散型随机变量Y,离散型随机变量X的条件熵 定义为
式中,p(xiyj)为X=xi, Y=yj的联合概率;p(xi/yj)为已知Y=yj时,X= xi的条件概率;u、v分别为X、Y的可能取值个数。
可以证明,H(X/Y)≤H(X)。所以,通过Y可以减少X的不确定性。
3.3 假设训练数据集是关系数据表r1, r2, …, rn,其中描述属性为A1, A2, …, Am、类别属性为C,类别属性C的无条件熵 定义为
式中,u为C的可能取值个数,即类别个数,类别记为c1, c2, …, cu;si为属于类别ci的记录集合,|si|即为属于类别ci的记录总数。
3.4 给定描述属性Ak(1≤k≤m),类别属性C的条件 熵定义为 
式中,v为Ak的可能取值个数,取值记为a1, a2, …, av;sj为Ak=aj的记录集合,|sj|即为Ak=aj的记录数目;sij为Ak=aj且属于类别ci的记录集合,|sij|即为Ak=aj且属于类别ci的记录数目。
3.5 采用类别属性的无条件熵与条件熵的差(信息增益)来度量描述属性减少类别属性不确定性的程度。
给定描述属性Ak(1≤k≤m),类别属性C的信息增益定义为:
G(C, Ak)=E(C)-E(C, Ak)
可以看出,G(C, Ak)反映Ak减少C不确定性的程度,G(C, Ak)越大,Ak对减少C不确定性的贡献越大。
代码实现:
package data.decidetree;
import java.io.BufferedReader;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.InputStreamReader;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.Vector;
public class DecideTreeID3 {
private static Vector<String> descripAttriList; // 描述属性
private static String[] typeAttriValues; //类别属性值域
private static String typeAttri; //类别属性
private static int typeAttriIndex; //类别属性下标
public DecideTreeID3() {
descripAttriList = new Vector<String>();
}
public TreeNode create_decidion_tree(Vector<Vector<String>> recordList,
Vector<String> descripAttris) {
// 创建recordList对应的节点Node
TreeNode root = new TreeNode();
if (this.isSampleDecideAttri(recordList)) {
// 样本recordList属于同一类别 typeAttri
root.setAttriName("<is Leaf>");
root.setDecideAttri(this.getDecideAttriValue(recordList));
//System.out.println("name="+root.getAttriName()+" decide=" + root.getDecideAttri());
return root;
}
if (descripAttris == null || descripAttris.size() == 0) {
root.setDecideAttri(this.getMostDecideNumsValue(recordList));
return root;
}
root.setRecordList(recordList);
//选择决策属性 并设置其分裂属性值、分裂属性的下标位置
this.chooseDecideAttri(root);
if (root.getSplitAttri() != null && root.getSplitAttri().size() > 0) {
// 遍历每个分裂属性值
for (String attriValue : root.getSplitAttri()) {
Vector<String> descripAttri = new Vector<String>(descripAttris);
// 取得下标
int index = this.getAttriIndex(root.getAttriName());
// 取得recordList的子集
Vector<Vector<String>> remainRecord = this.getRelatedRecord(recordList, attriValue, index);
if (remainRecord == null || remainRecord.size() == 0) {
TreeNode node = new TreeNode();
node.setDecideAttri(this.getMostDecideNumsValue(recordList));
root.getConnectNode().put(node, attriValue);
} else {
descripAttri.remove(root.getSplitAttriIndex());
TreeNode node = new TreeNode(create_decidion_tree(
remainRecord, descripAttri));
if (node.getAttriName() != null)
root.getConnectNode().put(node, attriValue);
}
}
}
return root;
}
/** * 选择分枝的测试属性 设置 treeNode的分裂属性位置,以及各个分枝
* * @param typeAttri
* 测试属性
* @return 选择的属性 及其位置
*/
private void chooseDecideAttri(TreeNode treeNode) {
double max_entropy = 0.0;
TreeNode copyNode = new TreeNode(treeNode);
// 取得E(C)的值
double Ec = this.gainEValue(treeNode);
for (int i = 0; i < descripAttriList.size() - 1; i++) {
String descripAttri = descripAttriList.get(i);
if (!typeAttri.equals(descripAttri)) {
// 计算每个描述属性在测试属性下的熵 ,熵值最大则为下一个测试属性
double entropy = Ec- this.informationGain(copyNode, descripAttri);
// System.out.println("G(" + typeAttri + ","+ descripAttri +")=" + entropy);
if (max_entropy < entropy) {
max_entropy = entropy;
// 设置该结点的分裂属性
treeNode.setAttriName(descripAttriList.get(copyNode
.getSplitAttriIndex()));
treeNode.setSplitAttri(copyNode.getSplitAttri());
treeNode.setSplitAttriIndex(copyNode.getSplitAttriIndex());
}
}
}
}
/** * 根据属性名称 取得其在descripAttriList中的下标位置
* * @param attriType
* @return
*/
private int getAttriIndex(String attriType) {
int i = 0;
int index = 0;
for (String descripAttri : descripAttriList) {
if (attriType.equals(descripAttri)) {
index = i;
break;
}
i++;
}
return index;
}
/** * 类别属性C的条件熵
* * @param typeAttri
* 测试属性C
* @param descripAttri
* 描述属性 Ak E(C,Ak)
* @return
*/
private double informationGain(TreeNode treeNode, String descripAttri) {
/** * 先按描述属性 descripAttri分类,再按测试属性typeAttri分类
*/
int n = treeNode.getRecordList().size(); // 该结点拥有的记录数
// 取得描述属性descripAttri的下标 ,并求得有多少种
List<String> type = new ArrayList<String>();
int descripIndex = this.getAttriIndex(descripAttri);
// 根据下标 获得该描述属性下的种类,及相应的子记录
for (List<String> eachRecord : treeNode.getRecordList()) {
// 取得第index的属性的值
String attri = eachRecord.get(descripIndex);
if (!type.contains(attri))
type.add(attri);
}
// 若分裂属性为该描述属性,其各个分裂属性为String[] splitAttris
String[] splitAttris = type.toArray(new String[0]);
double Ecak = 0.0;
// 用来存放treeNode的分裂属性
Vector<String> attr = new Vector<String>();
// 计算E(C,Ak)
// 用 newGroupAttri 存放各个分裂属性下的记录数,长度为分裂属性的种类数
int[] newGroupAttri = new int[splitAttris.length];
for (int i = 0; i < splitAttris.length; i++) {
newGroupAttri[i] = 0;
// 存放分裂属性的值
attr.add(splitAttris[i]);
for (List<String> eachRecord : treeNode.getRecordList()) {
// index为描述属性的下标,相应的该描述属性的各种值的下标也是index
// 取得每条记录在该描述属性下的值
String attri = eachRecord.get(descripIndex); // 找到该属性下的值
if (attri.contains(splitAttris[i])) {
// 是否已经有了该属性下的值
newGroupAttri[i]++;
}
}
}
treeNode.setSplitAttri(attr);
treeNode.setSplitAttriIndex(descripIndex);
// 再按测试属性treeNode.getAttriName()分类
for (int i = 0; i < splitAttris.length; i++) {
double sumIj = 0.0;
// 该描述属性的各个类别占记录数的比例
double multi = (double) newGroupAttri[i] / n;
int sum = 0;
// descripType 用来存放descripAttri
// 描述属性的各个值下,拥有最终决策属性的值的种类(即typeAttri的种类)
List<String> descripType = new ArrayList<String>();
for (List<String> eachRecord : treeNode.getRecordList()) {
String attri = eachRecord.get(typeAttriIndex);
String attriB = eachRecord.get(descripIndex); // 找到该属性下的值
if (!descripType.contains(attri)
&& attriB.contains(splitAttris[i]))
descripType.add(attri);
}
// descSplitAttris 为 typeAttri属性值的种类
String[] descSplitAttris = descripType.toArray(new String[0]);
for (int j = 0; j < descSplitAttris.length; j++) {
sum = 0;
// 在描述属性值为splitAttris[i]下,对拥有typeAttri属性值descSplitAttris[j]的记录计数
for (List<String> eachRecord : treeNode.getRecordList()) {
String attri = eachRecord.get(typeAttriIndex);
String attriB = eachRecord.get(descripIndex); // 找到该属性下的值
if (attriB.contains(splitAttris[i])
&& attri.contains(descSplitAttris[j])) {
sum++;
}
}
// 描述属性值为splitAttris[i],决策属性值为descSplitAttris[j],
// 在描述属性值为splitAttris[i]下占的比例
double Nij = (double) sum / newGroupAttri[i];
if (Nij > 0.0) {
sumIj -= Nij * (Math.log10(Nij) / Math.log10((double) 2));
}
}
Ecak = Ecak + multi * sumIj;
}
return Ecak;
}
/** * * @param treeNode
* @return 类别属性的无条件熵E(C)
*/
private double gainEValue(TreeNode treeNode) {
int n = treeNode.getRecordList().size();
double Ec = 0.0;
List<String> rootType = new ArrayList<String>();
// 根据下标 获得该属性下的种类
for (List<String> eachRecord : treeNode.getRecordList()) {
String attri = eachRecord.get(typeAttriIndex);
if (!rootType.contains(attri))
rootType.add(attri);
}
String[] rootAttris = rootType.toArray(new String[0]);
int u = rootAttris.length; // 类别个数
int[] s = new int[u]; // 存放每个类别的记录数
// 计算E(C)
for (int i = 0; i < rootAttris.length; i++) {
s[i] = 0;
for (List<String> eachRecord : treeNode.getRecordList()) {
String attri = eachRecord.get(typeAttriIndex); // 找到该属性下的值
if (attri.contains(rootAttris[i])) {
// 是否已经有了该属性下的值
s[i]++;
}
}
double Ni = (double) s[i] / n;
Ec -= Ni * (Math.log10(Ni) / Math.log10((double) 2));
}
// System.out.println("E(" + typeAttri + ")=" + Ec);
return Ec;
}
/** * @param recordList
* 所有记录
* @return 在recordList中,类别属性值的个数
*/
private int getDecideAttriValueNums(Vector<Vector<String>> recordList) {
Set<String> set = new HashSet<String>();
for (Vector<String> vec : recordList) {
String attri = vec.get(typeAttriIndex);
if (!set.contains(attri)) {
set.add(attri);
}
}
return set.size();
}
// S样本属于同一类别C?
public boolean isSampleDecideAttri(Vector<Vector<String>> recordList) {
return this.getDecideAttriValueNums(recordList) == 1 ? true : false;
}
/** * 样本recordList中的类别属性值只有一种时,返回该类别的值
* * @param recordList
* @return
*/
public String getDecideAttriValue(Vector<Vector<String>> recordList) {
return recordList.get(0).get(typeAttriIndex);
}
/** * 根据测试属性的决定规则,取得该结点下的记录
* * @param recordList
* @param ruleAttri
* @return
*/
public Vector<Vector<String>> getRelatedRecord(
Vector<Vector<String>> recordList, String ruleAttri, int index) {
Vector<Vector<String>> remainList = new Vector<Vector<String>>();
for (Vector<String> record : recordList) {
String attri = record.get(index);
if (ruleAttri.contains(attri)) {
remainList.add(record);
}
}
return remainList;
}
/** * * @param recordList
* @return 返回样本recordList中占多数的类别typeAttri值
*/
public String getMostDecideNumsValue(Vector<Vector<String>> recordList) {
int max_num = 0;
int index = 0;
for (int i = 0; i < typeAttriValues.length; i++) {
int count = 0;
for (Vector<String> vec : recordList) {
if (vec.get(typeAttriIndex).equals(typeAttriValues[i])) {
count++;
System.out.println(vec.toString());
}
}
if (count > max_num) {
max_num = count;
index = i;
}
}
return typeAttriValues[index];
}
/** * 从文件中读取数据
*/
public Vector<Vector<String>> readDataFromFile(String filePath) {
descripAttriList = new Vector<String>(); // 描述属性
Set<String> typeValueSet = new HashSet<String>();
Vector<Vector<String>> dataList = new Vector<Vector<String>>(); // 所有记录
try {
// 一次读一行
// File file=new File(filePath);
BufferedReader reader = new BufferedReader(new InputStreamReader(
new FileInputStream(filePath), "GBK"));
String tempLine = null;
int i = 0;
while ((tempLine = reader.readLine()) != null) {
String[] items = tempLine.split(",");
if (i == 0) {
for (int j = 0; j < items.length; j++) {
descripAttriList.add(items[j]);
}
typeAttri = items[items.length - 1];
typeAttriIndex = items.length - 1;
} else {
Vector<String> eachRecord = new Vector<String>();
for (String item : items) {
eachRecord.add(item);
}
typeValueSet.add(items[items.length - 1]);
dataList.add(eachRecord);
}
i++;
}
typeAttriValues = new String[typeValueSet.size()];
typeAttriValues = typeValueSet.toArray(new String[0]);
reader.close();
} catch (IOException e) {
e.printStackTrace();
return null;
}
return dataList;
}
public static void main(String[] args) {
// String filePath="D:\ indexDir\ decide.txt";
String filePath = "D:\ indexDir\ buyInfo.txt";
DecideTreeID3 dd = new DecideTreeID3();
Vector<Vector<String>> list = dd.readDataFromFile(filePath);
Vector<String> decAttri = new Vector<String>(descripAttriList);
if(list!=null){
TreeNode root = new TreeNode(dd.create_decidion_tree(list, decAttri));
System.out.println(dd.printNodes(root));
}
}
public String printNodes(TreeNode root) {
StringBuilder sb = new StringBuilder("");
if (root == null)
return sb.toString();
// 是否为叶子结点
if (root.getDecideAttri() != null) {
sb.append("leaf...name=" + root.getAttriName()
+ ".....decide=" + root.getDecideAttri()+"\n");
return sb.toString();
}
System.out.println("parent" + root.getAttriName() + "...分裂规则:"
+ root.getSplitAttri().toString());
for (Map.Entry<TreeNode, String> entry : root.getConnectNode()
.entrySet()) {
TreeNode node = new TreeNode(entry.getKey());
if (node.getDecideAttri() != null) {
sb.append("leaf...name=" + node.getAttriName()
+ "....val=" + entry.getValue() + "...decide="
+ node.getDecideAttri()+"\n");
} else {
sb.append("parent...name=" + node.getAttriName()
+ "...val=" + entry.getValue()+"\n");
sb.append(printNodes(node));
}
}
return sb.toString();
}
}
TreeNode.java
package data.decidetree;
import java.util.*;
public class TreeNode {
private String attriName; //属性名称
private String decideAttri; //样本类别属性C
private int splitAttriIndex; //分枝属性的下标位置
private Vector<Vector<String>> recordList; // 该结点下的记录
private Vector<String> splitAttri; //分裂属性
private Map<TreeNode, String> connectNode; //链接的结点
public TreeNode(){
recordList=new Vector<Vector<String>>();
connectNode=new HashMap<TreeNode, String>();
splitAttri=new Vector<String>();
}
public TreeNode(TreeNode treeNode){
this.attriName=treeNode.getAttriName();
this.recordList=treeNode.getRecordList();
this.splitAttri=treeNode.getSplitAttri();
if(treeNode.getConnectNode()==null)
connectNode=new HashMap<TreeNode, String>();
else
connectNode=treeNode.getConnectNode();
this.decideAttri=treeNode.getDecideAttri();
this.splitAttriIndex=treeNode.getSplitAttriIndex();
}
public int getSplitAttriIndex() {
return splitAttriIndex;
}
public void setSplitAttriIndex(int splitAttriIndex) {
this.splitAttriIndex = splitAttriIndex;
}
public Vector<String> getSplitAttri() {
return splitAttri;
}
public void setSplitAttri(Vector<String> splitAttri) {
this.splitAttri = splitAttri;
}
public Map<TreeNode, String> getConnectNode() {
return connectNode;
}
public void setConnectNode(Map<TreeNode, String> connectNode) {
this.connectNode = connectNode;
}
public Vector<Vector<String>> getRecordList() {
return recordList;
}
public void setRecordList(Vector<Vector<String>> recordList) {
this.recordList = recordList;
}
public String getAttriName() {
return attriName;
}
public void setAttriName(String attriName) {
this.attriName = attriName;
}
public String getDecideAttri() {
return decideAttri;
}
public void setDecideAttri(String decideAttri) {
this.decideAttri = decideAttri;
}
}
结果: 
