[Mlir-commits] [mlir] [mlir][tosa] Allow unranked indices argument for gather/scatter (PR #140618)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon May 19 13:58:41 PDT 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-tosa
@llvm/pr-subscribers-mlir
Author: Luke Hutton (lhutton1)
<details>
<summary>Changes</summary>
This commit allows the indices argument for gather and scatter to be unranked. This can be computed during shape inference.
---
Full diff: https://github.com/llvm/llvm-project/pull/140618.diff
3 Files Affected:
- (modified) mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td (+2-2)
- (modified) mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td (+2)
- (modified) mlir/test/Dialect/Tosa/ops.mlir (+14)
``````````diff
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index 52bb0eb992b69..a10fe28cef853 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -2125,7 +2125,7 @@ def Tosa_GatherOp : Tosa_InferShapedTypeOp<"gather"> {
let arguments = (ins
Tosa_Tensor3D:$values,
- TosaTensorRankOf<[Tosa_Int32], [2]>:$indices
+ Tosa_Int32Tensor2D:$indices
);
let results = (outs
@@ -2159,7 +2159,7 @@ def Tosa_ScatterOp : Tosa_InferShapedTypeOp<"scatter"> {
let arguments = (ins
Tosa_Tensor3D:$values_in,
- TosaTensorRankOf<[Tosa_Int32], [2]>:$indices,
+ Tosa_Int32Tensor2D:$indices,
Tosa_Tensor3D:$input
);
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
index b9ac1ff705514..536551c8f8437 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
@@ -181,6 +181,8 @@ def Tosa_TensorUpto4D : AnyTypeOf<[
def Tosa_Int32TensorUpto4D : AnyTypeOf<[
Tosa_UnrankedTensor, TosaTensorRankOf<[Tosa_Int32], [0,1,2,3,4]>]>;
+def Tosa_Int32Tensor2D : AnyTypeOf<[
+ Tosa_UnrankedTensor, TosaTensorRankOf<[Tosa_Int32], [2]>]>;
def Tosa_TensorAtLeast1D : AnyTypeOf<[
Tosa_UnrankedTensor, TosaRankedTensorOf<[Tosa_AnyNumber], [AtLeastRankOne]>], "tosa-conformant tensor of at least rank 1", "::mlir::TensorType">;
diff --git a/mlir/test/Dialect/Tosa/ops.mlir b/mlir/test/Dialect/Tosa/ops.mlir
index e327ed900f45f..88b819325d02f 100644
--- a/mlir/test/Dialect/Tosa/ops.mlir
+++ b/mlir/test/Dialect/Tosa/ops.mlir
@@ -711,6 +711,20 @@ func.func @test_scatter(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x26xi32>, %a
return %0 : tensor<13x21x3xf32>
}
+// -----
+// CHECK-LABEL: gather_unranked_indices
+func.func @test_gather_unranked_indices(%arg0: tensor<13x21x3xf32>, %arg1: tensor<*xi32>) -> tensor<13x26x3xf32> {
+ %0 = tosa.gather %arg0, %arg1 : (tensor<13x21x3xf32>, tensor<*xi32>) -> tensor<13x26x3xf32>
+ return %0 : tensor<13x26x3xf32>
+}
+
+// -----
+// CHECK-LABEL: scatter_unranked_indices
+func.func @test_scatter_unranked_indices(%arg0: tensor<13x21x3xf32>, %arg1: tensor<*xi32>, %arg2: tensor<13x26x3xf32>) -> tensor<13x21x3xf32> {
+ %0 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<13x21x3xf32>, tensor<*xi32>, tensor<13x26x3xf32>) -> tensor<13x21x3xf32>
+ return %0 : tensor<13x21x3xf32>
+}
+
// -----
// CHECK-LABEL: resize
func.func @test_resize(%arg0: tensor<1x32x32x8xf32>) -> tensor<1x64x64x8xf32> {
``````````
</details>
https://github.com/llvm/llvm-project/pull/140618
More information about the Mlir-commits
mailing list