[Mlir-commits] [mlir] Fix int range analysis (PR #192235)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Apr 15 03:51:26 PDT 2026
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
@llvm/pr-subscribers-mlir-arith
Author: smoothsmooth
<details>
<summary>Changes</summary>
The pattern in the tests file will hang when passing through int-range-optimizations.
The fix is to set the maximum lattice updates per (loop, element) before forcing max-range to fix the hang issue.
---
Full diff: https://github.com/llvm/llvm-project/pull/192235.diff
3 Files Affected:
- (modified) mlir/include/mlir/Analysis/DataFlow/IntegerRangeAnalysis.h (+23)
- (modified) mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp (+81)
- (added) mlir/test/Dialect/Arith/int-range-analysis-convergence.mlir (+70)
``````````diff
diff --git a/mlir/include/mlir/Analysis/DataFlow/IntegerRangeAnalysis.h b/mlir/include/mlir/Analysis/DataFlow/IntegerRangeAnalysis.h
index 9820a91291fdb..a6d914dfae4ab 100644
--- a/mlir/include/mlir/Analysis/DataFlow/IntegerRangeAnalysis.h
+++ b/mlir/include/mlir/Analysis/DataFlow/IntegerRangeAnalysis.h
@@ -61,6 +61,21 @@ class IntegerRangeAnalysis
ArrayRef<const IntegerValueRangeLattice *> operands,
ArrayRef<IntegerValueRangeLattice *> results) override;
+ // Override visitRegionSuccessors to add a visit cap on loop-carried
+ // lattice elements, preventing non-convergence on scf.while loops with
+ // dynamic bounds.
+ //
+ // Without this cap, loop-carried values whose ranges grow by +1 per
+ // worklist visit (e.g. [0,0]->[0,1]->[0,2]->...) require O(2^31)
+ // iterations to converge for i32. The existing widening in
+ // visitOperation only catches op results yielded directly to a
+ // terminator, not values propagated through nested region ops like
+ // scf.if.
+ void visitRegionSuccessors(ProgramPoint *point,
+ RegionBranchOpInterface branch,
+ RegionSuccessor successor,
+ ArrayRef<AbstractSparseLattice *> lattices) override;
+
/// Visit block arguments or operation results of an operation with region
/// control-flow for which values are not defined by region control-flow. This
/// function calls `InferIntRangeInterface` to provide values for block
@@ -70,6 +85,14 @@ class IntegerRangeAnalysis
Operation *op, const RegionSuccessor &successor,
ValueRange nonSuccessorInputs,
ArrayRef<IntegerValueRangeLattice *> nonSuccessorInputLattices) override;
+
+private:
+ // Maximum lattice updates per (loop, element) before forcing max-range.
+ static constexpr int64_t kMaxLoopVisits = 4;
+
+ // Per-(loop-op, lattice-element) visit counter.
+ DenseMap<std::pair<Operation *, AbstractSparseLattice *>, int64_t>
+ loopVisits;
};
/// Succeeds if an op can be converted to its unsigned equivalent without
diff --git a/mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp b/mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp
index 818450e2bc696..c7d053d5d0cd7 100644
--- a/mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp
+++ b/mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp
@@ -137,6 +137,87 @@ LogicalResult IntegerRangeAnalysis::visitOperation(
return success();
}
+void IntegerRangeAnalysis::visitRegionSuccessors(
+ ProgramPoint *point, RegionBranchOpInterface branch,
+ RegionSuccessor successor,
+ ArrayRef<AbstractSparseLattice *> lattices) {
+
+ Operation *branchOp = branch.getOperation();
+ bool isLoop = isa<LoopLikeOpInterface>(branchOp);
+
+ // For non-loop regions (scf.if, scf.index_switch, etc.), delegate to
+ // the upstream implementation — no visit cap needed.
+ if (!isLoop) {
+ SparseForwardDataFlowAnalysis::visitRegionSuccessors(
+ point, branch, successor, lattices);
+ return;
+ }
+
+ // For loops: replicate upstream logic with a visit cap.
+ const auto *predecessors =
+ getOrCreateFor<PredecessorState>(point, point);
+ assert(predecessors->allPredecessorsKnown() &&
+ "unexpected unresolved region successors");
+
+ for (Operation *op : predecessors->getKnownPredecessors()) {
+ std::optional<OperandRange> operands;
+ if (op == branch) {
+ operands = branch.getEntrySuccessorOperands(successor);
+ } else if (auto regionTerminator =
+ dyn_cast<RegionBranchTerminatorOpInterface>(op)) {
+ operands = regionTerminator.getSuccessorOperands(successor);
+ }
+ if (!operands) {
+ setAllToEntryStates(lattices);
+ return;
+ }
+
+ ValueRange inputs = predecessors->getSuccessorInputs(op);
+ assert(inputs.size() == operands->size() &&
+ "expected the same number of successor inputs as operands");
+
+ unsigned firstIndex = 0;
+ if (inputs.size() != lattices.size()) {
+ if (successor.isParent()) {
+ if (!inputs.empty())
+ firstIndex = cast<OpResult>(inputs.front()).getResultNumber();
+ } else {
+ if (!inputs.empty())
+ firstIndex = cast<BlockArgument>(inputs.front()).getArgNumber();
+ }
+ }
+
+ for (auto [oper, lattice] :
+ llvm::zip(*operands, ArrayRef(lattices).drop_front(firstIndex))) {
+ auto key = std::make_pair(branchOp, lattice);
+ int64_t &visits = loopVisits[key];
+
+ if (visits >= kMaxLoopVisits) {
+ // Force to max-range (lattice top) — guarantees convergence.
+ auto *intLattice = static_cast<IntegerValueRangeLattice *>(lattice);
+ ChangeResult changed =
+ intLattice->join(IntegerValueRange::getMaxRange(oper));
+ propagateIfChanged(intLattice, changed);
+ LLVM_DEBUG({
+ if (changed == ChangeResult::Change) {
+ LDBG() << "Forcing max-range after " << visits << " visits for ";
+ oper.printAsOperand(llvm::dbgs(), {});
+ llvm::dbgs() << "\n";
+ }
+ });
+ continue;
+ }
+
+ // Normal join with visit tracking.
+ ChangeResult changed =
+ lattice->join(*getLatticeElementFor(point, oper));
+ propagateIfChanged(lattice, changed);
+ if (changed == ChangeResult::Change)
+ ++visits;
+ }
+ }
+}
+
void IntegerRangeAnalysis::visitNonControlFlowArguments(
Operation *op, const RegionSuccessor &successor,
ValueRange nonSuccessorInputs,
diff --git a/mlir/test/Dialect/Arith/int-range-analysis-convergence.mlir b/mlir/test/Dialect/Arith/int-range-analysis-convergence.mlir
new file mode 100644
index 0000000000000..d8dabd253c698
--- /dev/null
+++ b/mlir/test/Dialect/Arith/int-range-analysis-convergence.mlir
@@ -0,0 +1,70 @@
+// IntegerRangeAnalysis non-convergence on scf.while with dynamic bounds.
+//
+// The carry range ratchets [0,0]->[0,1]->[0,2]->... without bound.
+// Two nested scf.if layers with differing arith chains (addi, muli)
+// bounded by remui create enough worklist cascade to prevent the
+// solver's back-to-back convergence shortcut from firing.
+//
+// After the fix (visit cap in visitRegionSuccessors), the analysis
+// converges in bounded time.
+//
+// RUN: mlir-opt -int-range-optimizations %s -o /dev/null
+
+func.func @grouped_gemm_while_hang(%n: i32, %flag: i1) -> i32 {
+ %c0 = arith.constant 0 : i32
+ %c1 = arith.constant 1 : i32
+ %c3 = arith.constant 3 : i32
+ %c7 = arith.constant 7 : i32
+ %c127 = arith.constant 127 : i32
+ %init = arith.cmpi slt, %c0, %n : i32
+
+ %res:2 = scf.while (%a0 = %c0, %cond = %init) : (i32, i1) -> (i32, i1) {
+ scf.condition(%cond) %a0, %cond : i32, i1
+ } do {
+ ^bb0(%b0: i32, %bc: i1):
+ %t0 = arith.addi %b0, %c1 : i32
+ %ic = arith.cmpi slt, %t0, %n : i32
+
+ %inner:2 = scf.while (%i0 = %t0, %iic = %ic) : (i32, i1) -> (i32, i1) {
+ scf.condition(%iic) %i0, %iic : i32, i1
+ } do {
+ ^bb1(%j0: i32, %jc: i1):
+
+ // Layer 0: branches must differ to prevent folding.
+ // remui bounds ranges to [0,126], preventing overflow-cascade
+ // convergence. Both branches must have ops (not just passthrough)
+ // to generate enough worklist items.
+ %L0 = scf.if %flag -> (i32) {
+ %a0_0 = arith.addi %j0, %c1 : i32
+ %a0_1 = arith.muli %a0_0, %c7 : i32
+ %a0_r = arith.remui %a0_1, %c127 : i32
+ scf.yield %a0_r : i32
+ } else {
+ %b0_0 = arith.addi %j0, %c3 : i32
+ %b0_1 = arith.muli %b0_0, %c7 : i32
+ %b0_r = arith.remui %b0_1, %c127 : i32
+ scf.yield %b0_r : i32
+ }
+
+ // Layer 1: second nested scf.if feeds from layer 0.
+ %L1 = scf.if %flag -> (i32) {
+ %a1_0 = arith.addi %L0, %c1 : i32
+ %a1_1 = arith.muli %a1_0, %c7 : i32
+ %a1_r = arith.remui %a1_1, %c127 : i32
+ scf.yield %a1_r : i32
+ } else {
+ %b1_0 = arith.addi %L0, %c3 : i32
+ %b1_1 = arith.muli %b1_0, %c7 : i32
+ %b1_r = arith.remui %b1_1, %c127 : i32
+ scf.yield %b1_r : i32
+ }
+
+ %nic = arith.cmpi slt, %L1, %n : i32
+ scf.yield %L1, %nic : i32, i1
+ }
+
+ %nc = arith.cmpi slt, %inner#0, %n : i32
+ scf.yield %inner#0, %nc : i32, i1
+ }
+ return %res#0 : i32
+}
``````````
</details>
https://github.com/llvm/llvm-project/pull/192235
More information about the Mlir-commits
mailing list