[Mlir-commits] [mlir] [mlir][tosa] Require operand/result tensors of at least rank 1 for some operations (PR #131335)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Mar 14 06:54:16 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Luke Hutton (lhutton1)

<details>
<summary>Changes</summary>

This commit updates the following operations (operands/results) to be of at least rank 1 such that it aligns wih the expectations of the specification:
- ARGMAX (input)
- REDUCE_ALL (input/output)
- REDUCE_ANY (input/output)
- REDUCE_MAX (input/output)
- REDUCE_MIN (input/output)
- REDUCE_PRODUCT (input/output)
- REDUCE_SUM (input/output)
- CONCAT (each input in input1/output)
- PAD (input1/output)
- REVERSE (input1/output)
- SLICE (input1/output)
- TILE (input1/output)
- TRANSPOSE (input1/output)

In addition to this change, PAD has been updated to allow unranked tensors for input1/output, inline with other operations.

---
Full diff: https://github.com/llvm/llvm-project/pull/131335.diff


5 Files Affected:

- (modified) mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td (+25-25) 
- (modified) mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td (+7) 
- (modified) mlir/lib/Dialect/Tosa/IR/TosaOps.cpp (+7-2) 
- (modified) mlir/test/Dialect/Tosa/canonicalize.mlir (+5-28) 
- (modified) mlir/test/Dialect/Tosa/invalid.mlir (+71-2) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index b79993f48b379..0c99dd6130c2a 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -41,7 +41,7 @@ def Tosa_ArgMaxOp : Tosa_InferShapedTypeOp<"argmax"> {
   }];
 
   let arguments = (ins
-    Tosa_Tensor: $input,
+    Tosa_TensorAtLeast1D: $input,
     I32Attr: $axis,
     DefaultValuedAttr<Tosa_NanPropagationAttr, "\"PROPAGATE\"">:$nan_mode
   );
@@ -1629,12 +1629,12 @@ def Tosa_ReduceAllOp : Tosa_InferTensorTypeOp<"reduce_all"> {
   }];
 
   let arguments = (ins
-    Tosa_Tensor:$input,
+    Tosa_TensorAtLeast1D:$input,
     I32Attr:$axis
   );
 
   let results = (outs
-    Tosa_Tensor:$output
+    Tosa_TensorAtLeast1D:$output
   );
 
   list<Availability> availability = [
@@ -1668,12 +1668,12 @@ def Tosa_ReduceAnyOp : Tosa_InferTensorTypeOp<"reduce_any"> {
   }];
 
   let arguments = (ins
-    Tosa_Tensor:$input,
+    Tosa_TensorAtLeast1D:$input,
     I32Attr:$axis
   );
 
   let results = (outs
-    Tosa_Tensor:$output
+    Tosa_TensorAtLeast1D:$output
   );
 
   list<Availability> availability = [
@@ -1707,13 +1707,13 @@ def Tosa_ReduceMaxOp : Tosa_InferTensorTypeOp<"reduce_max"> {
   }];
 
   let arguments = (ins
-    Tosa_Tensor:$input,
+    Tosa_TensorAtLeast1D:$input,
     I32Attr:$axis,
     DefaultValuedAttr<Tosa_NanPropagationAttr, "\"PROPAGATE\"">:$nan_mode
   );
 
   let results = (outs
-    Tosa_Tensor:$output
+    Tosa_TensorAtLeast1D:$output
   );
 
   list<Availability> availability = [
@@ -1748,13 +1748,13 @@ def Tosa_ReduceMinOp : Tosa_InferTensorTypeOp<"reduce_min"> {
   }];
 
   let arguments = (ins
-    Tosa_Tensor:$input,
+    Tosa_TensorAtLeast1D:$input,
     I32Attr:$axis,
     DefaultValuedAttr<Tosa_NanPropagationAttr, "\"PROPAGATE\"">:$nan_mode
   );
 
   let results = (outs
-    Tosa_Tensor:$output
+    Tosa_TensorAtLeast1D:$output
   );
 
   list<Availability> availability = [
@@ -1789,12 +1789,12 @@ def Tosa_ReduceProductOp : Tosa_InferTensorTypeOp<"reduce_product"> {
   }];
 
   let arguments = (ins
-    Tosa_Tensor:$input,
+    Tosa_TensorAtLeast1D:$input,
     I32Attr:$axis
   );
 
   let results = (outs
-    Tosa_Tensor:$output
+    Tosa_TensorAtLeast1D:$output
   );
 
   list<Availability> availability = [
@@ -1828,12 +1828,12 @@ def Tosa_ReduceSumOp : Tosa_InferTensorTypeOp<"reduce_sum"> {
   }];
 
   let arguments = (ins
-    Tosa_Tensor:$input,
+    Tosa_TensorAtLeast1D:$input,
     I32Attr:$axis
   );
 
   let results = (outs
-    Tosa_Tensor:$output
+    Tosa_TensorAtLeast1D:$output
   );
 
   list<Availability> availability = [
@@ -1872,12 +1872,12 @@ def Tosa_ConcatOp : Tosa_InferTensorTypeOp<"concat"> {
   }];
 
   let arguments = (ins
-    Variadic<Tosa_Tensor>:$input1,
+    Variadic<Tosa_TensorAtLeast1D>:$input1,
     I32Attr:$axis
   );
 
   let results = (outs
-    Tosa_Tensor:$output
+    Tosa_TensorAtLeast1D:$output
   );
 
   list<Availability> availability = [
@@ -1923,13 +1923,13 @@ def Tosa_PadOp : Tosa_InferShapedTypeOp<"pad"> {
   }];
 
   let arguments = (ins
-    Tosa_RankedTensor:$input1,
+    Tosa_TensorAtLeast1D:$input1,
     Tosa_Shape:$padding,
     Tosa_ScalarTensor:$pad_const
   );
 
   let results = (outs
-    Tosa_RankedTensor:$output
+    Tosa_TensorAtLeast1D:$output
   );
 
   list<Availability> availability = [
@@ -1996,12 +1996,12 @@ def Tosa_ReverseOp: Tosa_Op<"reverse", [
   }];
 
   let arguments = (ins
-    Tosa_Tensor:$input1,
+    Tosa_TensorAtLeast1D:$input1,
     I32Attr:$axis
   );
 
   let results = (outs
-    Tosa_Tensor:$output
+    Tosa_TensorAtLeast1D:$output
   );
 
   list<Availability> availability = [
@@ -2028,13 +2028,13 @@ def Tosa_SliceOp : Tosa_InferShapedTypeOp<"slice"> {
   }];
 
   let arguments = (ins
-    Tosa_Tensor:$input1,
+    Tosa_TensorAtLeast1D:$input1,
     Tosa_Shape:$start,
     Tosa_Shape:$size
   );
 
   let results = (outs
-    Tosa_Tensor:$output
+    Tosa_TensorAtLeast1D:$output
   );
 
   list<Availability> availability = [
@@ -2058,11 +2058,11 @@ def Tosa_TileOp : Tosa_InferShapedTypeOp<"tile"> {
   }];
 
   let arguments = (ins
-    Tosa_Tensor:$input1,
+    Tosa_TensorAtLeast1D:$input1,
     Tosa_Shape:$multiples);
 
   let results = (outs
-    Tosa_Tensor:$output
+    Tosa_TensorAtLeast1D:$output
   );
 
   list<Availability> availability = [
@@ -2093,12 +2093,12 @@ def Tosa_TransposeOp : Tosa_InferShapedTypeOp<"transpose",
   }];
 
   let arguments = (ins
-    Tosa_Tensor:$input1,
+    Tosa_TensorAtLeast1D:$input1,
     DenseI32ArrayAttr:$perms
   );
 
   let results = (
-    outs Tosa_Tensor:$output
+    outs Tosa_TensorAtLeast1D:$output
   );
 
   list<Availability> availability = [
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
index 0038d8c386ca7..67011f22fbe2a 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
@@ -101,6 +101,10 @@ def AllDimensionsAreSizeOne : And<[
     IsRankedTensorTypePred,
     CPred<"::llvm::all_of(::llvm::cast<::mlir::RankedTensorType>($_self).getShape(), [](auto v) { return v == 1; })">]>;
 
+def AtLeastRankOne : And<[
+  IsRankedTensorTypePred,
+  CPred<"::llvm::cast<::mlir::RankedTensorType>($_self).getRank() >= 1">]>;
+
 class TosaTensorOf<
     list<Type> allowedTypes, string summary = "tosa-conformant tensor">
     : TensorOf<allowedTypes, [Or<[HasNo0Dimensions, IsUnrankedTensorTypePred]>], summary>;
@@ -183,6 +187,9 @@ def Tosa_TensorUpto4D : AnyTypeOf<[
 def Tosa_Int32TensorUpto4D : AnyTypeOf<[
   Tosa_UnrankedTensor, TosaTensorRankOf<[Tosa_Int32], [0,1,2,3,4]>]>;
 
+def Tosa_TensorAtLeast1D : AnyTypeOf<[
+  Tosa_UnrankedTensor, TosaRankedTensorOf<[Tosa_AnyNumber], [AtLeastRankOne]>], "tosa-conformant tensor of at least rank 1", "::mlir::TensorType">;
+
 //===----------------------------------------------------------------------===//
 // Generic scalar, vector, or tensor of a particular type.
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 97a3009a20302..cdba332792eb0 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -1354,8 +1354,13 @@ LogicalResult tosa::PadOp::verify() {
     }
   }
 
-  RankedTensorType inputType = getInput1().getType();
-  RankedTensorType outputType = getOutput().getType();
+  RankedTensorType inputType =
+      llvm::dyn_cast<RankedTensorType>(getInput1().getType());
+  RankedTensorType outputType =
+      llvm::dyn_cast<RankedTensorType>(getOutput().getType());
+  if (!inputType || !outputType)
+    return success();
+
   auto paddingRank = cast<tosa::shapeType>(getPadding().getType()).getRank();
 
   if (inputType.getRank() != outputType.getRank())
diff --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir
index 3bc438e465e1d..077a6cee0a1bb 100644
--- a/mlir/test/Dialect/Tosa/canonicalize.mlir
+++ b/mlir/test/Dialect/Tosa/canonicalize.mlir
@@ -915,29 +915,6 @@ func.func @fold_abs_abs(%arg0: tensor<?x1xf32>) -> tensor<?x1xf32> {
 
 // -----
 
-// CHECK-LABEL: @fold_reduce_rank_zero
-func.func @fold_reduce_rank_zero() {
-  // CHECK-NOT: tosa.reduce_min
-  // CHECK-NOT: tosa.reverse
-  %0 = tensor.empty() : tensor<i32>
-  %1 = tosa.reduce_min %0 {axis = 0 : i32} : (tensor<i32>) -> tensor<i32>
-  %2 = tosa.reverse %0 {axis = 0 : i32} : (tensor<i32>) -> tensor<i32>
-  return
-}
-
-// -----
-
-// CHECK-LABEL: @fold_tile_rank_zero
-func.func nested @fold_tile_rank_zero() -> tensor<i32> {
-  // CHECK-NOT: tosa.tile
-  %0 = tensor.empty() : tensor<i32>
-  %cst = tosa.const_shape { values = dense<> : tensor<0xindex> } : () -> !tosa.shape<0>
-  %1 = tosa.tile %0, %cst : (tensor<i32>, !tosa.shape<0>) -> tensor<i32>
-  return %1 : tensor<i32>
-}
-
-// -----
-
 // CHECK-LABEL: @reshape_quant_nofold
 // check that segfault is fixed
 func.func @reshape_quant_nofold() -> tensor<1x1x1x1xi32> {
@@ -1015,12 +992,12 @@ func.func @cast_quant_nofold() -> tensor<!quant.uniform<i8:f32, 3.07574046018999
 // -----
 
 // CHECK-LABEL: @reverse_quant_fold
-func.func @reverse_quant_fold() -> tensor<!quant.uniform<i8:f32, 3.0757404601899907E-5:-128>> {
-   // CHECK: %[[CST:.*]] = "tosa.const"() <{values = dense<0> : tensor<i8>}> : () -> tensor<!quant.uniform<i8:f32, 3.0757404601899907E-5:-128>>
+func.func @reverse_quant_fold() -> tensor<1x!quant.uniform<i8:f32, 3.0757404601899907E-5:-128>> {
+   // CHECK: %[[CST:.*]] = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> : () -> tensor<1x!quant.uniform<i8:f32, 3.0757404601899907E-5:-128>>
    // CHECK: return %[[CST]]
-   %0 = "tosa.const"() {values = dense<0> : tensor<i8>} : () -> tensor<!quant.uniform<i8:f32, 3.0757404601899907E-5:-128>>
-   %1 = "tosa.reverse"(%0) { axis = 0 : i32 } : (tensor<!quant.uniform<i8:f32, 3.0757404601899907E-5:-128>>) -> tensor<!quant.uniform<i8:f32, 3.0757404601899907E-5:-128>>
-   return %1 : tensor<!quant.uniform<i8:f32, 3.0757404601899907E-5:-128>>
+   %0 = "tosa.const"() {values = dense<0> : tensor<1xi8>} : () -> tensor<1x!quant.uniform<i8:f32, 3.0757404601899907E-5:-128>>
+   %1 = "tosa.reverse"(%0) { axis = 0 : i32 } : (tensor<1x!quant.uniform<i8:f32, 3.0757404601899907E-5:-128>>) -> tensor<1x!quant.uniform<i8:f32, 3.0757404601899907E-5:-128>>
+   return %1 : tensor<1x!quant.uniform<i8:f32, 3.0757404601899907E-5:-128>>
 }
 
 // -----
diff --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir
index a488c051dcd3b..2dc749422c12d 100644
--- a/mlir/test/Dialect/Tosa/invalid.mlir
+++ b/mlir/test/Dialect/Tosa/invalid.mlir
@@ -452,9 +452,9 @@ func.func @test_reduce_sum_invalid_axis(%arg0 : tensor<2x3x4xf32>) -> () {
 
 // -----
 
-func.func @test_reduce_min_invalid_output_rank(%arg0 : tensor<i32>) -> () {
+func.func @test_reduce_min_invalid_output_rank(%arg0 : tensor<1xi32>) -> () {
   // expected-error at +1 {{'tosa.reduce_min' op expect output tensor rank to be equal to input tensor rank}}
-  %0 = tosa.reduce_min %arg0 {axis = 0 : i32} : (tensor<i32>) -> tensor<1x10xi32>
+  %0 = tosa.reduce_min %arg0 {axis = 0 : i32} : (tensor<1xi32>) -> tensor<1x10xi32>
   return
 }
 
@@ -1852,3 +1852,72 @@ func.func @test_maxpool2d_unexpected_output_width(%arg0: tensor<1x32x32x8xf32>)
          (tensor<1x32x32x8xf32>) -> tensor<1x32x2x8xf32>
   return %0 : tensor<1x32x2x8xf32>
 }
+
+// -----
+
+func.func @test_scalar_argmax(%arg0: tensor<i32>) -> tensor<i32> {
+  // expected-error at +1 {{'tosa.argmax' op operand #0 must be tosa-conformant tensor of at least rank 1, but got 'tensor<i32>'}}
+  %0 = tosa.argmax %arg0 {axis = 0 : i32} : (tensor<i32>) -> tensor<i32>
+  return %0 : tensor<i32>
+}
+
+// -----
+
+func.func @test_scalar_reduce_all(%arg0: tensor<i1>) -> tensor<i1> {
+  // expected-error at +1 {{'tosa.reduce_all' op operand #0 must be tosa-conformant tensor of at least rank 1, but got 'tensor<i1>'}}
+  %0 = tosa.reduce_all %arg0 {axis = 0 : i32} : (tensor<i1>) -> tensor<i1>
+  return %0 : tensor<i1>
+}
+
+// -----
+
+func.func @test_scalar_inputs_concat(%arg0: tensor<f32>, %arg1: tensor<f32>) -> tensor<2xf32> {
+  // expected-error at +1 {{'tosa.concat' op operand #0 must be variadic of tosa-conformant tensor of at least rank 1, but got 'tensor<f32>'}}
+  %0 = tosa.concat %arg0, %arg1 {axis = 0 : i32} : (tensor<f32>, tensor<f32>) -> tensor<2xf32>
+  return %0 : tensor<2xf32>
+}
+
+// -----
+
+func.func @test_scalar_pad(%arg0: tensor<f32>) -> tensor<f32> {
+  %0 = "tosa.const"() {values = dense<3.14> : tensor<1xf32>} : () -> tensor<1xf32>
+  %padding = tosa.const_shape {values = dense<0> : tensor<6xindex>} : () -> !tosa.shape<6>
+  // expected-error at +1 {{'tosa.pad' op operand #0 must be tosa-conformant tensor of at least rank 1, but got 'tensor<f32>'}}
+  %1 = tosa.pad %arg0, %padding, %0 : (tensor<f32>, !tosa.shape<6>, tensor<1xf32>) -> tensor<f32>
+  return %1 : tensor<f32>
+}
+
+// -----
+
+func.func @test_scalar_reverse(%arg0: tensor<f32>) -> tensor<f32> {
+  // expected-error at +1 {{'tosa.reverse' op operand #0 must be tosa-conformant tensor of at least rank 1, but got 'tensor<f32>'}}
+  %0 = tosa.reverse %arg0 {axis = 0: i32} : (tensor<f32>) -> tensor<f32>
+  return %arg0 : tensor<f32>
+}
+
+// -----
+
+func.func @test_scalar_slice(%arg0: tensor<f32>) -> tensor<f32> {
+  %0 = tosa.const_shape {values = dense<[]> : tensor<0xindex>} : () -> !tosa.shape<0>
+  %1 = tosa.const_shape {values = dense<[]> : tensor<0xindex>} : () -> !tosa.shape<0>
+  // expected-error at +1 {{'tosa.slice' op operand #0 must be tosa-conformant tensor of at least rank 1, but got 'tensor<f32>'}}
+  %2 = tosa.slice %arg0, %0, %1 : (tensor<f32>, !tosa.shape<0>, !tosa.shape<0>) -> tensor<f32>
+  return %2 : tensor<f32>
+}
+
+// -----
+
+func.func @test_scalar_tile(%arg0: tensor<f32>) -> tensor<*xf32> {
+  %cst = tosa.const_shape { values = dense<[]> : tensor<0xindex> } : () -> !tosa.shape<0>
+  // expected-error at +1 {{'tosa.tile' op operand #0 must be tosa-conformant tensor of at least rank 1, but got 'tensor<f32>'}}
+  %0 = tosa.tile %arg0, %cst: (tensor<f32>, !tosa.shape<0>) -> tensor<*xf32>
+  return %0 : tensor<*xf32>
+}
+
+// -----
+
+func.func @test_scalar_output_transpose(%arg0: tensor<*xf32>) -> tensor<f32> {
+  // expected-error at +1 {{'tosa.transpose' op result #0 must be tosa-conformant tensor of at least rank 1, but got 'tensor<f32>'}}
+  %1 = tosa.transpose %arg0 {perms = array<i32: 2, 0, 1>} : (tensor<*xf32>) -> tensor<f32>
+  return %1 : tensor<f32>
+}

``````````

</details>


https://github.com/llvm/llvm-project/pull/131335


More information about the Mlir-commits mailing list