[Mlir-commits] [mlir] [mlir][tosa] Allow unranked indices argument for gather/scatter (PR #140618)

Luke Hutton llvmlistbot at llvm.org
Mon May 19 13:58:06 PDT 2025


https://github.com/lhutton1 created https://github.com/llvm/llvm-project/pull/140618

This commit allows the indices argument for gather and scatter to be unranked. This can be computed during shape inference.

>From 4dbf2470125cbcc4bb303b132c4628dd4c8f4814 Mon Sep 17 00:00:00 2001
From: Luke Hutton <luke.hutton at arm.com>
Date: Mon, 19 May 2025 09:09:33 +0000
Subject: [PATCH] [mlir][tosa] Allow unranked indices argument for
 gather/scatter

This commit allows the indices argument for gather and scatter to be
unranked. This can be computed during shape inference.

Change-Id: Ibac56d27661f1a12a662df6c4e6660bd4d09df10
---
 mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td       |  4 ++--
 mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td |  2 ++
 mlir/test/Dialect/Tosa/ops.mlir                    | 14 ++++++++++++++
 3 files changed, 18 insertions(+), 2 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index 52bb0eb992b69..a10fe28cef853 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -2125,7 +2125,7 @@ def Tosa_GatherOp : Tosa_InferShapedTypeOp<"gather"> {
 
   let arguments = (ins
     Tosa_Tensor3D:$values,
-    TosaTensorRankOf<[Tosa_Int32], [2]>:$indices
+    Tosa_Int32Tensor2D:$indices
   );
 
   let results = (outs
@@ -2159,7 +2159,7 @@ def Tosa_ScatterOp : Tosa_InferShapedTypeOp<"scatter"> {
 
   let arguments = (ins
     Tosa_Tensor3D:$values_in,
-    TosaTensorRankOf<[Tosa_Int32], [2]>:$indices,
+    Tosa_Int32Tensor2D:$indices,
     Tosa_Tensor3D:$input
   );
 
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
index b9ac1ff705514..536551c8f8437 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
@@ -181,6 +181,8 @@ def Tosa_TensorUpto4D : AnyTypeOf<[
 
 def Tosa_Int32TensorUpto4D : AnyTypeOf<[
   Tosa_UnrankedTensor, TosaTensorRankOf<[Tosa_Int32], [0,1,2,3,4]>]>;
+def Tosa_Int32Tensor2D : AnyTypeOf<[
+  Tosa_UnrankedTensor, TosaTensorRankOf<[Tosa_Int32], [2]>]>;
 
 def Tosa_TensorAtLeast1D : AnyTypeOf<[
   Tosa_UnrankedTensor, TosaRankedTensorOf<[Tosa_AnyNumber], [AtLeastRankOne]>], "tosa-conformant tensor of at least rank 1", "::mlir::TensorType">;
diff --git a/mlir/test/Dialect/Tosa/ops.mlir b/mlir/test/Dialect/Tosa/ops.mlir
index e327ed900f45f..88b819325d02f 100644
--- a/mlir/test/Dialect/Tosa/ops.mlir
+++ b/mlir/test/Dialect/Tosa/ops.mlir
@@ -711,6 +711,20 @@ func.func @test_scatter(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x26xi32>, %a
   return %0 : tensor<13x21x3xf32>
 }
 
+// -----
+// CHECK-LABEL: gather_unranked_indices
+func.func @test_gather_unranked_indices(%arg0: tensor<13x21x3xf32>, %arg1: tensor<*xi32>) -> tensor<13x26x3xf32> {
+  %0 = tosa.gather %arg0, %arg1 : (tensor<13x21x3xf32>, tensor<*xi32>) -> tensor<13x26x3xf32>
+  return %0 : tensor<13x26x3xf32>
+}
+
+// -----
+// CHECK-LABEL: scatter_unranked_indices
+func.func @test_scatter_unranked_indices(%arg0: tensor<13x21x3xf32>, %arg1: tensor<*xi32>, %arg2: tensor<13x26x3xf32>) -> tensor<13x21x3xf32> {
+  %0 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<13x21x3xf32>, tensor<*xi32>, tensor<13x26x3xf32>) -> tensor<13x21x3xf32>
+  return %0 : tensor<13x21x3xf32>
+}
+
 // -----
 // CHECK-LABEL: resize
 func.func @test_resize(%arg0: tensor<1x32x32x8xf32>) -> tensor<1x64x64x8xf32> {



More information about the Mlir-commits mailing list