[Mlir-commits] [mlir] [mlir][scf] Relax requirements for loops fusion (PR #79187)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Jan 29 08:30:39 PST 2024


https://github.com/fabrizio-indirli updated https://github.com/llvm/llvm-project/pull/79187

>From 4aa1affa5c3b28fffdc0ada2c6c4919448a01ee0 Mon Sep 17 00:00:00 2001
From: Fabrizio Indirli <Fabrizio.Indirli at arm.com>
Date: Mon, 22 Jan 2024 11:16:30 +0000
Subject: [PATCH] [mlir][scf] Relax requirements for loops fusion

Enable the fusion of parallel loops also when the 1st loop
contains multiple write accesses to the same buffer,
if the accesses are always on the same indices.
Avoid failing on possible aliasing when only one memref
from the function args is being written.

Signed-off-by: Fabrizio Indirli <Fabrizio.Indirli at arm.com>
---
 .../SCF/Transforms/ParallelLoopFusion.cpp     | 39 +++++++--
 .../Dialect/SCF/parallel-loop-fusion.mlir     | 81 ++++++++++++++++++-
 2 files changed, 114 insertions(+), 6 deletions(-)

diff --git a/mlir/lib/Dialect/SCF/Transforms/ParallelLoopFusion.cpp b/mlir/lib/Dialect/SCF/Transforms/ParallelLoopFusion.cpp
index d7184ad0bad2c7..1b2361b59236c9 100644
--- a/mlir/lib/Dialect/SCF/Transforms/ParallelLoopFusion.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/ParallelLoopFusion.cpp
@@ -21,6 +21,8 @@
 #include "mlir/IR/OpDefinition.h"
 #include "mlir/Interfaces/SideEffectInterfaces.h"
 
+#include "llvm/ADT/SetVector.h"
+
 namespace mlir {
 #define GEN_PASS_DEF_SCFPARALLELLOOPFUSION
 #include "mlir/Dialect/SCF/Transforms/Passes.h.inc"
@@ -63,10 +65,27 @@ static bool haveNoReadsAfterWriteExceptSameIndex(
     llvm::function_ref<bool(Value, Value)> mayAlias) {
   DenseMap<Value, SmallVector<ValueRange, 1>> bufferStores;
   SmallVector<Value> bufferStoresVec;
+  llvm::SmallSetVector<BlockArgument, 10u> writtenArgs;
   firstPloop.getBody()->walk([&](memref::StoreOp store) {
     bufferStores[store.getMemRef()].push_back(store.getIndices());
-    bufferStoresVec.emplace_back(store.getMemRef());
+    const auto storeMemRef = store.getMemRef();
+    bufferStoresVec.emplace_back(storeMemRef);
+    if (llvm::isa<BlockArgument>(storeMemRef))
+      writtenArgs.insert(llvm::cast<BlockArgument>(storeMemRef));
+  });
+  // retrieve also the values written in the second loop
+  secondPloop.getBody()->walk([&](memref::StoreOp store) {
+    const auto storeMemRef = store.getMemRef();
+    if (llvm::isa<BlockArgument>(storeMemRef))
+      writtenArgs.insert(llvm::cast<BlockArgument>(storeMemRef));
   });
+
+  // return true if the value is a function's argument,
+  // and only one argument is being written in the two loops
+  auto isArgAndNoMultiWrites = [&writtenArgs](const Value &store) {
+    return llvm::isa<BlockArgument>(store) && (writtenArgs.size() <= 1u);
+  };
+
   auto walkResult = secondPloop.getBody()->walk([&](memref::LoadOp load) {
     Value loadMem = load.getMemRef();
     // Stop if the memref is defined in secondPloop body. Careful alias analysis
@@ -76,20 +95,30 @@ static bool haveNoReadsAfterWriteExceptSameIndex(
       return WalkResult::interrupt();
 
     for (Value store : bufferStoresVec)
-      if (store != loadMem && mayAlias(store, loadMem))
+      // avoid alias analysis when storing on a function argument,
+      // if it's the only argument that is being written in the two loops
+      if ((store != loadMem) && !isArgAndNoMultiWrites(store) &&
+          mayAlias(store, loadMem))
         return WalkResult::interrupt();
 
     auto write = bufferStores.find(loadMem);
     if (write == bufferStores.end())
       return WalkResult::advance();
 
-    // Allow only single write access per buffer.
-    if (write->second.size() != 1)
+    // Check that at last one store was retrieved
+    if (!write->second.size())
       return WalkResult::interrupt();
 
+    auto storeIndices = write->second.front();
+
+    // Multiple writes to the same memref are allowed only on the same indices
+    for (const auto &othStoreIndices : write->second) {
+      if (othStoreIndices != storeIndices)
+        return WalkResult::interrupt();
+    }
+
     // Check that the load indices of secondPloop coincide with store indices of
     // firstPloop for the same memrefs.
-    auto storeIndices = write->second.front();
     auto loadIndices = load.getIndices();
     if (storeIndices.size() != loadIndices.size())
       return WalkResult::interrupt();
diff --git a/mlir/test/Dialect/SCF/parallel-loop-fusion.mlir b/mlir/test/Dialect/SCF/parallel-loop-fusion.mlir
index 9fd33b4e524717..d62f5ed91dec8f 100644
--- a/mlir/test/Dialect/SCF/parallel-loop-fusion.mlir
+++ b/mlir/test/Dialect/SCF/parallel-loop-fusion.mlir
@@ -60,6 +60,7 @@ func.func @fuse_two(%A: memref<2x2xf32>, %B: memref<2x2xf32>,
 // CHECK:        [[C_ELEM:%.*]] = memref.load [[C]]{{\[}}[[I]], [[J]]]
 // CHECK:        [[SUM_ELEM:%.*]] = arith.addf [[B_ELEM]], [[C_ELEM]]
 // CHECK:        memref.store [[SUM_ELEM]], [[SUM]]{{\[}}[[I]], [[J]]]
+// CHECK-NOT:  scf.parallel
 // CHECK:        [[SUM_ELEM_:%.*]] = memref.load [[SUM]]{{\[}}[[I]], [[J]]]
 // CHECK:        [[A_ELEM:%.*]] = memref.load [[A]]{{\[}}[[I]], [[J]]]
 // CHECK:        [[PRODUCT_ELEM:%.*]] = arith.mulf [[SUM_ELEM_]], [[A_ELEM]]
@@ -113,10 +114,12 @@ func.func @fuse_three(%lhs: memref<100x10xf32>, %rhs: memref<100xf32>,
 // CHECK-SAME:     to ([[C100]], [[C10]]) step ([[C1]], [[C1]]) {
 // CHECK:        [[RHS_ELEM:%.*]] = memref.load [[RHS]]{{\[}}[[I]]]
 // CHECK:        memref.store [[RHS_ELEM]], [[BROADCAST_RHS]]{{\[}}[[I]], [[J]]]
+// CHECK-NOT:  scf.parallel
 // CHECK:        [[LHS_ELEM:%.*]] = memref.load [[LHS]]{{\[}}[[I]], [[J]]]
 // CHECK:        [[BROADCAST_RHS_ELEM:%.*]] = memref.load [[BROADCAST_RHS]]
 // CHECK:        [[DIFF_ELEM:%.*]] = arith.subf [[LHS_ELEM]], [[BROADCAST_RHS_ELEM]]
 // CHECK:        memref.store [[DIFF_ELEM]], [[DIFF]]{{\[}}[[I]], [[J]]]
+// CHECK-NOT:  scf.parallel
 // CHECK:        [[DIFF_ELEM_:%.*]] = memref.load [[DIFF]]{{\[}}[[I]], [[J]]]
 // CHECK:        [[EXP_ELEM:%.*]] = math.exp [[DIFF_ELEM_]]
 // CHECK:        memref.store [[EXP_ELEM]], [[RESULT]]{{\[}}[[I]], [[J]]]
@@ -382,8 +385,84 @@ func.func @do_not_fuse_alias(%A: memref<2x2xf32>, %B: memref<2x2xf32>,
   }
   return
 }
-
 // %sum and %result may alias with other args, do not fuse loops
 // CHECK-LABEL: func @do_not_fuse_alias
 // CHECK:      scf.parallel
 // CHECK:      scf.parallel
+
+// -----
+
+func.func @fuse_when_1st_has_multiple_stores(
+  %A: memref<2x2xf32>, %B: memref<2x2xf32>, %result: memref<2x2xf32>) {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %c2 = arith.constant 2 : index
+  %c0f32 = arith.constant 0.0 : f32
+  %sum = memref.alloc()  : memref<2x2xf32>
+  scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
+    memref.store %c0f32, %sum[%i, %j] : memref<2x2xf32>
+    %B_elem = memref.load %B[%i, %j] : memref<2x2xf32>
+    %sum_elem = arith.addf %B_elem, %B_elem : f32
+    memref.store %sum_elem, %sum[%i, %j] : memref<2x2xf32>
+    scf.reduce
+  }
+  scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
+    %sum_elem = memref.load %sum[%i, %j] : memref<2x2xf32>
+    %A_elem = memref.load %A[%i, %j] : memref<2x2xf32>
+    %product_elem = arith.mulf %sum_elem, %A_elem : f32
+    memref.store %product_elem, %result[%i, %j] : memref<2x2xf32>
+    scf.reduce
+  }
+  memref.dealloc %sum : memref<2x2xf32>
+  return
+}
+// CHECK-LABEL: func @fuse_when_1st_has_multiple_stores
+// CHECK-SAME:   ([[A:%.*]]: {{.*}}, [[B:%.*]]: {{.*}},
+// CHECK-SAME:    [[RESULT:%.*]]: {{.*}}) {
+// CHECK:      [[C0:%.*]] = arith.constant 0 : index
+// CHECK:      [[C1:%.*]] = arith.constant 1 : index
+// CHECK:      [[C2:%.*]] = arith.constant 2 : index
+// CHECK:      [[C0F32:%.*]] = arith.constant 0.
+// CHECK:      [[SUM:%.*]] = memref.alloc()
+// CHECK:      scf.parallel ([[I:%.*]], [[J:%.*]]) = ([[C0]], [[C0]])
+// CHECK-SAME:     to ([[C2]], [[C2]]) step ([[C1]], [[C1]]) {
+// CHECK:        [[B_ELEM:%.*]] = memref.load [[B]]{{\[}}[[I]], [[J]]]
+// CHECK:        [[SUM_ELEM:%.*]] = arith.addf [[B_ELEM]], [[B_ELEM]]
+// CHECK:        memref.store [[SUM_ELEM]], [[SUM]]{{\[}}[[I]], [[J]]]
+// CHECK-NOT:  scf.parallel
+// CHECK:        [[A_ELEM:%.*]] = memref.load [[A]]{{\[}}[[I]], [[J]]]
+// CHECK:        [[PRODUCT_ELEM:%.*]] = arith.mulf
+// CHECK:        memref.store [[PRODUCT_ELEM]], [[RESULT]]{{\[}}[[I]], [[J]]]
+// CHECK:        scf.reduce
+// CHECK:      }
+// CHECK:      memref.dealloc [[SUM]]
+
+// -----
+
+func.func @do_not_fuse_multiple_stores_on_diff_indices(
+  %A: memref<2x2xf32>, %B: memref<2x2xf32>, %result: memref<2x2xf32>) {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %c2 = arith.constant 2 : index
+  %c0_f32 = arith.constant 0.0 : f32
+  %sum = memref.alloc()  : memref<2x2xf32>
+  scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
+    memref.store %c0_f32, %sum[%i, %j] : memref<2x2xf32>
+    %B_elem = memref.load %B[%i, %j] : memref<2x2xf32>
+    %sum_elem = arith.addf %B_elem, %B_elem : f32
+    memref.store %sum_elem, %sum[%c0, %j] : memref<2x2xf32>
+    scf.reduce
+  }
+  scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
+    %sum_elem = memref.load %sum[%i, %j] : memref<2x2xf32>
+    %A_elem = memref.load %A[%i, %j] : memref<2x2xf32>
+    %product_elem = arith.mulf %sum_elem, %A_elem : f32
+    memref.store %product_elem, %result[%i, %j] : memref<2x2xf32>
+    scf.reduce
+  }
+  memref.dealloc %sum : memref<2x2xf32>
+  return
+}
+// CHECK-LABEL: func @do_not_fuse_multiple_stores_on_diff_indices
+// CHECK:        scf.parallel
+// CHECK:        scf.parallel
\ No newline at end of file



More information about the Mlir-commits mailing list