[Mlir-commits] [mlir] Fix int range analysis (PR #192235)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Apr 15 03:51:26 PDT 2026


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-arith

Author: smoothsmooth

<details>
<summary>Changes</summary>

The pattern in the tests file will hang when passing through int-range-optimizations.
The fix is to set the maximum lattice updates per (loop, element) before forcing max-range to fix the hang issue.

---
Full diff: https://github.com/llvm/llvm-project/pull/192235.diff


3 Files Affected:

- (modified) mlir/include/mlir/Analysis/DataFlow/IntegerRangeAnalysis.h (+23) 
- (modified) mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp (+81) 
- (added) mlir/test/Dialect/Arith/int-range-analysis-convergence.mlir (+70) 


``````````diff
diff --git a/mlir/include/mlir/Analysis/DataFlow/IntegerRangeAnalysis.h b/mlir/include/mlir/Analysis/DataFlow/IntegerRangeAnalysis.h
index 9820a91291fdb..a6d914dfae4ab 100644
--- a/mlir/include/mlir/Analysis/DataFlow/IntegerRangeAnalysis.h
+++ b/mlir/include/mlir/Analysis/DataFlow/IntegerRangeAnalysis.h
@@ -61,6 +61,21 @@ class IntegerRangeAnalysis
                  ArrayRef<const IntegerValueRangeLattice *> operands,
                  ArrayRef<IntegerValueRangeLattice *> results) override;
 
+  // Override visitRegionSuccessors to add a visit cap on loop-carried
+  // lattice elements, preventing non-convergence on scf.while loops with
+  // dynamic bounds.
+  //
+  // Without this cap, loop-carried values whose ranges grow by +1 per
+  // worklist visit (e.g. [0,0]->[0,1]->[0,2]->...) require O(2^31)
+  // iterations to converge for i32.  The existing widening in
+  // visitOperation only catches op results yielded directly to a
+  // terminator, not values propagated through nested region ops like
+  // scf.if.
+  void visitRegionSuccessors(ProgramPoint *point, 
+                             RegionBranchOpInterface branch,
+                             RegionSuccessor successor,
+                             ArrayRef<AbstractSparseLattice *> lattices) override;
+  
   /// Visit block arguments or operation results of an operation with region
   /// control-flow for which values are not defined by region control-flow. This
   /// function calls `InferIntRangeInterface` to provide values for block
@@ -70,6 +85,14 @@ class IntegerRangeAnalysis
       Operation *op, const RegionSuccessor &successor,
       ValueRange nonSuccessorInputs,
       ArrayRef<IntegerValueRangeLattice *> nonSuccessorInputLattices) override;
+
+private:
+  // Maximum lattice updates per (loop, element) before forcing max-range.
+  static constexpr int64_t kMaxLoopVisits = 4;
+
+  // Per-(loop-op, lattice-element) visit counter.
+  DenseMap<std::pair<Operation *, AbstractSparseLattice *>, int64_t>
+      loopVisits;
 };
 
 /// Succeeds if an op can be converted to its unsigned equivalent without
diff --git a/mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp b/mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp
index 818450e2bc696..c7d053d5d0cd7 100644
--- a/mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp
+++ b/mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp
@@ -137,6 +137,87 @@ LogicalResult IntegerRangeAnalysis::visitOperation(
   return success();
 }
 
+void IntegerRangeAnalysis::visitRegionSuccessors(
+    ProgramPoint *point, RegionBranchOpInterface branch,
+    RegionSuccessor successor,
+    ArrayRef<AbstractSparseLattice *> lattices) {
+
+  Operation *branchOp = branch.getOperation();
+  bool isLoop = isa<LoopLikeOpInterface>(branchOp);
+
+  // For non-loop regions (scf.if, scf.index_switch, etc.), delegate to
+  // the upstream implementation — no visit cap needed.
+  if (!isLoop) {
+    SparseForwardDataFlowAnalysis::visitRegionSuccessors(
+        point, branch, successor, lattices);
+    return;
+  }
+
+  // For loops: replicate upstream logic with a visit cap.
+  const auto *predecessors =
+      getOrCreateFor<PredecessorState>(point, point);
+  assert(predecessors->allPredecessorsKnown() &&
+         "unexpected unresolved region successors");
+
+  for (Operation *op : predecessors->getKnownPredecessors()) {
+    std::optional<OperandRange> operands;
+    if (op == branch) {
+      operands = branch.getEntrySuccessorOperands(successor);
+    } else if (auto regionTerminator =
+                   dyn_cast<RegionBranchTerminatorOpInterface>(op)) {
+      operands = regionTerminator.getSuccessorOperands(successor);
+    }
+    if (!operands) {
+      setAllToEntryStates(lattices);
+      return;
+    }
+
+    ValueRange inputs = predecessors->getSuccessorInputs(op);
+    assert(inputs.size() == operands->size() &&
+           "expected the same number of successor inputs as operands");
+
+    unsigned firstIndex = 0;
+    if (inputs.size() != lattices.size()) {
+      if (successor.isParent()) {
+        if (!inputs.empty())
+          firstIndex = cast<OpResult>(inputs.front()).getResultNumber();
+      } else {
+        if (!inputs.empty())
+          firstIndex = cast<BlockArgument>(inputs.front()).getArgNumber();
+      }
+    }
+
+    for (auto [oper, lattice] :
+         llvm::zip(*operands, ArrayRef(lattices).drop_front(firstIndex))) {
+      auto key = std::make_pair(branchOp, lattice);
+      int64_t &visits = loopVisits[key];
+
+      if (visits >= kMaxLoopVisits) {
+        // Force to max-range (lattice top) — guarantees convergence.
+        auto *intLattice = static_cast<IntegerValueRangeLattice *>(lattice);
+        ChangeResult changed =
+            intLattice->join(IntegerValueRange::getMaxRange(oper));
+        propagateIfChanged(intLattice, changed);
+        LLVM_DEBUG({
+          if (changed == ChangeResult::Change) {
+            LDBG() << "Forcing max-range after " << visits << " visits for ";
+            oper.printAsOperand(llvm::dbgs(), {});
+            llvm::dbgs() << "\n";
+          }
+        });
+        continue;
+      }
+
+      // Normal join with visit tracking.
+      ChangeResult changed =
+          lattice->join(*getLatticeElementFor(point, oper));
+      propagateIfChanged(lattice, changed);
+      if (changed == ChangeResult::Change)
+        ++visits;
+    }
+  }
+}
+
 void IntegerRangeAnalysis::visitNonControlFlowArguments(
     Operation *op, const RegionSuccessor &successor,
     ValueRange nonSuccessorInputs,
diff --git a/mlir/test/Dialect/Arith/int-range-analysis-convergence.mlir b/mlir/test/Dialect/Arith/int-range-analysis-convergence.mlir
new file mode 100644
index 0000000000000..d8dabd253c698
--- /dev/null
+++ b/mlir/test/Dialect/Arith/int-range-analysis-convergence.mlir
@@ -0,0 +1,70 @@
+// IntegerRangeAnalysis non-convergence on scf.while with dynamic bounds.
+//
+// The carry range ratchets [0,0]->[0,1]->[0,2]->... without bound.
+// Two nested scf.if layers with differing arith chains (addi, muli)
+// bounded by remui create enough worklist cascade to prevent the
+// solver's back-to-back convergence shortcut from firing.
+//
+// After the fix (visit cap in visitRegionSuccessors), the analysis
+// converges in bounded time.
+//
+// RUN: mlir-opt -int-range-optimizations %s -o /dev/null
+
+func.func @grouped_gemm_while_hang(%n: i32, %flag: i1) -> i32 {
+  %c0 = arith.constant 0 : i32
+  %c1 = arith.constant 1 : i32
+  %c3 = arith.constant 3 : i32
+  %c7 = arith.constant 7 : i32
+  %c127 = arith.constant 127 : i32
+  %init = arith.cmpi slt, %c0, %n : i32
+
+  %res:2 = scf.while (%a0 = %c0, %cond = %init) : (i32, i1) -> (i32, i1) {
+    scf.condition(%cond) %a0, %cond : i32, i1
+  } do {
+  ^bb0(%b0: i32, %bc: i1):
+    %t0 = arith.addi %b0, %c1 : i32
+    %ic = arith.cmpi slt, %t0, %n : i32
+
+    %inner:2 = scf.while (%i0 = %t0, %iic = %ic) : (i32, i1) -> (i32, i1) {
+      scf.condition(%iic) %i0, %iic : i32, i1
+    } do {
+    ^bb1(%j0: i32, %jc: i1):
+
+      // Layer 0: branches must differ to prevent folding.
+      // remui bounds ranges to [0,126], preventing overflow-cascade
+      // convergence.  Both branches must have ops (not just passthrough)
+      // to generate enough worklist items.
+      %L0 = scf.if %flag -> (i32) {
+        %a0_0 = arith.addi %j0, %c1 : i32
+        %a0_1 = arith.muli %a0_0, %c7 : i32
+        %a0_r = arith.remui %a0_1, %c127 : i32
+        scf.yield %a0_r : i32
+      } else {
+        %b0_0 = arith.addi %j0, %c3 : i32
+        %b0_1 = arith.muli %b0_0, %c7 : i32
+        %b0_r = arith.remui %b0_1, %c127 : i32
+        scf.yield %b0_r : i32
+      }
+
+      // Layer 1: second nested scf.if feeds from layer 0.
+      %L1 = scf.if %flag -> (i32) {
+        %a1_0 = arith.addi %L0, %c1 : i32
+        %a1_1 = arith.muli %a1_0, %c7 : i32
+        %a1_r = arith.remui %a1_1, %c127 : i32
+        scf.yield %a1_r : i32
+      } else {
+        %b1_0 = arith.addi %L0, %c3 : i32
+        %b1_1 = arith.muli %b1_0, %c7 : i32
+        %b1_r = arith.remui %b1_1, %c127 : i32
+        scf.yield %b1_r : i32
+      }
+
+      %nic = arith.cmpi slt, %L1, %n : i32
+      scf.yield %L1, %nic : i32, i1
+    }
+
+    %nc = arith.cmpi slt, %inner#0, %n : i32
+    scf.yield %inner#0, %nc : i32, i1
+  }
+  return %res#0 : i32
+}

``````````

</details>


https://github.com/llvm/llvm-project/pull/192235


More information about the Mlir-commits mailing list