[Mlir-commits] [mlir] [MLIR][Arith] Fix unsound addi/subi constant-fold when merged overflow flags would be violated (PR #189164)
Mehdi Amini
llvmlistbot at llvm.org
Sat Mar 28 06:03:39 PDT 2026
https://github.com/joker-eph updated https://github.com/llvm/llvm-project/pull/189164
>From af049a4dff50efdd72c997387bd380d4a50c2677 Mon Sep 17 00:00:00 2001
From: Mehdi Amini <joker.eph at gmail.com>
Date: Sat, 28 Mar 2026 05:04:18 -0700
Subject: [PATCH 1/2] [MLIR][Arith] Fix unsound addi/subi constant-fold when
merged overflow flags would be violated
The `AddIAddConstant` and related patterns fold chains of additions and
subtractions involving a constant into a single operation with the
combined constant. For example:
addi(addi(x, c0) [nsw], c1) [nsw] -> addi(x, c0+c1) [nsw]
This transformation is only valid when the combined constant c0+c1 does
not itself overflow under the merged overflow flags. If the constant sum
wraps (e.g. c0=-2, c1=-1 for i2: -2+-1=-3 wraps to +1), the folded op
produces poison for inputs where the original did not:
For x=1 (i2): original gives 1+(-2)=-1 then -1+(-1)=-2 (no poison),
but the fold produces 1+1=2 which overflows i2 under nsw -> poison.
Fix by adding `AddAttrsNoOverflow` / `SubAttrsNoOverflow` constraints to
all affected patterns. These constraints check that the combined constant
does not overflow signed (for nsw) or unsigned (for nuw) before allowing
the fold. The constraints are implemented as C++ helpers in ArithOps.cpp
that are called from the TableGen pattern guard predicates.
Also update the existing test `tripleSubSub1Ovf` whose prior expectation
was incorrect: `subi(subi(17, x), 42) [nsw,nuw]` has the combined
constant 17-42=-25 which wraps unsigned, so the fold is not applied.
Fixes #140328
Assisted-by: Claude Code
---
.../Dialect/Arith/IR/ArithCanonicalization.td | 39 +++++++++---
mlir/lib/Dialect/Arith/IR/ArithOps.cpp | 61 +++++++++++++++++++
mlir/test/Dialect/Arith/canonicalize.mlir | 49 ++++++++++++++-
3 files changed, 138 insertions(+), 11 deletions(-)
diff --git a/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td b/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
index e22fc1d478e4f..8eb6aeed6dd79 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
+++ b/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
@@ -30,6 +30,19 @@ def MergeOverflow : NativeCodeCall<"mergeOverflowFlags($0, $1)">;
// Default overflow flag (all wraparounds allowed).
defvar DefOverflow = ConstantEnumCase<Arith_IntegerOverflowAttr, "none">;
+// Constraint: constant c0+c1 does not overflow under the merged flags of ovf1
+// and ovf2. Without this guard, folding addi(addi(x, c0), c1) -> addi(x, c0+c1)
+// is unsound when the constant sum wraps and the result carries nsw/nuw.
+def AddAttrsNoOverflow :
+ Constraint<CPred<"addAttrsNoOverflow($0, $1, $2, $3)">,
+ "constant addition does not overflow with merged flags">;
+
+// Constraint: constant c0-c1 does not overflow under the merged flags of ovf1
+// and ovf2.
+def SubAttrsNoOverflow :
+ Constraint<CPred<"subAttrsNoOverflow($0, $1, $2, $3)">,
+ "constant subtraction does not overflow with merged flags">;
+
class cast<string type> : NativeCodeCall<"::mlir::cast<" # type # ">($0)">;
//===----------------------------------------------------------------------===//
@@ -40,12 +53,15 @@ class cast<string type> : NativeCodeCall<"::mlir::cast<" # type # ">($0)">;
// as the second operand.
// addi(addi(x, c0), c1) -> addi(x, c0 + c1)
+// Guard: the constant sum c0+c1 must not overflow under the merged flags,
+// otherwise the folded op would produce poison where the originals did not.
def AddIAddConstant :
Pat<(Arith_AddIOp:$res
(Arith_AddIOp $x, (ConstantLikeMatcher APIntAttr:$c0), $ovf1),
(ConstantLikeMatcher APIntAttr:$c1), $ovf2),
(Arith_AddIOp $x, (Arith_ConstantOp (AddIntAttrs $res, $c0, $c1)),
- (MergeOverflow $ovf1, $ovf2))>;
+ (MergeOverflow $ovf1, $ovf2)),
+ [(AddAttrsNoOverflow $c0, $c1, $ovf1, $ovf2)]>;
// addi(subi(x, c0), c1) -> addi(x, c1 - c0)
def AddISubConstantRHS :
@@ -53,7 +69,8 @@ def AddISubConstantRHS :
(Arith_SubIOp $x, (ConstantLikeMatcher APIntAttr:$c0), $ovf1),
(ConstantLikeMatcher APIntAttr:$c1), $ovf2),
(Arith_AddIOp $x, (Arith_ConstantOp (SubIntAttrs $res, $c1, $c0)),
- (MergeOverflow $ovf1, $ovf2))>;
+ (MergeOverflow $ovf1, $ovf2)),
+ [(SubAttrsNoOverflow $c1, $c0, $ovf1, $ovf2)]>;
// addi(subi(c0, x), c1) -> subi(c0 + c1, x)
def AddISubConstantLHS :
@@ -61,7 +78,8 @@ def AddISubConstantLHS :
(Arith_SubIOp (ConstantLikeMatcher APIntAttr:$c0), $x, $ovf1),
(ConstantLikeMatcher APIntAttr:$c1), $ovf2),
(Arith_SubIOp (Arith_ConstantOp (AddIntAttrs $res, $c0, $c1)), $x,
- (MergeOverflow $ovf1, $ovf2))>;
+ (MergeOverflow $ovf1, $ovf2)),
+ [(AddAttrsNoOverflow $c0, $c1, $ovf1, $ovf2)]>;
def IsScalarOrSplatNegativeOne :
Constraint<And<[
@@ -123,7 +141,8 @@ def SubILHSAddConstant :
(ConstantLikeMatcher APIntAttr:$c1),
(Arith_AddIOp $x, (ConstantLikeMatcher APIntAttr:$c0), $ovf1), $ovf2),
(Arith_SubIOp (Arith_ConstantOp (SubIntAttrs $res, $c1, $c0)), $x,
- (MergeOverflow $ovf1, $ovf2))>;
+ (MergeOverflow $ovf1, $ovf2)),
+ [(SubAttrsNoOverflow $c1, $c0, $ovf1, $ovf2)]>;
// subi(subi(x, c0), c1) -> subi(x, c0 + c1)
def SubIRHSSubConstantRHS :
@@ -131,7 +150,8 @@ def SubIRHSSubConstantRHS :
(Arith_SubIOp $x, (ConstantLikeMatcher APIntAttr:$c0), $ovf1),
(ConstantLikeMatcher APIntAttr:$c1), $ovf2),
(Arith_SubIOp $x, (Arith_ConstantOp (AddIntAttrs $res, $c0, $c1)),
- (MergeOverflow $ovf1, $ovf2))>;
+ (MergeOverflow $ovf1, $ovf2)),
+ [(AddAttrsNoOverflow $c0, $c1, $ovf1, $ovf2)]>;
// subi(subi(c0, x), c1) -> subi(c0 - c1, x)
def SubIRHSSubConstantLHS :
@@ -139,7 +159,8 @@ def SubIRHSSubConstantLHS :
(Arith_SubIOp (ConstantLikeMatcher APIntAttr:$c0), $x, $ovf1),
(ConstantLikeMatcher APIntAttr:$c1), $ovf2),
(Arith_SubIOp (Arith_ConstantOp (SubIntAttrs $res, $c0, $c1)), $x,
- (MergeOverflow $ovf1, $ovf2))>;
+ (MergeOverflow $ovf1, $ovf2)),
+ [(SubAttrsNoOverflow $c0, $c1, $ovf1, $ovf2)]>;
// subi(c1, subi(x, c0)) -> subi(c0 + c1, x)
def SubILHSSubConstantRHS :
@@ -147,7 +168,8 @@ def SubILHSSubConstantRHS :
(ConstantLikeMatcher APIntAttr:$c1),
(Arith_SubIOp $x, (ConstantLikeMatcher APIntAttr:$c0), $ovf1), $ovf2),
(Arith_SubIOp (Arith_ConstantOp (AddIntAttrs $res, $c0, $c1)), $x,
- (MergeOverflow $ovf1, $ovf2))>;
+ (MergeOverflow $ovf1, $ovf2)),
+ [(AddAttrsNoOverflow $c0, $c1, $ovf1, $ovf2)]>;
// subi(c1, subi(c0, x)) -> addi(x, c1 - c0)
def SubILHSSubConstantLHS :
@@ -155,7 +177,8 @@ def SubILHSSubConstantLHS :
(ConstantLikeMatcher APIntAttr:$c1),
(Arith_SubIOp (ConstantLikeMatcher APIntAttr:$c0), $x, $ovf1), $ovf2),
(Arith_AddIOp $x, (Arith_ConstantOp (SubIntAttrs $res, $c1, $c0)),
- (MergeOverflow $ovf1, $ovf2))>;
+ (MergeOverflow $ovf1, $ovf2)),
+ [(SubAttrsNoOverflow $c1, $c0, $ovf1, $ovf2)]>;
// subi(subi(a, b), a) -> subi(0, b)
def SubISubILHSRHSLHS :
diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
index 5f10a94522350..7f41db325262f 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
@@ -71,6 +71,67 @@ mergeOverflowFlags(IntegerOverflowFlagsAttr val1,
val1.getValue() & val2.getValue());
}
+// Check if the constant integer addition of c0 + c1 would overflow given the
+// merged overflow flags. Returns true if the constant fold would be
+// invalid.
+static bool addConstantWouldOverflow(const APInt &c0, const APInt &c1,
+ IntegerOverflowFlags flags) {
+ if (bitEnumContainsAny(flags, IntegerOverflowFlags::nsw)) {
+ bool overflow = false;
+ (void)c0.sadd_ov(c1, overflow);
+ if (overflow)
+ return true;
+ }
+ if (bitEnumContainsAny(flags, IntegerOverflowFlags::nuw)) {
+ bool overflow = false;
+ (void)c0.uadd_ov(c1, overflow);
+ if (overflow)
+ return true;
+ }
+ return false;
+}
+
+// Check if the constant integer subtraction of c0 - c1 would overflow given
+// the merged overflow flags. Returns true if the constant fold would be
+// invalid.
+static bool subConstantWouldOverflow(const APInt &c0, const APInt &c1,
+ IntegerOverflowFlags flags) {
+ if (bitEnumContainsAny(flags, IntegerOverflowFlags::nsw)) {
+ bool overflow = false;
+ (void)c0.ssub_ov(c1, overflow);
+ if (overflow)
+ return true;
+ }
+ if (bitEnumContainsAny(flags, IntegerOverflowFlags::nuw)) {
+ bool overflow = false;
+ (void)c0.usub_ov(c1, overflow);
+ if (overflow)
+ return true;
+ }
+ return false;
+}
+
+// Wrappers for use in TableGen pattern constraints.
+// Returns true if c0 + c1 would NOT overflow with the merged flags.
+static bool addAttrsNoOverflow(Attribute c0, Attribute c1,
+ IntegerOverflowFlagsAttr ovf1,
+ IntegerOverflowFlagsAttr ovf2) {
+ IntegerOverflowFlags merged = ovf1.getValue() & ovf2.getValue();
+ const APInt &v0 = llvm::cast<IntegerAttr>(c0).getValue();
+ const APInt &v1 = llvm::cast<IntegerAttr>(c1).getValue();
+ return !addConstantWouldOverflow(v0, v1, merged);
+}
+
+// Returns true if c0 - c1 would NOT overflow with the merged flags.
+static bool subAttrsNoOverflow(Attribute c0, Attribute c1,
+ IntegerOverflowFlagsAttr ovf1,
+ IntegerOverflowFlagsAttr ovf2) {
+ IntegerOverflowFlags merged = ovf1.getValue() & ovf2.getValue();
+ const APInt &v0 = llvm::cast<IntegerAttr>(c0).getValue();
+ const APInt &v1 = llvm::cast<IntegerAttr>(c1).getValue();
+ return !subConstantWouldOverflow(v0, v1, merged);
+}
+
/// Invert an integer comparison predicate.
arith::CmpIPredicate arith::invertPredicate(arith::CmpIPredicate pred) {
switch (pred) {
diff --git a/mlir/test/Dialect/Arith/canonicalize.mlir b/mlir/test/Dialect/Arith/canonicalize.mlir
index 18665e2eb6f4a..211196e26357b 100644
--- a/mlir/test/Dialect/Arith/canonicalize.mlir
+++ b/mlir/test/Dialect/Arith/canonicalize.mlir
@@ -1310,9 +1310,14 @@ func.func @tripleSubSub1(%arg0: index) -> index {
}
// CHECK-LABEL: @tripleSubSub1Ovf
-// CHECK: %[[cres:.+]] = arith.constant -25 : index
-// CHECK: %[[add:.+]] = arith.subi %[[cres]], %arg0 overflow<nsw, nuw> : index
-// CHECK: return %[[add]]
+// Folding subi(subi(c0, x), c1) -> subi(c0-c1, x) is unsound when c0-c1
+// unsigned-wraps and nuw is set, because the original always produces poison
+// while the folded form may not. Do not fold.
+// CHECK: %[[c17:.+]] = arith.constant 17 : index
+// CHECK: %[[c42:.+]] = arith.constant 42 : index
+// CHECK: %[[sub1:.+]] = arith.subi %[[c17]], %arg0 overflow<nsw, nuw>
+// CHECK: %[[sub2:.+]] = arith.subi %[[sub1]], %[[c42]] overflow<nsw, nuw>
+// CHECK: return %[[sub2]]
func.func @tripleSubSub1Ovf(%arg0: index) -> index {
%c17 = arith.constant 17 : index
%c42 = arith.constant 42 : index
@@ -3564,3 +3569,41 @@ func.func @convertf_fold_f8() -> f8E5M2 {
return %result : f8E5M2
}
+// -----
+
+// addi(addi(x, c0), c1) -> addi(x, c0+c1) must not fold when c0+c1 wraps
+// under the merged overflow flags. Here c0=-2, c1=-1 for i2: -2 + -1 = -3,
+// which wraps to +1 in i2. Since both ops have nsw the fold is unsound
+// (for x=1: original gives -2 with no poison, folded gives poison).
+//
+// CHECK-LABEL: @addi_no_fold_const_signed_overflow_nsw
+// CHECK: %[[cm2:.+]] = arith.constant -2 : i2
+// CHECK: %[[cm1:.+]] = arith.constant -1 : i2
+// CHECK: %[[t:.+]] = arith.addi %arg0, %[[cm2]] overflow<nsw>
+// CHECK: %[[r:.+]] = arith.addi %[[t]], %[[cm1]] overflow<nsw>
+// CHECK: return %[[r]]
+func.func @addi_no_fold_const_signed_overflow_nsw(%arg0: i2) -> i2 {
+ %cm2 = arith.constant -2 : i2
+ %cm1 = arith.constant -1 : i2
+ %t = arith.addi %arg0, %cm2 overflow<nsw> : i2
+ %r = arith.addi %t, %cm1 overflow<nsw> : i2
+ return %r : i2
+}
+
+// -----
+
+// When c0+c1 does not overflow, the fold is sound.
+// c0=1, c1=2 for i4: 1+2=3, no overflow in signed or unsigned sense.
+//
+// CHECK-LABEL: @addi_fold_const_no_overflow
+// CHECK: %[[c3:.+]] = arith.constant 3 : i4
+// CHECK: %[[r:.+]] = arith.addi %arg0, %[[c3]] overflow<nsw, nuw>
+// CHECK: return %[[r]]
+func.func @addi_fold_const_no_overflow(%arg0: i4) -> i4 {
+ %c1 = arith.constant 1 : i4
+ %c2 = arith.constant 2 : i4
+ %t = arith.addi %arg0, %c1 overflow<nsw, nuw> : i4
+ %r = arith.addi %t, %c2 overflow<nsw, nuw> : i4
+ return %r : i4
+}
+
>From 96c6fe64b7a627071d7856d00e1935bb81bf3544 Mon Sep 17 00:00:00 2001
From: Mehdi Amini <joker.eph at gmail.com>
Date: Sat, 28 Mar 2026 06:02:46 -0700
Subject: [PATCH 2/2] [MLIR][Arith] Address review findings: muli guard, early
exit, tests, comment fix
- Add MulAttrsNoOverflow guard to MulIMulIConstant (same soundness bug as
addi/subi: c0*c1 can overflow under merged nsw/nuw flags)
- Add mulAttrsNoOverflow C++ helper using smul_ov/umul_ov
- Add early exit for merged==none in add/sub/mulAttrsNoOverflow to skip
APInt extraction in the common no-flags case
- Fix inverted poison direction in tripleSubSub1Ovf comment
- Add no-fold tests for: nuw-only unsigned overflow (addi), subi pattern
(SubIRHSSubConstantRHS), AddISubConstantRHS pattern, muli pattern
- Add muli fold-allowed test (c0*c1 safe)
Assisted-by: Claude Code
---
.../Dialect/Arith/IR/ArithCanonicalization.td | 15 ++-
mlir/lib/Dialect/Arith/IR/ArithOps.cpp | 28 +++++
mlir/test/Dialect/Arith/canonicalize.mlir | 100 +++++++++++++++++-
3 files changed, 139 insertions(+), 4 deletions(-)
diff --git a/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td b/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
index 8eb6aeed6dd79..2582cade66252 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
+++ b/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
@@ -38,11 +38,19 @@ def AddAttrsNoOverflow :
"constant addition does not overflow with merged flags">;
// Constraint: constant c0-c1 does not overflow under the merged flags of ovf1
-// and ovf2.
+// and ovf2. Without this guard, folding addi(subi(x, c0), c1) -> addi(x, c1-c0)
+// is unsound when the constant difference wraps and the result carries nsw/nuw.
def SubAttrsNoOverflow :
Constraint<CPred<"subAttrsNoOverflow($0, $1, $2, $3)">,
"constant subtraction does not overflow with merged flags">;
+// Constraint: constant c0*c1 does not overflow under the merged flags of ovf1
+// and ovf2. Without this guard, folding muli(muli(x, c0), c1) -> muli(x, c0*c1)
+// is unsound when the constant product wraps and the result carries nsw/nuw.
+def MulAttrsNoOverflow :
+ Constraint<CPred<"mulAttrsNoOverflow($0, $1, $2, $3)">,
+ "constant multiplication does not overflow with merged flags">;
+
class cast<string type> : NativeCodeCall<"::mlir::cast<" # type # ">($0)">;
//===----------------------------------------------------------------------===//
@@ -103,6 +111,8 @@ def AddIMulNegativeOneLhs :
[(IsScalarOrSplatNegativeOne $c0)]>;
// muli(muli(x, c0), c1) -> muli(x, c0 * c1)
+// Guard: the constant product c0*c1 must not overflow under the merged flags,
+// otherwise the folded op would produce poison where the originals did not.
def MulIMulIConstant :
Pat<(Arith_MulIOp:$res
(Arith_MulIOp $x, (ConstantLikeMatcher APIntAttr:$c0), $ovf1),
@@ -110,7 +120,8 @@ def MulIMulIConstant :
(Arith_MulIOp $x, (Arith_ConstantOp (MulIntAttrs $res, $c0, $c1)),
(MergeOverflow $ovf1, $ovf2)),
[(Constraint<CPred<"$0.getType() == cast<IntegerAttr>($1).getType()">> $res, $c0),
- (Constraint<CPred<"$0.getType() == cast<IntegerAttr>($1).getType()">> $res, $c1)]>;
+ (Constraint<CPred<"$0.getType() == cast<IntegerAttr>($1).getType()">> $res, $c1),
+ (MulAttrsNoOverflow $c0, $c1, $ovf1, $ovf2)]>;
//===----------------------------------------------------------------------===//
// AddUIExtendedOp
diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
index 7f41db325262f..d55b10ec69ff5 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
@@ -117,6 +117,8 @@ static bool addAttrsNoOverflow(Attribute c0, Attribute c1,
IntegerOverflowFlagsAttr ovf1,
IntegerOverflowFlagsAttr ovf2) {
IntegerOverflowFlags merged = ovf1.getValue() & ovf2.getValue();
+ if (merged == IntegerOverflowFlags::none)
+ return true;
const APInt &v0 = llvm::cast<IntegerAttr>(c0).getValue();
const APInt &v1 = llvm::cast<IntegerAttr>(c1).getValue();
return !addConstantWouldOverflow(v0, v1, merged);
@@ -127,11 +129,37 @@ static bool subAttrsNoOverflow(Attribute c0, Attribute c1,
IntegerOverflowFlagsAttr ovf1,
IntegerOverflowFlagsAttr ovf2) {
IntegerOverflowFlags merged = ovf1.getValue() & ovf2.getValue();
+ if (merged == IntegerOverflowFlags::none)
+ return true;
const APInt &v0 = llvm::cast<IntegerAttr>(c0).getValue();
const APInt &v1 = llvm::cast<IntegerAttr>(c1).getValue();
return !subConstantWouldOverflow(v0, v1, merged);
}
+// Returns true if c0 * c1 would NOT overflow with the merged flags.
+static bool mulAttrsNoOverflow(Attribute c0, Attribute c1,
+ IntegerOverflowFlagsAttr ovf1,
+ IntegerOverflowFlagsAttr ovf2) {
+ IntegerOverflowFlags merged = ovf1.getValue() & ovf2.getValue();
+ if (merged == IntegerOverflowFlags::none)
+ return true;
+ const APInt &v0 = llvm::cast<IntegerAttr>(c0).getValue();
+ const APInt &v1 = llvm::cast<IntegerAttr>(c1).getValue();
+ if (bitEnumContainsAny(merged, IntegerOverflowFlags::nsw)) {
+ bool overflow = false;
+ (void)v0.smul_ov(v1, overflow);
+ if (overflow)
+ return false;
+ }
+ if (bitEnumContainsAny(merged, IntegerOverflowFlags::nuw)) {
+ bool overflow = false;
+ (void)v0.umul_ov(v1, overflow);
+ if (overflow)
+ return false;
+ }
+ return true;
+}
+
/// Invert an integer comparison predicate.
arith::CmpIPredicate arith::invertPredicate(arith::CmpIPredicate pred) {
switch (pred) {
diff --git a/mlir/test/Dialect/Arith/canonicalize.mlir b/mlir/test/Dialect/Arith/canonicalize.mlir
index 211196e26357b..97d053ba69f4f 100644
--- a/mlir/test/Dialect/Arith/canonicalize.mlir
+++ b/mlir/test/Dialect/Arith/canonicalize.mlir
@@ -1311,8 +1311,9 @@ func.func @tripleSubSub1(%arg0: index) -> index {
// CHECK-LABEL: @tripleSubSub1Ovf
// Folding subi(subi(c0, x), c1) -> subi(c0-c1, x) is unsound when c0-c1
-// unsigned-wraps and nuw is set, because the original always produces poison
-// while the folded form may not. Do not fold.
+// unsigned-wraps and nuw is set: the folded constant 17-42=-25 wraps unsigned,
+// so the folded op would produce poison on more inputs than the original.
+// Do not fold.
// CHECK: %[[c17:.+]] = arith.constant 17 : index
// CHECK: %[[c42:.+]] = arith.constant 42 : index
// CHECK: %[[sub1:.+]] = arith.subi %[[c17]], %arg0 overflow<nsw, nuw>
@@ -3607,3 +3608,98 @@ func.func @addi_fold_const_no_overflow(%arg0: i4) -> i4 {
return %r : i4
}
+// -----
+
+// addi(addi(x, c0), c1) must not fold when c0+c1 wraps unsigned and nuw is set.
+// c0=-2 (=6 unsigned), c1=3 for i3: 6+3=9 wraps unsigned (i3 unsigned max=7),
+// but -2+3=1 which does NOT overflow signed. With nuw only the fold is unsound.
+//
+// CHECK-LABEL: @addi_no_fold_const_unsigned_overflow_nuw
+// CHECK: %[[cm2:.+]] = arith.constant -2 : i3
+// CHECK: %[[c3:.+]] = arith.constant 3 : i3
+// CHECK: %[[t:.+]] = arith.addi %arg0, %[[cm2]] overflow<nuw>
+// CHECK: %[[r:.+]] = arith.addi %[[t]], %[[c3]] overflow<nuw>
+// CHECK: return %[[r]]
+func.func @addi_no_fold_const_unsigned_overflow_nuw(%arg0: i3) -> i3 {
+ %cm2 = arith.constant -2 : i3
+ %c3 = arith.constant 3 : i3
+ %t = arith.addi %arg0, %cm2 overflow<nuw> : i3
+ %r = arith.addi %t, %c3 overflow<nuw> : i3
+ return %r : i3
+}
+
+// -----
+
+// subi(subi(x, c0), c1) -> subi(x, c0+c1) must not fold when c0+c1 overflows
+// under the merged flags. c0=3, c1=1 for i3: 3+1=4 overflows signed (i3
+// signed max = 3). With nsw the fold is unsound.
+//
+// CHECK-LABEL: @subi_rhs_no_fold_const_overflow
+// CHECK: %[[c3:.+]] = arith.constant 3 : i3
+// CHECK: %[[c1:.+]] = arith.constant 1 : i3
+// CHECK: %[[t:.+]] = arith.subi %arg0, %[[c3]] overflow<nsw>
+// CHECK: %[[r:.+]] = arith.subi %[[t]], %[[c1]] overflow<nsw>
+// CHECK: return %[[r]]
+func.func @subi_rhs_no_fold_const_overflow(%arg0: i3) -> i3 {
+ %c3 = arith.constant 3 : i3
+ %c1 = arith.constant 1 : i3
+ %t = arith.subi %arg0, %c3 overflow<nsw> : i3
+ %r = arith.subi %t, %c1 overflow<nsw> : i3
+ return %r : i3
+}
+
+// -----
+
+// addi(subi(x, c0), c1) -> addi(x, c1-c0) must not fold when c1-c0 overflows
+// under the merged flags. c0=-3, c1=1 for i3: 1-(-3)=4 overflows signed.
+//
+// CHECK-LABEL: @addi_subi_rhs_no_fold_const_overflow
+// CHECK: %[[cm3:.+]] = arith.constant -3 : i3
+// CHECK: %[[c1:.+]] = arith.constant 1 : i3
+// CHECK: %[[t:.+]] = arith.subi %arg0, %[[cm3]] overflow<nsw>
+// CHECK: %[[r:.+]] = arith.addi %[[t]], %[[c1]] overflow<nsw>
+// CHECK: return %[[r]]
+func.func @addi_subi_rhs_no_fold_const_overflow(%arg0: i3) -> i3 {
+ %cm3 = arith.constant -3 : i3
+ %c1 = arith.constant 1 : i3
+ %t = arith.subi %arg0, %cm3 overflow<nsw> : i3
+ %r = arith.addi %t, %c1 overflow<nsw> : i3
+ return %r : i3
+}
+
+// -----
+
+// muli(muli(x, c0), c1) -> muli(x, c0*c1) must not fold when c0*c1 overflows
+// under the merged flags. c0=3, c1=3 for i4: 3*3=9 overflows signed (i4
+// signed max = 7). With nsw the fold is unsound.
+//
+// CHECK-LABEL: @muli_no_fold_const_overflow
+// CHECK: %[[c3:.+]] = arith.constant 3 : i4
+// CHECK: %[[t:.+]] = arith.muli %arg0, %[[c3]] overflow<nsw>
+// CHECK: %[[r:.+]] = arith.muli %[[t]], %[[c3]] overflow<nsw>
+// CHECK: return %[[r]]
+func.func @muli_no_fold_const_overflow(%arg0: i4) -> i4 {
+ %c3a = arith.constant 3 : i4
+ %c3b = arith.constant 3 : i4
+ %t = arith.muli %arg0, %c3a overflow<nsw> : i4
+ %r = arith.muli %t, %c3b overflow<nsw> : i4
+ return %r : i4
+}
+
+// -----
+
+// When c0*c1 does not overflow, the fold is sound.
+// c0=2, c1=2 for i4: 2*2=4, no overflow signed or unsigned.
+//
+// CHECK-LABEL: @muli_fold_const_no_overflow
+// CHECK: %[[c4:.+]] = arith.constant 4 : i4
+// CHECK: %[[r:.+]] = arith.muli %arg0, %[[c4]] overflow<nsw, nuw>
+// CHECK: return %[[r]]
+func.func @muli_fold_const_no_overflow(%arg0: i4) -> i4 {
+ %c2a = arith.constant 2 : i4
+ %c2b = arith.constant 2 : i4
+ %t = arith.muli %arg0, %c2a overflow<nsw, nuw> : i4
+ %r = arith.muli %t, %c2b overflow<nsw, nuw> : i4
+ return %r : i4
+}
+
More information about the Mlir-commits
mailing list