数据挖掘实战一:输入预测分类

# 导入第三方包
import pandas as pd
import numpy as np
import seaborn as sns

# 数据读取
income = pd.read_excel(r'./income.xlsx')
income.head()
#了解数据的大体结构。输出前几行

age workclass fnlwgt education education-num marital-status occupation relationship race sex capital-gain capital-loss hours-per-week native-country income
0 39 State-gov 77516 Bachelors 13 Never-married Adm-clerical Not-in-family White Male 2174 0 40 United-States <=50K
1 50 Self-emp-not-inc 83311 Bachelors 13 Married-civ-spouse Exec-managerial Husband White Male 0 0 13 United-States <=50K
2 38 Private 215646 HS-grad 9 Divorced Handlers-cleaners Not-in-family White Male 0 0 40 United-States <=50K
3 53 Private 234721 11th 7 Married-civ-spouse Handlers-cleaners Husband Black Male 0 0 40 United-States <=50K
4 28 Private 338409 Bachelors 13 Married-civ-spouse Prof-specialty Wife Black Female 0 0 40 Cuba <=50K
# 查看数据集是否存在缺失值
income.apply(lambda x:np.sum(x.isnull()))
##info()用来查看数据是否有缺失值,以及数据类型。
income.info()
### 3万条数据

RangeIndex: 32561 entries, 0 to 32560
Data columns (total 15 columns):
age               32561 non-null int64
workclass         30725 non-null object
fnlwgt            32561 non-null int64
education         32561 non-null object
education-num     32561 non-null int64
marital-status    32561 non-null object
occupation        30718 non-null object
relationship      32561 non-null object
race              32561 non-null object
sex               32561 non-null object
capital-gain      32561 non-null int64
capital-loss      32561 non-null int64
hours-per-week    32561 non-null int64
native-country    31978 non-null object
income            32561 non-null object
dtypes: int64(6), object(9)
memory usage: 3.7+ MB

从上可以看出,存在缺失值,workclass,occupation,native-country ,缺失值都是类别值,因此用众数进行填充

# 缺失值处理
income.fillna(value = {'workclass':income.workclass.mode()[0],
                              'occupation':income.occupation.mode()[0],
                              'native-country':income['native-country'].mode()[0]}, inplace = True)
income.head()
age workclass fnlwgt education education-num marital-status occupation relationship race sex capital-gain capital-loss hours-per-week native-country income
0 39 State-gov 77516 Bachelors 13 Never-married Adm-clerical Not-in-family White Male 2174 0 40 United-States <=50K
1 50 Self-emp-not-inc 83311 Bachelors 13 Married-civ-spouse Exec-managerial Husband White Male 0 0 13 United-States <=50K
2 38 Private 215646 HS-grad 9 Divorced Handlers-cleaners Not-in-family White Male 0 0 40 United-States <=50K
3 53 Private 234721 11th 7 Married-civ-spouse Handlers-cleaners Husband Black Male 0 0 40 United-States <=50K
4 28 Private 338409 Bachelors 13 Married-civ-spouse Prof-specialty Wife Black Female 0 0 40 Cuba <=50K
# 数据的探索性分析
income.describe()

#了解数据的大致分布,但是这种分布不包括字符串类型
age fnlwgt education-num capital-gain capital-loss hours-per-week
count 32561.000000 3.256100e+04 32561.000000 32561.000000 32561.000000 32561.000000
mean 38.581647 1.897784e+05 10.080679 1077.648844 87.303830 40.437456
std 13.640433 1.055500e+05 2.572720 7385.292085 402.960219 12.347429
min 17.000000 1.228500e+04 1.000000 0.000000 0.000000 1.000000
25% 28.000000 1.178270e+05 9.000000 0.000000 0.000000 40.000000
50% 37.000000 1.783560e+05 10.000000 0.000000 0.000000 40.000000
75% 48.000000 2.370510e+05 12.000000 0.000000 0.000000 45.000000
max 90.000000 1.484705e+06 16.000000 99999.000000 4356.000000 99.000000
income.describe(include =[ 'object'])
workclass education marital-status occupation relationship race sex native-country income
count 32561 32561 32561 32561 32561 32561 32561 32561 32561
unique 8 16 7 14 6 5 2 41 2
top Private HS-grad Married-civ-spouse Prof-specialty Husband White Male United-States <=50K
freq 24532 10501 14976 5983 13193 27816 21790 29753 24720
绘制不同收入水平下的年龄核密度图
# 导入绘图模块
import matplotlib.pyplot as plt
# 设置绘图风格
plt.style.use('ggplot')
# 设置多图形的组合
fig, axes = plt.subplots(2, 1)
# 绘制不同收入水平下的年龄核密度图
income.age[income.income == ' <=50K'].plot(kind = 'kde', label = '<=50K', ax = axes[0], legend = True, linestyle = '-')
income.age[income.income == ' >50K'].plot(kind = 'kde', label = '>50K', ax = axes[0], legend = True, linestyle = '--')
# 绘制不同收入水平下的周工作小时数和密度图
income['hours-per-week'][income.income == ' <=50K'].plot(kind = 'kde', label = '<=50K', ax = axes[1], legend = True, linestyle = '-')
income['hours-per-week'][income.income == ' >50K'].plot(kind = 'kde', label = '>50K', ax = axes[1], legend = True, linestyle = '--')
# 显示图形
plt.show()

数据挖掘实战一:输入预测分类_第1张图片

构造不同收入水平下各种族人数的数据
# 构造不同收入水平下各种族人数的数据
race = pd.DataFrame(income.groupby(by = ['race','income']).aggregate(np.size).loc[:,'age'])

# race = pd.DataFrame(income.groupby(by = ['race','income']))

print(race)
# 重设行索引
race = race.reset_index()
# 变量重命名
race.rename(columns={'age':'counts'}, inplace=True)
# 排序
race.sort_values(by = ['race','counts'], ascending=False, inplace=True)

# 构造不同收入水平下各家庭关系人数的数据
relationship = pd.DataFrame(income.groupby(by = ['relationship','income']).aggregate(np.size).loc[:,'age'])
relationship = relationship.reset_index()
relationship.rename(columns={'age':'counts'}, inplace=True)
relationship.sort_values(by = ['relationship','counts'], ascending=False, inplace=True)

print(race)
# 设置图框比例,并绘图
plt.figure(figsize=(9,5))
sns.barplot(x="race", y="counts", hue = 'income', data=race)
plt.show()

plt.figure(figsize=(9,5))
sns.barplot(x="relationship", y="counts", hue = 'income', data=relationship)
plt.show()
                              age
race                income       
 Amer-Indian-Eskimo  <=50K    275
                     >50K      36
 Asian-Pac-Islander  <=50K    763
                     >50K     276
 Black               <=50K   2737
                     >50K     387
 Other               <=50K    246
                     >50K      25
 White               <=50K  20699
                     >50K    7117
                  race  income  counts
8                White   <=50K   20699
9                White    >50K    7117
6                Other   <=50K     246
7                Other    >50K      25
4                Black   <=50K    2737
5                Black    >50K     387
2   Asian-Pac-Islander   <=50K     763
3   Asian-Pac-Islander    >50K     276
0   Amer-Indian-Eskimo   <=50K     275
1   Amer-Indian-Eskimo    >50K      36

数据挖掘实战一:输入预测分类_第2张图片

数据挖掘实战一:输入预测分类_第3张图片

# 离散变量的重编码,重编码是映射为了数字类型
for feature in income.columns:
    if income[feature].dtype == 'object':
        income[feature] = pd.Categorical(income[feature]).codes
income.head()
age workclass fnlwgt education education-num marital-status occupation relationship race sex capital-gain capital-loss hours-per-week native-country income
0 39 6 77516 9 13 4 0 1 4 1 2174 0 40 38 0
1 50 5 83311 9 13 2 3 0 4 1 0 0 13 38 0
2 38 3 215646 11 9 0 5 1 4 1 0 0 40 38 0
3 53 3 234721 1 7 2 5 0 2 1 0 0 40 38 0
4 28 3 338409 9 13 2 9 5 2 0 0 0 40 4 0
# 删除变量
income.drop(['education','fnlwgt'], axis = 1, inplace = True)
income.head()
age workclass education-num marital-status occupation relationship race sex capital-gain capital-loss hours-per-week native-country income
0 39 6 13 4 0 1 4 1 2174 0 40 38 0
1 50 5 13 2 3 0 4 1 0 0 13 38 0
2 38 3 9 0 5 1 4 1 0 0 40 38 0
3 53 3 7 2 5 0 2 1 0 0 40 38 0
4 28 3 13 2 9 5 2 0 0 0 40 4 0
# 数据拆分
from sklearn.model_selection import train_test_split
X_train, X_test, y_train, y_test = train_test_split(income.loc[:,'age':'native-country'], 
                                                    income['income'], train_size = 0.75, 
                                                    random_state = 1234)
print('训练数据集共有%d条观测' %X_train.shape[0])
print('测试数据集共有%d条观测' %X_test.shape[0])
训练数据集共有24420条观测
测试数据集共有8141条观测
# 导入k近邻模型的类
from sklearn.neighbors import KNeighborsClassifier
# 构建k近邻模型
kn = KNeighborsClassifier()
kn.fit(X_train, y_train)
print(kn)

# 预测测试集
kn_pred = kn.predict(X_test)
print(pd.crosstab(kn_pred, y_test))

# 模型得分
print('模型在训练集上的准确率%f' %kn.score(X_train,y_train))
print('模型在测试集上的准确率%f' %kn.score(X_test,y_test))

# # 导入模型评估模块
from sklearn import metrics

# 计算ROC曲线的x轴和y轴数据
fpr, tpr, _ = metrics.roc_curve(y_test,  kn.predict_proba(X_test)[:,1])
# 绘制ROC曲线
plt.plot(fpr, tpr, linestyle = 'solid', color = 'red')
# 添加阴影
plt.stackplot(fpr, tpr, color = 'steelblue')
# 绘制参考线
plt.plot([0,1],[0,1], linestyle = 'dashed', color = 'black')
# 往图中添加文本
plt.text(0.6,0.4,'AUC=%.3f' % metrics.auc(fpr,tpr), fontdict = dict(size = 18))
plt.show()
KNeighborsClassifier(algorithm='auto', leaf_size=30, metric='minkowski',
                     metric_params=None, n_jobs=None, n_neighbors=5, p=2,
                     weights='uniform')
income     0     1
row_0             
0       5637   723
1        589  1192
模型在训练集上的准确率0.890500
模型在测试集上的准确率0.838840

数据挖掘实战一:输入预测分类_第4张图片

# 导入GBDT模型的类
from sklearn.ensemble import GradientBoostingClassifier
# 构建GBDT模型
gbdt = GradientBoostingClassifier()
gbdt.fit(X_train, y_train)
print(gbdt)

# 预测测试集
gbdt_pred = gbdt.predict(X_test)
print(pd.crosstab(gbdt_pred, y_test))

# 模型得分
print('模型在训练集上的准确率%f' %gbdt.score(X_train,y_train))
print('模型在测试集上的准确率%f' %gbdt.score(X_test,y_test))

# 绘制ROC曲线
fpr, tpr, _ = metrics.roc_curve(y_test, gbdt.predict_proba(X_test)[:,1])
plt.plot(fpr, tpr, linestyle = 'solid', color = 'red')
plt.stackplot(fpr, tpr, color = 'steelblue')
plt.plot([0,1],[0,1], linestyle = 'dashed', color = 'black')
plt.text(0.6,0.4,'AUC=%.3f' % metrics.auc(fpr,tpr), fontdict = dict(size = 18))
plt.show()
GradientBoostingClassifier(criterion='friedman_mse', init=None,
                           learning_rate=0.1, loss='deviance', max_depth=3,
                           max_features=None, max_leaf_nodes=None,
                           min_impurity_decrease=0.0, min_impurity_split=None,
                           min_samples_leaf=1, min_samples_split=2,
                           min_weight_fraction_leaf=0.0, n_estimators=100,
                           n_iter_no_change=None, presort='auto',
                           random_state=None, subsample=1.0, tol=0.0001,
                           validation_fraction=0.1, verbose=0,
                           warm_start=False)
income     0     1
row_0             
0       5862   784
1        364  1131
模型在训练集上的准确率0.869451
模型在测试集上的准确率0.858985

数据挖掘实战一:输入预测分类_第5张图片

# K近邻模型的网格搜索法
# 导入网格搜索法的函数
from sklearn.model_selection import GridSearchCV
# 选择不同的参数
k_options = list(range(1,12))
parameters = {'n_neighbors':k_options}
# 搜索不同的K值
grid_kn = GridSearchCV(estimator = KNeighborsClassifier(), param_grid = parameters, cv=10, scoring='accuracy', verbose=0, n_jobs=2)
grid_kn.fit(X_train, y_train)
print(grid_kn)
# 结果输出
grid_kn.cv_results_, grid_kn.best_params_, grid_kn.best_score_  
GridSearchCV(cv=10, error_score='raise-deprecating',
             estimator=KNeighborsClassifier(algorithm='auto', leaf_size=30,
                                            metric='minkowski',
                                            metric_params=None, n_jobs=None,
                                            n_neighbors=5, p=2,
                                            weights='uniform'),
             iid='warn', n_jobs=2,
             param_grid={'n_neighbors': [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]},
             pre_dispatch='2*n_jobs', refit=True, return_train_score=False,
             scoring='accuracy', verbose=0)





({'mean_fit_time': array([0.48654635, 0.46757383, 0.47592268, 0.49453475, 0.47880325,
         0.46707897, 0.48813548, 0.49772682, 0.47156236, 0.46706924,
         0.4772419 ]),
  'std_fit_time': array([0.01297885, 0.0209712 , 0.02072921, 0.0111157 , 0.01790766,
         0.02473602, 0.0163508 , 0.00986862, 0.0199231 , 0.02273718,
         0.03890239]),
  'mean_score_time': array([0.13643334, 0.14554381, 0.14788237, 0.14755256, 0.15460474,
         0.15969527, 0.15761855, 0.15811694, 0.16488271, 0.16624339,
         0.16210868]),
  'std_score_time': array([0.00353061, 0.00297507, 0.00272643, 0.00230722, 0.00245015,
         0.0026701 , 0.0032928 , 0.0039421 , 0.00378626, 0.00273812,
         0.00302414]),
  'param_n_neighbors': masked_array(data=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11],
               mask=[False, False, False, False, False, False, False, False,
                     False, False, False],
         fill_value='?',
              dtype=object),
  'params': [{'n_neighbors': 1},
   {'n_neighbors': 2},
   {'n_neighbors': 3},
   {'n_neighbors': 4},
   {'n_neighbors': 5},
   {'n_neighbors': 6},
   {'n_neighbors': 7},
   {'n_neighbors': 8},
   {'n_neighbors': 9},
   {'n_neighbors': 10},
   {'n_neighbors': 11}],
  'split0_test_score': array([0.8014736 , 0.82316824, 0.8215309 , 0.8264429 , 0.82889889,
         0.83381089, 0.83831355, 0.83503889, 0.83831355, 0.84036021,
         0.84322554]),
  'split1_test_score': array([0.81416291, 0.83585755, 0.83544822, 0.84936553, 0.84240688,
         0.8489562 , 0.84527221, 0.84486287, 0.84240688, 0.84568154,
         0.84117888]),
  'split2_test_score': array([0.81866558, 0.83790422, 0.84158821, 0.85550553, 0.84568154,
         0.85345886, 0.84854687, 0.85100287, 0.84731887, 0.84854687,
         0.84690954]),
  'split3_test_score': array([0.81539091, 0.84854687, 0.84936553, 0.84936553, 0.84527221,
         0.84977487, 0.84568154, 0.84609087, 0.84240688, 0.84486287,
         0.8501842 ]),
  'split4_test_score': array([0.80917281, 0.83538084, 0.83701884, 0.84316134, 0.83783784,
         0.84111384, 0.83783784, 0.83824734, 0.83619984, 0.83865684,
         0.83701884]),
  'split5_test_score': array([0.82186732, 0.83783784, 0.83619984, 0.84111384, 0.83210483,
         0.84029484, 0.83415233, 0.83947584, 0.84234234, 0.84520885,
         0.84316134]),
  'split6_test_score': array([0.81810733, 0.8467841 , 0.84104875, 0.85866448, 0.85210979,
         0.86112249, 0.85006145, 0.85784515, 0.85210979, 0.85088079,
         0.84842278]),
  'split7_test_score': array([0.82466202, 0.8455551 , 0.84965178, 0.85538714, 0.85743548,
         0.86071282, 0.85907415, 0.85743548, 0.86112249, 0.85948382,
         0.85538714]),
  'split8_test_score': array([0.8095043 , 0.83408439, 0.8275297 , 0.83490373, 0.83941008,
         0.84309709, 0.83900041, 0.83981975, 0.84104875, 0.84350676,
         0.84391643]),
  'split9_test_score': array([0.81482999, 0.83941008, 0.83039738, 0.83818107, 0.83900041,
         0.84473576, 0.84514543, 0.84596477, 0.8443261 , 0.8467841 ,
         0.84350676]),
  'mean_test_score': array([0.81478296, 0.83845209, 0.83697789, 0.84520885, 0.84201474,
         0.8477068 , 0.84430794, 0.8455774 , 0.84475839, 0.8463964 ,
         0.84529075]),
  'std_test_score': array([0.00641223, 0.00701638, 0.00852103, 0.00976708, 0.00816947,
         0.00842365, 0.00694003, 0.00746034, 0.00688403, 0.00552323,
         0.00487327]),
  'rank_test_score': array([11,  9, 10,  5,  8,  1,  7,  3,  6,  2,  4], dtype=int32)},
 {'n_neighbors': 6},
 0.8477067977067977)
# 预测测试集
grid_kn_pred = grid_kn.predict(X_test)
print(pd.crosstab(grid_kn_pred, y_test))

# 模型得分
print('模型在训练集上的准确率%f' %grid_kn.score(X_train,y_train))
print('模型在测试集上的准确率%f' %grid_kn.score(X_test,y_test))

# 绘制ROC曲线
fpr, tpr, _ = metrics.roc_curve(y_test, grid_kn.predict_proba(X_test)[:,1])
plt.plot(fpr, tpr, linestyle = 'solid', color = 'red')
plt.stackplot(fpr, tpr, color = 'steelblue')
plt.plot([0,1],[0,1], linestyle = 'dashed', color = 'black')
plt.text(0.6,0.4,'AUC=%.3f' % metrics.auc(fpr,tpr), fontdict = dict(size = 18))
plt.show()
income     0     1
row_0             
0       5834   867
1        392  1048
模型在训练集上的准确率0.882473
模型在测试集上的准确率0.845351

数据挖掘实战一:输入预测分类_第6张图片

##### 一共三万条数据
# GBDT模型的网格搜索法
# 选择不同的参数
learning_rate_options = [0.01,0.05,0.1]
max_depth_options = [3,5,7,9]
n_estimators_options = [100,300,500]
parameters = {'learning_rate':learning_rate_options,'max_depth':max_depth_options,'n_estimators':n_estimators_options}

grid_gbdt = GridSearchCV(estimator = GradientBoostingClassifier(), param_grid = parameters, cv=10, scoring='accuracy', n_jobs=4)
grid_gbdt.fit(X_train, y_train)

# 结果输出
grid_gbdt.                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                             , grid_gbdt.best_params_, grid_gbdt.best_score_  
({'mean_fit_time': array([ 1.04818089,  3.06958101,  4.94570587,  2.10741355,  6.49262218,
         10.31357944,  3.79211471, 12.02091534, 19.82360842,  6.47681627,
         21.55113347, 36.75011723,  0.98570092,  2.69694235,  4.36460373,
          2.08887751,  5.49897749,  8.82655048,  3.94254134, 10.71680896,
         17.65818636,  7.33641875, 20.94076185, 36.44300106,  0.93227916,
          2.53501928,  4.2528425 ,  1.93049173,  5.26935222,  8.83973567,
          3.70204911, 10.58135281, 18.04044161,  7.09446683, 22.26043143,
         37.91747351]),
  'std_fit_time': array([0.00483648, 0.0206611 , 0.02346627, 0.01078171, 0.04006112,
         0.08443816, 0.05559104, 0.16459559, 0.27928162, 0.14452495,
         0.66920276, 0.75215195, 0.01029911, 0.02504625, 0.07578231,
         0.04792665, 0.12168305, 0.20212313, 0.07955249, 0.19383459,
         0.27861969, 0.16787211, 0.50177507, 0.5893442 , 0.0070915 ,
         0.01666785, 0.08162466, 0.03402036, 0.10060916, 0.15109724,
         0.05160359, 0.18684248, 0.14383105, 0.10197265, 0.23977306,
         2.26425837]),
  'mean_score_time': array([0.00471151, 0.0112366 , 0.0171237 , 0.00647447, 0.0172616 ,
         0.0264991 , 0.00875354, 0.02411616, 0.03694515, 0.01122572,
         0.0318507 , 0.05002267, 0.00462885, 0.00959365, 0.01390805,
         0.00651336, 0.01392448, 0.02110162, 0.00865908, 0.01957636,
         0.03109238, 0.01125824, 0.02738936, 0.04511173, 0.00438557,
         0.00871301, 0.01321261, 0.00598361, 0.01336231, 0.02128084,
         0.00788417, 0.01946959, 0.03210063, 0.01043718, 0.0282335 ,
         0.04428184]),
  'std_score_time': array([1.66779829e-04, 1.03325180e-04, 1.07752274e-04, 7.75191493e-05,
         1.42365449e-04, 2.01649271e-04, 8.73166179e-05, 1.47984253e-04,
         2.08206887e-04, 7.54081784e-05, 1.55344452e-04, 2.96993442e-04,
         2.14154570e-05, 4.05286409e-05, 1.29071777e-04, 8.44396348e-05,
         5.86331590e-05, 1.71900728e-04, 4.77111023e-05, 1.84118423e-04,
         3.62546530e-04, 9.20189008e-05, 2.26528819e-04, 3.68539617e-04,
         2.51831874e-05, 6.76646881e-05, 1.10660028e-04, 6.27695447e-05,
         1.87586799e-04, 2.69605476e-04, 4.36938978e-05, 2.31657194e-04,
         2.70279819e-04, 6.23716901e-05, 2.04953060e-04, 5.62525009e-03]),
  'param_learning_rate': masked_array(data=[0.01, 0.01, 0.01, 0.01, 0.01, 0.01, 0.01, 0.01, 0.01,
                     0.01, 0.01, 0.01, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05,
                     0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.1, 0.1, 0.1, 0.1,
                     0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1],
               mask=[False, False, False, False, False, False, False, False,
                     False, False, False, False, False, False, False, False,
                     False, False, False, False, False, False, False, False,
                     False, False, False, False, False, False, False, False,
                     False, False, False, False],
         fill_value='?',
              dtype=object),
  'param_max_depth': masked_array(data=[3, 3, 3, 5, 5, 5, 7, 7, 7, 9, 9, 9, 3, 3, 3, 5, 5, 5,
                     7, 7, 7, 9, 9, 9, 3, 3, 3, 5, 5, 5, 7, 7, 7, 9, 9, 9],
               mask=[False, False, False, False, False, False, False, False,
                     False, False, False, False, False, False, False, False,
                     False, False, False, False, False, False, False, False,
                     False, False, False, False, False, False, False, False,
                     False, False, False, False],
         fill_value='?',
              dtype=object),
  'param_n_estimators': masked_array(data=[100, 300, 500, 100, 300, 500, 100, 300, 500, 100, 300,
                     500, 100, 300, 500, 100, 300, 500, 100, 300, 500, 100,
                     300, 500, 100, 300, 500, 100, 300, 500, 100, 300, 500,
                     100, 300, 500],
               mask=[False, False, False, False, False, False, False, False,
                     False, False, False, False, False, False, False, False,
                     False, False, False, False, False, False, False, False,
                     False, False, False, False, False, False, False, False,
                     False, False, False, False],
         fill_value='?',
              dtype=object),
  'params': [{'learning_rate': 0.01, 'max_depth': 3, 'n_estimators': 100},
   {'learning_rate': 0.01, 'max_depth': 3, 'n_estimators': 300},
   {'learning_rate': 0.01, 'max_depth': 3, 'n_estimators': 500},
   {'learning_rate': 0.01, 'max_depth': 5, 'n_estimators': 100},
   {'learning_rate': 0.01, 'max_depth': 5, 'n_estimators': 300},
   {'learning_rate': 0.01, 'max_depth': 5, 'n_estimators': 500},
   {'learning_rate': 0.01, 'max_depth': 7, 'n_estimators': 100},
   {'learning_rate': 0.01, 'max_depth': 7, 'n_estimators': 300},
   {'learning_rate': 0.01, 'max_depth': 7, 'n_estimators': 500},
   {'learning_rate': 0.01, 'max_depth': 9, 'n_estimators': 100},
   {'learning_rate': 0.01, 'max_depth': 9, 'n_estimators': 300},
   {'learning_rate': 0.01, 'max_depth': 9, 'n_estimators': 500},
   {'learning_rate': 0.05, 'max_depth': 3, 'n_estimators': 100},
   {'learning_rate': 0.05, 'max_depth': 3, 'n_estimators': 300},
   {'learning_rate': 0.05, 'max_depth': 3, 'n_estimators': 500},
   {'learning_rate': 0.05, 'max_depth': 5, 'n_estimators': 100},
   {'learning_rate': 0.05, 'max_depth': 5, 'n_estimators': 300},
   {'learning_rate': 0.05, 'max_depth': 5, 'n_estimators': 500},
   {'learning_rate': 0.05, 'max_depth': 7, 'n_estimators': 100},
   {'learning_rate': 0.05, 'max_depth': 7, 'n_estimators': 300},
   {'learning_rate': 0.05, 'max_depth': 7, 'n_estimators': 500},
   {'learning_rate': 0.05, 'max_depth': 9, 'n_estimators': 100},
   {'learning_rate': 0.05, 'max_depth': 9, 'n_estimators': 300},
   {'learning_rate': 0.05, 'max_depth': 9, 'n_estimators': 500},
   {'learning_rate': 0.1, 'max_depth': 3, 'n_estimators': 100},
   {'learning_rate': 0.1, 'max_depth': 3, 'n_estimators': 300},
   {'learning_rate': 0.1, 'max_depth': 3, 'n_estimators': 500},
   {'learning_rate': 0.1, 'max_depth': 5, 'n_estimators': 100},
   {'learning_rate': 0.1, 'max_depth': 5, 'n_estimators': 300},
   {'learning_rate': 0.1, 'max_depth': 5, 'n_estimators': 500},
   {'learning_rate': 0.1, 'max_depth': 7, 'n_estimators': 100},
   {'learning_rate': 0.1, 'max_depth': 7, 'n_estimators': 300},
   {'learning_rate': 0.1, 'max_depth': 7, 'n_estimators': 500},
   {'learning_rate': 0.1, 'max_depth': 9, 'n_estimators': 100},
   {'learning_rate': 0.1, 'max_depth': 9, 'n_estimators': 300},
   {'learning_rate': 0.1, 'max_depth': 9, 'n_estimators': 500}],
  'split0_test_score': array([0.82685223, 0.83422022, 0.84404421, 0.83381089, 0.84445354,
         0.8514122 , 0.83667622, 0.8465002 , 0.84854687, 0.83381089,
         0.84322554, 0.84322554, 0.84322554, 0.8514122 , 0.85427753,
         0.8526402 , 0.85345886, 0.85550553, 0.84813754, 0.84977487,
         0.84813754, 0.83954155, 0.84527221, 0.83626688, 0.8489562 ,
         0.85345886, 0.85059353, 0.85304953, 0.85059353, 0.85223086,
         0.84936553, 0.84609087, 0.84281621, 0.84445354, 0.83790422,
         0.83462955]),
  'split1_test_score': array([0.85304953, 0.86614818, 0.86860418, 0.86410151, 0.8714695 ,
         0.8726975 , 0.86410151, 0.87556283, 0.87883749, 0.86000819,
         0.8739255 , 0.87597217, 0.86901351, 0.87228817, 0.87679083,
         0.87474417, 0.87924683, 0.88129349, 0.87842816, 0.88088416,
         0.87842816, 0.87679083, 0.87310684, 0.86573885, 0.86778551,
         0.87965616, 0.88252149, 0.87679083, 0.87679083, 0.87842816,
         0.88047483, 0.8739255 , 0.8726975 , 0.8751535 , 0.86369218,
         0.86000819]),
  'split2_test_score': array([0.84486287, 0.85714286, 0.86369218, 0.84977487, 0.86942284,
         0.87187884, 0.85509619, 0.8751535 , 0.87597217, 0.85304953,
         0.87351617, 0.87842816, 0.86369218, 0.87474417, 0.8763815 ,
         0.87187884, 0.87679083, 0.87556283, 0.87433483, 0.8751535 ,
         0.8726975 , 0.87597217, 0.87228817, 0.86778551, 0.86983217,
         0.8763815 , 0.8763815 , 0.87351617, 0.87679083, 0.87187884,
         0.87597217, 0.87351617, 0.86860418, 0.87474417, 0.86410151,
         0.85509619]),
  'split3_test_score': array([0.83872288, 0.85509619, 0.86287352, 0.84527221, 0.87024151,
         0.8776095 , 0.85468686, 0.88006549, 0.88538682, 0.85918952,
         0.87597217, 0.87965616, 0.86369218, 0.87228817, 0.8763815 ,
         0.87801883, 0.88415882, 0.88620549, 0.88374949, 0.88620549,
         0.88252149, 0.88047483, 0.88293082, 0.86983217, 0.87065084,
         0.8763815 , 0.88129349, 0.88211216, 0.88538682, 0.88088416,
         0.88661482, 0.87842816, 0.87065084, 0.88252149, 0.87024151,
         0.86041752]),
  'split4_test_score': array([0.85176085, 0.86322686, 0.86936937, 0.86076986, 0.86895987,
         0.87755938, 0.85995086, 0.87223587, 0.87592138, 0.85462735,
         0.87264537, 0.87018837, 0.86895987, 0.87510238, 0.87510238,
         0.87960688, 0.87592138, 0.87387387, 0.87633088, 0.87428337,
         0.87141687, 0.87100737, 0.85913186, 0.86036036, 0.87469287,
         0.87469287, 0.87264537, 0.87633088, 0.87305487, 0.87387387,
         0.87674038, 0.86732187, 0.85954136, 0.86363636, 0.85135135,
         0.84602785]),
  'split5_test_score': array([0.83701884, 0.84930385, 0.85462735, 0.84930385, 0.85667486,
         0.86977887, 0.85012285, 0.86486486, 0.87346437, 0.84889435,
         0.87305487, 0.87510238, 0.85421785, 0.86977887, 0.87469287,
         0.87100737, 0.87510238, 0.87551188, 0.87469287, 0.87346437,
         0.87346437, 0.87469287, 0.87305487, 0.86527437, 0.86445536,
         0.87674038, 0.87714988, 0.87346437, 0.87755938, 0.87305487,
         0.87837838, 0.87469287, 0.86773137, 0.87387387, 0.86445536,
         0.85094185]),
  'split6_test_score': array([0.84719377, 0.85702581, 0.8648095 , 0.85661614, 0.86603851,
         0.8738222 , 0.85948382, 0.87832855, 0.88529291, 0.85948382,
         0.87709955, 0.8816059 , 0.8648095 , 0.8738222 , 0.88365424,
         0.87546088, 0.88652192, 0.88406391, 0.88324457, 0.88652192,
         0.88365424, 0.8816059 , 0.87996723, 0.8725932 , 0.86972552,
         0.88529291, 0.88816059, 0.88406391, 0.88365424, 0.8828349 ,
         0.88611225, 0.88447358, 0.86972552, 0.8828349 , 0.86603851,
         0.86317083]),
  'split7_test_score': array([0.8455551 , 0.85415813, 0.85702581, 0.85620647, 0.86153216,
         0.86644818, 0.85702581, 0.87054486, 0.8725932 , 0.85661614,
         0.87218353, 0.87423187, 0.85702581, 0.86685785, 0.86808685,
         0.86685785, 0.87013519, 0.87177386, 0.87587054, 0.87668988,
         0.8725932 , 0.87750922, 0.86849652, 0.86439984, 0.8635805 ,
         0.86849652, 0.87054486, 0.87218353, 0.87505121, 0.87668988,
         0.87505121, 0.86931585, 0.86644818, 0.87177386, 0.85948382,
         0.85866448]),
  'split8_test_score': array([0.84104875, 0.85006145, 0.8533388 , 0.84842278, 0.8545678 ,
         0.85989349, 0.84883245, 0.86030315, 0.86562884, 0.84965178,
         0.85866448, 0.8623515 , 0.85374846, 0.86030315, 0.86808685,
         0.85907415, 0.86972552, 0.86849652, 0.86685785, 0.86399017,
         0.8635805 , 0.8648095 , 0.86030315, 0.85497747, 0.85825481,
         0.86644818, 0.86726751, 0.86521917, 0.86644818, 0.86767718,
         0.86439984, 0.8623515 , 0.85866448, 0.86439984, 0.85497747,
         0.83981975]),
  'split9_test_score': array([0.84063908, 0.85292913, 0.85661614, 0.84924211, 0.86071282,
         0.86849652, 0.85210979, 0.87095453, 0.8725932 , 0.85292913,
         0.87095453, 0.87587054, 0.85866448, 0.87054486, 0.87709955,
         0.86931585, 0.87832855, 0.87996723, 0.87218353, 0.87587054,
         0.87218353, 0.87300287, 0.87300287, 0.87013519, 0.86521917,
         0.87546088, 0.87791889, 0.87587054, 0.87750922, 0.87505121,
         0.87587054, 0.8738222 , 0.86767718, 0.86931585, 0.86030315,
         0.86153216]),
  'mean_test_score': array([0.84266994, 0.8539312 , 0.85950041, 0.85135135, 0.86240786,
         0.86895987, 0.85380835, 0.86945127, 0.87342342, 0.85282555,
         0.86912367, 0.87166257, 0.85970516, 0.86871417, 0.87305487,
         0.86986077, 0.87493857, 0.87522523, 0.87338247, 0.87428337,
         0.87186732, 0.87153972, 0.86875512, 0.86273546, 0.86531532,
         0.87330057, 0.87444717, 0.87325962, 0.87428337, 0.87325962,
         0.87489762, 0.87039312, 0.86445536, 0.87027027, 0.85925471,
         0.8530303 ]),
  'std_test_score': array([0.00727039, 0.00826494, 0.0074335 , 0.00816746, 0.00818114,
         0.00769633, 0.00723565, 0.00950916, 0.01003632, 0.00735759,
         0.00986804, 0.01077537, 0.00758753, 0.00712003, 0.00757436,
         0.00802774, 0.00874896, 0.00840377, 0.00963988, 0.01025699,
         0.00965164, 0.01160156, 0.01053757, 0.0100632 , 0.0069548 ,
         0.00827922, 0.00981026, 0.00835648, 0.00932593, 0.00816926,
         0.01036404, 0.00990618, 0.00838173, 0.01055096, 0.00880306,
         0.0093857 ]),
  'rank_test_score': array([36, 31, 29, 35, 27, 21, 32, 19,  7, 34, 20, 14, 28, 23, 12, 18,  2,
          1,  8,  5, 13, 15, 22, 26, 24,  9,  4, 10,  5, 10,  3, 16, 25, 17,
         30, 33], dtype=int32)},
 {'learning_rate': 0.05, 'max_depth': 5, 'n_estimators': 500},
 0.8752252252252253)
# 预测测试集
grid_gbdt_pred = grid_gbdt.predict(X_test)
print(pd.crosstab(grid_gbdt_pred, y_test))

# 模型得分
print('模型在训练集上的准确率%f' %grid_gbdt.score(X_train,y_train))
print('模型在测试集上的准确率%f' %grid_gbdt.score(X_test,y_test))

# 绘制ROC曲线
fpr, tpr, _ = metrics.roc_curve(y_test, grid_gbdt_pred)
plt.plot(fpr, tpr, linestyle = 'solid', color = 'red')
plt.stackplot(fpr, tpr, color = 'steelblue')
plt.plot([0,1],[0,1], linestyle = 'dashed', color = 'black')
plt.text(0.6,0.4,'AUC=%.3f' % metrics.auc(fpr,tpr), fontdict = dict(size = 18))
plt.show()
income     0     1
row_0             
0       5833   655
1        393  1260
模型在训练集上的准确率0.897379
模型在测试集上的准确率0.871269

数据挖掘实战一:输入预测分类_第7张图片
代码来自:
从零开始学python数据挖掘与分析第二章

你可能感兴趣的:(数据挖掘)