Advertisement

Kaggle 心脏病分类预测数据分析案例 (逻辑回归,KNN,决策树,随机森林...)

阅读量:

本文是一篇关于kaggle上一个’心脏病分类预测’数据集的分析小demo

整个机器学习流程包含以下几个关键环节:首先是数据观察与分析阶段;其次是进行系统的数据清洗与预处理;随后分别构建逻辑回归、KNN算法以及决策树模型;接着通过评估指标如F1值、混淆矩阵图以及精准率-召回率曲线图等手段对各模型性能进行全面评估;最后对各模型的ROC曲线进行详细比较;最终实现多模型集成优化以提升整体预测效果

数据集地址: https://www.kaggle.com/ronitf/heart-disease-uci

数据观察部分

复制代码
    import numpy as np
    import pandas as pd
    import matplotlib.pyplot as plt
    import seaborn as sns
    # 解决matplotlib中文问题
    from pylab import mpl
    mpl.rcParams['font.sans-serif'] = ['SimHei']  # 指定默认字体
    mpl.rcParams['axes.unicode_minus'] = False  # 解决保存图像是负号'-'显示为方块的问题
复制代码
    # 导入数据
    df = pd.read_csv('heart_disease_data/heart.csv')

瞄一瞄数据的总体情况

复制代码
    df.info()
复制代码
    <class 'pandas.core.frame.DataFrame'>
    RangeIndex: 303 entries, 0 to 302
    Data columns (total 14 columns):
    age         303 non-null int64
    sex         303 non-null int64
    cp          303 non-null int64
    trestbps    303 non-null int64
    chol        303 non-null int64
    fbs         303 non-null int64
    restecg     303 non-null int64
    thalach     303 non-null int64
    exang       303 non-null int64
    oldpeak     303 non-null float64
    slope       303 non-null int64
    ca          303 non-null int64
    thal        303 non-null int64
    target      303 non-null int64
    dtypes: float64(1), int64(13)
    memory usage: 33.2 KB

特征的含义

复制代码
    age 年龄
    sex 性别 1=male,0=female
    cp  胸痛类型(4种) 值1:典型心绞痛,值2:非典型心绞痛,值3:非心绞痛,值4:无症状
    trestbps 静息血压 
    chol 血清胆固醇
    fbs 空腹血糖 >120mg/dl ,1=true; 0=false
    restecg 静息心电图(值0,1,2)
    thalach 达到的最大心率
    exang 运动诱发的心绞痛(1=yes;0=no)
    oldpeak 相对于休息的运动引起的ST值(ST值与心电图上的位置有关)
    slope 运动高峰ST段的坡度 Value 1: upsloping向上倾斜, Value 2: flat持平, Value 3: downsloping向下倾斜
    ca  The number of major vessels(血管) (0-3)
    thal A blood disorder called thalassemia (3 = normal; 6 = fixed defect; 7 = reversable defect)
       一种叫做地中海贫血的血液疾病(3 =正常;6 =固定缺陷;7 =可逆转缺陷)
    target 生病没有(0=no,1=yes)
复制代码
    df.describe()
age sex cp trestbps chol fbs restecg thalach exang oldpeak slope ca thal target
count 303.000000 303.000000 303.000000 303.000000 303.000000 303.000000 303.000000 303.000000 303.000000 303.000000 303.000000 303.000000 303.000000 303.000000
mean 54.366337 0.683168 0.966997 131.623762 246.264026 0.148515 0.528053 149.646865 0.326733 1.039604 1.399340 0.729373 2.313531 0.544554
std 9.082101 0.466011 1.032052 17.538143 51.830751 0.356198 0.525860 22.905161 0.469794 1.161075 0.616226 1.022606 0.612277 0.498835
min 29.000000 0.000000 0.000000 94.000000 126.000000 0.000000 0.000000 71.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000
25% 47.500000 0.000000 0.000000 120.000000 211.000000 0.000000 0.000000 133.500000 0.000000 0.000000 1.000000 0.000000 2.000000 0.000000
50% 55.000000 1.000000 1.000000 130.000000 240.000000 0.000000 1.000000 153.000000 0.000000 0.800000 1.000000 0.000000 2.000000 1.000000
75% 61.000000 1.000000 2.000000 140.000000 274.500000 0.000000 1.000000 166.000000 1.000000 1.600000 2.000000 1.000000 3.000000 1.000000
max 77.000000 1.000000 3.000000 200.000000 564.000000 1.000000 2.000000 202.000000 1.000000 6.200000 2.000000 4.000000 3.000000 1.000000

简单的出图看看特征之间的关系

复制代码
    df.target.value_counts()
复制代码
    1    165
    0    138
    Name: target, dtype: int64
复制代码
    sns.countplot(x='target',data=df,palette="muted")
    plt.xlabel("得病/未得病比例")
复制代码
    Text(0.5,0,'得病/未得病比例')
在这里插入图片描述
复制代码
    df.sex.value_counts()
复制代码
    1    207
    0     96
    Name: sex, dtype: int64
复制代码
    sns.countplot(x='sex',data=df,palette="Set3")
    plt.xlabel("Sex (0 = 女, 1= 男)")
复制代码
    Text(0.5,0,'Sex (0 = 女, 1= 男)')
在这里插入图片描述
复制代码
    plt.figure(figsize=(18,7))
    sns.countplot(x='age',data = df, hue = 'target',palette='PuBuGn',saturation=0.8)
    plt.xticks(fontsize=13)
    plt.yticks(fontsize=13)
    plt.show()
在这里插入图片描述
对数据的认知是不可或缺的一环,但这篇重点讨论的是建模相关的内容,因此关于数据探索的部分处理起来相对容易

数据处理

对特征中非连续型数值(cp,slope,thal)特征进行处理

复制代码
    first = pd.get_dummies(df['cp'], prefix = "cp")
    second = pd.get_dummies(df['slope'], prefix = "slope")
    thrid = pd.get_dummies(df['thal'], prefix = "thal")
复制代码
    df = pd.concat([df,first,second,thrid], axis = 1)
    df = df.drop(columns = ['cp', 'slope', 'thal'])
    df.head(3)
age sex trestbps chol fbs restecg thalach exang oldpeak ca ... cp_1 cp_2 cp_3 slope_0 slope_1 slope_2 thal_0 thal_1 thal_2 thal_3
0 63 1 145 233 1 0 150 0 2.3 0 ... 0 0 1 1 0 0 0 1 0 0
1 37 1 130 250 0 1 187 0 3.5 0 ... 0 1 0 1 0 0 0 0 1 0
2 41 0 130 204 0 0 172 0 1.4 0 ... 1 0 0 0 0 1 0 0 1 0

3 rows × 22 columns

处理完成,生成最后的数据

复制代码
    y = df.target.values
    X = df.drop(['target'], axis = 1)
    X.shape
复制代码
    (303, 21)

分割数据集,并进行归一化处理

复制代码
    from sklearn.model_selection import train_test_split
    X_train,X_test,y_train,y_test = train_test_split(X,y,random_state=6)  #随机种子6
复制代码
    from sklearn.preprocessing import StandardScaler
    
    standardScaler = StandardScaler()
    standardScaler.fit(X_train)
    X_train = standardScaler.transform(X_train)
    X_test = standardScaler.transform(X_test)

模型创建 --Logistic Regression

复制代码
    from sklearn.linear_model import LogisticRegression 
    
    log_reg = LogisticRegression()
    log_reg.fit(X_train,y_train)
复制代码
    LogisticRegression(C=1.0, class_weight=None, dual=False, fit_intercept=True,
          intercept_scaling=1, max_iter=100, multi_class='ovr', n_jobs=1,
          penalty='l2', random_state=None, solver='liblinear', tol=0.0001,
          verbose=0, warm_start=False)
复制代码
    log_reg.score(X_train,y_train)
复制代码
8810572687224669
复制代码
    log_reg.score(X_test,y_test)
复制代码
8289473684210527
复制代码
    from sklearn.metrics import accuracy_score
    y_predict_log = log_reg.predict(X_test)
    
    # 调用accuracy_score计算分类准确度
    accuracy_score(y_test,y_predict_log)
复制代码
8289473684210527
使用网格搜索找出更好的模型参数
复制代码
    param_grid = [
    {
        'C':[0.01,0.1,1,10,100],
        'penalty':['l2','l1'],
        'class_weight':['balanced',None]
    }
    ]
复制代码
    from sklearn.model_selection import GridSearchCV
    
    grid_search = GridSearchCV(log_reg,param_grid,cv=10,n_jobs=-1)
复制代码
    %%time
    grid_search.fit(X_train,y_train)
复制代码
    Wall time: 2.88 s
    
    GridSearchCV(cv=10, error_score='raise',
       estimator=LogisticRegression(C=1.0, class_weight=None, dual=False, fit_intercept=True,
          intercept_scaling=1, max_iter=100, multi_class='ovr', n_jobs=1,
          penalty='l2', random_state=None, solver='liblinear', tol=0.0001,
          verbose=0, warm_start=False),
       fit_params=None, iid=True, n_jobs=-1,
       param_grid=[{'C': [0.01, 0.1, 1, 10, 100], 'penalty': ['l2', 'l1'], 'class_weight': ['balanced', None]}],
       pre_dispatch='2*n_jobs', refit=True, return_train_score='warn',
       scoring=None, verbose=0)
复制代码
    grid_search.best_estimator_
复制代码
    LogisticRegression(C=0.01, class_weight=None, dual=False, fit_intercept=True,
          intercept_scaling=1, max_iter=100, multi_class='ovr', n_jobs=1,
          penalty='l2', random_state=None, solver='liblinear', tol=0.0001,
          verbose=0, warm_start=False)
复制代码
    grid_search.best_score_
复制代码
8502202643171806
复制代码
    grid_search.best_params_
复制代码
    {'C': 0.01, 'class_weight': None, 'penalty': 'l2'}
复制代码
    log_reg = grid_search.best_estimator_
    log_reg.score(X_train,y_train)
复制代码
8634361233480177
复制代码
    log_reg.score(X_test,y_test)
复制代码
8289473684210527

查看F1指标

复制代码
    from sklearn.metrics import f1_score
    
    f1_score(y_test,y_predict_log)
复制代码
8470588235294118
复制代码
    from sklearn.metrics import classification_report
    print(classification_report(y_test,y_predict_log))
复制代码
             precision    recall  f1-score   support
    
          0       0.87      0.75      0.81        36
          1       0.80      0.90      0.85        40
    
    avg / total       0.83      0.83      0.83        76

绘制混淆矩阵

复制代码
    from sklearn.metrics import confusion_matrix
    cnf_matrix = confusion_matrix(y_test,y_predict_log)
    cnf_matrix
复制代码
    array([[27,  9],
       [ 4, 36]], dtype=int64)
复制代码
    def plot_cnf_matirx(cnf_matrix,description):
    class_names = [0,1]
    fig,ax = plt.subplots()
    tick_marks = np.arange(len(class_names))
    plt.xticks(tick_marks,class_names)
    plt.yticks(tick_marks,class_names)
    
    #create a heat map
    sns.heatmap(pd.DataFrame(cnf_matrix), annot = True, cmap = 'OrRd',
               fmt = 'g')
    ax.xaxis.set_label_position('top')
    plt.tight_layout()
    plt.title(description, y = 1.1,fontsize=16)
    plt.ylabel('实际值0/1',fontsize=12)
    plt.xlabel('预测值0/1',fontsize=12)
    plt.show()
    
    plot_cnf_matirx(cnf_matrix,'Confusion matrix -- Logistic Regression')
在这里插入图片描述
复制代码
    decision_scores = log_reg.decision_function(X_test)
    
    from sklearn.metrics import precision_recall_curve
    
    precisions,recalls,thresholds = precision_recall_curve(y_test,decision_scores)
复制代码
    plt.plot(thresholds,precisions[:-1])
    plt.plot(thresholds,recalls[:-1])
    plt.grid()
    plt.show()    #没有从最小值开始取,sklearn自己从自己觉得ok的位置开始取
在这里插入图片描述

绘制ROC曲线

复制代码
    from sklearn.metrics import roc_curve
    
    fprs,tprs,thresholds = roc_curve(y_test,decision_scores)
复制代码
    def plot_roc_curve(fprs,tprs):
    plt.figure(figsize=(8,6),dpi=80)
    plt.plot(fprs,tprs)
    plt.plot([0,1],linestyle='--')
    plt.xticks(fontsize=13)
    plt.yticks(fontsize=13)
    plt.ylabel('TP rate',fontsize=15)
    plt.xlabel('FP rate',fontsize=15)
    plt.title('ROC曲线',fontsize=17)
    plt.show()
    
    plot_roc_curve(fprs,tprs)
在这里插入图片描述
复制代码
    # 求面积,相当于求得分
    from sklearn.metrics import roc_auc_score  #auc:area under curve
    
    roc_auc_score(y_test,decision_scores)
复制代码
8784722222222222

模型创建–KNN临近算法

略过基本模型的创建,直接使用网格搜索进行参数调优

复制代码
    param_grid = [
    {
        'weights':['uniform'],
        'n_neighbors':[i for i in range(1,31)]
    },
    {
        'weights':['distance'],
        'n_neighbors':[i for i in range(1,31)],
        'p':[i for i in range(1,6)]
    }
    ]
复制代码
    %%time
    from sklearn.neighbors import KNeighborsClassifier
    knn_clf = KNeighborsClassifier()
    
    grid_search = GridSearchCV(knn_clf,param_grid)
    
    grid_search.fit(X_train,y_train)
复制代码
    Wall time: 7.23 s
复制代码
    grid_search.best_estimator_
复制代码
    KNeighborsClassifier(algorithm='auto', leaf_size=30, metric='minkowski',
           metric_params=None, n_jobs=1, n_neighbors=24, p=3,
           weights='distance')
复制代码
    grid_search.best_score_
复制代码
8502202643171806
复制代码
    grid_search.best_params_
复制代码
    {'n_neighbors': 24, 'p': 3, 'weights': 'distance'}
复制代码
    knn_clf = grid_search.best_estimator_
    knn_clf.score(X_train,y_train)
复制代码
0
复制代码
    knn_clf.score(X_test,y_test)
复制代码
8421052631578947
复制代码
    y_predict_knn = knn_clf.predict(X_test)

查看F1指标

复制代码
    f1_score(y_test,y_predict_knn)
复制代码
8536585365853658
复制代码
    print(classification_report(y_test,y_predict_knn))
复制代码
             precision    recall  f1-score   support
    
          0       0.85      0.81      0.83        36
          1       0.83      0.88      0.85        40
    
    avg / total       0.84      0.84      0.84        76

绘制混淆矩阵

复制代码
    cnf_matrix = confusion_matrix(y_test,y_predict_knn)
    cnf_matrix
复制代码
    array([[29,  7],
       [ 5, 35]], dtype=int64)
复制代码
    # 此处调用前面的绘制函数
    plot_cnf_matirx(cnf_matrix,'Confusion matrix -- KNN')
在这里插入图片描述
复制代码
    y_probabilities = knn_clf.predict_proba(X_test)[:,1]
    
    from sklearn.metrics import precision_recall_curve
    
    precisions,recalls,thresholds = precision_recall_curve(y_test,y_probabilities)
    
    plt.plot(thresholds,precisions[:-1])
    plt.plot(thresholds,recalls[:-1])
    plt.grid()
    plt.show()    #没有从最小值开始取,sklearn自己从自己觉得ok的位置开始取
在这里插入图片描述

绘制ROC曲线

复制代码
    from sklearn.metrics import roc_curve
    fprs2,tprs2,thresholds2 = roc_curve(y_test,y_probabilities)
    # 此处调用前面的绘制函数
    plot_roc_curve(fprs2,tprs2)
在这里插入图片描述
复制代码
    # 求面积,相当于求得分
    from sklearn.metrics import roc_auc_score  #auc:area under curve
    
    roc_auc_score(y_test,y_probabilities)
复制代码
8739583333333334

模型创建–DecisionTree

复制代码
    from sklearn.tree import DecisionTreeClassifier
    dt_clf= DecisionTreeClassifier(random_state=6)
复制代码
    from sklearn.model_selection import GridSearchCV
    param_grid = [
    {
        'max_features':['auto','sqrt','log2'],
        'min_samples_split':[2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18],
        'min_samples_leaf':[1,2,3,4,5,6,7,8,9,10,11]
    }
    ]
    grid_search = GridSearchCV(dt_clf,param_grid)
    
    grid_search.fit(X_train,y_train)
复制代码
    GridSearchCV(cv=None, error_score='raise',
       estimator=DecisionTreeClassifier(class_weight=None, criterion='gini', max_depth=None,
            max_features=None, max_leaf_nodes=None,
            min_impurity_decrease=0.0, min_impurity_split=None,
            min_samples_leaf=1, min_samples_split=2,
            min_weight_fraction_leaf=0.0, presort=False, random_state=6,
            splitter='best'),
       fit_params=None, iid=True, n_jobs=1,
       param_grid=[{'max_features': ['auto', 'sqrt', 'log2'], 'min_samples_split': [2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18], 'min_samples_leaf': [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]}],
       pre_dispatch='2*n_jobs', refit=True, return_train_score='warn',
       scoring=None, verbose=0)
复制代码
    grid_search.best_estimator_
复制代码
    DecisionTreeClassifier(class_weight=None, criterion='gini', max_depth=None,
            max_features='auto', max_leaf_nodes=None,
            min_impurity_decrease=0.0, min_impurity_split=None,
            min_samples_leaf=2, min_samples_split=8,
            min_weight_fraction_leaf=0.0, presort=False, random_state=6,
            splitter='best')
复制代码
    grid_search.best_score_
复制代码
7929515418502202
复制代码
    grid_search.best_params_
复制代码
    {'max_features': 'auto', 'min_samples_leaf': 2, 'min_samples_split': 8}
复制代码
    dt_clf = grid_search.best_estimator_
    dt_clf.score(X_train,y_train)
复制代码
8854625550660793
复制代码
    dt_clf.score(X_test,y_test)
复制代码
7236842105263158
复制代码
    y_predict_dt = dt_clf.predict(X_test)

查看F1指标

复制代码
    f1_score(y_test,y_predict_dt)
复制代码
7123287671232875

绘制混淆矩阵

复制代码
    print(classification_report(y_test,y_predict_dt))
复制代码
             precision    recall  f1-score   support
    
          0       0.67      0.81      0.73        36
          1       0.79      0.65      0.71        40
    
    avg / total       0.73      0.72      0.72        76
复制代码
    cnf_matrix = confusion_matrix(y_test,y_predict_dt)
    cnf_matrix
复制代码
    array([[29,  7],
       [14, 26]], dtype=int64)
复制代码
    # 此处调用前面的绘制函数
    plot_cnf_matirx(cnf_matrix,'Confusion matrix -- DecisionTree')
在这里插入图片描述
复制代码
    y_probabilities = dt_clf.predict_proba(X_test)[:,1]
    
    from sklearn.metrics import precision_recall_curve
    
    precisions,recalls,thresholds = precision_recall_curve(y_test,y_probabilities)
复制代码
    plt.plot(thresholds,precisions[:-1])
    plt.plot(thresholds,recalls[:-1])
    plt.grid()
    plt.show()    #没有从最小值开始取,sklearn自己从自己觉得ok的位置开始取
在这里插入图片描述

绘制ROC曲线

复制代码
    from sklearn.metrics import roc_curve
    fprs3,tprs3,thresholds3 = roc_curve(y_test,y_probabilities)
    # 此处调用前面的绘制函数
    plot_roc_curve(fprs3,tprs3)
在这里插入图片描述
复制代码
    # 求面积,相当于求得分
    from sklearn.metrics import roc_auc_score  #auc:area under curve
    
    roc_auc_score(y_test,y_probabilities)
复制代码
7527777777777778

结合起来一起看–ROC

复制代码
    sns.set_style('whitegrid')
    plt.figure(figsize=(12,8))
    plt.title('ROC Curve',fontsize=18)
    plt.plot(fprs,tprs,label='KNN')
    plt.plot(fprs2,tprs2,label='Log_Reg')
    plt.plot(fprs3,tprs3,label='dt_Clf')
    plt.plot([0,1],ls='--')
    plt.plot([0,0],[1,0],c='.8')
    plt.plot([1,1],c='.8')
    plt.ylabel('TP rate',fontsize=15)
    plt.xlabel('FP rate',fontsize=15)
    plt.xticks(fontsize=13)
    plt.yticks(fontsize=13)
    plt.legend()
    plt.show()
在这里插入图片描述

model ensemble

使用不同种类分类器的方法,

复制代码
    from sklearn.ensemble import VotingClassifier
    
    voting_clf = VotingClassifier(estimators=[
        ('log_clf',log_reg),
        ('knn_clf',knn_clf),
        ('dt_clf',dt_clf)
    ],voting='soft')
复制代码
    voting_clf.fit(X_train,y_train)
复制代码
    VotingClassifier(estimators=[('log_clf', LogisticRegression(C=0.01, class_weight=None, dual=False, fit_intercept=True,
          intercept_scaling=1, max_iter=100, multi_class='ovr', n_jobs=1,
          penalty='l2', random_state=None, solver='liblinear', tol=0.0001,
          verbose=0, warm_start=False)), ('knn_cl...         min_weight_fraction_leaf=0.0, presort=False, random_state=6,
            splitter='best'))],
         flatten_transform=None, n_jobs=1, voting='soft', weights=None)
复制代码
    voting_clf.score(X_train,y_train)
复制代码
    F:\software\anaconda\anaconda\lib\site-packages\sklearn\preprocessing\label.py:151: DeprecationWarning: The truth value of an empty array is ambiguous. Returning False, but in future this will result in an error. Use `array.size > 0` to check that an array is not empty.
      if diff:
    
    0.9955947136563876
复制代码
    voting_clf.score(X_test,y_test)
复制代码
    F:\software\anaconda\anaconda\lib\site-packages\sklearn\preprocessing\label.py:151: DeprecationWarning: The truth value of an empty array is ambiguous. Returning False, but in future this will result in an error. Use `array.size > 0` to check that an array is not empty.
      if diff:
    
    0.7894736842105263
复制代码
    y_predict_voting = voting_clf.predict(X_test)
复制代码
    F:\software\anaconda\anaconda\lib\site-packages\sklearn\preprocessing\label.py:151: DeprecationWarning: The truth value of an empty array is ambiguous. Returning False, but in future this will result in an error. Use `array.size > 0` to check that an array is not empty.
      if diff:
复制代码
    f1_score(y_test,y_predict_voting)
复制代码
7948717948717949
复制代码
    y_probabilities = voting_clf.predict_proba(X_test)[:,1]
    
    roc_auc_score(y_test,y_probabilities)
复制代码
8666666666666667

好像结果并不理想,可能也是数据集总体数量偏小的缘故

使用随机森林 (本身结合bagging和决策树)

复制代码
    from sklearn.ensemble import RandomForestClassifier
    
    rf_clf = RandomForestClassifier(n_estimators=500,random_state=666,oob_score=True,n_jobs=-1)
    rf_clf.fit(X,y)
复制代码
    RandomForestClassifier(bootstrap=True, class_weight=None, criterion='gini',
            max_depth=None, max_features='auto', max_leaf_nodes=None,
            min_impurity_decrease=0.0, min_impurity_split=None,
            min_samples_leaf=1, min_samples_split=2,
            min_weight_fraction_leaf=0.0, n_estimators=500, n_jobs=-1,
            oob_score=True, random_state=666, verbose=0, warm_start=False)
对oob数据集得分
复制代码
    rf_clf.oob_score_
复制代码
8118811881188119
复制代码
    rf_clf.score(X_test,y_test)
复制代码
7894736842105263
复制代码
    y_probabilities_rf = rf_clf.predict_proba(X_test)[:,1]
    
    roc_auc_score(y_test,y_probabilities_rf)
复制代码
9288194444444444
复制代码
    y_probabilities = rf_clf.predict_proba(X)[:,1]
    
    roc_auc_score(y,y_probabilities)
复制代码
0
复制代码
    from sklearn.metrics import roc_curve
    fprs4,tprs4,thresholds4 = roc_curve(y_test,y_probabilities_rf)
    # 此处调用前面的绘制函数
    plot_roc_curve(fprs4,tprs4)
在这里插入图片描述

总览

复制代码
    sns.set_style('whitegrid')
    plt.figure(figsize=(12,8))
    plt.title('ROC Curve',fontsize=18)
    plt.plot(fprs,tprs,label='KNN')
    plt.plot(fprs2,tprs2,label='Log_Reg')
    plt.plot(fprs3,tprs3,label='dt_Clf')
    plt.plot(fprs4,tprs4,label='rf_Clf')
    plt.plot([0,1],ls='--')
    plt.plot([0,0],[1,0],c='.8')
    plt.plot([1,1],c='.8')
    plt.ylabel('TP rate',fontsize=15)
    plt.xlabel('FP rate',fontsize=15)
    plt.xticks(fontsize=13)
    plt.yticks(fontsize=13)
    plt.legend()
    plt.show()
在这里插入图片描述

修改下参数

复制代码
    rf_clf2 = RandomForestClassifier(n_estimators=500,max_leaf_nodes=16,random_state=666,oob_score=True,n_jobs=-1)
    rf_clf2.fit(X,y)
复制代码
    RandomForestClassifier(bootstrap=True, class_weight=None, criterion='gini',
            max_depth=None, max_features='auto', max_leaf_nodes=16,
            min_impurity_decrease=0.0, min_impurity_split=None,
            min_samples_leaf=1, min_samples_split=2,
            min_weight_fraction_leaf=0.0, n_estimators=500, n_jobs=-1,
            oob_score=True, random_state=666, verbose=0, warm_start=False)
复制代码
    rf_clf2.oob_score_
复制代码
8316831683168316
复制代码
    y_probabilities = rf_clf2.predict_proba(X_test)[:,1]
    
    roc_auc_score(y_test,y_probabilities)
复制代码
8954861111111111
复制代码
    y_probabilities = rf_clf2.predict_proba(X)[:,1]
    
    roc_auc_score(y,y_probabilities)
复制代码
9757136583223539

该方法在前面划分的测试数据集上表现出ROC值为0.92,在所有数据集上的表现达到完美分类效果;调整相关参数后,该方法在测试数据集上的表现降至了ROC值为0.89,在所有数据集上的表现提升至了0.97。

仅就随机森林模型参数调整进行初步比较,目前阶段到这里已经足够,仍需持续深入的学习与探索

全部评论 (0)

还没有任何评论哟~