[Mlir-commits] [mlir] 9d5e144 - [mlir][tosa] Fix MulOp verifier handling for unranked operands (#141980)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Jun 5 00:54:04 PDT 2025
Author: Luke Hutton
Date: 2025-06-05T08:54:01+01:00
New Revision: 9d5e1449f7902329cdf448a1d238529836989582
URL: https://github.com/llvm/llvm-project/commit/9d5e1449f7902329cdf448a1d238529836989582
DIFF: https://github.com/llvm/llvm-project/commit/9d5e1449f7902329cdf448a1d238529836989582.diff
LOG: [mlir][tosa] Fix MulOp verifier handling for unranked operands (#141980)
The previous verifier checks did not correctly handle unranked operands.
For example, it could incorrectly assume the number of
`rankedOperandTypes` would be >= 2, which isn't the case when both a and
b are unranked.
This change simplifies these checks such that they only operate over the
intended a and b operands as opposed to the shift operand as well.
Added:
Modified:
mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
mlir/test/Dialect/Tosa/invalid.mlir
mlir/test/Dialect/Tosa/ops.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index a22e6b7aa9791..5eb062ec9b535 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -1857,7 +1857,8 @@ LogicalResult tosa::MulOp::inferReturnTypeComponents(
}
LogicalResult tosa::MulOp::verify() {
- auto resElemType = getElementTypeOrSelf(getOutput());
+ const Value output = getOutput();
+ auto resElemType = getElementTypeOrSelf(output);
// Verify if the element type among operands and result match tosa
// specification.
@@ -1897,59 +1898,39 @@ LogicalResult tosa::MulOp::verify() {
// Verify the op has same ranks for all main operands (excludes extra operands
// such as shift of mul op, so this is the only
diff erence with the built-in
// `SameOperandsAndResultRank` trait) and results types, if known.
-
- // delegate function that returns true if type is a shaped type with known
- // rank
- auto hasRank = [](const Type type) {
- if (auto shaped_type = dyn_cast<ShapedType>(type))
- return shaped_type.hasRank();
-
- return false;
- };
-
- auto rankedOperandTypes =
- llvm::to_vector(llvm::make_filter_range(getOperandTypes(), hasRank));
-
- auto rankedResultTypes =
- llvm::make_filter_range(getOperation()->getResultTypes(), hasRank);
-
- // If all operands and results are unranked, then no further verification.
- if (rankedOperandTypes.empty() && rankedResultTypes.empty())
+ TypeRange operandTypes = getOperandTypes();
+ ShapedType aType = cast<ShapedType>(operandTypes[0]);
+ ShapedType bType = cast<ShapedType>(operandTypes[1]);
+
+ const bool aHasRank = aType.hasRank();
+ const bool bHasRank = bType.hasRank();
+ if (aHasRank && bHasRank) {
+ const int64_t aRank = aType.getRank();
+ const int64_t bRank = bType.getRank();
+ if (aRank != bRank)
+ return emitOpError("a and b operands don't have matching ranks, got ")
+ << aRank << " and " << bRank;
+
+ // check for broadcast compatible shapes
+ SmallVector<int64_t> resultShape;
+ if (!mlir::OpTrait::util::getBroadcastedShape(
+ aType.getShape(), bType.getShape(), resultShape))
+ return emitOpError("a and b operands don't have broadcast-compatible "
+ "shapes, got ")
+ << aType << " and " << bType;
+ }
+
+ ShapedType resultType = cast<ShapedType>(output.getType());
+ if (!resultType.hasRank())
return success();
- // delegate function that returns rank of shaped type with known rank
- auto getRank = [](const Type type) {
- return cast<ShapedType>(type).getRank();
- };
-
- auto rank = !rankedOperandTypes.empty() ? getRank(*rankedOperandTypes.begin())
- : getRank(*rankedResultTypes.begin());
-
- for (size_t i = 0; i < 2; ++i) {
- if (rank != getRank(rankedOperandTypes[i])) {
- return emitOpError("operands don't have matching ranks");
- }
- }
-
- for (const auto type : rankedResultTypes) {
- if (rank != getRank(type)) {
- return emitOpError("result type has
diff erent rank than operands");
- }
- }
-
- // check for broadcast compatible shapes in first two operands (ignoring
- // shift)
-
- // delegate function that returns shape of shaped type
- auto getShape = [](const Type type) {
- return mlir::cast<ShapedType>(type).getShape();
- };
- SmallVector<int64_t> resultShape;
- if (!mlir::OpTrait::util::getBroadcastedShape(getShape(rankedOperandTypes[0]),
- getShape(rankedOperandTypes[1]),
- resultShape)) {
- return emitOpError("operands don't have broadcast-compatible shapes");
- }
+ const int64_t resultRank = resultType.getRank();
+ if (aHasRank && resultRank != aType.getRank())
+ return emitOpError("result type has
diff erent rank than a, got ")
+ << resultRank << " vs " << aType.getRank();
+ if (bHasRank && resultRank != bType.getRank())
+ return emitOpError("result type has
diff erent rank than b, got ")
+ << resultRank << " vs " << bType.getRank();
return success();
}
diff --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir
index 05505c3671674..a4617fc6fba8b 100644
--- a/mlir/test/Dialect/Tosa/invalid.mlir
+++ b/mlir/test/Dialect/Tosa/invalid.mlir
@@ -1134,11 +1134,38 @@ func.func @test_mul_non_scalar_shift_1d(%arg0: tensor<13x21x3xf32>, %arg1: tenso
// CHECK-LABEL: test_mul_non_broadcast
func.func @test_mul_non_broadcast(%arg0: tensor<13x21x2xf32>, %arg1: tensor<3x1x3xf32>) -> tensor<13x21x3xf32> {
%shift = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
- // expected-error at +1 {{'tosa.mul' op operands don't have broadcast-compatible shapes}}
+ // expected-error at +1 {{'tosa.mul' op a and b operands don't have broadcast-compatible shapes, got 'tensor<13x21x2xf32>' and 'tensor<3x1x3xf32>'}}
%0 = tosa.mul %arg0, %arg1, %shift : (tensor<13x21x2xf32>, tensor<3x1x3xf32>, tensor<1xi8>) -> tensor<13x21x3xf32>
return %0 : tensor<13x21x3xf32>
}
+// -----
+// CHECK-LABEL: test_mul_
diff erent_operand_ranks
+func.func @test_mul_
diff erent_operand_ranks(%arg0: tensor<13x21xf32>, %arg1: tensor<3x1x3xf32>) -> tensor<13x21x3xf32> {
+ %shift = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
+ // expected-error at +1 {{'tosa.mul' op a and b operands don't have matching ranks, got 2 and 3}}
+ %0 = tosa.mul %arg0, %arg1, %shift : (tensor<13x21xf32>, tensor<3x1x3xf32>, tensor<1xi8>) -> tensor<13x21x3xf32>
+ return %0 : tensor<13x21x3xf32>
+}
+
+// -----
+// CHECK-LABEL: test_mul_
diff erent_a_and_result_ranks
+func.func @test_mul_
diff erent_a_and_result_ranks(%arg0: tensor<13x21xf32>, %arg1: tensor<*xf32>) -> tensor<13x21x3xf32> {
+ %shift = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
+ // expected-error at +1 {{'tosa.mul' op result type has
diff erent rank than a, got 3 vs 2}}
+ %0 = tosa.mul %arg0, %arg1, %shift : (tensor<13x21xf32>, tensor<*xf32>, tensor<1xi8>) -> tensor<13x21x3xf32>
+ return %0 : tensor<13x21x3xf32>
+}
+
+// -----
+// CHECK-LABEL: test_mul_
diff erent_b_and_result_ranks
+func.func @test_mul_
diff erent_b_and_result_ranks(%arg0: tensor<*xf32>, %arg1: tensor<13x12xf32>) -> tensor<13x21x3xf32> {
+ %shift = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
+ // expected-error at +1 {{'tosa.mul' op result type has
diff erent rank than b, got 3 vs 2}}
+ %0 = tosa.mul %arg0, %arg1, %shift : (tensor<*xf32>, tensor<13x12xf32>, tensor<1xi8>) -> tensor<13x21x3xf32>
+ return %0 : tensor<13x21x3xf32>
+}
+
// -----
// CHECK-LABEL: test_resize_invalid_scale_values
func.func @test_resize_invalid_scale_values(%arg0: tensor<1x8x8x8xf32>) -> tensor<?x?x?x?xf32> {
diff --git a/mlir/test/Dialect/Tosa/ops.mlir b/mlir/test/Dialect/Tosa/ops.mlir
index 767fa833dedd4..b44aabe2ac178 100644
--- a/mlir/test/Dialect/Tosa/ops.mlir
+++ b/mlir/test/Dialect/Tosa/ops.mlir
@@ -424,6 +424,22 @@ func.func @test_mul_relaxed_result_type(%arg0: tensor<13x21x3xi16>, %arg1: tenso
return %0 : tensor<13x21x3xi16>
}
+// -----
+// CHECK-LABEL: test_mul_unranked_b
+func.func @test_mul_unranked_b(%arg0: tensor<13x21x3xf32>, %arg1: tensor<*xf32>) -> tensor<13x21x3xf32> {
+ %shift = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
+ %0 = tosa.mul %arg0, %arg1, %shift : (tensor<13x21x3xf32>, tensor<*xf32>, tensor<1xi8>) -> tensor<13x21x3xf32>
+ return %0 : tensor<13x21x3xf32>
+}
+
+// -----
+// CHECK-LABEL: test_mul_unranked_a_and_b
+func.func @test_mul_unranked_a_and_b(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>) -> tensor<13x21x3xf32> {
+ %shift = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
+ %0 = tosa.mul %arg0, %arg1, %shift : (tensor<*xf32>, tensor<*xf32>, tensor<1xi8>) -> tensor<13x21x3xf32>
+ return %0 : tensor<13x21x3xf32>
+}
+
// -----
// CHECK-LABEL: pow
func.func @test_pow(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x21x1xf32>) -> tensor<13x21x3xf32> {
More information about the Mlir-commits
mailing list