[Mlir-commits] [mlir] [mlir][tosa] Fix MulOp verifier handling for unranked operands (PR #141980)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu May 29 09:53:20 PDT 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-tosa
Author: Luke Hutton (lhutton1)
<details>
<summary>Changes</summary>
The previous verifier checks did not correctly handle unranked operands. For example, it could incorrectly assume the number of `rankedOperandTypes` would be >= 2, which isn't the case when both a and b are unranked.
This change simplifies these checks such that they only operate over the intended a and b operands as opposed to the shift operand as well.
---
Full diff: https://github.com/llvm/llvm-project/pull/141980.diff
3 Files Affected:
- (modified) mlir/lib/Dialect/Tosa/IR/TosaOps.cpp (+33-52)
- (modified) mlir/test/Dialect/Tosa/invalid.mlir (+28-1)
- (modified) mlir/test/Dialect/Tosa/ops.mlir (+16)
``````````diff
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 3ee5a85a21dca..298802fc7fa6c 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -1779,7 +1779,8 @@ LogicalResult tosa::MulOp::inferReturnTypeComponents(
}
LogicalResult tosa::MulOp::verify() {
- auto resElemType = getElementTypeOrSelf(getOutput());
+ const Value output = getOutput();
+ auto resElemType = getElementTypeOrSelf(output);
// Verify if the element type among operands and result match tosa
// specification.
@@ -1819,59 +1820,39 @@ LogicalResult tosa::MulOp::verify() {
// Verify the op has same ranks for all main operands (excludes extra operands
// such as shift of mul op, so this is the only difference with the built-in
// `SameOperandsAndResultRank` trait) and results types, if known.
-
- // delegate function that returns true if type is a shaped type with known
- // rank
- auto hasRank = [](const Type type) {
- if (auto shaped_type = dyn_cast<ShapedType>(type))
- return shaped_type.hasRank();
-
- return false;
- };
-
- auto rankedOperandTypes =
- llvm::to_vector(llvm::make_filter_range(getOperandTypes(), hasRank));
-
- auto rankedResultTypes =
- llvm::make_filter_range(getOperation()->getResultTypes(), hasRank);
-
- // If all operands and results are unranked, then no further verification.
- if (rankedOperandTypes.empty() && rankedResultTypes.empty())
+ TypeRange operandTypes = getOperandTypes();
+ ShapedType aType = cast<ShapedType>(operandTypes[0]);
+ ShapedType bType = cast<ShapedType>(operandTypes[1]);
+
+ const bool aHasRank = aType.hasRank();
+ const bool bHasRank = bType.hasRank();
+ if (aHasRank && bHasRank) {
+ const int64_t aRank = aType.getRank();
+ const int64_t bRank = bType.getRank();
+ if (aRank != bRank)
+ return emitOpError("a and b operands don't have matching ranks, got ")
+ << aRank << " and " << bRank;
+
+ // check for broadcast compatible shapes
+ SmallVector<int64_t> resultShape;
+ if (!mlir::OpTrait::util::getBroadcastedShape(
+ aType.getShape(), bType.getShape(), resultShape))
+ return emitOpError("a and b operands don't have broadcast-compatible "
+ "shapes, got ")
+ << aType << " and " << bType;
+ }
+
+ ShapedType resultType = cast<ShapedType>(output.getType());
+ if (!resultType.hasRank())
return success();
- // delegate function that returns rank of shaped type with known rank
- auto getRank = [](const Type type) {
- return cast<ShapedType>(type).getRank();
- };
-
- auto rank = !rankedOperandTypes.empty() ? getRank(*rankedOperandTypes.begin())
- : getRank(*rankedResultTypes.begin());
-
- for (size_t i = 0; i < 2; ++i) {
- if (rank != getRank(rankedOperandTypes[i])) {
- return emitOpError("operands don't have matching ranks");
- }
- }
-
- for (const auto type : rankedResultTypes) {
- if (rank != getRank(type)) {
- return emitOpError("result type has different rank than operands");
- }
- }
-
- // check for broadcast compatible shapes in first two operands (ignoring
- // shift)
-
- // delegate function that returns shape of shaped type
- auto getShape = [](const Type type) {
- return mlir::cast<ShapedType>(type).getShape();
- };
- SmallVector<int64_t> resultShape;
- if (!mlir::OpTrait::util::getBroadcastedShape(getShape(rankedOperandTypes[0]),
- getShape(rankedOperandTypes[1]),
- resultShape)) {
- return emitOpError("operands don't have broadcast-compatible shapes");
- }
+ const int64_t resultRank = resultType.getRank();
+ if (aHasRank && resultRank != aType.getRank())
+ return emitOpError("result type has different rank than a, got ")
+ << resultRank << " vs " << aType.getRank();
+ if (bHasRank && resultRank != bType.getRank())
+ return emitOpError("result type has different rank than b, got ")
+ << resultRank << " vs " << bType.getRank();
return success();
}
diff --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir
index 7b589fa839b44..3298e518de2f5 100644
--- a/mlir/test/Dialect/Tosa/invalid.mlir
+++ b/mlir/test/Dialect/Tosa/invalid.mlir
@@ -1107,11 +1107,38 @@ func.func @test_mul_non_scalar_shift_1d(%arg0: tensor<13x21x3xf32>, %arg1: tenso
// CHECK-LABEL: test_mul_non_broadcast
func.func @test_mul_non_broadcast(%arg0: tensor<13x21x2xf32>, %arg1: tensor<3x1x3xf32>) -> tensor<13x21x3xf32> {
%shift = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
- // expected-error at +1 {{'tosa.mul' op operands don't have broadcast-compatible shapes}}
+ // expected-error at +1 {{'tosa.mul' op a and b operands don't have broadcast-compatible shapes, got 'tensor<13x21x2xf32>' and 'tensor<3x1x3xf32>'}}
%0 = tosa.mul %arg0, %arg1, %shift : (tensor<13x21x2xf32>, tensor<3x1x3xf32>, tensor<1xi8>) -> tensor<13x21x3xf32>
return %0 : tensor<13x21x3xf32>
}
+// -----
+// CHECK-LABEL: test_mul_different_operand_ranks
+func.func @test_mul_different_operand_ranks(%arg0: tensor<13x21xf32>, %arg1: tensor<3x1x3xf32>) -> tensor<13x21x3xf32> {
+ %shift = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
+ // expected-error at +1 {{'tosa.mul' op a and b operands don't have matching ranks, got 2 and 3}}
+ %0 = tosa.mul %arg0, %arg1, %shift : (tensor<13x21xf32>, tensor<3x1x3xf32>, tensor<1xi8>) -> tensor<13x21x3xf32>
+ return %0 : tensor<13x21x3xf32>
+}
+
+// -----
+// CHECK-LABEL: test_mul_different_a_and_result_ranks
+func.func @test_mul_different_a_and_result_ranks(%arg0: tensor<13x21xf32>, %arg1: tensor<*xf32>) -> tensor<13x21x3xf32> {
+ %shift = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
+ // expected-error at +1 {{'tosa.mul' op result type has different rank than a, got 3 vs 2}}
+ %0 = tosa.mul %arg0, %arg1, %shift : (tensor<13x21xf32>, tensor<*xf32>, tensor<1xi8>) -> tensor<13x21x3xf32>
+ return %0 : tensor<13x21x3xf32>
+}
+
+// -----
+// CHECK-LABEL: test_mul_different_b_and_result_ranks
+func.func @test_mul_different_b_and_result_ranks(%arg0: tensor<*xf32>, %arg1: tensor<13x12xf32>) -> tensor<13x21x3xf32> {
+ %shift = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
+ // expected-error at +1 {{'tosa.mul' op result type has different rank than b, got 3 vs 2}}
+ %0 = tosa.mul %arg0, %arg1, %shift : (tensor<*xf32>, tensor<13x12xf32>, tensor<1xi8>) -> tensor<13x21x3xf32>
+ return %0 : tensor<13x21x3xf32>
+}
+
// -----
// CHECK-LABEL: test_resize_invalid_scale_values
func.func @test_resize_invalid_scale_values(%arg0: tensor<1x8x8x8xf32>) -> tensor<?x?x?x?xf32> {
diff --git a/mlir/test/Dialect/Tosa/ops.mlir b/mlir/test/Dialect/Tosa/ops.mlir
index 5ec506a45b3ad..882b59d029a4a 100644
--- a/mlir/test/Dialect/Tosa/ops.mlir
+++ b/mlir/test/Dialect/Tosa/ops.mlir
@@ -424,6 +424,22 @@ func.func @test_mul_relaxed_result_type(%arg0: tensor<13x21x3xi16>, %arg1: tenso
return %0 : tensor<13x21x3xi16>
}
+// -----
+// CHECK-LABEL: test_mul_unranked_b
+func.func @test_mul_unranked_b(%arg0: tensor<13x21x3xf32>, %arg1: tensor<*xf32>) -> tensor<13x21x3xf32> {
+ %shift = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
+ %0 = tosa.mul %arg0, %arg1, %shift : (tensor<13x21x3xf32>, tensor<*xf32>, tensor<1xi8>) -> tensor<13x21x3xf32>
+ return %0 : tensor<13x21x3xf32>
+}
+
+// -----
+// CHECK-LABEL: test_mul_unranked_a_and_b
+func.func @test_mul_unranked_a_and_b(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>) -> tensor<13x21x3xf32> {
+ %shift = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
+ %0 = tosa.mul %arg0, %arg1, %shift : (tensor<*xf32>, tensor<*xf32>, tensor<1xi8>) -> tensor<13x21x3xf32>
+ return %0 : tensor<13x21x3xf32>
+}
+
// -----
// CHECK-LABEL: pow
func.func @test_pow(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x21x1xf32>) -> tensor<13x21x3xf32> {
``````````
</details>
https://github.com/llvm/llvm-project/pull/141980
More information about the Mlir-commits
mailing list