[Mlir-commits] [mlir] c0d2ea9 - [mlir][scf] Improve `scf.parallel` fusion pass (#75852)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Dec 19 07:07:51 PST 2023
Author: Ivan Butygin
Date: 2023-12-19T18:07:46+03:00
New Revision: c0d2ea9d4202c7cce4214b3057a709ff2f1128ae
URL: https://github.com/llvm/llvm-project/commit/c0d2ea9d4202c7cce4214b3057a709ff2f1128ae
DIFF: https://github.com/llvm/llvm-project/commit/c0d2ea9d4202c7cce4214b3057a709ff2f1128ae.diff
LOG: [mlir][scf] Improve `scf.parallel` fusion pass (#75852)
Abort fusion if memref load may alias write, but not the exact alias.
Add alias check hook to `naivelyFuseParallelOps`, so user can customize
alias checking.
Use builtin alias analysis in `ParallelLoopFusion` pass.
Added:
Modified:
mlir/include/mlir/Dialect/SCF/Transforms/Transforms.h
mlir/lib/Dialect/SCF/Transforms/ParallelLoopFusion.cpp
mlir/test/Dialect/SCF/parallel-loop-fusion.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/Transforms.h b/mlir/include/mlir/Dialect/SCF/Transforms/Transforms.h
index e66686d4e08f5c..e91f9e4469ab72 100644
--- a/mlir/include/mlir/Dialect/SCF/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/SCF/Transforms/Transforms.h
@@ -34,7 +34,10 @@ class ParallelOp;
/// Fuses all adjacent scf.parallel operations with identical bounds and step
/// into one scf.parallel operations. Uses a naive aliasing and dependency
/// analysis.
-void naivelyFuseParallelOps(Region ®ion);
+/// User can additionally customize alias checking with `mayAlias` hook.
+/// `mayAlias` must return false if 2 values are guaranteed to not alias.
+void naivelyFuseParallelOps(Region ®ion,
+ llvm::function_ref<bool(Value, Value)> mayAlias);
/// Rewrite a for loop with bounds/step that potentially do not divide evenly
/// into a for loop where the step divides the iteration space evenly, followed
diff --git a/mlir/lib/Dialect/SCF/Transforms/ParallelLoopFusion.cpp b/mlir/lib/Dialect/SCF/Transforms/ParallelLoopFusion.cpp
index 9a5db1b41b35ad..d7184ad0bad2c7 100644
--- a/mlir/lib/Dialect/SCF/Transforms/ParallelLoopFusion.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/ParallelLoopFusion.cpp
@@ -12,6 +12,7 @@
#include "mlir/Dialect/SCF/Transforms/Passes.h"
+#include "mlir/Analysis/AliasAnalysis.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/SCF/Transforms/Transforms.h"
@@ -58,19 +59,27 @@ static bool equalIterationSpaces(ParallelOp firstPloop,
/// loop reads.
static bool haveNoReadsAfterWriteExceptSameIndex(
ParallelOp firstPloop, ParallelOp secondPloop,
- const IRMapping &firstToSecondPloopIndices) {
+ const IRMapping &firstToSecondPloopIndices,
+ llvm::function_ref<bool(Value, Value)> mayAlias) {
DenseMap<Value, SmallVector<ValueRange, 1>> bufferStores;
+ SmallVector<Value> bufferStoresVec;
firstPloop.getBody()->walk([&](memref::StoreOp store) {
bufferStores[store.getMemRef()].push_back(store.getIndices());
+ bufferStoresVec.emplace_back(store.getMemRef());
});
auto walkResult = secondPloop.getBody()->walk([&](memref::LoadOp load) {
+ Value loadMem = load.getMemRef();
// Stop if the memref is defined in secondPloop body. Careful alias analysis
// is needed.
- auto *memrefDef = load.getMemRef().getDefiningOp();
+ auto *memrefDef = loadMem.getDefiningOp();
if (memrefDef && memrefDef->getBlock() == load->getBlock())
return WalkResult::interrupt();
- auto write = bufferStores.find(load.getMemRef());
+ for (Value store : bufferStoresVec)
+ if (store != loadMem && mayAlias(store, loadMem))
+ return WalkResult::interrupt();
+
+ auto write = bufferStores.find(loadMem);
if (write == bufferStores.end())
return WalkResult::advance();
@@ -98,35 +107,39 @@ static bool haveNoReadsAfterWriteExceptSameIndex(
/// write patterns.
static LogicalResult
verifyDependencies(ParallelOp firstPloop, ParallelOp secondPloop,
- const IRMapping &firstToSecondPloopIndices) {
- if (!haveNoReadsAfterWriteExceptSameIndex(firstPloop, secondPloop,
- firstToSecondPloopIndices))
+ const IRMapping &firstToSecondPloopIndices,
+ llvm::function_ref<bool(Value, Value)> mayAlias) {
+ if (!haveNoReadsAfterWriteExceptSameIndex(
+ firstPloop, secondPloop, firstToSecondPloopIndices, mayAlias))
return failure();
IRMapping secondToFirstPloopIndices;
secondToFirstPloopIndices.map(secondPloop.getBody()->getArguments(),
firstPloop.getBody()->getArguments());
return success(haveNoReadsAfterWriteExceptSameIndex(
- secondPloop, firstPloop, secondToFirstPloopIndices));
+ secondPloop, firstPloop, secondToFirstPloopIndices, mayAlias));
}
static bool isFusionLegal(ParallelOp firstPloop, ParallelOp secondPloop,
- const IRMapping &firstToSecondPloopIndices) {
+ const IRMapping &firstToSecondPloopIndices,
+ llvm::function_ref<bool(Value, Value)> mayAlias) {
return !hasNestedParallelOp(firstPloop) &&
!hasNestedParallelOp(secondPloop) &&
equalIterationSpaces(firstPloop, secondPloop) &&
succeeded(verifyDependencies(firstPloop, secondPloop,
- firstToSecondPloopIndices));
+ firstToSecondPloopIndices, mayAlias));
}
/// Prepends operations of firstPloop's body into secondPloop's body.
static void fuseIfLegal(ParallelOp firstPloop, ParallelOp secondPloop,
- OpBuilder b) {
+ OpBuilder b,
+ llvm::function_ref<bool(Value, Value)> mayAlias) {
IRMapping firstToSecondPloopIndices;
firstToSecondPloopIndices.map(firstPloop.getBody()->getArguments(),
secondPloop.getBody()->getArguments());
- if (!isFusionLegal(firstPloop, secondPloop, firstToSecondPloopIndices))
+ if (!isFusionLegal(firstPloop, secondPloop, firstToSecondPloopIndices,
+ mayAlias))
return;
b.setInsertionPointToStart(secondPloop.getBody());
@@ -135,7 +148,8 @@ static void fuseIfLegal(ParallelOp firstPloop, ParallelOp secondPloop,
firstPloop.erase();
}
-void mlir::scf::naivelyFuseParallelOps(Region ®ion) {
+void mlir::scf::naivelyFuseParallelOps(
+ Region ®ion, llvm::function_ref<bool(Value, Value)> mayAlias) {
OpBuilder b(region);
// Consider every single block and attempt to fuse adjacent loops.
for (auto &block : region) {
@@ -159,7 +173,7 @@ void mlir::scf::naivelyFuseParallelOps(Region ®ion) {
}
for (ArrayRef<ParallelOp> ploops : ploopChains) {
for (int i = 0, e = ploops.size(); i + 1 < e; ++i)
- fuseIfLegal(ploops[i], ploops[i + 1], b);
+ fuseIfLegal(ploops[i], ploops[i + 1], b, mayAlias);
}
}
}
@@ -168,9 +182,15 @@ namespace {
struct ParallelLoopFusion
: public impl::SCFParallelLoopFusionBase<ParallelLoopFusion> {
void runOnOperation() override {
+ auto &AA = getAnalysis<AliasAnalysis>();
+
+ auto mayAlias = [&](Value val1, Value val2) -> bool {
+ return !AA.alias(val1, val2).isNo();
+ };
+
getOperation()->walk([&](Operation *child) {
for (Region ®ion : child->getRegions())
- naivelyFuseParallelOps(region);
+ naivelyFuseParallelOps(region, mayAlias);
});
}
};
diff --git a/mlir/test/Dialect/SCF/parallel-loop-fusion.mlir b/mlir/test/Dialect/SCF/parallel-loop-fusion.mlir
index aab64b2751caf7..8a42b3a1000ed6 100644
--- a/mlir/test/Dialect/SCF/parallel-loop-fusion.mlir
+++ b/mlir/test/Dialect/SCF/parallel-loop-fusion.mlir
@@ -357,3 +357,33 @@ func.func @nested_fuse(%A: memref<2x2xf32>, %B: memref<2x2xf32>,
// CHECK: }
// CHECK: }
// CHECK: memref.dealloc [[SUM]]
+
+// -----
+
+func.func @do_not_fuse_alias(%A: memref<2x2xf32>, %B: memref<2x2xf32>,
+ %C: memref<2x2xf32>, %result: memref<2x2xf32>,
+ %sum: memref<2x2xf32>) {
+ %c2 = arith.constant 2 : index
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
+ %B_elem = memref.load %B[%i, %j] : memref<2x2xf32>
+ %C_elem = memref.load %C[%i, %j] : memref<2x2xf32>
+ %sum_elem = arith.addf %B_elem, %C_elem : f32
+ memref.store %sum_elem, %sum[%i, %j] : memref<2x2xf32>
+ scf.yield
+ }
+ scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
+ %sum_elem = memref.load %sum[%i, %j] : memref<2x2xf32>
+ %A_elem = memref.load %A[%i, %j] : memref<2x2xf32>
+ %product_elem = arith.mulf %sum_elem, %A_elem : f32
+ memref.store %product_elem, %result[%i, %j] : memref<2x2xf32>
+ scf.yield
+ }
+ return
+}
+
+// %sum and %result may alias with other args, do not fuse loops
+// CHECK-LABEL: func @do_not_fuse_alias
+// CHECK: scf.parallel
+// CHECK: scf.parallel
More information about the Mlir-commits
mailing list