[Mlir-commits] [mlir] [mlir][tosa] Add verifier for `ArgMax` operator (PR #68410)

Georgios Pinitas llvmlistbot at llvm.org
Mon Oct 9 07:25:15 PDT 2023


https://github.com/GeorgeARM updated https://github.com/llvm/llvm-project/pull/68410

>From 69ebdcd4a57e0086e552855001bb2ca0e30a82f7 Mon Sep 17 00:00:00 2001
From: Georgios Pinitas <georgios.pinitas at arm.com>
Date: Thu, 5 Oct 2023 15:19:04 +0100
Subject: [PATCH] [mlir][tosa] Add verifier for `ArgMax` operator

Verifier ensures that operator is valid by checking:

* Output type is of integer type
* Axis is within the rank of the tensor

Signed-off-by: Georgios Pinitas <georgios.pinitas at arm.com>
---
 mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td   |  2 ++
 mlir/lib/Dialect/Tosa/IR/TosaOps.cpp           | 15 +++++++++++++++
 mlir/test/Dialect/Tosa/canonicalize.mlir       |  6 +++---
 mlir/test/Dialect/Tosa/constrained_shapes.mlir |  2 +-
 mlir/test/Dialect/Tosa/level_check.mlir        |  6 +++---
 5 files changed, 24 insertions(+), 7 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index f4d9a251fb97839..a80111aedfe0b59 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -48,6 +48,8 @@ def Tosa_ArgMaxOp : Tosa_InferShapedTypeOp<"argmax"> {
   let results = (outs
     Tosa_Tensor: $output
   );
+
+  let hasVerifier = 1;
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 0b92a3cb7a6203d..a719171b2b359d2 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -211,6 +211,21 @@ template <typename T> static LogicalResult verifyConvOp(T op) {
   return success();
 }
 
+LogicalResult tosa::ArgMaxOp::verify() {
+  // Ensure output is of 32-bit integer
+  const auto resultETy = llvm::cast<ShapedType>(getType()).getElementType();
+  if (!resultETy.isIntOrIndex())
+    return emitOpError("result tensor is not of integer type");
+
+  // Ensure axis is within the tensor rank
+  const auto inputType = llvm::cast<ShapedType>(getInput().getType());
+  const int64_t axis = getAxisAttr().getInt();
+  if (inputType.hasRank() && ((axis < 0) || axis >= inputType.getRank()))
+    return emitOpError("specified axis is outside the rank of the tensor");
+
+  return success();
+}
+
 LogicalResult tosa::AvgPool2dOp::verify() {
   auto inputType = llvm::cast<ShapedType>(getInput().getType());
   if (hasZeroDimension(inputType))
diff --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir
index 323864ea9013048..d36cf6a1d94a9f3 100644
--- a/mlir/test/Dialect/Tosa/canonicalize.mlir
+++ b/mlir/test/Dialect/Tosa/canonicalize.mlir
@@ -1,10 +1,10 @@
 // RUN: mlir-opt -canonicalize="test-convergence" %s | FileCheck %s
 
 // CHECK-LABEL: @argmax_nofold
-func.func @argmax_nofold(%arg0: tensor<?x1xf32>) -> tensor<?x1xf32> {
+func.func @argmax_nofold(%arg0: tensor<?x1xf32>) -> tensor<?x1xi32> {
   // CHECK: tosa.argmax
-  %0 = tosa.argmax %arg0 {axis = 0 : i32}: (tensor<?x1xf32>) -> tensor<?x1xf32>
-  return %0 : tensor<?x1xf32>
+  %0 = tosa.argmax %arg0 {axis = 0 : i32}: (tensor<?x1xf32>) -> tensor<?x1xi32>
+  return %0 : tensor<?x1xi32>
 }
 
 // CHECK-LABEL: @add_bcast_zero_int
diff --git a/mlir/test/Dialect/Tosa/constrained_shapes.mlir b/mlir/test/Dialect/Tosa/constrained_shapes.mlir
index 9acb024cf78d005..8c3ad828ab06f81 100644
--- a/mlir/test/Dialect/Tosa/constrained_shapes.mlir
+++ b/mlir/test/Dialect/Tosa/constrained_shapes.mlir
@@ -6,6 +6,6 @@
 // Uses argmax as canonical example to validate constrained TOSA tensor shapes.
 // CHECK-LABEL: argmax
 func.func @test_argmax(%arg0: tensor<?xf32>) -> tensor<?xi32> {
-  %0 = "tosa.argmax"(%arg0) {axis = 1 : i32} : (tensor<?xf32>) -> tensor<?xi32>
+  %0 = "tosa.argmax"(%arg0) {axis = 0 : i32} : (tensor<?xf32>) -> tensor<?xi32>
   return %0 : tensor<?xi32>
 }
diff --git a/mlir/test/Dialect/Tosa/level_check.mlir b/mlir/test/Dialect/Tosa/level_check.mlir
index e7fdf8af409b564..68238087f5c2523 100644
--- a/mlir/test/Dialect/Tosa/level_check.mlir
+++ b/mlir/test/Dialect/Tosa/level_check.mlir
@@ -1,10 +1,10 @@
 // RUN: mlir-opt %s -split-input-file -verify-diagnostics --tosa-validate
 
 
-func.func @test_argmax(%arg0: tensor<1x1x1x1x29x29x4xf32>) -> tensor<1x1x1x1x29x4xf32> {
+func.func @test_argmax(%arg0: tensor<1x1x1x1x29x29x4xf32>) -> tensor<1x1x1x1x29x4xi32> {
   // expected-error at +1 {{'tosa.argmax' op failed level check: operand rank(shape) <= MAX_RANK}}
-  %0 = "tosa.argmax"(%arg0) {axis = 4 : i32} : (tensor<1x1x1x1x29x29x4xf32>) -> tensor<1x1x1x1x29x4xf32>
-  return %0 : tensor<1x1x1x1x29x4xf32>
+  %0 = "tosa.argmax"(%arg0) {axis = 4 : i32} : (tensor<1x1x1x1x29x29x4xf32>) -> tensor<1x1x1x1x29x4xi32>
+  return %0 : tensor<1x1x1x1x29x4xi32>
 }
 
 // -----



More information about the Mlir-commits mailing list