[Mlir-commits] [mlir] [mlir][arith] Add constraints to the MulIOp for preventing type mismatch while folding (PR #136093)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Apr 17 00:23:48 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-arith

@llvm/pr-subscribers-mlir

Author: Prakhar Dixit (Prakhar-Dixit)

<details>
<summary>Changes</summary>

Fixes #<!-- -->135289
The original version didn't check if the types of lhs, rhs, and the result matched, which could cause type errors.
This fix adds type checks to make sure the constants have the same type as the result before applying the simplification.

**Minimal example crashing :**

```
func.func @<!-- -->nested_muli() -> (i32) {
  %0 = "test.constant"() {value = 0x7fffffff} : () -> i32
  %1 = "test.constant"() {value = -2147483648} : () -> i32
  %2 = "test.constant"() {value = 0x80000000} : () -> i32
  %4 = arith.muli %0, %1 : i32
  %5 = arith.muli %4, %2 : i32
  return %5 : i32
}
```

---
Full diff: https://github.com/llvm/llvm-project/pull/136093.diff


2 Files Affected:

- (modified) mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td (+3-1) 
- (modified) mlir/test/Dialect/Arith/canonicalize.mlir (+18) 


``````````diff
diff --git a/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td b/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
index 6d7ac2be951dd..7e212df9029d1 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
+++ b/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
@@ -90,7 +90,9 @@ def MulIMulIConstant :
           (Arith_MulIOp $x, (ConstantLikeMatcher APIntAttr:$c0), $ovf1),
           (ConstantLikeMatcher APIntAttr:$c1), $ovf2),
         (Arith_MulIOp $x, (Arith_ConstantOp (MulIntAttrs $res, $c0, $c1)),
-            (MergeOverflow $ovf1, $ovf2))>;
+            (MergeOverflow $ovf1, $ovf2)),
+             [(Constraint<CPred<"$0.getType() == cast<IntegerAttr>($1).getType()">> $res, $c0),
+              (Constraint<CPred<"$0.getType() == cast<IntegerAttr>($1).getType()">> $res, $c1)]>;
 
 //===----------------------------------------------------------------------===//
 // AddUIExtendedOp
diff --git a/mlir/test/Dialect/Arith/canonicalize.mlir b/mlir/test/Dialect/Arith/canonicalize.mlir
index f0b2731707d18..d62c5b18fd041 100644
--- a/mlir/test/Dialect/Arith/canonicalize.mlir
+++ b/mlir/test/Dialect/Arith/canonicalize.mlir
@@ -1234,6 +1234,24 @@ func.func @doubleAddSub2(%arg0: index, %arg1 : index) -> index {
   return %add : index
 }
 
+// Negative test case to ensure no further folding is performed when there's a type mismatch between the values and the result.
+// CHECK-LABEL:   func.func @nested_muli() -> i32 {
+// CHECK:           %[[VAL_0:.*]] = "test.constant"() <{value = 2147483647 : i64}> : () -> i32
+// CHECK:           %[[VAL_1:.*]] = "test.constant"() <{value = -2147483648 : i64}> : () -> i32
+// CHECK:           %[[VAL_2:.*]] = "test.constant"() <{value = 2147483648 : i64}> : () -> i32
+// CHECK:           %[[VAL_3:.*]] = arith.muli %[[VAL_0]], %[[VAL_1]] : i32
+// CHECK:           %[[VAL_4:.*]] = arith.muli %[[VAL_3]], %[[VAL_2]] : i32
+// CHECK:           return %[[VAL_4]] : i32
+// CHECK:         }
+func.func @nested_muli() -> (i32) {
+  %0 = "test.constant"() {value = 0x7fffffff} : () -> i32
+  %1 = "test.constant"() {value = -2147483648} : () -> i32
+  %2 = "test.constant"() {value = 0x80000000} : () -> i32
+  %4 = arith.muli %0, %1 : i32
+  %5 = arith.muli %4, %2 : i32
+  return %5 : i32
+}
+
 // CHECK-LABEL: @tripleMulIMulIIndex
 //       CHECK:   %[[cres:.+]] = arith.constant 15 : index
 //       CHECK:   %[[muli:.+]] = arith.muli %arg0, %[[cres]] : index

``````````

</details>


https://github.com/llvm/llvm-project/pull/136093


More information about the Mlir-commits mailing list