[Mlir-commits] [mlir] [mlir][scf] Refactor and improve ParallelLoopFusion (PR #179284)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Feb 2 08:57:14 PST 2026
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: None (fabrizio-indirli)
<details>
<summary>Changes</summary>
Refactor and extend the scf::ParalleLoopFusion pass:
- Refactor code, rename functions and add comments to improve readability
- Make the dependency analysis safer by checking for read-after-write dependencies also with vector.load/store & vector.transfer_read/write ops, in addition to memref.load/store, and bail out when other unsupported ops with memory effects are found.
- Extend the cases when the fusion is applied: allow fusing also when one of the two loops reads/writes to memory through a full view/alias of the buffer (read/written by the dual operation in the other loop) that can be trivially resolved, including rank-reducing full subviews.
---
Patch is 38.71 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/179284.diff
2 Files Affected:
- (modified) mlir/lib/Dialect/SCF/Transforms/ParallelLoopFusion.cpp (+408-71)
- (modified) mlir/test/Dialect/SCF/parallel-loop-fusion.mlir (+330-5)
``````````diff
diff --git a/mlir/lib/Dialect/SCF/Transforms/ParallelLoopFusion.cpp b/mlir/lib/Dialect/SCF/Transforms/ParallelLoopFusion.cpp
index 4ea832177c4f9..f3e841dfb5dcf 100644
--- a/mlir/lib/Dialect/SCF/Transforms/ParallelLoopFusion.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/ParallelLoopFusion.cpp
@@ -14,14 +14,19 @@
#include "mlir/Analysis/AliasAnalysis.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/SCF/Transforms/Transforms.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/OperationSupport.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
+#include "llvm/ADT/SetVector.h"
+#include "llvm/ADT/TypeSwitch.h"
+
namespace mlir {
#define GEN_PASS_DEF_SCFPARALLELLOOPFUSION
#include "mlir/Dialect/SCF/Transforms/Passes.h.inc"
@@ -55,110 +60,442 @@ static bool equalIterationSpaces(ParallelOp firstPloop,
matchOperands(firstPloop.getStep(), secondPloop.getStep());
}
-/// Checks if the parallel loops have mixed access to the same buffers. Returns
-/// `true` if the first parallel loop writes to the same indices that the second
-/// loop reads.
-static bool haveNoReadsAfterWriteExceptSameIndex(
+/// Check if both operations are the same type of memory write op and
+/// write to the same memory location (same buffer and same indices).
+static bool opsWriteSameMemLocation(Operation *op1, Operation *op2) {
+ if (!op1 || !op2 || op1->getName() != op2->getName())
+ return false;
+ // support only these memory-writing ops for now
+ if (!isa<memref::StoreOp, vector::TransferWriteOp, vector::StoreOp>(op1))
+ return false;
+ bool opsAreIdentical =
+ llvm::TypeSwitch<Operation *, bool>(op1)
+ .Case([&](memref::StoreOp storeOp1) {
+ auto storeOp2 = cast<memref::StoreOp>(op2);
+ return (storeOp1.getMemRef() == storeOp2.getMemRef()) &&
+ (storeOp1.getIndices() == storeOp2.getIndices());
+ })
+ .Case([&](vector::TransferWriteOp writeOp1) {
+ auto writeOp2 = cast<vector::TransferWriteOp>(op2);
+ return (writeOp1.getBase() == writeOp2.getBase()) &&
+ (writeOp1.getIndices() == writeOp2.getIndices()) &&
+ (writeOp1.getMask() == writeOp2.getMask()) &&
+ (writeOp1.getValueToStore().getType() ==
+ writeOp2.getValueToStore().getType()) &&
+ (writeOp1.getInBounds() == writeOp2.getInBounds());
+ })
+ .Case([&](vector::StoreOp vecStoreOp1) {
+ auto vecStoreOp2 = cast<vector::StoreOp>(op2);
+ return (vecStoreOp1.getBase() == vecStoreOp2.getBase()) &&
+ (vecStoreOp1.getIndices() == vecStoreOp2.getIndices()) &&
+ (vecStoreOp1.getValueToStore().getType() ==
+ vecStoreOp2.getValueToStore().getType()) &&
+ (vecStoreOp1.getAlignment() == vecStoreOp2.getAlignment()) &&
+ (vecStoreOp1.getNontemporal() ==
+ vecStoreOp2.getNontemporal());
+ })
+ .Default([](Operation *) { return false; });
+ return opsAreIdentical;
+}
+
+/// Check if val1 (from the first parallel loop) and val2 (from the
+/// second) are equivalent, considering the mapping of induction variables from
+/// the first to the second parallel loop.
+static bool valsAreEquivalent(Value val1, Value val2,
+ const IRMapping &loopsIVsMap) {
+ if (val1 == val2 || loopsIVsMap.lookupOrDefault(val1) == val2 ||
+ loopsIVsMap.lookupOrDefault(val2) == val1)
+ return true;
+ Operation *val1DefOp = val1.getDefiningOp();
+ Operation *val2DefOp = val2.getDefiningOp();
+ if (!val1DefOp || !val2DefOp)
+ return false;
+ if (!isMemoryEffectFree(val1DefOp) || !isMemoryEffectFree(val2DefOp))
+ return false;
+ return OperationEquivalence::isEquivalentTo(
+ val1DefOp, val2DefOp,
+ [&](Value v1, Value v2) {
+ return success(loopsIVsMap.lookupOrDefault(v1) == v2 ||
+ loopsIVsMap.lookupOrDefault(v2) == v1);
+ },
+ /*markEquivalent=*/nullptr, OperationEquivalence::Flags::IgnoreLocations);
+}
+
+/// Return the base memref value used by the given memory op.
+template <typename OpTy>
+static Value getBaseMemref(OpTy op) {
+ return llvm::TypeSwitch<Operation *, Value>(op.getOperation())
+ .Case([&](memref::LoadOp load) { return load.getMemRef(); })
+ .Case([&](memref::StoreOp store) { return store.getMemRef(); })
+ .Case([&](vector::TransferReadOp read) { return read.getBase(); })
+ .Case([&](vector::TransferWriteOp write) { return write.getBase(); })
+ .Case([&](vector::LoadOp load) { return load.getBase(); })
+ .Case([&](vector::StoreOp store) { return store.getBase(); })
+ .Default([](Operation *) { return Value(); });
+}
+
+/// Recognize scalar memref.load of an element produced by a
+/// vector.transfer_write (optionally through a rank-reducing, unit-stride
+/// subview) of the same buffer. This covers the pattern where a vector write
+/// stores a full lane pack and a subsequent loop iterates over the lane
+/// dimension with scalar loads. EXAMPLE:
+/// vector.transfer_write %V, %arg[%x, %y, ..., 0] {in_bounds = [true]} :
+/// vector<4xf32>, memref<4xf32, strided<[1], offset: ?>>
+/// scf.for %iter = %c0 to %c4 step %c1 iter_args(...) -> (f32) {
+/// %0 = memref.load %arg[%x, %y, ..., %iter] : memref<1x128x16x4xf32>
+/// ...
+/// }
+///
+static bool loadMatchesVectorWrite(memref::LoadOp loadOp,
+ vector::TransferWriteOp writeOp,
+ const IRMapping &ivsMap) {
+ auto vecTy = dyn_cast<VectorType>(writeOp.getVector().getType());
+ if (!vecTy || vecTy.getRank() != 1)
+ return false;
+
+ Value base = writeOp.getBase();
+ MemrefValue baseMemref = nullptr;
+ SmallVector<OpFoldResult> offsets;
+ SmallVector<OpFoldResult> sizes;
+ auto ctx = loadOp.getContext();
+ if (auto subView = base.getDefiningOp<memref::SubViewOp>()) {
+ if (!subView.hasUnitStride())
+ return false;
+ baseMemref = cast<MemrefValue>(subView.getSource());
+ offsets = llvm::to_vector(subView.getMixedOffsets());
+ sizes = llvm::to_vector(subView.getMixedSizes());
+ } else {
+ baseMemref = dyn_cast<MemrefValue>(base);
+ if (!baseMemref)
+ return false;
+ // Fabricate rank-1 view matching the vector length at the end.
+ sizes = SmallVector<OpFoldResult>{
+ IntegerAttr::get(IndexType::get(ctx), vecTy.getDimSize(0))};
+ offsets = SmallVector<OpFoldResult>{
+ writeOp.getIndices().empty()
+ ? OpFoldResult(IntegerAttr::get(IndexType::get(ctx), 0))
+ : writeOp.getIndices().front()};
+ }
+
+ if (sizes.empty() || !isConstantIntValue(sizes.back(), vecTy.getDimSize(0)))
+ return false;
+
+ if (loadOp.getMemref() != baseMemref)
+ return false;
+
+ auto loadIndices = loadOp.getIndices();
+ if (loadIndices.size() != sizes.size())
+ return false;
+
+ // All leading dims size-1; offsets must match load indices.
+ for (unsigned i = 0; i + 1 < sizes.size(); ++i) {
+ if (!isConstantIntValue(sizes[i], 1))
+ return false;
+ if (auto attr = offsets[i].dyn_cast<Attribute>()) {
+ auto cst = dyn_cast<IntegerAttr>(attr);
+ if (!cst || cst.getInt() != 0 || !matchPattern(loadIndices[i], m_Zero()))
+ return false;
+ } else if (auto val = offsets[i].dyn_cast<Value>()) {
+ if (!valsAreEquivalent(val, loadIndices[i], ivsMap))
+ return false;
+ } else {
+ return false;
+ }
+ }
+
+ // transfer_write must start at lane 0 of the subview.
+ if (writeOp.getIndices().size() != 1 ||
+ !matchPattern(writeOp.getIndices().front(), m_Zero()))
+ return false;
+
+ // Last load index must be an scf.for induction variable iterating [0,
+ // vecLen).
+ auto laneIdx = dyn_cast_or_null<BlockArgument>(loadIndices.back());
+ auto forOp = laneIdx ? dyn_cast<scf::ForOp>(laneIdx.getOwner()->getParentOp())
+ : nullptr;
+ if (!forOp || laneIdx != forOp.getInductionVar())
+ return false;
+ auto lb = forOp.getLowerBound().getDefiningOp<arith::ConstantIndexOp>();
+ auto ub = forOp.getUpperBound().getDefiningOp<arith::ConstantIndexOp>();
+ auto step = forOp.getStep().getDefiningOp<arith::ConstantIndexOp>();
+ if (!lb || lb.value() != 0 || !step || step.value() != 1 || !ub)
+ return false;
+ if (ub.value() != vecTy.getDimSize(0))
+ return false;
+
+ return true;
+}
+
+/// Check if both operations access the same positions of the same
+/// buffer, but one of the two does it through a rank-reducing full subview of
+/// the buffer (the other's base). EXAMPLE:
+/// memref.store %a, %buf[%c0, %i, %j] : memref<1x2x2xf32>
+/// %alias = memref.subview %buf[0, 0, 0][1, 2, 2][1, 1, 1]: memref<1x2x2xf32>
+/// to memref<2x2xf32>
+/// %val = memref.load %alias[%i, %j] : memref<2x2xf32>
+template <typename OpTy1, typename OpTy2>
+static bool opsAccessSameIndicesViaRankReducingSubview(
+ OpTy1 op1, OpTy2 op2, const IRMapping &firstToSecondPloopIVsMap) {
+ auto base1 = cast<MemrefValue>(getBaseMemref(op1));
+ auto base2 = cast<MemrefValue>(getBaseMemref(op2));
+ if (!base1 || !base2)
+ return false;
+
+ auto accessThroughTrivialSubviewIsSame =
+ [](memref::SubViewOp subView, ValueRange subViewAccess,
+ ValueRange sourceAccess, const IRMapping &ivsMap) -> bool {
+ if (!subView.hasZeroOffset() || !subView.hasUnitStride())
+ return false;
+
+ MemRefType srcType = subView.getSourceType();
+ MemRefType resType = subView.getType();
+ unsigned srcRank = srcType.getRank();
+ unsigned resRank = resType.getRank();
+ if (sourceAccess.size() != srcRank || subViewAccess.size() != resRank)
+ return false;
+
+ auto staticSizes = subView.getStaticSizes();
+ auto droppedDims =
+ mlir::computeRankReductionMask(srcType.getShape(), resType.getShape());
+ if (!droppedDims || (droppedDims->size() != srcRank - resRank))
+ return false;
+
+ unsigned resPos = 0;
+ for (unsigned srcPos = 0; srcPos < srcRank; ++srcPos) {
+ if (droppedDims->contains(srcPos)) {
+ if (staticSizes[srcPos] != 1 ||
+ !matchPattern(sourceAccess[srcPos], m_Zero()))
+ return false;
+ continue;
+ }
+ if (resPos >= resRank || !valsAreEquivalent(subViewAccess[resPos],
+ sourceAccess[srcPos], ivsMap))
+ return false;
+ ++resPos;
+ }
+ return resPos == resRank;
+ };
+
+ // Case 1: op1 uses a subview of op2's base.
+ if (auto subView = base1.template getDefiningOp<memref::SubViewOp>();
+ subView &&
+ memref::isSameViewOrTrivialAlias(
+ base2, cast<MemrefValue>(subView.getSource())) &&
+ accessThroughTrivialSubviewIsSame(subView, op1.getIndices(),
+ op2.getIndices(),
+ firstToSecondPloopIVsMap))
+ return true;
+
+ // Case 2: op2 uses a subview of op1's base.
+ if (auto subView = base2.template getDefiningOp<memref::SubViewOp>();
+ subView &&
+ memref::isSameViewOrTrivialAlias(
+ base1, cast<MemrefValue>(subView.getSource())) &&
+ accessThroughTrivialSubviewIsSame(subView, op2.getIndices(),
+ op1.getIndices(),
+ firstToSecondPloopIVsMap))
+ return true;
+
+ return false;
+}
+
+/// Check if both memory read/write operations access the same indices
+/// (considering also the mapping of induction variables from the first to the
+/// second parallel loop).
+template <typename OpTy1, typename OpTy2>
+static bool opsAccessSameIndices(OpTy1 op1, OpTy2 op2,
+ const IRMapping &loopsIVsMap) {
+ auto indices1 = op1.getIndices();
+ auto indices2 = op2.getIndices();
+ if (indices1.size() != indices2.size())
+ return opsAccessSameIndicesViaRankReducingSubview(op1, op2, loopsIVsMap);
+ for (auto [idx1, idx2] : llvm::zip(indices1, indices2)) {
+ if (!valsAreEquivalent(idx1, idx2, loopsIVsMap))
+ return false;
+ }
+ return true;
+}
+
+/// Check if the loadOp reads from the same memory location (same buffer,
+/// same indices and same properties) as written by the storeOp.
+static bool loadsFromSameMemoryLocationWrittenBy(
+ Operation *loadOp, Operation *storeOp,
+ const IRMapping &firstToSecondPloopIVsMap) {
+ if (!loadOp || !storeOp)
+ return false;
+ // support only these memory-reading ops for now
+ if (!isa<memref::LoadOp, vector::TransferReadOp, vector::LoadOp>(loadOp))
+ return false;
+ bool accessSameMemory =
+ llvm::TypeSwitch<Operation *, bool>(loadOp)
+ .Case([&](memref::LoadOp memLoadOp) {
+ if (auto memStoreOp = dyn_cast<memref::StoreOp>(storeOp))
+ return opsAccessSameIndices(memLoadOp, memStoreOp,
+ firstToSecondPloopIVsMap);
+ if (auto vecWriteOp = dyn_cast<vector::TransferWriteOp>(storeOp))
+ return loadMatchesVectorWrite(memLoadOp, vecWriteOp,
+ firstToSecondPloopIVsMap);
+ return false;
+ })
+ .Case([&](vector::TransferReadOp vecReadOp) {
+ auto vecWriteOp = dyn_cast<vector::TransferWriteOp>(storeOp);
+ if (!vecWriteOp)
+ return false;
+ return opsAccessSameIndices(vecReadOp, vecWriteOp,
+ firstToSecondPloopIVsMap) &&
+ (vecReadOp.getMask() == vecWriteOp.getMask()) &&
+ (vecReadOp.getInBounds() == vecWriteOp.getInBounds());
+ })
+ .Case([&](vector::LoadOp vecLoadOp) {
+ auto vecStoreOp = dyn_cast<vector::StoreOp>(storeOp);
+ if (!vecStoreOp)
+ return false;
+ return opsAccessSameIndices(vecLoadOp, vecStoreOp,
+ firstToSecondPloopIVsMap) &&
+ (vecLoadOp.getAlignment() == vecStoreOp.getAlignment());
+ })
+ .Default([](Operation *) { return false; });
+ return accessSameMemory;
+}
+
+static Value getStoreOpTargetBuffer(Operation *op) {
+ return llvm::TypeSwitch<Operation *, Value>(op)
+ .Case([&](memref::StoreOp storeOp) { return storeOp.getMemRef(); })
+ .Case([&](vector::TransferWriteOp writeOp) { return writeOp.getBase(); })
+ .Case([&](vector::StoreOp vecStoreOp) { return vecStoreOp.getBase(); })
+ .Default([](Operation *) { return Value(); });
+}
+
+/// Check that the parallel loops have no mixed access to the same buffers.
+/// Return `true` if the second parallel loop does not read or write the buffers
+/// written by the first loop using different indices.
+static bool haveNoDataDependenciesExceptSameIndex(
ParallelOp firstPloop, ParallelOp secondPloop,
const IRMapping &firstToSecondPloopIndices,
llvm::function_ref<bool(Value, Value)> mayAlias) {
- DenseMap<Value, SmallVector<ValueRange, 1>> bufferStores;
- SmallVector<Value> bufferStoresVec;
- firstPloop.getBody()->walk([&](memref::StoreOp store) {
- bufferStores[store.getMemRef()].push_back(store.getIndices());
- bufferStoresVec.emplace_back(store.getMemRef());
- });
- auto walkResult = secondPloop.getBody()->walk([&](memref::LoadOp load) {
- Value loadMem = load.getMemRef();
+ // Map buffers to their store/write ops in the firstPloop
+ DenseMap<Value, SmallVector<Operation *>> bufferStoresInFirstPloop;
+ // Record all the memory buffers used in store/write ops found in firstPloop
+ llvm::SmallSetVector<Value, 4> buffersWrittenInFirstPloop;
+
+ // Walk the first parallel loop to collect all store/write ops and their
+ // target buffers
+ if (firstPloop.getBody()
+ ->walk([&](Operation *op) {
+ auto memOpInterf = dyn_cast_if_present<MemoryEffectOpInterface>(op);
+ // ignore ops that don't write to memory
+ if (!memOpInterf ||
+ (!memOpInterf.hasEffect<MemoryEffects::Write>() &&
+ !memOpInterf.hasEffect<MemoryEffects::Free>()))
+ return WalkResult::advance();
+
+ // only these memory-writing ops are supported for now:
+ // memref.store, vector.transfer_write, vector.store
+ Value storeOpBase = getStoreOpTargetBuffer(op);
+ if (!storeOpBase)
+ return WalkResult::interrupt();
+
+ // Expect the base operand to be a Memref
+ MemrefValue storeOpBaseMemref = dyn_cast<MemrefValue>(storeOpBase);
+ if (!storeOpBaseMemref)
+ return WalkResult::interrupt();
+ // Get the original memref buffer, skipping full view-like ops
+ Value buffer =
+ memref::skipFullyAliasingOperations(storeOpBaseMemref);
+ bufferStoresInFirstPloop[buffer].push_back(op);
+ buffersWrittenInFirstPloop.insert(buffer);
+ return WalkResult::advance();
+ })
+ .wasInterrupted())
+ return false;
+
+ // Walk the second parallel loop to check load/read ops against the stores
+ // collected from the first parallel loop: the loops can be fused only if in
+ // the 2nd loop there are no loads/stores from/yo the buffers written in the
+ // 1st loop, except when on the same exact memory location (same indices) as
+ // written in the 1st loop.
+ auto walkResult = secondPloop.getBody()->walk([&](Operation *loadOp) {
+ auto memOpInterf = dyn_cast_if_present<MemoryEffectOpInterface>(loadOp);
+ // ignore ops that don't read from memory
+ if (!memOpInterf || (!memOpInterf.hasEffect<MemoryEffects::Read>()))
+ return WalkResult::advance();
+ // support only these memory-reading ops for now
+ if (!isa<memref::LoadOp, vector::TransferReadOp, vector::LoadOp>(loadOp) ||
+ !isa<MemrefValue>(loadOp->getOperand(0)))
+ return WalkResult::interrupt();
+
+ MemrefValue loadOpBase = cast<MemrefValue>(loadOp->getOperand(0));
+ MemrefValue loadedOrigBuf = memref::skipFullyAliasingOperations(loadOpBase);
// Stop if the memref is defined in secondPloop body. Careful alias analysis
// is needed.
- auto *memrefDef = loadMem.getDefiningOp();
- if (memrefDef && memrefDef->getBlock() == load->getBlock())
+ auto *memrefDef = loadedOrigBuf.getDefiningOp();
+ if (memrefDef && secondPloop->isAncestor(memrefDef))
return WalkResult::interrupt();
- for (Value store : bufferStoresVec)
- if (store != loadMem && mayAlias(store, loadMem))
+ for (Value storedMem : buffersWrittenInFirstPloop)
+ if (storedMem != loadedOrigBuf && mayAlias(storedMem, loadedOrigBuf))
return WalkResult::interrupt();
- auto write = bufferStores.find(loadMem);
- if (write == bufferStores.end())
+ auto writeOpsIt = bufferStoresInFirstPloop.find(loadedOrigBuf);
+ if (writeOpsIt == bufferStoresInFirstPloop.end())
return WalkResult::advance();
+ // Store/write ops to this buffer in the firstPloop
+ auto &writeOps = writeOpsIt->second;
- // Check that at last one store was retrieved
- if (write->second.empty())
- return WalkResult::interrupt();
+ // If the first loop has no writes to this buffer, continue
+ if (writeOps.empty())
+ return WalkResult::advance();
- auto storeIndices = write->second.front();
+ Operation *writeOp = writeOps.front();
- // Multiple writes to the same memref are allowed only on the same indices
- for (const auto &othStoreIndices : write->second) {
- if (othStoreIndices != storeIndices)
- return WalkResult::interrupt();
- }
+ // In the first parallel loop, multiple writes to the same memref are
+ // allowed only on the same memory location
+ if (!llvm::all_of(writeOps, [&](Operation *otherWriteOp) {
+ return opsWriteSameMemLocation(writeOp, otherWriteOp);
+ }))
+ return WalkResult::interrupt();
- // Check that the load indices of secondPloop coincide with store indices of
- // firstPloop for the same memrefs.
- auto loadIndices...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/179284
More information about the Mlir-commits
mailing list