[Mlir-commits] [mlir] [mlir][tosa] Fix mul op verifier when input types don't match result (PR #141617)
Luke Hutton
llvmlistbot at llvm.org
Tue May 27 08:04:37 PDT 2025
https://github.com/lhutton1 created https://github.com/llvm/llvm-project/pull/141617
This commit fixes a crash when operand types are not integer, but the result is. While this isn't valid, the verifier should not crash.
>From 47f7b59432b951ba15755b0a9ddbfed5f669dc9c Mon Sep 17 00:00:00 2001
From: Luke Hutton <luke.hutton at arm.com>
Date: Tue, 27 May 2025 13:58:58 +0000
Subject: [PATCH] [mlir][tosa] Fix mul op verifier when input types don't match
result
This commit fixes a crash when operand types are not integer, but the
result is. While this isn't valid, the verifier should not crash.
Change-Id: Id89ec38bcef26b1b08cf4797c1834e51da1fef09
---
mlir/lib/Dialect/Tosa/IR/TosaOps.cpp | 6 +++---
mlir/test/Dialect/Tosa/invalid.mlir | 10 ++++++++++
2 files changed, 13 insertions(+), 3 deletions(-)
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 3ee5a85a21dca..61bf9c074d8b1 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -1785,10 +1785,10 @@ LogicalResult tosa::MulOp::verify() {
// specification.
if (auto resIntType = dyn_cast<IntegerType>(resElemType)) {
IntegerType lhsIntType =
- cast<IntegerType>(getElementTypeOrSelf(getInput1()));
+ dyn_cast<IntegerType>(getElementTypeOrSelf(getInput1()));
IntegerType rhsIntType =
- cast<IntegerType>(getElementTypeOrSelf(getInput2()));
- if (lhsIntType != rhsIntType)
+ dyn_cast<IntegerType>(getElementTypeOrSelf(getInput2()));
+ if (!lhsIntType || !rhsIntType || lhsIntType != rhsIntType)
return emitOpError("requires the same element type for all operands");
// Though the spec requires the element type of result to be i32, a more
diff --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir
index 7b589fa839b44..c41f079ec526c 100644
--- a/mlir/test/Dialect/Tosa/invalid.mlir
+++ b/mlir/test/Dialect/Tosa/invalid.mlir
@@ -677,6 +677,16 @@ func.func @test_mul_type_mismatch(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x1
// -----
+// CHECK-LABEL: test_mul_int_type_mismatch
+func.func @test_mul_int_type_mismatch(%arg0: tensor<1xf32>, %arg1: tensor<1xf32>) -> tensor<1xi32> {
+ %shift = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
+ // expected-error at +1 {{'tosa.mul' op requires the same element type for all operands}}
+ %3 = tosa.mul %arg0, %arg1, %shift : (tensor<1xf32>, tensor<1xf32>, tensor<1xi8>) -> tensor<1xi32>
+ return %3 : tensor<1xi32>
+}
+
+// -----
+
// CHECK-LABEL: test_mul_invalid_shift
func.func @test_mul_invalid_shift(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x1x3xf32>) -> tensor<13x21x3xf32> {
%shift = "tosa.const"() {values = dense<1> : tensor<1xi8>} : () -> tensor<1xi8>
More information about the Mlir-commits
mailing list