[Mlir-commits] [mlir] 498121e - [mlir][tosa] Allow unranked indices argument for gather/scatter (#140618)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu May 22 09:01:22 PDT 2025
Author: Luke Hutton
Date: 2025-05-22T17:01:19+01:00
New Revision: 498121e00454fc306f345b7854cf96cb7282374b
URL: https://github.com/llvm/llvm-project/commit/498121e00454fc306f345b7854cf96cb7282374b
DIFF: https://github.com/llvm/llvm-project/commit/498121e00454fc306f345b7854cf96cb7282374b.diff
LOG: [mlir][tosa] Allow unranked indices argument for gather/scatter (#140618)
This commit allows the indices argument for gather and scatter to be
unranked. This can be computed during shape inference.
Added:
Modified:
mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
mlir/test/Dialect/Tosa/ops.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index 86f9ab94ec152..f93dd901535c3 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -2125,7 +2125,7 @@ def Tosa_GatherOp : Tosa_InferShapedTypeOp<"gather"> {
let arguments = (ins
Tosa_Tensor3D:$values,
- TosaTensorRankOf<[Tosa_Int32], [2]>:$indices
+ Tosa_Int32Tensor2D:$indices
);
let results = (outs
@@ -2159,7 +2159,7 @@ def Tosa_ScatterOp : Tosa_InferShapedTypeOp<"scatter"> {
let arguments = (ins
Tosa_Tensor3D:$values_in,
- TosaTensorRankOf<[Tosa_Int32], [2]>:$indices,
+ Tosa_Int32Tensor2D:$indices,
Tosa_Tensor3D:$input
);
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
index b9ac1ff705514..536551c8f8437 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
@@ -181,6 +181,8 @@ def Tosa_TensorUpto4D : AnyTypeOf<[
def Tosa_Int32TensorUpto4D : AnyTypeOf<[
Tosa_UnrankedTensor, TosaTensorRankOf<[Tosa_Int32], [0,1,2,3,4]>]>;
+def Tosa_Int32Tensor2D : AnyTypeOf<[
+ Tosa_UnrankedTensor, TosaTensorRankOf<[Tosa_Int32], [2]>]>;
def Tosa_TensorAtLeast1D : AnyTypeOf<[
Tosa_UnrankedTensor, TosaRankedTensorOf<[Tosa_AnyNumber], [AtLeastRankOne]>], "tosa-conformant tensor of at least rank 1", "::mlir::TensorType">;
diff --git a/mlir/test/Dialect/Tosa/ops.mlir b/mlir/test/Dialect/Tosa/ops.mlir
index 7aea1c06698e8..5ec506a45b3ad 100644
--- a/mlir/test/Dialect/Tosa/ops.mlir
+++ b/mlir/test/Dialect/Tosa/ops.mlir
@@ -719,6 +719,20 @@ func.func @test_scatter(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x26xi32>, %a
return %0 : tensor<13x21x3xf32>
}
+// -----
+// CHECK-LABEL: gather_unranked_indices
+func.func @test_gather_unranked_indices(%arg0: tensor<13x21x3xf32>, %arg1: tensor<*xi32>) -> tensor<13x26x3xf32> {
+ %0 = tosa.gather %arg0, %arg1 : (tensor<13x21x3xf32>, tensor<*xi32>) -> tensor<13x26x3xf32>
+ return %0 : tensor<13x26x3xf32>
+}
+
+// -----
+// CHECK-LABEL: scatter_unranked_indices
+func.func @test_scatter_unranked_indices(%arg0: tensor<13x21x3xf32>, %arg1: tensor<*xi32>, %arg2: tensor<13x26x3xf32>) -> tensor<13x21x3xf32> {
+ %0 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<13x21x3xf32>, tensor<*xi32>, tensor<13x26x3xf32>) -> tensor<13x21x3xf32>
+ return %0 : tensor<13x21x3xf32>
+}
+
// -----
// CHECK-LABEL: resize
func.func @test_resize(%arg0: tensor<1x32x32x8xf32>) -> tensor<1x64x64x8xf32> {
More information about the Mlir-commits
mailing list