[Mlir-commits] [mlir] 414709e - [mlir][tosa] Add verifier for `ArgMax` operator (#68410)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Oct 9 08:55:23 PDT 2023


Author: Georgios Pinitas
Date: 2023-10-09T08:55:18-07:00
New Revision: 414709ee5c74ae35ac5eb204007fef35cc1c8079

URL: https://github.com/llvm/llvm-project/commit/414709ee5c74ae35ac5eb204007fef35cc1c8079
DIFF: https://github.com/llvm/llvm-project/commit/414709ee5c74ae35ac5eb204007fef35cc1c8079.diff

LOG: [mlir][tosa] Add verifier for `ArgMax` operator (#68410)

Verifier ensures that operator is valid by checking:

* Output type is I32
* Axis is within the rank of the tensor

Signed-off-by: Georgios Pinitas <georgios.pinitas at arm.com>

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
    mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
    mlir/test/Dialect/Tosa/canonicalize.mlir
    mlir/test/Dialect/Tosa/constrained_shapes.mlir
    mlir/test/Dialect/Tosa/level_check.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index f4d9a251fb97839..a80111aedfe0b59 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -48,6 +48,8 @@ def Tosa_ArgMaxOp : Tosa_InferShapedTypeOp<"argmax"> {
   let results = (outs
     Tosa_Tensor: $output
   );
+
+  let hasVerifier = 1;
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 0b92a3cb7a6203d..a719171b2b359d2 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -211,6 +211,21 @@ template <typename T> static LogicalResult verifyConvOp(T op) {
   return success();
 }
 
+LogicalResult tosa::ArgMaxOp::verify() {
+  // Ensure output is of 32-bit integer
+  const auto resultETy = llvm::cast<ShapedType>(getType()).getElementType();
+  if (!resultETy.isIntOrIndex())
+    return emitOpError("result tensor is not of integer type");
+
+  // Ensure axis is within the tensor rank
+  const auto inputType = llvm::cast<ShapedType>(getInput().getType());
+  const int64_t axis = getAxisAttr().getInt();
+  if (inputType.hasRank() && ((axis < 0) || axis >= inputType.getRank()))
+    return emitOpError("specified axis is outside the rank of the tensor");
+
+  return success();
+}
+
 LogicalResult tosa::AvgPool2dOp::verify() {
   auto inputType = llvm::cast<ShapedType>(getInput().getType());
   if (hasZeroDimension(inputType))

diff  --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir
index 323864ea9013048..d36cf6a1d94a9f3 100644
--- a/mlir/test/Dialect/Tosa/canonicalize.mlir
+++ b/mlir/test/Dialect/Tosa/canonicalize.mlir
@@ -1,10 +1,10 @@
 // RUN: mlir-opt -canonicalize="test-convergence" %s | FileCheck %s
 
 // CHECK-LABEL: @argmax_nofold
-func.func @argmax_nofold(%arg0: tensor<?x1xf32>) -> tensor<?x1xf32> {
+func.func @argmax_nofold(%arg0: tensor<?x1xf32>) -> tensor<?x1xi32> {
   // CHECK: tosa.argmax
-  %0 = tosa.argmax %arg0 {axis = 0 : i32}: (tensor<?x1xf32>) -> tensor<?x1xf32>
-  return %0 : tensor<?x1xf32>
+  %0 = tosa.argmax %arg0 {axis = 0 : i32}: (tensor<?x1xf32>) -> tensor<?x1xi32>
+  return %0 : tensor<?x1xi32>
 }
 
 // CHECK-LABEL: @add_bcast_zero_int

diff  --git a/mlir/test/Dialect/Tosa/constrained_shapes.mlir b/mlir/test/Dialect/Tosa/constrained_shapes.mlir
index 9acb024cf78d005..8c3ad828ab06f81 100644
--- a/mlir/test/Dialect/Tosa/constrained_shapes.mlir
+++ b/mlir/test/Dialect/Tosa/constrained_shapes.mlir
@@ -6,6 +6,6 @@
 // Uses argmax as canonical example to validate constrained TOSA tensor shapes.
 // CHECK-LABEL: argmax
 func.func @test_argmax(%arg0: tensor<?xf32>) -> tensor<?xi32> {
-  %0 = "tosa.argmax"(%arg0) {axis = 1 : i32} : (tensor<?xf32>) -> tensor<?xi32>
+  %0 = "tosa.argmax"(%arg0) {axis = 0 : i32} : (tensor<?xf32>) -> tensor<?xi32>
   return %0 : tensor<?xi32>
 }

diff  --git a/mlir/test/Dialect/Tosa/level_check.mlir b/mlir/test/Dialect/Tosa/level_check.mlir
index e7fdf8af409b564..68238087f5c2523 100644
--- a/mlir/test/Dialect/Tosa/level_check.mlir
+++ b/mlir/test/Dialect/Tosa/level_check.mlir
@@ -1,10 +1,10 @@
 // RUN: mlir-opt %s -split-input-file -verify-diagnostics --tosa-validate
 
 
-func.func @test_argmax(%arg0: tensor<1x1x1x1x29x29x4xf32>) -> tensor<1x1x1x1x29x4xf32> {
+func.func @test_argmax(%arg0: tensor<1x1x1x1x29x29x4xf32>) -> tensor<1x1x1x1x29x4xi32> {
   // expected-error at +1 {{'tosa.argmax' op failed level check: operand rank(shape) <= MAX_RANK}}
-  %0 = "tosa.argmax"(%arg0) {axis = 4 : i32} : (tensor<1x1x1x1x29x29x4xf32>) -> tensor<1x1x1x1x29x4xf32>
-  return %0 : tensor<1x1x1x1x29x4xf32>
+  %0 = "tosa.argmax"(%arg0) {axis = 4 : i32} : (tensor<1x1x1x1x29x29x4xf32>) -> tensor<1x1x1x1x29x4xi32>
+  return %0 : tensor<1x1x1x1x29x4xi32>
 }
 
 // -----


        


More information about the Mlir-commits mailing list