import onnx
import onnx_graphsurgeon as gs
import numpy as np
labels = ['mouse']
names = {cls_id: label for cls_id, label in enumerate(labels)}
model = onnx.load("/wjr/develop/projects/yolov5/qat.onnx")
meta = model.metadata_props.add()
meta.key, meta.value = "names", str(names)
graph = model.graph
graph.input[0].type.tensor_type.shape.dim[0].dim_param = 'None'
onnx.save(model, f'/wjr/develop/projects/yolov5/qat2.onnx')
n_cls = 1
graph = gs.import_onnx(onnx.load("/wjr/develop/projects/yolov5/qat2.onnx"))
split = gs.Constant("split", values=np.array([4, n_cls], dtype=np.int64))
output_box_0 = gs.Variable(
name="output_box_0", shape=(1, 25500, 4), dtype=np.float32)
output_scores = gs.Variable(
name="output_scores", shape=(1, 25500, n_cls), dtype=np.float32)
split_node = gs.Node(
op="Split",
inputs=[graph.outputs[0], split],
outputs=[output_box_0, output_scores],
attrs={"axis": 2}
)
box_shape = gs.Constant(
"shape", values=np.array([1, 25500, 1, 4], dtype=np.int64))
output_boxes = gs.Variable(
name="output_boxes", shape=(1, 25500, 1, 4), dtype=np.float32)
reshape_node = gs.Node(
op="Reshape", inputs=[output_box_0, box_shape], outputs=[output_boxes])
keepTopK = 100
topK = 1000
num_detections = gs.Variable(
name="num_detections", dtype=np.int32, shape=[1])
nmsed_boxes = gs.Variable(name="nmsed_boxes", dtype=np.float32, shape=[
1, keepTopK, 4])
nmsed_scores = gs.Variable(name="nmsed_scores",
dtype=np.float32, shape=[1, keepTopK])
nmsed_classes = gs.Variable(name="nmsed_classes",
dtype=np.float32, shape=[1, keepTopK])
new_outputs = [num_detections, nmsed_boxes, nmsed_scores, nmsed_classes]
attrs = {}
attrs["shareLocation"] = False
attrs["backgroundLabelId"] = -1
attrs["numClasses"] = n_cls
attrs["topK"] = topK
attrs["keepTopK"] = keepTopK
attrs["scoreThreshold"] = 0.50
attrs["iouThreshold"] = 0.7
attrs["isNormalized"] = False
attrs["clipBoxes"] = False
attrs['scoreBits'] = 16
nms_node = gs.Node(
op="BatchedNMSDynamic_TRT",
attrs=attrs,
inputs=[output_boxes, output_scores],
outputs=new_outputs
)
graph.nodes.extend([split_node, reshape_node, nms_node])
graph.outputs = new_outputs
graph.cleanup().toposort()
onnx.save(gs.export_onnx(graph), "/wjr/develop/projects/yolov5/qat3.onnx")