[Mlir-commits] [mlir] [mlir][scf] Improve `scf.parallel` fusion pass (PR #75852)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Dec 18 12:56:04 PST 2023


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-scf

Author: Ivan Butygin (Hardcode84)

<details>
<summary>Changes</summary>

Abort fusion if memref load may alias write, but not the exact alias. 
Add alias check hook to `naivelyFuseParallelOps`, so user can customize alias checking. 
Use builtin alias analysis in `ParallelLoopFusion` pass.

---
Full diff: https://github.com/llvm/llvm-project/pull/75852.diff


3 Files Affected:

- (modified) mlir/include/mlir/Dialect/SCF/Transforms/Transforms.h (+4-1) 
- (modified) mlir/lib/Dialect/SCF/Transforms/ParallelLoopFusion.cpp (+34-14) 
- (modified) mlir/test/Dialect/SCF/parallel-loop-fusion.mlir (+30) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/Transforms.h b/mlir/include/mlir/Dialect/SCF/Transforms/Transforms.h
index e66686d4e08f5c..d3b4b588427712 100644
--- a/mlir/include/mlir/Dialect/SCF/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/SCF/Transforms/Transforms.h
@@ -34,7 +34,10 @@ class ParallelOp;
 /// Fuses all adjacent scf.parallel operations with identical bounds and step
 /// into one scf.parallel operations. Uses a naive aliasing and dependency
 /// analysis.
-void naivelyFuseParallelOps(Region &region);
+/// User can additioanlly customize alias checking with `mayAlias` hook.
+/// `mayAlias` must return false if 2 values are guaranteed to no alias.
+void naivelyFuseParallelOps(Region &region,
+                            llvm::function_ref<bool(Value, Value)> mayAlias);
 
 /// Rewrite a for loop with bounds/step that potentially do not divide evenly
 /// into a for loop where the step divides the iteration space evenly, followed
diff --git a/mlir/lib/Dialect/SCF/Transforms/ParallelLoopFusion.cpp b/mlir/lib/Dialect/SCF/Transforms/ParallelLoopFusion.cpp
index 9a5db1b41b35ad..d7184ad0bad2c7 100644
--- a/mlir/lib/Dialect/SCF/Transforms/ParallelLoopFusion.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/ParallelLoopFusion.cpp
@@ -12,6 +12,7 @@
 
 #include "mlir/Dialect/SCF/Transforms/Passes.h"
 
+#include "mlir/Analysis/AliasAnalysis.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/SCF/IR/SCF.h"
 #include "mlir/Dialect/SCF/Transforms/Transforms.h"
@@ -58,19 +59,27 @@ static bool equalIterationSpaces(ParallelOp firstPloop,
 /// loop reads.
 static bool haveNoReadsAfterWriteExceptSameIndex(
     ParallelOp firstPloop, ParallelOp secondPloop,
-    const IRMapping &firstToSecondPloopIndices) {
+    const IRMapping &firstToSecondPloopIndices,
+    llvm::function_ref<bool(Value, Value)> mayAlias) {
   DenseMap<Value, SmallVector<ValueRange, 1>> bufferStores;
+  SmallVector<Value> bufferStoresVec;
   firstPloop.getBody()->walk([&](memref::StoreOp store) {
     bufferStores[store.getMemRef()].push_back(store.getIndices());
+    bufferStoresVec.emplace_back(store.getMemRef());
   });
   auto walkResult = secondPloop.getBody()->walk([&](memref::LoadOp load) {
+    Value loadMem = load.getMemRef();
     // Stop if the memref is defined in secondPloop body. Careful alias analysis
     // is needed.
-    auto *memrefDef = load.getMemRef().getDefiningOp();
+    auto *memrefDef = loadMem.getDefiningOp();
     if (memrefDef && memrefDef->getBlock() == load->getBlock())
       return WalkResult::interrupt();
 
-    auto write = bufferStores.find(load.getMemRef());
+    for (Value store : bufferStoresVec)
+      if (store != loadMem && mayAlias(store, loadMem))
+        return WalkResult::interrupt();
+
+    auto write = bufferStores.find(loadMem);
     if (write == bufferStores.end())
       return WalkResult::advance();
 
@@ -98,35 +107,39 @@ static bool haveNoReadsAfterWriteExceptSameIndex(
 /// write patterns.
 static LogicalResult
 verifyDependencies(ParallelOp firstPloop, ParallelOp secondPloop,
-                   const IRMapping &firstToSecondPloopIndices) {
-  if (!haveNoReadsAfterWriteExceptSameIndex(firstPloop, secondPloop,
-                                            firstToSecondPloopIndices))
+                   const IRMapping &firstToSecondPloopIndices,
+                   llvm::function_ref<bool(Value, Value)> mayAlias) {
+  if (!haveNoReadsAfterWriteExceptSameIndex(
+          firstPloop, secondPloop, firstToSecondPloopIndices, mayAlias))
     return failure();
 
   IRMapping secondToFirstPloopIndices;
   secondToFirstPloopIndices.map(secondPloop.getBody()->getArguments(),
                                 firstPloop.getBody()->getArguments());
   return success(haveNoReadsAfterWriteExceptSameIndex(
-      secondPloop, firstPloop, secondToFirstPloopIndices));
+      secondPloop, firstPloop, secondToFirstPloopIndices, mayAlias));
 }
 
 static bool isFusionLegal(ParallelOp firstPloop, ParallelOp secondPloop,
-                          const IRMapping &firstToSecondPloopIndices) {
+                          const IRMapping &firstToSecondPloopIndices,
+                          llvm::function_ref<bool(Value, Value)> mayAlias) {
   return !hasNestedParallelOp(firstPloop) &&
          !hasNestedParallelOp(secondPloop) &&
          equalIterationSpaces(firstPloop, secondPloop) &&
          succeeded(verifyDependencies(firstPloop, secondPloop,
-                                      firstToSecondPloopIndices));
+                                      firstToSecondPloopIndices, mayAlias));
 }
 
 /// Prepends operations of firstPloop's body into secondPloop's body.
 static void fuseIfLegal(ParallelOp firstPloop, ParallelOp secondPloop,
-                        OpBuilder b) {
+                        OpBuilder b,
+                        llvm::function_ref<bool(Value, Value)> mayAlias) {
   IRMapping firstToSecondPloopIndices;
   firstToSecondPloopIndices.map(firstPloop.getBody()->getArguments(),
                                 secondPloop.getBody()->getArguments());
 
-  if (!isFusionLegal(firstPloop, secondPloop, firstToSecondPloopIndices))
+  if (!isFusionLegal(firstPloop, secondPloop, firstToSecondPloopIndices,
+                     mayAlias))
     return;
 
   b.setInsertionPointToStart(secondPloop.getBody());
@@ -135,7 +148,8 @@ static void fuseIfLegal(ParallelOp firstPloop, ParallelOp secondPloop,
   firstPloop.erase();
 }
 
-void mlir::scf::naivelyFuseParallelOps(Region &region) {
+void mlir::scf::naivelyFuseParallelOps(
+    Region &region, llvm::function_ref<bool(Value, Value)> mayAlias) {
   OpBuilder b(region);
   // Consider every single block and attempt to fuse adjacent loops.
   for (auto &block : region) {
@@ -159,7 +173,7 @@ void mlir::scf::naivelyFuseParallelOps(Region &region) {
     }
     for (ArrayRef<ParallelOp> ploops : ploopChains) {
       for (int i = 0, e = ploops.size(); i + 1 < e; ++i)
-        fuseIfLegal(ploops[i], ploops[i + 1], b);
+        fuseIfLegal(ploops[i], ploops[i + 1], b, mayAlias);
     }
   }
 }
@@ -168,9 +182,15 @@ namespace {
 struct ParallelLoopFusion
     : public impl::SCFParallelLoopFusionBase<ParallelLoopFusion> {
   void runOnOperation() override {
+    auto &AA = getAnalysis<AliasAnalysis>();
+
+    auto mayAlias = [&](Value val1, Value val2) -> bool {
+      return !AA.alias(val1, val2).isNo();
+    };
+
     getOperation()->walk([&](Operation *child) {
       for (Region &region : child->getRegions())
-        naivelyFuseParallelOps(region);
+        naivelyFuseParallelOps(region, mayAlias);
     });
   }
 };
diff --git a/mlir/test/Dialect/SCF/parallel-loop-fusion.mlir b/mlir/test/Dialect/SCF/parallel-loop-fusion.mlir
index aab64b2751caf7..0fcc21fa7c4875 100644
--- a/mlir/test/Dialect/SCF/parallel-loop-fusion.mlir
+++ b/mlir/test/Dialect/SCF/parallel-loop-fusion.mlir
@@ -357,3 +357,33 @@ func.func @nested_fuse(%A: memref<2x2xf32>, %B: memref<2x2xf32>,
 // CHECK:        }
 // CHECK:      }
 // CHECK:      memref.dealloc [[SUM]]
+
+// -----
+
+func.func @do_not_fuse_alias(%A: memref<2x2xf32>, %B: memref<2x2xf32>,
+                             %C: memref<2x2xf32>, %result: memref<2x2xf32>,
+                             %sum: memref<2x2xf32>) {
+  %c2 = arith.constant 2 : index
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
+    %B_elem = memref.load %B[%i, %j] : memref<2x2xf32>
+    %C_elem = memref.load %C[%i, %j] : memref<2x2xf32>
+    %sum_elem = arith.addf %B_elem, %C_elem : f32
+    memref.store %sum_elem, %sum[%i, %j] : memref<2x2xf32>
+    scf.yield
+  }
+  scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
+    %sum_elem = memref.load %sum[%i, %j] : memref<2x2xf32>
+    %A_elem = memref.load %A[%i, %j] : memref<2x2xf32>
+    %product_elem = arith.mulf %sum_elem, %A_elem : f32
+    memref.store %product_elem, %result[%i, %j] : memref<2x2xf32>
+    scf.yield
+  }
+  return
+}
+
+// %sum and %result may alias, do not fuse loops
+// CHECK-LABEL: func @do_not_fuse_alias
+// CHECK:      scf.parallel
+// CHECK:      scf.parallel

``````````

</details>


https://github.com/llvm/llvm-project/pull/75852


More information about the Mlir-commits mailing list