[Mlir-commits] [mlir] 5d8813d - [mlir] allow dense dataflow to customize call and region operations

Alex Zinenko llvmlistbot at llvm.org
Fri Jul 21 02:16:09 PDT 2023


Author: Alex Zinenko
Date: 2023-07-21T09:16:03Z
New Revision: 5d8813dec69360fce897f063a4a65106ae8ea22b

URL: https://github.com/llvm/llvm-project/commit/5d8813dec69360fce897f063a4a65106ae8ea22b
DIFF: https://github.com/llvm/llvm-project/commit/5d8813dec69360fce897f063a4a65106ae8ea22b.diff

LOG: [mlir] allow dense dataflow to customize call and region operations

Initial implementations of dense dataflow analyses feature special cases
for operations that have region- or call-based control flow by
leveraging the corresponding interfaces. This is not necessarily
sufficient as these operations may influence the dataflow state by
themselves as well we through the control flow. For example,
`linalg.generic` and similar operations have region-based control flow
and their proper memory effects, so any memory-related analyses such as
last-writer require processing `linalg.generic` directly instead of, or
in addition to, the region-based flow.

Provide hooks to customize the processing of operations with region-
cand call-based contol flow in forward and backward dense dataflow
analysis. These hooks are trigerred when control flow is transferred
between the "main" operation, i.e. the call or the region owner, and
another region. Such an apporach allows the analyses to update the
lattice before and/or after the regions. In the `linalg.generic`
example, the reads from memory are interpreted as happening before the
body region and the writes to memory are interpreted as happening after
the body region. Using these hooks in generic analysis may require
introducing additional interfaces, but for now assume that the specific
analysis have spceial cases for the (rare) operaitons with call- and
region-based control flow that need additional processing.

Reviewed By: Mogball, phisiart

Differential Revision: https://reviews.llvm.org/D155757

Added: 
    

Modified: 
    mlir/include/mlir/Analysis/DataFlow/DenseAnalysis.h
    mlir/lib/Analysis/DataFlow/DenseAnalysis.cpp
    mlir/test/Analysis/DataFlow/test-last-modified-callgraph.mlir
    mlir/test/Analysis/DataFlow/test-last-modified.mlir
    mlir/test/Analysis/DataFlow/test-next-access.mlir
    mlir/test/lib/Analysis/DataFlow/TestDenseBackwardDataFlowAnalysis.cpp
    mlir/test/lib/Analysis/DataFlow/TestDenseDataFlowAnalysis.cpp
    mlir/test/lib/Dialect/Test/TestDialect.cpp
    mlir/test/lib/Dialect/Test/TestOps.td

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Analysis/DataFlow/DenseAnalysis.h b/mlir/include/mlir/Analysis/DataFlow/DenseAnalysis.h
index c93ab248970b5a..b6e64f3943b9d0 100644
--- a/mlir/include/mlir/Analysis/DataFlow/DenseAnalysis.h
+++ b/mlir/include/mlir/Analysis/DataFlow/DenseAnalysis.h
@@ -17,12 +17,18 @@
 
 #include "mlir/Analysis/DataFlowFramework.h"
 #include "mlir/IR/SymbolTable.h"
+#include "mlir/Interfaces/CallInterfaces.h"
+#include "mlir/Interfaces/ControlFlowInterfaces.h"
 
 namespace mlir {
+namespace dataflow {
 
-class RegionBranchOpInterface;
+//===----------------------------------------------------------------------===//
+// CallControlFlowAction
+//===----------------------------------------------------------------------===//
 
-namespace dataflow {
+/// Indicates whether the control enters or exits the callee.
+enum class CallControlFlowAction { EnterCallee, ExitCallee };
 
 //===----------------------------------------------------------------------===//
 // AbstractDenseLattice
@@ -109,6 +115,32 @@ class AbstractDenseDataFlowAnalysis : public DataFlowAnalysis {
   /// operation transfer function.
   virtual void processOperation(Operation *op);
 
+  /// Propagate the dense lattice forward along the control flow edge from
+  /// `regionFrom` to `regionTo` regions of the `branch` operation. `nullopt`
+  /// values correspond to control flow branches originating at or targeting the
+  /// `branch` operation itself. Default implementation just joins the states,
+  /// meaning that operations implementing `RegionBranchOpInterface` don't have
+  /// any effect on the lattice that isn't already expressed by the interface
+  /// itself.
+  virtual void visitRegionBranchControlFlowTransfer(
+      RegionBranchOpInterface branch, std::optional<unsigned> regionFrom,
+      std::optional<unsigned> regionTo, const AbstractDenseLattice &before,
+      AbstractDenseLattice *after) {
+    join(after, before);
+  }
+
+  /// Propagate the dense lattice forward along the call control flow edge,
+  /// which can be either entering or exiting the callee. Default implementation
+  /// just meets the states, meaning that operations implementing
+  /// `CallOpInterface` don't have any effect on the lattice that isn't already
+  /// expressed by the interface itself.
+  virtual void visitCallControlFlowTransfer(CallOpInterface call,
+                                            CallControlFlowAction action,
+                                            const AbstractDenseLattice &before,
+                                            AbstractDenseLattice *after) {
+    join(after, before);
+  }
+
   /// Visit a program point within a region branch operation with predecessors
   /// in it. This can either be an entry block of one of the regions of the
   /// parent operation itself.
@@ -120,6 +152,10 @@ class AbstractDenseDataFlowAnalysis : public DataFlowAnalysis {
   /// Visit a block. The state at the start of the block is propagated from
   /// control-flow predecessors or callsites.
   void visitBlock(Block *block);
+
+  /// Visit an operation for which the data flow is described by the
+  /// `CallOpInterface`.
+  void visitCallOperation(CallOpInterface call, AbstractDenseLattice *after);
 };
 
 //===----------------------------------------------------------------------===//
@@ -146,6 +182,60 @@ class DenseDataFlowAnalysis : public AbstractDenseDataFlowAnalysis {
   virtual void visitOperation(Operation *op, const LatticeT &before,
                               LatticeT *after) = 0;
 
+  /// Hook for customizing the behavior of lattice propagation along the call
+  /// control flow edges. Two types of (forward) propagation are possible here:
+  ///   - `action == CallControlFlowAction::Enter` indicates that:
+  ///     - `before` is the state before the call operation;
+  ///     - `after` is the state at the beginning of the callee entry block;
+  ///   - `action == CallControlFlowAction::Exit` indicates that:
+  ///     - `before` is the state at the end of a callee exit block;
+  ///     - `after` is the state after the call operation.
+  /// By default, the `after` state is simply joined with the `before` state.
+  /// Concrete analyses can override this behavior or delegate to the parent
+  /// call for the default behavior. Specifically, if the `call` op may affect
+  /// the lattice prior to entering the callee, the custom behavior can be added
+  /// for `action == CallControlFlowAction::Enter`. If the `call` op may affect
+  /// the lattice post exiting the callee, the custom behavior can be added for
+  /// `action == CallControlFlowAction::Exit`.
+  virtual void visitCallControlFlowTransfer(CallOpInterface call,
+                                            CallControlFlowAction action,
+                                            const LatticeT &before,
+                                            LatticeT *after) {
+    AbstractDenseDataFlowAnalysis::visitCallControlFlowTransfer(call, action,
+                                                                before, after);
+  }
+
+  /// Hook for customizing the behavior of lattice propagation along the control
+  /// flow edges between regions and their parent op. The control flows from
+  /// `regionFrom` to `regionTo`, both of which may be `nullopt` to indicate the
+  /// parent op. The lattice is propagated forward along this edge. The lattices
+  /// are as follows:
+  ///   - `before:`
+  ///     - if `regionFrom` is a region, this is the lattice at the end of the
+  ///       block that exits the region; note that for multi-exit regions, the
+  ///       lattices are equal at the end of all exiting blocks, but they are
+  ///       associated with 
diff erent program points.
+  ///     - otherwise, this is the lattice before the parent op.
+  ///   - `after`:
+  ///     - if `regionTo` is a region, this is the lattice at the beginning of
+  ///       the entry block of that region;
+  ///     - otherwise, this is the lattice after the parent op.
+  /// By default, the `after` state is simply joined with the `before` state.
+  /// Concrete analyses can override this behavior or delegate to the parent
+  /// call for the default behavior. Specifically, if the `branch` op may affect
+  /// the lattice before entering any region, the custom behavior can be added
+  /// for `regionFrom == nullopt`. If the `branch` op may affect the lattice
+  /// after all terminated, the custom behavior can be added for `regionTo ==
+  /// nullptr`. The behavior can be further refined for specific pairs of "from"
+  /// and "to" regions.
+  virtual void visitRegionBranchControlFlowTransfer(
+      RegionBranchOpInterface branch, std::optional<unsigned> regionFrom,
+      std::optional<unsigned> regionTo, const LatticeT &before,
+      LatticeT *after) {
+    AbstractDenseDataFlowAnalysis::visitRegionBranchControlFlowTransfer(
+        branch, regionFrom, regionTo, before, after);
+  }
+
 protected:
   /// Get the dense lattice after this program point.
   LatticeT *getLattice(ProgramPoint point) override {
@@ -162,10 +252,27 @@ class DenseDataFlowAnalysis : public AbstractDenseDataFlowAnalysis {
   /// Type-erased wrappers that convert the abstract dense lattice to a derived
   /// lattice and invoke the virtual hooks operating on the derived lattice.
   void visitOperationImpl(Operation *op, const AbstractDenseLattice &before,
-                          AbstractDenseLattice *after) override {
+                          AbstractDenseLattice *after) final {
     visitOperation(op, static_cast<const LatticeT &>(before),
                    static_cast<LatticeT *>(after));
   }
+  void visitCallControlFlowTransfer(CallOpInterface call,
+                                    CallControlFlowAction action,
+                                    const AbstractDenseLattice &before,
+                                    AbstractDenseLattice *after) final {
+    visitCallControlFlowTransfer(call, action,
+                                 static_cast<const LatticeT &>(before),
+                                 static_cast<LatticeT *>(after));
+  }
+  void visitRegionBranchControlFlowTransfer(RegionBranchOpInterface branch,
+                                            std::optional<unsigned> regionFrom,
+                                            std::optional<unsigned> regionTo,
+                                            const AbstractDenseLattice &before,
+                                            AbstractDenseLattice *after) final {
+    visitRegionBranchControlFlowTransfer(branch, regionFrom, regionTo,
+                                         static_cast<const LatticeT &>(before),
+                                         static_cast<LatticeT *>(after));
+  }
 };
 
 //===----------------------------------------------------------------------===//
@@ -231,12 +338,42 @@ class AbstractDenseBackwardDataFlowAnalysis : public DataFlowAnalysis {
     propagateIfChanged(lhs, lhs->meet(rhs));
   }
 
-  /// Visit an operation. If this is a call operation or region control-flow
-  /// operation, then the state after the execution of the operation is set by
-  /// control-flow or the callgraph. Otherwise, this function invokes the
-  /// operation transfer function.
+  /// Visit an operation. Dispatches to specialized methods for call or region
+  /// control-flow operations. Otherwise, this function invokes the operation
+  /// transfer function.
   virtual void processOperation(Operation *op);
 
+  /// Propagate the dense lattice backwards along the control flow edge from
+  /// `regionFrom` to `regionTo` regions of the `branch` operation. `nullopt`
+  /// values correspond to control flow branches originating at or targeting the
+  /// `branch` operation itself. Default implementation just meets the states,
+  /// meaning that operations implementing `RegionBranchOpInterface` don't have
+  /// any effect on the lattice that isn't already expressed by the interface
+  /// itself.
+  virtual void visitRegionBranchControlFlowTransfer(
+      RegionBranchOpInterface branch, std::optional<unsigned> regionFrom,
+      std::optional<unsigned> regionTo, const AbstractDenseLattice &after,
+      AbstractDenseLattice *before) {
+    meet(before, after);
+  }
+
+  /// Propagate the dense lattice backwards along the call control flow edge,
+  /// which can be either entering or exiting the callee. Default implementation
+  /// just meets the states, meaning that operations implementing
+  /// `CallOpInterface` don't have any effect on hte lattice that isn't already
+  /// expressed by the interface itself.
+  virtual void visitCallControlFlowTransfer(CallOpInterface call,
+                                            CallControlFlowAction action,
+                                            const AbstractDenseLattice &after,
+                                            AbstractDenseLattice *before) {
+    meet(before, after);
+  }
+
+private:
+  /// Visit a block. The state and the end of the block is propagated from
+  /// control-flow successors of the block or callsites.
+  void visitBlock(Block *block);
+
   /// Visit a program point within a region branch operation with successors
   /// (from which the state is propagated) in or after it. `regionNo` indicates
   /// the region that contains the successor, `nullopt` indicating the successor
@@ -246,10 +383,16 @@ class AbstractDenseBackwardDataFlowAnalysis : public DataFlowAnalysis {
                                   std::optional<unsigned> regionNo,
                                   AbstractDenseLattice *before);
 
-private:
-  /// VIsit a block. The state and the end of the block is propagated from
-  /// control-flow successors of the block or callsites.
-  void visitBlock(Block *block);
+  /// Visit an operation for which the data flow is described by the
+  /// `CallOpInterface`. Performs inter-procedural data flow as follows:
+  ///
+  ///   - find the callable (resolve via the symbol table),
+  ///   - get the entry block of the callable region,
+  ///   - take the state before the first operation if present or at block end
+  ///     otherwise,
+  ///   - meet that state with the state before the call-like op, or use the
+  ///     custom logic if overridden by concrete analyses.
+  void visitCallOperation(CallOpInterface call, AbstractDenseLattice *before);
 
   /// Symbol table for call-level control flow.
   SymbolTableCollection &symbolTable;
@@ -280,6 +423,60 @@ class DenseBackwardDataFlowAnalysis
   virtual void visitOperation(Operation *op, const LatticeT &after,
                               LatticeT *before) = 0;
 
+  /// Hook for customizing the behavior of lattice propagation along the call
+  /// control flow edges. Two types of (back) propagation are possible here:
+  ///   - `action == CallControlFlowAction::Enter` indicates that:
+  ///     - `after` is the state at the top of the callee entry block;
+  ///     - `before` is the state before the call operation;
+  ///   - `action == CallControlFlowAction::Exit` indicates that:
+  ///     - `after` is the state after the call operation;
+  ///     - `before` is the state of exit blocks of the callee.
+  /// By default, the `before` state is simply met with the `after` state.
+  /// Concrete analyses can override this behavior or delegate to the parent
+  /// call for the default behavior. Specifically, if the `call` op may affect
+  /// the lattice prior to entering the callee, the custom behavior can be added
+  /// for `action == CallControlFlowAction::Enter`. If the `call` op may affect
+  /// the lattice post exiting the callee, the custom behavior can be added for
+  /// `action == CallControlFlowAction::Exit`.
+  virtual void visitCallControlFlowTransfer(CallOpInterface call,
+                                            CallControlFlowAction action,
+                                            const LatticeT &after,
+                                            LatticeT *before) {
+    AbstractDenseBackwardDataFlowAnalysis::visitCallControlFlowTransfer(
+        call, action, after, before);
+  }
+
+  /// Hook for customizing the behavior of lattice propagation along the control
+  /// flow edges between regions and their parent op. The control flows from
+  /// `regionFrom` to `regionTo`, both of which may be `nullopt` to indicate the
+  /// parent op. The lattice is propagated back along this edge. The lattices
+  /// are as follows:
+  ///   - `after`:
+  ///     - if `regionTo` is a region, this is the lattice at the beginning of
+  ///       the entry block of that region;
+  ///     - otherwise, this is the lattice after the parent op.
+  ///   - `before:`
+  ///     - if `regionFrom` is a region, this is the lattice at the end of the
+  ///       block that exits the region; note that for multi-exit regions, the
+  ///       lattices are equal at the end of all exiting blocks, but they are
+  ///       associated with 
diff erent program points.
+  ///     - otherwise, this is the lattice before the parent op.
+  /// By default, the `before` state is simply met with the `after` state.
+  /// Concrete analyses can override this behavior or delegate to the parent
+  /// call for the default behavior. Specifically, if the `branch` op may affect
+  /// the lattice before entering any region, the custom behavior can be added
+  /// for `regionFrom == nullopt`. If the `branch` op may affect the lattice
+  /// after all terminated, the custom behavior can be added for `regionTo ==
+  /// nullptr`. The behavior can be further refined for specific pairs of "from"
+  /// and "to" regions.
+  virtual void visitRegionBranchControlFlowTransfer(
+      RegionBranchOpInterface branch, std::optional<unsigned> regionFrom,
+      std::optional<unsigned> regionTo, const LatticeT &after,
+      LatticeT *before) {
+    AbstractDenseBackwardDataFlowAnalysis::visitRegionBranchControlFlowTransfer(
+        branch, regionFrom, regionTo, after, before);
+  }
+
 protected:
   /// Get the dense lattice at the given program point.
   LatticeT *getLattice(ProgramPoint point) override {
@@ -289,17 +486,33 @@ class DenseBackwardDataFlowAnalysis
   /// Set the dense lattice at control flow exit point (after the terminator)
   /// and propagate an update if it changed.
   virtual void setToExitState(LatticeT *lattice) = 0;
-  void setToExitState(AbstractDenseLattice *lattice) override {
+  void setToExitState(AbstractDenseLattice *lattice) final {
     setToExitState(static_cast<LatticeT *>(lattice));
   }
 
-  /// Type-erased wrapper that convert the abstract dense lattice to a derived
+  /// Type-erased wrappers that convert the abstract dense lattice to a derived
   /// lattice and invoke the virtual hooks operating on the derived lattice.
   void visitOperationImpl(Operation *op, const AbstractDenseLattice &after,
-                          AbstractDenseLattice *before) override {
+                          AbstractDenseLattice *before) final {
     visitOperation(op, static_cast<const LatticeT &>(after),
                    static_cast<LatticeT *>(before));
   }
+  void visitCallControlFlowTransfer(CallOpInterface call,
+                                    CallControlFlowAction action,
+                                    const AbstractDenseLattice &after,
+                                    AbstractDenseLattice *before) final {
+    visitCallControlFlowTransfer(call, action,
+                                 static_cast<const LatticeT &>(after),
+                                 static_cast<LatticeT *>(before));
+  }
+  void visitRegionBranchControlFlowTransfer(
+      RegionBranchOpInterface branch, std::optional<unsigned> regionForm,
+      std::optional<unsigned> regionTo, const AbstractDenseLattice &after,
+      AbstractDenseLattice *before) final {
+    visitRegionBranchControlFlowTransfer(branch, regionForm, regionTo,
+                                         static_cast<const LatticeT &>(after),
+                                         static_cast<LatticeT *>(before));
+  }
 };
 
 } // end namespace dataflow

diff  --git a/mlir/lib/Analysis/DataFlow/DenseAnalysis.cpp b/mlir/lib/Analysis/DataFlow/DenseAnalysis.cpp
index 36920f84e8742e..e6716b88e9cba3 100644
--- a/mlir/lib/Analysis/DataFlow/DenseAnalysis.cpp
+++ b/mlir/lib/Analysis/DataFlow/DenseAnalysis.cpp
@@ -42,6 +42,38 @@ LogicalResult AbstractDenseDataFlowAnalysis::visit(ProgramPoint point) {
   return success();
 }
 
+void AbstractDenseDataFlowAnalysis::visitCallOperation(
+    CallOpInterface call, AbstractDenseLattice *after) {
+
+  const auto *predecessors =
+      getOrCreateFor<PredecessorState>(call.getOperation(), call);
+  // If not all return sites are known, then conservatively assume we can't
+  // reason about the data-flow.
+  if (!predecessors->allPredecessorsKnown())
+    return setToEntryState(after);
+
+  for (Operation *predecessor : predecessors->getKnownPredecessors()) {
+    // Get the lattices at callee return:
+    //
+    //   func.func @callee() {
+    //     ...
+    //     return  // predecessor
+    //     // latticeAtCalleeReturn
+    //   }
+    //   func.func @caller() {
+    //     ...
+    //     call @callee
+    //     // latticeAfterCall
+    //     ...
+    //   }
+    AbstractDenseLattice *latticeAfterCall = after;
+    const AbstractDenseLattice *latticeAtCalleeReturn =
+        getLatticeFor(call.getOperation(), predecessor);
+    visitCallControlFlowTransfer(call, CallControlFlowAction::ExitCallee,
+                                 *latticeAtCalleeReturn, latticeAfterCall);
+  }
+}
+
 void AbstractDenseDataFlowAnalysis::processOperation(Operation *op) {
   // If the containing block is not executable, bail out.
   if (!getOrCreateFor<Executable>(op, op->getBlock())->isLive())
@@ -50,6 +82,13 @@ void AbstractDenseDataFlowAnalysis::processOperation(Operation *op) {
   // Get the dense lattice to update.
   AbstractDenseLattice *after = getLattice(op);
 
+  // Get the dense state before the execution of the op.
+  const AbstractDenseLattice *before;
+  if (Operation *prev = op->getPrevNode())
+    before = getLatticeFor(op, prev);
+  else
+    before = getLatticeFor(op, op->getBlock());
+
   // If this op implements region control-flow, then control-flow dictates its
   // transfer function.
   if (auto branch = dyn_cast<RegionBranchOpInterface>(op))
@@ -57,23 +96,8 @@ void AbstractDenseDataFlowAnalysis::processOperation(Operation *op) {
 
   // If this is a call operation, then join its lattices across known return
   // sites.
-  if (auto call = dyn_cast<CallOpInterface>(op)) {
-    const auto *predecessors = getOrCreateFor<PredecessorState>(op, call);
-    // If not all return sites are known, then conservatively assume we can't
-    // reason about the data-flow.
-    if (!predecessors->allPredecessorsKnown())
-      return setToEntryState(after);
-    for (Operation *predecessor : predecessors->getKnownPredecessors())
-      join(after, *getLatticeFor(op, predecessor));
-    return;
-  }
-
-  // Get the dense state before the execution of the op.
-  const AbstractDenseLattice *before;
-  if (Operation *prev = op->getPrevNode())
-    before = getLatticeFor(op, prev);
-  else
-    before = getLatticeFor(op, op->getBlock());
+  if (auto call = dyn_cast<CallOpInterface>(op))
+    return visitCallOperation(call, after);
 
   // Invoke the operation transfer function.
   visitOperationImpl(op, *before, after);
@@ -100,10 +124,15 @@ void AbstractDenseDataFlowAnalysis::visitBlock(Block *block) {
         return setToEntryState(after);
       for (Operation *callsite : callsites->getKnownPredecessors()) {
         // Get the dense lattice before the callsite.
+        const AbstractDenseLattice *before;
         if (Operation *prev = callsite->getPrevNode())
-          join(after, *getLatticeFor(block, prev));
+          before = getLatticeFor(block, prev);
         else
-          join(after, *getLatticeFor(block, callsite->getBlock()));
+          before = getLatticeFor(block, callsite->getBlock());
+
+        visitCallControlFlowTransfer(cast<CallOpInterface>(callsite),
+                                     CallControlFlowAction::EnterCallee,
+                                     *before, after);
       }
       return;
     }
@@ -152,7 +181,41 @@ void AbstractDenseDataFlowAnalysis::visitRegionBranchOperation(
     } else {
       before = getLatticeFor(point, op);
     }
-    join(after, *before);
+
+    // This function is called in two cases:
+    //   1. when visiting the block (point = block);
+    //   2. when visiting the parent operation (point = parent op).
+    // In both cases, we are looking for predecessor operations of the point,
+    //   1. predecessor may be the terminator of another block from another
+    //   region (assuming that the block does belong to another region via an
+    //   assertion) or the parent (when parent can transfer control to this
+    //   region);
+    //   2. predecessor may be the terminator of a block that exits the
+    //   region (when region transfers control to the parent) or the operation
+    //   before the parent.
+    // In the latter case, just perform the join as it isn't the control flow
+    // affected by the region.
+    std::optional<unsigned> regionFrom =
+        op == branch ? std::optional<unsigned>()
+                     : op->getBlock()->getParent()->getRegionNumber();
+    if (auto *toBlock = point.dyn_cast<Block *>()) {
+      assert(op == branch ||
+             toBlock->getParent() != op->getBlock()->getParent());
+      unsigned regionTo = toBlock->getParent()->getRegionNumber();
+      visitRegionBranchControlFlowTransfer(branch, regionFrom, regionTo,
+                                           *before, after);
+    } else {
+      assert(point.get<Operation *>() == branch &&
+             "expected to be visiting the branch itself");
+      // Only need to call the arc transfer when the predecessor is the region
+      // or the op itself, not the previous op.
+      if (op->getParentOp() == branch || op == branch) {
+        visitRegionBranchControlFlowTransfer(
+            branch, regionFrom, /*regionTo=*/std::nullopt, *before, after);
+      } else {
+        join(after, *before);
+      }
+    }
   }
 }
 
@@ -194,6 +257,44 @@ LogicalResult AbstractDenseBackwardDataFlowAnalysis::visit(ProgramPoint point) {
   return success();
 }
 
+void AbstractDenseBackwardDataFlowAnalysis::visitCallOperation(
+    CallOpInterface call, AbstractDenseLattice *before) {
+  // Find the callee.
+  Operation *callee = call.resolveCallable(&symbolTable);
+  auto callable = dyn_cast_or_null<CallableOpInterface>(callee);
+  if (!callable)
+    return setToExitState(before);
+
+  // No region means the callee is only declared in this module and we shouldn't
+  // assume anything about it.
+  Region *region = callable.getCallableRegion();
+  if (!region || region->empty())
+    return setToExitState(before);
+
+  // Call-level control flow specifies the data flow here.
+  //
+  //   func.func @callee() {
+  //     ^calleeEntryBlock:
+  //     // latticeAtCalleeEntry
+  //     ...
+  //   }
+  //   func.func @caller() {
+  //     ...
+  //     // latticeBeforeCall
+  //     call @callee
+  //     ...
+  //   }
+  Block *calleeEntryBlock = &region->front();
+  ProgramPoint calleeEntry = calleeEntryBlock->empty()
+                                 ? ProgramPoint(calleeEntryBlock)
+                                 : &calleeEntryBlock->front();
+  const AbstractDenseLattice &latticeAtCalleeEntry =
+      *getLatticeFor(call.getOperation(), calleeEntry);
+  AbstractDenseLattice *latticeBeforeCall = before;
+  visitCallControlFlowTransfer(call, CallControlFlowAction::EnterCallee,
+                               latticeAtCalleeEntry, latticeBeforeCall);
+}
+
 void AbstractDenseBackwardDataFlowAnalysis::processOperation(Operation *op) {
   // If the containing block is not executable, bail out.
   if (!getOrCreateFor<Executable>(op, op->getBlock())->isLive())
@@ -202,39 +303,6 @@ void AbstractDenseBackwardDataFlowAnalysis::processOperation(Operation *op) {
   // Get the dense lattice to update.
   AbstractDenseLattice *before = getLattice(op);
 
-  // If the op implements region control flow, then the interface specifies the
-  // control function.
-  // TODO: this is not always true, e.g. linalg.generic, but is implement this
-  // way for consistency with the dense forward analysis.
-  if (auto branch = dyn_cast<RegionBranchOpInterface>(op))
-    return visitRegionBranchOperation(op, branch, std::nullopt, before);
-
-  // If the op is a call-like, do inter-procedural data flow as follows:
-  //
-  //   - find the callable (resolve via the symbol table),
-  //   - get the entry block of the callable region,
-  //   - take the state before the first operation if present or at block end
-  //   otherwise,
-  //   - meet that state with the state before the call-like op.
-  if (auto call = dyn_cast<CallOpInterface>(op)) {
-    Operation *callee = call.resolveCallable(&symbolTable);
-    if (auto callable = dyn_cast<CallableOpInterface>(callee)) {
-      Region *region = callable.getCallableRegion();
-      if (region && !region->empty()) {
-        Block *entryBlock = &region->front();
-        if (entryBlock->empty())
-          meet(before, *getLatticeFor(op, entryBlock));
-        else
-          meet(before, *getLatticeFor(op, &entryBlock->front()));
-      } else {
-        setToExitState(before);
-      }
-    } else {
-      setToExitState(before);
-    }
-    return;
-  }
-
   // Get the dense state after execution of this op.
   const AbstractDenseLattice *after;
   if (Operation *next = op->getNextNode())
@@ -242,6 +310,12 @@ void AbstractDenseBackwardDataFlowAnalysis::processOperation(Operation *op) {
   else
     after = getLatticeFor(op, op->getBlock());
 
+  // Special cases where control flow may dictate data flow.
+  if (auto branch = dyn_cast<RegionBranchOpInterface>(op))
+    return visitRegionBranchOperation(op, branch, std::nullopt, before);
+  if (auto call = dyn_cast<CallOpInterface>(op))
+    return visitCallOperation(call, before);
+
   // Invoke the operation transfer function.
   visitOperationImpl(op, *after, before);
 }
@@ -280,16 +354,20 @@ void AbstractDenseBackwardDataFlowAnalysis::visitBlock(Block *block) {
         return setToExitState(before);
 
       for (Operation *callsite : callsites->getKnownPredecessors()) {
+        const AbstractDenseLattice *after;
         if (Operation *next = callsite->getNextNode())
-          meet(before, *getLatticeFor(block, next));
+          after = getLatticeFor(block, next);
         else
-          meet(before, *getLatticeFor(block, callsite->getBlock()));
+          after = getLatticeFor(block, callsite->getBlock());
+        visitCallControlFlowTransfer(cast<CallOpInterface>(callsite),
+                                     CallControlFlowAction::ExitCallee, *after,
+                                     before);
       }
       return;
     }
 
     // If this block is exiting from an operation with region-based control
-    // flow, follow that flow.
+    // flow, propagate the lattice back along the control flow edge.
     if (auto branch = dyn_cast<RegionBranchOpInterface>(block->getParentOp())) {
       visitRegionBranchOperation(block, branch,
                                  block->getParent()->getRegionNumber(), before);
@@ -346,7 +424,11 @@ void AbstractDenseBackwardDataFlowAnalysis::visitRegionBranchOperation(
       else
         after = getLatticeFor(point, &successorBlock->front());
     }
-    meet(before, *after);
+    std::optional<unsigned> successorNo =
+        successor.isParent() ? std::optional<unsigned>()
+                             : successor.getSuccessor()->getRegionNumber();
+    visitRegionBranchControlFlowTransfer(branch, regionNo, successorNo, *after,
+                                         before);
   }
 }
 

diff  --git a/mlir/test/Analysis/DataFlow/test-last-modified-callgraph.mlir b/mlir/test/Analysis/DataFlow/test-last-modified-callgraph.mlir
index c1fdf82e4bc766..4a243571c231a1 100644
--- a/mlir/test/Analysis/DataFlow/test-last-modified-callgraph.mlir
+++ b/mlir/test/Analysis/DataFlow/test-last-modified-callgraph.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt -test-last-modified %s 2>&1 | FileCheck %s
+// RUN: mlir-opt -test-last-modified --split-input-file %s 2>&1 | FileCheck %s
 
 // CHECK-LABEL: test_tag: test_callsite
 // CHECK: operand #0
@@ -64,4 +64,84 @@ func.func private @multiple_return_site_fn(%cond: i1, %a: i32, %ptr: memref<i32>
 func.func @test_multiple_return_sites(%cond: i1, %a: i32, %ptr: memref<i32>) -> memref<i32> {
   %0 = func.call @multiple_return_site_fn(%cond, %a, %ptr) : (i1, i32, memref<i32>) -> memref<i32>
   return {tag = "test_multiple_return_sites"} %0 : memref<i32>
-}
\ No newline at end of file
+}
+
+// -----
+
+
+func.func private @callee(%arg0: memref<f32>) -> memref<f32> {
+  %2 = arith.constant 2.0 : f32
+  memref.load %arg0[] {tag = "call_and_store_before::enter_callee"} : memref<f32>
+  memref.store %2, %arg0[] {tag_name = "callee"} : memref<f32>
+  memref.load %arg0[] {tag = "exit_callee"} : memref<f32>
+  return %arg0 : memref<f32>
+}
+// In this test, the "call" operation also stores to %arg0 itself before
+// transferring control flow to the callee. Therefore, the order of accesses is
+// "pre" -> "call" -> "callee" -> "post"
+
+// CHECK-LABEL: test_tag: call_and_store_before::enter_callee:
+// CHECK:  operand #0
+// CHECK:   - call
+// CHECK: test_tag: exit_callee:
+// CHECK:  operand #0
+// CHECK:   - callee
+// CHECK: test_tag: before_call:
+// CHECK:  operand #0
+// CHECK:   - pre
+// CHECK: test_tag: after_call:
+// CHECK:  operand #0
+// CHECK:   - callee
+// CHECK: test_tag: return:
+// CHECK:  operand #0
+// CHECK:   - post
+func.func @call_and_store_before(%arg0: memref<f32>) -> memref<f32> {
+  %0 = arith.constant 0.0 : f32
+  %1 = arith.constant 1.0 : f32
+  memref.store %0, %arg0[] {tag_name = "pre"} : memref<f32>
+  memref.load %arg0[] {tag = "before_call"} : memref<f32>
+  test.call_and_store @callee(%arg0), %arg0 {tag_name = "call", store_before_call = true} : (memref<f32>, memref<f32>) -> ()
+  memref.load %arg0[] {tag = "after_call"} : memref<f32>
+  memref.store %1, %arg0[] {tag_name = "post"} : memref<f32>
+  return {tag = "return"} %arg0 : memref<f32>
+}
+
+// -----
+
+func.func private @callee(%arg0: memref<f32>) -> memref<f32> {
+  %2 = arith.constant 2.0 : f32
+  memref.load %arg0[] {tag = "call_and_store_after::enter_callee"} : memref<f32>
+  memref.store %2, %arg0[] {tag_name = "callee"} : memref<f32>
+  memref.load %arg0[] {tag = "exit_callee"} : memref<f32>
+  return %arg0 : memref<f32>
+}
+
+// In this test, the "call" operation also stores to %arg0 itself after getting
+// control flow back from the callee. Therefore, the order of accesses is
+// "pre" -> "callee" -> "call" -> "post"
+
+// CHECK-LABEL: test_tag: call_and_store_after::enter_callee:
+// CHECK:  operand #0
+// CHECK:   - pre
+// CHECK: test_tag: exit_callee:
+// CHECK:  operand #0
+// CHECK:   - callee
+// CHECK: test_tag: before_call:
+// CHECK:  operand #0
+// CHECK:   - pre
+// CHECK: test_tag: after_call:
+// CHECK:  operand #0
+// CHECK:   - call
+// CHECK: test_tag: return:
+// CHECK:  operand #0
+// CHECK:   - post
+func.func @call_and_store_after(%arg0: memref<f32>) -> memref<f32> {
+  %0 = arith.constant 0.0 : f32
+  %1 = arith.constant 1.0 : f32
+  memref.store %0, %arg0[] {tag_name = "pre"} : memref<f32>
+  memref.load %arg0[] {tag = "before_call"} : memref<f32>
+  test.call_and_store @callee(%arg0), %arg0 {tag_name = "call", store_before_call = false} : (memref<f32>, memref<f32>) -> ()
+  memref.load %arg0[] {tag = "after_call"} : memref<f32>
+  memref.store %1, %arg0[] {tag_name = "post"} : memref<f32>
+  return {tag = "return"} %arg0 : memref<f32>
+}

diff  --git a/mlir/test/Analysis/DataFlow/test-last-modified.mlir b/mlir/test/Analysis/DataFlow/test-last-modified.mlir
index 69fb7125f0c5bf..069cbbcc0cc168 100644
--- a/mlir/test/Analysis/DataFlow/test-last-modified.mlir
+++ b/mlir/test/Analysis/DataFlow/test-last-modified.mlir
@@ -113,3 +113,119 @@ func.func @unknown_memory_effects(%ptr: memref<i32>) -> memref<i32> {
   "test.unknown_effects"() : () -> ()
   return {tag = "unknown_memory_effects_b"} %ptr : memref<i32>
 }
+
+// CHECK-LABEL: test_tag: store_with_a_region_before::before:
+// CHECK:  operand #0
+// CHECK:   - pre
+// CHECK: test_tag: inside_region:
+// CHECK:  operand #0
+// CHECK:   - region
+// CHECK: test_tag: after:
+// CHECK:  operand #0
+// CHECK:   - region
+// CHECK: test_tag: return:
+// CHECK:  operand #0
+// CHECK:   - post
+func.func @store_with_a_region_before(%arg0: memref<f32>) -> memref<f32> {
+  %0 = arith.constant 0.0 : f32
+  %1 = arith.constant 1.0 : f32
+  memref.store %0, %arg0[] {tag_name = "pre"} : memref<f32>
+  memref.load %arg0[] {tag = "store_with_a_region_before::before"} : memref<f32>
+  test.store_with_a_region %arg0 attributes { tag_name = "region", store_before_region = true } {
+    memref.load %arg0[] {tag = "inside_region"} : memref<f32>
+    test.store_with_a_region_terminator
+  } : memref<f32>
+  memref.load %arg0[] {tag = "after"} : memref<f32>
+  memref.store %1, %arg0[] {tag_name = "post"} : memref<f32>
+  return {tag = "return"} %arg0 : memref<f32>
+}
+
+// CHECK-LABEL: test_tag: store_with_a_region_after::before:
+// CHECK:  operand #0
+// CHECK:   - pre
+// CHECK: test_tag: inside_region:
+// CHECK:  operand #0
+// CHECK:   - pre
+// CHECK: test_tag: after:
+// CHECK:  operand #0
+// CHECK:   - region
+// CHECK: test_tag: return:
+// CHECK:  operand #0
+// CHECK:   - post
+func.func @store_with_a_region_after(%arg0: memref<f32>) -> memref<f32> {
+  %0 = arith.constant 0.0 : f32
+  %1 = arith.constant 1.0 : f32
+  memref.store %0, %arg0[] {tag_name = "pre"} : memref<f32>
+  memref.load %arg0[] {tag = "store_with_a_region_after::before"} : memref<f32>
+  test.store_with_a_region %arg0 attributes { tag_name = "region", store_before_region = false } {
+    memref.load %arg0[] {tag = "inside_region"} : memref<f32>
+    test.store_with_a_region_terminator
+  } : memref<f32>
+  memref.load %arg0[] {tag = "after"} : memref<f32>
+  memref.store %1, %arg0[] {tag_name = "post"} : memref<f32>
+  return {tag = "return"} %arg0 : memref<f32>
+}
+
+// CHECK-LABEL: test_tag: store_with_a_region_before_containing_a_store::before:
+// CHECK:  operand #0
+// CHECK:   - pre
+// CHECK: test_tag: enter_region:
+// CHECK:  operand #0
+// CHECK:   - region
+// CHECK: test_tag: exit_region:
+// CHECK:  operand #0
+// CHECK:   - inner
+// CHECK: test_tag: after:
+// CHECK:  operand #0
+// CHECK:   - inner
+// CHECK: test_tag: return:
+// CHECK:  operand #0
+// CHECK:   - post
+func.func @store_with_a_region_before_containing_a_store(%arg0: memref<f32>) -> memref<f32> {
+  %0 = arith.constant 0.0 : f32
+  %1 = arith.constant 1.0 : f32
+  memref.store %0, %arg0[] {tag_name = "pre"} : memref<f32>
+  memref.load %arg0[] {tag = "store_with_a_region_before_containing_a_store::before"} : memref<f32>
+  test.store_with_a_region %arg0 attributes { tag_name = "region", store_before_region = true } {
+    memref.load %arg0[] {tag = "enter_region"} : memref<f32>
+    %2 = arith.constant 2.0 : f32
+    memref.store %2, %arg0[] {tag_name = "inner"} : memref<f32>
+    memref.load %arg0[] {tag = "exit_region"} : memref<f32>
+    test.store_with_a_region_terminator
+  } : memref<f32>
+  memref.load %arg0[] {tag = "after"} : memref<f32>
+  memref.store %1, %arg0[] {tag_name = "post"} : memref<f32>
+  return {tag = "return"} %arg0 : memref<f32>
+}
+
+// CHECK-LABEL: test_tag: store_with_a_region_after_containing_a_store::before:
+// CHECK:  operand #0
+// CHECK:   - pre
+// CHECK: test_tag: enter_region:
+// CHECK:  operand #0
+// CHECK:   - pre
+// CHECK: test_tag: exit_region:
+// CHECK:  operand #0
+// CHECK:   - inner
+// CHECK: test_tag: after:
+// CHECK:  operand #0
+// CHECK:   - region
+// CHECK: test_tag: return:
+// CHECK:  operand #0
+// CHECK:   - post
+func.func @store_with_a_region_after_containing_a_store(%arg0: memref<f32>) -> memref<f32> {
+  %0 = arith.constant 0.0 : f32
+  %1 = arith.constant 1.0 : f32
+  memref.store %0, %arg0[] {tag_name = "pre"} : memref<f32>
+  memref.load %arg0[] {tag = "store_with_a_region_after_containing_a_store::before"} : memref<f32>
+  test.store_with_a_region %arg0 attributes { tag_name = "region", store_before_region = false } {
+    memref.load %arg0[] {tag = "enter_region"} : memref<f32>
+    %2 = arith.constant 2.0 : f32
+    memref.store %2, %arg0[] {tag_name = "inner"} : memref<f32>
+    memref.load %arg0[] {tag = "exit_region"} : memref<f32>
+    test.store_with_a_region_terminator
+  } : memref<f32>
+  memref.load %arg0[] {tag = "after"} : memref<f32>
+  memref.store %1, %arg0[] {tag_name = "post"} : memref<f32>
+  return {tag = "return"} %arg0 : memref<f32>
+}

diff  --git a/mlir/test/Analysis/DataFlow/test-next-access.mlir b/mlir/test/Analysis/DataFlow/test-next-access.mlir
index 31294f0bbc2a6b..a029baf8988b41 100644
--- a/mlir/test/Analysis/DataFlow/test-next-access.mlir
+++ b/mlir/test/Analysis/DataFlow/test-next-access.mlir
@@ -357,3 +357,157 @@ func.func @conditonal_call(%arg0: memref<f32>, %cond: i1) {
   memref.load %arg0[] {name = "post"} : memref<f32>
   return
 }
+
+// -----
+
+
+// In this test, the "call" operation also accesses %arg0 itself before
+// transferring control flow to the callee. Therefore, the order of accesses is
+// "caller" -> "call" -> "callee" -> "post"
+
+func.func private @callee(%arg0: memref<f32>) {
+  // CHECK:              name = "callee"
+  // CHECK-SAME-LITERAL: next_access = [["post"]]
+  memref.load %arg0[] {name = "callee"} : memref<f32>
+  return
+}
+
+// CHECK-LABEL: @call_and_store_before
+func.func @call_and_store_before(%arg0: memref<f32>) {
+  // CHECK:              name = "caller"
+  // CHECK-SAME-LITERAL: next_access = [["call"]]
+  memref.load %arg0[] {name = "caller"} : memref<f32>
+  // Note that the access after the entire call is "post".
+  // CHECK:              name = "call"
+  // CHECK-SAME-LITERAL: next_access = [["post"], ["post"]]
+  test.call_and_store @callee(%arg0), %arg0 {name = "call", store_before_call = true} : (memref<f32>, memref<f32>) -> ()
+  // CHECK:              name = "post"
+  // CHECK-SAME-LITERAL: next_access = ["unknown"]
+  memref.load %arg0[] {name = "post"} : memref<f32>
+  return
+}
+
+// -----
+
+// In this test, the "call" operation also accesses %arg0 itself after getting
+// control flow back from the callee. Therefore, the order of accesses is
+// "caller" -> "callee" -> "call" -> "post"
+
+func.func private @callee(%arg0: memref<f32>) {
+  // CHECK:              name = "callee"
+  // CHECK-SAME-LITERAL: next_access = [["call"]]
+  memref.load %arg0[] {name = "callee"} : memref<f32>
+  return
+}
+
+// CHECK-LABEL: @call_and_store_after
+func.func @call_and_store_after(%arg0: memref<f32>) {
+  // CHECK:              name = "caller"
+  // CHECK-SAME-LITERAL: next_access = [["callee"]]
+  memref.load %arg0[] {name = "caller"} : memref<f32>
+  // CHECK:              name = "call"
+  // CHECK-SAME-LITERAL: next_access = [["post"], ["post"]]
+  test.call_and_store @callee(%arg0), %arg0 {name = "call", store_before_call = true} : (memref<f32>, memref<f32>) -> ()
+  // CHECK:              name = "post"
+  // CHECK-SAME-LITERAL: next_access = ["unknown"]
+  memref.load %arg0[] {name = "post"} : memref<f32>
+  return
+}
+
+// -----
+
+// In this test, the "region" operation also accesses %arg0 itself before
+// entering the region. Therefore:
+//   - the next access of "pre" is the "region" operation itself;
+//   - at the entry of the block, the next access is "post".
+// CHECK-LABEL: @store_with_a_region
+func.func @store_with_a_region_before(%arg0: memref<f32>) {
+  // CHECK:              name = "pre"
+  // CHECK-SAME-LITERAL: next_access = [["region"]]
+  memref.load %arg0[] {name = "pre"} : memref<f32>
+  // CHECK:              name = "region"
+  // CHECK-SAME-LITERAL: next_access = [["post"]]
+  // CHECK-SAME-LITERAL: next_at_entry_point = [[["post"]]]
+  test.store_with_a_region %arg0 attributes { name = "region", store_before_region = true } {
+    test.store_with_a_region_terminator
+  } : memref<f32>
+  memref.load %arg0[] {name = "post"} : memref<f32>
+  return
+}
+
+// In this test, the "region" operation also accesses %arg0 itself after
+// exiting from the region. Therefore:
+//   - the next access of "pre" is the "region" operation itself;
+//   - at the entry of the block, the next access is "region".
+// CHECK-LABEL: @store_with_a_region
+func.func @store_with_a_region_after(%arg0: memref<f32>) {
+  // CHECK:              name = "pre"
+  // CHECK-SAME-LITERAL: next_access = [["region"]]
+  memref.load %arg0[] {name = "pre"} : memref<f32>
+  // CHECK:              name = "region"
+  // CHECK-SAME-LITERAL: next_access = [["post"]]
+  // CHECK-SAME-LITERAL: next_at_entry_point = [[["region"]]]
+  test.store_with_a_region %arg0 attributes { name = "region", store_before_region = false } {
+    test.store_with_a_region_terminator
+  } : memref<f32>
+  memref.load %arg0[] {name = "post"} : memref<f32>
+  return
+}
+
+// In this test, the operation with a region stores to %arg0 before going to the
+// region. Therefore: 
+//   - the next access of "pre" is the "region" operation itself;
+//   - the next access of the "region" operation (computed as the next access
+//     *after* said operation) is the "post" operation;
+//   - the next access of the "inner" operation is also "post";
+//   - the next access at the entry point of the region of the "region" operation
+//     is the "inner" operation.
+// That is, the order of access is: "pre" -> "region" -> "inner" -> "post".
+// CHECK-LABEL: @store_with_a_region_before_containing_a_load
+func.func @store_with_a_region_before_containing_a_load(%arg0: memref<f32>) {
+  // CHECK:              name = "pre"
+  // CHECK-SAME-LITERAL: next_access = [["region"]]
+  memref.load %arg0[] {name = "pre"} : memref<f32>
+  // CHECK:              name = "region"
+  // CHECK-SAME-LITERAL: next_access = [["post"]]
+  // CHECK-SAME-LITERAL: next_at_entry_point = [[["inner"]]]
+  test.store_with_a_region %arg0 attributes { name = "region", store_before_region = true } {
+    // CHECK:              name = "inner"
+    // CHECK-SAME-LITERAL: next_access = [["post"]]
+    memref.load %arg0[] {name = "inner"} : memref<f32>
+    test.store_with_a_region_terminator
+  } : memref<f32>
+  // CHECK:              name = "post"
+  // CHECK-SAME-LITERAL: next_access = ["unknown"]
+  memref.load %arg0[] {name = "post"} : memref<f32>
+  return
+}
+
+// In this test, the operation with a region stores to %arg0 after exiting from
+// the region. Therefore:
+//   - the next access of "pre" is "inner";
+//   - the next access of the "region" operation (computed as the next access
+//     *after* said operation) is the "post" operation);
+//   - the next access at the entry point of the region of the "region" operation
+//     is the "inner" operation;
+//   - the next access of the "inner" operation is the "region" operation itself.
+// That is, the order of access is "pre" -> "inner" -> "region" -> "post".
+// CHECK-LABEL: @store_with_a_region_after_containing_a_load
+func.func @store_with_a_region_after_containing_a_load(%arg0: memref<f32>) {
+  // CHECK:              name = "pre"
+  // CHECK-SAME-LITERAL: next_access = [["inner"]]
+  memref.load %arg0[] {name = "pre"} : memref<f32>
+  // CHECK:              name = "region"
+  // CHECK-SAME-LITERAL: next_access = [["post"]]
+  // CHECK-SAME-LITERAL: next_at_entry_point = [[["inner"]]]
+  test.store_with_a_region %arg0 attributes { name = "region", store_before_region = false } {
+    // CHECK:              name = "inner"
+    // CHECK-SAME-LITERAL: next_access = [["region"]]
+    memref.load %arg0[] {name = "inner"} : memref<f32>
+    test.store_with_a_region_terminator
+  } : memref<f32>
+  // CHECK:              name = "post"
+  // CHECK-SAME-LITERAL: next_access = ["unknown"]
+  memref.load %arg0[] {name = "post"} : memref<f32>
+  return
+}

diff  --git a/mlir/test/lib/Analysis/DataFlow/TestDenseBackwardDataFlowAnalysis.cpp b/mlir/test/lib/Analysis/DataFlow/TestDenseBackwardDataFlowAnalysis.cpp
index 2b20b01913c39e..a33b523d5d192f 100644
--- a/mlir/test/lib/Analysis/DataFlow/TestDenseBackwardDataFlowAnalysis.cpp
+++ b/mlir/test/lib/Analysis/DataFlow/TestDenseBackwardDataFlowAnalysis.cpp
@@ -11,11 +11,15 @@
 //===----------------------------------------------------------------------===//
 
 #include "TestDenseDataFlowAnalysis.h"
+#include "TestDialect.h"
 #include "mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h"
 #include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h"
 #include "mlir/Analysis/DataFlow/DenseAnalysis.h"
 #include "mlir/Analysis/DataFlowFramework.h"
+#include "mlir/IR/Builders.h"
 #include "mlir/IR/SymbolTable.h"
+#include "mlir/Interfaces/CallInterfaces.h"
+#include "mlir/Interfaces/ControlFlowInterfaces.h"
 #include "mlir/Interfaces/SideEffectInterfaces.h"
 #include "mlir/Pass/Pass.h"
 #include "mlir/Support/TypeID.h"
@@ -27,14 +31,14 @@ using namespace mlir::dataflow::test;
 
 namespace {
 
-class NextAccess : public AbstractDenseLattice, public test::AccessLatticeBase {
+class NextAccess : public AbstractDenseLattice, public AccessLatticeBase {
 public:
   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(NextAccess)
 
   using dataflow::AbstractDenseLattice::AbstractDenseLattice;
 
   ChangeResult meet(const AbstractDenseLattice &lattice) override {
-    return AccessLatticeBase::merge(static_cast<test::AccessLatticeBase>(
+    return AccessLatticeBase::merge(static_cast<AccessLatticeBase>(
         static_cast<const NextAccess &>(lattice)));
   }
 
@@ -50,6 +54,17 @@ class NextAccessAnalysis : public DenseBackwardDataFlowAnalysis<NextAccess> {
   void visitOperation(Operation *op, const NextAccess &after,
                       NextAccess *before) override;
 
+  void visitCallControlFlowTransfer(CallOpInterface call,
+                                    CallControlFlowAction action,
+                                    const NextAccess &after,
+                                    NextAccess *before) override;
+
+  void visitRegionBranchControlFlowTransfer(RegionBranchOpInterface branch,
+                                            std::optional<unsigned> regionFrom,
+                                            std::optional<unsigned> regionTo,
+                                            const NextAccess &after,
+                                            NextAccess *before) override;
+
   // TODO: this isn't ideal for the analysis. When there is no next access, it
   // means "we don't know what the next access is" rather than "there is no next
   // access". But it's unclear how to 
diff erentiate the two cases...
@@ -78,18 +93,53 @@ void NextAccessAnalysis::visitOperation(Operation *op, const NextAccess &after,
     if (!value)
       return setToExitState(before);
 
+    // If cannot find the most underlying value, we cannot assume anything about
+    // the next accesses.
     value = UnderlyingValueAnalysis::getMostUnderlyingValue(
         value, [&](Value value) {
           return getOrCreateFor<UnderlyingValueLattice>(op, value);
         });
     if (!value)
-      return;
+      return setToExitState(before);
 
     result |= before->set(value, op);
   }
   propagateIfChanged(before, result);
 }
 
+void NextAccessAnalysis::visitCallControlFlowTransfer(
+    CallOpInterface call, CallControlFlowAction action, const NextAccess &after,
+    NextAccess *before) {
+  auto testCallAndStore =
+      dyn_cast<::test::TestCallAndStoreOp>(call.getOperation());
+  if (testCallAndStore && ((action == CallControlFlowAction::EnterCallee &&
+                            testCallAndStore.getStoreBeforeCall()) ||
+                           (action == CallControlFlowAction::ExitCallee &&
+                            !testCallAndStore.getStoreBeforeCall()))) {
+    visitOperation(call, after, before);
+  } else {
+    AbstractDenseBackwardDataFlowAnalysis::visitCallControlFlowTransfer(
+        call, action, after, before);
+  }
+}
+
+void NextAccessAnalysis::visitRegionBranchControlFlowTransfer(
+    RegionBranchOpInterface branch, std::optional<unsigned> regionFrom,
+    std::optional<unsigned> regionTo, const NextAccess &after,
+    NextAccess *before) {
+  auto testStoreWithARegion =
+      dyn_cast<::test::TestStoreWithARegion>(branch.getOperation());
+
+  if (testStoreWithARegion &&
+      ((!regionTo && !testStoreWithARegion.getStoreBeforeRegion()) ||
+       (!regionFrom && testStoreWithARegion.getStoreBeforeRegion()))) {
+    visitOperation(branch, static_cast<const NextAccess &>(after),
+                   static_cast<NextAccess *>(before));
+  } else {
+    propagateIfChanged(before, before->meet(after));
+  }
+}
+
 namespace {
 struct TestNextAccessPass
     : public PassWrapper<TestNextAccessPass, OperationPass<>> {
@@ -99,6 +149,45 @@ struct TestNextAccessPass
 
   static constexpr llvm::StringLiteral kTagAttrName = "name";
   static constexpr llvm::StringLiteral kNextAccessAttrName = "next_access";
+  static constexpr llvm::StringLiteral kAtEntryPointAttrName =
+      "next_at_entry_point";
+
+  static Attribute makeNextAccessAttribute(Operation *op,
+                                           const DataFlowSolver &solver,
+                                           const NextAccess *nextAccess) {
+    if (!nextAccess)
+      return StringAttr::get(op->getContext(), "not computed");
+
+    SmallVector<Attribute> attrs;
+    for (Value operand : op->getOperands()) {
+      Value value = UnderlyingValueAnalysis::getMostUnderlyingValue(
+          operand, [&](Value value) {
+            return solver.lookupState<UnderlyingValueLattice>(value);
+          });
+      std::optional<ArrayRef<Operation *>> nextAcc =
+          nextAccess->getAdjacentAccess(value);
+      if (!nextAcc) {
+        attrs.push_back(StringAttr::get(op->getContext(), "unknown"));
+        continue;
+      }
+
+      SmallVector<Attribute> innerAttrs;
+      innerAttrs.reserve(nextAcc->size());
+      for (Operation *nextAccOp : *nextAcc) {
+        if (auto nextAccTag =
+                nextAccOp->getAttrOfType<StringAttr>(kTagAttrName)) {
+          innerAttrs.push_back(nextAccTag);
+          continue;
+        }
+        std::string repr;
+        llvm::raw_string_ostream os(repr);
+        nextAccOp->print(os);
+        innerAttrs.push_back(StringAttr::get(op->getContext(), os.str()));
+      }
+      attrs.push_back(ArrayAttr::get(op->getContext(), innerAttrs));
+    }
+    return ArrayAttr::get(op->getContext(), attrs);
+  }
 
   void runOnOperation() override {
     Operation *op = getOperation();
@@ -113,7 +202,6 @@ struct TestNextAccessPass
       emitError(op->getLoc(), "dataflow solver failed");
       return signalPassFailure();
     }
-
     op->walk([&](Operation *op) {
       auto tag = op->getAttrOfType<StringAttr>(kTagAttrName);
       if (!tag)
@@ -122,42 +210,28 @@ struct TestNextAccessPass
       const NextAccess *nextAccess = solver.lookupState<NextAccess>(
           op->getNextNode() == nullptr ? ProgramPoint(op->getBlock())
                                        : op->getNextNode());
-      if (!nextAccess) {
-        op->setAttr(kNextAccessAttrName,
-                    StringAttr::get(op->getContext(), "not computed"));
+      op->setAttr(kNextAccessAttrName,
+                  makeNextAccessAttribute(op, solver, nextAccess));
+
+      auto iface = dyn_cast<RegionBranchOpInterface>(op);
+      if (!iface)
         return;
-      }
 
-      SmallVector<Attribute> attrs;
-      for (Value operand : op->getOperands()) {
-        Value value = UnderlyingValueAnalysis::getMostUnderlyingValue(
-            operand, [&](Value value) {
-              return solver.lookupState<UnderlyingValueLattice>(value);
-            });
-        std::optional<ArrayRef<Operation *>> nextAcc =
-            nextAccess->getAdjacentAccess(value);
-        if (!nextAcc) {
-          attrs.push_back(StringAttr::get(op->getContext(), "unknown"));
+      SmallVector<Attribute> entryPointNextAccess;
+      SmallVector<RegionSuccessor> regionSuccessors;
+      iface.getSuccessorRegions(std::nullopt, regionSuccessors);
+      for (const RegionSuccessor &successor : regionSuccessors) {
+        if (!successor.getSuccessor() || successor.getSuccessor()->empty())
           continue;
-        }
-
-        SmallVector<Attribute> innerAttrs;
-        innerAttrs.reserve(nextAcc->size());
-        for (Operation *nextAccOp : *nextAcc) {
-          if (auto nextAccTag =
-                  nextAccOp->getAttrOfType<StringAttr>(kTagAttrName)) {
-            innerAttrs.push_back(nextAccTag);
-            continue;
-          }
-          std::string repr;
-          llvm::raw_string_ostream os(repr);
-          nextAccOp->print(os);
-          innerAttrs.push_back(StringAttr::get(op->getContext(), os.str()));
-        }
-        attrs.push_back(ArrayAttr::get(op->getContext(), innerAttrs));
+        Block &successorBlock = successor.getSuccessor()->front();
+        ProgramPoint successorPoint = successorBlock.empty()
+                                          ? ProgramPoint(&successorBlock)
+                                          : &successorBlock.front();
+        entryPointNextAccess.push_back(makeNextAccessAttribute(
+            op, solver, solver.lookupState<NextAccess>(successorPoint)));
       }
-
-      op->setAttr(kNextAccessAttrName, ArrayAttr::get(op->getContext(), attrs));
+      op->setAttr(kAtEntryPointAttrName,
+                  ArrayAttr::get(op->getContext(), entryPointNextAccess));
     });
   }
 };

diff  --git a/mlir/test/lib/Analysis/DataFlow/TestDenseDataFlowAnalysis.cpp b/mlir/test/lib/Analysis/DataFlow/TestDenseDataFlowAnalysis.cpp
index fbfc7307a4b084..2519331b33a2d3 100644
--- a/mlir/test/lib/Analysis/DataFlow/TestDenseDataFlowAnalysis.cpp
+++ b/mlir/test/lib/Analysis/DataFlow/TestDenseDataFlowAnalysis.cpp
@@ -7,6 +7,7 @@
 //===----------------------------------------------------------------------===//
 
 #include "TestDenseDataFlowAnalysis.h"
+#include "TestDialect.h"
 #include "mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h"
 #include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h"
 #include "mlir/Analysis/DataFlow/DenseAnalysis.h"
@@ -22,8 +23,7 @@ namespace {
 
 /// This lattice represents, for a given memory resource, the potential last
 /// operations that modified the resource.
-class LastModification : public AbstractDenseLattice,
-                         public test::AccessLatticeBase {
+class LastModification : public AbstractDenseLattice, public AccessLatticeBase {
 public:
   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(LastModification)
 
@@ -31,7 +31,7 @@ class LastModification : public AbstractDenseLattice,
 
   /// Join the last modifications.
   ChangeResult join(const AbstractDenseLattice &lattice) override {
-    return AccessLatticeBase::merge(static_cast<test::AccessLatticeBase>(
+    return AccessLatticeBase::merge(static_cast<AccessLatticeBase>(
         static_cast<const LastModification &>(lattice)));
   }
 
@@ -51,6 +51,17 @@ class LastModifiedAnalysis : public DenseDataFlowAnalysis<LastModification> {
   void visitOperation(Operation *op, const LastModification &before,
                       LastModification *after) override;
 
+  void visitCallControlFlowTransfer(CallOpInterface call,
+                                    CallControlFlowAction action,
+                                    const LastModification &before,
+                                    LastModification *after) override;
+
+  void visitRegionBranchControlFlowTransfer(RegionBranchOpInterface branch,
+                                            std::optional<unsigned> regionFrom,
+                                            std::optional<unsigned> regionTo,
+                                            const LastModification &before,
+                                            LastModification *after) override;
+
   /// At an entry point, the last modifications of all memory resources are
   /// unknown.
   void setToEntryState(LastModification *lattice) override {
@@ -80,12 +91,14 @@ void LastModifiedAnalysis::visitOperation(Operation *op,
     if (!value)
       return setToEntryState(after);
 
+    // If we cannot find the underlying value, we shouldn't just propagate the
+    // effects through, return the pessimistic state.
     value = UnderlyingValueAnalysis::getMostUnderlyingValue(
         value, [&](Value value) {
           return getOrCreateFor<UnderlyingValueLattice>(op, value);
         });
     if (!value)
-      return;
+      return setToEntryState(after);
 
     // Nothing to do for reads.
     if (isa<MemoryEffects::Read>(effect.getEffect()))
@@ -96,6 +109,36 @@ void LastModifiedAnalysis::visitOperation(Operation *op,
   propagateIfChanged(after, result);
 }
 
+void LastModifiedAnalysis::visitCallControlFlowTransfer(
+    CallOpInterface call, CallControlFlowAction action,
+    const LastModification &before, LastModification *after) {
+  auto testCallAndStore =
+      dyn_cast<::test::TestCallAndStoreOp>(call.getOperation());
+  if (testCallAndStore && ((action == CallControlFlowAction::EnterCallee &&
+                            testCallAndStore.getStoreBeforeCall()) ||
+                           (action == CallControlFlowAction::ExitCallee &&
+                            !testCallAndStore.getStoreBeforeCall()))) {
+    return visitOperation(call, before, after);
+  }
+  AbstractDenseDataFlowAnalysis::visitCallControlFlowTransfer(call, action,
+                                                              before, after);
+}
+
+void LastModifiedAnalysis::visitRegionBranchControlFlowTransfer(
+    RegionBranchOpInterface branch, std::optional<unsigned> regionFrom,
+    std::optional<unsigned> regionTo, const LastModification &before,
+    LastModification *after) {
+  auto testStoreWithARegion =
+      dyn_cast<::test::TestStoreWithARegion>(branch.getOperation());
+  if (testStoreWithARegion &&
+      ((!regionTo && !testStoreWithARegion.getStoreBeforeRegion()) ||
+       (!regionFrom && testStoreWithARegion.getStoreBeforeRegion()))) {
+    return visitOperation(branch, before, after);
+  }
+  AbstractDenseDataFlowAnalysis::visitRegionBranchControlFlowTransfer(
+      branch, regionFrom, regionTo, before, after);
+}
+
 namespace {
 struct TestLastModifiedPass
     : public PassWrapper<TestLastModifiedPass, OperationPass<>> {

diff  --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp
index 03aeac4c9dff80..072f6ff4b84d33 100644
--- a/mlir/test/lib/Dialect/Test/TestDialect.cpp
+++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp
@@ -26,6 +26,7 @@
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/IR/TypeUtilities.h"
 #include "mlir/IR/Verifier.h"
+#include "mlir/Interfaces/CallInterfaces.h"
 #include "mlir/Interfaces/InferIntRangeInterface.h"
 #include "mlir/Reducer/ReductionPatternInterface.h"
 #include "mlir/Support/LogicalResult.h"
@@ -2014,6 +2015,37 @@ static void printSumProperty(OpAsmPrinter &printer, Operation *op,
   printer << second << " = " << (second + first);
 }
 
+//===----------------------------------------------------------------------===//
+// Test Dataflow
+//===----------------------------------------------------------------------===//
+
+CallInterfaceCallable TestCallAndStoreOp::getCallableForCallee() {
+  return getCallee();
+}
+
+void TestCallAndStoreOp::setCalleeFromCallable(CallInterfaceCallable callee) {
+  setCalleeAttr(callee.get<SymbolRefAttr>());
+}
+
+Operation::operand_range TestCallAndStoreOp::getArgOperands() {
+  return getCalleeOperands();
+}
+
+void TestStoreWithARegion::getSuccessorRegions(
+    std::optional<unsigned> index, ArrayRef<Attribute> operands,
+    SmallVectorImpl<RegionSuccessor> &regions) {
+  if (!index) {
+    regions.emplace_back(&getBody(), getBody().front().getArguments());
+  } else {
+    regions.emplace_back();
+  }
+}
+
+MutableOperandRange TestStoreWithARegionTerminator::getMutableSuccessorOperands(
+    std::optional<unsigned> index) {
+  return MutableOperandRange(getOperation());
+}
+
 #include "TestOpEnums.cpp.inc"
 #include "TestOpInterfaces.cpp.inc"
 #include "TestTypeInterfaces.cpp.inc"

diff  --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index 6cb47451ec8b70..8056f6d7e03183 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -3443,6 +3443,41 @@ def TestOpWithNiceProperties : TEST_Op<"with_nice_properties"> {
   }];
 }
 
+//===----------------------------------------------------------------------===//
+// Test Dataflow
+//===----------------------------------------------------------------------===//
+
+def TestCallAndStoreOp : TEST_Op<"call_and_store",
+    [DeclareOpInterfaceMethods<CallOpInterface>]> {
+  let arguments = (ins
+    SymbolRefAttr:$callee,
+    Arg<AnyMemRef, "", [MemWrite]>:$address,
+    Variadic<AnyType>:$callee_operands,
+    BoolAttr:$store_before_call
+  );
+  let results = (outs
+    Variadic<AnyType>:$results
+  );
+  let assemblyFormat =
+    "$callee `(` $callee_operands `)` `,` $address attr-dict "
+    "`:` functional-type(operands, results)";
+}
 
+def TestStoreWithARegion : TEST_Op<"store_with_a_region",
+    [DeclareOpInterfaceMethods<RegionBranchOpInterface>,
+     SingleBlock]> {
+  let arguments = (ins 
+    Arg<AnyMemRef, "", [MemWrite]>:$address,
+    BoolAttr:$store_before_region
+  );
+  let regions = (region AnyRegion:$body);
+  let assemblyFormat =
+    "$address attr-dict-with-keyword regions `:` type($address)";
+}
+
+def TestStoreWithARegionTerminator : TEST_Op<"store_with_a_region_terminator",
+    [DeclareOpInterfaceMethods<RegionBranchTerminatorOpInterface>, Terminator, NoMemoryEffect]> {
+  let assemblyFormat = "attr-dict";
+}
 
 #endif // TEST_OPS


        


More information about the Mlir-commits mailing list