Python打卡DAY31

今日的示例代码包含2个部分

  1. notebook文件夹内的ipynb文件,介绍下今天的思路
  2. 项目文件夹中其他部分:拆分后的信贷项目,学习下如何拆分的,未来你看到的很多大项目都是类似的拆分方法

知识点回顾

  1. 规范的文件命名
  2. 规范的文件夹管理
  3. 机器学习项目的拆分
  4. 编码格式和类型注解

作业:尝试针对之前的心脏病项目ipynb,将他按照今天的示例项目整理成规范的形式,思考下哪些部分可以未来复用。

src/data/data_loader.py

import pandas as pd
from sklearn.model_selection import train_test_split

def load_and_split_data(file_path, target_column, test_size=0.2, random_state=42):
    """
    加载数据并划分训练集和测试集
    """
    data = pd.read_csv(file_path)
    X = data.drop(target_column, axis=1)
    y = data[target_column]
    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=test_size, random_state=random_state)
    return X_train, X_test, y_train, y_test
src/models/random_forest.py
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import classification_report, confusion_matrix
import time

def train_random_forest(X_train, y_train, X_test, y_test, random_state=42):
    """
    训练随机森林模型并评估性能
    """
    start_time = time.time()
    model = RandomForestClassifier(random_state=random_state)
    model.fit(X_train, y_train)
    y_pred = model.predict(X_test)
    end_time = time.time()

    print(f"训练与预测耗时: {end_time - start_time:.4f} 秒")
    print("\n默认随机森林 在测试集上的分类报告:")
    print(classification_report(y_test, y_pred))
    print("默认随机森林 在测试集上的混淆矩阵:")
    print(confusion_matrix(y_test, y_pred))
    return model

src/utils/visualization.py

import shap
import matplotlib.pyplot as plt

def plot_shap_values(model, X_test):
    """
    绘制SHAP值的条形图、蜂巢图和依赖图
    """
    explainer = shap.TreeExplainer(model)
    shap_values = explainer.shap_values(X_test)

    print("shap_values[0] shape:", shap_values[0].shape)
    print("X_test shape:", X_test.shape)

    # SHAP特征重要性条形图
    print("--- 1. SHAP 特征重要性条形图 ---")
    shap.summary_plot(shap_values[0], X_test, plot_type="bar", show=False)
    plt.title("SHAP Feature Importance (Bar Plot)")
    plt.show()

    # SHAP特征重要性蜂巢图
    print("--- 2. SHAP 特征重要性蜂巢图 ---")
    shap.summary_plot(shap_values[0], X_test, plot_type="violin", show=False, max_display=10)
    plt.title("SHAP Feature Importance (Violin Plot)")
    plt.show()

    # SHAP特征重要性依赖图
    print("--- 3. SHAP 特征重要性依赖图 ---")
    shap.dependence_plot('Years in current job', shap_values[0], X_test, show=False)
    plt.title("SHAP Feature Importance (dependence plot)")
    plt.show()
src/main.py
from src.data.data_loader import load_and_split_data
from src.models.random_forest import train_random_forest
from src.utils.visualization import plot_shap_values

if __name__ == "__main__":
    # 数据加载与划分
    file_path = "data/raw/heart.csv"
    target_column = "target"
    X_train, X_test, y_train, y_test = load_and_split_data(file_path, target_column)

    # 模型训练与评估
    model = train_random_forest(X_train, y_train, X_test, y_test)

    # SHAP值可视化
    plot_shap_values(model, X_test)

@浙大疏锦行

你可能感兴趣的:(浙大疏锦行打卡,python,机器学习)