最近在作pytorch 模型转tensorflow,通过onnx 中间转换和容易,但是再转换时有一个注意事项,即如何处理batch
pytorch 模型转tensorflow: //www.greatytc.com/p/3e5623696a8e
通过 onnx 手动修改batch 为动态值: model_onnx.graph.input[0].type.tensor_type.shape.dim[0].dim_param ='?'
这下就可以使用batch 了。
最后别忘了用transform_graph 压缩下模型大小: //www.greatytc.com/p/d2637646cda1
self.model.load_state_dict(model_dict)
example = torch.ones(1,1,112,112).cuda() #限定好的tensor 输入大小
# traced_script_module = torch.jit.trace(self.model, example)
# traced_script_module.save('./lt_model.pt')
torch.onnx.export(self.model, example,'./model_simple.onnx',input_names=['input'],
output_names=['output'])
model_onnx = onnx.load('./model_simple.onnx')
model_onnx.graph.input[0].type.tensor_type.shape.dim[0].dim_param ='?'
tf_rep = prepare(model_onnx)
print(tf_rep.tensor_dict)
tf_rep.export_graph('./lt_model.pb')