Advertisement

【ML】基于机器学习的心脏病预测研究(附代码和数据集,支持向量机SVM模型)

阅读量:

前言:
首先感谢诸位的关注与支持,在创作过程中力求做到最好。如有任何问题或建议,请随时欢迎各位交流探讨。与此同时,请注意以下事项:

  1. 本专栏面向订阅用户提供免费数据集,并附有相关代码放在文章中供参考使用。
  2. 超级VIP会员不在我们的服务范围内。
  3. 如不打算订阅本专栏的朋友有任何意见或建议,请随时通过私信与我们联系。

ML

ML

ML

ML

ML

ML

ML

ML

本次实战任务是:以机器学习为背景的心脏病预测研究(附有代码和数据集),支持向量机算法模型中涵盖了ROC曲线生成、PR曲线生成等内容。建议收藏后深入学习。

本次实战任务是:以机器学习为背景的心脏病预测研究(附有代码和数据集),支持向量机算法模型中涵盖了ROC曲线生成、PR曲线生成等内容。建议收藏后深入学习。

心脏疾病是人类健康状况的主要威胁,在全球范围内约有三分之一的人口因疾病而陷入死亡。我国每年约有数十万人口因心脏疾病而亡。如果能够通过收集人体相关的指标信息(如既往病史、家族病史、血压水平、血糖水平等),运用数据分析手段预判不同特征对心脏疾病的影响,并构建电子病历系统收集数据集并开发预测模型系统,则可对心脏疾病的预防工作发挥关键作用。

本项目基于数据分析、数据挖掘,根据疾病的特征预测是否患有心脏病。

1. 项目介绍

该研究采用了机器学习方法进行分类任务。该研究基于患者的若干特征参数进行分析判断。该研究采用UCI Machine Learning Repository中的数据集作为训练与验证用例。

UCI机器学习库中,一共包含4个关于心脏病诊断的数据集,分别是:

1、cleveland.data

2、hungarian.data

3、long-beach-va.data

4、switzerland.data

每个数据集都拥有76个属性参数,然而在现有公开实验中仅采用了其中14个关键属性指标作为研究重点。其中一项核心研究数据集cleveland.data被广泛应用于该领域的核心研究

2. 数据获取

关注微信公众号:AIexplore,私信【我的心脏】,即可获取。

在这里插入图片描述

3. 数据介绍

我们基于Kaggle网站收集了该研究所需的数据集。该集合共计包含一千零二十五个样本,在每个样本中都包含了十四项指标。其中十三项用于特征提取、一项用于分类目标。

在这里插入图片描述

4. 实验

4.1 数据信息

导入数据分析常用库:

复制代码
    # 导入数据分析常用库
    import numpy as np
    import pandas as pd
    import matplotlib.pyplot as plt
    import seaborn as sns
    
    
      
      
      
      
      
    
    代码解读

导入数据:

复制代码
    # 读取数据
    data_ori= pd.read_csv("heart.csv")
    
    
      
      
    
    代码解读

查看数据:

复制代码
    # 查看所有数据
    data_ori
    
    
      
      
    
    代码解读
在这里插入图片描述

查看前10条数据:

复制代码
    # 查看所有数据
    data_ori.head(10)
    
    
      
      
    
    代码解读
在这里插入图片描述

查看数据基本信息:

复制代码
    data_ori.info()
    
    
      
    
    代码解读
在这里插入图片描述

小结:

该数据集共计包含1025条数据;每条样本包含13个特征属性和1个标注字段;其中共有12个整数类型和一个浮点数类型字段;标签均为整数类型;所有样本的数据均无缺失信息

target取值为0或1
0:没有心脏病
1:患有心脏病

4.2 数据可视化分析

4.2.1 特征之间的相关关系

复制代码
    # 查看特征之间的相关关系
    plt.figure(figsize=(12,10))
    corr =  data_ori.corr() # 皮尔逊相关系数
    sns.heatmap(data=corr,annot=True, square=True,fmt='.3f', cmap='PuBu')
    plt.show()
    
    
      
      
      
      
      
    
    代码解读
在这里插入图片描述

小结:

  1. 这些指标间表现出较强的关联性(cp, thalach, exang, oldpeak)
  2. 这些因素与心脏病之间存在一定关联(age, sex, slope, ca, thal)
  3. 这些指标间的关联程度较弱(chol, fbs)

4.2.2 从整体看下心脏病患病情况

复制代码
    colors=['tomato','lightskyblue']
    
    countNoDisease = len(data_ori[data_ori['target']==0])
    countHaveDisease = len(data_ori[data_ori['target']==1])
    total =  len(data_ori['target'])
    rateNo =  countNoDisease/total*100
    rateHave = countHaveDisease/total*100
    diseaseRate =  pd.Series({'正常':countNoDisease, '患病': countHaveDisease})
    print("rateNo",rateNo)
    print("rateHave",rateHave)
    
    
      
      
      
      
      
      
      
      
      
      
    
    代码解读

rateNo 48.68292682926829
rateHave 51.31707317073171

复制代码
    # 绘图
    plt.bar(diseaseRate.index,diseaseRate.values,color='lightskyblue')
    # 以下2行解决windows下中文不显示问题
    plt.rcParams['font.sans-serif']=['simHei']
    plt.rcParams['axes.unicode_minus']=False
    plt.title('患病和正常人群分布',fontsize=14)
    plt.xlabel('是否患病',fontsize=12)
    plt.ylabel('人数',fontsize=12)
    plt.ylim([0,600])
    
    # 添加数据标签
    for a,b in zip([0,1],[countNoDisease, countHaveDisease]):
    plt.text(a,b+4,b,ha='center',fontsize=12)
    plt.show()
    
    
      
      
      
      
      
      
      
      
      
      
      
      
      
      
    
    代码解读
在这里插入图片描述

小结:
未患有心血管疾病的人数为499人(即占总人口的48.68%);而患有心血管疾病的患者共计526人(即占总人口的51.32%)。

4.2.3 研究心脏病和年龄之间的关系

复制代码
    # 由于年龄过于分散,将年龄拆分成不同段
    '''
    pandas.cut()用来把一组数据分割成离散的区间。
    比如有年龄数据,可以使用pandas.cut将年龄数据分割成不同的年龄段并打上标签。
    pandas.cut(x, bins, right=True, include_lowest=False, labels=None )
    参数介绍: x:被拆分的数据数组,必须是一维的,不能用DataFrame 
          bins:被切割后的区间,比如bins=[1,2,3],则区间为(1,2)和(2,3) 
          right=True,默认包含右区间,改为False,则不包含 
          include_lowest:bool型的参数,表示区间的左边是开还是闭的,默认为false
          labels:给分割后的bins打标签,长度必须和划分后的区间长度相等,
                   比如bins=[1,2,3],划分后有2个区间,则labels的长度必须为2。
    '''
    
    # 定义一个DataFrame接收分割后的年龄
    age_df = pd.DataFrame()
    
    # 用pd.cut() 将年龄进行分割
    age_df['age_range']  = pd.cut(x = data_ori['age'],
                                  bins = [0,18,40,60,100],
                                  include_lowest = True,right=False,
                                  labels = ['儿童','青年','中年','老年'])
    
    # 将原数据集的target合并到age_distDf中
    age_df = pd.concat([age_df['age_range'],data_ori['target']],axis=1)
    age_df.head(10)
    
    
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
    
    代码解读
在这里插入图片描述
复制代码
    # 画柱状图,观察不同年龄段的人患心脏病的情况
    sns.countplot(data=age_df,x='age_range',hue='target',palette='Set2')
    plt.xlabel('年龄段')
    plt.show()
    
    
      
      
      
      
    
    代码解读
在这里插入图片描述

小结:
这个数据集的人群集中在中年这个年龄段,且中年人患病的较多

复制代码
    # 绘制饼图比较不同年龄段人群患病情况
    plt.figure(figsize = (3*5,1*4))
    
    # 青年人患病比例
    ax1 = plt.subplot(1,3,1)
    youth =  age_df[age_df['age_range']=='青年']['target'].value_counts()
    plt.pie(youth,explode=[0,0.05],autopct='%.2f%%',labels=['患病','正常'],colors=colors)
    plt.title('青年人患病比例')
    
    # 中年人患病比例
    ax2 =  plt.subplot(1,3,2)
    middle = age_df[age_df['age_range']=='中年']['target'].value_counts()
    plt.pie(middle,explode=(0,0.05),autopct='%.2f%%',labels=['患病','正常'],colors=colors)
    plt.title('中年人患病比例')
    
    # 老年人患病比例
    ax2 = plt.subplot(1,3,3)
    old =  age_df[age_df['age_range']=='老年']['target'].value_counts()
    plt.pie(old,explode=[0,0.05],autopct='%.2f%%',labels=['患病','正常'],colors=colors)
    plt.title('老年人患病比例')
    plt.show()
    
    
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
    
    代码解读
在这里插入图片描述

总结:总结指出青年人群中出现疾病的比例显著高于老年人群。这与现代青年人面临工作压力、熬夜习惯以及缺乏运动等多种不良生活习惯密切相关。

4.2.4 心脏病和性别之间的关系

复制代码
    sex_df = data_ori[['sex','target']]
    plt.figure(figsize=(7, 5))
    sns.countplot(data = sex_df, x = 'sex', hue='target', palette=colors)
    plt.title('不同性别人群患病情况',fontsize=14)
    plt.xlabel('性别(0=女性,1=男性)',fontsize=12)
    plt.ylabel('人数',fontsize=12)
    plt.show()
    
    
      
      
      
      
      
      
      
    
    代码解读
在这里插入图片描述

小结:
女性患心脏病的比例高于男性(注意是比例)在这里插入图片描述

4.2.5 心脏病和cp(心绞痛类型)之间的关系

复制代码
    cp_df = data_ori[['cp','target']]
    sns.countplot(data = cp_df, x ='cp',hue='target',palette='Set2')
    plt.xlabel('胸痛类型\n(0=典型心绞痛;1=非典型心绞痛;2=非心绞痛;3=没有症状)',fontsize=12)
    plt.ylabel('人数',fontsize=12)
    plt.show()
    
    
      
      
      
      
      
    
    代码解读
在这里插入图片描述

小结:
对于典型的胸痛症状(心绞痛)患者而言,在仅出现单纯的心绞痛症状的情况下发生心脏病的概率相对较低;而对于属于1-3类的心绞痛患者的群体而言,在出现上述症状的情况下患心脏病的风险明显增加;这进一步表明了心脏疾病与不同类型的胸痛症状之间存在一定的关联性。

4.2.6 运动引起的心绞痛(exang)和心脏病之间的关系

复制代码
    exang_df = data_ori[['exang','target']]
    sns.countplot(data=exang_df, x='exang',hue='target',palette='Set2')
    plt.xlabel('运动是否引起心绞痛\n(0=否,1=是)',fontsize=12)
    plt.ylabel('人数',fontsize=12)
    plt.show()
    
    
      
      
      
      
      
    
    代码解读
在这里插入图片描述

小结:
运动导致的心绞痛被确诊为心脏病的概率较低。单纯的运动出现心绞痛现象时,应首先怀疑可能是其他疾病;但如果存在心脏病,则需遵医嘱进行休息(静养),过量的高强度运动仍然可能引发心绞痛。

4.2.7 心脏病和最大心率(thalach)之间的关系

复制代码
    # 心率和年龄也有一定的关系,可以结合考察心脏病,心率,年龄
    thalch_df = data_ori[['thalach','target']]
    thalch_df['age_range'] = pd.cut( data_ori['age'], bins=[0,18,40,60,100],
                                 labels=['儿童','青年','中年','老年'],
                                 include_lowest=True,right=False)
    
    
      
      
      
      
      
    
    代码解读
在这里插入图片描述
复制代码
    sns.swarmplot(data =thalch_df, x='age_range',y='thalach', hue='target')
    plt.xlabel('年龄段', fontsize=12)
    plt.ylabel('最大心率',fontsize=12)
    plt.show()
    
    
      
      
      
      
    
    代码解读
在这里插入图片描述

小结:
伴随年龄的增长,最大心率逐步下降。 在同一年龄段中,心脏病患者的平均心率整体上高于健康人群。

4.2.8 静息血压和心脏病的关系

复制代码
    trestbps_df = data_ori[['trestbps','target']]
    sns.boxplot(trestbps_df['target'],trestbps_df['trestbps'])
    plt.xlabel('是否患心脏病(0=否,1=是)',fontsize=12)
    plt.ylabel('人数',fontsize=12)
    plt.title('心脏病与静息血压的关系')
    plt.show()
    
    
      
      
      
      
      
      
    
    代码解读
在这里插入图片描述

小结:
正常人在静息状态下的血压水平略高于心脏病患者。 不同特征的配对组合分析涉及多种可能性,在此不做深入探讨后方有余地。 下述内容将运用多种机器学习模型来评估疾病风险。

4.3 数据预处理

non-continuous classification data processing (cp, restecg, slope, thal)
通过使用get_dummies()函数来进行非连续性分类数据的转换

复制代码
    # 非连续性分类数据处理(cp,restecg,slope,thal)
    # 采用get_dummies()编码方式处理非连续性分类数据
    cp_dummies= pd.get_dummies(data_ori['cp'],prefix = 'cp')
    restecg_dummies =  pd.get_dummies(data_ori['restecg'],prefix='restecg')
    slope_dummies =  pd.get_dummies(data_ori['slope'],prefix='slope')
    thal_dummies = pd.get_dummies(data_ori['thal'],prefix='thal')
    
    # 将原数据中经过独热编码的列删除
    data_ori_new =  data_ori.drop(['cp','restecg','slope','thal'],axis=1)
    data_ori_new = pd.concat([data_ori_new,cp_dummies,restecg_dummies,slope_dummies,thal_dummies],axis=1)
    data_ori_new.head(10)
    
    
      
      
      
      
      
      
      
      
      
      
      
    
    代码解读
在这里插入图片描述

4.4 数据划分

复制代码
    # 拆分数据集和目标集
    # 分离出数据和标签
    label =  data_ori_new['target']
    data =  data_ori_new.drop('target',axis=1)
    data.shape
    # (1025, 23)
    # 数据集合的不同特征之间数据相差有点大,对于SVM、KNN等算法,会产生权重影响,因此需要标准化处理数据
    from sklearn.preprocessing import StandardScaler
    standardScaler = StandardScaler()
    standardScaler.fit(data)
    data =  standardScaler.transform(data)
    
    # 拆分训练集,测试集
    from sklearn.model_selection import train_test_split
    train_X,test_X,train_y,test_y = train_test_split(data,label,random_state=3)
    train_X.shape
    # (768, 23)
    
    
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
    
    代码解读

4.5 支持向量机SVM模型

本次实战,我们使用支持向量机SVM模型拟合数据集

复制代码
    # -*- coding: utf-8 -*-
    """
    Created on Mon Mar 20 14:07:58 2023
    
    @author: augustqi
    """
    
    
    import pandas as pd
    import numpy as np
    from matplotlib import pyplot as plt
    from sklearn.model_selection import train_test_split
    from sklearn.svm import SVC
    from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score
    from sklearn.metrics import roc_curve, auc, precision_recall_curve # 多分类不绘制ROC曲线
    from sklearn.metrics import classification_report
    import joblib
    
    
    # 读取数据
    data_ori= pd.read_csv("heart.csv")
    
    # 特征X, 标签y
    X, y = data_ori.drop(['target'], axis=1), data_ori['target']
    
    # 划分数据集
    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3)
    
    # DataFrame -> Numpy
    X_train_np, X_test_np, y_train_np, y_test_np = np.array(X_train), np.array(X_test), np.array(y_train), np.array(y_test)
    
    # 模型, SVC
    model  = SVC(kernel="rbf", probability=True) # 需要将probability设为True,否则model.predict_proba无法使用
    
    # 训练
    model.fit(X_train_np, y_train_np)
    
    # 查看拟合效果
    y_pred_train = model.predict(X_train)
    acc_train = accuracy_score(y_train_np, y_pred_train)
    print("acc train:", acc_train)
    
    # 保存模型
    joblib.dump(model, "model_svm.pkl")
    
    # 测试
    # 加载训练好的模型
    model = joblib.load("model_svm.pkl")
    
    # 预测
    y_pred = model.predict(X_test_np)
    y_pred_proba = model.predict_proba(X_test_np)
    y_score  = y_pred_proba[:,1]
    
    # 统计
    acc = accuracy_score(y_test_np, y_pred)
    pre = precision_score(y_test_np, y_pred)
    rec = recall_score(y_test_np, y_pred)
    f1 = f1_score(y_test_np, y_pred)
    
    # print
    print("accuracy:", acc)
    print("precision:", pre)
    print("recall:", rec)
    print("f1_score:", f1)
    
    # roc
    fpr, tpr, thresholds = roc_curve(y_test_np, y_score)
    
    # auc值, 间接计算
    auc_rf = auc(fpr, tpr)
    
    # auc值,直接计算
    auc_rf_2 = roc_auc_score(y_test_np, y_score)
    
    # classification_report
    res_report = classification_report(y_test_np, y_pred)
    
    
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
    
    代码解读

结果:

复制代码
    acc train: 0.6875871687587168
    accuracy: 0.698051948051948
    precision: 0.6627906976744186
    recall: 0.7651006711409396
    f1_score: 0.7102803738317757
    
              precision    recall  f1-score   support
    
           0       0.74      0.64      0.68       159
           1       0.66      0.77      0.71       149
    
    accuracy                           0.70       308
       macro avg       0.70      0.70      0.70       308
    weighted avg       0.70      0.70      0.70       308
    
    
      
      
      
      
      
      
      
      
      
      
      
      
      
      
    
    代码解读

绘制ROC曲线,PR曲线:

复制代码
    # ROC曲线
    plt.figure(figsize=(10,6))
    # plt.xlim(0, 1)     # 设定x轴的范围
    # plt.ylim(0, 1)     # 设定y轴的范围
    plt.title('ROC Curve')
    plt.xlabel('False Postive Rate')
    plt.ylabel('True Postive Rate')
    plt.rcParams['font.sans-serif']= ['Times New Roman'] # 设置字体
    plt.rcParams['xtick.direction']='in' # 设置刻度朝向
    plt.rcParams['ytick.direction']='in'
    plt.plot([0,1],[0,1], linewidth=1, linestyle="--", color='black')
    plt.plot(fpr,tpr,linewidth=2, linestyle="--",color='red', label='ROC(area={0:.4f})'.format(auc_rf))
    plt.legend(loc="lower right")
    plt.savefig("roc_curve.png", dpi=600)
    plt.show()
    
    # PR曲线
    precision, recall, thresholds_pr = precision_recall_curve(y_test_np, y_score)
    aupr = auc(recall, precision)
    plt.figure(figsize=(10,6))
    plt.title('PR Curve')
    plt.xlabel('Recall')
    plt.ylabel('Precision')
    plt.rcParams['font.sans-serif']= ['Times New Roman'] # 设置字体
    plt.rcParams['xtick.direction']='in' # 设置刻度朝向
    plt.rcParams['ytick.direction']='in'
    plt.plot([0,1],[1,0], linewidth=1, linestyle="--", color='black')
    plt.plot(recall, precision, 'g--', label='aupr=%0.4f'%aupr)
    plt.legend(loc="lower left")
    plt.savefig("pr_curve.png", dpi=600)
    plt.show()
    
    
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
    
    代码解读

ROC曲线:

在这里插入图片描述

PR曲线:

在这里插入图片描述

5. 参考资料

https://zhuanlan.zhihu.com/p/143655244

在这里插入图片描述

扫码关注,不定期免费送书,手把手指导请联系下方小助理

在这里插入图片描述

全部评论 (0)

还没有任何评论哟~