[Mlir-commits] [mlir] [mlir][dataflow] Overload visitNonControlFlowArguments (PR #178383)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Jan 28 01:14:38 PST 2026


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: lonely eagle (linuxlonelyeagle)

<details>
<summary>Changes</summary>

This PR distinguishes between the two scenarios in visitNonControlFlowArguments:
* The original visitNonControlFlowArguments included `RegionSuccessor` in its parameters, which was used to handle RegionBranchOpInterface.
* The new visitNonControlFlowArguments removes `RegionSuccessor` and replaces it with `Region*`. We use it to infer lattices for entry block arguments of Region Ops that do not implement `RegionBranchOpInterface`(We cannot construct a RegionSuccessor for ops that do not implement RegionBranchOpInterface).

RFC:  https://discourse.llvm.org/t/rfc-drop-the-firstindex-argument-of-visitnoncontrolflowarguments-of-sparseforwarddataflowanalysis/89419/5


---
Full diff: https://github.com/llvm/llvm-project/pull/178383.diff


4 Files Affected:

- (modified) mlir/include/mlir/Analysis/DataFlow/IntegerRangeAnalysis.h (+8-2) 
- (modified) mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h (+26) 
- (modified) mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp (+45-45) 
- (modified) mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp (+1-1) 


``````````diff
diff --git a/mlir/include/mlir/Analysis/DataFlow/IntegerRangeAnalysis.h b/mlir/include/mlir/Analysis/DataFlow/IntegerRangeAnalysis.h
index 9820a91291fdb..2ebf63fb8833b 100644
--- a/mlir/include/mlir/Analysis/DataFlow/IntegerRangeAnalysis.h
+++ b/mlir/include/mlir/Analysis/DataFlow/IntegerRangeAnalysis.h
@@ -63,13 +63,19 @@ class IntegerRangeAnalysis
 
   /// Visit block arguments or operation results of an operation with region
   /// control-flow for which values are not defined by region control-flow. This
-  /// function calls `InferIntRangeInterface` to provide values for block
-  /// arguments or tries to reduce the range on loop induction variables with
+  /// function tries to reduce the range on loop induction variables with
   /// known bounds.
   void visitNonControlFlowArguments(
       Operation *op, const RegionSuccessor &successor,
       ValueRange nonSuccessorInputs,
       ArrayRef<IntegerValueRangeLattice *> nonSuccessorInputLattices) override;
+
+  /// This function calls `InferIntRangeInterface` to provide values for entry
+  /// block arguments where the parentOp does not implement
+  /// `RegionBranchOpInterface` (e.g., gpu.launch).
+  void visitNonControlFlowArguments(
+      Operation *op, Region *const region, ValueRange arguments,
+      ArrayRef<IntegerValueRangeLattice *> argLattices) override;
 };
 
 /// Succeeds if an op can be converted to its unsigned equivalent without
diff --git a/mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h b/mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h
index df50d8d193aeb..fb21c5bbb1310 100644
--- a/mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h
+++ b/mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h
@@ -218,6 +218,13 @@ class AbstractSparseForwardDataFlowAnalysis : public DataFlowAnalysis {
       ValueRange nonSuccessorInputs,
       ArrayRef<AbstractSparseLattice *> nonSuccessorInputLattices) = 0;
 
+  /// Given an operation with region non-control-flow, the lattices of the entry
+  /// block arguments, compute the lattice values for block arguments.(ex. the
+  /// block arguments of gpu.launch).
+  virtual void visitNonControlFlowArgumentsImpl(
+      Operation *op, Region *const region, ValueRange arguments,
+      ArrayRef<AbstractSparseLattice *> argLattices) = 0;
+
   /// Get the lattice element of a value.
   virtual AbstractSparseLattice *getLatticeElement(Value value) = 0;
 
@@ -335,6 +342,16 @@ class SparseForwardDataFlowAnalysis
     setAllToEntryStates(nonSuccessorInputLattices);
   }
 
+  /// Given an operation with region non-control-flow, the lattices of the entry
+  /// block arguments, compute the lattice values for block arguments.(ex. the
+  /// block argument of gpu.launch). By default, this method marks all lattice
+  /// elements as having reached a pessimistic fixpoint.
+  virtual void visitNonControlFlowArguments(Operation *op, Region *const region,
+                                            ValueRange arguments,
+                                            ArrayRef<StateT *> argLattices) {
+    setAllToEntryStates(argLattices);
+  }
+
 protected:
   /// Get the lattice element for a value.
   StateT *getLatticeElement(Value value) override {
@@ -391,6 +408,15 @@ class SparseForwardDataFlowAnalysis
          nonSuccessorInputLattices.size()});
   }
 
+  virtual void visitNonControlFlowArgumentsImpl(
+      Operation *op, Region *const region, ValueRange arguments,
+      ArrayRef<AbstractSparseLattice *> argLattices) override {
+    visitNonControlFlowArguments(
+        op, region, arguments,
+        {reinterpret_cast<StateT *const *>(argLattices.begin()),
+         argLattices.size()});
+  }
+
   void setToEntryState(AbstractSparseLattice *lattice) override {
     return setToEntryState(reinterpret_cast<StateT *>(lattice));
   }
diff --git a/mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp b/mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp
index 7b567f043577a..79f31ea311211 100644
--- a/mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp
+++ b/mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp
@@ -143,50 +143,6 @@ void IntegerRangeAnalysis::visitNonControlFlowArguments(
     ArrayRef<IntegerValueRangeLattice *> nonSuccessorInputLattices) {
   assert(nonSuccessorInputs.size() == nonSuccessorInputLattices.size() &&
          "size mismatch");
-  if (auto inferrable = dyn_cast<InferIntRangeInterface>(op)) {
-    LDBG() << "Inferring ranges for "
-           << OpWithFlags(op, OpPrintingFlags().skipRegions());
-
-    auto argRanges = llvm::map_to_vector(op->getOperands(), [&](Value value) {
-      return getLatticeElementFor(getProgramPointAfter(op), value)->getValue();
-    });
-
-    auto joinCallback = [&](Value v, const IntegerValueRange &attrs) {
-      auto arg = dyn_cast<BlockArgument>(v);
-      if (!arg)
-        return;
-      if (!llvm::is_contained(successor.getSuccessor()->getArguments(), arg))
-        return;
-
-      LDBG() << "Inferred range " << attrs;
-      auto it = llvm::find(successor.getSuccessor()->getArguments(), arg);
-      unsigned nonSuccessorInputIdx =
-          std::distance(successor.getSuccessor()->getArguments().begin(), it);
-      IntegerValueRangeLattice *lattice =
-          nonSuccessorInputLattices[nonSuccessorInputIdx];
-      IntegerValueRange oldRange = lattice->getValue();
-
-      ChangeResult changed = lattice->join(attrs);
-
-      // Catch loop results with loop variant bounds and conservatively make
-      // them [-inf, inf] so we don't circle around infinitely often (because
-      // the dataflow analysis in MLIR doesn't attempt to work out trip counts
-      // and often can't).
-      bool isYieldedValue = llvm::any_of(v.getUsers(), [](Operation *op) {
-        return op->hasTrait<OpTrait::IsTerminator>();
-      });
-      if (isYieldedValue && !oldRange.isUninitialized() &&
-          !(lattice->getValue() == oldRange)) {
-        LDBG() << "Loop variant loop result detected";
-        changed |= lattice->join(IntegerValueRange::getMaxRange(v));
-      }
-      propagateIfChanged(lattice, changed);
-    };
-
-    inferrable.inferResultRangesFromOptional(argRanges, joinCallback);
-    return;
-  }
-
   /// Given a lower bound, upper bound, or step from a LoopLikeInterface return
   /// the lower/upper bound for that result if possible.
   auto getLoopBoundFromFold = [&](OpFoldResult loopBound, Type boundType,
@@ -251,7 +207,51 @@ void IntegerRangeAnalysis::visitNonControlFlowArguments(
     }
     return;
   }
-
   return SparseForwardDataFlowAnalysis::visitNonControlFlowArguments(
       op, successor, nonSuccessorInputs, nonSuccessorInputLattices);
 }
+
+void IntegerRangeAnalysis::visitNonControlFlowArguments(
+    Operation *op, Region *const region, ValueRange arguments,
+    ArrayRef<IntegerValueRangeLattice *> argLattices) {
+  assert(arguments.size() == argLattices.size() && "size mismatch");
+  if (auto inferrable = dyn_cast<InferIntRangeInterface>(op)) {
+    LDBG() << "Inferring ranges for "
+           << OpWithFlags(op, OpPrintingFlags().skipRegions());
+
+    auto argRanges = llvm::map_to_vector(op->getOperands(), [&](Value value) {
+      return getLatticeElementFor(getProgramPointAfter(op), value)->getValue();
+    });
+
+    auto joinCallback = [&](Value v, const IntegerValueRange &attrs) {
+      auto arg = dyn_cast<BlockArgument>(v);
+      if (!arg)
+        return;
+      if (!llvm::is_contained(arguments, arg))
+        return;
+
+      LDBG() << "Inferred range " << attrs;
+      auto it = llvm::find(arguments, arg);
+      unsigned argIndex = std::distance(arguments.begin(), it);
+      IntegerValueRangeLattice *lattice = argLattices[argIndex];
+      IntegerValueRange oldRange = lattice->getValue();
+
+      ChangeResult changed = lattice->join(attrs);
+
+      // Catch loop results with loop variant bounds and conservatively make
+      // them [-inf, inf] so we don't circle around infinitely often (because
+      // the dataflow analysis in MLIR doesn't attempt to work out trip counts
+      // and often can't).
+      bool isYieldedValue = llvm::any_of(v.getUsers(), [](Operation *op) {
+        return op->hasTrait<OpTrait::IsTerminator>();
+      });
+      if (isYieldedValue && !oldRange.isUninitialized() &&
+          !(lattice->getValue() == oldRange)) {
+        LDBG() << "Loop variant loop result detected";
+        changed |= lattice->join(IntegerValueRange::getMaxRange(v));
+      }
+      propagateIfChanged(lattice, changed);
+    };
+    inferrable.inferResultRangesFromOptional(argRanges, joinCallback);
+  }
+}
diff --git a/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp b/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp
index 90f2a588d1ca4..b583231aca9af 100644
--- a/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp
+++ b/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp
@@ -187,7 +187,7 @@ void AbstractSparseForwardDataFlowAnalysis::visitBlock(Block *block) {
 
     // All block arguments are non-successor-inputs.
     return visitNonControlFlowArgumentsImpl(block->getParentOp(),
-                                            RegionSuccessor(block->getParent()),
+                                            block->getParent(),
                                             block->getArguments(), argLattices);
   }
 

``````````

</details>


https://github.com/llvm/llvm-project/pull/178383


More information about the Mlir-commits mailing list