Advertisement

图像分割算法:基于KNN的像素级分割算法

阅读量:

前提假设:图像分割区域较明显。

2.算法步骤:

2-1阶段预处理:从输入图像中按区域提取样本,并按照类别比例将这些样本存储起来。(也可视为构建训练数据集的过程)

算法执行:(1)获取各类别像素的均值,并按通道数量计算平均值;随后将结果分别保存

(2)提取输入图像的每一个像素值(包括RGB)信息。

将输入的每个像素值与各类别对应的平均像素值进行比较,并计算与其对应类均值像素的平方误差,在所有可能类别中找出使总误差最小的那个类别。

(4)通过步骤(3)即可将输入图像的每一个像素值分到每一类别样本中。

3.算法效果:利用该算法大致能够实现图像分割的目的。然而,在图像边缘区域的分割效果尚欠理想。

4.代码实现:

复制代码
 import time

    
 import numpy as np                #导入numpy库
    
 import matplotlib.image as mpimg         #加载pil的包  
    
 from sklearn.cluster import KMeans  
    
 import matplotlib.pyplot as plt
    
 from skimage import io,data
    
  
    
 time_start=time.time()
    
  
    
 num_class = 8 # 七巧板+背景共八个类别
    
 num_channels = 3  #通道数
    
 str_train_path = "./train"  #训练样本集路径
    
 str_test_path = "./test"  #测试样本集路径
    
 centers = np.zeros([num_class,num_channels])  #创建一个8行3列的矩阵
    
  
    
 ########################################################################################################################
    
 # 加载图像数据函数
    
 def loadData(filePath):  
    
     data = []  
    
     img = mpimg.imread(filePath)        #返回图片的像素值(即RGB信息)
    
     img = np.array(img,dtype='float')   #返回图片的浮点型数据(整数型->浮点型)
    
     m,n,_ = img.shape                   #返回图片的大小(即W,H,C)
    
     img = img/255                       #把RGB分量数值归一化到0-1的范围
    
  
    
     for i in range(m):                  #在W的范围内做循环(相当于按列循环)
    
     for j in range(n):              #在Wi的范围下的H的范围内做循环(相当于在某一列下按行循环)
    
         x,y,z = img[i,j,:]
    
         data.append([x,y,z])
    
     return np.mat(data),m,n             #返回原尺寸图片每一个像素点的像素数据、行、列
    
 ########################################################################################################################
    
  
    
 # generating centers - mannually assisted
    
 for ii in range(num_class):
    
     str_train_file = str_train_path +'/'+ str(ii) + ".jpg"    #训练样本文件路径
    
     normalized_imgData,row,col = loadData(str_train_file)     #返回原尺寸图片每一个像素点的像素数据、行、列
    
     centers[ii,:] = np.mean(normalized_imgData,axis=0)        #计算沿指定轴的算术平均值,
    
                                                           # 这里指对RGB三通道的像素数据求平均值
    
  
    
 ########################################################################################################################
    
 # 在这里可以加入将centers保存入文件的程序
    
     
    
 # 在这里可以添加将已存储的centers载入的程序
    
 ########################################################################################################################
    
 # 测试过程
    
 str_test_file = str_test_path +"/test.jpg"             #可以将从camera上得到的图片存储在特定文件夹下
    
  
    
 img = mpimg.imread(str_test_file)                      #返回图片的像素值
    
 img = np.array(img,dtype='float')/255                  #把RGB分量数值归一化到0-1的范围
    
 label_pixel = np.zeros([img.shape[0],img.shape[1]])    #针对被测试图像的shape建立一个基于像素点的label
    
 print(label_pixel.shape)
    
  
    
 for ii in range(img.shape[0]):
    
     for jj in range(img.shape[1]):
    
     pixel = img[ii,jj,:]
    
     square_distance = np.sum(np.square(centers - pixel),axis=1)  #对每一个通道上的像素值的平方差(center-pixel)求和
    
     label_pixel[ii,jj] = np.argmin(square_distance)    #在一个轴上返回最小值的索引。
    
  
    
 label_pixel = np.array(label_pixel,dtype='int32')
    
  
    
 print(label_pixel)
    
  
    
 color_map = [
    
          [252,4,10],
    
          [249,138,8],
    
          [249,246,8],
    
          [14,249,8],
    
          [8,240,249],
    
          [8,70,249],
    
          [106,13,155],
    
          [0,0,0]
    
         ]
    
 color_map = np.array(color_map,dtype='int32')
    
  
    
 img_show = np.zeros([label_pixel.shape[0],label_pixel.shape[1],3])  #创建一个[1160,966,3]的三维矩阵
    
  
    
  
    
 for ii in range(img.shape[0]):
    
     for jj in range(img.shape[1]):
    
     ind = label_pixel[ii,jj]
    
     # print(ind)
    
     img_show[ii,jj,0] = color_map[ind,0]
    
     img_show[ii,jj,1] = color_map[ind,1]
    
     img_show[ii,jj,2] = color_map[ind,2]
    
  
    
 io.imshow(img_show)
    
 plt.show()
    
 time_end = time.time()
    
 use_time = time_end-time_start
    
 print("the time for recognition: %f s " % (use_time))

全部评论 (0)

还没有任何评论哟~