[Mlir-commits] [mlir] [mlir][scf] Refactor and improve ParallelLoopFusion (PR #179284)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Feb 2 08:57:14 PST 2026


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: None (fabrizio-indirli)

<details>
<summary>Changes</summary>

Refactor and extend the scf::ParalleLoopFusion pass:
- Refactor code, rename functions and add comments to improve readability
- Make the dependency analysis safer by checking for read-after-write dependencies also with vector.load/store & vector.transfer_read/write ops, in addition to memref.load/store, and bail out when other unsupported ops with memory effects are found.
- Extend the cases when the fusion is applied: allow fusing also when one of the two loops reads/writes to memory through a full view/alias of the buffer (read/written by the dual operation in the other loop) that can be trivially resolved, including rank-reducing full subviews.

---

Patch is 38.71 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/179284.diff


2 Files Affected:

- (modified) mlir/lib/Dialect/SCF/Transforms/ParallelLoopFusion.cpp (+408-71) 
- (modified) mlir/test/Dialect/SCF/parallel-loop-fusion.mlir (+330-5) 


``````````diff
diff --git a/mlir/lib/Dialect/SCF/Transforms/ParallelLoopFusion.cpp b/mlir/lib/Dialect/SCF/Transforms/ParallelLoopFusion.cpp
index 4ea832177c4f9..f3e841dfb5dcf 100644
--- a/mlir/lib/Dialect/SCF/Transforms/ParallelLoopFusion.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/ParallelLoopFusion.cpp
@@ -14,14 +14,19 @@
 
 #include "mlir/Analysis/AliasAnalysis.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
 #include "mlir/Dialect/SCF/IR/SCF.h"
 #include "mlir/Dialect/SCF/Transforms/Transforms.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/IRMapping.h"
 #include "mlir/IR/OpDefinition.h"
 #include "mlir/IR/OperationSupport.h"
 #include "mlir/Interfaces/SideEffectInterfaces.h"
 
+#include "llvm/ADT/SetVector.h"
+#include "llvm/ADT/TypeSwitch.h"
+
 namespace mlir {
 #define GEN_PASS_DEF_SCFPARALLELLOOPFUSION
 #include "mlir/Dialect/SCF/Transforms/Passes.h.inc"
@@ -55,110 +60,442 @@ static bool equalIterationSpaces(ParallelOp firstPloop,
          matchOperands(firstPloop.getStep(), secondPloop.getStep());
 }
 
-/// Checks if the parallel loops have mixed access to the same buffers. Returns
-/// `true` if the first parallel loop writes to the same indices that the second
-/// loop reads.
-static bool haveNoReadsAfterWriteExceptSameIndex(
+/// Check if both operations are the same type of memory write op and
+/// write to the same memory location (same buffer and same indices).
+static bool opsWriteSameMemLocation(Operation *op1, Operation *op2) {
+  if (!op1 || !op2 || op1->getName() != op2->getName())
+    return false;
+  // support only these memory-writing ops for now
+  if (!isa<memref::StoreOp, vector::TransferWriteOp, vector::StoreOp>(op1))
+    return false;
+  bool opsAreIdentical =
+      llvm::TypeSwitch<Operation *, bool>(op1)
+          .Case([&](memref::StoreOp storeOp1) {
+            auto storeOp2 = cast<memref::StoreOp>(op2);
+            return (storeOp1.getMemRef() == storeOp2.getMemRef()) &&
+                   (storeOp1.getIndices() == storeOp2.getIndices());
+          })
+          .Case([&](vector::TransferWriteOp writeOp1) {
+            auto writeOp2 = cast<vector::TransferWriteOp>(op2);
+            return (writeOp1.getBase() == writeOp2.getBase()) &&
+                   (writeOp1.getIndices() == writeOp2.getIndices()) &&
+                   (writeOp1.getMask() == writeOp2.getMask()) &&
+                   (writeOp1.getValueToStore().getType() ==
+                    writeOp2.getValueToStore().getType()) &&
+                   (writeOp1.getInBounds() == writeOp2.getInBounds());
+          })
+          .Case([&](vector::StoreOp vecStoreOp1) {
+            auto vecStoreOp2 = cast<vector::StoreOp>(op2);
+            return (vecStoreOp1.getBase() == vecStoreOp2.getBase()) &&
+                   (vecStoreOp1.getIndices() == vecStoreOp2.getIndices()) &&
+                   (vecStoreOp1.getValueToStore().getType() ==
+                    vecStoreOp2.getValueToStore().getType()) &&
+                   (vecStoreOp1.getAlignment() == vecStoreOp2.getAlignment()) &&
+                   (vecStoreOp1.getNontemporal() ==
+                    vecStoreOp2.getNontemporal());
+          })
+          .Default([](Operation *) { return false; });
+  return opsAreIdentical;
+}
+
+/// Check if val1 (from the first parallel loop) and val2 (from the
+/// second) are equivalent, considering the mapping of induction variables from
+/// the first to the second parallel loop.
+static bool valsAreEquivalent(Value val1, Value val2,
+                              const IRMapping &loopsIVsMap) {
+  if (val1 == val2 || loopsIVsMap.lookupOrDefault(val1) == val2 ||
+      loopsIVsMap.lookupOrDefault(val2) == val1)
+    return true;
+  Operation *val1DefOp = val1.getDefiningOp();
+  Operation *val2DefOp = val2.getDefiningOp();
+  if (!val1DefOp || !val2DefOp)
+    return false;
+  if (!isMemoryEffectFree(val1DefOp) || !isMemoryEffectFree(val2DefOp))
+    return false;
+  return OperationEquivalence::isEquivalentTo(
+      val1DefOp, val2DefOp,
+      [&](Value v1, Value v2) {
+        return success(loopsIVsMap.lookupOrDefault(v1) == v2 ||
+                       loopsIVsMap.lookupOrDefault(v2) == v1);
+      },
+      /*markEquivalent=*/nullptr, OperationEquivalence::Flags::IgnoreLocations);
+}
+
+/// Return the base memref value used by the given memory op.
+template <typename OpTy>
+static Value getBaseMemref(OpTy op) {
+  return llvm::TypeSwitch<Operation *, Value>(op.getOperation())
+      .Case([&](memref::LoadOp load) { return load.getMemRef(); })
+      .Case([&](memref::StoreOp store) { return store.getMemRef(); })
+      .Case([&](vector::TransferReadOp read) { return read.getBase(); })
+      .Case([&](vector::TransferWriteOp write) { return write.getBase(); })
+      .Case([&](vector::LoadOp load) { return load.getBase(); })
+      .Case([&](vector::StoreOp store) { return store.getBase(); })
+      .Default([](Operation *) { return Value(); });
+}
+
+/// Recognize scalar memref.load of an element produced by a
+/// vector.transfer_write (optionally through a rank-reducing, unit-stride
+/// subview) of the same buffer. This covers the pattern where a vector write
+/// stores a full lane pack and a subsequent loop iterates over the lane
+/// dimension with scalar loads. EXAMPLE:
+///  vector.transfer_write %V, %arg[%x, %y, ..., 0] {in_bounds = [true]} :
+///             vector<4xf32>, memref<4xf32, strided<[1], offset: ?>>
+///  scf.for %iter = %c0 to %c4 step %c1 iter_args(...) -> (f32) {
+///    %0 = memref.load %arg[%x, %y, ..., %iter] : memref<1x128x16x4xf32>
+///    ...
+///  }
+///
+static bool loadMatchesVectorWrite(memref::LoadOp loadOp,
+                                   vector::TransferWriteOp writeOp,
+                                   const IRMapping &ivsMap) {
+  auto vecTy = dyn_cast<VectorType>(writeOp.getVector().getType());
+  if (!vecTy || vecTy.getRank() != 1)
+    return false;
+
+  Value base = writeOp.getBase();
+  MemrefValue baseMemref = nullptr;
+  SmallVector<OpFoldResult> offsets;
+  SmallVector<OpFoldResult> sizes;
+  auto ctx = loadOp.getContext();
+  if (auto subView = base.getDefiningOp<memref::SubViewOp>()) {
+    if (!subView.hasUnitStride())
+      return false;
+    baseMemref = cast<MemrefValue>(subView.getSource());
+    offsets = llvm::to_vector(subView.getMixedOffsets());
+    sizes = llvm::to_vector(subView.getMixedSizes());
+  } else {
+    baseMemref = dyn_cast<MemrefValue>(base);
+    if (!baseMemref)
+      return false;
+    // Fabricate rank-1 view matching the vector length at the end.
+    sizes = SmallVector<OpFoldResult>{
+        IntegerAttr::get(IndexType::get(ctx), vecTy.getDimSize(0))};
+    offsets = SmallVector<OpFoldResult>{
+        writeOp.getIndices().empty()
+            ? OpFoldResult(IntegerAttr::get(IndexType::get(ctx), 0))
+            : writeOp.getIndices().front()};
+  }
+
+  if (sizes.empty() || !isConstantIntValue(sizes.back(), vecTy.getDimSize(0)))
+    return false;
+
+  if (loadOp.getMemref() != baseMemref)
+    return false;
+
+  auto loadIndices = loadOp.getIndices();
+  if (loadIndices.size() != sizes.size())
+    return false;
+
+  // All leading dims size-1; offsets must match load indices.
+  for (unsigned i = 0; i + 1 < sizes.size(); ++i) {
+    if (!isConstantIntValue(sizes[i], 1))
+      return false;
+    if (auto attr = offsets[i].dyn_cast<Attribute>()) {
+      auto cst = dyn_cast<IntegerAttr>(attr);
+      if (!cst || cst.getInt() != 0 || !matchPattern(loadIndices[i], m_Zero()))
+        return false;
+    } else if (auto val = offsets[i].dyn_cast<Value>()) {
+      if (!valsAreEquivalent(val, loadIndices[i], ivsMap))
+        return false;
+    } else {
+      return false;
+    }
+  }
+
+  // transfer_write must start at lane 0 of the subview.
+  if (writeOp.getIndices().size() != 1 ||
+      !matchPattern(writeOp.getIndices().front(), m_Zero()))
+    return false;
+
+  // Last load index must be an scf.for induction variable iterating [0,
+  // vecLen).
+  auto laneIdx = dyn_cast_or_null<BlockArgument>(loadIndices.back());
+  auto forOp = laneIdx ? dyn_cast<scf::ForOp>(laneIdx.getOwner()->getParentOp())
+                       : nullptr;
+  if (!forOp || laneIdx != forOp.getInductionVar())
+    return false;
+  auto lb = forOp.getLowerBound().getDefiningOp<arith::ConstantIndexOp>();
+  auto ub = forOp.getUpperBound().getDefiningOp<arith::ConstantIndexOp>();
+  auto step = forOp.getStep().getDefiningOp<arith::ConstantIndexOp>();
+  if (!lb || lb.value() != 0 || !step || step.value() != 1 || !ub)
+    return false;
+  if (ub.value() != vecTy.getDimSize(0))
+    return false;
+
+  return true;
+}
+
+/// Check if both operations access the same positions of the same
+/// buffer, but one of the two does it through a rank-reducing full subview of
+/// the buffer (the other's base). EXAMPLE:
+///  memref.store %a, %buf[%c0, %i, %j] : memref<1x2x2xf32>
+///  %alias = memref.subview %buf[0, 0, 0][1, 2, 2][1, 1, 1]: memref<1x2x2xf32>
+///                           to memref<2x2xf32>
+///  %val = memref.load %alias[%i, %j] : memref<2x2xf32>
+template <typename OpTy1, typename OpTy2>
+static bool opsAccessSameIndicesViaRankReducingSubview(
+    OpTy1 op1, OpTy2 op2, const IRMapping &firstToSecondPloopIVsMap) {
+  auto base1 = cast<MemrefValue>(getBaseMemref(op1));
+  auto base2 = cast<MemrefValue>(getBaseMemref(op2));
+  if (!base1 || !base2)
+    return false;
+
+  auto accessThroughTrivialSubviewIsSame =
+      [](memref::SubViewOp subView, ValueRange subViewAccess,
+         ValueRange sourceAccess, const IRMapping &ivsMap) -> bool {
+    if (!subView.hasZeroOffset() || !subView.hasUnitStride())
+      return false;
+
+    MemRefType srcType = subView.getSourceType();
+    MemRefType resType = subView.getType();
+    unsigned srcRank = srcType.getRank();
+    unsigned resRank = resType.getRank();
+    if (sourceAccess.size() != srcRank || subViewAccess.size() != resRank)
+      return false;
+
+    auto staticSizes = subView.getStaticSizes();
+    auto droppedDims =
+        mlir::computeRankReductionMask(srcType.getShape(), resType.getShape());
+    if (!droppedDims || (droppedDims->size() != srcRank - resRank))
+      return false;
+
+    unsigned resPos = 0;
+    for (unsigned srcPos = 0; srcPos < srcRank; ++srcPos) {
+      if (droppedDims->contains(srcPos)) {
+        if (staticSizes[srcPos] != 1 ||
+            !matchPattern(sourceAccess[srcPos], m_Zero()))
+          return false;
+        continue;
+      }
+      if (resPos >= resRank || !valsAreEquivalent(subViewAccess[resPos],
+                                                  sourceAccess[srcPos], ivsMap))
+        return false;
+      ++resPos;
+    }
+    return resPos == resRank;
+  };
+
+  // Case 1: op1 uses a subview of op2's base.
+  if (auto subView = base1.template getDefiningOp<memref::SubViewOp>();
+      subView &&
+      memref::isSameViewOrTrivialAlias(
+          base2, cast<MemrefValue>(subView.getSource())) &&
+      accessThroughTrivialSubviewIsSame(subView, op1.getIndices(),
+                                        op2.getIndices(),
+                                        firstToSecondPloopIVsMap))
+    return true;
+
+  // Case 2: op2 uses a subview of op1's base.
+  if (auto subView = base2.template getDefiningOp<memref::SubViewOp>();
+      subView &&
+      memref::isSameViewOrTrivialAlias(
+          base1, cast<MemrefValue>(subView.getSource())) &&
+      accessThroughTrivialSubviewIsSame(subView, op2.getIndices(),
+                                        op1.getIndices(),
+                                        firstToSecondPloopIVsMap))
+    return true;
+
+  return false;
+}
+
+/// Check if both memory read/write operations access the same indices
+/// (considering also the mapping of induction variables from the first to the
+/// second parallel loop).
+template <typename OpTy1, typename OpTy2>
+static bool opsAccessSameIndices(OpTy1 op1, OpTy2 op2,
+                                 const IRMapping &loopsIVsMap) {
+  auto indices1 = op1.getIndices();
+  auto indices2 = op2.getIndices();
+  if (indices1.size() != indices2.size())
+    return opsAccessSameIndicesViaRankReducingSubview(op1, op2, loopsIVsMap);
+  for (auto [idx1, idx2] : llvm::zip(indices1, indices2)) {
+    if (!valsAreEquivalent(idx1, idx2, loopsIVsMap))
+      return false;
+  }
+  return true;
+}
+
+/// Check if the loadOp reads from the same memory location (same buffer,
+/// same indices and same properties) as written by the storeOp.
+static bool loadsFromSameMemoryLocationWrittenBy(
+    Operation *loadOp, Operation *storeOp,
+    const IRMapping &firstToSecondPloopIVsMap) {
+  if (!loadOp || !storeOp)
+    return false;
+  // support only these memory-reading ops for now
+  if (!isa<memref::LoadOp, vector::TransferReadOp, vector::LoadOp>(loadOp))
+    return false;
+  bool accessSameMemory =
+      llvm::TypeSwitch<Operation *, bool>(loadOp)
+          .Case([&](memref::LoadOp memLoadOp) {
+            if (auto memStoreOp = dyn_cast<memref::StoreOp>(storeOp))
+              return opsAccessSameIndices(memLoadOp, memStoreOp,
+                                          firstToSecondPloopIVsMap);
+            if (auto vecWriteOp = dyn_cast<vector::TransferWriteOp>(storeOp))
+              return loadMatchesVectorWrite(memLoadOp, vecWriteOp,
+                                            firstToSecondPloopIVsMap);
+            return false;
+          })
+          .Case([&](vector::TransferReadOp vecReadOp) {
+            auto vecWriteOp = dyn_cast<vector::TransferWriteOp>(storeOp);
+            if (!vecWriteOp)
+              return false;
+            return opsAccessSameIndices(vecReadOp, vecWriteOp,
+                                        firstToSecondPloopIVsMap) &&
+                   (vecReadOp.getMask() == vecWriteOp.getMask()) &&
+                   (vecReadOp.getInBounds() == vecWriteOp.getInBounds());
+          })
+          .Case([&](vector::LoadOp vecLoadOp) {
+            auto vecStoreOp = dyn_cast<vector::StoreOp>(storeOp);
+            if (!vecStoreOp)
+              return false;
+            return opsAccessSameIndices(vecLoadOp, vecStoreOp,
+                                        firstToSecondPloopIVsMap) &&
+                   (vecLoadOp.getAlignment() == vecStoreOp.getAlignment());
+          })
+          .Default([](Operation *) { return false; });
+  return accessSameMemory;
+}
+
+static Value getStoreOpTargetBuffer(Operation *op) {
+  return llvm::TypeSwitch<Operation *, Value>(op)
+      .Case([&](memref::StoreOp storeOp) { return storeOp.getMemRef(); })
+      .Case([&](vector::TransferWriteOp writeOp) { return writeOp.getBase(); })
+      .Case([&](vector::StoreOp vecStoreOp) { return vecStoreOp.getBase(); })
+      .Default([](Operation *) { return Value(); });
+}
+
+/// Check that the parallel loops have no mixed access to the same buffers.
+/// Return `true` if the second parallel loop does not read or write the buffers
+/// written by the first loop using different indices.
+static bool haveNoDataDependenciesExceptSameIndex(
     ParallelOp firstPloop, ParallelOp secondPloop,
     const IRMapping &firstToSecondPloopIndices,
     llvm::function_ref<bool(Value, Value)> mayAlias) {
-  DenseMap<Value, SmallVector<ValueRange, 1>> bufferStores;
-  SmallVector<Value> bufferStoresVec;
-  firstPloop.getBody()->walk([&](memref::StoreOp store) {
-    bufferStores[store.getMemRef()].push_back(store.getIndices());
-    bufferStoresVec.emplace_back(store.getMemRef());
-  });
-  auto walkResult = secondPloop.getBody()->walk([&](memref::LoadOp load) {
-    Value loadMem = load.getMemRef();
+  // Map buffers to their store/write ops in the firstPloop
+  DenseMap<Value, SmallVector<Operation *>> bufferStoresInFirstPloop;
+  // Record all the memory buffers used in store/write ops found in firstPloop
+  llvm::SmallSetVector<Value, 4> buffersWrittenInFirstPloop;
+
+  // Walk the first parallel loop to collect all store/write ops and their
+  // target buffers
+  if (firstPloop.getBody()
+          ->walk([&](Operation *op) {
+            auto memOpInterf = dyn_cast_if_present<MemoryEffectOpInterface>(op);
+            // ignore ops that don't write to memory
+            if (!memOpInterf ||
+                (!memOpInterf.hasEffect<MemoryEffects::Write>() &&
+                 !memOpInterf.hasEffect<MemoryEffects::Free>()))
+              return WalkResult::advance();
+
+            // only these memory-writing ops are supported for now:
+            // memref.store, vector.transfer_write, vector.store
+            Value storeOpBase = getStoreOpTargetBuffer(op);
+            if (!storeOpBase)
+              return WalkResult::interrupt();
+
+            // Expect the base operand to be a Memref
+            MemrefValue storeOpBaseMemref = dyn_cast<MemrefValue>(storeOpBase);
+            if (!storeOpBaseMemref)
+              return WalkResult::interrupt();
+            // Get the original memref buffer, skipping full view-like ops
+            Value buffer =
+                memref::skipFullyAliasingOperations(storeOpBaseMemref);
+            bufferStoresInFirstPloop[buffer].push_back(op);
+            buffersWrittenInFirstPloop.insert(buffer);
+            return WalkResult::advance();
+          })
+          .wasInterrupted())
+    return false;
+
+  // Walk the second parallel loop to check load/read ops against the stores
+  // collected from the first parallel loop: the loops can be fused only if in
+  // the 2nd loop there are no loads/stores from/yo the buffers written in the
+  // 1st loop, except when on the same exact memory location (same indices) as
+  // written in the 1st loop.
+  auto walkResult = secondPloop.getBody()->walk([&](Operation *loadOp) {
+    auto memOpInterf = dyn_cast_if_present<MemoryEffectOpInterface>(loadOp);
+    // ignore ops that don't read from memory
+    if (!memOpInterf || (!memOpInterf.hasEffect<MemoryEffects::Read>()))
+      return WalkResult::advance();
+    // support only these memory-reading ops for now
+    if (!isa<memref::LoadOp, vector::TransferReadOp, vector::LoadOp>(loadOp) ||
+        !isa<MemrefValue>(loadOp->getOperand(0)))
+      return WalkResult::interrupt();
+
+    MemrefValue loadOpBase = cast<MemrefValue>(loadOp->getOperand(0));
+    MemrefValue loadedOrigBuf = memref::skipFullyAliasingOperations(loadOpBase);
     // Stop if the memref is defined in secondPloop body. Careful alias analysis
     // is needed.
-    auto *memrefDef = loadMem.getDefiningOp();
-    if (memrefDef && memrefDef->getBlock() == load->getBlock())
+    auto *memrefDef = loadedOrigBuf.getDefiningOp();
+    if (memrefDef && secondPloop->isAncestor(memrefDef))
       return WalkResult::interrupt();
 
-    for (Value store : bufferStoresVec)
-      if (store != loadMem && mayAlias(store, loadMem))
+    for (Value storedMem : buffersWrittenInFirstPloop)
+      if (storedMem != loadedOrigBuf && mayAlias(storedMem, loadedOrigBuf))
         return WalkResult::interrupt();
 
-    auto write = bufferStores.find(loadMem);
-    if (write == bufferStores.end())
+    auto writeOpsIt = bufferStoresInFirstPloop.find(loadedOrigBuf);
+    if (writeOpsIt == bufferStoresInFirstPloop.end())
       return WalkResult::advance();
+    // Store/write ops to this buffer in the firstPloop
+    auto &writeOps = writeOpsIt->second;
 
-    // Check that at last one store was retrieved
-    if (write->second.empty())
-      return WalkResult::interrupt();
+    // If the first loop has no writes to this buffer, continue
+    if (writeOps.empty())
+      return WalkResult::advance();
 
-    auto storeIndices = write->second.front();
+    Operation *writeOp = writeOps.front();
 
-    // Multiple writes to the same memref are allowed only on the same indices
-    for (const auto &othStoreIndices : write->second) {
-      if (othStoreIndices != storeIndices)
-        return WalkResult::interrupt();
-    }
+    // In the first parallel loop, multiple writes to the same memref are
+    // allowed only on the same memory location
+    if (!llvm::all_of(writeOps, [&](Operation *otherWriteOp) {
+          return opsWriteSameMemLocation(writeOp, otherWriteOp);
+        }))
+      return WalkResult::interrupt();
 
-    // Check that the load indices of secondPloop coincide with store indices of
-    // firstPloop for the same memrefs.
-    auto loadIndices...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/179284


More information about the Mlir-commits mailing list