[Mlir-commits] [mlir] 986bef9 - [mlir] Remove redundant loads
Amy Zhuang
llvmlistbot at llvm.org
Thu Jun 3 16:05:06 PDT 2021
Author: Amy Zhuang
Date: 2021-06-03T15:51:46-07:00
New Revision: 986bef97826fc41cbac1b7ff74b4f40f4594ba68
URL: https://github.com/llvm/llvm-project/commit/986bef97826fc41cbac1b7ff74b4f40f4594ba68
DIFF: https://github.com/llvm/llvm-project/commit/986bef97826fc41cbac1b7ff74b4f40f4594ba68.diff
LOG: [mlir] Remove redundant loads
Reviewed By: vinayaka-polymage, bondhugula
Differential Revision: https://reviews.llvm.org/D103294
Added:
Modified:
mlir/lib/Transforms/MemRefDataFlowOpt.cpp
mlir/test/Transforms/memref-dataflow-opt.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Transforms/MemRefDataFlowOpt.cpp b/mlir/lib/Transforms/MemRefDataFlowOpt.cpp
index 401aaaf1c7cd4..f0c502a5bff95 100644
--- a/mlir/lib/Transforms/MemRefDataFlowOpt.cpp
+++ b/mlir/lib/Transforms/MemRefDataFlowOpt.cpp
@@ -7,7 +7,8 @@
//===----------------------------------------------------------------------===//
//
// This file implements a pass to forward memref stores to loads, thereby
-// potentially getting rid of intermediate memref's entirely.
+// potentially getting rid of intermediate memref's entirely. It also removes
+// redundant loads.
// TODO: In the future, similar techniques could be used to eliminate
// dead memref store's and perform more complex forwarding when support for
// SSA scalars live out of 'affine.for'/'affine.if' statements is available.
@@ -20,6 +21,7 @@
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/Dominance.h"
+#include "mlir/Support/LogicalResult.h"
#include "mlir/Transforms/Passes.h"
#include "llvm/ADT/SmallPtrSet.h"
#include <algorithm>
@@ -29,21 +31,24 @@
using namespace mlir;
namespace {
-// The store to load forwarding relies on three conditions:
+// The store to load forwarding and load CSE rely on three conditions:
//
-// 1) they need to have mathematically equivalent affine access functions
-// (checked after full composition of load/store operands); this implies that
-// they access the same single memref element for all iterations of the common
-// surrounding loop,
+// 1) store/load and load need to have mathematically equivalent affine access
+// functions (checked after full composition of load/store operands); this
+// implies that they access the same single memref element for all iterations of
+// the common surrounding loop,
//
-// 2) the store op should dominate the load op,
+// 2) the store/load op should dominate the load op,
//
-// 3) among all op's that satisfy both (1) and (2), the one that postdominates
-// all store op's that have a dependence into the load, is provably the last
-// writer to the particular memref location being loaded at the load op, and its
-// store value can be forwarded to the load. Note that the only dependences
-// that are to be considered are those that are satisfied at the block* of the
-// innermost common surrounding loop of the <store, load> being considered.
+// 3) among all op's that satisfy both (1) and (2), for store to load
+// forwarding, the one that postdominates all store op's that have a dependence
+// into the load, is provably the last writer to the particular memref location
+// being loaded at the load op, and its store value can be forwarded to the
+// load; for load CSE, any op that postdominates all store op's that have a
+// dependence into the load can be forwarded and the first one found is chosen.
+// Note that the only dependences that are to be considered are those that are
+// satisfied at the block* of the innermost common surrounding loop of the
+// <store/load, load> being considered.
//
// (* A dependence being satisfied at a block: a dependence that is satisfied by
// virtue of the destination operation appearing textually / lexically after
@@ -64,11 +69,13 @@ namespace {
struct MemRefDataFlowOpt : public MemRefDataFlowOptBase<MemRefDataFlowOpt> {
void runOnFunction() override;
- void forwardStoreToLoad(AffineReadOpInterface loadOp);
+ LogicalResult forwardStoreToLoad(AffineReadOpInterface loadOp);
+ void loadCSE(AffineReadOpInterface loadOp);
// A list of memref's that are potentially dead / could be eliminated.
SmallPtrSet<Value, 4> memrefsToErase;
- // Load op's whose results were replaced by those forwarded from stores.
+ // Load op's whose results were replaced by those forwarded from stores
+ // dominating stores or loads..
SmallVector<Operation *, 8> loadOpsToErase;
DominanceInfo *domInfo = nullptr;
@@ -83,9 +90,32 @@ std::unique_ptr<OperationPass<FuncOp>> mlir::createMemRefDataFlowOptPass() {
return std::make_unique<MemRefDataFlowOpt>();
}
+// Check if the store may be reaching the load.
+static bool storeMayReachLoad(Operation *storeOp, Operation *loadOp,
+ unsigned minSurroundingLoops) {
+ MemRefAccess srcAccess(storeOp);
+ MemRefAccess destAccess(loadOp);
+ FlatAffineConstraints dependenceConstraints;
+ unsigned nsLoops = getNumCommonSurroundingLoops(*loadOp, *storeOp);
+ unsigned d;
+ // Dependences at loop depth <= minSurroundingLoops do NOT matter.
+ for (d = nsLoops + 1; d > minSurroundingLoops; d--) {
+ DependenceResult result = checkMemrefAccessDependence(
+ srcAccess, destAccess, d, &dependenceConstraints,
+ /*dependenceComponents=*/nullptr);
+ if (hasDependence(result))
+ break;
+ }
+ if (d <= minSurroundingLoops)
+ return false;
+
+ return true;
+}
+
// This is a straightforward implementation not optimized for speed. Optimize
// if needed.
-void MemRefDataFlowOpt::forwardStoreToLoad(AffineReadOpInterface loadOp) {
+LogicalResult
+MemRefDataFlowOpt::forwardStoreToLoad(AffineReadOpInterface loadOp) {
// First pass over the use list to get the minimum number of surrounding
// loops common between the load op and the store op, with min taken across
// all store ops.
@@ -110,21 +140,7 @@ void MemRefDataFlowOpt::forwardStoreToLoad(AffineReadOpInterface loadOp) {
SmallVector<Operation *, 8> depSrcStores;
for (auto *storeOp : storeOps) {
- MemRefAccess srcAccess(storeOp);
- MemRefAccess destAccess(loadOp);
- // Find stores that may be reaching the load.
- FlatAffineConstraints dependenceConstraints;
- unsigned nsLoops = getNumCommonSurroundingLoops(*loadOp, *storeOp);
- unsigned d;
- // Dependences at loop depth <= minSurroundingLoops do NOT matter.
- for (d = nsLoops + 1; d > minSurroundingLoops; d--) {
- DependenceResult result = checkMemrefAccessDependence(
- srcAccess, destAccess, d, &dependenceConstraints,
- /*dependenceComponents=*/nullptr);
- if (hasDependence(result))
- break;
- }
- if (d == minSurroundingLoops)
+ if (!storeMayReachLoad(storeOp, loadOp, minSurroundingLoops))
continue;
// Stores that *may* be reaching the load.
@@ -138,6 +154,8 @@ void MemRefDataFlowOpt::forwardStoreToLoad(AffineReadOpInterface loadOp) {
// store %A[%M]
// load %A[%N]
// Use the AffineValueMap
diff erence based memref access equality checking.
+ MemRefAccess srcAccess(storeOp);
+ MemRefAccess destAccess(loadOp);
if (srcAccess != destAccess)
continue;
@@ -165,16 +183,104 @@ void MemRefDataFlowOpt::forwardStoreToLoad(AffineReadOpInterface loadOp) {
}
}
if (!lastWriteStoreOp)
- return;
+ return failure();
// Perform the actual store to load forwarding.
Value storeVal =
cast<AffineWriteOpInterface>(lastWriteStoreOp).getValueToStore();
+ // Check if 2 values have the same shape. This is needed for affine vector
+ // loads and stores.
+ if (storeVal.getType() != loadOp.getValue().getType())
+ return failure();
loadOp.getValue().replaceAllUsesWith(storeVal);
// Record the memref for a later sweep to optimize away.
memrefsToErase.insert(loadOp.getMemRef());
// Record this to erase later.
loadOpsToErase.push_back(loadOp);
+ return success();
+}
+
+// The load to load forwarding / redundant load elimination is similar to the
+// store to load forwarding.
+// loadA will be be replaced with loadB if:
+// 1) loadA and loadB have mathematically equivalent affine access functions.
+// 2) loadB dominates loadA.
+// 3) loadB postdominates all the store op's that have a dependence into loadA.
+void MemRefDataFlowOpt::loadCSE(AffineReadOpInterface loadOp) {
+ // The list of load op candidates for forwarding that satisfy conditions
+ // (1) and (2) above - they will be filtered later when checking (3).
+ SmallVector<Operation *, 8> fwdingCandidates;
+ SmallVector<Operation *, 8> storeOps;
+ unsigned minSurroundingLoops = getNestingDepth(loadOp);
+ MemRefAccess memRefAccess(loadOp);
+ // First pass over the use list to get 1) the minimum number of surrounding
+ // loops common between the load op and an load op candidate, with min taken
+ // across all load op candidates; 2) load op candidates; 3) store ops.
+ // We take min across all load op candidates instead of all load ops to make
+ // sure later dependence check is performed at loop depths that do matter.
+ for (auto *user : loadOp.getMemRef().getUsers()) {
+ if (auto storeOp = dyn_cast<AffineWriteOpInterface>(user)) {
+ storeOps.push_back(storeOp);
+ } else if (auto aLoadOp = dyn_cast<AffineReadOpInterface>(user)) {
+ MemRefAccess otherMemRefAccess(aLoadOp);
+ // No need to consider Load ops that have been replaced in previous store
+ // to load forwarding or loadCSE. If loadA or storeA can be forwarded to
+ // loadB, then loadA or storeA can be forwarded to loadC iff loadB can be
+ // forwarded to loadC.
+ // If loadB is visited before loadC and replace with loadA, we do not put
+ // loadB in candidates list, only loadA. If loadC is visited before loadB,
+ // loadC may be replaced with loadB, which will be replaced with loadA
+ // later.
+ if (aLoadOp != loadOp && !llvm::is_contained(loadOpsToErase, aLoadOp) &&
+ memRefAccess == otherMemRefAccess &&
+ domInfo->dominates(aLoadOp, loadOp)) {
+ fwdingCandidates.push_back(aLoadOp);
+ unsigned nsLoops = getNumCommonSurroundingLoops(*loadOp, *aLoadOp);
+ minSurroundingLoops = std::min(nsLoops, minSurroundingLoops);
+ }
+ }
+ }
+
+ // No forwarding candidate.
+ if (fwdingCandidates.empty())
+ return;
+
+ // Store ops that have a dependence into the load.
+ SmallVector<Operation *, 8> depSrcStores;
+
+ for (auto *storeOp : storeOps) {
+ if (!storeMayReachLoad(storeOp, loadOp, minSurroundingLoops))
+ continue;
+
+ // Stores that *may* be reaching the load.
+ depSrcStores.push_back(storeOp);
+ }
+
+ // 3. Of all the load op's that meet the above criteria, return the first load
+ // found that postdominates all 'depSrcStores' and has the same shape as the
+ // load to be replaced (if one exists). The shape check is needed for affine
+ // vector loads.
+ Operation *firstLoadOp = nullptr;
+ Value oldVal = loadOp.getValue();
+ for (auto *loadOp : fwdingCandidates) {
+ if (llvm::all_of(depSrcStores,
+ [&](Operation *depStore) {
+ return postDomInfo->postDominates(loadOp, depStore);
+ }) &&
+ cast<AffineReadOpInterface>(loadOp).getValue().getType() ==
+ oldVal.getType()) {
+ firstLoadOp = loadOp;
+ break;
+ }
+ }
+ if (!firstLoadOp)
+ return;
+
+ // Perform the actual load to load forwarding.
+ Value loadVal = cast<AffineReadOpInterface>(firstLoadOp).getValue();
+ loadOp.getValue().replaceAllUsesWith(loadVal);
+ // Record this to erase later.
+ loadOpsToErase.push_back(loadOp);
}
void MemRefDataFlowOpt::runOnFunction() {
@@ -191,10 +297,15 @@ void MemRefDataFlowOpt::runOnFunction() {
loadOpsToErase.clear();
memrefsToErase.clear();
- // Walk all load's and perform store to load forwarding.
- f.walk([&](AffineReadOpInterface loadOp) { forwardStoreToLoad(loadOp); });
+ // Walk all load's and perform store to load forwarding and loadCSE.
+ f.walk([&](AffineReadOpInterface loadOp) {
+ // Do store to load forwarding first, if no success, try loadCSE.
+ if (failed(forwardStoreToLoad(loadOp)))
+ loadCSE(loadOp);
+ });
- // Erase all load op's whose results were replaced with store fwd'ed ones.
+ // Erase all load op's whose results were replaced with store or load fwd'ed
+ // ones.
for (auto *loadOp : loadOpsToErase)
loadOp->erase();
diff --git a/mlir/test/Transforms/memref-dataflow-opt.mlir b/mlir/test/Transforms/memref-dataflow-opt.mlir
index 49a3dcbd20183..61fd9d8ce6231 100644
--- a/mlir/test/Transforms/memref-dataflow-opt.mlir
+++ b/mlir/test/Transforms/memref-dataflow-opt.mlir
@@ -300,3 +300,233 @@ func @vector_forwarding(%in : memref<512xf32>, %out : memref<512xf32>) {
// CHECK-NEXT: %[[LDVAL:.*]] = affine.vector_load
// CHECK-NEXT: affine.vector_store %[[LDVAL]],{{.*}}
// CHECK-NEXT: }
+
+func @vector_no_forwarding(%in : memref<512xf32>, %out : memref<512xf32>) {
+ %tmp = memref.alloc() : memref<512xf32>
+ affine.for %i = 0 to 16 {
+ %ld0 = affine.vector_load %in[32*%i] : memref<512xf32>, vector<32xf32>
+ affine.vector_store %ld0, %tmp[32*%i] : memref<512xf32>, vector<32xf32>
+ %ld1 = affine.vector_load %tmp[32*%i] : memref<512xf32>, vector<16xf32>
+ affine.vector_store %ld1, %out[32*%i] : memref<512xf32>, vector<16xf32>
+ }
+ return
+}
+
+// CHECK-LABEL: func @vector_no_forwarding
+// CHECK: affine.for %{{.*}} = 0 to 16 {
+// CHECK-NEXT: %[[LDVAL:.*]] = affine.vector_load
+// CHECK-NEXT: affine.vector_store %[[LDVAL]],{{.*}}
+// CHECK-NEXT: %[[LDVAL1:.*]] = affine.vector_load
+// CHECK-NEXT: affine.vector_store %[[LDVAL1]],{{.*}}
+// CHECK-NEXT: }
+
+// CHECK-LABEL: func @simple_three_loads
+func @simple_three_loads(%in : memref<10xf32>) {
+ affine.for %i0 = 0 to 10 {
+ // CHECK: affine.load
+ %v0 = affine.load %in[%i0] : memref<10xf32>
+ // CHECK-NOT: affine.load
+ %v1 = affine.load %in[%i0] : memref<10xf32>
+ %v2 = addf %v0, %v1 : f32
+ %v3 = affine.load %in[%i0] : memref<10xf32>
+ %v4 = addf %v2, %v3 : f32
+ }
+ return
+}
+
+// CHECK-LABEL: func @nested_loads_const_index
+func @nested_loads_const_index(%in : memref<10xf32>) {
+ %c0 = constant 0 : index
+ // CHECK: affine.load
+ %v0 = affine.load %in[%c0] : memref<10xf32>
+ affine.for %i0 = 0 to 10 {
+ affine.for %i1 = 0 to 20 {
+ affine.for %i2 = 0 to 30 {
+ // CHECK-NOT: affine.load
+ %v1 = affine.load %in[%c0] : memref<10xf32>
+ %v2 = addf %v0, %v1 : f32
+ }
+ }
+ }
+ return
+}
+
+// CHECK-LABEL: func @nested_loads
+func @nested_loads(%N : index, %in : memref<10xf32>) {
+ affine.for %i0 = 0 to 10 {
+ // CHECK: affine.load
+ %v0 = affine.load %in[%i0] : memref<10xf32>
+ affine.for %i1 = 0 to %N {
+ // CHECK-NOT: affine.load
+ %v1 = affine.load %in[%i0] : memref<10xf32>
+ %v2 = addf %v0, %v1 : f32
+ }
+ }
+ return
+}
+
+// CHECK-LABEL: func @nested_loads_
diff erent_memref_accesses_no_cse
+func @nested_loads_
diff erent_memref_accesses_no_cse(%in : memref<10xf32>) {
+ affine.for %i0 = 0 to 10 {
+ // CHECK: affine.load
+ %v0 = affine.load %in[%i0] : memref<10xf32>
+ affine.for %i1 = 0 to 20 {
+ // CHECK: affine.load
+ %v1 = affine.load %in[%i1] : memref<10xf32>
+ %v2 = addf %v0, %v1 : f32
+ }
+ }
+ return
+}
+
+// CHECK-LABEL: func @load_load_store
+func @load_load_store(%m : memref<10xf32>) {
+ affine.for %i0 = 0 to 10 {
+ // CHECK: affine.load
+ %v0 = affine.load %m[%i0] : memref<10xf32>
+ // CHECK-NOT: affine.load
+ %v1 = affine.load %m[%i0] : memref<10xf32>
+ %v2 = addf %v0, %v1 : f32
+ affine.store %v2, %m[%i0] : memref<10xf32>
+ }
+ return
+}
+
+// CHECK-LABEL: func @load_load_store_2_loops_no_cse
+func @load_load_store_2_loops_no_cse(%N : index, %m : memref<10xf32>) {
+ affine.for %i0 = 0 to 10 {
+ // CHECK: affine.load
+ %v0 = affine.load %m[%i0] : memref<10xf32>
+ affine.for %i1 = 0 to %N {
+ // CHECK: affine.load
+ %v1 = affine.load %m[%i0] : memref<10xf32>
+ %v2 = addf %v0, %v1 : f32
+ affine.store %v2, %m[%i0] : memref<10xf32>
+ }
+ }
+ return
+}
+
+// CHECK-LABEL: func @load_load_store_3_loops_no_cse
+func @load_load_store_3_loops_no_cse(%m : memref<10xf32>) {
+%cf1 = constant 1.0 : f32
+ affine.for %i0 = 0 to 10 {
+ // CHECK: affine.load
+ %v0 = affine.load %m[%i0] : memref<10xf32>
+ affine.for %i1 = 0 to 20 {
+ affine.for %i2 = 0 to 30 {
+ // CHECK: affine.load
+ %v1 = affine.load %m[%i0] : memref<10xf32>
+ %v2 = addf %v0, %v1 : f32
+ }
+ affine.store %cf1, %m[%i0] : memref<10xf32>
+ }
+ }
+ return
+}
+
+// CHECK-LABEL: func @load_load_store_3_loops
+func @load_load_store_3_loops(%m : memref<10xf32>) {
+%cf1 = constant 1.0 : f32
+ affine.for %i0 = 0 to 10 {
+ affine.for %i1 = 0 to 20 {
+ // CHECK: affine.load
+ %v0 = affine.load %m[%i0] : memref<10xf32>
+ affine.for %i2 = 0 to 30 {
+ // CHECK-NOT: affine.load
+ %v1 = affine.load %m[%i0] : memref<10xf32>
+ %v2 = addf %v0, %v1 : f32
+ }
+ }
+ affine.store %cf1, %m[%i0] : memref<10xf32>
+ }
+ return
+}
+
+// CHECK-LABEL: func @loads_in_sibling_loops_const_index_no_cse
+func @loads_in_sibling_loops_const_index_no_cse(%m : memref<10xf32>) {
+ %c0 = constant 0 : index
+ affine.for %i0 = 0 to 10 {
+ // CHECK: affine.load
+ %v0 = affine.load %m[%c0] : memref<10xf32>
+ }
+ affine.for %i1 = 0 to 10 {
+ // CHECK: affine.load
+ %v0 = affine.load %m[%c0] : memref<10xf32>
+ %v1 = addf %v0, %v0 : f32
+ }
+ return
+}
+
+// CHECK-LABEL: func @load_load_affine_apply
+func @load_load_affine_apply(%in : memref<10x10xf32>) {
+ affine.for %i0 = 0 to 10 {
+ affine.for %i1 = 0 to 10 {
+ %t0 = affine.apply affine_map<(d0, d1) -> (d1 + 1)>(%i0, %i1)
+ %t1 = affine.apply affine_map<(d0, d1) -> (d0)>(%i0, %i1)
+ %idx0 = affine.apply affine_map<(d0, d1) -> (d1)> (%t0, %t1)
+ %idx1 = affine.apply affine_map<(d0, d1) -> (d0 - 1)> (%t0, %t1)
+ // CHECK: affine.load
+ %v0 = affine.load %in[%idx0, %idx1] : memref<10x10xf32>
+ // CHECK-NOT: affine.load
+ %v1 = affine.load %in[%i0, %i1] : memref<10x10xf32>
+ %v2 = addf %v0, %v1 : f32
+ }
+ }
+ return
+}
+
+// CHECK-LABEL: func @vector_loads
+func @vector_loads(%in : memref<512xf32>, %out : memref<512xf32>) {
+ affine.for %i = 0 to 16 {
+ // CHECK: affine.vector_load
+ %ld0 = affine.vector_load %in[32*%i] : memref<512xf32>, vector<32xf32>
+ // CHECK-NOT: affine.vector_load
+ %ld1 = affine.vector_load %in[32*%i] : memref<512xf32>, vector<32xf32>
+ %add = addf %ld0, %ld1 : vector<32xf32>
+ affine.vector_store %ld1, %out[32*%i] : memref<512xf32>, vector<32xf32>
+ }
+ return
+}
+
+// CHECK-LABEL: func @vector_loads_no_cse
+func @vector_loads_no_cse(%in : memref<512xf32>, %out : memref<512xf32>) {
+ affine.for %i = 0 to 16 {
+ // CHECK: affine.vector_load
+ %ld0 = affine.vector_load %in[32*%i] : memref<512xf32>, vector<32xf32>
+ // CHECK: affine.vector_load
+ %ld1 = affine.vector_load %in[32*%i] : memref<512xf32>, vector<16xf32>
+ affine.vector_store %ld1, %out[32*%i] : memref<512xf32>, vector<16xf32>
+ }
+ return
+}
+
+// CHECK-LABEL: func @vector_load_store_load_no_cse
+func @vector_load_store_load_no_cse(%in : memref<512xf32>, %out : memref<512xf32>) {
+ affine.for %i = 0 to 16 {
+ // CHECK: affine.vector_load
+ %ld0 = affine.vector_load %in[32*%i] : memref<512xf32>, vector<32xf32>
+ affine.vector_store %ld0, %in[16*%i] : memref<512xf32>, vector<32xf32>
+ // CHECK: affine.vector_load
+ %ld1 = affine.vector_load %in[32*%i] : memref<512xf32>, vector<32xf32>
+ %add = addf %ld0, %ld1 : vector<32xf32>
+ affine.vector_store %ld1, %out[32*%i] : memref<512xf32>, vector<32xf32>
+ }
+ return
+}
+
+// CHECK-LABEL: func @vector_load_affine_apply_store_load
+func @vector_load_affine_apply_store_load(%in : memref<512xf32>, %out : memref<512xf32>) {
+ %cf1 = constant 1: index
+ affine.for %i = 0 to 15 {
+ // CHECK: affine.vector_load
+ %ld0 = affine.vector_load %in[32*%i] : memref<512xf32>, vector<32xf32>
+ %idx = affine.apply affine_map<(d0) -> (d0 + 1)> (%i)
+ affine.vector_store %ld0, %in[32*%idx] : memref<512xf32>, vector<32xf32>
+ // CHECK-NOT: affine.vector_load
+ %ld1 = affine.vector_load %in[32*%i] : memref<512xf32>, vector<32xf32>
+ %add = addf %ld0, %ld1 : vector<32xf32>
+ affine.vector_store %ld1, %out[32*%i] : memref<512xf32>, vector<32xf32>
+ }
+ return
+}
More information about the Mlir-commits
mailing list