[Mlir-commits] [mlir] [mlir][tosa] Add verifier for `ArgMax` operator (PR #68410)
Georgios Pinitas
llvmlistbot at llvm.org
Fri Oct 6 04:53:46 PDT 2023
https://github.com/GeorgeARM created https://github.com/llvm/llvm-project/pull/68410
Verifier ensures that operator is valid by checking:
* Output type is I32
* Axis is within the rank of the tensor
>From a0b027a2fd8e19654b0bd2278ddf5f891e084d09 Mon Sep 17 00:00:00 2001
From: Georgios Pinitas <georgios.pinitas at arm.com>
Date: Thu, 5 Oct 2023 15:19:04 +0100
Subject: [PATCH] [mlir][tosa] Add verifier for `ArgMax` operator
Verifier ensures that operator is valid by checking:
* Output type is I32
* Axis is within the rank of the tensor
Signed-off-by: Georgios Pinitas <georgios.pinitas at arm.com>
---
mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td | 2 ++
mlir/lib/Dialect/Tosa/IR/TosaOps.cpp | 15 +++++++++++++++
mlir/test/Dialect/Tosa/canonicalize.mlir | 6 +++---
mlir/test/Dialect/Tosa/constrained_shapes.mlir | 2 +-
mlir/test/Dialect/Tosa/level_check.mlir | 6 +++---
5 files changed, 24 insertions(+), 7 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index f4d9a251fb97839..a80111aedfe0b59 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -48,6 +48,8 @@ def Tosa_ArgMaxOp : Tosa_InferShapedTypeOp<"argmax"> {
let results = (outs
Tosa_Tensor: $output
);
+
+ let hasVerifier = 1;
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 0b92a3cb7a6203d..af112aa65e2a371 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -211,6 +211,21 @@ template <typename T> static LogicalResult verifyConvOp(T op) {
return success();
}
+LogicalResult tosa::ArgMaxOp::verify() {
+ // Ensure output is of 32-bit integer
+ const auto resultETy = llvm::cast<ShapedType>(getType()).getElementType();
+ if (!resultETy.isInteger(32))
+ return emitOpError("result tensor is not i32");
+
+ // Ensure axis is within the tensor rank
+ const auto inputType = llvm::cast<ShapedType>(getInput().getType());
+ const int64_t axis = getAxisAttr().getInt();
+ if (inputType.hasRank() && ((axis < 0) || axis >= inputType.getRank()))
+ return emitOpError("specified axis is outside the rank of the tensor");
+
+ return success();
+}
+
LogicalResult tosa::AvgPool2dOp::verify() {
auto inputType = llvm::cast<ShapedType>(getInput().getType());
if (hasZeroDimension(inputType))
diff --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir
index 323864ea9013048..d36cf6a1d94a9f3 100644
--- a/mlir/test/Dialect/Tosa/canonicalize.mlir
+++ b/mlir/test/Dialect/Tosa/canonicalize.mlir
@@ -1,10 +1,10 @@
// RUN: mlir-opt -canonicalize="test-convergence" %s | FileCheck %s
// CHECK-LABEL: @argmax_nofold
-func.func @argmax_nofold(%arg0: tensor<?x1xf32>) -> tensor<?x1xf32> {
+func.func @argmax_nofold(%arg0: tensor<?x1xf32>) -> tensor<?x1xi32> {
// CHECK: tosa.argmax
- %0 = tosa.argmax %arg0 {axis = 0 : i32}: (tensor<?x1xf32>) -> tensor<?x1xf32>
- return %0 : tensor<?x1xf32>
+ %0 = tosa.argmax %arg0 {axis = 0 : i32}: (tensor<?x1xf32>) -> tensor<?x1xi32>
+ return %0 : tensor<?x1xi32>
}
// CHECK-LABEL: @add_bcast_zero_int
diff --git a/mlir/test/Dialect/Tosa/constrained_shapes.mlir b/mlir/test/Dialect/Tosa/constrained_shapes.mlir
index 9acb024cf78d005..8c3ad828ab06f81 100644
--- a/mlir/test/Dialect/Tosa/constrained_shapes.mlir
+++ b/mlir/test/Dialect/Tosa/constrained_shapes.mlir
@@ -6,6 +6,6 @@
// Uses argmax as canonical example to validate constrained TOSA tensor shapes.
// CHECK-LABEL: argmax
func.func @test_argmax(%arg0: tensor<?xf32>) -> tensor<?xi32> {
- %0 = "tosa.argmax"(%arg0) {axis = 1 : i32} : (tensor<?xf32>) -> tensor<?xi32>
+ %0 = "tosa.argmax"(%arg0) {axis = 0 : i32} : (tensor<?xf32>) -> tensor<?xi32>
return %0 : tensor<?xi32>
}
diff --git a/mlir/test/Dialect/Tosa/level_check.mlir b/mlir/test/Dialect/Tosa/level_check.mlir
index e7fdf8af409b564..68238087f5c2523 100644
--- a/mlir/test/Dialect/Tosa/level_check.mlir
+++ b/mlir/test/Dialect/Tosa/level_check.mlir
@@ -1,10 +1,10 @@
// RUN: mlir-opt %s -split-input-file -verify-diagnostics --tosa-validate
-func.func @test_argmax(%arg0: tensor<1x1x1x1x29x29x4xf32>) -> tensor<1x1x1x1x29x4xf32> {
+func.func @test_argmax(%arg0: tensor<1x1x1x1x29x29x4xf32>) -> tensor<1x1x1x1x29x4xi32> {
// expected-error at +1 {{'tosa.argmax' op failed level check: operand rank(shape) <= MAX_RANK}}
- %0 = "tosa.argmax"(%arg0) {axis = 4 : i32} : (tensor<1x1x1x1x29x29x4xf32>) -> tensor<1x1x1x1x29x4xf32>
- return %0 : tensor<1x1x1x1x29x4xf32>
+ %0 = "tosa.argmax"(%arg0) {axis = 4 : i32} : (tensor<1x1x1x1x29x29x4xf32>) -> tensor<1x1x1x1x29x4xi32>
+ return %0 : tensor<1x1x1x1x29x4xi32>
}
// -----
More information about the Mlir-commits
mailing list