sklearn 决策树实例——泰坦尼克号生存率预测

题目地址:Titanic: Machine Learning from Disaster
相关的测试集和训练集我在excel进行了稍微的调整,链接:提取码: mdwi
分数不高,但是用的都是最基本的知识

from sklearn.tree import DecisionTreeClassifier
import pandas as pd
import warnings
from sklearn import tree  # 导入决策树
import graphviz
"""
打印处理
"""
# 不显示warning
warnings.filterwarnings('ignore')
# 显示所有列
pd.set_option('display.max_columns', None)
# 显示所有行
pd.set_option('display.max_rows', None)
# 显示宽度为1000
pd.set_option('display.width', 1000)


def prep(csvfile):
    """
    数据预处理
    """
    df = pd.read_csv(csvfile)
    # Embarked 1 = Cherbourg, 2 = Queenstown, 3 = Southampton
    target1, target2, target3 = df['Embarked'] == 'C', df['Embarked'] == 'Q', df['Embarked'] == 'S'
    df['Embarked'][target1], df['Embarked'][target2], df['Embarked'][target3] = 1, 2, 3
    # Sex 1 = male 0 = female
    target1, target2 = df['Sex'] == 'male', df['Sex'] == 'female'
    df['Sex'][target1], df['Sex'][target2] = 1, 0
    # 处理nan数据
    df['Age'][df['Age'].isna()] = df['Age'].mean()
    df['Embarked'][df['Embarked'].isna()] = 3
    df['Fare'][df['Fare'].isna()] = df['Fare'].mean()
    # 打印预览
    # print(df.head(5))
    return df


"""
获取训练集、测试集
"""
# 训练集
train = prep('data.csv')
Y_train = train['Survived'].values
X_train = train.drop('Survived', axis=1).values
# 测试集
X_test = prep('test.csv')

"""
实例化&训练决策树
"""
dtc = DecisionTreeClassifier(splitter='random', max_depth=10).fit(X_train, Y_train)
# 打分
print('训练集的分数:{}分'.format(round(100 * dtc.score(X_train, Y_train), 1)))
# 查看重要度
print(*zip([column for column in (train.drop('Survived', axis=1))],
           dtc.feature_importances_
           )
      )
# 测试
result = pd.DataFrame(dtc.predict(X_test), columns=['Survived'])
result.index = result.index + 892
# 格式化输出测试结果
result.to_csv('Titanic_result.csv',
              index_label='PassengerId'
              )

# Create Tree Picture
dot_data = tree.export_graphviz(dtc,
                                filled=True,
                                rounded=True,
                                out_file=None,
                                )
graph = graphviz.Source(dot_data)
graph.render()

结果:
在这里插入图片描述

你可能感兴趣的:(sklearn 决策树实例——泰坦尼克号生存率预测)