[Mlir-commits] [mlir] [mlir][tosa] Fix mul op verifier when input types don't match result (PR #141617)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue May 27 08:05:13 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Luke Hutton (lhutton1)

<details>
<summary>Changes</summary>

This commit fixes a crash when operand types are not integer, but the result is. While this isn't valid, the verifier should not crash.

---
Full diff: https://github.com/llvm/llvm-project/pull/141617.diff


2 Files Affected:

- (modified) mlir/lib/Dialect/Tosa/IR/TosaOps.cpp (+3-3) 
- (modified) mlir/test/Dialect/Tosa/invalid.mlir (+10) 


``````````diff
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 3ee5a85a21dca..61bf9c074d8b1 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -1785,10 +1785,10 @@ LogicalResult tosa::MulOp::verify() {
   // specification.
   if (auto resIntType = dyn_cast<IntegerType>(resElemType)) {
     IntegerType lhsIntType =
-        cast<IntegerType>(getElementTypeOrSelf(getInput1()));
+        dyn_cast<IntegerType>(getElementTypeOrSelf(getInput1()));
     IntegerType rhsIntType =
-        cast<IntegerType>(getElementTypeOrSelf(getInput2()));
-    if (lhsIntType != rhsIntType)
+        dyn_cast<IntegerType>(getElementTypeOrSelf(getInput2()));
+    if (!lhsIntType || !rhsIntType || lhsIntType != rhsIntType)
       return emitOpError("requires the same element type for all operands");
 
     // Though the spec requires the element type of result to be i32, a more
diff --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir
index 7b589fa839b44..c41f079ec526c 100644
--- a/mlir/test/Dialect/Tosa/invalid.mlir
+++ b/mlir/test/Dialect/Tosa/invalid.mlir
@@ -677,6 +677,16 @@ func.func @test_mul_type_mismatch(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x1
 
 // -----
 
+// CHECK-LABEL: test_mul_int_type_mismatch
+func.func @test_mul_int_type_mismatch(%arg0: tensor<1xf32>, %arg1: tensor<1xf32>) -> tensor<1xi32> {
+  %shift = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
+  // expected-error at +1 {{'tosa.mul' op requires the same element type for all operands}}
+  %3 = tosa.mul %arg0, %arg1, %shift : (tensor<1xf32>, tensor<1xf32>, tensor<1xi8>) -> tensor<1xi32>
+  return %3 : tensor<1xi32>
+}
+
+// -----
+
 // CHECK-LABEL: test_mul_invalid_shift
 func.func @test_mul_invalid_shift(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x1x3xf32>) -> tensor<13x21x3xf32> {
   %shift = "tosa.const"() {values = dense<1> : tensor<1xi8>} : () -> tensor<1xi8>

``````````

</details>


https://github.com/llvm/llvm-project/pull/141617


More information about the Mlir-commits mailing list