[Mlir-commits] [mlir] c5ae550 - [mlir][scf] Refactor and improve ParallelLoopFusion (#179284)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Feb 18 06:54:17 PST 2026
Author: fabrizio-indirli
Date: 2026-02-18T14:54:12Z
New Revision: c5ae550344d5da2d2f44523829a9af979d8af6f4
URL: https://github.com/llvm/llvm-project/commit/c5ae550344d5da2d2f44523829a9af979d8af6f4
DIFF: https://github.com/llvm/llvm-project/commit/c5ae550344d5da2d2f44523829a9af979d8af6f4.diff
LOG: [mlir][scf] Refactor and improve ParallelLoopFusion (#179284)
Refactor and extend the scf::ParalleLoopFusion pass:
- Refactor code, rename functions and add comments to improve
readability
- Make the dependency analysis safer by checking for read-after-write
dependencies also with vector.load/store & vector.transfer_read/write
ops, in addition to memref.load/store, and bail out when other
unsupported ops with memory effects are found.
- Extend the cases when the fusion is applied: allow fusing also when
one of the two loops reads/writes to memory through a full view/alias of
the buffer (read/written by the dual operation in the other loop) that
can be trivially resolved, including rank-reducing full subviews.
Added:
Modified:
mlir/include/mlir/Dialect/MemRef/Utils/MemRefUtils.h
mlir/include/mlir/Dialect/SCF/Utils/Utils.h
mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp
mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt
mlir/lib/Dialect/SCF/Transforms/ParallelLoopFusion.cpp
mlir/lib/Dialect/SCF/Utils/Utils.cpp
mlir/test/Dialect/SCF/parallel-loop-fusion.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/MemRef/Utils/MemRefUtils.h b/mlir/include/mlir/Dialect/MemRef/Utils/MemRefUtils.h
index 5d2429bb476e6..9af0f301d763c 100644
--- a/mlir/include/mlir/Dialect/MemRef/Utils/MemRefUtils.h
+++ b/mlir/include/mlir/Dialect/MemRef/Utils/MemRefUtils.h
@@ -157,6 +157,21 @@ void resolveSourceIndicesCollapseShape(Location loc, PatternRewriter &rewriter,
ValueRange indices,
SmallVectorImpl<Value> &sourceIndices);
+/// Given the 'indices' of a load/store operation where the memref is a result
+/// of a rank-reducing full subview op, returns the indices w.r.t to the source
+/// memref of the memref.subview op. For example
+///
+/// %alias = memref.subview %src[0, 0, 0][1, 2, 2][1, 1, 1]: memref<1x2x2xf32>
+/// to memref<2x2xf32>
+/// %val = memref.load %alias[%i, %j] : memref<2x2xf32>
+///
+/// could be folded into
+///
+/// %val = memref.load %src[0, %i, %j] : memref<1x2x2xf32>
+LogicalResult resolveSourceIndicesRankReducingSubview(
+ Location loc, OpBuilder &b, memref::SubViewOp subViewOp, ValueRange indices,
+ SmallVectorImpl<Value> &sourceIndices);
+
} // namespace memref
} // namespace mlir
diff --git a/mlir/include/mlir/Dialect/SCF/Utils/Utils.h b/mlir/include/mlir/Dialect/SCF/Utils/Utils.h
index c85f3b02c4a44..a758032ef69b4 100644
--- a/mlir/include/mlir/Dialect/SCF/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/SCF/Utils/Utils.h
@@ -18,6 +18,7 @@
#include "mlir/Support/LLVM.h"
#include "llvm/ADT/STLExtras.h"
#include <optional>
+#include <tuple>
namespace mlir {
class Location;
@@ -248,6 +249,12 @@ FailureOr<scf::ParallelOp> parallelLoopUnrollByFactors(
function_ref<void(unsigned, Operation *, OpBuilder)> annotateFn = nullptr,
IRMapping *clonedToSrcOpsMap = nullptr);
+/// Get constant loop bounds and steps for each of the induction variables of
+/// the given loop operation, if all the loop's ranges are constant. Each entry
+/// in the returned vector is a tuple (lowerBound, upperBound, step).
+llvm::SmallVector<std::tuple<int64_t, int64_t, int64_t>>
+getConstLoopBounds(mlir::LoopLikeOpInterface loopOp);
+
/// Get constant trip counts for each of the induction variables of the given
/// loop operation. If any of the loop's trip counts is not constant, return an
/// empty vector.
diff --git a/mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp b/mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp
index 2d341dce665e5..cf126cd85ddce 100644
--- a/mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp
+++ b/mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp
@@ -286,5 +286,46 @@ void resolveSourceIndicesCollapseShape(Location loc, PatternRewriter &rewriter,
}
}
+LogicalResult resolveSourceIndicesRankReducingSubview(
+ Location loc, OpBuilder &b, memref::SubViewOp subViewOp, ValueRange indices,
+ SmallVectorImpl<Value> &sourceIndices) {
+ if (!subViewOp.hasZeroOffset() || !subViewOp.hasUnitStride())
+ return failure();
+
+ MemRefType srcType = subViewOp.getSourceType();
+ MemRefType resType = subViewOp.getType();
+ unsigned srcRank = srcType.getRank();
+ unsigned resRank = resType.getRank();
+ if (srcRank <= resRank || indices.size() != resRank)
+ return failure();
+
+ auto droppedDims = subViewOp.getDroppedDims();
+ if (droppedDims.none() || droppedDims.count() != srcRank - resRank)
+ return failure();
+
+ auto mixedSizes = subViewOp.getMixedSizes();
+ if (mixedSizes.size() != srcRank)
+ return failure();
+
+ unsigned resultDim = 0;
+ for (unsigned sourceDim = 0; sourceDim < srcRank; ++sourceDim) {
+ if (droppedDims.test(sourceDim)) {
+ auto sizeCst = getConstantIntValue(mixedSizes[sourceDim]);
+ if (!sizeCst || *sizeCst != 1)
+ return failure();
+ sourceIndices.push_back(
+ getValueOrCreateConstantIndexOp(b, loc, b.getIndexAttr(0)));
+ continue;
+ }
+ if (resultDim >= indices.size())
+ return failure();
+ sourceIndices.push_back(indices[resultDim++]);
+ }
+ if (resultDim != indices.size())
+ return failure();
+
+ return success();
+}
+
} // namespace memref
} // namespace mlir
diff --git a/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt b/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt
index a9ffa9dc208a0..fb9aa4018d263 100644
--- a/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt
@@ -32,6 +32,7 @@ add_mlir_dialect_library(MLIRSCFTransforms
MLIRBufferizationTransforms
MLIRDestinationStyleOpInterface
MLIRDialectUtils
+ MLIRIndexDialect
MLIRIR
MLIRMemRefDialect
MLIRPass
diff --git a/mlir/lib/Dialect/SCF/Transforms/ParallelLoopFusion.cpp b/mlir/lib/Dialect/SCF/Transforms/ParallelLoopFusion.cpp
index 4ea832177c4f9..0b132e9109492 100644
--- a/mlir/lib/Dialect/SCF/Transforms/ParallelLoopFusion.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/ParallelLoopFusion.cpp
@@ -13,15 +13,31 @@
#include "mlir/Dialect/SCF/Transforms/Passes.h"
#include "mlir/Analysis/AliasAnalysis.h"
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Index/IR/IndexOps.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/SCF/Transforms/Transforms.h"
+#include "mlir/Dialect/SCF/Utils/Utils.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/IRMapping.h"
+#include "mlir/IR/Matchers.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/OperationSupport.h"
+#include "mlir/IR/PatternMatch.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/SetVector.h"
+#include "llvm/ADT/SmallBitVector.h"
+#include "llvm/ADT/TypeSwitch.h"
+
+#include <optional>
+#include <tuple>
+
namespace mlir {
#define GEN_PASS_DEF_SCFPARALLELLOOPFUSION
#include "mlir/Dialect/SCF/Transforms/Passes.h.inc"
@@ -55,114 +71,670 @@ static bool equalIterationSpaces(ParallelOp firstPloop,
matchOperands(firstPloop.getStep(), secondPloop.getStep());
}
-/// Checks if the parallel loops have mixed access to the same buffers. Returns
-/// `true` if the first parallel loop writes to the same indices that the second
-/// loop reads.
-static bool haveNoReadsAfterWriteExceptSameIndex(
+/// Check if both operations are the same type of memory write op and
+/// write to the same memory location (same buffer and same indices).
+static bool opsWriteSameMemLocation(Operation *op1, Operation *op2) {
+ if (!op1 || !op2 || op1->getName() != op2->getName())
+ return false;
+ if (op1 == op2)
+ return true;
+ // support only these memory-writing ops for now
+ if (!isa<memref::StoreOp, vector::TransferWriteOp, vector::StoreOp>(op1))
+ return false;
+ bool opsAreIdentical =
+ llvm::TypeSwitch<Operation *, bool>(op1)
+ .Case([&](memref::StoreOp storeOp1) {
+ auto storeOp2 = cast<memref::StoreOp>(op2);
+ return (storeOp1.getMemRef() == storeOp2.getMemRef()) &&
+ (storeOp1.getIndices() == storeOp2.getIndices());
+ })
+ .Case([&](vector::TransferWriteOp writeOp1) {
+ auto writeOp2 = cast<vector::TransferWriteOp>(op2);
+ return (writeOp1.getBase() == writeOp2.getBase()) &&
+ (writeOp1.getIndices() == writeOp2.getIndices()) &&
+ (writeOp1.getMask() == writeOp2.getMask()) &&
+ (writeOp1.getValueToStore().getType() ==
+ writeOp2.getValueToStore().getType()) &&
+ (writeOp1.getInBounds() == writeOp2.getInBounds());
+ })
+ .Case([&](vector::StoreOp vecStoreOp1) {
+ auto vecStoreOp2 = cast<vector::StoreOp>(op2);
+ return (vecStoreOp1.getBase() == vecStoreOp2.getBase()) &&
+ (vecStoreOp1.getIndices() == vecStoreOp2.getIndices()) &&
+ (vecStoreOp1.getValueToStore().getType() ==
+ vecStoreOp2.getValueToStore().getType()) &&
+ (vecStoreOp1.getAlignment() == vecStoreOp2.getAlignment()) &&
+ (vecStoreOp1.getNontemporal() ==
+ vecStoreOp2.getNontemporal());
+ })
+ .Default([](Operation *) { return false; });
+ return opsAreIdentical;
+}
+
+/// Check if val1 (from the first parallel loop) and val2 (from the
+/// second) are equivalent, considering the mapping of induction variables from
+/// the first to the second parallel loop.
+static bool valsAreEquivalent(Value val1, Value val2,
+ const IRMapping &loopsIVsMap) {
+ if (val1 == val2 || loopsIVsMap.lookupOrDefault(val1) == val2 ||
+ loopsIVsMap.lookupOrDefault(val2) == val1)
+ return true;
+ Operation *val1DefOp = val1.getDefiningOp();
+ Operation *val2DefOp = val2.getDefiningOp();
+ if (!val1DefOp || !val2DefOp)
+ return false;
+ if (!isMemoryEffectFree(val1DefOp) || !isMemoryEffectFree(val2DefOp))
+ return false;
+ return OperationEquivalence::isEquivalentTo(
+ val1DefOp, val2DefOp,
+ [&](Value v1, Value v2) {
+ return success(loopsIVsMap.lookupOrDefault(v1) == v2 ||
+ loopsIVsMap.lookupOrDefault(v2) == v1);
+ },
+ /*markEquivalent=*/nullptr, OperationEquivalence::Flags::IgnoreLocations);
+}
+
+/// If the `expr` value is the result of an integer addition of `base` and a
+/// constant, return the constant.
+static std::optional<int64_t> getAddConstant(Value expr, Value base,
+ const IRMapping &loopsIVsMap) {
+ if (auto addOp = expr.getDefiningOp<arith::AddIOp>()) {
+ if (auto constOp = getConstantIntValue(addOp.getLhs());
+ constOp && valsAreEquivalent(addOp.getRhs(), base, loopsIVsMap))
+ return constOp.value();
+ if (auto constOp = getConstantIntValue(addOp.getRhs());
+ constOp && valsAreEquivalent(addOp.getLhs(), base, loopsIVsMap))
+ return constOp.value();
+ return std::nullopt;
+ }
+
+ if (auto addOp = expr.getDefiningOp<index::AddOp>()) {
+ if (auto constOp = getConstantIntValue(addOp.getLhs());
+ constOp && valsAreEquivalent(addOp.getRhs(), base, loopsIVsMap))
+ return constOp.value();
+ if (auto constOp = getConstantIntValue(addOp.getRhs());
+ constOp && valsAreEquivalent(addOp.getLhs(), base, loopsIVsMap))
+ return constOp.value();
+ return std::nullopt;
+ }
+
+ if (auto applyOp = expr.getDefiningOp<affine::AffineApplyOp>()) {
+ AffineMap map = applyOp.getAffineMap();
+ if (map.getNumResults() != 1 || map.getNumDims() != 1 ||
+ map.getNumSymbols() != 0)
+ return std::nullopt;
+ if (!valsAreEquivalent(applyOp.getOperand(0), base, loopsIVsMap))
+ return std::nullopt;
+ AffineExpr result = map.getResult(0);
+ auto bin = dyn_cast<AffineBinaryOpExpr>(result);
+ if (!bin || bin.getKind() != AffineExprKind::Add)
+ return std::nullopt;
+ auto lhsDim = dyn_cast<AffineDimExpr>(bin.getLHS());
+ auto rhsDim = dyn_cast<AffineDimExpr>(bin.getRHS());
+ auto lhsConst = dyn_cast<AffineConstantExpr>(bin.getLHS());
+ auto rhsConst = dyn_cast<AffineConstantExpr>(bin.getRHS());
+ if (lhsConst && rhsDim)
+ return lhsConst.getValue();
+ if (rhsConst && lhsDim)
+ return rhsConst.getValue();
+ }
+ return std::nullopt;
+}
+
+// Return true if the scalar load index may hit any element covered by a
+// vector.store/transfer_write along a single memref dimension. Supported cases:
+//
+// 1) Direct index match (with optional offset):
+// vector.transfer_write %v, %A[%i] : vector<4xf32>, memref<...>
+// %x = memref.load %A[%i] : memref<...>
+//
+// 2) Loop IV range intersects the write range:
+// vector.transfer_write %v, %A[%c0] : vector<4xf32>, memref<...>
+// scf.for %k = %c0 to %c4 step %c1 { %x = memref.load %A[%k] }
+//
+// 3) Constant index (or IV + constant) within the write range:
+// vector.transfer_write %v, %A[%c0] : vector<4xf32>, memref<...>
+// %x = memref.load %A[%c2] : memref<...>
+// %y = memref.load %A[%i + %c1] : memref<...>
+//
+// Args:
+// - loadIndex: index used by the scalar load for this dimension.
+// - offset: subview offset for the base memref dimension (if any).
+// - writeIndex: index used by the transfer_write for this dimension. Can be
+// null if the dim was dropped by a rank reducing subview, whose result is
+// written by the vector.write.
+// - extent: vector size along this dimension (number of elements written).
+// - loopsIVsMap: IV equivalence map between fused loops.
+static bool loadIndexWithinWriteRange(Value loadIndex, OpFoldResult offset,
+ Value writeIndex, int64_t extent,
+ const IRMapping &loopsIVsMap) {
+ if (extent <= 0)
+ return false;
+
+ // Extract constant loop bounds for loop IVs (e.g. from scf.for).
+ auto getConstLoopBoundsForIV =
+ [](Value index) -> std::optional<std::tuple<int64_t, int64_t, int64_t>> {
+ auto blockArg = dyn_cast<BlockArgument>(index);
+ if (!blockArg)
+ return std::nullopt;
+ auto *parentOp = blockArg.getOwner()->getParentOp();
+ auto loopLike = dyn_cast<LoopLikeOpInterface>(parentOp);
+ if (!loopLike)
+ return std::nullopt;
+ auto ranges = getConstLoopBounds(loopLike);
+ if (ranges.empty())
+ return std::nullopt;
+
+ auto ivs = loopLike.getLoopInductionVars();
+ if (!ivs)
+ return std::nullopt;
+ auto it = llvm::find(*ivs, blockArg);
+ if (it == ivs->end())
+ return std::nullopt;
+ unsigned pos = std::distance(ivs->begin(), it);
+ if (pos >= ranges.size())
+ return std::nullopt;
+ auto [lb, ub, step] = ranges[pos];
+ return std::make_tuple(lb, ub, step);
+ };
+
+ std::optional<int64_t> offsetConst = getConstantIntValue(offset);
+ std::optional<int64_t> writeConst =
+ writeIndex ? getConstantIntValue(writeIndex) : std::optional<int64_t>(0);
+ if (!writeConst && writeIndex) {
+ // Treat single-iteration IVs as constants for matching.
+ if (auto bounds = getConstLoopBoundsForIV(writeIndex)) {
+ auto [lb, ub, step] = *bounds;
+ if (step > 0 && ub == lb + step)
+ writeConst = lb;
+ }
+ }
+
+ // Check whether a loop IV is fully contained in a constant write range.
+ auto loopIVWithinRange = [](int64_t lb, int64_t ub, int64_t step,
+ int64_t rangeStart, int64_t rangeExtent) -> bool {
+ if (rangeExtent <= 0 || step <= 0)
+ return false;
+ if (ub <= lb)
+ return false;
+ int64_t rangeEnd = rangeStart + rangeExtent;
+ return lb >= rangeStart && ub <= rangeEnd;
+ };
+
+ if (offsetConst && writeConst) {
+ // Constant start of the write range; check constant load or loop IV range.
+ int64_t start = *offsetConst + *writeConst;
+ if (auto loadConst = getConstantIntValue(loadIndex))
+ return (*loadConst >= start && *loadConst < start + extent);
+ if (auto bounds = getConstLoopBoundsForIV(loadIndex)) {
+ auto [lb, ub, step] = *bounds;
+ return loopIVWithinRange(lb, ub, step, start, extent);
+ }
+ }
+
+ if (writeIndex) {
+ // Direct IV match (or IV + constant) against the write index.
+ if (offsetConst && *offsetConst == 0 &&
+ valsAreEquivalent(loadIndex, writeIndex, loopsIVsMap))
+ return true;
+ if (auto addConst = getAddConstant(loadIndex, writeIndex, loopsIVsMap)) {
+ // Match load index of the form writeIndex + C within the write extent.
+ if (offsetConst) {
+ int64_t start = *offsetConst;
+ return (*addConst >= start && *addConst < start + extent);
+ }
+ }
+ return false;
+ }
+
+ if (auto offsetVal = dyn_cast<Value>(offset)) {
+ // Exact match when extent is 1 and the load hits the offset value.
+ if (extent == 1 && valsAreEquivalent(loadIndex, offsetVal, loopsIVsMap))
+ return true;
+ }
+
+ return false;
+}
+
+/// Return the base memref value used by the given memory op.
+static Value getBaseMemref(Operation *op) {
+ // TODO: use the common interface for memory ops once available.
+ return llvm::TypeSwitch<Operation *, Value>(op)
+ .Case([&](memref::LoadOp load) { return load.getMemRef(); })
+ .Case([&](memref::StoreOp store) { return store.getMemRef(); })
+ .Case([&](vector::TransferReadOp read) { return read.getBase(); })
+ .Case([&](vector::TransferWriteOp write) { return write.getBase(); })
+ .Case([&](vector::LoadOp load) { return load.getBase(); })
+ .Case([&](vector::StoreOp store) { return store.getBase(); })
+ .Default([](Operation *) { return Value(); });
+}
+
+/// Recognize scalar memref.load of an element produced by a vector write
+/// (vector.transfer_write or vector.store, optionally through a rank-reducing
+/// unit-stride subview) of the same buffer. This covers the pattern where a
+/// vector write stores a full lane pack and a subsequent scalar load reads an
+/// element from that lane pack. EXAMPLE:
+/// vector.transfer_write %V, %arg[%x, %y, ..., 0] {in_bounds = [true]} :
+/// vector<4xf32>, memref<4xf32, strided<[1], offset: ?>>
+/// scf.for %iter = %c0 to %c4 step %c1 iter_args(...) -> (f32) {
+/// %0 = memref.load %arg[%x, %y, ..., %iter] : memref<1x128x16x4xf32>
+/// ...
+/// }
+///
+static bool isLoadOnWrittenVector(memref::LoadOp loadOp, Value writeBase,
+ ValueRange writeIndices, VectorType vecTy,
+ ArrayRef<int64_t> vectorDimForWriteDim,
+ const IRMapping &ivsMap) {
+ if (!vecTy)
+ return false;
+
+ Value base = writeBase;
+ // The write base if there is no subview, or the subview source otherwise.
+ MemrefValue baseMemref = nullptr;
+ SmallVector<OpFoldResult> offsets;
+ llvm::SmallBitVector droppedDims;
+ bool hasSubview = false;
+ auto *ctx = loadOp.getContext();
+ if (auto subView = base.getDefiningOp<memref::SubViewOp>()) {
+ if (!subView.hasUnitStride())
+ return false;
+ baseMemref = cast<MemrefValue>(subView.getSource());
+ offsets = llvm::to_vector(subView.getMixedOffsets());
+ droppedDims = subView.getDroppedDims();
+ hasSubview = true;
+ } else {
+ baseMemref = dyn_cast<MemrefValue>(base);
+ if (!baseMemref)
+ return false;
+ }
+
+ auto loadIndices = loadOp.getIndices();
+ unsigned baseRank = baseMemref.getType().getRank();
+ if ((loadOp.getMemref() != baseMemref) || (loadIndices.size() != baseRank))
+ return false;
+
+ unsigned writeRank = writeIndices.size();
+ if ((!hasSubview && writeRank != baseRank) ||
+ (hasSubview && offsets.size() != baseRank) ||
+ (vectorDimForWriteDim.size() != writeRank))
+ return false;
+
+ auto zeroAttr = IntegerAttr::get(IndexType::get(ctx), 0);
+ unsigned writeMemrefDim = 0;
+ for (unsigned baseDim : llvm::seq(baseRank)) {
+ bool wasDropped = (hasSubview && droppedDims.test(baseDim));
+ int64_t vectorDim = !wasDropped ? vectorDimForWriteDim[writeMemrefDim] : -1;
+ int64_t extent = 1;
+ if (vectorDim >= 0) {
+ int64_t dimSize = vecTy.getDimSize(vectorDim);
+ if (dimSize == ShapedType::kDynamic)
+ return false;
+ extent = dimSize;
+ }
+ Value writeIndex = !wasDropped ? writeIndices[writeMemrefDim] : Value();
+ OpFoldResult offset =
+ hasSubview ? offsets[baseDim] : OpFoldResult(zeroAttr);
+ if (!loadIndexWithinWriteRange(loadIndices[baseDim], offset, writeIndex,
+ extent, ivsMap))
+ return false;
+ if (!wasDropped)
+ ++writeMemrefDim;
+ }
+
+ return true;
+}
+
+/// Recognize scalar memref.load of an element produced by a
+/// vector.transfer_write
+static bool loadMatchesVectorWrite(memref::LoadOp loadOp,
+ vector::TransferWriteOp writeOp,
+ const IRMapping &ivsMap) {
+ auto vecTy = dyn_cast<VectorType>(writeOp.getVector().getType());
+ if (!vecTy)
+ return false;
+
+ unsigned writeRank = writeOp.getIndices().size();
+ AffineMap permutationMap = writeOp.getPermutationMap();
+ if (!permutationMap.isProjectedPermutation() ||
+ permutationMap.getNumResults() != vecTy.getRank() ||
+ permutationMap.getNumDims() != writeRank)
+ return false;
+
+ SmallVector<int64_t> vectorDimForWriteDim(writeRank, -1);
+ for (unsigned vecDim = 0; vecDim < permutationMap.getNumResults(); ++vecDim) {
+ auto dimExpr = dyn_cast<AffineDimExpr>(permutationMap.getResult(vecDim));
+ if (!dimExpr)
+ return false;
+ unsigned writeDim = dimExpr.getPosition();
+ if (writeDim >= writeRank || vectorDimForWriteDim[writeDim] != -1)
+ return false;
+ vectorDimForWriteDim[writeDim] = vecDim;
+ }
+
+ return isLoadOnWrittenVector(loadOp, writeOp.getBase(), writeOp.getIndices(),
+ vecTy, vectorDimForWriteDim, ivsMap);
+}
+
+/// Recognize scalar memref.load of an element produced by a vector.store
+static bool loadMatchesVectorStore(memref::LoadOp loadOp,
+ vector::StoreOp storeOp,
+ const IRMapping &ivsMap) {
+ auto vecTy = dyn_cast<VectorType>(storeOp.getValueToStore().getType());
+ if (!vecTy)
+ return false;
+
+ unsigned writeRank = storeOp.getIndices().size();
+ if (vecTy.getRank() > writeRank)
+ return false;
+
+ SmallVector<int64_t> vectorDimForWriteDim(writeRank, -1);
+ unsigned vecRank = vecTy.getRank();
+ for (unsigned i = 0; i < vecRank; ++i) {
+ unsigned writeDim = writeRank - vecRank + i;
+ vectorDimForWriteDim[writeDim] = i;
+ }
+
+ return isLoadOnWrittenVector(loadOp, storeOp.getBase(), storeOp.getIndices(),
+ vecTy, vectorDimForWriteDim, ivsMap);
+}
+
+/// Check if both operations access the same positions of the same
+/// buffer, but one of the two does it through a rank-reducing full subview of
+/// the buffer (the other's base). EXAMPLE:
+/// memref.store %a, %buf[%c0, %i, %j] : memref<1x2x2xf32>
+/// %alias = memref.subview %buf[0, 0, 0][1, 2, 2][1, 1, 1]: memref<1x2x2xf32>
+/// to memref<2x2xf32>
+/// %val = memref.load %alias[%i, %j] : memref<2x2xf32>
+template <typename OpTy1, typename OpTy2>
+static bool opsAccessSameIndicesViaRankReducingSubview(
+ OpTy1 op1, OpTy2 op2, const IRMapping &firstToSecondPloopIVsMap,
+ OpBuilder &b) {
+ auto base1 = cast<MemrefValue>(getBaseMemref(op1));
+ auto base2 = cast<MemrefValue>(getBaseMemref(op2));
+ if (!base1 || !base2)
+ return false;
+
+ auto accessThroughTrivialSubviewIsSame =
+ [&b](memref::SubViewOp subView, ValueRange subViewAccess,
+ ValueRange sourceAccess, const IRMapping &ivsMap) -> bool {
+ SmallVector<Value> resolvedSubviewAccess;
+ LogicalResult resolved = resolveSourceIndicesRankReducingSubview(
+ subView.getLoc(), b, subView, subViewAccess, resolvedSubviewAccess);
+ if (failed(resolved) ||
+ (resolvedSubviewAccess.size() != sourceAccess.size()))
+ return false;
+ for (auto [dimIdx, resolvedIndex] :
+ llvm::enumerate(resolvedSubviewAccess)) {
+ if (!matchPattern(resolvedIndex, m_Zero()) &&
+ !valsAreEquivalent(resolvedIndex, sourceAccess[dimIdx], ivsMap))
+ return false;
+ }
+ return true;
+ };
+
+ // Case 1: op1 uses a subview of op2's base.
+ if (auto subView = base1.template getDefiningOp<memref::SubViewOp>();
+ subView &&
+ memref::isSameViewOrTrivialAlias(
+ base2, cast<MemrefValue>(subView.getSource())) &&
+ accessThroughTrivialSubviewIsSame(subView, op1.getIndices(),
+ op2.getIndices(),
+ firstToSecondPloopIVsMap))
+ return true;
+
+ // Case 2: op2 uses a subview of op1's base.
+ if (auto subView = base2.template getDefiningOp<memref::SubViewOp>();
+ subView &&
+ memref::isSameViewOrTrivialAlias(
+ base1, cast<MemrefValue>(subView.getSource())) &&
+ accessThroughTrivialSubviewIsSame(subView, op2.getIndices(),
+ op1.getIndices(),
+ firstToSecondPloopIVsMap))
+ return true;
+
+ return false;
+}
+
+/// Check if both memory read/write operations access the same indices
+/// (considering also the mapping of induction variables from the first to the
+/// second parallel loop).
+template <typename OpTy1, typename OpTy2>
+static bool opsAccessSameIndices(OpTy1 op1, OpTy2 op2,
+ const IRMapping &loopsIVsMap, OpBuilder &b) {
+ auto indices1 = op1.getIndices();
+ auto indices2 = op2.getIndices();
+ if (indices1.size() != indices2.size())
+ return opsAccessSameIndicesViaRankReducingSubview(op1, op2, loopsIVsMap, b);
+ for (auto [idx1, idx2] : llvm::zip(indices1, indices2)) {
+ if (!valsAreEquivalent(idx1, idx2, loopsIVsMap))
+ return false;
+ }
+ return true;
+}
+
+/// Check if the loadOp reads from the same memory location (same buffer,
+/// same indices and same properties) as written by the storeOp.
+static bool
+loadsFromSameMemoryLocationWrittenBy(Operation *loadOp, Operation *storeOp,
+ const IRMapping &firstToSecondPloopIVsMap,
+ OpBuilder &b) {
+ if (!loadOp || !storeOp)
+ return false;
+ // Support only these memory-reading ops for now
+ if (!isa<memref::LoadOp, vector::TransferReadOp, vector::LoadOp>(loadOp))
+ return false;
+ bool accessSameMemory =
+ llvm::TypeSwitch<Operation *, bool>(loadOp)
+ .Case([&](memref::LoadOp memLoadOp) {
+ if (auto memStoreOp = dyn_cast<memref::StoreOp>(storeOp))
+ return opsAccessSameIndices(memLoadOp, memStoreOp,
+ firstToSecondPloopIVsMap, b);
+ if (auto vecWriteOp = dyn_cast<vector::TransferWriteOp>(storeOp))
+ return loadMatchesVectorWrite(memLoadOp, vecWriteOp,
+ firstToSecondPloopIVsMap);
+ if (auto vecStoreOp = dyn_cast<vector::StoreOp>(storeOp))
+ return loadMatchesVectorStore(memLoadOp, vecStoreOp,
+ firstToSecondPloopIVsMap);
+ return false;
+ })
+ .Case([&](vector::TransferReadOp vecReadOp) {
+ auto vecWriteOp = dyn_cast<vector::TransferWriteOp>(storeOp);
+ if (!vecWriteOp)
+ return false;
+ return opsAccessSameIndices(vecReadOp, vecWriteOp,
+ firstToSecondPloopIVsMap, b) &&
+ (vecReadOp.getMask() == vecWriteOp.getMask()) &&
+ (vecReadOp.getInBounds() == vecWriteOp.getInBounds());
+ })
+ .Case([&](vector::LoadOp vecLoadOp) {
+ auto vecStoreOp = dyn_cast<vector::StoreOp>(storeOp);
+ if (!vecStoreOp)
+ return false;
+ return opsAccessSameIndices(vecLoadOp, vecStoreOp,
+ firstToSecondPloopIVsMap, b) &&
+ (vecLoadOp.getAlignment() == vecStoreOp.getAlignment());
+ })
+ .Default([](Operation *) { return false; });
+ return accessSameMemory;
+}
+
+static Value getStoreOpTargetBuffer(Operation *op) {
+ return llvm::TypeSwitch<Operation *, Value>(op)
+ .Case([&](memref::StoreOp storeOp) { return storeOp.getMemRef(); })
+ .Case([&](vector::TransferWriteOp writeOp) { return writeOp.getBase(); })
+ .Case([&](vector::StoreOp vecStoreOp) { return vecStoreOp.getBase(); })
+ .Default([](Operation *) { return Value(); });
+}
+
+/// To be called when `mayAlias(val1, val2)` is true. Check if the potential
+/// aliasing between the loadOp and storeOp can be resolved by analyzing their
+/// access patterns.
+static bool canResolveAlias(Operation *loadOp, Operation *storeOp,
+ const IRMapping &loopsIVsMap) {
+ if (auto transfWriteOp = dyn_cast<vector::TransferWriteOp>(storeOp);
+ transfWriteOp && isa<memref::LoadOp>(loadOp))
+ return loadMatchesVectorWrite(cast<memref::LoadOp>(loadOp), transfWriteOp,
+ loopsIVsMap);
+ if (auto vecStoreOp = dyn_cast<vector::StoreOp>(storeOp);
+ vecStoreOp && isa<memref::LoadOp>(loadOp))
+ return loadMatchesVectorStore(cast<memref::LoadOp>(loadOp), vecStoreOp,
+ loopsIVsMap);
+ return false;
+}
+
+/// Check that the parallel loops have no mixed access to the same buffers.
+/// Return `true` if the second parallel loop does not read or write the buffers
+/// written by the first loop using
diff erent indices.
+static bool haveNoDataDependenciesExceptSameIndex(
ParallelOp firstPloop, ParallelOp secondPloop,
const IRMapping &firstToSecondPloopIndices,
- llvm::function_ref<bool(Value, Value)> mayAlias) {
- DenseMap<Value, SmallVector<ValueRange, 1>> bufferStores;
- SmallVector<Value> bufferStoresVec;
- firstPloop.getBody()->walk([&](memref::StoreOp store) {
- bufferStores[store.getMemRef()].push_back(store.getIndices());
- bufferStoresVec.emplace_back(store.getMemRef());
- });
- auto walkResult = secondPloop.getBody()->walk([&](memref::LoadOp load) {
- Value loadMem = load.getMemRef();
- // Stop if the memref is defined in secondPloop body. Careful alias analysis
- // is needed.
- auto *memrefDef = loadMem.getDefiningOp();
- if (memrefDef && memrefDef->getBlock() == load->getBlock())
+ llvm::function_ref<bool(Value, Value)> mayAlias, OpBuilder &b) {
+ // Map buffers to their store/write ops in the firstPloop
+ DenseMap<Value, SmallVector<Operation *>> bufferStoresInFirstPloop;
+ // Record all the memory buffers used in store/write ops found in firstPloop
+ llvm::SmallSetVector<Value, 4> buffersWrittenInFirstPloop;
+
+ auto collectStoreOpsInWalk = [&](Operation *op) {
+ auto memOpInterf = dyn_cast_if_present<MemoryEffectOpInterface>(op);
+ // Ignore ops that don't write to memory
+ if (!memOpInterf || (!memOpInterf.hasEffect<MemoryEffects::Write>() &&
+ !memOpInterf.hasEffect<MemoryEffects::Free>()))
+ return WalkResult::advance();
+
+ // Only these memory-writing ops are supported for now:
+ // memref.store, vector.transfer_write, vector.store
+ Value storeOpBase = getStoreOpTargetBuffer(op);
+ if (!storeOpBase)
return WalkResult::interrupt();
- for (Value store : bufferStoresVec)
- if (store != loadMem && mayAlias(store, loadMem))
- return WalkResult::interrupt();
+ // Expect the base operand to be a Memref
+ MemrefValue storeOpBaseMemref = dyn_cast<MemrefValue>(storeOpBase);
+ if (!storeOpBaseMemref)
+ return WalkResult::interrupt();
+ // Get the original memref buffer, skipping full view-like ops
+ Value buffer = memref::skipFullyAliasingOperations(storeOpBaseMemref);
+ bufferStoresInFirstPloop[buffer].push_back(op);
+ buffersWrittenInFirstPloop.insert(buffer);
+ return WalkResult::advance();
+ };
- auto write = bufferStores.find(loadMem);
- if (write == bufferStores.end())
- return WalkResult::advance();
+ // Walk the first parallel loop to collect all store/write ops and their
+ // target buffers
+ if (firstPloop.getBody()->walk(collectStoreOpsInWalk).wasInterrupted())
+ return false;
- // Check that at last one store was retrieved
- if (write->second.empty())
+ // Check that this load/read op encountered while walking the second parallel
+ // loop does not have incompatible data dependencies with the store/write ops
+ // collected from the first parallel loop: the loops can be fused only if in
+ // the 2nd loop there are no loads/stores from/to the buffers written in the
+ // 1st loop, except when on the same exact memory location (same indices) as
+ // written in the 1st loop.
+ auto checkLoadInWalkHasNoIncompatibleDataDeps = [&](Operation *loadOp) {
+ auto memOpInterf = dyn_cast_if_present<MemoryEffectOpInterface>(loadOp);
+ // To be conservative, we should stop on ops that don't advertise their
+ // memory effects. However, many ops don't implement MemoryEffectOpInterface
+ // yet, so for now we just skip them.
+ // TODO: once more ops add MemoryEffectOpInterface, interrupt the walk here.
+ if (!memOpInterf &&
+ !loadOp->hasTrait<mlir::OpTrait::HasRecursiveMemoryEffects>())
+ return WalkResult::advance();
+ // Ignore ops that don't read from memory, and wrapping ops that have nested
+ // memory effects (e.g. loops, conditionals) as they will be analyzed when
+ // visiting their nested ops.
+ if ((!memOpInterf &&
+ loadOp->hasTrait<mlir::OpTrait::HasRecursiveMemoryEffects>()) ||
+ (memOpInterf && !memOpInterf.hasEffect<MemoryEffects::Read>()))
+ return WalkResult::advance();
+ // Support only these memory-reading ops for now
+ if (!isa<memref::LoadOp, vector::TransferReadOp, vector::LoadOp>(loadOp) ||
+ !isa<MemrefValue>(loadOp->getOperand(0)))
return WalkResult::interrupt();
- auto storeIndices = write->second.front();
+ MemrefValue loadOpBase = cast<MemrefValue>(loadOp->getOperand(0));
+ MemrefValue loadedOrigBuf = memref::skipFullyAliasingOperations(loadOpBase);
- // Multiple writes to the same memref are allowed only on the same indices
- for (const auto &othStoreIndices : write->second) {
- if (othStoreIndices != storeIndices)
+ for (Value storedMem : buffersWrittenInFirstPloop)
+ if ((storedMem != loadedOrigBuf) && mayAlias(storedMem, loadedOrigBuf) &&
+ !llvm::all_of(bufferStoresInFirstPloop[storedMem],
+ [&](Operation *storeOp) {
+ return canResolveAlias(loadOp, storeOp,
+ firstToSecondPloopIndices);
+ })) {
return WalkResult::interrupt();
+ }
+
+ auto writeOpsIt = bufferStoresInFirstPloop.find(loadedOrigBuf);
+ if (writeOpsIt == bufferStoresInFirstPloop.end())
+ return WalkResult::advance();
+ // Store/write ops to this buffer in the firstPloop
+ SmallVector<mlir::Operation *> &writeOps = writeOpsIt->second;
+
+ // If the first loop has no writes to this buffer, continue
+ if (writeOps.empty())
+ return WalkResult::advance();
+
+ Operation *writeOp = writeOps.front();
+
+ // In the first parallel loop, multiple writes to the same memref are
+ // allowed only on the same memory location
+ if (!llvm::all_of(writeOps, [&](Operation *otherWriteOp) {
+ return opsWriteSameMemLocation(writeOp, otherWriteOp);
+ })) {
+ return WalkResult::interrupt();
}
- // Check that the load indices of secondPloop coincide with store indices of
- // firstPloop for the same memrefs.
- auto loadIndices = load.getIndices();
- if (storeIndices.size() != loadIndices.size())
+ // Check that the load in secondPloop reads from the same memory location as
+ // written by the corresponding store in firstPloop
+ if (!loadsFromSameMemoryLocationWrittenBy(loadOp, writeOp,
+ firstToSecondPloopIndices, b)) {
return WalkResult::interrupt();
- for (int i = 0, e = storeIndices.size(); i < e; ++i) {
- if (firstToSecondPloopIndices.lookupOrDefault(storeIndices[i]) !=
- loadIndices[i]) {
- auto *storeIndexDefOp = storeIndices[i].getDefiningOp();
- auto *loadIndexDefOp = loadIndices[i].getDefiningOp();
- if (storeIndexDefOp && loadIndexDefOp) {
- if (!isMemoryEffectFree(storeIndexDefOp))
- return WalkResult::interrupt();
- if (!isMemoryEffectFree(loadIndexDefOp))
- return WalkResult::interrupt();
- if (!OperationEquivalence::isEquivalentTo(
- storeIndexDefOp, loadIndexDefOp,
- [&](Value storeIndex, Value loadIndex) {
- if (firstToSecondPloopIndices.lookupOrDefault(storeIndex) !=
- firstToSecondPloopIndices.lookupOrDefault(loadIndex))
- return failure();
- else
- return success();
- },
- /*markEquivalent=*/nullptr,
- OperationEquivalence::Flags::IgnoreLocations)) {
- return WalkResult::interrupt();
- }
- } else {
- return WalkResult::interrupt();
- }
- }
}
+
return WalkResult::advance();
- });
- return !walkResult.wasInterrupted();
+ };
+
+ // Walk the second parallel loop to check load/read ops against the stores
+ // collected from the first parallel loop.
+ return !secondPloop.getBody()
+ ->walk(checkLoadInWalkHasNoIncompatibleDataDeps)
+ .wasInterrupted();
}
-/// Analyzes dependencies in the most primitive way by checking simple read and
-/// write patterns.
-static LogicalResult
-verifyDependencies(ParallelOp firstPloop, ParallelOp secondPloop,
- const IRMapping &firstToSecondPloopIndices,
- llvm::function_ref<bool(Value, Value)> mayAlias) {
- if (!haveNoReadsAfterWriteExceptSameIndex(
- firstPloop, secondPloop, firstToSecondPloopIndices, mayAlias))
- return failure();
+/// Check that in each loop there are no read ops on the buffers written
+/// by the other loop, except when reading from the same exact memory location
+/// (same indices) as written in the other loop.
+static bool
+noIncompatibleDataDependencies(ParallelOp firstPloop, ParallelOp secondPloop,
+ const IRMapping &firstToSecondPloopIndices,
+ llvm::function_ref<bool(Value, Value)> mayAlias,
+ OpBuilder &b) {
+ if (!haveNoDataDependenciesExceptSameIndex(
+ firstPloop, secondPloop, firstToSecondPloopIndices, mayAlias, b))
+ return false;
IRMapping secondToFirstPloopIndices;
secondToFirstPloopIndices.map(secondPloop.getBody()->getArguments(),
firstPloop.getBody()->getArguments());
- return success(haveNoReadsAfterWriteExceptSameIndex(
- secondPloop, firstPloop, secondToFirstPloopIndices, mayAlias));
+ return haveNoDataDependenciesExceptSameIndex(
+ secondPloop, firstPloop, secondToFirstPloopIndices, mayAlias, b);
}
+/// Check if fusion of the two parallel loops is legal:
+/// i.e. no nested parallel loops, equal iteration spaces,
+/// and no incompatible data dependencies between the loops.
static bool isFusionLegal(ParallelOp firstPloop, ParallelOp secondPloop,
const IRMapping &firstToSecondPloopIndices,
- llvm::function_ref<bool(Value, Value)> mayAlias) {
+ llvm::function_ref<bool(Value, Value)> mayAlias,
+ OpBuilder &b) {
return !hasNestedParallelOp(firstPloop) &&
!hasNestedParallelOp(secondPloop) &&
equalIterationSpaces(firstPloop, secondPloop) &&
- succeeded(verifyDependencies(firstPloop, secondPloop,
- firstToSecondPloopIndices, mayAlias));
+ noIncompatibleDataDependencies(firstPloop, secondPloop,
+ firstToSecondPloopIndices, mayAlias, b);
}
-/// Prepends operations of firstPloop's body into secondPloop's body.
-/// Updates secondPloop with new loop.
+/// Prepend operations of firstPloop's body into secondPloop's body.
+/// Update secondPloop with new loop.
static void fuseIfLegal(ParallelOp firstPloop, ParallelOp &secondPloop,
OpBuilder builder,
llvm::function_ref<bool(Value, Value)> mayAlias) {
@@ -172,7 +744,7 @@ static void fuseIfLegal(ParallelOp firstPloop, ParallelOp &secondPloop,
firstToSecondPloopIndices.map(block1->getArguments(), block2->getArguments());
if (!isFusionLegal(firstPloop, secondPloop, firstToSecondPloopIndices,
- mayAlias))
+ mayAlias, builder))
return;
DominanceInfo dom;
@@ -272,6 +844,18 @@ struct ParallelLoopFusion
auto &aa = getAnalysis<AliasAnalysis>();
auto mayAlias = [&](Value val1, Value val2) -> bool {
+ // If the memref is defined in one of the parallel loops body, careful
+ // alias analysis is needed.
+ // TODO: check if this is still needed as a separate check.
+ auto val1Def = val1.getDefiningOp();
+ auto val2Def = val2.getDefiningOp();
+ auto val1Loop =
+ val1Def ? val1Def->getParentOfType<ParallelOp>() : nullptr;
+ auto val2Loop =
+ val2Def ? val2Def->getParentOfType<ParallelOp>() : nullptr;
+ if (val1Loop != val2Loop)
+ return true;
+
return !aa.alias(val1, val2).isNo();
};
diff --git a/mlir/lib/Dialect/SCF/Utils/Utils.cpp b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
index f8a4f057c9f0d..e795f3f0b019b 100644
--- a/mlir/lib/Dialect/SCF/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
@@ -1560,6 +1560,25 @@ bool mlir::isPerfectlyNestedForLoops(
return true;
}
+llvm::SmallVector<std::tuple<int64_t, int64_t, int64_t>>
+mlir::getConstLoopBounds(mlir::LoopLikeOpInterface loopOp) {
+ std::optional<SmallVector<OpFoldResult>> loBnds = loopOp.getLoopLowerBounds();
+ std::optional<SmallVector<OpFoldResult>> upBnds = loopOp.getLoopUpperBounds();
+ std::optional<SmallVector<OpFoldResult>> steps = loopOp.getLoopSteps();
+ if (!loBnds || !upBnds || !steps)
+ return {};
+ llvm::SmallVector<std::tuple<int64_t, int64_t, int64_t>> loopRanges;
+ for (auto [lb, ub, step] : llvm::zip(*loBnds, *upBnds, *steps)) {
+ auto lbCst = getConstantIntValue(lb);
+ auto ubCst = getConstantIntValue(ub);
+ auto stepCst = getConstantIntValue(step);
+ if (!lbCst || !ubCst || !stepCst)
+ return {};
+ loopRanges.emplace_back(*lbCst, *ubCst, *stepCst);
+ }
+ return loopRanges;
+}
+
llvm::SmallVector<llvm::APInt>
mlir::getConstLoopTripCounts(mlir::LoopLikeOpInterface loopOp) {
std::optional<SmallVector<OpFoldResult>> loBnds = loopOp.getLoopLowerBounds();
diff --git a/mlir/test/Dialect/SCF/parallel-loop-fusion.mlir b/mlir/test/Dialect/SCF/parallel-loop-fusion.mlir
index 0d4ea6f20e8d9..d876062b704f2 100644
--- a/mlir/test/Dialect/SCF/parallel-loop-fusion.mlir
+++ b/mlir/test/Dialect/SCF/parallel-loop-fusion.mlir
@@ -314,23 +314,24 @@ func.func @do_not_fuse_unmatching_read_write_patterns(
// -----
-func.func @do_not_fuse_loops_with_memref_defined_in_loop_bodies() {
+func.func @do_not_fuse_loops_with_nonfull_alias_defined_in_loop_bodies() {
%c2 = arith.constant 2 : index
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
+ %c1fp = arith.constant 1.0 : f32
%buffer = memref.alloc() : memref<2x2xf32>
- scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
+ scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c1) step (%c1, %c1) {
+ memref.store %c1fp, %buffer[%i, %j] : memref<2x2xf32>
scf.reduce
}
- scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
- %A = memref.subview %buffer[%c0, %c0][%c2, %c2][%c1, %c1]
- : memref<2x2xf32> to memref<?x?xf32, strided<[?, ?], offset: ?>>
- %A_elem = memref.load %A[%i, %j] : memref<?x?xf32, strided<[?, ?], offset: ?>>
+ scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c1) step (%c1, %c1) {
+ %A = memref.subview %buffer[%i, %c0][2, 1][1, 1] : memref<2x2xf32> to memref<2x1xf32, strided<[2, 1], offset: ?>>
+ %A_elem = memref.load %A[%i, %j] : memref<2x1xf32, strided<[2, 1], offset: ?>>
scf.reduce
}
return
}
-// CHECK-LABEL: func @do_not_fuse_loops_with_memref_defined_in_loop_bodies
+// CHECK-LABEL: func @do_not_fuse_loops_with_nonfull_alias_defined_in_loop_bodies
// CHECK: scf.parallel
// CHECK: scf.parallel
@@ -604,6 +605,415 @@ func.func @do_not_fuse_affine_apply_to_non_ind_var(
// -----
+func.func @fuse_trivial_rank_reducing_subview() {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %c2 = arith.constant 2 : index
+ %c1fp = arith.constant 1.0 : f32
+ %buf = memref.alloc() : memref<1x2x2xf32>
+ scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
+ memref.store %c1fp, %buf[%c0, %i, %j] : memref<1x2x2xf32>
+ scf.reduce
+ }
+ %sub = memref.subview %buf[0, 0, 0][1, 2, 2][1, 1, 1]
+ : memref<1x2x2xf32> to memref<2x2xf32>
+ scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
+ %v = memref.load %sub[%i, %j] : memref<2x2xf32>
+ memref.store %v, %buf[%c0, %i, %j] : memref<1x2x2xf32>
+ scf.reduce
+ }
+ memref.dealloc %buf : memref<1x2x2xf32>
+ return
+}
+// CHECK-LABEL: func @fuse_trivial_rank_reducing_subview
+// CHECK: %[[BUF:.*]] = memref.alloc() : memref<1x2x2xf32>
+// CHECK: %[[SUB:.*]] = memref.subview %[[BUF]]
+// CHECK: scf.parallel
+// CHECK: memref.store {{.*}}, %[[BUF]]
+// CHECK: %[[L:.*]] = memref.load %[[SUB]]
+// CHECK: memref.store %[[L]], %[[BUF]]
+// CHECK-NOT: scf.parallel
+// CHECK: memref.dealloc %[[BUF]] : memref<1x2x2xf32>
+
+// -----
+
+func.func @do_not_fuse_nontrivial_subview_offset() {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %c2 = arith.constant 2 : index
+ %c1fp = arith.constant 1.0 : f32
+ %buf = memref.alloc() : memref<2x2x2xf32>
+ scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
+ memref.store %c1fp, %buf[%c0, %i, %j] : memref<2x2x2xf32>
+ scf.reduce
+ }
+ %sub = memref.subview %buf[1, 0, 0][1, 2, 2][1, 1, 1]
+ : memref<2x2x2xf32> to memref<2x2xf32, strided<[2, 1], offset: 4>>
+ scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
+ %v = memref.load %sub[%i, %j]
+ : memref<2x2xf32, strided<[2, 1], offset: 4>>
+ memref.store %v, %buf[%c0, %i, %j] : memref<2x2x2xf32>
+ scf.reduce
+ }
+ memref.dealloc %buf : memref<2x2x2xf32>
+ return
+}
+// CHECK-LABEL: func @do_not_fuse_nontrivial_subview_offset
+// CHECK: scf.parallel
+// CHECK: scf.parallel
+
+// -----
+
+func.func @fuse_vector_load_store(%A: memref<4x4xf32>) {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %c4 = arith.constant 4 : index
+ %vec0 = arith.constant dense<0.0> : vector<4xf32>
+ scf.parallel (%i) = (%c0) to (%c4) step (%c1) {
+ vector.store %vec0, %A[%i, %c0] : memref<4x4xf32>, vector<4xf32>
+ scf.reduce
+ }
+ scf.parallel (%i) = (%c0) to (%c4) step (%c1) {
+ %v = vector.load %A[%i, %c0] : memref<4x4xf32>, vector<4xf32>
+ vector.store %v, %A[%i, %c0] : memref<4x4xf32>, vector<4xf32>
+ scf.reduce
+ }
+ return
+}
+// CHECK-LABEL: func @fuse_vector_load_store
+// CHECK: scf.parallel (%[[I:.*]]) = (%{{.*}}) to (%{{.*}}) step (%{{.*}}) {
+// CHECK: vector.store
+// CHECK: %[[V:.*]] = vector.load
+// CHECK: vector.store %[[V]]
+// CHECK-NOT: scf.parallel
+
+// -----
+
+func.func @do_not_fuse_vector_
diff erent_indices(%A: memref<4x4xf32>) {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %c4 = arith.constant 4 : index
+ %vec0 = arith.constant dense<0.0> : vector<4xf32>
+ scf.parallel (%i) = (%c0) to (%c4) step (%c1) {
+ vector.store %vec0, %A[%i, %c0] : memref<4x4xf32>, vector<4xf32>
+ scf.reduce
+ }
+ scf.parallel (%i) = (%c0) to (%c4) step (%c1) {
+ %j = affine.apply affine_map<(d0) -> (d0 + 1)>(%i)
+ %v = vector.load %A[%j, %c0] : memref<4x4xf32>, vector<4xf32>
+ vector.store %v, %A[%i, %c0] : memref<4x4xf32>, vector<4xf32>
+ scf.reduce
+ }
+ return
+}
+// CHECK-LABEL: func @do_not_fuse_vector_
diff erent_indices
+// CHECK: scf.parallel
+// CHECK: scf.parallel
+
+// -----
+
+func.func @fuse_vector_transfer_same_indices(%A: memref<4x4xf32>) {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %c4 = arith.constant 4 : index
+ %zero = arith.constant 0.0 : f32
+ scf.parallel (%i) = (%c0) to (%c4) step (%c1) {
+ %v = vector.transfer_read %A[%i, %c0], %zero {permutation_map = affine_map<(d0, d1) -> (d1)>, in_bounds = [true]} : memref<4x4xf32>, vector<4xf32>
+ vector.transfer_write %v, %A[%i, %c0] {permutation_map = affine_map<(d0, d1) -> (d1)>, in_bounds = [true]} : vector<4xf32>, memref<4x4xf32>
+ scf.reduce
+ }
+ scf.parallel (%i) = (%c0) to (%c4) step (%c1) {
+ %v = vector.transfer_read %A[%i, %c0], %zero {permutation_map = affine_map<(d0, d1) -> (d1)>, in_bounds = [true]} : memref<4x4xf32>, vector<4xf32>
+ vector.transfer_write %v, %A[%i, %c0] {permutation_map = affine_map<(d0, d1) -> (d1)>, in_bounds = [true]} : vector<4xf32>, memref<4x4xf32>
+ scf.reduce
+ }
+ return
+}
+// CHECK-LABEL: func @fuse_vector_transfer_same_indices
+// CHECK: scf.parallel
+// CHECK: vector.transfer_read %{{.*}}[%{{.*}}, %{{.*}}]
+// CHECK: vector.transfer_write %{{.*}}, %{{.*}}[%{{.*}}, %{{.*}}]
+// CHECK: vector.transfer_read %{{.*}}[%{{.*}}, %{{.*}}]
+// CHECK: vector.transfer_write %{{.*}}, %{{.*}}[%{{.*}}, %{{.*}}]
+// CHECK-NOT: scf.parallel
+
+// -----
+
+func.func @do_not_fuse_vector_transfer_
diff erent_indices(%A: memref<4x4xf32>) {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %c4 = arith.constant 4 : index
+ %zero = arith.constant 0.0 : f32
+ scf.parallel (%i) = (%c0) to (%c4) step (%c1) {
+ %v = vector.transfer_read %A[%i, %c0], %zero {permutation_map = affine_map<(d0, d1) -> (d1)>, in_bounds = [true]} : memref<4x4xf32>, vector<4xf32>
+ vector.transfer_write %v, %A[%i, %c0] {permutation_map = affine_map<(d0, d1) -> (d1)>, in_bounds = [true]} : vector<4xf32>, memref<4x4xf32>
+ scf.reduce
+ }
+ scf.parallel (%i) = (%c0) to (%c4) step (%c1) {
+ %j = affine.apply affine_map<(d0) -> (d0 + 1)>(%i)
+ %v = vector.transfer_read %A[%j, %c0], %zero {permutation_map = affine_map<(d0, d1) -> (d1)>, in_bounds = [true]} : memref<4x4xf32>, vector<4xf32>
+ vector.transfer_write %v, %A[%i, %c0] {permutation_map = affine_map<(d0, d1) -> (d1)>, in_bounds = [true]} : vector<4xf32>, memref<4x4xf32>
+ scf.reduce
+ }
+ return
+}
+// CHECK-LABEL: func @do_not_fuse_vector_transfer_
diff erent_indices
+// CHECK: scf.parallel
+// CHECK: scf.parallel
+
+// -----
+
+func.func @fuse_vector_transfer_with_subview(%A: memref<1x4xf32>) {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %c4 = arith.constant 4 : index
+ %zero = arith.constant 0.0 : f32
+ %vec = arith.constant dense<1.0> : vector<4xf32>
+ scf.parallel (%i) = (%c0) to (%c1) step (%c1) {
+ %sub = memref.subview %A[0, 0][1, 4][1, 1] : memref<1x4xf32> to memref<4xf32>
+ vector.transfer_write %vec, %sub[%c0] {permutation_map = affine_map<(d0) -> (d0)>, in_bounds = [true]} : vector<4xf32>, memref<4xf32>
+ scf.reduce
+ }
+ scf.parallel (%i) = (%c0) to (%c1) step (%c1) {
+ %sum = scf.for %k = %c0 to %c4 step %c1 iter_args(%acc = %zero) -> f32 {
+ %v = memref.load %A[%c0, %k] : memref<1x4xf32>
+ %n = arith.addf %v, %acc : f32
+ scf.yield %n : f32
+ }
+ memref.store %sum, %A[%c0, %c0] : memref<1x4xf32>
+ scf.reduce
+ }
+ return
+}
+// CHECK-LABEL: func @fuse_vector_transfer_with_subview
+// CHECK: scf.parallel
+// CHECK: vector.transfer_write
+// CHECK: scf.for
+// CHECK-NOT: scf.parallel
+
+// -----
+
+func.func @do_not_fuse_vector_transfer_nontrivial_subview(%A: memref<2x4xf32>) {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %zero = arith.constant 0.0 : f32
+ scf.parallel (%i) = (%c0) to (%c1) step (%c1) {
+ %v = vector.transfer_read %A[%c0, %i], %zero {permutation_map = affine_map<(d0, d1) -> (d1)>, in_bounds = [true]} : memref<2x4xf32>, vector<1xf32>
+ vector.transfer_write %v, %A[%c0, %i] {permutation_map = affine_map<(d0, d1) -> (d1)>, in_bounds = [true]} : vector<1xf32>, memref<2x4xf32>
+ scf.reduce
+ }
+ %sub = memref.subview %A[1, 0][1, 4][1, 1] : memref<2x4xf32> to memref<4xf32, strided<[1], offset: 4>>
+ scf.parallel (%i) = (%c0) to (%c1) step (%c1) {
+ %v = vector.transfer_read %sub[%i], %zero {in_bounds = [true]} : memref<4xf32, strided<[1], offset: 4>>, vector<1xf32>
+ vector.transfer_write %v, %sub[%i] {in_bounds = [true]} : vector<1xf32>, memref<4xf32, strided<[1], offset: 4>>
+ scf.reduce
+ }
+ return
+}
+// CHECK-LABEL: func @do_not_fuse_vector_transfer_nontrivial_subview
+// CHECK: scf.parallel
+// CHECK: scf.parallel
+
+// -----
+
+func.func @do_not_fuse_vector_transfer_
diff erent_masks(%A: memref<1x4xf32>) {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %zero = arith.constant 0.0 : f32
+ %mask_true = vector.create_mask %c1 : vector<1xi1>
+ %mask_false = vector.create_mask %c0 : vector<1xi1>
+ scf.parallel (%i) = (%c0) to (%c1) step (%c1) {
+ %v = vector.transfer_read %A[%c0, %i], %zero, %mask_true {permutation_map = affine_map<(d0, d1) -> (d1)>, in_bounds = [true]} : memref<1x4xf32>, vector<1xf32>
+ vector.transfer_write %v, %A[%c0, %i], %mask_true {permutation_map = affine_map<(d0, d1) -> (d1)>, in_bounds = [true]} : vector<1xf32>, memref<1x4xf32>
+ scf.reduce
+ }
+ scf.parallel (%i) = (%c0) to (%c1) step (%c1) {
+ %v = vector.transfer_read %A[%c0, %i], %zero, %mask_false {permutation_map = affine_map<(d0, d1) -> (d1)>, in_bounds = [true]} : memref<1x4xf32>, vector<1xf32>
+ vector.transfer_write %v, %A[%c0, %i], %mask_false {permutation_map = affine_map<(d0, d1) -> (d1)>, in_bounds = [true]} : vector<1xf32>, memref<1x4xf32>
+ scf.reduce
+ }
+ return
+}
+// CHECK-LABEL: func @do_not_fuse_vector_transfer_
diff erent_masks
+// CHECK: scf.parallel
+// CHECK: scf.parallel
+
+// -----
+
+func.func @fuse_vector_transfer_subview_rank_reducing(%A: memref<1x4xf32>, %B: memref<1x4xf32>) {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %c4 = arith.constant 4 : index
+ %zero = arith.constant 0.0 : f32
+ %vec = arith.constant dense<1.0> : vector<4xf32>
+ scf.parallel (%i) = (%c0) to (%c1) step (%c1) {
+ %sub = memref.subview %A[%i, %c0][1, 4][1, 1] : memref<1x4xf32> to memref<4xf32, strided<[1], offset: ?>>
+ vector.transfer_write %vec, %sub[%c0] {permutation_map = affine_map<(d0) -> (d0)>, in_bounds = [true]} : vector<4xf32>, memref<4xf32, strided<[1], offset: ?>>
+ scf.reduce
+ }
+ scf.parallel (%i) = (%c0) to (%c1) step (%c1) {
+ %sum = scf.for %k = %c0 to %c4 step %c1 iter_args(%acc = %zero) -> f32 {
+ %v = memref.load %A[%i, %k] : memref<1x4xf32>
+ %n = arith.addf %v, %acc : f32
+ scf.yield %n : f32
+ }
+ memref.store %sum, %B[%i, %c0] : memref<1x4xf32>
+ scf.reduce
+ }
+ return
+}
+// CHECK-LABEL: func @fuse_vector_transfer_subview_rank_reducing
+// CHECK: scf.parallel
+// CHECK: vector.transfer_write
+// CHECK: scf.for
+// CHECK-NOT: scf.parallel
+
+// -----
+
+func.func @do_not_fuse_vector_transfer_subview_offset(%A: memref<1x4xf32>, %B: memref<1x4xf32>) {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %c4 = arith.constant 4 : index
+ %zero = arith.constant 0.0 : f32
+ %vec = arith.constant dense<1.0> : vector<4xf32>
+ scf.parallel (%i) = (%c0) to (%c1) step (%c1) {
+ %sub = memref.subview %A[%i, %c0][1, 4][1, 1] : memref<1x4xf32> to memref<4xf32, strided<[1], offset: ?>>
+ vector.transfer_write %vec, %sub[%c0] {permutation_map = affine_map<(d0) -> (d0)>, in_bounds = [true]} : vector<4xf32>, memref<4xf32, strided<[1], offset: ?>>
+ scf.reduce
+ }
+ scf.parallel (%i) = (%c0) to (%c1) step (%c1) {
+ %sum = scf.for %k = %c0 to %c4 step %c1 iter_args(%acc = %zero) -> f32 {
+ %v = memref.load %A[%i, %k] : memref<1x4xf32>
+ %n = arith.addf %v, %acc : f32
+ scf.yield %n : f32
+ }
+ // Read from an offset alias to prevent fusion.
+ %off = memref.subview %A[%i, %c1][1, 3][1, 1] : memref<1x4xf32> to memref<3xf32, strided<[1], offset: ?>>
+ %v0 = memref.load %off[%c0] : memref<3xf32, strided<[1], offset: ?>>
+ %res = arith.addf %sum, %v0 : f32
+ memref.store %res, %B[%i, %c0] : memref<1x4xf32>
+ scf.reduce
+ }
+ return
+}
+// CHECK-LABEL: func @do_not_fuse_vector_transfer_subview_offset
+// CHECK: scf.parallel
+// CHECK: scf.parallel
+
+// -----
+
+func.func @fuse_vector_transfer_no_subview(%A: memref<1x4xf32>, %B: memref<1x4xf32>) {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %c4 = arith.constant 4 : index
+ %zero = arith.constant 0.0 : f32
+ %vec = arith.constant dense<2.0> : vector<4xf32>
+ scf.parallel (%i) = (%c0) to (%c1) step (%c1) {
+ vector.transfer_write %vec, %A[%c0, %i] {permutation_map = affine_map<(d0, d1) -> (d1)>, in_bounds = [true]} : vector<4xf32>, memref<1x4xf32>
+ scf.reduce
+ }
+ scf.parallel (%i) = (%c0) to (%c1) step (%c1) {
+ %sum = scf.for %k = %c0 to %c4 step %c1 iter_args(%acc = %zero) -> f32 {
+ %v = memref.load %A[%c0, %k] : memref<1x4xf32>
+ %n = arith.addf %v, %acc : f32
+ scf.yield %n : f32
+ }
+ memref.store %sum, %B[%c0, %c0] : memref<1x4xf32>
+ scf.reduce
+ }
+ return
+}
+// CHECK-LABEL: func @fuse_vector_transfer_no_subview
+// CHECK: vector.transfer_write
+// CHECK: scf.for
+// CHECK-NOT: scf.parallel
+
+// -----
+
+func.func @fuse_vector_transfer_scalar_load_rank2(%A: memref<2x4xf32>, %B: memref<2x4xf32>) {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %c2 = arith.constant 2 : index
+ %vec = arith.constant dense<1.0> : vector<2x4xf32>
+ scf.parallel (%i) = (%c0) to (%c1) step (%c1) {
+ vector.transfer_write %vec, %A[%c0, %c0] {permutation_map = affine_map<(d0, d1) -> (d0, d1)>, in_bounds = [true, true]} : vector<2x4xf32>, memref<2x4xf32>
+ scf.reduce
+ }
+ scf.parallel (%i) = (%c0) to (%c1) step (%c1) {
+ %v0 = memref.load %A[%c0, %c1] : memref<2x4xf32>
+ %v1 = memref.load %A[%c1, %c2] : memref<2x4xf32>
+ %sum = arith.addf %v0, %v1 : f32
+ memref.store %sum, %B[%c0, %c0] : memref<2x4xf32>
+ scf.reduce
+ }
+ return
+}
+// CHECK-LABEL: func @fuse_vector_transfer_scalar_load_rank2
+// CHECK: scf.parallel
+// CHECK: vector.transfer_write
+// CHECK: memref.load
+// CHECK: memref.load
+// CHECK-NOT: scf.parallel
+
+// -----
+
+func.func @fuse_vector_transfer_scalar_load_loop_rank2(%A: memref<2x4xf32>, %B: memref<2x4xf32>) {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %c4 = arith.constant 4 : index
+ %zero = arith.constant 0.0 : f32
+ %vec = arith.constant dense<2.0> : vector<2x4xf32>
+ scf.parallel (%i) = (%c0) to (%c1) step (%c1) {
+ vector.transfer_write %vec, %A[%c0, %c0] {permutation_map = affine_map<(d0, d1) -> (d0, d1)>, in_bounds = [true, true]} : vector<2x4xf32>, memref<2x4xf32>
+ scf.reduce
+ }
+ scf.parallel (%i) = (%c0) to (%c1) step (%c1) {
+ %sum = scf.for %k = %c0 to %c4 step %c1 iter_args(%acc = %zero) -> f32 {
+ %v = memref.load %A[%c1, %k] : memref<2x4xf32>
+ %n = arith.addf %v, %acc : f32
+ scf.yield %n : f32
+ }
+ memref.store %sum, %B[%c0, %c0] : memref<2x4xf32>
+ scf.reduce
+ }
+ return
+}
+// CHECK-LABEL: func @fuse_vector_transfer_scalar_load_loop_rank2
+// CHECK: scf.parallel
+// CHECK: vector.transfer_write
+// CHECK: scf.for
+// CHECK-NOT: scf.parallel
+
+// -----
+
+func.func @fuse_vector_store_scalar_load_rank2(%A: memref<2x4xf32>, %B: memref<2x4xf32>) {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %c2 = arith.constant 2 : index
+ %c3 = arith.constant 3 : index
+ %vec = arith.constant dense<3.0> : vector<2x4xf32>
+ scf.parallel (%i) = (%c0) to (%c1) step (%c1) {
+ vector.store %vec, %A[%c0, %c0] : memref<2x4xf32>, vector<2x4xf32>
+ scf.reduce
+ }
+ scf.parallel (%i) = (%c0) to (%c1) step (%c1) {
+ %v0 = memref.load %A[%c1, %c2] : memref<2x4xf32>
+ %v1 = memref.load %A[%c0, %c3] : memref<2x4xf32>
+ %sum = arith.addf %v0, %v1 : f32
+ memref.store %sum, %B[%c0, %c0] : memref<2x4xf32>
+ scf.reduce
+ }
+ return
+}
+// CHECK-LABEL: func @fuse_vector_store_scalar_load_rank2
+// CHECK: scf.parallel
+// CHECK: vector.store
+// CHECK: memref.load
+// CHECK: memref.load
+// CHECK-NOT: scf.parallel
+
+// -----
+
func.func @fuse_reductions_two(%A: memref<2x2xf32>, %B: memref<2x2xf32>) -> (f32, f32) {
%c2 = arith.constant 2 : index
%c0 = arith.constant 0 : index
More information about the Mlir-commits
mailing list