[Mlir-commits] [mlir] 71513a7 - [MLIR][Affine] Improve load elimination

llvmlistbot at llvm.org llvmlistbot at llvm.org
Sun Jul 9 05:53:28 PDT 2023


Author: rikhuijzer
Date: 2023-07-09T12:46:36+02:00
New Revision: 71513a71cdf380efd6a44be6939e2cb979a62407

URL: https://github.com/llvm/llvm-project/commit/71513a71cdf380efd6a44be6939e2cb979a62407
DIFF: https://github.com/llvm/llvm-project/commit/71513a71cdf380efd6a44be6939e2cb979a62407.diff

LOG: [MLIR][Affine] Improve load elimination

Fixes #62639.

Differential Revision: https://reviews.llvm.org/D154769

Added: 
    

Modified: 
    mlir/lib/Dialect/Affine/Utils/Utils.cpp
    mlir/test/Dialect/Affine/scalrep.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Affine/Utils/Utils.cpp b/mlir/lib/Dialect/Affine/Utils/Utils.cpp
index 1ba9c82e5af958..8781efb3193137 100644
--- a/mlir/lib/Dialect/Affine/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Affine/Utils/Utils.cpp
@@ -862,9 +862,10 @@ bool mlir::affine::hasNoInterveningEffect(Operation *start, T memOp) {
 /// other operations will overwrite the memory loaded between the given load
 /// and store.  If such a value exists, the replaced `loadOp` will be added to
 /// `loadOpsToErase` and its memref will be added to `memrefsToErase`.
-static LogicalResult forwardStoreToLoad(
-    AffineReadOpInterface loadOp, SmallVectorImpl<Operation *> &loadOpsToErase,
-    SmallPtrSetImpl<Value> &memrefsToErase, DominanceInfo &domInfo) {
+static void forwardStoreToLoad(AffineReadOpInterface loadOp,
+                               SmallVectorImpl<Operation *> &loadOpsToErase,
+                               SmallPtrSetImpl<Value> &memrefsToErase,
+                               DominanceInfo &domInfo) {
 
   // The store op candidate for forwarding that satisfies all conditions
   // to replace the load, if any.
@@ -911,7 +912,7 @@ static LogicalResult forwardStoreToLoad(
   }
 
   if (!lastWriteStoreOp)
-    return failure();
+    return;
 
   // Perform the actual store to load forwarding.
   Value storeVal =
@@ -919,13 +920,12 @@ static LogicalResult forwardStoreToLoad(
   // Check if 2 values have the same shape. This is needed for affine vector
   // loads and stores.
   if (storeVal.getType() != loadOp.getValue().getType())
-    return failure();
+    return;
   loadOp.getValue().replaceAllUsesWith(storeVal);
   // Record the memref for a later sweep to optimize away.
   memrefsToErase.insert(loadOp.getMemRef());
   // Record this to erase later.
   loadOpsToErase.push_back(loadOp);
-  return success();
 }
 
 template bool
@@ -995,16 +995,16 @@ static void loadCSE(AffineReadOpInterface loadA,
     MemRefAccess srcAccess(loadB);
     MemRefAccess destAccess(loadA);
 
-    // 1. The accesses have to be to the same location.
+    // 1. The accesses should be to be to the same location.
     if (srcAccess != destAccess) {
       continue;
     }
 
-    // 2. The store has to dominate the load op to be candidate.
+    // 2. loadB should dominate loadA.
     if (!domInfo.dominates(loadB, loadA))
       continue;
 
-    // 3. There is no write between loadA and loadB.
+    // 3. There should not be a write between loadA and loadB.
     if (!affine::hasNoInterveningEffect<MemoryEffects::Write>(
             loadB.getOperation(), loadA))
       continue;
@@ -1073,13 +1073,8 @@ void mlir::affine::affineScalarReplace(func::FuncOp f, DominanceInfo &domInfo,
 
   // Walk all load's and perform store to load forwarding.
   f.walk([&](AffineReadOpInterface loadOp) {
-    if (failed(
-            forwardStoreToLoad(loadOp, opsToErase, memrefsToErase, domInfo))) {
-      loadCSE(loadOp, opsToErase, domInfo);
-    }
+    forwardStoreToLoad(loadOp, opsToErase, memrefsToErase, domInfo);
   });
-
-  // Erase all load op's whose results were replaced with store fwd'ed ones.
   for (auto *op : opsToErase)
     op->erase();
   opsToErase.clear();
@@ -1088,9 +1083,9 @@ void mlir::affine::affineScalarReplace(func::FuncOp f, DominanceInfo &domInfo,
   f.walk([&](AffineWriteOpInterface storeOp) {
     findUnusedStore(storeOp, opsToErase, postDomInfo);
   });
-  // Erase all store op's which don't impact the program
   for (auto *op : opsToErase)
     op->erase();
+  opsToErase.clear();
 
   // Check if the store fwd'ed memrefs are now left with only stores and
   // deallocs and can thus be completely deleted. Note: the canonicalize pass
@@ -1114,6 +1109,15 @@ void mlir::affine::affineScalarReplace(func::FuncOp f, DominanceInfo &domInfo,
       user->erase();
     defOp->erase();
   }
+
+  // To eliminate as many loads as possible, run load CSE after eliminating
+  // stores. Otherwise, some stores are wrongly seen as having an intervening
+  // effect.
+  f.walk([&](AffineReadOpInterface loadOp) {
+    loadCSE(loadOp, opsToErase, domInfo);
+  });
+  for (auto *op : opsToErase)
+    op->erase();
 }
 
 // Perform the replacement in `op`.

diff  --git a/mlir/test/Dialect/Affine/scalrep.mlir b/mlir/test/Dialect/Affine/scalrep.mlir
index 64b8534a9816a5..22d394bfcf0979 100644
--- a/mlir/test/Dialect/Affine/scalrep.mlir
+++ b/mlir/test/Dialect/Affine/scalrep.mlir
@@ -280,6 +280,31 @@ func.func @refs_not_known_to_be_equal(%A : memref<100 x 100 x f32>, %M : index)
   return
 }
 
+// CHECK-LABEL: func @elim_load_after_store
+func.func @elim_load_after_store(%arg0: memref<100xf32>, %arg1: memref<100xf32>) {
+  %alloc = memref.alloc() : memref<1xf32>
+  %alloc_0 = memref.alloc() : memref<1xf32>
+  // CHECK: affine.for
+  affine.for %arg2 = 0 to 100 {
+    // CHECK: affine.load
+    %0 = affine.load %arg0[%arg2] : memref<100xf32>
+    %1 = affine.load %arg0[%arg2] : memref<100xf32>
+    // CHECK: arith.addf
+    %2 = arith.addf %0, %1 : f32
+    affine.store %2, %alloc_0[0] : memref<1xf32>
+    %3 = affine.load %arg0[%arg2] : memref<100xf32>
+    %4 = affine.load %alloc_0[0] : memref<1xf32>
+    // CHECK-NEXT: arith.addf
+    %5 = arith.addf %3, %4 : f32
+    affine.store %5, %alloc[0] : memref<1xf32>
+    %6 = affine.load %arg0[%arg2] : memref<100xf32>
+    %7 = affine.load %alloc[0] : memref<1xf32>
+    %8 = arith.addf %6, %7 : f32
+    affine.store %8, %arg1[%arg2] : memref<100xf32>
+  }
+  return
+}
+
 // The test checks for value forwarding from vector stores to vector loads.
 // The value loaded from %in can directly be stored to %out by eliminating
 // store and load from %tmp.


        


More information about the Mlir-commits mailing list