[Mlir-commits] [mlir] [mlir][scf] Relax requirements for loops fusion (PR #79187)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Jan 23 10:19:13 PST 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-scf
Author: None (fabrizio-indirli)
<details>
<summary>Changes</summary>
Enable the fusion of parallel loops also when the 1st loop contains multiple write accesses to the same buffer, if the accesses are always on the same indices.
Avoid failing on possible aliasing when only one memref from the function args is being written.
---
Full diff: https://github.com/llvm/llvm-project/pull/79187.diff
2 Files Affected:
- (modified) mlir/lib/Dialect/SCF/Transforms/ParallelLoopFusion.cpp (+25-5)
- (modified) mlir/test/Dialect/SCF/parallel-loop-fusion.mlir (+80-1)
``````````diff
diff --git a/mlir/lib/Dialect/SCF/Transforms/ParallelLoopFusion.cpp b/mlir/lib/Dialect/SCF/Transforms/ParallelLoopFusion.cpp
index d7184ad0bad2c7a..88e22b104bcfc74 100644
--- a/mlir/lib/Dialect/SCF/Transforms/ParallelLoopFusion.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/ParallelLoopFusion.cpp
@@ -21,6 +21,8 @@
#include "mlir/IR/OpDefinition.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
+#include "llvm/ADT/SetVector.h"
+
namespace mlir {
#define GEN_PASS_DEF_SCFPARALLELLOOPFUSION
#include "mlir/Dialect/SCF/Transforms/Passes.h.inc"
@@ -63,10 +65,20 @@ static bool haveNoReadsAfterWriteExceptSameIndex(
llvm::function_ref<bool(Value, Value)> mayAlias) {
DenseMap<Value, SmallVector<ValueRange, 1>> bufferStores;
SmallVector<Value> bufferStoresVec;
+ llvm::SmallSetVector<BlockArgument, 10u> writtenArgs;
firstPloop.getBody()->walk([&](memref::StoreOp store) {
bufferStores[store.getMemRef()].push_back(store.getIndices());
- bufferStoresVec.emplace_back(store.getMemRef());
+ const auto storeMemRef = store.getMemRef();
+ bufferStoresVec.emplace_back(storeMemRef);
+ if (llvm::isa<BlockArgument>(storeMemRef))
+ writtenArgs.insert(llvm::cast<BlockArgument>(storeMemRef));
+ });
+ secondPloop.getBody()->walk([&](memref::StoreOp store) {
+ const auto storeMemRef = store.getMemRef();
+ if (llvm::isa<BlockArgument>(storeMemRef))
+ writtenArgs.insert(llvm::cast<BlockArgument>(storeMemRef));
});
+
auto walkResult = secondPloop.getBody()->walk([&](memref::LoadOp load) {
Value loadMem = load.getMemRef();
// Stop if the memref is defined in secondPloop body. Careful alias analysis
@@ -76,20 +88,28 @@ static bool haveNoReadsAfterWriteExceptSameIndex(
return WalkResult::interrupt();
for (Value store : bufferStoresVec)
- if (store != loadMem && mayAlias(store, loadMem))
+ if ((store != loadMem) && (writtenArgs.size() > 1) &&
+ mayAlias(store, loadMem))
return WalkResult::interrupt();
auto write = bufferStores.find(loadMem);
if (write == bufferStores.end())
return WalkResult::advance();
- // Allow only single write access per buffer.
- if (write->second.size() != 1)
+ // Check that at last one store was retrieved
+ if (!write->second.size())
return WalkResult::interrupt();
+ auto storeIndices = write->second.front();
+
+ // Multiple writes to the same memref are allowed only on the same indices
+ for (const auto &othStoreIndices : write->second) {
+ if (othStoreIndices != storeIndices)
+ return WalkResult::interrupt();
+ }
+
// Check that the load indices of secondPloop coincide with store indices of
// firstPloop for the same memrefs.
- auto storeIndices = write->second.front();
auto loadIndices = load.getIndices();
if (storeIndices.size() != loadIndices.size())
return WalkResult::interrupt();
diff --git a/mlir/test/Dialect/SCF/parallel-loop-fusion.mlir b/mlir/test/Dialect/SCF/parallel-loop-fusion.mlir
index 9fd33b4e5247178..d62f5ed91dec8fb 100644
--- a/mlir/test/Dialect/SCF/parallel-loop-fusion.mlir
+++ b/mlir/test/Dialect/SCF/parallel-loop-fusion.mlir
@@ -60,6 +60,7 @@ func.func @fuse_two(%A: memref<2x2xf32>, %B: memref<2x2xf32>,
// CHECK: [[C_ELEM:%.*]] = memref.load [[C]]{{\[}}[[I]], [[J]]]
// CHECK: [[SUM_ELEM:%.*]] = arith.addf [[B_ELEM]], [[C_ELEM]]
// CHECK: memref.store [[SUM_ELEM]], [[SUM]]{{\[}}[[I]], [[J]]]
+// CHECK-NOT: scf.parallel
// CHECK: [[SUM_ELEM_:%.*]] = memref.load [[SUM]]{{\[}}[[I]], [[J]]]
// CHECK: [[A_ELEM:%.*]] = memref.load [[A]]{{\[}}[[I]], [[J]]]
// CHECK: [[PRODUCT_ELEM:%.*]] = arith.mulf [[SUM_ELEM_]], [[A_ELEM]]
@@ -113,10 +114,12 @@ func.func @fuse_three(%lhs: memref<100x10xf32>, %rhs: memref<100xf32>,
// CHECK-SAME: to ([[C100]], [[C10]]) step ([[C1]], [[C1]]) {
// CHECK: [[RHS_ELEM:%.*]] = memref.load [[RHS]]{{\[}}[[I]]]
// CHECK: memref.store [[RHS_ELEM]], [[BROADCAST_RHS]]{{\[}}[[I]], [[J]]]
+// CHECK-NOT: scf.parallel
// CHECK: [[LHS_ELEM:%.*]] = memref.load [[LHS]]{{\[}}[[I]], [[J]]]
// CHECK: [[BROADCAST_RHS_ELEM:%.*]] = memref.load [[BROADCAST_RHS]]
// CHECK: [[DIFF_ELEM:%.*]] = arith.subf [[LHS_ELEM]], [[BROADCAST_RHS_ELEM]]
// CHECK: memref.store [[DIFF_ELEM]], [[DIFF]]{{\[}}[[I]], [[J]]]
+// CHECK-NOT: scf.parallel
// CHECK: [[DIFF_ELEM_:%.*]] = memref.load [[DIFF]]{{\[}}[[I]], [[J]]]
// CHECK: [[EXP_ELEM:%.*]] = math.exp [[DIFF_ELEM_]]
// CHECK: memref.store [[EXP_ELEM]], [[RESULT]]{{\[}}[[I]], [[J]]]
@@ -382,8 +385,84 @@ func.func @do_not_fuse_alias(%A: memref<2x2xf32>, %B: memref<2x2xf32>,
}
return
}
-
// %sum and %result may alias with other args, do not fuse loops
// CHECK-LABEL: func @do_not_fuse_alias
// CHECK: scf.parallel
// CHECK: scf.parallel
+
+// -----
+
+func.func @fuse_when_1st_has_multiple_stores(
+ %A: memref<2x2xf32>, %B: memref<2x2xf32>, %result: memref<2x2xf32>) {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %c2 = arith.constant 2 : index
+ %c0f32 = arith.constant 0.0 : f32
+ %sum = memref.alloc() : memref<2x2xf32>
+ scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
+ memref.store %c0f32, %sum[%i, %j] : memref<2x2xf32>
+ %B_elem = memref.load %B[%i, %j] : memref<2x2xf32>
+ %sum_elem = arith.addf %B_elem, %B_elem : f32
+ memref.store %sum_elem, %sum[%i, %j] : memref<2x2xf32>
+ scf.reduce
+ }
+ scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
+ %sum_elem = memref.load %sum[%i, %j] : memref<2x2xf32>
+ %A_elem = memref.load %A[%i, %j] : memref<2x2xf32>
+ %product_elem = arith.mulf %sum_elem, %A_elem : f32
+ memref.store %product_elem, %result[%i, %j] : memref<2x2xf32>
+ scf.reduce
+ }
+ memref.dealloc %sum : memref<2x2xf32>
+ return
+}
+// CHECK-LABEL: func @fuse_when_1st_has_multiple_stores
+// CHECK-SAME: ([[A:%.*]]: {{.*}}, [[B:%.*]]: {{.*}},
+// CHECK-SAME: [[RESULT:%.*]]: {{.*}}) {
+// CHECK: [[C0:%.*]] = arith.constant 0 : index
+// CHECK: [[C1:%.*]] = arith.constant 1 : index
+// CHECK: [[C2:%.*]] = arith.constant 2 : index
+// CHECK: [[C0F32:%.*]] = arith.constant 0.
+// CHECK: [[SUM:%.*]] = memref.alloc()
+// CHECK: scf.parallel ([[I:%.*]], [[J:%.*]]) = ([[C0]], [[C0]])
+// CHECK-SAME: to ([[C2]], [[C2]]) step ([[C1]], [[C1]]) {
+// CHECK: [[B_ELEM:%.*]] = memref.load [[B]]{{\[}}[[I]], [[J]]]
+// CHECK: [[SUM_ELEM:%.*]] = arith.addf [[B_ELEM]], [[B_ELEM]]
+// CHECK: memref.store [[SUM_ELEM]], [[SUM]]{{\[}}[[I]], [[J]]]
+// CHECK-NOT: scf.parallel
+// CHECK: [[A_ELEM:%.*]] = memref.load [[A]]{{\[}}[[I]], [[J]]]
+// CHECK: [[PRODUCT_ELEM:%.*]] = arith.mulf
+// CHECK: memref.store [[PRODUCT_ELEM]], [[RESULT]]{{\[}}[[I]], [[J]]]
+// CHECK: scf.reduce
+// CHECK: }
+// CHECK: memref.dealloc [[SUM]]
+
+// -----
+
+func.func @do_not_fuse_multiple_stores_on_diff_indices(
+ %A: memref<2x2xf32>, %B: memref<2x2xf32>, %result: memref<2x2xf32>) {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %c2 = arith.constant 2 : index
+ %c0_f32 = arith.constant 0.0 : f32
+ %sum = memref.alloc() : memref<2x2xf32>
+ scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
+ memref.store %c0_f32, %sum[%i, %j] : memref<2x2xf32>
+ %B_elem = memref.load %B[%i, %j] : memref<2x2xf32>
+ %sum_elem = arith.addf %B_elem, %B_elem : f32
+ memref.store %sum_elem, %sum[%c0, %j] : memref<2x2xf32>
+ scf.reduce
+ }
+ scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
+ %sum_elem = memref.load %sum[%i, %j] : memref<2x2xf32>
+ %A_elem = memref.load %A[%i, %j] : memref<2x2xf32>
+ %product_elem = arith.mulf %sum_elem, %A_elem : f32
+ memref.store %product_elem, %result[%i, %j] : memref<2x2xf32>
+ scf.reduce
+ }
+ memref.dealloc %sum : memref<2x2xf32>
+ return
+}
+// CHECK-LABEL: func @do_not_fuse_multiple_stores_on_diff_indices
+// CHECK: scf.parallel
+// CHECK: scf.parallel
\ No newline at end of file
``````````
</details>
https://github.com/llvm/llvm-project/pull/79187
More information about the Mlir-commits
mailing list