Advertisement

数据挖掘--聚类之K均值算法

阅读量:

1.聚类概念

聚类 是将物理或抽象对象的集合分成相似的对象类的过程。使得同一个簇中的对象之间具有较高的相似性,而不同簇中的对象具有较高的相异性。 簇 是数据对象的集合,这些对象与同一簇中的对象彼此相似,而与其他簇的对象相异。

聚类可形式描述为:

D={o1, o2, ……, on}表示一个对象集合,

oi表示第i个对象,i={1,2,……,n};

Cx表示第x个簇,CxÍD,x=1,2,…,k;

Similarity(oi,oj)表示对象oi与对象oj之间的相似度。

2.K均值算法

误差平方和准则: 若Nx是第x个簇Cx中的对象数目,mx是这些对象的均值,即

误差平方和准则J就是所有簇的簇中各个对象与均值间的误差平方和之和,即:

J度量了用k个聚类中心m1,m2,…,mk代表k个簇C1,C2,…,Ck时所产生的总的误差平方和。对于不同的聚类,J的值不同,使J值极小的聚类是误差平方和准则下的最优结果。

核心思想: 首先选定k个初始聚类中心,根据最小距离原则将每个数据对象分配到某一簇中,然后不断迭代计算各个簇的聚类中心并依新的聚类中心调整聚类情况,直至收敛(J值不再变化或变化小于指定的阀值)。

当存在噪声和离群数据时,k中心点算法比k均值算法效果好,但是k中心点聚类算法的执行代价比k均值算法高。

K均值算法描述如下:

算法:k均值聚类算法(D,k)

输入:数据对象集合D,簇数目k

输出:k个簇的集合

步骤:

(1) 从D中随机选取k个不同的数据对象作为k个簇C1,C2,…,Ck的中心m1,m2,…,mk;

(2)repeat

(2.1)for D中每个数据对象o

(2.1.1) 寻找i,

(2.1.2)将o分配给簇Ci

(2.2)for 每个簇

计算

//计算新的聚类中心, 为当前簇 中的对象数目

3)计算平方误差

3. Until J 不再发生变化

代码实现:

复制代码
 package data.cluster;

    
  
    
 import java.io.BufferedReader;
    
 import java.io.FileInputStream;
    
 import java.io.IOException;
    
 import java.io.InputStreamReader;
    
 import java.util.*;
    
  
    
 import data.util.FormatDecimal;
    
  
    
 /** * * @author CC
    
  * */
    
 public class KMeans {
    
  
    
 	private static Vector<Node> datas;
    
  
    
 	private static final int RANDOM_SIZE = 2; //随机选取结点的个数
    
  
    
 	public KMeans() {
    
 		datas = new Vector<Node>();
    
 	}
    
  
    
 	public String toString(Vector<Node> datas) {
    
 		String str = "";
    
 		for (int i = 0; i < datas.size(); i++) {
    
 			str += datas.get(i).toString();
    
 		}
    
 		return str;
    
 	}
    
  
    
 	public String toString(Object[] obj) {
    
 		String str = "";
    
 		for (int i = 0; i < obj.length; i++) {
    
 			str += obj[i].toString() + ",";
    
 		}
    
 		return str;
    
 	}
    
  
    
 	/** * @param args
    
 	 */
    
 	public static void main(String[] args) {
    
 		// TODO Auto-generated method stub
    
 		String filePath = "D:\ indexDir\ kmeans.txt";
    
 		KMeans km = new KMeans();
    
 //		datas.add(new Node(1.0, 1.0));
    
 //		datas.add(new Node(1.2, 1.2));
    
 //		datas.add(new Node(0.8, 1.2));
    
 //		datas.add(new Node(0.9, 0.7));
    
 //		datas.add(new Node(1.3, 0.9));
    
 //		datas.add(new Node(1.0, 1.4));
    
 //		datas.add(new Node(3.0, 3.0));
    
 //		datas.add(new Node(3.1, 2.8));
    
 //		datas.add(new Node(3.2, 3.4));
    
 //		datas.add(new Node(2.7, 3.3));
    
 //		datas.add(new Node(2.6, 2.9));
    
 		km.readData(filePath);
    
 		//System.out.println(km.toString(datas));
    
 		Map<Vector<Node>, Double> map = km.squareSum();
    
 		for (Map.Entry<Vector<Node>, Double> entry : map.entrySet()) {
    
 			System.out.println("簇类的中心: \n" + km.toString(entry.getKey())
    
 					+ "\n 平方误差J= " + entry.getValue());
    
 		}
    
 	}
    
 	
    
 	public void readData(String filePath){
    
 		datas = new Vector<Node>();
    
 		try {
    
 			// 一次读一行
    
 			// File file=new File(filePath);
    
 			//数据格式为 X,Y 形成一行
    
 			BufferedReader reader = new BufferedReader(new InputStreamReader(
    
 					new FileInputStream(filePath)));
    
 			String tempLine = null;
    
 			while ((tempLine = reader.readLine()) != null) {
    
 				String[] items = tempLine.split(",");
    
 				datas.add(new Node(Double.parseDouble(items[0]),Double.parseDouble(items[1])));
    
 			}
    
 			reader.close();
    
 		} catch (IOException e) {
    
 			e.printStackTrace();
    
 		}
    
 	}
    
  
    
 	/** * 计算簇类的新中心以及平方误差
    
 	 * * @return
    
 	 */
    
 	public Map<Vector<Node>, Double> squareSum() {
    
 		// 选择的结点
    
 		Vector<Node> selectedNodes = this.pickNodes();
    
 		Vector<Vector<Node>> typeNode;
    
 		double sum = 0.0;
    
 		double calculate = 1.0;
    
 		while (Math.abs(calculate - sum) > 1e-6) {
    
 			// 初始化各个簇类
    
 			typeNode = new Vector<Vector<Node>>();
    
 			for (int i = 0; i < selectedNodes.size(); i++) {
    
 				Vector<Node> v = new Vector<Node>();
    
 				typeNode.add(v);
    
 			}
    
 			// 对结点进行分类 ,形成 selectedNodes.length 个簇
    
 			for (int i = 0; i < datas.size(); i++) {
    
 				double min = 65533.0;
    
 				int index = 0;
    
 				// 取得datas[i]离选择点最近的点
    
 				for (int j = 0; j < selectedNodes.size(); j++) {
    
 					double distance = calculateDistance(datas.get(i),
    
 							selectedNodes.get(j));
    
 					if (min > distance ) {
    
 						min = distance;
    
 						index = j;
    
 					}
    
 				}
    
  
    
 				// 添加结点至第index个簇
    
 				if (typeNode.get(index) == null) {
    
 					Vector<Node> indexNode = new Vector<Node>();
    
 					indexNode.add(datas.get(i));
    
 					typeNode.set(index, indexNode);
    
 				} else {
    
 					typeNode.get(index).add(datas.get(i));
    
 				}
    
 			}
    
 			// 更新新簇的中心值
    
 			for (int i = 0; i < typeNode.size(); i++) {
    
 				selectedNodes.set(i, this.getClusterMidNode(typeNode.get(i)));
    
 			}
    
  
    
 			sum = calculate;
    
 			calculate = this.calculateSquareSum(selectedNodes, typeNode);
    
 		}
    
 		Map<Vector<Node>, Double> map = new HashMap<Vector<Node>, Double>();
    
 		map.put(selectedNodes, calculate);
    
 		return map;
    
 	}
    
  
    
 	/** * 计算两点距离
    
 	 */
    
  
    
 	private double calculateDistance(Node A, Node B) {
    
 		return FormatDecimal.formatDouble(Math.sqrt((A.getX() - B.getX())
    
 				* (A.getX() - B.getX()) + (A.getY() - B.getY())
    
 				* (A.getY() - B.getY())),4);
    
 	}
    
  
    
 	/** * * 计算一个簇的中心值点
    
 	 * * @param nodeList
    
 	 *            该簇拥有的结点
    
 	 * @return
    
 	 */
    
 	private Node getClusterMidNode(List<Node> nodeList) {
    
 		int size = nodeList.size();
    
 		double x = 0.0, y = 0.0;
    
 		for (Node node : nodeList) {
    
 			x += node.getX();
    
 			y += node.getY();
    
 		}
    
 		return new Node(FormatDecimal.formatDouble(x / size,4), FormatDecimal.formatDouble(y / size,4));
    
 	}
    
  
    
 	/** * 计算平方误差
    
 	 * * @param selectedNodes
    
 	 *            各个簇类的中心点
    
 	 * @param typeNode
    
 	 *            对应各个簇类下的点集
    
 	 * @return
    
 	 */
    
 	public double calculateSquareSum(Vector<Node> selectedNodes,
    
 			Vector<Vector<Node>> typeNode) {
    
 		double sum = 0.0;
    
 		for (int i = 0; i < selectedNodes.size(); i++) {
    
 			for (Node node : typeNode.get(i)) {
    
 				sum += Math.pow(this.calculateDistance(selectedNodes.get(i),
    
 						node), 2.0);
    
 			}
    
 		}
    
 		return FormatDecimal.formatDouble(sum,4);
    
 	}
    
  
    
 	/** * 随机生成选择的结点
    
 	 * * @return
    
 	 */
    
 	private Vector<Node> pickNodes() {
    
 		int base = datas.size() - 1;
    
  
    
 		int randomSize = ((int) Math.random() * 1000) % base;
    
 		//int randomSize = RANDOM_SIZE;
    
 		
    
 		// 选择的结点不能小于两个
    
 		while (randomSize < 2) {
    
 			randomSize = ((int) (Math.random() * 1000)) % base;
    
 		}
    
 		
    
 		Integer[] randomIndex = new Integer[randomSize];
    
 		for (int i = 0; i < randomSize; i++) {
    
 			// 随机生成的下标不能重复
    
 			int index = ((int) (Math.random() * 1000)) % base;
    
 			while (hasIndex(randomIndex, index)) {
    
 				index = ((int) (Math.random() * 1000)) % base;
    
 			}
    
 			randomIndex[i] = index;
    
 		}
    
 		
    
 		Vector<Node> selectedNodes = new Vector<Node>();
    
 		// 选择的结点
    
 		for (int i = 0; i < randomIndex.length; i++) {
    
 			selectedNodes.add(datas.get(randomIndex[i]));
    
 		}
    
 		return selectedNodes;
    
 	}
    
  
    
 	/** * 是否已经有该下标了
    
 	 * * @param randomIndex
    
 	 *            下标数组
    
 	 * @param index
    
 	 *            下标
    
 	 * @return
    
 	 */
    
 	public boolean hasIndex(Integer[] randomIndex, Integer index) {
    
 		boolean flag = false;
    
 		if (randomIndex != null && randomIndex.length > 0) {
    
 			for (int i = 0; i < randomIndex.length; i++) {
    
 				if (index == randomIndex[i]) {
    
 					flag = true;
    
 					break;
    
 				}
    
 			}
    
 		}
    
 		return flag;
    
 	}
    
 }
    
    
    
    
复制代码
  public static double formatDouble(double transData,int length) {

    
 		
    
 		if(length<0){
    
 			return transData;
    
 		}else{
    
 			String str="";
    
 			for(int i=0;i<length;i++){
    
 				str+="#";
    
 			}
    
 			DecimalFormat df = new DecimalFormat("#."+str);
    
 			String data = df.format(transData);
    
 			return Double.parseDouble(data);
    
 		}
    
 	
    
 	}
    
    
    
    

测试结果:

全部评论 (0)

还没有任何评论哟~