k means java_数据挖掘-聚类-K-means算法Java实现
K-Means算法是最早也是应用最为广泛的聚类方法之一,在数据挖掘领域得到了广泛应用。该算法通过质心来定义每个簇的典型代表点——质心即一组数据点的平均位置,在n维连续的空间中应用这一方法时表现良好。
K-Means算法流程
step1:选择K个点作为初始质心
step2:repeat
将每个点指派到最近的质心,形成K个簇
重新计算每个簇的质心
until 质心不在变化
如图所示的样本集合,在初步筛选时选择了三个较为紧密分布的质心点;经过三次迭代运算后,其位置逐渐趋于稳定状态,并最终将整个样本集划分为三个子集。

我们对每一个步骤都进行分析
step1:选择K个点作为初始质心
在执行这一步骤之前,请明确了解所需计算出的K值及其意义;同时需要区分这里的K值与基于EM算法的那种动态划分方式的不同。
其次,如何选择初始质心
最简单的办法等同于随机选取中心点。然后进行多次运行,在所得结果中选择表现最佳的那个。这个方法虽然简单,但未必能达到理想的效果。然而这个方案存在较大的可能性会获得局部最优解。
另外一种较为复杂的策略是:首先随机选择一个质心中心,在此基础上找出与该质心距离最远的数据样本。接着对后续选择的每一个新质心,则会从之前选定的所有质心中挑选出与之距离最远的一个作为代表。通过这种方式实现,在最终结果中能够保证所选的质心既是随机分布又是分散开来的。
step2:repeat
将每个点指派到最近的质心,形成K个簇
重新计算每个簇的质心
until 质心不在变化
如何确定最邻近的概念?在欧式空间中计算两个点之间的距离可基于欧式空间。在文本处理中,则可采用余弦相似性等方法。给定的数据集能够适应采用多种合适的邻近性度量方法。
其他问题
离群点的处理
离群点可能会严重影响簇的发现。这可能导致最后阶段的结果与我们的预期相差较大。因此,在分析过程中识别并排除这些离群点是非常重要的。
在我的工作中,是利用方差来剔除离群点,结果显示效果非常好。
簇分裂和簇合并
选择较大的值K通常会使得聚类结果更具合理性。然而,在许多实际场景中,我们倾向于维持较低的簇数量。
此时尚常交替使用分割与合并策略。此方法可避免陷入局部最小值,并且能够实现预期数量的聚类结果。
贴上代码java版,以后有时间写个python版的
抽象了点,簇,和距离
Point.class
public class Point {
private double x;
private double y;
private int id;
private boolean beyond;//标识是否属于样本
public Point(int id, double x, double y) {
this.id = id;
this.x = x;
this.y = y;
this.beyond = true;
}
public Point(int id, double x, double y, boolean beyond) {
this.id = id;
this.x = x;
this.y = y;
this.beyond = beyond;
}
public double getX() {
return x;
}
public double getY() {
return y;
}
public int getId() {
return id;
}
public boolean isBeyond() {
return beyond;
}
@Override
public String toString() {
return "Point{" +
"id=" + id +
", x=" + x +
", y=" + y +
'}';
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
Point point = (Point) o;
if (Double.compare(point.x, x) != 0) return false;
if (Double.compare(point.y, y) != 0) return false;
return true;
}
@Override
public int hashCode() {
int result;
long temp;
temp = x != +0.0d ? Double.doubleToLongBits(x) : 0L;
result = (int) (temp ^ (temp >>> 32));
temp = y != +0.0d ? Double.doubleToLongBits(y) : 0L;
result = 31 * result + (int) (temp ^ (temp >>> 32));
return result;
}
} Cluster.class
public class Cluster {
private int id;//标识
private Point center;//中心
private List members = new ArrayList();//成员
public Cluster(int id, Point center) {
this.id = id;
this.center = center;
}
public Cluster(int id, Point center, List members) {
this.id = id;
this.center = center;
this.members = members;
}
public void addPoint(Point newPoint) {
if (!members.contains(newPoint))
members.add(newPoint);
else
throw new IllegalStateException("试图处理同一个样本数据!");
}
public int getId() {
return id;
}
public Point getCenter() {
return center;
}
public void setCenter(Point center) {
this.center = center;
}
public List getMembers() {
return members;
}
@Override
public String toString() {
return "Cluster{" +
"id=" + id +
", center=" + center +
", members=" + members +
"}";
}
} 抽象的距离,可以具体实现为欧式,曼式或其他距离公式
public abstract class AbstractDistance {
abstract public double getDis(Point p1, Point p2);
} 点对
public class Distence implements Comparable {
private Point source;
private Point dest;
private double dis;
private AbstractDistance distance;
public Distence(Point source, Point dest, AbstractDistance distance) {
this.source = source;
this.dest = dest;
this.distance = distance;
dis = distance.getDis(source, dest);
}
public Point getSource() {
return source;
}
public Point getDest() {
return dest;
}
public double getDis() {
return dis;
}
@Override
public int compareTo(Distence o) {
if (o.getDis() > dis)
return -1;
else
return 1;
}
}
核心实现类
public class KMeansCluster {
private int k;//簇的个数
private int num = 100000;//迭代次数
private List datas;//原始样本集
private String address;//样本集路径
private List data = new ArrayList();
private AbstractDistance distance = new AbstractDistance() {
@Override
public double getDis(Point p1, Point p2) {
//欧几里德距离
return Math.sqrt(Math.pow(p1.getX() - p2.getX(), 2) + Math.pow(p1.getY() - p2.getY(), 2));
}
};
public KMeansCluster(int k, int num, String address) {
this.k = k;
this.num = num;
this.address = address;
}
public KMeansCluster(int k, String address) {
this.k = k;
this.address = address;
}
public KMeansCluster(int k, List datas) {
this.k = k;
this.datas = datas;
}
public KMeansCluster(int k, int num, List datas) {
this.k = k;
this.num = num;
this.datas = datas;
}
private void check() {
if (k == 0)
throw new IllegalArgumentException("k must be the number > 0");
if (address == null && datas == null)
throw new IllegalArgumentException("program can't get real data");
}
/**
-
初始化数据
-
@throws java.io.FileNotFoundException
*/
public void init() throws FileNotFoundException {
check();
//读取文件,init data
//处理原始数据
for (int i = 0, j = datas.size(); i < j; i++)
data.add(new Point(i, datas.get(i), 0));
}
/**
-
第一次随机选取中心点
-
@return
*/
public Set chooseCenter() {
Set center = new HashSet();
Random ran = new Random();
int roll = 0;
while (center.size() < k) {
roll = ran.nextInt(data.size());
center.add(data.get(roll));
}
return center;
}
/**
-
@param center
-
@return
*/
public List prepare(Set center) {
List cluster = new ArrayList();
Iterator it = center.iterator();
int id = 0;
while (it.hasNext()) {
Point p = it.next();
if (p.isBeyond()) {
Cluster c = new Cluster(id++, p);
c.addPoint(p);
cluster.add(c);
} else
cluster.add(new Cluster(id++, p));
}
return cluster;
}
/**
-
第一次运算,中心点为样本值
-
@param center
-
@param cluster
-
@return
*/
public List clustering(Set center, List cluster) {
Point[] p = center.toArray(new Point[0]);
TreeSet distence = new TreeSet();//存放距离信息
Point source;
Point dest;
boolean flag = false;
for (int i = 0, n = data.size(); i < n; i++) {
distence.clear();
for (int j = 0; j < center.size(); j++) {
if (center.contains(data.get(i)))
break;
flag = true;
// 计算距离
source = data.get(i);
dest = p[j];
distence.add(new Distence(source, dest, distance));
}
if (flag == true) {
Distence min = distence.first();
for (int m = 0, k = cluster.size(); m < k; m++) {
if (cluster.get(m).getCenter().equals(min.getDest()))
cluster.get(m).addPoint(min.getSource());
}
}
flag = false;
}
return cluster;
}
/**
-
迭代运算,中心点为簇内样本均值
-
@param cluster
-
@return
*/
public List cluster(List cluster) {
// double error;
Set lastCenter = new HashSet();
for (int m = 0; m < num; m++) {
// error = 0;
Set center = new HashSet();
// 重新计算聚类中心
for (int j = 0; j < k; j++) {
List ps = cluster.get(j).getMembers();
int size = ps.size();
if (size < 3) {
center.add(cluster.get(j).getCenter());
continue;
}
// 计算距离
double x = 0.0, y = 0.0;
for (int k1 = 0; k1 < size; k1++) {
x += ps.get(k1).getX();
y += ps.get(k1).getY();
}
//得到新的中心点
Point nc = new Point(-1, x / size, y / size, false);
center.add(nc);
}
if (lastCenter.containsAll(center))//中心点不在变化,退出迭代
break;
lastCenter = center;
// 迭代运算
cluster = clustering(center, prepare(center));
// for (int nz = 0; nz < k; nz++) {
// error += cluster.get(nz).getError();//计算误差
// }
}
return cluster;
}
/**
-
输出聚类信息到控制台
-
@param cs
*/
public void out2console(List cs) {
for (int i = 0; i < cs.size(); i++) {
System.out.println("No." + (i + 1) + " cluster:");
Cluster c = cs.get(i);
List p = c.getMembers();
for (int j = 0; j < p.size(); j++) {
System.out.println("\t" + p.get(j).getX() + " ");
}
System.out.println();
}
}
} 代码还没有仔细优化,执行的效率可能还存在一定的问题
