[Mlir-commits] [mlir] [MLIR][Arith] Fix unsound addi/subi constant-fold when merged overflow flags would be violated (PR #189164)

Mehdi Amini llvmlistbot at llvm.org
Sat Mar 28 05:48:31 PDT 2026


https://github.com/joker-eph created https://github.com/llvm/llvm-project/pull/189164

The `AddIAddConstant` and related patterns fold chains of additions and subtractions involving a constant into a single operation with the combined constant. For example:

  addi(addi(x, c0) [nsw], c1) [nsw]  ->  addi(x, c0+c1) [nsw]

This transformation is only valid when the combined constant c0+c1 does not itself overflow under the merged overflow flags. If the constant sum wraps (e.g. c0=-2, c1=-1 for i2: -2+-1=-3 wraps to +1), the folded op produces poison for inputs where the original did not:

  For x=1 (i2): original gives 1+(-2)=-1 then -1+(-1)=-2 (no poison),
  but the fold produces 1+1=2 which overflows i2 under nsw -> poison.

Fix by adding `AddAttrsNoOverflow` / `SubAttrsNoOverflow` constraints to all affected patterns. These constraints check that the combined constant does not overflow signed (for nsw) or unsigned (for nuw) before allowing the fold. The constraints are implemented as C++ helpers in ArithOps.cpp that are called from the TableGen pattern guard predicates.

Fixes #140328

Assisted-by: Claude Code

>From af049a4dff50efdd72c997387bd380d4a50c2677 Mon Sep 17 00:00:00 2001
From: Mehdi Amini <joker.eph at gmail.com>
Date: Sat, 28 Mar 2026 05:04:18 -0700
Subject: [PATCH] [MLIR][Arith] Fix unsound addi/subi constant-fold when merged
 overflow flags would be violated

The `AddIAddConstant` and related patterns fold chains of additions and
subtractions involving a constant into a single operation with the
combined constant. For example:

  addi(addi(x, c0) [nsw], c1) [nsw]  ->  addi(x, c0+c1) [nsw]

This transformation is only valid when the combined constant c0+c1 does
not itself overflow under the merged overflow flags. If the constant sum
wraps (e.g. c0=-2, c1=-1 for i2: -2+-1=-3 wraps to +1), the folded op
produces poison for inputs where the original did not:

  For x=1 (i2): original gives 1+(-2)=-1 then -1+(-1)=-2 (no poison),
  but the fold produces 1+1=2 which overflows i2 under nsw -> poison.

Fix by adding `AddAttrsNoOverflow` / `SubAttrsNoOverflow` constraints to
all affected patterns. These constraints check that the combined constant
does not overflow signed (for nsw) or unsigned (for nuw) before allowing
the fold. The constraints are implemented as C++ helpers in ArithOps.cpp
that are called from the TableGen pattern guard predicates.

Also update the existing test `tripleSubSub1Ovf` whose prior expectation
was incorrect: `subi(subi(17, x), 42) [nsw,nuw]` has the combined
constant 17-42=-25 which wraps unsigned, so the fold is not applied.

Fixes #140328

Assisted-by: Claude Code
---
 .../Dialect/Arith/IR/ArithCanonicalization.td | 39 +++++++++---
 mlir/lib/Dialect/Arith/IR/ArithOps.cpp        | 61 +++++++++++++++++++
 mlir/test/Dialect/Arith/canonicalize.mlir     | 49 ++++++++++++++-
 3 files changed, 138 insertions(+), 11 deletions(-)

diff --git a/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td b/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
index e22fc1d478e4f..8eb6aeed6dd79 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
+++ b/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
@@ -30,6 +30,19 @@ def MergeOverflow : NativeCodeCall<"mergeOverflowFlags($0, $1)">;
 // Default overflow flag (all wraparounds allowed).
 defvar DefOverflow = ConstantEnumCase<Arith_IntegerOverflowAttr, "none">;
 
+// Constraint: constant c0+c1 does not overflow under the merged flags of ovf1
+// and ovf2. Without this guard, folding addi(addi(x, c0), c1) -> addi(x, c0+c1)
+// is unsound when the constant sum wraps and the result carries nsw/nuw.
+def AddAttrsNoOverflow :
+    Constraint<CPred<"addAttrsNoOverflow($0, $1, $2, $3)">,
+               "constant addition does not overflow with merged flags">;
+
+// Constraint: constant c0-c1 does not overflow under the merged flags of ovf1
+// and ovf2.
+def SubAttrsNoOverflow :
+    Constraint<CPred<"subAttrsNoOverflow($0, $1, $2, $3)">,
+               "constant subtraction does not overflow with merged flags">;
+
 class cast<string type> : NativeCodeCall<"::mlir::cast<" # type # ">($0)">;
 
 //===----------------------------------------------------------------------===//
@@ -40,12 +53,15 @@ class cast<string type> : NativeCodeCall<"::mlir::cast<" # type # ">($0)">;
 // as the second operand.
 
 // addi(addi(x, c0), c1) -> addi(x, c0 + c1)
+// Guard: the constant sum c0+c1 must not overflow under the merged flags,
+// otherwise the folded op would produce poison where the originals did not.
 def AddIAddConstant :
     Pat<(Arith_AddIOp:$res
           (Arith_AddIOp $x, (ConstantLikeMatcher APIntAttr:$c0), $ovf1),
           (ConstantLikeMatcher APIntAttr:$c1), $ovf2),
         (Arith_AddIOp $x, (Arith_ConstantOp (AddIntAttrs $res, $c0, $c1)),
-            (MergeOverflow $ovf1, $ovf2))>;
+            (MergeOverflow $ovf1, $ovf2)),
+        [(AddAttrsNoOverflow $c0, $c1, $ovf1, $ovf2)]>;
 
 // addi(subi(x, c0), c1) -> addi(x, c1 - c0)
 def AddISubConstantRHS :
@@ -53,7 +69,8 @@ def AddISubConstantRHS :
           (Arith_SubIOp $x, (ConstantLikeMatcher APIntAttr:$c0), $ovf1),
           (ConstantLikeMatcher APIntAttr:$c1), $ovf2),
         (Arith_AddIOp $x, (Arith_ConstantOp (SubIntAttrs $res, $c1, $c0)),
-            (MergeOverflow $ovf1, $ovf2))>;
+            (MergeOverflow $ovf1, $ovf2)),
+        [(SubAttrsNoOverflow $c1, $c0, $ovf1, $ovf2)]>;
 
 // addi(subi(c0, x), c1) -> subi(c0 + c1, x)
 def AddISubConstantLHS :
@@ -61,7 +78,8 @@ def AddISubConstantLHS :
           (Arith_SubIOp (ConstantLikeMatcher APIntAttr:$c0), $x, $ovf1),
           (ConstantLikeMatcher APIntAttr:$c1), $ovf2),
         (Arith_SubIOp (Arith_ConstantOp (AddIntAttrs $res, $c0, $c1)), $x,
-            (MergeOverflow $ovf1, $ovf2))>;
+            (MergeOverflow $ovf1, $ovf2)),
+        [(AddAttrsNoOverflow $c0, $c1, $ovf1, $ovf2)]>;
 
 def IsScalarOrSplatNegativeOne :
     Constraint<And<[
@@ -123,7 +141,8 @@ def SubILHSAddConstant :
           (ConstantLikeMatcher APIntAttr:$c1),
           (Arith_AddIOp $x, (ConstantLikeMatcher APIntAttr:$c0), $ovf1), $ovf2),
         (Arith_SubIOp (Arith_ConstantOp (SubIntAttrs $res, $c1, $c0)), $x,
-            (MergeOverflow $ovf1, $ovf2))>;
+            (MergeOverflow $ovf1, $ovf2)),
+        [(SubAttrsNoOverflow $c1, $c0, $ovf1, $ovf2)]>;
 
 // subi(subi(x, c0), c1) -> subi(x, c0 + c1)
 def SubIRHSSubConstantRHS :
@@ -131,7 +150,8 @@ def SubIRHSSubConstantRHS :
           (Arith_SubIOp $x, (ConstantLikeMatcher APIntAttr:$c0), $ovf1),
           (ConstantLikeMatcher APIntAttr:$c1), $ovf2),
         (Arith_SubIOp $x, (Arith_ConstantOp (AddIntAttrs $res, $c0, $c1)),
-            (MergeOverflow $ovf1, $ovf2))>;
+            (MergeOverflow $ovf1, $ovf2)),
+        [(AddAttrsNoOverflow $c0, $c1, $ovf1, $ovf2)]>;
 
 // subi(subi(c0, x), c1) -> subi(c0 - c1, x)
 def SubIRHSSubConstantLHS :
@@ -139,7 +159,8 @@ def SubIRHSSubConstantLHS :
           (Arith_SubIOp (ConstantLikeMatcher APIntAttr:$c0), $x, $ovf1),
           (ConstantLikeMatcher APIntAttr:$c1), $ovf2),
         (Arith_SubIOp (Arith_ConstantOp (SubIntAttrs $res, $c0, $c1)), $x,
-            (MergeOverflow $ovf1, $ovf2))>;
+            (MergeOverflow $ovf1, $ovf2)),
+        [(SubAttrsNoOverflow $c0, $c1, $ovf1, $ovf2)]>;
 
 // subi(c1, subi(x, c0)) -> subi(c0 + c1, x)
 def SubILHSSubConstantRHS :
@@ -147,7 +168,8 @@ def SubILHSSubConstantRHS :
           (ConstantLikeMatcher APIntAttr:$c1),
           (Arith_SubIOp $x, (ConstantLikeMatcher APIntAttr:$c0), $ovf1), $ovf2),
         (Arith_SubIOp (Arith_ConstantOp (AddIntAttrs $res, $c0, $c1)), $x,
-            (MergeOverflow $ovf1, $ovf2))>;
+            (MergeOverflow $ovf1, $ovf2)),
+        [(AddAttrsNoOverflow $c0, $c1, $ovf1, $ovf2)]>;
 
 // subi(c1, subi(c0, x)) -> addi(x, c1 - c0)
 def SubILHSSubConstantLHS :
@@ -155,7 +177,8 @@ def SubILHSSubConstantLHS :
           (ConstantLikeMatcher APIntAttr:$c1),
           (Arith_SubIOp (ConstantLikeMatcher APIntAttr:$c0), $x, $ovf1), $ovf2),
         (Arith_AddIOp $x, (Arith_ConstantOp (SubIntAttrs $res, $c1, $c0)),
-            (MergeOverflow $ovf1, $ovf2))>;
+            (MergeOverflow $ovf1, $ovf2)),
+        [(SubAttrsNoOverflow $c1, $c0, $ovf1, $ovf2)]>;
 
 // subi(subi(a, b), a) -> subi(0, b)
 def SubISubILHSRHSLHS :
diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
index 5f10a94522350..7f41db325262f 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
@@ -71,6 +71,67 @@ mergeOverflowFlags(IntegerOverflowFlagsAttr val1,
                                        val1.getValue() & val2.getValue());
 }
 
+// Check if the constant integer addition of c0 + c1 would overflow given the
+// merged overflow flags. Returns true if the constant fold would be
+// invalid.
+static bool addConstantWouldOverflow(const APInt &c0, const APInt &c1,
+                                     IntegerOverflowFlags flags) {
+  if (bitEnumContainsAny(flags, IntegerOverflowFlags::nsw)) {
+    bool overflow = false;
+    (void)c0.sadd_ov(c1, overflow);
+    if (overflow)
+      return true;
+  }
+  if (bitEnumContainsAny(flags, IntegerOverflowFlags::nuw)) {
+    bool overflow = false;
+    (void)c0.uadd_ov(c1, overflow);
+    if (overflow)
+      return true;
+  }
+  return false;
+}
+
+// Check if the constant integer subtraction of c0 - c1 would overflow given
+// the merged overflow flags. Returns true if the constant fold would be
+// invalid.
+static bool subConstantWouldOverflow(const APInt &c0, const APInt &c1,
+                                     IntegerOverflowFlags flags) {
+  if (bitEnumContainsAny(flags, IntegerOverflowFlags::nsw)) {
+    bool overflow = false;
+    (void)c0.ssub_ov(c1, overflow);
+    if (overflow)
+      return true;
+  }
+  if (bitEnumContainsAny(flags, IntegerOverflowFlags::nuw)) {
+    bool overflow = false;
+    (void)c0.usub_ov(c1, overflow);
+    if (overflow)
+      return true;
+  }
+  return false;
+}
+
+// Wrappers for use in TableGen pattern constraints.
+// Returns true if c0 + c1 would NOT overflow with the merged flags.
+static bool addAttrsNoOverflow(Attribute c0, Attribute c1,
+                               IntegerOverflowFlagsAttr ovf1,
+                               IntegerOverflowFlagsAttr ovf2) {
+  IntegerOverflowFlags merged = ovf1.getValue() & ovf2.getValue();
+  const APInt &v0 = llvm::cast<IntegerAttr>(c0).getValue();
+  const APInt &v1 = llvm::cast<IntegerAttr>(c1).getValue();
+  return !addConstantWouldOverflow(v0, v1, merged);
+}
+
+// Returns true if c0 - c1 would NOT overflow with the merged flags.
+static bool subAttrsNoOverflow(Attribute c0, Attribute c1,
+                               IntegerOverflowFlagsAttr ovf1,
+                               IntegerOverflowFlagsAttr ovf2) {
+  IntegerOverflowFlags merged = ovf1.getValue() & ovf2.getValue();
+  const APInt &v0 = llvm::cast<IntegerAttr>(c0).getValue();
+  const APInt &v1 = llvm::cast<IntegerAttr>(c1).getValue();
+  return !subConstantWouldOverflow(v0, v1, merged);
+}
+
 /// Invert an integer comparison predicate.
 arith::CmpIPredicate arith::invertPredicate(arith::CmpIPredicate pred) {
   switch (pred) {
diff --git a/mlir/test/Dialect/Arith/canonicalize.mlir b/mlir/test/Dialect/Arith/canonicalize.mlir
index 18665e2eb6f4a..211196e26357b 100644
--- a/mlir/test/Dialect/Arith/canonicalize.mlir
+++ b/mlir/test/Dialect/Arith/canonicalize.mlir
@@ -1310,9 +1310,14 @@ func.func @tripleSubSub1(%arg0: index) -> index {
 }
 
 // CHECK-LABEL: @tripleSubSub1Ovf
-//       CHECK:   %[[cres:.+]] = arith.constant -25 : index
-//       CHECK:   %[[add:.+]] = arith.subi %[[cres]], %arg0 overflow<nsw, nuw> : index
-//       CHECK:   return %[[add]]
+// Folding subi(subi(c0, x), c1) -> subi(c0-c1, x) is unsound when c0-c1
+// unsigned-wraps and nuw is set, because the original always produces poison
+// while the folded form may not. Do not fold.
+//       CHECK:   %[[c17:.+]] = arith.constant 17 : index
+//       CHECK:   %[[c42:.+]] = arith.constant 42 : index
+//       CHECK:   %[[sub1:.+]] = arith.subi %[[c17]], %arg0 overflow<nsw, nuw>
+//       CHECK:   %[[sub2:.+]] = arith.subi %[[sub1]], %[[c42]] overflow<nsw, nuw>
+//       CHECK:   return %[[sub2]]
 func.func @tripleSubSub1Ovf(%arg0: index) -> index {
   %c17 = arith.constant 17 : index
   %c42 = arith.constant 42 : index
@@ -3564,3 +3569,41 @@ func.func @convertf_fold_f8() -> f8E5M2 {
   return %result : f8E5M2
 }
 
+// -----
+
+// addi(addi(x, c0), c1) -> addi(x, c0+c1) must not fold when c0+c1 wraps
+// under the merged overflow flags. Here c0=-2, c1=-1 for i2: -2 + -1 = -3,
+// which wraps to +1 in i2. Since both ops have nsw the fold is unsound
+// (for x=1: original gives -2 with no poison, folded gives poison).
+//
+// CHECK-LABEL: @addi_no_fold_const_signed_overflow_nsw
+//       CHECK:   %[[cm2:.+]] = arith.constant -2 : i2
+//       CHECK:   %[[cm1:.+]] = arith.constant -1 : i2
+//       CHECK:   %[[t:.+]] = arith.addi %arg0, %[[cm2]] overflow<nsw>
+//       CHECK:   %[[r:.+]] = arith.addi %[[t]], %[[cm1]] overflow<nsw>
+//       CHECK:   return %[[r]]
+func.func @addi_no_fold_const_signed_overflow_nsw(%arg0: i2) -> i2 {
+  %cm2 = arith.constant -2 : i2
+  %cm1 = arith.constant -1 : i2
+  %t = arith.addi %arg0, %cm2 overflow<nsw> : i2
+  %r = arith.addi %t, %cm1 overflow<nsw> : i2
+  return %r : i2
+}
+
+// -----
+
+// When c0+c1 does not overflow, the fold is sound.
+// c0=1, c1=2 for i4: 1+2=3, no overflow in signed or unsigned sense.
+//
+// CHECK-LABEL: @addi_fold_const_no_overflow
+//       CHECK:   %[[c3:.+]] = arith.constant 3 : i4
+//       CHECK:   %[[r:.+]] = arith.addi %arg0, %[[c3]] overflow<nsw, nuw>
+//       CHECK:   return %[[r]]
+func.func @addi_fold_const_no_overflow(%arg0: i4) -> i4 {
+  %c1 = arith.constant 1 : i4
+  %c2 = arith.constant 2 : i4
+  %t = arith.addi %arg0, %c1 overflow<nsw, nuw> : i4
+  %r = arith.addi %t, %c2 overflow<nsw, nuw> : i4
+  return %r : i4
+}
+



More information about the Mlir-commits mailing list