[Mlir-commits] [mlir] 203d5ee - [MLIR][affine-loop-fusion] Handle defining ops between the source and dest loops

Diego Caballero llvmlistbot at llvm.org
Thu Feb 25 08:23:48 PST 2021


Author: Tung D. Le
Date: 2021-02-25T18:12:34+02:00
New Revision: 203d5eeec55b1f0e0dd2aa28f5c5ebe292802e62

URL: https://github.com/llvm/llvm-project/commit/203d5eeec55b1f0e0dd2aa28f5c5ebe292802e62
DIFF: https://github.com/llvm/llvm-project/commit/203d5eeec55b1f0e0dd2aa28f5c5ebe292802e62.diff

LOG: [MLIR][affine-loop-fusion] Handle defining ops between the source and dest loops

This patch handles defining ops between the source and dest loop nests, and prevents loop nests with `iter_args` from being fused.

If there is any SSA value in the dest loop nest whose defining op has dependence from the source loop nest, we cannot fuse the loop nests.

If there is a `affine.for` with `iter_args`, prevent it from being fused.

Reviewed By: dcaballe, bondhugula

Differential Revision: https://reviews.llvm.org/D97030

Added: 
    

Modified: 
    mlir/lib/Transforms/LoopFusion.cpp
    mlir/test/Transforms/loop-fusion.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Transforms/LoopFusion.cpp b/mlir/lib/Transforms/LoopFusion.cpp
index d6d26a85215b..4e02f2790bd2 100644
--- a/mlir/lib/Transforms/LoopFusion.cpp
+++ b/mlir/lib/Transforms/LoopFusion.cpp
@@ -179,8 +179,8 @@ struct MemRefDependenceGraph {
     // which contain accesses to the same memref 'value'. If the value is a
     // non-memref value, then the dependence is between a graph node which
     // defines an SSA value and another graph node which uses the SSA value
-    // (e.g. a constant operation defining a value which is used inside a loop
-    // nest).
+    // (e.g. a constant or load operation defining a value which is used inside
+    // a loop nest).
     Value value;
   };
 
@@ -369,6 +369,16 @@ struct MemRefDependenceGraph {
     return outEdgeCount;
   }
 
+  /// Return all nodes which define SSA values used in node 'id'.
+  void gatherDefiningNodes(unsigned id, DenseSet<unsigned> &definingNodes) {
+    for (MemRefDependenceGraph::Edge edge : inEdges[id])
+      // By definition of edge, if the edge value is a non-memref value,
+      // then the dependence is between a graph node which defines an SSA value
+      // and another graph node which uses the SSA value.
+      if (!edge.value.getType().isa<MemRefType>())
+        definingNodes.insert(edge.id);
+  }
+
   // Computes and returns an insertion point operation, before which the
   // the fused <srcId, dstId> loop nest can be inserted while preserving
   // dependences. Returns nullptr if no such insertion point is found.
@@ -376,6 +386,18 @@ struct MemRefDependenceGraph {
     if (outEdges.count(srcId) == 0)
       return getNode(dstId)->op;
 
+    // Skip if there is any defining node of 'dstId' that depends on 'srcId'.
+    DenseSet<unsigned> definingNodes;
+    gatherDefiningNodes(dstId, definingNodes);
+    if (llvm::any_of(definingNodes, [&](unsigned id) {
+          return hasDependencePath(srcId, id);
+        })) {
+      LLVM_DEBUG(llvm::dbgs()
+                 << "Can't fuse: a defining op with a user in the dst "
+                    "loop has dependence from the src loop\n");
+      return nullptr;
+    }
+
     // Build set of insts in range (srcId, dstId) which depend on 'srcId'.
     SmallPtrSet<Operation *, 2> srcDepInsts;
     for (auto &outEdge : outEdges[srcId])
@@ -784,10 +806,11 @@ bool MemRefDependenceGraph::init(FuncOp f) {
   }
 
   // Add dependence edges between nodes which produce SSA values and their
-  // users.
+  // users. Load ops can be considered as the ones producing SSA values.
   for (auto &idAndNode : nodes) {
     const Node &node = idAndNode.second;
-    if (!node.loads.empty() || !node.stores.empty())
+    // Stores don't define SSA values, skip them.
+    if (!node.stores.empty())
       continue;
     auto *opInst = node.op;
     for (auto value : opInst->getResults()) {
@@ -956,7 +979,7 @@ static Value createPrivateMemRef(AffineForOp forOp, Operation *srcStoreOpInst,
 
 /// Walking from node 'srcId' to node 'dstId' (exclusive of 'srcId' and
 /// 'dstId'), if there is any non-affine operation accessing 'memref', return
-/// false. Otherwise, return true.
+/// true. Otherwise, return false.
 static bool hasNonAffineUsersOnThePath(unsigned srcId, unsigned dstId,
                                        Value memref,
                                        MemRefDependenceGraph *mdg) {
@@ -1389,6 +1412,10 @@ struct GreedyFusion {
       // Skip if 'dstNode' is not a loop nest.
       if (!isa<AffineForOp>(dstNode->op))
         continue;
+      // Skip if 'dstNode' is a loop nest returning values.
+      // TODO: support loop nests that return values.
+      if (dstNode->op->getNumResults() > 0)
+        continue;
 
       LLVM_DEBUG(llvm::dbgs() << "Evaluating dst loop " << dstId << "\n");
 
@@ -1419,6 +1446,11 @@ struct GreedyFusion {
           LLVM_DEBUG(llvm::dbgs() << "Evaluating src loop " << srcId
                                   << " for dst loop " << dstId << "\n");
 
+          // Skip if 'srcNode' is a loop nest returning values.
+          // TODO: support loop nests that return values.
+          if (isa<AffineForOp>(srcNode->op) && srcNode->op->getNumResults() > 0)
+            continue;
+
           DenseSet<Value> producerConsumerMemrefs;
           gatherProducerConsumerMemrefs(srcId, dstId, mdg,
                                         producerConsumerMemrefs);

diff  --git a/mlir/test/Transforms/loop-fusion.mlir b/mlir/test/Transforms/loop-fusion.mlir
index 21584fa6f7f1..9b2966b51c17 100644
--- a/mlir/test/Transforms/loop-fusion.mlir
+++ b/mlir/test/Transforms/loop-fusion.mlir
@@ -2837,6 +2837,7 @@ func @should_fuse_multi_store_producer_with_scaping_memrefs_and_preserve_src(
 }
 
 // -----
+
 func @should_not_fuse_due_to_dealloc(%arg0: memref<16xf32>){
   %A = alloc() : memref<16xf32>
   %C = alloc() : memref<16xf32>
@@ -2866,3 +2867,152 @@ func @should_not_fuse_due_to_dealloc(%arg0: memref<16xf32>){
 // CHECK-NEXT:      affine.load
 // CHECK-NEXT:      addf
 // CHECK-NEXT:      affine.store
+
+// -----
+
+// CHECK-LABEL: func @should_fuse_defining_node_has_no_dependence_from_source_node
+func @should_fuse_defining_node_has_no_dependence_from_source_node(
+    %a : memref<10xf32>, %b : memref<f32>) -> () {
+  affine.for %i0 = 0 to 10 {
+    %0 = affine.load %b[] : memref<f32>
+    affine.store %0, %a[%i0] : memref<10xf32>
+  }
+  %0 = affine.load %b[] : memref<f32>
+  affine.for %i1 = 0 to 10 {
+    %1 = affine.load %a[%i1] : memref<10xf32>
+    %2 = divf %0, %1 : f32
+  }
+
+	// Loops '%i0' and '%i1' should be fused even though there is a defining
+  // node between the loops. It is because the node has no dependence from '%i0'.
+  // CHECK:       affine.load %{{.*}}[] : memref<f32>
+  // CHECK-NEXT:  affine.for %{{.*}} = 0 to 10 {
+  // CHECK-NEXT:    affine.load %{{.*}}[] : memref<f32>
+  // CHECK-NEXT:    affine.store %{{.*}}, %{{.*}}[%{{.*}}] : memref<10xf32>
+  // CHECK-NEXT:    affine.load %{{.*}}[%{{.*}}] : memref<10xf32>
+  // CHECK-NEXT:    divf
+  // CHECK-NEXT:  }
+  // CHECK-NOT:   affine.for
+  return
+}
+
+// -----
+
+// CHECK-LABEL: func @should_not_fuse_defining_node_has_dependence_from_source_loop
+func @should_not_fuse_defining_node_has_dependence_from_source_loop(
+    %a : memref<10xf32>, %b : memref<f32>) -> () {
+  %cst = constant 0.000000e+00 : f32
+  affine.for %i0 = 0 to 10 {
+    affine.store %cst, %b[] : memref<f32>
+    affine.store %cst, %a[%i0] : memref<10xf32>
+  }
+  %0 = affine.load %b[] : memref<f32>
+  affine.for %i1 = 0 to 10 {
+    %1 = affine.load %a[%i1] : memref<10xf32>
+    %2 = divf %0, %1 : f32
+  }
+
+	// Loops '%i0' and '%i1' should not be fused because the defining node
+  // of '%0' used in '%i1' has dependence from loop '%i0'.
+  // CHECK:       affine.for %{{.*}} = 0 to 10 {
+  // CHECK-NEXT:    affine.store %{{.*}}, %{{.*}}[] : memref<f32>
+  // CHECK-NEXT:    affine.store %{{.*}}, %{{.*}}[%{{.*}}] : memref<10xf32>
+  // CHECK-NEXT:  }
+  // CHECK-NEXT:  affine.load %{{.*}}[] : memref<f32>
+  // CHECK:       affine.for %{{.*}} = 0 to 10 {
+  // CHECK-NEXT:    affine.load %{{.*}}[%{{.*}}] : memref<10xf32>
+  // CHECK-NEXT:    divf
+  // CHECK-NEXT:  }
+  return
+}
+
+// -----
+
+// CHECK-LABEL: func @should_not_fuse_defining_node_has_transitive_dependence_from_source_loop
+func @should_not_fuse_defining_node_has_transitive_dependence_from_source_loop(
+    %a : memref<10xf32>, %b : memref<10xf32>, %c : memref<f32>) -> () {
+  %cst = constant 0.000000e+00 : f32
+  affine.for %i0 = 0 to 10 {
+    affine.store %cst, %a[%i0] : memref<10xf32>
+    affine.store %cst, %b[%i0] : memref<10xf32>
+  }
+  affine.for %i1 = 0 to 10 {
+    %1 = affine.load %b[%i1] : memref<10xf32>
+    affine.store %1, %c[] : memref<f32>
+  }
+  %0 = affine.load %c[] : memref<f32>
+  affine.for %i2 = 0 to 10 {
+    %1 = affine.load %a[%i2] : memref<10xf32>
+    %2 = divf %0, %1 : f32
+  }
+
+	// When loops '%i0' and '%i2' are evaluated first, they should not be
+  // fused. The defining node of '%0' in loop '%i2' has transitive dependence
+  // from loop '%i0'. After that, loops '%i0' and '%i1' are evaluated, and they
+  // will be fused as usual.
+  // CHECK:       affine.for %{{.*}} = 0 to 10 {
+  // CHECK-NEXT:    affine.store %{{.*}}, %{{.*}}[%{{.*}}] : memref<10xf32>
+  // CHECK-NEXT:    affine.store %{{.*}}, %{{.*}}[%{{.*}}] : memref<10xf32>
+  // CHECK-NEXT:    affine.load %{{.*}}[%{{.*}}] : memref<10xf32>
+  // CHECK-NEXT:    affine.store %{{.*}}, %{{.*}}[] : memref<f32>
+  // CHECK-NEXT:  }
+  // CHECK-NEXT:  affine.load %{{.*}}[] : memref<f32>
+  // CHECK:       affine.for %{{.*}} = 0 to 10 {
+  // CHECK-NEXT:    affine.load %{{.*}}[%{{.*}}] : memref<10xf32>
+  // CHECK-NEXT:    divf
+  // CHECK-NEXT:  }
+  // CHECK-NOT:   affine.for
+  return
+}
+
+// -----
+
+// CHECK-LABEL: func @should_not_fuse_dest_loop_nest_return_value
+func @should_not_fuse_dest_loop_nest_return_value(
+    %a : memref<10xf32>) -> () {
+  %cst = constant 0.000000e+00 : f32
+  affine.for %i0 = 0 to 10 {
+    affine.store %cst, %a[%i0] : memref<10xf32>
+  }
+  %b = affine.for %i1 = 0 to 10 step 2 iter_args(%b_iter = %cst) -> f32 {
+    %load_a = affine.load %a[%i1] : memref<10xf32>
+    affine.yield %load_a: f32
+  }
+
+  // CHECK:       affine.for %{{.*}} = 0 to 10 {
+  // CHECK-NEXT:    affine.store %{{.*}}, %{{.*}}[%{{.*}}] : memref<10xf32>
+  // CHECK-NEXT:  }
+  // CHECK:       affine.for %{{.*}} = 0 to 10 step 2 iter_args(%{{.*}} = %{{.*}}) -> (f32) {
+  // CHECK-NEXT:    affine.load
+  // CHECK-NEXT:    affine.yield
+  // CHECK-NEXT:  }
+
+  return
+}
+
+// -----
+
+// CHECK-LABEL: func @should_not_fuse_src_loop_nest_return_value
+func @should_not_fuse_src_loop_nest_return_value(
+    %a : memref<10xf32>) -> () {
+  %cst = constant 1.000000e+00 : f32
+  %b = affine.for %i = 0 to 10 step 2 iter_args(%b_iter = %cst) -> f32 {
+    %c = addf %b_iter, %b_iter : f32
+    affine.store %c, %a[%i] : memref<10xf32>
+    affine.yield %c: f32
+  }
+  affine.for %i1 = 0 to 10 {
+    %1 = affine.load %a[%i1] : memref<10xf32>
+  }
+
+  // CHECK:       %{{.*}} = affine.for %{{.*}} = 0 to 10 step 2 iter_args(%{{.*}} = %{{.*}}) -> (f32) {
+  // CHECK-NEXT:    %{{.*}} = addf %{{.*}}, %{{.*}} : f32
+  // CHECK-NEXT:    affine.store %{{.*}}, %{{.*}}[%{{.*}}] : memref<10xf32>
+  // CHECK-NEXT:    affine.yield %{{.*}} : f32
+  // CHECK-NEXT:  }
+  // CHECK:       affine.for %{{.*}} = 0 to 10 {
+  // CHECK-NEXT:    affine.load %{{.*}}[%{{.*}}] : memref<10xf32>
+  // CHECK-NEXT:  }
+
+  return
+}


        


More information about the Mlir-commits mailing list