[Mlir-commits] [mlir] [mlir][scf] Relax requirements for loops fusion (PR #79187)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Jan 24 01:15:14 PST 2024
https://github.com/fabrizio-indirli updated https://github.com/llvm/llvm-project/pull/79187
>From 06605146db67b5205fabf4ebf7e37dd88bd17296 Mon Sep 17 00:00:00 2001
From: Fabrizio Indirli <Fabrizio.Indirli at arm.com>
Date: Mon, 22 Jan 2024 11:16:30 +0000
Subject: [PATCH] [mlir][scf] Relax requirements for loops fusion
Enable the fusion of parallel loops also when the 1st loop
contains multiple write accesses to the same buffer,
if the accesses are always on the same indices.
Avoid failing on possible aliasing when only one memref
from the function args is being written.
Signed-off-by: Fabrizio Indirli <Fabrizio.Indirli at arm.com>
---
.../SCF/Transforms/ParallelLoopFusion.cpp | 30 +++++--
.../Dialect/SCF/parallel-loop-fusion.mlir | 81 ++++++++++++++++++-
2 files changed, 105 insertions(+), 6 deletions(-)
diff --git a/mlir/lib/Dialect/SCF/Transforms/ParallelLoopFusion.cpp b/mlir/lib/Dialect/SCF/Transforms/ParallelLoopFusion.cpp
index d7184ad0bad2c7a..88e22b104bcfc74 100644
--- a/mlir/lib/Dialect/SCF/Transforms/ParallelLoopFusion.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/ParallelLoopFusion.cpp
@@ -21,6 +21,8 @@
#include "mlir/IR/OpDefinition.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
+#include "llvm/ADT/SetVector.h"
+
namespace mlir {
#define GEN_PASS_DEF_SCFPARALLELLOOPFUSION
#include "mlir/Dialect/SCF/Transforms/Passes.h.inc"
@@ -63,10 +65,20 @@ static bool haveNoReadsAfterWriteExceptSameIndex(
llvm::function_ref<bool(Value, Value)> mayAlias) {
DenseMap<Value, SmallVector<ValueRange, 1>> bufferStores;
SmallVector<Value> bufferStoresVec;
+ llvm::SmallSetVector<BlockArgument, 10u> writtenArgs;
firstPloop.getBody()->walk([&](memref::StoreOp store) {
bufferStores[store.getMemRef()].push_back(store.getIndices());
- bufferStoresVec.emplace_back(store.getMemRef());
+ const auto storeMemRef = store.getMemRef();
+ bufferStoresVec.emplace_back(storeMemRef);
+ if (llvm::isa<BlockArgument>(storeMemRef))
+ writtenArgs.insert(llvm::cast<BlockArgument>(storeMemRef));
+ });
+ secondPloop.getBody()->walk([&](memref::StoreOp store) {
+ const auto storeMemRef = store.getMemRef();
+ if (llvm::isa<BlockArgument>(storeMemRef))
+ writtenArgs.insert(llvm::cast<BlockArgument>(storeMemRef));
});
+
auto walkResult = secondPloop.getBody()->walk([&](memref::LoadOp load) {
Value loadMem = load.getMemRef();
// Stop if the memref is defined in secondPloop body. Careful alias analysis
@@ -76,20 +88,28 @@ static bool haveNoReadsAfterWriteExceptSameIndex(
return WalkResult::interrupt();
for (Value store : bufferStoresVec)
- if (store != loadMem && mayAlias(store, loadMem))
+ if ((store != loadMem) && (writtenArgs.size() > 1) &&
+ mayAlias(store, loadMem))
return WalkResult::interrupt();
auto write = bufferStores.find(loadMem);
if (write == bufferStores.end())
return WalkResult::advance();
- // Allow only single write access per buffer.
- if (write->second.size() != 1)
+ // Check that at last one store was retrieved
+ if (!write->second.size())
return WalkResult::interrupt();
+ auto storeIndices = write->second.front();
+
+ // Multiple writes to the same memref are allowed only on the same indices
+ for (const auto &othStoreIndices : write->second) {
+ if (othStoreIndices != storeIndices)
+ return WalkResult::interrupt();
+ }
+
// Check that the load indices of secondPloop coincide with store indices of
// firstPloop for the same memrefs.
- auto storeIndices = write->second.front();
auto loadIndices = load.getIndices();
if (storeIndices.size() != loadIndices.size())
return WalkResult::interrupt();
diff --git a/mlir/test/Dialect/SCF/parallel-loop-fusion.mlir b/mlir/test/Dialect/SCF/parallel-loop-fusion.mlir
index 9fd33b4e5247178..d62f5ed91dec8fb 100644
--- a/mlir/test/Dialect/SCF/parallel-loop-fusion.mlir
+++ b/mlir/test/Dialect/SCF/parallel-loop-fusion.mlir
@@ -60,6 +60,7 @@ func.func @fuse_two(%A: memref<2x2xf32>, %B: memref<2x2xf32>,
// CHECK: [[C_ELEM:%.*]] = memref.load [[C]]{{\[}}[[I]], [[J]]]
// CHECK: [[SUM_ELEM:%.*]] = arith.addf [[B_ELEM]], [[C_ELEM]]
// CHECK: memref.store [[SUM_ELEM]], [[SUM]]{{\[}}[[I]], [[J]]]
+// CHECK-NOT: scf.parallel
// CHECK: [[SUM_ELEM_:%.*]] = memref.load [[SUM]]{{\[}}[[I]], [[J]]]
// CHECK: [[A_ELEM:%.*]] = memref.load [[A]]{{\[}}[[I]], [[J]]]
// CHECK: [[PRODUCT_ELEM:%.*]] = arith.mulf [[SUM_ELEM_]], [[A_ELEM]]
@@ -113,10 +114,12 @@ func.func @fuse_three(%lhs: memref<100x10xf32>, %rhs: memref<100xf32>,
// CHECK-SAME: to ([[C100]], [[C10]]) step ([[C1]], [[C1]]) {
// CHECK: [[RHS_ELEM:%.*]] = memref.load [[RHS]]{{\[}}[[I]]]
// CHECK: memref.store [[RHS_ELEM]], [[BROADCAST_RHS]]{{\[}}[[I]], [[J]]]
+// CHECK-NOT: scf.parallel
// CHECK: [[LHS_ELEM:%.*]] = memref.load [[LHS]]{{\[}}[[I]], [[J]]]
// CHECK: [[BROADCAST_RHS_ELEM:%.*]] = memref.load [[BROADCAST_RHS]]
// CHECK: [[DIFF_ELEM:%.*]] = arith.subf [[LHS_ELEM]], [[BROADCAST_RHS_ELEM]]
// CHECK: memref.store [[DIFF_ELEM]], [[DIFF]]{{\[}}[[I]], [[J]]]
+// CHECK-NOT: scf.parallel
// CHECK: [[DIFF_ELEM_:%.*]] = memref.load [[DIFF]]{{\[}}[[I]], [[J]]]
// CHECK: [[EXP_ELEM:%.*]] = math.exp [[DIFF_ELEM_]]
// CHECK: memref.store [[EXP_ELEM]], [[RESULT]]{{\[}}[[I]], [[J]]]
@@ -382,8 +385,84 @@ func.func @do_not_fuse_alias(%A: memref<2x2xf32>, %B: memref<2x2xf32>,
}
return
}
-
// %sum and %result may alias with other args, do not fuse loops
// CHECK-LABEL: func @do_not_fuse_alias
// CHECK: scf.parallel
// CHECK: scf.parallel
+
+// -----
+
+func.func @fuse_when_1st_has_multiple_stores(
+ %A: memref<2x2xf32>, %B: memref<2x2xf32>, %result: memref<2x2xf32>) {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %c2 = arith.constant 2 : index
+ %c0f32 = arith.constant 0.0 : f32
+ %sum = memref.alloc() : memref<2x2xf32>
+ scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
+ memref.store %c0f32, %sum[%i, %j] : memref<2x2xf32>
+ %B_elem = memref.load %B[%i, %j] : memref<2x2xf32>
+ %sum_elem = arith.addf %B_elem, %B_elem : f32
+ memref.store %sum_elem, %sum[%i, %j] : memref<2x2xf32>
+ scf.reduce
+ }
+ scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
+ %sum_elem = memref.load %sum[%i, %j] : memref<2x2xf32>
+ %A_elem = memref.load %A[%i, %j] : memref<2x2xf32>
+ %product_elem = arith.mulf %sum_elem, %A_elem : f32
+ memref.store %product_elem, %result[%i, %j] : memref<2x2xf32>
+ scf.reduce
+ }
+ memref.dealloc %sum : memref<2x2xf32>
+ return
+}
+// CHECK-LABEL: func @fuse_when_1st_has_multiple_stores
+// CHECK-SAME: ([[A:%.*]]: {{.*}}, [[B:%.*]]: {{.*}},
+// CHECK-SAME: [[RESULT:%.*]]: {{.*}}) {
+// CHECK: [[C0:%.*]] = arith.constant 0 : index
+// CHECK: [[C1:%.*]] = arith.constant 1 : index
+// CHECK: [[C2:%.*]] = arith.constant 2 : index
+// CHECK: [[C0F32:%.*]] = arith.constant 0.
+// CHECK: [[SUM:%.*]] = memref.alloc()
+// CHECK: scf.parallel ([[I:%.*]], [[J:%.*]]) = ([[C0]], [[C0]])
+// CHECK-SAME: to ([[C2]], [[C2]]) step ([[C1]], [[C1]]) {
+// CHECK: [[B_ELEM:%.*]] = memref.load [[B]]{{\[}}[[I]], [[J]]]
+// CHECK: [[SUM_ELEM:%.*]] = arith.addf [[B_ELEM]], [[B_ELEM]]
+// CHECK: memref.store [[SUM_ELEM]], [[SUM]]{{\[}}[[I]], [[J]]]
+// CHECK-NOT: scf.parallel
+// CHECK: [[A_ELEM:%.*]] = memref.load [[A]]{{\[}}[[I]], [[J]]]
+// CHECK: [[PRODUCT_ELEM:%.*]] = arith.mulf
+// CHECK: memref.store [[PRODUCT_ELEM]], [[RESULT]]{{\[}}[[I]], [[J]]]
+// CHECK: scf.reduce
+// CHECK: }
+// CHECK: memref.dealloc [[SUM]]
+
+// -----
+
+func.func @do_not_fuse_multiple_stores_on_diff_indices(
+ %A: memref<2x2xf32>, %B: memref<2x2xf32>, %result: memref<2x2xf32>) {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %c2 = arith.constant 2 : index
+ %c0_f32 = arith.constant 0.0 : f32
+ %sum = memref.alloc() : memref<2x2xf32>
+ scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
+ memref.store %c0_f32, %sum[%i, %j] : memref<2x2xf32>
+ %B_elem = memref.load %B[%i, %j] : memref<2x2xf32>
+ %sum_elem = arith.addf %B_elem, %B_elem : f32
+ memref.store %sum_elem, %sum[%c0, %j] : memref<2x2xf32>
+ scf.reduce
+ }
+ scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
+ %sum_elem = memref.load %sum[%i, %j] : memref<2x2xf32>
+ %A_elem = memref.load %A[%i, %j] : memref<2x2xf32>
+ %product_elem = arith.mulf %sum_elem, %A_elem : f32
+ memref.store %product_elem, %result[%i, %j] : memref<2x2xf32>
+ scf.reduce
+ }
+ memref.dealloc %sum : memref<2x2xf32>
+ return
+}
+// CHECK-LABEL: func @do_not_fuse_multiple_stores_on_diff_indices
+// CHECK: scf.parallel
+// CHECK: scf.parallel
\ No newline at end of file
More information about the Mlir-commits
mailing list