[Mlir-commits] [mlir] 986bef9 - [mlir] Remove redundant loads

Amy Zhuang llvmlistbot at llvm.org
Thu Jun 3 16:05:06 PDT 2021


Author: Amy Zhuang
Date: 2021-06-03T15:51:46-07:00
New Revision: 986bef97826fc41cbac1b7ff74b4f40f4594ba68

URL: https://github.com/llvm/llvm-project/commit/986bef97826fc41cbac1b7ff74b4f40f4594ba68
DIFF: https://github.com/llvm/llvm-project/commit/986bef97826fc41cbac1b7ff74b4f40f4594ba68.diff

LOG: [mlir] Remove redundant loads

Reviewed By: vinayaka-polymage, bondhugula

Differential Revision: https://reviews.llvm.org/D103294

Added: 
    

Modified: 
    mlir/lib/Transforms/MemRefDataFlowOpt.cpp
    mlir/test/Transforms/memref-dataflow-opt.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Transforms/MemRefDataFlowOpt.cpp b/mlir/lib/Transforms/MemRefDataFlowOpt.cpp
index 401aaaf1c7cd4..f0c502a5bff95 100644
--- a/mlir/lib/Transforms/MemRefDataFlowOpt.cpp
+++ b/mlir/lib/Transforms/MemRefDataFlowOpt.cpp
@@ -7,7 +7,8 @@
 //===----------------------------------------------------------------------===//
 //
 // This file implements a pass to forward memref stores to loads, thereby
-// potentially getting rid of intermediate memref's entirely.
+// potentially getting rid of intermediate memref's entirely. It also removes
+// redundant loads.
 // TODO: In the future, similar techniques could be used to eliminate
 // dead memref store's and perform more complex forwarding when support for
 // SSA scalars live out of 'affine.for'/'affine.if' statements is available.
@@ -20,6 +21,7 @@
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/StandardOps/IR/Ops.h"
 #include "mlir/IR/Dominance.h"
+#include "mlir/Support/LogicalResult.h"
 #include "mlir/Transforms/Passes.h"
 #include "llvm/ADT/SmallPtrSet.h"
 #include <algorithm>
@@ -29,21 +31,24 @@
 using namespace mlir;
 
 namespace {
-// The store to load forwarding relies on three conditions:
+// The store to load forwarding and load CSE rely on three conditions:
 //
-// 1) they need to have mathematically equivalent affine access functions
-// (checked after full composition of load/store operands); this implies that
-// they access the same single memref element for all iterations of the common
-// surrounding loop,
+// 1) store/load and load need to have mathematically equivalent affine access
+// functions (checked after full composition of load/store operands); this
+// implies that they access the same single memref element for all iterations of
+// the common surrounding loop,
 //
-// 2) the store op should dominate the load op,
+// 2) the store/load op should dominate the load op,
 //
-// 3) among all op's that satisfy both (1) and (2), the one that postdominates
-// all store op's that have a dependence into the load, is provably the last
-// writer to the particular memref location being loaded at the load op, and its
-// store value can be forwarded to the load. Note that the only dependences
-// that are to be considered are those that are satisfied at the block* of the
-// innermost common surrounding loop of the <store, load> being considered.
+// 3) among all op's that satisfy both (1) and (2), for store to load
+// forwarding, the one that postdominates all store op's that have a dependence
+// into the load, is provably the last writer to the particular memref location
+// being loaded at the load op, and its store value can be forwarded to the
+// load; for load CSE, any op that postdominates all store op's that have a
+// dependence into the load can be forwarded and the first one found is chosen.
+// Note that the only dependences that are to be considered are those that are
+// satisfied at the block* of the innermost common surrounding loop of the
+// <store/load, load> being considered.
 //
 // (* A dependence being satisfied at a block: a dependence that is satisfied by
 // virtue of the destination operation appearing textually / lexically after
@@ -64,11 +69,13 @@ namespace {
 struct MemRefDataFlowOpt : public MemRefDataFlowOptBase<MemRefDataFlowOpt> {
   void runOnFunction() override;
 
-  void forwardStoreToLoad(AffineReadOpInterface loadOp);
+  LogicalResult forwardStoreToLoad(AffineReadOpInterface loadOp);
+  void loadCSE(AffineReadOpInterface loadOp);
 
   // A list of memref's that are potentially dead / could be eliminated.
   SmallPtrSet<Value, 4> memrefsToErase;
-  // Load op's whose results were replaced by those forwarded from stores.
+  // Load op's whose results were replaced by those forwarded from stores
+  // dominating stores or loads..
   SmallVector<Operation *, 8> loadOpsToErase;
 
   DominanceInfo *domInfo = nullptr;
@@ -83,9 +90,32 @@ std::unique_ptr<OperationPass<FuncOp>> mlir::createMemRefDataFlowOptPass() {
   return std::make_unique<MemRefDataFlowOpt>();
 }
 
+// Check if the store may be reaching the load.
+static bool storeMayReachLoad(Operation *storeOp, Operation *loadOp,
+                              unsigned minSurroundingLoops) {
+  MemRefAccess srcAccess(storeOp);
+  MemRefAccess destAccess(loadOp);
+  FlatAffineConstraints dependenceConstraints;
+  unsigned nsLoops = getNumCommonSurroundingLoops(*loadOp, *storeOp);
+  unsigned d;
+  // Dependences at loop depth <= minSurroundingLoops do NOT matter.
+  for (d = nsLoops + 1; d > minSurroundingLoops; d--) {
+    DependenceResult result = checkMemrefAccessDependence(
+        srcAccess, destAccess, d, &dependenceConstraints,
+        /*dependenceComponents=*/nullptr);
+    if (hasDependence(result))
+      break;
+  }
+  if (d <= minSurroundingLoops)
+    return false;
+
+  return true;
+}
+
 // This is a straightforward implementation not optimized for speed. Optimize
 // if needed.
-void MemRefDataFlowOpt::forwardStoreToLoad(AffineReadOpInterface loadOp) {
+LogicalResult
+MemRefDataFlowOpt::forwardStoreToLoad(AffineReadOpInterface loadOp) {
   // First pass over the use list to get the minimum number of surrounding
   // loops common between the load op and the store op, with min taken across
   // all store ops.
@@ -110,21 +140,7 @@ void MemRefDataFlowOpt::forwardStoreToLoad(AffineReadOpInterface loadOp) {
   SmallVector<Operation *, 8> depSrcStores;
 
   for (auto *storeOp : storeOps) {
-    MemRefAccess srcAccess(storeOp);
-    MemRefAccess destAccess(loadOp);
-    // Find stores that may be reaching the load.
-    FlatAffineConstraints dependenceConstraints;
-    unsigned nsLoops = getNumCommonSurroundingLoops(*loadOp, *storeOp);
-    unsigned d;
-    // Dependences at loop depth <= minSurroundingLoops do NOT matter.
-    for (d = nsLoops + 1; d > minSurroundingLoops; d--) {
-      DependenceResult result = checkMemrefAccessDependence(
-          srcAccess, destAccess, d, &dependenceConstraints,
-          /*dependenceComponents=*/nullptr);
-      if (hasDependence(result))
-        break;
-    }
-    if (d == minSurroundingLoops)
+    if (!storeMayReachLoad(storeOp, loadOp, minSurroundingLoops))
       continue;
 
     // Stores that *may* be reaching the load.
@@ -138,6 +154,8 @@ void MemRefDataFlowOpt::forwardStoreToLoad(AffineReadOpInterface loadOp) {
     //     store %A[%M]
     //     load %A[%N]
     // Use the AffineValueMap 
diff erence based memref access equality checking.
+    MemRefAccess srcAccess(storeOp);
+    MemRefAccess destAccess(loadOp);
     if (srcAccess != destAccess)
       continue;
 
@@ -165,16 +183,104 @@ void MemRefDataFlowOpt::forwardStoreToLoad(AffineReadOpInterface loadOp) {
     }
   }
   if (!lastWriteStoreOp)
-    return;
+    return failure();
 
   // Perform the actual store to load forwarding.
   Value storeVal =
       cast<AffineWriteOpInterface>(lastWriteStoreOp).getValueToStore();
+  // Check if 2 values have the same shape. This is needed for affine vector
+  // loads and stores.
+  if (storeVal.getType() != loadOp.getValue().getType())
+    return failure();
   loadOp.getValue().replaceAllUsesWith(storeVal);
   // Record the memref for a later sweep to optimize away.
   memrefsToErase.insert(loadOp.getMemRef());
   // Record this to erase later.
   loadOpsToErase.push_back(loadOp);
+  return success();
+}
+
+// The load to load forwarding / redundant load elimination is similar to the
+// store to load forwarding.
+// loadA will be be replaced with loadB if:
+// 1) loadA and loadB have mathematically equivalent affine access functions.
+// 2) loadB dominates loadA.
+// 3) loadB postdominates all the store op's that have a dependence into loadA.
+void MemRefDataFlowOpt::loadCSE(AffineReadOpInterface loadOp) {
+  // The list of load op candidates for forwarding that satisfy conditions
+  // (1) and (2) above - they will be filtered later when checking (3).
+  SmallVector<Operation *, 8> fwdingCandidates;
+  SmallVector<Operation *, 8> storeOps;
+  unsigned minSurroundingLoops = getNestingDepth(loadOp);
+  MemRefAccess memRefAccess(loadOp);
+  // First pass over the use list to get 1) the minimum number of surrounding
+  // loops common between the load op and an load op candidate, with min taken
+  // across all load op candidates; 2) load op candidates; 3) store ops.
+  // We take min across all load op candidates instead of all load ops to make
+  // sure later dependence check is performed at loop depths that do matter.
+  for (auto *user : loadOp.getMemRef().getUsers()) {
+    if (auto storeOp = dyn_cast<AffineWriteOpInterface>(user)) {
+      storeOps.push_back(storeOp);
+    } else if (auto aLoadOp = dyn_cast<AffineReadOpInterface>(user)) {
+      MemRefAccess otherMemRefAccess(aLoadOp);
+      // No need to consider Load ops that have been replaced in previous store
+      // to load forwarding or loadCSE. If loadA or storeA can be forwarded to
+      // loadB, then loadA or storeA can be forwarded to loadC iff loadB can be
+      // forwarded to loadC.
+      // If loadB is visited before loadC and replace with loadA, we do not put
+      // loadB in candidates list, only loadA. If loadC is visited before loadB,
+      // loadC may be replaced with loadB, which will be replaced with loadA
+      // later.
+      if (aLoadOp != loadOp && !llvm::is_contained(loadOpsToErase, aLoadOp) &&
+          memRefAccess == otherMemRefAccess &&
+          domInfo->dominates(aLoadOp, loadOp)) {
+        fwdingCandidates.push_back(aLoadOp);
+        unsigned nsLoops = getNumCommonSurroundingLoops(*loadOp, *aLoadOp);
+        minSurroundingLoops = std::min(nsLoops, minSurroundingLoops);
+      }
+    }
+  }
+
+  // No forwarding candidate.
+  if (fwdingCandidates.empty())
+    return;
+
+  // Store ops that have a dependence into the load.
+  SmallVector<Operation *, 8> depSrcStores;
+
+  for (auto *storeOp : storeOps) {
+    if (!storeMayReachLoad(storeOp, loadOp, minSurroundingLoops))
+      continue;
+
+    // Stores that *may* be reaching the load.
+    depSrcStores.push_back(storeOp);
+  }
+
+  // 3. Of all the load op's that meet the above criteria, return the first load
+  // found that postdominates all 'depSrcStores' and has the same shape as the
+  // load to be replaced (if one exists). The shape check is needed for affine
+  // vector loads.
+  Operation *firstLoadOp = nullptr;
+  Value oldVal = loadOp.getValue();
+  for (auto *loadOp : fwdingCandidates) {
+    if (llvm::all_of(depSrcStores,
+                     [&](Operation *depStore) {
+                       return postDomInfo->postDominates(loadOp, depStore);
+                     }) &&
+        cast<AffineReadOpInterface>(loadOp).getValue().getType() ==
+            oldVal.getType()) {
+      firstLoadOp = loadOp;
+      break;
+    }
+  }
+  if (!firstLoadOp)
+    return;
+
+  // Perform the actual load to load forwarding.
+  Value loadVal = cast<AffineReadOpInterface>(firstLoadOp).getValue();
+  loadOp.getValue().replaceAllUsesWith(loadVal);
+  // Record this to erase later.
+  loadOpsToErase.push_back(loadOp);
 }
 
 void MemRefDataFlowOpt::runOnFunction() {
@@ -191,10 +297,15 @@ void MemRefDataFlowOpt::runOnFunction() {
   loadOpsToErase.clear();
   memrefsToErase.clear();
 
-  // Walk all load's and perform store to load forwarding.
-  f.walk([&](AffineReadOpInterface loadOp) { forwardStoreToLoad(loadOp); });
+  // Walk all load's and perform store to load forwarding and loadCSE.
+  f.walk([&](AffineReadOpInterface loadOp) {
+    // Do store to load forwarding first, if no success, try loadCSE.
+    if (failed(forwardStoreToLoad(loadOp)))
+      loadCSE(loadOp);
+  });
 
-  // Erase all load op's whose results were replaced with store fwd'ed ones.
+  // Erase all load op's whose results were replaced with store or load fwd'ed
+  // ones.
   for (auto *loadOp : loadOpsToErase)
     loadOp->erase();
 

diff  --git a/mlir/test/Transforms/memref-dataflow-opt.mlir b/mlir/test/Transforms/memref-dataflow-opt.mlir
index 49a3dcbd20183..61fd9d8ce6231 100644
--- a/mlir/test/Transforms/memref-dataflow-opt.mlir
+++ b/mlir/test/Transforms/memref-dataflow-opt.mlir
@@ -300,3 +300,233 @@ func @vector_forwarding(%in : memref<512xf32>, %out : memref<512xf32>) {
 // CHECK-NEXT:   %[[LDVAL:.*]] = affine.vector_load
 // CHECK-NEXT:   affine.vector_store %[[LDVAL]],{{.*}}
 // CHECK-NEXT: }
+
+func @vector_no_forwarding(%in : memref<512xf32>, %out : memref<512xf32>) {
+  %tmp = memref.alloc() : memref<512xf32>
+  affine.for %i = 0 to 16 {
+    %ld0 = affine.vector_load %in[32*%i] : memref<512xf32>, vector<32xf32>
+    affine.vector_store %ld0, %tmp[32*%i] : memref<512xf32>, vector<32xf32>
+    %ld1 = affine.vector_load %tmp[32*%i] : memref<512xf32>, vector<16xf32>
+    affine.vector_store %ld1, %out[32*%i] : memref<512xf32>, vector<16xf32>
+  }
+  return
+}
+
+// CHECK-LABEL: func @vector_no_forwarding
+// CHECK:      affine.for %{{.*}} = 0 to 16 {
+// CHECK-NEXT:   %[[LDVAL:.*]] = affine.vector_load
+// CHECK-NEXT:   affine.vector_store %[[LDVAL]],{{.*}}
+// CHECK-NEXT:   %[[LDVAL1:.*]] = affine.vector_load
+// CHECK-NEXT:   affine.vector_store %[[LDVAL1]],{{.*}}
+// CHECK-NEXT: }
+
+// CHECK-LABEL: func @simple_three_loads
+func @simple_three_loads(%in : memref<10xf32>) {
+  affine.for %i0 = 0 to 10 {
+    // CHECK:       affine.load
+    %v0 = affine.load %in[%i0] : memref<10xf32>
+    // CHECK-NOT:   affine.load
+    %v1 = affine.load %in[%i0] : memref<10xf32>
+    %v2 = addf %v0, %v1 : f32
+    %v3 = affine.load %in[%i0] : memref<10xf32>
+    %v4 = addf %v2, %v3 : f32
+  }
+  return
+}
+
+// CHECK-LABEL: func @nested_loads_const_index
+func @nested_loads_const_index(%in : memref<10xf32>) {
+  %c0 = constant 0 : index
+  // CHECK:       affine.load
+  %v0 = affine.load %in[%c0] : memref<10xf32>
+  affine.for %i0 = 0 to 10 {
+    affine.for %i1 = 0 to 20 {
+      affine.for %i2 = 0 to 30 {
+        // CHECK-NOT:   affine.load
+        %v1 = affine.load %in[%c0] : memref<10xf32>
+        %v2 = addf %v0, %v1 : f32
+      }
+    }
+  }
+  return
+}
+
+// CHECK-LABEL: func @nested_loads
+func @nested_loads(%N : index, %in : memref<10xf32>) {
+  affine.for %i0 = 0 to 10 {
+    // CHECK:       affine.load
+    %v0 = affine.load %in[%i0] : memref<10xf32>
+    affine.for %i1 = 0 to %N {
+      // CHECK-NOT:   affine.load
+      %v1 = affine.load %in[%i0] : memref<10xf32>
+      %v2 = addf %v0, %v1 : f32
+    }
+  }
+  return
+}
+
+// CHECK-LABEL: func @nested_loads_
diff erent_memref_accesses_no_cse
+func @nested_loads_
diff erent_memref_accesses_no_cse(%in : memref<10xf32>) {
+  affine.for %i0 = 0 to 10 {
+    // CHECK:       affine.load
+    %v0 = affine.load %in[%i0] : memref<10xf32>
+    affine.for %i1 = 0 to 20 {
+      // CHECK:       affine.load
+      %v1 = affine.load %in[%i1] : memref<10xf32>
+      %v2 = addf %v0, %v1 : f32
+    }
+  }
+  return
+}
+
+// CHECK-LABEL: func @load_load_store
+func @load_load_store(%m : memref<10xf32>) {
+  affine.for %i0 = 0 to 10 {
+    // CHECK:       affine.load
+    %v0 = affine.load %m[%i0] : memref<10xf32>
+    // CHECK-NOT:       affine.load
+    %v1 = affine.load %m[%i0] : memref<10xf32>
+    %v2 = addf %v0, %v1 : f32
+    affine.store %v2, %m[%i0] : memref<10xf32>
+  }
+  return
+}
+
+// CHECK-LABEL: func @load_load_store_2_loops_no_cse
+func @load_load_store_2_loops_no_cse(%N : index, %m : memref<10xf32>) {
+  affine.for %i0 = 0 to 10 {
+    // CHECK:       affine.load
+    %v0 = affine.load %m[%i0] : memref<10xf32>
+    affine.for %i1 = 0 to %N {
+      // CHECK:       affine.load
+      %v1 = affine.load %m[%i0] : memref<10xf32>
+      %v2 = addf %v0, %v1 : f32
+      affine.store %v2, %m[%i0] : memref<10xf32>
+    }
+  }
+  return
+}
+
+// CHECK-LABEL: func @load_load_store_3_loops_no_cse
+func @load_load_store_3_loops_no_cse(%m : memref<10xf32>) {
+%cf1 = constant 1.0 : f32
+  affine.for %i0 = 0 to 10 {
+    // CHECK:       affine.load
+    %v0 = affine.load %m[%i0] : memref<10xf32>
+    affine.for %i1 = 0 to 20 {
+      affine.for %i2 = 0 to 30 {
+        // CHECK:       affine.load
+        %v1 = affine.load %m[%i0] : memref<10xf32>
+        %v2 = addf %v0, %v1 : f32
+      }
+      affine.store %cf1, %m[%i0] : memref<10xf32>
+    }
+  }
+  return
+}
+
+// CHECK-LABEL: func @load_load_store_3_loops
+func @load_load_store_3_loops(%m : memref<10xf32>) {
+%cf1 = constant 1.0 : f32
+  affine.for %i0 = 0 to 10 {
+    affine.for %i1 = 0 to 20 {
+      // CHECK:       affine.load
+      %v0 = affine.load %m[%i0] : memref<10xf32>
+      affine.for %i2 = 0 to 30 {
+        // CHECK-NOT:   affine.load
+        %v1 = affine.load %m[%i0] : memref<10xf32>
+        %v2 = addf %v0, %v1 : f32
+      }
+    }
+    affine.store %cf1, %m[%i0] : memref<10xf32>
+  }
+  return
+}
+
+// CHECK-LABEL: func @loads_in_sibling_loops_const_index_no_cse
+func @loads_in_sibling_loops_const_index_no_cse(%m : memref<10xf32>) {
+  %c0 = constant 0 : index
+  affine.for %i0 = 0 to 10 {
+    // CHECK:       affine.load
+    %v0 = affine.load %m[%c0] : memref<10xf32>
+  }
+  affine.for %i1 = 0 to 10 {
+    // CHECK:       affine.load
+    %v0 = affine.load %m[%c0] : memref<10xf32>
+    %v1 = addf %v0, %v0 : f32
+  }
+  return
+}
+
+// CHECK-LABEL: func @load_load_affine_apply
+func @load_load_affine_apply(%in : memref<10x10xf32>) {
+  affine.for %i0 = 0 to 10 {
+    affine.for %i1 = 0 to 10 {
+      %t0 = affine.apply affine_map<(d0, d1) -> (d1 + 1)>(%i0, %i1)
+      %t1 = affine.apply affine_map<(d0, d1) -> (d0)>(%i0, %i1)
+      %idx0 = affine.apply affine_map<(d0, d1) -> (d1)> (%t0, %t1)
+      %idx1 = affine.apply affine_map<(d0, d1) -> (d0 - 1)> (%t0, %t1)
+      // CHECK:       affine.load
+      %v0 = affine.load %in[%idx0, %idx1] : memref<10x10xf32>
+      // CHECK-NOT:   affine.load
+      %v1 = affine.load %in[%i0, %i1] : memref<10x10xf32>
+      %v2 = addf %v0, %v1 : f32
+    }
+  }
+  return
+}
+
+// CHECK-LABEL: func @vector_loads
+func @vector_loads(%in : memref<512xf32>, %out : memref<512xf32>) {
+  affine.for %i = 0 to 16 {
+    // CHECK:       affine.vector_load
+    %ld0 = affine.vector_load %in[32*%i] : memref<512xf32>, vector<32xf32>
+    // CHECK-NOT:   affine.vector_load
+    %ld1 = affine.vector_load %in[32*%i] : memref<512xf32>, vector<32xf32>
+    %add = addf %ld0, %ld1 : vector<32xf32>
+    affine.vector_store %ld1, %out[32*%i] : memref<512xf32>, vector<32xf32>
+  }
+  return
+}
+
+// CHECK-LABEL: func @vector_loads_no_cse
+func @vector_loads_no_cse(%in : memref<512xf32>, %out : memref<512xf32>) {
+  affine.for %i = 0 to 16 {
+    // CHECK:       affine.vector_load
+    %ld0 = affine.vector_load %in[32*%i] : memref<512xf32>, vector<32xf32>
+    // CHECK:   affine.vector_load
+    %ld1 = affine.vector_load %in[32*%i] : memref<512xf32>, vector<16xf32>
+    affine.vector_store %ld1, %out[32*%i] : memref<512xf32>, vector<16xf32>
+  }
+  return
+}
+
+// CHECK-LABEL: func @vector_load_store_load_no_cse
+func @vector_load_store_load_no_cse(%in : memref<512xf32>, %out : memref<512xf32>) {
+  affine.for %i = 0 to 16 {
+    // CHECK:       affine.vector_load
+    %ld0 = affine.vector_load %in[32*%i] : memref<512xf32>, vector<32xf32>
+    affine.vector_store %ld0, %in[16*%i] : memref<512xf32>, vector<32xf32>
+    // CHECK:       affine.vector_load
+    %ld1 = affine.vector_load %in[32*%i] : memref<512xf32>, vector<32xf32>
+    %add = addf %ld0, %ld1 : vector<32xf32>
+    affine.vector_store %ld1, %out[32*%i] : memref<512xf32>, vector<32xf32>
+  }
+  return
+}
+
+// CHECK-LABEL: func @vector_load_affine_apply_store_load
+func @vector_load_affine_apply_store_load(%in : memref<512xf32>, %out : memref<512xf32>) {
+  %cf1 = constant 1: index
+  affine.for %i = 0 to 15 {
+    // CHECK:       affine.vector_load
+    %ld0 = affine.vector_load %in[32*%i] : memref<512xf32>, vector<32xf32>
+    %idx = affine.apply affine_map<(d0) -> (d0 + 1)> (%i)
+    affine.vector_store %ld0, %in[32*%idx] : memref<512xf32>, vector<32xf32>
+    // CHECK-NOT:   affine.vector_load
+    %ld1 = affine.vector_load %in[32*%i] : memref<512xf32>, vector<32xf32>
+    %add = addf %ld0, %ld1 : vector<32xf32>
+    affine.vector_store %ld1, %out[32*%i] : memref<512xf32>, vector<32xf32>
+  }
+  return
+}


        


More information about the Mlir-commits mailing list