[Mlir-commits] [mlir] [mlir][arith] Add constraints to the MulIOp for preventing type mismatch while folding (PR #136093)

Prakhar Dixit llvmlistbot at llvm.org
Thu Apr 17 00:23:15 PDT 2025


https://github.com/Prakhar-Dixit created https://github.com/llvm/llvm-project/pull/136093

Fixes #135289
The original version didn't check if the types of lhs, rhs, and the result matched, which could cause type errors.
This fix adds type checks to make sure the constants have the same type as the result before applying the simplification.

**Minimal example crashing :**

```
func.func @nested_muli() -> (i32) {
  %0 = "test.constant"() {value = 0x7fffffff} : () -> i32
  %1 = "test.constant"() {value = -2147483648} : () -> i32
  %2 = "test.constant"() {value = 0x80000000} : () -> i32
  %4 = arith.muli %0, %1 : i32
  %5 = arith.muli %4, %2 : i32
  return %5 : i32
}
```

>From 1e3c3dab121a220e3ec9f614f8cbb066bc60db9e Mon Sep 17 00:00:00 2001
From: Prakhar Dixit <dixitprakhar11 at gmail.com>
Date: Thu, 17 Apr 2025 12:41:52 +0530
Subject: [PATCH] [mlir][arith] Prevent type mismatch during MulIOp folding

---
 .../Dialect/Arith/IR/ArithCanonicalization.td  |  4 +++-
 mlir/test/Dialect/Arith/canonicalize.mlir      | 18 ++++++++++++++++++
 2 files changed, 21 insertions(+), 1 deletion(-)

diff --git a/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td b/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
index 6d7ac2be951dd..7e212df9029d1 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
+++ b/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
@@ -90,7 +90,9 @@ def MulIMulIConstant :
           (Arith_MulIOp $x, (ConstantLikeMatcher APIntAttr:$c0), $ovf1),
           (ConstantLikeMatcher APIntAttr:$c1), $ovf2),
         (Arith_MulIOp $x, (Arith_ConstantOp (MulIntAttrs $res, $c0, $c1)),
-            (MergeOverflow $ovf1, $ovf2))>;
+            (MergeOverflow $ovf1, $ovf2)),
+             [(Constraint<CPred<"$0.getType() == cast<IntegerAttr>($1).getType()">> $res, $c0),
+              (Constraint<CPred<"$0.getType() == cast<IntegerAttr>($1).getType()">> $res, $c1)]>;
 
 //===----------------------------------------------------------------------===//
 // AddUIExtendedOp
diff --git a/mlir/test/Dialect/Arith/canonicalize.mlir b/mlir/test/Dialect/Arith/canonicalize.mlir
index f0b2731707d18..d62c5b18fd041 100644
--- a/mlir/test/Dialect/Arith/canonicalize.mlir
+++ b/mlir/test/Dialect/Arith/canonicalize.mlir
@@ -1234,6 +1234,24 @@ func.func @doubleAddSub2(%arg0: index, %arg1 : index) -> index {
   return %add : index
 }
 
+// Negative test case to ensure no further folding is performed when there's a type mismatch between the values and the result.
+// CHECK-LABEL:   func.func @nested_muli() -> i32 {
+// CHECK:           %[[VAL_0:.*]] = "test.constant"() <{value = 2147483647 : i64}> : () -> i32
+// CHECK:           %[[VAL_1:.*]] = "test.constant"() <{value = -2147483648 : i64}> : () -> i32
+// CHECK:           %[[VAL_2:.*]] = "test.constant"() <{value = 2147483648 : i64}> : () -> i32
+// CHECK:           %[[VAL_3:.*]] = arith.muli %[[VAL_0]], %[[VAL_1]] : i32
+// CHECK:           %[[VAL_4:.*]] = arith.muli %[[VAL_3]], %[[VAL_2]] : i32
+// CHECK:           return %[[VAL_4]] : i32
+// CHECK:         }
+func.func @nested_muli() -> (i32) {
+  %0 = "test.constant"() {value = 0x7fffffff} : () -> i32
+  %1 = "test.constant"() {value = -2147483648} : () -> i32
+  %2 = "test.constant"() {value = 0x80000000} : () -> i32
+  %4 = arith.muli %0, %1 : i32
+  %5 = arith.muli %4, %2 : i32
+  return %5 : i32
+}
+
 // CHECK-LABEL: @tripleMulIMulIIndex
 //       CHECK:   %[[cres:.+]] = arith.constant 15 : index
 //       CHECK:   %[[muli:.+]] = arith.muli %arg0, %[[cres]] : index



More information about the Mlir-commits mailing list