[Mlir-commits] [mlir] [mlir][scf] Relax requirements for loops fusion (PR #79187)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Jan 29 08:26:55 PST 2024


https://github.com/fabrizio-indirli updated https://github.com/llvm/llvm-project/pull/79187

>From 375b51984521944723bf01081b463327ad202559 Mon Sep 17 00:00:00 2001
From: Fabrizio Indirli <Fabrizio.Indirli at arm.com>
Date: Mon, 22 Jan 2024 11:16:30 +0000
Subject: [PATCH] [mlir][scf] Relax requirements for loops fusion

Enable the fusion of parallel loops also when the 1st loop
contains multiple write accesses to the same buffer,
if the accesses are always on the same indices.
Avoid failing on possible aliasing when only one memref
from the function args is being written.

Signed-off-by: Fabrizio Indirli <Fabrizio.Indirli at arm.com>
---
 .../SCF/Transforms/ParallelLoopFusion.cpp     | 39 +++++++--
 .../Dialect/SCF/parallel-loop-fusion.mlir     | 81 ++++++++++++++++++-
 2 files changed, 114 insertions(+), 6 deletions(-)

diff --git a/mlir/lib/Dialect/SCF/Transforms/ParallelLoopFusion.cpp b/mlir/lib/Dialect/SCF/Transforms/ParallelLoopFusion.cpp
index d7184ad0bad2c7a..2d791a32f716c8a 100644
--- a/mlir/lib/Dialect/SCF/Transforms/ParallelLoopFusion.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/ParallelLoopFusion.cpp
@@ -21,6 +21,8 @@
 #include "mlir/IR/OpDefinition.h"
 #include "mlir/Interfaces/SideEffectInterfaces.h"
 
+#include "llvm/ADT/SetVector.h"
+
 namespace mlir {
 #define GEN_PASS_DEF_SCFPARALLELLOOPFUSION
 #include "mlir/Dialect/SCF/Transforms/Passes.h.inc"
@@ -63,10 +65,27 @@ static bool haveNoReadsAfterWriteExceptSameIndex(
     llvm::function_ref<bool(Value, Value)> mayAlias) {
   DenseMap<Value, SmallVector<ValueRange, 1>> bufferStores;
   SmallVector<Value> bufferStoresVec;
+  llvm::SmallSetVector<BlockArgument, 10u> writtenArgs;
   firstPloop.getBody()->walk([&](memref::StoreOp store) {
     bufferStores[store.getMemRef()].push_back(store.getIndices());
-    bufferStoresVec.emplace_back(store.getMemRef());
+    const auto storeMemRef = store.getMemRef();
+    bufferStoresVec.emplace_back(storeMemRef);
+    if (llvm::isa<BlockArgument>(storeMemRef))
+      writtenArgs.insert(llvm::cast<BlockArgument>(storeMemRef));
+  });
+  // retrieve also the values written in the second loop
+  secondPloop.getBody()->walk([&](memref::StoreOp store) {
+    const auto storeMemRef = store.getMemRef();
+    if (llvm::isa<BlockArgument>(storeMemRef))
+      writtenArgs.insert(llvm::cast<BlockArgument>(storeMemRef));
   });
+
+  // return true if the value is a function's argument that is written multiple
+  // times
+  auto isArgAndNoMultiWrites = [&writtenArgs](const Value &store) {
+    return llvm::isa<BlockArgument>(store) && (writtenArgs.size() <= 1u);
+  };
+
   auto walkResult = secondPloop.getBody()->walk([&](memref::LoadOp load) {
     Value loadMem = load.getMemRef();
     // Stop if the memref is defined in secondPloop body. Careful alias analysis
@@ -76,20 +95,30 @@ static bool haveNoReadsAfterWriteExceptSameIndex(
       return WalkResult::interrupt();
 
     for (Value store : bufferStoresVec)
-      if (store != loadMem && mayAlias(store, loadMem))
+      // avoid alias analysis when storing on a func argument that is accessed
+      // only once
+      if ((store != loadMem) && !isArgAndNoMultiWrites(store) &&
+          mayAlias(store, loadMem))
         return WalkResult::interrupt();
 
     auto write = bufferStores.find(loadMem);
     if (write == bufferStores.end())
       return WalkResult::advance();
 
-    // Allow only single write access per buffer.
-    if (write->second.size() != 1)
+    // Check that at last one store was retrieved
+    if (!write->second.size())
       return WalkResult::interrupt();
 
+    auto storeIndices = write->second.front();
+
+    // Multiple writes to the same memref are allowed only on the same indices
+    for (const auto &othStoreIndices : write->second) {
+      if (othStoreIndices != storeIndices)
+        return WalkResult::interrupt();
+    }
+
     // Check that the load indices of secondPloop coincide with store indices of
     // firstPloop for the same memrefs.
-    auto storeIndices = write->second.front();
     auto loadIndices = load.getIndices();
     if (storeIndices.size() != loadIndices.size())
       return WalkResult::interrupt();
diff --git a/mlir/test/Dialect/SCF/parallel-loop-fusion.mlir b/mlir/test/Dialect/SCF/parallel-loop-fusion.mlir
index 9fd33b4e5247178..d62f5ed91dec8fb 100644
--- a/mlir/test/Dialect/SCF/parallel-loop-fusion.mlir
+++ b/mlir/test/Dialect/SCF/parallel-loop-fusion.mlir
@@ -60,6 +60,7 @@ func.func @fuse_two(%A: memref<2x2xf32>, %B: memref<2x2xf32>,
 // CHECK:        [[C_ELEM:%.*]] = memref.load [[C]]{{\[}}[[I]], [[J]]]
 // CHECK:        [[SUM_ELEM:%.*]] = arith.addf [[B_ELEM]], [[C_ELEM]]
 // CHECK:        memref.store [[SUM_ELEM]], [[SUM]]{{\[}}[[I]], [[J]]]
+// CHECK-NOT:  scf.parallel
 // CHECK:        [[SUM_ELEM_:%.*]] = memref.load [[SUM]]{{\[}}[[I]], [[J]]]
 // CHECK:        [[A_ELEM:%.*]] = memref.load [[A]]{{\[}}[[I]], [[J]]]
 // CHECK:        [[PRODUCT_ELEM:%.*]] = arith.mulf [[SUM_ELEM_]], [[A_ELEM]]
@@ -113,10 +114,12 @@ func.func @fuse_three(%lhs: memref<100x10xf32>, %rhs: memref<100xf32>,
 // CHECK-SAME:     to ([[C100]], [[C10]]) step ([[C1]], [[C1]]) {
 // CHECK:        [[RHS_ELEM:%.*]] = memref.load [[RHS]]{{\[}}[[I]]]
 // CHECK:        memref.store [[RHS_ELEM]], [[BROADCAST_RHS]]{{\[}}[[I]], [[J]]]
+// CHECK-NOT:  scf.parallel
 // CHECK:        [[LHS_ELEM:%.*]] = memref.load [[LHS]]{{\[}}[[I]], [[J]]]
 // CHECK:        [[BROADCAST_RHS_ELEM:%.*]] = memref.load [[BROADCAST_RHS]]
 // CHECK:        [[DIFF_ELEM:%.*]] = arith.subf [[LHS_ELEM]], [[BROADCAST_RHS_ELEM]]
 // CHECK:        memref.store [[DIFF_ELEM]], [[DIFF]]{{\[}}[[I]], [[J]]]
+// CHECK-NOT:  scf.parallel
 // CHECK:        [[DIFF_ELEM_:%.*]] = memref.load [[DIFF]]{{\[}}[[I]], [[J]]]
 // CHECK:        [[EXP_ELEM:%.*]] = math.exp [[DIFF_ELEM_]]
 // CHECK:        memref.store [[EXP_ELEM]], [[RESULT]]{{\[}}[[I]], [[J]]]
@@ -382,8 +385,84 @@ func.func @do_not_fuse_alias(%A: memref<2x2xf32>, %B: memref<2x2xf32>,
   }
   return
 }
-
 // %sum and %result may alias with other args, do not fuse loops
 // CHECK-LABEL: func @do_not_fuse_alias
 // CHECK:      scf.parallel
 // CHECK:      scf.parallel
+
+// -----
+
+func.func @fuse_when_1st_has_multiple_stores(
+  %A: memref<2x2xf32>, %B: memref<2x2xf32>, %result: memref<2x2xf32>) {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %c2 = arith.constant 2 : index
+  %c0f32 = arith.constant 0.0 : f32
+  %sum = memref.alloc()  : memref<2x2xf32>
+  scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
+    memref.store %c0f32, %sum[%i, %j] : memref<2x2xf32>
+    %B_elem = memref.load %B[%i, %j] : memref<2x2xf32>
+    %sum_elem = arith.addf %B_elem, %B_elem : f32
+    memref.store %sum_elem, %sum[%i, %j] : memref<2x2xf32>
+    scf.reduce
+  }
+  scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
+    %sum_elem = memref.load %sum[%i, %j] : memref<2x2xf32>
+    %A_elem = memref.load %A[%i, %j] : memref<2x2xf32>
+    %product_elem = arith.mulf %sum_elem, %A_elem : f32
+    memref.store %product_elem, %result[%i, %j] : memref<2x2xf32>
+    scf.reduce
+  }
+  memref.dealloc %sum : memref<2x2xf32>
+  return
+}
+// CHECK-LABEL: func @fuse_when_1st_has_multiple_stores
+// CHECK-SAME:   ([[A:%.*]]: {{.*}}, [[B:%.*]]: {{.*}},
+// CHECK-SAME:    [[RESULT:%.*]]: {{.*}}) {
+// CHECK:      [[C0:%.*]] = arith.constant 0 : index
+// CHECK:      [[C1:%.*]] = arith.constant 1 : index
+// CHECK:      [[C2:%.*]] = arith.constant 2 : index
+// CHECK:      [[C0F32:%.*]] = arith.constant 0.
+// CHECK:      [[SUM:%.*]] = memref.alloc()
+// CHECK:      scf.parallel ([[I:%.*]], [[J:%.*]]) = ([[C0]], [[C0]])
+// CHECK-SAME:     to ([[C2]], [[C2]]) step ([[C1]], [[C1]]) {
+// CHECK:        [[B_ELEM:%.*]] = memref.load [[B]]{{\[}}[[I]], [[J]]]
+// CHECK:        [[SUM_ELEM:%.*]] = arith.addf [[B_ELEM]], [[B_ELEM]]
+// CHECK:        memref.store [[SUM_ELEM]], [[SUM]]{{\[}}[[I]], [[J]]]
+// CHECK-NOT:  scf.parallel
+// CHECK:        [[A_ELEM:%.*]] = memref.load [[A]]{{\[}}[[I]], [[J]]]
+// CHECK:        [[PRODUCT_ELEM:%.*]] = arith.mulf
+// CHECK:        memref.store [[PRODUCT_ELEM]], [[RESULT]]{{\[}}[[I]], [[J]]]
+// CHECK:        scf.reduce
+// CHECK:      }
+// CHECK:      memref.dealloc [[SUM]]
+
+// -----
+
+func.func @do_not_fuse_multiple_stores_on_diff_indices(
+  %A: memref<2x2xf32>, %B: memref<2x2xf32>, %result: memref<2x2xf32>) {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %c2 = arith.constant 2 : index
+  %c0_f32 = arith.constant 0.0 : f32
+  %sum = memref.alloc()  : memref<2x2xf32>
+  scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
+    memref.store %c0_f32, %sum[%i, %j] : memref<2x2xf32>
+    %B_elem = memref.load %B[%i, %j] : memref<2x2xf32>
+    %sum_elem = arith.addf %B_elem, %B_elem : f32
+    memref.store %sum_elem, %sum[%c0, %j] : memref<2x2xf32>
+    scf.reduce
+  }
+  scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
+    %sum_elem = memref.load %sum[%i, %j] : memref<2x2xf32>
+    %A_elem = memref.load %A[%i, %j] : memref<2x2xf32>
+    %product_elem = arith.mulf %sum_elem, %A_elem : f32
+    memref.store %product_elem, %result[%i, %j] : memref<2x2xf32>
+    scf.reduce
+  }
+  memref.dealloc %sum : memref<2x2xf32>
+  return
+}
+// CHECK-LABEL: func @do_not_fuse_multiple_stores_on_diff_indices
+// CHECK:        scf.parallel
+// CHECK:        scf.parallel
\ No newline at end of file



More information about the Mlir-commits mailing list