[Mlir-commits] [mlir] bb273a3 - [mlir][linalg][bufferize][NFC] Move tensor interface impl to new build target

Matthias Springer llvmlistbot at llvm.org
Wed Nov 24 01:25:30 PST 2021


Author: Matthias Springer
Date: 2021-11-24T18:25:17+09:00
New Revision: bb273a35a02a00dbba8549e858df310f4b6a32b1

URL: https://github.com/llvm/llvm-project/commit/bb273a35a02a00dbba8549e858df310f4b6a32b1
DIFF: https://github.com/llvm/llvm-project/commit/bb273a35a02a00dbba8549e858df310f4b6a32b1.diff

LOG: [mlir][linalg][bufferize][NFC] Move tensor interface impl to new build target

This makes ComprehensiveBufferize entirely independent of the tensor dialect.

Differential Revision: https://reviews.llvm.org/D114217

Added: 
    mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.h
    mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp

Modified: 
    mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h
    mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.h
    mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp
    mlir/lib/Dialect/Linalg/ComprehensiveBufferize/CMakeLists.txt
    mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp
    mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
    mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp
    utils/bazel/llvm-project-overlay/mlir/BUILD.bazel

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h
index 881f1edb11c47..5ce675101c5aa 100644
--- a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h
+++ b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h
@@ -322,6 +322,26 @@ struct PostAnalysisStep {
                             SmallVector<Operation *> &newOps) = 0;
 };
 
+/// Return a contiguous MemRefType (i.e. with canonical/empty layout map)
+/// with the same shape as `shapedType` and specified `layout` and
+/// `addressSpace`.
+MemRefType getContiguousMemRefType(ShapedType shapedType,
+                                   MemRefLayoutAttrInterface layout = {},
+                                   Attribute memorySpace = {});
+
+/// Return a contiguous MemRefType (i.e. with canonical/empty layout map)
+/// with the same shape as `shapedType` and specified `layout` and
+/// `addressSpace` or an UnrankedMemRefType otherwise.
+Type getContiguousOrUnrankedMemRefType(Type type,
+                                       MemRefLayoutAttrInterface layout = {},
+                                       Attribute memorySpace = {});
+
+/// Return a MemRefType to which the `tensorType` can be bufferized in a
+/// composable fashion. The layout must be the most dynamic possible and
+/// canonicalize away once bufferization is finished.
+MemRefType getDynamicMemRefType(RankedTensorType tensorType,
+                                unsigned addressSpace = 0);
+
 } // namespace comprehensive_bufferize
 } // namespace linalg
 } // namespace mlir

diff  --git a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.h b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.h
index e225baa498a4d..d5ee20bf33929 100644
--- a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.h
+++ b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.h
@@ -1,3 +1,11 @@
+//===- LinalgInterfaceImpl.h - Linalg Impl. of BufferizableOpInterface ----===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
 #ifndef MLIR_DIALECT_LINALG_COMPREHENSIVEBUFFERIZE_LINALG_INTERFACE_IMPL_H
 #define MLIR_DIALECT_LINALG_COMPREHENSIVEBUFFERIZE_LINALG_INTERFACE_IMPL_H
 

diff  --git a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.h b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.h
new file mode 100644
index 0000000000000..29355ef338f3a
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.h
@@ -0,0 +1,27 @@
+//===- LinalgInterfaceImpl.h - Linalg Impl. of BufferizableOpInterface ----===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_LINALG_COMPREHENSIVEBUFFERIZE_TENSOR_INTERFACE_IMPL_H
+#define MLIR_DIALECT_LINALG_COMPREHENSIVEBUFFERIZE_TENSOR_INTERFACE_IMPL_H
+
+namespace mlir {
+
+class DialectRegistry;
+
+namespace linalg {
+namespace comprehensive_bufferize {
+namespace tensor_ext {
+
+void registerBufferizableOpInterfaceExternalModels(DialectRegistry &registry);
+
+} // namespace tensor_ext
+} // namespace comprehensive_bufferize
+} // namespace linalg
+} // namespace mlir
+
+#endif // MLIR_DIALECT_LINALG_COMPREHENSIVEBUFFERIZE_TENSOR_INTERFACE_IMPL_H

diff  --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp
index fc9f414f7cd9d..a530139f1e507 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp
@@ -12,6 +12,7 @@
 #include "mlir/IR/BlockAndValueMapping.h"
 #include "mlir/IR/BuiltinOps.h"
 #include "mlir/IR/Operation.h"
+#include "mlir/IR/TypeUtilities.h"
 #include "mlir/IR/Value.h"
 #include "llvm/Support/Debug.h"
 
@@ -528,3 +529,31 @@ void mlir::linalg::comprehensive_bufferize::BufferizationState::
     op->erase();
   obsoleteOps.clear();
 }
+
+MemRefType mlir::linalg::comprehensive_bufferize::getContiguousMemRefType(
+    ShapedType shapedType, MemRefLayoutAttrInterface layout,
+    Attribute memorySpace) {
+  return MemRefType::get(shapedType.getShape(), shapedType.getElementType(),
+                         layout, memorySpace);
+}
+
+Type mlir::linalg::comprehensive_bufferize::getContiguousOrUnrankedMemRefType(
+    Type type, MemRefLayoutAttrInterface layout, Attribute memorySpace) {
+  if (type.isa<RankedTensorType, MemRefType>())
+    return getContiguousMemRefType(type.cast<ShapedType>(), layout,
+                                   memorySpace);
+  assert(!layout && "expected empty layout with UnrankedMemRefType");
+  return UnrankedMemRefType::get(getElementTypeOrSelf(type), memorySpace);
+}
+
+MemRefType mlir::linalg::comprehensive_bufferize::getDynamicMemRefType(
+    RankedTensorType tensorType, unsigned addressSpace) {
+  // TODO: address space decisions to connect with the actual alloc.
+  int64_t dynamicOffset = ShapedType::kDynamicStrideOrOffset;
+  SmallVector<int64_t> dynamicStrides(tensorType.getRank(),
+                                      ShapedType::kDynamicStrideOrOffset);
+  AffineMap stridedLayout = makeStridedLinearLayoutMap(
+      dynamicStrides, dynamicOffset, tensorType.getContext());
+  return MemRefType::get(tensorType.getShape(), tensorType.getElementType(),
+                         stridedLayout, addressSpace);
+}

diff  --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/CMakeLists.txt b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/CMakeLists.txt
index e55653464155f..80708504c9ed5 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/CMakeLists.txt
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/CMakeLists.txt
@@ -2,6 +2,7 @@ set(LLVM_OPTIONAL_SOURCES
   BufferizableOpInterface.cpp
   ComprehensiveBufferize.cpp
   LinalgInterfaceImpl.cpp
+  TensorInterfaceImpl.cpp
 )
 
 add_mlir_dialect_library(MLIRBufferizableOpInterface
@@ -25,6 +26,16 @@ add_mlir_dialect_library(MLIRLinalgBufferizableOpInterfaceImpl
   MLIRTensor
 )
 
+add_mlir_dialect_library(MLIRTensorBufferizableOpInterfaceImpl
+  TensorInterfaceImpl.cpp
+
+  LINK_LIBS PUBLIC
+  MLIRBufferizableOpInterface
+  MLIRIR
+  MLIRMemRef
+  MLIRTensor
+)
+
 add_mlir_dialect_library(MLIRComprehensiveBufferize
   ComprehensiveBufferize.cpp
 
@@ -37,6 +48,5 @@ add_mlir_dialect_library(MLIRComprehensiveBufferize
   MLIRSCF
   MLIRStandard
   MLIRStandardOpsTransforms
-  MLIRTensor
   MLIRVector
 )

diff  --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp
index d062bbab4ad0c..7eecbc4f95331 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp
@@ -587,45 +587,6 @@ getEquivalentEnclosingFuncBBArg(Value v,
 // Bufferization-specific MemRefType support.
 //===----------------------------------------------------------------------===//
 
-/// Return a contiguous MemRefType (i.e. with canonical/empty layout map)
-/// with the same shape as `shapedType` and specified `layout` and
-/// `addressSpace`.
-static MemRefType getContiguousMemRefType(ShapedType shapedType,
-                                          MemRefLayoutAttrInterface layout = {},
-                                          Attribute memorySpace = {}) {
-  return MemRefType::get(shapedType.getShape(), shapedType.getElementType(),
-                         layout, memorySpace);
-}
-
-/// Return a contiguous MemRefType (i.e. with canonical/empty layout map)
-/// with the same shape as `shapedType` and specified `layout` and
-/// `addressSpace` or an UnrankedMemRefType otherwise.
-static Type
-getContiguousOrUnrankedMemRefType(Type type,
-                                  MemRefLayoutAttrInterface layout = {},
-                                  Attribute memorySpace = {}) {
-  if (type.isa<RankedTensorType, MemRefType>())
-    return getContiguousMemRefType(type.cast<ShapedType>(), layout,
-                                   memorySpace);
-  assert(!layout && "expected empty layout with UnrankedMemRefType");
-  return UnrankedMemRefType::get(getElementTypeOrSelf(type), memorySpace);
-}
-
-/// Return a MemRefType to which the `tensorType` can be bufferized in a
-/// composable fashion. The layout must be the most dynamic possible and
-/// canonicalize away once bufferization is finished.
-static MemRefType getDynamicMemRefType(RankedTensorType tensorType,
-                                       unsigned addressSpace = 0) {
-  // TODO: address space decisions to connect with the actual alloc.
-  int64_t dynamicOffset = ShapedType::kDynamicStrideOrOffset;
-  SmallVector<int64_t> dynamicStrides(tensorType.getRank(),
-                                      ShapedType::kDynamicStrideOrOffset);
-  AffineMap stridedLayout = makeStridedLinearLayoutMap(
-      dynamicStrides, dynamicOffset, tensorType.getContext());
-  return MemRefType::get(tensorType.getShape(), tensorType.getElementType(),
-                         stridedLayout, addressSpace);
-}
-
 /// Return the FunctionType with `argumentTypes` and `resultTypes` where each
 /// tensor is replaced by the corresponding buffer type.
 /// In order for all the callers to agree, this *must* bufferize to the most
@@ -1965,420 +1926,6 @@ struct ReturnOpInterface
 
 } // namespace std_ext
 
-namespace tensor_ext {
-
-struct CastOpInterface
-    : public BufferizableOpInterface::ExternalModel<CastOpInterface,
-                                                    tensor::CastOp> {
-  bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand) const {
-    return false;
-  }
-
-  bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand) const {
-    return false;
-  }
-
-  SmallVector<OpOperand *> getAliasingOpOperand(Operation *op,
-                                                OpResult opResult) const {
-    return {&op->getOpOperand(0)};
-  }
-
-  OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand) const {
-    return op->getResult(0);
-  }
-
-  BufferRelation bufferRelation(Operation *op, OpOperand &opOperand) const {
-    return BufferRelation::Equivalent;
-  }
-
-  LogicalResult bufferize(Operation *op, OpBuilder &b,
-                          BufferizationState &state) const {
-    auto castOp = cast<tensor::CastOp>(op);
-
-    // Take a guard before anything else.
-    OpBuilder::InsertionGuard g(b);
-    b.setInsertionPoint(castOp);
-
-    Value resultBuffer = getResultBuffer(b, castOp->getResult(0), state);
-    if (!resultBuffer)
-      return failure();
-    Type sourceType = resultBuffer.getType();
-    auto rankedMemRefType = sourceType.dyn_cast<MemRefType>();
-    auto unrankedMemRefType = sourceType.dyn_cast<UnrankedMemRefType>();
-    assert(rankedMemRefType || unrankedMemRefType);
-    Attribute memorySpace = rankedMemRefType
-                                ? rankedMemRefType.getMemorySpace()
-                                : unrankedMemRefType.getMemorySpace();
-    TensorType tensorType = castOp.getResult().getType().cast<TensorType>();
-    MemRefLayoutAttrInterface layout =
-        rankedMemRefType && tensorType.isa<RankedTensorType>()
-            ? rankedMemRefType.getLayout()
-            : MemRefLayoutAttrInterface();
-    Type memRefType = getContiguousOrUnrankedMemRefType(
-        castOp.getResult().getType(), layout, memorySpace);
-    Value res =
-        b.create<memref::CastOp>(castOp.getLoc(), memRefType, resultBuffer);
-    state.aliasInfo.insertNewBufferEquivalence(res, castOp.getResult());
-    state.mapBuffer(castOp.getResult(), res);
-    return success();
-  }
-};
-
-struct DimOpInterface
-    : public BufferizableOpInterface::ExternalModel<DimOpInterface,
-                                                    tensor::DimOp> {
-  bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand) const {
-    return true;
-  }
-
-  bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand) const {
-    return false;
-  }
-
-  OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand) const {
-    return OpResult();
-  }
-
-  LogicalResult bufferize(Operation *op, OpBuilder &b,
-                          BufferizationState &state) const {
-    auto dimOp = cast<tensor::DimOp>(op);
-
-    // Take a guard before anything else.
-    OpBuilder::InsertionGuard g(b);
-    b.setInsertionPoint(dimOp);
-
-    if (dimOp.source().getType().isa<RankedTensorType>()) {
-      Value v = state.lookupBuffer(dimOp.source());
-      dimOp.result().replaceAllUsesWith(
-          b.create<memref::DimOp>(dimOp.getLoc(), v, dimOp.index()));
-    }
-    return success();
-  }
-};
-
-struct ExtractSliceOpInterface
-    : public BufferizableOpInterface::ExternalModel<ExtractSliceOpInterface,
-                                                    tensor::ExtractSliceOp> {
-  bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand) const {
-    return false;
-  }
-
-  bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand) const {
-    return false;
-  }
-
-  SmallVector<OpOperand *> getAliasingOpOperand(Operation *op,
-                                                OpResult opResult) const {
-    return {&op->getOpOperand(0) /*source*/};
-  }
-
-  OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand) const {
-    return &opOperand == &op->getOpOperand(0) /*source*/
-               ? op->getResult(0)
-               : OpResult();
-  }
-
-  BufferRelation bufferRelation(Operation *op, OpOperand &opOperand) const {
-    return BufferRelation::None;
-  }
-
-  LogicalResult bufferize(Operation *op, OpBuilder &b,
-                          BufferizationState &state) const {
-    auto extractSliceOp = cast<tensor::ExtractSliceOp>(op);
-    LDBG("bufferize: " << *extractSliceOp << '\n');
-
-    // Take a guard before anything else.
-    OpBuilder::InsertionGuard g(b);
-    b.setInsertionPoint(extractSliceOp);
-
-    Location loc = extractSliceOp.getLoc();
-    Value srcMemref = state.lookupBuffer(extractSliceOp.source());
-    auto srcMemrefType = srcMemref.getType().cast<MemRefType>();
-    auto dstTensorType =
-        extractSliceOp.result().getType().cast<RankedTensorType>();
-
-    // If not inplaceable, alloc.
-    bool inplace = state.aliasInfo.isInPlace(extractSliceOp->getResult(0));
-    Value alloc;
-    if (!inplace)
-      alloc = createNewAllocDeallocPairForShapedValue(
-          b, loc, extractSliceOp.result(), state);
-
-    // Bufferize to subview.
-    auto subviewMemRefType =
-        memref::SubViewOp::inferRankReducedResultType(
-            dstTensorType.getRank(), srcMemrefType,
-            extractSliceOp.getMixedOffsets(), extractSliceOp.getMixedSizes(),
-            extractSliceOp.getMixedStrides())
-            .cast<MemRefType>();
-    Value subView = b.create<memref::SubViewOp>(
-        loc, subviewMemRefType, srcMemref, extractSliceOp.getMixedOffsets(),
-        extractSliceOp.getMixedSizes(), extractSliceOp.getMixedStrides());
-    // Insert new alias.
-    state.aliasInfo.insertNewBufferAlias(subView, srcMemref);
-
-    /// If not inplaceable, copy.
-    if (!inplace) {
-      // Do not copy if the copied data is never read.
-      if (isValueRead(extractSliceOp.result()))
-        state.allocationFns.memCpyFn(b, extractSliceOp.getLoc(), subView,
-                                     alloc);
-      subView = alloc;
-    }
-
-    state.mapBuffer(extractSliceOp.result(), subView);
-    return success();
-  }
-};
-
-struct ExtractOpInterface
-    : public BufferizableOpInterface::ExternalModel<ExtractOpInterface,
-                                                    tensor::ExtractOp> {
-  bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand) const {
-    return true;
-  }
-
-  bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand) const {
-    return false;
-  }
-
-  OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand) const {
-    return OpResult();
-  }
-
-  LogicalResult bufferize(Operation *op, OpBuilder &b,
-                          BufferizationState &state) const {
-    auto extractOp = cast<tensor::ExtractOp>(op);
-
-    // Take a guard before anything else.
-    OpBuilder::InsertionGuard g(b);
-    b.setInsertionPoint(extractOp);
-
-    Location loc = extractOp.getLoc();
-    Value srcMemref = state.lookupBuffer(extractOp.tensor());
-    Value l = b.create<memref::LoadOp>(loc, srcMemref, extractOp.indices());
-    extractOp.replaceAllUsesWith(l);
-    return success();
-  }
-};
-
-/// Return true if the (ExtractSliceOp, InsertSliceOp) pair match (i.e.
-/// equivalent operand / result and same offset/sizes/strides specification).
-///
-/// This is one particular type of relationship between ops on tensors that
-/// reduce to an equivalence on buffers. This should be generalized and
-/// exposed as interfaces on the proper types.
-static bool
-areEquivalentExtractSliceOps(const BufferizationAliasInfo &aliasInfo,
-                             ExtractSliceOp st, InsertSliceOp sti) {
-  if (!st || !sti)
-    return false;
-  if (!aliasInfo.areEquivalentBufferizedValues(st.source(), sti.dest()))
-    return false;
-  if (!sameOffsetsSizesAndStrides(st, sti, isEqualConstantIntOrValue))
-    return false;
-  return true;
-}
-
-/// Return true if the source of a `insertSliceOp` bufferizes to an
-/// equivalent ExtractSliceOp that bufferizes inplace.
-static bool isSourceEquivalentToAMatchingInplaceExtractSliceOp(
-    const BufferizationAliasInfo &aliasInfo, InsertSliceOp insertSliceOp) {
-  LDBG("isSourceEquivalentToAMatchingInplaceExtractSliceOp: " << *insertSliceOp
-                                                              << '\n');
-  bool foundOp = false;
-  aliasInfo.applyOnEquivalenceClass(insertSliceOp.source(), [&](Value value) {
-    auto extractSliceOp = value.getDefiningOp<ExtractSliceOp>();
-    if (extractSliceOp &&
-        areEquivalentExtractSliceOps(aliasInfo, extractSliceOp,
-                                     insertSliceOp) &&
-        aliasInfo.isInPlace(extractSliceOp->getResult(0))) {
-      LDBG("\tfound: " << extractSliceOp.getOperation() << '\n');
-      foundOp = true;
-    }
-  });
-
-  if (!foundOp)
-    LDBG("\tnot equivalent\n");
-
-  return foundOp;
-}
-
-/// Return true if `value` is originating from an ExtractSliceOp that matches
-/// the given InsertSliceOp.
-static bool hasMatchingExtractSliceOp(const BufferizationAliasInfo &aliasInfo,
-                                      Value value, InsertSliceOp insertOp) {
-  auto condition = [&](Value val) {
-    if (auto extractOp = val.getDefiningOp<ExtractSliceOp>())
-      if (areEquivalentExtractSliceOps(aliasInfo, extractOp, insertOp))
-        return true;
-    return false;
-  };
-
-  return llvm::all_of(findValueInReverseUseDefChain(value, condition),
-                      condition);
-}
-
-struct InsertSliceOpInterface
-    : public BufferizableOpInterface::ExternalModel<InsertSliceOpInterface,
-                                                    tensor::InsertSliceOp> {
-  bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand) const {
-    return true;
-  }
-
-  bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand) const {
-    return &opOperand == &op->getOpOperand(1) /*dest*/;
-  }
-
-  SmallVector<OpOperand *> getAliasingOpOperand(Operation *op,
-                                                OpResult opResult) const {
-    return {&op->getOpOperand(1) /*dest*/};
-  }
-
-  OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand) const {
-    return &opOperand == &op->getOpOperand(1) /*dest*/
-               ? op->getResult(0)
-               : OpResult();
-  }
-
-  BufferRelation bufferRelation(Operation *op, OpOperand &opOperand) const {
-    return BufferRelation::Equivalent;
-  }
-
-  bool isNotConflicting(Operation *op, OpOperand *uRead,
-                        OpOperand *uConflictingWrite,
-                        const BufferizationAliasInfo &aliasInfo) const {
-    Operation *readingOp = uRead->getOwner();
-    Operation *conflictingWritingOp = uConflictingWrite->getOwner();
-
-    // Special rules for matching ExtractSliceOp/InsertSliceOp pairs. If
-    // uRead is an InsertSliceOp...
-    if (auto insertSliceOp = dyn_cast<InsertSliceOp>(readingOp)) {
-      // As an example, consider the following IR.
-      //
-      // %0 = tensor.extract_slice %t[%a, %b][%c, %d][1, 1] {inplace = [true] }
-      // %1 = linalg.fill %cst, %0 {inplace= [true] }
-      // %2 = tensor.insert_slice %1 into %t[%a, %b][%c, %d][1, 1]
-      //     {inplace= [true] }
-
-      // TODO: Use insertSliceOp.getDestOpOperand etc. when available.
-      if (uRead == &insertSliceOp->getOpOperand(1) /*dest*/ &&
-          hasMatchingExtractSliceOp(aliasInfo, uConflictingWrite->get(),
-                                    insertSliceOp))
-        // Case 1: The main insight is that InsertSliceOp reads only part of
-        // the destination tensor. The overwritten area is not read. If
-        // uConflictingWrite writes into exactly the memory location that is
-        // being read by uRead, this is not a conflict.
-        //
-        // In the above example:
-        // uRead             = OpOperand 1 (%t) of tensor.insert_slice
-        // uConflictingWrite = OpOperand 1 (%0) of linalg.fill
-        //
-        // The read of %t does not conflict with the write of the FillOp
-        // (same aliases!) because the area that the FillOp operates on is
-        // exactly the one that is *not* read via %t.
-        return true;
-
-      if (uRead == &insertSliceOp->getOpOperand(0) /*source*/ &&
-          uConflictingWrite == &insertSliceOp->getOpOperand(1) /*dest*/ &&
-          hasMatchingExtractSliceOp(aliasInfo, uRead->get(), insertSliceOp))
-        // Case 2: The read of the source tensor and the write to the dest
-        // tensor via an InsertSliceOp is not a conflict if the read is
-        // reading exactly that part of an equivalent tensor that the
-        // InsertSliceOp is writing.
-        //
-        // In the above example:
-        // uRead             = OpOperand 0 (%1) of tensor.insert_slice
-        // uConflictingWrite = OpOperand 1 (%t) of tensor.insert_slice
-        return true;
-    }
-
-    // If uConflictingWrite is an InsertSliceOp...
-    if (auto insertSliceOp = dyn_cast<InsertSliceOp>(conflictingWritingOp))
-      // As an example, consider the following IR.
-      //
-      // %0 = tensor.extract_slice %t[%a, %b][%c, %d][1, 1] {inplace = [true] }
-      // %1 = linalg.fill %cst, %0 {inplace= [true] }
-      // %2 = tensor.insert_slice %1 into %t[%a, %b][%c, %d][1, 1]
-      //     {inplace= [true] }
-      // %3 = vector.transfer_read %1, %cst
-      //
-      // In the above example:
-      // uRead             = OpOperand 0 (%1) of vector.transfer_read
-      // uConflictingWrite = OpOperand 1 (%t) of tensor.insert_slice
-      // lastWrite         = %1
-      //
-      // This is not a conflict because the InsertSliceOp overwrites the
-      // memory segment of %1 with the exact same data. (Effectively, there
-      // is no memory write here.)
-      if (uConflictingWrite == &insertSliceOp->getOpOperand(1) /*dest*/ &&
-          aliasInfo.areEquivalentBufferizedValues(uRead->get(),
-                                                  insertSliceOp.source()) &&
-          hasMatchingExtractSliceOp(aliasInfo, insertSliceOp.source(),
-                                    insertSliceOp))
-        return true;
-
-    return false;
-  }
-
-  LogicalResult bufferize(Operation *op, OpBuilder &b,
-                          BufferizationState &state) const {
-    // insert_slice ops arise from tiling and bufferizing them out-of-place is
-    // generally a deal breaker. When used with loops, this ends up cloning the
-    // whole tensor on every single iteration and is a symptom of a
-    // catastrophically bad scheduling decision.
-    // TODO: be very loud about it or even consider failing the pass.
-    auto insertSliceOp = cast<tensor::InsertSliceOp>(op);
-    LDBG("bufferize: " << *insertSliceOp << '\n');
-
-    // Take a guard before anything else.
-    OpBuilder::InsertionGuard g(b);
-    b.setInsertionPoint(insertSliceOp);
-    Location loc = insertSliceOp.getLoc();
-
-    // When bufferizing out-of-place, `getResultBuffer` allocates.
-    Value dstMemref = getResultBuffer(b, insertSliceOp->getResult(0), state);
-    if (!dstMemref)
-      return failure();
-
-    // A copy of the source buffer is needed if either:
-    //   - The producer of `source` is not inplace. This is the case where a
-    //     slice is computed out of place into the inplace full tensor.
-    //   - The result is not inplace. This is the case where the whole tensor is
-    //     cloned and the clone needs to be updated.
-    // TODO: Is this necessary?
-    bool needCopy = !isSourceEquivalentToAMatchingInplaceExtractSliceOp(
-                        state.aliasInfo, insertSliceOp) ||
-                    !state.aliasInfo.isInPlace(insertSliceOp->getResult(0));
-    if (needCopy) {
-      LDBG("insert_slice needs extra source copy: " << insertSliceOp.source()
-                                                    << " -> copy\n");
-      // Take a subview of the dst.
-      auto dstMemrefType = dstMemref.getType().cast<MemRefType>();
-      auto subviewMemRefType =
-          memref::SubViewOp::inferRankReducedResultType(
-              insertSliceOp.getSourceType().getRank(), dstMemrefType,
-              insertSliceOp.getMixedOffsets(), insertSliceOp.getMixedSizes(),
-              insertSliceOp.getMixedStrides())
-              .cast<MemRefType>();
-      Value subView = b.create<memref::SubViewOp>(
-          loc, subviewMemRefType, dstMemref, insertSliceOp.getMixedOffsets(),
-          insertSliceOp.getMixedSizes(), insertSliceOp.getMixedStrides());
-      // Insert new alias.
-      state.aliasInfo.insertNewBufferAlias(subView, dstMemref);
-      // Copy tensor.
-      Value srcMemref = state.lookupBuffer(insertSliceOp.source());
-      state.allocationFns.memCpyFn(b, insertSliceOp.getLoc(), srcMemref,
-                                   subView);
-    }
-
-    state.mapBuffer(insertSliceOp.result(), dstMemref);
-    return success();
-  }
-};
-
-} // namespace tensor_ext
-
 namespace vector_ext {
 
 struct TransferReadOpInterface
@@ -2484,13 +2031,6 @@ void registerBufferizableOpInterfaceExternalModels(DialectRegistry &registry) {
   registry.addOpInterface<scf::YieldOp, scf_ext::YieldOpInterface>();
   registry.addOpInterface<CallOp, std_ext::CallOpInterface>();
   registry.addOpInterface<ReturnOp, std_ext::ReturnOpInterface>();
-  registry.addOpInterface<tensor::CastOp, tensor_ext::CastOpInterface>();
-  registry.addOpInterface<tensor::DimOp, tensor_ext::DimOpInterface>();
-  registry.addOpInterface<tensor::ExtractSliceOp,
-                          tensor_ext::ExtractSliceOpInterface>();
-  registry.addOpInterface<tensor::ExtractOp, tensor_ext::ExtractOpInterface>();
-  registry.addOpInterface<tensor::InsertSliceOp,
-                          tensor_ext::InsertSliceOpInterface>();
   registry.addOpInterface<vector::TransferReadOp,
                           vector_ext::TransferReadOpInterface>();
   registry.addOpInterface<vector::TransferWriteOp,

diff  --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp
new file mode 100644
index 0000000000000..f72a23d7ca811
--- /dev/null
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp
@@ -0,0 +1,437 @@
+//===- TensorInterfaceImpl.cpp - Tensor Impl. of BufferizableOpInterface --===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.h"
+#include "mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/IR/Dialect.h"
+#include "mlir/IR/Operation.h"
+
+namespace mlir {
+namespace linalg {
+namespace comprehensive_bufferize {
+namespace tensor_ext {
+
+using tensor::ExtractSliceOp;
+using tensor::InsertSliceOp;
+
+struct CastOpInterface
+    : public BufferizableOpInterface::ExternalModel<CastOpInterface,
+                                                    tensor::CastOp> {
+  bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand) const {
+    return false;
+  }
+
+  bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand) const {
+    return false;
+  }
+
+  SmallVector<OpOperand *> getAliasingOpOperand(Operation *op,
+                                                OpResult opResult) const {
+    return {&op->getOpOperand(0)};
+  }
+
+  OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand) const {
+    return op->getResult(0);
+  }
+
+  BufferRelation bufferRelation(Operation *op, OpOperand &opOperand) const {
+    return BufferRelation::Equivalent;
+  }
+
+  LogicalResult bufferize(Operation *op, OpBuilder &b,
+                          BufferizationState &state) const {
+    auto castOp = cast<tensor::CastOp>(op);
+
+    // Take a guard before anything else.
+    OpBuilder::InsertionGuard g(b);
+    b.setInsertionPoint(castOp);
+
+    Value resultBuffer = getResultBuffer(b, castOp->getResult(0), state);
+    if (!resultBuffer)
+      return failure();
+    Type sourceType = resultBuffer.getType();
+    auto rankedMemRefType = sourceType.dyn_cast<MemRefType>();
+    auto unrankedMemRefType = sourceType.dyn_cast<UnrankedMemRefType>();
+    assert(rankedMemRefType || unrankedMemRefType);
+    Attribute memorySpace = rankedMemRefType
+                                ? rankedMemRefType.getMemorySpace()
+                                : unrankedMemRefType.getMemorySpace();
+    TensorType tensorType = castOp.getResult().getType().cast<TensorType>();
+    MemRefLayoutAttrInterface layout =
+        rankedMemRefType && tensorType.isa<RankedTensorType>()
+            ? rankedMemRefType.getLayout()
+            : MemRefLayoutAttrInterface();
+    Type memRefType = getContiguousOrUnrankedMemRefType(
+        castOp.getResult().getType(), layout, memorySpace);
+    Value res =
+        b.create<memref::CastOp>(castOp.getLoc(), memRefType, resultBuffer);
+    state.aliasInfo.insertNewBufferEquivalence(res, castOp.getResult());
+    state.mapBuffer(castOp.getResult(), res);
+    return success();
+  }
+};
+
+struct DimOpInterface
+    : public BufferizableOpInterface::ExternalModel<DimOpInterface,
+                                                    tensor::DimOp> {
+  bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand) const {
+    return true;
+  }
+
+  bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand) const {
+    return false;
+  }
+
+  OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand) const {
+    return OpResult();
+  }
+
+  LogicalResult bufferize(Operation *op, OpBuilder &b,
+                          BufferizationState &state) const {
+    auto dimOp = cast<tensor::DimOp>(op);
+
+    // Take a guard before anything else.
+    OpBuilder::InsertionGuard g(b);
+    b.setInsertionPoint(dimOp);
+
+    if (dimOp.source().getType().isa<RankedTensorType>()) {
+      Value v = state.lookupBuffer(dimOp.source());
+      dimOp.result().replaceAllUsesWith(
+          b.create<memref::DimOp>(dimOp.getLoc(), v, dimOp.index()));
+    }
+    return success();
+  }
+};
+
+struct ExtractSliceOpInterface
+    : public BufferizableOpInterface::ExternalModel<ExtractSliceOpInterface,
+                                                    tensor::ExtractSliceOp> {
+  bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand) const {
+    return false;
+  }
+
+  bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand) const {
+    return false;
+  }
+
+  SmallVector<OpOperand *> getAliasingOpOperand(Operation *op,
+                                                OpResult opResult) const {
+    return {&op->getOpOperand(0) /*source*/};
+  }
+
+  OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand) const {
+    return &opOperand == &op->getOpOperand(0) /*source*/
+               ? op->getResult(0)
+               : OpResult();
+  }
+
+  BufferRelation bufferRelation(Operation *op, OpOperand &opOperand) const {
+    return BufferRelation::None;
+  }
+
+  LogicalResult bufferize(Operation *op, OpBuilder &b,
+                          BufferizationState &state) const {
+    auto extractSliceOp = cast<tensor::ExtractSliceOp>(op);
+
+    // Take a guard before anything else.
+    OpBuilder::InsertionGuard g(b);
+    b.setInsertionPoint(extractSliceOp);
+
+    Location loc = extractSliceOp.getLoc();
+    Value srcMemref = state.lookupBuffer(extractSliceOp.source());
+    auto srcMemrefType = srcMemref.getType().cast<MemRefType>();
+    auto dstTensorType =
+        extractSliceOp.result().getType().cast<RankedTensorType>();
+
+    // If not inplaceable, alloc.
+    bool inplace = state.aliasInfo.isInPlace(extractSliceOp->getResult(0));
+    Value alloc;
+    if (!inplace)
+      alloc = state.allocationFns.createAllocDeallocFn(
+          b, loc, extractSliceOp.result(), state);
+
+    // Bufferize to subview.
+    auto subviewMemRefType =
+        memref::SubViewOp::inferRankReducedResultType(
+            dstTensorType.getRank(), srcMemrefType,
+            extractSliceOp.getMixedOffsets(), extractSliceOp.getMixedSizes(),
+            extractSliceOp.getMixedStrides())
+            .cast<MemRefType>();
+    Value subView = b.create<memref::SubViewOp>(
+        loc, subviewMemRefType, srcMemref, extractSliceOp.getMixedOffsets(),
+        extractSliceOp.getMixedSizes(), extractSliceOp.getMixedStrides());
+    // Insert new alias.
+    state.aliasInfo.insertNewBufferAlias(subView, srcMemref);
+
+    /// If not inplaceable, copy.
+    if (!inplace) {
+      // Do not copy if the copied data is never read.
+      if (isValueRead(extractSliceOp.result()))
+        state.allocationFns.memCpyFn(b, extractSliceOp.getLoc(), subView,
+                                     alloc);
+      subView = alloc;
+    }
+
+    state.mapBuffer(extractSliceOp.result(), subView);
+    return success();
+  }
+};
+
+struct ExtractOpInterface
+    : public BufferizableOpInterface::ExternalModel<ExtractOpInterface,
+                                                    tensor::ExtractOp> {
+  bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand) const {
+    return true;
+  }
+
+  bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand) const {
+    return false;
+  }
+
+  OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand) const {
+    return OpResult();
+  }
+
+  LogicalResult bufferize(Operation *op, OpBuilder &b,
+                          BufferizationState &state) const {
+    auto extractOp = cast<tensor::ExtractOp>(op);
+
+    // Take a guard before anything else.
+    OpBuilder::InsertionGuard g(b);
+    b.setInsertionPoint(extractOp);
+
+    Location loc = extractOp.getLoc();
+    Value srcMemref = state.lookupBuffer(extractOp.tensor());
+    Value l = b.create<memref::LoadOp>(loc, srcMemref, extractOp.indices());
+    extractOp.replaceAllUsesWith(l);
+    return success();
+  }
+};
+
+/// Return true if the (ExtractSliceOp, InsertSliceOp) pair match (i.e.
+/// equivalent operand / result and same offset/sizes/strides specification).
+///
+/// This is one particular type of relationship between ops on tensors that
+/// reduce to an equivalence on buffers. This should be generalized and
+/// exposed as interfaces on the proper types.
+static bool
+areEquivalentExtractSliceOps(const BufferizationAliasInfo &aliasInfo,
+                             ExtractSliceOp st, InsertSliceOp sti) {
+  if (!st || !sti)
+    return false;
+  if (!aliasInfo.areEquivalentBufferizedValues(st.source(), sti.dest()))
+    return false;
+  if (!sameOffsetsSizesAndStrides(st, sti, isEqualConstantIntOrValue))
+    return false;
+  return true;
+}
+
+/// Return true if the source of a `insertSliceOp` bufferizes to an
+/// equivalent ExtractSliceOp that bufferizes inplace.
+static bool isSourceEquivalentToAMatchingInplaceExtractSliceOp(
+    const BufferizationAliasInfo &aliasInfo, InsertSliceOp insertSliceOp) {
+  bool foundOp = false;
+  aliasInfo.applyOnEquivalenceClass(insertSliceOp.source(), [&](Value value) {
+    auto extractSliceOp = value.getDefiningOp<ExtractSliceOp>();
+    if (extractSliceOp &&
+        areEquivalentExtractSliceOps(aliasInfo, extractSliceOp,
+                                     insertSliceOp) &&
+        aliasInfo.isInPlace(extractSliceOp->getResult(0))) {
+      foundOp = true;
+    }
+  });
+  return foundOp;
+}
+
+/// Return true if `value` is originating from an ExtractSliceOp that matches
+/// the given InsertSliceOp.
+static bool hasMatchingExtractSliceOp(const BufferizationAliasInfo &aliasInfo,
+                                      Value value, InsertSliceOp insertOp) {
+  auto condition = [&](Value val) {
+    if (auto extractOp = val.getDefiningOp<ExtractSliceOp>())
+      if (areEquivalentExtractSliceOps(aliasInfo, extractOp, insertOp))
+        return true;
+    return false;
+  };
+
+  return llvm::all_of(findValueInReverseUseDefChain(value, condition),
+                      condition);
+}
+
+struct InsertSliceOpInterface
+    : public BufferizableOpInterface::ExternalModel<InsertSliceOpInterface,
+                                                    tensor::InsertSliceOp> {
+  bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand) const {
+    return true;
+  }
+
+  bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand) const {
+    return &opOperand == &op->getOpOperand(1) /*dest*/;
+  }
+
+  SmallVector<OpOperand *> getAliasingOpOperand(Operation *op,
+                                                OpResult opResult) const {
+    return {&op->getOpOperand(1) /*dest*/};
+  }
+
+  OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand) const {
+    return &opOperand == &op->getOpOperand(1) /*dest*/
+               ? op->getResult(0)
+               : OpResult();
+  }
+
+  BufferRelation bufferRelation(Operation *op, OpOperand &opOperand) const {
+    return BufferRelation::Equivalent;
+  }
+
+  bool isNotConflicting(Operation *op, OpOperand *uRead,
+                        OpOperand *uConflictingWrite,
+                        const BufferizationAliasInfo &aliasInfo) const {
+    Operation *readingOp = uRead->getOwner();
+    Operation *conflictingWritingOp = uConflictingWrite->getOwner();
+
+    // Special rules for matching ExtractSliceOp/InsertSliceOp pairs. If
+    // uRead is an InsertSliceOp...
+    if (auto insertSliceOp = dyn_cast<InsertSliceOp>(readingOp)) {
+      // As an example, consider the following IR.
+      //
+      // %0 = tensor.extract_slice %t[%a, %b][%c, %d][1, 1] {inplace = [true] }
+      // %1 = linalg.fill %cst, %0 {inplace= [true] }
+      // %2 = tensor.insert_slice %1 into %t[%a, %b][%c, %d][1, 1]
+      //     {inplace= [true] }
+
+      // TODO: Use insertSliceOp.getDestOpOperand etc. when available.
+      if (uRead == &insertSliceOp->getOpOperand(1) /*dest*/ &&
+          hasMatchingExtractSliceOp(aliasInfo, uConflictingWrite->get(),
+                                    insertSliceOp))
+        // Case 1: The main insight is that InsertSliceOp reads only part of
+        // the destination tensor. The overwritten area is not read. If
+        // uConflictingWrite writes into exactly the memory location that is
+        // being read by uRead, this is not a conflict.
+        //
+        // In the above example:
+        // uRead             = OpOperand 1 (%t) of tensor.insert_slice
+        // uConflictingWrite = OpOperand 1 (%0) of linalg.fill
+        //
+        // The read of %t does not conflict with the write of the FillOp
+        // (same aliases!) because the area that the FillOp operates on is
+        // exactly the one that is *not* read via %t.
+        return true;
+
+      if (uRead == &insertSliceOp->getOpOperand(0) /*source*/ &&
+          uConflictingWrite == &insertSliceOp->getOpOperand(1) /*dest*/ &&
+          hasMatchingExtractSliceOp(aliasInfo, uRead->get(), insertSliceOp))
+        // Case 2: The read of the source tensor and the write to the dest
+        // tensor via an InsertSliceOp is not a conflict if the read is
+        // reading exactly that part of an equivalent tensor that the
+        // InsertSliceOp is writing.
+        //
+        // In the above example:
+        // uRead             = OpOperand 0 (%1) of tensor.insert_slice
+        // uConflictingWrite = OpOperand 1 (%t) of tensor.insert_slice
+        return true;
+    }
+
+    // If uConflictingWrite is an InsertSliceOp...
+    if (auto insertSliceOp = dyn_cast<InsertSliceOp>(conflictingWritingOp))
+      // As an example, consider the following IR.
+      //
+      // %0 = tensor.extract_slice %t[%a, %b][%c, %d][1, 1] {inplace = [true] }
+      // %1 = linalg.fill %cst, %0 {inplace= [true] }
+      // %2 = tensor.insert_slice %1 into %t[%a, %b][%c, %d][1, 1]
+      //     {inplace= [true] }
+      // %3 = vector.transfer_read %1, %cst
+      //
+      // In the above example:
+      // uRead             = OpOperand 0 (%1) of vector.transfer_read
+      // uConflictingWrite = OpOperand 1 (%t) of tensor.insert_slice
+      // lastWrite         = %1
+      //
+      // This is not a conflict because the InsertSliceOp overwrites the
+      // memory segment of %1 with the exact same data. (Effectively, there
+      // is no memory write here.)
+      if (uConflictingWrite == &insertSliceOp->getOpOperand(1) /*dest*/ &&
+          aliasInfo.areEquivalentBufferizedValues(uRead->get(),
+                                                  insertSliceOp.source()) &&
+          hasMatchingExtractSliceOp(aliasInfo, insertSliceOp.source(),
+                                    insertSliceOp))
+        return true;
+
+    return false;
+  }
+
+  LogicalResult bufferize(Operation *op, OpBuilder &b,
+                          BufferizationState &state) const {
+    // insert_slice ops arise from tiling and bufferizing them out-of-place is
+    // generally a deal breaker. When used with loops, this ends up cloning the
+    // whole tensor on every single iteration and is a symptom of a
+    // catastrophically bad scheduling decision.
+    // TODO: be very loud about it or even consider failing the pass.
+    auto insertSliceOp = cast<tensor::InsertSliceOp>(op);
+
+    // Take a guard before anything else.
+    OpBuilder::InsertionGuard g(b);
+    b.setInsertionPoint(insertSliceOp);
+    Location loc = insertSliceOp.getLoc();
+
+    // When bufferizing out-of-place, `getResultBuffer` allocates.
+    Value dstMemref = getResultBuffer(b, insertSliceOp->getResult(0), state);
+    if (!dstMemref)
+      return failure();
+
+    // A copy of the source buffer is needed if either:
+    //   - The producer of `source` is not inplace. This is the case where a
+    //     slice is computed out of place into the inplace full tensor.
+    //   - The result is not inplace. This is the case where the whole tensor is
+    //     cloned and the clone needs to be updated.
+    // TODO: Is this necessary?
+    bool needCopy = !isSourceEquivalentToAMatchingInplaceExtractSliceOp(
+                        state.aliasInfo, insertSliceOp) ||
+                    !state.aliasInfo.isInPlace(insertSliceOp->getResult(0));
+    if (needCopy) {
+      // Take a subview of the dst.
+      auto dstMemrefType = dstMemref.getType().cast<MemRefType>();
+      auto subviewMemRefType =
+          memref::SubViewOp::inferRankReducedResultType(
+              insertSliceOp.getSourceType().getRank(), dstMemrefType,
+              insertSliceOp.getMixedOffsets(), insertSliceOp.getMixedSizes(),
+              insertSliceOp.getMixedStrides())
+              .cast<MemRefType>();
+      Value subView = b.create<memref::SubViewOp>(
+          loc, subviewMemRefType, dstMemref, insertSliceOp.getMixedOffsets(),
+          insertSliceOp.getMixedSizes(), insertSliceOp.getMixedStrides());
+      // Insert new alias.
+      state.aliasInfo.insertNewBufferAlias(subView, dstMemref);
+      // Copy tensor.
+      Value srcMemref = state.lookupBuffer(insertSliceOp.source());
+      state.allocationFns.memCpyFn(b, insertSliceOp.getLoc(), srcMemref,
+                                   subView);
+    }
+
+    state.mapBuffer(insertSliceOp.result(), dstMemref);
+    return success();
+  }
+};
+
+} // namespace tensor_ext
+} // namespace comprehensive_bufferize
+} // namespace linalg
+} // namespace mlir
+
+void mlir::linalg::comprehensive_bufferize::tensor_ext::
+    registerBufferizableOpInterfaceExternalModels(DialectRegistry &registry) {
+  registry.addOpInterface<tensor::CastOp, tensor_ext::CastOpInterface>();
+  registry.addOpInterface<tensor::DimOp, tensor_ext::DimOpInterface>();
+  registry.addOpInterface<tensor::ExtractSliceOp,
+                          tensor_ext::ExtractSliceOpInterface>();
+  registry.addOpInterface<tensor::ExtractOp, tensor_ext::ExtractOpInterface>();
+  registry.addOpInterface<tensor::InsertSliceOp,
+                          tensor_ext::InsertSliceOpInterface>();
+}

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
index 25368352bcd19..a098e9e29fa4f 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
@@ -49,6 +49,7 @@ add_mlir_dialect_library(MLIRLinalgTransforms
   MLIRStandardOpsTransforms
   MLIRStandardToLLVM
   MLIRTensor
+  MLIRTensorBufferizableOpInterfaceImpl
   MLIRTransforms
   MLIRTransformUtils
   MLIRVector

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp
index 626eafa31bbb2..69182f92b70b6 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp
@@ -10,6 +10,7 @@
 #include "mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h"
 #include "mlir/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.h"
 #include "mlir/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.h"
+#include "mlir/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.h"
 #include "mlir/Dialect/Linalg/Passes.h"
 #include "mlir/Pass/Pass.h"
 #include "mlir/Pass/PassManager.h"
@@ -38,6 +39,7 @@ struct LinalgComprehensiveModuleBufferize
                 arith::ArithmeticDialect, StandardOpsDialect, AffineDialect>();
     registerBufferizableOpInterfaceExternalModels(registry);
     linalg_ext::registerBufferizableOpInterfaceExternalModels(registry);
+    tensor_ext::registerBufferizableOpInterfaceExternalModels(registry);
   }
 };
 } // end namespace

diff  --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
index 969d82c2c163d..aaeac453cc085 100644
--- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
@@ -6326,6 +6326,25 @@ cc_library(
     ],
 )
 
+cc_library(
+    name = "TensorBufferizableOpInterfaceImpl",
+    srcs = [
+        "lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp",
+    ],
+    hdrs = [
+        "include/mlir/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.h",
+    ],
+    includes = ["include"],
+    deps = [
+        ":BufferizableOpInterface",
+        ":IR",
+        ":MemRefDialect",
+        ":Support",
+        ":TensorDialect",
+        "//llvm:Support",
+    ],
+)
+
 td_library(
     name = "LinalgDocTdFiles",
     srcs = ["include/mlir/Dialect/Linalg/IR/LinalgDoc.td"],
@@ -6545,6 +6564,7 @@ cc_library(
         ":StandardOps",
         ":StandardOpsTransforms",
         ":Support",
+        ":TensorBufferizableOpInterfaceImpl",
         ":TensorDialect",
         ":TransformUtils",
         ":VectorOps",
@@ -6575,7 +6595,6 @@ cc_library(
         ":SCFDialect",
         ":StandardOps",
         ":Support",
-        ":TensorDialect",
         ":TransformUtils",
         ":VectorOps",
         "//llvm:Support",


        


More information about the Mlir-commits mailing list