ArmNN 主要处理的模型是Tflite模型,对onnx模型的支持不好,但提供了ONNX 模型解析的基础框架,支持了少量的ONNX算子,想要支持更多的ONNX 算子,就必须自己去添加了,这里添加了QuantizeLayer和DequantizeLayer
armnn::TensorInfo ToTensorInfo(const std::string &name, std::vector<unsigned int> &shape, int data_type)函数添加新的数据类型
case onnx::TensorProto::INT8:
{
type = DataType::QAsymmS8;
break;
}
头文件中添加函数
void ParseQuantize(const onnx::NodeProto &nodeProto);
void ParseDequantize(const onnx::NodeProto &nodeProto);
添加实现函数
void OnnxParserImpl::ParseQuantize(const onnx::NodeProto &node)
{
CHECK_VALID_SIZE(armnn::numeric_cast<size_t>(node.input_size()), 3);
CHECK_VALID_SIZE(armnn::numeric_cast<size_t>(node.output_size()), 1);
IConnectableLayer *const layer = m_Network->AddQuantizeLayer(node.name().c_str());
ARMNN_ASSERT(layer != nullptr);
onnx::TensorProto onnxTensor = *m_TensorsInfo[node.input(1)].m_tensor;
auto srcFloatDataPtr1 = onnxTensor.float_data().data();
float scale = (*srcFloatDataPtr1);
onnx::TensorProto onnxTensor2 = *m_TensorsInfo[node.input(2)].m_tensor;
// auto srcFloatDataPtr2 = onnxTensor2.float_data().data();
// float zeropint = (*srcFloatDataPtr2);
// std::cout << zeropint << std::endl;
auto srcData = reinterpret_cast<const int32_t *>(onnxTensor2.raw_data().c_str());
int32_t zeropint = srcData[0];
auto outputInfo = ComputeOutputInfo({node.output(0)}, layer, {m_TensorsInfo[node.input(0)].m_info->GetShape()});
outputInfo[0].SetQuantizationScale(scale);
outputInfo[0].SetQuantizationOffset(zeropint);
layer->GetOutputSlot(0).SetTensorInfo(outputInfo[0]);
// register the input connection slots for the layer, connections are made after all layers have been created
// only the tensors for the inputs are relevant, exclude the const tensors
RegisterInputSlots(layer, {node.input(0)});
// register the output connection slots for the layer, connections are made after all layers have been created
RegisterOutputSlots(layer, {node.output(0)});
}
void OnnxParserImpl::ParseDequantize(const onnx::NodeProto &node)
{
CHECK_VALID_SIZE(armnn::numeric_cast<size_t>(node.input_size()), 3);
CHECK_VALID_SIZE(armnn::numeric_cast<size_t>(node.output_size()), 1);
IConnectableLayer *const layer = m_Network->AddDequantizeLayer(node.name().c_str());
ARMNN_ASSERT(layer != nullptr);
auto outputInfo = ComputeOutputInfo({node.output(0)}, layer, {m_TensorsInfo[node.input(0)].m_info->GetShape()});
layer->GetOutputSlot(0).SetTensorInfo(outputInfo[0]);
// register the input connection slots for the layer, connections are made after all layers have been created
// only the tensors for the inputs are relevant, exclude the const tensors
RegisterInputSlots(layer, {node.input(0)});
// register the output connection slots for the layer, connections are made after all layers have been created
RegisterOutputSlots(layer, {node.output(0)});
}
std::map<std::string, OnnxParserImpl::OperationParsingFunction> OnnxParserImpl::m_ParserFunctions中添加函数的映射
{"QuantizeLinear", &OnnxParserImpl::ParseQuantize},
{"DequantizeLinear", &OnnxParserImpl::ParseDequantize},