mirror of
https://github.com/zebrajr/opencv.git
synced 2026-01-15 12:15:17 +00:00
Added conv kernel size
This commit is contained in:
@@ -2014,6 +2014,25 @@ void ONNXImporter::parseConv(LayerParams& layerParams, const opencv_onnx::NodePr
|
||||
layerParams.blobs.push_back(getBlob(node_proto, j));
|
||||
}
|
||||
}
|
||||
// ONNX allows omitting 'kernel_shape' attribute for Conv. In that case, it should be inferred from weights.
|
||||
// See: https://onnx.ai/onnx/operators/onnx__Conv.html
|
||||
if (!layerParams.has("kernel_size"))
|
||||
{
|
||||
Mat weights;
|
||||
if (!layerParams.blobs.empty())
|
||||
weights = layerParams.blobs[0];
|
||||
else if (constBlobs.find(node_proto.input(1)) != constBlobs.end())
|
||||
weights = getBlob(node_proto, 1);
|
||||
|
||||
if (!weights.empty() && weights.dims >= 3)
|
||||
{
|
||||
const int kDims = weights.dims - 2;
|
||||
std::vector<int32_t> kernel(kDims);
|
||||
for (int i = 0; i < kDims; ++i)
|
||||
kernel[i] = weights.size[2 + i];
|
||||
layerParams.set("kernel_size", DictValue::arrayInt(kernel.data(), static_cast<int>(kernel.size())));
|
||||
}
|
||||
}
|
||||
int outCn = layerParams.blobs.empty() ? outShapes[node_proto.input(1)][0] : layerParams.blobs[0].size[0];
|
||||
layerParams.set("num_output", outCn);
|
||||
|
||||
@@ -2030,6 +2049,20 @@ void ONNXImporter::parseConvTranspose(LayerParams& layerParams, const opencv_onn
|
||||
layerParams.set("num_output", layerParams.blobs[0].size[1] * layerParams.get<int>("group", 1));
|
||||
layerParams.set("bias_term", node_proto.input_size() == 3);
|
||||
|
||||
// ONNX allows omitting 'kernel_shape' attribute for ConvTranspose. Infer it from weights if needed.
|
||||
if (!layerParams.has("kernel_size"))
|
||||
{
|
||||
const Mat& weights = layerParams.blobs[0];
|
||||
if (!weights.empty() && weights.dims >= 3)
|
||||
{
|
||||
const int kDims = weights.dims - 2;
|
||||
std::vector<int32_t> kernel(kDims);
|
||||
for (int i = 0; i < kDims; ++i)
|
||||
kernel[i] = weights.size[2 + i];
|
||||
layerParams.set("kernel_size", DictValue::arrayInt(kernel.data(), static_cast<int>(kernel.size())));
|
||||
}
|
||||
}
|
||||
|
||||
if (!layerParams.has("kernel_size"))
|
||||
CV_Error(Error::StsNotImplemented,
|
||||
"Required attribute 'kernel_size' is not present.");
|
||||
|
||||
Reference in New Issue
Block a user