[Mlir-commits] [mlir] [mlir][tosa] Enhance CONV3D & DEPTHWISE_CONV2D verifier (PR #135738)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Apr 14 20:09:10 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-linalg

Author: TatWai Chong (tatwaichong)

<details>
<summary>Changes</summary>

Verify the correctness of pad, stride, dilation, and dimension of input/weight/bias/output.

---

Patch is 74.01 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/135738.diff


11 Files Affected:

- (modified) mlir/lib/Dialect/Tosa/IR/TosaOps.cpp (+152-94) 
- (modified) mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir (+22-22) 
- (modified) mlir/test/Dialect/Tosa/availability.mlir (+2-2) 
- (modified) mlir/test/Dialect/Tosa/canonicalize.mlir (+6-6) 
- (modified) mlir/test/Dialect/Tosa/error_if_check.mlir (+152) 
- (modified) mlir/test/Dialect/Tosa/invalid_extension.mlir (+4-4) 
- (modified) mlir/test/Dialect/Tosa/level_check.mlir (+60-60) 
- (modified) mlir/test/Dialect/Tosa/ops.mlir (+8-8) 
- (modified) mlir/test/Dialect/Tosa/profile_pro_fp_unsupported.mlir (+2-2) 
- (modified) mlir/test/Dialect/Tosa/profile_pro_int_unsupported.mlir (+2-2) 
- (modified) mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir (+17-18) 


``````````diff
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 9579d71a2afe9..194c577edd368 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -428,6 +428,152 @@ static LogicalResult verifyConvOpModes(T op) {
   return success();
 }
 
+//===----------------------------------------------------------------------===//
+// ERROR_IF functions.
+// ERROR_IF is a predicate that must set an error if the condition holds.
+//===----------------------------------------------------------------------===//
+
+template <typename T>
+static LogicalResult verifyConvOpErrorIf(T op) {
+  llvm::ArrayRef<int64_t> padding = op.getPad();
+  if (llvm::any_of(padding, [](int64_t p) { return p < 0; }))
+    return op.emitOpError("expect all padding values to be >= 0, got ")
+           << padding;
+
+  llvm::ArrayRef<int64_t> strides = op.getStride();
+  if (llvm::any_of(strides, [](int64_t s) { return s < 1; }))
+    return op.emitOpError("expect all stride values to be >= 1, got ")
+           << strides;
+
+  llvm::ArrayRef<int64_t> dilations = op.getDilation();
+  if (llvm::any_of(dilations, [](int64_t d) { return d < 1; }))
+    return op.emitOpError("expect all dilation values to be >= 1, got ")
+           << dilations;
+
+  const RankedTensorType outputType =
+      llvm::dyn_cast<RankedTensorType>(op.getOutput().getType());
+  if (!outputType)
+    // Skip following checks if output is not ranked
+    return success();
+
+  const RankedTensorType inputType =
+      llvm::dyn_cast<RankedTensorType>(op.getInput().getType());
+  const RankedTensorType weightType =
+      llvm::dyn_cast<RankedTensorType>(op.getWeight().getType());
+
+  if (inputType && weightType) {
+    const auto verifyOutputSize =
+        [&op](const int64_t inputSize, const int64_t kernelSize,
+              const int64_t outputSize, const int64_t padBefore,
+              const int64_t padAfter, const int64_t stride,
+              const int64_t dilation, const llvm::StringRef dimName,
+              const llvm::StringRef dimAxis,
+              const llvm::StringRef padBeforeName,
+              const llvm::StringRef padAfterName) -> LogicalResult {
+      if (inputSize == ShapedType::kDynamic ||
+          kernelSize == ShapedType::kDynamic)
+        return success();
+
+      // ERROR_IF: O != idiv_check(I - 1 + pa + pb - (K - 1) * d, s) + 1
+
+      const std::optional<int64_t> calculatedOutSizeMinusOne = idivCheck(
+          inputSize - 1 + padBefore + padAfter - (kernelSize - 1) * dilation,
+          stride);
+      if (!calculatedOutSizeMinusOne.has_value())
+        return op.emitOpError("expected input_")
+               << dimName << " - 1 + pad_" << padBeforeName << " + pad_"
+               << padAfterName << " - (kernel_" << dimName
+               << " - 1) * dilation_" << dimAxis
+               << " to be wholly divisible by stride_" << dimAxis << ", got ("
+               << inputSize << " - 1 + " << padBefore << " + " << padAfter
+               << " - (" << kernelSize << " - 1) * " << dilation << ") / "
+               << stride;
+
+      const int64_t calculatedOutSize = calculatedOutSizeMinusOne.value() + 1;
+      if (outputSize != ShapedType::kDynamic && calculatedOutSize != outputSize)
+        return op.emitOpError("calculated output ")
+               << dimName << " did not match expected: "
+               << "calculated=" << calculatedOutSize
+               << ", expected=" << outputSize;
+
+      return success();
+    };
+
+    /// ERROR_IF: O != idiv_check(I - 1 + p_a + p_b - (K - 1) * d, s) + 1
+
+    // input = [_,IH,IW,_], weight = [_,KH,KW,_], output = [_,OH,OW,_]
+    if constexpr (std::is_same<T, tosa::Conv2DOp>::value) {
+      if (failed(verifyOutputSize(
+              inputType.getDimSize(1), weightType.getDimSize(1),
+              outputType.getDimSize(1), padding[0], padding[1], strides[0],
+              dilations[0], "height", "y", "top", "bottom")))
+        return failure();
+
+      if (failed(verifyOutputSize(
+              inputType.getDimSize(2), weightType.getDimSize(2),
+              outputType.getDimSize(2), padding[2], padding[3], strides[1],
+              dilations[1], "width", "x", "left", "right")))
+        return failure();
+    }
+
+    // input = [_,IH,IW,_], weight = [KH,KW,_,_], output = [_,OH,OW,_]
+    if constexpr (std::is_same<T, tosa::DepthwiseConv2DOp>::value) {
+      if (failed(verifyOutputSize(
+              inputType.getDimSize(1), weightType.getDimSize(0),
+              outputType.getDimSize(1), padding[0], padding[1], strides[0],
+              dilations[0], "height", "y", "top", "bottom")))
+        return failure();
+
+      if (failed(verifyOutputSize(
+              inputType.getDimSize(2), weightType.getDimSize(1),
+              outputType.getDimSize(2), padding[2], padding[3], strides[1],
+              dilations[1], "width", "x", "left", "right")))
+        return failure();
+    }
+
+    // input = [_,ID,IH,IW,_], weight = [_,KD,KH,KW,_], output = [_,OD,OH,OW,_]
+    if constexpr (std::is_same<T, tosa::Conv3DOp>::value) {
+      if (failed(verifyOutputSize(
+              inputType.getDimSize(1), weightType.getDimSize(1),
+              outputType.getDimSize(1), padding[0], padding[1], strides[0],
+              dilations[0], "depth", "d", "front", "back")))
+        return failure();
+
+      if (failed(verifyOutputSize(
+              inputType.getDimSize(2), weightType.getDimSize(2),
+              outputType.getDimSize(2), padding[2], padding[3], strides[1],
+              dilations[1], "height", "y", "top", "bottom")))
+        return failure();
+
+      if (failed(verifyOutputSize(
+              inputType.getDimSize(3), weightType.getDimSize(3),
+              outputType.getDimSize(3), padding[4], padding[5], strides[2],
+              dilations[2], "width", "x", "left", "right")))
+        return failure();
+    }
+  }
+
+  const RankedTensorType biasType =
+      llvm::dyn_cast<RankedTensorType>(op.getBias().getType());
+  if (!biasType)
+    // Skip following checks if bias is not ranked
+    return success();
+
+  const int64_t biasChannels = biasType.getDimSize(0);
+  const int64_t outputChannels = outputType.getDimSize(3);
+  if (biasChannels == ShapedType::kDynamic ||
+      outputChannels == ShapedType::kDynamic)
+    // Skip following checks if biasChannels or outputChannels is dynamic dim
+    return success();
+
+  if (biasChannels != outputChannels && biasChannels != 1)
+    return op.emitOpError(
+               "bias channels expected to be equal to output channels (")
+           << outputChannels << ") or 1, got " << biasChannels;
+
+  return success();
+}
+
 // verify that inType and outType have same element types
 template <typename T>
 static LogicalResult verifySameElementTypes(T op, Type inType, Type outType) {
@@ -2570,99 +2716,9 @@ LogicalResult Conv2DOp::inferReturnTypeComponents(
 }
 
 LogicalResult Conv2DOp::verify() {
-  if (verifyConvOp(*this).failed() || verifyConvOpModes(*this).failed())
+  if (verifyConvOp(*this).failed() || verifyConvOpModes(*this).failed() ||
+      verifyConvOpErrorIf(*this).failed())
     return failure();
-
-  llvm::ArrayRef<int64_t> padding = getPad();
-  if (llvm::any_of(padding, [](int64_t p) { return p < 0; }))
-    return emitOpError("expect all padding values to be >= 0, got ") << padding;
-
-  llvm::ArrayRef<int64_t> strides = getStride();
-  if (llvm::any_of(strides, [](int64_t s) { return s < 1; }))
-    return emitOpError("expect all stride values to be >= 1, got ") << strides;
-
-  llvm::ArrayRef<int64_t> dilations = getDilation();
-  if (llvm::any_of(dilations, [](int64_t d) { return d < 1; }))
-    return emitOpError("expect all dilation values to be >= 1, got ")
-           << dilations;
-
-  const RankedTensorType outputType =
-      llvm::dyn_cast<RankedTensorType>(getOutput().getType());
-  if (!outputType)
-    // Skip following checks if output is not ranked
-    return success();
-
-  const RankedTensorType inputType =
-      llvm::dyn_cast<RankedTensorType>(getInput().getType());
-  const RankedTensorType weightType =
-      llvm::dyn_cast<RankedTensorType>(getWeight().getType());
-
-  if (inputType && weightType) {
-    const auto verifyOutputSize =
-        [this](const int64_t inputSize, const int64_t kernelSize,
-               const int64_t outputSize, const int64_t padBefore,
-               const int64_t padAfter, const int64_t stride,
-               const int64_t dilation, const llvm::StringRef dimName,
-               const llvm::StringRef dimAxis,
-               const llvm::StringRef padBeforeName,
-               const llvm::StringRef padAfterName) -> LogicalResult {
-      if (inputSize == ShapedType::kDynamic ||
-          kernelSize == ShapedType::kDynamic)
-        return success();
-
-      const std::optional<int64_t> calculatedOutSizeMinusOne = idivCheck(
-          inputSize - 1 + padBefore + padAfter - (kernelSize - 1) * dilation,
-          stride);
-      if (!calculatedOutSizeMinusOne.has_value())
-        return emitOpError("expected input_")
-               << dimName << " - 1 + pad_" << padBeforeName << " + pad_"
-               << padAfterName << " - (kernel_" << dimName
-               << " - 1) * dilation_" << dimAxis
-               << " to be wholly divisible by stride_" << dimAxis << ", got ("
-               << inputSize << " - 1 + " << padBefore << " + " << padAfter
-               << " - (" << kernelSize << " - 1) * " << dilation << ") / "
-               << stride;
-
-      const int64_t calculatedOutSize = calculatedOutSizeMinusOne.value() + 1;
-      if (outputSize != ShapedType::kDynamic && calculatedOutSize != outputSize)
-        return emitOpError("calculated output ")
-               << dimName << " did not match expected: "
-               << "calculated=" << calculatedOutSize
-               << ", expected=" << outputSize;
-
-      return success();
-    };
-
-    if (failed(verifyOutputSize(
-            inputType.getDimSize(1), weightType.getDimSize(1),
-            outputType.getDimSize(1), padding[0], padding[1], strides[0],
-            dilations[0], "height", "y", "top", "bottom")))
-      return failure();
-
-    if (failed(verifyOutputSize(
-            inputType.getDimSize(2), weightType.getDimSize(2),
-            outputType.getDimSize(2), padding[2], padding[3], strides[1],
-            dilations[1], "width", "x", "left", "right")))
-      return failure();
-  }
-
-  const RankedTensorType biasType =
-      llvm::dyn_cast<RankedTensorType>(getBias().getType());
-  if (!biasType)
-    // Skip following checks if bias is not ranked
-    return success();
-
-  const int64_t biasChannels = biasType.getDimSize(0);
-  const int64_t outputChannels = outputType.getDimSize(3);
-  if (biasChannels == ShapedType::kDynamic ||
-      outputChannels == ShapedType::kDynamic)
-    // Skip following checks if biasChannels or outputChannels is dynamic dim
-    return success();
-
-  if (biasChannels != outputChannels && biasChannels != 1)
-    return emitOpError(
-               "bias channels expected to be equal to output channels (")
-           << outputChannels << ") or 1, got " << biasChannels;
   return success();
 }
 
@@ -2737,7 +2793,8 @@ LogicalResult Conv3DOp::inferReturnTypeComponents(
 }
 
 LogicalResult Conv3DOp::verify() {
-  if (verifyConvOp(*this).failed() || verifyConvOpModes(*this).failed())
+  if (verifyConvOp(*this).failed() || verifyConvOpModes(*this).failed() ||
+      verifyConvOpErrorIf(*this).failed())
     return failure();
   return success();
 }
@@ -2847,7 +2904,8 @@ LogicalResult DepthwiseConv2DOp::inferReturnTypeComponents(
 }
 
 LogicalResult DepthwiseConv2DOp::verify() {
-  if (verifyConvOp(*this).failed() || verifyConvOpModes(*this).failed())
+  if (verifyConvOp(*this).failed() || verifyConvOpModes(*this).failed() ||
+      verifyConvOpErrorIf(*this).failed())
     return failure();
   return success();
 }
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
index 242772fe5cdcf..a737a8a05bae6 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
@@ -878,22 +878,22 @@ func.func @depthwise_conv2d_f16_f32_acc(%arg0 : tensor<1x7x5x3xf16>, %arg1 : ten
 // CHECK: #[[$MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)>
 
 // CHECK-LABEL: @conv3d_f32
-func.func @conv3d_f32(%input: tensor<1x49x48x47x27xf32>, %weights: tensor<28x3x4x5x27xf32>, %bias: tensor<28xf32>) -> () {
-  // CHECK-DAG:  %[[TRANSPOSE:.+]] = linalg.transpose ins(%arg1 : tensor<28x3x4x5x27xf32>) outs(%[[TRANSPOSEDINIT:.+]] : tensor<3x4x5x27x28xf32>) permutation = [1, 2, 3, 4, 0]
-  // CHECK-DAG:  %[[INIT:.+]] = tensor.empty() : tensor<1x47x45x43x28xf32>
+func.func @conv3d_f32(%input: tensor<1x49x48x47x27xf32>, %weights: tensor<43x3x4x5x27xf32>, %bias: tensor<43xf32>) -> () {
+  // CHECK-DAG:  %[[TRANSPOSE:.+]] = linalg.transpose ins(%arg1 : tensor<43x3x4x5x27xf32>) outs(%[[TRANSPOSEDINIT:.+]] : tensor<3x4x5x27x43xf32>) permutation = [1, 2, 3, 4, 0]
+  // CHECK-DAG:  %[[INIT:.+]] = tensor.empty() : tensor<1x47x45x43x43xf32>
   // CHECK:      %[[BROADCAST:.+]] = linalg.generic
   // CHECK-SAME: {indexing_maps = [#[[$MAP1]], #[[$MAP2]]], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"]}
-  // CHECK-SAME: ins(%arg2 : tensor<28xf32>) outs(%[[INIT]] : tensor<1x47x45x43x28xf32>) {
+  // CHECK-SAME: ins(%arg2 : tensor<43xf32>) outs(%[[INIT]] : tensor<1x47x45x43x43xf32>) {
   // CHECK:      ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f32):
   // CHECK:        linalg.yield %[[IN]] : f32
-  // CHECK:      } -> tensor<1x47x45x43x28xf32>
+  // CHECK:      } -> tensor<1x47x45x43x43xf32>
   // CHECK:      linalg.conv_3d_ndhwc_dhwcf
   // CHECK-SAME: {dilations = dense<1> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>}
-  // CHECK-SAME: ins(%arg0, %[[TRANSPOSE]] : tensor<1x49x48x47x27xf32>, tensor<3x4x5x27x28xf32>)
-  // CHECK-SAME: outs(%[[BROADCAST]] : tensor<1x47x45x43x28xf32>) -> tensor<1x47x45x43x28xf32>
+  // CHECK-SAME: ins(%arg0, %[[TRANSPOSE]] : tensor<1x49x48x47x27xf32>, tensor<3x4x5x27x43xf32>)
+  // CHECK-SAME: outs(%[[BROADCAST]] : tensor<1x47x45x43x43xf32>) -> tensor<1x47x45x43x43xf32>
   %input_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf32>}> : () -> tensor<1xf32>
   %weight_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf32>}> : () -> tensor<1xf32>
-  %0 = tosa.conv3d %input, %weights, %bias, %input_zp, %weight_zp {acc_type = f32, pad = array<i64: 0, 0, 0, 0, 0, 0>, stride = array<i64: 1, 1, 1>, dilation = array<i64: 1, 1, 1>} : (tensor<1x49x48x47x27xf32>, tensor<28x3x4x5x27xf32>, tensor<28xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x47x45x43x28xf32>
+  %0 = tosa.conv3d %input, %weights, %bias, %input_zp, %weight_zp {acc_type = f32, pad = array<i64: 0, 0, 0, 0, 0, 0>, stride = array<i64: 1, 1, 1>, dilation = array<i64: 1, 1, 1>} : (tensor<1x49x48x47x27xf32>, tensor<43x3x4x5x27xf32>, tensor<43xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x47x45x43x43xf32>
   return
 }
 
@@ -919,40 +919,40 @@ func.func @conv3d_scalar_bias_f32(%input: tensor<1x49x48x47x27xf32>, %weights: t
 // CHECK: #[[$MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)>
 
 // CHECK-LABEL: @conv3d_i8
-func.func @conv3d_i8(%input: tensor<1x49x48x47x27xi8>, %weights: tensor<28x3x4x5x27xi8>, %bias: tensor<28xi32>) -> () {
-  // CHECK-DAG:  %[[TRANSPOSE:.+]] = linalg.transpose ins(%arg1 : tensor<28x3x4x5x27xi8>) outs(%[[TRANSPOSEDINIT:.+]] : tensor<3x4x5x27x28xi8>) permutation = [1, 2, 3, 4, 0]
-  // CHECK-DAG:  %[[INIT:.+]] = tensor.empty() : tensor<1x47x45x43x28xi32>
+func.func @conv3d_i8(%input: tensor<1x49x48x47x27xi8>, %weights: tensor<43x3x4x5x27xi8>, %bias: tensor<43xi32>) -> () {
+  // CHECK-DAG:  %[[TRANSPOSE:.+]] = linalg.transpose ins(%arg1 : tensor<43x3x4x5x27xi8>) outs(%[[TRANSPOSEDINIT:.+]] : tensor<3x4x5x27x43xi8>) permutation = [1, 2, 3, 4, 0]
+  // CHECK-DAG:  %[[INIT:.+]] = tensor.empty() : tensor<1x47x45x43x43xi32>
   // CHECK:      %[[BROADCAST:.+]] = linalg.generic
   // CHECK-SAME: {indexing_maps = [#[[$MAP1]], #[[$MAP2]]], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"]}
-  // CHECK-SAME: ins(%arg2 : tensor<28xi32>)
-  // CHECK-SAME: outs(%[[INIT]] : tensor<1x47x45x43x28xi32>) {
+  // CHECK-SAME: ins(%arg2 : tensor<43xi32>)
+  // CHECK-SAME: outs(%[[INIT]] : tensor<1x47x45x43x43xi32>) {
   // CHECK:      ^bb0(%[[IN:.+]]: i32, %[[OUT:.+]]: i32):
   // CHECK:        linalg.yield %[[IN]] : i32
-  // CHECK:      } -> tensor<1x47x45x43x28xi32>
+  // CHECK:      } -> tensor<1x47x45x43x43xi32>
   // CHECK:      %[[IZP:.+]] = arith.constant -128 : i32
   // CHECK:      %[[FZP:.+]] = arith.constant 42 : i32
   // CHECK:      linalg.conv_3d_ndhwc_dhwcf_q
   // CHECK-SAME: {dilations = dense<1> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>}
-  // CHECK-SAME: ins(%arg0, %[[TRANSPOSE]], %[[IZP]], %[[FZP]] : tensor<1x49x48x47x27xi8>, tensor<3x4x5x27x28xi8>, i32, i32)
-  // CHECK-SAME: outs(%[[BROADCAST]] : tensor<1x47x45x43x28xi32>) -> tensor<1x47x45x43x28xi32>
+  // CHECK-SAME: ins(%arg0, %[[TRANSPOSE]], %[[IZP]], %[[FZP]] : tensor<1x49x48x47x27xi8>, tensor<3x4x5x27x43xi8>, i32, i32)
+  // CHECK-SAME: outs(%[[BROADCAST]] : tensor<1x47x45x43x43xi32>) -> tensor<1x47x45x43x43xi32>
 
   %input_zp = "tosa.const"() <{values = dense<-128> : tensor<1xi8>}> : () -> tensor<1xi8>
   %weight_zp = "tosa.const"() <{values = dense<42> : tensor<1xi8>}> : () -> tensor<1xi8>
-  %0 = tosa.conv3d %input, %weights, %bias, %input_zp, %weight_zp {acc_type = i32, pad = array<i64: 0, 0, 0, 0, 0, 0>, stride = array<i64: 1, 1, 1>, dilation = array<i64: 1, 1, 1>} : (tensor<1x49x48x47x27xi8>, tensor<28x3x4x5x27xi8>, tensor<28xi32>, tensor<1xi8>, tensor<1xi8>) -> tensor<1x47x45x43x28xi32>
+  %0 = tosa.conv3d %input, %weights, %bias, %input_zp, %weight_zp {acc_type = i32, pad = array<i64: 0, 0, 0, 0, 0, 0>, stride = array<i64: 1, 1, 1>, dilation = array<i64: 1, 1, 1>} : (tensor<1x49x48x47x27xi8>, tensor<43x3x4x5x27xi8>, tensor<43xi32>, tensor<1xi8>, tensor<1xi8>) -> tensor<1x47x45x43x43xi32>
   return
 }
 
 // -----
 
 // CHECK-LABEL: @conv3d_f16_f32_acc
-func.func @conv3d_f16_f32_acc(%input: tensor<1x49x48x47x27xf16>, %weights: tensor<28x3x4x5x27xf16>, %bias: tensor<28xf16>) -> () {
+func.func @conv3d_f16_f32_acc(%input: tensor<1x49x48x47x27xf16>, %weights: tensor<43x3x4x5x27xf16>, %bias: tensor<43xf16>) -> () {
   %input_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf16>}> : () -> tensor<1xf16>
   %weight_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf16>}> : () -> tensor<1xf16>
-  // CHECK: linalg.generic {{{.*}}} ins(%{{.*}} : tensor<28xf16>) outs(%{{.*}} : tensor<1x47x45x43x28xf32>)
+  // CHECK: linalg.generic {{{.*}}} ins(%{{.*}} : tensor<43xf16>) outs(%{{.*}} : tensor<1x47x45x43x43xf32>)
   // CHECK: arith.extf %{{.*}} : f16 to f32
-  // CHECK: %[[CONV:.*]] = linalg.conv_3d_ndhwc_dhwcf {{{.*}}} ins(%{{.*}}, %{{.*}} : tensor<1x49x48x47x27xf16>, tensor<3x4x5x27x28xf16>) outs(%{{.*}} : tensor<1x47x45x43x28xf32>) -> tensor<1x47x45x43x28xf32>
-  // CHECK: tosa.cast %[[CONV]] : (tensor<1x47x45x43x28xf32>) -> tensor<1x47x45x43x28xf16>
-  %0 = tosa.conv3d %input, %weights, %bias, %input_zp, %weight_zp {acc_type = f32, pad = array<i64: 0, 0, 0, 0, 0, 0>, stride = array<i64: 1, 1, 1>, dilation = array<i64: 1, 1, 1>} : (tensor<1x49x48x47x27xf16>, tensor<28x3x4x5x27xf16>, tensor<28xf16>, tensor<1xf16>, tensor<1xf16>) -> tensor<1x47x45x43x28xf16>
+  // CHECK: %[[CONV:.*]] = linalg.conv_3d_ndhwc_dhwcf {{{.*}}} ins(%{{.*}}, %{{.*}} : tensor<1x49x48x47x27xf16>, tensor<3x4x5x27x43xf16>) outs(%{{.*}} : tensor<1x47x45x43x43xf32>) -> tensor<1x47x45x43x43xf32>
+  // CHECK: tosa.cast %[[CONV]] : (tensor<1x47x45x43x43xf32>) -> tensor<1x47x45x43x43xf16>
+  %0 = tosa.conv3d %input, %weights, %bias, %input_zp, %weight_zp {acc_type = f32, pad = array<i64: 0, 0, 0, 0, 0, 0>, stride = array<i64: 1, 1, 1>, dilation = array<i64: 1, 1, 1>} : (tensor<1x49x48x47x27xf16>, tensor<43x3x4x5x27xf16>, tensor<43xf16>, tensor<1xf16>, tensor<1xf16>) -> tensor<1x47x45x43x43xf16>
   return
 }
 
diff --git a/mlir/test/Dialect/Tosa/availability.mlir b/mlir/test/Dialect/Tosa/availability.mlir
index 75126a11ac504..7374cfd1145b9 100644
--- a/mlir/test/Dialect/...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/135738


More information about the Mlir-commits mailing list