[Mlir-commits] [mlir] [mlir][dataflow] IntRange: Replace yield-based widening with per-state lattice budget (PR #196616)

Ivan Butygin llvmlistbot at llvm.org
Fri May 8 12:48:45 PDT 2026


https://github.com/Hardcode84 updated https://github.com/llvm/llvm-project/pull/196616

>From 49fff6cf3d47c20924d55c97180325e4277e96c3 Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Fri, 8 May 2026 20:43:33 +0200
Subject: [PATCH] [mlir][dataflow] Replace yield-based widening with per-state
 lattice budget

IntegerRangeAnalysis can hang on `scf.while` loops with dynamic bounds: a
loop-carried range ratchets [0,0]->[0,1]->[0,2]->... by one per worklist
visit, requiring up to 2^31 iterations on i32. The new
int-range-analysis-convergence.mlir test reproduces this.

The ratchet lives at framework merge sites (region successors, callable
args) where the solver joins lattices via virtual
`Lattice::join(const AbstractSparseLattice &)`. The pre-existing
`isYieldedResult`/`isYieldedValue` heuristic in
`IntegerRangeAnalysis::visitOperation` doesn't help: it runs in the
transfer-function callback for inferrable-op results used by a terminator,
not on the merge path. It is also harmful where it fires - slams to
maxRange on the *second* visit (after, say, [1,1]->[1,2]), so naturally
bounded accumulators (e.g. `arith.minsi`-clamped iter args) widen to
[INT_MIN, INT_MAX] today, blocking downstream folds.

Replace it with a per-state widening budget on `IntegerValueRangeLattice`:
the lattice counts merge-site joins and forces the range to its max once
the count hits kIntegerRangeWideningBudget (128). Only the virtual
overload is overridden, so transfer-function joins via the non-virtual
`join(const ValueT &)` are unaffected. The new int-range-loop-iter-args.mlir
test pins the tighter bounds; the convergence test verifies termination.
---
 .../Analysis/DataFlow/IntegerRangeAnalysis.h  | 28 ++++++++
 .../DataFlow/IntegerRangeAnalysis.cpp         | 59 +++++++----------
 .../Arith/int-range-analysis-convergence.mlir | 65 +++++++++++++++++++
 .../Arith/int-range-loop-iter-args.mlir       | 63 ++++++++++++++++++
 4 files changed, 181 insertions(+), 34 deletions(-)
 create mode 100644 mlir/test/Dialect/Arith/int-range-analysis-convergence.mlir
 create mode 100644 mlir/test/Dialect/Arith/int-range-loop-iter-args.mlir

diff --git a/mlir/include/mlir/Analysis/DataFlow/IntegerRangeAnalysis.h b/mlir/include/mlir/Analysis/DataFlow/IntegerRangeAnalysis.h
index 5b6ae9bf84265..8d75c4016355b 100644
--- a/mlir/include/mlir/Analysis/DataFlow/IntegerRangeAnalysis.h
+++ b/mlir/include/mlir/Analysis/DataFlow/IntegerRangeAnalysis.h
@@ -26,9 +26,37 @@ class RewriterBase;
 namespace dataflow {
 
 /// This lattice element represents the integer value range of an SSA value.
+///
+/// `join` overrides the base behaviour to apply per-state widening: once
+/// the lattice has absorbed enough strictly-increasing merges the range is
+/// forced to its max as a sound over-approximation. This is the sole
+/// convergence guarantee for `IntegerRangeAnalysis` on loop-carried
+/// values; without it, `scf.while` loops with dynamic bounds and nested
+/// region ops can keep the solver ratcheting a loop-carried range by +1
+/// per worklist visit for up to 2^31 iterations on i32. The budget is
+/// sized to be much larger than realistic merge counts on naturally
+/// bounded accumulators (e.g. `arith.minsi`/`arith.andi`-clamped iter
+/// args) so the analysis still converges to a tight range on those.
+///
+/// Note that only the `(const AbstractSparseLattice &)` overload is
+/// overridden, so the widening fires only at framework merge sites
+/// (block-arg / region-successor / callable-arg joins) —
+/// transfer-function updates that go through the non-virtual
+/// `join(const ValueT &)` overload are unaffected.
 class IntegerValueRangeLattice : public Lattice<IntegerValueRange> {
 public:
   using Lattice::Lattice;
+  // The override below would otherwise hide the inherited
+  // `join(const ValueT &)` overload that callers (e.g. transfer functions)
+  // rely on for direct-value joins.
+  using Lattice::join;
+
+  ChangeResult join(const AbstractSparseLattice &rhs) override;
+
+private:
+  /// Per-state merge-site change counter. Drives the widening budget in
+  /// `join`.
+  unsigned mergeChangeCount = 0;
 };
 
 /// Integer range analysis determines the integer value range of SSA values
diff --git a/mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp b/mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp
index b29fc28131806..613772c2b7404 100644
--- a/mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp
+++ b/mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp
@@ -58,6 +58,29 @@ LogicalResult staticallyNonNegative(DataFlowSolver &solver, Operation *op) {
 }
 } // namespace mlir::dataflow
 
+/// Number of merge-site joins a single integer-range lattice element is
+/// allowed to absorb before `IntegerValueRangeLattice::join` forces it to
+/// its max as a sound over-approximation.
+///
+/// Trade-off: high enough that realistic loops with dynamic bounds (which
+/// typically converge to a tight range in a small number of merge
+/// iterations) are not widened prematurely; low enough that the +1
+/// ratchet pathology this widening exists to cut off (loop-carried ranges
+/// growing by one per worklist visit) terminates after at most this many
+/// extra solver iterations rather than ~2^31.
+static constexpr unsigned kIntegerRangeWideningBudget = 128;
+
+ChangeResult IntegerValueRangeLattice::join(const AbstractSparseLattice &rhs) {
+  ChangeResult changed = Lattice::join(rhs);
+  if (mergeChangeCount >= kIntegerRangeWideningBudget) {
+    return changed | Lattice::join(IntegerValueRange::getMaxRange(
+                         cast<Value>(getAnchor())));
+  }
+  if (changed == ChangeResult::Change)
+    ++mergeChangeCount;
+  return changed;
+}
+
 LogicalResult IntegerRangeAnalysis::visitOperation(
     Operation *op, ArrayRef<const IntegerValueRangeLattice *> operands,
     ArrayRef<IntegerValueRangeLattice *> results) {
@@ -82,23 +105,7 @@ LogicalResult IntegerRangeAnalysis::visitOperation(
 
     LDBG() << "Inferred range " << attrs;
     IntegerValueRangeLattice *lattice = results[result.getResultNumber()];
-    IntegerValueRange oldRange = lattice->getValue();
-
-    ChangeResult changed = lattice->join(attrs);
-
-    // Catch loop results with loop variant bounds and conservatively make
-    // them [-inf, inf] so we don't circle around infinitely often (because
-    // the dataflow analysis in MLIR doesn't attempt to work out trip counts
-    // and often can't).
-    bool isYieldedResult = llvm::any_of(v.getUsers(), [](Operation *op) {
-      return op->hasTrait<OpTrait::IsTerminator>();
-    });
-    if (isYieldedResult && !oldRange.isUninitialized() &&
-        !(lattice->getValue() == oldRange)) {
-      LDBG() << "Loop variant loop result detected";
-      changed |= lattice->join(IntegerValueRange::getMaxRange(v));
-    }
-    propagateIfChanged(lattice, changed);
+    propagateIfChanged(lattice, lattice->join(attrs));
   };
 
   inferrable.inferResultRangesFromOptional(argRanges, joinCallback);
@@ -132,23 +139,7 @@ void IntegerRangeAnalysis::visitNonControlFlowArguments(
           std::distance(successor.getSuccessor()->getArguments().begin(), it);
       IntegerValueRangeLattice *lattice =
           nonSuccessorInputLattices[nonSuccessorInputIdx];
-      IntegerValueRange oldRange = lattice->getValue();
-
-      ChangeResult changed = lattice->join(attrs);
-
-      // Catch loop results with loop variant bounds and conservatively make
-      // them [-inf, inf] so we don't circle around infinitely often (because
-      // the dataflow analysis in MLIR doesn't attempt to work out trip counts
-      // and often can't).
-      bool isYieldedValue = llvm::any_of(v.getUsers(), [](Operation *op) {
-        return op->hasTrait<OpTrait::IsTerminator>();
-      });
-      if (isYieldedValue && !oldRange.isUninitialized() &&
-          !(lattice->getValue() == oldRange)) {
-        LDBG() << "Loop variant loop result detected";
-        changed |= lattice->join(IntegerValueRange::getMaxRange(v));
-      }
-      propagateIfChanged(lattice, changed);
+      propagateIfChanged(lattice, lattice->join(attrs));
     };
 
     inferrable.inferResultRangesFromOptional(argRanges, joinCallback);
diff --git a/mlir/test/Dialect/Arith/int-range-analysis-convergence.mlir b/mlir/test/Dialect/Arith/int-range-analysis-convergence.mlir
new file mode 100644
index 0000000000000..75b75546c0051
--- /dev/null
+++ b/mlir/test/Dialect/Arith/int-range-analysis-convergence.mlir
@@ -0,0 +1,65 @@
+// IntegerRangeAnalysis non-convergence on scf.while with dynamic bounds.
+//
+// The carry range ratchets [0,0]->[0,1]->[0,2]->... without bound.
+// Two nested scf.if layers with differing arith chains (addi, muli)
+// bounded by remui create enough worklist cascade to prevent the
+// solver's back-to-back convergence shortcut from firing.
+//
+// With framework-level merge-site widening, the analysis converges
+// in bounded time.
+//
+// RUN: mlir-opt --int-range-optimizations %s -o /dev/null
+
+func.func @grouped_gemm_while_hang(%n: i32, %flag: i1) -> i32 {
+  %c0 = arith.constant 0 : i32
+  %c1 = arith.constant 1 : i32
+  %c3 = arith.constant 3 : i32
+  %c7 = arith.constant 7 : i32
+  %c127 = arith.constant 127 : i32
+  %init = arith.cmpi slt, %c0, %n : i32
+
+  %res:2 = scf.while (%a0 = %c0, %cond = %init) : (i32, i1) -> (i32, i1) {
+    scf.condition(%cond) %a0, %cond : i32, i1
+  } do {
+  ^bb0(%b0: i32, %bc: i1):
+    %t0 = arith.addi %b0, %c1 : i32
+    %ic = arith.cmpi slt, %t0, %n : i32
+
+    %inner:2 = scf.while (%i0 = %t0, %iic = %ic) : (i32, i1) -> (i32, i1) {
+      scf.condition(%iic) %i0, %iic : i32, i1
+    } do {
+    ^bb1(%j0: i32, %jc: i1):
+
+      %L0 = scf.if %flag -> (i32) {
+        %a0_0 = arith.addi %j0, %c1 : i32
+        %a0_1 = arith.muli %a0_0, %c7 : i32
+        %a0_r = arith.remui %a0_1, %c127 : i32
+        scf.yield %a0_r : i32
+      } else {
+        %b0_0 = arith.addi %j0, %c3 : i32
+        %b0_1 = arith.muli %b0_0, %c7 : i32
+        %b0_r = arith.remui %b0_1, %c127 : i32
+        scf.yield %b0_r : i32
+      }
+
+      %L1 = scf.if %flag -> (i32) {
+        %a1_0 = arith.addi %L0, %c1 : i32
+        %a1_1 = arith.muli %a1_0, %c7 : i32
+        %a1_r = arith.remui %a1_1, %c127 : i32
+        scf.yield %a1_r : i32
+      } else {
+        %b1_0 = arith.addi %L0, %c3 : i32
+        %b1_1 = arith.muli %b1_0, %c7 : i32
+        %b1_r = arith.remui %b1_1, %c127 : i32
+        scf.yield %b1_r : i32
+      }
+
+      %nic = arith.cmpi slt, %L1, %n : i32
+      scf.yield %L1, %nic : i32, i1
+    }
+
+    %nc = arith.cmpi slt, %inner#0, %n : i32
+    scf.yield %inner#0, %nc : i32, i1
+  }
+  return %res#0 : i32
+}
diff --git a/mlir/test/Dialect/Arith/int-range-loop-iter-args.mlir b/mlir/test/Dialect/Arith/int-range-loop-iter-args.mlir
new file mode 100644
index 0000000000000..24801875e257c
--- /dev/null
+++ b/mlir/test/Dialect/Arith/int-range-loop-iter-args.mlir
@@ -0,0 +1,63 @@
+// RUN: mlir-opt --int-range-optimizations %s | FileCheck %s
+
+// Verify that `IntegerRangeAnalysis` infers tight bounds for loop-carried
+// values that are structurally bounded inside the loop body (via
+// `arith.minsi`, `arith.andi`, etc.). Convergence is guaranteed by the
+// per-state widening budget on `IntegerValueRangeLattice`; the budget is
+// large enough that these naturally bounded ratchets reach a fixpoint
+// without being widened to `[INT_MIN, INT_MAX]`.
+
+// CHECK-LABEL: func @bounded_acc_for
+// CHECK: test.reflect_bounds {smax = 10 : si32, smin = 0 : si32, umax = 10 : ui32, umin = 0 : ui32}
+func.func @bounded_acc_for(%n: i32) -> i32 {
+  %c0 = arith.constant 0 : i32
+  %c1 = arith.constant 1 : i32
+  %c10 = arith.constant 10 : i32
+  %res = scf.for %i = %c0 to %n step %c1 iter_args(%acc = %c0) -> i32  : i32 {
+    %incr = arith.addi %acc, %c1 : i32
+    %clamped = arith.minsi %incr, %c10 : i32
+    scf.yield %clamped : i32
+  }
+  %r = test.reflect_bounds %res : i32
+  return %r : i32
+}
+
+// The `arith.cmpi slt, %acc, 100` should fold to `true` once the analysis
+// proves the iter arg stays in `[0, 10]`, exposing a downstream
+// optimization that the previous yield-based widening masked.
+// CHECK-LABEL: func @bounded_acc_while
+// CHECK: %[[TRUE:.*]] = arith.constant true
+// CHECK: scf.condition(%[[TRUE]])
+// CHECK: test.reflect_bounds {smax = 10 : si32, smin = 0 : si32, umax = 10 : ui32, umin = 0 : ui32}
+func.func @bounded_acc_while() -> i32 {
+  %c0 = arith.constant 0 : i32
+  %c1 = arith.constant 1 : i32
+  %c10 = arith.constant 10 : i32
+  %c100 = arith.constant 100 : i32
+  %res = scf.while (%acc = %c0) : (i32) -> i32 {
+    %cond = arith.cmpi slt, %acc, %c100 : i32
+    scf.condition(%cond) %acc : i32
+  } do {
+  ^bb0(%a: i32):
+    %incr = arith.addi %a, %c1 : i32
+    %clamped = arith.minsi %incr, %c10 : i32
+    scf.yield %clamped : i32
+  }
+  %r = test.reflect_bounds %res : i32
+  return %r : i32
+}
+
+// CHECK-LABEL: func @bounded_mask_for
+// CHECK: test.reflect_bounds {smax = 15 : si32, smin = 0 : si32, umax = 15 : ui32, umin = 0 : ui32}
+func.func @bounded_mask_for(%n: i32) -> i32 {
+  %c0 = arith.constant 0 : i32
+  %c1 = arith.constant 1 : i32
+  %c15 = arith.constant 15 : i32
+  %res = scf.for %i = %c0 to %n step %c1 iter_args(%acc = %c0) -> i32  : i32 {
+    %incr = arith.addi %acc, %c1 : i32
+    %masked = arith.andi %incr, %c15 : i32
+    scf.yield %masked : i32
+  }
+  %r = test.reflect_bounds %res : i32
+  return %r : i32
+}



More information about the Mlir-commits mailing list