[Mlir-commits] [mlir] [mlir][Interfaces] Track and infer no-overflow flags in integer ranges (PR #191777)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Apr 13 02:04:59 PDT 2026
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Hocky Yudhiono (hockyy)
<details>
<summary>Changes</summary>
Introduce `OverflowFlags` on `ConstantIntRanges` and propagate them through range operations and printing. Move the shared overflow enum to `InferIntRangeInterface.h`, teach arith integer range inference to prove additional `nsw`/`nuw` guarantees for add/sub/mul from operand ranges, and preserve user-declared overflow flags. Add unit coverage for overflow-flag behavior (`withOverflowFlags`, union/intersection semantics, and textual formatting), plus a small cleanup in strided metadata range operand collection.
Assisted-by: openai/gpt-5.3-codex, cursor:anthropic/claude-opus-4.6
---
Full diff: https://github.com/llvm/llvm-project/pull/191777.diff
6 Files Affected:
- (modified) mlir/include/mlir/Interfaces/InferIntRangeInterface.h (+23-2)
- (modified) mlir/include/mlir/Interfaces/Utils/InferIntRangeCommon.h (+1-7)
- (modified) mlir/lib/Analysis/DataFlow/StridedMetadataRangeAnalysis.cpp (+4-4)
- (modified) mlir/lib/Dialect/Arith/IR/InferIntRangeInterfaceImpls.cpp (+106-6)
- (modified) mlir/lib/Interfaces/InferIntRangeInterface.cpp (+26-4)
- (modified) mlir/unittests/Interfaces/InferIntRangeInterfaceTest.cpp (+65)
``````````diff
diff --git a/mlir/include/mlir/Interfaces/InferIntRangeInterface.h b/mlir/include/mlir/Interfaces/InferIntRangeInterface.h
index a6de3d1885eec..c046ce391f22d 100644
--- a/mlir/include/mlir/Interfaces/InferIntRangeInterface.h
+++ b/mlir/include/mlir/Interfaces/InferIntRangeInterface.h
@@ -15,9 +15,18 @@
#define MLIR_INTERFACES_INFERINTRANGEINTERFACE_H
#include "mlir/IR/OpDefinition.h"
+#include "llvm/ADT/BitmaskEnum.h"
#include <optional>
namespace mlir {
+enum class OverflowFlags : uint32_t {
+ None = 0,
+ Nsw = 1,
+ Nuw = 2,
+ LLVM_MARK_AS_BITMASK_ENUM(Nuw)
+};
+LLVM_ENABLE_BITMASK_ENUMS_IN_NAMESPACE();
+
/// A set of arbitrary-precision integers representing bounds on a given integer
/// value. These bounds are inclusive on both ends, so
/// bounds of [4, 5] mean 4 <= x <= 5. Separate bounds are tracked for
@@ -29,8 +38,10 @@ class ConstantIntRanges {
/// Bound umin <= (unsigned)x <= umax and smin <= signed(x) <= smax.
/// Non-integer values should be bounded by APInts of bitwidth 0.
ConstantIntRanges(const APInt &umin, const APInt &umax, const APInt &smin,
- const APInt &smax)
- : uminVal(umin), umaxVal(umax), sminVal(smin), smaxVal(smax) {
+ const APInt &smax,
+ OverflowFlags noOverflowFlags = OverflowFlags::None)
+ : uminVal(umin), umaxVal(umax), sminVal(smin), smaxVal(smax),
+ noOverflowFlags(noOverflowFlags) {
assert(uminVal.getBitWidth() == umaxVal.getBitWidth() &&
umaxVal.getBitWidth() == sminVal.getBitWidth() &&
sminVal.getBitWidth() == smaxVal.getBitWidth() &&
@@ -96,11 +107,21 @@ class ConstantIntRanges {
/// value.
std::optional<APInt> getConstantValue() const;
+ /// Return no-overflow properties proven for the operation computing
+ /// the bounded value.
+ OverflowFlags getOverflowFlags() const { return noOverflowFlags; }
+
+ /// Return this range with updated no-overflow properties.
+ ConstantIntRanges withOverflowFlags(OverflowFlags flags) const {
+ return {uminVal, umaxVal, sminVal, smaxVal, flags};
+ }
+
friend raw_ostream &operator<<(raw_ostream &os,
const ConstantIntRanges &range);
private:
APInt uminVal, umaxVal, sminVal, smaxVal;
+ OverflowFlags noOverflowFlags = OverflowFlags::None;
};
raw_ostream &operator<<(raw_ostream &, const ConstantIntRanges &);
diff --git a/mlir/include/mlir/Interfaces/Utils/InferIntRangeCommon.h b/mlir/include/mlir/Interfaces/Utils/InferIntRangeCommon.h
index e369c80a26ea9..8c0784a4513bb 100644
--- a/mlir/include/mlir/Interfaces/Utils/InferIntRangeCommon.h
+++ b/mlir/include/mlir/Interfaces/Utils/InferIntRangeCommon.h
@@ -16,7 +16,6 @@
#include "mlir/Interfaces/InferIntRangeInterface.h"
#include "llvm/ADT/ArrayRef.h"
-#include "llvm/ADT/BitmaskEnum.h"
#include <optional>
namespace mlir {
@@ -39,12 +38,7 @@ static constexpr unsigned indexMaxWidth = 64;
enum class CmpMode : uint32_t { Both, Signed, Unsigned };
-enum class OverflowFlags : uint32_t {
- None = 0,
- Nsw = 1,
- Nuw = 2,
- LLVM_MARK_AS_BITMASK_ENUM(Nuw)
-};
+using OverflowFlags = mlir::OverflowFlags;
/// Function that performs inference on an array of `ConstantIntRanges` while
/// taking special overflow behavior into account.
diff --git a/mlir/lib/Analysis/DataFlow/StridedMetadataRangeAnalysis.cpp b/mlir/lib/Analysis/DataFlow/StridedMetadataRangeAnalysis.cpp
index 01c9dafaddf10..560683b86cbb4 100644
--- a/mlir/lib/Analysis/DataFlow/StridedMetadataRangeAnalysis.cpp
+++ b/mlir/lib/Analysis/DataFlow/StridedMetadataRangeAnalysis.cpp
@@ -104,10 +104,10 @@ LogicalResult StridedMetadataRangeAnalysis::visitOperation(
};
// Convert the arguments lattices to a vector.
- SmallVector<StridedMetadataRange> argRanges = llvm::map_to_vector(
- operands, [](const StridedMetadataRangeLattice *lattice) {
- return lattice->getValue();
- });
+ SmallVector<StridedMetadataRange, 0> argRanges;
+ argRanges.reserve(operands.size());
+ for (const StridedMetadataRangeLattice *lattice : operands)
+ argRanges.push_back(lattice->getValue());
// Callback to set metadata on a result.
auto joinCallback = [&](Value v, const StridedMetadataRange &md) {
diff --git a/mlir/lib/Dialect/Arith/IR/InferIntRangeInterfaceImpls.cpp b/mlir/lib/Dialect/Arith/IR/InferIntRangeInterfaceImpls.cpp
index 49f89e1bd17f3..db17d35b04a0c 100644
--- a/mlir/lib/Dialect/Arith/IR/InferIntRangeInterfaceImpls.cpp
+++ b/mlir/lib/Dialect/Arith/IR/InferIntRangeInterfaceImpls.cpp
@@ -28,6 +28,94 @@ convertArithOverflowFlags(arith::IntegerOverflowFlags flags) {
return retFlags;
}
+static bool proveNoSignedAddOverflow(ArrayRef<ConstantIntRanges> argRanges) {
+ const APInt &lhsMin = argRanges[0].smin();
+ const APInt &lhsMax = argRanges[0].smax();
+ const APInt &rhsMin = argRanges[1].smin();
+ const APInt &rhsMax = argRanges[1].smax();
+ bool overflow = false;
+ (void)lhsMin.sadd_ov(rhsMin, overflow);
+ if (overflow)
+ return false;
+ (void)lhsMax.sadd_ov(rhsMax, overflow);
+ return !overflow;
+}
+
+static bool proveNoUnsignedAddOverflow(ArrayRef<ConstantIntRanges> argRanges) {
+ bool overflow = false;
+ (void)argRanges[0].umax().uadd_ov(argRanges[1].umax(), overflow);
+ return !overflow;
+}
+
+static bool proveNoSignedSubOverflow(ArrayRef<ConstantIntRanges> argRanges) {
+ const APInt &lhsMin = argRanges[0].smin();
+ const APInt &lhsMax = argRanges[0].smax();
+ const APInt &rhsMin = argRanges[1].smin();
+ const APInt &rhsMax = argRanges[1].smax();
+ bool overflow = false;
+ (void)lhsMin.ssub_ov(rhsMax, overflow);
+ if (overflow)
+ return false;
+ (void)lhsMax.ssub_ov(rhsMin, overflow);
+ return !overflow;
+}
+
+static bool proveNoUnsignedSubOverflow(ArrayRef<ConstantIntRanges> argRanges) {
+ return argRanges[0].umin().uge(argRanges[1].umax());
+}
+
+static bool proveNoSignedMulOverflow(ArrayRef<ConstantIntRanges> argRanges) {
+ const APInt &lhsMin = argRanges[0].smin();
+ const APInt &lhsMax = argRanges[0].smax();
+ const APInt &rhsMin = argRanges[1].smin();
+ const APInt &rhsMax = argRanges[1].smax();
+ bool overflow = false;
+ (void)lhsMin.smul_ov(rhsMin, overflow);
+ if (overflow)
+ return false;
+ (void)lhsMin.smul_ov(rhsMax, overflow);
+ if (overflow)
+ return false;
+ (void)lhsMax.smul_ov(rhsMin, overflow);
+ if (overflow)
+ return false;
+ (void)lhsMax.smul_ov(rhsMax, overflow);
+ return !overflow;
+}
+
+static bool proveNoUnsignedMulOverflow(ArrayRef<ConstantIntRanges> argRanges) {
+ bool overflow = false;
+ (void)argRanges[0].umax().umul_ov(argRanges[1].umax(), overflow);
+ return !overflow;
+}
+
+static OverflowFlags proveNoOverflowForAdd(ArrayRef<ConstantIntRanges> args) {
+ OverflowFlags flags = OverflowFlags::None;
+ if (proveNoSignedAddOverflow(args))
+ flags |= OverflowFlags::Nsw;
+ if (proveNoUnsignedAddOverflow(args))
+ flags |= OverflowFlags::Nuw;
+ return flags;
+}
+
+static OverflowFlags proveNoOverflowForSub(ArrayRef<ConstantIntRanges> args) {
+ OverflowFlags flags = OverflowFlags::None;
+ if (proveNoSignedSubOverflow(args))
+ flags |= OverflowFlags::Nsw;
+ if (proveNoUnsignedSubOverflow(args))
+ flags |= OverflowFlags::Nuw;
+ return flags;
+}
+
+static OverflowFlags proveNoOverflowForMul(ArrayRef<ConstantIntRanges> args) {
+ OverflowFlags flags = OverflowFlags::None;
+ if (proveNoSignedMulOverflow(args))
+ flags |= OverflowFlags::Nsw;
+ if (proveNoUnsignedMulOverflow(args))
+ flags |= OverflowFlags::Nuw;
+ return flags;
+}
+
//===----------------------------------------------------------------------===//
// ConstantOp
//===----------------------------------------------------------------------===//
@@ -65,8 +153,12 @@ void arith::ConstantOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
void arith::AddIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
SetIntRangeFn setResultRange) {
- setResultRange(getResult(), inferAdd(argRanges, convertArithOverflowFlags(
- getOverflowFlags())));
+ ConstantIntRanges range = inferAdd(argRanges, convertArithOverflowFlags(
+ getOverflowFlags()));
+ OverflowFlags noOverflowFlags =
+ proveNoOverflowForAdd(argRanges) |
+ convertArithOverflowFlags(getOverflowFlags());
+ setResultRange(getResult(), range.withOverflowFlags(noOverflowFlags));
}
//===----------------------------------------------------------------------===//
@@ -75,8 +167,12 @@ void arith::AddIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
void arith::SubIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
SetIntRangeFn setResultRange) {
- setResultRange(getResult(), inferSub(argRanges, convertArithOverflowFlags(
- getOverflowFlags())));
+ ConstantIntRanges range = inferSub(argRanges, convertArithOverflowFlags(
+ getOverflowFlags()));
+ OverflowFlags noOverflowFlags =
+ proveNoOverflowForSub(argRanges) |
+ convertArithOverflowFlags(getOverflowFlags());
+ setResultRange(getResult(), range.withOverflowFlags(noOverflowFlags));
}
//===----------------------------------------------------------------------===//
@@ -85,8 +181,12 @@ void arith::SubIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
void arith::MulIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
SetIntRangeFn setResultRange) {
- setResultRange(getResult(), inferMul(argRanges, convertArithOverflowFlags(
- getOverflowFlags())));
+ ConstantIntRanges range = inferMul(argRanges, convertArithOverflowFlags(
+ getOverflowFlags()));
+ OverflowFlags noOverflowFlags =
+ proveNoOverflowForMul(argRanges) |
+ convertArithOverflowFlags(getOverflowFlags());
+ setResultRange(getResult(), range.withOverflowFlags(noOverflowFlags));
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Interfaces/InferIntRangeInterface.cpp b/mlir/lib/Interfaces/InferIntRangeInterface.cpp
index 9d8e5f50a725b..327f3ed17c988 100644
--- a/mlir/lib/Interfaces/InferIntRangeInterface.cpp
+++ b/mlir/lib/Interfaces/InferIntRangeInterface.cpp
@@ -17,7 +17,8 @@ using namespace mlir;
bool ConstantIntRanges::operator==(const ConstantIntRanges &other) const {
return umin().getBitWidth() == other.umin().getBitWidth() &&
umin() == other.umin() && umax() == other.umax() &&
- smin() == other.smin() && smax() == other.smax();
+ smin() == other.smin() && smax() == other.smax() &&
+ getOverflowFlags() == other.getOverflowFlags();
}
const APInt &ConstantIntRanges::umin() const { return uminVal; }
@@ -94,8 +95,10 @@ ConstantIntRanges::rangeUnion(const ConstantIntRanges &other) const {
const APInt &umaxUnion = umax().ugt(other.umax()) ? umax() : other.umax();
const APInt &sminUnion = smin().slt(other.smin()) ? smin() : other.smin();
const APInt &smaxUnion = smax().sgt(other.smax()) ? smax() : other.smax();
+ OverflowFlags noOverflowUnion =
+ getOverflowFlags() & other.getOverflowFlags();
- return {uminUnion, umaxUnion, sminUnion, smaxUnion};
+ return {uminUnion, umaxUnion, sminUnion, smaxUnion, noOverflowUnion};
}
ConstantIntRanges
@@ -111,8 +114,11 @@ ConstantIntRanges::intersection(const ConstantIntRanges &other) const {
const APInt &umaxIntersect = umax().ult(other.umax()) ? umax() : other.umax();
const APInt &sminIntersect = smin().sgt(other.smin()) ? smin() : other.smin();
const APInt &smaxIntersect = smax().slt(other.smax()) ? smax() : other.smax();
+ OverflowFlags noOverflowIntersect =
+ getOverflowFlags() | other.getOverflowFlags();
- return {uminIntersect, umaxIntersect, sminIntersect, smaxIntersect};
+ return {uminIntersect, umaxIntersect, sminIntersect, smaxIntersect,
+ noOverflowIntersect};
}
std::optional<APInt> ConstantIntRanges::getConstantValue() const {
@@ -129,7 +135,23 @@ raw_ostream &mlir::operator<<(raw_ostream &os, const ConstantIntRanges &range) {
range.umin().print(os, /*isSigned*/ false);
os << ", ";
range.umax().print(os, /*isSigned*/ false);
- return os << "] signed : [" << range.smin() << ", " << range.smax() << "]";
+ os << "] signed : [" << range.smin() << ", " << range.smax() << "]";
+ OverflowFlags noOverflowFlags = range.getOverflowFlags();
+ if (noOverflowFlags == OverflowFlags::None)
+ return os;
+
+ os << " nooverflow<";
+ bool emitted = false;
+ if ((noOverflowFlags & OverflowFlags::Nsw) != OverflowFlags::None) {
+ os << "nsw";
+ emitted = true;
+ }
+ if ((noOverflowFlags & OverflowFlags::Nuw) != OverflowFlags::None) {
+ if (emitted)
+ os << ", ";
+ os << "nuw";
+ }
+ return os << ">";
}
IntegerValueRange IntegerValueRange::getMaxRange(Value value) {
diff --git a/mlir/unittests/Interfaces/InferIntRangeInterfaceTest.cpp b/mlir/unittests/Interfaces/InferIntRangeInterfaceTest.cpp
index 97c75b3680567..19890a2d93c7b 100644
--- a/mlir/unittests/Interfaces/InferIntRangeInterfaceTest.cpp
+++ b/mlir/unittests/Interfaces/InferIntRangeInterfaceTest.cpp
@@ -8,6 +8,7 @@
#include "mlir/Interfaces/InferIntRangeInterface.h"
#include "llvm/ADT/APInt.h"
+#include "llvm/Support/raw_ostream.h"
#include <limits>
#include <gtest/gtest.h>
@@ -97,3 +98,67 @@ TEST(IntRangeAttrs, Join) {
ConstantIntRanges zeroOneSignedOnly(zero, uintMax, zero, one);
EXPECT_EQ(zeroOneUnsignedOnly.rangeUnion(zeroOneSignedOnly), maximal);
}
+
+TEST(IntRangeAttrs, OverflowFlags) {
+ APInt zero = APInt::getZero(64);
+ APInt one = zero + 1;
+ APInt two = zero + 2;
+
+ ConstantIntRanges nswOnly(zero, one, zero, one, OverflowFlags::Nsw);
+ ConstantIntRanges nuwOnly(one, two, one, two, OverflowFlags::Nuw);
+
+ EXPECT_NE(nswOnly.getOverflowFlags() & OverflowFlags::Nsw,
+ OverflowFlags::None);
+ EXPECT_EQ(nswOnly.getOverflowFlags() & OverflowFlags::Nuw,
+ OverflowFlags::None);
+
+ ConstantIntRanges both = nswOnly.withOverflowFlags(OverflowFlags::Nsw |
+ OverflowFlags::Nuw);
+ EXPECT_NE(both.getOverflowFlags() & OverflowFlags::Nsw,
+ OverflowFlags::None);
+ EXPECT_NE(both.getOverflowFlags() & OverflowFlags::Nuw,
+ OverflowFlags::None);
+
+ // rangeUnion conservatively preserves only proofs present in both inputs.
+ EXPECT_EQ(nswOnly.rangeUnion(nuwOnly).getOverflowFlags(),
+ OverflowFlags::None);
+ EXPECT_EQ(both.rangeUnion(nswOnly).getOverflowFlags(),
+ OverflowFlags::Nsw);
+
+ // intersection preserves proofs from either input.
+ EXPECT_EQ(nswOnly.intersection(nuwOnly).getOverflowFlags(),
+ OverflowFlags::Nsw | OverflowFlags::Nuw);
+ EXPECT_EQ(both.intersection(nswOnly).getOverflowFlags(),
+ OverflowFlags::Nsw | OverflowFlags::Nuw);
+ ConstantIntRanges none(zero, two, zero, two, OverflowFlags::None);
+ EXPECT_EQ(nswOnly.intersection(none).getOverflowFlags(),
+ OverflowFlags::Nsw);
+}
+
+TEST(IntRangeAttrs, OverflowFlagsPrinting) {
+ APInt zero = APInt::getZero(64);
+ APInt one = zero + 1;
+
+ auto toString = [](const ConstantIntRanges &r) {
+ std::string buf;
+ llvm::raw_string_ostream os(buf);
+ os << r;
+ return buf;
+ };
+
+ ConstantIntRanges noFlags(zero, one, zero, one);
+ EXPECT_EQ(toString(noFlags), "unsigned : [0, 1] signed : [0, 1]");
+
+ ConstantIntRanges nsw(zero, one, zero, one, OverflowFlags::Nsw);
+ EXPECT_EQ(toString(nsw),
+ "unsigned : [0, 1] signed : [0, 1] nooverflow<nsw>");
+
+ ConstantIntRanges nuw(zero, one, zero, one, OverflowFlags::Nuw);
+ EXPECT_EQ(toString(nuw),
+ "unsigned : [0, 1] signed : [0, 1] nooverflow<nuw>");
+
+ ConstantIntRanges both(zero, one, zero, one,
+ OverflowFlags::Nsw | OverflowFlags::Nuw);
+ EXPECT_EQ(toString(both),
+ "unsigned : [0, 1] signed : [0, 1] nooverflow<nsw, nuw>");
+}
``````````
</details>
https://github.com/llvm/llvm-project/pull/191777
More information about the Mlir-commits
mailing list