[Mlir-commits] [mlir] 35f4cdb - [mlir][arith] Add constraints to the MulIOp for preventing type mismatch while folding (#136093)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Apr 17 02:17:38 PDT 2025


Author: Prakhar Dixit
Date: 2025-04-17T11:17:34+02:00
New Revision: 35f4cdbf59fca82b97869cce7e9e5d5009144938

URL: https://github.com/llvm/llvm-project/commit/35f4cdbf59fca82b97869cce7e9e5d5009144938
DIFF: https://github.com/llvm/llvm-project/commit/35f4cdbf59fca82b97869cce7e9e5d5009144938.diff

LOG: [mlir][arith] Add constraints to the MulIOp for preventing type mismatch while folding (#136093)

Fixes #135289
The original version didn't check if the types of lhs, rhs, and the
result matched, which could cause type errors.
This fix adds type checks to make sure the constants attributes have
the same type as the SSA values before applying the simplification.

Added: 
    

Modified: 
    mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
    mlir/test/Dialect/Arith/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td b/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
index 6d7ac2be951dd..7e212df9029d1 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
+++ b/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
@@ -90,7 +90,9 @@ def MulIMulIConstant :
           (Arith_MulIOp $x, (ConstantLikeMatcher APIntAttr:$c0), $ovf1),
           (ConstantLikeMatcher APIntAttr:$c1), $ovf2),
         (Arith_MulIOp $x, (Arith_ConstantOp (MulIntAttrs $res, $c0, $c1)),
-            (MergeOverflow $ovf1, $ovf2))>;
+            (MergeOverflow $ovf1, $ovf2)),
+             [(Constraint<CPred<"$0.getType() == cast<IntegerAttr>($1).getType()">> $res, $c0),
+              (Constraint<CPred<"$0.getType() == cast<IntegerAttr>($1).getType()">> $res, $c1)]>;
 
 //===----------------------------------------------------------------------===//
 // AddUIExtendedOp

diff  --git a/mlir/test/Dialect/Arith/canonicalize.mlir b/mlir/test/Dialect/Arith/canonicalize.mlir
index f0b2731707d18..d62c5b18fd041 100644
--- a/mlir/test/Dialect/Arith/canonicalize.mlir
+++ b/mlir/test/Dialect/Arith/canonicalize.mlir
@@ -1234,6 +1234,24 @@ func.func @doubleAddSub2(%arg0: index, %arg1 : index) -> index {
   return %add : index
 }
 
+// Negative test case to ensure no further folding is performed when there's a type mismatch between the values and the result.
+// CHECK-LABEL:   func.func @nested_muli() -> i32 {
+// CHECK:           %[[VAL_0:.*]] = "test.constant"() <{value = 2147483647 : i64}> : () -> i32
+// CHECK:           %[[VAL_1:.*]] = "test.constant"() <{value = -2147483648 : i64}> : () -> i32
+// CHECK:           %[[VAL_2:.*]] = "test.constant"() <{value = 2147483648 : i64}> : () -> i32
+// CHECK:           %[[VAL_3:.*]] = arith.muli %[[VAL_0]], %[[VAL_1]] : i32
+// CHECK:           %[[VAL_4:.*]] = arith.muli %[[VAL_3]], %[[VAL_2]] : i32
+// CHECK:           return %[[VAL_4]] : i32
+// CHECK:         }
+func.func @nested_muli() -> (i32) {
+  %0 = "test.constant"() {value = 0x7fffffff} : () -> i32
+  %1 = "test.constant"() {value = -2147483648} : () -> i32
+  %2 = "test.constant"() {value = 0x80000000} : () -> i32
+  %4 = arith.muli %0, %1 : i32
+  %5 = arith.muli %4, %2 : i32
+  return %5 : i32
+}
+
 // CHECK-LABEL: @tripleMulIMulIIndex
 //       CHECK:   %[[cres:.+]] = arith.constant 15 : index
 //       CHECK:   %[[muli:.+]] = arith.muli %arg0, %[[cres]] : index


        


More information about the Mlir-commits mailing list