[Mlir-commits] [mlir] [mlir][tosa] Fix mul op verifier when input types don't match result (PR #141617)

Luke Hutton llvmlistbot at llvm.org
Tue May 27 08:04:37 PDT 2025


https://github.com/lhutton1 created https://github.com/llvm/llvm-project/pull/141617

This commit fixes a crash when operand types are not integer, but the result is. While this isn't valid, the verifier should not crash.

>From 47f7b59432b951ba15755b0a9ddbfed5f669dc9c Mon Sep 17 00:00:00 2001
From: Luke Hutton <luke.hutton at arm.com>
Date: Tue, 27 May 2025 13:58:58 +0000
Subject: [PATCH] [mlir][tosa] Fix mul op verifier when input types don't match
 result

This commit fixes a crash when operand types are not integer, but the
result is. While this isn't valid, the verifier should not crash.

Change-Id: Id89ec38bcef26b1b08cf4797c1834e51da1fef09
---
 mlir/lib/Dialect/Tosa/IR/TosaOps.cpp |  6 +++---
 mlir/test/Dialect/Tosa/invalid.mlir  | 10 ++++++++++
 2 files changed, 13 insertions(+), 3 deletions(-)

diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 3ee5a85a21dca..61bf9c074d8b1 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -1785,10 +1785,10 @@ LogicalResult tosa::MulOp::verify() {
   // specification.
   if (auto resIntType = dyn_cast<IntegerType>(resElemType)) {
     IntegerType lhsIntType =
-        cast<IntegerType>(getElementTypeOrSelf(getInput1()));
+        dyn_cast<IntegerType>(getElementTypeOrSelf(getInput1()));
     IntegerType rhsIntType =
-        cast<IntegerType>(getElementTypeOrSelf(getInput2()));
-    if (lhsIntType != rhsIntType)
+        dyn_cast<IntegerType>(getElementTypeOrSelf(getInput2()));
+    if (!lhsIntType || !rhsIntType || lhsIntType != rhsIntType)
       return emitOpError("requires the same element type for all operands");
 
     // Though the spec requires the element type of result to be i32, a more
diff --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir
index 7b589fa839b44..c41f079ec526c 100644
--- a/mlir/test/Dialect/Tosa/invalid.mlir
+++ b/mlir/test/Dialect/Tosa/invalid.mlir
@@ -677,6 +677,16 @@ func.func @test_mul_type_mismatch(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x1
 
 // -----
 
+// CHECK-LABEL: test_mul_int_type_mismatch
+func.func @test_mul_int_type_mismatch(%arg0: tensor<1xf32>, %arg1: tensor<1xf32>) -> tensor<1xi32> {
+  %shift = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
+  // expected-error at +1 {{'tosa.mul' op requires the same element type for all operands}}
+  %3 = tosa.mul %arg0, %arg1, %shift : (tensor<1xf32>, tensor<1xf32>, tensor<1xi8>) -> tensor<1xi32>
+  return %3 : tensor<1xi32>
+}
+
+// -----
+
 // CHECK-LABEL: test_mul_invalid_shift
 func.func @test_mul_invalid_shift(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x1x3xf32>) -> tensor<13x21x3xf32> {
   %shift = "tosa.const"() {values = dense<1> : tensor<1xi8>} : () -> tensor<1xi8>



More information about the Mlir-commits mailing list