[Mlir-commits] [mlir] [mlir][tosa] Fix MulOp verifier handling for unranked operands (PR #141980)
Luke Hutton
llvmlistbot at llvm.org
Thu May 29 09:52:40 PDT 2025
https://github.com/lhutton1 created https://github.com/llvm/llvm-project/pull/141980
The previous verifier checks did not correctly handle unranked operands. For example, it could incorrectly assume the number of `rankedOperandTypes` would be >= 2, which isn't the case when both a and b are unranked.
This change simplifies these checks such that they only operate over the intended a and b operands as opposed to the shift operand as well.
>From b0f576c51b4782801af1da7655de1616af22e5b7 Mon Sep 17 00:00:00 2001
From: Luke Hutton <luke.hutton at arm.com>
Date: Thu, 29 May 2025 13:59:05 +0000
Subject: [PATCH] [mlir][tosa] Fix MulOp verifier handling for unranked
operands
The previous verifier checks did not correctly handle unranked
operands. For example, it could incorrectly assume the number
of `rankedOperandTypes` would be >= 2, which isn't the case when
both a and b are unranked.
This change simplifies these checks such that they only operate
over the intended a and b operands as opposed to the shift operand
as well.
Change-Id: I0d0b7f7e8058f9a25dcb6c051aa0375cf780b80c
---
mlir/lib/Dialect/Tosa/IR/TosaOps.cpp | 85 +++++++++++-----------------
mlir/test/Dialect/Tosa/invalid.mlir | 29 +++++++++-
mlir/test/Dialect/Tosa/ops.mlir | 16 ++++++
3 files changed, 77 insertions(+), 53 deletions(-)
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 3ee5a85a21dca..298802fc7fa6c 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -1779,7 +1779,8 @@ LogicalResult tosa::MulOp::inferReturnTypeComponents(
}
LogicalResult tosa::MulOp::verify() {
- auto resElemType = getElementTypeOrSelf(getOutput());
+ const Value output = getOutput();
+ auto resElemType = getElementTypeOrSelf(output);
// Verify if the element type among operands and result match tosa
// specification.
@@ -1819,59 +1820,39 @@ LogicalResult tosa::MulOp::verify() {
// Verify the op has same ranks for all main operands (excludes extra operands
// such as shift of mul op, so this is the only difference with the built-in
// `SameOperandsAndResultRank` trait) and results types, if known.
-
- // delegate function that returns true if type is a shaped type with known
- // rank
- auto hasRank = [](const Type type) {
- if (auto shaped_type = dyn_cast<ShapedType>(type))
- return shaped_type.hasRank();
-
- return false;
- };
-
- auto rankedOperandTypes =
- llvm::to_vector(llvm::make_filter_range(getOperandTypes(), hasRank));
-
- auto rankedResultTypes =
- llvm::make_filter_range(getOperation()->getResultTypes(), hasRank);
-
- // If all operands and results are unranked, then no further verification.
- if (rankedOperandTypes.empty() && rankedResultTypes.empty())
+ TypeRange operandTypes = getOperandTypes();
+ ShapedType aType = cast<ShapedType>(operandTypes[0]);
+ ShapedType bType = cast<ShapedType>(operandTypes[1]);
+
+ const bool aHasRank = aType.hasRank();
+ const bool bHasRank = bType.hasRank();
+ if (aHasRank && bHasRank) {
+ const int64_t aRank = aType.getRank();
+ const int64_t bRank = bType.getRank();
+ if (aRank != bRank)
+ return emitOpError("a and b operands don't have matching ranks, got ")
+ << aRank << " and " << bRank;
+
+ // check for broadcast compatible shapes
+ SmallVector<int64_t> resultShape;
+ if (!mlir::OpTrait::util::getBroadcastedShape(
+ aType.getShape(), bType.getShape(), resultShape))
+ return emitOpError("a and b operands don't have broadcast-compatible "
+ "shapes, got ")
+ << aType << " and " << bType;
+ }
+
+ ShapedType resultType = cast<ShapedType>(output.getType());
+ if (!resultType.hasRank())
return success();
- // delegate function that returns rank of shaped type with known rank
- auto getRank = [](const Type type) {
- return cast<ShapedType>(type).getRank();
- };
-
- auto rank = !rankedOperandTypes.empty() ? getRank(*rankedOperandTypes.begin())
- : getRank(*rankedResultTypes.begin());
-
- for (size_t i = 0; i < 2; ++i) {
- if (rank != getRank(rankedOperandTypes[i])) {
- return emitOpError("operands don't have matching ranks");
- }
- }
-
- for (const auto type : rankedResultTypes) {
- if (rank != getRank(type)) {
- return emitOpError("result type has different rank than operands");
- }
- }
-
- // check for broadcast compatible shapes in first two operands (ignoring
- // shift)
-
- // delegate function that returns shape of shaped type
- auto getShape = [](const Type type) {
- return mlir::cast<ShapedType>(type).getShape();
- };
- SmallVector<int64_t> resultShape;
- if (!mlir::OpTrait::util::getBroadcastedShape(getShape(rankedOperandTypes[0]),
- getShape(rankedOperandTypes[1]),
- resultShape)) {
- return emitOpError("operands don't have broadcast-compatible shapes");
- }
+ const int64_t resultRank = resultType.getRank();
+ if (aHasRank && resultRank != aType.getRank())
+ return emitOpError("result type has different rank than a, got ")
+ << resultRank << " vs " << aType.getRank();
+ if (bHasRank && resultRank != bType.getRank())
+ return emitOpError("result type has different rank than b, got ")
+ << resultRank << " vs " << bType.getRank();
return success();
}
diff --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir
index 7b589fa839b44..3298e518de2f5 100644
--- a/mlir/test/Dialect/Tosa/invalid.mlir
+++ b/mlir/test/Dialect/Tosa/invalid.mlir
@@ -1107,11 +1107,38 @@ func.func @test_mul_non_scalar_shift_1d(%arg0: tensor<13x21x3xf32>, %arg1: tenso
// CHECK-LABEL: test_mul_non_broadcast
func.func @test_mul_non_broadcast(%arg0: tensor<13x21x2xf32>, %arg1: tensor<3x1x3xf32>) -> tensor<13x21x3xf32> {
%shift = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
- // expected-error at +1 {{'tosa.mul' op operands don't have broadcast-compatible shapes}}
+ // expected-error at +1 {{'tosa.mul' op a and b operands don't have broadcast-compatible shapes, got 'tensor<13x21x2xf32>' and 'tensor<3x1x3xf32>'}}
%0 = tosa.mul %arg0, %arg1, %shift : (tensor<13x21x2xf32>, tensor<3x1x3xf32>, tensor<1xi8>) -> tensor<13x21x3xf32>
return %0 : tensor<13x21x3xf32>
}
+// -----
+// CHECK-LABEL: test_mul_different_operand_ranks
+func.func @test_mul_different_operand_ranks(%arg0: tensor<13x21xf32>, %arg1: tensor<3x1x3xf32>) -> tensor<13x21x3xf32> {
+ %shift = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
+ // expected-error at +1 {{'tosa.mul' op a and b operands don't have matching ranks, got 2 and 3}}
+ %0 = tosa.mul %arg0, %arg1, %shift : (tensor<13x21xf32>, tensor<3x1x3xf32>, tensor<1xi8>) -> tensor<13x21x3xf32>
+ return %0 : tensor<13x21x3xf32>
+}
+
+// -----
+// CHECK-LABEL: test_mul_different_a_and_result_ranks
+func.func @test_mul_different_a_and_result_ranks(%arg0: tensor<13x21xf32>, %arg1: tensor<*xf32>) -> tensor<13x21x3xf32> {
+ %shift = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
+ // expected-error at +1 {{'tosa.mul' op result type has different rank than a, got 3 vs 2}}
+ %0 = tosa.mul %arg0, %arg1, %shift : (tensor<13x21xf32>, tensor<*xf32>, tensor<1xi8>) -> tensor<13x21x3xf32>
+ return %0 : tensor<13x21x3xf32>
+}
+
+// -----
+// CHECK-LABEL: test_mul_different_b_and_result_ranks
+func.func @test_mul_different_b_and_result_ranks(%arg0: tensor<*xf32>, %arg1: tensor<13x12xf32>) -> tensor<13x21x3xf32> {
+ %shift = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
+ // expected-error at +1 {{'tosa.mul' op result type has different rank than b, got 3 vs 2}}
+ %0 = tosa.mul %arg0, %arg1, %shift : (tensor<*xf32>, tensor<13x12xf32>, tensor<1xi8>) -> tensor<13x21x3xf32>
+ return %0 : tensor<13x21x3xf32>
+}
+
// -----
// CHECK-LABEL: test_resize_invalid_scale_values
func.func @test_resize_invalid_scale_values(%arg0: tensor<1x8x8x8xf32>) -> tensor<?x?x?x?xf32> {
diff --git a/mlir/test/Dialect/Tosa/ops.mlir b/mlir/test/Dialect/Tosa/ops.mlir
index 5ec506a45b3ad..882b59d029a4a 100644
--- a/mlir/test/Dialect/Tosa/ops.mlir
+++ b/mlir/test/Dialect/Tosa/ops.mlir
@@ -424,6 +424,22 @@ func.func @test_mul_relaxed_result_type(%arg0: tensor<13x21x3xi16>, %arg1: tenso
return %0 : tensor<13x21x3xi16>
}
+// -----
+// CHECK-LABEL: test_mul_unranked_b
+func.func @test_mul_unranked_b(%arg0: tensor<13x21x3xf32>, %arg1: tensor<*xf32>) -> tensor<13x21x3xf32> {
+ %shift = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
+ %0 = tosa.mul %arg0, %arg1, %shift : (tensor<13x21x3xf32>, tensor<*xf32>, tensor<1xi8>) -> tensor<13x21x3xf32>
+ return %0 : tensor<13x21x3xf32>
+}
+
+// -----
+// CHECK-LABEL: test_mul_unranked_a_and_b
+func.func @test_mul_unranked_a_and_b(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>) -> tensor<13x21x3xf32> {
+ %shift = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
+ %0 = tosa.mul %arg0, %arg1, %shift : (tensor<*xf32>, tensor<*xf32>, tensor<1xi8>) -> tensor<13x21x3xf32>
+ return %0 : tensor<13x21x3xf32>
+}
+
// -----
// CHECK-LABEL: pow
func.func @test_pow(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x21x1xf32>) -> tensor<13x21x3xf32> {
More information about the Mlir-commits
mailing list