Advertisement

k means java_数据挖掘-聚类-K-means算法Java实现

阅读量:

K-Means算法是最早也是应用最为广泛的聚类方法之一,在数据挖掘领域得到了广泛应用。该算法通过质心来定义每个簇的典型代表点——质心即一组数据点的平均位置,在n维连续的空间中应用这一方法时表现良好。

K-Means算法流程

step1:选择K个点作为初始质心

step2:repeat

将每个点指派到最近的质心,形成K个簇

重新计算每个簇的质心

until 质心不在变化

如图所示的样本集合,在初步筛选时选择了三个较为紧密分布的质心点;经过三次迭代运算后,其位置逐渐趋于稳定状态,并最终将整个样本集划分为三个子集。

230ebd92b1364e785a6e6896a90aa818.png

我们对每一个步骤都进行分析

step1:选择K个点作为初始质心

在执行这一步骤之前,请明确了解所需计算出的K值及其意义;同时需要区分这里的K值与基于EM算法的那种动态划分方式的不同。

其次,如何选择初始质心

最简单的办法等同于随机选取中心点。然后进行多次运行,在所得结果中选择表现最佳的那个。这个方法虽然简单,但未必能达到理想的效果。然而这个方案存在较大的可能性会获得局部最优解。

另外一种较为复杂的策略是:首先随机选择一个质心中心,在此基础上找出与该质心距离最远的数据样本。接着对后续选择的每一个新质心,则会从之前选定的所有质心中挑选出与之距离最远的一个作为代表。通过这种方式实现,在最终结果中能够保证所选的质心既是随机分布又是分散开来的。

step2:repeat

将每个点指派到最近的质心,形成K个簇

重新计算每个簇的质心

until 质心不在变化

如何确定最邻近的概念?在欧式空间中计算两个点之间的距离可基于欧式空间。在文本处理中,则可采用余弦相似性等方法。给定的数据集能够适应采用多种合适的邻近性度量方法。

其他问题

离群点的处理

离群点可能会严重影响簇的发现。这可能导致最后阶段的结果与我们的预期相差较大。因此,在分析过程中识别并排除这些离群点是非常重要的。

在我的工作中,是利用方差来剔除离群点,结果显示效果非常好。

簇分裂和簇合并

选择较大的值K通常会使得聚类结果更具合理性。然而,在许多实际场景中,我们倾向于维持较低的簇数量。

此时尚常交替使用分割与合并策略。此方法可避免陷入局部最小值,并且能够实现预期数量的聚类结果。

贴上代码java版,以后有时间写个python版的

抽象了点,簇,和距离

Point.class

public class Point {

private double x;

private double y;

private int id;

private boolean beyond;//标识是否属于样本

public Point(int id, double x, double y) {

this.id = id;

this.x = x;

this.y = y;

this.beyond = true;

}

public Point(int id, double x, double y, boolean beyond) {

this.id = id;

this.x = x;

this.y = y;

this.beyond = beyond;

}

public double getX() {

return x;

}

public double getY() {

return y;

}

public int getId() {

return id;

}

public boolean isBeyond() {

return beyond;

}

@Override

public String toString() {

return "Point{" +

"id=" + id +

", x=" + x +

", y=" + y +

'}';

}

@Override

public boolean equals(Object o) {

if (this == o) return true;

if (o == null || getClass() != o.getClass()) return false;

Point point = (Point) o;

if (Double.compare(point.x, x) != 0) return false;

if (Double.compare(point.y, y) != 0) return false;

return true;

}

@Override

public int hashCode() {

int result;

long temp;

temp = x != +0.0d ? Double.doubleToLongBits(x) : 0L;

result = (int) (temp ^ (temp >>> 32));

temp = y != +0.0d ? Double.doubleToLongBits(y) : 0L;

result = 31 * result + (int) (temp ^ (temp >>> 32));

return result;

}

} Cluster.class

public class Cluster {

private int id;//标识

private Point center;//中心

private List members = new ArrayList();//成员

public Cluster(int id, Point center) {

this.id = id;

this.center = center;

}

public Cluster(int id, Point center, List members) {

this.id = id;

this.center = center;

this.members = members;

}

public void addPoint(Point newPoint) {

if (!members.contains(newPoint))

members.add(newPoint);

else

throw new IllegalStateException("试图处理同一个样本数据!");

}

public int getId() {

return id;

}

public Point getCenter() {

return center;

}

public void setCenter(Point center) {

this.center = center;

}

public List getMembers() {

return members;

}

@Override

public String toString() {

return "Cluster{" +

"id=" + id +

", center=" + center +

", members=" + members +

"}";

}

} 抽象的距离,可以具体实现为欧式,曼式或其他距离公式

public abstract class AbstractDistance {

abstract public double getDis(Point p1, Point p2);

} 点对

public class Distence implements Comparable {

private Point source;

private Point dest;

private double dis;

private AbstractDistance distance;

public Distence(Point source, Point dest, AbstractDistance distance) {

this.source = source;

this.dest = dest;

this.distance = distance;

dis = distance.getDis(source, dest);

}

public Point getSource() {

return source;

}

public Point getDest() {

return dest;

}

public double getDis() {

return dis;

}

@Override

public int compareTo(Distence o) {

if (o.getDis() > dis)

return -1;

else

return 1;

}

}

核心实现类

public class KMeansCluster {

private int k;//簇的个数

private int num = 100000;//迭代次数

private List datas;//原始样本集

private String address;//样本集路径

private List data = new ArrayList();

private AbstractDistance distance = new AbstractDistance() {

@Override

public double getDis(Point p1, Point p2) {

//欧几里德距离

return Math.sqrt(Math.pow(p1.getX() - p2.getX(), 2) + Math.pow(p1.getY() - p2.getY(), 2));

}

};

public KMeansCluster(int k, int num, String address) {

this.k = k;

this.num = num;

this.address = address;

}

public KMeansCluster(int k, String address) {

this.k = k;

this.address = address;

}

public KMeansCluster(int k, List datas) {

this.k = k;

this.datas = datas;

}

public KMeansCluster(int k, int num, List datas) {

this.k = k;

this.num = num;

this.datas = datas;

}

private void check() {

if (k == 0)

throw new IllegalArgumentException("k must be the number > 0");

if (address == null && datas == null)

throw new IllegalArgumentException("program can't get real data");

}

/**

  • 初始化数据

  • @throws java.io.FileNotFoundException

*/

public void init() throws FileNotFoundException {

check();

//读取文件,init data

//处理原始数据

for (int i = 0, j = datas.size(); i < j; i++)

data.add(new Point(i, datas.get(i), 0));

}

/**

  • 第一次随机选取中心点

  • @return

*/

public Set chooseCenter() {

Set center = new HashSet();

Random ran = new Random();

int roll = 0;

while (center.size() < k) {

roll = ran.nextInt(data.size());

center.add(data.get(roll));

}

return center;

}

/**

  • @param center

  • @return

*/

public List prepare(Set center) {

List cluster = new ArrayList();

Iterator it = center.iterator();

int id = 0;

while (it.hasNext()) {

Point p = it.next();

if (p.isBeyond()) {

Cluster c = new Cluster(id++, p);

c.addPoint(p);

cluster.add(c);

} else

cluster.add(new Cluster(id++, p));

}

return cluster;

}

/**

  • 第一次运算,中心点为样本值

  • @param center

  • @param cluster

  • @return

*/

public List clustering(Set center, List cluster) {

Point[] p = center.toArray(new Point[0]);

TreeSet distence = new TreeSet();//存放距离信息

Point source;

Point dest;

boolean flag = false;

for (int i = 0, n = data.size(); i < n; i++) {

distence.clear();

for (int j = 0; j < center.size(); j++) {

if (center.contains(data.get(i)))

break;

flag = true;

// 计算距离

source = data.get(i);

dest = p[j];

distence.add(new Distence(source, dest, distance));

}

if (flag == true) {

Distence min = distence.first();

for (int m = 0, k = cluster.size(); m < k; m++) {

if (cluster.get(m).getCenter().equals(min.getDest()))

cluster.get(m).addPoint(min.getSource());

}

}

flag = false;

}

return cluster;

}

/**

  • 迭代运算,中心点为簇内样本均值

  • @param cluster

  • @return

*/

public List cluster(List cluster) {

// double error;

Set lastCenter = new HashSet();

for (int m = 0; m < num; m++) {

// error = 0;

Set center = new HashSet();

// 重新计算聚类中心

for (int j = 0; j < k; j++) {

List ps = cluster.get(j).getMembers();

int size = ps.size();

if (size < 3) {

center.add(cluster.get(j).getCenter());

continue;

}

// 计算距离

double x = 0.0, y = 0.0;

for (int k1 = 0; k1 < size; k1++) {

x += ps.get(k1).getX();

y += ps.get(k1).getY();

}

//得到新的中心点

Point nc = new Point(-1, x / size, y / size, false);

center.add(nc);

}

if (lastCenter.containsAll(center))//中心点不在变化,退出迭代

break;

lastCenter = center;

// 迭代运算

cluster = clustering(center, prepare(center));

// for (int nz = 0; nz < k; nz++) {

// error += cluster.get(nz).getError();//计算误差

// }

}

return cluster;

}

/**

  • 输出聚类信息到控制台

  • @param cs

*/

public void out2console(List cs) {

for (int i = 0; i < cs.size(); i++) {

System.out.println("No." + (i + 1) + " cluster:");

Cluster c = cs.get(i);

List p = c.getMembers();

for (int j = 0; j < p.size(); j++) {

System.out.println("\t" + p.get(j).getX() + " ");

}

System.out.println();

}

}

} 代码还没有仔细优化,执行的效率可能还存在一定的问题

全部评论 (0)

还没有任何评论哟~