[Mlir-commits] [mlir] [MLIR][Arith] Fix unsound addi/subi constant-fold when merged overflow flags would be violated (PR #189164)

Mehdi Amini llvmlistbot at llvm.org
Sun Mar 29 06:21:25 PDT 2026


https://github.com/joker-eph updated https://github.com/llvm/llvm-project/pull/189164

>From 739edcefa9d5991b82537ff685393427e2a89875 Mon Sep 17 00:00:00 2001
From: Mehdi Amini <joker.eph at gmail.com>
Date: Sat, 28 Mar 2026 05:04:18 -0700
Subject: [PATCH 1/2] [MLIR][Arith] Fix unsound addi/subi constant-fold when
 merged overflow flags would be violated

The `AddIAddConstant` and related patterns fold chains of additions and
subtractions involving a constant into a single operation with the
combined constant. For example:

  addi(addi(x, c0) [nsw], c1) [nsw]  ->  addi(x, c0+c1) [nsw]

This transformation is only valid when the combined constant c0+c1 does
not itself overflow under the merged overflow flags. If the constant sum
wraps (e.g. c0=-2, c1=-1 for i2: -2+-1=-3 wraps to +1), the folded op
produces poison for inputs where the original did not:

  For x=1 (i2): original gives 1+(-2)=-1 then -1+(-1)=-2 (no poison),
  but the fold produces 1+1=2 which overflows i2 under nsw -> poison.

Fix by adding `AddAttrsNoOverflow` / `SubAttrsNoOverflow` constraints to
all affected patterns. These constraints check that the combined constant
does not overflow signed (for nsw) or unsigned (for nuw) before allowing
the fold. The constraints are implemented as C++ helpers in ArithOps.cpp
that are called from the TableGen pattern guard predicates.

Also update the existing test `tripleSubSub1Ovf` whose prior expectation
was incorrect: `subi(subi(17, x), 42) [nsw,nuw]` has the combined
constant 17-42=-25 which wraps unsigned, so the fold is not applied.

Fixes #140328

Assisted-by: Claude Code
---
 .../Dialect/Arith/IR/ArithCanonicalization.td | 39 +++++++++---
 mlir/lib/Dialect/Arith/IR/ArithOps.cpp        | 61 +++++++++++++++++++
 mlir/test/Dialect/Arith/canonicalize.mlir     | 49 ++++++++++++++-
 3 files changed, 138 insertions(+), 11 deletions(-)

diff --git a/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td b/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
index e22fc1d478e4f..8eb6aeed6dd79 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
+++ b/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
@@ -30,6 +30,19 @@ def MergeOverflow : NativeCodeCall<"mergeOverflowFlags($0, $1)">;
 // Default overflow flag (all wraparounds allowed).
 defvar DefOverflow = ConstantEnumCase<Arith_IntegerOverflowAttr, "none">;
 
+// Constraint: constant c0+c1 does not overflow under the merged flags of ovf1
+// and ovf2. Without this guard, folding addi(addi(x, c0), c1) -> addi(x, c0+c1)
+// is unsound when the constant sum wraps and the result carries nsw/nuw.
+def AddAttrsNoOverflow :
+    Constraint<CPred<"addAttrsNoOverflow($0, $1, $2, $3)">,
+               "constant addition does not overflow with merged flags">;
+
+// Constraint: constant c0-c1 does not overflow under the merged flags of ovf1
+// and ovf2.
+def SubAttrsNoOverflow :
+    Constraint<CPred<"subAttrsNoOverflow($0, $1, $2, $3)">,
+               "constant subtraction does not overflow with merged flags">;
+
 class cast<string type> : NativeCodeCall<"::mlir::cast<" # type # ">($0)">;
 
 //===----------------------------------------------------------------------===//
@@ -40,12 +53,15 @@ class cast<string type> : NativeCodeCall<"::mlir::cast<" # type # ">($0)">;
 // as the second operand.
 
 // addi(addi(x, c0), c1) -> addi(x, c0 + c1)
+// Guard: the constant sum c0+c1 must not overflow under the merged flags,
+// otherwise the folded op would produce poison where the originals did not.
 def AddIAddConstant :
     Pat<(Arith_AddIOp:$res
           (Arith_AddIOp $x, (ConstantLikeMatcher APIntAttr:$c0), $ovf1),
           (ConstantLikeMatcher APIntAttr:$c1), $ovf2),
         (Arith_AddIOp $x, (Arith_ConstantOp (AddIntAttrs $res, $c0, $c1)),
-            (MergeOverflow $ovf1, $ovf2))>;
+            (MergeOverflow $ovf1, $ovf2)),
+        [(AddAttrsNoOverflow $c0, $c1, $ovf1, $ovf2)]>;
 
 // addi(subi(x, c0), c1) -> addi(x, c1 - c0)
 def AddISubConstantRHS :
@@ -53,7 +69,8 @@ def AddISubConstantRHS :
           (Arith_SubIOp $x, (ConstantLikeMatcher APIntAttr:$c0), $ovf1),
           (ConstantLikeMatcher APIntAttr:$c1), $ovf2),
         (Arith_AddIOp $x, (Arith_ConstantOp (SubIntAttrs $res, $c1, $c0)),
-            (MergeOverflow $ovf1, $ovf2))>;
+            (MergeOverflow $ovf1, $ovf2)),
+        [(SubAttrsNoOverflow $c1, $c0, $ovf1, $ovf2)]>;
 
 // addi(subi(c0, x), c1) -> subi(c0 + c1, x)
 def AddISubConstantLHS :
@@ -61,7 +78,8 @@ def AddISubConstantLHS :
           (Arith_SubIOp (ConstantLikeMatcher APIntAttr:$c0), $x, $ovf1),
           (ConstantLikeMatcher APIntAttr:$c1), $ovf2),
         (Arith_SubIOp (Arith_ConstantOp (AddIntAttrs $res, $c0, $c1)), $x,
-            (MergeOverflow $ovf1, $ovf2))>;
+            (MergeOverflow $ovf1, $ovf2)),
+        [(AddAttrsNoOverflow $c0, $c1, $ovf1, $ovf2)]>;
 
 def IsScalarOrSplatNegativeOne :
     Constraint<And<[
@@ -123,7 +141,8 @@ def SubILHSAddConstant :
           (ConstantLikeMatcher APIntAttr:$c1),
           (Arith_AddIOp $x, (ConstantLikeMatcher APIntAttr:$c0), $ovf1), $ovf2),
         (Arith_SubIOp (Arith_ConstantOp (SubIntAttrs $res, $c1, $c0)), $x,
-            (MergeOverflow $ovf1, $ovf2))>;
+            (MergeOverflow $ovf1, $ovf2)),
+        [(SubAttrsNoOverflow $c1, $c0, $ovf1, $ovf2)]>;
 
 // subi(subi(x, c0), c1) -> subi(x, c0 + c1)
 def SubIRHSSubConstantRHS :
@@ -131,7 +150,8 @@ def SubIRHSSubConstantRHS :
           (Arith_SubIOp $x, (ConstantLikeMatcher APIntAttr:$c0), $ovf1),
           (ConstantLikeMatcher APIntAttr:$c1), $ovf2),
         (Arith_SubIOp $x, (Arith_ConstantOp (AddIntAttrs $res, $c0, $c1)),
-            (MergeOverflow $ovf1, $ovf2))>;
+            (MergeOverflow $ovf1, $ovf2)),
+        [(AddAttrsNoOverflow $c0, $c1, $ovf1, $ovf2)]>;
 
 // subi(subi(c0, x), c1) -> subi(c0 - c1, x)
 def SubIRHSSubConstantLHS :
@@ -139,7 +159,8 @@ def SubIRHSSubConstantLHS :
           (Arith_SubIOp (ConstantLikeMatcher APIntAttr:$c0), $x, $ovf1),
           (ConstantLikeMatcher APIntAttr:$c1), $ovf2),
         (Arith_SubIOp (Arith_ConstantOp (SubIntAttrs $res, $c0, $c1)), $x,
-            (MergeOverflow $ovf1, $ovf2))>;
+            (MergeOverflow $ovf1, $ovf2)),
+        [(SubAttrsNoOverflow $c0, $c1, $ovf1, $ovf2)]>;
 
 // subi(c1, subi(x, c0)) -> subi(c0 + c1, x)
 def SubILHSSubConstantRHS :
@@ -147,7 +168,8 @@ def SubILHSSubConstantRHS :
           (ConstantLikeMatcher APIntAttr:$c1),
           (Arith_SubIOp $x, (ConstantLikeMatcher APIntAttr:$c0), $ovf1), $ovf2),
         (Arith_SubIOp (Arith_ConstantOp (AddIntAttrs $res, $c0, $c1)), $x,
-            (MergeOverflow $ovf1, $ovf2))>;
+            (MergeOverflow $ovf1, $ovf2)),
+        [(AddAttrsNoOverflow $c0, $c1, $ovf1, $ovf2)]>;
 
 // subi(c1, subi(c0, x)) -> addi(x, c1 - c0)
 def SubILHSSubConstantLHS :
@@ -155,7 +177,8 @@ def SubILHSSubConstantLHS :
           (ConstantLikeMatcher APIntAttr:$c1),
           (Arith_SubIOp (ConstantLikeMatcher APIntAttr:$c0), $x, $ovf1), $ovf2),
         (Arith_AddIOp $x, (Arith_ConstantOp (SubIntAttrs $res, $c1, $c0)),
-            (MergeOverflow $ovf1, $ovf2))>;
+            (MergeOverflow $ovf1, $ovf2)),
+        [(SubAttrsNoOverflow $c1, $c0, $ovf1, $ovf2)]>;
 
 // subi(subi(a, b), a) -> subi(0, b)
 def SubISubILHSRHSLHS :
diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
index 5f10a94522350..7f41db325262f 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
@@ -71,6 +71,67 @@ mergeOverflowFlags(IntegerOverflowFlagsAttr val1,
                                        val1.getValue() & val2.getValue());
 }
 
+// Check if the constant integer addition of c0 + c1 would overflow given the
+// merged overflow flags. Returns true if the constant fold would be
+// invalid.
+static bool addConstantWouldOverflow(const APInt &c0, const APInt &c1,
+                                     IntegerOverflowFlags flags) {
+  if (bitEnumContainsAny(flags, IntegerOverflowFlags::nsw)) {
+    bool overflow = false;
+    (void)c0.sadd_ov(c1, overflow);
+    if (overflow)
+      return true;
+  }
+  if (bitEnumContainsAny(flags, IntegerOverflowFlags::nuw)) {
+    bool overflow = false;
+    (void)c0.uadd_ov(c1, overflow);
+    if (overflow)
+      return true;
+  }
+  return false;
+}
+
+// Check if the constant integer subtraction of c0 - c1 would overflow given
+// the merged overflow flags. Returns true if the constant fold would be
+// invalid.
+static bool subConstantWouldOverflow(const APInt &c0, const APInt &c1,
+                                     IntegerOverflowFlags flags) {
+  if (bitEnumContainsAny(flags, IntegerOverflowFlags::nsw)) {
+    bool overflow = false;
+    (void)c0.ssub_ov(c1, overflow);
+    if (overflow)
+      return true;
+  }
+  if (bitEnumContainsAny(flags, IntegerOverflowFlags::nuw)) {
+    bool overflow = false;
+    (void)c0.usub_ov(c1, overflow);
+    if (overflow)
+      return true;
+  }
+  return false;
+}
+
+// Wrappers for use in TableGen pattern constraints.
+// Returns true if c0 + c1 would NOT overflow with the merged flags.
+static bool addAttrsNoOverflow(Attribute c0, Attribute c1,
+                               IntegerOverflowFlagsAttr ovf1,
+                               IntegerOverflowFlagsAttr ovf2) {
+  IntegerOverflowFlags merged = ovf1.getValue() & ovf2.getValue();
+  const APInt &v0 = llvm::cast<IntegerAttr>(c0).getValue();
+  const APInt &v1 = llvm::cast<IntegerAttr>(c1).getValue();
+  return !addConstantWouldOverflow(v0, v1, merged);
+}
+
+// Returns true if c0 - c1 would NOT overflow with the merged flags.
+static bool subAttrsNoOverflow(Attribute c0, Attribute c1,
+                               IntegerOverflowFlagsAttr ovf1,
+                               IntegerOverflowFlagsAttr ovf2) {
+  IntegerOverflowFlags merged = ovf1.getValue() & ovf2.getValue();
+  const APInt &v0 = llvm::cast<IntegerAttr>(c0).getValue();
+  const APInt &v1 = llvm::cast<IntegerAttr>(c1).getValue();
+  return !subConstantWouldOverflow(v0, v1, merged);
+}
+
 /// Invert an integer comparison predicate.
 arith::CmpIPredicate arith::invertPredicate(arith::CmpIPredicate pred) {
   switch (pred) {
diff --git a/mlir/test/Dialect/Arith/canonicalize.mlir b/mlir/test/Dialect/Arith/canonicalize.mlir
index 18665e2eb6f4a..211196e26357b 100644
--- a/mlir/test/Dialect/Arith/canonicalize.mlir
+++ b/mlir/test/Dialect/Arith/canonicalize.mlir
@@ -1310,9 +1310,14 @@ func.func @tripleSubSub1(%arg0: index) -> index {
 }
 
 // CHECK-LABEL: @tripleSubSub1Ovf
-//       CHECK:   %[[cres:.+]] = arith.constant -25 : index
-//       CHECK:   %[[add:.+]] = arith.subi %[[cres]], %arg0 overflow<nsw, nuw> : index
-//       CHECK:   return %[[add]]
+// Folding subi(subi(c0, x), c1) -> subi(c0-c1, x) is unsound when c0-c1
+// unsigned-wraps and nuw is set, because the original always produces poison
+// while the folded form may not. Do not fold.
+//       CHECK:   %[[c17:.+]] = arith.constant 17 : index
+//       CHECK:   %[[c42:.+]] = arith.constant 42 : index
+//       CHECK:   %[[sub1:.+]] = arith.subi %[[c17]], %arg0 overflow<nsw, nuw>
+//       CHECK:   %[[sub2:.+]] = arith.subi %[[sub1]], %[[c42]] overflow<nsw, nuw>
+//       CHECK:   return %[[sub2]]
 func.func @tripleSubSub1Ovf(%arg0: index) -> index {
   %c17 = arith.constant 17 : index
   %c42 = arith.constant 42 : index
@@ -3564,3 +3569,41 @@ func.func @convertf_fold_f8() -> f8E5M2 {
   return %result : f8E5M2
 }
 
+// -----
+
+// addi(addi(x, c0), c1) -> addi(x, c0+c1) must not fold when c0+c1 wraps
+// under the merged overflow flags. Here c0=-2, c1=-1 for i2: -2 + -1 = -3,
+// which wraps to +1 in i2. Since both ops have nsw the fold is unsound
+// (for x=1: original gives -2 with no poison, folded gives poison).
+//
+// CHECK-LABEL: @addi_no_fold_const_signed_overflow_nsw
+//       CHECK:   %[[cm2:.+]] = arith.constant -2 : i2
+//       CHECK:   %[[cm1:.+]] = arith.constant -1 : i2
+//       CHECK:   %[[t:.+]] = arith.addi %arg0, %[[cm2]] overflow<nsw>
+//       CHECK:   %[[r:.+]] = arith.addi %[[t]], %[[cm1]] overflow<nsw>
+//       CHECK:   return %[[r]]
+func.func @addi_no_fold_const_signed_overflow_nsw(%arg0: i2) -> i2 {
+  %cm2 = arith.constant -2 : i2
+  %cm1 = arith.constant -1 : i2
+  %t = arith.addi %arg0, %cm2 overflow<nsw> : i2
+  %r = arith.addi %t, %cm1 overflow<nsw> : i2
+  return %r : i2
+}
+
+// -----
+
+// When c0+c1 does not overflow, the fold is sound.
+// c0=1, c1=2 for i4: 1+2=3, no overflow in signed or unsigned sense.
+//
+// CHECK-LABEL: @addi_fold_const_no_overflow
+//       CHECK:   %[[c3:.+]] = arith.constant 3 : i4
+//       CHECK:   %[[r:.+]] = arith.addi %arg0, %[[c3]] overflow<nsw, nuw>
+//       CHECK:   return %[[r]]
+func.func @addi_fold_const_no_overflow(%arg0: i4) -> i4 {
+  %c1 = arith.constant 1 : i4
+  %c2 = arith.constant 2 : i4
+  %t = arith.addi %arg0, %c1 overflow<nsw, nuw> : i4
+  %r = arith.addi %t, %c2 overflow<nsw, nuw> : i4
+  return %r : i4
+}
+

>From ad93f61715697246b7231e5052e32457ea1867a0 Mon Sep 17 00:00:00 2001
From: Mehdi Amini <joker.eph at gmail.com>
Date: Sat, 28 Mar 2026 19:14:18 -0700
Subject: [PATCH 2/2] [MLIR][Arith] Fix unsound MulIMulIConstant fold when
 constant product overflows

The `MulIMulIConstant` pattern folds `muli(muli(x, c0), c1)` into
`muli(x, c0*c1)` with merged overflow flags. This is only valid when the
combined constant c0*c1 does not itself overflow under those flags.

Example: in i8 with nsw,
  muli(muli(x, 16) [nsw], 16) [nsw]
the fold produces muli(x, 0) [nsw] because 16*16=256 wraps to 0 in i8.
The folded op returns 0 for all inputs, while the original would only
produce poison when x * 16 overflows.

Fix by adding a `MulAttrsNoOverflow` constraint (analogous to the already-
present `AddAttrsNoOverflow`/`SubAttrsNoOverflow`) that checks c0*c1 does
not overflow signed (for nsw) or unsigned (for nuw) before allowing the fold.

Part of the same root-cause class as the addi/subi fix in this branch.

Fixes #140328

Assisted-by: Claude Code
---
 .../Dialect/Arith/IR/ArithCanonicalization.td |  15 +-
 mlir/lib/Dialect/Arith/IR/ArithOps.cpp        |  28 +++
 mlir/test/Dialect/Arith/canonicalize.mlir     | 176 +++++++++++++++++-
 3 files changed, 215 insertions(+), 4 deletions(-)

diff --git a/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td b/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
index 8eb6aeed6dd79..2582cade66252 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
+++ b/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
@@ -38,11 +38,19 @@ def AddAttrsNoOverflow :
                "constant addition does not overflow with merged flags">;
 
 // Constraint: constant c0-c1 does not overflow under the merged flags of ovf1
-// and ovf2.
+// and ovf2. Without this guard, folding addi(subi(x, c0), c1) -> addi(x, c1-c0)
+// is unsound when the constant difference wraps and the result carries nsw/nuw.
 def SubAttrsNoOverflow :
     Constraint<CPred<"subAttrsNoOverflow($0, $1, $2, $3)">,
                "constant subtraction does not overflow with merged flags">;
 
+// Constraint: constant c0*c1 does not overflow under the merged flags of ovf1
+// and ovf2. Without this guard, folding muli(muli(x, c0), c1) -> muli(x, c0*c1)
+// is unsound when the constant product wraps and the result carries nsw/nuw.
+def MulAttrsNoOverflow :
+    Constraint<CPred<"mulAttrsNoOverflow($0, $1, $2, $3)">,
+               "constant multiplication does not overflow with merged flags">;
+
 class cast<string type> : NativeCodeCall<"::mlir::cast<" # type # ">($0)">;
 
 //===----------------------------------------------------------------------===//
@@ -103,6 +111,8 @@ def AddIMulNegativeOneLhs :
         [(IsScalarOrSplatNegativeOne $c0)]>;
 
 // muli(muli(x, c0), c1) -> muli(x, c0 * c1)
+// Guard: the constant product c0*c1 must not overflow under the merged flags,
+// otherwise the folded op would produce poison where the originals did not.
 def MulIMulIConstant :
     Pat<(Arith_MulIOp:$res
           (Arith_MulIOp $x, (ConstantLikeMatcher APIntAttr:$c0), $ovf1),
@@ -110,7 +120,8 @@ def MulIMulIConstant :
         (Arith_MulIOp $x, (Arith_ConstantOp (MulIntAttrs $res, $c0, $c1)),
             (MergeOverflow $ovf1, $ovf2)),
              [(Constraint<CPred<"$0.getType() == cast<IntegerAttr>($1).getType()">> $res, $c0),
-              (Constraint<CPred<"$0.getType() == cast<IntegerAttr>($1).getType()">> $res, $c1)]>;
+              (Constraint<CPred<"$0.getType() == cast<IntegerAttr>($1).getType()">> $res, $c1),
+              (MulAttrsNoOverflow $c0, $c1, $ovf1, $ovf2)]>;
 
 //===----------------------------------------------------------------------===//
 // AddUIExtendedOp
diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
index 7f41db325262f..d55b10ec69ff5 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
@@ -117,6 +117,8 @@ static bool addAttrsNoOverflow(Attribute c0, Attribute c1,
                                IntegerOverflowFlagsAttr ovf1,
                                IntegerOverflowFlagsAttr ovf2) {
   IntegerOverflowFlags merged = ovf1.getValue() & ovf2.getValue();
+  if (merged == IntegerOverflowFlags::none)
+    return true;
   const APInt &v0 = llvm::cast<IntegerAttr>(c0).getValue();
   const APInt &v1 = llvm::cast<IntegerAttr>(c1).getValue();
   return !addConstantWouldOverflow(v0, v1, merged);
@@ -127,11 +129,37 @@ static bool subAttrsNoOverflow(Attribute c0, Attribute c1,
                                IntegerOverflowFlagsAttr ovf1,
                                IntegerOverflowFlagsAttr ovf2) {
   IntegerOverflowFlags merged = ovf1.getValue() & ovf2.getValue();
+  if (merged == IntegerOverflowFlags::none)
+    return true;
   const APInt &v0 = llvm::cast<IntegerAttr>(c0).getValue();
   const APInt &v1 = llvm::cast<IntegerAttr>(c1).getValue();
   return !subConstantWouldOverflow(v0, v1, merged);
 }
 
+// Returns true if c0 * c1 would NOT overflow with the merged flags.
+static bool mulAttrsNoOverflow(Attribute c0, Attribute c1,
+                               IntegerOverflowFlagsAttr ovf1,
+                               IntegerOverflowFlagsAttr ovf2) {
+  IntegerOverflowFlags merged = ovf1.getValue() & ovf2.getValue();
+  if (merged == IntegerOverflowFlags::none)
+    return true;
+  const APInt &v0 = llvm::cast<IntegerAttr>(c0).getValue();
+  const APInt &v1 = llvm::cast<IntegerAttr>(c1).getValue();
+  if (bitEnumContainsAny(merged, IntegerOverflowFlags::nsw)) {
+    bool overflow = false;
+    (void)v0.smul_ov(v1, overflow);
+    if (overflow)
+      return false;
+  }
+  if (bitEnumContainsAny(merged, IntegerOverflowFlags::nuw)) {
+    bool overflow = false;
+    (void)v0.umul_ov(v1, overflow);
+    if (overflow)
+      return false;
+  }
+  return true;
+}
+
 /// Invert an integer comparison predicate.
 arith::CmpIPredicate arith::invertPredicate(arith::CmpIPredicate pred) {
   switch (pred) {
diff --git a/mlir/test/Dialect/Arith/canonicalize.mlir b/mlir/test/Dialect/Arith/canonicalize.mlir
index 211196e26357b..f26d75f62eeb6 100644
--- a/mlir/test/Dialect/Arith/canonicalize.mlir
+++ b/mlir/test/Dialect/Arith/canonicalize.mlir
@@ -1311,8 +1311,9 @@ func.func @tripleSubSub1(%arg0: index) -> index {
 
 // CHECK-LABEL: @tripleSubSub1Ovf
 // Folding subi(subi(c0, x), c1) -> subi(c0-c1, x) is unsound when c0-c1
-// unsigned-wraps and nuw is set, because the original always produces poison
-// while the folded form may not. Do not fold.
+// unsigned-wraps and nuw is set: the folded constant 17-42=-25 wraps unsigned,
+// so the folded op would produce poison on more inputs than the original.
+// Do not fold.
 //       CHECK:   %[[c17:.+]] = arith.constant 17 : index
 //       CHECK:   %[[c42:.+]] = arith.constant 42 : index
 //       CHECK:   %[[sub1:.+]] = arith.subi %[[c17]], %arg0 overflow<nsw, nuw>
@@ -3607,3 +3608,174 @@ func.func @addi_fold_const_no_overflow(%arg0: i4) -> i4 {
   return %r : i4
 }
 
+// -----
+
+// addi(addi(x, c0), c1) must not fold when c0+c1 wraps unsigned and nuw is set.
+// c0=-2 (=6 unsigned), c1=3 for i3: 6+3=9 wraps unsigned (i3 unsigned max=7),
+// but -2+3=1 which does NOT overflow signed. With nuw only the fold is unsound.
+//
+// CHECK-LABEL: @addi_no_fold_const_unsigned_overflow_nuw
+//       CHECK:   %[[cm2:.+]] = arith.constant -2 : i3
+//       CHECK:   %[[c3:.+]] = arith.constant 3 : i3
+//       CHECK:   %[[t:.+]] = arith.addi %arg0, %[[cm2]] overflow<nuw>
+//       CHECK:   %[[r:.+]] = arith.addi %[[t]], %[[c3]] overflow<nuw>
+//       CHECK:   return %[[r]]
+func.func @addi_no_fold_const_unsigned_overflow_nuw(%arg0: i3) -> i3 {
+  %cm2 = arith.constant -2 : i3
+  %c3 = arith.constant 3 : i3
+  %t = arith.addi %arg0, %cm2 overflow<nuw> : i3
+  %r = arith.addi %t, %c3 overflow<nuw> : i3
+  return %r : i3
+}
+
+// -----
+
+// subi(subi(x, c0), c1) -> subi(x, c0+c1) must not fold when c0+c1 overflows
+// under the merged flags. c0=3, c1=1 for i3: 3+1=4 overflows signed (i3
+// signed max = 3). With nsw the fold is unsound.
+//
+// CHECK-LABEL: @subi_rhs_no_fold_const_overflow
+//       CHECK:   %[[c3:.+]] = arith.constant 3 : i3
+//       CHECK:   %[[c1:.+]] = arith.constant 1 : i3
+//       CHECK:   %[[t:.+]] = arith.subi %arg0, %[[c3]] overflow<nsw>
+//       CHECK:   %[[r:.+]] = arith.subi %[[t]], %[[c1]] overflow<nsw>
+//       CHECK:   return %[[r]]
+func.func @subi_rhs_no_fold_const_overflow(%arg0: i3) -> i3 {
+  %c3 = arith.constant 3 : i3
+  %c1 = arith.constant 1 : i3
+  %t = arith.subi %arg0, %c3 overflow<nsw> : i3
+  %r = arith.subi %t, %c1 overflow<nsw> : i3
+  return %r : i3
+}
+
+// -----
+
+// addi(subi(x, c0), c1) -> addi(x, c1-c0) must not fold when c1-c0 overflows
+// under the merged flags. c0=-3, c1=1 for i3: 1-(-3)=4 overflows signed.
+//
+// CHECK-LABEL: @addi_subi_rhs_no_fold_const_overflow
+//       CHECK:   %[[cm3:.+]] = arith.constant -3 : i3
+//       CHECK:   %[[c1:.+]] = arith.constant 1 : i3
+//       CHECK:   %[[t:.+]] = arith.subi %arg0, %[[cm3]] overflow<nsw>
+//       CHECK:   %[[r:.+]] = arith.addi %[[t]], %[[c1]] overflow<nsw>
+//       CHECK:   return %[[r]]
+func.func @addi_subi_rhs_no_fold_const_overflow(%arg0: i3) -> i3 {
+  %cm3 = arith.constant -3 : i3
+  %c1 = arith.constant 1 : i3
+  %t = arith.subi %arg0, %cm3 overflow<nsw> : i3
+  %r = arith.addi %t, %c1 overflow<nsw> : i3
+  return %r : i3
+}
+
+// -----
+
+// muli(muli(x, c0), c1) -> muli(x, c0*c1) must not fold when c0*c1 overflows
+// under the merged flags. c0=3, c1=3 for i4: 3*3=9 overflows signed (i4
+// signed max = 7). With nsw the fold is unsound.
+//
+// CHECK-LABEL: @muli_no_fold_const_overflow
+//       CHECK:   %[[c3:.+]] = arith.constant 3 : i4
+//       CHECK:   %[[t:.+]] = arith.muli %arg0, %[[c3]] overflow<nsw>
+//       CHECK:   %[[r:.+]] = arith.muli %[[t]], %[[c3]] overflow<nsw>
+//       CHECK:   return %[[r]]
+func.func @muli_no_fold_const_overflow(%arg0: i4) -> i4 {
+  %c3a = arith.constant 3 : i4
+  %c3b = arith.constant 3 : i4
+  %t = arith.muli %arg0, %c3a overflow<nsw> : i4
+  %r = arith.muli %t, %c3b overflow<nsw> : i4
+  return %r : i4
+}
+
+// -----
+
+// When c0*c1 does not overflow, the fold is sound.
+// c0=2, c1=2 for i4: 2*2=4, no overflow signed or unsigned.
+//
+// CHECK-LABEL: @muli_fold_const_no_overflow
+//       CHECK:   %[[c4:.+]] = arith.constant 4 : i4
+//       CHECK:   %[[r:.+]] = arith.muli %arg0, %[[c4]] overflow<nsw, nuw>
+//       CHECK:   return %[[r]]
+func.func @muli_fold_const_no_overflow(%arg0: i4) -> i4 {
+  %c2a = arith.constant 2 : i4
+  %c2b = arith.constant 2 : i4
+  %t = arith.muli %arg0, %c2a overflow<nsw, nuw> : i4
+  %r = arith.muli %t, %c2b overflow<nsw, nuw> : i4
+  return %r : i4
+}
+
+// -----
+
+// addi(subi(c0, x), c1) -> subi(c0+c1, x) must not fold when c0+c1 overflows
+// under the merged flags. c0=5, c1=4 for i4: 5+4=9 overflows signed (i4 max=7).
+//
+// CHECK-LABEL: @addi_subi_lhs_no_fold_const_overflow
+//       CHECK:   %[[c5:.+]] = arith.constant 5 : i4
+//       CHECK:   %[[c4:.+]] = arith.constant 4 : i4
+//       CHECK:   %[[t:.+]] = arith.subi %[[c5]], %arg0 overflow<nsw>
+//       CHECK:   %[[r:.+]] = arith.addi %[[t]], %[[c4]] overflow<nsw>
+//       CHECK:   return %[[r]]
+func.func @addi_subi_lhs_no_fold_const_overflow(%arg0: i4) -> i4 {
+  %c5 = arith.constant 5 : i4
+  %c4 = arith.constant 4 : i4
+  %t = arith.subi %c5, %arg0 overflow<nsw> : i4
+  %r = arith.addi %t, %c4 overflow<nsw> : i4
+  return %r : i4
+}
+
+// -----
+
+// subi(c1, addi(x, c0)) -> subi(c1-c0, x) must not fold when c1-c0 overflows
+// under the merged flags. c1=3, c0=-5 for i4: 3-(-5)=8 overflows signed (i4 max=7).
+//
+// CHECK-LABEL: @subi_lhs_add_no_fold_const_overflow
+//       CHECK:   %[[cm5:.+]] = arith.constant -5 : i4
+//       CHECK:   %[[c3:.+]] = arith.constant 3 : i4
+//       CHECK:   %[[t:.+]] = arith.addi %arg0, %[[cm5]] overflow<nsw>
+//       CHECK:   %[[r:.+]] = arith.subi %[[c3]], %[[t]] overflow<nsw>
+//       CHECK:   return %[[r]]
+func.func @subi_lhs_add_no_fold_const_overflow(%arg0: i4) -> i4 {
+  %cm5 = arith.constant -5 : i4
+  %c3 = arith.constant 3 : i4
+  %t = arith.addi %arg0, %cm5 overflow<nsw> : i4
+  %r = arith.subi %c3, %t overflow<nsw> : i4
+  return %r : i4
+}
+
+// -----
+
+// subi(c1, subi(x, c0)) -> subi(c0+c1, x) must not fold when c0+c1 overflows
+// under the merged flags. c0=5, c1=4 for i4: 5+4=9 overflows signed (i4 max=7).
+//
+// CHECK-LABEL: @subi_lhs_subi_rhs_no_fold_const_overflow
+//       CHECK:   %[[c5:.+]] = arith.constant 5 : i4
+//       CHECK:   %[[c4:.+]] = arith.constant 4 : i4
+//       CHECK:   %[[t:.+]] = arith.subi %arg0, %[[c5]] overflow<nsw>
+//       CHECK:   %[[r:.+]] = arith.subi %[[c4]], %[[t]] overflow<nsw>
+//       CHECK:   return %[[r]]
+func.func @subi_lhs_subi_rhs_no_fold_const_overflow(%arg0: i4) -> i4 {
+  %c5 = arith.constant 5 : i4
+  %c4 = arith.constant 4 : i4
+  %t = arith.subi %arg0, %c5 overflow<nsw> : i4
+  %r = arith.subi %c4, %t overflow<nsw> : i4
+  return %r : i4
+}
+
+// -----
+
+// subi(c1, subi(c0, x)) -> addi(x, c1-c0) must not fold when c1-c0 overflows
+// under the merged flags. c1=3, c0=-5 for i4: 3-(-5)=8 overflows signed (i4 max=7).
+//
+// CHECK-LABEL: @subi_lhs_subi_lhs_no_fold_const_overflow
+//       CHECK:   %[[cm5:.+]] = arith.constant -5 : i4
+//       CHECK:   %[[c3:.+]] = arith.constant 3 : i4
+//       CHECK:   %[[t:.+]] = arith.subi %[[cm5]], %arg0 overflow<nsw>
+//       CHECK:   %[[r:.+]] = arith.subi %[[c3]], %[[t]] overflow<nsw>
+//       CHECK:   return %[[r]]
+func.func @subi_lhs_subi_lhs_no_fold_const_overflow(%arg0: i4) -> i4 {
+  %cm5 = arith.constant -5 : i4
+  %c3 = arith.constant 3 : i4
+  %t = arith.subi %cm5, %arg0 overflow<nsw> : i4
+  %r = arith.subi %c3, %t overflow<nsw> : i4
+  return %r : i4
+}
+



More information about the Mlir-commits mailing list