[Mlir-commits] [mlir] [mlir][tosa] Allow unranked indices argument for gather/scatter (PR #140618)
Luke Hutton
llvmlistbot at llvm.org
Mon May 19 13:58:06 PDT 2025
https://github.com/lhutton1 created https://github.com/llvm/llvm-project/pull/140618
This commit allows the indices argument for gather and scatter to be unranked. This can be computed during shape inference.
>From 4dbf2470125cbcc4bb303b132c4628dd4c8f4814 Mon Sep 17 00:00:00 2001
From: Luke Hutton <luke.hutton at arm.com>
Date: Mon, 19 May 2025 09:09:33 +0000
Subject: [PATCH] [mlir][tosa] Allow unranked indices argument for
gather/scatter
This commit allows the indices argument for gather and scatter to be
unranked. This can be computed during shape inference.
Change-Id: Ibac56d27661f1a12a662df6c4e6660bd4d09df10
---
mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td | 4 ++--
mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td | 2 ++
mlir/test/Dialect/Tosa/ops.mlir | 14 ++++++++++++++
3 files changed, 18 insertions(+), 2 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index 52bb0eb992b69..a10fe28cef853 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -2125,7 +2125,7 @@ def Tosa_GatherOp : Tosa_InferShapedTypeOp<"gather"> {
let arguments = (ins
Tosa_Tensor3D:$values,
- TosaTensorRankOf<[Tosa_Int32], [2]>:$indices
+ Tosa_Int32Tensor2D:$indices
);
let results = (outs
@@ -2159,7 +2159,7 @@ def Tosa_ScatterOp : Tosa_InferShapedTypeOp<"scatter"> {
let arguments = (ins
Tosa_Tensor3D:$values_in,
- TosaTensorRankOf<[Tosa_Int32], [2]>:$indices,
+ Tosa_Int32Tensor2D:$indices,
Tosa_Tensor3D:$input
);
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
index b9ac1ff705514..536551c8f8437 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
@@ -181,6 +181,8 @@ def Tosa_TensorUpto4D : AnyTypeOf<[
def Tosa_Int32TensorUpto4D : AnyTypeOf<[
Tosa_UnrankedTensor, TosaTensorRankOf<[Tosa_Int32], [0,1,2,3,4]>]>;
+def Tosa_Int32Tensor2D : AnyTypeOf<[
+ Tosa_UnrankedTensor, TosaTensorRankOf<[Tosa_Int32], [2]>]>;
def Tosa_TensorAtLeast1D : AnyTypeOf<[
Tosa_UnrankedTensor, TosaRankedTensorOf<[Tosa_AnyNumber], [AtLeastRankOne]>], "tosa-conformant tensor of at least rank 1", "::mlir::TensorType">;
diff --git a/mlir/test/Dialect/Tosa/ops.mlir b/mlir/test/Dialect/Tosa/ops.mlir
index e327ed900f45f..88b819325d02f 100644
--- a/mlir/test/Dialect/Tosa/ops.mlir
+++ b/mlir/test/Dialect/Tosa/ops.mlir
@@ -711,6 +711,20 @@ func.func @test_scatter(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x26xi32>, %a
return %0 : tensor<13x21x3xf32>
}
+// -----
+// CHECK-LABEL: gather_unranked_indices
+func.func @test_gather_unranked_indices(%arg0: tensor<13x21x3xf32>, %arg1: tensor<*xi32>) -> tensor<13x26x3xf32> {
+ %0 = tosa.gather %arg0, %arg1 : (tensor<13x21x3xf32>, tensor<*xi32>) -> tensor<13x26x3xf32>
+ return %0 : tensor<13x26x3xf32>
+}
+
+// -----
+// CHECK-LABEL: scatter_unranked_indices
+func.func @test_scatter_unranked_indices(%arg0: tensor<13x21x3xf32>, %arg1: tensor<*xi32>, %arg2: tensor<13x26x3xf32>) -> tensor<13x21x3xf32> {
+ %0 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<13x21x3xf32>, tensor<*xi32>, tensor<13x26x3xf32>) -> tensor<13x21x3xf32>
+ return %0 : tensor<13x21x3xf32>
+}
+
// -----
// CHECK-LABEL: resize
func.func @test_resize(%arg0: tensor<1x32x32x8xf32>) -> tensor<1x64x64x8xf32> {
More information about the Mlir-commits
mailing list