[Mlir-commits] [mlir] [mlir][tosa] Add verifier for `ArgMax` operator (PR #68410)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Oct 6 04:54:52 PDT 2023


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-tosa

<details>
<summary>Changes</summary>

Verifier ensures that operator is valid by checking:

* Output type is I32
* Axis is within the rank of the tensor

---
Full diff: https://github.com/llvm/llvm-project/pull/68410.diff


5 Files Affected:

- (modified) mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td (+2) 
- (modified) mlir/lib/Dialect/Tosa/IR/TosaOps.cpp (+15) 
- (modified) mlir/test/Dialect/Tosa/canonicalize.mlir (+3-3) 
- (modified) mlir/test/Dialect/Tosa/constrained_shapes.mlir (+1-1) 
- (modified) mlir/test/Dialect/Tosa/level_check.mlir (+3-3) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index f4d9a251fb97839..a80111aedfe0b59 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -48,6 +48,8 @@ def Tosa_ArgMaxOp : Tosa_InferShapedTypeOp<"argmax"> {
   let results = (outs
     Tosa_Tensor: $output
   );
+
+  let hasVerifier = 1;
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 0b92a3cb7a6203d..af112aa65e2a371 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -211,6 +211,21 @@ template <typename T> static LogicalResult verifyConvOp(T op) {
   return success();
 }
 
+LogicalResult tosa::ArgMaxOp::verify() {
+  // Ensure output is of 32-bit integer
+  const auto resultETy = llvm::cast<ShapedType>(getType()).getElementType();
+  if (!resultETy.isInteger(32))
+    return emitOpError("result tensor is not i32");
+
+  // Ensure axis is within the tensor rank
+  const auto inputType = llvm::cast<ShapedType>(getInput().getType());
+  const int64_t axis = getAxisAttr().getInt();
+  if (inputType.hasRank() && ((axis < 0) || axis >= inputType.getRank()))
+    return emitOpError("specified axis is outside the rank of the tensor");
+
+  return success();
+}
+
 LogicalResult tosa::AvgPool2dOp::verify() {
   auto inputType = llvm::cast<ShapedType>(getInput().getType());
   if (hasZeroDimension(inputType))
diff --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir
index 323864ea9013048..d36cf6a1d94a9f3 100644
--- a/mlir/test/Dialect/Tosa/canonicalize.mlir
+++ b/mlir/test/Dialect/Tosa/canonicalize.mlir
@@ -1,10 +1,10 @@
 // RUN: mlir-opt -canonicalize="test-convergence" %s | FileCheck %s
 
 // CHECK-LABEL: @argmax_nofold
-func.func @argmax_nofold(%arg0: tensor<?x1xf32>) -> tensor<?x1xf32> {
+func.func @argmax_nofold(%arg0: tensor<?x1xf32>) -> tensor<?x1xi32> {
   // CHECK: tosa.argmax
-  %0 = tosa.argmax %arg0 {axis = 0 : i32}: (tensor<?x1xf32>) -> tensor<?x1xf32>
-  return %0 : tensor<?x1xf32>
+  %0 = tosa.argmax %arg0 {axis = 0 : i32}: (tensor<?x1xf32>) -> tensor<?x1xi32>
+  return %0 : tensor<?x1xi32>
 }
 
 // CHECK-LABEL: @add_bcast_zero_int
diff --git a/mlir/test/Dialect/Tosa/constrained_shapes.mlir b/mlir/test/Dialect/Tosa/constrained_shapes.mlir
index 9acb024cf78d005..8c3ad828ab06f81 100644
--- a/mlir/test/Dialect/Tosa/constrained_shapes.mlir
+++ b/mlir/test/Dialect/Tosa/constrained_shapes.mlir
@@ -6,6 +6,6 @@
 // Uses argmax as canonical example to validate constrained TOSA tensor shapes.
 // CHECK-LABEL: argmax
 func.func @test_argmax(%arg0: tensor<?xf32>) -> tensor<?xi32> {
-  %0 = "tosa.argmax"(%arg0) {axis = 1 : i32} : (tensor<?xf32>) -> tensor<?xi32>
+  %0 = "tosa.argmax"(%arg0) {axis = 0 : i32} : (tensor<?xf32>) -> tensor<?xi32>
   return %0 : tensor<?xi32>
 }
diff --git a/mlir/test/Dialect/Tosa/level_check.mlir b/mlir/test/Dialect/Tosa/level_check.mlir
index e7fdf8af409b564..68238087f5c2523 100644
--- a/mlir/test/Dialect/Tosa/level_check.mlir
+++ b/mlir/test/Dialect/Tosa/level_check.mlir
@@ -1,10 +1,10 @@
 // RUN: mlir-opt %s -split-input-file -verify-diagnostics --tosa-validate
 
 
-func.func @test_argmax(%arg0: tensor<1x1x1x1x29x29x4xf32>) -> tensor<1x1x1x1x29x4xf32> {
+func.func @test_argmax(%arg0: tensor<1x1x1x1x29x29x4xf32>) -> tensor<1x1x1x1x29x4xi32> {
   // expected-error at +1 {{'tosa.argmax' op failed level check: operand rank(shape) <= MAX_RANK}}
-  %0 = "tosa.argmax"(%arg0) {axis = 4 : i32} : (tensor<1x1x1x1x29x29x4xf32>) -> tensor<1x1x1x1x29x4xf32>
-  return %0 : tensor<1x1x1x1x29x4xf32>
+  %0 = "tosa.argmax"(%arg0) {axis = 4 : i32} : (tensor<1x1x1x1x29x29x4xf32>) -> tensor<1x1x1x1x29x4xi32>
+  return %0 : tensor<1x1x1x1x29x4xi32>
 }
 
 // -----

``````````

</details>


https://github.com/llvm/llvm-project/pull/68410


More information about the Mlir-commits mailing list