当 NLP 玩家遇到一个 CV 图像分类的任务时,老实的说,我是有点懵逼的。。。
任务目标是,训练一个层数较少,结构较为简单的图像分类模型,使用其最后一层隐藏层输出,作为特征向量来表征图像。
之前都是使用 Keras 较多,于是本次准备借着这个简单的任务上手 TensorFlow 2.1 。
TensorFlow 2.1 自带的 tf.data.Dataset 处理训练数据十分好用,并且自带 shuffle,repeat,和划分 batch 的方法。可以通过python generator, numpy list, Tensor slices 等数据结构直接构成 Dataset。
我训练使用的数据是文档中的插图,5个类别共 10w 张。
起初我使用的方法是:构造一个 python generator,训练时,使用 tf 自带的 tf.io.read_file() 和 tf.image.decode_jpeg() 方法从磁盘中读取数据,再使用 tf.data.Dataset.from_generator 生成数据集。
但训练时发现这样的数据处理有着很大的问题:受制于generator 的读取数据速度,batch 数据生成的速度跟不上 GPU 的训练速度,导致 GPU 的利用率只有不到 10%,训练速度很慢。很慢。。很慢。。。
这时我想到了 TFRecord 。TFRecord 可以将数据转存为二进制文件保存,这样在训练时读取数据就不会遇到以上的问题了。
使用 TFRecord 来进行数据处理,首先需要将原始图片数据转存为 TFRecord 格式:
def pares_image(pic):
'''
图片预处理,并转为字符串
'''
label = pares_label(pic)
try:
img = Image.open(pic)
img = img.convert('RGB')
img = img.resize((54,96))
img_raw = img.tobytes()
except:
return None, None
return img_raw, label
train_data_list = get_file_path(train_data_path)
writer = tf.io.TFRecordWriter('./data/test_data')
for data in tqdm(test_data_list):
img_raw, label = pares_image(data)
if (img_raw is not None) and (label != 'not valid'):
exam = tf.train.Example(
features = tf.train.Features(
feature = {
'label': tf.train.Feature(int64_list=tf.train.Int64List (value=[int(label_2_idx[label])])),
'data' : tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_raw]))
}
)
)
writer.write (exam.SerializeToString())
writer.close()
接下来读取 TFRecord 文件,加载进 tf.data,Dataset
train_reader = tf.data.TFRecordDataset('./data/train_data')
test_reader = tf.data.TFRecordDataset('./data/test_data')
valid_reader = tf.data.TFRecordDataset('./data/valid_data')
feature_description = {
'data' : tf.io.FixedLenFeature([], tf.string),
'label': tf.io.FixedLenFeature([] , tf.int64, default_value=-1)
}
def _parse_function (exam_proto):
temp = tf.io.parse_single_example (exam_proto, feature_description)
img = tf.io.decode_raw(temp['data'], tf.uint8)
img = tf.reshape(img, [54, 96, 3])
img = img / 255
label = temp['label']
return (img, label)
train_dataset = train_reader.repeat(5).shuffle(12800, reshuffle_each_iteration=True).map(_parse_function).batch(128)
test_dataset = test_reader.repeat(5).shuffle(12800, reshuffle_each_iteration=True).map(_parse_function).batch(128)
valid_dataset = test_reader.repeat(5).shuffle(12800, reshuffle_each_iteration=True).map(_parse_function).batch(128)
读取文件时,需要重新将二进制的图像数据重新 decode 为 Tensor,并进行数值归一化处理。