[Mlir-commits] [mlir] bb273a3 - [mlir][linalg][bufferize][NFC] Move tensor interface impl to new build target
Matthias Springer
llvmlistbot at llvm.org
Wed Nov 24 01:25:30 PST 2021
Author: Matthias Springer
Date: 2021-11-24T18:25:17+09:00
New Revision: bb273a35a02a00dbba8549e858df310f4b6a32b1
URL: https://github.com/llvm/llvm-project/commit/bb273a35a02a00dbba8549e858df310f4b6a32b1
DIFF: https://github.com/llvm/llvm-project/commit/bb273a35a02a00dbba8549e858df310f4b6a32b1.diff
LOG: [mlir][linalg][bufferize][NFC] Move tensor interface impl to new build target
This makes ComprehensiveBufferize entirely independent of the tensor dialect.
Differential Revision: https://reviews.llvm.org/D114217
Added:
mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.h
mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp
Modified:
mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h
mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.h
mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp
mlir/lib/Dialect/Linalg/ComprehensiveBufferize/CMakeLists.txt
mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp
mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp
utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h
index 881f1edb11c47..5ce675101c5aa 100644
--- a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h
+++ b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h
@@ -322,6 +322,26 @@ struct PostAnalysisStep {
SmallVector<Operation *> &newOps) = 0;
};
+/// Return a contiguous MemRefType (i.e. with canonical/empty layout map)
+/// with the same shape as `shapedType` and specified `layout` and
+/// `addressSpace`.
+MemRefType getContiguousMemRefType(ShapedType shapedType,
+ MemRefLayoutAttrInterface layout = {},
+ Attribute memorySpace = {});
+
+/// Return a contiguous MemRefType (i.e. with canonical/empty layout map)
+/// with the same shape as `shapedType` and specified `layout` and
+/// `addressSpace` or an UnrankedMemRefType otherwise.
+Type getContiguousOrUnrankedMemRefType(Type type,
+ MemRefLayoutAttrInterface layout = {},
+ Attribute memorySpace = {});
+
+/// Return a MemRefType to which the `tensorType` can be bufferized in a
+/// composable fashion. The layout must be the most dynamic possible and
+/// canonicalize away once bufferization is finished.
+MemRefType getDynamicMemRefType(RankedTensorType tensorType,
+ unsigned addressSpace = 0);
+
} // namespace comprehensive_bufferize
} // namespace linalg
} // namespace mlir
diff --git a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.h b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.h
index e225baa498a4d..d5ee20bf33929 100644
--- a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.h
+++ b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.h
@@ -1,3 +1,11 @@
+//===- LinalgInterfaceImpl.h - Linalg Impl. of BufferizableOpInterface ----===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
#ifndef MLIR_DIALECT_LINALG_COMPREHENSIVEBUFFERIZE_LINALG_INTERFACE_IMPL_H
#define MLIR_DIALECT_LINALG_COMPREHENSIVEBUFFERIZE_LINALG_INTERFACE_IMPL_H
diff --git a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.h b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.h
new file mode 100644
index 0000000000000..29355ef338f3a
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.h
@@ -0,0 +1,27 @@
+//===- LinalgInterfaceImpl.h - Linalg Impl. of BufferizableOpInterface ----===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_LINALG_COMPREHENSIVEBUFFERIZE_TENSOR_INTERFACE_IMPL_H
+#define MLIR_DIALECT_LINALG_COMPREHENSIVEBUFFERIZE_TENSOR_INTERFACE_IMPL_H
+
+namespace mlir {
+
+class DialectRegistry;
+
+namespace linalg {
+namespace comprehensive_bufferize {
+namespace tensor_ext {
+
+void registerBufferizableOpInterfaceExternalModels(DialectRegistry ®istry);
+
+} // namespace tensor_ext
+} // namespace comprehensive_bufferize
+} // namespace linalg
+} // namespace mlir
+
+#endif // MLIR_DIALECT_LINALG_COMPREHENSIVEBUFFERIZE_TENSOR_INTERFACE_IMPL_H
diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp
index fc9f414f7cd9d..a530139f1e507 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp
@@ -12,6 +12,7 @@
#include "mlir/IR/BlockAndValueMapping.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/Operation.h"
+#include "mlir/IR/TypeUtilities.h"
#include "mlir/IR/Value.h"
#include "llvm/Support/Debug.h"
@@ -528,3 +529,31 @@ void mlir::linalg::comprehensive_bufferize::BufferizationState::
op->erase();
obsoleteOps.clear();
}
+
+MemRefType mlir::linalg::comprehensive_bufferize::getContiguousMemRefType(
+ ShapedType shapedType, MemRefLayoutAttrInterface layout,
+ Attribute memorySpace) {
+ return MemRefType::get(shapedType.getShape(), shapedType.getElementType(),
+ layout, memorySpace);
+}
+
+Type mlir::linalg::comprehensive_bufferize::getContiguousOrUnrankedMemRefType(
+ Type type, MemRefLayoutAttrInterface layout, Attribute memorySpace) {
+ if (type.isa<RankedTensorType, MemRefType>())
+ return getContiguousMemRefType(type.cast<ShapedType>(), layout,
+ memorySpace);
+ assert(!layout && "expected empty layout with UnrankedMemRefType");
+ return UnrankedMemRefType::get(getElementTypeOrSelf(type), memorySpace);
+}
+
+MemRefType mlir::linalg::comprehensive_bufferize::getDynamicMemRefType(
+ RankedTensorType tensorType, unsigned addressSpace) {
+ // TODO: address space decisions to connect with the actual alloc.
+ int64_t dynamicOffset = ShapedType::kDynamicStrideOrOffset;
+ SmallVector<int64_t> dynamicStrides(tensorType.getRank(),
+ ShapedType::kDynamicStrideOrOffset);
+ AffineMap stridedLayout = makeStridedLinearLayoutMap(
+ dynamicStrides, dynamicOffset, tensorType.getContext());
+ return MemRefType::get(tensorType.getShape(), tensorType.getElementType(),
+ stridedLayout, addressSpace);
+}
diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/CMakeLists.txt b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/CMakeLists.txt
index e55653464155f..80708504c9ed5 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/CMakeLists.txt
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/CMakeLists.txt
@@ -2,6 +2,7 @@ set(LLVM_OPTIONAL_SOURCES
BufferizableOpInterface.cpp
ComprehensiveBufferize.cpp
LinalgInterfaceImpl.cpp
+ TensorInterfaceImpl.cpp
)
add_mlir_dialect_library(MLIRBufferizableOpInterface
@@ -25,6 +26,16 @@ add_mlir_dialect_library(MLIRLinalgBufferizableOpInterfaceImpl
MLIRTensor
)
+add_mlir_dialect_library(MLIRTensorBufferizableOpInterfaceImpl
+ TensorInterfaceImpl.cpp
+
+ LINK_LIBS PUBLIC
+ MLIRBufferizableOpInterface
+ MLIRIR
+ MLIRMemRef
+ MLIRTensor
+)
+
add_mlir_dialect_library(MLIRComprehensiveBufferize
ComprehensiveBufferize.cpp
@@ -37,6 +48,5 @@ add_mlir_dialect_library(MLIRComprehensiveBufferize
MLIRSCF
MLIRStandard
MLIRStandardOpsTransforms
- MLIRTensor
MLIRVector
)
diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp
index d062bbab4ad0c..7eecbc4f95331 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp
@@ -587,45 +587,6 @@ getEquivalentEnclosingFuncBBArg(Value v,
// Bufferization-specific MemRefType support.
//===----------------------------------------------------------------------===//
-/// Return a contiguous MemRefType (i.e. with canonical/empty layout map)
-/// with the same shape as `shapedType` and specified `layout` and
-/// `addressSpace`.
-static MemRefType getContiguousMemRefType(ShapedType shapedType,
- MemRefLayoutAttrInterface layout = {},
- Attribute memorySpace = {}) {
- return MemRefType::get(shapedType.getShape(), shapedType.getElementType(),
- layout, memorySpace);
-}
-
-/// Return a contiguous MemRefType (i.e. with canonical/empty layout map)
-/// with the same shape as `shapedType` and specified `layout` and
-/// `addressSpace` or an UnrankedMemRefType otherwise.
-static Type
-getContiguousOrUnrankedMemRefType(Type type,
- MemRefLayoutAttrInterface layout = {},
- Attribute memorySpace = {}) {
- if (type.isa<RankedTensorType, MemRefType>())
- return getContiguousMemRefType(type.cast<ShapedType>(), layout,
- memorySpace);
- assert(!layout && "expected empty layout with UnrankedMemRefType");
- return UnrankedMemRefType::get(getElementTypeOrSelf(type), memorySpace);
-}
-
-/// Return a MemRefType to which the `tensorType` can be bufferized in a
-/// composable fashion. The layout must be the most dynamic possible and
-/// canonicalize away once bufferization is finished.
-static MemRefType getDynamicMemRefType(RankedTensorType tensorType,
- unsigned addressSpace = 0) {
- // TODO: address space decisions to connect with the actual alloc.
- int64_t dynamicOffset = ShapedType::kDynamicStrideOrOffset;
- SmallVector<int64_t> dynamicStrides(tensorType.getRank(),
- ShapedType::kDynamicStrideOrOffset);
- AffineMap stridedLayout = makeStridedLinearLayoutMap(
- dynamicStrides, dynamicOffset, tensorType.getContext());
- return MemRefType::get(tensorType.getShape(), tensorType.getElementType(),
- stridedLayout, addressSpace);
-}
-
/// Return the FunctionType with `argumentTypes` and `resultTypes` where each
/// tensor is replaced by the corresponding buffer type.
/// In order for all the callers to agree, this *must* bufferize to the most
@@ -1965,420 +1926,6 @@ struct ReturnOpInterface
} // namespace std_ext
-namespace tensor_ext {
-
-struct CastOpInterface
- : public BufferizableOpInterface::ExternalModel<CastOpInterface,
- tensor::CastOp> {
- bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand) const {
- return false;
- }
-
- bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand) const {
- return false;
- }
-
- SmallVector<OpOperand *> getAliasingOpOperand(Operation *op,
- OpResult opResult) const {
- return {&op->getOpOperand(0)};
- }
-
- OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand) const {
- return op->getResult(0);
- }
-
- BufferRelation bufferRelation(Operation *op, OpOperand &opOperand) const {
- return BufferRelation::Equivalent;
- }
-
- LogicalResult bufferize(Operation *op, OpBuilder &b,
- BufferizationState &state) const {
- auto castOp = cast<tensor::CastOp>(op);
-
- // Take a guard before anything else.
- OpBuilder::InsertionGuard g(b);
- b.setInsertionPoint(castOp);
-
- Value resultBuffer = getResultBuffer(b, castOp->getResult(0), state);
- if (!resultBuffer)
- return failure();
- Type sourceType = resultBuffer.getType();
- auto rankedMemRefType = sourceType.dyn_cast<MemRefType>();
- auto unrankedMemRefType = sourceType.dyn_cast<UnrankedMemRefType>();
- assert(rankedMemRefType || unrankedMemRefType);
- Attribute memorySpace = rankedMemRefType
- ? rankedMemRefType.getMemorySpace()
- : unrankedMemRefType.getMemorySpace();
- TensorType tensorType = castOp.getResult().getType().cast<TensorType>();
- MemRefLayoutAttrInterface layout =
- rankedMemRefType && tensorType.isa<RankedTensorType>()
- ? rankedMemRefType.getLayout()
- : MemRefLayoutAttrInterface();
- Type memRefType = getContiguousOrUnrankedMemRefType(
- castOp.getResult().getType(), layout, memorySpace);
- Value res =
- b.create<memref::CastOp>(castOp.getLoc(), memRefType, resultBuffer);
- state.aliasInfo.insertNewBufferEquivalence(res, castOp.getResult());
- state.mapBuffer(castOp.getResult(), res);
- return success();
- }
-};
-
-struct DimOpInterface
- : public BufferizableOpInterface::ExternalModel<DimOpInterface,
- tensor::DimOp> {
- bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand) const {
- return true;
- }
-
- bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand) const {
- return false;
- }
-
- OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand) const {
- return OpResult();
- }
-
- LogicalResult bufferize(Operation *op, OpBuilder &b,
- BufferizationState &state) const {
- auto dimOp = cast<tensor::DimOp>(op);
-
- // Take a guard before anything else.
- OpBuilder::InsertionGuard g(b);
- b.setInsertionPoint(dimOp);
-
- if (dimOp.source().getType().isa<RankedTensorType>()) {
- Value v = state.lookupBuffer(dimOp.source());
- dimOp.result().replaceAllUsesWith(
- b.create<memref::DimOp>(dimOp.getLoc(), v, dimOp.index()));
- }
- return success();
- }
-};
-
-struct ExtractSliceOpInterface
- : public BufferizableOpInterface::ExternalModel<ExtractSliceOpInterface,
- tensor::ExtractSliceOp> {
- bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand) const {
- return false;
- }
-
- bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand) const {
- return false;
- }
-
- SmallVector<OpOperand *> getAliasingOpOperand(Operation *op,
- OpResult opResult) const {
- return {&op->getOpOperand(0) /*source*/};
- }
-
- OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand) const {
- return &opOperand == &op->getOpOperand(0) /*source*/
- ? op->getResult(0)
- : OpResult();
- }
-
- BufferRelation bufferRelation(Operation *op, OpOperand &opOperand) const {
- return BufferRelation::None;
- }
-
- LogicalResult bufferize(Operation *op, OpBuilder &b,
- BufferizationState &state) const {
- auto extractSliceOp = cast<tensor::ExtractSliceOp>(op);
- LDBG("bufferize: " << *extractSliceOp << '\n');
-
- // Take a guard before anything else.
- OpBuilder::InsertionGuard g(b);
- b.setInsertionPoint(extractSliceOp);
-
- Location loc = extractSliceOp.getLoc();
- Value srcMemref = state.lookupBuffer(extractSliceOp.source());
- auto srcMemrefType = srcMemref.getType().cast<MemRefType>();
- auto dstTensorType =
- extractSliceOp.result().getType().cast<RankedTensorType>();
-
- // If not inplaceable, alloc.
- bool inplace = state.aliasInfo.isInPlace(extractSliceOp->getResult(0));
- Value alloc;
- if (!inplace)
- alloc = createNewAllocDeallocPairForShapedValue(
- b, loc, extractSliceOp.result(), state);
-
- // Bufferize to subview.
- auto subviewMemRefType =
- memref::SubViewOp::inferRankReducedResultType(
- dstTensorType.getRank(), srcMemrefType,
- extractSliceOp.getMixedOffsets(), extractSliceOp.getMixedSizes(),
- extractSliceOp.getMixedStrides())
- .cast<MemRefType>();
- Value subView = b.create<memref::SubViewOp>(
- loc, subviewMemRefType, srcMemref, extractSliceOp.getMixedOffsets(),
- extractSliceOp.getMixedSizes(), extractSliceOp.getMixedStrides());
- // Insert new alias.
- state.aliasInfo.insertNewBufferAlias(subView, srcMemref);
-
- /// If not inplaceable, copy.
- if (!inplace) {
- // Do not copy if the copied data is never read.
- if (isValueRead(extractSliceOp.result()))
- state.allocationFns.memCpyFn(b, extractSliceOp.getLoc(), subView,
- alloc);
- subView = alloc;
- }
-
- state.mapBuffer(extractSliceOp.result(), subView);
- return success();
- }
-};
-
-struct ExtractOpInterface
- : public BufferizableOpInterface::ExternalModel<ExtractOpInterface,
- tensor::ExtractOp> {
- bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand) const {
- return true;
- }
-
- bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand) const {
- return false;
- }
-
- OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand) const {
- return OpResult();
- }
-
- LogicalResult bufferize(Operation *op, OpBuilder &b,
- BufferizationState &state) const {
- auto extractOp = cast<tensor::ExtractOp>(op);
-
- // Take a guard before anything else.
- OpBuilder::InsertionGuard g(b);
- b.setInsertionPoint(extractOp);
-
- Location loc = extractOp.getLoc();
- Value srcMemref = state.lookupBuffer(extractOp.tensor());
- Value l = b.create<memref::LoadOp>(loc, srcMemref, extractOp.indices());
- extractOp.replaceAllUsesWith(l);
- return success();
- }
-};
-
-/// Return true if the (ExtractSliceOp, InsertSliceOp) pair match (i.e.
-/// equivalent operand / result and same offset/sizes/strides specification).
-///
-/// This is one particular type of relationship between ops on tensors that
-/// reduce to an equivalence on buffers. This should be generalized and
-/// exposed as interfaces on the proper types.
-static bool
-areEquivalentExtractSliceOps(const BufferizationAliasInfo &aliasInfo,
- ExtractSliceOp st, InsertSliceOp sti) {
- if (!st || !sti)
- return false;
- if (!aliasInfo.areEquivalentBufferizedValues(st.source(), sti.dest()))
- return false;
- if (!sameOffsetsSizesAndStrides(st, sti, isEqualConstantIntOrValue))
- return false;
- return true;
-}
-
-/// Return true if the source of a `insertSliceOp` bufferizes to an
-/// equivalent ExtractSliceOp that bufferizes inplace.
-static bool isSourceEquivalentToAMatchingInplaceExtractSliceOp(
- const BufferizationAliasInfo &aliasInfo, InsertSliceOp insertSliceOp) {
- LDBG("isSourceEquivalentToAMatchingInplaceExtractSliceOp: " << *insertSliceOp
- << '\n');
- bool foundOp = false;
- aliasInfo.applyOnEquivalenceClass(insertSliceOp.source(), [&](Value value) {
- auto extractSliceOp = value.getDefiningOp<ExtractSliceOp>();
- if (extractSliceOp &&
- areEquivalentExtractSliceOps(aliasInfo, extractSliceOp,
- insertSliceOp) &&
- aliasInfo.isInPlace(extractSliceOp->getResult(0))) {
- LDBG("\tfound: " << extractSliceOp.getOperation() << '\n');
- foundOp = true;
- }
- });
-
- if (!foundOp)
- LDBG("\tnot equivalent\n");
-
- return foundOp;
-}
-
-/// Return true if `value` is originating from an ExtractSliceOp that matches
-/// the given InsertSliceOp.
-static bool hasMatchingExtractSliceOp(const BufferizationAliasInfo &aliasInfo,
- Value value, InsertSliceOp insertOp) {
- auto condition = [&](Value val) {
- if (auto extractOp = val.getDefiningOp<ExtractSliceOp>())
- if (areEquivalentExtractSliceOps(aliasInfo, extractOp, insertOp))
- return true;
- return false;
- };
-
- return llvm::all_of(findValueInReverseUseDefChain(value, condition),
- condition);
-}
-
-struct InsertSliceOpInterface
- : public BufferizableOpInterface::ExternalModel<InsertSliceOpInterface,
- tensor::InsertSliceOp> {
- bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand) const {
- return true;
- }
-
- bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand) const {
- return &opOperand == &op->getOpOperand(1) /*dest*/;
- }
-
- SmallVector<OpOperand *> getAliasingOpOperand(Operation *op,
- OpResult opResult) const {
- return {&op->getOpOperand(1) /*dest*/};
- }
-
- OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand) const {
- return &opOperand == &op->getOpOperand(1) /*dest*/
- ? op->getResult(0)
- : OpResult();
- }
-
- BufferRelation bufferRelation(Operation *op, OpOperand &opOperand) const {
- return BufferRelation::Equivalent;
- }
-
- bool isNotConflicting(Operation *op, OpOperand *uRead,
- OpOperand *uConflictingWrite,
- const BufferizationAliasInfo &aliasInfo) const {
- Operation *readingOp = uRead->getOwner();
- Operation *conflictingWritingOp = uConflictingWrite->getOwner();
-
- // Special rules for matching ExtractSliceOp/InsertSliceOp pairs. If
- // uRead is an InsertSliceOp...
- if (auto insertSliceOp = dyn_cast<InsertSliceOp>(readingOp)) {
- // As an example, consider the following IR.
- //
- // %0 = tensor.extract_slice %t[%a, %b][%c, %d][1, 1] {inplace = [true] }
- // %1 = linalg.fill %cst, %0 {inplace= [true] }
- // %2 = tensor.insert_slice %1 into %t[%a, %b][%c, %d][1, 1]
- // {inplace= [true] }
-
- // TODO: Use insertSliceOp.getDestOpOperand etc. when available.
- if (uRead == &insertSliceOp->getOpOperand(1) /*dest*/ &&
- hasMatchingExtractSliceOp(aliasInfo, uConflictingWrite->get(),
- insertSliceOp))
- // Case 1: The main insight is that InsertSliceOp reads only part of
- // the destination tensor. The overwritten area is not read. If
- // uConflictingWrite writes into exactly the memory location that is
- // being read by uRead, this is not a conflict.
- //
- // In the above example:
- // uRead = OpOperand 1 (%t) of tensor.insert_slice
- // uConflictingWrite = OpOperand 1 (%0) of linalg.fill
- //
- // The read of %t does not conflict with the write of the FillOp
- // (same aliases!) because the area that the FillOp operates on is
- // exactly the one that is *not* read via %t.
- return true;
-
- if (uRead == &insertSliceOp->getOpOperand(0) /*source*/ &&
- uConflictingWrite == &insertSliceOp->getOpOperand(1) /*dest*/ &&
- hasMatchingExtractSliceOp(aliasInfo, uRead->get(), insertSliceOp))
- // Case 2: The read of the source tensor and the write to the dest
- // tensor via an InsertSliceOp is not a conflict if the read is
- // reading exactly that part of an equivalent tensor that the
- // InsertSliceOp is writing.
- //
- // In the above example:
- // uRead = OpOperand 0 (%1) of tensor.insert_slice
- // uConflictingWrite = OpOperand 1 (%t) of tensor.insert_slice
- return true;
- }
-
- // If uConflictingWrite is an InsertSliceOp...
- if (auto insertSliceOp = dyn_cast<InsertSliceOp>(conflictingWritingOp))
- // As an example, consider the following IR.
- //
- // %0 = tensor.extract_slice %t[%a, %b][%c, %d][1, 1] {inplace = [true] }
- // %1 = linalg.fill %cst, %0 {inplace= [true] }
- // %2 = tensor.insert_slice %1 into %t[%a, %b][%c, %d][1, 1]
- // {inplace= [true] }
- // %3 = vector.transfer_read %1, %cst
- //
- // In the above example:
- // uRead = OpOperand 0 (%1) of vector.transfer_read
- // uConflictingWrite = OpOperand 1 (%t) of tensor.insert_slice
- // lastWrite = %1
- //
- // This is not a conflict because the InsertSliceOp overwrites the
- // memory segment of %1 with the exact same data. (Effectively, there
- // is no memory write here.)
- if (uConflictingWrite == &insertSliceOp->getOpOperand(1) /*dest*/ &&
- aliasInfo.areEquivalentBufferizedValues(uRead->get(),
- insertSliceOp.source()) &&
- hasMatchingExtractSliceOp(aliasInfo, insertSliceOp.source(),
- insertSliceOp))
- return true;
-
- return false;
- }
-
- LogicalResult bufferize(Operation *op, OpBuilder &b,
- BufferizationState &state) const {
- // insert_slice ops arise from tiling and bufferizing them out-of-place is
- // generally a deal breaker. When used with loops, this ends up cloning the
- // whole tensor on every single iteration and is a symptom of a
- // catastrophically bad scheduling decision.
- // TODO: be very loud about it or even consider failing the pass.
- auto insertSliceOp = cast<tensor::InsertSliceOp>(op);
- LDBG("bufferize: " << *insertSliceOp << '\n');
-
- // Take a guard before anything else.
- OpBuilder::InsertionGuard g(b);
- b.setInsertionPoint(insertSliceOp);
- Location loc = insertSliceOp.getLoc();
-
- // When bufferizing out-of-place, `getResultBuffer` allocates.
- Value dstMemref = getResultBuffer(b, insertSliceOp->getResult(0), state);
- if (!dstMemref)
- return failure();
-
- // A copy of the source buffer is needed if either:
- // - The producer of `source` is not inplace. This is the case where a
- // slice is computed out of place into the inplace full tensor.
- // - The result is not inplace. This is the case where the whole tensor is
- // cloned and the clone needs to be updated.
- // TODO: Is this necessary?
- bool needCopy = !isSourceEquivalentToAMatchingInplaceExtractSliceOp(
- state.aliasInfo, insertSliceOp) ||
- !state.aliasInfo.isInPlace(insertSliceOp->getResult(0));
- if (needCopy) {
- LDBG("insert_slice needs extra source copy: " << insertSliceOp.source()
- << " -> copy\n");
- // Take a subview of the dst.
- auto dstMemrefType = dstMemref.getType().cast<MemRefType>();
- auto subviewMemRefType =
- memref::SubViewOp::inferRankReducedResultType(
- insertSliceOp.getSourceType().getRank(), dstMemrefType,
- insertSliceOp.getMixedOffsets(), insertSliceOp.getMixedSizes(),
- insertSliceOp.getMixedStrides())
- .cast<MemRefType>();
- Value subView = b.create<memref::SubViewOp>(
- loc, subviewMemRefType, dstMemref, insertSliceOp.getMixedOffsets(),
- insertSliceOp.getMixedSizes(), insertSliceOp.getMixedStrides());
- // Insert new alias.
- state.aliasInfo.insertNewBufferAlias(subView, dstMemref);
- // Copy tensor.
- Value srcMemref = state.lookupBuffer(insertSliceOp.source());
- state.allocationFns.memCpyFn(b, insertSliceOp.getLoc(), srcMemref,
- subView);
- }
-
- state.mapBuffer(insertSliceOp.result(), dstMemref);
- return success();
- }
-};
-
-} // namespace tensor_ext
-
namespace vector_ext {
struct TransferReadOpInterface
@@ -2484,13 +2031,6 @@ void registerBufferizableOpInterfaceExternalModels(DialectRegistry ®istry) {
registry.addOpInterface<scf::YieldOp, scf_ext::YieldOpInterface>();
registry.addOpInterface<CallOp, std_ext::CallOpInterface>();
registry.addOpInterface<ReturnOp, std_ext::ReturnOpInterface>();
- registry.addOpInterface<tensor::CastOp, tensor_ext::CastOpInterface>();
- registry.addOpInterface<tensor::DimOp, tensor_ext::DimOpInterface>();
- registry.addOpInterface<tensor::ExtractSliceOp,
- tensor_ext::ExtractSliceOpInterface>();
- registry.addOpInterface<tensor::ExtractOp, tensor_ext::ExtractOpInterface>();
- registry.addOpInterface<tensor::InsertSliceOp,
- tensor_ext::InsertSliceOpInterface>();
registry.addOpInterface<vector::TransferReadOp,
vector_ext::TransferReadOpInterface>();
registry.addOpInterface<vector::TransferWriteOp,
diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp
new file mode 100644
index 0000000000000..f72a23d7ca811
--- /dev/null
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp
@@ -0,0 +1,437 @@
+//===- TensorInterfaceImpl.cpp - Tensor Impl. of BufferizableOpInterface --===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.h"
+#include "mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/IR/Dialect.h"
+#include "mlir/IR/Operation.h"
+
+namespace mlir {
+namespace linalg {
+namespace comprehensive_bufferize {
+namespace tensor_ext {
+
+using tensor::ExtractSliceOp;
+using tensor::InsertSliceOp;
+
+struct CastOpInterface
+ : public BufferizableOpInterface::ExternalModel<CastOpInterface,
+ tensor::CastOp> {
+ bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand) const {
+ return false;
+ }
+
+ bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand) const {
+ return false;
+ }
+
+ SmallVector<OpOperand *> getAliasingOpOperand(Operation *op,
+ OpResult opResult) const {
+ return {&op->getOpOperand(0)};
+ }
+
+ OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand) const {
+ return op->getResult(0);
+ }
+
+ BufferRelation bufferRelation(Operation *op, OpOperand &opOperand) const {
+ return BufferRelation::Equivalent;
+ }
+
+ LogicalResult bufferize(Operation *op, OpBuilder &b,
+ BufferizationState &state) const {
+ auto castOp = cast<tensor::CastOp>(op);
+
+ // Take a guard before anything else.
+ OpBuilder::InsertionGuard g(b);
+ b.setInsertionPoint(castOp);
+
+ Value resultBuffer = getResultBuffer(b, castOp->getResult(0), state);
+ if (!resultBuffer)
+ return failure();
+ Type sourceType = resultBuffer.getType();
+ auto rankedMemRefType = sourceType.dyn_cast<MemRefType>();
+ auto unrankedMemRefType = sourceType.dyn_cast<UnrankedMemRefType>();
+ assert(rankedMemRefType || unrankedMemRefType);
+ Attribute memorySpace = rankedMemRefType
+ ? rankedMemRefType.getMemorySpace()
+ : unrankedMemRefType.getMemorySpace();
+ TensorType tensorType = castOp.getResult().getType().cast<TensorType>();
+ MemRefLayoutAttrInterface layout =
+ rankedMemRefType && tensorType.isa<RankedTensorType>()
+ ? rankedMemRefType.getLayout()
+ : MemRefLayoutAttrInterface();
+ Type memRefType = getContiguousOrUnrankedMemRefType(
+ castOp.getResult().getType(), layout, memorySpace);
+ Value res =
+ b.create<memref::CastOp>(castOp.getLoc(), memRefType, resultBuffer);
+ state.aliasInfo.insertNewBufferEquivalence(res, castOp.getResult());
+ state.mapBuffer(castOp.getResult(), res);
+ return success();
+ }
+};
+
+struct DimOpInterface
+ : public BufferizableOpInterface::ExternalModel<DimOpInterface,
+ tensor::DimOp> {
+ bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand) const {
+ return true;
+ }
+
+ bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand) const {
+ return false;
+ }
+
+ OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand) const {
+ return OpResult();
+ }
+
+ LogicalResult bufferize(Operation *op, OpBuilder &b,
+ BufferizationState &state) const {
+ auto dimOp = cast<tensor::DimOp>(op);
+
+ // Take a guard before anything else.
+ OpBuilder::InsertionGuard g(b);
+ b.setInsertionPoint(dimOp);
+
+ if (dimOp.source().getType().isa<RankedTensorType>()) {
+ Value v = state.lookupBuffer(dimOp.source());
+ dimOp.result().replaceAllUsesWith(
+ b.create<memref::DimOp>(dimOp.getLoc(), v, dimOp.index()));
+ }
+ return success();
+ }
+};
+
+struct ExtractSliceOpInterface
+ : public BufferizableOpInterface::ExternalModel<ExtractSliceOpInterface,
+ tensor::ExtractSliceOp> {
+ bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand) const {
+ return false;
+ }
+
+ bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand) const {
+ return false;
+ }
+
+ SmallVector<OpOperand *> getAliasingOpOperand(Operation *op,
+ OpResult opResult) const {
+ return {&op->getOpOperand(0) /*source*/};
+ }
+
+ OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand) const {
+ return &opOperand == &op->getOpOperand(0) /*source*/
+ ? op->getResult(0)
+ : OpResult();
+ }
+
+ BufferRelation bufferRelation(Operation *op, OpOperand &opOperand) const {
+ return BufferRelation::None;
+ }
+
+ LogicalResult bufferize(Operation *op, OpBuilder &b,
+ BufferizationState &state) const {
+ auto extractSliceOp = cast<tensor::ExtractSliceOp>(op);
+
+ // Take a guard before anything else.
+ OpBuilder::InsertionGuard g(b);
+ b.setInsertionPoint(extractSliceOp);
+
+ Location loc = extractSliceOp.getLoc();
+ Value srcMemref = state.lookupBuffer(extractSliceOp.source());
+ auto srcMemrefType = srcMemref.getType().cast<MemRefType>();
+ auto dstTensorType =
+ extractSliceOp.result().getType().cast<RankedTensorType>();
+
+ // If not inplaceable, alloc.
+ bool inplace = state.aliasInfo.isInPlace(extractSliceOp->getResult(0));
+ Value alloc;
+ if (!inplace)
+ alloc = state.allocationFns.createAllocDeallocFn(
+ b, loc, extractSliceOp.result(), state);
+
+ // Bufferize to subview.
+ auto subviewMemRefType =
+ memref::SubViewOp::inferRankReducedResultType(
+ dstTensorType.getRank(), srcMemrefType,
+ extractSliceOp.getMixedOffsets(), extractSliceOp.getMixedSizes(),
+ extractSliceOp.getMixedStrides())
+ .cast<MemRefType>();
+ Value subView = b.create<memref::SubViewOp>(
+ loc, subviewMemRefType, srcMemref, extractSliceOp.getMixedOffsets(),
+ extractSliceOp.getMixedSizes(), extractSliceOp.getMixedStrides());
+ // Insert new alias.
+ state.aliasInfo.insertNewBufferAlias(subView, srcMemref);
+
+ /// If not inplaceable, copy.
+ if (!inplace) {
+ // Do not copy if the copied data is never read.
+ if (isValueRead(extractSliceOp.result()))
+ state.allocationFns.memCpyFn(b, extractSliceOp.getLoc(), subView,
+ alloc);
+ subView = alloc;
+ }
+
+ state.mapBuffer(extractSliceOp.result(), subView);
+ return success();
+ }
+};
+
+struct ExtractOpInterface
+ : public BufferizableOpInterface::ExternalModel<ExtractOpInterface,
+ tensor::ExtractOp> {
+ bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand) const {
+ return true;
+ }
+
+ bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand) const {
+ return false;
+ }
+
+ OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand) const {
+ return OpResult();
+ }
+
+ LogicalResult bufferize(Operation *op, OpBuilder &b,
+ BufferizationState &state) const {
+ auto extractOp = cast<tensor::ExtractOp>(op);
+
+ // Take a guard before anything else.
+ OpBuilder::InsertionGuard g(b);
+ b.setInsertionPoint(extractOp);
+
+ Location loc = extractOp.getLoc();
+ Value srcMemref = state.lookupBuffer(extractOp.tensor());
+ Value l = b.create<memref::LoadOp>(loc, srcMemref, extractOp.indices());
+ extractOp.replaceAllUsesWith(l);
+ return success();
+ }
+};
+
+/// Return true if the (ExtractSliceOp, InsertSliceOp) pair match (i.e.
+/// equivalent operand / result and same offset/sizes/strides specification).
+///
+/// This is one particular type of relationship between ops on tensors that
+/// reduce to an equivalence on buffers. This should be generalized and
+/// exposed as interfaces on the proper types.
+static bool
+areEquivalentExtractSliceOps(const BufferizationAliasInfo &aliasInfo,
+ ExtractSliceOp st, InsertSliceOp sti) {
+ if (!st || !sti)
+ return false;
+ if (!aliasInfo.areEquivalentBufferizedValues(st.source(), sti.dest()))
+ return false;
+ if (!sameOffsetsSizesAndStrides(st, sti, isEqualConstantIntOrValue))
+ return false;
+ return true;
+}
+
+/// Return true if the source of a `insertSliceOp` bufferizes to an
+/// equivalent ExtractSliceOp that bufferizes inplace.
+static bool isSourceEquivalentToAMatchingInplaceExtractSliceOp(
+ const BufferizationAliasInfo &aliasInfo, InsertSliceOp insertSliceOp) {
+ bool foundOp = false;
+ aliasInfo.applyOnEquivalenceClass(insertSliceOp.source(), [&](Value value) {
+ auto extractSliceOp = value.getDefiningOp<ExtractSliceOp>();
+ if (extractSliceOp &&
+ areEquivalentExtractSliceOps(aliasInfo, extractSliceOp,
+ insertSliceOp) &&
+ aliasInfo.isInPlace(extractSliceOp->getResult(0))) {
+ foundOp = true;
+ }
+ });
+ return foundOp;
+}
+
+/// Return true if `value` is originating from an ExtractSliceOp that matches
+/// the given InsertSliceOp.
+static bool hasMatchingExtractSliceOp(const BufferizationAliasInfo &aliasInfo,
+ Value value, InsertSliceOp insertOp) {
+ auto condition = [&](Value val) {
+ if (auto extractOp = val.getDefiningOp<ExtractSliceOp>())
+ if (areEquivalentExtractSliceOps(aliasInfo, extractOp, insertOp))
+ return true;
+ return false;
+ };
+
+ return llvm::all_of(findValueInReverseUseDefChain(value, condition),
+ condition);
+}
+
+struct InsertSliceOpInterface
+ : public BufferizableOpInterface::ExternalModel<InsertSliceOpInterface,
+ tensor::InsertSliceOp> {
+ bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand) const {
+ return true;
+ }
+
+ bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand) const {
+ return &opOperand == &op->getOpOperand(1) /*dest*/;
+ }
+
+ SmallVector<OpOperand *> getAliasingOpOperand(Operation *op,
+ OpResult opResult) const {
+ return {&op->getOpOperand(1) /*dest*/};
+ }
+
+ OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand) const {
+ return &opOperand == &op->getOpOperand(1) /*dest*/
+ ? op->getResult(0)
+ : OpResult();
+ }
+
+ BufferRelation bufferRelation(Operation *op, OpOperand &opOperand) const {
+ return BufferRelation::Equivalent;
+ }
+
+ bool isNotConflicting(Operation *op, OpOperand *uRead,
+ OpOperand *uConflictingWrite,
+ const BufferizationAliasInfo &aliasInfo) const {
+ Operation *readingOp = uRead->getOwner();
+ Operation *conflictingWritingOp = uConflictingWrite->getOwner();
+
+ // Special rules for matching ExtractSliceOp/InsertSliceOp pairs. If
+ // uRead is an InsertSliceOp...
+ if (auto insertSliceOp = dyn_cast<InsertSliceOp>(readingOp)) {
+ // As an example, consider the following IR.
+ //
+ // %0 = tensor.extract_slice %t[%a, %b][%c, %d][1, 1] {inplace = [true] }
+ // %1 = linalg.fill %cst, %0 {inplace= [true] }
+ // %2 = tensor.insert_slice %1 into %t[%a, %b][%c, %d][1, 1]
+ // {inplace= [true] }
+
+ // TODO: Use insertSliceOp.getDestOpOperand etc. when available.
+ if (uRead == &insertSliceOp->getOpOperand(1) /*dest*/ &&
+ hasMatchingExtractSliceOp(aliasInfo, uConflictingWrite->get(),
+ insertSliceOp))
+ // Case 1: The main insight is that InsertSliceOp reads only part of
+ // the destination tensor. The overwritten area is not read. If
+ // uConflictingWrite writes into exactly the memory location that is
+ // being read by uRead, this is not a conflict.
+ //
+ // In the above example:
+ // uRead = OpOperand 1 (%t) of tensor.insert_slice
+ // uConflictingWrite = OpOperand 1 (%0) of linalg.fill
+ //
+ // The read of %t does not conflict with the write of the FillOp
+ // (same aliases!) because the area that the FillOp operates on is
+ // exactly the one that is *not* read via %t.
+ return true;
+
+ if (uRead == &insertSliceOp->getOpOperand(0) /*source*/ &&
+ uConflictingWrite == &insertSliceOp->getOpOperand(1) /*dest*/ &&
+ hasMatchingExtractSliceOp(aliasInfo, uRead->get(), insertSliceOp))
+ // Case 2: The read of the source tensor and the write to the dest
+ // tensor via an InsertSliceOp is not a conflict if the read is
+ // reading exactly that part of an equivalent tensor that the
+ // InsertSliceOp is writing.
+ //
+ // In the above example:
+ // uRead = OpOperand 0 (%1) of tensor.insert_slice
+ // uConflictingWrite = OpOperand 1 (%t) of tensor.insert_slice
+ return true;
+ }
+
+ // If uConflictingWrite is an InsertSliceOp...
+ if (auto insertSliceOp = dyn_cast<InsertSliceOp>(conflictingWritingOp))
+ // As an example, consider the following IR.
+ //
+ // %0 = tensor.extract_slice %t[%a, %b][%c, %d][1, 1] {inplace = [true] }
+ // %1 = linalg.fill %cst, %0 {inplace= [true] }
+ // %2 = tensor.insert_slice %1 into %t[%a, %b][%c, %d][1, 1]
+ // {inplace= [true] }
+ // %3 = vector.transfer_read %1, %cst
+ //
+ // In the above example:
+ // uRead = OpOperand 0 (%1) of vector.transfer_read
+ // uConflictingWrite = OpOperand 1 (%t) of tensor.insert_slice
+ // lastWrite = %1
+ //
+ // This is not a conflict because the InsertSliceOp overwrites the
+ // memory segment of %1 with the exact same data. (Effectively, there
+ // is no memory write here.)
+ if (uConflictingWrite == &insertSliceOp->getOpOperand(1) /*dest*/ &&
+ aliasInfo.areEquivalentBufferizedValues(uRead->get(),
+ insertSliceOp.source()) &&
+ hasMatchingExtractSliceOp(aliasInfo, insertSliceOp.source(),
+ insertSliceOp))
+ return true;
+
+ return false;
+ }
+
+ LogicalResult bufferize(Operation *op, OpBuilder &b,
+ BufferizationState &state) const {
+ // insert_slice ops arise from tiling and bufferizing them out-of-place is
+ // generally a deal breaker. When used with loops, this ends up cloning the
+ // whole tensor on every single iteration and is a symptom of a
+ // catastrophically bad scheduling decision.
+ // TODO: be very loud about it or even consider failing the pass.
+ auto insertSliceOp = cast<tensor::InsertSliceOp>(op);
+
+ // Take a guard before anything else.
+ OpBuilder::InsertionGuard g(b);
+ b.setInsertionPoint(insertSliceOp);
+ Location loc = insertSliceOp.getLoc();
+
+ // When bufferizing out-of-place, `getResultBuffer` allocates.
+ Value dstMemref = getResultBuffer(b, insertSliceOp->getResult(0), state);
+ if (!dstMemref)
+ return failure();
+
+ // A copy of the source buffer is needed if either:
+ // - The producer of `source` is not inplace. This is the case where a
+ // slice is computed out of place into the inplace full tensor.
+ // - The result is not inplace. This is the case where the whole tensor is
+ // cloned and the clone needs to be updated.
+ // TODO: Is this necessary?
+ bool needCopy = !isSourceEquivalentToAMatchingInplaceExtractSliceOp(
+ state.aliasInfo, insertSliceOp) ||
+ !state.aliasInfo.isInPlace(insertSliceOp->getResult(0));
+ if (needCopy) {
+ // Take a subview of the dst.
+ auto dstMemrefType = dstMemref.getType().cast<MemRefType>();
+ auto subviewMemRefType =
+ memref::SubViewOp::inferRankReducedResultType(
+ insertSliceOp.getSourceType().getRank(), dstMemrefType,
+ insertSliceOp.getMixedOffsets(), insertSliceOp.getMixedSizes(),
+ insertSliceOp.getMixedStrides())
+ .cast<MemRefType>();
+ Value subView = b.create<memref::SubViewOp>(
+ loc, subviewMemRefType, dstMemref, insertSliceOp.getMixedOffsets(),
+ insertSliceOp.getMixedSizes(), insertSliceOp.getMixedStrides());
+ // Insert new alias.
+ state.aliasInfo.insertNewBufferAlias(subView, dstMemref);
+ // Copy tensor.
+ Value srcMemref = state.lookupBuffer(insertSliceOp.source());
+ state.allocationFns.memCpyFn(b, insertSliceOp.getLoc(), srcMemref,
+ subView);
+ }
+
+ state.mapBuffer(insertSliceOp.result(), dstMemref);
+ return success();
+ }
+};
+
+} // namespace tensor_ext
+} // namespace comprehensive_bufferize
+} // namespace linalg
+} // namespace mlir
+
+void mlir::linalg::comprehensive_bufferize::tensor_ext::
+ registerBufferizableOpInterfaceExternalModels(DialectRegistry ®istry) {
+ registry.addOpInterface<tensor::CastOp, tensor_ext::CastOpInterface>();
+ registry.addOpInterface<tensor::DimOp, tensor_ext::DimOpInterface>();
+ registry.addOpInterface<tensor::ExtractSliceOp,
+ tensor_ext::ExtractSliceOpInterface>();
+ registry.addOpInterface<tensor::ExtractOp, tensor_ext::ExtractOpInterface>();
+ registry.addOpInterface<tensor::InsertSliceOp,
+ tensor_ext::InsertSliceOpInterface>();
+}
diff --git a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
index 25368352bcd19..a098e9e29fa4f 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
@@ -49,6 +49,7 @@ add_mlir_dialect_library(MLIRLinalgTransforms
MLIRStandardOpsTransforms
MLIRStandardToLLVM
MLIRTensor
+ MLIRTensorBufferizableOpInterfaceImpl
MLIRTransforms
MLIRTransformUtils
MLIRVector
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp
index 626eafa31bbb2..69182f92b70b6 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp
@@ -10,6 +10,7 @@
#include "mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h"
#include "mlir/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.h"
#include "mlir/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.h"
+#include "mlir/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.h"
#include "mlir/Dialect/Linalg/Passes.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassManager.h"
@@ -38,6 +39,7 @@ struct LinalgComprehensiveModuleBufferize
arith::ArithmeticDialect, StandardOpsDialect, AffineDialect>();
registerBufferizableOpInterfaceExternalModels(registry);
linalg_ext::registerBufferizableOpInterfaceExternalModels(registry);
+ tensor_ext::registerBufferizableOpInterfaceExternalModels(registry);
}
};
} // end namespace
diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
index 969d82c2c163d..aaeac453cc085 100644
--- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
@@ -6326,6 +6326,25 @@ cc_library(
],
)
+cc_library(
+ name = "TensorBufferizableOpInterfaceImpl",
+ srcs = [
+ "lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp",
+ ],
+ hdrs = [
+ "include/mlir/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.h",
+ ],
+ includes = ["include"],
+ deps = [
+ ":BufferizableOpInterface",
+ ":IR",
+ ":MemRefDialect",
+ ":Support",
+ ":TensorDialect",
+ "//llvm:Support",
+ ],
+)
+
td_library(
name = "LinalgDocTdFiles",
srcs = ["include/mlir/Dialect/Linalg/IR/LinalgDoc.td"],
@@ -6545,6 +6564,7 @@ cc_library(
":StandardOps",
":StandardOpsTransforms",
":Support",
+ ":TensorBufferizableOpInterfaceImpl",
":TensorDialect",
":TransformUtils",
":VectorOps",
@@ -6575,7 +6595,6 @@ cc_library(
":SCFDialect",
":StandardOps",
":Support",
- ":TensorDialect",
":TransformUtils",
":VectorOps",
"//llvm:Support",
More information about the Mlir-commits
mailing list