作业:尝试针对之前的心脏病项目ipynb,将他按照今天的示例项目整理成规范的形式,思考下哪些部分可以未来复用。
在有多级目录时,相对导入仅在同一包内有效,尤其在下级文件导入上级文件夹中的文件。
# src/config.py
CONFIG = {
"data_path": PROJECT_ROOT / "data/raw/heart.csv",
"test_size": 0.2,
"random_state": 42,
"models": {
"random_forest": {
"n_estimators": 100,
"max_depth": 5
},
"xgboost": {
"learning_rate": 0.1,
"max_depth": 3,
"n_estimators": 200
}
}
}
# src/data/loader.py
from pathlib import Path
import pandas as pd
from sklearn.model_selection import train_test_split
from src.config import CONFIG
def load_data() -> tuple:
"""加载并拆分数据集"""
df = pd.read_csv(CONFIG["data_path"])
# 假设最后一列是目标变量
X = df.iloc[:, :-1]
y = df.iloc[:, -1]
return train_test_split(
X, y,
test_size=CONFIG["test_size"],
random_state=CONFIG["random_state"]
)
# src/models/base_model.py
from abc import ABC, abstractmethod
import pandas as pd
class BaseModel(ABC):
"""所有模型的统一接口"""
@abstractmethod
def train(self, X_train: pd.DataFrame, y_train: pd.Series):
pass
@abstractmethod
def predict(self, X_test: pd.DataFrame) -> pd.Series:
pass
@abstractmethod
def save(self, path: str):
pass
# src/models/random_forest.py
from sklearn.ensemble import RandomForestClassifier
from .base_model import BaseModel
from src.config import CONFIG
class RandomForestModel(BaseModel):
def __init__(self):
self.model = RandomForestClassifier(
n_estimators=CONFIG["models"]["random_forest"]["n_estimators"],
max_depth=CONFIG["models"]["random_forest"]["max_depth"],
random_state=CONFIG["random_state"]
)
def train(self, X_train, y_train):
self.model.fit(X_train, y_train)
def predict(self, X_test):
return self.model.predict(X_test)
def save(self, path):
joblib.dump(self.model, path)
# src/evaluation/metrics.py
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
def calculate_all_metrics(y_true, y_pred) -> dict:
return {
"accuracy": accuracy_score(y_true, y_pred),
"precision": precision_score(y_true, y_pred),
"recall": recall_score(y_true, y_pred),
"f1": f1_score(y_true, y_pred)
}
# scripts/train_model.py
from src.models import train
if __name__ == "__main__":
train.train_all_models()