[Mlir-commits] [mlir] [MLIR][Affine] Extend/generalize MDG to properly add edges between non-affine ops (PR #125451)

Uday Bondhugula llvmlistbot at llvm.org
Sun Feb 2 19:15:56 PST 2025


https://github.com/bondhugula updated https://github.com/llvm/llvm-project/pull/125451

>From 2a12a212070e1c1bef35d4b932133b5ee113d5c9 Mon Sep 17 00:00:00 2001
From: Uday Bondhugula <uday at polymagelabs.com>
Date: Fri, 10 Jan 2025 12:47:47 +0530
Subject: [PATCH] [MLIR][Affine] Extend/generalize MDG to properly add edges
 between non-affine ops

Drop arbitrary checks and hacks from affine fusion MDG construction and
handle all ops using memory read/write effects. This has been a long
pending change and it now makes affine fusion more powerful in the
presence of non-affine ops and does not limit fusion in parts of the
block where it is feasible simply because of non-affine ops elsewhere or
intervening non-affine users.

Populate memref read and write ops in non-affine region holding ops and
non-affine ops at the top level of the Block properly; add the
appropriate edges to MDG. Use memory read-write effects and drop
assumptions and special handling of ops due to historic reasons.

Update MDG to drop unnecessary "unhandled region" hack. This hack is no
longer needed with the update to fully and properly construct the MDG.

MDG edges now capture dependences between nodes completely. Drop
non-affine users check. With the MDG generalization to properly include edges
between non-affine nodes/operations, the non-affine users on path check
in fusion is no longer needed. Add more test cases to exercise MDG
generalization.

Drop unnecessary failure when encountering side-effect-free affine.if
ops.

Improve documentation on MDG.
---
 .../mlir/Dialect/Affine/Analysis/Utils.h      |  60 ++++-
 mlir/lib/Dialect/Affine/Analysis/Utils.cpp    | 226 ++++++++++++++----
 .../Dialect/Affine/Transforms/LoopFusion.cpp  |  77 +-----
 mlir/test/Dialect/Affine/loop-fusion-3.mlir   | 178 +++++++++++++-
 .../Dialect/Affine/loop-fusion-inner.mlir     |   8 +-
 mlir/test/Dialect/Affine/loop-fusion.mlir     |  12 +-
 6 files changed, 419 insertions(+), 142 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Affine/Analysis/Utils.h b/mlir/include/mlir/Dialect/Affine/Analysis/Utils.h
index b8f354892ee60a..d78d4432578683 100644
--- a/mlir/include/mlir/Dialect/Affine/Analysis/Utils.h
+++ b/mlir/include/mlir/Dialect/Affine/Analysis/Utils.h
@@ -37,9 +37,16 @@ struct MemRefAccess;
 // was encountered in the loop nest.
 struct LoopNestStateCollector {
   SmallVector<AffineForOp, 4> forOps;
+  // Affine loads.
   SmallVector<Operation *, 4> loadOpInsts;
+  // Affine stores.
   SmallVector<Operation *, 4> storeOpInsts;
-  bool hasNonAffineRegionOp = false;
+  // Non-affine loads.
+  SmallVector<Operation *, 4> memrefLoads;
+  // Non-affine stores.
+  SmallVector<Operation *, 4> memrefStores;
+  // Free operations.
+  SmallVector<Operation *, 4> memrefFrees;
 
   // Collects load and store operations, and whether or not a region holding op
   // other than ForOp and IfOp was encountered in the loop nest.
@@ -47,9 +54,15 @@ struct LoopNestStateCollector {
 };
 
 // MemRefDependenceGraph is a graph data structure where graph nodes are
-// top-level operations in a `Block` which contain load/store ops, and edges
-// are memref dependences between the nodes.
-// TODO: Add a more flexible dependence graph representation.
+// top-level operations in a `Block` and edges are memref dependences or SSA
+// dependences (on memrefs) between the nodes. Nodes are created for all
+// top-level operations except in certain cases (see `init` method). Edges are
+// created between nodes with a dependence (see `Edge` documentation). Edges
+// aren't created from/to nodes that have no memory effects. This strucuture
+// also supports checkpointing the current state and reverting to the last
+// committed state. Note that we maintain only one committed state and hence
+// it's not possible to recover a commit state other than the last committed
+// state.
 struct MemRefDependenceGraph {
 public:
   // Node represents a node in the graph. A Node is either an entire loop nest
@@ -60,10 +73,18 @@ struct MemRefDependenceGraph {
     unsigned id;
     // The top-level statement which is (or contains) a load/store.
     Operation *op;
-    // List of load operations.
+    // List of affine loads.
     SmallVector<Operation *, 4> loads;
-    // List of store op insts.
+    // List of non-affine loads.
+    SmallVector<Operation *, 4> memrefLoads;
+    // List of affine store ops.
     SmallVector<Operation *, 4> stores;
+    // List of non-affine stores.
+    SmallVector<Operation *, 4> memrefStores;
+    // List of free operations.
+    SmallVector<Operation *, 4> memrefFrees;
+    // Set of private memrefs used in this node.
+    DenseSet<Value> privateMemrefs;
 
     Node(unsigned id, Operation *op) : id(id), op(op) {}
 
@@ -73,6 +94,13 @@ struct MemRefDependenceGraph {
     // Returns the store op count for 'memref'.
     unsigned getStoreOpCount(Value memref) const;
 
+    /// Returns true if there exists an operation with a write memory effect to
+    /// `memref` in this node.
+    unsigned hasStore(Value memref) const;
+
+    // Returns true if the node has a free op on `memref`.
+    unsigned hasFree(Value memref) const;
+
     // Returns all store ops in 'storeOps' which access 'memref'.
     void getStoreOpsForMemref(Value memref,
                               SmallVectorImpl<Operation *> *storeOps) const;
@@ -86,7 +114,16 @@ struct MemRefDependenceGraph {
     void getLoadAndStoreMemrefSet(DenseSet<Value> *loadAndStoreMemrefSet) const;
   };
 
-  // Edge represents a data dependence between nodes in the graph.
+  // Edge represents a data dependence between nodes in the graph. It can either
+  // be a memory dependence or an SSA dependence. In the former case, it
+  // corresponds to a pair of memory accesses to the same memref or aliasing
+  // memrefs where at least one of them has a write or free memory effect. The
+  // memory accesses need not be affine load/store operations. Operations are
+  // checked for read/write effects and edges may be added conservatively. Edges
+  // are not created to/from nodes that have no memory effect. An exception to
+  // this are SSA dependences between operations that define memrefs (like
+  // alloc's, view-like ops) and their memory-effecting users that are enclosed
+  // in loops.
   struct Edge {
     // The id of the node at the other end of the edge.
     // If this edge is stored in Edge = Node.inEdges[i], then
@@ -182,9 +219,12 @@ struct MemRefDependenceGraph {
   // of sibling node 'sibId' into node 'dstId'.
   void updateEdges(unsigned sibId, unsigned dstId);
 
-  // Adds ops in 'loads' and 'stores' to node at 'id'.
-  void addToNode(unsigned id, const SmallVectorImpl<Operation *> &loads,
-                 const SmallVectorImpl<Operation *> &stores);
+  // Adds the specified ops to lists of node at 'id'.
+  void addToNode(unsigned id, ArrayRef<Operation *> loads,
+                 ArrayRef<Operation *> stores,
+                 ArrayRef<Operation *> memrefLoads,
+                 ArrayRef<Operation *> memrefStores,
+                 ArrayRef<Operation *> memrefFrees);
 
   void clearNodeLoadAndStores(unsigned id);
 
diff --git a/mlir/lib/Dialect/Affine/Analysis/Utils.cpp b/mlir/lib/Dialect/Affine/Analysis/Utils.cpp
index 29608647d85746..d87c6a1b5141e8 100644
--- a/mlir/lib/Dialect/Affine/Analysis/Utils.cpp
+++ b/mlir/lib/Dialect/Affine/Analysis/Utils.cpp
@@ -42,23 +42,48 @@ using Node = MemRefDependenceGraph::Node;
 // was encountered in the loop nest.
 void LoopNestStateCollector::collect(Operation *opToWalk) {
   opToWalk->walk([&](Operation *op) {
-    if (isa<AffineForOp>(op))
-      forOps.push_back(cast<AffineForOp>(op));
-    else if (op->getNumRegions() != 0 && !isa<AffineIfOp>(op))
-      hasNonAffineRegionOp = true;
-    else if (isa<AffineReadOpInterface>(op))
+    if (auto forOp = dyn_cast<AffineForOp>(op)) {
+      forOps.push_back(forOp);
+    } else if (isa<AffineReadOpInterface>(op)) {
       loadOpInsts.push_back(op);
-    else if (isa<AffineWriteOpInterface>(op))
+    } else if (isa<AffineWriteOpInterface>(op)) {
       storeOpInsts.push_back(op);
+    } else {
+      auto memInterface = dyn_cast<MemoryEffectOpInterface>(op);
+      if (!memInterface) {
+        if (op->hasTrait<OpTrait::HasRecursiveMemoryEffects>())
+          // This op itself is memory-effect free.
+          return;
+        // Check operands. Eg. ops like the `call` op are handled here.
+        for (Value v : op->getOperands()) {
+          if (!isa<MemRefType>(v.getType()))
+            continue;
+          // Conservatively, we assume the memref is read and written to.
+          memrefLoads.push_back(op);
+          memrefStores.push_back(op);
+        }
+      } else {
+        // Non-affine loads and stores.
+        if (hasEffect<MemoryEffects::Read>(op))
+          memrefLoads.push_back(op);
+        if (hasEffect<MemoryEffects::Write>(op))
+          memrefStores.push_back(op);
+        if (hasEffect<MemoryEffects::Free>(op))
+          memrefFrees.push_back(op);
+      }
+    }
   });
 }
 
-// Returns the load op count for 'memref'.
 unsigned Node::getLoadOpCount(Value memref) const {
   unsigned loadOpCount = 0;
   for (Operation *loadOp : loads) {
-    if (memref == cast<AffineReadOpInterface>(loadOp).getMemRef())
+    if (auto affineLoad = dyn_cast<AffineReadOpInterface>(loadOp)) {
+      if (memref == affineLoad.getMemRef())
+        ++loadOpCount;
+    } else if (hasEffect<MemoryEffects::Read>(loadOp, memref)) {
       ++loadOpCount;
+    }
   }
   return loadOpCount;
 }
@@ -66,13 +91,39 @@ unsigned Node::getLoadOpCount(Value memref) const {
 // Returns the store op count for 'memref'.
 unsigned Node::getStoreOpCount(Value memref) const {
   unsigned storeOpCount = 0;
-  for (Operation *storeOp : stores) {
-    if (memref == cast<AffineWriteOpInterface>(storeOp).getMemRef())
+  for (auto *storeOp : llvm::concat<Operation *const>(stores, memrefStores)) {
+    if (auto affineStore = dyn_cast<AffineWriteOpInterface>(storeOp)) {
+      if (memref == affineStore.getMemRef())
+        ++storeOpCount;
+    } else if (hasEffect<MemoryEffects::Write>(const_cast<Operation *>(storeOp),
+                                               memref)) {
       ++storeOpCount;
+    }
   }
   return storeOpCount;
 }
 
+// Returns the store op count for 'memref'.
+unsigned Node::hasStore(Value memref) const {
+  return llvm::any_of(
+      llvm::concat<Operation *const>(stores, memrefStores),
+      [&](Operation *storeOp) {
+        if (auto affineStore = dyn_cast<AffineWriteOpInterface>(storeOp)) {
+          if (memref == affineStore.getMemRef())
+            return true;
+        } else if (hasEffect<MemoryEffects::Write>(storeOp, memref)) {
+          return true;
+        }
+        return false;
+      });
+}
+
+unsigned Node::hasFree(Value memref) const {
+  return llvm::any_of(memrefFrees, [&](Operation *freeOp) {
+    return hasEffect<MemoryEffects::Free>(freeOp, memref);
+  });
+}
+
 // Returns all store ops in 'storeOps' which access 'memref'.
 void Node::getStoreOpsForMemref(Value memref,
                                 SmallVectorImpl<Operation *> *storeOps) const {
@@ -106,8 +157,88 @@ void Node::getLoadAndStoreMemrefSet(
   }
 }
 
-// Initializes the data dependence graph by walking operations in `block`.
-// Assigns each node in the graph a node id based on program order in 'f'.
+/// Returns the values that this op has a memref effect of type `EffectTys` on,
+/// not considering recursive effects.
+template <typename... EffectTys>
+static void getEffectedValues(Operation *op, SmallVectorImpl<Value> &values) {
+  auto memOp = dyn_cast<MemoryEffectOpInterface>(op);
+  if (!memOp) {
+    if (op->hasTrait<OpTrait::HasRecursiveMemoryEffects>())
+      // No effects.
+      return;
+    // Memref operands have to be considered as being affected.
+    for (Value operand : op->getOperands()) {
+      if (isa<MemRefType>(operand.getType()))
+        values.push_back(operand);
+    }
+    return;
+  }
+  SmallVector<SideEffects::EffectInstance<MemoryEffects::Effect>, 4> effects;
+  memOp.getEffects(effects);
+  for (auto &effect : effects) {
+    Value effectVal = effect.getValue();
+    if (isa<EffectTys...>(effect.getEffect()) && effectVal &&
+        isa<MemRefType>(effectVal.getType()))
+      values.push_back(effectVal);
+  };
+}
+
+/// Add `op` to MDG creating a new node and adding its memory accesses (affine
+/// or non-affine to memrefAccesses (memref -> list of nodes with accesses) map.
+Node *addNodeToMDG(Operation *nodeOp, MemRefDependenceGraph &mdg,
+                   DenseMap<Value, SetVector<unsigned>> &memrefAccesses) {
+  auto &nodes = mdg.nodes;
+  // Create graph node 'id' to represent top-level 'forOp' and record
+  // all loads and store accesses it contains.
+  LoopNestStateCollector collector;
+  collector.collect(nodeOp);
+  unsigned newNodeId = mdg.nextNodeId++;
+  Node &node = nodes.insert({newNodeId, Node(newNodeId, nodeOp)}).first->second;
+  for (Operation *op : collector.loadOpInsts) {
+    node.loads.push_back(op);
+    auto memref = cast<AffineReadOpInterface>(op).getMemRef();
+    memrefAccesses[memref].insert(node.id);
+  }
+  for (Operation *opInst : collector.storeOpInsts) {
+    node.stores.push_back(opInst);
+    auto memref = cast<AffineWriteOpInterface>(opInst).getMemRef();
+    memrefAccesses[memref].insert(node.id);
+  }
+  for (Operation *op : collector.memrefLoads) {
+    SmallVector<Value> effectedValues;
+    getEffectedValues<MemoryEffects::Read>(op, effectedValues);
+    if (llvm::any_of(((ValueRange)effectedValues).getTypes(),
+                     [](Type type) { return !isa<MemRefType>(type); }))
+      // We do not the interaction here.
+      return nullptr;
+    for (Value memref : effectedValues)
+      memrefAccesses[memref].insert(node.id);
+    node.memrefLoads.push_back(op);
+  }
+  for (Operation *op : collector.memrefStores) {
+    SmallVector<Value> effectedValues;
+    getEffectedValues<MemoryEffects::Write>(op, effectedValues);
+    if (llvm::any_of((ValueRange(effectedValues)).getTypes(),
+                     [](Type type) { return !isa<MemRefType>(type); }))
+      return nullptr;
+    for (Value memref : effectedValues)
+      memrefAccesses[memref].insert(node.id);
+    node.memrefStores.push_back(op);
+  }
+  for (Operation *op : collector.memrefFrees) {
+    SmallVector<Value> effectedValues;
+    getEffectedValues<MemoryEffects::Free>(op, effectedValues);
+    if (llvm::any_of((ValueRange(effectedValues)).getTypes(),
+                     [](Type type) { return !isa<MemRefType>(type); }))
+      return nullptr;
+    for (Value memref : effectedValues)
+      memrefAccesses[memref].insert(node.id);
+    node.memrefFrees.push_back(op);
+  }
+
+  return &node;
+}
+
 bool MemRefDependenceGraph::init() {
   LLVM_DEBUG(llvm::dbgs() << "--- Initializing MDG ---\n");
   // Map from a memref to the set of ids of the nodes that have ops accessing
@@ -116,36 +247,19 @@ bool MemRefDependenceGraph::init() {
 
   DenseMap<Operation *, unsigned> forToNodeMap;
   for (Operation &op : block) {
-    if (dyn_cast<AffineForOp>(op)) {
-      // Create graph node 'id' to represent top-level 'forOp' and record
-      // all loads and store accesses it contains.
-      LoopNestStateCollector collector;
-      collector.collect(&op);
-      // Return false if a region holding op other than 'affine.for' and
-      // 'affine.if' was found (not currently supported).
-      if (collector.hasNonAffineRegionOp)
+    if (auto forOp = dyn_cast<AffineForOp>(op)) {
+      Node *node = addNodeToMDG(&op, *this, memrefAccesses);
+      if (!node)
         return false;
-      Node node(nextNodeId++, &op);
-      for (auto *opInst : collector.loadOpInsts) {
-        node.loads.push_back(opInst);
-        auto memref = cast<AffineReadOpInterface>(opInst).getMemRef();
-        memrefAccesses[memref].insert(node.id);
-      }
-      for (auto *opInst : collector.storeOpInsts) {
-        node.stores.push_back(opInst);
-        auto memref = cast<AffineWriteOpInterface>(opInst).getMemRef();
-        memrefAccesses[memref].insert(node.id);
-      }
-      forToNodeMap[&op] = node.id;
-      nodes.insert({node.id, node});
-    } else if (dyn_cast<AffineReadOpInterface>(op)) {
+      forToNodeMap[&op] = node->id;
+    } else if (isa<AffineReadOpInterface>(op)) {
       // Create graph node for top-level load op.
       Node node(nextNodeId++, &op);
       node.loads.push_back(&op);
       auto memref = cast<AffineReadOpInterface>(op).getMemRef();
       memrefAccesses[memref].insert(node.id);
       nodes.insert({node.id, node});
-    } else if (dyn_cast<AffineWriteOpInterface>(op)) {
+    } else if (isa<AffineWriteOpInterface>(op)) {
       // Create graph node for top-level store op.
       Node node(nextNodeId++, &op);
       node.stores.push_back(&op);
@@ -155,8 +269,9 @@ bool MemRefDependenceGraph::init() {
     } else if (op.getNumResults() > 0 && !op.use_empty()) {
       // Create graph node for top-level producer of SSA values, which
       // could be used by loop nest nodes.
-      Node node(nextNodeId++, &op);
-      nodes.insert({node.id, node});
+      Node *node = addNodeToMDG(&op, *this, memrefAccesses);
+      if (!node)
+        return false;
     } else if (!isMemoryEffectFree(&op) &&
                (op.getNumRegions() == 0 || isa<RegionBranchOpInterface>(op))) {
       // Create graph node for top-level op unless it is known to be
@@ -165,9 +280,10 @@ bool MemRefDependenceGraph::init() {
       // well-defined control flow. During the fusion validity checks, we look
       // for non-affine ops on the path from source to destination, at which
       // point we check which memrefs if any are used in the region.
-      Node node(nextNodeId++, &op);
-      nodes.insert({node.id, node});
-    } else if (op.getNumRegions() != 0) {
+      Node *node = addNodeToMDG(&op, *this, memrefAccesses);
+      if (!node)
+        return false;
+    } else if (op.getNumRegions() != 0 && !isa<RegionBranchOpInterface>(op)) {
       // Return false if non-handled/unknown region-holding ops are found. We
       // won't know what such ops do or what its regions mean; for e.g., it may
       // not be an imperative op.
@@ -175,6 +291,9 @@ bool MemRefDependenceGraph::init() {
                  << "MDG init failed; unknown region-holding op found!\n");
       return false;
     }
+    // We aren't creating nodes for memory-effect free ops either with no
+    // regions (unless it has results being used) or those with branch op
+    // interface.
   }
 
   for (auto &idAndNode : nodes) {
@@ -216,16 +335,20 @@ bool MemRefDependenceGraph::init() {
   // Walk memref access lists and add graph edges between dependent nodes.
   for (auto &memrefAndList : memrefAccesses) {
     unsigned n = memrefAndList.second.size();
+    Value srcMemRef = memrefAndList.first;
+    // Add edges between all dependent pairs among the node IDs on this memref.
     for (unsigned i = 0; i < n; ++i) {
       unsigned srcId = memrefAndList.second[i];
-      bool srcHasStore =
-          getNode(srcId)->getStoreOpCount(memrefAndList.first) > 0;
+      Node *srcNode = getNode(srcId);
+      bool srcHasStoreOrFree =
+          srcNode->hasStore(srcMemRef) || srcNode->hasFree(srcMemRef);
       for (unsigned j = i + 1; j < n; ++j) {
         unsigned dstId = memrefAndList.second[j];
-        bool dstHasStore =
-            getNode(dstId)->getStoreOpCount(memrefAndList.first) > 0;
-        if (srcHasStore || dstHasStore)
-          addEdge(srcId, dstId, memrefAndList.first);
+        Node *dstNode = getNode(dstId);
+        bool dstHasStoreOrFree =
+            dstNode->hasStore(srcMemRef) || dstNode->hasFree(srcMemRef);
+        if (srcHasStoreOrFree || dstHasStoreOrFree)
+          addEdge(srcId, dstId, srcMemRef);
       }
     }
   }
@@ -565,12 +688,17 @@ void MemRefDependenceGraph::updateEdges(unsigned sibId, unsigned dstId) {
 }
 
 // Adds ops in 'loads' and 'stores' to node at 'id'.
-void MemRefDependenceGraph::addToNode(
-    unsigned id, const SmallVectorImpl<Operation *> &loads,
-    const SmallVectorImpl<Operation *> &stores) {
+void MemRefDependenceGraph::addToNode(unsigned id, ArrayRef<Operation *> loads,
+                                      ArrayRef<Operation *> stores,
+                                      ArrayRef<Operation *> memrefLoads,
+                                      ArrayRef<Operation *> memrefStores,
+                                      ArrayRef<Operation *> memrefFrees) {
   Node *node = getNode(id);
   llvm::append_range(node->loads, loads);
   llvm::append_range(node->stores, stores);
+  llvm::append_range(node->memrefLoads, memrefLoads);
+  llvm::append_range(node->memrefStores, memrefStores);
+  llvm::append_range(node->memrefFrees, memrefFrees);
 }
 
 void MemRefDependenceGraph::clearNodeLoadAndStores(unsigned id) {
diff --git a/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp b/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp
index 6fefe4487ef59a..c22ec213be95c8 100644
--- a/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp
+++ b/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp
@@ -343,51 +343,6 @@ static Value createPrivateMemRef(AffineForOp forOp, Operation *srcStoreOpInst,
   return newMemRef;
 }
 
-/// Returns true if there are any non-affine uses of `memref` in any of
-/// the operations between `start` and `end` (both exclusive). Any other
-/// than affine read/write are treated as non-affine uses of `memref`.
-static bool hasNonAffineUsersOnPath(Operation *start, Operation *end,
-                                    Value memref) {
-  assert(start->getBlock() == end->getBlock());
-  assert(start->isBeforeInBlock(end) && "start expected to be before end");
-  Block *block = start->getBlock();
-  // Check if there is a non-affine memref user in any op between `start` and
-  // `end`.
-  return llvm::any_of(memref.getUsers(), [&](Operation *user) {
-    if (isa<AffineReadOpInterface, AffineWriteOpInterface>(user))
-      return false;
-    Operation *ancestor = block->findAncestorOpInBlock(*user);
-    return ancestor && start->isBeforeInBlock(ancestor) &&
-           ancestor->isBeforeInBlock(end);
-  });
-}
-
-/// Check whether a memref value used in any operation of 'src' has a
-/// non-affine operation that is between `src` and `end` (exclusive of `src`
-/// and `end`)  where `src` and `end` are expected to be in the same Block.
-/// Any other than affine read/write are treated as non-affine uses of memref.
-static bool hasNonAffineUsersOnPath(Operation *src, Operation *end) {
-  assert(src->getBlock() == end->getBlock() && "same block expected");
-
-  // Trivial case. `src` and `end` are exclusive.
-  if (src == end || end->isBeforeInBlock(src))
-    return false;
-
-  // Collect relevant memref values.
-  llvm::SmallDenseSet<Value, 2> memRefValues;
-  src->walk([&](Operation *op) {
-    for (Value v : op->getOperands())
-      // Collect memref values only.
-      if (isa<MemRefType>(v.getType()))
-        memRefValues.insert(v);
-    return WalkResult::advance();
-  });
-  // Look for non-affine users between `src` and `end`.
-  return llvm::any_of(memRefValues, [&](Value memref) {
-    return hasNonAffineUsersOnPath(src, end, memref);
-  });
-}
-
 // Checks the profitability of fusing a backwards slice of the loop nest
 // surrounding 'srcOpInst' into the loop nest surrounding 'dstLoadOpInsts'.
 // The argument 'srcStoreOpInst' is used to calculate the storage reduction on
@@ -864,19 +819,6 @@ struct GreedyFusion {
         DenseSet<Value> srcEscapingMemRefs;
         gatherEscapingMemrefs(srcNode->id, mdg, srcEscapingMemRefs);
 
-        // Skip if there are non-affine operations in between the 'srcNode'
-        // and 'dstNode' using their memrefs. If so, we wouldn't be able to
-        // compute a legal insertion point for now. 'srcNode' and 'dstNode'
-        // memrefs with non-affine operation users would be considered
-        // escaping memrefs so we can limit this check to only scenarios with
-        // escaping memrefs.
-        if (!srcEscapingMemRefs.empty() &&
-            hasNonAffineUsersOnPath(srcNode->op, dstNode->op)) {
-          LLVM_DEBUG(llvm::dbgs()
-                     << "Can't fuse: non-affine users in between the loops\n");
-          continue;
-        }
-
         // Compute an operation list insertion point for the fused loop
         // nest which preserves dependences.
         Operation *fusedLoopInsPoint =
@@ -1039,8 +981,10 @@ struct GreedyFusion {
 
         // Clear and add back loads and stores.
         mdg->clearNodeLoadAndStores(dstNode->id);
-        mdg->addToNode(dstId, dstLoopCollector.loadOpInsts,
-                       dstLoopCollector.storeOpInsts);
+        mdg->addToNode(
+            dstId, dstLoopCollector.loadOpInsts, dstLoopCollector.storeOpInsts,
+            dstLoopCollector.memrefLoads, dstLoopCollector.memrefStores,
+            dstLoopCollector.memrefFrees);
 
         if (removeSrcNode) {
           LLVM_DEBUG(llvm::dbgs()
@@ -1229,15 +1173,7 @@ struct GreedyFusion {
         storeMemrefs.insert(
             cast<AffineWriteOpInterface>(storeOpInst).getMemRef());
       }
-      if (storeMemrefs.size() > 1)
-        return false;
-
-      // Skip if a memref value in one node is used by a non-affine memref
-      // access that lies between 'dstNode' and 'sibNode'.
-      if (hasNonAffineUsersOnPath(dstNode->op, sibNode->op) ||
-          hasNonAffineUsersOnPath(sibNode->op, dstNode->op))
-        return false;
-      return true;
+      return storeMemrefs.size() <= 1;
     };
 
     // Search for siblings which load the same memref block argument.
@@ -1339,7 +1275,8 @@ struct GreedyFusion {
     // Clear and add back loads and stores
     mdg->clearNodeLoadAndStores(dstNode->id);
     mdg->addToNode(dstNode->id, dstLoopCollector.loadOpInsts,
-                   dstLoopCollector.storeOpInsts);
+                   dstLoopCollector.storeOpInsts, dstLoopCollector.memrefLoads,
+                   dstLoopCollector.memrefStores, dstLoopCollector.memrefFrees);
     // Remove old sibling loop nest if it no longer has outgoing dependence
     // edges, and it does not write to a memref which escapes the block.
     if (mdg->getOutEdgeCount(sibNode->id) == 0) {
diff --git a/mlir/test/Dialect/Affine/loop-fusion-3.mlir b/mlir/test/Dialect/Affine/loop-fusion-3.mlir
index 6bc4feadb8c98f..d291f46a519c3f 100644
--- a/mlir/test/Dialect/Affine/loop-fusion-3.mlir
+++ b/mlir/test/Dialect/Affine/loop-fusion-3.mlir
@@ -351,8 +351,157 @@ func.func @should_not_fuse_since_top_level_non_affine_mem_write_users(
 // CHECK:  affine.for
 // CHECK:    arith.addf
 
+// Tests that fusion isn't prevented by the presence of a dealloc op in
+// between since we can move the fused nest.
+
+// CHECK-LABEL: func @fuse_non_affine_intervening_op
+func.func @fuse_non_affine_intervening_op() {
+  %cst = arith.constant 0.0 : f32
+
+  %a = memref.alloc() : memref<100xf32>
+  %b = memref.alloc() : memref<100xf32>
+  %c = memref.alloc() : memref<100xf32>
+
+  affine.for %i = 0 to 100 {
+    affine.store %cst, %a[%i] : memref<100xf32>
+    affine.store %cst, %c[%i] : memref<100xf32>
+  }
+
+  // The source is fused into the destination while being moved here.
+  // CHECK:      affine.for %{{.*}} = 0 to 100
+  // CHECK-NEXT:   affine.store %cst
+  // CHECK-NEXT:   affine.store %cst{{.*}}
+  // CHECK-NEXT:   affine.load
+  // CHECK-NEXT:   affine.store
+  // CHECK-NEXT: }
+  // CHECK-NEXT: memref.dealloc
+
+  memref.dealloc %c : memref<100xf32>
+
+  affine.for %i = 0 to 100 {
+    %v = affine.load %a[%i] : memref<100xf32>
+    affine.store %v, %b[%i] : memref<100xf32>
+  }
+
+  return
+}
+
+// Tests that fusion happens in the presence of intervening non-affine reads.
+
+// CHECK-LABEL: func @fuse_non_affine_intervening_read
+func.func @fuse_non_affine_intervening_read() {
+  %cst = arith.constant 0.0 : f32
+
+  %a = memref.alloc() : memref<100xf32>
+  %b = memref.alloc() : memref<100xf32>
+  %c = memref.alloc() : memref<100xf32>
+
+  affine.for %i = 0 to 100 {
+    affine.store %cst, %a[%i] : memref<100xf32>
+  }
+
+  // The source is fused into the destination while being moved here.
+  // CHECK:      affine.for %{{.*}} = 0 to 100
+  // CHECK-NEXT:   affine.store %cst
+  // CHECK-NEXT:   affine.load
+  // CHECK-NEXT:   affine.store
+  // CHECK-NEXT: }
+
+  // CHECK:     affine.for %{{.*}} = 0 to 100
+  // CHECK-NEXT:  memref.load
+  affine.for %i = 0 to 100 {
+    memref.load %a[%i] : memref<100xf32>
+  }
+
+  affine.for %i = 0 to 100 {
+    %v = affine.load %a[%i] : memref<100xf32>
+    affine.store %v, %b[%i] : memref<100xf32>
+  }
+
+  return
+}
+
+// Tests that fusion happens in the presence of intervening non-affine region
+// ops.
+
+// CHECK-LABEL: func @fuse_non_affine_intervening_read_nest
+func.func @fuse_non_affine_intervening_read_nest() {
+  %cst = arith.constant 0.0 : f32
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %c100 = arith.constant 100 : index
+
+  %a = memref.alloc() : memref<100xf32>
+  %b = memref.alloc() : memref<100xf32>
+  %c = memref.alloc() : memref<100xf32>
+
+  affine.for %i = 0 to 100 {
+    affine.store %cst, %a[%i] : memref<100xf32>
+  }
+
+  // The source is fused into the destination while being moved here.
+  // CHECK:      affine.for %{{.*}} = 0 to 100
+  // CHECK-NEXT:   affine.store %cst
+  // CHECK-NEXT:   affine.load
+  // CHECK-NEXT:   affine.store
+  // CHECK-NEXT: }
+
+  // CHECK: scf.for
+  // CHECK-NEXT:  memref.load
+  scf.for %i = %c0 to %c100 step %c1 {
+    memref.load %a[%i] : memref<100xf32>
+  }
+
+  affine.for %i = 0 to 100 {
+    %v = affine.load %a[%i] : memref<100xf32>
+    affine.store %v, %b[%i] : memref<100xf32>
+  }
+
+  return
+}
+
+// Tests that fusion does not happen when there are non-affine sources
+// intervening.
+
+// CHECK-LABEL: func @no_fusion_scf_for_store
+func.func @no_fusion_scf_for_store() {
+  %cst = arith.constant 0.0 : f32
+  %cst1 = arith.constant 1.0 : f32
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %c100 = arith.constant 100 : index
+
+  %a = memref.alloc() : memref<100xf32>
+  %b = memref.alloc() : memref<100xf32>
+  %c = memref.alloc() : memref<100xf32>
+
+  // CHECK:      affine.for %{{.*}} = 0 to 100
+  // CHECK-NEXT:   affine.store
+  // CHECK-NEXT: }
+  affine.for %i = 0 to 100 {
+    affine.store %cst, %a[%i] : memref<100xf32>
+  }
+
+  // CHECK: scf.for
+  scf.for %i = %c0 to %c100 step %c1 {
+    memref.store %cst1, %a[%i] : memref<100xf32>
+  }
+
+  // CHECK:      affine.for %{{.*}} = 0 to 100
+  // CHECK-NEXT:   affine.load
+  // CHECK-NEXT:   affine.store
+  affine.for %i = 0 to 100 {
+    // Non-affine source for this load.
+    %v = affine.load %a[%i] : memref<100xf32>
+    affine.store %v, %b[%i] : memref<100xf32>
+  }
+
+  return
+}
+
 // -----
 
+// CHECK-LABEL: func @fuse_minor_affine_map
 // MAXIMAL-LABEL: func @fuse_minor_affine_map
 func.func @fuse_minor_affine_map(%in: memref<128xf32>, %out: memref<20x512xf32>) {
   %tmp = memref.alloc() : memref<128xf32>
@@ -418,7 +567,7 @@ func.func @should_fuse_multi_store_producer_and_privatize_memfefs() {
   return
 }
 
-
+// CHECK-LABEL: func @should_fuse_multi_store_producer_with_escaping_memrefs_and_remove_src
 func.func @should_fuse_multi_store_producer_with_escaping_memrefs_and_remove_src(
     %a : memref<10xf32>, %b : memref<10xf32>) {
   %cst = arith.constant 0.000000e+00 : f32
@@ -467,7 +616,7 @@ func.func @should_fuse_multi_store_producer_with_escaping_memrefs_and_preserve_s
     %0 = affine.load %b[%i2] : memref<10xf32>
   }
 
-	// Loops '%i0' and '%i2' should be fused first and '%i0' should be removed
+  // Loops '%i0' and '%i2' should be fused first and '%i0' should be removed
   // since fusion is maximal. Then the fused loop and '%i1' should be fused
   // and the fused loop shouldn't be removed since fusion is not maximal.
   // CHECK:       affine.for %{{.*}} = 0 to 10 {
@@ -517,6 +666,31 @@ func.func @should_not_fuse_due_to_dealloc(%arg0: memref<16xf32>){
 // CHECK-NEXT:      arith.addf
 // CHECK-NEXT:      affine.store
 
+// CHECK-LABEL: func @cannot_fuse_intervening_deallocs
+func.func @cannot_fuse_intervening_deallocs(%arg0: memref<16xf32>){
+  %A = memref.alloc() : memref<16xf32>
+  %C = memref.alloc() : memref<16xf32>
+  %cst_1 = arith.constant 1.000000e+00 : f32
+  // CHECK: affine.for %{{.*}} = 0 to 16
+  affine.for %arg1 = 0 to 16 {
+    %a = affine.load %arg0[%arg1] : memref<16xf32>
+    affine.store %a, %A[%arg1] : memref<16xf32>
+    affine.store %a, %C[%arg1] : memref<16xf32>
+  }
+  // The presence of B's alloc prevents placement of the fused nest above C's
+  // dealloc. No fusion here.
+  memref.dealloc %C : memref<16xf32>
+  %B = memref.alloc() : memref<16xf32>
+  // CHECK: affine.for %{{.*}} = 0 to 16
+  affine.for %arg1 = 0 to 16 {
+    %a = affine.load %A[%arg1] : memref<16xf32>
+    %b = arith.addf %cst_1, %a : f32
+    affine.store %b, %B[%arg1] : memref<16xf32>
+  }
+  memref.dealloc %A : memref<16xf32>
+  return
+}
+
 // -----
 
 // CHECK-LABEL: func @should_fuse_defining_node_has_no_dependence_from_source_node
diff --git a/mlir/test/Dialect/Affine/loop-fusion-inner.mlir b/mlir/test/Dialect/Affine/loop-fusion-inner.mlir
index 61af9a4baf46d2..e76441f2c32373 100644
--- a/mlir/test/Dialect/Affine/loop-fusion-inner.mlir
+++ b/mlir/test/Dialect/Affine/loop-fusion-inner.mlir
@@ -90,13 +90,13 @@ func.func @fusion_inner_multiple_nests() {
   // CHECK:      affine.for %{{.*}} = 0 to 4 {
   // Everything inside fused into two nests (the second will be DCE'd).
   // CHECK-NEXT:   memref.alloc() : memref<4xi8>
-  // CHECK-NEXT:   memref.alloc() : memref<1xi8>
-  // CHECK-NEXT:   memref.alloc() : memref<1xi8>
   // CHECK-NEXT:   memref.alloc() : memref<8x4xi8>
   // CHECK-NEXT:   memref.alloc() : memref<4xi8>
-  // CHECK-NEXT:   affine.for %{{.*}} = 0 to 2 {
+  // CHECK-NEXT:   affine.for %{{.*}} = 0 to 4 {
   // CHECK:        }
-  // CHECK:        affine.for %{{.*}} = 0 to 4 {
+  // CHECK-NEXT:   affine.for %{{.*}} = 0 to 2 {
+  // CHECK:          arith.muli
+  // CHECK-NEXT:     arith.extsi
   // CHECK:        }
   // CHECK-NEXT:   memref.dealloc
   // CHECK-NEXT: }
diff --git a/mlir/test/Dialect/Affine/loop-fusion.mlir b/mlir/test/Dialect/Affine/loop-fusion.mlir
index 045b1bec272e1e..1c119e87c53360 100644
--- a/mlir/test/Dialect/Affine/loop-fusion.mlir
+++ b/mlir/test/Dialect/Affine/loop-fusion.mlir
@@ -448,8 +448,8 @@ func.func @should_fuse_no_top_level_access() {
 
 #set0 = affine_set<(d0) : (1 == 0)>
 
-// CHECK-LABEL: func @should_not_fuse_if_op_at_top_level() {
-func.func @should_not_fuse_if_op_at_top_level() {
+// CHECK-LABEL: func @should_fuse_despite_affine_if() {
+func.func @should_fuse_despite_affine_if() {
   %m = memref.alloc() : memref<10xf32>
   %cf7 = arith.constant 7.0 : f32
 
@@ -462,12 +462,10 @@ func.func @should_not_fuse_if_op_at_top_level() {
   %c0 = arith.constant 4 : index
   affine.if #set0(%c0) {
   }
-  // Top-level IfOp should prevent fusion.
+  // An unrelated affine.if op doesn't prevent fusion.
   // CHECK:      affine.for %{{.*}} = 0 to 10 {
-  // CHECK-NEXT:   affine.store %{{.*}}, %{{.*}}[%{{.*}}] : memref<10xf32>
-  // CHECK-NEXT: }
-  // CHECK:      affine.for %{{.*}} = 0 to 10 {
-  // CHECK-NEXT:   affine.load %{{.*}}[%{{.*}}] : memref<10xf32>
+  // CHECK-NEXT:   affine.store %{{.*}}, %{{.*}}[0] : memref<1xf32>
+  // CHECK-NEXT:   affine.load %{{.*}}[0] : memref<1xf32>
   // CHECK-NEXT: }
   return
 }



More information about the Mlir-commits mailing list