Advertisement

实战3:基于EfficientNetV2网络实现乳腺 X 光图像乳腺癌检测

阅读量:

项目概述:

本次比赛的主要目的是识别乳腺癌。采用基于定期筛查所得的乳腺癌X线影像数据来训练您的模型。

数据介绍:

注意:本研究涉及女性受试者的放射学乳房影像数据。

  • site_id为来源医院的标识符。
    • patient_id记录患者的唯一标识符。
    • image_id指定图像的具体编号。
    • laterality标识目标图像所处乳房位置(左/右)。
    • view指明成像设备采集的方向参数,默认为双侧检查。
    • age代表患者的年龄值(以整数年表示)。
    • implant说明患者是否接受过乳房植入物治疗(仅在站点1记录群体水平植入情况)。
    • density评估乳房组织密度等级(A=最低、D=最高),高密度区域诊断难度增加。
    • machine_id对应使用的成像设备型号编码信息。
    • cancer标记乳腺癌诊断结果(阳性与否)。
    • biopsy确认后续病理检查是否已经完成。
    • invasive针对乳腺癌原位或侵袭性特征进行标记判断(仅在阳性情况下适用)。
    • BIRADS评估随访难度标准:阴性=1、正常=2、待评估=0。
    • prediction_id分配唯一预测编号(多个图像共享同一编号)。
    • difficult_negative_case标记异常病例判定情况(True/False)。

代码教程:

导入库:

复制代码
 import numpy as np

    
 import pandas as pd
    
 import matplotlib.pyplot as plt
    
 import seaborn as sns
    
 sns.set_style('darkgrid')
    
  
    
 # Manage files
    
 import pydicom
    
 from os import listdir
    
  
    
 from scipy.stats import mode
    
  
    
 # Others
    
 import warnings
    
 warnings.filterwarnings("ignore", category=FutureWarning)
    
  
    
 !pip install -qU python-gdcm pydicom pylibjpeg
    
    
    
    
    代码解读
复制代码
 filepath = 'input/rsna-breast-cancer-detection/train.csv'

    
 data = pd.read_csv(filepath)
    
 data.head()
    
    
    
    
    代码解读

空值数据统计:

复制代码
    data.isna().sum()
    
    代码解读
复制代码
 site_id                        0

    
 patient_id                     0
    
 image_id                       0
    
 laterality                     0
    
 view                           0
    
 age                           37
    
 cancer                         0
    
 biopsy                         0
    
 invasive                       0
    
 BIRADS                     28420
    
 implant                        0
    
 density                    25236
    
 machine_id                     0
    
 difficult_negative_case        0
    
 dtype: int64
    
    
    
    
    代码解读
复制代码
 num_patients = data['patient_id'].nunique()

    
 min_patient_age = int(data['age'].min())
    
 max_patient_age = int(data['age'].max())
    
 groupby_id = data.groupby('patient_id')['cancer'].apply(lambda x: x.unique()[0])
    
 n_negative = (groupby_id == 0).sum()
    
 n_positive = (groupby_id == 1).sum()
    
  
    
 print(f"There are {num_patients} different patients in the train set.\n")
    
 print(f"The younger patient is {min_patient_age} years old.")
    
 print(f"The older patient is {max_patient_age} years old.\n")
    
 print(f"There are {n_negative} patients negative to breast cancer. Ratio = {n_negative / num_patients}")
    
 print(f"There are {n_positive} patients positive to breast cancer. Ratio = {n_positive / num_patients}")
    
    
    
    
    代码解读
复制代码
复制代码
 ages = data.groupby('patient_id')['age'].apply(lambda x: x.unique()[0])

    
 cancer_ages = data[data['cancer'] == 1].groupby('patient_id')['age'].apply(lambda x: x.unique()[0])
    
 no_cancer_ages = data[data['cancer'] == 0].groupby('patient_id')['age'].apply(lambda x: x.unique()[0])
    
  
    
 plt.figure(figsize=(16, 10))
    
  
    
 plt.subplot(1, 2, 1)
    
 sns.histplot(ages, bins=63, color='orange', kde=True)
    
 plt.title("All the patient")
    
 plt.xlim(33, 89)
    
  
    
 plt.subplot(2, 2, 2)
    
 sns.histplot(cancer_ages, bins=51, color='red', kde=True)
    
 plt.title("Patients with cancer")
    
 plt.xlim(33, 89)
    
  
    
 plt.subplot(2, 2, 4)
    
 sns.histplot(no_cancer_ages, bins=63, color='green', kde=True)
    
 plt.title("Patients without cancer")
    
 plt.xlim(33, 89)
    
  
    
 plt.suptitle("Age distribution of the patients")
    
 plt.show()
    
    
    
    
    代码解读
复制代码
 print("Mean:", ages.mean())

    
 print("Std:", ages.std())
    
 print("Q1:", ages.quantile(0.25))
    
 print("Median:", ages.median())
    
 print("Q3:", ages.quantile(0.75))
    
 print("Mode:", ages.mode()[0])
    
    
    
    
    代码解读
复制代码
复制代码
 n_images_per_patient = data['patient_id'].value_counts()

    
 plt.figure(figsize=(16, 6))
    
 sns.countplot(n_images_per_patient, palette='Reds_r')
    
 plt.title("Number of images taken per patients")
    
 plt.xlabel('Number of images taken')
    
 plt.ylabel('Count of patients')
    
 plt.show()
    
    
    
    
    代码解读
复制代码
 fig, ax = plt.subplots(2, 3, figsize=(16, 8))

    
 sns.countplot(data['laterality'], palette='Blues_r', ax=ax[0, 0])
    
 sns.countplot(data['implant'], palette='Greens_r', ax=ax[0, 1])
    
 sns.countplot(data['difficult_negative_case'], palette='Reds_r', ax=ax[0, 2])
    
 sns.countplot(data['view'], palette='Oranges_r', ax=ax[1, 0])
    
 sns.countplot(data['density'], palette='Purples_r', order=['A', 'B', 'C', 'D'], ax=ax[1, 1])
    
 sns.countplot(data['site_id'], palette='Greys_r', ax=ax[1, 2])
    
 plt.show()
    
    
    
    
    代码解读
复制代码
 biopsy_counts = data.groupby('cancer')['biopsy'].value_counts().unstack().fillna(0)

    
 biopsy_perc = biopsy_counts.transpose() / biopsy_counts.sum(axis=1)
    
  
    
 fig, ax = plt.subplots(1, 2, figsize=(10, 4))
    
 sns.countplot(data['cancer'], palette='Greens', ax=ax[0])
    
 sns.heatmap(biopsy_perc, square=True, annot=True, fmt='.1%', cmap='Blues', ax=ax[1])
    
 ax[0].set_title("Number of images showing cancer")
    
 ax[1].set_title("Percentage of images\nresulting in a biopsy")
    
 plt.show()
    
    
    
    
    代码解读
复制代码
 fig, ax = plt.subplots(1, 3, figsize=(16, 4))

    
 sns.countplot(data[data['cancer'] == True]['invasive'], ax=ax[0], palette='Reds')
    
 sns.countplot(data[data['cancer'] == False]['BIRADS'], order=[0, 1, 2], ax=ax[1], palette='Blues')
    
 sns.countplot(data[data['cancer'] == True]['BIRADS'], order=[0, 1, 2], ax=ax[2], palette='Blues')
    
 ax[0].set_title("Count of invasive cancer images")
    
 ax[1].set_title("BIRADS for healthy images")
    
 ax[2].set_title("BIRADS for cancer images")
    
 plt.show()
    
    
    
    
    代码解读
复制代码
 plt.figure(figsize=(14, 5))

    
 sns.countplot(data['machine_id'])
    
 plt.title("Count of images taken by machine ID")
    
 plt.show()
    
    
    
    
    代码解读
复制代码
 train_path = 'input/rsna-breast-cancer-detection/train_images'

    
 test_path = 'input/rsna-breast-cancer-detection/test_images'
    
    
    
    
    代码解读
复制代码
 def load_patient_scans(path, patient_id):

    
     patient_path = path + '/' + str(patient_id)
    
     return [pydicom.dcmread(patient_path + '/' + file) for file in listdir(patient_path)]
    
    
    
    
    代码解读
复制代码
 fig, ax = plt.subplots(1, 2, figsize=(20, 5))

    
 im = ax[0].imshow(scans[0].pixel_array, cmap='bone')
    
 ax[0].grid(False)
    
 fig.colorbar(im, ax=ax[0])
    
 sns.histplot(scans[0].pixel_array.flatten(), ax=ax[1], bins=50)
    
 plt.show()
    
    
    
    
    代码解读
复制代码
 def get_scan_info():

    
     modes, rows, cols = [], [], []
    
     machine_ids = data['machine_id'].unique()
    
     for m_id in machine_ids:
    
     m_id_modes, m_id_rows, m_id_cols = [], [], []
    
     print(f"Machine id {m_id} in progress")
    
     patient_ids = data[data['machine_id'] == m_id]['patient_id'].unique()
    
     for n in range(50):
    
         try:
    
             scan = load_patient_scans(train_path, patient_ids[n])[0]
    
             m_id_modes.append(mode(scan.pixel_array.flatten())[0][0])
    
             m_id_rows.append(scan.Rows)
    
             m_id_cols.append(scan.Columns)
    
         except IndexError:
    
             break
    
     modes.append(m_id_modes)
    
     rows.append(m_id_rows)
    
     cols.append(m_id_cols)
    
     return modes, rows, cols
    
    
    
    
    代码解读
复制代码
 modes, rows, cols = get_scan_info()

    
  
    
 machine_ids = data['machine_id'].unique()
    
 medians = [np.median(x) for x in modes]
    
 stds = [np.std(x) for x in modes]
    
 rows = [np.mean(x) for x in rows]
    
 cols = [np.mean(x) for x in cols]
    
 df = pd.DataFrame(data={'Machine ID': machine_ids, 'Mode (median)': medians, 'Mode (std)': stds, 'Rows (mean)': rows, 'Cols (mean)': cols})
    
 df.astype(int).set_index('Machine ID').T
    
    
    
    
    代码解读
复制代码

根据不同的机器ID值进行分析时会发现像素分布有所差异。实际上,在模式识别中所依据的特征与背景像素特征存在显著关联性。通常情况下,在大多数设备中该参数会被归一化处理后设定为零值状态;但值得注意的是对于某些特定设备(如ID 29和210)其数值分别达到了3000和1000以上水平;此外所有设备都需遵循光度标准设定下的参数规范性要求。在图像处理领域中存在两种主要类型:一种是在单色图像中采用灰度编码方式其中较高数值代表较亮像素而较低数值对应较暗像素;另一种则是在单色图像中采用反向灰度编码其中较高数值对应较暗像素而较低数值则代表亮度较高的区域这可能与特定应用需求有关联关系需要具体场景下加以判断以便获得最佳效果。此外图像分辨率与其所属设备型号之间也存在明显差异;例如设备29和210具有远高于其他型号(如197和216)的分辨率水平这使得在训练阶段选择合适的参数设置显得尤为重要以确保模型能够适应不同规格的输入数据从而实现较高的准确率

复制代码
 plt.figure(figsize=(22, 8))

    
 for i, m_id in enumerate(machine_ids):
    
     patient_ids = data[data['machine_id'] == m_id]['patient_id'].unique()
    
     scan = load_patient_scans(train_path, patient_ids[0])[0] # Load first scan of first patient
    
     plt.subplot(2, 5, i+1)
    
     plt.imshow(scan.pixel_array, cmap='bone')
    
     plt.title(f"Machine {m_id}")
    
     plt.colorbar()
    
     plt.grid(False)
    
 plt.show()
    
    
    
    
    代码解读

在该数据集里,包含了具有假体的哺乳部位影像学检查。观察这些假体及其与其他影像学检查的对比可能非常有趣。

复制代码
 m_id_implants = data[data['implant'] == 1]['machine_id'].unique()

    
 print("Scans showing implents are from machines", m_id_implants)
    
    
    
    
    代码解读
复制代码
 patient_ids = data[data['implant'] == 1]['patient_id'].unique()

    
  
    
 # Display scans showing implants
    
 plt.figure(figsize=(22, 8))
    
 for i in range(10):
    
     scan = load_patient_scans(train_path, patient_ids[i])[0] # Load first scan of the patient
    
     plt.subplot(2, 5, i+1)
    
     plt.imshow(scan.pixel_array, cmap='bone')
    
     plt.title(f"Patient {patient_ids[i]}")
    
     plt.grid(False)
    
 plt.show()
    
    
    
    
    代码解读
复制代码
 plt.figure(figsize=(22, 8))

    
 scans = load_patient_scans(train_path, 13095)
    
 for i in range(10):
    
     plt.subplot(2, 5, i+1)
    
     plt.imshow(scans[i].pixel_array, cmap='bone')
    
     plt.grid(False)
    
 plt.suptitle("All scans of patient 13095")
    
 plt.show()
    
    
    
    
    代码解读
复制代码
 def display_cancer_or_not(cancer=True):

    
     cancer_scans = data[data['cancer'] == int(cancer)].sample(frac=1, random_state=0)
    
     plt.figure(figsize=(22, 10))
    
     for i in range(10):
    
     patient = str(cancer_scans.iloc[i][['patient_id']][0])
    
     file = str(cancer_scans.iloc[i][['image_id']][0]) + '.dcm'
    
     scan = pydicom.dcmread(train_path + '/' + patient + '/' + file)
    
     plt.subplot(2, 5, i+1)
    
     plt.imshow(scan.pixel_array, cmap='bone')
    
     plt.title(f"Patient {patient}\nScan {file}")
    
     plt.grid(False)
    
     plt.suptitle(f"Cancer = {cancer}")
    
     plt.show()
    
    
    
    
    代码解读
复制代码
    display_cancer_or_not(cancer=True)
    
    代码解读

后续待更新。。。。。。

全部评论 (0)

还没有任何评论哟~