[Mlir-commits] [mlir] [mlir][tosa] Support unranked input/weight tensors for convolution ops (PR #134856)

Luke Hutton llvmlistbot at llvm.org
Wed Apr 23 06:10:26 PDT 2025


https://github.com/lhutton1 updated https://github.com/llvm/llvm-project/pull/134856

>From 40a5eb511ceaa106e78a0fb43ddef69b44256a1c Mon Sep 17 00:00:00 2001
From: Luke Hutton <luke.hutton at arm.com>
Date: Tue, 8 Apr 2025 09:48:52 +0000
Subject: [PATCH] [mlir][tosa] Support unranked input/weight tensors for
 convolution ops

This commit ensures that convolution operators including:
conv2d, depthwise_conv2d, transpose_conv2d and conv3d, can have
unranked input/weight operands.

In order to support operands with unranked tensors, the tablegen
definition was relaxed. The relaxation of tensor type will later
be checked by the validation pass, should the user wish to use it.

Change-Id: I33334909e0d4d0676daae81bfc4647e86abc063a
Signed-off-by: Luke Hutton <luke.hutton at arm.com>
---
 mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td  |   8 +-
 .../mlir/Dialect/Tosa/IR/TosaTypesBase.td     |   5 -
 mlir/lib/Dialect/Tosa/IR/TosaOps.cpp          | 127 ++++++++----------
 mlir/test/Dialect/Tosa/invalid.mlir           |  16 +--
 mlir/test/Dialect/Tosa/ops.mlir               |  21 +++
 5 files changed, 86 insertions(+), 91 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index c94edad62cac7..cc78aaed911e6 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -125,7 +125,7 @@ def Tosa_Conv2DOp : Tosa_ConvOp<"conv2d"> {
 
   let arguments = (ins
     Tosa_Tensor4D:$input,
-    TosaTensorRankOf<[Tosa_Weight], [4]>:$weight,
+    Tosa_Tensor4D:$weight,
     Tosa_Tensor1D:$bias,
     Tosa_ScalarIntOrFloatTensor:$input_zp,
     Tosa_ScalarIntOrFloatTensor:$weight_zp,
@@ -172,7 +172,7 @@ def Tosa_Conv3DOp : Tosa_ConvOp<"conv3d"> {
 
   let arguments = (ins
     Tosa_Tensor5D:$input,
-    TosaTensorRankOf<[Tosa_Weight], [5]>:$weight,
+    Tosa_Tensor5D:$weight,
     Tosa_Tensor1D:$bias,
     Tosa_ScalarIntOrFloatTensor:$input_zp,
     Tosa_ScalarIntOrFloatTensor:$weight_zp,
@@ -218,7 +218,7 @@ def Tosa_DepthwiseConv2DOp : Tosa_ConvOp<"depthwise_conv2d"> {
 
   let arguments = (ins
     Tosa_Tensor4D:$input,
-    TosaTensorRankOf<[Tosa_Weight], [4]>:$weight,
+    Tosa_Tensor4D:$weight,
     Tosa_Tensor1D:$bias,
     Tosa_ScalarIntOrFloatTensor:$input_zp,
     Tosa_ScalarIntOrFloatTensor:$weight_zp,
@@ -434,7 +434,7 @@ def Tosa_TransposeConv2DOp : Tosa_ConvOp<"transpose_conv2d"> {
 
   let arguments = (ins
     Tosa_Tensor4D:$input,
-    TosaTensorRankOf<[Tosa_Weight], [4]>:$weight,
+    Tosa_Tensor4D:$weight,
     Tosa_Tensor1D:$bias,
     Tosa_ScalarIntOrFloatTensor:$input_zp,
     Tosa_ScalarIntOrFloatTensor:$weight_zp,
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
index 67011f22fbe2a..b9ac1ff705514 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
@@ -84,11 +84,6 @@ def Tosa_QuantizedInt : AnyTypeOf<[Tosa_QuantizedType<"uint8", [8], 0>,
 def Tosa_AnyNumber : AnyTypeOf<[Tosa_Int, Tosa_QuantizedInt, AnyFloat],
                                 "number">;
 
-// For weight tensors from tosa::Conv2DOp, tosa::Conv3DOp,
-// tosa::DepthwiseConv2DOp, tosa::TransposeConv2DOp
-def Tosa_Weight : AnyTypeOf<[Tosa_Int4, Tosa_Int8,
-                             Tosa_QuantizedInt, AnyFloat]>;
-
 //===----------------------------------------------------------------------===//
 // TOSA Tensor Conformance
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 1ab4ce7d4558b..93fdb9aefe1e6 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -278,19 +278,8 @@ Value mlir::tosa::createPadConstTensor(OpBuilder &builder, Location loc,
 
 template <typename T>
 static LogicalResult verifyConvOp(T op) {
-  // All TOSA conv ops have an input and weight arguments which must be ranked
-  // tensors.
-  auto inputType = llvm::dyn_cast<RankedTensorType>(op.getInput().getType());
-  if (!inputType) {
-    op.emitOpError("expect a ranked tensor for input, got ") << op.getInput();
-    return failure();
-  }
-
-  auto weightType = llvm::dyn_cast<RankedTensorType>(op.getWeight().getType());
-  if (!weightType) {
-    op.emitOpError("expect a ranked tensor for weight, got ") << op.getWeight();
-    return failure();
-  }
+  const auto inputType = llvm::dyn_cast<TensorType>(op.getInput().getType());
+  const auto weightType = llvm::dyn_cast<TensorType>(op.getWeight().getType());
 
   auto inputEType = inputType.getElementType();
   auto weightEType = weightType.getElementType();
@@ -2998,14 +2987,6 @@ LogicalResult TransposeConv2DOp::verify() {
     return emitOpError("expect all stride values to be >= 1, got [")
            << strides << "]";
 
-  const auto inputType = llvm::dyn_cast<RankedTensorType>(getInput().getType());
-
-  const auto outputType =
-      llvm::dyn_cast<RankedTensorType>(getOutput().getType());
-
-  const auto weightType =
-      llvm::dyn_cast<RankedTensorType>(getWeight().getType());
-
   const auto checkPadAgainstKernelDim =
       [this](int64_t pad_value, int64_t kernel_dim_size,
              llvm::StringRef pad_name,
@@ -3019,69 +3000,77 @@ LogicalResult TransposeConv2DOp::verify() {
   };
 
   const llvm::ArrayRef<int64_t> padding = getOutPad();
-
   const int64_t outPadTop = padding[0];
   const int64_t outPadBottom = padding[1];
+  const int64_t outPadLeft = padding[2];
+  const int64_t outPadRight = padding[3];
 
-  const int64_t kernelHeight = weightType.getDimSize(1);
-
-  if (!ShapedType::isDynamic(kernelHeight)) {
-    if (failed(checkPadAgainstKernelDim(outPadTop, kernelHeight, "out_pad_top",
-                                        "KH")))
-      return failure();
-
-    if (failed(checkPadAgainstKernelDim(outPadBottom, kernelHeight,
-                                        "out_pad_bottom", "KH")))
-      return failure();
-  }
+  const auto weightType =
+      llvm::dyn_cast<RankedTensorType>(getWeight().getType());
 
-  const int64_t kernelWidth = weightType.getDimSize(2);
+  if (weightType) {
+    const int64_t kernelHeight = weightType.getDimSize(1);
+    if (!ShapedType::isDynamic(kernelHeight)) {
+      if (failed(checkPadAgainstKernelDim(outPadTop, kernelHeight,
+                                          "out_pad_top", "KH")))
+        return failure();
 
-  const int64_t outPadLeft = padding[2];
-  const int64_t outPadRight = padding[3];
+      if (failed(checkPadAgainstKernelDim(outPadBottom, kernelHeight,
+                                          "out_pad_bottom", "KH")))
+        return failure();
+    }
 
-  if (!ShapedType::isDynamic(kernelWidth)) {
-    if (failed(checkPadAgainstKernelDim(outPadLeft, kernelWidth, "out_pad_left",
-                                        "KW")))
-      return failure();
+    const int64_t kernelWidth = weightType.getDimSize(2);
+    if (!ShapedType::isDynamic(kernelWidth)) {
+      if (failed(checkPadAgainstKernelDim(outPadLeft, kernelWidth,
+                                          "out_pad_left", "KW")))
+        return failure();
 
-    if (failed(checkPadAgainstKernelDim(outPadRight, kernelWidth,
-                                        "out_pad_right", "KW")))
-      return failure();
+      if (failed(checkPadAgainstKernelDim(outPadRight, kernelWidth,
+                                          "out_pad_right", "KW")))
+        return failure();
+    }
   }
 
   // Rest of the checks depend on the output type being a RankedTensorType
+  const auto outputType =
+      llvm::dyn_cast<RankedTensorType>(getOutput().getType());
   if (!outputType)
     return success();
 
-  const int64_t inputHeight = inputType.getDimSize(1);
-  const int64_t outputHeight = outputType.getDimSize(1);
-
-  if (!ShapedType::isDynamic(inputHeight) &&
-      !ShapedType::isDynamic(outputHeight)) {
-    if (outputHeight !=
-        (inputHeight - 1) * strideY + outPadTop + outPadBottom + kernelHeight)
-      return emitOpError(
-                 "dimension mismatch: expected OH == (IH - 1) * stride_y "
-                 "+ out_pad_top + out_pad_bottom + KH, but got ")
-             << outputHeight << " != (" << inputHeight << " - 1) * " << strideY
-             << " + " << outPadTop << " + " << outPadBottom << " + "
-             << kernelHeight;
-  }
+  const auto inputType = llvm::dyn_cast<RankedTensorType>(getInput().getType());
+  if (inputType && weightType) {
+    const int64_t inputHeight = inputType.getDimSize(1);
+    const int64_t kernelHeight = weightType.getDimSize(1);
+    const int64_t outputHeight = outputType.getDimSize(1);
+
+    if (!ShapedType::isDynamic(inputHeight) &&
+        !ShapedType::isDynamic(outputHeight)) {
+      if (outputHeight !=
+          (inputHeight - 1) * strideY + outPadTop + outPadBottom + kernelHeight)
+        return emitOpError(
+                   "dimension mismatch: expected OH == (IH - 1) * stride_y "
+                   "+ out_pad_top + out_pad_bottom + KH, but got ")
+               << outputHeight << " != (" << inputHeight << " - 1) * "
+               << strideY << " + " << outPadTop << " + " << outPadBottom
+               << " + " << kernelHeight;
+    }
 
-  const int64_t inputWidth = inputType.getDimSize(2);
-  const int64_t outputWidth = outputType.getDimSize(2);
+    const int64_t inputWidth = inputType.getDimSize(2);
+    const int64_t kernelWidth = weightType.getDimSize(2);
+    const int64_t outputWidth = outputType.getDimSize(2);
 
-  if (!ShapedType::isDynamic(inputWidth) &&
-      !ShapedType::isDynamic(outputWidth)) {
-    if (outputWidth !=
-        (inputWidth - 1) * strideX + outPadLeft + outPadRight + kernelWidth)
-      return emitOpError(
-                 "dimension mismatch: expected OW == (IW - 1) * stride_x "
-                 "+ out_pad_left + out_pad_right + KW, but got ")
-             << outputWidth << " != (" << inputWidth << " - 1) * " << strideX
-             << " + " << outPadLeft << " + " << outPadRight << " + "
-             << kernelWidth;
+    if (!ShapedType::isDynamic(inputWidth) &&
+        !ShapedType::isDynamic(outputWidth)) {
+      if (outputWidth !=
+          (inputWidth - 1) * strideX + outPadLeft + outPadRight + kernelWidth)
+        return emitOpError(
+                   "dimension mismatch: expected OW == (IW - 1) * stride_x "
+                   "+ out_pad_left + out_pad_right + KW, but got ")
+               << outputWidth << " != (" << inputWidth << " - 1) * " << strideX
+               << " + " << outPadLeft << " + " << outPadRight << " + "
+               << kernelWidth;
+    }
   }
 
   const auto biasType = llvm::dyn_cast<RankedTensorType>(getBias().getType());
diff --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir
index 269ed58fdc81c..b084337fc046a 100644
--- a/mlir/test/Dialect/Tosa/invalid.mlir
+++ b/mlir/test/Dialect/Tosa/invalid.mlir
@@ -22,22 +22,12 @@ func.func @test_const_non_tensor_attr() {
 
 // -----
 
-func.func @test_conv2d(%arg0: tensor<1x29x29x4xf32>, %arg1: tensor<16x3x3x4xi8>, %arg2: tensor<16xi8>) -> tensor<1x27x27x16xi8> {
+func.func @test_conv2d(%arg0: tensor<*xf32>, %arg1: tensor<16x3x3x4xi8>, %arg2: tensor<16xi8>) -> tensor<1x27x27x16xi8> {
   %input_zp = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
   %weight_zp = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
   // expected-error at +1 {{'tosa.conv2d' op expect both input and weight to be float or not together, got 'f32' and 'i8'}}
   %0 = tosa.conv2d %arg0, %arg1, %arg2, %input_zp, %weight_zp {acc_type = i32, dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>}
-           : (tensor<1x29x29x4xf32>, tensor<16x3x3x4xi8>, tensor<16xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<1x27x27x16xi8>
-  return %0 : tensor<1x27x27x16xi8>
-}
-
-// -----
-
-func.func @test_conv2d(%arg0: tensor<*xi8>, %arg1: tensor<16x3x3x4xi8>, %arg2: tensor<16xi8>) -> tensor<1x27x27x16xi8> {
-  %zp = "tosa.const"() {values = dense<0> : tensor<1xi8>} : () -> tensor<1xi8>
-  // expected-error at +1 {{'tosa.conv2d' op expect a ranked tensor for input, got <block argument> of type 'tensor<*xi8>' at index: 0}}
-  %0 = tosa.conv2d %arg0, %arg1, %arg2, %zp, %zp {acc_type = i32, dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>}
-           : (tensor<*xi8>, tensor<16x3x3x4xi8>, tensor<16xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<1x27x27x16xi8>
+           : (tensor<*xf32>, tensor<16x3x3x4xi8>, tensor<16xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<1x27x27x16xi8>
   return %0 : tensor<1x27x27x16xi8>
 }
 
@@ -45,7 +35,7 @@ func.func @test_conv2d(%arg0: tensor<*xi8>, %arg1: tensor<16x3x3x4xi8>, %arg2: t
 
 func.func @test_conv2d(%arg0: tensor<1x29x29x4xi8>, %arg1: tensor<*xi8>, %arg2: tensor<16xi8>) -> tensor<1x27x27x16xi8> {
   %zp = "tosa.const"() {values = dense<0> : tensor<1xi8>} : () -> tensor<1xi8>
-  // expected-error at +1 {{'tosa.conv2d' op operand #1 must be 4D tensor of 4-bit signless integer or 8-bit signless integer or Quint8 type or Qint4 type or Qint8 type or Qint16 type or Qint32 type or floating-point values, but got 'tensor<*xi8>'}}
+  // expected-error at +1 {{'tosa.conv2d' op illegal: operand/result data types not supported}}
   %0 = tosa.conv2d %arg0, %arg1, %arg2, %zp, %zp {acc_type = i32, dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>}
            : (tensor<1x29x29x4xi8>, tensor<*xi8>, tensor<16xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<1x27x27x16xi8>
   return %0 : tensor<1x27x27x16xi8>
diff --git a/mlir/test/Dialect/Tosa/ops.mlir b/mlir/test/Dialect/Tosa/ops.mlir
index c1181825f0c97..bde9b418e66e2 100644
--- a/mlir/test/Dialect/Tosa/ops.mlir
+++ b/mlir/test/Dialect/Tosa/ops.mlir
@@ -70,6 +70,13 @@ func.func @test_conv2d(%arg0: tensor<1x4x4x4xf32>, %arg1: tensor<8x1x1x4xf32>, %
   return %0 : tensor<1x4x4x8xf32>
 }
 
+// -----
+// CHECK-LABEL: conv2d_unranked_input
+func.func @test_conv2d_unranked_input(%arg0: tensor<*xf32>, %arg1: tensor<8x1x1x4xf32>, %arg2: tensor<8xf32>, %arg3: tensor<1xf32>, %arg4: tensor<1xf32>) -> tensor<1x4x4x8xf32> {
+  %0 = tosa.conv2d %arg0, %arg1, %arg2, %arg3, %arg4 {acc_type = f32, dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>, local_bound = true} : (tensor<*xf32>, tensor<8x1x1x4xf32>, tensor<8xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x4x4x8xf32>
+  return %0 : tensor<1x4x4x8xf32>
+}
+
 // -----
 // CHECK-LABEL: conv2d_quant_uniform
 func.func @test_conv2d_quant_uniform(%arg0: tensor<1x4x4x4x!quant.uniform<i8:f32, 0.01>>, %arg1: tensor<8x1x1x4x!quant.uniform<i8:f32, 0.01>>, %arg2: tensor<8x!quant.uniform<i8:f32, 0.01>>) -> tensor<1x4x4x8x!quant.uniform<i32:f32, 0.01>> {
@@ -202,6 +209,20 @@ func.func @test_transpose_conv2d(%arg0: tensor<1x32x32x8xf32>, %arg1: tensor<16x
   return %0 : tensor<1x32x32x16xf32>
 }
 
+// -----
+// CHECK-LABEL: transpose_conv2d_unranked_input
+func.func @test_transpose_conv2d_unranked_input(%arg0: tensor<*xf32>, %arg1: tensor<16x1x1x8xf32>, %arg2: tensor<16xf32>, %arg3: tensor<1xf32>, %arg4: tensor<1xf32>) -> tensor<1x32x32x16xf32> {
+  %0 = tosa.transpose_conv2d %arg0, %arg1, %arg2, %arg3, %arg4 {acc_type = f32, out_pad = array<i64: 0, 0, 0, 0>, out_shape = array<i64: 1, 32, 32, 16>, stride = array<i64: 1, 1>} : (tensor<*xf32>, tensor<16x1x1x8xf32>, tensor<16xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x32x32x16xf32>
+  return %0 : tensor<1x32x32x16xf32>
+}
+
+// -----
+// CHECK-LABEL: transpose_conv2d_unranked_weight
+func.func @test_transpose_conv2d_unranked_weight(%arg0: tensor<1x32x32x8xf32>, %arg1: tensor<*xf32>, %arg2: tensor<16xf32>, %arg3: tensor<1xf32>, %arg4: tensor<1xf32>) -> tensor<1x32x32x16xf32> {
+  %0 = tosa.transpose_conv2d %arg0, %arg1, %arg2, %arg3, %arg4 {acc_type = f32, out_pad = array<i64: 0, 0, 0, 0>, out_shape = array<i64: 1, 32, 32, 16>, stride = array<i64: 1, 1>} : (tensor<1x32x32x8xf32>, tensor<*xf32>, tensor<16xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x32x32x16xf32>
+  return %0 : tensor<1x32x32x16xf32>
+}
+
 // -----
 // CHECK-LABEL: transpose_conv2d_with_local_bound
 func.func @test_transpose_conv2d_with_local_bound(%arg0: tensor<1x32x32x8xf32>, %arg1: tensor<16x1x1x8xf32>, %arg2: tensor<16xf32>, %arg3: tensor<1xf32>, %arg4: tensor<1xf32>) -> tensor<1x32x32x16xf32> {



More information about the Mlir-commits mailing list