Advertisement

均值漂移(mean shift )聚类算法Matlab实现详解

阅读量:

Mean shift 算法是基于核密度估计的爬山算法,可用于聚类、图像分割、跟踪等,其在声呐图像数据处理也有广泛的应用,笔者在网上找了一遍也没有找到关于Mean shift的matlab实现代码,找到的都是关于它的文字描述,无奈笔者只能根据网上找到的文字描述自己动手编写相关的matlab代码,现分享给大家。

1、均值漂移的基本形式

对于N维空间中给定的点集X,则对于空间中的任意点x与点集X中距离小于r的点x_{i}的mean shift向量为:
M_{r}=rac{1}{K}um_{x_{i}n S}^{}, S=eft  y:{T}(y-x)<r2,yn X ight

而漂移的过程,就是通过计算偏移量,然后不断的更新球心的位置,更新公式为:
x=x+M_{r}

直到偏移量的值很小时停止更新。

2、mean shift算法流程文字描述

假设多维空间中的数据点类别数未知,选定搜素半径r,执行如下步骤:

1、在未被标记的数据点中随机选择一个点作为中心x_{0}

2、找出所有离x_{0}距离小于r的点,记作集合M,并认为这些点属于类别c,同时将这些点在类别c上的访问次数加1;

3、以x_{0}为中心点,计算x_{0}到集合M中每个元素的向量,将这些向量相加,得到漂移向量M_{r}

4、更新中心点,x_{0}=x_{0}+M_{r}。表示x_{0}沿着方向M_{r}移动了距离eft  M_{r} ight

5、重复步骤2-4,直到eft  M_{r} ight 的大小很小,小于设置的阈值后,停止迭代,记住此时的x_{0},在这个迭代过程中的遇到的所有的点都属于类别c。

6、如果收敛时当前的类别c的中心于之前已经存在的类别{c}'的中心小于阈值,那么当前的c应该和{c}'属于同一类,并合并成{c}',否则把c作为新的类别,增加一类。

7、重复1-6直到所有的点都被标记访问。

8、分类:根据每个点找出其被访问次数最多的那一类,并将其归属到此类中。

以上就是均值漂移聚类算法流程。

3、mean shift算法matlab实现

下面笔者将给出均值漂移算法的matlab程序:

复制代码
 function [out,category] = mean_shift(radius,threshould,data)

    
 % 均值漂移聚类分析
    
 % 输入参数
    
 % radius    聚类半径
    
 % data      K-by-N    k个N维数据点集
    
 % 输出参数
    
 %%
    
 r2 = radius*radius;
    
 threshould2 = threshould*threshould;
    
 [k,N] = size(data);
    
 access_cnt = zeros(k,1);  %每个点被不同类访问次数计数
    
 center = data(1,:);
    
 dir = zeros(k,N);
    
 cluster_cnt = 1;
    
 density_l = 0;
    
 cnt = 0;
    
  
    
 theta = (0:1:360)/180*pi;
    
 circle_x = radius*cos(theta);
    
 circle_y = radius*sin(theta);
    
  
    
 figure;
    
 h = axes;
    
 plot(h,data(:,1),data(:,2),'k.');
    
 hold(h,'on');grid(h,'on');
    
  
    
 while 1
    
     cnt = cnt + 1;
    
     
    
     for i = 1:N
    
     dir(:,i) = data(:,i)-center(i);
    
     end
    
     dis = sum(dir.^2,2);  %按行求和
    
  
    
     indx = find(dis < r2);  %找到半径r内的数据点
    
     density = length(indx);
    
     shift = sum(dir(indx,:))/density;  %求飘移值
    
     access_cnt(indx,cluster_cnt) = access_cnt(indx,cluster_cnt) + 1; %当前类访问次数累加
    
     
    
 %     if cnt > 1
    
 %         delete(h1);
    
 %         delete(h2);
    
 %     end
    
     h1 = plot(h,circle_x+center(1),circle_y+center(2),'g');
    
     h2 = plot(h,data(indx,1),data(indx,2),'r.');
    
     
    
 %      if shift*shift' < threshould2   %判断是否满足停止收敛条件
    
     if density_l >= density
    
     density_l = 0;
    
     if cluster_cnt == 1
    
         out(cluster_cnt,:) = center;
    
     else
    
         dir_t = out;
    
         for kk = 1:cluster_cnt-1
    
             dir_t(kk,:) = out(kk,:)-center;   %将当前的收敛中心于之前的计算距离
    
         end
    
         dis_t = sum(dir_t.^2,2);
    
         [min_dis,min_indx] = min(dis_t);
    
         if min_dis < threshould2        %判断当前的中心离之前已有的中心距离是否小于阈值
    
             access_cnt(:,min_indx)= access_cnt(:,min_indx) + access_cnt(:,cluster_cnt);
    
             access_cnt(:,cluster_cnt) = 0;   %清零之前的分类访问
    
             cluster_cnt = cluster_cnt - 1;
    
         else
    
             out(cluster_cnt,:) = center;
    
         end
    
     end
    
     cluster_cnt = cluster_cnt +1;      %类别计数
    
     acc_cnt_p = sum(access_cnt,2);     %求每个点已被访问的次数
    
     no_acc_p = find(acc_cnt_p == 0);   %找出还没有被访问的点
    
     if size(no_acc_p,1) > 0
    
         center = data(no_acc_p(1),:);  %初始化成没有被访问点
    
     else
    
         break;
    
     end
    
     if size(access_cnt,2) < cluster_cnt      %判断有没有新增的类,有的话添加一类
    
         access_cnt = [access_cnt,zeros(k,1)];
    
     end
    
      else
    
     density_l = density;
    
     center = center + shift;  %更新中心值
    
     pause(0.02);        
    
     end
    
 end
    
   85. category = zeros(k,1);
    
 for kk = 1:k
    
     [max_acc,max_indx] = max(access_cnt(kk,:));  %找出当前点的最大访问次数及其类别
    
     category(kk) = max_indx;                     %将对应的点表上类别
    
 end
    
    
    
    
    代码解读

笔者生成了一组二维的数据点进行了测试,代码如下:

复制代码
 clear all;

    
 close all;
    
 clc;
    
 %%
    
 num = 500;
    
 radius = 3;
    
 threshould = 0.2;
    
 data1 = [randn(num,1),randn(num,1)];
    
 data2 = [randn(num,1)+6,randn(num,1)+6];
    
 data3 = [randn(num,1)-6,randn(num,1)+6];
    
 data4 = [randn(num,1)-6,randn(num,1)-6];
    
 data5 = [randn(num,1)+6,randn(num,1)-6];
    
 data = [data1;data2;data3;data4;data5];
    
 [out,category] = mean_shift(radius, threshould, data);
    
  
    
 category_num = size(out,1);
    
 figure;
    
 plot(data(:,1),data(:,2),'k.');
    
 hold on; grid on;
    
 plot(out(:,1),out(:,2),'r*');
    
 if category_num == 1
    
     plot(data(:,1),data(:,2),'c.');
    
 elseif category_num == 2
    
     category1 = data(find(category == 1),:);
    
     category2 = data(find(category == 2),:);
    
     plot(category1(:,1),category1(:,2),'c.');
    
     plot(category2(:,1),category2(:,2),'g.');
    
 elseif category_num == 3
    
     category1 = data(find(category == 1),:);
    
     category2 = data(find(category == 2),:);
    
     category3 = data(find(category == 3),:);
    
     plot(category1(:,1),category1(:,2),'c.');
    
     plot(category2(:,1),category2(:,2),'g.');    
    
     plot(category3(:,1),category3(:,2),'y.'); 
    
 elseif category_num == 4
    
     category1 = data(find(category == 1),:);
    
     category2 = data(find(category == 2),:);
    
     category3 = data(find(category == 3),:);
    
     category4 = data(find(category == 4),:);
    
     plot(category1(:,1),category1(:,2),'c.');
    
     plot(category2(:,1),category2(:,2),'g.');    
    
     plot(category3(:,1),category3(:,2),'y.');     
    
     plot(category4(:,1),category4(:,2),'b.'); 
    
 else
    
     category1 = data(find(category == 1),:);
    
     category2 = data(find(category == 2),:);
    
     category3 = data(find(category == 3),:);
    
     category4 = data(find(category == 4),:);
    
     category5 = data(find(category == 5),:);
    
     plot(category1(:,1),category1(:,2),'c.');
    
     plot(category2(:,1),category2(:,2),'g.');    
    
     plot(category3(:,1),category3(:,2),'y.');     
    
     plot(category4(:,1),category4(:,2),'b.');    
    
     plot(category5(:,1),category5(:,2),'k.');
    
 end
    
    
    
    
    代码解读

运行一下代码出现如下结果:

图1

图2

图1为算法运行时,数据点被访问的过程图。图2位分类的结果图,可以看到一共有5类数据被准确的分出来。

全部评论 (0)

还没有任何评论哟~