import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix
import seaborn as sns
matplotlib.pyplot
:Python 中最常用的绘图库,用于创建各种图表confusion_matrix
:来自 scikit-learn 库,专门用于计算分类模型的混淆矩阵seaborn
:基于 matplotlib 的高级可视化库,提供更美观的默认图表样式y_true = [0, 1, 0, 1, 1] # 真实标签
y_pred = [0, 1, 1, 1, 0] # 模型预测标签
这是两组简单的分类标签数据,用于演示混淆矩阵的计算。其中:
cm = confusion_matrix(y_true, y_pred)
confusion_matrix
函数计算出的混淆矩阵是一个二维数组,其结构为:
[[TN, FP],
[FN, TP]]
对于示例数据,计算得到的混淆矩阵为:
[[1, 1],
[1, 2]]
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues')
heatmap():seaborn 中用于绘制热力图的函数
annot=True:在每个单元格中显示数值
fmt='d':指定数值格式为整数
cmap='Blues':使用蓝色系的配色方案
plt.xlabel('Predicted') # x轴标签:预测类别
plt.ylabel('True') # y轴标签:真实类别
plt.title('Confusion Matrix') # 图表标题plt.show()
最终生成的混淆矩阵热图直观展示了模型在分类任务中的表现:
这个可视化工具常用于评估分类模型的性能,特别是在处理不平衡数据集时非常有用。
import matplotlib.pyplot as plt
# 示例:绘制混淆矩阵热图
from sklearn.metrics import confusion_matrix
import seaborn as sns
y_true = [0, 1, 0, 1, 1]
y_pred = [0, 1, 1, 1, 0]
cm = confusion_matrix(y_true, y_pred)
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues')
plt.xlabel('Predicted')
plt.ylabel('True')
plt.title('Confusion Matrix')
plt.show()
混淆矩阵的计算核心是统计 “真实标签” 与 “预测标签” 的所有组合情况
y_true = [0, 1, 0, 1, 1] # 真实标签:[0,1,0,1,1]
y_pred = [0, 1, 1, 1, 0] # 预测标签:[0,1,1,1,0]
y_true = [0, 1, 0, 1, 1]
(共 5 个样本,0 和 1 为两类)y_pred = [0, 1, 1, 1, 0]
(模型对每个样本的预测结果)二分类问题的混淆矩阵是2×2 矩阵,行代表 “真实标签”,列代表 “预测标签”,包含 4 个核心指标:
真实标签 \ 预测标签 | 预测为 0(负类) | 预测为 1(正类) |
---|---|---|
真实为 0(负类) | TN(真负例) | FP(假正例) |
真实为 1(正类) | FN(假负例) | TP(真正例) |
将每个样本的 “真实标签” 和 “预测标签” 一一对应,判断属于哪类指标:
样本序号 | 真实标签 | 预测标签 | 对应指标 |
---|---|---|---|
1 | 0 | 0 | TN |
2 | 1 | 1 | TP |
3 | 0 | 1 | FP |
4 | 1 | 1 | TP |
5 | 1 | 0 | FN |
统计每个指标的数量:
按 “真实标签行 × 预测标签列” 的顺序,将统计结果填入矩阵:
# 第一行(真实为0):TN=1,FP=1
# 第二行(真实为1):FN=1,TP=2混淆矩阵 = [[1, 1], # 真实0对应的预测结果
[1, 2]] # 真实1对应的预测结果
运行sklearn
的confusion_matrix
函数,结果完全一致:
from sklearn.metrics import confusion_matrix
y_true = [0, 1, 0, 1, 1]
y_pred = [0, 1, 1, 1, 0]
print(confusion_matrix(y_true, y_pred))
输出结果:
[[1 1]
[1 2]]
计算混淆矩阵的关键是:
通过混淆矩阵可以直观看到模型的错误类型(如 FP 多说明容易误判负类为正类),是评估分类模型的基础工具。