[Mlir-commits] [mlir] 1e84219 - [mlir][dataflow] IntRange: Replace yield-based widening with per-state lattice budget (#196616)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon May 11 04:53:46 PDT 2026
Author: Ivan Butygin
Date: 2026-05-11T14:53:41+03:00
New Revision: 1e8421904e34ce58b4b09a115aa058c6692395c5
URL: https://github.com/llvm/llvm-project/commit/1e8421904e34ce58b4b09a115aa058c6692395c5
DIFF: https://github.com/llvm/llvm-project/commit/1e8421904e34ce58b4b09a115aa058c6692395c5.diff
LOG: [mlir][dataflow] IntRange: Replace yield-based widening with per-state lattice budget (#196616)
IntegerRangeAnalysis can hang on `scf.while` loops with dynamic bounds:
a
loop-carried range ratchets [0,0]->[0,1]->[0,2]->... by one per worklist
visit, requiring up to 2^31 iterations on i32. The new
`int-range-analysis-convergence.mlir` test reproduces this.
The ratchet lives at framework merge sites (region successors, callable
args) where the solver joins lattices via virtual
`Lattice::join(const AbstractSparseLattice &)`. The pre-existing
`isYieldedResult`/`isYieldedValue` heuristic in
`IntegerRangeAnalysis::visitOperation` doesn't help: it runs in the
transfer-function callback for inferrable-op results used by a
terminator,
not on the merge path. It is also harmful where it fires - slams to
maxRange on the *second* visit (after, say, [1,1]->[1,2]), so naturally
bounded accumulators (e.g. `arith.minsi`-clamped iter args) widen to
[INT_MIN, INT_MAX].
Replace it with a per-state widening budget on
`IntegerValueRangeLattice`:
the lattice counts merge-site joins and forces the range to its max once
the count hits `kIntegerRangeWideningBudget` (128). Only the virtual
overload is overridden, so transfer-function joins via the non-virtual
`join(const ValueT &)` are unaffected. The new
`int-range-loop-iter-args.mlir`
test pins the tighter bounds; the convergence test verifies termination.
Added:
mlir/test/Dialect/Arith/int-range-analysis-convergence.mlir
mlir/test/Dialect/Arith/int-range-loop-iter-args.mlir
Modified:
mlir/include/mlir/Analysis/DataFlow/IntegerRangeAnalysis.h
mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Analysis/DataFlow/IntegerRangeAnalysis.h b/mlir/include/mlir/Analysis/DataFlow/IntegerRangeAnalysis.h
index 5b6ae9bf84265..8d75c4016355b 100644
--- a/mlir/include/mlir/Analysis/DataFlow/IntegerRangeAnalysis.h
+++ b/mlir/include/mlir/Analysis/DataFlow/IntegerRangeAnalysis.h
@@ -26,9 +26,37 @@ class RewriterBase;
namespace dataflow {
/// This lattice element represents the integer value range of an SSA value.
+///
+/// `join` overrides the base behaviour to apply per-state widening: once
+/// the lattice has absorbed enough strictly-increasing merges the range is
+/// forced to its max as a sound over-approximation. This is the sole
+/// convergence guarantee for `IntegerRangeAnalysis` on loop-carried
+/// values; without it, `scf.while` loops with dynamic bounds and nested
+/// region ops can keep the solver ratcheting a loop-carried range by +1
+/// per worklist visit for up to 2^31 iterations on i32. The budget is
+/// sized to be much larger than realistic merge counts on naturally
+/// bounded accumulators (e.g. `arith.minsi`/`arith.andi`-clamped iter
+/// args) so the analysis still converges to a tight range on those.
+///
+/// Note that only the `(const AbstractSparseLattice &)` overload is
+/// overridden, so the widening fires only at framework merge sites
+/// (block-arg / region-successor / callable-arg joins) —
+/// transfer-function updates that go through the non-virtual
+/// `join(const ValueT &)` overload are unaffected.
class IntegerValueRangeLattice : public Lattice<IntegerValueRange> {
public:
using Lattice::Lattice;
+ // The override below would otherwise hide the inherited
+ // `join(const ValueT &)` overload that callers (e.g. transfer functions)
+ // rely on for direct-value joins.
+ using Lattice::join;
+
+ ChangeResult join(const AbstractSparseLattice &rhs) override;
+
+private:
+ /// Per-state merge-site change counter. Drives the widening budget in
+ /// `join`.
+ unsigned mergeChangeCount = 0;
};
/// Integer range analysis determines the integer value range of SSA values
diff --git a/mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp b/mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp
index b29fc28131806..613772c2b7404 100644
--- a/mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp
+++ b/mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp
@@ -58,6 +58,29 @@ LogicalResult staticallyNonNegative(DataFlowSolver &solver, Operation *op) {
}
} // namespace mlir::dataflow
+/// Number of merge-site joins a single integer-range lattice element is
+/// allowed to absorb before `IntegerValueRangeLattice::join` forces it to
+/// its max as a sound over-approximation.
+///
+/// Trade-off: high enough that realistic loops with dynamic bounds (which
+/// typically converge to a tight range in a small number of merge
+/// iterations) are not widened prematurely; low enough that the +1
+/// ratchet pathology this widening exists to cut off (loop-carried ranges
+/// growing by one per worklist visit) terminates after at most this many
+/// extra solver iterations rather than ~2^31.
+static constexpr unsigned kIntegerRangeWideningBudget = 128;
+
+ChangeResult IntegerValueRangeLattice::join(const AbstractSparseLattice &rhs) {
+ ChangeResult changed = Lattice::join(rhs);
+ if (mergeChangeCount >= kIntegerRangeWideningBudget) {
+ return changed | Lattice::join(IntegerValueRange::getMaxRange(
+ cast<Value>(getAnchor())));
+ }
+ if (changed == ChangeResult::Change)
+ ++mergeChangeCount;
+ return changed;
+}
+
LogicalResult IntegerRangeAnalysis::visitOperation(
Operation *op, ArrayRef<const IntegerValueRangeLattice *> operands,
ArrayRef<IntegerValueRangeLattice *> results) {
@@ -82,23 +105,7 @@ LogicalResult IntegerRangeAnalysis::visitOperation(
LDBG() << "Inferred range " << attrs;
IntegerValueRangeLattice *lattice = results[result.getResultNumber()];
- IntegerValueRange oldRange = lattice->getValue();
-
- ChangeResult changed = lattice->join(attrs);
-
- // Catch loop results with loop variant bounds and conservatively make
- // them [-inf, inf] so we don't circle around infinitely often (because
- // the dataflow analysis in MLIR doesn't attempt to work out trip counts
- // and often can't).
- bool isYieldedResult = llvm::any_of(v.getUsers(), [](Operation *op) {
- return op->hasTrait<OpTrait::IsTerminator>();
- });
- if (isYieldedResult && !oldRange.isUninitialized() &&
- !(lattice->getValue() == oldRange)) {
- LDBG() << "Loop variant loop result detected";
- changed |= lattice->join(IntegerValueRange::getMaxRange(v));
- }
- propagateIfChanged(lattice, changed);
+ propagateIfChanged(lattice, lattice->join(attrs));
};
inferrable.inferResultRangesFromOptional(argRanges, joinCallback);
@@ -132,23 +139,7 @@ void IntegerRangeAnalysis::visitNonControlFlowArguments(
std::distance(successor.getSuccessor()->getArguments().begin(), it);
IntegerValueRangeLattice *lattice =
nonSuccessorInputLattices[nonSuccessorInputIdx];
- IntegerValueRange oldRange = lattice->getValue();
-
- ChangeResult changed = lattice->join(attrs);
-
- // Catch loop results with loop variant bounds and conservatively make
- // them [-inf, inf] so we don't circle around infinitely often (because
- // the dataflow analysis in MLIR doesn't attempt to work out trip counts
- // and often can't).
- bool isYieldedValue = llvm::any_of(v.getUsers(), [](Operation *op) {
- return op->hasTrait<OpTrait::IsTerminator>();
- });
- if (isYieldedValue && !oldRange.isUninitialized() &&
- !(lattice->getValue() == oldRange)) {
- LDBG() << "Loop variant loop result detected";
- changed |= lattice->join(IntegerValueRange::getMaxRange(v));
- }
- propagateIfChanged(lattice, changed);
+ propagateIfChanged(lattice, lattice->join(attrs));
};
inferrable.inferResultRangesFromOptional(argRanges, joinCallback);
diff --git a/mlir/test/Dialect/Arith/int-range-analysis-convergence.mlir b/mlir/test/Dialect/Arith/int-range-analysis-convergence.mlir
new file mode 100644
index 0000000000000..a932d4b699a89
--- /dev/null
+++ b/mlir/test/Dialect/Arith/int-range-analysis-convergence.mlir
@@ -0,0 +1,91 @@
+// IntegerRangeAnalysis convergence on scf.while with dynamic bounds.
+//
+// The carry range ratchets [0,0]->[0,1]->[0,2]->... per worklist visit;
+// nested scf.if layers with arith chains (addi, muli) bounded by remui
+// create enough worklist cascade to defeat the solver's back-to-back
+// convergence shortcut. The per-state widening budget on
+// IntegerValueRangeLattice forces the range to its max after a bounded
+// number of strict refinements, so the analysis terminates instead of
+// hanging for ~minutes (or 2^31 iterations).
+//
+// We assert:
+// - the analysis terminates and produces well-formed IR;
+// - the loop-carried iter arg of the outer scf.while widens to
+// [INT_MIN, INT_MAX] (the only sound result once the budget fires);
+// - transfer-function results inside the body stay tight (e.g.
+// `arith.remui ..., %c127` = [0, 126]), verifying the widening is
+// scoped to framework merge sites, not transfer-function joins.
+//
+// RUN: mlir-opt -int-range-optimizations %s | FileCheck %s
+
+// CHECK-LABEL: func.func @grouped_gemm_while_hang
+// CHECK-SAME: (%[[N:.*]]: i32, %{{.*}}: i1) -> i32
+func.func @grouped_gemm_while_hang(%n: i32, %flag: i1) -> i32 {
+ %c0 = arith.constant 0 : i32
+ %c1 = arith.constant 1 : i32
+ %c3 = arith.constant 3 : i32
+ %c7 = arith.constant 7 : i32
+ %c127 = arith.constant 127 : i32
+ %init = arith.cmpi slt, %c0, %n : i32
+
+ // CHECK: %[[OUTER:.*]]:2 = scf.while
+ %res:2 = scf.while (%a0 = %c0, %cond = %init) : (i32, i1) -> (i32, i1) {
+ scf.condition(%cond) %a0, %cond : i32, i1
+ } do {
+ ^bb0(%b0: i32, %bc: i1):
+ %t0 = arith.addi %b0, %c1 : i32
+ %ic = arith.cmpi slt, %t0, %n : i32
+
+ // CHECK: scf.while
+ %inner:2 = scf.while (%i0 = %t0, %iic = %ic) : (i32, i1) -> (i32, i1) {
+ scf.condition(%iic) %i0, %iic : i32, i1
+ } do {
+ ^bb1(%j0: i32, %jc: i1):
+
+ %L0 = scf.if %flag -> (i32) {
+ %a0_0 = arith.addi %j0, %c1 : i32
+ %a0_1 = arith.muli %a0_0, %c7 : i32
+ %a0_r = arith.remui %a0_1, %c127 : i32
+ scf.yield %a0_r : i32
+ } else {
+ %b0_0 = arith.addi %j0, %c3 : i32
+ %b0_1 = arith.muli %b0_0, %c7 : i32
+ %b0_r = arith.remui %b0_1, %c127 : i32
+ scf.yield %b0_r : i32
+ }
+
+ %L1 = scf.if %flag -> (i32) {
+ %a1_0 = arith.addi %L0, %c1 : i32
+ %a1_1 = arith.muli %a1_0, %c7 : i32
+ %a1_r = arith.remui %a1_1, %c127 : i32
+ scf.yield %a1_r : i32
+ } else {
+ %b1_0 = arith.addi %L0, %c3 : i32
+ %b1_1 = arith.muli %b1_0, %c7 : i32
+ %b1_r = arith.remui %b1_1, %c127 : i32
+ scf.yield %b1_r : i32
+ }
+
+ %nic = arith.cmpi slt, %L1, %n : i32
+ // The yielded `arith.remui` result stays at [0, 126]: the widening
+ // budget only fires on virtual `Lattice::join` at framework merge
+ // sites, not on transfer-function joins for inferrable ops.
+ // CHECK: test.reflect_bounds {smax = 126 : si32, smin = 0 : si32, umax = 126 : ui32, umin = 0 : ui32}
+ %r_l1 = test.reflect_bounds %L1 : i32
+ scf.yield %L1, %nic : i32, i1
+ }
+
+ %nc = arith.cmpi slt, %inner#0, %n : i32
+ scf.yield %inner#0, %nc : i32, i1
+ }
+ // The outer loop-carried iter arg goes through region-successor merges
+ // and is widened to maxRange after the budget is exhausted. The mere
+ // presence of these bounds here is the convergence assertion: without
+ // the patch the analysis would not terminate to print this attribute.
+ // CHECK: %[[BOUNDED:.*]] = test.reflect_bounds
+ // CHECK-SAME: {smax = 2147483647 : si32, smin = -2147483648 : si32, umax = 4294967295 : ui32, umin = 0 : ui32}
+ // CHECK-SAME: %[[OUTER]]#0 : i32
+ %r = test.reflect_bounds %res#0 : i32
+ // CHECK: return %[[BOUNDED]] : i32
+ return %r : i32
+}
diff --git a/mlir/test/Dialect/Arith/int-range-loop-iter-args.mlir b/mlir/test/Dialect/Arith/int-range-loop-iter-args.mlir
new file mode 100644
index 0000000000000..24801875e257c
--- /dev/null
+++ b/mlir/test/Dialect/Arith/int-range-loop-iter-args.mlir
@@ -0,0 +1,63 @@
+// RUN: mlir-opt --int-range-optimizations %s | FileCheck %s
+
+// Verify that `IntegerRangeAnalysis` infers tight bounds for loop-carried
+// values that are structurally bounded inside the loop body (via
+// `arith.minsi`, `arith.andi`, etc.). Convergence is guaranteed by the
+// per-state widening budget on `IntegerValueRangeLattice`; the budget is
+// large enough that these naturally bounded ratchets reach a fixpoint
+// without being widened to `[INT_MIN, INT_MAX]`.
+
+// CHECK-LABEL: func @bounded_acc_for
+// CHECK: test.reflect_bounds {smax = 10 : si32, smin = 0 : si32, umax = 10 : ui32, umin = 0 : ui32}
+func.func @bounded_acc_for(%n: i32) -> i32 {
+ %c0 = arith.constant 0 : i32
+ %c1 = arith.constant 1 : i32
+ %c10 = arith.constant 10 : i32
+ %res = scf.for %i = %c0 to %n step %c1 iter_args(%acc = %c0) -> i32 : i32 {
+ %incr = arith.addi %acc, %c1 : i32
+ %clamped = arith.minsi %incr, %c10 : i32
+ scf.yield %clamped : i32
+ }
+ %r = test.reflect_bounds %res : i32
+ return %r : i32
+}
+
+// The `arith.cmpi slt, %acc, 100` should fold to `true` once the analysis
+// proves the iter arg stays in `[0, 10]`, exposing a downstream
+// optimization that the previous yield-based widening masked.
+// CHECK-LABEL: func @bounded_acc_while
+// CHECK: %[[TRUE:.*]] = arith.constant true
+// CHECK: scf.condition(%[[TRUE]])
+// CHECK: test.reflect_bounds {smax = 10 : si32, smin = 0 : si32, umax = 10 : ui32, umin = 0 : ui32}
+func.func @bounded_acc_while() -> i32 {
+ %c0 = arith.constant 0 : i32
+ %c1 = arith.constant 1 : i32
+ %c10 = arith.constant 10 : i32
+ %c100 = arith.constant 100 : i32
+ %res = scf.while (%acc = %c0) : (i32) -> i32 {
+ %cond = arith.cmpi slt, %acc, %c100 : i32
+ scf.condition(%cond) %acc : i32
+ } do {
+ ^bb0(%a: i32):
+ %incr = arith.addi %a, %c1 : i32
+ %clamped = arith.minsi %incr, %c10 : i32
+ scf.yield %clamped : i32
+ }
+ %r = test.reflect_bounds %res : i32
+ return %r : i32
+}
+
+// CHECK-LABEL: func @bounded_mask_for
+// CHECK: test.reflect_bounds {smax = 15 : si32, smin = 0 : si32, umax = 15 : ui32, umin = 0 : ui32}
+func.func @bounded_mask_for(%n: i32) -> i32 {
+ %c0 = arith.constant 0 : i32
+ %c1 = arith.constant 1 : i32
+ %c15 = arith.constant 15 : i32
+ %res = scf.for %i = %c0 to %n step %c1 iter_args(%acc = %c0) -> i32 : i32 {
+ %incr = arith.addi %acc, %c1 : i32
+ %masked = arith.andi %incr, %c15 : i32
+ scf.yield %masked : i32
+ }
+ %r = test.reflect_bounds %res : i32
+ return %r : i32
+}
More information about the Mlir-commits
mailing list