基于SHAP框架的MobileNetv2模型事后可解释性分析
发布时间
阅读量:
阅读量
基于SHAP框架的MobileNetv2模型事后可解释性分析
-
0. MobileNetv2模型
-
-
1.1 预测单张数据
-
- 1.1.1 加载包
- 1.1.2 加载torchvision的预训练模型
-
1.2 批量预测数据
-
- 1.2.1 加载包
- 1.2.2 加载模型并训练
-
-
2. 训练与可解释性分析
-
- 2.1 单张预测结果
- 2.2 批量预测结果
0. MobileNetv2模型
MobileNetV2是对MobileNetV1的一种优化改进,并被设计为一种轻量级神经网络模型。该模型继承了V1版本的深度可分离卷积结构,并在此基础上引入了线性瓶颈(Linear Bottleneck)和倒残差(Inverted Residual)两种关键组件以提升性能。在整体网络架构中采用了固定扩展率设计,在实验研究中发现,在5到10之间的扩展率设置能够带来几乎一致的理想性能曲线。A类网络在使用较低扩展率时表现出更好的性能表现,在整体网络架构中采用了固定扩展率设计,在实验研究中发现,在5到10之间的扩展率设置能够带来几乎一致的理想性能曲线。A类网络在使用较低扩展率时表现出更好的性能表现
MobileNetV2主要采用一个扩张因子为6的设计策略来处理输入张量的尺寸变化。具体而言,在经过该网络架构时,在经过瓶颈层(Bottleneck layer)之前的特征图通常会先通过一个中间扩张层(Intermediate expansion layer)进行放大处理。例如,在一个使用了64个输入通道并最终生成128个输出通道的瓶颈层中,中间扩张层将具有64 \times 6 = 384个输出通道的数量。
1.1 预测单张数据
1.1.1 加载包
import numpy as np
import torchvision
import torch
import torch.nn as nn
import shap
from PIL import Image
import json
AI助手
1.1.2 加载torchvision的预训练模型
###
'''
预测单张数据
'''
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = torchvision.models.mobilenet_v2(pretrained=True, progress=False)
model.to(device)
model.eval()
X, y = shap.datasets.imagenet50()
# Prepare data transformation pipeline
mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]
topk = 4
batch_size = 50
n_evals = 10000
############################################################################################################
url = "https://s3.amazonaws.com/deep-learning-models/image-models/imagenet_class_index.json"
with open(shap.datasets.cache(url)) as file:
class_names = [v[1] for v in json.load(file).values()]
##########################################################################################################
def nhwc_to_nchw(x: torch.Tensor) -> torch.Tensor:
if x.dim() == 4:
x = x if x.shape[1] == 3 else x.permute(0, 3, 1, 2)
elif x.dim() == 3:
x = x if x.shape[0] == 3 else x.permute(2, 0, 1)
return x
def nchw_to_nhwc(x: torch.Tensor) -> torch.Tensor:
if x.dim() == 4:
x = x if x.shape[3] == 3 else x.permute(0, 2, 3, 1)
elif x.dim() == 3:
x = x if x.shape[2] == 3 else x.permute(1, 2, 0)
return x
transform= [
torchvision.transforms.Lambda(nhwc_to_nchw),
torchvision.transforms.Lambda(lambda x: x*(1/255)),
torchvision.transforms.Normalize(mean=mean, std=std),
torchvision.transforms.Lambda(nchw_to_nhwc),
]
inv_transform= [
torchvision.transforms.Lambda(nhwc_to_nchw),
torchvision.transforms.Normalize(
mean = (-1 * np.array(mean) / np.array(std)).tolist(),
std = (1 / np.array(std)).tolist()
),
torchvision.transforms.Lambda(nchw_to_nhwc),
]
transform = torchvision.transforms.Compose(transform)
inv_transform = torchvision.transforms.Compose(inv_transform)
def predict(img: np.ndarray) -> torch.Tensor:
img = nhwc_to_nchw(torch.Tensor(img))
img = img.to(device)
output = model(img)
return output
###########################################################################################################
# Check that transformations work correctly
Xtr = transform(torch.Tensor(X))
out = predict(Xtr[1:3])
classes = torch.argmax(out, axis=1).cpu().numpy()
print(f'Classes: {classes}: {np.array(class_names)[classes]}')
########################################################################################################
# define a masker that is used to mask out partitions of the input image.
masker_blur = shap.maskers.Image("blur(128,128)", Xtr[0].shape)
# create an explainer with model and image masker
explainer = shap.Explainer(predict, masker_blur, output_names=class_names)
# feed only one image
# here we explain two images using 100 evaluations of the underlying model to estimate the SHAP values
shap_values = explainer(Xtr[1:4], max_evals=n_evals, batch_size=batch_size,
outputs=shap.Explanation.argsort.flip[:topk])
shap_values.data = inv_transform(shap_values.data).cpu().numpy()[0]
shap_values.values = [val for val in np.moveaxis(shap_values.values[0],-1, 0)]
shap.image_plot(shap_values=shap_values.values,
pixel_values=shap_values.data,
labels=shap_values.output_names,
true_labels=[class_names[132]])
AI助手
1.2 批量预测数据
1.2.1 加载包
import numpy as np
import torchvision
import torch
import torch.nn as nn
import shap
from PIL import Image
import json
AI助手
1.2.2 加载模型并训练
'''
批量化预测Imagenet数据集
'''
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = torchvision.models.mobilenet_v2(pretrained=True, progress=False)
model.to(device)
model.eval()
X, y = shap.datasets.imagenet50()
# Prepare data transformation pipeline
mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]
topk = 4
batch_size = 50
n_evals = 10000
############################################################################################################
url = "https://s3.amazonaws.com/deep-learning-models/image-models/imagenet_class_index.json"
with open(shap.datasets.cache(url)) as file:
class_names = [v[1] for v in json.load(file).values()]
##########################################################################################################
def nhwc_to_nchw(x: torch.Tensor) -> torch.Tensor:
if x.dim() == 4:
x = x if x.shape[1] == 3 else x.permute(0, 3, 1, 2)
elif x.dim() == 3:
x = x if x.shape[0] == 3 else x.permute(2, 0, 1)
return x
def nchw_to_nhwc(x: torch.Tensor) -> torch.Tensor:
if x.dim() == 4:
x = x if x.shape[3] == 3 else x.permute(0, 2, 3, 1)
elif x.dim() == 3:
x = x if x.shape[2] == 3 else x.permute(1, 2, 0)
return x
transform= [
torchvision.transforms.Lambda(nhwc_to_nchw),
torchvision.transforms.Lambda(lambda x: x*(1/255)),
torchvision.transforms.Normalize(mean=mean, std=std),
torchvision.transforms.Lambda(nchw_to_nhwc),
]
inv_transform= [
torchvision.transforms.Lambda(nhwc_to_nchw),
torchvision.transforms.Normalize(
mean = (-1 * np.array(mean) / np.array(std)).tolist(),
std = (1 / np.array(std)).tolist()
),
torchvision.transforms.Lambda(nchw_to_nhwc),
]
transform = torchvision.transforms.Compose(transform)
inv_transform = torchvision.transforms.Compose(inv_transform)
def predict(img: np.ndarray) -> torch.Tensor:
img = nhwc_to_nchw(torch.Tensor(img))
img = img.to(device)
output = model(img)
return output
###########################################################################################################
# Check that transformations work correctly
Xtr = transform(torch.Tensor(X))
out = predict(Xtr[1:5])
classes = torch.argmax(out, axis=1).cpu().numpy()
print(f'Classes: {classes}: {np.array(class_names)[classes]}')
########################################################################################################
# define a masker that is used to mask out partitions of the input image.
masker_blur = shap.maskers.Image("blur(128,128)", Xtr[0].shape)
# create an explainer with model and image masker
explainer = shap.Explainer(predict, masker_blur, output_names=class_names)
# feed only one image
# here we explain two images using 100 evaluations of the underlying model to estimate the SHAP values
shap_values = explainer(Xtr[1:5], max_evals=n_evals, batch_size=batch_size,
outputs=shap.Explanation.argsort.flip[:topk])
shap_values.data = inv_transform(shap_values.data).cpu().numpy()
shap_values.values = [val for val in np.moveaxis(shap_values.values,-1, 0)]
shap.image_plot(shap_values=shap_values.values,
pixel_values=shap_values.data,
labels=shap_values.output_names)
AI助手
2. 训练与可解释性分析
2.1 单张预测结果
The parameter 'pretrained' is deprecated since 0.13 and will be removed in 0.15, please use 'weights' instead.
Arguments other than a weight enum or `None` for 'weights' are deprecated since 0.13 and will be removed in 0.15. The current behavior is equivalent to passing `weights=MobileNet_V2_Weights.IMAGENET1K_V1`. You can also use `weights=MobileNet_V2_Weights.DEFAULT` to get the most up-to-date weights.
Classes: [132 814]: ['American_egret' 'speedboat']
Partition explainer: 33%|███▎ | 1/3 [00:00<?, ?it/s]
0%| | 0/9998 [00:00<?, ?it/s]
20%|██ | 2042/9998 [00:00<00:00, 15163.07it/s]
36%|███▌ | 3592/9998 [00:04<00:09, 693.75it/s]
...
99%|█████████▉| 9942/9998 [00:24<00:00, 295.75it/s]
100%|█████████▉| 9992/9998 [00:25<00:00, 295.50it/s]
10042it [00:25, 295.84it/s]
Partition explainer: 4it [01:32, 30.73s/it]
Process finished with exit code 0
AI助手

2.2 批量预测结果
The parameter 'pretrained' is deprecated since 0.13 and will be removed in 0.15, please use 'weights' instead.
Arguments other than a weight enum or `None` for 'weights' are deprecated since 0.13 and will be removed in 0.15. The current behavior is equivalent to passing `weights=MobileNet_V2_Weights.IMAGENET1K_V1`. You can also use `weights=MobileNet_V2_Weights.DEFAULT` to get the most up-to-date weights.
Classes: [132 814 746 42]: ['American_egret' 'speedboat' 'puck' 'agama']
Partition explainer: 25%|██▌ | 1/4 [00:00<?, ?it/s]
0%| | 0/9998 [00:00<?, ?it/s]
19%|█▉ | 1942/9998 [00:00<00:00, 13616.69it/s]
33%|███▎ | 3342/9998 [00:03<00:09, 708.17it/s]
39%|███▉ | 3942/9998 [00:05<00:10, 576.55it/s]
43%|████▎ | 4292/9998 [00:06<00:10, 521.64it/s]
45%|████▌ | 4542/9998 [00:07<00:11, 487.35it/s]
47%|████▋ | 4742/9998 [00:07<00:11, 460.08it/s]
49%|████▉ | 4892/9998 [00:08<00:11, 440.01it/s]
...
98%|█████████▊| 9842/9998 [00:24<00:00, 293.41it/s]
99%|█████████▉| 9892/9998 [00:25<00:00, 292.31it/s]
99%|█████████▉| 9942/9998 [00:25<00:00, 292.30it/s]
100%|█████████▉| 9992/9998 [00:25<00:00, 292.03it/s]
10042it [00:25, 291.35it/s]
Partition explainer: 5it [02:04, 31.11s/it]
Process finished with exit code 0
AI助手

全部评论 (0)
还没有任何评论哟~
