Advertisement

LIME算法:图像分类解释器(代码实现)

阅读量:

之前在一篇名为《LIME算法:模型的可解释性(代码实现)》的文章中,我详细介绍了该算法的基本原理及其在文本分类模型中的具体应用。在此笔记中,则深入探讨了LIME算法在图像分类模型中的应用,并总结了解决方案以及遇到的问题与应对策略。

一、算法简介

LIME算法源自 Marco Tulio Ribeiro 2016 年出版的一篇论文《Why Should I Trust You? Explaining the Predictions of Any Classifier》,其中介绍了局部可解释性模型算法。该方法主要用于分析文本类与图像类模型中的行为机制。

在这里插入图片描述

在日常进行图像分类模型测试时,经常遇到难以理解的预测结果。通过将我家小猫的照片用于测试,在分析结果显示中出现了‘纸箱、安全带’这样的分类项。

在这里插入图片描述

出于好奇地想知道,在家的小猫究竟有哪些部位与安全带相似?而想要获得这一答案,则可以通过借助LIME解释器来进行说明。

二、LIME解释器代码实现

为了实现LIME解释器在图像分类模型中的应用, 首先需要一个已建立完善的图像分类模型. 在此过程中, 可以借鉴lime算法的GitHub实例, 通过keras深度学习框架获取Google Inception net-v3预训练的深度神经网络模型.

复制代码
    #加载需要的包
    import os
    import keras
    from keras.applications import inception_v3 as inc_net
    from keras.preprocessing import image
    from keras.applications.imagenet_utils import decode_predictions
    from skimage.io import imread
    import matplotlib.pyplot as plt
    import numpy as np
    print('Notebook run using keras:', keras.__version__)
    
    #下载Google Inception net-v3深度神经网络模型
    inet_model = inc_net.InceptionV3()
    
    
      
      
      
      
      
      
      
      
      
      
      
      
      
    
    代码解释

对待分类图像做数据预处理

复制代码
    def transform_img_fn(path_list):
    out = []
    for img_path in path_list:
        img = image.load_img(img_path, target_size=(299, 299))
        x = image.img_to_array(img)
        x = np.expand_dims(x, axis=0)
        x = inc_net.preprocess_input(x)
        out.append(x)
    return np.vstack(out)
    
    
      
      
      
      
      
      
      
      
      
    
    代码解释

读取图像,输出预测结果。

复制代码
    images = transform_img_fn([os.path.join('./','cat.jpg')])#加载图像后直接进行数据处理
    plt.imshow(images[0] / 2 + 0.5)
    preds = inet_model.predict(images)
    for x in decode_predictions(preds)[0]:
    print(x)#输出预测结果
    
    
      
      
      
      
      
    
    代码解释

可以看到输出的预测结果TOP5,分别为埃及猫、猞猁、纸箱、窗口屏幕、安全带。

在这里插入图片描述

预测模型以最高概率识别出"埃及猫"与真实猫咪相一致。然而为何模型会将预测结果归类为纸箱、安全带等明显不符之物?这些问题可通过LIME算法得以解析清楚。

基于LIME的方法原理如下:首先我们将原始图像转换为易于理解的关键特征表示。接着通过对这些关键特征施加扰动来生成多个被扰乱后的样本集合。随后我们将这些被扰乱后的样本还原至原始的空间中进行分析,并将其作为目标预测结果的基础。利用这些关键信息构建一个简化的模型或数据表达形式以便观察哪些重要的区域(如超像素)具有较大的权重或影响程度。这种方法的核心在于通过简化分析工具来揭示复杂模型的行为机制。

复制代码
    #加载lime包
    import lime
    from lime import lime_image
    
    
      
      
      
    
    代码解释

建立解释器,explain_instance的参数包括:

辅助理解图像: 代解释图像
分类器函数: 分类器
标签集合: 可分析标签
遮蔽色: 隐藏颜色
高概率标签: 预测概率最高的K个标签生成说明
功能数量: 说明中出现的最大功能数
样本数量: 学习线性模型的领域大小
批量大小: 批处理大小
距离计算方法: 距离度量
模型回归器,默认使用岭回归": 模型回归器,默认为岭回归
分段函数,默认将图像分割成指定尺寸": 分段,将图像分为多少个大小
随机种子,默认用于分割算法": 随机种子,用作分割算法的随机种子

复制代码
    explainer = lime_image.LimeImageExplainer()
    x=images[0].astype(np.double) #lime要求numpy array
    explanation = explainer.explain_instance(x, inet_model.predict, top_labels=5, hide_color=0, num_samples=1000)
    
    
      
      
      
    
    代码解释

解释器进度条跑完说明解释器运行完成。

在这里插入图片描述

对图像分类结果进行解释,首先看“埃及猫”的解释。

复制代码
    from skimage.segmentation import mark_boundaries
    temp, mask = explanation.get_image_and_mask(explanation.top_labels[0], positive_only=True, num_features=5, hide_rest=True)
    plt.imshow(mark_boundaries(temp / 2 + 0.5, mask))
    
    
      
      
      
    
    代码解释

显示出的图像是对分类结果最有力的部分。

在这里插入图片描述

可以调整参数hide_rest让图像的其他部分取消隐藏。

复制代码
    temp, mask = explanation.get_image_and_mask(explanation.top_labels[0], positive_only=True, num_features=5, hide_rest=False)
    plt.imshow(mark_boundaries(temp / 2 + 0.5, mask))
    
    
      
      
    
    代码解释
在这里插入图片描述

还可以看到赞成部分和反对部分,赞成为绿色区域,反对为红色区域。

复制代码
    temp, mask = explanation.get_image_and_mask(explanation.top_labels[0], positive_only=True, num_features=5, hide_rest=False)
    plt.imshow(mark_boundaries(temp / 2 + 0.5, mask))
    
    
      
      
    
    代码解释
在这里插入图片描述

这部分内容是关于预测结果为'埃及猫'的说明;除此之外, 还可以参考关于预测结果为'安全带'的相关说明.

复制代码
    temp, mask = explanation.get_image_and_mask(explanation.top_labels[4], positive_only=True, num_features=5, hide_rest=True)
    plt.imshow(mark_boundaries(temp / 2 + 0.5, mask))
    
    
      
      
    
    代码解释
在这里插入图片描述

可以看到模型对于图像中“不太像猫“的部分作出了”安全带“的分类。

基于此,借助LIME算法我们可以生成该图像分类结果的可视化说明,并帮助我们更清晰地了解模型如何进行分类。

三、LIME算法的应用条件

为了实现对LIME算法的理解与应用,在使用Keras框架之前

复制代码
    images = transform_img_fn([os.path.join('./','cat.jpg')])
    images.shape
    
    
      
      
    
    代码解释

输出结果呈现为:(1, 29866467774758648650844847608406857764007)
而PyTorch中的张量通道排列模式遵循NCHW格式,在该格式中数值的排列情况表现为(1, C=3,H=256,W=256)。
因此,在基于PyTorch框架实现LIME解释器时,则需考虑通道顺序问题,并应首先封装分类器。

复制代码
    def batch_model(images):
    model.eval()
    images=torch.FloatTensor(images).permute(0, 3, 1, 2)
    #通过permute改变通道顺序适用于pytorch,也就是(batches, channels, height, width)
    return model(images)
    
    
      
      
      
      
      
    
    代码解释

四、参考资料

1、https://github.com/marcotcr/lime

全部评论 (0)

还没有任何评论哟~