[Mlir-commits] [mlir] [mlir][tosa] Add expected output shape check to argmax verifier (PR #129870)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Mar 5 03:22:55 PST 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-tosa

Author: Luke Hutton (lhutton1)

<details>
<summary>Changes</summary>

Fixes some test cases which incorrectly declared the output shape and added a negative test case.

---
Full diff: https://github.com/llvm/llvm-project/pull/129870.diff


4 Files Affected:

- (modified) mlir/lib/Dialect/Tosa/IR/TosaOps.cpp (+21-4) 
- (modified) mlir/test/Dialect/Tosa/canonicalize.mlir (+3-3) 
- (modified) mlir/test/Dialect/Tosa/constrained_shapes.mlir (+3-3) 
- (modified) mlir/test/Dialect/Tosa/invalid.mlir (+8) 


``````````diff
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 800968e6f4766..bd5c5e56398c1 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -438,17 +438,34 @@ static LogicalResult verifySameElementTypes(T op, Type inType, Type outType) {
 }
 
 LogicalResult tosa::ArgMaxOp::verify() {
+  const ShapedType resultType = llvm::cast<ShapedType>(getType());
+
   // Ensure output is of 32-bit integer
-  const auto resultETy = llvm::cast<ShapedType>(getType()).getElementType();
-  if (!resultETy.isIntOrIndex())
+  if (const auto resultETy = resultType.getElementType();
+      !resultETy.isIntOrIndex())
     return emitOpError("result tensor is not of integer type");
 
-  // Ensure axis is within the tensor rank
   const auto inputType = llvm::cast<ShapedType>(getInput().getType());
+  if (!inputType.hasRank())
+    return success();
+
+  // Ensure axis is within the tensor rank
   const int64_t axis = getAxisAttr().getInt();
-  if (inputType.hasRank() && ((axis < 0) || axis >= inputType.getRank()))
+  if (((axis < 0) || axis >= inputType.getRank()))
     return emitOpError("specified axis is outside the rank of the tensor");
 
+  if (!resultType.hasRank())
+    return success();
+
+  const ArrayRef<int64_t> inputShape = inputType.getShape();
+  const ArrayRef<int64_t> outputShape = resultType.getShape();
+  llvm::SmallVector<int64_t> expectedOutputShape(inputShape.begin(),
+                                                 inputShape.end());
+  expectedOutputShape.erase(expectedOutputShape.begin() + axis);
+  if (failed(verifyCompatibleShape(expectedOutputShape, outputShape)))
+    return emitOpError("expected output shape '")
+           << expectedOutputShape << "', got '" << outputShape << "'";
+
   return success();
 }
 
diff --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir
index a0184e2d82704..09aba79647c79 100644
--- a/mlir/test/Dialect/Tosa/canonicalize.mlir
+++ b/mlir/test/Dialect/Tosa/canonicalize.mlir
@@ -1,10 +1,10 @@
 // RUN: mlir-opt --split-input-file -canonicalize="test-convergence" %s | FileCheck %s
 
 // CHECK-LABEL: @argmax_nofold
-func.func @argmax_nofold(%arg0: tensor<?x1xf32>) -> tensor<?x1xi32> {
+func.func @argmax_nofold(%arg0: tensor<?x1xf32>) -> tensor<1xi32> {
   // CHECK: tosa.argmax
-  %0 = tosa.argmax %arg0 {axis = 0 : i32}: (tensor<?x1xf32>) -> tensor<?x1xi32>
-  return %0 : tensor<?x1xi32>
+  %0 = tosa.argmax %arg0 {axis = 0 : i32}: (tensor<?x1xf32>) -> tensor<1xi32>
+  return %0 : tensor<1xi32>
 }
 
 // -----
diff --git a/mlir/test/Dialect/Tosa/constrained_shapes.mlir b/mlir/test/Dialect/Tosa/constrained_shapes.mlir
index 8c3ad828ab06f..e06efbbfa1ad9 100644
--- a/mlir/test/Dialect/Tosa/constrained_shapes.mlir
+++ b/mlir/test/Dialect/Tosa/constrained_shapes.mlir
@@ -5,7 +5,7 @@
 // -----
 // Uses argmax as canonical example to validate constrained TOSA tensor shapes.
 // CHECK-LABEL: argmax
-func.func @test_argmax(%arg0: tensor<?xf32>) -> tensor<?xi32> {
-  %0 = "tosa.argmax"(%arg0) {axis = 0 : i32} : (tensor<?xf32>) -> tensor<?xi32>
-  return %0 : tensor<?xi32>
+func.func @test_argmax(%arg0: tensor<?xf32>) -> tensor<i32> {
+  %0 = "tosa.argmax"(%arg0) {axis = 0 : i32} : (tensor<?xf32>) -> tensor<i32>
+  return %0 : tensor<i32>
 }
diff --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir
index e665510ff0143..76093b0b3c1ca 100644
--- a/mlir/test/Dialect/Tosa/invalid.mlir
+++ b/mlir/test/Dialect/Tosa/invalid.mlir
@@ -1392,3 +1392,11 @@ func.func @test_rfft2d_width_input_output_match(%arg0: tensor<1x4x8xf16>) -> (te
   %0, %1 = tosa.rfft2d %arg0 {inverse = false} : (tensor<1x4x8xf16>) -> (tensor<1x4x3xf16>, tensor<1x4x3xf16>)
   return %0, %1 : tensor<1x4x3xf16>, tensor<1x4x3xf16>
 }
+
+// -----
+
+func.func @test_argmax_invalid_output_shape(%arg0: tensor<1x2x3xf32>) -> tensor<1x2x3xf32> {
+  // expected-error at +1 {{'tosa.argmax' op expected output shape '2, 3', got '1, 2, 3'}}
+  %0 = tosa.argmax %arg0 {axis = 0 : i32}: (tensor<1x2x3xf32>) -> tensor<1x2x3xi32>
+  return %0 : tensor<1x2x3xi32>
+}

``````````

</details>


https://github.com/llvm/llvm-project/pull/129870


More information about the Mlir-commits mailing list