[Mlir-commits] [mlir] [mlir][tosa] Add expected output shape check to argmax verifier (PR #129870)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Mar 5 03:22:55 PST 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Luke Hutton (lhutton1)
<details>
<summary>Changes</summary>
Fixes some test cases which incorrectly declared the output shape and added a negative test case.
---
Full diff: https://github.com/llvm/llvm-project/pull/129870.diff
4 Files Affected:
- (modified) mlir/lib/Dialect/Tosa/IR/TosaOps.cpp (+21-4)
- (modified) mlir/test/Dialect/Tosa/canonicalize.mlir (+3-3)
- (modified) mlir/test/Dialect/Tosa/constrained_shapes.mlir (+3-3)
- (modified) mlir/test/Dialect/Tosa/invalid.mlir (+8)
``````````diff
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 800968e6f4766..bd5c5e56398c1 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -438,17 +438,34 @@ static LogicalResult verifySameElementTypes(T op, Type inType, Type outType) {
}
LogicalResult tosa::ArgMaxOp::verify() {
+ const ShapedType resultType = llvm::cast<ShapedType>(getType());
+
// Ensure output is of 32-bit integer
- const auto resultETy = llvm::cast<ShapedType>(getType()).getElementType();
- if (!resultETy.isIntOrIndex())
+ if (const auto resultETy = resultType.getElementType();
+ !resultETy.isIntOrIndex())
return emitOpError("result tensor is not of integer type");
- // Ensure axis is within the tensor rank
const auto inputType = llvm::cast<ShapedType>(getInput().getType());
+ if (!inputType.hasRank())
+ return success();
+
+ // Ensure axis is within the tensor rank
const int64_t axis = getAxisAttr().getInt();
- if (inputType.hasRank() && ((axis < 0) || axis >= inputType.getRank()))
+ if (((axis < 0) || axis >= inputType.getRank()))
return emitOpError("specified axis is outside the rank of the tensor");
+ if (!resultType.hasRank())
+ return success();
+
+ const ArrayRef<int64_t> inputShape = inputType.getShape();
+ const ArrayRef<int64_t> outputShape = resultType.getShape();
+ llvm::SmallVector<int64_t> expectedOutputShape(inputShape.begin(),
+ inputShape.end());
+ expectedOutputShape.erase(expectedOutputShape.begin() + axis);
+ if (failed(verifyCompatibleShape(expectedOutputShape, outputShape)))
+ return emitOpError("expected output shape '")
+ << expectedOutputShape << "', got '" << outputShape << "'";
+
return success();
}
diff --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir
index a0184e2d82704..09aba79647c79 100644
--- a/mlir/test/Dialect/Tosa/canonicalize.mlir
+++ b/mlir/test/Dialect/Tosa/canonicalize.mlir
@@ -1,10 +1,10 @@
// RUN: mlir-opt --split-input-file -canonicalize="test-convergence" %s | FileCheck %s
// CHECK-LABEL: @argmax_nofold
-func.func @argmax_nofold(%arg0: tensor<?x1xf32>) -> tensor<?x1xi32> {
+func.func @argmax_nofold(%arg0: tensor<?x1xf32>) -> tensor<1xi32> {
// CHECK: tosa.argmax
- %0 = tosa.argmax %arg0 {axis = 0 : i32}: (tensor<?x1xf32>) -> tensor<?x1xi32>
- return %0 : tensor<?x1xi32>
+ %0 = tosa.argmax %arg0 {axis = 0 : i32}: (tensor<?x1xf32>) -> tensor<1xi32>
+ return %0 : tensor<1xi32>
}
// -----
diff --git a/mlir/test/Dialect/Tosa/constrained_shapes.mlir b/mlir/test/Dialect/Tosa/constrained_shapes.mlir
index 8c3ad828ab06f..e06efbbfa1ad9 100644
--- a/mlir/test/Dialect/Tosa/constrained_shapes.mlir
+++ b/mlir/test/Dialect/Tosa/constrained_shapes.mlir
@@ -5,7 +5,7 @@
// -----
// Uses argmax as canonical example to validate constrained TOSA tensor shapes.
// CHECK-LABEL: argmax
-func.func @test_argmax(%arg0: tensor<?xf32>) -> tensor<?xi32> {
- %0 = "tosa.argmax"(%arg0) {axis = 0 : i32} : (tensor<?xf32>) -> tensor<?xi32>
- return %0 : tensor<?xi32>
+func.func @test_argmax(%arg0: tensor<?xf32>) -> tensor<i32> {
+ %0 = "tosa.argmax"(%arg0) {axis = 0 : i32} : (tensor<?xf32>) -> tensor<i32>
+ return %0 : tensor<i32>
}
diff --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir
index e665510ff0143..76093b0b3c1ca 100644
--- a/mlir/test/Dialect/Tosa/invalid.mlir
+++ b/mlir/test/Dialect/Tosa/invalid.mlir
@@ -1392,3 +1392,11 @@ func.func @test_rfft2d_width_input_output_match(%arg0: tensor<1x4x8xf16>) -> (te
%0, %1 = tosa.rfft2d %arg0 {inverse = false} : (tensor<1x4x8xf16>) -> (tensor<1x4x3xf16>, tensor<1x4x3xf16>)
return %0, %1 : tensor<1x4x3xf16>, tensor<1x4x3xf16>
}
+
+// -----
+
+func.func @test_argmax_invalid_output_shape(%arg0: tensor<1x2x3xf32>) -> tensor<1x2x3xf32> {
+ // expected-error at +1 {{'tosa.argmax' op expected output shape '2, 3', got '1, 2, 3'}}
+ %0 = tosa.argmax %arg0 {axis = 0 : i32}: (tensor<1x2x3xf32>) -> tensor<1x2x3xi32>
+ return %0 : tensor<1x2x3xi32>
+}
``````````
</details>
https://github.com/llvm/llvm-project/pull/129870
More information about the Mlir-commits
mailing list