TensorFlow2.0中自定义数据集实现分类(宝可精灵数据集)+补充tf.where! |
注意: 如果你这个test只划分了10%(就是测试样本比较少)。这样的话,测试性能波动是比较大的,这里10%,大概100多张(每类20多张),还是比较可观的。如果你的数据集非常非常小,测试的时候波动是比较大的。我们为了让测试更加的准确,我们有目的的让test,validation增加到20%。这里我们的数据集大概1000多张,可以看做一个中小规模的数据集。
总的说来,我们接下来会分为4大步骤:
import os, glob
import random, csv
import tensorflow as tf
def load_csv(root, filename, name2label):
"""
加载CSV文件!
:param root:
:param filename:
:param name2label:
:return:
"""
# root:数据集根目录
# filename:csv文件名
# name2label:类别名编码表
if not os.path.exists(os.path.join(root, filename)):
images = []
for name in name2label.keys():
# 'pokemon\\mewtwo\\00001.png
images += glob.glob(os.path.join(root, name, '*.png'))
images += glob.glob(os.path.join(root, name, '*.jpg'))
images += glob.glob(os.path.join(root, name, '*.jpeg'))
# 1167, 'pokemon\\bulbasaur\\00000000.png'
print(len(images), images)
random.shuffle(images)
with open(os.path.join(root, filename), mode='w', newline='') as f:
writer = csv.writer(f)
for img in images: # 'pokemon\\bulbasaur\\00000000.png'
name = img.split(os.sep)[-2]
label = name2label[name]
# 'pokemon\\bulbasaur\\00000000.png', 0
writer.writerow([img, label])
print('written into csv file:', filename)
# read from csv file
images, labels = [], []
with open(os.path.join(root, filename)) as f:
reader = csv.reader(f)
for row in reader:
# 'pokemon\\bulbasaur\\00000000.png', 0
img, label = row
label = int(label)
images.append(img)
labels.append(label)
assert len(images) == len(labels)
return images, labels
# 加载pokemon数据集的工具!
def load_pokemon(root, mode='train'):
"""
加载pokemon数据集的工具!
:param root: 数据集存储的目录
:param mode: mode:当前加载的数据是train,val,还是test
:return:
"""
# 创建数字编码表,范围0-4;
name2label = {} # "sq...":0 类别名:类标签; 字典 可以看一下目录,一共有5个文件夹,5个类别:0-4范围;
for name in sorted(os.listdir(os.path.join(root))): # 列出所有目录;
if not os.path.isdir(os.path.join(root, name)):
continue
# 给每个类别编码一个数字
name2label[name] = len(name2label.keys())
# 读取Label信息;保存索引文件images.csv
# [file1,file2,], 对应的标签[3,1] 2个一一对应的list对象。
images, labels = load_csv(root, 'images.csv', name2label) # 根据目录,把每个照片的路径提取出来,以及每个照片路径所对应的类别都存储起来,存储到CSV文件中。
# 图片切割成,训练60%,验证20%,测试20%。
if mode == 'train': # 60%
images = images[:int(0.6 * len(images))]
labels = labels[:int(0.6 * len(labels))]
elif mode == 'val': # 20% = 60%->80%
images = images[int(0.6 * len(images)):int(0.8 * len(images))]
labels = labels[int(0.6 * len(labels)):int(0.8 * len(labels))]
else: # 20% = 80%->100%
images = images[int(0.8 * len(images)):]
labels = labels[int(0.8 * len(labels)):]
return images, labels, name2label
# 数据normalize
# 下面这2个值均值和方差,怎么得到的。其实是统计所有imagenet的图片(几百万张)的均值和方差;
# 所有者2个数据比较有意义,因为本质上所有图片的分布都和imagenet图片的分布基本一致。
# 这6个数据基本是通用的,网上一搜就能查到。
img_mean = tf.constant([0.485, 0.456, 0.406])
img_std = tf.constant([0.229, 0.224, 0.225])
def normalize(x, mean=img_mean, std=img_std):
# x shape: [224, 224, 3]
# mean:shape为1;这里用到了广播机制。我们安装好右边对齐的原则,可以得到如下;
# mean : [1, 1, 3], std: [3] 先插入1
# mean : [224, 224, 3], std: [3] 再变为224
x = (x - mean)/std
return x
# 数据normalize之后,这里有一个反normalizaion操作。比如数据可视化的时候,需要反过来。
def denormalize(x, mean=img_mean, std=img_std):
x = x * std + mean
return x
def preprocess(x,y):
# x: 图片的路径,
# y:图片的数字编码
x = tf.io.read_file(x) # 通过图片路径读取图片
x = tf.image.decode_jpeg(x, channels=3) # RGBA 这里注意有些图片不止3个通道。还有A,透明通道。
x = tf.image.resize(x, [244, 244]) # 图片重置的,这里224*224,刚好resnet大小匹配的,方便查看。
# data augmentation, 0~255 首先做一个数据增强!这个操作必须在normalizaion之前(因为是针对图片的。)
# x = tf.image.random_flip_up_down(x) # 随机的做一个上和下的翻转。如果全都翻转,相当于没有增加。随机选择一部分翻转。
x= tf.image.random_flip_left_right(x) # 随机的做一个左和右的翻转。
# x = tf.image.random_crop(x, [224, 224, 3]) # 图片裁剪,这里注意这里裁剪到224*224,所以resize不能是224,比如250,250不然什么也没做。
# x: [0,255]=> 0~1 或者-0.5~0.5 其次:normalizaion
x = tf.cast(x, dtype=tf.float32) / 255.
# 0~1 => D(0,1) 调用函数;
x = normalize(x)
y = tf.convert_to_tensor(y)
return x, y
def main():
import time
images, labels, table = load_pokemon('/home/zhangkf/tf/TF2/TF2_8_data/pokeman', 'train')
# 图片的路径
print('images', len(images), images)
# 图片的标签
print('labels', len(labels), labels)
# 编码表,所对应的类别名字。
print(table)
# 数据集装载
db = tf.data.Dataset.from_tensor_slices((images, labels))
# 预处理
db = db.shuffle(1000).map(preprocess).batch(32)
# 我们做一个可视化,图片可视化出来。
writter = tf.summary.create_file_writer('log')
for step, (x, y) in enumerate(db):
# 这里x的大小: [32, 224, 224, 3]
# 这里y: [32]
with writter.as_default():
tf.summary.image('img', x, step=step, max_outputs=9) # 一次记录9张图片。
time.sleep(5) # 如果感觉太快,每5秒刷新一次batch。
if __name__ == '__main__':
main()
import os
import tensorflow as tf
import numpy as np
from tensorflow import keras
from tensorflow.keras import layers
tf.random.set_seed(22)
np.random.seed(22)
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
assert tf.__version__.startswith('2.')
class ResnetBlock(keras.Model):
def __init__(self, channels, strides=1):
super(ResnetBlock, self).__init__()
self.channels = channels
self.strides = strides
self.conv1 = layers.Conv2D(channels, 3, strides=strides,
padding=[[0,0],[1,1],[1,1],[0,0]])
self.bn1 = keras.layers.BatchNormalization()
self.conv2 = layers.Conv2D(channels, 3, strides=1,
padding=[[0,0],[1,1],[1,1],[0,0]])
self.bn2 = keras.layers.BatchNormalization()
if strides!=1:
self.down_conv = layers.Conv2D(channels, 1, strides=strides, padding='valid')
self.down_bn = tf.keras.layers.BatchNormalization()
def call(self, inputs, training=None):
residual = inputs
x = self.conv1(inputs)
x = tf.nn.relu(x)
x = self.bn1(x, training=training)
x = self.conv2(x)
x = tf.nn.relu(x)
x = self.bn2(x, training=training)
# 残差连接
if self.strides!=1:
residual = self.down_conv(inputs)
residual = tf.nn.relu(residual)
residual = self.down_bn(residual, training=training)
x = x + residual
x = tf.nn.relu(x)
return x
# Resnet18的实现。
class ResNet(keras.Model):
def __init__(self, num_classes, initial_filters=16, **kwargs):
super(ResNet, self).__init__(**kwargs)
self.stem = layers.Conv2D(initial_filters, 3, strides=3, padding='valid')
# 一共包含8个ResnetBlock模块,16+ 根连接的1层 + 输出的1层。
# 就是分成了4组,每组的第一个完成了高和宽的降维;第二个Strides=1就是维度保持不变。
self.blocks = keras.models.Sequential([
ResnetBlock(initial_filters * 2, strides=3),
ResnetBlock(initial_filters * 2, strides=1),
# layers.Dropout(rate=0.5),
ResnetBlock(initial_filters * 4, strides=3),
ResnetBlock(initial_filters * 4, strides=1),
ResnetBlock(initial_filters * 8, strides=2),
ResnetBlock(initial_filters * 8, strides=1),
ResnetBlock(initial_filters * 16, strides=2),
ResnetBlock(initial_filters * 16, strides=1),
])
self.final_bn = layers.BatchNormalization()
self.avg_pool = layers.GlobalMaxPool2D()
self.fc = layers.Dense(num_classes) # 全连接层
def call(self, inputs, training=None):
# print('x:',inputs.shape)
out = self.stem(inputs) # 根链接。
out = tf.nn.relu(out)
# print('stem:',out.shape)
out = self.blocks(out, training=training)
# print('res:',out.shape)
out = self.final_bn(out, training=training)
# out = tf.nn.relu(out)
out = self.avg_pool(out)
# print('avg_pool:',out.shape)
out = self.fc(out) # 分类层得到一个输出。
# print('out:',out.shape)
return out
def main():
num_classes = 5
resnet18 = ResNet(5)
resnet18.build(input_shape=(4,224,224,3))
resnet18.summary()
if __name__ == '__main__':
main()
import os
import tensorflow as tf
import numpy as np
from tensorflow import keras
from tensorflow.python.keras.api._v2.keras import layers, optimizers, losses
from tensorflow.keras.callbacks import EarlyStopping
tf.random.set_seed(22)
np.random.seed(22)
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
assert tf.__version__.startswith('2.')
# 导入一些具体的工具
from pokemon import load_pokemon, normalize, denormalize
from resnet import ResNet # 导入模型
# 预处理的函数,复制过来。
def preprocess(x,y):
# x: 图片的路径,y:图片的数字编码
x = tf.io.read_file(x)
x = tf.image.decode_jpeg(x, channels=3) # RGBA
x = tf.image.resize(x, [244, 244])
x = tf.image.random_flip_left_right(x)
# x = tf.image.random_flip_up_down(x)
x = tf.image.random_crop(x, [224,224,3])
# x: [0,255]=> -1~1
x = tf.cast(x, dtype=tf.float32) / 255.
x = normalize(x)
y = tf.convert_to_tensor(y)
y = tf.one_hot(y, depth=5)
return x, y
batchsz = 8
# creat train db 一般训练的时候需要shuffle。其它是不需要的。
images, labels, table = load_pokemon('/home/zhangkf/johnCodes/TF2/TF2_8_data/pokeman',mode='train')
db_train = tf.data.Dataset.from_tensor_slices((images, labels)) # 变成个Dataset对象。
db_train = db_train.shuffle(1000).map(preprocess).batch(batchsz) # map函数图片路径变为内容。
# crate validation db
images2, labels2, table = load_pokemon('/home/zhangkf/johnCodes/TF2/TF2_8_data/pokeman',mode='val')
db_val = tf.data.Dataset.from_tensor_slices((images2, labels2))
db_val = db_val.map(preprocess).batch(batchsz)
# create test db
images3, labels3, table = load_pokemon('/home/zhangkf/johnCodes/TF2/TF2_8_data/pokeman',mode='test')
db_test = tf.data.Dataset.from_tensor_slices((images3, labels3))
db_test = db_test.map(preprocess).batch(batchsz)
# 训练样本太小了,resnet网络表达能力很强。这里换成4层小的网络了。
# resnet = keras.Sequential([
# layers.Conv2D(16,5,3),
# layers.MaxPool2D(3,3),
# layers.ReLU(),
# layers.Conv2D(64,5,3),
# layers.MaxPool2D(2,2),
# layers.ReLU(),
# layers.Flatten(),
# layers.Dense(64),
# layers.ReLU(),
# layers.Dense(5)
# ])
# 首先创建Resnet18
resnet = ResNet(5)
resnet.build(input_shape=(batchsz, 224, 224, 3))
resnet.summary()
# monitor监听器, 连续5个验证准确率不增加,这个事情触发。
early_stopping = EarlyStopping(
monitor='val_accuracy',
min_delta=0.001,
patience=20
)
# 网络的装配。
resnet.compile(optimizer=optimizers.Adam(lr=1e-4),
loss=losses.CategoricalCrossentropy(from_logits=True),
metrics=['accuracy'])
# 完成标准的train,val, test;
# 标准的逻辑必须通过db_val挑选模型的参数,就需要提供一个earlystopping技术,
resnet.fit(db_train, validation_data=db_val, validation_freq=1, epochs=1000,
callbacks=[early_stopping]) # 1个epoch验证1次。触发了这个事情,提前停止了。
resnet.evaluate(db_test)
ssh://[email protected]:22/home/zhangkf/anaconda3/envs/tf2b/bin/python -u /home/zhangkf/johnCodes/TF2/TF2_8_data/train_scratch.py
WARNING: Logging before flag parsing goes to stderr.
W0828 10:51:31.044445 139752945207040 deprecation.py:323] From /home/zhangkf/anaconda3/envs/tf2b/lib/python3.7/site-packages/tensorflow/python/data/util/random_seed.py:58: add_dispatch_support.<locals>.wrapper (from tensorflow.python.ops.array_ops) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.where in 2.0, which has the same broadcast rule as np.where
Model: "res_net"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
conv2d (Conv2D) multiple 448
_________________________________________________________________
sequential (Sequential) multiple 2797280
_________________________________________________________________
batch_normalization_20 (Batc multiple 1024
_________________________________________________________________
global_max_pooling2d (Global multiple 0
_________________________________________________________________
dense (Dense) multiple 1285
=================================================================
Total params: 2,800,037
Trainable params: 2,794,725
Non-trainable params: 5,312
_________________________________________________________________
Epoch 1/1000
88/88 [==============================] - 37s 425ms/step - loss: 2.3255 - accuracy: 0.2472 - val_loss: 1.6504 - val_accuracy: 0.1845
Epoch 2/1000
88/88 [==============================] - 24s 277ms/step - loss: 0.3819 - accuracy: 0.8250 - val_loss: 1.7054 - val_accuracy: 0.1760
Epoch 3/1000
88/88 [==============================] - 24s 273ms/step - loss: 0.0643 - accuracy: 0.9928 - val_loss: 1.7032 - val_accuracy: 0.1931
Epoch 4/1000
88/88 [==============================] - 24s 271ms/step - loss: 0.0198 - accuracy: 0.9949 - val_loss: 1.7639 - val_accuracy: 0.2318
Epoch 5/1000
88/88 [==============================] - 21s 242ms/step - loss: 0.0069 - accuracy: 1.0000 - val_loss: 1.8334 - val_accuracy: 0.2103
Epoch 6/1000
88/88 [==============================] - 23s 266ms/step - loss: 0.0046 - accuracy: 1.0000 - val_loss: 1.8315 - val_accuracy: 0.3004
Epoch 7/1000
88/88 [==============================] - 24s 271ms/step - loss: 0.0036 - accuracy: 1.0000 - val_loss: 1.9369 - val_accuracy: 0.3090
Epoch 8/1000
88/88 [==============================] - 23s 264ms/step - loss: 0.0029 - accuracy: 1.0000 - val_loss: 2.0374 - val_accuracy: 0.2918
Epoch 9/1000
88/88 [==============================] - 24s 268ms/step - loss: 0.0024 - accuracy: 1.0000 - val_loss: 2.0581 - val_accuracy: 0.3219
Epoch 10/1000
88/88 [==============================] - 24s 269ms/step - loss: 0.0021 - accuracy: 1.0000 - val_loss: 2.0795 - val_accuracy: 0.3219
Epoch 11/1000
88/88 [==============================] - 24s 268ms/step - loss: 0.0018 - accuracy: 1.0000 - val_loss: 2.0932 - val_accuracy: 0.3047
Epoch 12/1000
88/88 [==============================] - 24s 272ms/step - loss: 0.0016 - accuracy: 1.0000 - val_loss: 2.0997 - val_accuracy: 0.3047
Epoch 13/1000
88/88 [==============================] - 21s 243ms/step - loss: 0.0014 - accuracy: 1.0000 - val_loss: 2.1042 - val_accuracy: 0.3090
Epoch 14/1000
88/88 [==============================] - 23s 265ms/step - loss: 0.0012 - accuracy: 1.0000 - val_loss: 2.1075 - val_accuracy: 0.3090
Epoch 15/1000
88/88 [==============================] - 24s 270ms/step - loss: 0.0011 - accuracy: 1.0000 - val_loss: 2.1106 - val_accuracy: 0.3090
Epoch 16/1000
88/88 [==============================] - 23s 260ms/step - loss: 0.0010 - accuracy: 1.0000 - val_loss: 2.1136 - val_accuracy: 0.3090
Epoch 17/1000
88/88 [==============================] - 23s 267ms/step - loss: 9.1804e-04 - accuracy: 1.0000 - val_loss: 2.1162 - val_accuracy: 0.3090
Epoch 18/1000
88/88 [==============================] - 24s 268ms/step - loss: 8.3833e-04 - accuracy: 1.0000 - val_loss: 2.1183 - val_accuracy: 0.3090
Epoch 19/1000
88/88 [==============================] - 24s 272ms/step - loss: 7.6828e-04 - accuracy: 1.0000 - val_loss: 2.1211 - val_accuracy: 0.3047
Epoch 20/1000
88/88 [==============================] - 24s 269ms/step - loss: 7.0653e-04 - accuracy: 1.0000 - val_loss: 2.1231 - val_accuracy: 0.3047
Epoch 21/1000
88/88 [==============================] - 24s 268ms/step - loss: 6.5178e-04 - accuracy: 1.0000 - val_loss: 2.1257 - val_accuracy: 0.3047
Epoch 22/1000
88/88 [==============================] - 23s 262ms/step - loss: 6.0310e-04 - accuracy: 1.0000 - val_loss: 2.1278 - val_accuracy: 0.3047
Epoch 23/1000
88/88 [==============================] - 24s 268ms/step - loss: 5.5923e-04 - accuracy: 1.0000 - val_loss: 2.1298 - val_accuracy: 0.3047
Epoch 24/1000
88/88 [==============================] - 23s 267ms/step - loss: 5.1979e-04 - accuracy: 1.0000 - val_loss: 2.1319 - val_accuracy: 0.3047
Epoch 25/1000
88/88 [==============================] - 23s 259ms/step - loss: 4.8411e-04 - accuracy: 1.0000 - val_loss: 2.1338 - val_accuracy: 0.3047
Epoch 26/1000
问题:小样本很难训练,我们观察到上面的一个现象,它在Training上面准确率很容易达到100%左右,但是在Val的验证准确率很低20%左右。那就意味着这个网络结构完全是没有训练好的。为什么会这样呢?5个验证准确率没有增加0.01,触发EarlyStopping操作,程序停止。
Resnet18上的验证准确率很低,基本是不工作的网络结构,这里主要是因为Resnet网络太大了,网络层数比较深,而且网络参数量比较大。有什么办法解决这个问题?第一个最根本重要的是增加数据集。imagenet是有几百万图片。这里宝卡梦精灵数据集每个类别才有200多张图片。
第二个办法:对网路做一些约束,比如把网络层数减少,网络参数参数量减少。或者增加一些正则化的手段,或者过拟合的手段。这些都可以尝试一下。
但是我们这里的图片实在是太少了,所以我们这里采用另一种方式,直接换一个比较小型的网络。对于数据集不够的情况下这个小网络往往能发挥出不可预测的效果。我们测试一下。网路小容易训练。
换好网络结构之后的效果如下:测试准确率达到86%左右啦。
ssh://[email protected]:22/home/zhangkf/anaconda3/envs/tf2b/bin/python -u /home/zhangkf/johnCodes/TF2/TF2_8_data/train_scratch.py
WARNING: Logging before flag parsing goes to stderr.
W0828 10:33:44.677829 139788463359744 deprecation.py:323] From /home/zhangkf/anaconda3/envs/tf2b/lib/python3.7/site-packages/tensorflow/python/data/util/random_seed.py:58: add_dispatch_support.<locals>.wrapper (from tensorflow.python.ops.array_ops) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.where in 2.0, which has the same broadcast rule as np.where
Model: "sequential"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
conv2d (Conv2D) multiple 1216
_________________________________________________________________
max_pooling2d (MaxPooling2D) multiple 0
_________________________________________________________________
re_lu (ReLU) multiple 0
_________________________________________________________________
conv2d_1 (Conv2D) multiple 25664
_________________________________________________________________
max_pooling2d_1 (MaxPooling2 multiple 0
_________________________________________________________________
re_lu_1 (ReLU) multiple 0
_________________________________________________________________
flatten (Flatten) multiple 0
_________________________________________________________________
dense (Dense) multiple 36928
_________________________________________________________________
re_lu_2 (ReLU) multiple 0
_________________________________________________________________
dense_1 (Dense) multiple 325
=================================================================
Total params: 64,133
Trainable params: 64,133
Non-trainable params: 0
_________________________________________________________________
Epoch 1/1000
88/88 [==============================] - 23s 266ms/step - loss: 1.4197 - accuracy: 0.3153 - val_loss: 1.1087 - val_accuracy: 0.6223
Epoch 2/1000
88/88 [==============================] - 22s 251ms/step - loss: 0.9204 - accuracy: 0.6749 - val_loss: 0.7547 - val_accuracy: 0.7940
Epoch 3/1000
88/88 [==============================] - 23s 266ms/step - loss: 0.6312 - accuracy: 0.8405 - val_loss: 0.5803 - val_accuracy: 0.8326
Epoch 4/1000
88/88 [==============================] - 19s 221ms/step - loss: 0.4844 - accuracy: 0.8674 - val_loss: 0.4981 - val_accuracy: 0.8369
Epoch 5/1000
88/88 [==============================] - 22s 253ms/step - loss: 0.4089 - accuracy: 0.8874 - val_loss: 0.4607 - val_accuracy: 0.8455
Epoch 6/1000
88/88 [==============================] - 23s 259ms/step - loss: 0.3629 - accuracy: 0.9137 - val_loss: 0.4383 - val_accuracy: 0.8584
Epoch 7/1000
88/88 [==============================] - 21s 243ms/step - loss: 0.3281 - accuracy: 0.9312 - val_loss: 0.4283 - val_accuracy: 0.8498
Epoch 8/1000
88/88 [==============================] - 24s 269ms/step - loss: 0.3019 - accuracy: 0.9329 - val_loss: 0.4155 - val_accuracy: 0.8541
Epoch 9/1000
88/88 [==============================] - 22s 249ms/step - loss: 0.2772 - accuracy: 0.9367 - val_loss: 0.4091 - val_accuracy: 0.8541
Epoch 10/1000
88/88 [==============================] - 23s 265ms/step - loss: 0.2553 - accuracy: 0.9413 - val_loss: 0.4041 - val_accuracy: 0.8541
Epoch 11/1000
88/88 [==============================] - 23s 259ms/step - loss: 0.2344 - accuracy: 0.9513 - val_loss: 0.4001 - val_accuracy: 0.8584
Epoch 12/1000
88/88 [==============================] - 22s 252ms/step - loss: 0.2156 - accuracy: 0.9560 - val_loss: 0.3960 - val_accuracy: 0.8584
Epoch 13/1000
88/88 [==============================] - 23s 257ms/step - loss: 0.1981 - accuracy: 0.9651 - val_loss: 0.3915 - val_accuracy: 0.8627
Epoch 14/1000
88/88 [==============================] - 23s 258ms/step - loss: 0.1804 - accuracy: 0.9686 - val_loss: 0.3909 - val_accuracy: 0.8627
Epoch 15/1000
88/88 [==============================] - 21s 237ms/step - loss: 0.1655 - accuracy: 0.9717 - val_loss: 0.3849 - val_accuracy: 0.8627
Epoch 16/1000
88/88 [==============================] - 24s 269ms/step - loss: 0.1501 - accuracy: 0.9758 - val_loss: 0.3862 - val_accuracy: 0.8627
Epoch 17/1000
88/88 [==============================] - 18s 206ms/step - loss: 0.1362 - accuracy: 0.9784 - val_loss: 0.3825 - val_accuracy: 0.8627
Epoch 18/1000
88/88 [==============================] - 19s 212ms/step - loss: 0.1225 - accuracy: 0.9788 - val_loss: 0.3804 - val_accuracy: 0.8627
Epoch 19/1000
88/88 [==============================] - 22s 254ms/step - loss: 0.1107 - accuracy: 0.9856 - val_loss: 0.3789 - val_accuracy: 0.8712
Epoch 20/1000
88/88 [==============================] - 23s 265ms/step - loss: 0.0996 - accuracy: 0.9883 - val_loss: 0.3785 - val_accuracy: 0.8755
Epoch 21/1000
88/88 [==============================] - 22s 249ms/step - loss: 0.0897 - accuracy: 0.9896 - val_loss: 0.3850 - val_accuracy: 0.8670
Epoch 22/1000
88/88 [==============================] - 22s 251ms/step - loss: 0.0808 - accuracy: 0.9935 - val_loss: 0.3869 - val_accuracy: 0.8670
Epoch 23/1000
88/88 [==============================] - 23s 258ms/step - loss: 0.0726 - accuracy: 0.9937 - val_loss: 0.3926 - val_accuracy: 0.8670
Epoch 24/1000
88/88 [==============================] - 23s 266ms/step - loss: 0.0652 - accuracy: 0.9937 - val_loss: 0.3958 - val_accuracy: 0.8627
Epoch 25/1000
88/88 [==============================] - 23s 260ms/step - loss: 0.0583 - accuracy: 0.9937 - val_loss: 0.3974 - val_accuracy: 0.8670
Epoch 26/1000
88/88 [==============================] - 20s 222ms/step - loss: 0.0521 - accuracy: 0.9958 - val_loss: 0.4033 - val_accuracy: 0.8670
Epoch 27/1000
88/88 [==============================] - 22s 255ms/step - loss: 0.0471 - accuracy: 0.9970 - val_loss: 0.4051 - val_accuracy: 0.8670
Epoch 28/1000
88/88 [==============================] - 23s 260ms/step - loss: 0.0424 - accuracy: 0.9986 - val_loss: 0.4089 - val_accuracy: 0.8627
Epoch 29/1000
88/88 [==============================] - 20s 231ms/step - loss: 0.0380 - accuracy: 0.9994 - val_loss: 0.4088 - val_accuracy: 0.8584
Epoch 30/1000
88/88 [==============================] - 20s 223ms/step - loss: 0.0342 - accuracy: 0.9994 - val_loss: 0.4104 - val_accuracy: 0.8627
Epoch 31/1000
88/88 [==============================] - 23s 264ms/step - loss: 0.0311 - accuracy: 0.9994 - val_loss: 0.4137 - val_accuracy: 0.8627
Epoch 32/1000
88/88 [==============================] - 19s 215ms/step - loss: 0.0278 - accuracy: 0.9994 - val_loss: 0.4183 - val_accuracy: 0.8627
Epoch 33/1000
88/88 [==============================] - 22s 253ms/step - loss: 0.0251 - accuracy: 0.9994 - val_loss: 0.4170 - val_accuracy: 0.8627
Epoch 34/1000
88/88 [==============================] - 22s 253ms/step - loss: 0.0226 - accuracy: 0.9994 - val_loss: 0.4236 - val_accuracy: 0.8584
Epoch 35/1000
88/88 [==============================] - 22s 251ms/step - loss: 0.0205 - accuracy: 0.9994 - val_loss: 0.4269 - val_accuracy: 0.8584
Epoch 36/1000
88/88 [==============================] - 23s 257ms/step - loss: 0.0187 - accuracy: 0.9994 - val_loss: 0.4276 - val_accuracy: 0.8584
Epoch 37/1000
88/88 [==============================] - 23s 264ms/step - loss: 0.0169 - accuracy: 0.9994 - val_loss: 0.4359 - val_accuracy: 0.8584
Epoch 38/1000
88/88 [==============================] - 23s 259ms/step - loss: 0.0153 - accuracy: 0.9994 - val_loss: 0.4345 - val_accuracy: 0.8584
Epoch 39/1000
88/88 [==============================] - 23s 262ms/step - loss: 0.0142 - accuracy: 0.9994 - val_loss: 0.4405 - val_accuracy: 0.8584
Epoch 40/1000
88/88 [==============================] - 21s 244ms/step - loss: 0.0129 - accuracy: 0.9994 - val_loss: 0.4409 - val_accuracy: 0.8627
30/30 [==============================] - 5s 154ms/step - loss: 0.4972 - accuracy: 0.8755
Process finished with exit code 0
import os
import tensorflow as tf
import numpy as np
from tensorflow import keras
from tensorflow.python.keras.api._v2.keras import layers, optimizers, losses
from tensorflow.keras.callbacks import EarlyStopping
tf.random.set_seed(22)
np.random.seed(22)
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
assert tf.__version__.startswith('2.')
# 导入一些具体的工具
from pokemon import load_pokemon, normalize, denormalize
from resnet import ResNet # 导入模型
# 预处理的函数,复制过来。
def preprocess(x,y):
# x: 图片的路径,y:图片的数字编码
x = tf.io.read_file(x)
x = tf.image.decode_jpeg(x, channels=3) # RGBA
x = tf.image.resize(x, [244, 244])
x = tf.image.random_flip_left_right(x)
# x = tf.image.random_flip_up_down(x)
x = tf.image.random_crop(x, [224,224,3])
# x: [0,255]=> -1~1
x = tf.cast(x, dtype=tf.float32) / 255.
x = normalize(x)
y = tf.convert_to_tensor(y)
y = tf.one_hot(y, depth=5)
return x, y
batchsz = 128
# creat train db 一般训练的时候需要shuffle。其它是不需要的。
images, labels, table = load_pokemon('/home/zhangkf/johnCodes/TF2/TF2_8_data/pokeman',mode='train')
db_train = tf.data.Dataset.from_tensor_slices((images, labels)) # 变成个Dataset对象。
db_train = db_train.shuffle(1000).map(preprocess).batch(batchsz) # map函数图片路径变为内容。
# crate validation db
images2, labels2, table = load_pokemon('/home/zhangkf/johnCodes/TF2/TF2_8_data/pokeman',mode='val')
db_val = tf.data.Dataset.from_tensor_slices((images2, labels2))
db_val = db_val.map(preprocess).batch(batchsz)
# create test db
images3, labels3, table = load_pokemon('/home/zhangkf/johnCodes/TF2/TF2_8_data/pokeman',mode='test')
db_test = tf.data.Dataset.from_tensor_slices((images3, labels3))
db_test = db_test.map(preprocess).batch(batchsz)
# 导入别的已经训练好的网络和参数
# 这部分工作在keras网络中提供了一些经典的网络以及经典网络训练好的参数。
# 这里使用Vgg19,还把他的权值导入进来。imagenet训练的1000类,我们就把输出层去掉。
net = keras.applications.VGG19(weights='imagenet', include_top=False,
pooling='max')
net.trainable = False; # 把这部分老的网络,不需要参与反向更新。不训练。
newnet = keras.Sequential([
net,
layers.Dense(5)
])
newnet.build(input_shape=(4, 224, 224, 3))
newnet.summary()
# monitor监听器, 连续5个验证准确率不增加,这个事情触发。
early_stopping = EarlyStopping(
monitor='val_accuracy',
min_delta=0.001,
patience=5
)
# 网络的装配。
newnet.compile(optimizer=optimizers.Adam(lr=1e-3),
loss=losses.CategoricalCrossentropy(from_logits=True),
metrics=['accuracy'])
# 完成标准的train,val, test;
# 标准的逻辑必须通过db_val挑选模型的参数,就需要提供一个earlystopping技术,
newnet.fit(db_train, validation_data=db_val, validation_freq=1, epochs=100,
callbacks=[early_stopping]) # 1个epoch验证1次。触发了这个事情,提前停止了。
newnet.evaluate(db_test)
ssh://[email protected]:22/home/zhangkf/anaconda3/envs/tf2b/bin/python -u /home/zhangkf/johnCodes/TF2/TF2_8_data/train_transfer.py
WARNING: Logging before flag parsing goes to stderr.
W0828 14:10:43.894998 140497148106496 deprecation.py:323] From /home/zhangkf/anaconda3/envs/tf2b/lib/python3.7/site-packages/tensorflow/python/data/util/random_seed.py:58: add_dispatch_support.<locals>.wrapper (from tensorflow.python.ops.array_ops) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.where in 2.0, which has the same broadcast rule as np.where
Model: "sequential"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
vgg19 (Model) (None, 512) 20024384
_________________________________________________________________
dense (Dense) (None, 5) 2565
=================================================================
Total params: 20,026,949
Trainable params: 2,565
Non-trainable params: 20,024,384
_________________________________________________________________
Epoch 1/100
6/6 [==============================] - 27s 4s/step - loss: 2.4913 - accuracy: 0.2389 - val_loss: 1.9317 - val_accuracy: 0.2489
Epoch 2/100
6/6 [==============================] - 23s 4s/step - loss: 1.8861 - accuracy: 0.2632 - val_loss: 1.7571 - val_accuracy: 0.2618
Epoch 3/100
6/6 [==============================] - 21s 4s/step - loss: 1.6289 - accuracy: 0.3133 - val_loss: 1.4749 - val_accuracy: 0.4034
Epoch 4/100
6/6 [==============================] - 18s 3s/step - loss: 1.3627 - accuracy: 0.4664 - val_loss: 1.4244 - val_accuracy: 0.4249
Epoch 5/100
6/6 [==============================] - 23s 4s/step - loss: 1.2218 - accuracy: 0.5365 - val_loss: 1.1910 - val_accuracy: 0.5193
Epoch 6/100
6/6 [==============================] - 23s 4s/step - loss: 1.0595 - accuracy: 0.5823 - val_loss: 1.0799 - val_accuracy: 0.5794
Epoch 7/100
6/6 [==============================] - 22s 4s/step - loss: 0.9539 - accuracy: 0.6295 - val_loss: 0.9891 - val_accuracy: 0.6009
Epoch 8/100
6/6 [==============================] - 23s 4s/step - loss: 0.8529 - accuracy: 0.6724 - val_loss: 0.9117 - val_accuracy: 0.6395
Epoch 9/100
6/6 [==============================] - 24s 4s/step - loss: 0.7763 - accuracy: 0.7153 - val_loss: 0.8449 - val_accuracy: 0.6867
Epoch 10/100
6/6 [==============================] - 24s 4s/step - loss: 0.7108 - accuracy: 0.7539 - val_loss: 0.7860 - val_accuracy: 0.7082
Epoch 11/100
6/6 [==============================] - 23s 4s/step - loss: 0.6567 - accuracy: 0.7754 - val_loss: 0.7379 - val_accuracy: 0.7296
Epoch 12/100
6/6 [==============================] - 21s 4s/step - loss: 0.6082 - accuracy: 0.8011 - val_loss: 0.7027 - val_accuracy: 0.7511
Epoch 13/100
6/6 [==============================] - 22s 4s/step - loss: 0.5690 - accuracy: 0.8255 - val_loss: 0.6692 - val_accuracy: 0.7639
Epoch 14/100
6/6 [==============================] - 23s 4s/step - loss: 0.5346 - accuracy: 0.8383 - val_loss: 0.6384 - val_accuracy: 0.7811
Epoch 15/100
6/6 [==============================] - 24s 4s/step - loss: 0.5036 - accuracy: 0.8512 - val_loss: 0.6139 - val_accuracy: 0.7983
Epoch 16/100
6/6 [==============================] - 23s 4s/step - loss: 0.4764 - accuracy: 0.8584 - val_loss: 0.5924 - val_accuracy: 0.8069
Epoch 17/100
6/6 [==============================] - 19s 3s/step - loss: 0.4525 - accuracy: 0.8670 - val_loss: 0.5731 - val_accuracy: 0.8112
Epoch 18/100
6/6 [==============================] - 23s 4s/step - loss: 0.4308 - accuracy: 0.8813 - val_loss: 0.5558 - val_accuracy: 0.8112
Epoch 19/100
6/6 [==============================] - 23s 4s/step - loss: 0.4112 - accuracy: 0.8898 - val_loss: 0.5401 - val_accuracy: 0.8240
Epoch 20/100
6/6 [==============================] - 17s 3s/step - loss: 0.3933 - accuracy: 0.8941 - val_loss: 0.5262 - val_accuracy: 0.8240
Epoch 21/100
6/6 [==============================] - 23s 4s/step - loss: 0.3770 - accuracy: 0.9013 - val_loss: 0.5137 - val_accuracy: 0.8283
Epoch 22/100
6/6 [==============================] - 24s 4s/step - loss: 0.3620 - accuracy: 0.9070 - val_loss: 0.5021 - val_accuracy: 0.8283
Epoch 23/100
6/6 [==============================] - 23s 4s/step - loss: 0.3481 - accuracy: 0.9156 - val_loss: 0.4914 - val_accuracy: 0.8283
Epoch 24/100
6/6 [==============================] - 20s 3s/step - loss: 0.3353 - accuracy: 0.9213 - val_loss: 0.4817 - val_accuracy: 0.8283
Epoch 25/100
6/6 [==============================] - 23s 4s/step - loss: 0.3234 - accuracy: 0.9256 - val_loss: 0.4727 - val_accuracy: 0.8369
Epoch 26/100
6/6 [==============================] - 23s 4s/step - loss: 0.3122 - accuracy: 0.9328 - val_loss: 0.4643 - val_accuracy: 0.8326
Epoch 27/100
6/6 [==============================] - 23s 4s/step - loss: 0.3018 - accuracy: 0.9328 - val_loss: 0.4565 - val_accuracy: 0.8369
Epoch 28/100
6/6 [==============================] - 23s 4s/step - loss: 0.2921 - accuracy: 0.9342 - val_loss: 0.4492 - val_accuracy: 0.8369
Epoch 29/100
6/6 [==============================] - 23s 4s/step - loss: 0.2829 - accuracy: 0.9342 - val_loss: 0.4424 - val_accuracy: 0.8369
Epoch 30/100
6/6 [==============================] - 23s 4s/step - loss: 0.2743 - accuracy: 0.9342 - val_loss: 0.4361 - val_accuracy: 0.8369
2/2 [==============================] - 5s 3s/step - loss: 0.3366 - accuracy: 0.8970
Process finished with exit code 0
tf.where(
condition,
x=None,
y=None,
name=None
)
Return the elements, either from x or y, depending on the condition.
理解:where嘛,就是要根据条件找到你要的东西。
condition:条件,是一个boolean
x:数据
y:同x维度的数据。
返回,返回符合条件的数据。当条件为真,取x对应的数据;当条件为假,取y对应的数据
import tensorflow as tf
import numpy as np
# 定义一个tensor,表示condition,内部数据随机产生
condition = tf.convert_to_tensor(np.random.random([2, 3]), dtype=tf.float32)
print(condition)
# 定义两个tensor,表示原数据
a = tf.ones(shape=[2, 3], name='a')
print(a)
b = tf.zeros(shape=[2, 3], name='b')
print(b)
# 选择大于0.5的数值的坐标,并根据condition信息在a和b中选取数据
result = tf.where(condition > 0.5, a, b)
print(result)
ssh://[email protected]:22/home/zhangkf/anaconda3/envs/tf2c/bin/python -u/home/zhangkf/johnCodes/TF1/test.py
tf.Tensor(
[[0.04526703 0.08822254 0.6437674 ]
[0.3951503 0.39249578 0.51326084]], shape=(2, 3), dtype=float32)
tf.Tensor(
[[1. 1. 1.]
[1. 1. 1.]], shape=(2, 3), dtype=float32)
tf.Tensor(
[[0. 0. 0.]
[0. 0. 0.]], shape=(2, 3), dtype=float32)
tf.Tensor(
[[0. 0. 1.]
[0. 0. 1.]], shape=(2, 3), dtype=float32)
Process finished with exit code 0
import tensorflow as tf
import numpy as np
# 定义一个tensor,表示condition,内部数据随机产生
condition = tf.constant([1, 2, 2, 4])
print(condition)
# 定义两个tensor,表示原数据
a = tf.constant([[1, 2, 2, 4], [3, 4, 5, 6], [7, 8, 9, 10], [2, 3, 3, 4]])
print(a)
# 选择condition==2所在的坐标(哪些行),并根据result_index进行选择a中对应的行。
result_index = tf.where(condition == 2)
result = tf.gather_nd(a, result_index) # 返回a第2行和第3行。
print(result)
tf.Tensor([1 2 2 4], shape=(4,), dtype=int32)
tf.Tensor(
[[ 1 2 2 4]
[ 3 4 5 6]
[ 7 8 9 10]
[ 2 3 3 4]], shape=(4, 4), dtype=int32)
tf.Tensor(
[[1]
[2]], shape=(2, 1), dtype=int64)
tf.Tensor(
[[ 3 4 5 6]
[ 7 8 9 10]], shape=(2, 4), dtype=int32)
Process finished with exit code 0
>>> a = np.arange(24).reshape(2,3,4)
>>> a
array([[[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11]],
[[12, 13, 14, 15],
[16, 17, 18, 19],
[20, 21, 22, 23]]])
>>> np.argmax(a,axis = 0) #shape(3,4)
array([[1, 1, 1, 1],
[1, 1, 1, 1],
[1, 1, 1, 1]])
>>> np.argmax(a,axis = 1) #shape(2,4)
array([[2, 2, 2, 2],
[2, 2, 2, 2]])
>>> np.argmax(a,axis = -1) #shape(2,3)
array([[3, 3, 3],
[3, 3, 3]])