[Mlir-commits] [mlir] 1e01a89 - [mlir][Linalg] Add ComprehensiveBufferize for functions(step 1/n)
Nicolas Vasilache
llvmlistbot at llvm.org
Thu May 13 15:25:05 PDT 2021
Author: Nicolas Vasilache
Date: 2021-05-13T22:24:40Z
New Revision: 1e01a8919f8d0fdc8c2f5f679fcc541b61381b0f
URL: https://github.com/llvm/llvm-project/commit/1e01a8919f8d0fdc8c2f5f679fcc541b61381b0f
DIFF: https://github.com/llvm/llvm-project/commit/1e01a8919f8d0fdc8c2f5f679fcc541b61381b0f.diff
LOG: [mlir][Linalg] Add ComprehensiveBufferize for functions(step 1/n)
This is the first step towards upstreaming comprehensive bufferization following the
discourse post: https://llvm.discourse.group/t/rfc-linalg-on-tensors-update-and-comprehensive-bufferization-rfc/3373/6.
This first commit introduces a basic pass for bufferizing within function boundaries,
assuming that the inplaceable function boundaries have been marked as such.
Differential revision: https://reviews.llvm.org/D101693
Added:
mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp
mlir/test/Dialect/Linalg/comprehensive-func-bufferize.mlir
Modified:
mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td
mlir/include/mlir/Dialect/Linalg/Passes.h
mlir/include/mlir/Dialect/Linalg/Passes.td
mlir/lib/Dialect/Linalg/IR/LinalgTypes.cpp
mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td
index 007cb6de12f60..307eebf6ddf5e 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td
@@ -37,7 +37,17 @@ def Linalg_Dialect : Dialect {
let dependentDialects = [
"AffineDialect", "StandardOpsDialect", "tensor::TensorDialect"
];
+ let hasOperationAttrVerify = 1;
let extraClassDeclaration = [{
+ /// Attribute name used to to memoize indexing maps for named ops.
+ constexpr const static ::llvm::StringLiteral
+ kMemoizedIndexingMapsAttrName = "linalg.memoized_indexing_maps";
+
+ /// Attribute name used to mark region arguments that can be bufferized
+ /// in-place during linalg comprehensive bufferization.
+ constexpr const static ::llvm::StringLiteral
+ kInplaceableAttrName = "linalg.inplaceable";
+
using RegionBuilderFunType = llvm::function_ref<void(Block &, ValueRange)>;
RegionBuilderFunType getRegionBuilder(StringRef name) {
return namedStructuredOpRegionBuilders.lookup(name);
diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.h b/mlir/include/mlir/Dialect/Linalg/Passes.h
index b81ea52ba3357..47dac77701143 100644
--- a/mlir/include/mlir/Dialect/Linalg/Passes.h
+++ b/mlir/include/mlir/Dialect/Linalg/Passes.h
@@ -53,6 +53,12 @@ std::unique_ptr<OperationPass<FuncOp>> createConvertLinalgToParallelLoopsPass();
/// Placeholder for now, this is NYI.
std::unique_ptr<OperationPass<FuncOp>> createConvertLinalgToAffineLoopsPass();
+/// Create a pass that bufferizes the body of a FuncOp and tries to reuse the
+/// buffers for those arguments that:
+/// a) have been annotated 'inplaceable' and
+/// b) whose buffer uses would be free of memory hazards.
+std::unique_ptr<Pass> createLinalgComprehensiveFuncBufferizePass();
+
/// Create a pass to convert Linalg operations which work on tensors to use
/// buffers instead.
std::unique_ptr<OperationPass<FuncOp>> createLinalgBufferizePass();
diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.td b/mlir/include/mlir/Dialect/Linalg/Passes.td
index c529dfded3eab..2934f17f1e901 100644
--- a/mlir/include/mlir/Dialect/Linalg/Passes.td
+++ b/mlir/include/mlir/Dialect/Linalg/Passes.td
@@ -22,6 +22,21 @@ def ConvertElementwiseToLinalg : FunctionPass<"convert-elementwise-to-linalg"> {
let dependentDialects = ["linalg::LinalgDialect", "memref::MemRefDialect"];
}
+def LinalgComprehensiveFuncBufferize :
+ FunctionPass<"linalg-comprehensive-func-bufferize"> {
+ let summary = "Bufferize (tensor into memref) the body of a FuncOp and try "
+ "to reuse the buffers for those arguments that "
+ "a) have been annotated 'inplaceable' and "
+ "b) whose buffer uses would be free of memory hazards";
+ let description = [{
+ This pass implements a cross-dialect bufferization approach and performs an
+ analysis to determine which op operands and results may be bufferized in the
+ same buffers. The analysis is performed on SSA use-def chains starting from
+ function operands that are annotated with the 'inplaceable' attribute
+ }];
+ let constructor = "mlir::createLinalgComprehensiveFuncBufferizePass()";
+}
+
def LinalgFoldUnitExtentDims : FunctionPass<"linalg-fold-unit-extent-dims"> {
let summary = "Remove unit-extent dimension in Linalg ops on tensors";
let constructor = "mlir::createLinalgFoldUnitExtentDimsPass()";
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgTypes.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgTypes.cpp
index 0337365761a29..047e1e9a62be5 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgTypes.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgTypes.cpp
@@ -15,6 +15,7 @@
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/DialectImplementation.h"
+#include "mlir/IR/FunctionSupport.h"
#include "mlir/Parser.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Transforms/InliningUtils.h"
@@ -57,6 +58,14 @@ struct LinalgInlinerInterface : public DialectInlinerInterface {
// LinalgDialect
//===----------------------------------------------------------------------===//
+/// Attribute name used to to memoize indexing maps for named ops.
+constexpr const ::llvm::StringLiteral
+ LinalgDialect::kMemoizedIndexingMapsAttrName;
+
+/// Attribute name used to mark region arguments that can be bufferized
+/// in-place during linalg comprehensive bufferization.
+constexpr const ::llvm::StringLiteral LinalgDialect::kInplaceableAttrName;
+
/// Trait to check if T provides a `regionBuilder` method.
template <typename T, typename... Args>
using has_region_builder = decltype(T::regionBuilder);
@@ -131,3 +140,21 @@ void mlir::linalg::LinalgDialect::printType(Type type,
DialectAsmPrinter &os) const {
print(type.cast<RangeType>(), os);
}
+
+LogicalResult LinalgDialect::verifyOperationAttribute(Operation *op,
+ NamedAttribute attr) {
+ if (attr.first == LinalgDialect::kInplaceableAttrName) {
+ if (!attr.second.isa<BoolAttr>()) {
+ return op->emitError() << "'" << LinalgDialect::kInplaceableAttrName
+ << "' is expected to be a boolean attribute";
+ }
+ if (!op->hasTrait<OpTrait::FunctionLike>())
+ return op->emitError() << "expected " << attr.first
+ << " to be used on function-like operations";
+ return success();
+ }
+ if (attr.first == LinalgDialect::kMemoizedIndexingMapsAttrName)
+ return success();
+ return op->emitError() << "attribute '" << attr.first
+ << "' not supported by the linalg dialect";
+}
diff --git a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
index 6a2f81185af9b..e52bb39815c70 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
@@ -1,6 +1,7 @@
add_mlir_dialect_library(MLIRLinalgTransforms
Bufferize.cpp
CodegenStrategy.cpp
+ ComprehensiveBufferize.cpp
Detensorize.cpp
DropUnitDims.cpp
ElementwiseToLinalg.cpp
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp
new file mode 100644
index 0000000000000..2534eeeb7dcd5
--- /dev/null
+++ b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp
@@ -0,0 +1,785 @@
+//===- ComprehensiveBufferize.cpp - Single pass bufferization -------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Perform inplace bufferization within function boundaries.
+// This is a specialized pass that supports inplace analysis for a fixed subset
+// of ops that have well-defined inplace semantics.
+// This pass caters to high-performance codegen where buffer reuse is deemed
+// necessary: the pass should fail if the bufferized form of the function needs
+// to return any buffer.
+// Generic control-flow and branching are unsupported.
+// Composability with extensible set of ops is not a first-class concern.
+//
+// Bufferization occurs by:
+// a. performing an inPlace analysis `inPlaceAnalysisFuncOpInternals`
+// which marks each operation within the function with the
+// `kInPlaceResultsAttrName` attribute.
+// b. traversing each operation in the function and rewriting it in
+// buffer form and keeping a BlockAndValueMapping mapping of the
+// rewrites. New allocations are introduced during this step.
+// TODO: Allocation + depending op hoisting to outermost enclosing
+// sequential scope.
+// c. at the end of this bufferization, 2 cases may occur:
+// * inplaceable function arguments may be reused in place after the
+// function itself has been bufferized. This is encoded by IR resembling:
+//
+// ```
+// #map = affine_map<(d0)[s0, s1] -> (d0 * s1 + s0)>
+// func @foo(%A: tensor<?xf32> {linalg.inplaceable = true}) -> tensor<?xf32> {
+// %0 = memref.buffer_cast %A : memref<?xf32, #map>
+// // ... uses of %0
+// %res = memref.tensor_load %0 : memref<?xf32, #map>
+// return %res : tensor<?xf32>
+// }
+// ```
+//
+// this is the cue for the bufferization of the function foo (and calls to
+// it) may bufferize to `func @foo(%A: memref<?xf32, some_layout>)`.
+// To fully achieve bufferization, an additional analysis is needed to
+// determine whether function argument/operand pairs bufferize to a single
+// inplace buffer argument (i.e. functions may return tensors in arbitrary
+// order that may not match argument numbers).
+// * results that don't map to an inplaceable function argument must be
+// allocated. Since memref semantics wrt ownership of the underlying
+// memory region are not well-defined, comprehensive bufferization chooses
+// to perform allocations in a scoped fashion: returning memrefs is always
+// considered illegal. Such scenarios are encoded by IR resembling:
+//
+// ```
+// #map = affine_map<(d0)[s0, s1] -> (d0 * s1 + s0)>
+// func @foo(%A: tensor<?xf32> {linalg.inplaceable = true}) -> tensor<?xf32> {
+// %0 = memref.buffer_cast %A : memref<?xf32, #map>
+// %1 = memref.dim %0, %c0 : memref<?xf32, #map>
+// %2 = memref.alloc(%1) : memref<?xf32>
+// %3 = memref.cast %2 : memref<?xf32> to memref<?xf32, #map>
+// // ... uses of %3
+// memref.dealloc %2 : memref<?xf32, #map>
+// %res = memref.tensor_load %3 : memref<?xf32, #map>
+// return %res : tensor<?xf32>
+// }
+// ```
+//
+// this is the cue for the bufferization of the function foo (and calls to
+// it) that it must bufferize to
+// `func @foo(%A: memref<?xf32, some_layout>,
+// %B: memref<?xf32, some_layout>)` (i.e. make a cloned
+// allocation of the result tensor)
+// To fully achieve bufferization, the alloc/dealloc pair must be lifted
+// out of the function at each call site.
+//
+// Lastly, note that layout map chosen to bufferize is the most dynamic
+// canonical strided layout of the proper rank. This ensures compatibility with
+// expected layouts after transformations. Combinations of memref.cast +
+// canonicalization are responsible for clean ups.
+
+#include "PassDetail.h"
+#include "mlir/Analysis/SliceAnalysis.h"
+#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
+#include "mlir/Dialect/Linalg/Passes.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/IR/Operation.h"
+#include "mlir/Interfaces/LoopLikeInterface.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/BufferUtils.h"
+
+#include "llvm/ADT/ScopeExit.h"
+#include "llvm/ADT/TypeSwitch.h"
+
+#define DEBUG_TYPE "comprehensive-func-bufferize"
+
+using namespace mlir;
+using namespace linalg;
+using namespace tensor;
+
+#define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ")
+
+//===----------------------------------------------------------------------===//
+// Op-specific semantics helper to retrieve matching inplaceable result.
+//===----------------------------------------------------------------------===//
+
+/// Return the OpResult that matches an operand.
+/// Return null if no such result exists.
+OpResult getMatchingOpResult(LinalgOp linalgOp, OpOperand &opOperand) {
+ if (!opOperand.get().getType().isa<RankedTensorType>())
+ return OpResult();
+ // For now assume inputs are never inplaceable.
+ // TODO: refine this.
+ if (opOperand.getOperandNumber() < linalgOp.getNumInputs())
+ return OpResult();
+ // For now assume if the operand appears twice, it is not inplaceable.
+ // TODO: refine this.
+ for (auto &opOperand2 : linalgOp->getOpOperands()) {
+ if (opOperand.getOperandNumber() == opOperand2.getOperandNumber())
+ continue;
+ if (opOperand.get() == opOperand2.get())
+ return OpResult();
+ }
+ int64_t outputOperandIndex =
+ opOperand.getOperandNumber() - linalgOp.getNumInputs();
+ int64_t numOutputBuffers = 0;
+ for (unsigned idx = 0; idx < outputOperandIndex; ++idx)
+ if (!linalgOp.getOutputShapedType(idx).isa<TensorType>())
+ ++numOutputBuffers;
+ return linalgOp->getResult(outputOperandIndex - numOutputBuffers);
+}
+
+/// Determine which results may be reused inplace by the bufferization
+/// patterns of `bufferizeFuncOpInternals`.
+/// The inplace analysis uses this information along with interfering read
+/// analysis to determine which op results reuse the same buffer as some
+/// operand.
+OpResult getMatchingOpResult(OpOperand &opOperand) {
+ OpResult res =
+ llvm::TypeSwitch<Operation *, OpResult>(opOperand.getOwner())
+ .Case([&](LinalgOp op) { return getMatchingOpResult(op, opOperand); })
+ .Default([&](Operation *op) { return OpResult(); });
+ return res;
+}
+
+//===----------------------------------------------------------------------===//
+// Bufferization-specific attribute manipulation.
+//===----------------------------------------------------------------------===//
+
+/// Attribute marker to specify op results that can be bufferized inPlace.
+constexpr StringLiteral kInPlaceResultsAttrName = "__inplace_results_attr__";
+
+// TODO: proper enum.
+enum class InPlaceSpec {
+ False,
+ True,
+ None,
+};
+
+static StringRef stringify(InPlaceSpec val) {
+ switch (val) {
+ case InPlaceSpec::False:
+ return "false";
+ case InPlaceSpec::True:
+ return "true";
+ case InPlaceSpec::None:
+ return "none";
+ }
+ return "";
+}
+
+static Optional<InPlaceSpec> symbolize(StringRef str) {
+ return StringSwitch<Optional<InPlaceSpec>>(str)
+ .Case("false", InPlaceSpec::False)
+ .Case("true", InPlaceSpec::True)
+ .Case("none", InPlaceSpec::None)
+ .Default(None);
+}
+
+/// Mark whether OpResult can actually be bufferized inplace. If `inPlace` is
+/// `InPlaceSpec::True`, the use-def chain analysis has guaranteed that no
+/// subsequent write would occur to the bufferized tensor value (i.e. the result
+/// can be bufferized inPlace).
+static void setInPlaceOpResult(OpResult opResult,
+ InPlaceSpec inPlace = InPlaceSpec::True) {
+ if (!opResult)
+ return;
+
+ Operation *op = opResult.getOwner();
+ auto attr =
+ op->getAttr(kInPlaceResultsAttrName).dyn_cast_or_null<ArrayAttr>();
+ SmallVector<StringRef> inPlaceVector =
+ attr ? SmallVector<StringRef>(
+ llvm::to_vector<4>(attr.getAsValueRange<StringAttr>()))
+ : SmallVector<StringRef>(op->getNumResults(),
+ stringify(InPlaceSpec::None));
+ LLVM_DEBUG(DBGS() << "Set inPlace=" << stringify(inPlace) << ": " << *op
+ << " @idx=" << opResult.getResultNumber() << "\n");
+ inPlaceVector[opResult.getResultNumber()] = stringify(inPlace);
+ op->setAttr(kInPlaceResultsAttrName,
+ OpBuilder(op).getStrArrayAttr(inPlaceVector));
+}
+
+/// Get the InPlaceSpec attribute entry `kInPlaceResultsAttrName` for
+/// `opResult`. If the result is `InPlaceSpec::True`, the use-def chain analysis
+/// has guaranteed that no subsequent read of the tensor value occurs and the
+/// result can be buferized inPlace.
+/// If no InPlaceSpec attribute has been set for `opResult`, return
+/// InPlaceSpec::None.
+static InPlaceSpec getInPlace(OpResult opResult) {
+ if (!opResult)
+ return InPlaceSpec::None;
+
+ Operation *op = opResult.getOwner();
+ auto attr =
+ op->getAttr(kInPlaceResultsAttrName).dyn_cast_or_null<ArrayAttr>();
+ if (!attr)
+ return InPlaceSpec::None;
+
+ // Must return a proper value.
+ return *symbolize(*(attr.getAsValueRange<StringAttr>().begin() +
+ opResult.getResultNumber()));
+}
+
+/// Get inPlace information for `bbArg`.
+/// If it does not come from a function, return InPlaceSpec::False.
+static InPlaceSpec getInPlace(BlockArgument bbArg) {
+ auto funcOp = dyn_cast<FuncOp>(bbArg.getOwner()->getParentOp());
+ if (!funcOp)
+ return InPlaceSpec::False;
+ auto attr = funcOp.getArgAttrOfType<BoolAttr>(
+ bbArg.getArgNumber(), LinalgDialect::kInplaceableAttrName);
+ if (!attr)
+ return InPlaceSpec::None;
+ return attr.getValue() ? InPlaceSpec::True : InPlaceSpec::False;
+}
+
+//===----------------------------------------------------------------------===//
+// Bufferization-specific BlockAndValueMapping support with debugging.
+//===----------------------------------------------------------------------===//
+
+/// Wrapper for better debugging.
+static void map(BlockAndValueMapping &bvm, ValueRange keys, ValueRange values) {
+ assert(!keys.empty() && "Unexpected empty keys");
+ LLVM_DEBUG(DBGS() << "Map: " << keys.front() << " to " << values.front()
+ << "\n");
+ return bvm.map(keys, values);
+}
+
+/// Wrapper for better debugging.
+static void map(BlockAndValueMapping &bvm, Value key, Value value) {
+ LLVM_DEBUG(DBGS() << "Map: " << key << " to " << value << "\n");
+ return bvm.map(key, value);
+}
+
+/// Wrapper for better debugging.
+static Value lookup(BlockAndValueMapping &bvm, Value key) {
+ // TODO: if key comes from bbArg, forward.
+ assert(key.getType().isa<TensorType>());
+ if (!bvm.lookupOrNull(key)) {
+ if (auto bbArg = key.dyn_cast<BlockArgument>()) {
+ if (isa<FuncOp>(key.getParentBlock()->getParentOp()))
+ key.getParentBlock()->getParentOp()->dump();
+ else
+ key.getParentBlock()->getParentOp()->getParentOfType<FuncOp>()->dump();
+ bbArg.getOwner()->getParentOp()->dump();
+ } else {
+ key.getDefiningOp()->getParentOfType<FuncOp>()->dump();
+ }
+ llvm::errs() << "NO VALUE FOR KEY: " << key << "\n";
+ abort();
+ }
+ return bvm.lookup(key);
+}
+
+//===----------------------------------------------------------------------===//
+// Bufferization-specific support.
+//===----------------------------------------------------------------------===//
+
+/// Determine whether any subsequent read of the tensor `opOperand` may occur.
+/// For now, this assumes any use is a read. If any use of the tensor does not
+/// properly dominate `opOperand.getOwner()`, then the tensor cannot be
+/// bufferized inPlace.
+// TODO: For now, this assumes any use is a read. Refine this.
+bool hasInterferingTensorRead(OpOperand &opOperand,
+ const DominanceInfo &domInfo) {
+ if (!opOperand.get().getType().isa<RankedTensorType>())
+ return false;
+ for (auto &use : opOperand.get().getUses()) {
+ Operation *user = use.getOwner();
+
+ // If properly dominate, there is a clear sequence point and we can dismiss
+ // read.
+ if (domInfo.properlyDominates(user, opOperand.getOwner()))
+ continue;
+ // Otherwise, we need to analyze self-dependencies, for now just let it go.
+ // TODO: proper self-dependence analysis.
+ if (domInfo.dominates(user, opOperand.getOwner()))
+ continue;
+ if (user == opOperand.getOwner() &&
+ use.getOperandNumber() == opOperand.getOperandNumber())
+ continue;
+ LLVM_DEBUG(DBGS() << "found interfering read operand #"
+ << opOperand.getOperandNumber()
+ << " in op: " << *opOperand.getOwner() << "\n");
+ return true;
+ }
+ LLVM_DEBUG(DBGS() << "no interfering read\n");
+ return false;
+}
+
+/// Return false if either:
+/// 1. `opOperand` is produced by a constant op. For now this is assumed to be
+/// bufferized to a GlobalMemrefOp that cannot be written. Generalize in the
+/// future.
+/// 2.`opOperand` is a BlockArgument of a FuncOp that is not known to be
+/// bufferizable inplace.
+/// 3.`opOperand` has an interfering tensor read.
+/// Return true otherwise.
+bool isBufferizableInPlace(OpOperand &opOperand, const DominanceInfo &domInfo) {
+ // Constant tensors are deemed not bufferizable for now.
+ if (auto constantOp =
+ dyn_cast_or_null<ConstantOp>(opOperand.get().getDefiningOp()))
+ return !constantOp.getResult().getType().isa<RankedTensorType>();
+ if (auto bbArg = opOperand.get().dyn_cast<BlockArgument>()) {
+ // Uses of function arguments that may not be written-to need to be copied.
+ // If the function argument itself is not inplaceable, early return false.
+ // If is is inplaceable, interfering tensor read need to be checked.
+ //
+ // TODO: better propagate the fact that we want a single clone inside the
+ // function. Atm every user that wants to write inplace will create its own
+ // alloc, irrespective of whether or not interfering reads occur.
+ if (isa<FuncOp>(bbArg.getOwner()->getParentOp()))
+ if (getInPlace(bbArg) != InPlaceSpec::True)
+ return false;
+ }
+ return !hasInterferingTensorRead(opOperand, domInfo);
+}
+
+//===----------------------------------------------------------------------===//
+// Bufferization-specific MemRefType support.
+//===----------------------------------------------------------------------===//
+
+/// Return a contiguous MemRefType (i.e. with canonical/empty layout map) with
+/// the same shape as `shapedType` and specified `layout` and `addressSpace`.
+static MemRefType getContiguousMemRefType(ShapedType shapedType,
+ ArrayRef<AffineMap> layout = {},
+ unsigned addressSpace = 0) {
+ if (RankedTensorType tensorType = shapedType.dyn_cast<RankedTensorType>())
+ return MemRefType::get(tensorType.getShape(), tensorType.getElementType(),
+ layout, addressSpace);
+ MemRefType memrefType = shapedType.cast<MemRefType>();
+ return MemRefType::get(memrefType.getShape(), memrefType.getElementType(),
+ layout, addressSpace);
+}
+
+/// Return a contiguous MemRefType (i.e. with canonical/empty layout map) with
+/// the same shape as `shapedType` and specified `layout` and `addressSpace` or
+/// an UnrankedMemRefType otherwise.
+static Type getContiguousOrUnrankedMemRefType(Type type,
+ ArrayRef<AffineMap> layout = {},
+ unsigned addressSpace = 0) {
+ if (type.isa<RankedTensorType, MemRefType>())
+ return getContiguousMemRefType(type.cast<ShapedType>(), layout,
+ addressSpace);
+ assert(layout.empty() && "expected empty layout with UnrankedMemRefType");
+ return UnrankedMemRefType::get(getElementTypeOrSelf(type), addressSpace);
+}
+
+/// Return a MemRefType to which the `tensorType` can be bufferized in a
+/// composable fashion. The layout must be the most dynamic possible and
+/// canonicalize away once bufferization is finished.
+static MemRefType getDynamicMemRefType(RankedTensorType tensorType,
+ unsigned addressSpace = 0) {
+ // TODO: address space decisions to connect with the actual alloc.
+ int64_t dynamicOffset = ShapedType::kDynamicStrideOrOffset;
+ SmallVector<int64_t> dynamicStrides(tensorType.getRank(),
+ ShapedType::kDynamicStrideOrOffset);
+ AffineMap stridedLayout = makeStridedLinearLayoutMap(
+ dynamicStrides, dynamicOffset, tensorType.getContext());
+ return MemRefType::get(tensorType.getShape(), tensorType.getElementType(),
+ stridedLayout, addressSpace);
+}
+
+//===----------------------------------------------------------------------===//
+// Bufferization-specific inPlace pattern matching support.
+//===----------------------------------------------------------------------===//
+
+/// First assign `op` if `slice.back()` isa `T`, then check condition.
+/// If anything fails just return failure. Otherwise update `sliceRef` by
+/// dropping `sliceRef.back()`, then return success().
+template <typename T>
+static LogicalResult
+matchAndDropBack(ArrayRef<Operation *> &sliceRef, T &op,
+ llvm::function_ref<LogicalResult(T)> condition = nullptr) {
+ if (sliceRef.empty())
+ return failure();
+ op = dyn_cast<T>(sliceRef.back());
+ if (!op || (condition && failed(condition(op))))
+ return failure();
+ sliceRef = sliceRef.drop_back();
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// Bufferization-specific scoped alloc/dealloc insertion support.
+//===----------------------------------------------------------------------===//
+
+/// Create an Allocop/DeAllocOp pair, where the AllocOp is after
+/// `shapedValue.getDefiningOp` (or at the top of the block in case of a bbArg)
+/// and the DeallocOp is at the end of the block.
+static Value createNewAllocDeallocPairForShapedValue(
+ OpBuilder &b, Location loc, Value shapedValue,
+ SmallVector<Value, 4> dynOperands = {}) {
+ // Take a guard before anything else.
+ OpBuilder::InsertionGuard g(b);
+
+ // TODO: non-zero address space.
+ // TODO: layout information if relevant.
+ // Cannot allocate an unranked memref so just always go for the contiguous
+ // form.
+ MemRefType allocMemRefType =
+ getContiguousMemRefType(shapedValue.getType().cast<ShapedType>());
+ assert(shapedValue.getType().isa<ShapedType>());
+ MemRefType memRefType = shapedValue.getType().dyn_cast<MemRefType>();
+ memRefType = memRefType ? memRefType : allocMemRefType;
+
+ if (auto bbArg = shapedValue.dyn_cast<BlockArgument>()) {
+ b.setInsertionPointToStart(bbArg.getOwner());
+ loc = bbArg.getOwner()->getParentOp()->getLoc();
+ } else {
+ b.setInsertionPointAfter(shapedValue.getDefiningOp());
+ loc = shapedValue.getDefiningOp()->getLoc();
+ }
+
+ // If the dynOperands are not passed explicity, copmpute them.
+ // This circumvents currently missing dim(init_tensor) canonicalizations.
+ // TODO: dim(init_tensor) canonicalization.
+ if (dynOperands.empty()) {
+ for (auto dim : llvm::enumerate(memRefType.getShape()))
+ if (dim.value() == ShapedType::kDynamicSize)
+ dynOperands.push_back(
+ b.create<memref::DimOp>(loc, shapedValue, dim.index()));
+ }
+
+ Value allocated =
+ b.create<memref::AllocOp>(loc, allocMemRefType, dynOperands);
+ Value casted = allocated;
+ if (memRefType != allocMemRefType)
+ casted = b.create<memref::CastOp>(loc, memRefType, allocated);
+ b.setInsertionPoint(allocated.getParentBlock()->getTerminator());
+ b.create<memref::DeallocOp>(loc, allocated);
+ return casted;
+}
+
+//===----------------------------------------------------------------------===//
+// Bufferization-specific inPlace analysis support.
+//===----------------------------------------------------------------------===//
+
+/// Detect the simple terminator pattern:
+/// ```
+/// candidate -> ... -> inplaceable_op(candidate) -> term
+/// ```
+template <typename ContainerOp, typename TerminatorOp>
+static LogicalResult detectInplaceOpToTerminator(Operation *parentOp,
+ BlockArgument candidate,
+ ArrayRef<Operation *> slice) {
+ assert(parentOp && "Unexpected null parent op");
+ if (!isa<ContainerOp>(parentOp))
+ return failure();
+ TerminatorOp terminatorOp;
+ // Match returnOp and update slice.
+ if (failed(matchAndDropBack(slice, terminatorOp))) {
+ LLVM_DEBUG(DBGS() << "FAIL: inplaceOpToTerm pattern -> slice must end with "
+ "a known terminator\n");
+ return failure();
+ }
+ return success();
+}
+
+/// The following uses internal knowledge of the position of tied operand /
+/// results.
+static void propagateInPlace(const SmallVector<OpOperand *> &initalWorklist,
+ const DominanceInfo &domInfo) {
+ LLVM_DEBUG(DBGS() << "\n\n");
+ LLVM_DEBUG(DBGS() << "Start propagateInPlace from initial WL\n");
+ for (OpOperand *operand : initalWorklist)
+ LLVM_DEBUG(DBGS() << "WL item: " << operand->get() << " used by "
+ << *operand->getOwner() << "\n");
+ SmallVector<OpOperand *> worklist(initalWorklist);
+ for (unsigned idx = 0; idx < worklist.size(); ++idx) {
+ // TODO: bail on subtensor/subtensor_insert and vector.transfer_read/write
+ // that should have been already captured in destructive update patterns?
+ OpOperand &operand = *worklist[idx];
+ LLVM_DEBUG(DBGS() << "WL item: " << *operand.getOwner() << "\n");
+ // If the owner turns out to be a CallOp without
+ // `kWriteableFuncBufferArgsAttrName` this will be a noop.
+ if (isBufferizableInPlace(operand, domInfo)) {
+ LLVM_DEBUG(DBGS() << "bufferizable inplace\n");
+ setInPlaceOpResult(getMatchingOpResult(operand));
+ }
+ LLVM_DEBUG(DBGS() << "propagatedInPlace: " << *operand.getOwner() << "\n");
+ // use can have interfering reads that prevent it from being written inPlace
+ // but the values it produces are still themselves candidates for inPlace at
+ // their point of use.
+ for (Value v : operand.getOwner()->getResults()) {
+ LLVM_DEBUG(DBGS() << "propagate result: " << v << "\n");
+ for (auto &use : v.getUses()) {
+ LLVM_DEBUG(DBGS() << "add use to WL: " << use.get() << "\n");
+ worklist.push_back(&use);
+ }
+ }
+ }
+ LLVM_DEBUG(DBGS() << "\n\n");
+}
+
+static void propagateInPlace(BlockArgument &bbArg,
+ const DominanceInfo &domInfo) {
+ SmallVector<OpOperand *> worklist;
+ for (auto &use : bbArg.getUses())
+ worklist.push_back(&use);
+ propagateInPlace(worklist, domInfo);
+}
+
+/// Iterate over bbArgs of `parentOp` and determine if they are the root of a
+/// known destructive update chain. Such a destructive update is related to
+/// traditional loop nest + memory analysis but provides a simpler SSA use-def
+/// chain-based abstraction.
+static void destructiveUpdateAnalysis(Block *block,
+ const DominanceInfo &domInfo) {
+ Operation *parentOp = block->getParentOp();
+ for (BlockArgument candidate : block->getArguments()) {
+ LLVM_DEBUG(llvm::dbgs() << "\n\n");
+ LLVM_DEBUG(DBGS() << "Destructive update analysis on candidate: "
+ << candidate << "\nof:\n"
+ << *parentOp << "\n");
+
+ if (!candidate.getType().isa<ShapedType>()) {
+ LLVM_DEBUG(DBGS() << "Not a tensor\n");
+ continue;
+ }
+
+ // FuncOp arguments must be inplaceable otherwise they cannot be the root of
+ // a destructive update chain.
+ if (isa<FuncOp>(parentOp) && getInPlace(candidate) != InPlaceSpec::True) {
+ LLVM_DEBUG(DBGS() << "Not inplace\n");
+ continue;
+ }
+
+ llvm::SetVector<Operation *> slice;
+ getForwardSlice(candidate, &slice,
+ [&](Operation *op) { return op->getBlock() == block; });
+
+ LLVM_DEBUG(DBGS() << "Slice:\n");
+ for (auto *op : slice)
+ LLVM_DEBUG(DBGS() << *op << "\n");
+
+ bool failedDetectingDestructiveUpdate =
+ // func / return inplace patterns.
+ failed(detectInplaceOpToTerminator<FuncOp, ReturnOp>(
+ parentOp, candidate, slice.getArrayRef()));
+ if (failedDetectingDestructiveUpdate) {
+ LLVM_DEBUG(DBGS() << "Failed to detect a destructive update pattern\n");
+ continue;
+ }
+
+ propagateInPlace(candidate, domInfo);
+ }
+}
+
+//===----------------------------------------------------------------------===//
+// Bufferization as simple BlockAndValueMapping rewrites.
+//===----------------------------------------------------------------------===//
+
+/// Helper function for LinalgOp bufferization.
+/// Operate on mixed tensor + buffer Linalg ops for progressive bufferization.
+/// Allocate the output buffers for the remaining tensor output operands of
+/// the Linalg op. If the tensor is an "init" tensor (i.e. its value is
+/// actually used in the payload region), we additionally copy the original
+/// value into the newly allocated buffer.
+static void allocateBuffersForResults(OpBuilder &b, Location loc, LinalgOp op,
+ SmallVectorImpl<Value> &resultBuffers,
+ BlockAndValueMapping &bvm) {
+ // Take a guard before anything else.
+ OpBuilder::InsertionGuard g(b);
+
+ // Lazily compute loopRanges.
+ SmallVector<Range, 4> loopRanges;
+
+ // Linalg invariant: output tensors and result match 1-1.
+ assert(op.getNumOutputTensors() == op->getNumResults());
+ for (auto &opOperand : op.getOutputOpOperands()) {
+ Value output = opOperand.get();
+ if (output.getType().isa<MemRefType>()) {
+ resultBuffers.push_back(output);
+ continue;
+ }
+
+ // If output tensor is marked inPlace, just use the buffer.
+ // The following uses internal knowledge of the position of tied operand /
+ // results.
+ OpResult tiedResult = getMatchingOpResult(op, opOperand);
+ if (getInPlace(tiedResult) == InPlaceSpec::True) {
+ resultBuffers.push_back(lookup(bvm, output));
+ continue;
+ }
+
+ Value dimTensor = bvm.lookupOrDefault(output);
+ Value alloc = createNewAllocDeallocPairForShapedValue(b, loc, dimTensor);
+ b.setInsertionPointAfter(alloc.getDefiningOp());
+ resultBuffers.push_back(alloc);
+
+ // Additionally, if the output buffer is used, clone its value for now.
+ if (op.payloadUsesValueFromOpOperand(&opOperand))
+ b.create<CopyOp>(loc, lookup(bvm, output), alloc);
+ }
+ if (op->getNumResults())
+ map(bvm, op->getResults(), resultBuffers);
+}
+
+static void finalizeBufferAllocation(OpBuilder &b, LinalgOp op,
+ ValueRange inputs, ValueRange outputs,
+ BlockAndValueMapping &bvm) {
+ SmallVector<Value, 8> newOperands = inputs;
+ newOperands.append(outputs.begin(), outputs.end());
+ auto otherOperands = op.getAssumedNonShapedOperands();
+ newOperands.append(otherOperands.begin(), otherOperands.end());
+ Location loc = op.getLoc();
+ op.clone(b, loc, /*resultTypes=*/TypeRange{}, newOperands);
+
+ // Replace the results of the old op with the new output buffers.
+ if (op->getNumResults())
+ map(bvm, op->getResults(), outputs);
+ if (!op.hasTensorSemantics())
+ op->erase();
+}
+
+/// Generic conversion for any LinalgOp.
+/// Operate on mixed tensor + buffer Linalg ops for progressive bufferization.
+static LogicalResult convertAnyLinalgOp(OpBuilder &b, LinalgOp op,
+ BlockAndValueMapping &bvm) {
+ // Take a guard before anything else.
+ OpBuilder::InsertionGuard g(b);
+
+ if (op.hasBufferSemantics())
+ return failure();
+
+ LLVM_DEBUG(DBGS() << "convert: " << *op << "\n");
+
+ b.setInsertionPoint(op);
+ Location loc = op.getLoc();
+ SmallVector<Value, 2> newInputBuffers;
+ newInputBuffers.reserve(op.getNumInputs());
+ for (Value v : op.getInputs())
+ newInputBuffers.push_back(lookup(bvm, v));
+ SmallVector<Value, 2> newOutputBuffers;
+ allocateBuffersForResults(b, loc, op, newOutputBuffers, bvm);
+ finalizeBufferAllocation(b, op, newInputBuffers, newOutputBuffers, bvm);
+ return success();
+}
+
+/// DimOp tensor operand is modified inplace. This allows leaving dead tensors
+/// behind that will get DCE'd.
+static LogicalResult convertDimOp(OpBuilder &b, memref::DimOp dimOp,
+ BlockAndValueMapping &bvm) {
+ if (dimOp.memrefOrTensor().getType().isa<RankedTensorType>())
+ dimOp.memrefOrTensorMutable().assign(lookup(bvm, dimOp.memrefOrTensor()));
+ return success();
+}
+
+/// FuncOp always creates TensorToMemRef ops.
+static LogicalResult convertFuncOp(OpBuilder &b, FuncOp funcOp,
+ BlockAndValueMapping &bvm) {
+ // Take a guard before anything else.
+ OpBuilder::InsertionGuard g(b);
+ b.setInsertionPointToStart(&funcOp.body().front());
+ for (auto bbArg : funcOp.getArguments()) {
+ auto tensorType = bbArg.getType().dyn_cast<TensorType>();
+ if (!tensorType)
+ continue;
+ auto rankedTensorType = tensorType.dyn_cast<RankedTensorType>();
+ // Cast the tensor to the most dynamic buffer possible. Further
+ // canonicalizations will clean up.
+ Type memRefType = rankedTensorType
+ ? getDynamicMemRefType(rankedTensorType)
+ : getContiguousOrUnrankedMemRefType(tensorType);
+ Value tensorToMemref =
+ b.create<memref::BufferCastOp>(funcOp.getLoc(), memRefType, bbArg);
+ map(bvm, bbArg, tensorToMemref);
+ }
+ return success();
+}
+
+/// ReturnOp always creates memref::TensorLoadOp.
+static LogicalResult convertReturnOp(OpBuilder &b, ReturnOp returnOp,
+ BlockAndValueMapping &bvm) {
+ // Take a guard before anything else.
+ OpBuilder::InsertionGuard g(b);
+ b.setInsertionPoint(returnOp);
+
+ FuncOp funcOp = cast<FuncOp>(returnOp->getParentOp());
+ assert(funcOp && "only support FuncOp parent for ReturnOp");
+ for (OpOperand &operand : returnOp->getOpOperands()) {
+ auto tensorType = operand.get().getType().dyn_cast<TensorType>();
+ if (!tensorType)
+ continue;
+ operand.set(b.create<memref::TensorLoadOp>(returnOp.getLoc(),
+ lookup(bvm, operand.get())));
+ }
+ return success();
+}
+
+static void inPlaceAnalysisFuncOpInternals(FuncOp funcOp,
+ const DominanceInfo &domInfo) {
+ assert(funcOp && funcOp->getNumRegions() > 0 && !funcOp.body().empty() &&
+ "expected a funcOp definition with a body");
+
+ // Start propagating from FuncOp bbArgs.
+ destructiveUpdateAnalysis(&funcOp.body().front(), domInfo);
+}
+
+static LogicalResult bufferizeFuncOpInternals(
+ FuncOp funcOp, BlockAndValueMapping &bvm,
+ const DenseMap<FuncOp, SmallVector<int64_t>> &tiedResultsMap) {
+ OpBuilder b(funcOp->getContext());
+ /// Start by converting `funcOp` arguments.
+ if (failed(convertFuncOp(b, funcOp, bvm)))
+ return failure();
+ WalkResult result = funcOp.walk<WalkOrder::PreOrder>([&](Operation *op) {
+ LogicalResult status =
+ llvm::TypeSwitch<Operation *, LogicalResult>(op)
+ // Skip BufferCast and TensorLoad ops.
+ .Case<memref::BufferCastOp, memref::TensorLoadOp>(
+ [&](auto) { return success(); })
+ .Case([&](memref::DimOp op) { return convertDimOp(b, op, bvm); })
+ .Case([&](LinalgOp op) { return convertAnyLinalgOp(b, op, bvm); })
+ .Case([&](ReturnOp op) { return convertReturnOp(b, op, bvm); })
+ .Default([&](Operation *op) {
+ auto isaTensor = [](Type t) { return t.isa<TensorType>(); };
+ if (llvm::any_of(op->getOperandTypes(), isaTensor) ||
+ llvm::any_of(op->getResultTypes(), isaTensor))
+ return failure();
+ return success();
+ });
+ if (failed(status)) {
+ op->emitError("Failed bufferization");
+ return WalkResult::interrupt();
+ }
+ return WalkResult::advance();
+ });
+ if (result.wasInterrupted())
+ return failure();
+ return success();
+}
+
+namespace {
+struct LinalgComprehensiveFuncBufferize
+ : public LinalgComprehensiveFuncBufferizeBase<
+ LinalgComprehensiveFuncBufferize> {
+ void runOnFunction() override;
+
+ void getDependentDialects(DialectRegistry ®istry) const override {
+ registry.insert<linalg::LinalgDialect, memref::MemRefDialect>();
+ }
+};
+} // end namespace
+
+void LinalgComprehensiveFuncBufferize::runOnFunction() {
+ auto funcOp = getFunction();
+ DominanceInfo domInfo(funcOp);
+ BlockAndValueMapping bvm;
+ DenseMap<FuncOp, SmallVector<int64_t>> tiedResultsMap;
+ inPlaceAnalysisFuncOpInternals(funcOp, domInfo);
+
+ LLVM_DEBUG(DBGS() << "Begin BufferizeFuncOpInternals:\n" << funcOp << "\n");
+ auto guard = llvm::make_scope_exit([&] {
+ funcOp.walk(
+ [&](Operation *op) { op->removeAttr(kInPlaceResultsAttrName); });
+ LLVM_DEBUG(DBGS() << "End BufferizeFuncOpInternals:\n" << funcOp << "\n");
+ });
+ if (failed(bufferizeFuncOpInternals(funcOp, bvm, tiedResultsMap)))
+ signalPassFailure();
+}
+
+std::unique_ptr<Pass> mlir::createLinalgComprehensiveFuncBufferizePass() {
+ return std::make_unique<LinalgComprehensiveFuncBufferize>();
+}
diff --git a/mlir/test/Dialect/Linalg/comprehensive-func-bufferize.mlir b/mlir/test/Dialect/Linalg/comprehensive-func-bufferize.mlir
new file mode 100644
index 0000000000000..69c4e3fe59196
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/comprehensive-func-bufferize.mlir
@@ -0,0 +1,83 @@
+// RUN: mlir-opt %s -linalg-comprehensive-func-bufferize -split-input-file | FileCheck %s
+
+// CHECK-DAG: #[[$map_2d_dyn:.*]] = affine_map<(d0)[s0, s1] -> (d0 * s1 + s0)>
+
+// CHECK-LABEL: func @fill_inplace(
+// CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: tensor<?xf32> {linalg.inplaceable = true})
+func @fill_inplace(%A : tensor<?xf32> {linalg.inplaceable = true}) -> tensor<?xf32> {
+ // CHECK: %[[I:.*]] = memref.buffer_cast %[[A]] : memref<?xf32, #[[$map_2d_dyn]]>
+
+ // CHECK: %[[F0:.*]] = constant 0.000000e+00 : f32
+ %f0 = constant 0.0 : f32
+
+ /// Inplaceable, no alloc
+ // CHECK-NOT: alloc
+ // CHECK: linalg.fill(%[[I]], %[[F0]]) : memref<?xf32, #[[$map_2d_dyn]]>, f32
+ %r = linalg.fill(%A, %f0) : tensor<?xf32>, f32 -> tensor<?xf32>
+
+ // CHECK: %[[R:.*]] = memref.tensor_load %[[I]] : memref<?xf32, #[[$map_2d_dyn]]>
+ // CHECK: return %[[R]] : tensor<?xf32>
+ return %r: tensor<?xf32>
+}
+
+// -----
+
+// CHECK-DAG: #[[$map_2d_dyn:.*]] = affine_map<(d0)[s0, s1] -> (d0 * s1 + s0)>
+
+/// No linalg.inplaceable flag, must allocate.
+// CHECK-LABEL: func @not_inplace(
+// CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: tensor<?xf32>)
+func @not_inplace(%A : tensor<?xf32>) -> tensor<?xf32> {
+ // CHECK: %[[I:.*]] = memref.buffer_cast %[[A]] : memref<?xf32, #[[$map_2d_dyn]]>
+
+ // CHECK: %[[D0:.*]] = memref.dim %[[I]], {{.*}} : memref<?xf32, #[[$map_2d_dyn]]>
+ // CHECK: %[[ALLOC:.*]] = memref.alloc(%[[D0]]) : memref<?xf32>
+ // CHECK: %[[I2:.*]] = memref.cast %[[ALLOC]] : memref<?xf32> to memref<?xf32, #map>
+
+ // CHECK: %[[F0:.*]] = constant 0.000000e+00 : f32
+ %f0 = constant 0.0 : f32
+
+ // CHECK: linalg.fill(%[[I2]], %[[F0]]) : memref<?xf32, #[[$map_2d_dyn]]>, f32
+ %r = linalg.fill(%A, %f0) : tensor<?xf32>, f32 -> tensor<?xf32>
+
+ // CHECK: dealloc %[[ALLOC]] : memref<?xf32>
+ // CHECK: %[[R:.*]] = memref.tensor_load %[[I2]] : memref<?xf32, #[[$map_2d_dyn]]>
+ // CHECK: return %[[R]] : tensor<?xf32>
+ return %r: tensor<?xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @not_inplace
+// CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: tensor<?x?xf32>
+func @not_inplace(%A : tensor<?x?xf32> {linalg.inplaceable = true}) -> tensor<?x?xf32> {
+ %f0 = constant 0.0 : f32
+
+ // CHECK: %[[BUFFER_CAST:.*]] = memref.buffer_cast %[[A]] : memref<?x?xf32
+
+ /// Cross-op multiple uses of %A, the first op which has interfering reads must alloc.
+ // CHECK: %[[ALLOC:.*]] = memref.alloc
+ // CHECK: %[[CAST:.*]] = memref.cast %[[ALLOC]]
+ // CHECK: linalg.fill(%[[CAST]]
+ %f = linalg.fill(%A, %f0) : tensor<?x?xf32>, f32 -> tensor<?x?xf32>
+
+ /// The second op has no interfering reads and can reuse.
+ // CHECK-NOT: alloc
+ // CHECK: linalg.matmul{{.*}}outs(%[[BUFFER_CAST]]
+ %r = linalg.matmul ins(%f, %f: tensor<?x?xf32>, tensor<?x?xf32>)
+ outs(%A: tensor<?x?xf32>)
+ -> tensor<?x?xf32>
+ return %r: tensor<?x?xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @not_inplace
+func @not_inplace(%A : tensor<?x?xf32> {linalg.inplaceable = true}) -> tensor<?x?xf32> {
+ /// Within op multiple uses of %A, must alloc.
+ // CHECK: alloc
+ %r = linalg.matmul ins(%A, %A: tensor<?x?xf32>, tensor<?x?xf32>)
+ outs(%A: tensor<?x?xf32>)
+ -> tensor<?x?xf32>
+ return %r: tensor<?x?xf32>
+}
More information about the Mlir-commits
mailing list