[Mlir-commits] [mlir] [mlir][Interfaces] Track and infer no-overflow flags in integer ranges (PR #191777)

Hocky Yudhiono llvmlistbot at llvm.org
Mon Apr 13 02:24:36 PDT 2026


https://github.com/hockyy updated https://github.com/llvm/llvm-project/pull/191777

>From 250978c3340338327a868df5fe2dd19842627370 Mon Sep 17 00:00:00 2001
From: Hocky Yudhiono <hocky.yudhiono at gmail.com>
Date: Mon, 13 Apr 2026 17:08:01 +0800
Subject: [PATCH] [mlir][Interfaces] Track and infer no-overflow flags in
 integer ranges

---
 .../mlir/Interfaces/InferIntRangeInterface.h  |  25 +++-
 .../Interfaces/Utils/InferIntRangeCommon.h    |   8 +-
 .../DataFlow/StridedMetadataRangeAnalysis.cpp |   8 +-
 .../Arith/IR/InferIntRangeInterfaceImpls.cpp  | 123 +++++++++++++++++-
 .../lib/Interfaces/InferIntRangeInterface.cpp |  29 ++++-
 .../Interfaces/InferIntRangeInterfaceTest.cpp |  59 +++++++++
 6 files changed, 229 insertions(+), 23 deletions(-)

diff --git a/mlir/include/mlir/Interfaces/InferIntRangeInterface.h b/mlir/include/mlir/Interfaces/InferIntRangeInterface.h
index a6de3d1885eec..9870cf1cbb6fa 100644
--- a/mlir/include/mlir/Interfaces/InferIntRangeInterface.h
+++ b/mlir/include/mlir/Interfaces/InferIntRangeInterface.h
@@ -15,9 +15,18 @@
 #define MLIR_INTERFACES_INFERINTRANGEINTERFACE_H
 
 #include "mlir/IR/OpDefinition.h"
+#include "llvm/ADT/BitmaskEnum.h"
 #include <optional>
 
 namespace mlir {
+enum class OverflowFlags : uint32_t {
+  None = 0,
+  Nsw = 1,
+  Nuw = 2,
+  LLVM_MARK_AS_BITMASK_ENUM(Nuw)
+};
+LLVM_ENABLE_BITMASK_ENUMS_IN_NAMESPACE();
+
 /// A set of arbitrary-precision integers representing bounds on a given integer
 /// value. These bounds are inclusive on both ends, so
 /// bounds of [4, 5] mean 4 <= x <= 5. Separate bounds are tracked for
@@ -29,8 +38,10 @@ class ConstantIntRanges {
   /// Bound umin <= (unsigned)x <= umax and smin <= signed(x) <= smax.
   /// Non-integer values should be bounded by APInts of bitwidth 0.
   ConstantIntRanges(const APInt &umin, const APInt &umax, const APInt &smin,
-                    const APInt &smax)
-      : uminVal(umin), umaxVal(umax), sminVal(smin), smaxVal(smax) {
+                    const APInt &smax,
+                    OverflowFlags overflowFlags = OverflowFlags::None)
+      : uminVal(umin), umaxVal(umax), sminVal(smin), smaxVal(smax),
+        overflowFlags(overflowFlags) {
     assert(uminVal.getBitWidth() == umaxVal.getBitWidth() &&
            umaxVal.getBitWidth() == sminVal.getBitWidth() &&
            sminVal.getBitWidth() == smaxVal.getBitWidth() &&
@@ -96,11 +107,21 @@ class ConstantIntRanges {
   /// value.
   std::optional<APInt> getConstantValue() const;
 
+  /// Return overflow properties proven for the operation computing the bounded
+  /// value.
+  OverflowFlags getOverflowFlags() const { return overflowFlags; }
+
+  /// Return this range with updated overflow properties.
+  ConstantIntRanges withOverflowFlags(OverflowFlags flags) const {
+    return {uminVal, umaxVal, sminVal, smaxVal, flags};
+  }
+
   friend raw_ostream &operator<<(raw_ostream &os,
                                  const ConstantIntRanges &range);
 
 private:
   APInt uminVal, umaxVal, sminVal, smaxVal;
+  OverflowFlags overflowFlags = OverflowFlags::None;
 };
 
 raw_ostream &operator<<(raw_ostream &, const ConstantIntRanges &);
diff --git a/mlir/include/mlir/Interfaces/Utils/InferIntRangeCommon.h b/mlir/include/mlir/Interfaces/Utils/InferIntRangeCommon.h
index e369c80a26ea9..8c0784a4513bb 100644
--- a/mlir/include/mlir/Interfaces/Utils/InferIntRangeCommon.h
+++ b/mlir/include/mlir/Interfaces/Utils/InferIntRangeCommon.h
@@ -16,7 +16,6 @@
 
 #include "mlir/Interfaces/InferIntRangeInterface.h"
 #include "llvm/ADT/ArrayRef.h"
-#include "llvm/ADT/BitmaskEnum.h"
 #include <optional>
 
 namespace mlir {
@@ -39,12 +38,7 @@ static constexpr unsigned indexMaxWidth = 64;
 
 enum class CmpMode : uint32_t { Both, Signed, Unsigned };
 
-enum class OverflowFlags : uint32_t {
-  None = 0,
-  Nsw = 1,
-  Nuw = 2,
-  LLVM_MARK_AS_BITMASK_ENUM(Nuw)
-};
+using OverflowFlags = mlir::OverflowFlags;
 
 /// Function that performs inference on an array of `ConstantIntRanges` while
 /// taking special overflow behavior into account.
diff --git a/mlir/lib/Analysis/DataFlow/StridedMetadataRangeAnalysis.cpp b/mlir/lib/Analysis/DataFlow/StridedMetadataRangeAnalysis.cpp
index 01c9dafaddf10..560683b86cbb4 100644
--- a/mlir/lib/Analysis/DataFlow/StridedMetadataRangeAnalysis.cpp
+++ b/mlir/lib/Analysis/DataFlow/StridedMetadataRangeAnalysis.cpp
@@ -104,10 +104,10 @@ LogicalResult StridedMetadataRangeAnalysis::visitOperation(
   };
 
   // Convert the arguments lattices to a vector.
-  SmallVector<StridedMetadataRange> argRanges = llvm::map_to_vector(
-      operands, [](const StridedMetadataRangeLattice *lattice) {
-        return lattice->getValue();
-      });
+  SmallVector<StridedMetadataRange, 0> argRanges;
+  argRanges.reserve(operands.size());
+  for (const StridedMetadataRangeLattice *lattice : operands)
+    argRanges.push_back(lattice->getValue());
 
   // Callback to set metadata on a result.
   auto joinCallback = [&](Value v, const StridedMetadataRange &md) {
diff --git a/mlir/lib/Dialect/Arith/IR/InferIntRangeInterfaceImpls.cpp b/mlir/lib/Dialect/Arith/IR/InferIntRangeInterfaceImpls.cpp
index 49f89e1bd17f3..c1ad9544b7102 100644
--- a/mlir/lib/Dialect/Arith/IR/InferIntRangeInterfaceImpls.cpp
+++ b/mlir/lib/Dialect/Arith/IR/InferIntRangeInterfaceImpls.cpp
@@ -11,6 +11,7 @@
 #include "mlir/Interfaces/Utils/InferIntRangeCommon.h"
 
 #include <optional>
+#include <utility>
 
 #define DEBUG_TYPE "int-range-analysis"
 
@@ -28,6 +29,110 @@ convertArithOverflowFlags(arith::IntegerOverflowFlags flags) {
   return retFlags;
 }
 
+template <typename Op>
+static bool proveNoOverflow(const APInt &lhs, const APInt &rhs, Op op) {
+  bool overflow = false;
+  (void)op(lhs, rhs, overflow);
+  return !overflow;
+}
+
+template <typename Op>
+static bool
+proveNoOverflowForPairs(ArrayRef<std::pair<const APInt *, const APInt *>> pairs,
+                        Op op) {
+  for (const auto &[lhs, rhs] : pairs) {
+    if (!proveNoOverflow(*lhs, *rhs, op))
+      return false;
+  }
+  return true;
+}
+
+static OverflowFlags proveNoOverflowFlags(
+    ArrayRef<ConstantIntRanges> args,
+    function_ref<bool(ArrayRef<ConstantIntRanges>)> proveSigned,
+    function_ref<bool(ArrayRef<ConstantIntRanges>)> proveUnsigned) {
+  OverflowFlags flags = OverflowFlags::None;
+  if (proveSigned(args))
+    flags |= OverflowFlags::Nsw;
+  if (proveUnsigned(args))
+    flags |= OverflowFlags::Nuw;
+  return flags;
+}
+
+static bool proveNoSignedAddOverflow(ArrayRef<ConstantIntRanges> argRanges) {
+  const APInt &lhsMin = argRanges[0].smin();
+  const APInt &lhsMax = argRanges[0].smax();
+  const APInt &rhsMin = argRanges[1].smin();
+  const APInt &rhsMax = argRanges[1].smax();
+  return proveNoOverflowForPairs(
+      {{&lhsMin, &rhsMin}, {&lhsMax, &rhsMax}},
+      [](const APInt &lhs, const APInt &rhs, bool &overflow) {
+        return lhs.sadd_ov(rhs, overflow);
+      });
+}
+
+static bool proveNoUnsignedAddOverflow(ArrayRef<ConstantIntRanges> argRanges) {
+  return proveNoOverflow(
+      argRanges[0].umax(), argRanges[1].umax(),
+      [](const APInt &lhs, const APInt &rhs, bool &overflow) {
+        return lhs.uadd_ov(rhs, overflow);
+      });
+}
+
+static bool proveNoSignedSubOverflow(ArrayRef<ConstantIntRanges> argRanges) {
+  const APInt &lhsMin = argRanges[0].smin();
+  const APInt &lhsMax = argRanges[0].smax();
+  const APInt &rhsMin = argRanges[1].smin();
+  const APInt &rhsMax = argRanges[1].smax();
+  return proveNoOverflowForPairs(
+      {{&lhsMin, &rhsMax}, {&lhsMax, &rhsMin}},
+      [](const APInt &lhs, const APInt &rhs, bool &overflow) {
+        return lhs.ssub_ov(rhs, overflow);
+      });
+}
+
+static bool proveNoUnsignedSubOverflow(ArrayRef<ConstantIntRanges> argRanges) {
+  return argRanges[0].umin().uge(argRanges[1].umax());
+}
+
+static bool proveNoSignedMulOverflow(ArrayRef<ConstantIntRanges> argRanges) {
+  const APInt &lhsMin = argRanges[0].smin();
+  const APInt &lhsMax = argRanges[0].smax();
+  const APInt &rhsMin = argRanges[1].smin();
+  const APInt &rhsMax = argRanges[1].smax();
+  return proveNoOverflowForPairs(
+      {{&lhsMin, &rhsMin},
+       {&lhsMin, &rhsMax},
+       {&lhsMax, &rhsMin},
+       {&lhsMax, &rhsMax}},
+      [](const APInt &lhs, const APInt &rhs, bool &overflow) {
+        return lhs.smul_ov(rhs, overflow);
+      });
+}
+
+static bool proveNoUnsignedMulOverflow(ArrayRef<ConstantIntRanges> argRanges) {
+  return proveNoOverflow(
+      argRanges[0].umax(), argRanges[1].umax(),
+      [](const APInt &lhs, const APInt &rhs, bool &overflow) {
+        return lhs.umul_ov(rhs, overflow);
+      });
+}
+
+static OverflowFlags proveNoOverflowForAdd(ArrayRef<ConstantIntRanges> args) {
+  return proveNoOverflowFlags(args, proveNoSignedAddOverflow,
+                              proveNoUnsignedAddOverflow);
+}
+
+static OverflowFlags proveNoOverflowForSub(ArrayRef<ConstantIntRanges> args) {
+  return proveNoOverflowFlags(args, proveNoSignedSubOverflow,
+                              proveNoUnsignedSubOverflow);
+}
+
+static OverflowFlags proveNoOverflowForMul(ArrayRef<ConstantIntRanges> args) {
+  return proveNoOverflowFlags(args, proveNoSignedMulOverflow,
+                              proveNoUnsignedMulOverflow);
+}
+
 //===----------------------------------------------------------------------===//
 // ConstantOp
 //===----------------------------------------------------------------------===//
@@ -65,8 +170,10 @@ void arith::ConstantOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
 
 void arith::AddIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
                                       SetIntRangeFn setResultRange) {
-  setResultRange(getResult(), inferAdd(argRanges, convertArithOverflowFlags(
-                                                      getOverflowFlags())));
+  OverflowFlags declaredFlags = convertArithOverflowFlags(getOverflowFlags());
+  ConstantIntRanges range = inferAdd(argRanges, declaredFlags);
+  OverflowFlags overflowFlags = proveNoOverflowForAdd(argRanges) | declaredFlags;
+  setResultRange(getResult(), range.withOverflowFlags(overflowFlags));
 }
 
 //===----------------------------------------------------------------------===//
@@ -75,8 +182,10 @@ void arith::AddIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
 
 void arith::SubIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
                                       SetIntRangeFn setResultRange) {
-  setResultRange(getResult(), inferSub(argRanges, convertArithOverflowFlags(
-                                                      getOverflowFlags())));
+  OverflowFlags declaredFlags = convertArithOverflowFlags(getOverflowFlags());
+  ConstantIntRanges range = inferSub(argRanges, declaredFlags);
+  OverflowFlags overflowFlags = proveNoOverflowForSub(argRanges) | declaredFlags;
+  setResultRange(getResult(), range.withOverflowFlags(overflowFlags));
 }
 
 //===----------------------------------------------------------------------===//
@@ -85,8 +194,10 @@ void arith::SubIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
 
 void arith::MulIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
                                       SetIntRangeFn setResultRange) {
-  setResultRange(getResult(), inferMul(argRanges, convertArithOverflowFlags(
-                                                      getOverflowFlags())));
+  OverflowFlags declaredFlags = convertArithOverflowFlags(getOverflowFlags());
+  ConstantIntRanges range = inferMul(argRanges, declaredFlags);
+  OverflowFlags overflowFlags = proveNoOverflowForMul(argRanges) | declaredFlags;
+  setResultRange(getResult(), range.withOverflowFlags(overflowFlags));
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Interfaces/InferIntRangeInterface.cpp b/mlir/lib/Interfaces/InferIntRangeInterface.cpp
index 9d8e5f50a725b..16283ab553bf3 100644
--- a/mlir/lib/Interfaces/InferIntRangeInterface.cpp
+++ b/mlir/lib/Interfaces/InferIntRangeInterface.cpp
@@ -17,7 +17,8 @@ using namespace mlir;
 bool ConstantIntRanges::operator==(const ConstantIntRanges &other) const {
   return umin().getBitWidth() == other.umin().getBitWidth() &&
          umin() == other.umin() && umax() == other.umax() &&
-         smin() == other.smin() && smax() == other.smax();
+         smin() == other.smin() && smax() == other.smax() &&
+         getOverflowFlags() == other.getOverflowFlags();
 }
 
 const APInt &ConstantIntRanges::umin() const { return uminVal; }
@@ -94,8 +95,9 @@ ConstantIntRanges::rangeUnion(const ConstantIntRanges &other) const {
   const APInt &umaxUnion = umax().ugt(other.umax()) ? umax() : other.umax();
   const APInt &sminUnion = smin().slt(other.smin()) ? smin() : other.smin();
   const APInt &smaxUnion = smax().sgt(other.smax()) ? smax() : other.smax();
+  OverflowFlags overflowUnion = getOverflowFlags() & other.getOverflowFlags();
 
-  return {uminUnion, umaxUnion, sminUnion, smaxUnion};
+  return {uminUnion, umaxUnion, sminUnion, smaxUnion, overflowUnion};
 }
 
 ConstantIntRanges
@@ -111,8 +113,11 @@ ConstantIntRanges::intersection(const ConstantIntRanges &other) const {
   const APInt &umaxIntersect = umax().ult(other.umax()) ? umax() : other.umax();
   const APInt &sminIntersect = smin().sgt(other.smin()) ? smin() : other.smin();
   const APInt &smaxIntersect = smax().slt(other.smax()) ? smax() : other.smax();
+  OverflowFlags overflowIntersect =
+      getOverflowFlags() | other.getOverflowFlags();
 
-  return {uminIntersect, umaxIntersect, sminIntersect, smaxIntersect};
+  return {uminIntersect, umaxIntersect, sminIntersect, smaxIntersect,
+          overflowIntersect};
 }
 
 std::optional<APInt> ConstantIntRanges::getConstantValue() const {
@@ -129,7 +134,23 @@ raw_ostream &mlir::operator<<(raw_ostream &os, const ConstantIntRanges &range) {
   range.umin().print(os, /*isSigned*/ false);
   os << ", ";
   range.umax().print(os, /*isSigned*/ false);
-  return os << "] signed : [" << range.smin() << ", " << range.smax() << "]";
+  os << "] signed : [" << range.smin() << ", " << range.smax() << "]";
+  OverflowFlags overflowFlags = range.getOverflowFlags();
+  if (overflowFlags == OverflowFlags::None)
+    return os;
+
+  os << " overflow<";
+  bool emitted = false;
+  if ((overflowFlags & OverflowFlags::Nsw) != OverflowFlags::None) {
+    os << "nsw";
+    emitted = true;
+  }
+  if ((overflowFlags & OverflowFlags::Nuw) != OverflowFlags::None) {
+    if (emitted)
+      os << ", ";
+    os << "nuw";
+  }
+  return os << ">";
 }
 
 IntegerValueRange IntegerValueRange::getMaxRange(Value value) {
diff --git a/mlir/unittests/Interfaces/InferIntRangeInterfaceTest.cpp b/mlir/unittests/Interfaces/InferIntRangeInterfaceTest.cpp
index 97c75b3680567..1b88a3f7a74aa 100644
--- a/mlir/unittests/Interfaces/InferIntRangeInterfaceTest.cpp
+++ b/mlir/unittests/Interfaces/InferIntRangeInterfaceTest.cpp
@@ -8,6 +8,7 @@
 
 #include "mlir/Interfaces/InferIntRangeInterface.h"
 #include "llvm/ADT/APInt.h"
+#include "llvm/Support/raw_ostream.h"
 #include <limits>
 
 #include <gtest/gtest.h>
@@ -97,3 +98,61 @@ TEST(IntRangeAttrs, Join) {
   ConstantIntRanges zeroOneSignedOnly(zero, uintMax, zero, one);
   EXPECT_EQ(zeroOneUnsignedOnly.rangeUnion(zeroOneSignedOnly), maximal);
 }
+
+TEST(IntRangeAttrs, OverflowFlags) {
+  APInt zero = APInt::getZero(64);
+  APInt one = zero + 1;
+  APInt two = zero + 2;
+
+  ConstantIntRanges nswOnly(zero, one, zero, one, OverflowFlags::Nsw);
+  ConstantIntRanges nuwOnly(one, two, one, two, OverflowFlags::Nuw);
+
+  EXPECT_NE(nswOnly.getOverflowFlags() & OverflowFlags::Nsw,
+            OverflowFlags::None);
+  EXPECT_EQ(nswOnly.getOverflowFlags() & OverflowFlags::Nuw,
+            OverflowFlags::None);
+
+  ConstantIntRanges both =
+      nswOnly.withOverflowFlags(OverflowFlags::Nsw | OverflowFlags::Nuw);
+  EXPECT_NE(both.getOverflowFlags() & OverflowFlags::Nsw, OverflowFlags::None);
+  EXPECT_NE(both.getOverflowFlags() & OverflowFlags::Nuw, OverflowFlags::None);
+
+  // rangeUnion conservatively preserves only proofs present in both inputs.
+  EXPECT_EQ(nswOnly.rangeUnion(nuwOnly).getOverflowFlags(),
+            OverflowFlags::None);
+  EXPECT_EQ(both.rangeUnion(nswOnly).getOverflowFlags(), OverflowFlags::Nsw);
+
+  // intersection preserves proofs from either input.
+  EXPECT_EQ(nswOnly.intersection(nuwOnly).getOverflowFlags(),
+            OverflowFlags::Nsw | OverflowFlags::Nuw);
+  EXPECT_EQ(both.intersection(nswOnly).getOverflowFlags(),
+            OverflowFlags::Nsw | OverflowFlags::Nuw);
+  ConstantIntRanges none(zero, two, zero, two, OverflowFlags::None);
+  EXPECT_EQ(nswOnly.intersection(none).getOverflowFlags(), OverflowFlags::Nsw);
+}
+
+TEST(IntRangeAttrs, OverflowFlagsPrinting) {
+  APInt zero = APInt::getZero(64);
+  APInt one = zero + 1;
+
+  auto toString = [](const ConstantIntRanges &r) {
+    std::string buf;
+    llvm::raw_string_ostream os(buf);
+    os << r;
+    return buf;
+  };
+
+  ConstantIntRanges noFlags(zero, one, zero, one);
+  EXPECT_EQ(toString(noFlags), "unsigned : [0, 1] signed : [0, 1]");
+
+  ConstantIntRanges nsw(zero, one, zero, one, OverflowFlags::Nsw);
+  EXPECT_EQ(toString(nsw), "unsigned : [0, 1] signed : [0, 1] overflow<nsw>");
+
+  ConstantIntRanges nuw(zero, one, zero, one, OverflowFlags::Nuw);
+  EXPECT_EQ(toString(nuw), "unsigned : [0, 1] signed : [0, 1] overflow<nuw>");
+
+  ConstantIntRanges both(zero, one, zero, one,
+                         OverflowFlags::Nsw | OverflowFlags::Nuw);
+  EXPECT_EQ(toString(both),
+            "unsigned : [0, 1] signed : [0, 1] overflow<nsw, nuw>");
+}



More information about the Mlir-commits mailing list