写在开头:现在网上的资料对测试数据集的混淆矩阵可视化大都基于MNIST、CAIFR等标准数据集来绘制,如果想把自己制作的数据集的预测结果可视化,如何做?本文为你提供一种解决方案。
上篇博文我们写了自制数据集的制作、训练、可视化流程。本篇博文在此基础上将测试集的预测结果可视化,让我们直观的观察哪些数据预测正确,哪些数据预测错误。数据集的制作流程及代码在上篇博文中已有详细介绍,本次不作为解释重点,重点解释绘制混淆矩阵部分的代码及实现。
from __future__ import print_function #这行代码必须放在第一行,不然会报错,原因我也没有细察,请各位读者注意一下
import os,glob,random, csv,itertools
import tensorflow as tf
import numpy as np
from tensorflow import keras
from tensorflow.keras import optimizers,losses,layers
from matplotlib import pyplot as plt
from sklearn.metrics import confusion_matrix
#设置全局种子
tf.random.set_seed(22)
np.random.seed(22)
#检查tensorflow版本是否为2.0版本
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
assert tf.__version__.startswith('2.')
# 设置GPU显存按需分配(如果没有装GPU版本,这部分就不用写,但是如果不用GPU加速计算,使用CPU将会严重限制模型的训练速度)
gpus = tf.config.experimental.list_physical_devices('GPU')
if gpus:
try:
for gpu in gpus:
tf.config.experimental.set_memory_growth(gpu, True)
logical_gpus = tf.config.experimental.list_logical_devices('GPU')
print(len(gpus), "Physical GPUs,", len(logical_gpus), "Logical GPUs")
except RuntimeError as e:
print(e)
def load_csv(root, filename, name2label):
if not os.path.exists(os.path.join(root, filename)):
images = []
for name in name2label.keys()
images += glob.glob(os.path.join(root, name, '*.png'))
images += glob.glob(os.path.join(root, name, '*.jpg'))
images += glob.glob(os.path.join(root, name, '*.jpeg'))
print(len(images), images)
random.shuffle(images)
with open(os.path.join(root, filename), mode='w', newline='') as f:
writer = csv.writer(f)
for img in images: #
name = img.split(os.sep)[-2]
label = name2label[name]
writer.writerow([img, label])
print('written into csv file:', filename)
images, labels = [], []
with open(os.path.join(root, filename)) as f:
reader = csv.reader(f)
for row in reader:
img, label = row
label = int(label)
images.append(img)
labels.append(label)
assert len(images) == len(labels)
return images, labels
def load_huangtu(root, mode='train'):
name2label = {}
for name in sorted(os.listdir(os.path.join(root))):
if not os.path.isdir(os.path.join(root, name)):
continue
name2label[name] = len(name2label.keys())
images, labels = load_csv(root, 'images.csv', name2label)
if mode == 'train': # 60%
images = images[:int(0.6 * len(images))]
labels = labels[:int(0.6 * len(labels))]
elif mode == 'val': # 20% = 60%->80%
images = images[int(0.6 * len(images)):int(0.8 * len(images))]
labels = labels[int(0.6 * len(labels)):int(0.8 * len(labels))]
else: # 20% = 80%->100%
images = images[int(0.8 * len(images)):]
labels = labels[int(0.8 * len(labels)):]
return images, labels, name2label
#下面两个矩阵常量是谷歌经过大量的图片计算得到的值,后续可以直接使用
img_mean = tf.constant([0.485, 0.456, 0.406])
img_std = tf.constant([0.229, 0.224, 0.225])
def normalize(x, mean=img_mean, std=img_std):
x = (x - mean)/std
return x
def preprocess(x,y):
# x: 图片的路径,y:图片的数字编码
x = tf.io.read_file(x)
x = tf.image.decode_jpeg(x, channels=3) # RGBA
x = tf.image.resize(x, [244, 244])
x = tf.image.random_flip_up_down(x)#数据增强(上下翻转)
x= tf.image.random_flip_left_right(x) #数据增强(左右翻转)
x = tf.image.random_crop(x, [224, 224, 3]) #数据增强(随机裁剪)
x = tf.cast(x, dtype=tf.float32) / 255.
x = normalize(x) #255-0
y = tf.convert_to_tensor(y)
y = tf.one_hot(y, depth=7)
return x, y
#参数 y_true为测试数据集的真实标签,y_pred为网络对测试数据集的预测结果
def plot_confusion_matrix(y_true, y_pred, title = "Confusion matrix",
cmap = plt.cm.Blues, save_flg = False):
classes = [str(i) for i in range(7)]#参数i的取值范围根据你自己数据集的划分类别来修改,我这儿为7代表数据集共有7类
labels = range(7)#数据集的标签类别,跟上面I对应
cm = confusion_matrix(y_true, y_pred, labels=labels)
plt.figure(figsize=(14, 12))
plt.imshow(cm, interpolation='nearest', cmap=cmap)
plt.title(title, fontsize=40)
plt.colorbar()
tick_marks = np.arange(len(classes))
plt.xticks(tick_marks, classes, fontsize=20)
plt.yticks(tick_marks, classes, fontsize=20)
print('Confusion matrix, without normalization')
thresh = cm.max() / 2.
for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
plt.text(j, i, cm[i, j],
horizontalalignment="center",
color="white" if cm[i, j] > thresh else "black")
plt.ylabel('True label', fontsize=30)
plt.xlabel('Predicted label', fontsize=30)
if save_flg:
plt.savefig("./confusion_matrix.png")
plt.show()
def main():
images, labels, table = load_huangtu('huangtu', 'train')
print('images', len(images), images)
print('labels', len(labels), labels)
print(table)
db = tf.data.Dataset.from_tensor_slices((images, labels))
db = db.shuffle(1000).map(preprocess).batch(32)
#被注释掉的代码可用于tensorboard可视化
'''
writter = tf.summary.create_file_writer('此处为文件的创建/保存地质')
for step, (x,y) in enumerate(db):
with writter.as_default():
x = denormalize(x)
tf.summary.image('img',x,step=step,max_outputs=9)
time.sleep(5)
'''
batchsz = 128
# 创建训练集Datset对象
images, labels, table = load_huangtu('huangtu',mode='train')
db_train = tf.data.Dataset.from_tensor_slices((images, labels))
db_train = db_train.shuffle(1000).map(preprocess).batch(batchsz)
# 创建验证集Datset对象
images2, labels2, table = load_huangtu('huangtu',mode='val')
db_val = tf.data.Dataset.from_tensor_slices((images2, labels2))
db_val = db_val.map(preprocess).batch(batchsz)
# 创建测试集Datset对象
images3, labels3, table = load_huangtu('huangtu',mode='test')
db_test = tf.data.Dataset.from_tensor_slices((images3, labels3))
db_test = db_test.map(preprocess).batch(batchsz)
#迁移学习
net = keras.applications.DenseNet121(weights='imagenet', include_top=False,pooling='max')
net.trainable = False
newnet = keras.Sequential([
net,
layers.Dense(1024,activation='relu'),
layers.BatchNormalization(),
layers.Dropout(rate=0.5),
layers.Dense(7)
])
newnet.build(input_shape=(4,224,224,3))
newnet.summary()
newnet.compile(optimizer=optimizers.Adam(lr=0.01),
loss=losses.CategoricalCrossentropy(from_logits=True),
metrics=['accuracy'])
'''
## 调用tensorboard
log_dir='调用上面的路径地址'
tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=log_dir, histogram_freq=1)
'''
newnet.fit(db_train, validation_data=db_val, validation_freq=1, epochs=50)
score=newnet.evaluate(db_test)
print('Test loss:', score[0])
print('Test accuracy:', score[1])
#下面三行代码为绘制混淆矩阵的传参
predict_classes = newnet.predict(db_test)#对测试数据集进行预测
true_classes = np.argmax(predict_classes,1)#汲取预测结果
plot_confusion_matrix(true_classes, labels3, save_flg = True)#调用混淆矩阵
#调用主函数
if __name__ == '__main__':
main()
将上面所有的代码按顺序组合起来,就可以做为一个完整的数据集制作、模型训练、混淆矩阵可视化的代码
下图为自制数据集的混淆矩阵可视化结果
对角线上表示预测正确结果。
写在最后:相比对标准数据集的预测,重点注意和思考我们自制数据集的传参。
本文到此结束