[Mlir-commits] [mlir] fe9d0a4 - [MLIR] Generalize affine fusion to work on `Block` instead of `FuncOp`

Uday Bondhugula llvmlistbot at llvm.org
Wed Dec 14 09:27:19 PST 2022


Author: Uday Bondhugula
Date: 2022-12-14T22:56:29+05:30
New Revision: fe9d0a47d55ff3c55d8caa03c55c2651985b2f0a

URL: https://github.com/llvm/llvm-project/commit/fe9d0a47d55ff3c55d8caa03c55c2651985b2f0a
DIFF: https://github.com/llvm/llvm-project/commit/fe9d0a47d55ff3c55d8caa03c55c2651985b2f0a.diff

LOG: [MLIR] Generalize affine fusion to work on `Block` instead of `FuncOp`

The affine fusion pass can actually work on the top-level of a `Block`
and doesn't require to be called on a `FuncOp`. Remove this restriction
and generalize the pass to work on any `Block`. This allows fusion to be
performed, for example, on multiple blocks of a FuncOp or any
region-holding op like an scf.while, scf.if or even at an inner depth of
an affine.for or affine.if op. This generalization has no effect on
existing functionality. No changes to the fusion logic or its
transformational power were needed.

Update fusion pass to be a generic operation pass (instead of FuncOp
pass) and remove references and assumptions on the parent being a
FuncOp.

Reviewed By: dcaballe

Differential Revision: https://reviews.llvm.org/D139293

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Affine/Passes.h
    mlir/include/mlir/Dialect/Affine/Passes.td
    mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp
    mlir/lib/Dialect/Affine/Utils/LoopFusionUtils.cpp
    mlir/test/Transforms/loop-fusion-2.mlir
    mlir/test/Transforms/loop-fusion-3.mlir
    mlir/test/Transforms/loop-fusion-4.mlir
    mlir/test/Transforms/loop-fusion.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Affine/Passes.h b/mlir/include/mlir/Dialect/Affine/Passes.h
index a735fc0474ee0..7102626db0f60 100644
--- a/mlir/include/mlir/Dialect/Affine/Passes.h
+++ b/mlir/include/mlir/Dialect/Affine/Passes.h
@@ -73,10 +73,11 @@ createAffineScalarReplacementPass();
 /// bounds into a single loop.
 std::unique_ptr<OperationPass<func::FuncOp>> createLoopCoalescingPass();
 
-/// Creates a loop fusion pass which fuses loops according to type of fusion
+/// Creates a loop fusion pass which fuses affine loop nests at the top-level of
+/// the operation the pass is created on according to the type of fusion
 /// specified in `fusionMode`. Buffers of size less than or equal to
 /// `localBufSizeThreshold` are promoted to memory space `fastMemorySpace`.
-std::unique_ptr<OperationPass<func::FuncOp>>
+std::unique_ptr<Pass>
 createLoopFusionPass(unsigned fastMemorySpace = 0,
                      uint64_t localBufSizeThreshold = 0,
                      bool maximalFusion = false,

diff  --git a/mlir/include/mlir/Dialect/Affine/Passes.td b/mlir/include/mlir/Dialect/Affine/Passes.td
index 67b7eb32d1e9b..493fb39cd91e5 100644
--- a/mlir/include/mlir/Dialect/Affine/Passes.td
+++ b/mlir/include/mlir/Dialect/Affine/Passes.td
@@ -43,22 +43,24 @@ def AffineDataCopyGeneration : Pass<"affine-data-copy-generate", "func::FuncOp">
   ];
 }
 
-def AffineLoopFusion : Pass<"affine-loop-fusion", "func::FuncOp"> {
+def AffineLoopFusion : Pass<"affine-loop-fusion"> {
   let summary = "Fuse affine loop nests";
   let description = [{
-    This pass performs fusion of loop nests using a slicing-based approach. It
-    combines two fusion strategies: producer-consumer fusion and sibling fusion.
-    Producer-consumer fusion is aimed at fusing pairs of loops where the first
-    one writes to a memref that the second reads. Sibling fusion targets pairs
-    of loops that share no dependences between them but that load from the same
-    memref. The fused loop nests, when possible, are rewritten to access
-    significantly smaller local buffers instead of the original memref's, and
-    the latter are often either completely optimized away or contracted. This
-    transformation leads to enhanced locality and lower memory footprint through
-    the elimination or contraction of temporaries/intermediate memref's. These
-    benefits are sometimes achieved at the expense of redundant computation
-    through a cost model that evaluates available choices such as the depth at
-    which a source slice should be materialized in the designation slice.
+    This pass performs fusion of loop nests using a slicing-based approach. The
+    transformation works on an MLIR `Block` granularity and applies to all
+    blocks of the pass is run on. It combines two fusion strategies:
+    producer-consumer fusion and sibling fusion. Producer-consumer fusion is
+    aimed at fusing pairs of loops where the first one writes to a memref that
+    the second reads. Sibling fusion targets pairs of loops that share no
+    dependences between them but that load from the same memref. The fused loop
+    nests, when possible, are rewritten to access significantly smaller local
+    buffers instead of the original memref's, and the latter are often either
+    completely optimized away or contracted. This transformation leads to
+    enhanced locality and lower memory footprint through the elimination or
+    contraction of temporaries/intermediate memref's. These benefits are
+    sometimes achieved at the expense of redundant computation through a cost
+    model that evaluates available choices such as the depth at which a source
+    slice should be materialized in the designation slice.
 
     Example 1: Producer-consumer fusion.
     Input:

diff  --git a/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp b/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp
index e30079b6e3745..cd02a8f19f79a 100644
--- a/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp
+++ b/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp
@@ -6,7 +6,7 @@
 //
 //===----------------------------------------------------------------------===//
 //
-// This file implements loop fusion.
+// This file implements affine fusion.
 //
 //===----------------------------------------------------------------------===//
 
@@ -19,7 +19,6 @@
 #include "mlir/Dialect/Affine/LoopFusionUtils.h"
 #include "mlir/Dialect/Affine/LoopUtils.h"
 #include "mlir/Dialect/Affine/Utils.h"
-#include "mlir/Dialect/Func/IR/FuncOps.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/IR/AffineExpr.h"
 #include "mlir/IR/AffineMap.h"
@@ -65,12 +64,13 @@ struct LoopFusion : public impl::AffineLoopFusionBase<LoopFusion> {
     this->affineFusionMode = affineFusionMode;
   }
 
+  void runOnBlock(Block *block);
   void runOnOperation() override;
 };
 
 } // namespace
 
-std::unique_ptr<OperationPass<func::FuncOp>>
+std::unique_ptr<Pass>
 mlir::createLoopFusionPass(unsigned fastMemorySpace,
                            uint64_t localBufSizeThreshold, bool maximalFusion,
                            enum FusionMode affineFusionMode) {
@@ -104,7 +104,7 @@ struct LoopNestStateCollector {
 };
 
 // MemRefDependenceGraph is a graph data structure where graph nodes are
-// top-level operations in a FuncOp which contain load/store ops, and edges
+// top-level operations in a `Block` which contain load/store ops, and edges
 // are memref dependences between the nodes.
 // TODO: Add a more flexible dependence graph representation.
 // TODO: Add a depth parameter to dependence graph construction.
@@ -207,11 +207,11 @@ struct MemRefDependenceGraph {
   // The next unique identifier to use for newly created graph nodes.
   unsigned nextNodeId = 0;
 
-  MemRefDependenceGraph() = default;
+  MemRefDependenceGraph(Block &block) : block(block) {}
 
   // Initializes the dependence graph based on operations in 'f'.
   // Returns true on success, false otherwise.
-  bool init(func::FuncOp f);
+  bool init(Block *block);
 
   // Returns the graph node for 'id'.
   Node *getNode(unsigned id) {
@@ -258,7 +258,7 @@ struct MemRefDependenceGraph {
   }
 
   // Returns true if node 'id' writes to any memref which escapes (or is an
-  // argument to) the function/block. Returns false otherwise.
+  // argument to) the block. Returns false otherwise.
   bool writesToLiveInOrEscapingMemrefs(unsigned id) {
     Node *node = getNode(id);
     for (auto *storeOpInst : node->stores) {
@@ -267,7 +267,8 @@ struct MemRefDependenceGraph {
       // Return true if 'memref' is a block argument.
       if (!op)
         return true;
-      // Return true if any use of 'memref' escapes the function.
+      // Return true if any use of 'memref' does not deference it in an affine
+      // way.
       for (auto *user : memref.getUsers())
         if (!isa<AffineMapAccessInterface>(*user))
           return true;
@@ -597,6 +598,9 @@ struct MemRefDependenceGraph {
     }
   }
   void dump() const { print(llvm::errs()); }
+
+  /// The block for which this graph is created to perform fusion.
+  Block █
 };
 
 /// Returns true if node 'srcId' can be removed after fusing it with node
@@ -710,13 +714,14 @@ gatherProducerConsumerMemrefs(unsigned srcId, unsigned dstId,
                                 producerConsumerMemrefs);
 }
 
-/// A memref escapes the function if either:
+/// A memref escapes in the context of the fusion pass if either:
 ///   1. it (or its alias) is a block argument, or
 ///   2. created by an op not known to guarantee alias freedom,
-///   3. it (or its alias) is used by a non-affine op (e.g., call op, memref
-///      load/store ops, alias creating ops, unknown ops, etc.); such ops
-///      do not deference the memref in an affine way.
-static bool isEscapingMemref(Value memref) {
+///   3. it (or its alias) are used by ops other than affine dereferencing ops
+///   (e.g., by call op, memref load/store ops, alias creating ops, unknown ops,
+///   terminator ops, etc.); such ops do not deference the memref in an affine
+///   way.
+static bool isEscapingMemref(Value memref, Block *block) {
   Operation *defOp = memref.getDefiningOp();
   // Check if 'memref' is a block argument.
   if (!defOp)
@@ -724,7 +729,7 @@ static bool isEscapingMemref(Value memref) {
 
   // Check if this is defined to be an alias of another memref.
   if (auto viewOp = dyn_cast<mlir::ViewLikeOpInterface>(defOp))
-    if (isEscapingMemref(viewOp.getViewSource()))
+    if (isEscapingMemref(viewOp.getViewSource(), block))
       return true;
 
   // Any op besides allocating ops wouldn't guarantee alias freedom
@@ -733,14 +738,18 @@ static bool isEscapingMemref(Value memref) {
 
   // Check if 'memref' is used by a non-deferencing op (including unknown ones)
   // (e.g., call ops, alias creating ops, etc.).
-  for (Operation *user : memref.getUsers())
+  for (Operation *user : memref.getUsers()) {
+    // Ignore users outside of `block`.
+    if (block->getParent()->findAncestorOpInRegion(*user)->getBlock() != block)
+      continue;
     if (!isa<AffineMapAccessInterface>(*user))
       return true;
+  }
   return false;
 }
 
 /// Returns in 'escapingMemRefs' the memrefs from affine store ops in node 'id'
-/// that escape the function or are accessed by non-affine ops.
+/// that escape the block or are accessed in a non-affine way.
 void gatherEscapingMemrefs(unsigned id, MemRefDependenceGraph *mdg,
                            DenseSet<Value> &escapingMemRefs) {
   auto *node = mdg->getNode(id);
@@ -748,29 +757,25 @@ void gatherEscapingMemrefs(unsigned id, MemRefDependenceGraph *mdg,
     auto memref = cast<AffineWriteOpInterface>(storeOp).getMemRef();
     if (escapingMemRefs.count(memref))
       continue;
-    if (isEscapingMemref(memref))
+    if (isEscapingMemref(memref, &mdg->block))
       escapingMemRefs.insert(memref);
   }
 }
 
 } // namespace
 
-// Initializes the data dependence graph by walking operations in 'f'.
+// Initializes the data dependence graph by walking operations in `block`.
 // Assigns each node in the graph a node id based on program order in 'f'.
 // TODO: Add support for taking a Block arg to construct the
 // dependence graph at a 
diff erent depth.
-bool MemRefDependenceGraph::init(func::FuncOp f) {
+bool MemRefDependenceGraph::init(Block *block) {
   LLVM_DEBUG(llvm::dbgs() << "--- Initializing MDG ---\n");
   // Map from a memref to the set of ids of the nodes that have ops accessing
   // the memref.
   DenseMap<Value, SetVector<unsigned>> memrefAccesses;
 
-  // TODO: support multi-block functions.
-  if (!llvm::hasSingleElement(f))
-    return false;
-
   DenseMap<Operation *, unsigned> forToNodeMap;
-  for (auto &op : f.front()) {
+  for (Operation &op : *block) {
     if (auto forOp = dyn_cast<AffineForOp>(op)) {
       // Create graph node 'id' to represent top-level 'forOp' and record
       // all loads and store accesses it contains.
@@ -845,14 +850,18 @@ bool MemRefDependenceGraph::init(func::FuncOp f) {
     // Stores don't define SSA values, skip them.
     if (!node.stores.empty())
       continue;
-    auto *opInst = node.op;
-    for (auto value : opInst->getResults()) {
-      for (auto *user : value.getUsers()) {
+    Operation *opInst = node.op;
+    for (Value value : opInst->getResults()) {
+      for (Operation *user : value.getUsers()) {
+        // Ignore users outside of the block.
+        if (block->getParent()->findAncestorOpInRegion(*user)->getBlock() !=
+            block)
+          continue;
         SmallVector<AffineForOp, 4> loops;
         getLoopIVs(*user, &loops);
         if (loops.empty())
           continue;
-        assert(forToNodeMap.count(loops[0]) > 0);
+        assert(forToNodeMap.count(loops[0]) > 0 && "missing mapping");
         unsigned userLoopNestId = forToNodeMap[loops[0]];
         addEdge(node.id, userLoopNestId, value);
       }
@@ -918,7 +927,7 @@ static Value createPrivateMemRef(AffineForOp forOp, Operation *srcStoreOpInst,
   // Create builder to insert alloc op just before 'forOp'.
   OpBuilder b(forInst);
   // Builder to create constants at the top level.
-  OpBuilder top(forInst->getParentOfType<func::FuncOp>().getBody());
+  OpBuilder top(forInst->getParentRegion());
   // Create new memref type based on slice bounds.
   auto oldMemRef = cast<AffineWriteOpInterface>(srcStoreOpInst).getMemRef();
   auto oldMemRefType = oldMemRef.getType().cast<MemRefType>();
@@ -979,7 +988,7 @@ static Value createPrivateMemRef(AffineForOp forOp, Operation *srcStoreOpInst,
   // a constant shape.
   // TODO: Create/move alloc ops for private memrefs closer to their
   // consumer loop nests to reduce their live range. Currently they are added
-  // at the beginning of the function, because loop nests can be reordered
+  // at the beginning of the block, because loop nests can be reordered
   // during the fusion pass.
   Value newMemRef = top.create<memref::AllocOp>(forOp.getLoc(), newMemRefType);
 
@@ -1508,8 +1517,8 @@ struct GreedyFusion {
               }))
             continue;
 
-          // Gather memrefs in 'srcNode' that are written and escape to the
-          // function (e.g., memref function arguments, returned memrefs,
+          // Gather memrefs in 'srcNode' that are written and escape out of the
+          // block (e.g., memref block arguments, returned memrefs,
           // memrefs passed to function calls, etc.).
           DenseSet<Value> srcEscapingMemRefs;
           gatherEscapingMemrefs(srcNode->id, mdg, srcEscapingMemRefs);
@@ -1829,7 +1838,7 @@ struct GreedyFusion {
     }
   }
 
-  // Searches function argument uses and the graph from 'dstNode' looking for a
+  // Searches block argument uses and the graph from 'dstNode' looking for a
   // fusion candidate sibling node which shares no dependences with 'dstNode'
   // but which loads from the same memref. Returns true and sets
   // 'idAndMemrefToFuse' on success. Returns false otherwise.
@@ -1874,36 +1883,37 @@ struct GreedyFusion {
       return true;
     };
 
-    // Search for siblings which load the same memref function argument.
-    auto fn = dstNode->op->getParentOfType<func::FuncOp>();
-    for (unsigned i = 0, e = fn.getNumArguments(); i != e; ++i) {
-      for (auto *user : fn.getArgument(i).getUsers()) {
-        if (auto loadOp = dyn_cast<AffineReadOpInterface>(user)) {
-          // Gather loops surrounding 'use'.
-          SmallVector<AffineForOp, 4> loops;
-          getLoopIVs(*user, &loops);
-          // Skip 'use' if it is not within a loop nest.
-          if (loops.empty())
-            continue;
-          Node *sibNode = mdg->getForOpNode(loops[0]);
-          assert(sibNode != nullptr);
-          // Skip 'use' if it not a sibling to 'dstNode'.
-          if (sibNode->id == dstNode->id)
-            continue;
-          // Skip 'use' if it has been visited.
-          if (visitedSibNodeIds->count(sibNode->id) > 0)
-            continue;
-          // Skip 'use' if it does not load from the same memref as 'dstNode'.
-          auto memref = loadOp.getMemRef();
-          if (dstNode->getLoadOpCount(memref) == 0)
-            continue;
-          // Check if 'sibNode/dstNode' can be input-reuse fused on 'memref'.
-          if (canFuseWithSibNode(sibNode, memref)) {
-            visitedSibNodeIds->insert(sibNode->id);
-            idAndMemrefToFuse->first = sibNode->id;
-            idAndMemrefToFuse->second = memref;
-            return true;
-          }
+    // Search for siblings which load the same memref block argument.
+    Block *block = dstNode->op->getBlock();
+    for (unsigned i = 0, e = block->getNumArguments(); i != e; ++i) {
+      for (Operation *user : block->getArgument(i).getUsers()) {
+        auto loadOp = dyn_cast<AffineReadOpInterface>(user);
+        if (!loadOp)
+          continue;
+        // Gather loops surrounding 'use'.
+        SmallVector<AffineForOp, 4> loops;
+        getLoopIVs(*user, &loops);
+        // Skip 'use' if it is not within a loop nest.
+        if (loops.empty())
+          continue;
+        Node *sibNode = mdg->getForOpNode(loops[0]);
+        assert(sibNode != nullptr);
+        // Skip 'use' if it not a sibling to 'dstNode'.
+        if (sibNode->id == dstNode->id)
+          continue;
+        // Skip 'use' if it has been visited.
+        if (visitedSibNodeIds->count(sibNode->id) > 0)
+          continue;
+        // Skip 'use' if it does not load from the same memref as 'dstNode'.
+        auto memref = loadOp.getMemRef();
+        if (dstNode->getLoadOpCount(memref) == 0)
+          continue;
+        // Check if 'sibNode/dstNode' can be input-reuse fused on 'memref'.
+        if (canFuseWithSibNode(sibNode, memref)) {
+          visitedSibNodeIds->insert(sibNode->id);
+          idAndMemrefToFuse->first = sibNode->id;
+          idAndMemrefToFuse->second = memref;
+          return true;
         }
       }
     }
@@ -1968,8 +1978,7 @@ struct GreedyFusion {
     mdg->addToNode(dstNode->id, dstLoopCollector.loadOpInsts,
                    dstLoopCollector.storeOpInsts);
     // Remove old sibling loop nest if it no longer has outgoing dependence
-    // edges, and it does not write to a memref which escapes the
-    // function.
+    // edges, and it does not write to a memref which escapes the block.
     if (mdg->getOutEdgeCount(sibNode->id) == 0) {
       Operation *op = sibNode->op;
       mdg->removeNode(sibNode->id);
@@ -1996,9 +2005,10 @@ struct GreedyFusion {
 
 } // namespace
 
-void LoopFusion::runOnOperation() {
-  MemRefDependenceGraph g;
-  if (!g.init(getOperation()))
+/// Run fusion on `block`.
+void LoopFusion::runOnBlock(Block *block) {
+  MemRefDependenceGraph g(*block);
+  if (!g.init(block))
     return;
 
   Optional<unsigned> fastMemorySpaceOpt;
@@ -2015,3 +2025,9 @@ void LoopFusion::runOnOperation() {
   else
     fusion.runGreedyFusion();
 }
+
+void LoopFusion::runOnOperation() {
+  for (Region &region : getOperation()->getRegions())
+    for (Block &block : region.getBlocks())
+      runOnBlock(&block);
+}

diff  --git a/mlir/lib/Dialect/Affine/Utils/LoopFusionUtils.cpp b/mlir/lib/Dialect/Affine/Utils/LoopFusionUtils.cpp
index 889c9078c5300..c183a338475d3 100644
--- a/mlir/lib/Dialect/Affine/Utils/LoopFusionUtils.cpp
+++ b/mlir/lib/Dialect/Affine/Utils/LoopFusionUtils.cpp
@@ -17,7 +17,6 @@
 #include "mlir/Dialect/Affine/Analysis/Utils.h"
 #include "mlir/Dialect/Affine/IR/AffineOps.h"
 #include "mlir/Dialect/Affine/LoopUtils.h"
-#include "mlir/Dialect/Func/IR/FuncOps.h"
 #include "mlir/IR/BlockAndValueMapping.h"
 #include "mlir/IR/Operation.h"
 #include "llvm/Support/Debug.h"
@@ -471,7 +470,7 @@ bool mlir::getLoopNestStats(AffineForOp forOpRoot, LoopNestStats *stats) {
   auto walkResult = forOpRoot.walk([&](AffineForOp forOp) {
     auto *childForOp = forOp.getOperation();
     auto *parentForOp = forOp->getParentOp();
-    if (!llvm::isa<func::FuncOp>(parentForOp)) {
+    if (forOp != forOpRoot) {
       if (!isa<AffineForOp>(parentForOp)) {
         LLVM_DEBUG(llvm::dbgs() << "Expected parent AffineForOp\n");
         return WalkResult::interrupt();

diff  --git a/mlir/test/Transforms/loop-fusion-2.mlir b/mlir/test/Transforms/loop-fusion-2.mlir
index c1fded7a16bb9..49538a018ce9c 100644
--- a/mlir/test/Transforms/loop-fusion-2.mlir
+++ b/mlir/test/Transforms/loop-fusion-2.mlir
@@ -1,5 +1,5 @@
-// RUN: mlir-opt -allow-unregistered-dialect %s -affine-loop-fusion -split-input-file | FileCheck %s
-// RUN: mlir-opt -allow-unregistered-dialect %s -affine-loop-fusion="fusion-maximal" -split-input-file | FileCheck %s --check-prefix=MAXIMAL
+// RUN: mlir-opt -allow-unregistered-dialect %s -pass-pipeline='builtin.module(func.func(affine-loop-fusion))' -split-input-file | FileCheck %s
+// RUN: mlir-opt -allow-unregistered-dialect %s -pass-pipeline='builtin.module(func.func(affine-loop-fusion{fusion-maximal}))' -split-input-file | FileCheck %s --check-prefix=MAXIMAL
 
 // Part I of fusion tests in  mlir/test/Transforms/loop-fusion.mlir.
 // Part III of fusion tests in mlir/test/Transforms/loop-fusion-3.mlir

diff  --git a/mlir/test/Transforms/loop-fusion-3.mlir b/mlir/test/Transforms/loop-fusion-3.mlir
index 54457a17a2a5e..37ad178235dc9 100644
--- a/mlir/test/Transforms/loop-fusion-3.mlir
+++ b/mlir/test/Transforms/loop-fusion-3.mlir
@@ -1,5 +1,5 @@
-// RUN: mlir-opt -allow-unregistered-dialect %s -affine-loop-fusion -split-input-file | FileCheck %s
-// RUN: mlir-opt -allow-unregistered-dialect %s -affine-loop-fusion="fusion-maximal" -split-input-file | FileCheck %s --check-prefix=MAXIMAL
+// RUN: mlir-opt -allow-unregistered-dialect %s -pass-pipeline='builtin.module(func.func(affine-loop-fusion))' -split-input-file | FileCheck %s
+// RUN: mlir-opt -allow-unregistered-dialect %s -pass-pipeline='builtin.module(func.func(affine-loop-fusion{fusion-maximal}))' -split-input-file | FileCheck %s --check-prefix=MAXIMAL
 
 // Part I of fusion tests in  mlir/test/Transforms/loop-fusion.mlir.
 // Part II of fusion tests in mlir/test/Transforms/loop-fusion-2.mlir
@@ -532,8 +532,8 @@ func.func @should_fuse_defining_node_has_no_dependence_from_source_node(
     %2 = arith.divf %0, %1 : f32
   }
 
-	// Loops '%i0' and '%i1' should be fused even though there is a defining
-  // node between the loops. It is because the node has no dependence from '%i0'.
+  // Loops '%i0' and '%i1' should be fused even though there is a defining node
+  // between the loops. It is because the node has no dependence from '%i0'.
   // CHECK:       affine.load %{{.*}}[] : memref<f32>
   // CHECK-NEXT:  affine.for %{{.*}} = 0 to 10 {
   // CHECK-NEXT:    affine.load %{{.*}}[] : memref<f32>
@@ -561,8 +561,8 @@ func.func @should_not_fuse_defining_node_has_dependence_from_source_loop(
     %2 = arith.divf %0, %1 : f32
   }
 
-	// Loops '%i0' and '%i1' should not be fused because the defining node
-  // of '%0' used in '%i1' has dependence from loop '%i0'.
+  // Loops '%i0' and '%i1' should not be fused because the defining node of '%0'
+  // used in '%i1' has dependence from loop '%i0'.
   // CHECK:       affine.for %{{.*}} = 0 to 10 {
   // CHECK-NEXT:    affine.store %{{.*}}, %{{.*}}[] : memref<f32>
   // CHECK-NEXT:    affine.store %{{.*}}, %{{.*}}[%{{.*}}] : memref<10xf32>

diff  --git a/mlir/test/Transforms/loop-fusion-4.mlir b/mlir/test/Transforms/loop-fusion-4.mlir
index d3dc5f788de26..b84e0f2e21e54 100644
--- a/mlir/test/Transforms/loop-fusion-4.mlir
+++ b/mlir/test/Transforms/loop-fusion-4.mlir
@@ -1,5 +1,5 @@
-// RUN: mlir-opt -allow-unregistered-dialect %s -affine-loop-fusion="mode=producer" -split-input-file | FileCheck %s --check-prefix=PRODUCER-CONSUMER
-// RUN: mlir-opt -allow-unregistered-dialect %s -affine-loop-fusion="fusion-maximal mode=sibling" -split-input-file | FileCheck %s --check-prefix=SIBLING-MAXIMAL
+// RUN: mlir-opt -allow-unregistered-dialect %s -pass-pipeline='builtin.module(func.func(affine-loop-fusion{mode=producer}))' -split-input-file | FileCheck %s --check-prefix=PRODUCER-CONSUMER
+// RUN: mlir-opt -allow-unregistered-dialect %s -pass-pipeline='builtin.module(func.func(affine-loop-fusion{fusion-maximal mode=sibling}))' -split-input-file | FileCheck %s --check-prefix=SIBLING-MAXIMAL
 
 // Part I of fusion tests in  mlir/test/Transforms/loop-fusion.mlir.
 // Part II of fusion tests in mlir/test/Transforms/loop-fusion-2.mlir
@@ -141,3 +141,36 @@ func.func @reduce_add_non_maximal_f32_f32(%arg0: memref<64x64xf32, 1>, %arg1 : m
 // SIBLING-MAXIMAL-NEXT:             affine.for %[[idx_1:.*]] = 0 to 64 {
 // SIBLING-MAXIMAL-NEXT:               %[[result_1:.*]] = affine.for %[[idx_2:.*]] = 0 to 32 iter_args(%[[iter_0:.*]] = %[[cst_1]]) -> (f32) {
 // SIBLING-MAXIMAL-NEXT:                 %[[result_0:.*]] = affine.for %[[idx_3:.*]] = 0 to 64 iter_args(%[[iter_1:.*]] = %[[cst_0]]) -> (f32) {
+
+// -----
+
+// PRODUCER-CONSUMER-LABEL: func @fusion_for_multiple_blocks() {
+func.func @fusion_for_multiple_blocks() {
+^bb0:
+  %m = memref.alloc() : memref<10xf32>
+  %cf7 = arith.constant 7.0 : f32
+
+  affine.for %i0 = 0 to 10 {
+    affine.store %cf7, %m[%i0] : memref<10xf32>
+  }
+  affine.for %i1 = 0 to 10 {
+    %v0 = affine.load %m[%i1] : memref<10xf32>
+  }
+  // PRODUCER-CONSUMER:      affine.for %{{.*}} = 0 to 10 {
+  // PRODUCER-CONSUMER-NEXT:   affine.store %{{.*}}, %{{.*}}[0] : memref<1xf32>
+  // PRODUCER-CONSUMER-NEXT:   affine.load %{{.*}}[0] : memref<1xf32>
+  // PRODUCER-CONSUMER-NEXT: }
+  cf.br ^bb1
+^bb1:
+  affine.for %i0 = 0 to 10 {
+    affine.store %cf7, %m[%i0] : memref<10xf32>
+  }
+  affine.for %i1 = 0 to 10 {
+    %v0 = affine.load %m[%i1] : memref<10xf32>
+  }
+  // PRODUCER-CONSUMER:      affine.for %{{.*}} = 0 to 10 {
+  // PRODUCER-CONSUMER-NEXT:   affine.store %{{.*}}, %{{.*}}[0] : memref<1xf32>
+  // PRODUCER-CONSUMER-NEXT:   affine.load %{{.*}}[0] : memref<1xf32>
+  // PRODUCER-CONSUMER-NEXT: }
+  return
+}

diff  --git a/mlir/test/Transforms/loop-fusion.mlir b/mlir/test/Transforms/loop-fusion.mlir
index fcf8c67dac003..5ec195cb6c946 100644
--- a/mlir/test/Transforms/loop-fusion.mlir
+++ b/mlir/test/Transforms/loop-fusion.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt -allow-unregistered-dialect %s -affine-loop-fusion -split-input-file | FileCheck %s
+// RUN: mlir-opt -allow-unregistered-dialect %s -pass-pipeline='builtin.module(func.func(affine-loop-fusion))' -split-input-file | FileCheck %s
 
 // Part II of fusion tests in  mlir/test/Transforms/loop-fusion=2.mlir.
 // Part III of fusion tests in mlir/test/Transforms/loop-fusion-3.mlir


        


More information about the Mlir-commits mailing list