[Mlir-commits] [mlir] [mlir][affine] Improve `--affine-scalrep` to identify reduction variables (PR #118987)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Dec 6 07:22:15 PST 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-affine

Author: Clément Fournier (oowekyala)

<details>
<summary>Changes</summary>

Improve the affine scalar replacement pass to identify memref accesses that are used as a reduction variable, and turn them into `iter_args` variables. For instance in:
```mlir
%x = memref.alloc(): memref<10x10xf32>
%min = memref.alloc(): memref<10xf32>
// initialize %min
affine.for %i = 0 to 10 {
   affine.for %j = 0 to 10 {
      %0 = memref.load %min[%i]: memref<10xf32>
      %1 = memref.load %x[%i, %j]: memref<10x10xf32>
      %2 = arith.minimumf %0, %1: f32
      memref.store %2, %min[%i] : memref<10xf32>
   }
}
```
the load/store pattern on `%min` in the inner loop is characteristic of a reduction. The memory location `%min[%i]` is invariant on the inner loop induction var, so it is basically used as a scalar. We can rewrite this loop to the following:
```mlir
%x = memref.alloc(): memref<10x10xf32>
%min = memref.alloc(): memref<10xf32>
// initialize %min
affine.for %i = 0 to 10 {
  %0 = memref.load %min[%i]: memref<10xf32>
  %1 = affine.for %j = 0 to 10 iter_args(%acc = %0) -> f32 {
    %2 = memref.load %x[%i, %j]: memref<10x10xf32>
    %3 = arith.minimumf %acc, %2: f32
    affine.yield %3 : f32
  }
  memref.store %1, %min[%i] : memref<10xf32>
}
```
where this memory location is "scalarized" as an `iter_args` variable. This allows existing affine passes to apply more optimizations on the reduction loop, eg, it can be vectorized, or it can be turned into an `affine.parallel` loop with a combiner for the reduction.

This kind of code pattern is often found in the affine loops generated from linalg code, so I think it's very useful to include this.

I expect maybe some backlash over why I put this into the scalar replacement pass instead of a new pass. I think this is justfied because
1. This transformation moves some loads and stores out of the loop, and these may be forwardable by the existing scalar replacement transformations. Conversely maybe forwarding some loads and stores frees up some dependencies that make this new loop rewriting pattern applicable. So to me those transformation are tightly related, and maybe they should even be put into a fixed-point loop within the scalrep pass.
2. This transformation effectively replaces buffer accesses by a scalar `iter_args` variable. So even if it seems unrelated to the load-store forwardings that the pass is currently doing, I think it still fits within the scope of `--affine-scalrep`.

Thanks for reading!


---
Full diff: https://github.com/llvm/llvm-project/pull/118987.diff


6 Files Affected:

- (modified) mlir/include/mlir/Dialect/Affine/IR/AffineMemoryOpInterfaces.td (+10) 
- (modified) mlir/include/mlir/Dialect/Affine/IR/AffineOps.td (+2-1) 
- (modified) mlir/lib/Dialect/Affine/Analysis/AffineAnalysis.cpp (+8) 
- (modified) mlir/lib/Dialect/Affine/IR/AffineOps.cpp (+5) 
- (modified) mlir/lib/Dialect/Affine/Utils/Utils.cpp (+133) 
- (modified) mlir/test/Dialect/Affine/scalrep.mlir (+41-6) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/Affine/IR/AffineMemoryOpInterfaces.td b/mlir/include/mlir/Dialect/Affine/IR/AffineMemoryOpInterfaces.td
index c07ab9deca48c1..efbe15eb00d7a8 100644
--- a/mlir/include/mlir/Dialect/Affine/IR/AffineMemoryOpInterfaces.td
+++ b/mlir/include/mlir/Dialect/Affine/IR/AffineMemoryOpInterfaces.td
@@ -138,6 +138,16 @@ def AffineWriteOpInterface : OpInterface<"AffineWriteOpInterface"> {
         return $_op.getOperand($_op.getStoredValOperandIndex());
       }]
     >,
+    InterfaceMethod<
+      /*desc=*/"Returns the value to store.",
+      /*retTy=*/"::mlir::OpOperand&",
+      /*methodName=*/"getValueToStoreMutable",
+      /*args=*/(ins),
+      /*methodBody=*/[{}],
+      /*defaultImplementation=*/[{
+        return $_op->getOpOperand($_op.getStoredValOperandIndex());
+      }]
+    >,
   ];
 }
 
diff --git a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
index 03172f7ce00e4b..0f49a26f7aebe1 100644
--- a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
+++ b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
@@ -119,7 +119,8 @@ def AffineForOp : Affine_Op<"for",
      ImplicitAffineTerminator, ConditionallySpeculatable,
      RecursiveMemoryEffects, DeclareOpInterfaceMethods<LoopLikeOpInterface,
      ["getLoopInductionVars", "getLoopLowerBounds", "getLoopSteps",
-      "getLoopUpperBounds", "getYieldedValuesMutable",
+      "getLoopUpperBounds", "getYieldedValuesMutable", "getLoopResults",
+      "getInitsMutable", "getYieldedValuesMutable",
       "replaceWithAdditionalYields"]>,
      DeclareOpInterfaceMethods<RegionBranchOpInterface,
      ["getEntrySuccessorOperands"]>]> {
diff --git a/mlir/lib/Dialect/Affine/Analysis/AffineAnalysis.cpp b/mlir/lib/Dialect/Affine/Analysis/AffineAnalysis.cpp
index 9b776900c379a2..bee82906b31f11 100644
--- a/mlir/lib/Dialect/Affine/Analysis/AffineAnalysis.cpp
+++ b/mlir/lib/Dialect/Affine/Analysis/AffineAnalysis.cpp
@@ -491,6 +491,14 @@ LogicalResult MemRefAccess::getAccessRelation(IntegerRelation &rel) const {
   IntegerRelation domainRel = domain;
   if (rel.getSpace().isUsingIds() && !domainRel.getSpace().isUsingIds())
     domainRel.resetIds();
+
+  if (!rel.getSpace().isUsingIds()) {
+    assert(rel.getNumVars() == 0);
+    rel.resetIds();
+    if (!domainRel.getSpace().isUsingIds())
+      domainRel.resetIds();
+  }
+
   domainRel.appendVar(VarKind::Range, accessValueMap.getNumResults());
   domainRel.mergeAndAlignSymbols(rel);
   domainRel.mergeLocalVars(rel);
diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
index dceebbfec586c8..075b3be4e5c9e6 100644
--- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
+++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
@@ -17,6 +17,7 @@
 #include "mlir/IR/Matchers.h"
 #include "mlir/IR/OpDefinition.h"
 #include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/ValueRange.h"
 #include "mlir/Interfaces/ShapedOpInterfaces.h"
 #include "mlir/Interfaces/ValueBoundsOpInterface.h"
 #include "mlir/Transforms/InliningUtils.h"
@@ -2460,6 +2461,10 @@ bool AffineForOp::matchingBoundOperandList() {
 
 SmallVector<Region *> AffineForOp::getLoopRegions() { return {&getRegion()}; }
 
+std::optional<ResultRange> AffineForOp::getLoopResults() {
+  return {getResults()};
+}
+
 std::optional<SmallVector<Value>> AffineForOp::getLoopInductionVars() {
   return SmallVector<Value>{getInductionVar()};
 }
diff --git a/mlir/lib/Dialect/Affine/Utils/Utils.cpp b/mlir/lib/Dialect/Affine/Utils/Utils.cpp
index 07d399adae0cd4..64d301c5a7fe98 100644
--- a/mlir/lib/Dialect/Affine/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Affine/Utils/Utils.cpp
@@ -14,6 +14,7 @@
 #include "mlir/Dialect/Affine/Utils.h"
 
 #include "mlir/Dialect/Affine/Analysis/Utils.h"
+#include "mlir/Dialect/Affine/IR/AffineMemoryOpInterfaces.h"
 #include "mlir/Dialect/Affine/IR/AffineOps.h"
 #include "mlir/Dialect/Affine/IR/AffineValueMap.h"
 #include "mlir/Dialect/Affine/LoopUtils.h"
@@ -26,9 +27,13 @@
 #include "mlir/IR/IRMapping.h"
 #include "mlir/IR/ImplicitLocOpBuilder.h"
 #include "mlir/IR/IntegerSet.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Interfaces/LoopLikeInterface.h"
+#include "mlir/Interfaces/SideEffectInterfaces.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 #include "llvm/Support/LogicalResult.h"
 #include <optional>
+#include <tuple>
 
 #define DEBUG_TYPE "affine-utils"
 
@@ -886,6 +891,8 @@ static void forwardStoreToLoad(
   // loads and stores.
   if (storeVal.getType() != loadOp.getValue().getType())
     return;
+  LLVM_DEBUG(llvm::dbgs() << "Erased load (forwarded from store): " << loadOp
+                          << "\n");
   loadOp.getValue().replaceAllUsesWith(storeVal);
   // Record the memref for a later sweep to optimize away.
   memrefsToErase.insert(loadOp.getMemRef());
@@ -940,11 +947,132 @@ static void findUnusedStore(AffineWriteOpInterface writeA,
                                                              mayAlias))
       continue;
 
+    LLVM_DEBUG(llvm::dbgs() << "Erased store (unused): " << writeA << "\n");
     opsToErase.push_back(writeA);
     break;
   }
 }
 
+static bool isLoopInvariant(LoopLikeOpInterface loop,
+                            AffineReadOpInterface load) {
+  for (auto operand : load.getMapOperands()) {
+    if (!loop.isDefinedOutsideOfLoop(operand)) {
+      return false;
+    }
+  }
+  return true;
+}
+
+/// This attempts to find load-store pairs in the body of the loop
+/// that could be replaced by an iter_args variable on the loop. The
+/// initial load and the final store are moved out of the loop. For
+/// such a pair to be eligible:
+/// 1. the load must be followed by the store
+/// 2. the memref must not be read again after the store
+/// 3. the indices of the load and store must match AND be
+/// loop-invariant for the given loop.
+///
+/// This is a useful transformation as
+/// - it exposes reduction dependencies that can be extracted by
+/// --affine-parallelize
+/// - it is a common pattern in code lowered from linalg.
+/// - it exposes more opportunities for forwarding of load/store by
+/// moving the load/store out of the loop and into an enclosing scope,
+/// which may themselves have some load/stores that can be matched with
+/// the new ones.
+///
+/// This last point is why it makes sense to include this transformation within
+/// the scalar replacement pass.
+static void findReductionVariablesAndRewrite(
+    LoopLikeOpInterface loop, PostDominanceInfo &postDominanceInfo,
+    llvm::function_ref<bool(Value, Value)> mayAlias) {
+
+  if (!loop.getLoopResults())
+    return;
+
+  SmallVector<std::pair<AffineReadOpInterface, AffineWriteOpInterface>> result;
+  auto *region = loop.getLoopRegions()[0];
+  auto &block = region->front();
+
+  for (auto &op : block.without_terminator()) {
+    // iterate over ops to find loop-invariant load/store pairs
+    auto asLoad = dyn_cast<AffineReadOpInterface>(op);
+    if (!asLoad) {
+      continue;
+    }
+
+    // Indices must be loop-invariant
+    if (!isLoopInvariant(loop, asLoad))
+      continue;
+
+    // find a corresponding store
+    for (auto *user : asLoad.getMemRef().getUsers()) {
+      if (user->getBlock() != &block || user->isBeforeInBlock(&op))
+        continue;
+      auto asStore = dyn_cast<AffineWriteOpInterface>(user);
+      if (!asStore)
+        continue;
+
+      // both load and store must access the same index
+      if (MemRefAccess(asLoad) != MemRefAccess(asStore)) {
+        break;
+      }
+
+      // Check that nobody could be reading from the store before the next load,
+      // as we want to eliminate the store.
+      if (!affine::hasNoInterveningEffect<MemoryEffects::Read>(
+              asStore.getOperation(), asLoad, mayAlias))
+        break;
+
+      // now let's just replace this pair of accesses with loop iter args
+      result.push_back({asLoad, asStore});
+    }
+  }
+  if (result.empty())
+    return;
+
+  SmallVector<Value> newInitOperands;
+  SmallVector<Value> newYieldOperands;
+  IRRewriter rewriter(loop->getContext());
+  rewriter.startOpModification(loop->getParentOp());
+  rewriter.setInsertionPoint(loop);
+  for (auto [load, store] : result) {
+    auto rewrittenLoad = cast<AffineReadOpInterface>(rewriter.clone(*load));
+    newInitOperands.push_back(rewrittenLoad.getValue());
+    newYieldOperands.push_back(store.getValueToStore());
+  }
+
+  const auto numResults = loop.getLoopResults()->size();
+  auto rewritten = loop.replaceWithAdditionalYields(
+      rewriter, newInitOperands, false,
+      [&](OpBuilder &b, Location loc, ArrayRef<BlockArgument> newBbArgs) {
+        return newYieldOperands;
+      });
+  if (failed(rewritten)) {
+    rewriter.cancelOpModification(loop->getParentOp());
+    return;
+  }
+  auto newLoop = *rewritten;
+
+  rewriter.setInsertionPointAfter(newLoop);
+  Operation *next = newLoop;
+  for (auto [loadStore, bbArg, loopRes] :
+       llvm::zip(result, rewritten->getRegionIterArgs().drop_front(numResults),
+                 rewritten->getLoopResults()->drop_front(numResults))) {
+    auto load = loadStore.first;
+    rewriter.replaceOp(load, bbArg);
+
+    auto store = loadStore.second;
+    rewriter.moveOpAfter(store, next);
+    store.getValueToStoreMutable().set(loopRes);
+    next = store;
+  }
+
+  rewriter.finalizeOpModification(newLoop->getParentOp());
+  LLVM_DEBUG(llvm::dbgs() << "Replaced loop reduction variable: \n"
+                          << newLoop << "\n");
+}
+
 // The load to load forwarding / redundant load elimination is similar to the
 // store to load forwarding.
 // loadA will be be replaced with loadB if:
@@ -1045,6 +1173,11 @@ void mlir::affine::affineScalarReplace(func::FuncOp f, DominanceInfo &domInfo,
     return !aliasAnalysis.alias(val1, val2).isNo();
   };
 
+  // scalarize reduction variables as iter_args
+  f.walk([&](AffineForOp loop) {
+    findReductionVariablesAndRewrite(loop, postDomInfo, mayAlias);
+  });
+
   // Walk all load's and perform store to load forwarding.
   f.walk([&](AffineReadOpInterface loadOp) {
     forwardStoreToLoad(loadOp, opsToErase, memrefsToErase, domInfo, mayAlias);
diff --git a/mlir/test/Dialect/Affine/scalrep.mlir b/mlir/test/Dialect/Affine/scalrep.mlir
index fdfe3bfb62f957..d238a8af07e507 100644
--- a/mlir/test/Dialect/Affine/scalrep.mlir
+++ b/mlir/test/Dialect/Affine/scalrep.mlir
@@ -141,9 +141,14 @@ func.func @store_load_store_nested_no_fwd(%N : index) {
   affine.for %i0 = 0 to 10 {
     affine.store %cf7, %m[%i0] : memref<10xf32>
     affine.for %i1 = 0 to %N {
-      // CHECK: %{{[0-9]+}} = affine.load %{{.*}}[%{{.*}}] : memref<10xf32>
+      // CHECK:      %[[C7:.*]] = arith.constant 7.0{{.*}}
+      // CHECK:      %[[C9:.*]] = arith.constant 9.0{{.*}}
+      // CHECK:      %{{[0-9]+}} = affine.for %{{.*}} = 0 to %{{.*}} iter_args(%[[A:.*]] = %[[C7]]) -> (f32)
+      // CHECK-NEXT:    %[[R:.*]] = arith.addf %[[A]], %[[A]] : f32
+      // CHECK:    affine.yield %[[C9]] : f32
       %v0 = affine.load %m[%i0] : memref<10xf32>
       %v1 = arith.addf %v0, %v0 : f32
+      "use"(%v1) : (f32) -> ()
       affine.store %cf9, %m[%i0] : memref<10xf32>
     }
   }
@@ -423,7 +428,8 @@ func.func @load_load_store_2_loops_no_cse(%N : index, %m : memref<10xf32>) {
     // CHECK:       affine.load
     %v0 = affine.load %m[%i0] : memref<10xf32>
     affine.for %i1 = 0 to %N {
-      // CHECK:       affine.load
+      // CHECK:       iter_args
+      // CHECK-NOT:       affine.load
       %v1 = affine.load %m[%i0] : memref<10xf32>
       %v2 = arith.addf %v0, %v1 : f32
       affine.store %v2, %m[%i0] : memref<10xf32>
@@ -556,10 +562,11 @@ func.func @reduction_multi_store() -> memref<1xf32> {
    "test.foo"(%m) : (f32) -> ()
   }
 
-// CHECK:       affine.for
-// CHECK:         affine.load
-// CHECK:         affine.store %[[S:.*]],
-// CHECK-NEXT:    "test.foo"(%[[S]])
+// CHECK:       affine.for {{.*}}
+// CHECK-NEXT:    %[[A:.*]] = affine.load
+// CHECK-NEXT:    %[[X:.*]] = arith.addf %[[A]],
+// CHECK-NEXT:    affine.store %[[X]]
+// CHECK-NEXT:    "test.foo"(%[[X]])
 
   return %A : memref<1xf32>
 }
@@ -891,6 +898,34 @@ func.func @parallel_surrounding_for() {
 // CHECK-NEXT:  return
 }
 
+// CHECK-LABEL: func @reduction_extraction
+func.func @reduction_extraction(%x : memref<10x10xf32>) -> f32 {
+  %b = memref.alloc() : memref<f32>
+  %cst = arith.constant 0.0 : f32
+  affine.store %cst, %b[] : memref<f32>
+  affine.for %i0 = 0 to 10 {
+    affine.for %i1 = 0 to 10 {
+      %v0 = affine.load %x[%i0,%i1] : memref<10x10xf32>
+      %acc = affine.load %b[] : memref<f32>
+      %v1 = arith.addf %acc, %v0 : f32
+      affine.store %v1, %b[] : memref<f32>
+    }
+  }
+  %x2 = affine.load %b[]: memref<f32>
+  return %x2 : f32
+// CHECK:       %[[I:.*]] = arith.constant 0{{.*}} : f32
+// CHECK-NEXT:  %[[SUM2:.*]] = affine.for %{{.*}} = 0 to 10 iter_args(%[[ACC2:.*]] = %[[I]]) -> (f32) {
+// CHECK-NEXT:    %[[SUM:.*]] = affine.for %{{.*}} = 0 to 10 iter_args(%[[ACC:.*]] = %[[ACC2]]) -> (f32) {
+// CHECK-NEXT:      %[[X:.*]] = affine.load {{.*}} : memref<10x10xf32>
+// CHECK-NEXT:      %[[Y:.*]] = arith.addf %[[ACC]], %[[X]] : f32
+// CHECK-NEXT:      affine.yield %[[Y]] : f32
+// CHECK-NEXT:    }
+// CHECK-NEXT:    affine.yield %[[SUM]] : f32
+// CHECK-NEXT:  }
+// CHECK-NEXT:  return %[[SUM2]] : f32
+}
+
+
 // CHECK-LABEL: func.func @dead_affine_region_op
 func.func @dead_affine_region_op() {
   %c1 = arith.constant 1 : index

``````````

</details>


https://github.com/llvm/llvm-project/pull/118987


More information about the Mlir-commits mailing list