Tensorflow转ONNX

参考文章:

https://github.com/onnx/tutorials/blob/master/tutorials/OnnxTensorflowExport.ipynb

https://www.jianshu.com/p/8ec3a6c9c453

使用TensorFlow框架训练模型,然后导出为ONNX格式,一般需要经历以下几个步骤:

训练(Training)

转PB文件(Graph Freezing)

模型格式转换(Model Conversion)


训练(Training)

为了成功将TF model转成ONNX model,需要准备3部分信息:

1. TF的图定义,即只包含网络拓扑信息(不含权重信息)。获取方法为在inference代码中,插入以下代码进行输出:

with open("net.proto", "wb") as file:
    graph = tf.get_default_graph().as_graph_def(add_shapes=True)
    file.write(graph.SerializeToString())

2. 形状信息。默认情况下,as_graph_def() 方法不会导出形状信息,通过指定add_shapes参数为True来强制其输出,方法参考上述代码。

3. 检查点文件(checkpoint,ckpt)。也就是我们平常所说的权重文件,TensorFlow一般将其导出为checkpoint文件。

总而言之,经过上述几步,我们可以得到:记录网络拓扑信息和形状信息的*.proto文件,以及记录权重的checkpoint文件若干。


转PB文件(Graph Freezing)

如上所述,常规TF的模型导出方法会将网络信息与权重信息分开存储在不同文件当中,这在部署时候不是很方便。官方提供了一种Freeze Graph的方式,用于将模型相关信息统统打包到一个*.pb文件当中。

官方提供了相关工具freeze_graph,一般安装完TensorFlow后会自动添加到用户PATH相应的bin目录下,如果没有找到的话可以去TensorFlow源码tensorflow/python/tools/free_graph.py这个位置去找一下,或者直接通过命令行导入module的方式调用。

举例如下,如果有多个输出节点,用逗号隔开:

# 1.直接调用
freeze_graph --input_graph=/home/mnist-tf/graph.proto \
    --input_checkpoint=/home/mnist-tf/ckpt/model.ckpt \
    --output_graph=/tmp/frozen_graph.pb \
    --output_node_names=fc2/add \
    --input_binary=True

# 2. 通过调用module的方式
python -m tensorflow.python.tools.freeze_graph \
    --input_graph=my_checkpoint_dir/graphdef.pb \
    --input_binary=true \
    --output_node_names=output \
    --input_checkpoint=my_checkpoint_dir \
    --output_graph=tests/models/fc-layers/frozen.pb

其中有一个比较困扰的点在于需要准确的知道输出节点的节点名称,我的做法是通过tf.get_default_graph().as_graph_def().node来得到各节点信息,然后再从中查看具体的输出节点名。

print([tensor.name for tensor in tf.get_default_graph().as_graph_def().node])

模型格式转换(Model Conversion)

使用https://github.com/onnx/tensorflow-onnx提供的tf2onnx工具,示例如下:

python -m tf2onnx.convert\
    --input tests/models/fc-layers/frozen.pb\
    --inputs X:0\
    --outputs output:0\
    --output tests/models/fc-layers/model.onnx\
    --verbose

模型输入输出名字需要用 node_name:port_id的格式,要不然后面会转换出错。

大功告成。

你可能感兴趣的:(Coooding)