[Mlir-commits] [mlir] 2b5d177 - [MLIR][Affine-loop-fusion] Fix a bug in affine-loop-fusion pass when there are non-affine operations
Uday Bondhugula
llvmlistbot at llvm.org
Fri Jun 26 05:57:54 PDT 2020
Author: Tung D. Le
Date: 2020-06-26T18:26:42+05:30
New Revision: 2b5d1776ffad2614756ef059d64b957c7731e7be
URL: https://github.com/llvm/llvm-project/commit/2b5d1776ffad2614756ef059d64b957c7731e7be
DIFF: https://github.com/llvm/llvm-project/commit/2b5d1776ffad2614756ef059d64b957c7731e7be.diff
LOG: [MLIR][Affine-loop-fusion] Fix a bug in affine-loop-fusion pass when there are non-affine operations
When there is a mix of affine load/store and non-affine operations (e.g. std.load, std.store),
affine-loop-fusion ignores the present of non-affine ops, thus changing the program semantics.
E.g. we have a program of three affine loops operating on the same memref in which one of them uses std.load and std.store, as follows.
```
affine.for
affine.store %1
affine.for
std.load %1
std.store %1
affine.for
affine.load %1
affine.store %1
```
affine-loop-fusion will produce the following result which changed the program semantics:
```
affine.for
std.load %1
std.store %1
affine.for
affine.store %1
affine.load %1
affine.store %1
```
This patch is to fix the above problem by checking non-affine users of the memref that are between the source and destination nodes of interest.
Differential Revision: https://reviews.llvm.org/D82158
Added:
Modified:
mlir/lib/Transforms/LoopFusion.cpp
mlir/test/Transforms/loop-fusion.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Transforms/LoopFusion.cpp b/mlir/lib/Transforms/LoopFusion.cpp
index 6a7a88e5a1ad..f71ff2aba9e9 100644
--- a/mlir/lib/Transforms/LoopFusion.cpp
+++ b/mlir/lib/Transforms/LoopFusion.cpp
@@ -948,6 +948,65 @@ static Value createPrivateMemRef(AffineForOp forOp, Operation *srcStoreOpInst,
return newMemRef;
}
+/// Walking from node 'srcId' to node 'dstId' (exclusive of 'srcId' and
+/// 'dstId'), if there is any non-affine operation accessing 'memref', return
+/// false. Otherwise, return true.
+static bool hasNonAffineUsersOnThePath(unsigned srcId, unsigned dstId,
+ Value memref,
+ MemRefDependenceGraph *mdg) {
+ auto *srcNode = mdg->getNode(srcId);
+ auto *dstNode = mdg->getNode(dstId);
+ Value::user_range users = memref.getUsers();
+ // For each MemRefDependenceGraph's node that is between 'srcNode' and
+ // 'dstNode' (exclusive of 'srcNodes' and 'dstNode'), check whether any
+ // non-affine operation in the node accesses the 'memref'.
+ for (auto &idAndNode : mdg->nodes) {
+ Operation *op = idAndNode.second.op;
+ // Take care of operations between 'srcNode' and 'dstNode'.
+ if (srcNode->op->isBeforeInBlock(op) && op->isBeforeInBlock(dstNode->op)) {
+ // Walk inside the operation to find any use of the memref.
+ // Interrupt the walk if found.
+ auto walkResult = op->walk([&](Operation *user) {
+ // Skip affine ops.
+ if (isMemRefDereferencingOp(*user))
+ return WalkResult::advance();
+ // Find a non-affine op that uses the memref.
+ if (llvm::is_contained(users, user))
+ return WalkResult::interrupt();
+ return WalkResult::advance();
+ });
+ if (walkResult.wasInterrupted())
+ return true;
+ }
+ }
+ return false;
+}
+
+/// Check whether a memref value in node 'srcId' has a non-affine that
+/// is between node 'srcId' and node 'dstId' (exclusive of 'srcNode' and
+/// 'dstNode').
+static bool hasNonAffineUsersOnThePath(unsigned srcId, unsigned dstId,
+ MemRefDependenceGraph *mdg) {
+ // Collect memref values in node 'srcId'.
+ auto *srcNode = mdg->getNode(srcId);
+ llvm::SmallDenseSet<Value, 2> memRefValues;
+ srcNode->op->walk([&](Operation *op) {
+ // Skip affine ops.
+ if (isa<AffineForOp>(op))
+ return WalkResult::advance();
+ for (Value v : op->getOperands())
+ // Collect memref values only.
+ if (v.getType().isa<MemRefType>())
+ memRefValues.insert(v);
+ return WalkResult::advance();
+ });
+ // Looking for users between node 'srcId' and node 'dstId'.
+ for (Value memref : memRefValues)
+ if (hasNonAffineUsersOnThePath(srcId, dstId, memref, mdg))
+ return true;
+ return false;
+}
+
// Checks if node 'srcId' can be safely fused into node 'dstId'. Node 'srcId'
// may write to multiple memrefs but it is required that only one of them,
// 'srcLiveOutStoreOp', has output edges.
@@ -1008,6 +1067,12 @@ canFuseSrcWhichWritesToLiveOut(unsigned srcId, unsigned dstId,
// TODO(andydavis) Check the shape and lower bounds here too.
if (srcNumElements != dstNumElements)
return false;
+
+ // Return false if 'memref' is used by a non-affine operation that is
+ // between node 'srcId' and node 'dstId'.
+ if (hasNonAffineUsersOnThePath(srcId, dstId, mdg))
+ return false;
+
return true;
}
@@ -1793,6 +1858,12 @@ struct GreedyFusion {
}
if (storeMemrefs.size() != 1)
return false;
+
+ // Skip if a memref value in one node is used by a non-affine memref
+ // access that lies between 'dstNode' and 'sibNode'.
+ if (hasNonAffineUsersOnThePath(dstNode->id, sibNode->id, mdg) ||
+ hasNonAffineUsersOnThePath(sibNode->id, dstNode->id, mdg))
+ return false;
return true;
};
diff --git a/mlir/test/Transforms/loop-fusion.mlir b/mlir/test/Transforms/loop-fusion.mlir
index a153b96bf362..51d2fb42a1c1 100644
--- a/mlir/test/Transforms/loop-fusion.mlir
+++ b/mlir/test/Transforms/loop-fusion.mlir
@@ -2570,3 +2570,67 @@ func @calc(%arg0: memref<?xf32>, %arg1: memref<?xf32>, %arg2: memref<?xf32>, %le
// CHECK-NEXT: affine.store %{{.*}}, %arg{{.*}}[%arg{{.*}}] : memref<?xf32>
// CHECK-NEXT: }
// CHECK-NEXT: return
+
+// -----
+
+// CHECK-LABEL: func @should_not_fuse_since_non_affine_users
+func @should_not_fuse_since_non_affine_users(%in0 : memref<32xf32>,
+ %in1 : memref<32xf32>) {
+ affine.for %d = 0 to 32 {
+ %lhs = affine.load %in0[%d] : memref<32xf32>
+ %rhs = affine.load %in1[%d] : memref<32xf32>
+ %add = addf %lhs, %rhs : f32
+ affine.store %add, %in0[%d] : memref<32xf32>
+ }
+ affine.for %d = 0 to 32 {
+ %lhs = load %in0[%d] : memref<32xf32>
+ %rhs = load %in1[%d] : memref<32xf32>
+ %add = subf %lhs, %rhs : f32
+ store %add, %in0[%d] : memref<32xf32>
+ }
+ affine.for %d = 0 to 32 {
+ %lhs = affine.load %in0[%d] : memref<32xf32>
+ %rhs = affine.load %in1[%d] : memref<32xf32>
+ %add = mulf %lhs, %rhs : f32
+ affine.store %add, %in0[%d] : memref<32xf32>
+ }
+ return
+}
+
+// CHECK: affine.for
+// CHECK: addf
+// CHECK: affine.for
+// CHECK: subf
+// CHECK: affine.for
+// CHECK: mulf
+
+// -----
+
+// CHECK-LABEL: func @should_not_fuse_since_top_level_non_affine_users
+func @should_not_fuse_since_top_level_non_affine_users(%in0 : memref<32xf32>,
+ %in1 : memref<32xf32>) {
+ %sum = alloc() : memref<f32>
+ affine.for %d = 0 to 32 {
+ %lhs = affine.load %in0[%d] : memref<32xf32>
+ %rhs = affine.load %in1[%d] : memref<32xf32>
+ %add = addf %lhs, %rhs : f32
+ store %add, %sum[] : memref<f32>
+ affine.store %add, %in0[%d] : memref<32xf32>
+ }
+ %load_sum = load %sum[] : memref<f32>
+ affine.for %d = 0 to 32 {
+ %lhs = affine.load %in0[%d] : memref<32xf32>
+ %rhs = affine.load %in1[%d] : memref<32xf32>
+ %add = mulf %lhs, %rhs : f32
+ %sub = subf %add, %load_sum: f32
+ affine.store %sub, %in0[%d] : memref<32xf32>
+ }
+ dealloc %sum : memref<f32>
+ return
+}
+
+// CHECK: affine.for
+// CHECK: addf
+// CHECK: affine.for
+// CHECK: mulf
+// CHECK: subf
More information about the Mlir-commits
mailing list