Advertisement

【ML】基于机器学习的心脏病预测研究(附代码,lightgbm模型)

阅读量:

写在前面:
首先感谢兄弟们的关注和订阅,让我有创作的动力,在创作过程我会尽最大能力,保证作品的质量,如果有问题,可以私信我,让我们携手共进,共创辉煌。(专栏订阅用户订阅专栏后免费提供数据集,代码贴在博文中,超级VIP用户不在服务范围之内,不想订阅专栏的兄弟们可以私信我详聊)

Hello,大家好,我是augustqi。今天手把手带大家做一个机器学习实战项目:基于机器学习的心脏病预测研究。多的不说,少的不唠,下面开始今天的教程。

以下内容,完全是我根据参考资料和个人理解撰写出来的,不存在滥用原创的问题。

1. 项目介绍

这是一个基于机器学习的二分类任务,根据给定“患者”的某些属性信息,预测是否患有心脏病。本项目使用的数据来源于UCI机器学习库。

UCI机器学习库中,一共包含4个关于心脏病诊断的数据集,分别是:

1、cleveland.data

2、hungarian.data

3、long-beach-va.data

4、switzerland.data

每个数据集都包含76个属性,但是所有已发布的实验都只使用了其中14个属性。其中,cleveland.data是机器学习研究人员最常使用的数据集。

2. 数据获取

2.1 获取方式1

UCI网站获取:

https://archive.ics.uci.edu/ml/datasets/heart+disease

2.2 获取方式2

kaggle网站获取:

https://www.kaggle.com/datasets/johnsmith88/heart-disease-dataset

2.3 获取方式3

Kaggle网站下载的数据集已经经过了预处理 ,关注微信公众号:AIexplore,私信【我的心脏】,即可获取。

3. 数据介绍

我们使用Kaggle网站上提供的数据集,共包含1025条数据,每条数据14个属性(13个特征+1个标签)

属性 解释
age 年龄
sex 性别,1表示男,0表示女
cp 心绞痛病史,1:典型心绞痛,2:非典型心绞痛,3:无心绞痛,4:无症状
trestbps 静息血压,入院时测量得到,单位为毫米汞柱(mm Hg)
chol 胆固醇含量,单位:mgldl
fbs 空腹时是否血糖高,如果空腹血糖大于120 mg/dl,值为1,否则值为0
restecg 静息时的心电图特征。0:正常。1: ST-T波异常(T波倒置和/或ST段抬高或压低>0.05 mV)。2:根据Estes标准显示可能或明确的左心室肥厚
thalach 最大心率
exang 运动是否会导致心绞痛,1表示会,0表示不会
oldpeak 运动相比于静息状态,心电图中的ST-T波是否会被压平。1表示会,0表示不会
slope 心电图中ST波峰值的坡度(1:上升,2:平坦,3:下降)
ca 心脏周边大血管的个数(0-3)
thal 是否患有地中海贫血症(0:未知,1:正常,2:先天缺陷,3:可逆缺陷)
target 标签列。是否有心脏病,0表示没有,1表示有

4. 特征工程

4.1 数据预处理

实际的项目过程肯定是要做特征工程的,但是,我们下载的数据,已经经过了预处理。不需要做缺失值处理、异常值处理、特征编码等。如果使用的是树模型,不需要对数据进行标准化、归一化。使用非树模型,则可以对数据进行标准化、归一化。本项目使用的是LightGBM模型,属于树模型,因此不进行数据标准化、归一化处理。

4.2 数据划分

根据8:2的比例,将数据集划分为训练集和测试集,训练集用于模型训练,测试集用于模型测试。

核心代码:

复制代码
 # 读取数据

    
 data_ori= pd.read_csv("heart.csv")
    
  
    
 # 特征X, 标签y
    
 X, y = data_ori.drop(['target'], axis=1), data_ori['target']
    
  
    
 # 划分数据集
    
 X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2)
    
  
    
 # DataFrame -> numpy
    
 X_train_np, X_test_np, y_train_np, y_test_np = np.array(X_train), np.array(X_test), np.array(y_train), np.array(y_test)

5. 搭建模型

lightgbm是微软开发的boosting工具包,这个包计算速度更快,占用内存更少,准确度和xgboost相当。首先安装lightgbm:

复制代码
    pip install lightgbm

定义的模型如下,LGBMClassifier中的其他参数使用默认值,

复制代码
 from lightgbm import LGBMClassifier

    
  
    
 # 模型
    
 model_lgb = LGBMClassifier(n_estimators=100)

6. 训练模型

在训练集上训练模型,并保存训练好的模型:

复制代码
 # 训练

    
 model_lgb.fit(X_train_np, y_train_np)
    
  
    
 # 查看拟合效果
    
 y_pred_train = model_lgb.predict(X_train)
    
 acc_train = accuracy_score(y_train_np, y_pred_train)
    
 print("acc train:", acc_train)
    
  
    
 # 保存模型
    
 joblib.dump(model_lgb, "model_lgb.pkl")

输出: acc train: 1.0
训练集上准确率已经100%,说明模型已经学习的很好了。

7. 测试模型

加载训练好的模型,并在测试集上进行测试:

复制代码
 # 测试

    
 # 加载训练好的模型
    
 model_lgb = joblib.load("model_lgb.pkl")
    
 # 预测
    
 y_pred = model_lgb.predict(X_test_np)
    
 y_pred_proba = model_lgb.predict_proba(X_test_np)
    
 y_score  = y_pred_proba[:,1]
    
  
    
 # 统计
    
 acc = accuracy_score(y_test_np, y_pred)
    
 pre = precision_score(y_test_np, y_pred)
    
 rec = recall_score(y_test_np, y_pred)
    
 f1 = f1_score(y_test_np, y_pred)
    
  
    
 # roc
    
 fpr, tpr, thresholds = roc_curve(y_test_np, y_score)
    
 auc_lgb = auc(fpr, tpr)
    
 plt.figure(figsize=(10,6))
    
 # plt.xlim(0, 1)     # 设定x轴的范围
    
 # plt.ylim(0, 1)     # 设定y轴的范围
    
 plt.title('ROC Curve')
    
 plt.xlabel('False Postive Rate')
    
 plt.ylabel('True Postive Rate')
    
 plt.rcParams['font.sans-serif']= ['Times New Roman'] # 设置字体
    
 plt.rcParams['xtick.direction']='in' # 设置刻度朝向
    
 plt.rcParams['ytick.direction']='in'
    
 plt.plot([0,1],[0,1], linewidth=1, linestyle="--", color='black')
    
 plt.plot(fpr,tpr,linewidth=2, linestyle="-",color='red', label='ROC(area={0:.4f})'.format(auc_lgb))
    
 plt.legend(loc="lower right")
    
 plt.savefig("roc_curve_lgb.png", dpi=600)
    
 plt.show()

准确率: 98.54%
精确率:100%
召回率:97.14%
f1得分:98.55%

ROC曲线:

8. 项目完整代码

环境配置:
python版本:3.9.0
pandas版本:1.4.2
numpy版本:1.22.3
matplotlib版本:3.5.1
sklearn版本:1.0.2
lightgbm版本:3.3.2
joblib版本:1.1.0

项目完整代码如下:

复制代码
 # -*- coding: utf-8 -*-

    
 """
    
 Created on Mon Oct 24 09:47:50 2022
    
   5. @author: augustqi
    
 """
    
  
    
 import pandas as pd
    
 import numpy as np
    
 from matplotlib import pyplot as plt
    
 from sklearn.model_selection import train_test_split
    
 from lightgbm import LGBMClassifier
    
 # from sklearn.ensemble import RandomForestClassifier
    
 from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
    
 from sklearn.metrics import roc_curve, auc # 多分类不绘制ROC曲线
    
 import joblib
    
  
    
  
    
 # 读取数据
    
 data_ori= pd.read_csv("heart.csv")
    
  
    
 # 特征X, 标签y
    
 X, y = data_ori.drop(['target'], axis=1), data_ori['target']
    
  
    
 # 划分数据集
    
 X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2)
    
  
    
 # DataFrame -> numpy
    
 X_train_np, X_test_np, y_train_np, y_test_np = np.array(X_train), np.array(X_test), np.array(y_train), np.array(y_test)
    
  
    
 # 模型
    
 model_lgb = LGBMClassifier(n_estimators=100)
    
  
    
 # 训练
    
 model_lgb.fit(X_train_np, y_train_np)
    
  
    
 # 查看拟合效果
    
 y_pred_train = model_lgb.predict(X_train)
    
 acc_train = accuracy_score(y_train_np, y_pred_train)
    
 print("acc train:", acc_train)
    
  
    
 # 保存模型
    
 joblib.dump(model_lgb, "model_lgb.pkl")
    
  
    
  
    
 # 测试
    
 # 加载训练好的模型
    
 model_lgb = joblib.load("model_lgb.pkl")
    
 # 预测
    
 y_pred = model_lgb.predict(X_test_np)
    
 y_pred_proba = model_lgb.predict_proba(X_test_np)
    
 y_score  = y_pred_proba[:,1]
    
  
    
 # 统计
    
 acc = accuracy_score(y_test_np, y_pred)
    
 pre = precision_score(y_test_np, y_pred)
    
 rec = recall_score(y_test_np, y_pred)
    
 f1 = f1_score(y_test_np, y_pred)
    
  
    
 # roc
    
 fpr, tpr, thresholds = roc_curve(y_test_np, y_score)
    
 auc_lgb = auc(fpr, tpr)
    
 plt.figure(figsize=(10,6))
    
 # plt.xlim(0, 1)     # 设定x轴的范围
    
 # plt.ylim(0, 1)     # 设定y轴的范围
    
 plt.title('ROC Curve')
    
 plt.xlabel('False Postive Rate')
    
 plt.ylabel('True Postive Rate')
    
 plt.rcParams['font.sans-serif']= ['Times New Roman'] # 设置字体
    
 plt.rcParams['xtick.direction']='in' # 设置刻度朝向
    
 plt.rcParams['ytick.direction']='in'
    
 plt.plot([0,1],[0,1], linewidth=1, linestyle="--", color='black')
    
 plt.plot(fpr,tpr,linewidth=2, linestyle="-",color='red', label='ROC(area={0:.4f})'.format(auc_lgb))
    
 plt.legend(loc="lower right")
    
 plt.savefig("roc_curve_lgb.png", dpi=600)
    
 plt.show()

参考资料

[1]https://www.it610.com/article/1531942863791747072.htm

[2]

[3]

[4]

[5]https://www.freesion.com/article/2415541776/

全部评论 (0)

还没有任何评论哟~