Python实例题
题目
问题描述
解题思路
关键代码框架
难点分析
股票数据分析与预测系统
开发一个股票数据分析系统,实现以下功能:
requests
库调用 Alpha Vantage 或 Yahoo Finance API 获取数据pandas
进行数据处理和分析scikit-learn
构建预测模型matplotlib
或plotly
进行可视化import requests
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestRegressor
from sklearn.metrics import mean_squared_error, mean_absolute_error
import time
import os
from datetime import datetime, timedelta
# 设置中文显示
plt.rcParams["font.family"] = ["SimHei", "WenQuanYi Micro Hei", "Heiti TC"]
plt.rcParams['axes.unicode_minus'] = False # 解决负号显示问题
class StockDataCollector:
def __init__(self, api_key="YOUR_API_KEY"):
self.api_key = api_key
self.base_url = "https://www.alphavantage.co/query"
self.cache_dir = "stock_data_cache"
os.makedirs(self.cache_dir, exist_ok=True)
def _cache_filename(self, symbol, interval):
"""生成缓存文件名"""
return os.path.join(self.cache_dir, f"{symbol}_{interval}.csv")
def _save_to_cache(self, symbol, interval, df):
"""保存数据到缓存"""
df.to_csv(self._cache_filename(symbol, interval))
def _load_from_cache(self, symbol, interval, days=7):
"""从缓存加载数据"""
filename = self._cache_filename(symbol, interval)
if not os.path.exists(filename):
return None
# 检查缓存是否过期
file_mtime = datetime.fromtimestamp(os.path.getmtime(filename))
if datetime.now() - file_mtime > timedelta(days=days):
return None
return pd.read_csv(filename, index_col=0, parse_dates=True)
def get_stock_data(self, symbol, interval="daily", output_size="full", use_cache=True):
"""获取股票数据"""
# 尝试从缓存加载
if use_cache:
cached_data = self._load_from_cache(symbol, interval)
if cached_data is not None:
print(f"从缓存加载 {symbol} 的 {interval} 数据")
return cached_data
# 从API获取数据
print(f"从API获取 {symbol} 的 {interval} 数据")
function_map = {
"daily": "TIME_SERIES_DAILY",
"weekly": "TIME_SERIES_WEEKLY",
"monthly": "TIME_SERIES_MONTHLY"
}
if interval not in function_map:
raise ValueError(f"不支持的时间间隔: {interval}")
params = {
"function": function_map[interval],
"symbol": symbol,
"apikey": self.api_key,
"outputsize": output_size
}
try:
response = requests.get(self.base_url, params=params)
response.raise_for_status()
data = response.json()
# 解析数据
if interval == "daily":
time_series_key = "Time Series (Daily)"
elif interval == "weekly":
time_series_key = "Weekly Time Series"
else: # monthly
time_series_key = "Monthly Time Series"
if time_series_key not in data:
raise ValueError(f"API返回格式错误: {data}")
df = pd.DataFrame(data[time_series_key]).T
df.columns = [col.split(" ")[1] for col in df.columns]
df = df.astype(float)
df.index = pd.to_datetime(df.index)
df.sort_index(inplace=True)
# 保存到缓存
self._save_to_cache(symbol, interval, df)
return df
except Exception as e:
print(f"获取数据失败: {e}")
return None
class TechnicalIndicator:
@staticmethod
def calculate_moving_average(df, window=5):
"""计算移动平均线"""
df[f"MA{window}"] = df["close"].rolling(window=window).mean()
return df
@staticmethod
def calculate_macd(df, fast=12, slow=26, signal=9):
"""计算MACD指标"""
df["EMA_fast"] = df["close"].ewm(span=fast, adjust=False).mean()
df["EMA_slow"] = df["close"].ewm(span=slow, adjust=False).mean()
df["MACD"] = df["EMA_fast"] - df["EMA_slow"]
df["MACD_signal"] = df["MACD"].ewm(span=signal, adjust=False).mean()
df["MACD_hist"] = df["MACD"] - df["MACD_signal"]
return df
@staticmethod
def calculate_kdj(df, window=9):
"""计算KDJ指标"""
low_min = df["low"].rolling(window=window).min()
high_max = df["high"].rolling(window=window).max()
df["RSV"] = (df["close"] - low_min) / (high_max - low_min) * 100
df["K"] = df["RSV"].ewm(com=2).mean()
df["D"] = df["K"].ewm(com=2).mean()
df["J"] = 3 * df["K"] - 2 * df["D"]
# 处理可能的NaN值
df.fillna(0, inplace=True)
return df
@staticmethod
def calculate_rsi(df, window=14):
"""计算RSI指标"""
delta = df["close"].diff()
gain = delta.where(delta > 0, 0)
loss = -delta.where(delta < 0, 0)
avg_gain = gain.rolling(window=window).mean()
avg_loss = loss.rolling(window=window).mean()
rs = avg_gain / avg_loss
df["RSI"] = 100 - (100 / (1 + rs))
# 处理可能的NaN值
df.fillna(0, inplace=True)
return df
class StockPredictor:
def __init__(self):
self.model = RandomForestRegressor(n_estimators=100, random_state=42)
self.features = []
self.target = "close"
def prepare_data(self, df, look_back=5, forecast_days=1):
"""准备训练数据"""
# 添加技术指标作为特征
df = TechnicalIndicator.calculate_moving_average(df, 5)
df = TechnicalIndicator.calculate_moving_average(df, 10)
df = TechnicalIndicator.calculate_moving_average(df, 20)
df = TechnicalIndicator.calculate_macd(df)
df = TechnicalIndicator.calculate_kdj(df)
df = TechnicalIndicator.calculate_rsi(df)
# 创建滞后特征
for i in range(1, look_back + 1):
df[f"lag_{i}"] = df[self.target].shift(i)
# 创建目标变量(未来价格)
df[f"future_{forecast_days}"] = df[self.target].shift(-forecast_days)
# 删除包含NaN的行
df.dropna(inplace=True)
# 选择特征列
self.features = [col for col in df.columns if col not in
[self.target, f"future_{forecast_days}", "EMA_fast", "EMA_slow"]]
return df
def train(self, df, look_back=5, forecast_days=1):
"""训练模型"""
prepared_df = self.prepare_data(df, look_back, forecast_days)
# 划分训练集和测试集
X = prepared_df[self.features]
y = prepared_df[f"future_{forecast_days}"]
X_train, X_test, y_train, y_test = train_test_split(
X, y, test_size=0.2, shuffle=False
)
# 训练模型
self.model.fit(X_train, y_train)
# 评估模型
y_pred = self.model.predict(X_test)
mse = mean_squared_error(y_test, y_pred)
mae = mean_absolute_error(y_test, y_pred)
print(f"模型评估 - MSE: {mse:.4f}, MAE: {mae:.4f}")
return {
"mse": mse,
"mae": mae,
"y_test": y_test,
"y_pred": y_pred
}
def predict(self, df, look_back=5, forecast_days=1):
"""预测未来价格"""
prepared_df = self.prepare_data(df, look_back, forecast_days)
# 使用最后几行数据进行预测
last_data = prepared_df[self.features].iloc[-1:].values
prediction = self.model.predict(last_data)
return prediction[0]
def plot_prediction(self, df, look_back=5, forecast_days=1):
"""可视化预测结果"""
prepared_df = self.prepare_data(df, look_back, forecast_days)
# 划分训练集和测试集
X = prepared_df[self.features]
y = prepared_df[f"future_{forecast_days}"]
X_train, X_test, y_train, y_test = train_test_split(
X, y, test_size=0.2, shuffle=False
)
# 预测
y_pred = self.model.predict(X_test)
# 绘制结果
plt.figure(figsize=(12, 6))
plt.plot(y_test.index, y_test.values, label="实际价格")
plt.plot(y_test.index, y_pred, label="预测价格", linestyle="--")
plt.title("股票价格预测")
plt.xlabel("日期")
plt.ylabel("价格")
plt.legend()
plt.grid(True)
plt.show()
class TradingStrategy:
@staticmethod
def macd_crossover_strategy(df):
"""MACD交叉策略"""
signals = []
positions = []
for i in range(len(df)):
if i == 0:
signals.append(0)
positions.append(0)
continue
# MACD金叉: MACD线从下向上穿过信号线
if df["MACD"].iloc[i] > df["MACD_signal"].iloc[i] and \
df["MACD"].iloc[i-1] <= df["MACD_signal"].iloc[i-1]:
signals.append(1) # 买入信号
# MACD死叉: MACD线从上向下穿过信号线
elif df["MACD"].iloc[i] < df["MACD_signal"].iloc[i] and \
df["MACD"].iloc[i-1] >= df["MACD_signal"].iloc[i-1]:
signals.append(-1) # 卖出信号
else:
signals.append(0) # 持有信号
# 根据信号更新持仓
if signals[i] == 1:
positions.append(1) # 持有股票
elif signals[i] == -1:
positions.append(0) # 空仓
else:
positions.append(positions[i-1]) # 保持上一状态
df["signal"] = signals
df["position"] = positions
return df
@staticmethod
def evaluate_strategy(df):
"""评估交易策略"""
# 计算策略收益
df["daily_return"] = df["close"].pct_change()
df["strategy_return"] = df["position"].shift(1) * df["daily_return"]
# 计算累积收益
df["cumulative_returns"] = (1 + df["daily_return"]).cumprod()
df["cumulative_strategy"] = (1 + df["strategy_return"]).cumprod()
# 计算最终收益率
total_return = df["cumulative_returns"].iloc[-1] - 1
strategy_return = df["cumulative_strategy"].iloc[-1] - 1
# 计算夏普比率
risk_free_rate = 0.03 # 假设无风险年化收益率为3%
trading_days = 252
daily_risk_free = (1 + risk_free_rate) ** (1/trading_days) - 1
excess_returns = df["strategy_return"] - daily_risk_free
sharpe_ratio = np.sqrt(trading_days) * excess_returns.mean() / excess_returns.std()
print(f"基准收益率: {total_return:.2%}")
print(f"策略收益率: {strategy_return:.2%}")
print(f"夏普比率: {sharpe_ratio:.2f}")
# 绘制累积收益曲线
plt.figure(figsize=(12, 6))
plt.plot(df.index, df["cumulative_returns"], label="基准收益")
plt.plot(df.index, df["cumulative_strategy"], label="策略收益")
plt.title("策略收益 vs 基准收益")
plt.xlabel("日期")
plt.ylabel("累积收益")
plt.legend()
plt.grid(True)
plt.show()
return {
"total_return": total_return,
"strategy_return": strategy_return,
"sharpe_ratio": sharpe_ratio
}
# 使用示例
if __name__ == "__main__":
collector = StockDataCollector(api_key="YOUR_API_KEY")
symbol = "AAPL" # 苹果公司股票
# 获取数据
df = collector.get_stock_data(symbol)
if df is None:
print("无法获取股票数据")
else:
# 预测
predictor = StockPredictor()
results = predictor.train(df, look_back=5, forecast_days=1)
next_day_price = predictor.predict(df, look_back=5, forecast_days=1)
print(f"预测下一个交易日价格: {next_day_price:.2f}")
predictor.plot_prediction(df)
# 策略分析
strategy_df = TechnicalIndicator.calculate_macd(df.copy())
strategy_df = TradingStrategy.macd_crossover_strategy(strategy_df)
TradingStrategy.evaluate_strategy(strategy_df)