[Mlir-commits] [mlir] [mlir][tosa] Fix MulOp verifier handling for unranked operands (PR #141980)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu May 29 09:53:20 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-tosa

Author: Luke Hutton (lhutton1)

<details>
<summary>Changes</summary>

The previous verifier checks did not correctly handle unranked operands. For example, it could incorrectly assume the number of `rankedOperandTypes` would be >= 2, which isn't the case when both a and b are unranked.

This change simplifies these checks such that they only operate over the intended a and b operands as opposed to the shift operand as well.

---
Full diff: https://github.com/llvm/llvm-project/pull/141980.diff


3 Files Affected:

- (modified) mlir/lib/Dialect/Tosa/IR/TosaOps.cpp (+33-52) 
- (modified) mlir/test/Dialect/Tosa/invalid.mlir (+28-1) 
- (modified) mlir/test/Dialect/Tosa/ops.mlir (+16) 


``````````diff
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 3ee5a85a21dca..298802fc7fa6c 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -1779,7 +1779,8 @@ LogicalResult tosa::MulOp::inferReturnTypeComponents(
 }
 
 LogicalResult tosa::MulOp::verify() {
-  auto resElemType = getElementTypeOrSelf(getOutput());
+  const Value output = getOutput();
+  auto resElemType = getElementTypeOrSelf(output);
 
   // Verify if the element type among operands and result match tosa
   // specification.
@@ -1819,59 +1820,39 @@ LogicalResult tosa::MulOp::verify() {
   // Verify the op has same ranks for all main operands (excludes extra operands
   // such as shift of mul op, so this is the only difference with the built-in
   // `SameOperandsAndResultRank` trait) and results types, if known.
-
-  // delegate function that returns true if type is a shaped type with known
-  // rank
-  auto hasRank = [](const Type type) {
-    if (auto shaped_type = dyn_cast<ShapedType>(type))
-      return shaped_type.hasRank();
-
-    return false;
-  };
-
-  auto rankedOperandTypes =
-      llvm::to_vector(llvm::make_filter_range(getOperandTypes(), hasRank));
-
-  auto rankedResultTypes =
-      llvm::make_filter_range(getOperation()->getResultTypes(), hasRank);
-
-  // If all operands and results are unranked, then no further verification.
-  if (rankedOperandTypes.empty() && rankedResultTypes.empty())
+  TypeRange operandTypes = getOperandTypes();
+  ShapedType aType = cast<ShapedType>(operandTypes[0]);
+  ShapedType bType = cast<ShapedType>(operandTypes[1]);
+
+  const bool aHasRank = aType.hasRank();
+  const bool bHasRank = bType.hasRank();
+  if (aHasRank && bHasRank) {
+    const int64_t aRank = aType.getRank();
+    const int64_t bRank = bType.getRank();
+    if (aRank != bRank)
+      return emitOpError("a and b operands don't have matching ranks, got ")
+             << aRank << " and " << bRank;
+
+    // check for broadcast compatible shapes
+    SmallVector<int64_t> resultShape;
+    if (!mlir::OpTrait::util::getBroadcastedShape(
+            aType.getShape(), bType.getShape(), resultShape))
+      return emitOpError("a and b operands don't have broadcast-compatible "
+                         "shapes, got ")
+             << aType << " and " << bType;
+  }
+
+  ShapedType resultType = cast<ShapedType>(output.getType());
+  if (!resultType.hasRank())
     return success();
 
-  // delegate function that returns rank of shaped type with known rank
-  auto getRank = [](const Type type) {
-    return cast<ShapedType>(type).getRank();
-  };
-
-  auto rank = !rankedOperandTypes.empty() ? getRank(*rankedOperandTypes.begin())
-                                          : getRank(*rankedResultTypes.begin());
-
-  for (size_t i = 0; i < 2; ++i) {
-    if (rank != getRank(rankedOperandTypes[i])) {
-      return emitOpError("operands don't have matching ranks");
-    }
-  }
-
-  for (const auto type : rankedResultTypes) {
-    if (rank != getRank(type)) {
-      return emitOpError("result type has different rank than operands");
-    }
-  }
-
-  // check for broadcast compatible shapes in first two operands (ignoring
-  // shift)
-
-  // delegate function that returns shape of shaped type
-  auto getShape = [](const Type type) {
-    return mlir::cast<ShapedType>(type).getShape();
-  };
-  SmallVector<int64_t> resultShape;
-  if (!mlir::OpTrait::util::getBroadcastedShape(getShape(rankedOperandTypes[0]),
-                                                getShape(rankedOperandTypes[1]),
-                                                resultShape)) {
-    return emitOpError("operands don't have broadcast-compatible shapes");
-  }
+  const int64_t resultRank = resultType.getRank();
+  if (aHasRank && resultRank != aType.getRank())
+    return emitOpError("result type has different rank than a, got ")
+           << resultRank << " vs " << aType.getRank();
+  if (bHasRank && resultRank != bType.getRank())
+    return emitOpError("result type has different rank than b, got ")
+           << resultRank << " vs " << bType.getRank();
 
   return success();
 }
diff --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir
index 7b589fa839b44..3298e518de2f5 100644
--- a/mlir/test/Dialect/Tosa/invalid.mlir
+++ b/mlir/test/Dialect/Tosa/invalid.mlir
@@ -1107,11 +1107,38 @@ func.func @test_mul_non_scalar_shift_1d(%arg0: tensor<13x21x3xf32>, %arg1: tenso
 // CHECK-LABEL: test_mul_non_broadcast
 func.func @test_mul_non_broadcast(%arg0: tensor<13x21x2xf32>, %arg1: tensor<3x1x3xf32>) -> tensor<13x21x3xf32> {
   %shift = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
-  // expected-error at +1 {{'tosa.mul' op operands don't have broadcast-compatible shapes}}
+  // expected-error at +1 {{'tosa.mul' op a and b operands don't have broadcast-compatible shapes, got 'tensor<13x21x2xf32>' and 'tensor<3x1x3xf32>'}}
   %0 = tosa.mul %arg0, %arg1, %shift : (tensor<13x21x2xf32>, tensor<3x1x3xf32>, tensor<1xi8>) -> tensor<13x21x3xf32>
   return %0 : tensor<13x21x3xf32>
 }
 
+// -----
+// CHECK-LABEL: test_mul_different_operand_ranks
+func.func @test_mul_different_operand_ranks(%arg0: tensor<13x21xf32>, %arg1: tensor<3x1x3xf32>) -> tensor<13x21x3xf32> {
+  %shift = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
+  // expected-error at +1 {{'tosa.mul' op a and b operands don't have matching ranks, got 2 and 3}}
+  %0 = tosa.mul %arg0, %arg1, %shift : (tensor<13x21xf32>, tensor<3x1x3xf32>, tensor<1xi8>) -> tensor<13x21x3xf32>
+  return %0 : tensor<13x21x3xf32>
+}
+
+// -----
+// CHECK-LABEL: test_mul_different_a_and_result_ranks
+func.func @test_mul_different_a_and_result_ranks(%arg0: tensor<13x21xf32>, %arg1: tensor<*xf32>) -> tensor<13x21x3xf32> {
+  %shift = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
+  // expected-error at +1 {{'tosa.mul' op result type has different rank than a, got 3 vs 2}}
+  %0 = tosa.mul %arg0, %arg1, %shift : (tensor<13x21xf32>, tensor<*xf32>, tensor<1xi8>) -> tensor<13x21x3xf32>
+  return %0 : tensor<13x21x3xf32>
+}
+
+// -----
+// CHECK-LABEL: test_mul_different_b_and_result_ranks
+func.func @test_mul_different_b_and_result_ranks(%arg0: tensor<*xf32>, %arg1: tensor<13x12xf32>) -> tensor<13x21x3xf32> {
+  %shift = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
+  // expected-error at +1 {{'tosa.mul' op result type has different rank than b, got 3 vs 2}}
+  %0 = tosa.mul %arg0, %arg1, %shift : (tensor<*xf32>, tensor<13x12xf32>, tensor<1xi8>) -> tensor<13x21x3xf32>
+  return %0 : tensor<13x21x3xf32>
+}
+
 // -----
 // CHECK-LABEL: test_resize_invalid_scale_values
 func.func @test_resize_invalid_scale_values(%arg0: tensor<1x8x8x8xf32>) -> tensor<?x?x?x?xf32> {
diff --git a/mlir/test/Dialect/Tosa/ops.mlir b/mlir/test/Dialect/Tosa/ops.mlir
index 5ec506a45b3ad..882b59d029a4a 100644
--- a/mlir/test/Dialect/Tosa/ops.mlir
+++ b/mlir/test/Dialect/Tosa/ops.mlir
@@ -424,6 +424,22 @@ func.func @test_mul_relaxed_result_type(%arg0: tensor<13x21x3xi16>, %arg1: tenso
   return %0 : tensor<13x21x3xi16>
 }
 
+// -----
+// CHECK-LABEL: test_mul_unranked_b
+func.func @test_mul_unranked_b(%arg0: tensor<13x21x3xf32>, %arg1: tensor<*xf32>) -> tensor<13x21x3xf32> {
+  %shift = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
+  %0 = tosa.mul %arg0, %arg1, %shift : (tensor<13x21x3xf32>, tensor<*xf32>, tensor<1xi8>) -> tensor<13x21x3xf32>
+  return %0 : tensor<13x21x3xf32>
+}
+
+// -----
+// CHECK-LABEL: test_mul_unranked_a_and_b
+func.func @test_mul_unranked_a_and_b(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>) -> tensor<13x21x3xf32> {
+  %shift = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
+  %0 = tosa.mul %arg0, %arg1, %shift : (tensor<*xf32>, tensor<*xf32>, tensor<1xi8>) -> tensor<13x21x3xf32>
+  return %0 : tensor<13x21x3xf32>
+}
+
 // -----
 // CHECK-LABEL: pow
 func.func @test_pow(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x21x1xf32>) -> tensor<13x21x3xf32> {

``````````

</details>


https://github.com/llvm/llvm-project/pull/141980


More information about the Mlir-commits mailing list