[Mlir-commits] [mlir] [mlir][Interfaces] Track and infer no-overflow flags in integer ranges (PR #191777)
Hocky Yudhiono
llvmlistbot at llvm.org
Tue Apr 14 01:18:19 PDT 2026
https://github.com/hockyy updated https://github.com/llvm/llvm-project/pull/191777
>From bc8d2d58e079aa9619bc51874f1c60498eca5cb6 Mon Sep 17 00:00:00 2001
From: Hocky Yudhiono <hocky.yudhiono at gmail.com>
Date: Mon, 13 Apr 2026 17:08:01 +0800
Subject: [PATCH 1/4] [mlir][Interfaces] Track and infer no-overflow flags in
integer ranges
---
.../mlir/Interfaces/InferIntRangeInterface.h | 25 +++-
.../Interfaces/Utils/InferIntRangeCommon.h | 8 +-
.../Arith/IR/InferIntRangeInterfaceImpls.cpp | 132 +++++++++++++++++-
.../lib/Interfaces/InferIntRangeInterface.cpp | 31 +++-
.../Interfaces/InferIntRangeInterfaceTest.cpp | 59 ++++++++
5 files changed, 236 insertions(+), 19 deletions(-)
diff --git a/mlir/include/mlir/Interfaces/InferIntRangeInterface.h b/mlir/include/mlir/Interfaces/InferIntRangeInterface.h
index a6de3d1885eec..9870cf1cbb6fa 100644
--- a/mlir/include/mlir/Interfaces/InferIntRangeInterface.h
+++ b/mlir/include/mlir/Interfaces/InferIntRangeInterface.h
@@ -15,9 +15,18 @@
#define MLIR_INTERFACES_INFERINTRANGEINTERFACE_H
#include "mlir/IR/OpDefinition.h"
+#include "llvm/ADT/BitmaskEnum.h"
#include <optional>
namespace mlir {
+enum class OverflowFlags : uint32_t {
+ None = 0,
+ Nsw = 1,
+ Nuw = 2,
+ LLVM_MARK_AS_BITMASK_ENUM(Nuw)
+};
+LLVM_ENABLE_BITMASK_ENUMS_IN_NAMESPACE();
+
/// A set of arbitrary-precision integers representing bounds on a given integer
/// value. These bounds are inclusive on both ends, so
/// bounds of [4, 5] mean 4 <= x <= 5. Separate bounds are tracked for
@@ -29,8 +38,10 @@ class ConstantIntRanges {
/// Bound umin <= (unsigned)x <= umax and smin <= signed(x) <= smax.
/// Non-integer values should be bounded by APInts of bitwidth 0.
ConstantIntRanges(const APInt &umin, const APInt &umax, const APInt &smin,
- const APInt &smax)
- : uminVal(umin), umaxVal(umax), sminVal(smin), smaxVal(smax) {
+ const APInt &smax,
+ OverflowFlags overflowFlags = OverflowFlags::None)
+ : uminVal(umin), umaxVal(umax), sminVal(smin), smaxVal(smax),
+ overflowFlags(overflowFlags) {
assert(uminVal.getBitWidth() == umaxVal.getBitWidth() &&
umaxVal.getBitWidth() == sminVal.getBitWidth() &&
sminVal.getBitWidth() == smaxVal.getBitWidth() &&
@@ -96,11 +107,21 @@ class ConstantIntRanges {
/// value.
std::optional<APInt> getConstantValue() const;
+ /// Return overflow properties proven for the operation computing the bounded
+ /// value.
+ OverflowFlags getOverflowFlags() const { return overflowFlags; }
+
+ /// Return this range with updated overflow properties.
+ ConstantIntRanges withOverflowFlags(OverflowFlags flags) const {
+ return {uminVal, umaxVal, sminVal, smaxVal, flags};
+ }
+
friend raw_ostream &operator<<(raw_ostream &os,
const ConstantIntRanges &range);
private:
APInt uminVal, umaxVal, sminVal, smaxVal;
+ OverflowFlags overflowFlags = OverflowFlags::None;
};
raw_ostream &operator<<(raw_ostream &, const ConstantIntRanges &);
diff --git a/mlir/include/mlir/Interfaces/Utils/InferIntRangeCommon.h b/mlir/include/mlir/Interfaces/Utils/InferIntRangeCommon.h
index e369c80a26ea9..8c0784a4513bb 100644
--- a/mlir/include/mlir/Interfaces/Utils/InferIntRangeCommon.h
+++ b/mlir/include/mlir/Interfaces/Utils/InferIntRangeCommon.h
@@ -16,7 +16,6 @@
#include "mlir/Interfaces/InferIntRangeInterface.h"
#include "llvm/ADT/ArrayRef.h"
-#include "llvm/ADT/BitmaskEnum.h"
#include <optional>
namespace mlir {
@@ -39,12 +38,7 @@ static constexpr unsigned indexMaxWidth = 64;
enum class CmpMode : uint32_t { Both, Signed, Unsigned };
-enum class OverflowFlags : uint32_t {
- None = 0,
- Nsw = 1,
- Nuw = 2,
- LLVM_MARK_AS_BITMASK_ENUM(Nuw)
-};
+using OverflowFlags = mlir::OverflowFlags;
/// Function that performs inference on an array of `ConstantIntRanges` while
/// taking special overflow behavior into account.
diff --git a/mlir/lib/Dialect/Arith/IR/InferIntRangeInterfaceImpls.cpp b/mlir/lib/Dialect/Arith/IR/InferIntRangeInterfaceImpls.cpp
index 49f89e1bd17f3..2624448b5f7b4 100644
--- a/mlir/lib/Dialect/Arith/IR/InferIntRangeInterfaceImpls.cpp
+++ b/mlir/lib/Dialect/Arith/IR/InferIntRangeInterfaceImpls.cpp
@@ -11,6 +11,7 @@
#include "mlir/Interfaces/Utils/InferIntRangeCommon.h"
#include <optional>
+#include <utility>
#define DEBUG_TYPE "int-range-analysis"
@@ -28,6 +29,116 @@ convertArithOverflowFlags(arith::IntegerOverflowFlags flags) {
return retFlags;
}
+template <typename Op>
+static bool proveNoOverflow(const APInt &lhs, const APInt &rhs, Op op) {
+ bool overflow = false;
+ (void)op(lhs, rhs, overflow);
+ return !overflow;
+}
+
+template <typename Op>
+static bool
+proveNoOverflowForPairs(ArrayRef<std::pair<const APInt *, const APInt *>> pairs,
+ Op op) {
+ for (const auto &[lhs, rhs] : pairs) {
+ if (!proveNoOverflow(*lhs, *rhs, op))
+ return false;
+ }
+ return true;
+}
+
+static OverflowFlags proveNoOverflowFlags(
+ ArrayRef<ConstantIntRanges> args,
+ function_ref<bool(ArrayRef<ConstantIntRanges>)> proveSigned,
+ function_ref<bool(ArrayRef<ConstantIntRanges>)> proveUnsigned) {
+ OverflowFlags flags = OverflowFlags::None;
+ if (proveSigned(args))
+ flags |= OverflowFlags::Nsw;
+ if (proveUnsigned(args))
+ flags |= OverflowFlags::Nuw;
+ return flags;
+}
+
+static bool proveNoSignedAddOverflow(ArrayRef<ConstantIntRanges> argRanges) {
+ const APInt &lhsMin = argRanges[0].smin();
+ const APInt &lhsMax = argRanges[0].smax();
+ const APInt &rhsMin = argRanges[1].smin();
+ const APInt &rhsMax = argRanges[1].smax();
+ // Signed add is monotone in both operands, so it is enough to check
+ // the interval endpoints to prove no signed wrap for the whole range.
+ return proveNoOverflowForPairs(
+ {{&lhsMin, &rhsMin}, {&lhsMax, &rhsMax}},
+ [](const APInt &lhs, const APInt &rhs, bool &overflow) {
+ return lhs.sadd_ov(rhs, overflow);
+ });
+}
+
+static bool proveNoUnsignedAddOverflow(ArrayRef<ConstantIntRanges> argRanges) {
+ return proveNoOverflow(
+ argRanges[0].umax(), argRanges[1].umax(),
+ [](const APInt &lhs, const APInt &rhs, bool &overflow) {
+ return lhs.uadd_ov(rhs, overflow);
+ });
+}
+
+static bool proveNoSignedSubOverflow(ArrayRef<ConstantIntRanges> argRanges) {
+ const APInt &lhsMin = argRanges[0].smin();
+ const APInt &lhsMax = argRanges[0].smax();
+ const APInt &rhsMin = argRanges[1].smin();
+ const APInt &rhsMax = argRanges[1].smax();
+ // For lhs - rhs, the extrema occur at (lhsMin - rhsMax) and
+ // (lhsMax - rhsMin). If both are no-wrap, the full interval is no-wrap.
+ return proveNoOverflowForPairs(
+ {{&lhsMin, &rhsMax}, {&lhsMax, &rhsMin}},
+ [](const APInt &lhs, const APInt &rhs, bool &overflow) {
+ return lhs.ssub_ov(rhs, overflow);
+ });
+}
+
+static bool proveNoUnsignedSubOverflow(ArrayRef<ConstantIntRanges> argRanges) {
+ return argRanges[0].umin().uge(argRanges[1].umax());
+}
+
+static bool proveNoSignedMulOverflow(ArrayRef<ConstantIntRanges> argRanges) {
+ const APInt &lhsMin = argRanges[0].smin();
+ const APInt &lhsMax = argRanges[0].smax();
+ const APInt &rhsMin = argRanges[1].smin();
+ const APInt &rhsMax = argRanges[1].smax();
+ // Signed multiply is not monotone across sign changes, so conservatively
+ // require all four corner products to be no-wrap.
+ return proveNoOverflowForPairs(
+ {{&lhsMin, &rhsMin},
+ {&lhsMin, &rhsMax},
+ {&lhsMax, &rhsMin},
+ {&lhsMax, &rhsMax}},
+ [](const APInt &lhs, const APInt &rhs, bool &overflow) {
+ return lhs.smul_ov(rhs, overflow);
+ });
+}
+
+static bool proveNoUnsignedMulOverflow(ArrayRef<ConstantIntRanges> argRanges) {
+ return proveNoOverflow(
+ argRanges[0].umax(), argRanges[1].umax(),
+ [](const APInt &lhs, const APInt &rhs, bool &overflow) {
+ return lhs.umul_ov(rhs, overflow);
+ });
+}
+
+static OverflowFlags proveNoOverflowForAdd(ArrayRef<ConstantIntRanges> args) {
+ return proveNoOverflowFlags(args, proveNoSignedAddOverflow,
+ proveNoUnsignedAddOverflow);
+}
+
+static OverflowFlags proveNoOverflowForSub(ArrayRef<ConstantIntRanges> args) {
+ return proveNoOverflowFlags(args, proveNoSignedSubOverflow,
+ proveNoUnsignedSubOverflow);
+}
+
+static OverflowFlags proveNoOverflowForMul(ArrayRef<ConstantIntRanges> args) {
+ return proveNoOverflowFlags(args, proveNoSignedMulOverflow,
+ proveNoUnsignedMulOverflow);
+}
+
//===----------------------------------------------------------------------===//
// ConstantOp
//===----------------------------------------------------------------------===//
@@ -65,8 +176,11 @@ void arith::ConstantOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
void arith::AddIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
SetIntRangeFn setResultRange) {
- setResultRange(getResult(), inferAdd(argRanges, convertArithOverflowFlags(
- getOverflowFlags())));
+ OverflowFlags declaredFlags = convertArithOverflowFlags(getOverflowFlags());
+ ConstantIntRanges range = inferAdd(argRanges, declaredFlags);
+ OverflowFlags overflowFlags =
+ proveNoOverflowForAdd(argRanges) | declaredFlags;
+ setResultRange(getResult(), range.withOverflowFlags(overflowFlags));
}
//===----------------------------------------------------------------------===//
@@ -75,8 +189,11 @@ void arith::AddIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
void arith::SubIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
SetIntRangeFn setResultRange) {
- setResultRange(getResult(), inferSub(argRanges, convertArithOverflowFlags(
- getOverflowFlags())));
+ OverflowFlags declaredFlags = convertArithOverflowFlags(getOverflowFlags());
+ ConstantIntRanges range = inferSub(argRanges, declaredFlags);
+ OverflowFlags overflowFlags =
+ proveNoOverflowForSub(argRanges) | declaredFlags;
+ setResultRange(getResult(), range.withOverflowFlags(overflowFlags));
}
//===----------------------------------------------------------------------===//
@@ -85,8 +202,11 @@ void arith::SubIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
void arith::MulIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
SetIntRangeFn setResultRange) {
- setResultRange(getResult(), inferMul(argRanges, convertArithOverflowFlags(
- getOverflowFlags())));
+ OverflowFlags declaredFlags = convertArithOverflowFlags(getOverflowFlags());
+ ConstantIntRanges range = inferMul(argRanges, declaredFlags);
+ OverflowFlags overflowFlags =
+ proveNoOverflowForMul(argRanges) | declaredFlags;
+ setResultRange(getResult(), range.withOverflowFlags(overflowFlags));
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Interfaces/InferIntRangeInterface.cpp b/mlir/lib/Interfaces/InferIntRangeInterface.cpp
index 9d8e5f50a725b..0d3e06dc4d995 100644
--- a/mlir/lib/Interfaces/InferIntRangeInterface.cpp
+++ b/mlir/lib/Interfaces/InferIntRangeInterface.cpp
@@ -17,7 +17,8 @@ using namespace mlir;
bool ConstantIntRanges::operator==(const ConstantIntRanges &other) const {
return umin().getBitWidth() == other.umin().getBitWidth() &&
umin() == other.umin() && umax() == other.umax() &&
- smin() == other.smin() && smax() == other.smax();
+ smin() == other.smin() && smax() == other.smax() &&
+ getOverflowFlags() == other.getOverflowFlags();
}
const APInt &ConstantIntRanges::umin() const { return uminVal; }
@@ -94,8 +95,10 @@ ConstantIntRanges::rangeUnion(const ConstantIntRanges &other) const {
const APInt &umaxUnion = umax().ugt(other.umax()) ? umax() : other.umax();
const APInt &sminUnion = smin().slt(other.smin()) ? smin() : other.smin();
const APInt &smaxUnion = smax().sgt(other.smax()) ? smax() : other.smax();
+ // For a union, keep only guarantees proven on both inputs.
+ OverflowFlags overflowUnion = getOverflowFlags() & other.getOverflowFlags();
- return {uminUnion, umaxUnion, sminUnion, smaxUnion};
+ return {uminUnion, umaxUnion, sminUnion, smaxUnion, overflowUnion};
}
ConstantIntRanges
@@ -111,8 +114,12 @@ ConstantIntRanges::intersection(const ConstantIntRanges &other) const {
const APInt &umaxIntersect = umax().ult(other.umax()) ? umax() : other.umax();
const APInt &sminIntersect = smin().sgt(other.smin()) ? smin() : other.smin();
const APInt &smaxIntersect = smax().slt(other.smax()) ? smax() : other.smax();
+ // For an intersection, guarantees from either input remain valid.
+ OverflowFlags overflowIntersect =
+ getOverflowFlags() | other.getOverflowFlags();
- return {uminIntersect, umaxIntersect, sminIntersect, smaxIntersect};
+ return {uminIntersect, umaxIntersect, sminIntersect, smaxIntersect,
+ overflowIntersect};
}
std::optional<APInt> ConstantIntRanges::getConstantValue() const {
@@ -129,7 +136,23 @@ raw_ostream &mlir::operator<<(raw_ostream &os, const ConstantIntRanges &range) {
range.umin().print(os, /*isSigned*/ false);
os << ", ";
range.umax().print(os, /*isSigned*/ false);
- return os << "] signed : [" << range.smin() << ", " << range.smax() << "]";
+ os << "] signed : [" << range.smin() << ", " << range.smax() << "]";
+ OverflowFlags overflowFlags = range.getOverflowFlags();
+ if (overflowFlags == OverflowFlags::None)
+ return os;
+
+ os << " overflow<";
+ bool emitted = false;
+ if ((overflowFlags & OverflowFlags::Nsw) != OverflowFlags::None) {
+ os << "nsw";
+ emitted = true;
+ }
+ if ((overflowFlags & OverflowFlags::Nuw) != OverflowFlags::None) {
+ if (emitted)
+ os << ", ";
+ os << "nuw";
+ }
+ return os << ">";
}
IntegerValueRange IntegerValueRange::getMaxRange(Value value) {
diff --git a/mlir/unittests/Interfaces/InferIntRangeInterfaceTest.cpp b/mlir/unittests/Interfaces/InferIntRangeInterfaceTest.cpp
index 97c75b3680567..1b88a3f7a74aa 100644
--- a/mlir/unittests/Interfaces/InferIntRangeInterfaceTest.cpp
+++ b/mlir/unittests/Interfaces/InferIntRangeInterfaceTest.cpp
@@ -8,6 +8,7 @@
#include "mlir/Interfaces/InferIntRangeInterface.h"
#include "llvm/ADT/APInt.h"
+#include "llvm/Support/raw_ostream.h"
#include <limits>
#include <gtest/gtest.h>
@@ -97,3 +98,61 @@ TEST(IntRangeAttrs, Join) {
ConstantIntRanges zeroOneSignedOnly(zero, uintMax, zero, one);
EXPECT_EQ(zeroOneUnsignedOnly.rangeUnion(zeroOneSignedOnly), maximal);
}
+
+TEST(IntRangeAttrs, OverflowFlags) {
+ APInt zero = APInt::getZero(64);
+ APInt one = zero + 1;
+ APInt two = zero + 2;
+
+ ConstantIntRanges nswOnly(zero, one, zero, one, OverflowFlags::Nsw);
+ ConstantIntRanges nuwOnly(one, two, one, two, OverflowFlags::Nuw);
+
+ EXPECT_NE(nswOnly.getOverflowFlags() & OverflowFlags::Nsw,
+ OverflowFlags::None);
+ EXPECT_EQ(nswOnly.getOverflowFlags() & OverflowFlags::Nuw,
+ OverflowFlags::None);
+
+ ConstantIntRanges both =
+ nswOnly.withOverflowFlags(OverflowFlags::Nsw | OverflowFlags::Nuw);
+ EXPECT_NE(both.getOverflowFlags() & OverflowFlags::Nsw, OverflowFlags::None);
+ EXPECT_NE(both.getOverflowFlags() & OverflowFlags::Nuw, OverflowFlags::None);
+
+ // rangeUnion conservatively preserves only proofs present in both inputs.
+ EXPECT_EQ(nswOnly.rangeUnion(nuwOnly).getOverflowFlags(),
+ OverflowFlags::None);
+ EXPECT_EQ(both.rangeUnion(nswOnly).getOverflowFlags(), OverflowFlags::Nsw);
+
+ // intersection preserves proofs from either input.
+ EXPECT_EQ(nswOnly.intersection(nuwOnly).getOverflowFlags(),
+ OverflowFlags::Nsw | OverflowFlags::Nuw);
+ EXPECT_EQ(both.intersection(nswOnly).getOverflowFlags(),
+ OverflowFlags::Nsw | OverflowFlags::Nuw);
+ ConstantIntRanges none(zero, two, zero, two, OverflowFlags::None);
+ EXPECT_EQ(nswOnly.intersection(none).getOverflowFlags(), OverflowFlags::Nsw);
+}
+
+TEST(IntRangeAttrs, OverflowFlagsPrinting) {
+ APInt zero = APInt::getZero(64);
+ APInt one = zero + 1;
+
+ auto toString = [](const ConstantIntRanges &r) {
+ std::string buf;
+ llvm::raw_string_ostream os(buf);
+ os << r;
+ return buf;
+ };
+
+ ConstantIntRanges noFlags(zero, one, zero, one);
+ EXPECT_EQ(toString(noFlags), "unsigned : [0, 1] signed : [0, 1]");
+
+ ConstantIntRanges nsw(zero, one, zero, one, OverflowFlags::Nsw);
+ EXPECT_EQ(toString(nsw), "unsigned : [0, 1] signed : [0, 1] overflow<nsw>");
+
+ ConstantIntRanges nuw(zero, one, zero, one, OverflowFlags::Nuw);
+ EXPECT_EQ(toString(nuw), "unsigned : [0, 1] signed : [0, 1] overflow<nuw>");
+
+ ConstantIntRanges both(zero, one, zero, one,
+ OverflowFlags::Nsw | OverflowFlags::Nuw);
+ EXPECT_EQ(toString(both),
+ "unsigned : [0, 1] signed : [0, 1] overflow<nsw, nuw>");
+}
>From ec713c606a81132c1547c5153047c3fdc7804e79 Mon Sep 17 00:00:00 2001
From: Hocky Yudhiono <hocky.yudhiono at gmail.com>
Date: Mon, 13 Apr 2026 17:58:56 +0800
Subject: [PATCH 2/4] [mlir][interfaces] Add testcases
---
.../mlir/Interfaces/InferIntRangeInterface.h | 4 +
.../DataFlow/StridedMetadataRangeAnalysis.cpp | 8 +-
.../lib/Interfaces/InferIntRangeInterface.cpp | 7 +-
.../Interfaces/Utils/InferIntRangeCommon.cpp | 4 +-
.../Interfaces/InferIntRangeInterfaceTest.cpp | 181 ++++++++++++++++++
5 files changed, 197 insertions(+), 7 deletions(-)
diff --git a/mlir/include/mlir/Interfaces/InferIntRangeInterface.h b/mlir/include/mlir/Interfaces/InferIntRangeInterface.h
index 9870cf1cbb6fa..58682967df9b1 100644
--- a/mlir/include/mlir/Interfaces/InferIntRangeInterface.h
+++ b/mlir/include/mlir/Interfaces/InferIntRangeInterface.h
@@ -50,6 +50,10 @@ class ConstantIntRanges {
bool operator==(const ConstantIntRanges &other) const;
+ /// Returns true if all signed and unsigned bounds match, ignoring
+ /// overflow flags.
+ bool hasSameBounds(const ConstantIntRanges &other) const;
+
/// The minimum value of an integer when it is interpreted as unsigned.
const APInt &umin() const;
diff --git a/mlir/lib/Analysis/DataFlow/StridedMetadataRangeAnalysis.cpp b/mlir/lib/Analysis/DataFlow/StridedMetadataRangeAnalysis.cpp
index 01c9dafaddf10..cb7305195de7a 100644
--- a/mlir/lib/Analysis/DataFlow/StridedMetadataRangeAnalysis.cpp
+++ b/mlir/lib/Analysis/DataFlow/StridedMetadataRangeAnalysis.cpp
@@ -104,10 +104,10 @@ LogicalResult StridedMetadataRangeAnalysis::visitOperation(
};
// Convert the arguments lattices to a vector.
- SmallVector<StridedMetadataRange> argRanges = llvm::map_to_vector(
- operands, [](const StridedMetadataRangeLattice *lattice) {
- return lattice->getValue();
- });
+ SmallVector<StridedMetadataRange, 2> argRanges;
+ argRanges.reserve(operands.size());
+ for (const StridedMetadataRangeLattice *lattice : operands)
+ argRanges.push_back(lattice->getValue());
// Callback to set metadata on a result.
auto joinCallback = [&](Value v, const StridedMetadataRange &md) {
diff --git a/mlir/lib/Interfaces/InferIntRangeInterface.cpp b/mlir/lib/Interfaces/InferIntRangeInterface.cpp
index 0d3e06dc4d995..090c13cea4599 100644
--- a/mlir/lib/Interfaces/InferIntRangeInterface.cpp
+++ b/mlir/lib/Interfaces/InferIntRangeInterface.cpp
@@ -15,10 +15,13 @@
using namespace mlir;
bool ConstantIntRanges::operator==(const ConstantIntRanges &other) const {
+ return hasSameBounds(other) && getOverflowFlags() == other.getOverflowFlags();
+}
+
+bool ConstantIntRanges::hasSameBounds(const ConstantIntRanges &other) const {
return umin().getBitWidth() == other.umin().getBitWidth() &&
umin() == other.umin() && umax() == other.umax() &&
- smin() == other.smin() && smax() == other.smax() &&
- getOverflowFlags() == other.getOverflowFlags();
+ smin() == other.smin() && smax() == other.smax();
}
const APInt &ConstantIntRanges::umin() const { return uminVal; }
diff --git a/mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp b/mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp
index c9f49fda726e7..ffb407a650599 100644
--- a/mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp
+++ b/mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp
@@ -100,7 +100,9 @@ mlir::intrange::inferIndexOp(const InferRangeFn &inferFn,
bool truncEqual = false;
switch (mode) {
case intrange::CmpMode::Both:
- truncEqual = (thirtyTwo == sixtyFourAsThirtyTwo);
+ // Overflow flags are auxiliary guarantees and should not decide whether
+ // the 64-bit result is a truncation-compatible match.
+ truncEqual = thirtyTwo.hasSameBounds(sixtyFourAsThirtyTwo);
break;
case intrange::CmpMode::Signed:
truncEqual = (thirtyTwo.smin() == sixtyFourAsThirtyTwo.smin() &&
diff --git a/mlir/unittests/Interfaces/InferIntRangeInterfaceTest.cpp b/mlir/unittests/Interfaces/InferIntRangeInterfaceTest.cpp
index 1b88a3f7a74aa..9affdeac1fb6f 100644
--- a/mlir/unittests/Interfaces/InferIntRangeInterfaceTest.cpp
+++ b/mlir/unittests/Interfaces/InferIntRangeInterfaceTest.cpp
@@ -7,6 +7,9 @@
//===----------------------------------------------------------------------===//
#include "mlir/Interfaces/InferIntRangeInterface.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/IR/BuiltinOps.h"
+#include "mlir/Interfaces/Utils/InferIntRangeCommon.h"
#include "llvm/ADT/APInt.h"
#include "llvm/Support/raw_ostream.h"
#include <limits>
@@ -15,6 +18,22 @@
using namespace mlir;
+namespace {
+
+template <typename OpTy>
+ConstantIntRanges inferBinaryOpResult(OpTy op,
+ ArrayRef<ConstantIntRanges> argRanges) {
+ std::optional<ConstantIntRanges> inferred;
+ op.inferResultRanges(argRanges, [&](Value v, const ConstantIntRanges &range) {
+ if (v == op.getResult())
+ inferred = range;
+ });
+ assert(inferred.has_value() && "binary op did not produce a result range");
+ return *inferred;
+}
+
+} // namespace
+
TEST(IntRangeAttrs, BasicConstructors) {
APInt zero = APInt::getZero(64);
APInt two(64, 2);
@@ -129,6 +148,12 @@ TEST(IntRangeAttrs, OverflowFlags) {
OverflowFlags::Nsw | OverflowFlags::Nuw);
ConstantIntRanges none(zero, two, zero, two, OverflowFlags::None);
EXPECT_EQ(nswOnly.intersection(none).getOverflowFlags(), OverflowFlags::Nsw);
+
+ // Full equality tracks both bounds and overflow proofs.
+ EXPECT_FALSE(nswOnly == nswOnly.withOverflowFlags(OverflowFlags::None));
+ // Bounds-only equality intentionally ignores overflow proofs.
+ EXPECT_TRUE(nswOnly.hasSameBounds(
+ nswOnly.withOverflowFlags(OverflowFlags::Nsw | OverflowFlags::Nuw)));
}
TEST(IntRangeAttrs, OverflowFlagsPrinting) {
@@ -156,3 +181,159 @@ TEST(IntRangeAttrs, OverflowFlagsPrinting) {
EXPECT_EQ(toString(both),
"unsigned : [0, 1] signed : [0, 1] overflow<nsw, nuw>");
}
+
+TEST(IntRangeAttrs, InferIndexOpCmpBothIgnoresOverflowFlags) {
+ intrange::InferRangeFn inferFn = [](ArrayRef<ConstantIntRanges> args) {
+ unsigned width = args.front().umin().getBitWidth();
+ APInt zero = APInt::getZero(width);
+ APInt one(width, 1);
+ return ConstantIntRanges(zero, one, zero, one, OverflowFlags::Nsw);
+ };
+
+ APInt zero64 = APInt::getZero(64);
+ APInt one64(64, 1);
+ ConstantIntRanges arg(zero64, one64, zero64, one64);
+ ConstantIntRanges result =
+ intrange::inferIndexOp(inferFn, {arg}, intrange::CmpMode::Both);
+
+ EXPECT_EQ(result.umin(), zero64);
+ EXPECT_EQ(result.umax(), one64);
+ EXPECT_EQ(result.smin(), zero64);
+ EXPECT_EQ(result.smax(), one64);
+ EXPECT_EQ(result.getOverflowFlags(), OverflowFlags::Nsw);
+}
+
+TEST(IntRangeAttrs, ArithAddIOpInfersOverflowFlags) {
+ MLIRContext context;
+ context.loadDialect<arith::ArithDialect>();
+
+ OpBuilder builder(&context);
+ Location loc = builder.getUnknownLoc();
+ ModuleOp module = ModuleOp::create(loc);
+ builder.setInsertionPointToStart(module.getBody());
+
+ Value zero = arith::ConstantIntOp::create(builder, loc, 0, 8);
+ Value one = arith::ConstantIntOp::create(builder, loc, 1, 8);
+
+ arith::AddIOp add = arith::AddIOp::create(builder, loc, zero, one);
+ arith::AddIOp addWithDeclaredNsw = arith::AddIOp::create(
+ builder, loc, zero, one, arith::IntegerOverflowFlags::nsw);
+
+ APInt c0(8, 0), c1(8, 1), c2(8, 2), c10(8, 10), c20(8, 20), c120(8, 120),
+ c127(8, 127), c255(8, 255);
+
+ // Both signed and unsigned proofs succeed.
+ ConstantIntRanges addLhs(c0, c10, c0, c10);
+ ConstantIntRanges addRhs(c0, c20, c0, c20);
+ ConstantIntRanges addResult = inferBinaryOpResult(add, {addLhs, addRhs});
+ EXPECT_EQ(addResult.getOverflowFlags(),
+ OverflowFlags::Nsw | OverflowFlags::Nuw);
+
+ // Signed may overflow, but unsigned remains provably no-wrap.
+ ConstantIntRanges mayOverflowLhs(c120, c127, c120, c127);
+ ConstantIntRanges mayOverflowRhs(c1, c2, c1, c2);
+ ConstantIntRanges addMayOverflowResult =
+ inferBinaryOpResult(add, {mayOverflowLhs, mayOverflowRhs});
+ EXPECT_EQ(addMayOverflowResult.getOverflowFlags(), OverflowFlags::Nuw);
+
+ // Declared op flags are preserved and merged with inferred ones.
+ ConstantIntRanges addDeclaredNswResult =
+ inferBinaryOpResult(addWithDeclaredNsw, {mayOverflowLhs, mayOverflowRhs});
+ EXPECT_EQ(addDeclaredNswResult.getOverflowFlags(),
+ OverflowFlags::Nsw | OverflowFlags::Nuw);
+
+ // Both signed and unsigned proofs fail.
+ ConstantIntRanges fullUnsigned = ConstantIntRanges::fromUnsigned(c0, c255);
+ ConstantIntRanges addFullyOverflowingResult =
+ inferBinaryOpResult(add, {fullUnsigned, fullUnsigned});
+ EXPECT_EQ(addFullyOverflowingResult.getOverflowFlags(), OverflowFlags::None);
+}
+
+TEST(IntRangeAttrs, ArithSubIOpInfersOverflowFlags) {
+ MLIRContext context;
+ context.loadDialect<arith::ArithDialect>();
+
+ OpBuilder builder(&context);
+ Location loc = builder.getUnknownLoc();
+ ModuleOp module = ModuleOp::create(loc);
+ builder.setInsertionPointToStart(module.getBody());
+
+ Value zero = arith::ConstantIntOp::create(builder, loc, 0, 8);
+ Value one = arith::ConstantIntOp::create(builder, loc, 1, 8);
+
+ arith::SubIOp sub = arith::SubIOp::create(builder, loc, zero, one);
+ arith::SubIOp subWithDeclaredNuw = arith::SubIOp::create(
+ builder, loc, zero, one, arith::IntegerOverflowFlags::nuw);
+
+ APInt c0(8, 0), c5(8, 5), c10(8, 10), c20(8, 20);
+ APInt sMin = APInt::getSignedMinValue(8);
+ APInt sNeg120(8, -120, true);
+
+ // Both signed and unsigned proofs succeed.
+ ConstantIntRanges subLhs(c10, c20, c10, c20);
+ ConstantIntRanges subRhs(c0, c5, c0, c5);
+ ConstantIntRanges subResult = inferBinaryOpResult(sub, {subLhs, subRhs});
+ EXPECT_EQ(subResult.getOverflowFlags(),
+ OverflowFlags::Nsw | OverflowFlags::Nuw);
+
+ // Signed may overflow, but unsigned remains provably no-wrap.
+ ConstantIntRanges subMayOverflowLhs =
+ ConstantIntRanges::fromSigned(sMin, sNeg120);
+ ConstantIntRanges subMayOverflowRhs(c10, c20, c10, c20);
+ ConstantIntRanges subMayOverflowResult =
+ inferBinaryOpResult(sub, {subMayOverflowLhs, subMayOverflowRhs});
+ EXPECT_EQ(subMayOverflowResult.getOverflowFlags(), OverflowFlags::Nuw);
+
+ // Declared op flags are preserved.
+ ConstantIntRanges subDeclaredNuwResult = inferBinaryOpResult(
+ subWithDeclaredNuw, {subMayOverflowLhs, subMayOverflowRhs});
+ EXPECT_EQ(subDeclaredNuwResult.getOverflowFlags(), OverflowFlags::Nuw);
+}
+
+TEST(IntRangeAttrs, ArithMulIOpInfersOverflowFlags) {
+ MLIRContext context;
+ context.loadDialect<arith::ArithDialect>();
+
+ OpBuilder builder(&context);
+ Location loc = builder.getUnknownLoc();
+ ModuleOp module = ModuleOp::create(loc);
+ builder.setInsertionPointToStart(module.getBody());
+
+ Value zero = arith::ConstantIntOp::create(builder, loc, 0, 8);
+ Value one = arith::ConstantIntOp::create(builder, loc, 1, 8);
+
+ arith::MulIOp mul = arith::MulIOp::create(builder, loc, zero, one);
+ arith::MulIOp mulWithDeclaredFlags = arith::MulIOp::create(
+ builder, loc, zero, one,
+ arith::IntegerOverflowFlags::nsw | arith::IntegerOverflowFlags::nuw);
+
+ APInt c2(8, 2), c3(8, 3), c4(8, 4), c5(8, 5), c8(8, 8), c16(8, 16),
+ c20(8, 20);
+
+ // Both signed and unsigned proofs succeed.
+ ConstantIntRanges mulLhs(c2, c3, c2, c3);
+ ConstantIntRanges mulRhs(c4, c5, c4, c5);
+ ConstantIntRanges mulResult = inferBinaryOpResult(mul, {mulLhs, mulRhs});
+ EXPECT_EQ(mulResult.getOverflowFlags(),
+ OverflowFlags::Nsw | OverflowFlags::Nuw);
+
+ // Unsigned proof succeeds, but signed proof fails (16 * 8 = 128 in i8).
+ ConstantIntRanges mulNuwOnlyLhs(c16, c16, c16, c16);
+ ConstantIntRanges mulNuwOnlyRhs(c8, c8, c8, c8);
+ ConstantIntRanges mulNuwOnlyResult =
+ inferBinaryOpResult(mul, {mulNuwOnlyLhs, mulNuwOnlyRhs});
+ EXPECT_EQ(mulNuwOnlyResult.getOverflowFlags(), OverflowFlags::Nuw);
+
+ // Both signed and unsigned proofs fail.
+ ConstantIntRanges mulMayOverflowLhs(c16, c20, c16, c20);
+ ConstantIntRanges mulMayOverflowRhs(c16, c20, c16, c20);
+ ConstantIntRanges mulMayOverflowResult =
+ inferBinaryOpResult(mul, {mulMayOverflowLhs, mulMayOverflowRhs});
+ EXPECT_EQ(mulMayOverflowResult.getOverflowFlags(), OverflowFlags::None);
+
+ // Declared op flags are preserved.
+ ConstantIntRanges mulDeclaredResult = inferBinaryOpResult(
+ mulWithDeclaredFlags, {mulMayOverflowLhs, mulMayOverflowRhs});
+ EXPECT_EQ(mulDeclaredResult.getOverflowFlags(),
+ OverflowFlags::Nsw | OverflowFlags::Nuw);
+}
>From 0d7fc84de15ccd76185158f938d55e537238dd29 Mon Sep 17 00:00:00 2001
From: Hocky Yudhiono <hocky.yudhiono at gmail.com>
Date: Tue, 14 Apr 2026 16:17:38 +0800
Subject: [PATCH 3/4] [mlir][interfaces] Refactor interfaces
---
.../mlir/Interfaces/InferIntRangeInterface.h | 5 +
.../Interfaces/Utils/InferIntRangeCommon.h | 14 +-
.../Arith/IR/InferIntRangeInterfaceImpls.cpp | 117 +---------------
.../lib/Interfaces/InferIntRangeInterface.cpp | 6 +-
.../Interfaces/Utils/InferIntRangeCommon.cpp | 125 +++++++++++++++++-
mlir/test/lib/Dialect/Test/TestOpDefs.cpp | 84 ++++++++++++
mlir/test/lib/Dialect/Test/TestOps.td | 7 +-
.../Interfaces/InferIntRangeInterfaceTest.cpp | 2 +
8 files changed, 236 insertions(+), 124 deletions(-)
diff --git a/mlir/include/mlir/Interfaces/InferIntRangeInterface.h b/mlir/include/mlir/Interfaces/InferIntRangeInterface.h
index 58682967df9b1..e2391e37ebaae 100644
--- a/mlir/include/mlir/Interfaces/InferIntRangeInterface.h
+++ b/mlir/include/mlir/Interfaces/InferIntRangeInterface.h
@@ -19,6 +19,8 @@
#include <optional>
namespace mlir {
+
+namespace intrange {
enum class OverflowFlags : uint32_t {
None = 0,
Nsw = 1,
@@ -26,6 +28,7 @@ enum class OverflowFlags : uint32_t {
LLVM_MARK_AS_BITMASK_ENUM(Nuw)
};
LLVM_ENABLE_BITMASK_ENUMS_IN_NAMESPACE();
+} // end namespace intrange
/// A set of arbitrary-precision integers representing bounds on a given integer
/// value. These bounds are inclusive on both ends, so
@@ -35,6 +38,8 @@ LLVM_ENABLE_BITMASK_ENUMS_IN_NAMESPACE();
/// unsigned semantics.
class ConstantIntRanges {
public:
+ using OverflowFlags = intrange::OverflowFlags;
+
/// Bound umin <= (unsigned)x <= umax and smin <= signed(x) <= smax.
/// Non-integer values should be bounded by APInts of bitwidth 0.
ConstantIntRanges(const APInt &umin, const APInt &umax, const APInt &smin,
diff --git a/mlir/include/mlir/Interfaces/Utils/InferIntRangeCommon.h b/mlir/include/mlir/Interfaces/Utils/InferIntRangeCommon.h
index 8c0784a4513bb..3f7bcf06cc79d 100644
--- a/mlir/include/mlir/Interfaces/Utils/InferIntRangeCommon.h
+++ b/mlir/include/mlir/Interfaces/Utils/InferIntRangeCommon.h
@@ -38,8 +38,6 @@ static constexpr unsigned indexMaxWidth = 64;
enum class CmpMode : uint32_t { Both, Signed, Unsigned };
-using OverflowFlags = mlir::OverflowFlags;
-
/// Function that performs inference on an array of `ConstantIntRanges` while
/// taking special overflow behavior into account.
using InferRangeWithOvfFlagsFn =
@@ -77,12 +75,24 @@ ConstantIntRanges truncRange(const ConstantIntRanges &range,
ConstantIntRanges inferAdd(ArrayRef<ConstantIntRanges> argRanges,
OverflowFlags ovfFlags = OverflowFlags::None);
+OverflowFlags
+inferOverflowFlagsForAdd(ArrayRef<ConstantIntRanges> argRanges,
+ OverflowFlags declaredFlags = OverflowFlags::None);
+
ConstantIntRanges inferSub(ArrayRef<ConstantIntRanges> argRanges,
OverflowFlags ovfFlags = OverflowFlags::None);
+OverflowFlags
+inferOverflowFlagsForSub(ArrayRef<ConstantIntRanges> argRanges,
+ OverflowFlags declaredFlags = OverflowFlags::None);
+
ConstantIntRanges inferMul(ArrayRef<ConstantIntRanges> argRanges,
OverflowFlags ovfFlags = OverflowFlags::None);
+OverflowFlags
+inferOverflowFlagsForMul(ArrayRef<ConstantIntRanges> argRanges,
+ OverflowFlags declaredFlags = OverflowFlags::None);
+
ConstantIntRanges inferDivS(ArrayRef<ConstantIntRanges> argRanges);
ConstantIntRanges inferDivU(ArrayRef<ConstantIntRanges> argRanges);
diff --git a/mlir/lib/Dialect/Arith/IR/InferIntRangeInterfaceImpls.cpp b/mlir/lib/Dialect/Arith/IR/InferIntRangeInterfaceImpls.cpp
index 2624448b5f7b4..92b9f347a634e 100644
--- a/mlir/lib/Dialect/Arith/IR/InferIntRangeInterfaceImpls.cpp
+++ b/mlir/lib/Dialect/Arith/IR/InferIntRangeInterfaceImpls.cpp
@@ -11,7 +11,6 @@
#include "mlir/Interfaces/Utils/InferIntRangeCommon.h"
#include <optional>
-#include <utility>
#define DEBUG_TYPE "int-range-analysis"
@@ -29,116 +28,6 @@ convertArithOverflowFlags(arith::IntegerOverflowFlags flags) {
return retFlags;
}
-template <typename Op>
-static bool proveNoOverflow(const APInt &lhs, const APInt &rhs, Op op) {
- bool overflow = false;
- (void)op(lhs, rhs, overflow);
- return !overflow;
-}
-
-template <typename Op>
-static bool
-proveNoOverflowForPairs(ArrayRef<std::pair<const APInt *, const APInt *>> pairs,
- Op op) {
- for (const auto &[lhs, rhs] : pairs) {
- if (!proveNoOverflow(*lhs, *rhs, op))
- return false;
- }
- return true;
-}
-
-static OverflowFlags proveNoOverflowFlags(
- ArrayRef<ConstantIntRanges> args,
- function_ref<bool(ArrayRef<ConstantIntRanges>)> proveSigned,
- function_ref<bool(ArrayRef<ConstantIntRanges>)> proveUnsigned) {
- OverflowFlags flags = OverflowFlags::None;
- if (proveSigned(args))
- flags |= OverflowFlags::Nsw;
- if (proveUnsigned(args))
- flags |= OverflowFlags::Nuw;
- return flags;
-}
-
-static bool proveNoSignedAddOverflow(ArrayRef<ConstantIntRanges> argRanges) {
- const APInt &lhsMin = argRanges[0].smin();
- const APInt &lhsMax = argRanges[0].smax();
- const APInt &rhsMin = argRanges[1].smin();
- const APInt &rhsMax = argRanges[1].smax();
- // Signed add is monotone in both operands, so it is enough to check
- // the interval endpoints to prove no signed wrap for the whole range.
- return proveNoOverflowForPairs(
- {{&lhsMin, &rhsMin}, {&lhsMax, &rhsMax}},
- [](const APInt &lhs, const APInt &rhs, bool &overflow) {
- return lhs.sadd_ov(rhs, overflow);
- });
-}
-
-static bool proveNoUnsignedAddOverflow(ArrayRef<ConstantIntRanges> argRanges) {
- return proveNoOverflow(
- argRanges[0].umax(), argRanges[1].umax(),
- [](const APInt &lhs, const APInt &rhs, bool &overflow) {
- return lhs.uadd_ov(rhs, overflow);
- });
-}
-
-static bool proveNoSignedSubOverflow(ArrayRef<ConstantIntRanges> argRanges) {
- const APInt &lhsMin = argRanges[0].smin();
- const APInt &lhsMax = argRanges[0].smax();
- const APInt &rhsMin = argRanges[1].smin();
- const APInt &rhsMax = argRanges[1].smax();
- // For lhs - rhs, the extrema occur at (lhsMin - rhsMax) and
- // (lhsMax - rhsMin). If both are no-wrap, the full interval is no-wrap.
- return proveNoOverflowForPairs(
- {{&lhsMin, &rhsMax}, {&lhsMax, &rhsMin}},
- [](const APInt &lhs, const APInt &rhs, bool &overflow) {
- return lhs.ssub_ov(rhs, overflow);
- });
-}
-
-static bool proveNoUnsignedSubOverflow(ArrayRef<ConstantIntRanges> argRanges) {
- return argRanges[0].umin().uge(argRanges[1].umax());
-}
-
-static bool proveNoSignedMulOverflow(ArrayRef<ConstantIntRanges> argRanges) {
- const APInt &lhsMin = argRanges[0].smin();
- const APInt &lhsMax = argRanges[0].smax();
- const APInt &rhsMin = argRanges[1].smin();
- const APInt &rhsMax = argRanges[1].smax();
- // Signed multiply is not monotone across sign changes, so conservatively
- // require all four corner products to be no-wrap.
- return proveNoOverflowForPairs(
- {{&lhsMin, &rhsMin},
- {&lhsMin, &rhsMax},
- {&lhsMax, &rhsMin},
- {&lhsMax, &rhsMax}},
- [](const APInt &lhs, const APInt &rhs, bool &overflow) {
- return lhs.smul_ov(rhs, overflow);
- });
-}
-
-static bool proveNoUnsignedMulOverflow(ArrayRef<ConstantIntRanges> argRanges) {
- return proveNoOverflow(
- argRanges[0].umax(), argRanges[1].umax(),
- [](const APInt &lhs, const APInt &rhs, bool &overflow) {
- return lhs.umul_ov(rhs, overflow);
- });
-}
-
-static OverflowFlags proveNoOverflowForAdd(ArrayRef<ConstantIntRanges> args) {
- return proveNoOverflowFlags(args, proveNoSignedAddOverflow,
- proveNoUnsignedAddOverflow);
-}
-
-static OverflowFlags proveNoOverflowForSub(ArrayRef<ConstantIntRanges> args) {
- return proveNoOverflowFlags(args, proveNoSignedSubOverflow,
- proveNoUnsignedSubOverflow);
-}
-
-static OverflowFlags proveNoOverflowForMul(ArrayRef<ConstantIntRanges> args) {
- return proveNoOverflowFlags(args, proveNoSignedMulOverflow,
- proveNoUnsignedMulOverflow);
-}
-
//===----------------------------------------------------------------------===//
// ConstantOp
//===----------------------------------------------------------------------===//
@@ -179,7 +68,7 @@ void arith::AddIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
OverflowFlags declaredFlags = convertArithOverflowFlags(getOverflowFlags());
ConstantIntRanges range = inferAdd(argRanges, declaredFlags);
OverflowFlags overflowFlags =
- proveNoOverflowForAdd(argRanges) | declaredFlags;
+ inferOverflowFlagsForAdd(argRanges, declaredFlags);
setResultRange(getResult(), range.withOverflowFlags(overflowFlags));
}
@@ -192,7 +81,7 @@ void arith::SubIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
OverflowFlags declaredFlags = convertArithOverflowFlags(getOverflowFlags());
ConstantIntRanges range = inferSub(argRanges, declaredFlags);
OverflowFlags overflowFlags =
- proveNoOverflowForSub(argRanges) | declaredFlags;
+ inferOverflowFlagsForSub(argRanges, declaredFlags);
setResultRange(getResult(), range.withOverflowFlags(overflowFlags));
}
@@ -205,7 +94,7 @@ void arith::MulIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
OverflowFlags declaredFlags = convertArithOverflowFlags(getOverflowFlags());
ConstantIntRanges range = inferMul(argRanges, declaredFlags);
OverflowFlags overflowFlags =
- proveNoOverflowForMul(argRanges) | declaredFlags;
+ inferOverflowFlagsForMul(argRanges, declaredFlags);
setResultRange(getResult(), range.withOverflowFlags(overflowFlags));
}
diff --git a/mlir/lib/Interfaces/InferIntRangeInterface.cpp b/mlir/lib/Interfaces/InferIntRangeInterface.cpp
index 090c13cea4599..13edaeb91057d 100644
--- a/mlir/lib/Interfaces/InferIntRangeInterface.cpp
+++ b/mlir/lib/Interfaces/InferIntRangeInterface.cpp
@@ -13,6 +13,7 @@
#include <optional>
using namespace mlir;
+using mlir::intrange::OverflowFlags;
bool ConstantIntRanges::operator==(const ConstantIntRanges &other) const {
return hasSameBounds(other) && getOverflowFlags() == other.getOverflowFlags();
@@ -118,8 +119,7 @@ ConstantIntRanges::intersection(const ConstantIntRanges &other) const {
const APInt &sminIntersect = smin().sgt(other.smin()) ? smin() : other.smin();
const APInt &smaxIntersect = smax().slt(other.smax()) ? smax() : other.smax();
// For an intersection, guarantees from either input remain valid.
- OverflowFlags overflowIntersect =
- getOverflowFlags() | other.getOverflowFlags();
+ auto overflowIntersect = getOverflowFlags() | other.getOverflowFlags();
return {uminIntersect, umaxIntersect, sminIntersect, smaxIntersect,
overflowIntersect};
@@ -140,7 +140,7 @@ raw_ostream &mlir::operator<<(raw_ostream &os, const ConstantIntRanges &range) {
os << ", ";
range.umax().print(os, /*isSigned*/ false);
os << "] signed : [" << range.smin() << ", " << range.smax() << "]";
- OverflowFlags overflowFlags = range.getOverflowFlags();
+ auto overflowFlags = range.getOverflowFlags();
if (overflowFlags == OverflowFlags::None)
return os;
diff --git a/mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp b/mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp
index ffb407a650599..b9fa1ed04c475 100644
--- a/mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp
+++ b/mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp
@@ -24,6 +24,7 @@
#include <iterator>
#include <optional>
+#include <utility>
using namespace mlir;
@@ -183,6 +184,118 @@ ConstantIntRanges mlir::intrange::truncRange(const ConstantIntRanges &range,
// Addition
//===----------------------------------------------------------------------===//
+template <typename Op>
+static bool isOverflowFree(const APInt &lhs, const APInt &rhs, Op op) {
+ bool overflow = false;
+ (void)op(lhs, rhs, overflow);
+ return !overflow;
+}
+
+template <typename Op>
+static bool
+areAllPairsOverflowFree(ArrayRef<std::pair<const APInt *, const APInt *>> pairs,
+ Op op) {
+ for (const auto &[lhs, rhs] : pairs) {
+ if (!isOverflowFree(*lhs, *rhs, op))
+ return false;
+ }
+ return true;
+}
+
+static intrange::OverflowFlags updateOverflowFlags(
+ ArrayRef<ConstantIntRanges> args, intrange::OverflowFlags declaredFlags,
+ function_ref<bool(ArrayRef<ConstantIntRanges>)> isSignedOverflowFree,
+ function_ref<bool(ArrayRef<ConstantIntRanges>)> isUnsignedOverflowFree) {
+ intrange::OverflowFlags flags = declaredFlags;
+ if (!any(flags & intrange::OverflowFlags::Nsw) && isSignedOverflowFree(args))
+ flags |= intrange::OverflowFlags::Nsw;
+ if (!any(flags & intrange::OverflowFlags::Nuw) &&
+ isUnsignedOverflowFree(args))
+ flags |= intrange::OverflowFlags::Nuw;
+ return flags;
+}
+
+static bool isSignedAddOverflowFree(ArrayRef<ConstantIntRanges> argRanges) {
+ const APInt &lhsMin = argRanges[0].smin();
+ const APInt &lhsMax = argRanges[0].smax();
+ const APInt &rhsMin = argRanges[1].smin();
+ const APInt &rhsMax = argRanges[1].smax();
+ // Signed add is monotone in both operands, so it is enough to check
+ // the interval endpoints to prove no signed wrap for the whole range.
+ return areAllPairsOverflowFree(
+ {{&lhsMin, &rhsMin}, {&lhsMax, &rhsMax}},
+ [](const APInt &lhs, const APInt &rhs, bool &overflow) {
+ return lhs.sadd_ov(rhs, overflow);
+ });
+}
+
+static bool isUnsignedAddOverflowFree(ArrayRef<ConstantIntRanges> argRanges) {
+ return isOverflowFree(argRanges[0].umax(), argRanges[1].umax(),
+ [](const APInt &lhs, const APInt &rhs, bool &overflow) {
+ return lhs.uadd_ov(rhs, overflow);
+ });
+}
+
+static bool isSignedSubOverflowFree(ArrayRef<ConstantIntRanges> argRanges) {
+ const APInt &lhsMin = argRanges[0].smin();
+ const APInt &lhsMax = argRanges[0].smax();
+ const APInt &rhsMin = argRanges[1].smin();
+ const APInt &rhsMax = argRanges[1].smax();
+ // For lhs - rhs, the extrema occur at (lhsMin - rhsMax) and
+ // (lhsMax - rhsMin). If both are no-wrap, the full interval is no-wrap.
+ return areAllPairsOverflowFree(
+ {{&lhsMin, &rhsMax}, {&lhsMax, &rhsMin}},
+ [](const APInt &lhs, const APInt &rhs, bool &overflow) {
+ return lhs.ssub_ov(rhs, overflow);
+ });
+}
+
+static bool isUnsignedSubOverflowFree(ArrayRef<ConstantIntRanges> argRanges) {
+ return argRanges[0].umin().uge(argRanges[1].umax());
+}
+
+static bool isSignedMulOverflowFree(ArrayRef<ConstantIntRanges> argRanges) {
+ const APInt &lhsMin = argRanges[0].smin();
+ const APInt &lhsMax = argRanges[0].smax();
+ const APInt &rhsMin = argRanges[1].smin();
+ const APInt &rhsMax = argRanges[1].smax();
+ // Signed multiply is not monotone across sign changes, so conservatively
+ // require all four corner products to be no-wrap.
+ return areAllPairsOverflowFree(
+ {{&lhsMin, &rhsMin},
+ {&lhsMin, &rhsMax},
+ {&lhsMax, &rhsMin},
+ {&lhsMax, &rhsMax}},
+ [](const APInt &lhs, const APInt &rhs, bool &overflow) {
+ return lhs.smul_ov(rhs, overflow);
+ });
+}
+
+static bool isUnsignedMulOverflowFree(ArrayRef<ConstantIntRanges> argRanges) {
+ return isOverflowFree(argRanges[0].umax(), argRanges[1].umax(),
+ [](const APInt &lhs, const APInt &rhs, bool &overflow) {
+ return lhs.umul_ov(rhs, overflow);
+ });
+}
+
+intrange::OverflowFlags mlir::intrange::inferOverflowFlagsForAdd(
+ ArrayRef<ConstantIntRanges> args, intrange::OverflowFlags declaredFlags) {
+ return updateOverflowFlags(args, declaredFlags, isSignedAddOverflowFree,
+ isUnsignedAddOverflowFree);
+}
+
+intrange::OverflowFlags mlir::intrange::inferOverflowFlagsForSub(
+ ArrayRef<ConstantIntRanges> args, intrange::OverflowFlags declaredFlags) {
+ return updateOverflowFlags(args, declaredFlags, isSignedSubOverflowFree,
+ isUnsignedSubOverflowFree);
+}
+
+intrange::OverflowFlags mlir::intrange::inferOverflowFlagsForMul(
+ ArrayRef<ConstantIntRanges> args, intrange::OverflowFlags declaredFlags) {
+ return updateOverflowFlags(args, declaredFlags, isSignedMulOverflowFree,
+ isUnsignedMulOverflowFree);
+}
+
ConstantIntRanges
mlir::intrange::inferAdd(ArrayRef<ConstantIntRanges> argRanges,
OverflowFlags ovfFlags) {
@@ -813,7 +926,11 @@ mlir::intrange::inferAffineExpr(AffineExpr expr,
inferAffineExpr(binExpr.getLHS(), dimRanges, symbolRanges);
ConstantIntRanges rhs =
inferAffineExpr(binExpr.getRHS(), dimRanges, symbolRanges);
- return inferAdd({lhs, rhs}, OverflowFlags::Nsw);
+ OverflowFlags declaredFlags = OverflowFlags::Nsw;
+ ConstantIntRanges range = inferAdd({lhs, rhs}, declaredFlags);
+ OverflowFlags overflowFlags =
+ inferOverflowFlagsForAdd({lhs, rhs}, declaredFlags);
+ return range.withOverflowFlags(overflowFlags);
}
case AffineExprKind::Mul: {
auto binExpr = cast<AffineBinaryOpExpr>(expr);
@@ -821,7 +938,11 @@ mlir::intrange::inferAffineExpr(AffineExpr expr,
inferAffineExpr(binExpr.getLHS(), dimRanges, symbolRanges);
ConstantIntRanges rhs =
inferAffineExpr(binExpr.getRHS(), dimRanges, symbolRanges);
- return inferMul({lhs, rhs}, OverflowFlags::Nsw);
+ OverflowFlags declaredFlags = OverflowFlags::Nsw;
+ ConstantIntRanges range = inferMul({lhs, rhs}, declaredFlags);
+ OverflowFlags overflowFlags =
+ inferOverflowFlagsForMul({lhs, rhs}, declaredFlags);
+ return range.withOverflowFlags(overflowFlags);
}
case AffineExprKind::Mod: {
auto binExpr = cast<AffineBinaryOpExpr>(expr);
diff --git a/mlir/test/lib/Dialect/Test/TestOpDefs.cpp b/mlir/test/lib/Dialect/Test/TestOpDefs.cpp
index 5fee060689d24..8d628b8d56e7d 100644
--- a/mlir/test/lib/Dialect/Test/TestOpDefs.cpp
+++ b/mlir/test/lib/Dialect/Test/TestOpDefs.cpp
@@ -1005,9 +1005,93 @@ void TestReflectBoundsOp::inferResultRanges(
setUmaxAttr(b.getIntegerAttr(uIntTy, range.umax()));
setSminAttr(b.getIntegerAttr(sIntTy, range.smin()));
setSmaxAttr(b.getIntegerAttr(sIntTy, range.smax()));
+ setOverflowAttr(
+ b.getUI32IntegerAttr(static_cast<uint32_t>(range.getOverflowFlags())));
setResultRanges(getResult(), range);
}
+ParseResult TestReflectBoundsOp::parse(OpAsmParser &parser,
+ OperationState &result) {
+ OpAsmParser::UnresolvedOperand value;
+ Type type;
+ Builder &builder = parser.getBuilder();
+
+ if (parser.parseOptionalAttrDict(result.attributes))
+ return failure();
+ if (parser.parseOperand(value))
+ return failure();
+
+ intrange::OverflowFlags flags = intrange::OverflowFlags::None;
+ if (succeeded(parser.parseOptionalKeyword("overflow"))) {
+ if (parser.parseLess())
+ return failure();
+
+ auto parseOverflowFlag = [&](bool allowNone) -> ParseResult {
+ if (succeeded(parser.parseOptionalKeyword("nsw"))) {
+ flags |= intrange::OverflowFlags::Nsw;
+ return success();
+ }
+ if (succeeded(parser.parseOptionalKeyword("nuw"))) {
+ flags |= intrange::OverflowFlags::Nuw;
+ return success();
+ }
+ if (allowNone && succeeded(parser.parseOptionalKeyword("none")))
+ return success();
+ return parser.emitError(parser.getCurrentLocation())
+ << "expected overflow flag 'none', 'nsw', or 'nuw'";
+ };
+
+ if (failed(parseOverflowFlag(/*allowNone=*/true)))
+ return failure();
+
+ while (flags != intrange::OverflowFlags::None &&
+ succeeded(parser.parseOptionalComma())) {
+ if (failed(parseOverflowFlag(/*allowNone=*/false)))
+ return failure();
+ }
+
+ if (parser.parseGreater())
+ return failure();
+ result.addAttribute(
+ TestReflectBoundsOp::getOverflowAttrName(result.name),
+ builder.getUI32IntegerAttr(static_cast<uint32_t>(flags)));
+ }
+
+ if (parser.parseColonType(type))
+ return failure();
+ if (parser.resolveOperand(value, type, result.operands))
+ return failure();
+ result.addTypes(type);
+ return success();
+}
+
+void TestReflectBoundsOp::print(OpAsmPrinter &p) {
+ p.printOptionalAttrDict((*this)->getAttrs(), /*elidedAttrs=*/{"overflow"});
+ intrange::OverflowFlags flags = intrange::OverflowFlags::None;
+ if (auto overflowAttr = getOverflowAttr())
+ flags = static_cast<intrange::OverflowFlags>(
+ overflowAttr.getValue().getZExtValue());
+ p << " overflow<";
+ bool emitted = false;
+ auto emitOverflowFlag = [&](intrange::OverflowFlags flag,
+ StringLiteral keyword) {
+ if ((flags & flag) == intrange::OverflowFlags::None)
+ return;
+ if (emitted)
+ p << ", ";
+ p << keyword;
+ emitted = true;
+ };
+ emitOverflowFlag(intrange::OverflowFlags::Nsw, "nsw");
+ emitOverflowFlag(intrange::OverflowFlags::Nuw, "nuw");
+ if (!emitted)
+ p << "none";
+ p << '>';
+
+ p << ' ' << getValue();
+ p << " : " << getType();
+}
+
//===----------------------------------------------------------------------===//
// ConversionFuncOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index 774329b9d2736..c517809fa6a4b 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -3347,7 +3347,7 @@ def TestReflectBoundsOp : TEST_Op<"reflect_bounds",
Example:
```mlir
- CHECK: test.reflect_bounds {smax = 7 : index, smin = 0 : index, umax = 7 : index, umin = 0 : index}
+ CHECK: test.reflect_bounds {smax = 7 : index, smin = 0 : index, umax = 7 : index, umin = 0 : index} overflow<none>
%1 = test.reflect_bounds %0 : index
```
}];
@@ -3356,10 +3356,11 @@ def TestReflectBoundsOp : TEST_Op<"reflect_bounds",
OptionalAttr<APIntAttr>:$umin,
OptionalAttr<APIntAttr>:$umax,
OptionalAttr<APIntAttr>:$smin,
- OptionalAttr<APIntAttr>:$smax);
+ OptionalAttr<APIntAttr>:$smax,
+ OptionalAttr<UI32Attr>:$overflow);
let results = (outs InferIntRangeType:$result);
- let assemblyFormat = "attr-dict $value `:` type($result)";
+ let hasCustomAssemblyFormat = 1;
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/unittests/Interfaces/InferIntRangeInterfaceTest.cpp b/mlir/unittests/Interfaces/InferIntRangeInterfaceTest.cpp
index 9affdeac1fb6f..daff80ffa4f76 100644
--- a/mlir/unittests/Interfaces/InferIntRangeInterfaceTest.cpp
+++ b/mlir/unittests/Interfaces/InferIntRangeInterfaceTest.cpp
@@ -20,6 +20,8 @@ using namespace mlir;
namespace {
+using mlir::intrange::OverflowFlags;
+
template <typename OpTy>
ConstantIntRanges inferBinaryOpResult(OpTy op,
ArrayRef<ConstantIntRanges> argRanges) {
>From 3f5f892de1445c8eee24857c6b5577a0a3d5f2a4 Mon Sep 17 00:00:00 2001
From: Hocky Yudhiono <hocky.yudhiono at gmail.com>
Date: Tue, 14 Apr 2026 16:17:45 +0800
Subject: [PATCH 4/4] [mlir][interfaces] Add testcases
---
.../Dialect/Affine/int-range-interface.mlir | 44 ++++++++++++-------
.../Dialect/Arith/int-range-interface.mlir | 43 ++++++++++++++----
.../Dialect/MemRef/int-range-inference.mlir | 25 ++++++++---
.../Dialect/Tensor/int-range-inference.mlir | 25 ++++++++---
.../Dialect/Vector/int-range-interface.mlir | 12 ++++-
.../infer-int-range-test-ops.mlir | 2 +-
6 files changed, 116 insertions(+), 35 deletions(-)
diff --git a/mlir/test/Dialect/Affine/int-range-interface.mlir b/mlir/test/Dialect/Affine/int-range-interface.mlir
index ac64ad09ee244..bd4c8b2e4611d 100644
--- a/mlir/test/Dialect/Affine/int-range-interface.mlir
+++ b/mlir/test/Dialect/Affine/int-range-interface.mlir
@@ -1,7 +1,7 @@
// RUN: mlir-opt --int-range-optimizations %s | FileCheck %s
// CHECK-LABEL: func @affine_apply_constant
-// CHECK: test.reflect_bounds {smax = 42 : index, smin = 42 : index, umax = 42 : index, umin = 42 : index}
+// CHECK: test.reflect_bounds {smax = 42 : index, smin = 42 : index, umax = 42 : index, umin = 42 : index} overflow<none>
func.func @affine_apply_constant() -> index {
%0 = affine.apply affine_map<() -> (42)>()
%1 = test.reflect_bounds %0 : index
@@ -9,7 +9,7 @@ func.func @affine_apply_constant() -> index {
}
// CHECK-LABEL: func @affine_apply_add
-// CHECK: test.reflect_bounds {smax = 15 : index, smin = 6 : index, umax = 15 : index, umin = 6 : index}
+// CHECK: test.reflect_bounds {smax = 15 : index, smin = 6 : index, umax = 15 : index, umin = 6 : index} overflow<nsw, nuw>
func.func @affine_apply_add() -> index {
%d0 = test.with_bounds { umin = 2 : index, umax = 5 : index,
smin = 2 : index, smax = 5 : index } : index
@@ -20,8 +20,22 @@ func.func @affine_apply_add() -> index {
func.return %1 : index
}
+// CHECK-LABEL: func @affine_apply_add_overflow
+// CHECK: test.reflect_bounds {smax = 16 : index, smin = 7 : index, umax = 16 : index, umin = 7 : index} overflow<nsw, nuw>
+func.func @affine_apply_add_overflow() -> index {
+ %d0 = test.with_bounds { umin = 2 : index, umax = 5 : index,
+ smin = 2 : index, smax = 5 : index } : index
+ %d1 = test.with_bounds { umin = 4 : index, umax = 10 : index,
+ smin = 4 : index, smax = 10 : index } : index
+ %c1 = arith.constant 1 : index
+ %0 = affine.apply affine_map<(d0, d1) -> (d0 + d1)>(%d0, %d1)
+ %1 = arith.addi %0, %c1 overflow<nuw> : index
+ %2 = test.reflect_bounds %1 : index
+ func.return %2 : index
+}
+
// CHECK-LABEL: func @affine_apply_mul
-// CHECK: test.reflect_bounds {smax = 30 : index, smin = 12 : index, umax = 30 : index, umin = 12 : index}
+// CHECK: test.reflect_bounds {smax = 30 : index, smin = 12 : index, umax = 30 : index, umin = 12 : index} overflow<nsw, nuw>
func.func @affine_apply_mul() -> index {
%d0 = test.with_bounds { umin = 2 : index, umax = 5 : index,
smin = 2 : index, smax = 5 : index } : index
@@ -33,7 +47,7 @@ func.func @affine_apply_mul() -> index {
}
// CHECK-LABEL: func @affine_apply_floordiv
-// CHECK: test.reflect_bounds {smax = 2 : index, smin = 1 : index, umax = 2 : index, umin = 1 : index}
+// CHECK: test.reflect_bounds {smax = 2 : index, smin = 1 : index, umax = 2 : index, umin = 1 : index} overflow<none>
func.func @affine_apply_floordiv() -> index {
%d0 = test.with_bounds { umin = 5 : index, umax = 10 : index,
smin = 5 : index, smax = 10 : index } : index
@@ -43,7 +57,7 @@ func.func @affine_apply_floordiv() -> index {
}
// CHECK-LABEL: func @affine_apply_ceildiv
-// CHECK: test.reflect_bounds {smax = 3 : index, smin = 2 : index, umax = 3 : index, umin = 2 : index}
+// CHECK: test.reflect_bounds {smax = 3 : index, smin = 2 : index, umax = 3 : index, umin = 2 : index} overflow<none>
func.func @affine_apply_ceildiv() -> index {
%d0 = test.with_bounds { umin = 5 : index, umax = 10 : index,
smin = 5 : index, smax = 10 : index } : index
@@ -53,7 +67,7 @@ func.func @affine_apply_ceildiv() -> index {
}
// CHECK-LABEL: func @affine_apply_mod
-// CHECK: test.reflect_bounds {smax = 3 : index, smin = 0 : index, umax = 3 : index, umin = 0 : index}
+// CHECK: test.reflect_bounds {smax = 3 : index, smin = 0 : index, umax = 3 : index, umin = 0 : index} overflow<none>
func.func @affine_apply_mod() -> index {
%d0 = test.with_bounds { umin = 5 : index, umax = 27 : index,
smin = 5 : index, smax = 27 : index } : index
@@ -63,7 +77,7 @@ func.func @affine_apply_mod() -> index {
}
// CHECK-LABEL: func @affine_apply_complex
-// CHECK: test.reflect_bounds {smax = 13 : index, smin = 5 : index, umax = 13 : index, umin = 5 : index}
+// CHECK: test.reflect_bounds {smax = 13 : index, smin = 5 : index, umax = 13 : index, umin = 5 : index} overflow<nsw, nuw>
func.func @affine_apply_complex() -> index {
%d0 = test.with_bounds { umin = 10 : index, umax = 20 : index,
smin = 10 : index, smax = 20 : index } : index
@@ -76,7 +90,7 @@ func.func @affine_apply_complex() -> index {
}
// CHECK-LABEL: func @affine_apply_with_symbols
-// CHECK: test.reflect_bounds {smax = 24 : index, smin = 9 : index, umax = 24 : index, umin = 9 : index}
+// CHECK: test.reflect_bounds {smax = 24 : index, smin = 9 : index, umax = 24 : index, umin = 9 : index} overflow<nsw, nuw>
func.func @affine_apply_with_symbols() -> index {
%d0 = test.with_bounds { umin = 2 : index, umax = 5 : index,
smin = 2 : index, smax = 5 : index } : index
@@ -89,7 +103,7 @@ func.func @affine_apply_with_symbols() -> index {
}
// CHECK-LABEL: func @affine_apply_sub
-// CHECK: test.reflect_bounds {smax = 1 : index, smin = -8 : index
+// CHECK: test.reflect_bounds {smax = 1 : index, smin = -8 : index, umax = -1 : index, umin = 0 : index} overflow<nsw>
func.func @affine_apply_sub() -> index {
%d0 = test.with_bounds { umin = 2 : index, umax = 5 : index,
smin = 2 : index, smax = 5 : index } : index
@@ -102,7 +116,7 @@ func.func @affine_apply_sub() -> index {
}
// CHECK-LABEL: func @affine_apply_mul_constant
-// CHECK: test.reflect_bounds {smax = 20 : index, smin = 8 : index, umax = 20 : index, umin = 8 : index}
+// CHECK: test.reflect_bounds {smax = 20 : index, smin = 8 : index, umax = 20 : index, umin = 8 : index} overflow<nsw, nuw>
func.func @affine_apply_mul_constant() -> index {
%d0 = test.with_bounds { umin = 2 : index, umax = 5 : index,
smin = 2 : index, smax = 5 : index } : index
@@ -113,7 +127,7 @@ func.func @affine_apply_mul_constant() -> index {
}
// CHECK-LABEL: func @affine_apply_mod_small_range
-// CHECK: test.reflect_bounds {smax = 2 : index, smin = 1 : index, umax = 2 : index, umin = 1 : index}
+// CHECK: test.reflect_bounds {smax = 2 : index, smin = 1 : index, umax = 2 : index, umin = 1 : index} overflow<none>
func.func @affine_apply_mod_small_range() -> index {
%d0 = test.with_bounds { umin = 5 : index, umax = 6 : index,
smin = 5 : index, smax = 6 : index } : index
@@ -124,7 +138,7 @@ func.func @affine_apply_mod_small_range() -> index {
}
// CHECK-LABEL: func @affine_apply_mod_already_in_range
-// CHECK: test.reflect_bounds {smax = 7 : index, smin = 5 : index, umax = 7 : index, umin = 5 : index}
+// CHECK: test.reflect_bounds {smax = 7 : index, smin = 5 : index, umax = 7 : index, umin = 5 : index} overflow<none>
func.func @affine_apply_mod_already_in_range() -> index {
%d0 = test.with_bounds { umin = 5 : index, umax = 7 : index,
smin = 5 : index, smax = 7 : index } : index
@@ -135,7 +149,7 @@ func.func @affine_apply_mod_already_in_range() -> index {
}
// CHECK-LABEL: func @affine_apply_mod_variable_divisor
-// CHECK: test.reflect_bounds {smax = 4 : index, smin = 0 : index, umax = 4 : index, umin = 0 : index}
+// CHECK: test.reflect_bounds {smax = 4 : index, smin = 0 : index, umax = 4 : index, umin = 0 : index} overflow<none>
func.func @affine_apply_mod_variable_divisor() -> index {
%d0 = test.with_bounds { umin = 10 : index, umax = 20 : index,
smin = 10 : index, smax = 20 : index } : index
@@ -148,7 +162,7 @@ func.func @affine_apply_mod_variable_divisor() -> index {
}
// CHECK-LABEL: func @affine_apply_mod_cross_boundary
-// CHECK: test.reflect_bounds {smax = 7 : index, smin = 0 : index, umax = 7 : index, umin = 0 : index}
+// CHECK: test.reflect_bounds {smax = 7 : index, smin = 0 : index, umax = 7 : index, umin = 0 : index} overflow<none>
func.func @affine_apply_mod_cross_boundary() -> index {
%d0 = test.with_bounds { umin = 14 : index, umax = 17 : index,
smin = 14 : index, smax = 17 : index } : index
@@ -160,7 +174,7 @@ func.func @affine_apply_mod_cross_boundary() -> index {
}
// CHECK-LABEL: func @affine_apply_mod_negative_dividend
-// CHECK: test.reflect_bounds {smax = 3 : index, smin = 0 : index, umax = 3 : index, umin = 0 : index}
+// CHECK: test.reflect_bounds {smax = 3 : index, smin = 0 : index, umax = 3 : index, umin = 0 : index} overflow<none>
func.func @affine_apply_mod_negative_dividend() -> index {
%d0 = test.with_bounds { umin = 0 : index, umax = 2 : index,
smin = -2 : index, smax = 2 : index } : index
diff --git a/mlir/test/Dialect/Arith/int-range-interface.mlir b/mlir/test/Dialect/Arith/int-range-interface.mlir
index dd8240299ef7e..ce2b86540f20f 100644
--- a/mlir/test/Dialect/Arith/int-range-interface.mlir
+++ b/mlir/test/Dialect/Arith/int-range-interface.mlir
@@ -886,7 +886,7 @@ func.func @test_add_1() -> i8 {
// Tests below check inference with overflow flags.
// CHECK-LABEL: func @test_add_i8_wrap1
-// CHECK: test.reflect_bounds {smax = 127 : si8, smin = -128 : si8, umax = 128 : ui8, umin = 1 : ui8}
+// CHECK: test.reflect_bounds {smax = 127 : si8, smin = -128 : si8, umax = 128 : ui8, umin = 1 : ui8} overflow<nuw>
func.func @test_add_i8_wrap1() -> i8 {
%cst1 = arith.constant 1 : i8
%0 = test.with_bounds { umin = 0 : i8, umax = 127 : i8, smin = 0 : i8, smax = 127 : i8 } : i8
@@ -897,7 +897,7 @@ func.func @test_add_i8_wrap1() -> i8 {
}
// CHECK-LABEL: func @test_add_i8_wrap2
-// CHECK: test.reflect_bounds {smax = 127 : si8, smin = -128 : si8, umax = 128 : ui8, umin = 1 : ui8}
+// CHECK: test.reflect_bounds {smax = 127 : si8, smin = -128 : si8, umax = 128 : ui8, umin = 1 : ui8} overflow<nuw>
func.func @test_add_i8_wrap2() -> i8 {
%cst1 = arith.constant 1 : i8
%0 = test.with_bounds { umin = 0 : i8, umax = 127 : i8, smin = 0 : i8, smax = 127 : i8 } : i8
@@ -908,7 +908,7 @@ func.func @test_add_i8_wrap2() -> i8 {
}
// CHECK-LABEL: func @test_add_i8_nowrap
-// CHECK: test.reflect_bounds {smax = 127 : si8, smin = 1 : si8, umax = 127 : ui8, umin = 1 : ui8}
+// CHECK: test.reflect_bounds {smax = 127 : si8, smin = 1 : si8, umax = 127 : ui8, umin = 1 : ui8} overflow<nsw, nuw>
func.func @test_add_i8_nowrap() -> i8 {
%cst1 = arith.constant 1 : i8
%0 = test.with_bounds { umin = 0 : i8, umax = 127 : i8, smin = 0 : i8, smax = 127 : i8 } : i8
@@ -919,7 +919,7 @@ func.func @test_add_i8_nowrap() -> i8 {
}
// CHECK-LABEL: func @test_sub_i8_wrap1
-// CHECK: test.reflect_bounds {smax = 5 : si8, smin = -10 : si8, umax = 255 : ui8, umin = 0 : ui8} %1 : i8
+// CHECK: test.reflect_bounds {smax = 5 : si8, smin = -10 : si8, umax = 255 : ui8, umin = 0 : ui8} overflow<nsw>
func.func @test_sub_i8_wrap1() -> i8 {
%cst10 = arith.constant 10 : i8
%0 = test.with_bounds { umin = 0 : i8, umax = 15 : i8, smin = 0 : i8, smax = 15 : i8 } : i8
@@ -930,7 +930,7 @@ func.func @test_sub_i8_wrap1() -> i8 {
}
// CHECK-LABEL: func @test_sub_i8_wrap2
-// CHECK: test.reflect_bounds {smax = 5 : si8, smin = -10 : si8, umax = 255 : ui8, umin = 0 : ui8} %1 : i8
+// CHECK: test.reflect_bounds {smax = 5 : si8, smin = -10 : si8, umax = 255 : ui8, umin = 0 : ui8} overflow<nsw>
func.func @test_sub_i8_wrap2() -> i8 {
%cst10 = arith.constant 10 : i8
%0 = test.with_bounds { umin = 0 : i8, umax = 15 : i8, smin = 0 : i8, smax = 15 : i8 } : i8
@@ -941,7 +941,7 @@ func.func @test_sub_i8_wrap2() -> i8 {
}
// CHECK-LABEL: func @test_sub_i8_nowrap
-// CHECK: test.reflect_bounds {smax = 5 : si8, smin = 0 : si8, umax = 5 : ui8, umin = 0 : ui8}
+// CHECK: test.reflect_bounds {smax = 5 : si8, smin = 0 : si8, umax = 5 : ui8, umin = 0 : ui8} overflow<nsw, nuw>
func.func @test_sub_i8_nowrap() -> i8 {
%cst10 = arith.constant 10 : i8
%0 = test.with_bounds { umin = 0 : i8, umax = 15 : i8, smin = 0 : i8, smax = 15 : i8 } : i8
@@ -952,7 +952,7 @@ func.func @test_sub_i8_nowrap() -> i8 {
}
// CHECK-LABEL: func @test_mul_i8_wrap
-// CHECK: test.reflect_bounds {smax = 127 : si8, smin = -128 : si8, umax = 200 : ui8, umin = 100 : ui8}
+// CHECK: test.reflect_bounds {smax = 127 : si8, smin = -128 : si8, umax = 200 : ui8, umin = 100 : ui8} overflow<nuw>
func.func @test_mul_i8_wrap() -> i8 {
%cst10 = arith.constant 10 : i8
%0 = test.with_bounds { umin = 10 : i8, umax = 20 : i8, smin = 10 : i8, smax = 20 : i8 } : i8
@@ -963,7 +963,7 @@ func.func @test_mul_i8_wrap() -> i8 {
}
// CHECK-LABEL: func @test_mul_i8_nowrap
-// CHECK: test.reflect_bounds {smax = 127 : si8, smin = 100 : si8, umax = 127 : ui8, umin = 100 : ui8}
+// CHECK: test.reflect_bounds {smax = 127 : si8, smin = 100 : si8, umax = 127 : ui8, umin = 100 : ui8} overflow<nsw, nuw>
func.func @test_mul_i8_nowrap() -> i8 {
%cst10 = arith.constant 10 : i8
%0 = test.with_bounds { umin = 10 : i8, umax = 20 : i8, smin = 10 : i8, smax = 20 : i8 } : i8
@@ -1060,3 +1060,30 @@ func.func @noninteger_operation_result(%lb: index, %ub: index, %step: index, %co
"use"(%result, %outs#1) : (i32, f32) -> ()
return
}
+
+// CHECK-LABEL: func @test_add_i8_reflect_overflow_kinds
+// CHECK-DAG: test.reflect_bounds {smax = 127 : si8, smin = -128 : si8, umax = 128 : ui8, umin = 1 : ui8} overflow<nuw>
+// CHECK-DAG: test.reflect_bounds {smax = 0 : si8, smin = -9 : si8, umax = 255 : ui8, umin = 0 : ui8} overflow<nsw>
+// CHECK-DAG: test.reflect_bounds {smax = 127 : si8, smin = 1 : si8, umax = 127 : ui8, umin = 1 : ui8} overflow<nsw, nuw>
+// CHECK-DAG: test.reflect_bounds {smax = 127 : si8, smin = -128 : si8, umax = 255 : ui8, umin = 0 : ui8} overflow<none>
+func.func @test_add_i8_reflect_overflow_kinds() -> (i8, i8, i8, i8) {
+ %cst1 = arith.constant 1 : i8
+
+ %0 = test.with_bounds {umin = 0 : i8, umax = 127 : i8, smin = 0 : i8, smax = 127 : i8} : i8
+ %1 = arith.addi %0, %cst1 overflow<nuw> : i8
+ %2 = test.reflect_bounds %1 : i8
+
+ %3 = test.with_bounds {umin = 246 : i8, umax = 255 : i8, smin = -10 : i8, smax = -1 : i8} : i8
+ %4 = arith.addi %3, %cst1 overflow<nsw> : i8
+ %5 = test.reflect_bounds %4 : i8
+
+ %6 = test.with_bounds {umin = 0 : i8, umax = 126 : i8, smin = 0 : i8, smax = 126 : i8} : i8
+ %7 = arith.addi %6, %cst1 overflow<nsw, nuw> : i8
+ %8 = test.reflect_bounds %7 : i8
+
+ %9 = test.with_bounds {umin = 0 : i8, umax = 255 : i8, smin = -128 : i8, smax = 127 : i8} : i8
+ %10 = arith.addi %9, %cst1 : i8
+ %11 = test.reflect_bounds %10 : i8
+
+ return %2, %5, %8, %11 : i8, i8, i8, i8
+}
diff --git a/mlir/test/Dialect/MemRef/int-range-inference.mlir b/mlir/test/Dialect/MemRef/int-range-inference.mlir
index 34568d1d1d520..214747383c483 100644
--- a/mlir/test/Dialect/MemRef/int-range-inference.mlir
+++ b/mlir/test/Dialect/MemRef/int-range-inference.mlir
@@ -13,7 +13,7 @@ func.func @dim_const(%m: memref<3x5xi32>) -> index {
// CHECK-LABEL: @dim_any_static
// CHECK: %[[op:.+]] = memref.dim
-// CHECK: %[[ret:.+]] = test.reflect_bounds {smax = 5 : index, smin = 3 : index, umax = 5 : index, umin = 3 : index} %[[op]]
+// CHECK: %[[ret:.+]] = test.reflect_bounds {smax = 5 : index, smin = 3 : index, umax = 5 : index, umin = 3 : index} overflow<none> %[[op]]
// CHECK: return %[[ret]]
func.func @dim_any_static(%m: memref<3x5xi32>, %x: index) -> index {
%0 = memref.dim %m, %x : memref<3x5xi32>
@@ -25,7 +25,7 @@ func.func @dim_any_static(%m: memref<3x5xi32>, %x: index) -> index {
// CHECK-LABEL: @dim_dynamic
// CHECK: %[[op:.+]] = memref.dim
-// CHECK: %[[ret:.+]] = test.reflect_bounds {smax = 9223372036854775807 : index, smin = 0 : index, umax = 9223372036854775807 : index, umin = 0 : index} %[[op]]
+// CHECK: %[[ret:.+]] = test.reflect_bounds {smax = 9223372036854775807 : index, smin = 0 : index, umax = 9223372036854775807 : index, umin = 0 : index} overflow<none> %[[op]]
// CHECK: return %[[ret]]
func.func @dim_dynamic(%m: memref<?x5xi32>) -> index {
%c0 = arith.constant 0 : index
@@ -38,7 +38,7 @@ func.func @dim_dynamic(%m: memref<?x5xi32>) -> index {
// CHECK-LABEL: @dim_any_dynamic
// CHECK: %[[op:.+]] = memref.dim
-// CHECK: %[[ret:.+]] = test.reflect_bounds {smax = 9223372036854775807 : index, smin = 0 : index, umax = 9223372036854775807 : index, umin = 0 : index} %[[op]]
+// CHECK: %[[ret:.+]] = test.reflect_bounds {smax = 9223372036854775807 : index, smin = 0 : index, umax = 9223372036854775807 : index, umin = 0 : index} overflow<none> %[[op]]
// CHECK: return %[[ret]]
func.func @dim_any_dynamic(%m: memref<?x5xi32>, %x: index) -> index {
%0 = memref.dim %m, %x : memref<?x5xi32>
@@ -50,7 +50,7 @@ func.func @dim_any_dynamic(%m: memref<?x5xi32>, %x: index) -> index {
// CHECK-LABEL: @dim_some_omitting_dynamic
// CHECK: %[[op:.+]] = memref.dim
-// CHECK: %[[ret:.+]] = test.reflect_bounds {smax = 5 : index, smin = 3 : index, umax = 5 : index, umin = 3 : index} %[[op]]
+// CHECK: %[[ret:.+]] = test.reflect_bounds {smax = 5 : index, smin = 3 : index, umax = 5 : index, umin = 3 : index} overflow<none> %[[op]]
// CHECK: return %[[ret]]
func.func @dim_some_omitting_dynamic(%m: memref<?x3x5xi32>, %x: index) -> index {
%c1 = arith.constant 1 : index
@@ -64,7 +64,7 @@ func.func @dim_some_omitting_dynamic(%m: memref<?x3x5xi32>, %x: index) -> index
// CHECK-LABEL: @dim_unranked
// CHECK: %[[op:.+]] = memref.dim
-// CHECK: %[[ret:.+]] = test.reflect_bounds {smax = 9223372036854775807 : index, smin = 0 : index, umax = 9223372036854775807 : index, umin = 0 : index} %[[op]]
+// CHECK: %[[ret:.+]] = test.reflect_bounds {smax = 9223372036854775807 : index, smin = 0 : index, umax = 9223372036854775807 : index, umin = 0 : index} overflow<none> %[[op]]
// CHECK: return %[[ret]]
func.func @dim_unranked(%m: memref<*xi32>) -> index {
%c0 = arith.constant 0 : index
@@ -72,3 +72,18 @@ func.func @dim_unranked(%m: memref<*xi32>) -> index {
%1 = test.reflect_bounds %0 : index
return %1 : index
}
+
+// -----
+
+// CHECK-LABEL: @dim_any_static_add_overflow
+// CHECK: %[[dim:.+]] = memref.dim
+// CHECK: %[[add:.+]] = arith.addi %[[dim]], %{{.*}} overflow<nuw> : index
+// CHECK: %[[ret:.+]] = test.reflect_bounds {smax = 6 : index, smin = 4 : index, umax = 6 : index, umin = 4 : index} overflow<nsw, nuw> %[[add]]
+// CHECK: return %[[ret]]
+func.func @dim_any_static_add_overflow(%m: memref<3x5xi32>, %x: index) -> index {
+ %c1 = arith.constant 1 : index
+ %0 = memref.dim %m, %x : memref<3x5xi32>
+ %1 = arith.addi %0, %c1 overflow<nuw> : index
+ %2 = test.reflect_bounds %1 : index
+ return %2 : index
+}
diff --git a/mlir/test/Dialect/Tensor/int-range-inference.mlir b/mlir/test/Dialect/Tensor/int-range-inference.mlir
index e90ebf5fccb8e..f32963bc3bdbe 100644
--- a/mlir/test/Dialect/Tensor/int-range-inference.mlir
+++ b/mlir/test/Dialect/Tensor/int-range-inference.mlir
@@ -13,7 +13,7 @@ func.func @dim_const(%t: tensor<3x5xi32>) -> index {
// CHECK-LABEL: @dim_any_static
// CHECK: %[[op:.+]] = tensor.dim
-// CHECK: %[[ret:.+]] = test.reflect_bounds {smax = 5 : index, smin = 3 : index, umax = 5 : index, umin = 3 : index} %[[op]]
+// CHECK: %[[ret:.+]] = test.reflect_bounds {smax = 5 : index, smin = 3 : index, umax = 5 : index, umin = 3 : index} overflow<none> %[[op]]
// CHECK: return %[[ret]]
func.func @dim_any_static(%t: tensor<3x5xi32>, %x: index) -> index {
%0 = tensor.dim %t, %x : tensor<3x5xi32>
@@ -25,7 +25,7 @@ func.func @dim_any_static(%t: tensor<3x5xi32>, %x: index) -> index {
// CHECK-LABEL: @dim_dynamic
// CHECK: %[[op:.+]] = tensor.dim
-// CHECK: %[[ret:.+]] = test.reflect_bounds {smax = 9223372036854775807 : index, smin = 0 : index, umax = 9223372036854775807 : index, umin = 0 : index} %[[op]]
+// CHECK: %[[ret:.+]] = test.reflect_bounds {smax = 9223372036854775807 : index, smin = 0 : index, umax = 9223372036854775807 : index, umin = 0 : index} overflow<none> %[[op]]
// CHECK: return %[[ret]]
func.func @dim_dynamic(%t: tensor<?x5xi32>) -> index {
%c0 = arith.constant 0 : index
@@ -38,7 +38,7 @@ func.func @dim_dynamic(%t: tensor<?x5xi32>) -> index {
// CHECK-LABEL: @dim_any_dynamic
// CHECK: %[[op:.+]] = tensor.dim
-// CHECK: %[[ret:.+]] = test.reflect_bounds {smax = 9223372036854775807 : index, smin = 0 : index, umax = 9223372036854775807 : index, umin = 0 : index} %[[op]]
+// CHECK: %[[ret:.+]] = test.reflect_bounds {smax = 9223372036854775807 : index, smin = 0 : index, umax = 9223372036854775807 : index, umin = 0 : index} overflow<none> %[[op]]
// CHECK: return %[[ret]]
func.func @dim_any_dynamic(%t: tensor<?x5xi32>, %x: index) -> index {
%0 = tensor.dim %t, %x : tensor<?x5xi32>
@@ -50,7 +50,7 @@ func.func @dim_any_dynamic(%t: tensor<?x5xi32>, %x: index) -> index {
// CHECK-LABEL: @dim_some_omitting_dynamic
// CHECK: %[[op:.+]] = tensor.dim
-// CHECK: %[[ret:.+]] = test.reflect_bounds {smax = 5 : index, smin = 3 : index, umax = 5 : index, umin = 3 : index} %[[op]]
+// CHECK: %[[ret:.+]] = test.reflect_bounds {smax = 5 : index, smin = 3 : index, umax = 5 : index, umin = 3 : index} overflow<none> %[[op]]
// CHECK: return %[[ret]]
func.func @dim_some_omitting_dynamic(%t: tensor<?x3x5xi32>, %x: index) -> index {
%c1 = arith.constant 1 : index
@@ -64,7 +64,7 @@ func.func @dim_some_omitting_dynamic(%t: tensor<?x3x5xi32>, %x: index) -> index
// CHECK-LABEL: @dim_unranked
// CHECK: %[[op:.+]] = tensor.dim
-// CHECK: %[[ret:.+]] = test.reflect_bounds {smax = 9223372036854775807 : index, smin = 0 : index, umax = 9223372036854775807 : index, umin = 0 : index} %[[op]]
+// CHECK: %[[ret:.+]] = test.reflect_bounds {smax = 9223372036854775807 : index, smin = 0 : index, umax = 9223372036854775807 : index, umin = 0 : index} overflow<none> %[[op]]
// CHECK: return %[[ret]]
func.func @dim_unranked(%t: tensor<*xi32>) -> index {
%c0 = arith.constant 0 : index
@@ -72,3 +72,18 @@ func.func @dim_unranked(%t: tensor<*xi32>) -> index {
%1 = test.reflect_bounds %0 : index
return %1 : index
}
+
+// -----
+
+// CHECK-LABEL: @dim_any_static_add_overflow
+// CHECK: %[[dim:.+]] = tensor.dim
+// CHECK: %[[add:.+]] = arith.addi %[[dim]], %{{.*}} overflow<nuw> : index
+// CHECK: %[[ret:.+]] = test.reflect_bounds {smax = 6 : index, smin = 4 : index, umax = 6 : index, umin = 4 : index} overflow<nsw, nuw> %[[add]]
+// CHECK: return %[[ret]]
+func.func @dim_any_static_add_overflow(%t: tensor<3x5xi32>, %x: index) -> index {
+ %c1 = arith.constant 1 : index
+ %0 = tensor.dim %t, %x : tensor<3x5xi32>
+ %1 = arith.addi %0, %c1 overflow<nuw> : index
+ %2 = test.reflect_bounds %1 : index
+ return %2 : index
+}
diff --git a/mlir/test/Dialect/Vector/int-range-interface.mlir b/mlir/test/Dialect/Vector/int-range-interface.mlir
index 4da8d8a967c73..553a1b09fbb45 100644
--- a/mlir/test/Dialect/Vector/int-range-interface.mlir
+++ b/mlir/test/Dialect/Vector/int-range-interface.mlir
@@ -79,6 +79,16 @@ func.func @vector_add() -> vector<4xindex> {
func.return %3 : vector<4xindex>
}
+// CHECK-LABEL: func @vector_add_i8_reflect_overflow
+// CHECK: test.reflect_bounds {smax = 127 : si8, smin = -128 : si8, umax = 128 : ui8, umin = 1 : ui8} overflow<nuw>
+func.func @vector_add_i8_reflect_overflow() -> vector<4xi8> {
+ %c1 = arith.constant dense<1> : vector<4xi8>
+ %0 = test.with_bounds {umin = 0 : ui8, umax = 127 : ui8, smin = 0 : si8, smax = 127 : si8} : vector<4xi8>
+ %1 = arith.addi %0, %c1 overflow<nuw> : vector<4xi8>
+ %2 = test.reflect_bounds %1 : vector<4xi8>
+ func.return %2 : vector<4xi8>
+}
+
// CHECK-LABEL: func @vector_insert
// CHECK: test.reflect_bounds {smax = 8 : index, smin = 5 : index, umax = 8 : index, umin = 5 : index}
func.func @vector_insert() -> vector<4xindex> {
@@ -91,7 +101,7 @@ func.func @vector_insert() -> vector<4xindex> {
// CHECK-LABEL: func @test_loaded_vector_extract
// No bounds
-// CHECK: test.reflect_bounds {smax = 2147483647 : si32, smin = -2147483648 : si32, umax = 4294967295 : ui32, umin = 0 : ui32} %{{.*}} : i32
+// CHECK: test.reflect_bounds {smax = 2147483647 : si32, smin = -2147483648 : si32, umax = 4294967295 : ui32, umin = 0 : ui32}
func.func @test_loaded_vector_extract(%memref : memref<16xi32>) -> i32 {
%c0 = arith.constant 0 : index
%v = vector.load %memref[%c0] : memref<16xi32>, vector<4xi32>
diff --git a/mlir/test/Interfaces/InferIntRangeInterface/infer-int-range-test-ops.mlir b/mlir/test/Interfaces/InferIntRangeInterface/infer-int-range-test-ops.mlir
index c6344447d9f74..e4998f396e0de 100644
--- a/mlir/test/Interfaces/InferIntRangeInterface/infer-int-range-test-ops.mlir
+++ b/mlir/test/Interfaces/InferIntRangeInterface/infer-int-range-test-ops.mlir
@@ -148,7 +148,7 @@ func.func @dont_propagate_across_infinite_loop() -> index {
^bb0(%i1: index):
scf.yield
}
- // CHECK: %[[ret:.*]] = test.reflect_bounds %[[loopRes]] : index
+ // CHECK: %[[ret:.*]] = test.reflect_bounds overflow<none> %[[loopRes]] : index
%2 = test.reflect_bounds %1 : index
// CHECK: return %[[ret]]
return %2 : index
More information about the Mlir-commits
mailing list