[Mlir-commits] [mlir] [mlir][tosa] Require operand/result tensors of at least rank 1 for some operations (PR #131335)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Mar 14 06:54:16 PDT 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Luke Hutton (lhutton1)
<details>
<summary>Changes</summary>
This commit updates the following operations (operands/results) to be of at least rank 1 such that it aligns wih the expectations of the specification:
- ARGMAX (input)
- REDUCE_ALL (input/output)
- REDUCE_ANY (input/output)
- REDUCE_MAX (input/output)
- REDUCE_MIN (input/output)
- REDUCE_PRODUCT (input/output)
- REDUCE_SUM (input/output)
- CONCAT (each input in input1/output)
- PAD (input1/output)
- REVERSE (input1/output)
- SLICE (input1/output)
- TILE (input1/output)
- TRANSPOSE (input1/output)
In addition to this change, PAD has been updated to allow unranked tensors for input1/output, inline with other operations.
---
Full diff: https://github.com/llvm/llvm-project/pull/131335.diff
5 Files Affected:
- (modified) mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td (+25-25)
- (modified) mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td (+7)
- (modified) mlir/lib/Dialect/Tosa/IR/TosaOps.cpp (+7-2)
- (modified) mlir/test/Dialect/Tosa/canonicalize.mlir (+5-28)
- (modified) mlir/test/Dialect/Tosa/invalid.mlir (+71-2)
``````````diff
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index b79993f48b379..0c99dd6130c2a 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -41,7 +41,7 @@ def Tosa_ArgMaxOp : Tosa_InferShapedTypeOp<"argmax"> {
}];
let arguments = (ins
- Tosa_Tensor: $input,
+ Tosa_TensorAtLeast1D: $input,
I32Attr: $axis,
DefaultValuedAttr<Tosa_NanPropagationAttr, "\"PROPAGATE\"">:$nan_mode
);
@@ -1629,12 +1629,12 @@ def Tosa_ReduceAllOp : Tosa_InferTensorTypeOp<"reduce_all"> {
}];
let arguments = (ins
- Tosa_Tensor:$input,
+ Tosa_TensorAtLeast1D:$input,
I32Attr:$axis
);
let results = (outs
- Tosa_Tensor:$output
+ Tosa_TensorAtLeast1D:$output
);
list<Availability> availability = [
@@ -1668,12 +1668,12 @@ def Tosa_ReduceAnyOp : Tosa_InferTensorTypeOp<"reduce_any"> {
}];
let arguments = (ins
- Tosa_Tensor:$input,
+ Tosa_TensorAtLeast1D:$input,
I32Attr:$axis
);
let results = (outs
- Tosa_Tensor:$output
+ Tosa_TensorAtLeast1D:$output
);
list<Availability> availability = [
@@ -1707,13 +1707,13 @@ def Tosa_ReduceMaxOp : Tosa_InferTensorTypeOp<"reduce_max"> {
}];
let arguments = (ins
- Tosa_Tensor:$input,
+ Tosa_TensorAtLeast1D:$input,
I32Attr:$axis,
DefaultValuedAttr<Tosa_NanPropagationAttr, "\"PROPAGATE\"">:$nan_mode
);
let results = (outs
- Tosa_Tensor:$output
+ Tosa_TensorAtLeast1D:$output
);
list<Availability> availability = [
@@ -1748,13 +1748,13 @@ def Tosa_ReduceMinOp : Tosa_InferTensorTypeOp<"reduce_min"> {
}];
let arguments = (ins
- Tosa_Tensor:$input,
+ Tosa_TensorAtLeast1D:$input,
I32Attr:$axis,
DefaultValuedAttr<Tosa_NanPropagationAttr, "\"PROPAGATE\"">:$nan_mode
);
let results = (outs
- Tosa_Tensor:$output
+ Tosa_TensorAtLeast1D:$output
);
list<Availability> availability = [
@@ -1789,12 +1789,12 @@ def Tosa_ReduceProductOp : Tosa_InferTensorTypeOp<"reduce_product"> {
}];
let arguments = (ins
- Tosa_Tensor:$input,
+ Tosa_TensorAtLeast1D:$input,
I32Attr:$axis
);
let results = (outs
- Tosa_Tensor:$output
+ Tosa_TensorAtLeast1D:$output
);
list<Availability> availability = [
@@ -1828,12 +1828,12 @@ def Tosa_ReduceSumOp : Tosa_InferTensorTypeOp<"reduce_sum"> {
}];
let arguments = (ins
- Tosa_Tensor:$input,
+ Tosa_TensorAtLeast1D:$input,
I32Attr:$axis
);
let results = (outs
- Tosa_Tensor:$output
+ Tosa_TensorAtLeast1D:$output
);
list<Availability> availability = [
@@ -1872,12 +1872,12 @@ def Tosa_ConcatOp : Tosa_InferTensorTypeOp<"concat"> {
}];
let arguments = (ins
- Variadic<Tosa_Tensor>:$input1,
+ Variadic<Tosa_TensorAtLeast1D>:$input1,
I32Attr:$axis
);
let results = (outs
- Tosa_Tensor:$output
+ Tosa_TensorAtLeast1D:$output
);
list<Availability> availability = [
@@ -1923,13 +1923,13 @@ def Tosa_PadOp : Tosa_InferShapedTypeOp<"pad"> {
}];
let arguments = (ins
- Tosa_RankedTensor:$input1,
+ Tosa_TensorAtLeast1D:$input1,
Tosa_Shape:$padding,
Tosa_ScalarTensor:$pad_const
);
let results = (outs
- Tosa_RankedTensor:$output
+ Tosa_TensorAtLeast1D:$output
);
list<Availability> availability = [
@@ -1996,12 +1996,12 @@ def Tosa_ReverseOp: Tosa_Op<"reverse", [
}];
let arguments = (ins
- Tosa_Tensor:$input1,
+ Tosa_TensorAtLeast1D:$input1,
I32Attr:$axis
);
let results = (outs
- Tosa_Tensor:$output
+ Tosa_TensorAtLeast1D:$output
);
list<Availability> availability = [
@@ -2028,13 +2028,13 @@ def Tosa_SliceOp : Tosa_InferShapedTypeOp<"slice"> {
}];
let arguments = (ins
- Tosa_Tensor:$input1,
+ Tosa_TensorAtLeast1D:$input1,
Tosa_Shape:$start,
Tosa_Shape:$size
);
let results = (outs
- Tosa_Tensor:$output
+ Tosa_TensorAtLeast1D:$output
);
list<Availability> availability = [
@@ -2058,11 +2058,11 @@ def Tosa_TileOp : Tosa_InferShapedTypeOp<"tile"> {
}];
let arguments = (ins
- Tosa_Tensor:$input1,
+ Tosa_TensorAtLeast1D:$input1,
Tosa_Shape:$multiples);
let results = (outs
- Tosa_Tensor:$output
+ Tosa_TensorAtLeast1D:$output
);
list<Availability> availability = [
@@ -2093,12 +2093,12 @@ def Tosa_TransposeOp : Tosa_InferShapedTypeOp<"transpose",
}];
let arguments = (ins
- Tosa_Tensor:$input1,
+ Tosa_TensorAtLeast1D:$input1,
DenseI32ArrayAttr:$perms
);
let results = (
- outs Tosa_Tensor:$output
+ outs Tosa_TensorAtLeast1D:$output
);
list<Availability> availability = [
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
index 0038d8c386ca7..67011f22fbe2a 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
@@ -101,6 +101,10 @@ def AllDimensionsAreSizeOne : And<[
IsRankedTensorTypePred,
CPred<"::llvm::all_of(::llvm::cast<::mlir::RankedTensorType>($_self).getShape(), [](auto v) { return v == 1; })">]>;
+def AtLeastRankOne : And<[
+ IsRankedTensorTypePred,
+ CPred<"::llvm::cast<::mlir::RankedTensorType>($_self).getRank() >= 1">]>;
+
class TosaTensorOf<
list<Type> allowedTypes, string summary = "tosa-conformant tensor">
: TensorOf<allowedTypes, [Or<[HasNo0Dimensions, IsUnrankedTensorTypePred]>], summary>;
@@ -183,6 +187,9 @@ def Tosa_TensorUpto4D : AnyTypeOf<[
def Tosa_Int32TensorUpto4D : AnyTypeOf<[
Tosa_UnrankedTensor, TosaTensorRankOf<[Tosa_Int32], [0,1,2,3,4]>]>;
+def Tosa_TensorAtLeast1D : AnyTypeOf<[
+ Tosa_UnrankedTensor, TosaRankedTensorOf<[Tosa_AnyNumber], [AtLeastRankOne]>], "tosa-conformant tensor of at least rank 1", "::mlir::TensorType">;
+
//===----------------------------------------------------------------------===//
// Generic scalar, vector, or tensor of a particular type.
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 97a3009a20302..cdba332792eb0 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -1354,8 +1354,13 @@ LogicalResult tosa::PadOp::verify() {
}
}
- RankedTensorType inputType = getInput1().getType();
- RankedTensorType outputType = getOutput().getType();
+ RankedTensorType inputType =
+ llvm::dyn_cast<RankedTensorType>(getInput1().getType());
+ RankedTensorType outputType =
+ llvm::dyn_cast<RankedTensorType>(getOutput().getType());
+ if (!inputType || !outputType)
+ return success();
+
auto paddingRank = cast<tosa::shapeType>(getPadding().getType()).getRank();
if (inputType.getRank() != outputType.getRank())
diff --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir
index 3bc438e465e1d..077a6cee0a1bb 100644
--- a/mlir/test/Dialect/Tosa/canonicalize.mlir
+++ b/mlir/test/Dialect/Tosa/canonicalize.mlir
@@ -915,29 +915,6 @@ func.func @fold_abs_abs(%arg0: tensor<?x1xf32>) -> tensor<?x1xf32> {
// -----
-// CHECK-LABEL: @fold_reduce_rank_zero
-func.func @fold_reduce_rank_zero() {
- // CHECK-NOT: tosa.reduce_min
- // CHECK-NOT: tosa.reverse
- %0 = tensor.empty() : tensor<i32>
- %1 = tosa.reduce_min %0 {axis = 0 : i32} : (tensor<i32>) -> tensor<i32>
- %2 = tosa.reverse %0 {axis = 0 : i32} : (tensor<i32>) -> tensor<i32>
- return
-}
-
-// -----
-
-// CHECK-LABEL: @fold_tile_rank_zero
-func.func nested @fold_tile_rank_zero() -> tensor<i32> {
- // CHECK-NOT: tosa.tile
- %0 = tensor.empty() : tensor<i32>
- %cst = tosa.const_shape { values = dense<> : tensor<0xindex> } : () -> !tosa.shape<0>
- %1 = tosa.tile %0, %cst : (tensor<i32>, !tosa.shape<0>) -> tensor<i32>
- return %1 : tensor<i32>
-}
-
-// -----
-
// CHECK-LABEL: @reshape_quant_nofold
// check that segfault is fixed
func.func @reshape_quant_nofold() -> tensor<1x1x1x1xi32> {
@@ -1015,12 +992,12 @@ func.func @cast_quant_nofold() -> tensor<!quant.uniform<i8:f32, 3.07574046018999
// -----
// CHECK-LABEL: @reverse_quant_fold
-func.func @reverse_quant_fold() -> tensor<!quant.uniform<i8:f32, 3.0757404601899907E-5:-128>> {
- // CHECK: %[[CST:.*]] = "tosa.const"() <{values = dense<0> : tensor<i8>}> : () -> tensor<!quant.uniform<i8:f32, 3.0757404601899907E-5:-128>>
+func.func @reverse_quant_fold() -> tensor<1x!quant.uniform<i8:f32, 3.0757404601899907E-5:-128>> {
+ // CHECK: %[[CST:.*]] = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> : () -> tensor<1x!quant.uniform<i8:f32, 3.0757404601899907E-5:-128>>
// CHECK: return %[[CST]]
- %0 = "tosa.const"() {values = dense<0> : tensor<i8>} : () -> tensor<!quant.uniform<i8:f32, 3.0757404601899907E-5:-128>>
- %1 = "tosa.reverse"(%0) { axis = 0 : i32 } : (tensor<!quant.uniform<i8:f32, 3.0757404601899907E-5:-128>>) -> tensor<!quant.uniform<i8:f32, 3.0757404601899907E-5:-128>>
- return %1 : tensor<!quant.uniform<i8:f32, 3.0757404601899907E-5:-128>>
+ %0 = "tosa.const"() {values = dense<0> : tensor<1xi8>} : () -> tensor<1x!quant.uniform<i8:f32, 3.0757404601899907E-5:-128>>
+ %1 = "tosa.reverse"(%0) { axis = 0 : i32 } : (tensor<1x!quant.uniform<i8:f32, 3.0757404601899907E-5:-128>>) -> tensor<1x!quant.uniform<i8:f32, 3.0757404601899907E-5:-128>>
+ return %1 : tensor<1x!quant.uniform<i8:f32, 3.0757404601899907E-5:-128>>
}
// -----
diff --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir
index a488c051dcd3b..2dc749422c12d 100644
--- a/mlir/test/Dialect/Tosa/invalid.mlir
+++ b/mlir/test/Dialect/Tosa/invalid.mlir
@@ -452,9 +452,9 @@ func.func @test_reduce_sum_invalid_axis(%arg0 : tensor<2x3x4xf32>) -> () {
// -----
-func.func @test_reduce_min_invalid_output_rank(%arg0 : tensor<i32>) -> () {
+func.func @test_reduce_min_invalid_output_rank(%arg0 : tensor<1xi32>) -> () {
// expected-error at +1 {{'tosa.reduce_min' op expect output tensor rank to be equal to input tensor rank}}
- %0 = tosa.reduce_min %arg0 {axis = 0 : i32} : (tensor<i32>) -> tensor<1x10xi32>
+ %0 = tosa.reduce_min %arg0 {axis = 0 : i32} : (tensor<1xi32>) -> tensor<1x10xi32>
return
}
@@ -1852,3 +1852,72 @@ func.func @test_maxpool2d_unexpected_output_width(%arg0: tensor<1x32x32x8xf32>)
(tensor<1x32x32x8xf32>) -> tensor<1x32x2x8xf32>
return %0 : tensor<1x32x2x8xf32>
}
+
+// -----
+
+func.func @test_scalar_argmax(%arg0: tensor<i32>) -> tensor<i32> {
+ // expected-error at +1 {{'tosa.argmax' op operand #0 must be tosa-conformant tensor of at least rank 1, but got 'tensor<i32>'}}
+ %0 = tosa.argmax %arg0 {axis = 0 : i32} : (tensor<i32>) -> tensor<i32>
+ return %0 : tensor<i32>
+}
+
+// -----
+
+func.func @test_scalar_reduce_all(%arg0: tensor<i1>) -> tensor<i1> {
+ // expected-error at +1 {{'tosa.reduce_all' op operand #0 must be tosa-conformant tensor of at least rank 1, but got 'tensor<i1>'}}
+ %0 = tosa.reduce_all %arg0 {axis = 0 : i32} : (tensor<i1>) -> tensor<i1>
+ return %0 : tensor<i1>
+}
+
+// -----
+
+func.func @test_scalar_inputs_concat(%arg0: tensor<f32>, %arg1: tensor<f32>) -> tensor<2xf32> {
+ // expected-error at +1 {{'tosa.concat' op operand #0 must be variadic of tosa-conformant tensor of at least rank 1, but got 'tensor<f32>'}}
+ %0 = tosa.concat %arg0, %arg1 {axis = 0 : i32} : (tensor<f32>, tensor<f32>) -> tensor<2xf32>
+ return %0 : tensor<2xf32>
+}
+
+// -----
+
+func.func @test_scalar_pad(%arg0: tensor<f32>) -> tensor<f32> {
+ %0 = "tosa.const"() {values = dense<3.14> : tensor<1xf32>} : () -> tensor<1xf32>
+ %padding = tosa.const_shape {values = dense<0> : tensor<6xindex>} : () -> !tosa.shape<6>
+ // expected-error at +1 {{'tosa.pad' op operand #0 must be tosa-conformant tensor of at least rank 1, but got 'tensor<f32>'}}
+ %1 = tosa.pad %arg0, %padding, %0 : (tensor<f32>, !tosa.shape<6>, tensor<1xf32>) -> tensor<f32>
+ return %1 : tensor<f32>
+}
+
+// -----
+
+func.func @test_scalar_reverse(%arg0: tensor<f32>) -> tensor<f32> {
+ // expected-error at +1 {{'tosa.reverse' op operand #0 must be tosa-conformant tensor of at least rank 1, but got 'tensor<f32>'}}
+ %0 = tosa.reverse %arg0 {axis = 0: i32} : (tensor<f32>) -> tensor<f32>
+ return %arg0 : tensor<f32>
+}
+
+// -----
+
+func.func @test_scalar_slice(%arg0: tensor<f32>) -> tensor<f32> {
+ %0 = tosa.const_shape {values = dense<[]> : tensor<0xindex>} : () -> !tosa.shape<0>
+ %1 = tosa.const_shape {values = dense<[]> : tensor<0xindex>} : () -> !tosa.shape<0>
+ // expected-error at +1 {{'tosa.slice' op operand #0 must be tosa-conformant tensor of at least rank 1, but got 'tensor<f32>'}}
+ %2 = tosa.slice %arg0, %0, %1 : (tensor<f32>, !tosa.shape<0>, !tosa.shape<0>) -> tensor<f32>
+ return %2 : tensor<f32>
+}
+
+// -----
+
+func.func @test_scalar_tile(%arg0: tensor<f32>) -> tensor<*xf32> {
+ %cst = tosa.const_shape { values = dense<[]> : tensor<0xindex> } : () -> !tosa.shape<0>
+ // expected-error at +1 {{'tosa.tile' op operand #0 must be tosa-conformant tensor of at least rank 1, but got 'tensor<f32>'}}
+ %0 = tosa.tile %arg0, %cst: (tensor<f32>, !tosa.shape<0>) -> tensor<*xf32>
+ return %0 : tensor<*xf32>
+}
+
+// -----
+
+func.func @test_scalar_output_transpose(%arg0: tensor<*xf32>) -> tensor<f32> {
+ // expected-error at +1 {{'tosa.transpose' op result #0 must be tosa-conformant tensor of at least rank 1, but got 'tensor<f32>'}}
+ %1 = tosa.transpose %arg0 {perms = array<i32: 2, 0, 1>} : (tensor<*xf32>) -> tensor<f32>
+ return %1 : tensor<f32>
+}
``````````
</details>
https://github.com/llvm/llvm-project/pull/131335
More information about the Mlir-commits
mailing list