[Mlir-commits] [mlir] [mlir][scf] Refactor and improve ParallelLoopFusion (PR #179284)
Ivan Butygin
llvmlistbot at llvm.org
Mon Feb 16 11:35:43 PST 2026
================
@@ -55,124 +71,690 @@ static bool equalIterationSpaces(ParallelOp firstPloop,
matchOperands(firstPloop.getStep(), secondPloop.getStep());
}
-/// Checks if the parallel loops have mixed access to the same buffers. Returns
-/// `true` if the first parallel loop writes to the same indices that the second
-/// loop reads.
-static bool haveNoReadsAfterWriteExceptSameIndex(
+/// Check if both operations are the same type of memory write op and
+/// write to the same memory location (same buffer and same indices).
+static bool opsWriteSameMemLocation(Operation *op1, Operation *op2) {
+ if (!op1 || !op2 || op1->getName() != op2->getName())
+ return false;
+ if (op1 == op2)
+ return true;
+ // support only these memory-writing ops for now
+ if (!isa<memref::StoreOp, vector::TransferWriteOp, vector::StoreOp>(op1))
+ return false;
+ bool opsAreIdentical =
+ llvm::TypeSwitch<Operation *, bool>(op1)
+ .Case([&](memref::StoreOp storeOp1) {
+ auto storeOp2 = cast<memref::StoreOp>(op2);
+ return (storeOp1.getMemRef() == storeOp2.getMemRef()) &&
+ (storeOp1.getIndices() == storeOp2.getIndices());
+ })
+ .Case([&](vector::TransferWriteOp writeOp1) {
+ auto writeOp2 = cast<vector::TransferWriteOp>(op2);
+ return (writeOp1.getBase() == writeOp2.getBase()) &&
+ (writeOp1.getIndices() == writeOp2.getIndices()) &&
+ (writeOp1.getMask() == writeOp2.getMask()) &&
+ (writeOp1.getValueToStore().getType() ==
+ writeOp2.getValueToStore().getType()) &&
+ (writeOp1.getInBounds() == writeOp2.getInBounds());
+ })
+ .Case([&](vector::StoreOp vecStoreOp1) {
+ auto vecStoreOp2 = cast<vector::StoreOp>(op2);
+ return (vecStoreOp1.getBase() == vecStoreOp2.getBase()) &&
+ (vecStoreOp1.getIndices() == vecStoreOp2.getIndices()) &&
+ (vecStoreOp1.getValueToStore().getType() ==
+ vecStoreOp2.getValueToStore().getType()) &&
+ (vecStoreOp1.getAlignment() == vecStoreOp2.getAlignment()) &&
+ (vecStoreOp1.getNontemporal() ==
+ vecStoreOp2.getNontemporal());
+ })
+ .Default([](Operation *) { return false; });
+ return opsAreIdentical;
+}
+
+/// Check if val1 (from the first parallel loop) and val2 (from the
+/// second) are equivalent, considering the mapping of induction variables from
+/// the first to the second parallel loop.
+static bool valsAreEquivalent(Value val1, Value val2,
+ const IRMapping &loopsIVsMap) {
+ if (val1 == val2 || loopsIVsMap.lookupOrDefault(val1) == val2 ||
+ loopsIVsMap.lookupOrDefault(val2) == val1)
+ return true;
+ Operation *val1DefOp = val1.getDefiningOp();
+ Operation *val2DefOp = val2.getDefiningOp();
+ if (!val1DefOp || !val2DefOp)
+ return false;
+ if (!isMemoryEffectFree(val1DefOp) || !isMemoryEffectFree(val2DefOp))
+ return false;
+ return OperationEquivalence::isEquivalentTo(
+ val1DefOp, val2DefOp,
+ [&](Value v1, Value v2) {
+ return success(loopsIVsMap.lookupOrDefault(v1) == v2 ||
+ loopsIVsMap.lookupOrDefault(v2) == v1);
+ },
+ /*markEquivalent=*/nullptr, OperationEquivalence::Flags::IgnoreLocations);
+}
+
+static std::optional<int64_t> getConstantIndex(Value value) {
+ if (auto constOp = value.getDefiningOp<arith::ConstantIndexOp>())
+ return constOp.value();
+ return std::nullopt;
+}
+
+/// If the `expr` value is the result of an integer addition of `base` and a
+/// constant, return the constant.
+static std::optional<int64_t> getAddConstant(Value expr, Value base,
+ const IRMapping &loopsIVsMap) {
+ if (auto addOp = expr.getDefiningOp<arith::AddIOp>()) {
+ if (auto constOp = getConstantIntValue(addOp.getLhs());
+ constOp && valsAreEquivalent(addOp.getRhs(), base, loopsIVsMap))
+ return constOp.value();
+ if (auto constOp = getConstantIntValue(addOp.getRhs());
+ constOp && valsAreEquivalent(addOp.getLhs(), base, loopsIVsMap))
+ return constOp.value();
+ return std::nullopt;
+ }
+
+ if (auto addOp = expr.getDefiningOp<index::AddOp>()) {
+ if (auto constOp = getConstantIntValue(addOp.getLhs());
+ constOp && valsAreEquivalent(addOp.getRhs(), base, loopsIVsMap))
+ return constOp.value();
+ if (auto constOp = getConstantIntValue(addOp.getRhs());
+ constOp && valsAreEquivalent(addOp.getLhs(), base, loopsIVsMap))
+ return constOp.value();
+ return std::nullopt;
+ }
+
+ if (auto applyOp = expr.getDefiningOp<affine::AffineApplyOp>()) {
+ AffineMap map = applyOp.getAffineMap();
+ if (map.getNumResults() != 1 || map.getNumDims() != 1 ||
+ map.getNumSymbols() != 0)
+ return std::nullopt;
+ if (!valsAreEquivalent(applyOp.getOperand(0), base, loopsIVsMap))
+ return std::nullopt;
+ AffineExpr result = map.getResult(0);
+ auto bin = dyn_cast<AffineBinaryOpExpr>(result);
+ if (!bin || bin.getKind() != AffineExprKind::Add)
+ return std::nullopt;
+ auto lhsDim = dyn_cast<AffineDimExpr>(bin.getLHS());
+ auto rhsDim = dyn_cast<AffineDimExpr>(bin.getRHS());
+ auto lhsConst = dyn_cast<AffineConstantExpr>(bin.getLHS());
+ auto rhsConst = dyn_cast<AffineConstantExpr>(bin.getRHS());
+ if (lhsConst && rhsDim)
+ return lhsConst.getValue();
+ if (rhsConst && lhsDim)
+ return rhsConst.getValue();
+ }
+ return std::nullopt;
+}
+
+// Return true if the scalar load index may hit any element covered by a
+// vector.store/transfer_write along a single memref dimension. Supported cases:
+//
+// 1) Direct index match (with optional offset):
+// vector.transfer_write %v, %A[%i] : vector<4xf32>, memref<...>
+// %x = memref.load %A[%i] : memref<...>
+//
+// 2) Loop IV range intersects the write range:
+// vector.transfer_write %v, %A[%c0] : vector<4xf32>, memref<...>
+// scf.for %k = %c0 to %c4 step %c1 { %x = memref.load %A[%k] }
+//
+// 3) Constant index (or IV + constant) within the write range:
+// vector.transfer_write %v, %A[%c0] : vector<4xf32>, memref<...>
+// %x = memref.load %A[%c2] : memref<...>
+// %y = memref.load %A[%i + %c1] : memref<...>
+//
+// Args:
+// - loadIndex: index used by the scalar load for this dimension.
+// - offset: subview offset for the base memref dimension (if any).
+// - writeIndex: index used by the transfer_write for this dimension.
+// - extent: vector size along this dimension (number of elements written).
+// - loopsIVsMap: IV equivalence map between fused loops.
+static bool loadIndexWithinWriteRange(Value loadIndex, OpFoldResult offset,
+ Value writeIndex, int64_t extent,
+ const IRMapping &loopsIVsMap) {
+ if (extent <= 0)
+ return false;
+
+ // Extract constant loop bounds for IVs (scf.for / scf.parallel).
+ auto getConstLoopBoundsForIV =
+ [](Value index) -> std::optional<std::tuple<int64_t, int64_t, int64_t>> {
+ auto blockArg = dyn_cast<BlockArgument>(index);
+ if (!blockArg)
+ return std::nullopt;
+ auto *parentOp = blockArg.getOwner()->getParentOp();
+ auto loopLike = dyn_cast<LoopLikeOpInterface>(parentOp);
+ if (!loopLike)
+ return std::nullopt;
+ auto ranges = getConstLoopBounds(loopLike);
+ if (ranges.empty())
+ return std::nullopt;
+
+ auto ivs = loopLike.getLoopInductionVars();
+ if (!isa<scf::ForOp, scf::ParallelOp>(parentOp) || !ivs)
+ return std::nullopt;
+ auto it = llvm::find(*ivs, blockArg);
+ if (it == ivs->end())
+ return std::nullopt;
+ unsigned pos = std::distance(ivs->begin(), it);
+ if (pos >= ranges.size())
+ return std::nullopt;
+ auto [lb, ub, step] = ranges[pos];
+ return std::make_tuple(lb, ub, step);
+ };
+
+ std::optional<int64_t> offsetConst = getConstantIntValue(offset);
+ std::optional<int64_t> writeConst =
+ writeIndex ? getConstantIndex(writeIndex) : std::optional<int64_t>(0);
+ if (!writeConst && writeIndex) {
+ // Treat single-iteration IVs as constants for matching.
+ if (auto bounds = getConstLoopBoundsForIV(writeIndex)) {
+ auto [lb, ub, step] = *bounds;
+ if (step > 0 && ub == lb + step)
+ writeConst = lb;
+ }
+ }
+
+ // Check whether a loop IV is fully contained in a constant write range.
+ auto loopIVWithinRange = [](int64_t lb, int64_t ub, int64_t step,
+ int64_t rangeStart, int64_t rangeExtent) -> bool {
+ if (rangeExtent <= 0 || step <= 0)
+ return false;
+ if (ub <= lb)
+ return false;
+ int64_t rangeEnd = rangeStart + rangeExtent;
+ return lb >= rangeStart && ub <= rangeEnd;
+ };
+
+ if (offsetConst && writeConst) {
+ // Constant start of the write range; check constant load or loop IV range.
+ int64_t start = *offsetConst + *writeConst;
+ if (auto loadConst = getConstantIndex(loadIndex))
+ return (*loadConst >= start && *loadConst < start + extent);
+ if (auto bounds = getConstLoopBoundsForIV(loadIndex)) {
+ auto [lb, ub, step] = *bounds;
+ return loopIVWithinRange(lb, ub, step, start, extent);
+ }
+ }
+
+ if (writeIndex) {
+ // Direct IV match (or IV + constant) against the write index.
+ if (offsetConst && *offsetConst == 0 &&
+ valsAreEquivalent(loadIndex, writeIndex, loopsIVsMap))
+ return true;
+ if (auto addConst = getAddConstant(loadIndex, writeIndex, loopsIVsMap)) {
+ // Match load index of the form writeIndex + C within the write extent.
+ if (offsetConst) {
+ int64_t start = *offsetConst;
+ return (*addConst >= start && *addConst < start + extent);
+ }
+ }
+ return false;
+ }
+
+ if (auto offsetVal = offset.dyn_cast<Value>()) {
----------------
Hardcode84 wrote:
`x.dyn_cast<>` is deprecated in favor of `dyn_cast<>(x)`
https://github.com/llvm/llvm-project/pull/179284
More information about the Mlir-commits
mailing list