[Mlir-commits] [mlir] [mlir][tosa] Refactor convolution infer return type (PR #178869)
Luke Hutton
llvmlistbot at llvm.org
Fri Jan 30 03:23:18 PST 2026
https://github.com/lhutton1 created https://github.com/llvm/llvm-project/pull/178869
Lots of logic was repeated for Conv2D, Conv3D and Conv2DBlockScaled ops. This commit factors out common logic to reduce code duplication.
In doing so, a bug in calculating the bias shape was also fixed. Since DepthwiseConv2D and TransposeConv2D were fixed independently, this commit fixes #175765.
>From 63d12b3fdb1eb53a3f6034d71bdc71cbe8f56f47 Mon Sep 17 00:00:00 2001
From: Luke Hutton <luke.hutton at arm.com>
Date: Fri, 30 Jan 2026 10:22:11 +0000
Subject: [PATCH] [mlir][tosa] Refactor convolution infer return type
Lots of logic was repeated for Conv2D, Conv3D and Conv2DBlockScaled
ops. This commit factors out common logic to reduce code
duplication.
In doing so, a bug in calculating the bias shape was also fixed.
Since DepthwiseConv2D and TransposeConv2D were fixed independently,
this commit fixes #175765.
Change-Id: I5d13d876b05b1ff6b9955dc5cabf016160509524
---
mlir/lib/Dialect/Tosa/IR/TosaOps.cpp | 373 +++++++++---------
mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir | 22 +-
2 files changed, 216 insertions(+), 179 deletions(-)
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index c412d788c9b29..52432c8d87049 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -3435,162 +3435,241 @@ static LogicalResult poolingInferReturnTypes(
return success();
}
-LogicalResult Conv2DOp::inferReturnTypeComponents(
- MLIRContext *context, ::std::optional<Location> location,
- Conv2DOp::Adaptor adaptor,
- SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
- llvm::SmallVector<int64_t> outputShape(4, ShapedType::kDynamic);
-
- int64_t inputWidth = ShapedType::kDynamic;
- int64_t inputHeight = ShapedType::kDynamic;
- int64_t weightWidth = ShapedType::kDynamic;
- int64_t weightHeight = ShapedType::kDynamic;
+template <typename AdaptorT>
+class ConvInferShapeAdaptor;
- // Input shape describes input width/height and batch.
+class ConvInferShapeAdaptorBase {
+protected:
+ static void updateIfDynamic(int64_t ¤t, int64_t candidate) {
+ if (ShapedType::isDynamic(current))
+ current = candidate;
+ }
+};
- ShapeAdaptor inputShape(adaptor.getInput().getType());
- if (inputShape.hasRank()) {
+template <>
+class ConvInferShapeAdaptor<Conv2DOp::Adaptor>
+ : public ConvInferShapeAdaptorBase {
+public:
+ explicit ConvInferShapeAdaptor(Conv2DOp::Adaptor adaptor)
+ : adaptor(adaptor) {}
+
+ void inferInputShape(SmallVectorImpl<int64_t> &outputShape,
+ SmallVectorImpl<int64_t> &inputSpatial) {
+ const ShapeAdaptor inputShape(adaptor.getInput().getType());
+ if (!inputShape.hasRank())
+ return;
outputShape[0] = inputShape.getDimSize(0);
- inputHeight = inputShape.getDimSize(1);
- inputWidth = inputShape.getDimSize(2);
+ inputSpatial[0] = inputShape.getDimSize(1);
+ inputSpatial[1] = inputShape.getDimSize(2);
}
- // Weight shapes describes the filter width/height and the output channels.
- ShapeAdaptor weightShape(adaptor.getWeight().getType());
- if (weightShape.hasRank()) {
+ void inferWeightShape(SmallVectorImpl<int64_t> &outputShape,
+ SmallVectorImpl<int64_t> &weightSpatial) {
+ const ShapeAdaptor weightShape(adaptor.getWeight().getType());
+ if (!weightShape.hasRank())
+ return;
outputShape[3] = weightShape.getDimSize(0);
- weightHeight = weightShape.getDimSize(1);
- weightWidth = weightShape.getDimSize(2);
+ weightSpatial[0] = weightShape.getDimSize(1);
+ weightSpatial[1] = weightShape.getDimSize(2);
}
- // Bias shape can describe the output channels.
- ShapeAdaptor biasShape(adaptor.getBias().getType());
- if (biasShape.hasRank()) {
- outputShape[3] = ShapedType::isDynamic(outputShape[3])
- ? biasShape.getDimSize(0)
- : outputShape[3];
- }
+ int64_t getNumSpatialDims() const { return 2; }
+ int64_t getOutputRank() const { return 4; }
- llvm::ArrayRef<int64_t> dilation = adaptor.getDilation();
- llvm::ArrayRef<int64_t> stride = adaptor.getStride();
- llvm::ArrayRef<int64_t> padding = adaptor.getPad();
-
- if (ShapedType::isStatic(inputHeight) && ShapedType::isStatic(weightHeight)) {
- int64_t inputSize = inputHeight + padding[0] + padding[1];
- int64_t filterSize = (weightHeight - 1) * dilation[0] + 1;
- int64_t unstridedResult = inputSize - filterSize + 1;
- outputShape[1] = (unstridedResult - 1) / stride[0] + 1;
+ LogicalResult getSpatialParameters(SmallVector<int64_t> &padValues,
+ SmallVector<int64_t> &strideValues,
+ SmallVector<int64_t> &dilationValues) {
+ padValues.assign(adaptor.getPad().begin(), adaptor.getPad().end());
+ strideValues.assign(adaptor.getStride().begin(), adaptor.getStride().end());
+ dilationValues.assign(adaptor.getDilation().begin(),
+ adaptor.getDilation().end());
+ return success();
}
- if (ShapedType::isStatic(inputWidth) && ShapedType::isStatic(weightWidth)) {
- int64_t inputSize = inputWidth + padding[2] + padding[3];
- int64_t filterSize = (weightWidth - 1) * dilation[1] + 1;
- int64_t unstridedResult = inputSize - filterSize + 1;
- outputShape[2] = (unstridedResult - 1) / stride[1] + 1;
- }
+private:
+ Conv2DOp::Adaptor adaptor;
+};
- inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
- return success();
-}
+template <>
+class ConvInferShapeAdaptor<Conv2DBlockScaledOp::Adaptor>
+ : public ConvInferShapeAdaptorBase {
+public:
+ explicit ConvInferShapeAdaptor(Conv2DBlockScaledOp::Adaptor adaptor)
+ : adaptor(adaptor) {}
+
+ void inferInputShape(SmallVectorImpl<int64_t> &outputShape,
+ SmallVectorImpl<int64_t> &inputSpatial) {
+ const ShapeAdaptor inputDataShape(adaptor.getInputData().getType());
+ if (inputDataShape.hasRank()) {
+ outputShape[0] = inputDataShape.getDimSize(0);
+ inputSpatial[0] = inputDataShape.getDimSize(1);
+ inputSpatial[1] = inputDataShape.getDimSize(2);
+ }
-LogicalResult Conv2DOp::verify() {
- if (verifyConvOp(*this).failed() || verifyConvOpModes(*this).failed() ||
- verifyConvOpErrorIf(*this).failed())
- return failure();
- return success();
-}
+ const ShapeAdaptor inputScaleShape(adaptor.getInputScale().getType());
+ if (!inputScaleShape.hasRank())
+ return;
+ updateIfDynamic(outputShape[0], inputScaleShape.getDimSize(0));
+ updateIfDynamic(inputSpatial[0], inputScaleShape.getDimSize(1));
+ updateIfDynamic(inputSpatial[1], inputScaleShape.getDimSize(2));
+ }
+
+ void inferWeightShape(SmallVectorImpl<int64_t> &outputShape,
+ SmallVectorImpl<int64_t> &weightSpatial) {
+ const ShapeAdaptor weightDataShape(adaptor.getWeightData().getType());
+ if (weightDataShape.hasRank()) {
+ outputShape[3] = weightDataShape.getDimSize(0);
+ weightSpatial[0] = weightDataShape.getDimSize(1);
+ weightSpatial[1] = weightDataShape.getDimSize(2);
+ }
-LogicalResult Conv2DBlockScaledOp::inferReturnTypeComponents(
- MLIRContext *context, ::std::optional<Location> location,
- Conv2DBlockScaledOp::Adaptor adaptor,
- SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
- SmallVector<int64_t, 4> outShape(4, ShapedType::kDynamic);
+ const ShapeAdaptor weightScaleShape(adaptor.getWeightScale().getType());
+ if (!weightScaleShape.hasRank())
+ return;
+ updateIfDynamic(outputShape[3], weightScaleShape.getDimSize(0));
+ updateIfDynamic(weightSpatial[0], weightScaleShape.getDimSize(1));
+ updateIfDynamic(weightSpatial[1], weightScaleShape.getDimSize(2));
+ }
+
+ int64_t getNumSpatialDims() const { return 2; }
+ int64_t getOutputRank() const { return 4; }
+
+ LogicalResult getSpatialParameters(SmallVector<int64_t> &padValues,
+ SmallVector<int64_t> &strideValues,
+ SmallVector<int64_t> &dilationValues) {
+ if (!tosa::getConstShapeValues(adaptor.getPad().getDefiningOp(),
+ padValues) ||
+ !tosa::getConstShapeValues(adaptor.getStride().getDefiningOp(),
+ strideValues) ||
+ !tosa::getConstShapeValues(adaptor.getDilation().getDefiningOp(),
+ dilationValues))
+ return failure();
+ return success();
+ }
- int64_t inputWidth = ShapedType::kDynamic;
- int64_t inputHeight = ShapedType::kDynamic;
- int64_t weightWidth = ShapedType::kDynamic;
- int64_t weightHeight = ShapedType::kDynamic;
+private:
+ Conv2DBlockScaledOp::Adaptor adaptor;
+};
- // Input shape describes input width/height and batch.
- const ShapeAdaptor inputDataShape(adaptor.getInputData().getType());
- if (inputDataShape.hasRank()) {
- outShape[0] = inputDataShape.getDimSize(0);
- inputHeight = inputDataShape.getDimSize(1);
- inputWidth = inputDataShape.getDimSize(2);
- }
- const ShapeAdaptor inputScaleShape(adaptor.getInputScale().getType());
- if (inputScaleShape.hasRank()) {
- outShape[0] = ShapedType::isDynamic(outShape[0])
- ? inputScaleShape.getDimSize(0)
- : outShape[0];
- inputHeight = ShapedType::isDynamic(inputHeight)
- ? inputScaleShape.getDimSize(1)
- : inputHeight;
- inputWidth = ShapedType::isDynamic(inputWidth)
- ? inputScaleShape.getDimSize(2)
- : inputWidth;
+template <>
+class ConvInferShapeAdaptor<Conv3DOp::Adaptor>
+ : public ConvInferShapeAdaptorBase {
+public:
+ explicit ConvInferShapeAdaptor(Conv3DOp::Adaptor adaptor)
+ : adaptor(adaptor) {}
+
+ void inferInputShape(SmallVectorImpl<int64_t> &outputShape,
+ SmallVectorImpl<int64_t> &inputSpatial) {
+ const ShapeAdaptor inputShape(adaptor.getInput().getType());
+ if (!inputShape.hasRank())
+ return;
+ outputShape[0] = inputShape.getDimSize(0);
+ inputSpatial[0] = inputShape.getDimSize(1);
+ inputSpatial[1] = inputShape.getDimSize(2);
+ inputSpatial[2] = inputShape.getDimSize(3);
}
- // Weight shapes describes the filter width/height and the output channels.
- const ShapeAdaptor weightDataShape(adaptor.getWeightData().getType());
- if (weightDataShape.hasRank()) {
- outShape[3] = weightDataShape.getDimSize(0);
- weightHeight = weightDataShape.getDimSize(1);
- weightWidth = weightDataShape.getDimSize(2);
+ void inferWeightShape(SmallVectorImpl<int64_t> &outputShape,
+ SmallVectorImpl<int64_t> &weightSpatial) {
+ const ShapeAdaptor weightShape(adaptor.getWeight().getType());
+ if (!weightShape.hasRank())
+ return;
+ outputShape[4] = weightShape.getDimSize(0);
+ weightSpatial[0] = weightShape.getDimSize(1);
+ weightSpatial[1] = weightShape.getDimSize(2);
+ weightSpatial[2] = weightShape.getDimSize(3);
}
- const ShapeAdaptor weightScaleShape(adaptor.getWeightScale().getType());
- if (weightScaleShape.hasRank()) {
- outShape[3] = ShapedType::isDynamic(outShape[3])
- ? weightScaleShape.getDimSize(0)
- : outShape[3];
- weightHeight = ShapedType::isDynamic(weightHeight)
- ? weightScaleShape.getDimSize(1)
- : weightHeight;
- weightWidth = ShapedType::isDynamic(weightWidth)
- ? weightScaleShape.getDimSize(2)
- : weightWidth;
+
+ int64_t getNumSpatialDims() const { return 3; }
+ int64_t getOutputRank() const { return 5; }
+
+ LogicalResult getSpatialParameters(SmallVector<int64_t> &padValues,
+ SmallVector<int64_t> &strideValues,
+ SmallVector<int64_t> &dilationValues) {
+ padValues.assign(adaptor.getPad().begin(), adaptor.getPad().end());
+ strideValues.assign(adaptor.getStride().begin(), adaptor.getStride().end());
+ dilationValues.assign(adaptor.getDilation().begin(),
+ adaptor.getDilation().end());
+ return success();
}
- // Bias shape can describe the output channels.
- const ShapeAdaptor biasShape(adaptor.getBias().getType());
+private:
+ Conv3DOp::Adaptor adaptor;
+};
+
+template <typename AdaptorT>
+LogicalResult inferConvReturnTypeComponents(
+ AdaptorT adaptor,
+ SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
+ ConvInferShapeAdaptor<AdaptorT> convShapeAdaptor(adaptor);
+ llvm::SmallVector<int64_t> outputShape(convShapeAdaptor.getOutputRank(),
+ ShapedType::kDynamic);
+ llvm::SmallVector<int64_t> inputSpatial(convShapeAdaptor.getNumSpatialDims(),
+ ShapedType::kDynamic);
+ llvm::SmallVector<int64_t> weightSpatial(convShapeAdaptor.getNumSpatialDims(),
+ ShapedType::kDynamic);
+
+ convShapeAdaptor.inferInputShape(outputShape, inputSpatial);
+ convShapeAdaptor.inferWeightShape(outputShape, weightSpatial);
+
+ const ShapeAdaptor biasShape = adaptor.getBias().getType();
if (biasShape.hasRank()) {
const int64_t biasSize = biasShape.getDimSize(0);
- // Bias of size 1 may be broadcast
if (biasSize != 1) {
- outShape[3] = ShapedType::isDynamic(outShape[3]) ? biasSize : outShape[3];
+ const size_t outputChannelDim = convShapeAdaptor.getOutputRank() - 1;
+ outputShape[outputChannelDim] =
+ ShapedType::isDynamic(outputShape[outputChannelDim])
+ ? biasSize
+ : outputShape[outputChannelDim];
}
}
SmallVector<int64_t> padValues;
SmallVector<int64_t> strideValues;
SmallVector<int64_t> dilationValues;
- if (!tosa::getConstShapeValues(adaptor.getPad().getDefiningOp(), padValues) ||
- !tosa::getConstShapeValues(adaptor.getStride().getDefiningOp(),
- strideValues) ||
- !tosa::getConstShapeValues(adaptor.getDilation().getDefiningOp(),
- dilationValues)) {
- inferredReturnShapes.push_back(ShapedTypeComponents(outShape));
+ if (failed(convShapeAdaptor.getSpatialParameters(padValues, strideValues,
+ dilationValues))) {
+ inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
return success();
}
- if (ShapedType::isStatic(inputHeight) && ShapedType::isStatic(weightHeight)) {
- const int64_t inputSize = inputHeight + padValues[0] + padValues[1];
- const int64_t filterSize = (weightHeight - 1) * dilationValues[0] + 1;
+ for (int64_t dim = 0; dim < convShapeAdaptor.getNumSpatialDims(); ++dim) {
+ if (!ShapedType::isStatic(inputSpatial[dim]) ||
+ !ShapedType::isStatic(weightSpatial[dim]))
+ continue;
+ const int64_t inputSize =
+ inputSpatial[dim] + padValues[2 * dim] + padValues[2 * dim + 1];
+ const int64_t filterSize =
+ (weightSpatial[dim] - 1) * dilationValues[dim] + 1;
const int64_t unstridedResult = inputSize - filterSize + 1;
- outShape[1] = (unstridedResult - 1) / strideValues[0] + 1;
+ outputShape[dim + 1] = (unstridedResult - 1) / strideValues[dim] + 1;
}
- if (ShapedType::isStatic(inputWidth) && ShapedType::isStatic(weightWidth)) {
- const int64_t inputSize = inputWidth + padValues[2] + padValues[3];
- const int64_t filterSize = (weightWidth - 1) * dilationValues[1] + 1;
- const int64_t unstridedResult = inputSize - filterSize + 1;
- outShape[2] = (unstridedResult - 1) / strideValues[1] + 1;
- }
+ inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
+ return success();
+}
- inferredReturnShapes.push_back(ShapedTypeComponents(outShape));
+LogicalResult Conv2DOp::inferReturnTypeComponents(
+ MLIRContext *context, ::std::optional<Location> location,
+ Conv2DOp::Adaptor adaptor,
+ SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
+ return inferConvReturnTypeComponents(adaptor, inferredReturnShapes);
+}
+
+LogicalResult Conv2DOp::verify() {
+ if (verifyConvOp(*this).failed() || verifyConvOpModes(*this).failed() ||
+ verifyConvOpErrorIf(*this).failed())
+ return failure();
return success();
}
+LogicalResult Conv2DBlockScaledOp::inferReturnTypeComponents(
+ MLIRContext *context, ::std::optional<Location> location,
+ Conv2DBlockScaledOp::Adaptor adaptor,
+ SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
+ return inferConvReturnTypeComponents(adaptor, inferredReturnShapes);
+}
+
LogicalResult Conv2DBlockScaledOp::verify() {
if (failed(verifySameElementTypes(*this, getInputData().getType(),
getWeightData().getType(), "input_data",
@@ -3732,67 +3811,7 @@ LogicalResult Conv3DOp::inferReturnTypeComponents(
MLIRContext *context, ::std::optional<Location> location,
Conv3DOp::Adaptor adaptor,
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
- llvm::SmallVector<int64_t> outputShape(5, ShapedType::kDynamic);
-
- int64_t inputWidth = ShapedType::kDynamic;
- int64_t inputHeight = ShapedType::kDynamic;
- int64_t inputDepth = ShapedType::kDynamic;
-
- int64_t weightWidth = ShapedType::kDynamic;
- int64_t weightHeight = ShapedType::kDynamic;
- int64_t weightDepth = ShapedType::kDynamic;
-
- // Input shape describes input width/height and batch.
- ShapeAdaptor inputShape(adaptor.getInput().getType());
- if (inputShape.hasRank()) {
- outputShape[0] = inputShape.getDimSize(0);
- inputDepth = inputShape.getDimSize(1);
- inputHeight = inputShape.getDimSize(2);
- inputWidth = inputShape.getDimSize(3);
- }
-
- // Weight shapes describes the filter width/height and the output channels.
- ShapeAdaptor weightShape(adaptor.getWeight().getType());
- if (weightShape.hasRank()) {
- outputShape[4] = weightShape.getDimSize(0);
- weightDepth = weightShape.getDimSize(1);
- weightHeight = weightShape.getDimSize(2);
- weightWidth = weightShape.getDimSize(3);
- }
-
- // Bias shape can describe the output channels.
- ShapeAdaptor biasShape(adaptor.getBias().getType());
- if (biasShape.hasRank() && ShapedType::isDynamic(outputShape[4])) {
- outputShape[4] = biasShape.getDimSize(0);
- }
-
- llvm::ArrayRef<int64_t> dilation = adaptor.getDilation();
- llvm::ArrayRef<int64_t> stride = adaptor.getStride();
- llvm::ArrayRef<int64_t> pad = adaptor.getPad();
-
- if (ShapedType::isStatic(inputDepth) && ShapedType::isStatic(weightDepth)) {
- int32_t inputSize = inputDepth + pad[0] + pad[1];
- int32_t filterSize = (weightDepth - 1) * dilation[0] + 1;
- int32_t unstridedResult = inputSize - filterSize + 1;
- outputShape[1] = (unstridedResult - 1) / stride[0] + 1;
- }
-
- if (ShapedType::isStatic(inputHeight) && ShapedType::isStatic(weightHeight)) {
- int32_t inputSize = inputHeight + pad[2] + pad[3];
- int32_t filterSize = (weightHeight - 1) * dilation[1] + 1;
- int32_t unstridedResult = inputSize - filterSize + 1;
- outputShape[2] = (unstridedResult - 1) / stride[1] + 1;
- }
-
- if (ShapedType::isStatic(inputWidth) && ShapedType::isStatic(weightWidth)) {
- int32_t inputSize = inputWidth + pad[4] + pad[5];
- int32_t filterSize = (weightWidth - 1) * dilation[2] + 1;
- int32_t unstridedResult = inputSize - filterSize + 1;
- outputShape[3] = (unstridedResult - 1) / stride[2] + 1;
- }
-
- inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
- return success();
+ return inferConvReturnTypeComponents(adaptor, inferredReturnShapes);
}
LogicalResult Conv3DOp::verify() {
diff --git a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
index bc5f41b1af304..231eec8192eec 100644
--- a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
+++ b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
@@ -1741,8 +1741,26 @@ func.func @test_tconv2d_bias_broadcast(%input: tensor<2x6x7x3xf32>, %weight: ten
%0 = tosa.transpose_conv2d %input, %weight, %bias, %input_zp, %weight_zp
{ acc_type = f32, pad = array<i64: 0, 0, 0, 0>, out_pad = array<i64: 0, 0, 0, 0>, out_shape = array<i64: -1, -1, -1, -1>, stride = array<i64: 1, 1>, dilation = array<i64: 1, 1> }
: (tensor<2x6x7x3xf32>, tensor<?x3x3x3xf32>, tensor<1xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<?x?x?x?xf32>
- return
- }
+ return
+}
+
+// -----
+
+// CHECK-LABEL: @test_conv2d_bias_broadcast
+func.func @test_conv2d_bias_broadcast(%input: tensor<2x8x9x3xf32>, %weights: tensor<?x3x6x3xf32>, %bias: tensor<1xf32>, %input_zp: tensor<1xf32>, %weight_zp: tensor<1xf32>) -> () {
+ // CHECK: -> tensor<2x6x4x?xf32>
+ %0 = tosa.conv2d %input, %weights, %bias, %input_zp, %weight_zp {acc_type = f32, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>, dilation = array<i64: 1, 1>} : (tensor<2x8x9x3xf32>, tensor<?x3x6x3xf32>, tensor<1xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<?x?x?x?xf32>
+ return
+}
+
+// -----
+
+// CHECK-LABEL: @test_conv3d_bias_broadcast
+func.func @test_conv3d_bias_broadcast(%input: tensor<2x8x9x10x3xf32>, %weights: tensor<?x3x6x4x3xf32>, %bias: tensor<1xf32>, %input_zp: tensor<1xf32>, %weight_zp: tensor<1xf32>) -> () {
+ // CHECK: -> tensor<2x6x4x7x?xf32>
+ %0 = tosa.conv3d %input, %weights, %bias, %input_zp, %weight_zp {acc_type = f32, dilation = array<i64: 1, 1, 1>, pad = array<i64: 0, 0, 0, 0, 0, 0>, stride = array<i64: 1, 1, 1>} : (tensor<2x8x9x10x3xf32>, tensor<?x3x6x4x3xf32>, tensor<1xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<?x?x?x?x?xf32>
+ return
+}
// -----
More information about the Mlir-commits
mailing list