[Mlir-commits] [mlir] [mlir][tosa] Refactor convolution infer return type (PR #178869)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Jan 30 03:23:48 PST 2026


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-tosa

Author: Luke Hutton (lhutton1)

<details>
<summary>Changes</summary>

Lots of logic was repeated for Conv2D, Conv3D and  Conv2DBlockScaled ops. This commit factors out common logic to reduce code duplication.

In doing so, a bug in calculating the bias shape was also fixed. Since DepthwiseConv2D and TransposeConv2D were fixed independently, this commit fixes #<!-- -->175765.

---

Patch is 20.21 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/178869.diff


2 Files Affected:

- (modified) mlir/lib/Dialect/Tosa/IR/TosaOps.cpp (+196-177) 
- (modified) mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir (+20-2) 


``````````diff
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index c412d788c9b29..52432c8d87049 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -3435,162 +3435,241 @@ static LogicalResult poolingInferReturnTypes(
   return success();
 }
 
-LogicalResult Conv2DOp::inferReturnTypeComponents(
-    MLIRContext *context, ::std::optional<Location> location,
-    Conv2DOp::Adaptor adaptor,
-    SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
-  llvm::SmallVector<int64_t> outputShape(4, ShapedType::kDynamic);
-
-  int64_t inputWidth = ShapedType::kDynamic;
-  int64_t inputHeight = ShapedType::kDynamic;
-  int64_t weightWidth = ShapedType::kDynamic;
-  int64_t weightHeight = ShapedType::kDynamic;
+template <typename AdaptorT>
+class ConvInferShapeAdaptor;
 
-  // Input shape describes input width/height and batch.
+class ConvInferShapeAdaptorBase {
+protected:
+  static void updateIfDynamic(int64_t &current, int64_t candidate) {
+    if (ShapedType::isDynamic(current))
+      current = candidate;
+  }
+};
 
-  ShapeAdaptor inputShape(adaptor.getInput().getType());
-  if (inputShape.hasRank()) {
+template <>
+class ConvInferShapeAdaptor<Conv2DOp::Adaptor>
+    : public ConvInferShapeAdaptorBase {
+public:
+  explicit ConvInferShapeAdaptor(Conv2DOp::Adaptor adaptor)
+      : adaptor(adaptor) {}
+
+  void inferInputShape(SmallVectorImpl<int64_t> &outputShape,
+                       SmallVectorImpl<int64_t> &inputSpatial) {
+    const ShapeAdaptor inputShape(adaptor.getInput().getType());
+    if (!inputShape.hasRank())
+      return;
     outputShape[0] = inputShape.getDimSize(0);
-    inputHeight = inputShape.getDimSize(1);
-    inputWidth = inputShape.getDimSize(2);
+    inputSpatial[0] = inputShape.getDimSize(1);
+    inputSpatial[1] = inputShape.getDimSize(2);
   }
 
-  // Weight shapes describes the filter width/height and the output channels.
-  ShapeAdaptor weightShape(adaptor.getWeight().getType());
-  if (weightShape.hasRank()) {
+  void inferWeightShape(SmallVectorImpl<int64_t> &outputShape,
+                        SmallVectorImpl<int64_t> &weightSpatial) {
+    const ShapeAdaptor weightShape(adaptor.getWeight().getType());
+    if (!weightShape.hasRank())
+      return;
     outputShape[3] = weightShape.getDimSize(0);
-    weightHeight = weightShape.getDimSize(1);
-    weightWidth = weightShape.getDimSize(2);
+    weightSpatial[0] = weightShape.getDimSize(1);
+    weightSpatial[1] = weightShape.getDimSize(2);
   }
 
-  // Bias shape can describe the output channels.
-  ShapeAdaptor biasShape(adaptor.getBias().getType());
-  if (biasShape.hasRank()) {
-    outputShape[3] = ShapedType::isDynamic(outputShape[3])
-                         ? biasShape.getDimSize(0)
-                         : outputShape[3];
-  }
+  int64_t getNumSpatialDims() const { return 2; }
+  int64_t getOutputRank() const { return 4; }
 
-  llvm::ArrayRef<int64_t> dilation = adaptor.getDilation();
-  llvm::ArrayRef<int64_t> stride = adaptor.getStride();
-  llvm::ArrayRef<int64_t> padding = adaptor.getPad();
-
-  if (ShapedType::isStatic(inputHeight) && ShapedType::isStatic(weightHeight)) {
-    int64_t inputSize = inputHeight + padding[0] + padding[1];
-    int64_t filterSize = (weightHeight - 1) * dilation[0] + 1;
-    int64_t unstridedResult = inputSize - filterSize + 1;
-    outputShape[1] = (unstridedResult - 1) / stride[0] + 1;
+  LogicalResult getSpatialParameters(SmallVector<int64_t> &padValues,
+                                     SmallVector<int64_t> &strideValues,
+                                     SmallVector<int64_t> &dilationValues) {
+    padValues.assign(adaptor.getPad().begin(), adaptor.getPad().end());
+    strideValues.assign(adaptor.getStride().begin(), adaptor.getStride().end());
+    dilationValues.assign(adaptor.getDilation().begin(),
+                          adaptor.getDilation().end());
+    return success();
   }
 
-  if (ShapedType::isStatic(inputWidth) && ShapedType::isStatic(weightWidth)) {
-    int64_t inputSize = inputWidth + padding[2] + padding[3];
-    int64_t filterSize = (weightWidth - 1) * dilation[1] + 1;
-    int64_t unstridedResult = inputSize - filterSize + 1;
-    outputShape[2] = (unstridedResult - 1) / stride[1] + 1;
-  }
+private:
+  Conv2DOp::Adaptor adaptor;
+};
 
-  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
-  return success();
-}
+template <>
+class ConvInferShapeAdaptor<Conv2DBlockScaledOp::Adaptor>
+    : public ConvInferShapeAdaptorBase {
+public:
+  explicit ConvInferShapeAdaptor(Conv2DBlockScaledOp::Adaptor adaptor)
+      : adaptor(adaptor) {}
+
+  void inferInputShape(SmallVectorImpl<int64_t> &outputShape,
+                       SmallVectorImpl<int64_t> &inputSpatial) {
+    const ShapeAdaptor inputDataShape(adaptor.getInputData().getType());
+    if (inputDataShape.hasRank()) {
+      outputShape[0] = inputDataShape.getDimSize(0);
+      inputSpatial[0] = inputDataShape.getDimSize(1);
+      inputSpatial[1] = inputDataShape.getDimSize(2);
+    }
 
-LogicalResult Conv2DOp::verify() {
-  if (verifyConvOp(*this).failed() || verifyConvOpModes(*this).failed() ||
-      verifyConvOpErrorIf(*this).failed())
-    return failure();
-  return success();
-}
+    const ShapeAdaptor inputScaleShape(adaptor.getInputScale().getType());
+    if (!inputScaleShape.hasRank())
+      return;
+    updateIfDynamic(outputShape[0], inputScaleShape.getDimSize(0));
+    updateIfDynamic(inputSpatial[0], inputScaleShape.getDimSize(1));
+    updateIfDynamic(inputSpatial[1], inputScaleShape.getDimSize(2));
+  }
+
+  void inferWeightShape(SmallVectorImpl<int64_t> &outputShape,
+                        SmallVectorImpl<int64_t> &weightSpatial) {
+    const ShapeAdaptor weightDataShape(adaptor.getWeightData().getType());
+    if (weightDataShape.hasRank()) {
+      outputShape[3] = weightDataShape.getDimSize(0);
+      weightSpatial[0] = weightDataShape.getDimSize(1);
+      weightSpatial[1] = weightDataShape.getDimSize(2);
+    }
 
-LogicalResult Conv2DBlockScaledOp::inferReturnTypeComponents(
-    MLIRContext *context, ::std::optional<Location> location,
-    Conv2DBlockScaledOp::Adaptor adaptor,
-    SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
-  SmallVector<int64_t, 4> outShape(4, ShapedType::kDynamic);
+    const ShapeAdaptor weightScaleShape(adaptor.getWeightScale().getType());
+    if (!weightScaleShape.hasRank())
+      return;
+    updateIfDynamic(outputShape[3], weightScaleShape.getDimSize(0));
+    updateIfDynamic(weightSpatial[0], weightScaleShape.getDimSize(1));
+    updateIfDynamic(weightSpatial[1], weightScaleShape.getDimSize(2));
+  }
+
+  int64_t getNumSpatialDims() const { return 2; }
+  int64_t getOutputRank() const { return 4; }
+
+  LogicalResult getSpatialParameters(SmallVector<int64_t> &padValues,
+                                     SmallVector<int64_t> &strideValues,
+                                     SmallVector<int64_t> &dilationValues) {
+    if (!tosa::getConstShapeValues(adaptor.getPad().getDefiningOp(),
+                                   padValues) ||
+        !tosa::getConstShapeValues(adaptor.getStride().getDefiningOp(),
+                                   strideValues) ||
+        !tosa::getConstShapeValues(adaptor.getDilation().getDefiningOp(),
+                                   dilationValues))
+      return failure();
+    return success();
+  }
 
-  int64_t inputWidth = ShapedType::kDynamic;
-  int64_t inputHeight = ShapedType::kDynamic;
-  int64_t weightWidth = ShapedType::kDynamic;
-  int64_t weightHeight = ShapedType::kDynamic;
+private:
+  Conv2DBlockScaledOp::Adaptor adaptor;
+};
 
-  // Input shape describes input width/height and batch.
-  const ShapeAdaptor inputDataShape(adaptor.getInputData().getType());
-  if (inputDataShape.hasRank()) {
-    outShape[0] = inputDataShape.getDimSize(0);
-    inputHeight = inputDataShape.getDimSize(1);
-    inputWidth = inputDataShape.getDimSize(2);
-  }
-  const ShapeAdaptor inputScaleShape(adaptor.getInputScale().getType());
-  if (inputScaleShape.hasRank()) {
-    outShape[0] = ShapedType::isDynamic(outShape[0])
-                      ? inputScaleShape.getDimSize(0)
-                      : outShape[0];
-    inputHeight = ShapedType::isDynamic(inputHeight)
-                      ? inputScaleShape.getDimSize(1)
-                      : inputHeight;
-    inputWidth = ShapedType::isDynamic(inputWidth)
-                     ? inputScaleShape.getDimSize(2)
-                     : inputWidth;
+template <>
+class ConvInferShapeAdaptor<Conv3DOp::Adaptor>
+    : public ConvInferShapeAdaptorBase {
+public:
+  explicit ConvInferShapeAdaptor(Conv3DOp::Adaptor adaptor)
+      : adaptor(adaptor) {}
+
+  void inferInputShape(SmallVectorImpl<int64_t> &outputShape,
+                       SmallVectorImpl<int64_t> &inputSpatial) {
+    const ShapeAdaptor inputShape(adaptor.getInput().getType());
+    if (!inputShape.hasRank())
+      return;
+    outputShape[0] = inputShape.getDimSize(0);
+    inputSpatial[0] = inputShape.getDimSize(1);
+    inputSpatial[1] = inputShape.getDimSize(2);
+    inputSpatial[2] = inputShape.getDimSize(3);
   }
 
-  // Weight shapes describes the filter width/height and the output channels.
-  const ShapeAdaptor weightDataShape(adaptor.getWeightData().getType());
-  if (weightDataShape.hasRank()) {
-    outShape[3] = weightDataShape.getDimSize(0);
-    weightHeight = weightDataShape.getDimSize(1);
-    weightWidth = weightDataShape.getDimSize(2);
+  void inferWeightShape(SmallVectorImpl<int64_t> &outputShape,
+                        SmallVectorImpl<int64_t> &weightSpatial) {
+    const ShapeAdaptor weightShape(adaptor.getWeight().getType());
+    if (!weightShape.hasRank())
+      return;
+    outputShape[4] = weightShape.getDimSize(0);
+    weightSpatial[0] = weightShape.getDimSize(1);
+    weightSpatial[1] = weightShape.getDimSize(2);
+    weightSpatial[2] = weightShape.getDimSize(3);
   }
-  const ShapeAdaptor weightScaleShape(adaptor.getWeightScale().getType());
-  if (weightScaleShape.hasRank()) {
-    outShape[3] = ShapedType::isDynamic(outShape[3])
-                      ? weightScaleShape.getDimSize(0)
-                      : outShape[3];
-    weightHeight = ShapedType::isDynamic(weightHeight)
-                       ? weightScaleShape.getDimSize(1)
-                       : weightHeight;
-    weightWidth = ShapedType::isDynamic(weightWidth)
-                      ? weightScaleShape.getDimSize(2)
-                      : weightWidth;
+
+  int64_t getNumSpatialDims() const { return 3; }
+  int64_t getOutputRank() const { return 5; }
+
+  LogicalResult getSpatialParameters(SmallVector<int64_t> &padValues,
+                                     SmallVector<int64_t> &strideValues,
+                                     SmallVector<int64_t> &dilationValues) {
+    padValues.assign(adaptor.getPad().begin(), adaptor.getPad().end());
+    strideValues.assign(adaptor.getStride().begin(), adaptor.getStride().end());
+    dilationValues.assign(adaptor.getDilation().begin(),
+                          adaptor.getDilation().end());
+    return success();
   }
 
-  // Bias shape can describe the output channels.
-  const ShapeAdaptor biasShape(adaptor.getBias().getType());
+private:
+  Conv3DOp::Adaptor adaptor;
+};
+
+template <typename AdaptorT>
+LogicalResult inferConvReturnTypeComponents(
+    AdaptorT adaptor,
+    SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
+  ConvInferShapeAdaptor<AdaptorT> convShapeAdaptor(adaptor);
+  llvm::SmallVector<int64_t> outputShape(convShapeAdaptor.getOutputRank(),
+                                         ShapedType::kDynamic);
+  llvm::SmallVector<int64_t> inputSpatial(convShapeAdaptor.getNumSpatialDims(),
+                                          ShapedType::kDynamic);
+  llvm::SmallVector<int64_t> weightSpatial(convShapeAdaptor.getNumSpatialDims(),
+                                           ShapedType::kDynamic);
+
+  convShapeAdaptor.inferInputShape(outputShape, inputSpatial);
+  convShapeAdaptor.inferWeightShape(outputShape, weightSpatial);
+
+  const ShapeAdaptor biasShape = adaptor.getBias().getType();
   if (biasShape.hasRank()) {
     const int64_t biasSize = biasShape.getDimSize(0);
-    // Bias of size 1 may be broadcast
     if (biasSize != 1) {
-      outShape[3] = ShapedType::isDynamic(outShape[3]) ? biasSize : outShape[3];
+      const size_t outputChannelDim = convShapeAdaptor.getOutputRank() - 1;
+      outputShape[outputChannelDim] =
+          ShapedType::isDynamic(outputShape[outputChannelDim])
+              ? biasSize
+              : outputShape[outputChannelDim];
     }
   }
 
   SmallVector<int64_t> padValues;
   SmallVector<int64_t> strideValues;
   SmallVector<int64_t> dilationValues;
-  if (!tosa::getConstShapeValues(adaptor.getPad().getDefiningOp(), padValues) ||
-      !tosa::getConstShapeValues(adaptor.getStride().getDefiningOp(),
-                                 strideValues) ||
-      !tosa::getConstShapeValues(adaptor.getDilation().getDefiningOp(),
-                                 dilationValues)) {
-    inferredReturnShapes.push_back(ShapedTypeComponents(outShape));
+  if (failed(convShapeAdaptor.getSpatialParameters(padValues, strideValues,
+                                                   dilationValues))) {
+    inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
     return success();
   }
 
-  if (ShapedType::isStatic(inputHeight) && ShapedType::isStatic(weightHeight)) {
-    const int64_t inputSize = inputHeight + padValues[0] + padValues[1];
-    const int64_t filterSize = (weightHeight - 1) * dilationValues[0] + 1;
+  for (int64_t dim = 0; dim < convShapeAdaptor.getNumSpatialDims(); ++dim) {
+    if (!ShapedType::isStatic(inputSpatial[dim]) ||
+        !ShapedType::isStatic(weightSpatial[dim]))
+      continue;
+    const int64_t inputSize =
+        inputSpatial[dim] + padValues[2 * dim] + padValues[2 * dim + 1];
+    const int64_t filterSize =
+        (weightSpatial[dim] - 1) * dilationValues[dim] + 1;
     const int64_t unstridedResult = inputSize - filterSize + 1;
-    outShape[1] = (unstridedResult - 1) / strideValues[0] + 1;
+    outputShape[dim + 1] = (unstridedResult - 1) / strideValues[dim] + 1;
   }
 
-  if (ShapedType::isStatic(inputWidth) && ShapedType::isStatic(weightWidth)) {
-    const int64_t inputSize = inputWidth + padValues[2] + padValues[3];
-    const int64_t filterSize = (weightWidth - 1) * dilationValues[1] + 1;
-    const int64_t unstridedResult = inputSize - filterSize + 1;
-    outShape[2] = (unstridedResult - 1) / strideValues[1] + 1;
-  }
+  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
+  return success();
+}
 
-  inferredReturnShapes.push_back(ShapedTypeComponents(outShape));
+LogicalResult Conv2DOp::inferReturnTypeComponents(
+    MLIRContext *context, ::std::optional<Location> location,
+    Conv2DOp::Adaptor adaptor,
+    SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
+  return inferConvReturnTypeComponents(adaptor, inferredReturnShapes);
+}
+
+LogicalResult Conv2DOp::verify() {
+  if (verifyConvOp(*this).failed() || verifyConvOpModes(*this).failed() ||
+      verifyConvOpErrorIf(*this).failed())
+    return failure();
   return success();
 }
 
+LogicalResult Conv2DBlockScaledOp::inferReturnTypeComponents(
+    MLIRContext *context, ::std::optional<Location> location,
+    Conv2DBlockScaledOp::Adaptor adaptor,
+    SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
+  return inferConvReturnTypeComponents(adaptor, inferredReturnShapes);
+}
+
 LogicalResult Conv2DBlockScaledOp::verify() {
   if (failed(verifySameElementTypes(*this, getInputData().getType(),
                                     getWeightData().getType(), "input_data",
@@ -3732,67 +3811,7 @@ LogicalResult Conv3DOp::inferReturnTypeComponents(
     MLIRContext *context, ::std::optional<Location> location,
     Conv3DOp::Adaptor adaptor,
     SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
-  llvm::SmallVector<int64_t> outputShape(5, ShapedType::kDynamic);
-
-  int64_t inputWidth = ShapedType::kDynamic;
-  int64_t inputHeight = ShapedType::kDynamic;
-  int64_t inputDepth = ShapedType::kDynamic;
-
-  int64_t weightWidth = ShapedType::kDynamic;
-  int64_t weightHeight = ShapedType::kDynamic;
-  int64_t weightDepth = ShapedType::kDynamic;
-
-  // Input shape describes input width/height and batch.
-  ShapeAdaptor inputShape(adaptor.getInput().getType());
-  if (inputShape.hasRank()) {
-    outputShape[0] = inputShape.getDimSize(0);
-    inputDepth = inputShape.getDimSize(1);
-    inputHeight = inputShape.getDimSize(2);
-    inputWidth = inputShape.getDimSize(3);
-  }
-
-  // Weight shapes describes the filter width/height and the output channels.
-  ShapeAdaptor weightShape(adaptor.getWeight().getType());
-  if (weightShape.hasRank()) {
-    outputShape[4] = weightShape.getDimSize(0);
-    weightDepth = weightShape.getDimSize(1);
-    weightHeight = weightShape.getDimSize(2);
-    weightWidth = weightShape.getDimSize(3);
-  }
-
-  // Bias shape can describe the output channels.
-  ShapeAdaptor biasShape(adaptor.getBias().getType());
-  if (biasShape.hasRank() && ShapedType::isDynamic(outputShape[4])) {
-    outputShape[4] = biasShape.getDimSize(0);
-  }
-
-  llvm::ArrayRef<int64_t> dilation = adaptor.getDilation();
-  llvm::ArrayRef<int64_t> stride = adaptor.getStride();
-  llvm::ArrayRef<int64_t> pad = adaptor.getPad();
-
-  if (ShapedType::isStatic(inputDepth) && ShapedType::isStatic(weightDepth)) {
-    int32_t inputSize = inputDepth + pad[0] + pad[1];
-    int32_t filterSize = (weightDepth - 1) * dilation[0] + 1;
-    int32_t unstridedResult = inputSize - filterSize + 1;
-    outputShape[1] = (unstridedResult - 1) / stride[0] + 1;
-  }
-
-  if (ShapedType::isStatic(inputHeight) && ShapedType::isStatic(weightHeight)) {
-    int32_t inputSize = inputHeight + pad[2] + pad[3];
-    int32_t filterSize = (weightHeight - 1) * dilation[1] + 1;
-    int32_t unstridedResult = inputSize - filterSize + 1;
-    outputShape[2] = (unstridedResult - 1) / stride[1] + 1;
-  }
-
-  if (ShapedType::isStatic(inputWidth) && ShapedType::isStatic(weightWidth)) {
-    int32_t inputSize = inputWidth + pad[4] + pad[5];
-    int32_t filterSize = (weightWidth - 1) * dilation[2] + 1;
-    int32_t unstridedResult = inputSize - filterSize + 1;
-    outputShape[3] = (unstridedResult - 1) / stride[2] + 1;
-  }
-
-  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
-  return success();
+  return inferConvReturnTypeComponents(adaptor, inferredReturnShapes);
 }
 
 LogicalResult Conv3DOp::verify() {
diff --git a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
index bc5f41b1af304..231eec8192eec 100644
--- a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
+++ b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
@@ -1741,8 +1741,26 @@ func.func @test_tconv2d_bias_broadcast(%input: tensor<2x6x7x3xf32>, %weight: ten
   %0 = tosa.transpose_conv2d %input, %weight, %bias, %input_zp, %weight_zp
        { acc_type = f32, pad = array<i64: 0, 0, 0, 0>, out_pad = array<i64: 0, 0, 0, 0>, out_shape = array<i64: -1, -1, -1, -1>, stride = array<i64: 1, 1>, dilation = array<i64: 1, 1> }
        : (tensor<2x6x7x3xf32>, tensor<?x3x3x3xf32>, tensor<1xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<?x?x?x?xf32>
-    return
-  }
+  return
+}
+
+// -----
+
+// CHECK-LABEL: @test_conv2d_bias_broadcast
+func.func @test_conv2d_bias_broadcast(%input: tensor<2x8x9x3xf32>, %weights: tensor<?x3x6x3xf32>, %bias: tensor<1xf32>, %input_zp: tensor<1xf32>, %weight_zp: tensor<1xf32>) -> () {
+  // CHECK: -> tensor<2x6x4x?xf32>
+  %0 = tosa.conv2d %input, %weights, %bias, %input_zp, %weight_zp {acc_type = f32, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>, dilation = array<i64: 1, 1>} : (tensor<2x8x9x3xf32>, tensor<?x3x6x3xf32>, tensor<1xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<?x?x?x?xf32>
+  return
+}
+
+// -----
+
+// CHECK-LABEL: @test_conv3d_bias_broadcast
+func.func @test_conv3d_bias_broadcast(%input: tensor<2x8x9x10x3xf32>, %weights: tensor<?x3x6x4x3xf32>, %bias: tensor<1xf32>, %input_zp: tensor<1xf32>, %weight_zp: tensor<1xf32>) -> () {
+  // CHECK: -> tensor<2x6x4x7x?xf32>
+  %0 = tosa.conv3d %input, %weights, %bias, %input_zp, %weight_zp {acc_type = f32, dilation = array<i64: 1, 1, 1>, ...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/178869


More information about the Mlir-commits mailing list