[Mlir-commits] [mlir] [mlir][tosa] Add verifier checks for Gather (PR #137204)

Tai Ly llvmlistbot at llvm.org
Thu Apr 24 08:52:03 PDT 2025


https://github.com/Tai78641 created https://github.com/llvm/llvm-project/pull/137204

This adds verifier checks for the gather op
to make sure the shapes of inputs and output
are consistent with respect to spec.


>From 008fc7b55deea196f5bb2a4068eec45f12b1f8ae Mon Sep 17 00:00:00 2001
From: Tai Ly <tai.ly at arm.com>
Date: Wed, 9 Apr 2025 01:29:47 +0000
Subject: [PATCH] [mlir][tosa] Add verifier checks for Gather

This adds verifier checks for the gather op
to make sure the shapes of inputs and output
are consistent with respect to spec.

Signed-off-by: Tai Ly <tai.ly at arm.com>
Change-Id: I16685bceef25f428669c5412d897b6918a424119
---
 mlir/lib/Dialect/Tosa/IR/TosaOps.cpp | 48 ++++++++++++++++++++++++++--
 mlir/test/Dialect/Tosa/verifier.mlir | 32 +++++++++++++++++++
 2 files changed, 78 insertions(+), 2 deletions(-)

diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 751ae785bda6f..22aca774a403d 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -2262,8 +2262,52 @@ LogicalResult tosa::GatherOp::inferReturnTypeComponents(
 }
 
 LogicalResult tosa::GatherOp::verify() {
-  return verifySameElementTypes(*this, /* inType = */ getValues().getType(),
-                                /* outType = */ getOutput().getType());
+  if (verifySameElementTypes(*this, /* inType = */ getValues().getType(),
+                             /* outType = */ getOutput().getType())
+          .failed()) {
+    return failure();
+  }
+
+  const ShapeAdaptor valuesShape(getValues().getType());
+  const ShapeAdaptor indicesShape(getIndices().getType());
+  const ShapeAdaptor outputShape(getOutput().getType());
+
+  int64_t N = ShapedType::kDynamic;
+  int64_t W = ShapedType::kDynamic;
+  int64_t C = ShapedType::kDynamic;
+
+  if (valuesShape.hasRank()) {
+    N = valuesShape.getDimSize(0);
+    C = valuesShape.getDimSize(2);
+  }
+  if (indicesShape.hasRank()) {
+    const int64_t indicesN = indicesShape.getDimSize(0);
+    W = indicesShape.getDimSize(1);
+    if (N == ShapedType::kDynamic)
+      N = indicesN;
+    else if (indicesN != ShapedType::kDynamic && N != indicesN)
+      return emitOpError() << "requires indices dimension 0 to have size " << N
+                           << ", got " << indicesN;
+  }
+  if (outputShape.hasRank()) {
+    const int64_t outputN = outputShape.getDimSize(0);
+    const int64_t outputW = outputShape.getDimSize(1);
+    const int64_t outputC = outputShape.getDimSize(2);
+    if (N != ShapedType::kDynamic && outputN != ShapedType::kDynamic &&
+        N != outputN)
+      return emitOpError() << "requires output dimension 0 to have size " << N
+                           << ", got " << outputN;
+
+    if (W != ShapedType::kDynamic && outputW != ShapedType::kDynamic &&
+        W != outputW)
+      return emitOpError() << "requires output dimension 1 to have size " << W
+                           << ", got " << outputW;
+    if (C != ShapedType::kDynamic && outputC != ShapedType::kDynamic &&
+        C != outputC)
+      return emitOpError() << "requires output dimension 2 to have size " << C
+                           << ", got " << outputC;
+  }
+  return success();
 }
 
 LogicalResult tosa::ResizeOp::inferReturnTypeComponents(
diff --git a/mlir/test/Dialect/Tosa/verifier.mlir b/mlir/test/Dialect/Tosa/verifier.mlir
index 262e6d4265ea6..2b78773e4ed7f 100644
--- a/mlir/test/Dialect/Tosa/verifier.mlir
+++ b/mlir/test/Dialect/Tosa/verifier.mlir
@@ -358,3 +358,35 @@ func.func @test_concat_axis_sum_error(%arg0: tensor<1x2xf32>, %arg1: tensor<2x?x
   %0 = tosa.concat %arg0, %arg1 {axis = 0 : i32} : (tensor<1x2xf32>, tensor<2x?xf32>) -> tensor<2x?xf32>
   return %0 : tensor<2x?xf32>
 }
+
+// -----
+// CHECK-LABEL: @test_gather_invalid_indices_N
+func.func @test_gather_invalid_indices_N(%arg0: tensor<13x21x3xf32>, %arg1: tensor<12x26xi32>) -> tensor<13x26x3xf32> {
+  // expected-error at +1 {{'tosa.gather' op requires indices dimension 0 to have size 13, got 12}}
+  %0 = tosa.gather %arg0, %arg1 : (tensor<13x21x3xf32>, tensor<12x26xi32>) -> tensor<13x26x3xf32>
+  return %0 : tensor<13x26x3xf32>
+}
+
+// -----
+// CHECK-LABEL: test_gather_invalid_out_N
+func.func @test_gather_invalid_out_N(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x26xi32>) -> tensor<12x26x3xf32> {
+  // expected-error at +1 {{'tosa.gather' op requires output dimension 0 to have size 13, got 12}}
+  %0 = tosa.gather %arg0, %arg1 : (tensor<13x21x3xf32>, tensor<13x26xi32>) -> tensor<12x26x3xf32>
+  return %0 : tensor<12x26x3xf32>
+}
+
+// -----
+// CHECK-LABEL: test_gather_invalid_out_W
+func.func @test_gather_invalid_out_W(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x26xi32>) -> tensor<13x28x3xf32> {
+  // expected-error at +1 {{'tosa.gather' op requires output dimension 1 to have size 26, got 28}}
+  %0 = tosa.gather %arg0, %arg1 : (tensor<13x21x3xf32>, tensor<13x26xi32>) -> tensor<13x28x3xf32>
+  return %0 : tensor<13x28x3xf32>
+}
+
+// -----
+// CHECK-LABEL: test_gather_invalid_out_C
+func.func @test_gather_invalid_out_C(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x26xi32>) -> tensor<13x26x8xf32> {
+  // expected-error at +1 {{'tosa.gather' op requires output dimension 2 to have size 3, got 8}}
+  %0 = tosa.gather %arg0, %arg1 : (tensor<13x21x3xf32>, tensor<13x26xi32>) -> tensor<13x26x8xf32>
+  return %0 : tensor<13x26x8xf32>
+}



More information about the Mlir-commits mailing list