[Mlir-commits] [mlir] Refactor LoopFuseSiblingOp and support parallel fusion (PR #94391)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Jun 4 13:29:34 PDT 2024

@@ -1070,6 +1071,206 @@ TileLoops mlir::extractFixedOuterLoops(scf::ForOp rootForOp,
   return tileLoops;
+/// Checks if the parallel loops have mixed access to the same buffers. Returns
+/// `true` if the first parallel loop writes to the same indices that the second
+/// loop reads.
+static bool haveNoReadsAfterWriteExceptSameIndex(
+    scf::ParallelOp firstPloop, scf::ParallelOp secondPloop,
+    const IRMapping &firstToSecondPloopIndices,
+    llvm::function_ref<bool(Value, Value)> mayAlias) {
+  DenseMap<Value, SmallVector<ValueRange, 1>> bufferStores;
+  SmallVector<Value> bufferStoresVec;
+  firstPloop.getBody()->walk([&](memref::StoreOp store) {
+    bufferStores[store.getMemRef()].push_back(store.getIndices());
+    bufferStoresVec.emplace_back(store.getMemRef());
+  });
+  auto walkResult = secondPloop.getBody()->walk([&](memref::LoadOp load) {
+    Value loadMem = load.getMemRef();
+    // Stop if the memref is defined in secondPloop body. Careful alias analysis
+    // is needed.
+    auto *memrefDef = loadMem.getDefiningOp();
+    if (memrefDef && memrefDef->getBlock() == load->getBlock())
+      return WalkResult::interrupt();
+    for (Value store : bufferStoresVec)
+      if (store != loadMem && mayAlias(store, loadMem))
+        return WalkResult::interrupt();
+    auto write = bufferStores.find(loadMem);
+    if (write == bufferStores.end())
+      return WalkResult::advance();
+    // Check that at last one store was retrieved
+    if (!write->second.size())
+      return WalkResult::interrupt();
+    auto storeIndices = write->second.front();
+    // Multiple writes to the same memref are allowed only on the same indices
+    for (const auto &othStoreIndices : write->second) {
+      if (othStoreIndices != storeIndices)
+        return WalkResult::interrupt();
+    }
+    // Check that the load indices of secondPloop coincide with store indices of
+    // firstPloop for the same memrefs.
+    auto loadIndices = load.getIndices();
+    if (storeIndices.size() != loadIndices.size())
+      return WalkResult::interrupt();
+    for (int i = 0, e = storeIndices.size(); i < e; ++i) {
+      if (firstToSecondPloopIndices.lookupOrDefault(storeIndices[i]) !=
+          loadIndices[i]) {
+        auto *storeIndexDefOp = storeIndices[i].getDefiningOp();
+        auto *loadIndexDefOp = loadIndices[i].getDefiningOp();
+        if (storeIndexDefOp && loadIndexDefOp) {
+          if (!isMemoryEffectFree(storeIndexDefOp))
+            return WalkResult::interrupt();
+          if (!isMemoryEffectFree(loadIndexDefOp))
+            return WalkResult::interrupt();
+          if (!OperationEquivalence::isEquivalentTo(
+                  storeIndexDefOp, loadIndexDefOp,
+                  [&](Value storeIndex, Value loadIndex) {
+                    if (firstToSecondPloopIndices.lookupOrDefault(storeIndex) !=
+                        firstToSecondPloopIndices.lookupOrDefault(loadIndex))
+                      return failure();
+                    else
+                      return success();
+                  },
+                  /*markEquivalent=*/nullptr,
+                  OperationEquivalence::Flags::IgnoreLocations)) {
+            return WalkResult::interrupt();
+          }
+        } else
+          return WalkResult::interrupt();
+      }
+    }
+    return WalkResult::advance();
+  });
+  return !walkResult.wasInterrupted();
+/// Analyzes dependencies in the most primitive way by checking simple read and
+/// write patterns.
+static LogicalResult
+verifyDependencies(scf::ParallelOp firstPloop, scf::ParallelOp secondPloop,
+                   const IRMapping &firstToSecondPloopIndices,
+                   llvm::function_ref<bool(Value, Value)> mayAlias) {
+  if (!haveNoReadsAfterWriteExceptSameIndex(
+          firstPloop, secondPloop, firstToSecondPloopIndices, mayAlias))
+    return failure();
+  IRMapping secondToFirstPloopIndices;
+  secondToFirstPloopIndices.map(secondPloop.getBody()->getArguments(),
+                                firstPloop.getBody()->getArguments());
+  return success(haveNoReadsAfterWriteExceptSameIndex(
+      secondPloop, firstPloop, secondToFirstPloopIndices, mayAlias));
+/// Verify equal iteration spaces.
+static bool equalIterationSpaces(scf::ParallelOp firstPloop,
+                                 scf::ParallelOp secondPloop) {
+  if (firstPloop.getNumLoops() != secondPloop.getNumLoops())
+    return false;
+  auto matchOperands = [&](const OperandRange &lhs,
+                           const OperandRange &rhs) -> bool {
+    // TODO: Extend this to support aliases and equal constants.
+    return std::equal(lhs.begin(), lhs.end(), rhs.begin());
+  };
+  return matchOperands(firstPloop.getLowerBound(),
+                       secondPloop.getLowerBound()) &&
+         matchOperands(firstPloop.getUpperBound(),
+                       secondPloop.getUpperBound()) &&
+         matchOperands(firstPloop.getStep(), secondPloop.getStep());
+/// Verify there are no nested ParallelOps.
+static bool hasNestedParallelOp(scf::ParallelOp ploop) {
+  auto walkResult = ploop.getBody()->walk(
+      [](scf::ParallelOp) { return WalkResult::interrupt(); });
+  return walkResult.wasInterrupted();
+static bool isFusionLegal(scf::ParallelOp firstPloop,
+                          scf::ParallelOp secondPloop,
+                          const IRMapping &firstToSecondPloopIndices,
+                          llvm::function_ref<bool(Value, Value)> mayAlias) {
+  return !hasNestedParallelOp(firstPloop) &&
+         !hasNestedParallelOp(secondPloop) &&
+         equalIterationSpaces(firstPloop, secondPloop) &&
+         succeeded(verifyDependencies(firstPloop, secondPloop,
+                                      firstToSecondPloopIndices, mayAlias));
+void mlir::fuseIfLegal(scf::ParallelOp firstPloop, scf::ParallelOp &secondPloop,
+                       OpBuilder builder,
+                       llvm::function_ref<bool(Value, Value)> mayAlias) {
+  Block *block1 = firstPloop.getBody();
+  Block *block2 = secondPloop.getBody();
+  IRMapping firstToSecondPloopIndices;
+  firstToSecondPloopIndices.map(block1->getArguments(), block2->getArguments());
+  if (!isFusionLegal(firstPloop, secondPloop, firstToSecondPloopIndices,
+                     mayAlias))
+    return;
+  DominanceInfo dom;
+  // We are fusing first loop into second, make sure there are no users of the
+  // first loop results between loops.
+  for (Operation *user : firstPloop->getUsers())
+    if (!dom.properlyDominates(secondPloop, user, /*enclosingOpOk*/ false))
+      return;
+  ValueRange inits1 = firstPloop.getInitVals();
+  ValueRange inits2 = secondPloop.getInitVals();
+  SmallVector<Value> newInitVars(inits1.begin(), inits1.end());
+  newInitVars.append(inits2.begin(), inits2.end());
+  IRRewriter b(builder);
+  b.setInsertionPoint(secondPloop);
+  auto newSecondPloop = b.create<scf::ParallelOp>(
+      secondPloop.getLoc(), secondPloop.getLowerBound(),
+      secondPloop.getUpperBound(), secondPloop.getStep(), newInitVars);
+  Block *newBlock = newSecondPloop.getBody();
+  auto term1 = cast<scf::ReduceOp>(block1->getTerminator());
+  auto term2 = cast<scf::ReduceOp>(block2->getTerminator());
+  b.inlineBlockBefore(block2, newBlock, newBlock->begin(),
+                      newBlock->getArguments());
+  b.inlineBlockBefore(block1, newBlock, newBlock->begin(),
+                      newBlock->getArguments());
+  ValueRange results = newSecondPloop.getResults();
+  if (!results.empty()) {
+    b.setInsertionPointToEnd(newBlock);
+    ValueRange reduceArgs1 = term1.getOperands();
+    ValueRange reduceArgs2 = term2.getOperands();
+    SmallVector<Value> newReduceArgs(reduceArgs1.begin(), reduceArgs1.end());
+    newReduceArgs.append(reduceArgs2.begin(), reduceArgs2.end());
+    auto newReduceOp = b.create<scf::ReduceOp>(term2.getLoc(), newReduceArgs);
+    for (auto &&[i, reg] : llvm::enumerate(llvm::concat<Region>(
+             term1.getReductions(), term2.getReductions()))) {
+      Block &oldRedBlock = reg.front();
+      Block &newRedBlock = newReduceOp.getReductions()[i].front();
+      b.inlineBlockBefore(&oldRedBlock, &newRedBlock, newRedBlock.begin(),
+                          newRedBlock.getArguments());
+    }
+    firstPloop.replaceAllUsesWith(results.take_front(inits1.size()));
+    secondPloop.replaceAllUsesWith(results.take_back(inits2.size()));
+  }
+  term1->erase();
+  term2->erase();
+  firstPloop.erase();
+  secondPloop.erase();
+  secondPloop = newSecondPloop;
srcarroll wrote:

this whole block of changes was copied from `ParallelLoopFusion.cpp` without modification.


More information about the Mlir-commits mailing list