Advertisement

【ML】基于机器学习的心脏病预测研究(附代码和数据集,XGBoost模型)

阅读量:

写在前面:
首先致谢广大粉丝与读者的关注与支持!正是这份动力促使我在创作过程中力求精雕细琢。无论是日常更新还是特别策划期的制作工作,我都将不懈努力以确保每一期内容都能达到品质要求。如有任何问题或建议,请随时通过私信与我联系,我们将会秉持开放的态度进行沟通与改进。此外,值得注意的是,目前我们的服务仅限于基础用户群体,-tiered VIP会员体系中尚未包含的服务范围也将在未来逐步扩展中呈现给各位读者。如果您对本专栏的内容不感兴趣或者有其他需求,欢迎随时沟通讨论!

ML

本次实战项目的主题是基于机器学习的心脏病预测研究,并提供相关的代码及数据集。其中涵盖了ROC曲线绘制以及PR曲线分析。这一份实用资料值得收藏和深入研究。

心 病 是 人 类 健 康 的 第 一 大 挡 次 ,据世界卫生组织统计数据显示 ,全球大约有三分之一的人口会因心 病 而 死亡 。我国 每 年 大 约 有 数 十万人口因心 病 而 死亡 。如果 我们 可以 提取人 体 相关的指 标 (如既往病史、家族病史、血压情况、血糖情况等),通 过 数据挖 掖 方 式 来 分析不 同 特征 对 心 病 的 影响 ,或者 建立电子病历 ,收集数 �据 集 并 建立预 测模 型 ,将会在 预防心 病 方面 发挥 至 关重 要的作 用 。

本项目基于数据分析、数据挖掘,根据疾病的特征预测是否患有心脏病。

1. 项目介绍

该任务采用机器学习方法进行分类研究。基于患者的多个特征变量进行预测分析。数据集源自UCI机器学习数据库平台。

UCI机器学习库中,一共包含4个关于心脏病诊断的数据集,分别是:

1、cleveland.data

2、hungarian.data

3、long-beach-va.data

4、switzerland.data

任何一个数据集都拥有76个属性;然而,在现有的公开实验中仅采用了其中的14个属性进行建模训练。特别值得注意的是cleveland.data这一特定数据集,在机器学习领域具有广泛的代表性。

2. 数据获取

关注微信公众号:AIexplore,私信【我的心脏】,即可获取。

在这里插入图片描述

3. 数据介绍

我们从Kaggle网站收集了数据集,并包含了总共1025条数据集信息。

在这里插入图片描述

4. 实验

4.1 数据信息

导入数据分析常用库:

复制代码
    # 导入数据分析常用库
    import numpy as np
    import pandas as pd
    import matplotlib.pyplot as plt
    import seaborn as sns
    
    
      
      
      
      
      
    
    代码解读

导入数据:

复制代码
    # 读取数据
    data_ori= pd.read_csv("heart.csv")
    
    
      
      
    
    代码解读

查看数据:

复制代码
    # 查看所有数据
    data_ori
    
    
      
      
    
    代码解读
在这里插入图片描述

查看前10条数据:

复制代码
    # 查看所有数据
    data_ori.head(10)
    
    
      
      
    
    代码解读
在这里插入图片描述

查看数据基本信息:

复制代码
    data_ori.info()
    
    
      
    
    代码解读
在这里插入图片描述

小结:

该研究采用了包含共约776,896行的大型中文语料库作为基础语料。
每个样本包含多个上下文窗口以及其对应的标注信息。
其中,在这776,896行中包含了大约78,958种不同的实体名词和近296,487种不同的动词短语。
整个语料库覆盖了包括社会关系、时间关系、空间关系等多个主要语义维度。

target取值为0或1
0:没有心脏病
1:患有心脏病

4.2 数据可视化分析

4.2.1 特征之间的相关关系

复制代码
    # 查看特征之间的相关关系
    plt.figure(figsize=(12,10))
    corr =  data_ori.corr() # 皮尔逊相关系数
    sns.heatmap(data=corr,annot=True, square=True,fmt='.3f', cmap='PuBu')
    plt.show()
    
    
      
      
      
      
      
    
    代码解读
在这里插入图片描述

小结:

  1. 这些指标包括心绞痛类型(cp)、最大心率(thalach)、运动诱发的心绞痛(exang)以及ST段变化量(oldpeak),它们与心脏病存在较强的关联关系
  2. 这些因素包括年龄(age)、性别(sex)、ST波峰值变化程度(slope)、心脏周围血管数量(ca)以及地中海贫血症状(thal),它们与心脏病之间显示出一定的关联
  3. 这两个指标即胆固醇含量(chol)和空腹血糖水平(fbs),它们之间的相关性相对较弱

4.2.2 从整体看下心脏病患病情况

复制代码
    colors=['tomato','lightskyblue']
    
    countNoDisease = len(data_ori[data_ori['target']==0])
    countHaveDisease = len(data_ori[data_ori['target']==1])
    total =  len(data_ori['target'])
    rateNo =  countNoDisease/total*100
    rateHave = countHaveDisease/total*100
    diseaseRate =  pd.Series({'正常':countNoDisease, '患病': countHaveDisease})
    print("rateNo",rateNo)
    print("rateHave",rateHave)
    
    
      
      
      
      
      
      
      
      
      
      
    
    代码解读

rateNo 48.68292682926829
rateHave 51.31707317073171

复制代码
    # 绘图
    plt.bar(diseaseRate.index,diseaseRate.values,color='lightskyblue')
    # 以下2行解决windows下中文不显示问题
    plt.rcParams['font.sans-serif']=['simHei']
    plt.rcParams['axes.unicode_minus']=False
    plt.title('患病和正常人群分布',fontsize=14)
    plt.xlabel('是否患病',fontsize=12)
    plt.ylabel('人数',fontsize=12)
    plt.ylim([0,600])
    
    # 添加数据标签
    for a,b in zip([0,1],[countNoDisease, countHaveDisease]):
    plt.text(a,b+4,b,ha='center',fontsize=12)
    plt.show()
    
    
      
      
      
      
      
      
      
      
      
      
      
      
      
      
    
    代码解读
在这里插入图片描述

小结:
未患心脏病的人数为:499;未患心脏病的比例达到48.68%;患有心脏疾病的患者人数为526;患有心脏疾病的患者所占比例达到51.32%。

4.2.3 研究心脏病和年龄之间的关系

复制代码
    # 由于年龄过于分散,将年龄拆分成不同段
    '''
    pandas.cut()用来把一组数据分割成离散的区间。
    比如有年龄数据,可以使用pandas.cut将年龄数据分割成不同的年龄段并打上标签。
    pandas.cut(x, bins, right=True, include_lowest=False, labels=None )
    参数介绍: x:被拆分的数据数组,必须是一维的,不能用DataFrame 
          bins:被切割后的区间,比如bins=[1,2,3],则区间为(1,2)和(2,3) 
          right=True,默认包含右区间,改为False,则不包含 
          include_lowest:bool型的参数,表示区间的左边是开还是闭的,默认为false
          labels:给分割后的bins打标签,长度必须和划分后的区间长度相等,
                   比如bins=[1,2,3],划分后有2个区间,则labels的长度必须为2。
    '''
    
    # 定义一个DataFrame接收分割后的年龄
    age_df = pd.DataFrame()
    
    # 用pd.cut() 将年龄进行分割
    age_df['age_range']  = pd.cut(x = data_ori['age'],
                                  bins = [0,18,40,60,100],
                                  include_lowest = True,right=False,
                                  labels = ['儿童','青年','中年','老年'])
    
    # 将原数据集的target合并到age_distDf中
    age_df = pd.concat([age_df['age_range'],data_ori['target']],axis=1)
    age_df.head(10)
    
    
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
    
    代码解读
在这里插入图片描述
复制代码
    # 画柱状图,观察不同年龄段的人患心脏病的情况
    sns.countplot(data=age_df,x='age_range',hue='target',palette='Set2')
    plt.xlabel('年龄段')
    plt.show()
    
    
      
      
      
      
    
    代码解读
在这里插入图片描述

小结:
这个数据集的人群集中在中年这个年龄段,且中年人患病的较多

复制代码
    # 绘制饼图比较不同年龄段人群患病情况
    plt.figure(figsize = (3*5,1*4))
    
    # 青年人患病比例
    ax1 = plt.subplot(1,3,1)
    youth =  age_df[age_df['age_range']=='青年']['target'].value_counts()
    plt.pie(youth,explode=[0,0.05],autopct='%.2f%%',labels=['患病','正常'],colors=colors)
    plt.title('青年人患病比例')
    
    # 中年人患病比例
    ax2 =  plt.subplot(1,3,2)
    middle = age_df[age_df['age_range']=='中年']['target'].value_counts()
    plt.pie(middle,explode=(0,0.05),autopct='%.2f%%',labels=['患病','正常'],colors=colors)
    plt.title('中年人患病比例')
    
    # 老年人患病比例
    ax2 = plt.subplot(1,3,3)
    old =  age_df[age_df['age_range']=='老年']['target'].value_counts()
    plt.pie(old,explode=[0,0.05],autopct='%.2f%%',labels=['患病','正常'],colors=colors)
    plt.title('老年人患病比例')
    plt.show()
    
    
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
    
    代码解读
在这里插入图片描述

小结:
青年人群中的患病情况普遍高于老年人群,在很大程度上与其长期承受过大的心理压力、频繁熬夜以及缺乏规律性的身体活动等因素密切相关。

4.2.4 心脏病和性别之间的关系

复制代码
    sex_df = data_ori[['sex','target']]
    plt.figure(figsize=(7, 5))
    sns.countplot(data = sex_df, x = 'sex', hue='target', palette=colors)
    plt.title('不同性别人群患病情况',fontsize=14)
    plt.xlabel('性别(0=女性,1=男性)',fontsize=12)
    plt.ylabel('人数',fontsize=12)
    plt.show()
    
    
      
      
      
      
      
      
      
    
    代码解读
在这里插入图片描述

小结:
女性患心脏病的比例高于男性(注意是比例)在这里插入图片描述

4.2.5 心脏病和cp(心绞痛类型)之间的关系

复制代码
    cp_df = data_ori[['cp','target']]
    sns.countplot(data = cp_df, x ='cp',hue='target',palette='Set2')
    plt.xlabel('胸痛类型\n(0=典型心绞痛;1=非典型心绞痛;2=非心绞痛;3=没有症状)',fontsize=12)
    plt.ylabel('人数',fontsize=12)
    plt.show()
    
    
      
      
      
      
      
    
    代码解读
在这里插入图片描述

小结:
只有典型的单纯性心绞痛患者的发病风险较低单纯的单纯性心绞痛并不构成急性冠脉综合征而1型至3型的心绞疼痛者其发病风险显著增加这表明心脏 疾病 与 心 眥 疼 痛 类 型 之间存在一定关联

4.2.6 运动引起的心绞痛(exang)和心脏病之间的关系

复制代码
    exang_df = data_ori[['exang','target']]
    sns.countplot(data=exang_df, x='exang',hue='target',palette='Set2')
    plt.xlabel('运动是否引起心绞痛\n(0=否,1=是)',fontsize=12)
    plt.ylabel('人数',fontsize=12)
    plt.show()
    
    
      
      
      
      
      
    
    代码解读
在这里插入图片描述

小结:
运动导致心绞痛被确诊为心脏病的可能性较低。仅限于心绞痛现象,则可能源于其他疾病。但如果本身患有心脏病,则需遵医嘱进行休息调养。过量运动同样可能导致心绞痛。

4.2.7 心脏病和最大心率(thalach)之间的关系

复制代码
    # 心率和年龄也有一定的关系,可以结合考察心脏病,心率,年龄
    thalch_df = data_ori[['thalach','target']]
    thalch_df['age_range'] = pd.cut( data_ori['age'], bins=[0,18,40,60,100],
                                 labels=['儿童','青年','中年','老年'],
                                 include_lowest=True,right=False)
    
    
      
      
      
      
      
    
    代码解读
在这里插入图片描述
复制代码
    sns.swarmplot(data =thalch_df, x='age_range',y='thalach', hue='target')
    plt.xlabel('年龄段', fontsize=12)
    plt.ylabel('最大心率',fontsize=12)
    plt.show()
    
    
      
      
      
      
    
    代码解读
在这里插入图片描述

小结:
最大心率会伴随年龄的增长而下降。 该年龄段内的心脏病患者通常具有较高的心率水平。

4.2.8 静息血压和心脏病的关系

复制代码
    trestbps_df = data_ori[['trestbps','target']]
    sns.boxplot(trestbps_df['target'],trestbps_df['trestbps'])
    plt.xlabel('是否患心脏病(0=否,1=是)',fontsize=12)
    plt.ylabel('人数',fontsize=12)
    plt.title('心脏病与静息血压的关系')
    plt.show()
    
    
      
      
      
      
      
      
    
    代码解读
在这里插入图片描述

小结:
健康人的静息血压通常略高于心脏病患者。 在考察不同特征之间的两两组合关系时涉及多种可能性,在此不做深入探讨后文将进一步应用多种机器学习算法来评估疾病发生的风险。

4.3 数据预处理

非离散型分类数据的编码(cp、restecg、slope、thal)
通过使用get_dummies()函数对这些离散型特征进行编码。

复制代码
    # 非连续性分类数据处理(cp,restecg,slope,thal)
    # 采用get_dummies()编码方式处理非连续性分类数据
    cp_dummies= pd.get_dummies(data_ori['cp'],prefix = 'cp')
    restecg_dummies =  pd.get_dummies(data_ori['restecg'],prefix='restecg')
    slope_dummies =  pd.get_dummies(data_ori['slope'],prefix='slope')
    thal_dummies = pd.get_dummies(data_ori['thal'],prefix='thal')
    
    # 将原数据中经过独热编码的列删除
    data_ori_new =  data_ori.drop(['cp','restecg','slope','thal'],axis=1)
    data_ori_new = pd.concat([data_ori_new,cp_dummies,restecg_dummies,slope_dummies,thal_dummies],axis=1)
    data_ori_new.head(10)
    
    
      
      
      
      
      
      
      
      
      
      
      
    
    代码解读
在这里插入图片描述

4.4 数据划分

复制代码
    # 拆分数据集和目标集
    # 分离出数据和标签
    label =  data_ori_new['target']
    data =  data_ori_new.drop('target',axis=1)
    data.shape
    # (1025, 23)
    # 数据集合的不同特征之间数据相差有点大,对于SVM、KNN等算法,会产生权重影响,因此需要标准化处理数据
    from sklearn.preprocessing import StandardScaler
    standardScaler = StandardScaler()
    standardScaler.fit(data)
    data =  standardScaler.transform(data)
    
    # 拆分训练集,测试集
    from sklearn.model_selection import train_test_split
    train_X,test_X,train_y,test_y = train_test_split(data,label,random_state=3)
    train_X.shape
    # (768, 23)
    
    
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
    
    代码解读

4.5 XGBoost模型

本次实战,我们使用XGBoost模型拟合数据集

复制代码
    # -*- coding: utf-8 -*-
    """
    Created on Mon Mar 20 14:07:58 2023
    
    @author: augustqi
    """
    
    
    import pandas as pd
    import numpy as np
    from matplotlib import pyplot as plt
    from sklearn.model_selection import train_test_split
    from xgboost import XGBClassifier
    from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score
    from sklearn.metrics import roc_curve, auc, precision_recall_curve # 多分类不绘制ROC曲线
    from sklearn.metrics import classification_report
    import joblib
    
    
    # 读取数据
    data_ori= pd.read_csv("heart.csv")
    
    # 特征X, 标签y
    X, y = data_ori.drop(['target'], axis=1), data_ori['target']
    
    # 划分数据集
    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3)
    
    # DataFrame -> Numpy
    X_train_np, X_test_np, y_train_np, y_test_np = np.array(X_train), np.array(X_test), np.array(y_train), np.array(y_test)
    
    # 模型
    model  = XGBClassifier(n_estimators=50)
    
    # 训练
    model.fit(X_train_np, y_train_np)
    
    # 查看拟合效果
    y_pred_train = model.predict(X_train)
    acc_train = accuracy_score(y_train_np, y_pred_train)
    print("acc train:", acc_train)
    
    # 保存模型
    joblib.dump(model, "model_xgb.pkl")
    
    
    # 测试
    # 加载训练好的模型
    model = joblib.load("model_xgb.pkl")
    # 预测
    y_pred = model.predict(X_test_np)
    y_pred_proba = model.predict_proba(X_test_np)
    y_score  = y_pred_proba[:,1]
    
    # 统计
    acc = accuracy_score(y_test_np, y_pred)
    pre = precision_score(y_test_np, y_pred)
    rec = recall_score(y_test_np, y_pred)
    f1 = f1_score(y_test_np, y_pred)
    
    print("accuracy:", acc)
    print("precision:", pre)
    print("recall:", rec)
    print("f1_score:", f1)
    
    # roc
    fpr, tpr, thresholds = roc_curve(y_test_np, y_score)
    
    # auc值, 间接计算
    auc_rf = auc(fpr, tpr)
    
    # auc值,直接计算
    auc_rf_2 = roc_auc_score(y_test_np, y_score)
    
    # classification_report
    res_report = classification_report(y_test_np, y_pred)
    
    
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
    
    代码解读

结果:

复制代码
    acc train: 1.0
    accuracy: 0.9805194805194806
    precision: 0.9803921568627451
    recall: 0.9803921568627451
    f1_score: 0.9803921568627451
    
              precision    recall  f1-score   support
    
           0       0.98      0.98      0.98       155
           1       0.98      0.98      0.98       153
    
    accuracy                           0.98       308
       macro avg       0.98      0.98      0.98       308
    weighted avg       0.98      0.98      0.98       308
    
    
    
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
    
    代码解读

绘制ROC曲线,PR曲线:

复制代码
    # ROC曲线
    plt.figure(figsize=(10,6))
    # plt.xlim(0, 1)     # 设定x轴的范围
    # plt.ylim(0, 1)     # 设定y轴的范围
    plt.title('ROC Curve')
    plt.xlabel('False Postive Rate')
    plt.ylabel('True Postive Rate')
    plt.rcParams['font.sans-serif']= ['Times New Roman'] # 设置字体
    plt.rcParams['xtick.direction']='in' # 设置刻度朝向
    plt.rcParams['ytick.direction']='in'
    plt.plot([0,1],[0,1], linewidth=1, linestyle="--", color='black')
    plt.plot(fpr,tpr,linewidth=2, linestyle="--",color='red', label='ROC(area={0:.4f})'.format(auc_rf))
    plt.legend(loc="lower right")
    plt.savefig("roc_curve.png", dpi=600)
    plt.show()
    
    # PR曲线
    precision, recall, thresholds_pr = precision_recall_curve(y_test_np, y_score)
    aupr = auc(recall, precision)
    plt.figure(figsize=(10,6))
    plt.title('PR Curve')
    plt.xlabel('Recall')
    plt.ylabel('Precision')
    plt.rcParams['font.sans-serif']= ['Times New Roman'] # 设置字体
    plt.rcParams['xtick.direction']='in' # 设置刻度朝向
    plt.rcParams['ytick.direction']='in'
    plt.plot([0,1],[1,0], linewidth=1, linestyle="--", color='black')
    plt.plot(recall, precision, 'g--', label='aupr=%0.4f'%aupr)
    plt.legend(loc="lower left")
    plt.savefig("pr_curve.png", dpi=600)
    plt.show()
    
    
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
    
    代码解读

ROC曲线:

在这里插入图片描述

PR曲线:

在这里插入图片描述

5. 参考资料

https://zhuanlan.zhihu.com/p/143655244

在这里插入图片描述

扫码关注,不定期免费送书,手把手指导请联系下方小助理

在这里插入图片描述

全部评论 (0)

还没有任何评论哟~