[Mlir-commits] [mlir] 2b5d177 - [MLIR][Affine-loop-fusion] Fix a bug in affine-loop-fusion pass when there are non-affine operations

Uday Bondhugula llvmlistbot at llvm.org
Fri Jun 26 05:57:54 PDT 2020


Author: Tung D. Le
Date: 2020-06-26T18:26:42+05:30
New Revision: 2b5d1776ffad2614756ef059d64b957c7731e7be

URL: https://github.com/llvm/llvm-project/commit/2b5d1776ffad2614756ef059d64b957c7731e7be
DIFF: https://github.com/llvm/llvm-project/commit/2b5d1776ffad2614756ef059d64b957c7731e7be.diff

LOG: [MLIR][Affine-loop-fusion] Fix a bug in affine-loop-fusion pass when there are non-affine operations

When there is a mix of affine load/store and non-affine operations (e.g. std.load, std.store),
affine-loop-fusion ignores the present of non-affine ops, thus changing the program semantics.

E.g. we have a program of three affine loops operating on the same memref in which one of them uses std.load and std.store, as follows.
```
affine.for
  affine.store %1
affine.for
  std.load %1
  std.store %1
affine.for
  affine.load %1
  affine.store %1
```
affine-loop-fusion will produce the following result which changed the program semantics:
```
affine.for
  std.load %1
  std.store %1
affine.for
  affine.store %1
  affine.load %1
  affine.store %1
```

This patch is to fix the above problem by checking non-affine users of the memref that are between the source and destination nodes of interest.

Differential Revision: https://reviews.llvm.org/D82158

Added: 
    

Modified: 
    mlir/lib/Transforms/LoopFusion.cpp
    mlir/test/Transforms/loop-fusion.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Transforms/LoopFusion.cpp b/mlir/lib/Transforms/LoopFusion.cpp
index 6a7a88e5a1ad..f71ff2aba9e9 100644
--- a/mlir/lib/Transforms/LoopFusion.cpp
+++ b/mlir/lib/Transforms/LoopFusion.cpp
@@ -948,6 +948,65 @@ static Value createPrivateMemRef(AffineForOp forOp, Operation *srcStoreOpInst,
   return newMemRef;
 }
 
+/// Walking from node 'srcId' to node 'dstId' (exclusive of 'srcId' and
+/// 'dstId'), if there is any non-affine operation accessing 'memref', return
+/// false. Otherwise, return true.
+static bool hasNonAffineUsersOnThePath(unsigned srcId, unsigned dstId,
+                                       Value memref,
+                                       MemRefDependenceGraph *mdg) {
+  auto *srcNode = mdg->getNode(srcId);
+  auto *dstNode = mdg->getNode(dstId);
+  Value::user_range users = memref.getUsers();
+  // For each MemRefDependenceGraph's node that is between 'srcNode' and
+  // 'dstNode' (exclusive of 'srcNodes' and 'dstNode'), check whether any
+  // non-affine operation in the node accesses the 'memref'.
+  for (auto &idAndNode : mdg->nodes) {
+    Operation *op = idAndNode.second.op;
+    // Take care of operations between 'srcNode' and 'dstNode'.
+    if (srcNode->op->isBeforeInBlock(op) && op->isBeforeInBlock(dstNode->op)) {
+      // Walk inside the operation to find any use of the memref.
+      // Interrupt the walk if found.
+      auto walkResult = op->walk([&](Operation *user) {
+        // Skip affine ops.
+        if (isMemRefDereferencingOp(*user))
+          return WalkResult::advance();
+        // Find a non-affine op that uses the memref.
+        if (llvm::is_contained(users, user))
+          return WalkResult::interrupt();
+        return WalkResult::advance();
+      });
+      if (walkResult.wasInterrupted())
+        return true;
+    }
+  }
+  return false;
+}
+
+/// Check whether a memref value in node 'srcId' has a non-affine that
+/// is between node 'srcId' and node 'dstId' (exclusive of 'srcNode' and
+/// 'dstNode').
+static bool hasNonAffineUsersOnThePath(unsigned srcId, unsigned dstId,
+                                       MemRefDependenceGraph *mdg) {
+  // Collect memref values in node 'srcId'.
+  auto *srcNode = mdg->getNode(srcId);
+  llvm::SmallDenseSet<Value, 2> memRefValues;
+  srcNode->op->walk([&](Operation *op) {
+    // Skip affine ops.
+    if (isa<AffineForOp>(op))
+      return WalkResult::advance();
+    for (Value v : op->getOperands())
+      // Collect memref values only.
+      if (v.getType().isa<MemRefType>())
+        memRefValues.insert(v);
+    return WalkResult::advance();
+  });
+  // Looking for users between node 'srcId' and node 'dstId'.
+  for (Value memref : memRefValues)
+    if (hasNonAffineUsersOnThePath(srcId, dstId, memref, mdg))
+      return true;
+  return false;
+}
+
 // Checks if node 'srcId' can be safely fused into node 'dstId'. Node 'srcId'
 // may write to multiple memrefs but it is required that only one of them,
 // 'srcLiveOutStoreOp', has output edges.
@@ -1008,6 +1067,12 @@ canFuseSrcWhichWritesToLiveOut(unsigned srcId, unsigned dstId,
   // TODO(andydavis) Check the shape and lower bounds here too.
   if (srcNumElements != dstNumElements)
     return false;
+
+  // Return false if 'memref' is used by a non-affine operation that is
+  // between node 'srcId' and node 'dstId'.
+  if (hasNonAffineUsersOnThePath(srcId, dstId, mdg))
+    return false;
+
   return true;
 }
 
@@ -1793,6 +1858,12 @@ struct GreedyFusion {
       }
       if (storeMemrefs.size() != 1)
         return false;
+
+      // Skip if a memref value in one node is used by a non-affine memref
+      // access that lies between 'dstNode' and 'sibNode'.
+      if (hasNonAffineUsersOnThePath(dstNode->id, sibNode->id, mdg) ||
+          hasNonAffineUsersOnThePath(sibNode->id, dstNode->id, mdg))
+        return false;
       return true;
     };
 

diff  --git a/mlir/test/Transforms/loop-fusion.mlir b/mlir/test/Transforms/loop-fusion.mlir
index a153b96bf362..51d2fb42a1c1 100644
--- a/mlir/test/Transforms/loop-fusion.mlir
+++ b/mlir/test/Transforms/loop-fusion.mlir
@@ -2570,3 +2570,67 @@ func @calc(%arg0: memref<?xf32>, %arg1: memref<?xf32>, %arg2: memref<?xf32>, %le
 // CHECK-NEXT:    affine.store %{{.*}}, %arg{{.*}}[%arg{{.*}}] : memref<?xf32>
 // CHECK-NEXT:  }
 // CHECK-NEXT:  return
+
+// -----
+
+// CHECK-LABEL: func @should_not_fuse_since_non_affine_users
+func @should_not_fuse_since_non_affine_users(%in0 : memref<32xf32>,
+                      %in1 : memref<32xf32>) {
+  affine.for %d = 0 to 32 {
+    %lhs = affine.load %in0[%d] : memref<32xf32>
+    %rhs = affine.load %in1[%d] : memref<32xf32>
+    %add = addf %lhs, %rhs : f32
+    affine.store %add, %in0[%d] : memref<32xf32>
+  }
+  affine.for %d = 0 to 32 {
+    %lhs = load %in0[%d] : memref<32xf32>
+    %rhs = load %in1[%d] : memref<32xf32>
+    %add = subf %lhs, %rhs : f32
+    store %add, %in0[%d] : memref<32xf32>
+  }
+  affine.for %d = 0 to 32 {
+    %lhs = affine.load %in0[%d] : memref<32xf32>
+    %rhs = affine.load %in1[%d] : memref<32xf32>
+    %add = mulf %lhs, %rhs : f32
+    affine.store %add, %in0[%d] : memref<32xf32>
+  }
+  return
+}
+
+// CHECK:  affine.for
+// CHECK:    addf
+// CHECK:  affine.for
+// CHECK:    subf
+// CHECK:  affine.for
+// CHECK:    mulf
+
+// -----
+
+// CHECK-LABEL: func @should_not_fuse_since_top_level_non_affine_users
+func @should_not_fuse_since_top_level_non_affine_users(%in0 : memref<32xf32>,
+                      %in1 : memref<32xf32>) {
+  %sum = alloc() : memref<f32>
+  affine.for %d = 0 to 32 {
+    %lhs = affine.load %in0[%d] : memref<32xf32>
+    %rhs = affine.load %in1[%d] : memref<32xf32>
+    %add = addf %lhs, %rhs : f32
+    store %add, %sum[] : memref<f32>
+    affine.store %add, %in0[%d] : memref<32xf32>
+  }
+  %load_sum = load %sum[] : memref<f32>
+  affine.for %d = 0 to 32 {
+    %lhs = affine.load %in0[%d] : memref<32xf32>
+    %rhs = affine.load %in1[%d] : memref<32xf32>
+    %add = mulf %lhs, %rhs : f32
+    %sub = subf %add, %load_sum: f32
+    affine.store %sub, %in0[%d] : memref<32xf32>
+  }
+  dealloc %sum : memref<f32>
+  return
+}
+
+// CHECK:  affine.for
+// CHECK:    addf
+// CHECK:  affine.for
+// CHECK:    mulf
+// CHECK:    subf


        


More information about the Mlir-commits mailing list