[Mlir-commits] [mlir] ce4f99e - [mlir][Linalg] Add comprehensive bufferization support for subtensor (5/n)
Nicolas Vasilache
llvmlistbot at llvm.org
Thu May 27 05:51:10 PDT 2021
Author: Nicolas Vasilache
Date: 2021-05-27T12:48:08Z
New Revision: ce4f99e7f272d481d0689c551d9818019992c841
URL: https://github.com/llvm/llvm-project/commit/ce4f99e7f272d481d0689c551d9818019992c841
DIFF: https://github.com/llvm/llvm-project/commit/ce4f99e7f272d481d0689c551d9818019992c841.diff
LOG: [mlir][Linalg] Add comprehensive bufferization support for subtensor (5/n)
This revision refactors and simplifies the pattern detection logic: thanks to SSA value properties, we can actually look at all the uses of a given value and avoid having to pattern-match specific chains of operations.
A bufferization pattern for subtensor is added and specific inplaceability analysis is implemented for the simple case of subtensor. More advanced use cases will follow.
Differential revision: https://reviews.llvm.org/D102512
Added:
Modified:
mlir/include/mlir/Dialect/Linalg/Passes.td
mlir/include/mlir/Interfaces/ViewLikeInterface.h
mlir/include/mlir/Interfaces/ViewLikeInterface.td
mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp
mlir/lib/Interfaces/ViewLikeInterface.cpp
mlir/test/Dialect/Linalg/comprehensive-func-bufferize.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.td b/mlir/include/mlir/Dialect/Linalg/Passes.td
index 2023f8820d0a4..d36d655638a2d 100644
--- a/mlir/include/mlir/Dialect/Linalg/Passes.td
+++ b/mlir/include/mlir/Dialect/Linalg/Passes.td
@@ -34,6 +34,11 @@ def LinalgComprehensiveFuncBufferize :
same buffers. The analysis is performed on SSA use-def chains starting from
function operands that are annotated with the 'inplaceable' attribute
}];
+ let options = [
+ Option<"testAnalysisOnly", "test-analysis-only", "bool",
+ /*default=*/"false",
+ "Only runs inplaceability analysis (for testing purposes only)">
+ ];
let constructor = "mlir::createLinalgComprehensiveFuncBufferizePass()";
}
diff --git a/mlir/include/mlir/Interfaces/ViewLikeInterface.h b/mlir/include/mlir/Interfaces/ViewLikeInterface.h
index 60f20432d3c3b..0094fffeea966 100644
--- a/mlir/include/mlir/Interfaces/ViewLikeInterface.h
+++ b/mlir/include/mlir/Interfaces/ViewLikeInterface.h
@@ -32,6 +32,10 @@ class OffsetSizeAndStrideOpInterface;
namespace detail {
LogicalResult verifyOffsetSizeAndStrideOp(OffsetSizeAndStrideOpInterface op);
+
+bool sameOffsetsSizesAndStrides(
+ OffsetSizeAndStrideOpInterface a, OffsetSizeAndStrideOpInterface b,
+ llvm::function_ref<bool(OpFoldResult, OpFoldResult)> cmp);
} // namespace detail
} // namespace mlir
diff --git a/mlir/include/mlir/Interfaces/ViewLikeInterface.td b/mlir/include/mlir/Interfaces/ViewLikeInterface.td
index f94350fe50e31..e26a02f61966a 100644
--- a/mlir/include/mlir/Interfaces/ViewLikeInterface.td
+++ b/mlir/include/mlir/Interfaces/ViewLikeInterface.td
@@ -419,6 +419,23 @@ def OffsetSizeAndStrideOpInterface : OpInterface<"OffsetSizeAndStrideOpInterface
return $_op.getOperand(getIndexOfDynamicStride(idx));
}]
>,
+ InterfaceMethod<
+ /*desc=*/[{
+ Return true if all `other`'s offsets, sizes and strides are the same.
+ Takes a custom `cmp` comparison function on OpFoldResult to avoid taking
+ a dialect dependence.
+ }],
+ /*retTy=*/"bool",
+ /*methodName=*/"isSameAs",
+ /*args=*/(ins "OffsetSizeAndStrideOpInterface":$other,
+ "llvm::function_ref<bool(OpFoldResult, OpFoldResult)>":$cmp),
+ /*methodBody=*/"",
+ /*defaultImplementation=*/[{
+ return detail::sameOffsetsSizesAndStrides(
+ ::mlir::cast<::mlir::OffsetSizeAndStrideOpInterface>(
+ $_op.getOperation()), other, cmp);
+ }]
+ >,
];
let extraClassDeclaration = [{
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp
index 4ad8095505025..d14281a6b19c8 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp
@@ -88,6 +88,7 @@
#include "mlir/Transforms/BufferUtils.h"
#include "llvm/ADT/ScopeExit.h"
+#include "llvm/ADT/SetVector.h"
#include "llvm/ADT/TypeSwitch.h"
#define DEBUG_TYPE "comprehensive-func-bufferize"
@@ -152,12 +153,18 @@ OpResult getMatchingOpResult(SubTensorInsertOp op, OpOperand &opOperand) {
/// analysis to determine which op results reuse the same buffer as some
/// operand.
OpResult getMatchingOpResult(OpOperand &opOperand) {
- OpResult res =
- llvm::TypeSwitch<Operation *, OpResult>(opOperand.getOwner())
- .Case<LinalgOp, SubTensorInsertOp, VectorTransferOpInterface>(
- [&](auto op) { return getMatchingOpResult(op, opOperand); })
- .Default([&](Operation *op) { return OpResult(); });
- return res;
+ return llvm::TypeSwitch<Operation *, OpResult>(opOperand.getOwner())
+ // clang-format off
+ // Ops that perform destructive updates on operand(s) to produce
+ // result(s).
+ .Case<LinalgOp,
+ SubTensorInsertOp,
+ VectorTransferOpInterface>(
+ [&](auto op) { return getMatchingOpResult(op, opOperand); })
+ // Other ops.
+ .Case<SubTensorOp>([&](auto op) { return OpResult(); })
+ .Default([&](Operation *op) { return OpResult(); });
+ // clang-format on
}
//===----------------------------------------------------------------------===//
@@ -290,70 +297,6 @@ static Value lookup(BlockAndValueMapping &bvm, Value key) {
return bvm.lookup(key);
}
-//===----------------------------------------------------------------------===//
-// Bufferization-specific support.
-//===----------------------------------------------------------------------===//
-
-/// Determine whether any subsequent read of the tensor `opOperand` may occur.
-/// For now, this assumes any use is a read. If any use of the tensor does not
-/// properly dominate `opOperand.getOwner()`, then the tensor cannot be
-/// bufferized inPlace.
-// TODO: For now, this assumes any use is a read. Refine this.
-bool hasInterferingTensorRead(OpOperand &opOperand,
- const DominanceInfo &domInfo) {
- if (!opOperand.get().getType().isa<RankedTensorType>())
- return false;
- for (auto &use : opOperand.get().getUses()) {
- Operation *user = use.getOwner();
-
- // If properly dominate, there is a clear sequence point and we can dismiss
- // read.
- if (domInfo.properlyDominates(user, opOperand.getOwner()))
- continue;
- // Otherwise, we need to analyze self-dependencies, for now just let it go.
- // TODO: proper self-dependence analysis.
- if (domInfo.dominates(user, opOperand.getOwner()))
- continue;
- if (user == opOperand.getOwner() &&
- use.getOperandNumber() == opOperand.getOperandNumber())
- continue;
- LLVM_DEBUG(DBGS() << "found interfering read operand #"
- << opOperand.getOperandNumber()
- << " in op: " << *opOperand.getOwner() << "\n");
- return true;
- }
- LLVM_DEBUG(DBGS() << "no interfering read\n");
- return false;
-}
-
-/// Return false if either:
-/// 1. `opOperand` is produced by a constant op. For now this is assumed to be
-/// bufferized to a GlobalMemrefOp that cannot be written. Generalize in the
-/// future.
-/// 2.`opOperand` is a BlockArgument of a FuncOp that is not known to be
-/// bufferizable inplace.
-/// 3.`opOperand` has an interfering tensor read.
-/// Return true otherwise.
-bool isBufferizableInPlace(OpOperand &opOperand, const DominanceInfo &domInfo) {
- // Constant tensors are deemed not bufferizable for now.
- if (auto constantOp =
- dyn_cast_or_null<ConstantOp>(opOperand.get().getDefiningOp()))
- return !constantOp.getResult().getType().isa<RankedTensorType>();
- if (auto bbArg = opOperand.get().dyn_cast<BlockArgument>()) {
- // Uses of function arguments that may not be written-to need to be copied.
- // If the function argument itself is not inplaceable, early return false.
- // If is is inplaceable, interfering tensor read need to be checked.
- //
- // TODO: better propagate the fact that we want a single clone inside the
- // function. Atm every user that wants to write inplace will create its own
- // alloc, irrespective of whether or not interfering reads occur.
- if (isa<FuncOp>(bbArg.getOwner()->getParentOp()))
- if (getInPlace(bbArg) != InPlaceSpec::True)
- return false;
- }
- return !hasInterferingTensorRead(opOperand, domInfo);
-}
-
//===----------------------------------------------------------------------===//
// Bufferization-specific MemRefType support.
//===----------------------------------------------------------------------===//
@@ -399,26 +342,6 @@ static MemRefType getDynamicMemRefType(RankedTensorType tensorType,
stridedLayout, addressSpace);
}
-//===----------------------------------------------------------------------===//
-// Bufferization-specific inPlace pattern matching support.
-//===----------------------------------------------------------------------===//
-
-/// First assign `op` if `slice.back()` isa `T`, then check condition.
-/// If anything fails just return failure. Otherwise update `sliceRef` by
-/// dropping `sliceRef.back()`, then return success().
-template <typename T>
-static LogicalResult
-matchAndDropBack(ArrayRef<Operation *> &sliceRef, T &op,
- llvm::function_ref<LogicalResult(T)> condition = nullptr) {
- if (sliceRef.empty())
- return failure();
- op = dyn_cast<T>(sliceRef.back());
- if (!op || (condition && failed(condition(op))))
- return failure();
- sliceRef = sliceRef.drop_back();
- return success();
-}
-
//===----------------------------------------------------------------------===//
// Bufferization-specific scoped alloc/dealloc insertion support.
//===----------------------------------------------------------------------===//
@@ -470,121 +393,6 @@ static Value createNewAllocDeallocPairForShapedValue(
return casted;
}
-//===----------------------------------------------------------------------===//
-// Bufferization-specific inPlace analysis support.
-//===----------------------------------------------------------------------===//
-
-/// Detect the simple terminator pattern:
-/// ```
-/// candidate -> ... -> inplaceable_op(candidate) -> term
-/// ```
-template <typename ContainerOp, typename TerminatorOp>
-static LogicalResult detectInplaceOpToTerminator(Operation *parentOp,
- BlockArgument candidate,
- ArrayRef<Operation *> slice) {
- assert(parentOp && "Unexpected null parent op");
- if (!isa<ContainerOp>(parentOp))
- return failure();
- TerminatorOp terminatorOp;
- // Match returnOp and update slice.
- if (failed(matchAndDropBack(slice, terminatorOp))) {
- LLVM_DEBUG(DBGS() << "FAIL: inplaceOpToTerm pattern -> slice must end with "
- "a known terminator\n");
- return failure();
- }
- return success();
-}
-
-/// The following uses internal knowledge of the position of tied operand /
-/// results.
-static void propagateInPlace(const SmallVector<OpOperand *> &initalWorklist,
- const DominanceInfo &domInfo) {
- LLVM_DEBUG(DBGS() << "\n\n");
- LLVM_DEBUG(DBGS() << "Start propagateInPlace from initial WL\n");
- LLVM_DEBUG(for (OpOperand *operand
- : initalWorklist) DBGS()
- << "WL item: " << operand->get() << " used by "
- << *operand->getOwner() << "\n");
- SmallVector<OpOperand *> worklist(initalWorklist);
- for (unsigned idx = 0; idx < worklist.size(); ++idx) {
- // TODO: bail on subtensor/subtensor_insert and vector.transfer_read/write
- // that should have been already captured in destructive update patterns?
- OpOperand &operand = *worklist[idx];
- LLVM_DEBUG(DBGS() << "WL item: " << *operand.getOwner() << "\n");
- // If the owner turns out to be a CallOp without
- // `kWriteableFuncBufferArgsAttrName` this will be a noop.
- if (isBufferizableInPlace(operand, domInfo)) {
- LLVM_DEBUG(DBGS() << "bufferizable inplace\n");
- setInPlaceOpResult(getMatchingOpResult(operand));
- }
- LLVM_DEBUG(DBGS() << "propagatedInPlace: " << *operand.getOwner() << "\n");
- // use can have interfering reads that prevent it from being written inPlace
- // but the values it produces are still themselves candidates for inPlace at
- // their point of use.
- for (Value v : operand.getOwner()->getResults()) {
- LLVM_DEBUG(DBGS() << "propagate result: " << v << "\n");
- for (auto &use : v.getUses()) {
- LLVM_DEBUG(DBGS() << "add use to WL: " << use.get() << "\n");
- worklist.push_back(&use);
- }
- }
- }
- LLVM_DEBUG(DBGS() << "\n\n");
-}
-
-static void propagateInPlace(BlockArgument &bbArg,
- const DominanceInfo &domInfo) {
- SmallVector<OpOperand *> worklist;
- for (auto &use : bbArg.getUses())
- worklist.push_back(&use);
- propagateInPlace(worklist, domInfo);
-}
-
-/// Iterate over bbArgs of `parentOp` and determine if they are the root of a
-/// known destructive update chain. Such a destructive update is related to
-/// traditional loop nest + memory analysis but provides a simpler SSA use-def
-/// chain-based abstraction.
-static void destructiveUpdateAnalysis(Block *block,
- const DominanceInfo &domInfo) {
- Operation *parentOp = block->getParentOp();
- for (BlockArgument candidate : block->getArguments()) {
- LLVM_DEBUG(llvm::dbgs() << "\n\n");
- LLVM_DEBUG(DBGS() << "Destructive update analysis on candidate: "
- << candidate << "\nof:\n"
- << *parentOp << "\n");
-
- if (!candidate.getType().isa<ShapedType>()) {
- LLVM_DEBUG(DBGS() << "Not a tensor\n");
- continue;
- }
-
- // FuncOp arguments must be inplaceable otherwise they cannot be the root of
- // a destructive update chain.
- if (isa<FuncOp>(parentOp) && getInPlace(candidate) != InPlaceSpec::True) {
- LLVM_DEBUG(DBGS() << "Not inplace\n");
- continue;
- }
-
- llvm::SetVector<Operation *> slice;
- getForwardSlice(candidate, &slice,
- [&](Operation *op) { return op->getBlock() == block; });
-
- LLVM_DEBUG(DBGS() << "Slice:\n");
- LLVM_DEBUG(for (auto *op : slice) DBGS() << *op << "\n");
-
- bool failedDetectingDestructiveUpdate =
- // func / return inplace patterns.
- failed(detectInplaceOpToTerminator<FuncOp, ReturnOp>(
- parentOp, candidate, slice.getArrayRef()));
- if (failedDetectingDestructiveUpdate) {
- LLVM_DEBUG(DBGS() << "Failed to detect a destructive update pattern\n");
- continue;
- }
-
- propagateInPlace(candidate, domInfo);
- }
-}
-
//===----------------------------------------------------------------------===//
// Bufferization as simple BlockAndValueMapping rewrites.
//===----------------------------------------------------------------------===//
@@ -748,6 +556,55 @@ static LogicalResult bufferize(OpBuilder &b, ReturnOp returnOp,
return success();
}
+/// Bufferize SubTensorOp to subview with optional alloc + copy depending on
+/// whether or not it is marked inplaceable.
+/// Note that `getMatchingOpResult` on a SubTensorOp always returns null.
+/// As consequence a SubTensorOp always alloc + copy when taken in isolation.
+static LogicalResult bufferize(OpBuilder &b, SubTensorOp subTensorOp,
+ BlockAndValueMapping &bvm) {
+ LLVM_DEBUG(DBGS() << "bufferize: " << *subTensorOp << "\n");
+
+ // Take a guard before anything else.
+ OpBuilder::InsertionGuard g(b);
+ b.setInsertionPoint(subTensorOp);
+
+ Location loc = subTensorOp.getLoc();
+ // Bail if source was not bufferized.
+ Value srcMemref = lookup(bvm, subTensorOp.source());
+ if (!srcMemref)
+ return failure();
+ auto srcMemrefType = srcMemref.getType().cast<MemRefType>();
+ auto dstTensorType = subTensorOp.result().getType().cast<RankedTensorType>();
+
+ // If not inplaceable, alloc.
+ Value alloc;
+ auto inPlace = getInPlace(subTensorOp->getResult(0));
+ if (inPlace != InPlaceSpec::True) {
+ alloc =
+ createNewAllocDeallocPairForShapedValue(b, loc, subTensorOp.result());
+ b.setInsertionPointAfter(alloc.getDefiningOp());
+ }
+
+ // Bufferize to subview.
+ auto subviewMemRefType =
+ memref::SubViewOp::inferRankReducedResultType(
+ dstTensorType.getRank(), srcMemrefType, subTensorOp.getMixedOffsets(),
+ subTensorOp.getMixedSizes(), subTensorOp.getMixedStrides())
+ .cast<MemRefType>();
+ Value subView = b.create<memref::SubViewOp>(
+ loc, subviewMemRefType, srcMemref, subTensorOp.getMixedOffsets(),
+ subTensorOp.getMixedSizes(), subTensorOp.getMixedStrides());
+
+ /// If not inplaceable, copy.
+ if (alloc) {
+ b.create<CopyOp>(subTensorOp.getLoc(), subView, alloc);
+ subView = alloc;
+ }
+
+ map(bvm, subTensorOp.result(), subView);
+ return success();
+}
+
static LogicalResult bufferize(OpBuilder &b,
SubTensorInsertOp subTensorInsertOp,
BlockAndValueMapping &bvm) {
@@ -765,7 +622,7 @@ static LogicalResult bufferize(OpBuilder &b,
if (inPlace != InPlaceSpec::True) {
// Since subtensor_insert arise from tiling and introducing loops, this case
// is generally a deal breaker. When used with loops, this ends up cloning
- // the whole tensor on every single iteration and is a symtpom of a
+ // the whole tensor on every single iteration and is a symptom of a
// catastrophically bad scheduling decision.
// TODO: be very loud about it or even consider failing the pass.
Value newDstMemref = createNewAllocDeallocPairForShapedValue(
@@ -865,13 +722,165 @@ static LogicalResult bufferize(OpBuilder &b, VectorTransferOpInterface op,
return success();
}
+//===----------------------------------------------------------------------===//
+// Functions and calls bufferization support.
+//===----------------------------------------------------------------------===//
+
+/// Determine whether any subsequent read of the tensor `opOperand` may occur.
+/// For now, this assumes any use is a read. If any use of the tensor does not
+/// properly dominate `opOperand.getOwner()`, then the tensor cannot be
+/// bufferized inPlace.
+// TODO: For now, this assumes any use is a read. Refine this.
+bool hasInterferingTensorRead(OpOperand &opOperand,
+ const DominanceInfo &domInfo) {
+ if (!opOperand.get().getType().isa<RankedTensorType>())
+ return false;
+ for (auto &use : opOperand.get().getUses()) {
+ Operation *user = use.getOwner();
+ // If properly dominate, there is a clear sequence point and we can dismiss
+ // read.
+ if (domInfo.properlyDominates(user, opOperand.getOwner()))
+ continue;
+ // Otherwise, we need to analyze self-dependencies, for now just let it go.
+ // TODO: proper self-dependence analysis.
+ if (domInfo.dominates(user, opOperand.getOwner()))
+ continue;
+ if (user == opOperand.getOwner() &&
+ use.getOperandNumber() == opOperand.getOperandNumber())
+ continue;
+ LLVM_DEBUG(DBGS() << "found interfering read operand #"
+ << opOperand.getOperandNumber()
+ << " in op: " << *opOperand.getOwner() << "\n");
+ return true;
+ }
+ LLVM_DEBUG(DBGS() << "no interfering read\n");
+ return false;
+}
+
+/// Return false if either:
+/// 1. `opOperand` is produced by a constant op. For now this is assumed to be
+/// bufferized to a GlobalMemrefOp that cannot be written. Generalize in the
+/// future.
+/// 2.`opOperand` is a BlockArgument of a FuncOp that is not known to be
+/// bufferizable inplace.
+/// Return true otherwise.
+static bool bufferizeToWriteable(OpOperand &opOperand) {
+ // Constant tensors are deemed not bufferizable for now.
+ if (auto constantOp =
+ dyn_cast_or_null<ConstantOp>(opOperand.get().getDefiningOp()))
+ return !constantOp.getResult().getType().isa<RankedTensorType>();
+ if (auto bbArg = opOperand.get().dyn_cast<BlockArgument>()) {
+ // Uses of function arguments that may not be written-to need to be copied.
+ // If the function argument itself is not inplaceable, early return false.
+ // If is is inplaceable, interfering tensor read need to be checked.
+ //
+ // TODO: better propagate the fact that we want a single clone inside the
+ // function. Atm every user that wants to write inplace will create its own
+ // alloc, irrespective of whether or not interfering reads occur.
+ if (isa<FuncOp>(bbArg.getOwner()->getParentOp())) {
+ if (getInPlace(bbArg) != InPlaceSpec::True)
+ return false;
+ } else {
+ // Conservatively dump any other block argument for now.
+ return false;
+ }
+ }
+ return true;
+}
+
+/// Return false if either:
+/// 1. `opOperand` is produced by a constant op. For now this is assumed to be
+/// bufferized to a GlobalMemrefOp that cannot be written. Generalize in the
+/// future.
+/// 2.`opOperand` is a BlockArgument of a FuncOp that is not known to be
+/// bufferizable inplace.
+/// 3.`opOperand` has an interfering tensor read.
+/// Return true otherwise.
+static bool isBufferizableInPlace(OpOperand &opOperand,
+ const DominanceInfo &domInfo) {
+ return bufferizeToWriteable(opOperand) &&
+ !hasInterferingTensorRead(opOperand, domInfo);
+}
+
+/// Return true if `operand` bufferizes to a buffer that is known to never be
+/// written.
+static bool bufferizeToReadOnly(OpOperand &operand) {
+ return llvm::TypeSwitch<Operation *, bool>(operand.getOwner())
+ .Case([&](LinalgOp linalgOp) { return linalgOp.isInputTensor(&operand); })
+ .Default([&](Operation *op) { return false; });
+}
+
+/// Assume operand is a use of a `subTensorOp`.
+/// Return true if this use bufferizes to a buffer that is known to never be
+/// written.
+/// Note: This function takes into consideration uses of subTensorOp and whether
+/// the owner of those uses is inplaceable. This needs to be run in postorder to
+/// provide the most accurate analysis; otherwise it is conservative.
+static bool subTensorUseBufferizesToReadOnly(OpOperand &operand) {
+ assert(operand.get().getDefiningOp<SubTensorOp>() && "expected subtensor op");
+ if (auto subTensorInsertOp =
+ dyn_cast<SubTensorInsertOp>(operand.getOwner())) {
+ return operand.getOperandNumber() == 0 /* source of the subTensorInsert*/ &&
+ // If the subTensorInsertOp is not inplace, there is no possible
+ // internal aliasing with subTensorOp, which is inplaceable.
+ getInPlace(subTensorInsertOp->getResult(0)) != InPlaceSpec::True;
+ }
+ return bufferizeToReadOnly(operand);
+}
+
+/// Return true if `dominator.getOwner()` dominates all other uses of
+/// `dominator.get()`.
+static bool dominatesAllOtherUses(OpOperand &dominator,
+ const DominanceInfo &domInfo) {
+ for (OpOperand &use : dominator.get().getUses()) {
+ // Same use.
+ if (use.getOwner() == dominator.getOwner() &&
+ use.getOperandNumber() == dominator.getOperandNumber())
+ continue;
+ if (!domInfo.properlyDominates(dominator.getOwner(), use.getOwner()))
+ return false;
+ }
+ return true;
+}
+
+/// SubTensorOp introduces potential aliasing and a combination of things need
+/// to occur to determine whether it is inplaceable.
+static void analyzeInPlaceSubTensor(SubTensorOp subTensorOp,
+ const DominanceInfo &domInfo) {
+ // Case 1:
+ // a. All uses are known to bufferize to readonly buffers.
+ // b. The source has no use that is not dominated by subTensorOp.
+ // This can skip bufferizeToWriteable analysis / function boundary annotation.
+ if (llvm::all_of(subTensorOp.result().getUses(),
+ subTensorUseBufferizesToReadOnly) &&
+ dominatesAllOtherUses(subTensorOp->getOpOperand(0), domInfo))
+ return setInPlaceOpResult(subTensorOp->getResult(0), InPlaceSpec::True);
+
+ // TODO: Implement more advanced use cases.There is a notion of transitivity
+ // and interference sets lurking.
+}
+
+/// Analyze the internals of a FuncOp to determine inplaceable ops.
static void inPlaceAnalysisFuncOpInternals(FuncOp funcOp,
const DominanceInfo &domInfo) {
assert(funcOp && funcOp->getNumRegions() > 0 && !funcOp.body().empty() &&
"expected a funcOp definition with a body");
- // Start propagating from FuncOp bbArgs.
- destructiveUpdateAnalysis(&funcOp.body().front(), domInfo);
+ funcOp.walk([&](Operation *op) {
+ // Skip SubTensorOp in a first pass.
+ if (auto subTensorOp = dyn_cast<SubTensorOp>(op))
+ return analyzeInPlaceSubTensor(subTensorOp, domInfo);
+
+ // All other ops are checked for `isBufferizableInPlace`.
+ for (OpOperand &opOperand : op->getOpOperands()) {
+ OpResult result = getMatchingOpResult(opOperand);
+ if (result && isBufferizableInPlace(opOperand, domInfo)) {
+ LLVM_DEBUG(DBGS() << "bufferizable inplace operand #"
+ << opOperand.getOperandNumber() << " in " << *op);
+ setInPlaceOpResult(result);
+ }
+ }
+ });
}
static LogicalResult bufferizeFuncOpInternals(
@@ -881,15 +890,22 @@ static LogicalResult bufferizeFuncOpInternals(
/// Start by bufferizing `funcOp` arguments.
if (failed(bufferize(b, funcOp, bvm)))
return failure();
- WalkResult result = funcOp.walk<WalkOrder::PreOrder>([&](Operation *op) {
+ WalkResult result = funcOp.walk<WalkOrder::PostOrder>([&](Operation *op) {
LogicalResult status =
llvm::TypeSwitch<Operation *, LogicalResult>(op)
// Skip BufferCast and TensorLoad ops.
- .Case<memref::BufferCastOp, memref::TensorLoadOp>(
+ // clang-format off
+ .Case<memref::BufferCastOp,
+ memref::TensorLoadOp>(
[&](auto) { return success(); })
- .Case<memref::DimOp, LinalgOp, ReturnOp, SubTensorInsertOp,
+ .Case<memref::DimOp,
+ LinalgOp,
+ ReturnOp,
+ SubTensorOp,
+ SubTensorInsertOp,
VectorTransferOpInterface>(
[&](auto op) { return bufferize(b, op, bvm); })
+ // clang-format on
.Default([&](Operation *op) {
auto isaTensor = [](Type t) { return t.isa<TensorType>(); };
if (llvm::any_of(op->getOperandTypes(), isaTensor) ||
@@ -925,8 +941,17 @@ void LinalgComprehensiveFuncBufferize::runOnFunction() {
DominanceInfo domInfo(funcOp);
BlockAndValueMapping bvm;
DenseMap<FuncOp, SmallVector<int64_t>> tiedResultsMap;
+ LLVM_DEBUG(llvm::dbgs() << "\n\n");
+ LLVM_DEBUG(DBGS() << "Begin InPlaceAnalysisFuncOpInternals:\n"
+ << funcOp << "\n");
inPlaceAnalysisFuncOpInternals(funcOp, domInfo);
+ LLVM_DEBUG(DBGS() << "End InPlaceAnalysisFuncOpInternals:\n"
+ << funcOp << "\n");
+
+ if (testAnalysisOnly)
+ return;
+ LLVM_DEBUG(llvm::dbgs() << "\n\n");
LLVM_DEBUG(DBGS() << "Begin BufferizeFuncOpInternals:\n" << funcOp << "\n");
auto guard = llvm::make_scope_exit([&] {
funcOp.walk(
diff --git a/mlir/lib/Interfaces/ViewLikeInterface.cpp b/mlir/lib/Interfaces/ViewLikeInterface.cpp
index ad2b6b49feb7e..4a963a1d54fda 100644
--- a/mlir/lib/Interfaces/ViewLikeInterface.cpp
+++ b/mlir/lib/Interfaces/ViewLikeInterface.cpp
@@ -155,3 +155,24 @@ ParseResult mlir::parseOperandsOrIntegersSizesList(
return parseOperandsOrIntegersImpl<ShapedType::kDynamicSize>(parser, values,
integers);
}
+
+bool mlir::detail::sameOffsetsSizesAndStrides(
+ OffsetSizeAndStrideOpInterface a, OffsetSizeAndStrideOpInterface b,
+ llvm::function_ref<bool(OpFoldResult, OpFoldResult)> cmp) {
+ if (a.static_offsets().size() != b.static_offsets().size())
+ return false;
+ if (a.static_sizes().size() != b.static_sizes().size())
+ return false;
+ if (a.static_strides().size() != b.static_strides().size())
+ return false;
+ for (auto it : llvm::zip(a.getMixedOffsets(), b.getMixedOffsets()))
+ if (!cmp(std::get<0>(it), std::get<1>(it)))
+ return false;
+ for (auto it : llvm::zip(a.getMixedSizes(), b.getMixedSizes()))
+ if (!cmp(std::get<0>(it), std::get<1>(it)))
+ return false;
+ for (auto it : llvm::zip(a.getMixedStrides(), b.getMixedStrides()))
+ if (!cmp(std::get<0>(it), std::get<1>(it)))
+ return false;
+ return true;
+}
diff --git a/mlir/test/Dialect/Linalg/comprehensive-func-bufferize.mlir b/mlir/test/Dialect/Linalg/comprehensive-func-bufferize.mlir
index 19599b99866bf..674483b7acdc4 100644
--- a/mlir/test/Dialect/Linalg/comprehensive-func-bufferize.mlir
+++ b/mlir/test/Dialect/Linalg/comprehensive-func-bufferize.mlir
@@ -1,4 +1,5 @@
// RUN: mlir-opt %s -linalg-comprehensive-func-bufferize -split-input-file | FileCheck %s
+// RUN: mlir-opt %s -linalg-comprehensive-func-bufferize=test-analysis-only -split-input-file | FileCheck %s --check-prefix=ANALYSIS
// CHECK-DAG: #[[$map_2d_dyn:.*]] = affine_map<(d0)[s0, s1] -> (d0 * s1 + s0)>
@@ -218,3 +219,140 @@ func @subtensor_insert_fun_not_inplace(%A : tensor<?xf32> {linalg.inplaceable =
%r1 = linalg.fill(%A, %f0) : tensor<?xf32>, f32 -> tensor<?xf32>
return %r0, %r1: tensor<?xf32>, tensor<?xf32>
}
+
+// -----
+
+// CHECK-LABEL: func @subtensor_fun
+func @subtensor_fun(%A : tensor<?xf32> {linalg.inplaceable = true})
+ -> tensor<4xf32>
+{
+ // CHECK: %[[BUFFER_CAST_A:.*]] = memref.buffer_cast {{.*}} : memref<?xf32
+
+ // CHECK: %[[ALLOC:.*]] = memref.alloc() : memref<4xf32>
+ // CHECK: %[[SV:.*]] = memref.subview %[[BUFFER_CAST_A]][0] [4] [1]
+ // CHECK: linalg.copy(%[[SV]], %[[ALLOC]])
+ %r0 = subtensor %A[0][4][1] : tensor<?xf32> to tensor<4xf32>
+ return %r0: tensor<4xf32>
+}
+
+// -----
+
+// ANALYSIS-LABEL: func @subtensor_readonly_use
+func @subtensor_readonly_use(
+ %A : tensor<?x?xf32> {linalg.inplaceable = true},
+ %B : tensor<4x4xf32>, %C : tensor<4x4xf32>) -> tensor<4x4xf32>
+{
+ // subtensor is only used as a read.
+ // ANALYSIS: subtensor {{.*}} {__inplace_results_attr__ = ["true"]}
+ %sA = subtensor %A[0, 0][4, 4][1, 1] : tensor<?x?xf32> to tensor<4x4xf32>
+ // matmul output operand is not inplaceable at the function boundary.
+ // ANALYSIS: linalg.matmul {{.*}}
+ // ANALYSIS-NOT: {__inplace_results_attr__ = ["true"]}
+ %D = linalg.matmul ins(%sA, %B: tensor<4x4xf32>, tensor<4x4xf32>)
+ outs(%B: tensor<4x4xf32>)
+ -> tensor<4x4xf32>
+ return %D: tensor<4x4xf32>
+}
+
+// -----
+
+// ANALYSIS-LABEL: func @subtensor_nonmatching_subtensor_insert_inplace
+func @subtensor_nonmatching_subtensor_insert_inplace(
+ %A : tensor<?xf32> {linalg.inplaceable = true}, %idx: index)
+ -> tensor<?xf32>
+{
+ // subtensor has no matching subtensor_insert and is not just used by known
+ // readonly ops.
+ // ANALYSIS: subtensor {{.*}}
+ // ANALYSIS-NOT: {__inplace_results_attr__ = ["true"]}
+ %r0 = subtensor %A[0][4][1] : tensor<?xf32> to tensor<4xf32>
+ // subtensor_insert can bufferize inplace fine.
+ // ANALYSIS: subtensor_insert {{.*}} {__inplace_results_attr__ = ["true"]}
+ %r1 = subtensor_insert %r0 into %A[%idx][4][1] : tensor<4xf32> into tensor<?xf32>
+ return %r1: tensor<?xf32>
+}
+
+// -----
+
+// ANALYSIS-LABEL: func @subtensor_nonmatching_subtensor_insert_non_inplace
+func @subtensor_nonmatching_subtensor_insert_non_inplace(
+ %A : tensor<?xf32> {linalg.inplaceable = false}, %idx: index)
+ -> tensor<?xf32>
+{
+ // subtensor has no matching subtensor_insert and is not just used by known
+ // readonly ops.
+ // ANALYSIS: subtensor {{.*}} {__inplace_results_attr__ = ["true"]}
+ %r0 = subtensor %A[0][4][1] : tensor<?xf32> to tensor<4xf32>
+ // subtensor_insert cannot bufferize inplace.
+ // ANALYSIS: subtensor_insert {{.*}}
+ // ANALYSIS-NOT: {__inplace_results_attr__ = ["true"]}
+ %r1 = subtensor_insert %r0 into %A[%idx][4][1] : tensor<4xf32> into tensor<?xf32>
+ return %r1: tensor<?xf32>
+}
+
+// -----
+
+// ANALYSIS-LABEL: func @subtensor_matching_subtensor_insert
+func @subtensor_matching_subtensor_insert(%A : tensor<?xf32> {linalg.inplaceable = true})
+ -> tensor<?xf32>
+{
+ // subtensor has a matching subtensor_insert that bufferizes inplace.
+ // TODO: Atm subtensor is not inplaceable but can be.
+ // In the grander scheme, this will canonicalize away beforehand.
+ // ANALYSIS: subtensor {{.*}}
+ // ANALYSIS-NOT: {__inplace_results_attr__ = ["true"]}
+ %r0 = subtensor %A[0][4][1] : tensor<?xf32> to tensor<4xf32>
+ // subtensor_insert can bufferize inplace fine.
+ // ANALYSIS: subtensor_insert {{.*}} {__inplace_results_attr__ = ["true"]}
+ %r1 = subtensor_insert %r0 into %A[0][4][1] : tensor<4xf32> into tensor<?xf32>
+ return %r1: tensor<?xf32>
+}
+
+// -----
+
+// ANALYSIS-LABEL: func @subtensor_matching_and_nonmatching_1
+func @subtensor_matching_and_nonmatching_1(%A : tensor<?xf32> {linalg.inplaceable = true}, %idx: index)
+ -> (tensor<?xf32>, tensor<?xf32>)
+{
+ // %r1 is not inplaceable and %r2 is a matching subtensor_insert so %r0 could
+ // be inplaceable.
+ // In the grander scheme, %r2 will canonicalize away beforehand but %r0 will still
+ // not be inplaceable as the production of %r1 may involve a self-copy.
+ // ANALYSIS: subtensor {{.*}}
+ // ANALYSIS-NOT: {__inplace_results_attr__ = ["true"]}
+ %r0 = subtensor %A[0][4][1] : tensor<?xf32> to tensor<4xf32>
+ // ANALYSIS: subtensor_insert {{.*}}
+ // ANALYSIS-NOT: {__inplace_results_attr__ = ["true"]}
+ %r1 = subtensor_insert %r0 into %A[%idx][4][1] : tensor<4xf32> into tensor<?xf32>
+ // ANALYSIS: subtensor_insert {{.*}} {__inplace_results_attr__ = ["true"]}
+ %r2 = subtensor_insert %r0 into %A[0][4][1] : tensor<4xf32> into tensor<?xf32>
+ return %r1, %r2: tensor<?xf32>, tensor<?xf32>
+}
+
+// -----
+
+// ANALYSIS-LABEL: func @subtensor_matching_and_nonmatching_2
+func @subtensor_matching_and_nonmatching_2(%A : tensor<?xf32> {linalg.inplaceable = true}, %idx: index)
+ -> (tensor<?xf32>, tensor<?xf32>)
+{
+ // %r1 is not inplaceable and %r2 is a matching subtensor_insert so %r0 should
+ // be inplaceable.
+ // In the grander scheme, %r2 will canonicalize away beforehand and %r0 will become
+ // inplaceable by reducing to the `subtensor_nonmatching_subtensor_insert_non_inplace`
+ // case,
+ // ANALYSIS: subtensor {{.*}}
+ // ANALYSIS-NOT: {__inplace_results_attr__ = ["true"]}
+ %r0 = subtensor %A[0][4][1] : tensor<?xf32> to tensor<4xf32>
+ // ANALYSIS: subtensor_insert {{.*}}
+ // ANALYSIS-NOT: {__inplace_results_attr__ = ["true"]}
+ %r2 = subtensor_insert %r0 into %A[0][4][1] : tensor<4xf32> into tensor<?xf32>
+ // ANALYSIS: subtensor_insert {{.*}} {__inplace_results_attr__ = ["true"]}
+ %r1 = subtensor_insert %r0 into %A[%idx][4][1] : tensor<4xf32> into tensor<?xf32>
+
+ return %r1, %r2: tensor<?xf32>, tensor<?xf32>
+}
+
+// -----
+
+// TODO: unknown ops, linalg chain success, linalg chain failure.
+
More information about the Mlir-commits
mailing list