[Mlir-commits] [mlir] [mlir][dataflow] Overload visitNonControlFlowArguments (PR #178383)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Jan 28 01:14:38 PST 2026
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: lonely eagle (linuxlonelyeagle)
<details>
<summary>Changes</summary>
This PR distinguishes between the two scenarios in visitNonControlFlowArguments:
* The original visitNonControlFlowArguments included `RegionSuccessor` in its parameters, which was used to handle RegionBranchOpInterface.
* The new visitNonControlFlowArguments removes `RegionSuccessor` and replaces it with `Region*`. We use it to infer lattices for entry block arguments of Region Ops that do not implement `RegionBranchOpInterface`(We cannot construct a RegionSuccessor for ops that do not implement RegionBranchOpInterface).
RFC: https://discourse.llvm.org/t/rfc-drop-the-firstindex-argument-of-visitnoncontrolflowarguments-of-sparseforwarddataflowanalysis/89419/5
---
Full diff: https://github.com/llvm/llvm-project/pull/178383.diff
4 Files Affected:
- (modified) mlir/include/mlir/Analysis/DataFlow/IntegerRangeAnalysis.h (+8-2)
- (modified) mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h (+26)
- (modified) mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp (+45-45)
- (modified) mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp (+1-1)
``````````diff
diff --git a/mlir/include/mlir/Analysis/DataFlow/IntegerRangeAnalysis.h b/mlir/include/mlir/Analysis/DataFlow/IntegerRangeAnalysis.h
index 9820a91291fdb..2ebf63fb8833b 100644
--- a/mlir/include/mlir/Analysis/DataFlow/IntegerRangeAnalysis.h
+++ b/mlir/include/mlir/Analysis/DataFlow/IntegerRangeAnalysis.h
@@ -63,13 +63,19 @@ class IntegerRangeAnalysis
/// Visit block arguments or operation results of an operation with region
/// control-flow for which values are not defined by region control-flow. This
- /// function calls `InferIntRangeInterface` to provide values for block
- /// arguments or tries to reduce the range on loop induction variables with
+ /// function tries to reduce the range on loop induction variables with
/// known bounds.
void visitNonControlFlowArguments(
Operation *op, const RegionSuccessor &successor,
ValueRange nonSuccessorInputs,
ArrayRef<IntegerValueRangeLattice *> nonSuccessorInputLattices) override;
+
+ /// This function calls `InferIntRangeInterface` to provide values for entry
+ /// block arguments where the parentOp does not implement
+ /// `RegionBranchOpInterface` (e.g., gpu.launch).
+ void visitNonControlFlowArguments(
+ Operation *op, Region *const region, ValueRange arguments,
+ ArrayRef<IntegerValueRangeLattice *> argLattices) override;
};
/// Succeeds if an op can be converted to its unsigned equivalent without
diff --git a/mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h b/mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h
index df50d8d193aeb..fb21c5bbb1310 100644
--- a/mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h
+++ b/mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h
@@ -218,6 +218,13 @@ class AbstractSparseForwardDataFlowAnalysis : public DataFlowAnalysis {
ValueRange nonSuccessorInputs,
ArrayRef<AbstractSparseLattice *> nonSuccessorInputLattices) = 0;
+ /// Given an operation with region non-control-flow, the lattices of the entry
+ /// block arguments, compute the lattice values for block arguments.(ex. the
+ /// block arguments of gpu.launch).
+ virtual void visitNonControlFlowArgumentsImpl(
+ Operation *op, Region *const region, ValueRange arguments,
+ ArrayRef<AbstractSparseLattice *> argLattices) = 0;
+
/// Get the lattice element of a value.
virtual AbstractSparseLattice *getLatticeElement(Value value) = 0;
@@ -335,6 +342,16 @@ class SparseForwardDataFlowAnalysis
setAllToEntryStates(nonSuccessorInputLattices);
}
+ /// Given an operation with region non-control-flow, the lattices of the entry
+ /// block arguments, compute the lattice values for block arguments.(ex. the
+ /// block argument of gpu.launch). By default, this method marks all lattice
+ /// elements as having reached a pessimistic fixpoint.
+ virtual void visitNonControlFlowArguments(Operation *op, Region *const region,
+ ValueRange arguments,
+ ArrayRef<StateT *> argLattices) {
+ setAllToEntryStates(argLattices);
+ }
+
protected:
/// Get the lattice element for a value.
StateT *getLatticeElement(Value value) override {
@@ -391,6 +408,15 @@ class SparseForwardDataFlowAnalysis
nonSuccessorInputLattices.size()});
}
+ virtual void visitNonControlFlowArgumentsImpl(
+ Operation *op, Region *const region, ValueRange arguments,
+ ArrayRef<AbstractSparseLattice *> argLattices) override {
+ visitNonControlFlowArguments(
+ op, region, arguments,
+ {reinterpret_cast<StateT *const *>(argLattices.begin()),
+ argLattices.size()});
+ }
+
void setToEntryState(AbstractSparseLattice *lattice) override {
return setToEntryState(reinterpret_cast<StateT *>(lattice));
}
diff --git a/mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp b/mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp
index 7b567f043577a..79f31ea311211 100644
--- a/mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp
+++ b/mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp
@@ -143,50 +143,6 @@ void IntegerRangeAnalysis::visitNonControlFlowArguments(
ArrayRef<IntegerValueRangeLattice *> nonSuccessorInputLattices) {
assert(nonSuccessorInputs.size() == nonSuccessorInputLattices.size() &&
"size mismatch");
- if (auto inferrable = dyn_cast<InferIntRangeInterface>(op)) {
- LDBG() << "Inferring ranges for "
- << OpWithFlags(op, OpPrintingFlags().skipRegions());
-
- auto argRanges = llvm::map_to_vector(op->getOperands(), [&](Value value) {
- return getLatticeElementFor(getProgramPointAfter(op), value)->getValue();
- });
-
- auto joinCallback = [&](Value v, const IntegerValueRange &attrs) {
- auto arg = dyn_cast<BlockArgument>(v);
- if (!arg)
- return;
- if (!llvm::is_contained(successor.getSuccessor()->getArguments(), arg))
- return;
-
- LDBG() << "Inferred range " << attrs;
- auto it = llvm::find(successor.getSuccessor()->getArguments(), arg);
- unsigned nonSuccessorInputIdx =
- std::distance(successor.getSuccessor()->getArguments().begin(), it);
- IntegerValueRangeLattice *lattice =
- nonSuccessorInputLattices[nonSuccessorInputIdx];
- IntegerValueRange oldRange = lattice->getValue();
-
- ChangeResult changed = lattice->join(attrs);
-
- // Catch loop results with loop variant bounds and conservatively make
- // them [-inf, inf] so we don't circle around infinitely often (because
- // the dataflow analysis in MLIR doesn't attempt to work out trip counts
- // and often can't).
- bool isYieldedValue = llvm::any_of(v.getUsers(), [](Operation *op) {
- return op->hasTrait<OpTrait::IsTerminator>();
- });
- if (isYieldedValue && !oldRange.isUninitialized() &&
- !(lattice->getValue() == oldRange)) {
- LDBG() << "Loop variant loop result detected";
- changed |= lattice->join(IntegerValueRange::getMaxRange(v));
- }
- propagateIfChanged(lattice, changed);
- };
-
- inferrable.inferResultRangesFromOptional(argRanges, joinCallback);
- return;
- }
-
/// Given a lower bound, upper bound, or step from a LoopLikeInterface return
/// the lower/upper bound for that result if possible.
auto getLoopBoundFromFold = [&](OpFoldResult loopBound, Type boundType,
@@ -251,7 +207,51 @@ void IntegerRangeAnalysis::visitNonControlFlowArguments(
}
return;
}
-
return SparseForwardDataFlowAnalysis::visitNonControlFlowArguments(
op, successor, nonSuccessorInputs, nonSuccessorInputLattices);
}
+
+void IntegerRangeAnalysis::visitNonControlFlowArguments(
+ Operation *op, Region *const region, ValueRange arguments,
+ ArrayRef<IntegerValueRangeLattice *> argLattices) {
+ assert(arguments.size() == argLattices.size() && "size mismatch");
+ if (auto inferrable = dyn_cast<InferIntRangeInterface>(op)) {
+ LDBG() << "Inferring ranges for "
+ << OpWithFlags(op, OpPrintingFlags().skipRegions());
+
+ auto argRanges = llvm::map_to_vector(op->getOperands(), [&](Value value) {
+ return getLatticeElementFor(getProgramPointAfter(op), value)->getValue();
+ });
+
+ auto joinCallback = [&](Value v, const IntegerValueRange &attrs) {
+ auto arg = dyn_cast<BlockArgument>(v);
+ if (!arg)
+ return;
+ if (!llvm::is_contained(arguments, arg))
+ return;
+
+ LDBG() << "Inferred range " << attrs;
+ auto it = llvm::find(arguments, arg);
+ unsigned argIndex = std::distance(arguments.begin(), it);
+ IntegerValueRangeLattice *lattice = argLattices[argIndex];
+ IntegerValueRange oldRange = lattice->getValue();
+
+ ChangeResult changed = lattice->join(attrs);
+
+ // Catch loop results with loop variant bounds and conservatively make
+ // them [-inf, inf] so we don't circle around infinitely often (because
+ // the dataflow analysis in MLIR doesn't attempt to work out trip counts
+ // and often can't).
+ bool isYieldedValue = llvm::any_of(v.getUsers(), [](Operation *op) {
+ return op->hasTrait<OpTrait::IsTerminator>();
+ });
+ if (isYieldedValue && !oldRange.isUninitialized() &&
+ !(lattice->getValue() == oldRange)) {
+ LDBG() << "Loop variant loop result detected";
+ changed |= lattice->join(IntegerValueRange::getMaxRange(v));
+ }
+ propagateIfChanged(lattice, changed);
+ };
+ inferrable.inferResultRangesFromOptional(argRanges, joinCallback);
+ }
+}
diff --git a/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp b/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp
index 90f2a588d1ca4..b583231aca9af 100644
--- a/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp
+++ b/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp
@@ -187,7 +187,7 @@ void AbstractSparseForwardDataFlowAnalysis::visitBlock(Block *block) {
// All block arguments are non-successor-inputs.
return visitNonControlFlowArgumentsImpl(block->getParentOp(),
- RegionSuccessor(block->getParent()),
+ block->getParent(),
block->getArguments(), argLattices);
}
``````````
</details>
https://github.com/llvm/llvm-project/pull/178383
More information about the Mlir-commits
mailing list