学习来源:@浙大疏锦行
SHAP(SHapley Additive exPlanations)库是一个用于解释机器学习模型预测结果的开源 Python 库。
**一、核心概念**
1. **Shapley 值**
* 它来源于合作博弈论。在机器学习模型解释的语境下,可以这样理解:对于一个模型的预测结果,每个特征都看作是一个 “玩家”,模型的输出是这些 “玩家” 合作的结果。Shapley 值表示每个特征对预测结果的平均边际贡献。例如,在一个预测房价的模型中,特征包括房屋面积、房龄、周边学校质量等。Shapley 值能够衡量每个特征(如房屋面积)对房价预测结果的独立贡献,考虑了该特征在所有可能的特征组合中的贡献情况。
2. **SHAP 值的性质**
* **局部准确性** :SHAP 值的总和应该等于模型的输出。这保证了模型输出能够被准确地分解为各个特征的贡献之和。
* **一致性** :如果一个特征在模型中对预测结果的贡献更大,那么这个特征应该分配到更大的 SHAP 值。例如,当模型更依赖高分辨率的医学影像特征来诊断疾病时,该影像特征的 SHAP 值会比其他不太重要的特征更大。
**二、功能特点**
1. **模型无关性**
* SHAP 库可以用于解释各种类型的机器学习模型,无论是线性模型(如线性回归)、树模型(如决策树、随机森林、梯度提升树)还是深度学习模型(如神经网络)。这使得它具有很强的通用性。例如,对于一个复杂的深度神经网络模型,SHAP 能够像解释简单的线性模型一样,提供各个神经元或输入特征对输出的贡献解释。
2. **丰富的可视化功能**
* **SHAP 散点图** :可以展示单个特征对模型输出的全局影响。特征值在 x 轴上,SHAP 值在 y 轴上。通过这种散点图,能够观察到特征值的变化是如何影响模型预测结果的。例如,在一个预测信用卡违约概率的模型中,收入水平这个特征的 SHAP 散点图可以直观地展示高收入是否与低违约概率相关。
* **依赖图** :可以用于分析两个特征之间的相互作用。它能够展示一个特征的 SHAP 值如何随着另一个特征的变化而变化。比如,在一个预测产品销量的模型中,通过依赖图可以观察到广告投入和产品价格这两个特征之间的相互关系,以及它们是如何共同影响销量预测的。
* **摘要图(Summary Plot)** :这是一种展示所有特征对模型预测结果贡献的全局概览图。它将特征按照 SHAP 值的绝对值大小排序,能够快速了解哪些特征对模型输出影响最大。在基因数据分析中,摘要图可以帮助研究人员快速确定对疾病易感性预测最重要的基因特征。
3. **性能优化**
* SHAP 库在实现过程中考虑了性能问题,对于一些常用模型(如 XGBoost、LightGBM 等)提供了高效的计算方法。例如,它利用树结构的特点来加速 SHAP 值的计算,使得在处理大规模数据和复杂树模型时能够快速得到解释结果。
import shap
import matplotlib.pyplot as plt
# 初始化 SHAP 解释器
explainer = shap.TreeExplainer(rf_model)
# 计算 SHAP 值(基于测试集),这个shap_values是一个numpy数组,表示每个特征对每个样本的贡献值
# 这里大家先知道这是个numpy数组即可,我们后面学习完numpy在来回头解读这个值
shap_values = explainer.shap_values(X_test) # 这个计算耗时
shap_values # 每一行代表一个样本,每一列代表一个特征,值表示该特征对该样本的预测结果的影响程度。正值表示该特征对预测结果有正向影响,负值表示负向影响。
shap_values.shape # 第一维是样本数,第二维是特征数,第三维是类别数
print("shap_values shape:", shap_values.shape)
print("shap_values[0] shape:", shap_values[0].shape)
print("shap_values[:, :, 0] shape:", shap_values[:, :, 0].shape)
print("X_test shape:", X_test.shape)
# --- 1. SHAP 特征重要性条形图 (Summary Plot - Bar) ---
print("--- 1. SHAP 特征重要性条形图 ---")
shap.summary_plot(shap_values[:, :, 0], X_test, plot_type="bar",show=False) # 这里的show=False表示不直接显示图形,这样可以继续用plt来修改元素,不然就直接输出了
plt.title("SHAP Feature Importance (Bar Plot)")
plt.show()
# --- 2. SHAP 特征重要性蜂巢图 (Summary Plot - Violin) ---
print("--- 2. SHAP 特征重要性蜂巢图 ---")
shap.summary_plot(shap_values[:, :, 0], X_test,plot_type="violin",show=False,max_display=10) # 这里的show=False表示不直接显示图形,这样可以继续用plt来修改元素,不然就直接输出了
plt.title("SHAP Feature Importance (Violin Plot)")
plt.show()
# 注意下上面几个参数,plot_type可以是bar和violin,max_display表示显示前多少个特征,默认是20个
# --- 新增:SHAP 特征重要性条形图 (Bar Plot) ---
print("--- SHAP 特征重要性条形图 ---")
# 创建 Explanation 对象适配 shap.plots.bar
shap_explanation = shap.Explanation(
values=shap_values[:, :, 0], # 提取对应维度的 SHAP 值
base_values=explainer.expected_value[0], # 模型基准值(如分类任务需调整)
data=X_test.values, # 特征数据
feature_names=X_test.columns.tolist() # 特征名称
)
# 绘制条形图
plt.figure(figsize=(10, 6)) # 统一图形大小
shap.plots.bar(shap_explanation,
max_display=10, # 显示前10个重要特征
show=False)
plt.title("SHAP Feature Importance (Bar Plot)")
plt.tight_layout()
plt.savefig("SHAP_combined_plots.pdf", format='pdf', bbox_inches='tight') # 保存组合图
plt.show()
# --- 新增:局部特征重要性条形图 ---
import matplotlib.pyplot as plt
# 选择单个样本(例如第5个测试样本)
sample_index = 5 # 可替换为需要分析的样本索引
shap_explanation = shap.Explanation(
values=shap_values[:, :, 0], # 提取指定类别的 SHAP 值
data=X_test.values,
feature_names=X_test.columns.tolist()
)
# 绘制局部条形图
plt.figure(figsize=(10, 6)) # 统一图形尺寸
shap.plots.bar(
shap_explanation[sample_index], # 关键:提取单个样本的 SHAP 值
show_data=True, # 显示特征值(默认灰色显示在左侧)[2](@ref)
max_display=10, # 限制显示前10个重要特征
show=False # 禁用自动显示,通过plt控制
)
plt.title(f"Local SHAP Feature Importance (Sample {sample_index})")
plt.tight_layout()
plt.show()
# --- 新增:SHAP 依赖图 (Dependence Plot) ---
print("--- SHAP 依赖图 ---")
plt.figure(figsize=(10, 6))
# 绘制 'age' 特征的依赖图,交互特征自动选择
shap.dependence_plot(
ind='Annual Income', # 主特征名称
shap_values=shap_values[:, :, 0], # 提取指定类别的 SHAP 值
features=X_test, # 特征数据(需为 DataFrame)
interaction_index='auto', # 自动选择交互特征(如网页1中的 'thal')
show=False
)
# 优化图形显示
plt.title("SHAP Dependence Plot for Annual Income", fontsize=14)
plt.xlabel("Annual Income", fontsize=12)
plt.ylabel("SHAP Value (Impact on Prediction)", fontsize=12)
plt.xticks(fontsize=10)
plt.yticks(fontsize=10)
plt.grid(alpha=0.3)
plt.tight_layout()
plt.savefig("SHAP_dependence_plot.pdf", format='pdf', dpi=1200, bbox_inches='tight')
plt.show()
# --- 2. SHAP Force Plot 局部解释 ---
print("--- 2. SHAP Force Plot 局部解释 ---")
import matplotlib.pyplot as plt
# 选择关键样本(例如预测概率边界样本)
sample_index = np.argmax(np.abs(shap_values[:, :, 0].sum(axis=1) + explainer.expected_value[0])) # 选择对模型输出影响最大的样本
# 创建 Force Plot
plt.figure(figsize=(12, 4))
shap.force_plot(
explainer.expected_value[0], # 基准值
shap_values[sample_index, :, 0], # 样本SHAP值
X_test.iloc[sample_index],
feature_names=X_test.columns.tolist(),
matplotlib=True, # 集成到matplotlib
show=False, # 禁用自动显示
text_rotation=15, # 特征名称旋转角度
plot_cmap=['#FF0051', '#008BFB'] # 红蓝配色[3,6](@ref)
)
# 添加自定义标题和样式
plt.title(f"Local Feature Contributions (Sample Index: {sample_index})",
fontsize=14, y=1.08)
plt.grid(axis='y', linestyle='--', alpha=0.6)
plt.tight_layout()
# 保存高清矢量图
plt.savefig("SHAP_force_plot.pdf", format='pdf', dpi=1200, bbox_inches='tight')
plt.show()