小心使用tf.image.resize_images,填坑经验分享给你

上上周,我在一个项目上线前对模型进行测试时出现了问题,这个问题困扰了我近两周,终于找到了问题根源,做个简短总结分享给你,希望对大家有帮助。

问题描述:

线上线下测试结果不一致,且差异很大

具体来说,
线下测试直接load由ckpt存储的模型,然后使用cv进行数据预处理,然后评估测试集上的准召,一切正常。
线上测试时,首先使用tf.image相关函数将预处理写死在模型中,将ckpt模型转为savemodel格式,然后使用tf-serving部署后,发送请求进行线上实测,此时和线下测试结果差异较大。

问题定位:

主要问题出在ckpt转savemodel时,预处理部分 tf.image.resize_images 和 tf.cast 两个函数的使用上
虽然问题发生在模型转换时,但真正的问题出在对于tf.image.resize_images函数的使用上,因此任何可能的使用场景,包括预处理,数据增强,模型转换等,都有可能被它坑到,这也是我写这篇文章的原因,提醒大家不要向我一样被它坑到。
小小吐槽: 在发现真正的问题所在之前,由于我大幅的修改了我的训练框架,所以我从模型结构到loss函数,再到数据增强方法,排查了一遍,最终才发现,问题的出现,仅仅是我将
tf.image.resize_images 的 method 参数:
tf.image.ResizeMethod.BILINEAR 修改为了 tf.image.ResizeMethod.BICUBIC

what?这样小的一个修改就崩了?
下面我将我的排查过程详细描述出来,希望对大家有所启发。

如果打印tf.image.resize_images函数前后的数据类型

print(img_decoded.dtype)
resized_image = tf.image.resize_images(img_decoded, [new_height, new_width], method=tf.image.ResizeMethod.BICUBIC)
print(resized_image.dtype)

可以观察到如下结果
tf.uint8
tf.float32

而如果打印resize后的数据范围
tf img max 294.077484131
tf img min -25.2455863953

可以看到本来是0-255的uint8数据处理后不但数据类型发生了变化,而且像素值越界了!

此外,在预处理结束后我还使用了tf.cast函数转换数据类型

padd_image = tf.cast(resized_image, tf.uint8)

如果输入数据已经越界,此时tf.cast函数的使用也存在问题:

为方便理解问题,观察以下可视化结果:
使用cv2进行预处理的结果
小心使用tf.image.resize_images,填坑经验分享给你_第1张图片
tf版本的预处理结果(resize_images + cast)
小心使用tf.image.resize_images,填坑经验分享给你_第2张图片
可以看到resize_images + cast 函数的使用对原图有很明显的破坏

我们找到越界的部分,对resize后越界的部分进行可视化(用255或0截断后显示,正常区域用黑色填充)
小于0的部分
小心使用tf.image.resize_images,填坑经验分享给你_第3张图片
超过255的部分
小心使用tf.image.resize_images,填坑经验分享给你_第4张图片
上面两张图是正常越界截断后的结果,为了观察与tf.cast函数处理的区别

将 resize+cast 后 >255 部分的像素值可视化出来(为了凸显这部分像素,正常区域改用白色填充)
小心使用tf.image.resize_images,填坑经验分享给你_第5张图片
通过上图可以观察到,tf.cast对越界的处理机制并不是截断,而是类似取余操作,或者类似变量赋值时超过数据类型取值范围时的处理机制。

具体来说,如果越界的像素值是256,得到的返回结果对应的像素值是0;如果是257,得到的像素值是1,以此类推。

从图中越界的黄色区域(255,255,0)被tf.cast函数处理后变为蓝色区域(0,0,255)可以印证这一说法。

解决方法:
首先这并不能算google工程师的一个bug,因为tf.image.resize_images函数并没有对返回值的取值范围做保证,本质它就是进行插值,插值结果它不管。只是cv2或者PIL的类似函数中帮我们做了很多的“保护“。

通过尝试,最简便的解决方法是修改插值方法,经验证:
小心使用tf.image.resize_images,填坑经验分享给你_第6张图片
上面两种插值方法都不会造成像素值越界
如果你需要确保你的返回结果是在正常范围内的,那就在上面两个方法中选一个。此外,最邻近插值会带来比较明显的“不连续感”,因此推荐选择双线性插值,同时它也是默认参数。三次样条,虽然平滑性好,但是tf的实现版本真的是坑到我了。。。

当然,单单使用tf.image.resize_images也仅仅是对图片造成了微弱的扰动,但是配合上tf.cast函数的特有机制,对模型的干扰就比较大了。

综上:

1.使用 tf.image.resize_images函数时,如果使用三次样条插值,不要想当然的认为返回值是0-255的。
2.tf.cast函数的处理机制要注意,类似取余,而不是截断

搜了一下,被其它和resize相关的问题困扰的人也不少
感兴趣可以探究下
How Tensorflow’s tf.image.resize stole 60 days of my life
tensorflow-issues-19627
tf-image-resize-bilinear-vs-cv2-resize
说明tf的resize实现多少有些问题,这些应该不是bug,但确实给tensorflow的使用者们造成了不少困扰
小心使用tf.image.resize_images,填坑经验分享给你_第7张图片
这些函数的使用并不像cv2的api那样安全可信任
因此使用tf.image系的函数要慎重,一定要check数据类型,check函数处理后是否在0-255的范围,尤其是resize相关。

你可能感兴趣的:(AI杂货铺,tf.cast)