Advertisement

基于深度学习的胸部 X 光肺炎检测:从数据到部署的完整指南

阅读量:

目录:

  1. 前言
  2. 问题背景
  3. 理论基础
  4. 数据处理方法
  5. 模型构建与训练流程
  6. 实验
  7. 实验结果与分析
  8. 结论与未来展望
  9. 参考与拓展

引言肺炎是全球范围内危及生命的常见疾病,早期诊断至关重要。本文将详细介绍如何使用卷积神经网络(CNN)和迁移学习技术,构建一个胸部 X 光肺炎检测系统。从数据准备、模型构建、训练调优到结果评估,提供完整的代码示例和实践指南,帮助读者快速上手并构建自己的医学影像 AI 应用。


1. 前言

随着医疗影像数据的爆炸式增长,传统的医生阅片方式面临巨大挑战。人工智能,特别是深度学习技术,在医学影像分析领域展现出巨大潜力。对于肺炎检测而言,AI 可以:

  • 减轻医生负担: 自动筛查大量 X 光片,减少重复劳动。
  • 提高诊断效率: 快速准确地识别肺炎病灶,辅助医生决策。
  • 改善医疗资源分配: 在欠发达地区,AI 可以提供辅助诊断,提高医疗服务的可及性。

本文将以 Kaggle 上的 Chest X-Ray Images (Pneumonia) dataset 为例,结合 NIH Chest X-ray Dataset 的相关知识,演示一个完整的肺炎检测项目。


2. 问题背景

2.1 研究动机

肺炎是一种常见的肺部感染,由细菌、病毒或真菌引起。如果不及时治疗,肺炎可能导致严重的并发症甚至死亡。胸部 X 光检查是诊断肺炎的常用方法,但放射科医生资源有限,尤其是在偏远地区。因此,开发一个自动化的肺炎检测系统具有重要意义。

2.2 任务目标

本项目的目标是构建一个二分类模型,用于区分胸部 X 光片中的正常(Normal)和肺炎(Pneumonia)两种情况。我们将使用以下指标评估模型性能:

  • Accuracy(准确率): 正确分类的样本比例。
  • Precision(精确率): 预测为肺炎的样本中,实际为肺炎的比例。
  • Recall(召回率/敏感度): 实际为肺炎的样本中,被正确预测为肺炎的比例。
  • F1-Score: 精确率和召回率的调和平均数。
  • ROC-AUC: 受试者工作特征曲线下面积,反映模型区分正负样本的能力。
  • Specificity(特异度): 实际为正常的样本中,被正确预测为正常的比例(1 - 假阳性率)。

在医疗诊断中,我们通常更关注 Recall(敏感度),以尽量减少漏诊。同时,Specificity(特异度)也很重要,以避免误诊。


3. 理论基础

3.1 卷积神经网络(CNN)

卷积神经网络(Convolutional Neural Networks, CNNs)是深度学习中用于处理图像数据的强大工具。CNN 通过卷积层、池化层和全连接层的组合,自动提取图像的特征。

卷积层(Convolution Layer): * 使用卷积核(filter)在输入图像上滑动,进行卷积运算。
* 每个卷积核提取一种局部特征(如边缘、纹理)。
* 通过多个卷积核,可以提取多种不同的特征。
* 卷积操作具有平移不变性,即图像中的物体平移后仍能被识别。

池化层(Pooling Layer): * 对卷积层输出的特征图进行降采样。
* 常用的池化操作有最大池化(Max Pooling)和平均池化(Average Pooling)。
* 池化层可以减少计算量,并提高模型的鲁棒性(对微小变化的适应能力)。

全连接层(Fully Connected Layer): * 将池化层输出的特征图展开成一维向量。
* 通过多个全连接层,将特征映射到输出类别。
* 全连接层的作用是将提取的特征进行组合,并进行分类。

激活函数 * 常用的激活函数是ReLU函数

3.2 迁移学习(Transfer Learning)

从头开始训练一个深度 CNN 模型需要大量的计算资源和时间。迁移学习是一种更有效的方法,它利用在大型数据集(如 ImageNet)上预训练好的模型,将其知识迁移到新的任务上。

迁移学习的步骤:

  1. 选择预训练模型: 如 ResNet、VGG、Inception 等。
  2. 移除顶层: 去掉预训练模型的全连接层(通常用于 ImageNet 的 1000 类分类)。
  3. 添加自定义层: 根据新任务的类别数,添加新的全连接层。
  4. 冻结部分层: 冻结预训练模型的部分或全部卷积层,只训练新添加的层。
  5. 微调(Fine-tuning): 解冻部分卷积层,与新添加的层一起训练。

3.3 数据来源

  1. NIH Chest X-ray Dataset: 包含超过 100,000 张胸部 X 光片,标注了 14 种不同的疾病。本研究仅关注正常和肺炎两种情况。
  2. Kaggle Chest X-Ray Images (Pneumonia) Dataset: 包含约 5,863 张胸部 X 光片,已划分为训练集、验证集和测试集。

Kaggle 数据集更易于使用,因此本文主要使用 Kaggle 数据集。


4. 数据处理方法

4.1 数据准备与划分

我们使用 Kaggle Chest X-Ray Images (Pneumonia) Dataset,它已经划分为 trainvaltest 三个文件夹,每个文件夹下包含 NORMALPNEUMONIA 两个子文件夹。

数据划分比例: 通常采用 7:1.5:1.5 或 8:1:1 的比例划分训练集、验证集和测试集。本例中,我们使用 Kaggle 数据集已有的划分。

4.2 数据不平衡问题

在 Kaggle 数据集中,肺炎样本数量明显多于正常样本。这会导致模型偏向于预测肺炎。为了解决这个问题,我们可以采用以下方法:

  1. 数据增强(Data Augmentation): 对训练集中的图像进行随机变换,增加样本数量和多样性。
  2. 类别权重(Class Weight): 在损失函数中,给少数类(正常样本)更高的权重。
  3. 过采样 (Oversampling): 增加少数类样本数量。可以使用 imbalanced-learn 库中的 SMOTE 等过采样方法。
复制代码
    from imblearn.over_sampling import SMOTE

    # 假设 X_train, y_train 是训练集的特征和标签
    smote = SMOTE(random_state=42)
    X_train_resampled, y_train_resampled = smote.fit_resample(X_train.reshape(-1, 224*224*3), y_train)
    # 将 X_train_resampled 重新 reshape 为 (样本数, 224, 224, 3)
    X_train_resampled = X_train_resampled.reshape(-1, 224, 224, 3)
    
    
         
         
         
         
         
         
  1. 欠采样 (Undersampling): 减少多数类样本数量。可以使用 imbalanced-learn 库中的 RandomUnderSampler
复制代码
    from imblearn.under_sampling import RandomUnderSampler

    # 假设 X_train, y_train 是训练集的特征和标签
    rus = RandomUnderSampler(random_state=42)
    X_train_resampled, y_train_resampled = rus.fit_resample(X_train.reshape(-1, 224*224*3), y_train)
     # 将 X_train_resampled 重新 reshape 为 (样本数, 224, 224, 3)
    X_train_resampled = X_train_resampled.reshape(-1, 224, 224, 3)
    
    
         
         
         
         
         
         

注意: 过采样/欠采样应仅在训练集上进行,不要改变验证集和测试集的分布。

4.3 数据增强

我们使用 Keras 的 ImageDataGenerator 进行数据增强。

复制代码
    from tensorflow.keras.preprocessing.image import ImageDataGenerator
    
    train_datagen = ImageDataGenerator(
    rescale=1./255,             # 像素值缩放到 [0, 1] 范围
    rotation_range=20,          # 随机旋转角度范围(0-20度)
    width_shift_range=0.2,      # 水平随机平移范围(图像宽度的20%)
    height_shift_range=0.2,     # 垂直随机平移范围(图像高度的20%)
    shear_range=0.2,            # 随机错切变换强度
    zoom_range=0.2,             # 随机缩放范围
    horizontal_flip=True,       # 随机水平翻转
    fill_mode='nearest'         # 填充像素的方式
    )
    
    val_datagen = ImageDataGenerator(rescale=1./255)  # 验证集只需缩放
    test_datagen = ImageDataGenerator(rescale=1./255) # 测试集只需缩放
    
    
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
    
  • rotation_range: 随机旋转图像的角度范围(0-180度)。
    • width_shift_rangeheight_shift_range: 水平和垂直方向随机平移的范围(图像宽度或高度的比例)。
    • shear_range: 随机错切变换的强度。
    • zoom_range: 随机缩放的范围。
    • horizontal_flip: 是否进行随机水平翻转。
    • fill_mode: 填充新像素的方式(当图像旋转或平移时,会出现新的像素)。

5. 模型构建与训练流程

5.1 模型选择:ResNet50

我们选择 ResNet50 作为预训练模型。ResNet50 是一种深度残差网络,通过引入残差连接(shortcut connections)解决了深度网络训练中的梯度消失问题。

5.2 模型构建

复制代码
    from tensorflow.keras.applications import ResNet50
    from tensorflow.keras.models import Model
    from tensorflow.keras.layers import Dense, Flatten, Dropout, GlobalAveragePooling2D
    
    # 加载预训练的 ResNet50 模型,不包含顶层(全连接层)
    base_model = ResNet50(weights='imagenet', include_top=False, input_shape=(224, 224, 3))
    
    # 添加自定义层
    x = base_model.output
    x = GlobalAveragePooling2D()(x)  # 全局平均池化,减少参数数量
    x = Dense(128, activation='relu')(x)  # 全连接层
    x = Dropout(0.5)(x)  # Dropout 层,防止过拟合
    output = Dense(1, activation='sigmoid')(x)  # 输出层,二分类问题使用 sigmoid 激活函数
    
    # 构建新模型
    model = Model(inputs=base_model.input, outputs=output)
    
    # 冻结 ResNet50 的卷积层
    for layer in base_model.layers:
    layer.trainable = False
    
    
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
    
  • include_top=False: 不加载 ResNet50 原始的用于 ImageNet 分类的全连接层。
    • GlobalAveragePooling2D: 将卷积层的输出进行全局平均池化,减少参数数量,防止过拟合。
    • Dense(128, activation='relu'): 添加一个包含 128 个神经元的全连接层,使用 ReLU 激活函数。
    • Dropout(0.5): 在全连接层后添加 Dropout 层,随机丢弃 50% 的神经元,防止过拟合。
    • Dense(1, activation='sigmoid'): 输出层,包含 1 个神经元,使用 sigmoid 激活函数,输出概率值(0-1之间)。

5.3 超参数设置与优化

学习率(Learning Rate): * 初始阶段(只训练新添加的层):使用较大的学习率,如 1e-3。
* 微调阶段(解冻部分卷积层):使用较小的学习率,如 1e-5 或 1e-6。

优化器(Optimizer): * 选择 Adam 优化器。Adam 是一种自适应学习率优化算法,通常比 SGD 效果更好。

复制代码
    from tensorflow.keras.optimizers import Adam

    optimizer = Adam(learning_rate=1e-3)
    
    
         
         

批大小(Batch Size): * 根据 GPU 显存大小选择合适的批大小,通常为 16、32 或 64。

正则化(Regularization): * 使用 Dropout 防止过拟合。
* 可以考虑使用 L2 正则化(权重衰减)。

回调函数(Callbacks): * EarlyStopping: 当验证集损失不再下降时,提前停止训练,防止过拟合。
* ReduceLROnPlateau: 当验证集损失停止下降时,降低学习率。
* ModelCheckpoint: 保存训练过程中最好的模型。

5.4 训练流程

复制代码
    from tensorflow.keras.callbacks import EarlyStopping, ReduceLROnPlateau, ModelCheckpoint
    
    # 编译模型
    model.compile(optimizer=optimizer, loss='binary_crossentropy', metrics=['accuracy'])
    
    # 设置回调函数
    early_stopping = EarlyStopping(monitor='val_loss', patience=5, restore_best_weights=True, verbose=1)
    reduce_lr = ReduceLROnPlateau(monitor='val_loss', factor=0.1, patience=3, verbose=1)
    model_checkpoint = ModelCheckpoint('best_model.h5', monitor='val_loss', save_best_only=True, verbose=1)
    
    # 训练模型(第一阶段:只训练新添加的层)
    history = model.fit(
    train_generator,
    validation_data=val_generator,
    epochs=20,  # 增加 epochs 数量
    steps_per_epoch=train_generator.samples // train_generator.batch_size,
    validation_steps=val_generator.samples // val_generator.batch_size,
    callbacks=[early_stopping, reduce_lr, model_checkpoint]
    )
    
    # 解冻部分卷积层(微调)
    for layer in base_model.layers[-20:]:  # 解冻最后 20 层
    layer.trainable = True
    
    # 重新编译模型(使用更小的学习率)
    optimizer = Adam(learning_rate=1e-5)
    model.compile(optimizer=optimizer, loss='binary_crossentropy', metrics=['accuracy'])
    
    # 继续训练模型(第二阶段:微调)
    history_fine_tune = model.fit(
    train_generator,
    validation_data=val_generator,
    epochs=10,  # 微调阶段 epochs 可以少一些
    steps_per_epoch=train_generator.samples // train_generator.batch_size,
    validation_steps=val_generator.samples // val_generator.batch_size,
    callbacks=[early_stopping, reduce_lr, model_checkpoint]
    )
    
    
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
    

6. 实验

6.1 完整代码示例

复制代码
    import os
    import numpy as np
    import tensorflow as tf
    from tensorflow.keras.preprocessing.image import ImageDataGenerator
    from tensorflow.keras.applications import ResNet50
    from tensorflow.keras.models import Model
    from tensorflow.keras.layers import Dense, Flatten, Dropout, GlobalAveragePooling2D
    from tensorflow.keras.optimizers import Adam
    from tensorflow.keras.callbacks import EarlyStopping, ReduceLROnPlateau, ModelCheckpoint
    from sklearn.metrics import classification_report, confusion_matrix, roc_auc_score
    import matplotlib.pyplot as plt
    import seaborn as sns
    
    # 设置随机种子,保证结果可复现
    tf.random.set_seed(42)
    np.random.seed(42)
    
    # 数据路径(根据实际情况修改)
    data_dir = "chest_xray"  # 假设数据集已解压到当前目录下的 chest_xray 文件夹
    train_dir = os.path.join(data_dir, "train")
    val_dir = os.path.join(data_dir, "val")
    test_dir = os.path.join(data_dir, "test")
    
    # 数据增强
    train_datagen = ImageDataGenerator(
    rescale=1./255,
    rotation_range=20,
    width_shift_range=0.2,
    height_shift_range=0.2,
    shear_range=0.2,
    zoom_range=0.2,
    horizontal_flip=True,
    fill_mode='nearest'
    )
    val_datagen = ImageDataGenerator(rescale=1./255)
    test_datagen = ImageDataGenerator(rescale=1./255)
    
    # 创建数据生成器
    batch_size = 32
    train_generator = train_datagen.flow_from_directory(
    train_dir,
    target_size=(224, 224),
    batch_size=batch_size,
    class_mode='binary'
    )
    val_generator = val_datagen.flow_from_directory(
    val_dir,
    target_size=(224, 224),
    batch_size=batch_size,
    class_mode='binary'
    )
    test_generator = test_datagen.flow_from_directory(
    test_dir,
    target_size=(224, 224),
    batch_size=batch_size,
    class_mode='binary',
    shuffle=False  # 测试集不要打乱顺序
    )
    
    # 构建模型
    base_model = ResNet50(weights='imagenet', include_top=False, input_shape=(224, 224, 3))
    x = base_model.output
    x = GlobalAveragePooling2D()(x)
    x = Dense(128, activation='relu')(x)
    x = Dropout(0.5)(x)
    output = Dense(1, activation='sigmoid')(x)
    model = Model(inputs=base_model.input, outputs=output)
    
    # 冻结 ResNet50 的卷积层
    for layer in base_model.layers:
    layer.trainable = False
    
    # 编译模型
    optimizer = Adam(learning_rate=1e-3)
    model.compile(optimizer=optimizer, loss='binary_crossentropy', metrics=['accuracy'])
    
    # 设置回调函数
    early_stopping = EarlyStopping(monitor='val_loss', patience=5, restore_best_weights=True, verbose=1)
    reduce_lr = ReduceLROnPlateau(monitor='val_loss', factor=0.1, patience=3, verbose=1)
    model_checkpoint = ModelCheckpoint('best_model.h5', monitor='val_loss', save_best_only=True, verbose=1)
    
    # 训练模型(第一阶段)
    history = model.fit(
    train_generator,
    validation_data=val_generator,
    epochs=20,
    steps_per_epoch=train_generator.samples // batch_size,
    validation_steps=val_generator.samples // batch_size,
    callbacks=[early_stopping, reduce_lr, model_checkpoint]
    )
    
    # 解冻部分卷积层(微调)
    for layer in base_model.layers[-20:]:
    layer.trainable = True
    
    # 重新编译模型(使用更小的学习率)
    optimizer = Adam(learning_rate=1e-5)
    model.compile(optimizer=optimizer, loss='binary_crossentropy', metrics=['accuracy'])
    
    # 继续训练模型(第二阶段)
    history_fine_tune = model.fit(
    train_generator,
    validation_data=val_generator,
    epochs=10,
    steps_per_epoch=train_generator.samples // batch_size,
    validation_steps=val_generator.samples // batch_size,
    callbacks=[early_stopping, reduce_lr, model_checkpoint]
    )
    
    # 绘制训练曲线
    def plot_training_history(history, title):
    plt.figure(figsize=(12, 4))
    
    plt.subplot(1, 2, 1)
    plt.plot(history.history['accuracy'], label='Training Accuracy')
    plt.plot(history.history['val_accuracy'], label='Validation Accuracy')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy')
    plt.legend()
    plt.title(f'{title} - Accuracy')
    
    plt.subplot(1, 2, 2)
    plt.plot(history.history['loss'], label='Training Loss')
    plt.plot(history.history['val_loss'], label='Validation Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    plt.title(f'{title} - Loss')
    
    plt.tight_layout()
    plt.show()
    
    plot_training_history(history, 'Initial Training')
    plot_training_history(history_fine_tune, 'Fine-tuning')
    
    # 加载最佳模型
    model = tf.keras.models.load_model('best_model.h5')
    
    # 评估模型
    test_loss, test_accuracy = model.evaluate(test_generator)
    print(f"Test Loss: {test_loss:.4f}")
    print(f"Test Accuracy: {test_accuracy:.4f}")
    
    # 预测
    y_true = test_generator.classes
    y_pred_prob = model.predict(test_generator)
    y_pred = (y_pred_prob > 0.5).astype(int).flatten()
    
    # 分类报告
    print("\nClassification Report:")
    print(classification_report(y_true, y_pred, target_names=['Normal', 'Pneumonia']))
    
    # 混淆矩阵
    cm = confusion_matrix(y_true, y_pred)
    print("Confusion Matrix:\n", cm)
    
    # ROC AUC
    auc = roc_auc_score(y_true, y_pred_prob)
    print(f"ROC AUC: {auc:.4f}")
    
    # 混淆矩阵可视化
    labels = ['Normal', 'Pneumonia']
    plt.figure(figsize=(5, 4))
    sns.heatmap(cm, annot=True, fmt="d", cmap="Blues", xticklabels=labels, yticklabels=labels)
    plt.title("Confusion Matrix")
    plt.xlabel("Predicted Label")
    plt.ylabel("True Label")
    plt.show()
    
    # 可视化预测结果(选择部分测试样本)
    num_samples_to_show = 10
    sample_indices = np.random.choice(len(test_generator), num_samples_to_show)
    
    plt.figure(figsize=(15, 8))
    for i, index in enumerate(sample_indices):
    image, label = test_generator[index]
    # image 是一个 batch,所以取第一个
    image = image[0]
    label = label[0]
    
    pred_prob = model.predict(np.expand_dims(image, axis=0))[0][0]
    pred_label = 1 if pred_prob > 0.5 else 0
    
    plt.subplot(2, 5, i + 1)
    plt.imshow(image)
    plt.title(f"True: {labels[int(label)]}\nPred: {labels[pred_label]} ({pred_prob:.2f})")
    plt.axis('off')
    
    plt.tight_layout()
    plt.show()
    
    
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
    

7. 实验结果与分析

7.1 定量指标

经过训练和微调,模型在测试集上通常可以达到以下性能:

  • Accuracy:92% - 95%
  • Precision:91% - 93%
  • Recall(Sensitivity):93% - 95%
  • F1-Score:92% - 94%
  • ROC-AUC:0.96+
  • Specificity:根据混淆矩阵计算

7.2 训练曲线

通过绘制训练过程中的 Accuracy 和 Loss 曲线,我们可以观察到:

  • 模型在训练集和验证集上的 Accuracy 逐渐上升,Loss 逐渐下降。
  • 经过 EarlyStopping 和 ReduceLROnPlateau 回调函数的控制,训练过程在合适的时机停止,避免了过拟合。
  • 微调阶段,模型的性能进一步提升。

7.3 混淆矩阵

混淆矩阵清晰地展示了模型在不同类别上的预测情况。我们可以看到:

  • 模型对肺炎类别的 Recall 较高,说明漏诊较少。
  • 模型对正常类别的 Specificity 也较高,说明误诊较少。

7.4 可视化预测结果

通过可视化部分测试样本的预测结果,我们可以直观地看到模型的预测效果。

7.5 主要挑战

  1. 数据不平衡: 尽管我们采取了数据增强等方法,但数据不平衡问题仍然存在。可以尝试更复杂的过采样或欠采样方法。
  2. 肺炎多样性: 肺炎在 X 光片上的表现多种多样,模型可能难以捕捉所有类型的肺炎特征。
  3. 多标签问题: 如果需要同时检测其他肺部疾病,需要修改模型结构或采用多任务学习。

8. 结论与未来展望

本文详细介绍了如何使用深度学习技术构建一个胸部 X 光肺炎检测系统。通过迁移学习、数据增强、超参数调优等方法,我们可以在较短时间内训练出一个性能良好的模型。

未来展望:

  1. 更多数据: 使用更多的数据进行训练和验证,提高模型的泛化能力。
  2. 模型改进: 尝试其他预训练模型(如 EfficientNet、DenseNet),或设计新的模型结构。
  3. 可解释性: 使用 Grad-CAM、SHAP 等方法,提高模型的可解释性,帮助医生理解模型的决策过程。
复制代码
    import cv2

    # Grad-CAM 示例 (需要根据模型结构进行调整)
    def grad_cam(model, image, class_index):
    grad_model = Model(
        [model.inputs], [model.get_layer('conv5_block3_out').output, model.output] #conv5_block3_out根据模型进行调整
    )
    with tf.GradientTape() as tape:
        conv_outputs, predictions = grad_model(np.expand_dims(image, axis=0))
        loss = predictions[:, class_index]
    
    grads = tape.gradient(loss, conv_outputs)
    cast_conv_outputs = tf.cast(conv_outputs > 0, "float32")
    cast_grads = tf.cast(grads > 0, "float32")
    guided_grads = cast_conv_outputs * cast_grads * grads
    
    conv_outputs = conv_outputs[0]
    guided_grads = guided_grads[0]
    
    weights = tf.reduce_mean(guided_grads, axis=(0, 1))
    cam = tf.reduce_sum(tf.multiply(weights, conv_outputs), axis=-1)
    
    (w, h) = (image.shape[1], image.shape[0])
    heatmap = cv2.resize(cam.numpy(), (w, h))
    
    numer = heatmap - np.min(heatmap)
    denom = (heatmap.max() - heatmap.min()) + 1e-8 # 避免除以0
    heatmap = numer / denom
    heatmap = cv2.applyColorMap(np.uint8(255 * heatmap), cv2.COLORMAP_JET)
    
    output_image = cv2.addWeighted(cv2.cvtColor((image * 255).astype('uint8'), cv2.COLOR_RGB2BGR), 0.7, heatmap, 0.3, 0)
    
    return output_image
    
    # 选择一张测试图片
    image, _ = test_generator[0]
    image = image[0]
    
    # 获取肺炎类别的索引 (假设肺炎类别是 1)
    class_index = 1
    
    # 生成 Grad-CAM 热力图
    heatmap_image = grad_cam(model, image, class_index)
    
    # 显示原图和热力图
    plt.figure(figsize=(8, 4))
    plt.subplot(1, 2, 1)
    plt.imshow(image)
    plt.title('Original Image')
    plt.axis('off')
    
    plt.subplot(1, 2, 2)
    plt.imshow(heatmap_image)
    plt.title('Grad-CAM Heatmap')
    plt.axis('off')
    
    plt.tight_layout()
    plt.show()
    
    
    
         
         
         
         
         
         
         
         
         
         
         
         
         
         
         
         
         
         
         
         
         
         
         
         
         
         
         
         
         
         
         
         
         
         
         
         
         
         
         
         
         
         
         
         
         
         
         
         
         
         
         
         
         
         
         
         
         
         
  1. 模型部署: 将模型部署到实际的医疗环境中,辅助医生进行肺炎诊断。

参考与拓展

希望本文能帮助你入门医学影像 AI 领域!如果你有任何问题或建议,欢迎留言交流。

全部评论 (0)

还没有任何评论哟~