[Mlir-commits] [mlir] [mlir][bufferization] Use TensorLike, BufferLike type interfaces (PR #136736)

Andrei Golubev llvmlistbot at llvm.org
Tue Apr 22 11:02:40 PDT 2025


https://github.com/andrey-golubev created https://github.com/llvm/llvm-project/pull/136736

The general idea is to replace most of the places that rely on builtin's TensorType / BaseMemRefType with the newly added type interfaces.

Thus far, do the bare minimum: refactor (almost) "blindly" the API of the dialect and options, leaving most of the logic "as is". The exceptions are the bufferization.{to_tensor, to_memref} ops that act as "glue" when bufferizing neighbouring operations and the enclosing functions.

>From fe90c52e99e4655eeabf7985944953e66dda6565 Mon Sep 17 00:00:00 2001
From: "Golubev, Andrey" <andrey.golubev at intel.com>
Date: Thu, 17 Apr 2025 15:36:01 +0000
Subject: [PATCH] [mlir][bufferization] Use TensorLike, BufferLike type
 interfaces

The general idea is to replace most of the places that rely on builtin's
TensorType / BaseMemRefType with the newly added type interfaces.

Thus far, do the bare minimum: refactor (almost) "blindly" the API of
the dialect and options, leaving most of the logic "as is". The
exceptions are the bufferization.{to_tensor, to_memref} ops that act as
"glue" when bufferizing neighbouring operations and the enclosing
functions.
---
 .../IR/BufferizableOpInterface.h              |  21 ++--
 .../IR/BufferizableOpInterface.td             |   2 +-
 .../Bufferization/IR/BufferizationOps.td      |  17 +--
 .../IR/BufferizationTypeInterfaces.h          |   1 +
 .../IR/BufferizationTypeInterfaces.td         |  13 ++-
 .../IR/UnstructuredControlFlow.h              |  35 +++---
 .../BufferizableOpInterfaceImpl.cpp           |  13 ++-
 .../IR/BufferizableOpInterface.cpp            |  94 ++++++++-------
 .../Bufferization/IR/BufferizationDialect.cpp |   6 +-
 .../Bufferization/IR/BufferizationOps.cpp     |  30 ++---
 .../IR/BufferizationTypeInterfaces.cpp        |  21 ++++
 .../Dialect/Bufferization/IR/CMakeLists.txt   |   1 +
 .../Transforms/BufferViewFlowAnalysis.cpp     |  17 +--
 .../Bufferization/Transforms/Bufferize.cpp    |  18 +--
 .../FuncBufferizableOpInterfaceImpl.cpp       |  30 ++---
 .../BufferizableOpInterfaceImpl.cpp           | 109 ++++++++++--------
 .../SparsificationAndBufferizationPass.cpp    |   4 +-
 .../Transforms/Utils/CodegenUtils.cpp         |   4 +-
 .../BufferizableOpInterfaceImpl.cpp           |   9 +-
 .../Transforms/one-shot-bufferize.mlir        |  21 +++-
 mlir/test/Dialect/Bufferization/invalid.mlir  |   8 +-
 .../Bufferization/TestTensorCopyInsertion.cpp |   4 +-
 mlir/test/lib/Dialect/Test/TestOpDefs.cpp     |  24 ++++
 mlir/test/lib/Dialect/Test/TestOps.h          |   1 +
 mlir/test/lib/Dialect/Test/TestOps.td         |  55 ++++++++-
 mlir/test/lib/Dialect/Test/TestTypeDefs.td    |   3 +
 26 files changed, 370 insertions(+), 191 deletions(-)
 create mode 100644 mlir/lib/Dialect/Bufferization/IR/BufferizationTypeInterfaces.cpp

diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
index ada9539e87121..70092908d961f 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
@@ -17,6 +17,7 @@
 #include <optional>
 
 #include "mlir/Dialect/Bufferization/IR/BufferizationEnums.h.inc"
+#include "mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.h"
 
 namespace mlir {
 class OpBuilder;
@@ -259,18 +260,18 @@ struct BufferizationOptions {
       std::function<LogicalResult(OpBuilder &, Location, Value, Value)>;
   /// Initializer function for analysis state.
   using AnalysisStateInitFn = std::function<void(AnalysisState &)>;
-  /// Tensor -> MemRef type converter.
-  /// Parameters: tensor type, memory space, func op, bufferization options
+  /// TensorLike -> BufferLike type converter.
+  /// Parameters: tensor like type, memory space, func op, bufferization options
   using FunctionArgTypeConverterFn =
-      std::function<BaseMemRefType(TensorType, Attribute memorySpace,
+      std::function<BufferLikeType(TensorLikeType, Attribute memorySpace,
                                    func::FuncOp, const BufferizationOptions &)>;
-  /// Tensor -> MemRef type converter.
+  /// TensorLike -> BufferLike type converter.
   /// Parameters: Value, memory space, bufferization options
-  using UnknownTypeConverterFn = std::function<BaseMemRefType(
+  using UnknownTypeConverterFn = std::function<BufferLikeType(
       Value, Attribute memorySpace, const BufferizationOptions &)>;
   // Produce a MemorySpace attribute from a tensor type
   using DefaultMemorySpaceFn =
-      std::function<std::optional<Attribute>(TensorType t)>;
+      std::function<std::optional<Attribute>(TensorLikeType t)>;
 
   BufferizationOptions();
 
@@ -360,7 +361,7 @@ struct BufferizationOptions {
   // Returning std::nullopt will cause bufferization to fail (useful to indicate
   // failure to determine memory space for a tensor type).
   DefaultMemorySpaceFn defaultMemorySpaceFn =
-      [](TensorType t) -> std::optional<Attribute> { return Attribute(); };
+      [](TensorLikeType t) -> std::optional<Attribute> { return Attribute(); };
 
   /// If set to `true`, the analysis is skipped. A buffer is copied before every
   /// write. This flag cannot be used together with `testAnalysisOnly = true`.
@@ -600,7 +601,7 @@ FailureOr<Value> getBuffer(RewriterBase &rewriter, Value value,
 /// IR, this function can be used.
 ///
 /// This function is a wrapper around BufferizableOpInterface::getBufferType.
-FailureOr<BaseMemRefType> getBufferType(Value value,
+FailureOr<BufferLikeType> getBufferType(Value value,
                                         const BufferizationOptions &options);
 
 /// Return the buffer type for a given Value (tensor) after bufferization
@@ -613,7 +614,7 @@ FailureOr<BaseMemRefType> getBufferType(Value value,
 /// IR, this function can be used.
 ///
 /// This function is a wrapper around `BufferizableOpInterface::getBufferType`.
-FailureOr<BaseMemRefType> getBufferType(Value value,
+FailureOr<BufferLikeType> getBufferType(Value value,
                                         const BufferizationOptions &options,
                                         SmallVector<Value> &invocationStack);
 
@@ -693,7 +694,7 @@ AliasingOpOperandList defaultGetAliasingOpOperands(Value value,
 /// This is the default implementation of
 /// BufferizableOpInterface::getBufferType. Should not be called from other
 /// places.
-FailureOr<BaseMemRefType>
+FailureOr<BufferLikeType>
 defaultGetBufferType(Value value, const BufferizationOptions &options,
                      SmallVector<Value> &invocationStack);
 
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td
index 95022d7d665d2..1de1742fab81a 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td
@@ -518,7 +518,7 @@ def BufferizableOpInterface : OpInterface<"BufferizableOpInterface"> {
           Note: This interface method should never be called directly from user
           code. Always use `bufferization::getBufferType`.
         }],
-        /*retType=*/"::mlir::FailureOr<::mlir::BaseMemRefType>",
+        /*retType=*/"::mlir::FailureOr<::mlir::bufferization::BufferLikeType>",
         /*methodName=*/"getBufferType",
         /*args=*/(ins "::mlir::Value":$value,
                       "const ::mlir::bufferization::BufferizationOptions &":$options,
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
index fad78a63444b9..81ce0f3fb650b 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
@@ -13,6 +13,7 @@ include "mlir/Dialect/Bufferization/IR/AllocationOpInterface.td"
 include "mlir/Dialect/Bufferization/IR/BufferViewFlowOpInterface.td"
 include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td"
 include "mlir/Dialect/Bufferization/IR/BufferizationBase.td"
+include "mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.td"
 include "mlir/Interfaces/DestinationStyleOpInterface.td"
 include "mlir/Interfaces/InferTypeOpInterface.td"
 include "mlir/Interfaces/SideEffectInterfaces.td"
@@ -109,7 +110,7 @@ def Bufferization_AllocTensorOp : Bufferization_Op<"alloc_tensor",
     AliasingValueList getAliasingValues(
         OpOperand &opOperand, const AnalysisState &state);
 
-    FailureOr<BaseMemRefType> getBufferType(
+    FailureOr<BufferLikeType> getBufferType(
         Value value, const BufferizationOptions &options,
         SmallVector<Value> &invocationStack);
 
@@ -438,11 +439,11 @@ def Bufferization_ToTensorOp : Bufferization_Op<"to_tensor", [
     away. However, such IR is no longer bufferizable with One-Shot Bufferize.
   }];
 
-  let arguments = (ins Arg<AnyRankedOrUnrankedMemRef,
+  let arguments = (ins Arg<Bufferization_BufferLikeTypeInterface,
                            "the reference to load from",
                            [MemReadAt<0, FullEffect>]>:$memref,
                        UnitAttr:$restrict, UnitAttr:$writable);
-  let results = (outs AnyTensor:$result);
+  let results = (outs Bufferization_TensorLikeTypeInterface:$result);
 
   let extraClassDeclaration = [{
     /// The result of a to_tensor is always a tensor.
@@ -465,10 +466,10 @@ def Bufferization_ToTensorOp : Bufferization_Op<"to_tensor", [
 
     bool isWritable(Value value, const AnalysisState &state);
 
-    FailureOr<BaseMemRefType> getBufferType(
+    FailureOr<BufferLikeType> getBufferType(
         Value value, const BufferizationOptions &options,
         SmallVector<Value> &invocationStack) {
-      return ::llvm::cast<BaseMemRefType>(getMemref().getType());
+      return ::llvm::cast<BufferLikeType>(getMemref().getType());
     }
   }];
 
@@ -493,6 +494,7 @@ def Bufferization_ToTensorOp : Bufferization_Op<"to_tensor", [
 // ToMemrefOp
 //===----------------------------------------------------------------------===//
 
+// TODO: rename to "to_buffer"
 def Bufferization_ToMemrefOp : Bufferization_Op<"to_memref", [
     BufferizableOpInterface,
     SameOperandsAndResultShape,
@@ -519,8 +521,9 @@ def Bufferization_ToMemrefOp : Bufferization_Op<"to_memref", [
     the returned buffer) will not be written to.
   }];
 
-  let arguments = (ins AnyTensor:$tensor, UnitAttr:$read_only);
-  let results = (outs AnyRankedOrUnrankedMemRef:$memref);
+  let arguments = (ins Bufferization_TensorLikeTypeInterface:$tensor,
+                       UnitAttr:$read_only);
+  let results = (outs Bufferization_BufferLikeTypeInterface:$memref);
 
   let extraClassDeclaration = [{
     //===------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.h
index 5faa1479ee542..290f1298f2501 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.h
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.h
@@ -13,6 +13,7 @@
 // Bufferization Type Interfaces
 //===----------------------------------------------------------------------===//
 
+#include "mlir/IR/Attributes.h" // mlir::Attribute
 #include "mlir/IR/Types.h"
 
 #include "mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.h.inc"
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.td b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.td
index f19224a295648..c053a6bdc1a91 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.td
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.td
@@ -33,10 +33,17 @@ def Bufferization_BufferLikeTypeInterface
   let description = [{
     Indicates that this type is a buffer type (similarly to a MLIR builtin
     memref) for bufferization purposes.
-
-    The interface currently has no methods as it is used by types to opt into
-    being supported by the bufferization procedures.
   }];
+
+  let methods = [
+    InterfaceMethod<
+      /*desc=*/[{
+        Returns the memory space in which data referred to by this buffer resides.
+      }],
+      /*retType=*/"::mlir::Attribute",
+      /*methodName=*/"getMemorySpace"
+    >,
+  ];
 }
 
 #endif // BUFFERIZATION_TYPE_INTERFACES
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/UnstructuredControlFlow.h b/mlir/include/mlir/Dialect/Bufferization/IR/UnstructuredControlFlow.h
index 78109770efab7..89eb65c4a0942 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/UnstructuredControlFlow.h
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/UnstructuredControlFlow.h
@@ -32,7 +32,7 @@ template <typename ConcreteModel, typename ConcreteOp>
 struct OpWithUnstructuredControlFlowBufferizableOpInterfaceExternalModel
     : public BufferizableOpInterface::ExternalModel<ConcreteModel, ConcreteOp> {
 
-  FailureOr<BaseMemRefType>
+  FailureOr<BufferLikeType>
   getBufferType(Operation *op, Value value, const BufferizationOptions &options,
                 SmallVector<Value> &invocationStack) const {
     // Note: The user may want to override this function for OpResults in
@@ -46,7 +46,7 @@ struct OpWithUnstructuredControlFlowBufferizableOpInterfaceExternalModel
     // operand types of all forwarded values. If these are all the same type,
     // take that type. Otherwise, take only the memory space and fall back to a
     // buffer type with a fully dynamic layout map.
-    BaseMemRefType bufferType;
+    BufferLikeType bufferType;
     auto tensorType = cast<TensorType>(value.getType());
     for (OpOperand *opOperand :
          detail::getCallerOpOperands(cast<BlockArgument>(value))) {
@@ -59,13 +59,13 @@ struct OpWithUnstructuredControlFlowBufferizableOpInterfaceExternalModel
         continue;
 
       // Compute the bufferized type of the forwarded operand.
-      BaseMemRefType callerType;
-      if (auto memrefType =
-              dyn_cast<BaseMemRefType>(opOperand->get().getType())) {
+      BufferLikeType callerType;
+      if (auto bufferLikeType =
+              dyn_cast<BufferLikeType>(opOperand->get().getType())) {
         // The operand was already bufferized. Take its type directly.
-        callerType = memrefType;
+        callerType = bufferLikeType;
       } else {
-        FailureOr<BaseMemRefType> maybeCallerType =
+        FailureOr<BufferLikeType> maybeCallerType =
             bufferization::getBufferType(opOperand->get(), options,
                                          invocationStack);
         if (failed(maybeCallerType))
@@ -86,14 +86,20 @@ struct OpWithUnstructuredControlFlowBufferizableOpInterfaceExternalModel
         // of the earlier forwarded operands, fall back to a buffer type with a
         // fully dynamic layout map.
 #ifndef NDEBUG
+      assert(mlir::isa<BaseMemRefType>(bufferType) &&
+             mlir::isa<BaseMemRefType>(callerType) && "expected memrefs");
+      auto memrefType = mlir::cast<BaseMemRefType>(bufferType);
+      auto callerMemrefType = mlir::cast<BaseMemRefType>(callerType);
+
       if (auto rankedTensorType = dyn_cast<RankedTensorType>(tensorType)) {
-        assert(bufferType.hasRank() && callerType.hasRank() &&
+        assert(memrefType.hasRank() && callerMemrefType.hasRank() &&
                "expected ranked memrefs");
-        assert(llvm::all_equal({bufferType.getShape(), callerType.getShape(),
-                                rankedTensorType.getShape()}) &&
-               "expected same shape");
+        assert(
+            llvm::all_equal({memrefType.getShape(), callerMemrefType.getShape(),
+                             rankedTensorType.getShape()}) &&
+            "expected same shape");
       } else {
-        assert(!bufferType.hasRank() && !callerType.hasRank() &&
+        assert(!memrefType.hasRank() && !callerMemrefType.hasRank() &&
                "expected unranked memrefs");
       }
 #endif // NDEBUG
@@ -102,8 +108,9 @@ struct OpWithUnstructuredControlFlowBufferizableOpInterfaceExternalModel
         return op->emitOpError("incoming operands of block argument have "
                                "inconsistent memory spaces");
 
-      bufferType = getMemRefTypeWithFullyDynamicLayout(
-          tensorType, bufferType.getMemorySpace());
+      bufferType =
+          mlir::cast<BufferLikeType>(getMemRefTypeWithFullyDynamicLayout(
+              tensorType, bufferType.getMemorySpace()));
     }
 
     if (!bufferType)
diff --git a/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp
index 5e69a98db8f1e..433757192bfd1 100644
--- a/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -26,7 +26,7 @@ struct ConstantOpInterface
   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
                           const BufferizationOptions &options) const {
     auto constantOp = cast<arith::ConstantOp>(op);
-    auto type = dyn_cast<RankedTensorType>(constantOp.getType());
+    auto type = dyn_cast<TensorLikeType>(constantOp.getType());
 
     // Only ranked tensors are supported.
     if (!type)
@@ -176,7 +176,7 @@ struct SelectOpInterface
     return success();
   }
 
-  FailureOr<BaseMemRefType>
+  FailureOr<bufferization::BufferLikeType>
   getBufferType(Operation *op, Value value, const BufferizationOptions &options,
                 SmallVector<Value> &invocationStack) const {
     auto selectOp = cast<arith::SelectOp>(op);
@@ -195,10 +195,11 @@ struct SelectOpInterface
     // If the buffers have different types, they differ only in their layout
     // map.
     auto memrefType = llvm::cast<MemRefType>(*trueType);
-    return getMemRefTypeWithFullyDynamicLayout(
-        RankedTensorType::get(memrefType.getShape(),
-                              memrefType.getElementType()),
-        memrefType.getMemorySpace());
+    return mlir::cast<bufferization::BufferLikeType>(
+        getMemRefTypeWithFullyDynamicLayout(
+            RankedTensorType::get(memrefType.getShape(),
+                                  memrefType.getElementType()),
+            memrefType.getMemorySpace()));
   }
 };
 
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
index 99ffa62c41a4d..82ff1bdfe5fd7 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
@@ -206,12 +206,13 @@ FailureOr<Value> bufferization::allocateTensorForShapedValue(
   // Add 'memory_space' attribute. Not needed if 'copy' operand is specified.
   if (copy)
     return allocTensorOp.getResult();
-  FailureOr<BaseMemRefType> copyBufferType = getBufferType(tensor, options);
+  FailureOr<BufferLikeType> copyBufferType = getBufferType(tensor, options);
   if (failed(copyBufferType))
     return failure();
   std::optional<Attribute> memorySpace = copyBufferType->getMemorySpace();
   if (!memorySpace)
-    memorySpace = options.defaultMemorySpaceFn(tensorType);
+    memorySpace =
+        options.defaultMemorySpaceFn(mlir::cast<TensorLikeType>(tensorType));
   if (memorySpace.has_value())
     allocTensorOp.setMemorySpaceAttr(memorySpace.value());
   return allocTensorOp.getResult();
@@ -229,6 +230,7 @@ LogicalResult BufferizableOpInterface::resolveTensorOpOperandConflicts(
   // Find all out-of-place OpOperands.
   for (OpOperand &opOperand : op->getOpOperands()) {
     Type operandType = opOperand.get().getType();
+    // Note: can only copy TensorType (any other TensorLikeType is rejected)
     if (!llvm::isa<TensorType>(operandType))
       continue;
     if (state.isInPlace(opOperand))
@@ -328,18 +330,21 @@ bool OpFilter::isOpAllowed(Operation *op) const {
 namespace {
 
 /// Default function arg type converter: Use a fully dynamic layout map.
-BaseMemRefType
-defaultFunctionArgTypeConverter(TensorType type, Attribute memorySpace,
-                                func::FuncOp funcOp,
+bufferization::BufferLikeType
+defaultFunctionArgTypeConverter(bufferization::TensorLikeType type,
+                                Attribute memorySpace, func::FuncOp funcOp,
                                 const BufferizationOptions &options) {
-  return getMemRefTypeWithFullyDynamicLayout(type, memorySpace);
+  return mlir::cast<bufferization::BufferLikeType>(
+      getMemRefTypeWithFullyDynamicLayout(mlir::cast<TensorType>(type),
+                                          memorySpace));
 }
 /// Default unknown type converter: Use a fully dynamic layout map.
-BaseMemRefType
+BufferLikeType
 defaultUnknownTypeConverter(Value value, Attribute memorySpace,
                             const BufferizationOptions &options) {
-  return getMemRefTypeWithFullyDynamicLayout(
-      llvm::cast<TensorType>(value.getType()), memorySpace);
+  return mlir::cast<bufferization::BufferLikeType>(
+      getMemRefTypeWithFullyDynamicLayout(
+          llvm::cast<TensorType>(value.getType()), memorySpace));
 }
 
 } // namespace
@@ -376,14 +381,16 @@ BufferizationOptions::dynCastBufferizableOp(Value value) const {
 
 void BufferizationOptions::setFunctionBoundaryTypeConversion(
     LayoutMapOption layoutMapOption) {
-  functionArgTypeConverterFn = [=](TensorType tensorType, Attribute memorySpace,
-                                   func::FuncOp funcOp,
+  functionArgTypeConverterFn = [=](TensorLikeType tensorType,
+                                   Attribute memorySpace, func::FuncOp funcOp,
                                    const BufferizationOptions &options) {
     if (layoutMapOption == LayoutMapOption::IdentityLayoutMap)
-      return bufferization::getMemRefTypeWithStaticIdentityLayout(tensorType,
-                                                                  memorySpace);
-    return bufferization::getMemRefTypeWithFullyDynamicLayout(tensorType,
-                                                              memorySpace);
+      return mlir::cast<bufferization::BufferLikeType>(
+          bufferization::getMemRefTypeWithStaticIdentityLayout(
+              mlir::cast<TensorType>(tensorType), memorySpace));
+    return mlir::cast<bufferization::BufferLikeType>(
+        bufferization::getMemRefTypeWithFullyDynamicLayout(
+            mlir::cast<TensorType>(tensorType), memorySpace));
   };
   inferFunctionResultLayout =
       layoutMapOption == LayoutMapOption::InferLayoutMap;
@@ -473,7 +480,8 @@ bool AnalysisState::bufferizesToMemoryWrite(Value value) const {
 /// read. Also takes into account ops that create an alias but do not read by
 /// themselves (e.g., ExtractSliceOp).
 bool AnalysisState::isValueRead(Value value) const {
-  assert(llvm::isa<TensorType>(value.getType()) && "expected TensorType");
+  assert(llvm::isa<bufferization::TensorLikeType>(value.getType()) &&
+         "expected TensorLikeType");
   SmallVector<OpOperand *> workingSet;
   DenseSet<OpOperand *> visited;
   for (OpOperand &use : value.getUses())
@@ -663,7 +671,8 @@ static void ensureToMemrefOpIsValid(Value tensor, Type memrefType) {
 FailureOr<Value> bufferization::getBuffer(RewriterBase &rewriter, Value value,
                                           const BufferizationOptions &options) {
 #ifndef NDEBUG
-  auto tensorType = llvm::dyn_cast<TensorType>(value.getType());
+  auto tensorType =
+      llvm::dyn_cast<bufferization::TensorLikeType>(value.getType());
   assert(tensorType && "unexpected non-tensor type");
 #endif // NDEBUG
 
@@ -674,7 +683,7 @@ FailureOr<Value> bufferization::getBuffer(RewriterBase &rewriter, Value value,
   // Insert to_memref op.
   OpBuilder::InsertionGuard g(rewriter);
   setInsertionPointAfter(rewriter, value);
-  FailureOr<BaseMemRefType> memrefType = getBufferType(value, options);
+  FailureOr<BufferLikeType> memrefType = getBufferType(value, options);
   if (failed(memrefType))
     return failure();
   ensureToMemrefOpIsValid(value, *memrefType);
@@ -684,18 +693,18 @@ FailureOr<Value> bufferization::getBuffer(RewriterBase &rewriter, Value value,
 }
 
 /// Return the buffer type for a given Value (tensor) after bufferization.
-FailureOr<BaseMemRefType>
+FailureOr<BufferLikeType>
 bufferization::getBufferType(Value value, const BufferizationOptions &options) {
   SmallVector<Value> invocationStack;
   return getBufferType(value, options, invocationStack);
 }
 
 /// Return the buffer type for a given Value (tensor) after bufferization.
-FailureOr<BaseMemRefType>
+FailureOr<BufferLikeType>
 bufferization::getBufferType(Value value, const BufferizationOptions &options,
                              SmallVector<Value> &invocationStack) {
-  assert(llvm::isa<TensorType>(value.getType()) &&
-         "unexpected non-tensor type");
+  assert(llvm::isa<TensorLikeType>(value.getType()) &&
+         "unexpected non-tensor-like type");
   invocationStack.push_back(value);
   auto popFromStack =
       llvm::make_scope_exit([&]() { invocationStack.pop_back(); });
@@ -708,11 +717,12 @@ bufferization::getBufferType(Value value, const BufferizationOptions &options,
 
   // Op is not bufferizable.
   auto memSpace =
-      options.defaultMemorySpaceFn(cast<TensorType>(value.getType()));
+      options.defaultMemorySpaceFn(cast<TensorLikeType>(value.getType()));
   if (!memSpace.has_value())
     return op->emitError("could not infer memory space");
 
-  return getMemRefType(value, options, /*layout=*/{}, *memSpace);
+  return mlir::cast<BufferLikeType>(
+      getMemRefType(value, options, /*layout=*/{}, *memSpace));
 }
 
 bool bufferization::hasTensorSemantics(Operation *op) {
@@ -732,12 +742,11 @@ void bufferization::replaceOpWithBufferizedValues(RewriterBase &rewriter,
   SmallVector<Value> replacements;
   for (OpResult opResult : op->getOpResults()) {
     Value replacement = values[opResult.getResultNumber()];
-    if (llvm::isa<TensorType>(opResult.getType())) {
-      // The OpResult is a tensor. Such values are replaced with memrefs during
+    if (llvm::isa<bufferization::TensorLikeType>(opResult.getType())) {
+      // The OpResult is a tensor. Such values are replaced with buffers during
       // bufferization.
-      assert((llvm::isa<MemRefType>(replacement.getType()) ||
-              llvm::isa<UnrankedMemRefType>(replacement.getType())) &&
-             "tensor op result should be replaced with a memref value");
+      assert(llvm::isa<bufferization::BufferLikeType>(replacement.getType()) &&
+             "tensor op result should be replaced with a buffer value");
       // The existing uses of the OpResult still expect a tensor. Insert a
       // ToTensorOp. Throughout bufferization, this ToTensorOp will gradually
       // loose all of its users and eventually DCE away.
@@ -789,6 +798,8 @@ BaseMemRefType bufferization::getMemRefType(Value value,
                                             const BufferizationOptions &options,
                                             MemRefLayoutAttrInterface layout,
                                             Attribute memorySpace) {
+  assert(mlir::isa<TensorType>(value.getType()) &&
+         "expected tensor type in tensor -> memref conversion");
   auto tensorType = llvm::cast<TensorType>(value.getType());
 
   // Case 1: Unranked memref type.
@@ -807,7 +818,8 @@ BaseMemRefType bufferization::getMemRefType(Value value,
                            memorySpace);
   }
 
-  return options.unknownTypeConverterFn(value, memorySpace, options);
+  return mlir::cast<BaseMemRefType>(
+      options.unknownTypeConverterFn(value, memorySpace, options));
 }
 
 BaseMemRefType
@@ -928,7 +940,7 @@ AliasingOpOperandList bufferization::detail::defaultGetAliasingOpOperands(
   Operation *op = getOwnerOfValue(value);
   SmallVector<AliasingOpOperand> result;
   for (OpOperand &opOperand : op->getOpOperands()) {
-    if (!llvm::isa<TensorType>(opOperand.get().getType()))
+    if (!llvm::isa<bufferization::TensorLikeType>(opOperand.get().getType()))
       continue;
     AliasingValueList aliasingValues = state.getAliasingValues(opOperand);
     for (const auto &it : aliasingValues)
@@ -938,14 +950,15 @@ AliasingOpOperandList bufferization::detail::defaultGetAliasingOpOperands(
   return AliasingOpOperandList(std::move(result));
 }
 
-FailureOr<BaseMemRefType> bufferization::detail::defaultGetBufferType(
+FailureOr<BufferLikeType> bufferization::detail::defaultGetBufferType(
     Value value, const BufferizationOptions &options,
     SmallVector<Value> &invocationStack) {
-  assert(llvm::isa<TensorType>(value.getType()) && "expected tensor type");
+  assert(llvm::isa<TensorLikeType>(value.getType()) && "expected tensor type");
 
   // No further analysis is possible for a block argument.
   if (llvm::isa<BlockArgument>(value))
-    return bufferization::getMemRefType(value, options);
+    return mlir::cast<BufferLikeType>(
+        bufferization::getMemRefType(value, options));
 
   // Value is an OpResult.
   Operation *op = getOwnerOfValue(value);
@@ -963,11 +976,12 @@ FailureOr<BaseMemRefType> bufferization::detail::defaultGetBufferType(
   // If we do not know the memory space and there is no default memory space,
   // report a failure.
   auto memSpace =
-      options.defaultMemorySpaceFn(cast<TensorType>(value.getType()));
+      options.defaultMemorySpaceFn(cast<TensorLikeType>(value.getType()));
   if (!memSpace.has_value())
     return op->emitError("could not infer memory space");
 
-  return getMemRefType(value, options, /*layout=*/{}, *memSpace);
+  return mlir::cast<BufferLikeType>(
+      getMemRefType(value, options, /*layout=*/{}, *memSpace));
 }
 
 bool bufferization::detail::defaultIsRepetitiveRegion(
@@ -993,7 +1007,7 @@ bufferization::detail::unknownGetAliasingOpOperands(Value value) {
   // with every OpOperand.
   AliasingOpOperandList r;
   for (OpOperand &operand : value.getDefiningOp()->getOpOperands())
-    if (isa<TensorType>(operand.get().getType()))
+    if (isa<bufferization::TensorLikeType>(operand.get().getType()))
       r.addAlias({&operand, BufferRelation::Unknown, /*isDefinite=*/false});
   return r;
 }
@@ -1006,18 +1020,18 @@ bufferization::detail::unknownGetAliasingValues(OpOperand &opOperand) {
   // with every OpOperand.
   AliasingValueList r;
   for (OpResult result : opOperand.getOwner()->getOpResults())
-    if (llvm::isa<TensorType>(result.getType()))
+    if (llvm::isa<bufferization::TensorLikeType>(result.getType()))
       r.addAlias({result, BufferRelation::Unknown, /*isDefinite=*/false});
   for (Region &region : opOperand.getOwner()->getRegions())
     if (!region.getBlocks().empty())
       for (BlockArgument bbArg : region.getBlocks().front().getArguments())
-        if (isa<TensorType>(bbArg.getType()))
+        if (isa<bufferization::TensorLikeType>(bbArg.getType()))
           r.addAlias({bbArg, BufferRelation::Unknown, /*isDefinite=*/false});
   return r;
 }
 
 bool bufferization::detail::defaultHasTensorSemantics(Operation *op) {
-  auto isaTensor = [](Type t) { return isa<TensorType>(t); };
+  auto isaTensor = [](Type t) { return isa<bufferization::TensorLikeType>(t); };
   bool hasTensorBlockArgument = any_of(op->getRegions(), [&](Region &r) {
     return any_of(r.getBlocks(), [&](Block &b) {
       return any_of(b.getArguments(), [&](BlockArgument bbArg) {
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizationDialect.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizationDialect.cpp
index 6b9253a5d71da..02f9252dcb088 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizationDialect.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizationDialect.cpp
@@ -62,7 +62,11 @@ struct BuiltinTensorExternalModel
 template <typename MemRef>
 struct BuiltinMemRefExternalModel
     : BufferLikeType::ExternalModel<BuiltinMemRefExternalModel<MemRef>,
-                                    MemRef> {};
+                                    MemRef> {
+  mlir::Attribute getMemorySpace(mlir::Type type) const {
+    return mlir::cast<MemRef>(type).getMemorySpace();
+  }
+};
 } // namespace
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
index 4fce9be390bd6..2ceb6795899c9 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
@@ -220,7 +220,7 @@ AliasingValueList AllocTensorOp::getAliasingValues(OpOperand &opOperand,
   return {};
 }
 
-FailureOr<BaseMemRefType>
+FailureOr<bufferization::BufferLikeType>
 AllocTensorOp::getBufferType(Value value, const BufferizationOptions &options,
                              SmallVector<Value> &invocationStack) {
   assert(value == getResult() && "invalid value");
@@ -235,13 +235,15 @@ AllocTensorOp::getBufferType(Value value, const BufferizationOptions &options,
     if (failed(copyBufferType))
       return failure();
     memorySpace = copyBufferType->getMemorySpace();
-  } else if (auto ms = options.defaultMemorySpaceFn(getType())) {
+  } else if (auto ms = options.defaultMemorySpaceFn(
+                 mlir::cast<TensorLikeType>(getType()))) {
     memorySpace = *ms;
   } else {
     return getOperation()->emitError("could not infer memory space");
   }
 
-  return getMemRefTypeWithStaticIdentityLayout(getType(), memorySpace);
+  return mlir::cast<BufferLikeType>(
+      getMemRefTypeWithStaticIdentityLayout(getType(), memorySpace));
 }
 
 LogicalResult AllocTensorOp::verify() {
@@ -585,7 +587,7 @@ MaterializeInDestinationOp::bufferize(RewriterBase &rewriter,
       return failure();
     buffer = *maybeBuffer;
   } else {
-    assert(isa<BaseMemRefType>(getDest().getType()) && "expected memref type");
+    assert(isa<BufferLikeType>(getDest().getType()) && "expected buffer type");
     buffer = getDest();
   }
   auto srcBuffer = getBuffer(rewriter, getSource(), options);
@@ -632,7 +634,7 @@ Value MaterializeInDestinationOp::buildSubsetExtraction(OpBuilder &builder,
     return {};
 
   // Build a bufferization.to_tensor op.
-  assert(isa<BaseMemRefType>(getDest().getType()) && "expected memref type");
+  assert(isa<BufferLikeType>(getDest().getType()) && "expected buffer type");
   assert(getRestrict() &&
          "expected that ops with memrefs dest have 'restrict'");
   setRestrict(false);
@@ -667,22 +669,22 @@ bool MaterializeInDestinationOp::operatesOnDisjointSubset(
 }
 
 LogicalResult MaterializeInDestinationOp::verify() {
-  if (!isa<TensorType, BaseMemRefType>(getDest().getType()))
-    return emitOpError("'dest' must be a tensor or a memref");
+  if (!isa<TensorType, BufferLikeType>(getDest().getType()))
+    return emitOpError("'dest' must be a tensor or a buffer");
   if (auto destType = dyn_cast<TensorType>(getDest().getType())) {
     if (getOperation()->getNumResults() != 1)
       return emitOpError("tensor 'dest' implies exactly one tensor result");
     if (destType != getResult().getType())
       return emitOpError("result and 'dest' types must match");
   }
-  if (isa<BaseMemRefType>(getDest().getType()) &&
+  if (isa<BufferLikeType>(getDest().getType()) &&
       getOperation()->getNumResults() != 0)
-    return emitOpError("memref 'dest' implies zero results");
-  if (getRestrict() && !isa<BaseMemRefType>(getDest().getType()))
-    return emitOpError("'restrict' is valid only for memref destinations");
-  if (getWritable() != isa<BaseMemRefType>(getDest().getType()))
+    return emitOpError("buffer 'dest' implies zero results");
+  if (getRestrict() && !isa<BufferLikeType>(getDest().getType()))
+    return emitOpError("'restrict' is valid only for buffer destinations");
+  if (getWritable() != isa<BufferLikeType>(getDest().getType()))
     return emitOpError("'writable' must be specified if and only if the "
-                       "destination is of memref type");
+                       "destination is of buffer type");
   TensorType srcType = getSource().getType();
   ShapedType destType = cast<ShapedType>(getDest().getType());
   if (srcType.hasRank() != destType.hasRank())
@@ -724,7 +726,7 @@ MutableOperandRange MaterializeInDestinationOp::getDpsInitsMutable() {
 void MaterializeInDestinationOp::getEffects(
     SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
         &effects) {
-  if (isa<BaseMemRefType>(getDest().getType()))
+  if (isa<BufferLikeType>(getDest().getType()))
     effects.emplace_back(MemoryEffects::Write::get(), &getDestMutable(),
                          SideEffects::DefaultResource::get());
 }
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizationTypeInterfaces.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizationTypeInterfaces.cpp
new file mode 100644
index 0000000000000..0e973915c6fc9
--- /dev/null
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizationTypeInterfaces.cpp
@@ -0,0 +1,21 @@
+//===- BufferizationTypeInterfaces.cpp - Type Interfaces --------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.h"
+
+//===----------------------------------------------------------------------===//
+// Bufferization Type Interfaces
+//===----------------------------------------------------------------------===//
+
+namespace mlir {
+namespace bufferization {
+
+#include "mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.cpp.inc"
+
+} // namespace bufferization
+} // namespace mlir
diff --git a/mlir/lib/Dialect/Bufferization/IR/CMakeLists.txt b/mlir/lib/Dialect/Bufferization/IR/CMakeLists.txt
index 63dcc1eb233e9..5d8f0060f2c3f 100644
--- a/mlir/lib/Dialect/Bufferization/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/Bufferization/IR/CMakeLists.txt
@@ -6,6 +6,7 @@ add_mlir_dialect_library(MLIRBufferizationDialect
   BufferizationDialect.cpp
   BufferViewFlowOpInterface.cpp
   UnstructuredControlFlow.cpp
+  BufferizationTypeInterfaces.cpp
 
   ADDITIONAL_HEADER_DIRS
   ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Bufferization
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.cpp b/mlir/lib/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.cpp
index 72f47b8b468ea..cb9db1288039a 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.cpp
@@ -9,6 +9,7 @@
 #include "mlir/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.h"
 
 #include "mlir/Dialect/Bufferization/IR/BufferViewFlowOpInterface.h"
+#include "mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.h"
 #include "mlir/Interfaces/CallInterfaces.h"
 #include "mlir/Interfaces/ControlFlowInterfaces.h"
 #include "mlir/Interfaces/FunctionInterfaces.h"
@@ -93,11 +94,11 @@ void BufferViewFlowAnalysis::build(Operation *op) {
   // given op as terminals.
   auto populateTerminalValues = [&](Operation *op) {
     for (Value v : op->getResults())
-      if (isa<BaseMemRefType>(v.getType()))
+      if (isa<BufferLikeType>(v.getType()))
         this->terminals.insert(v);
     for (Region &r : op->getRegions())
       for (BlockArgument v : r.getArguments())
-        if (isa<BaseMemRefType>(v.getType()))
+        if (isa<BufferLikeType>(v.getType()))
           this->terminals.insert(v);
   };
 
@@ -108,12 +109,12 @@ void BufferViewFlowAnalysis::build(Operation *op) {
     if (auto bufferViewFlowOp = dyn_cast<BufferViewFlowOpInterface>(op)) {
       bufferViewFlowOp.populateDependencies(registerDependencies);
       for (Value v : op->getResults())
-        if (isa<BaseMemRefType>(v.getType()) &&
+        if (isa<BufferLikeType>(v.getType()) &&
             bufferViewFlowOp.mayBeTerminalBuffer(v))
           this->terminals.insert(v);
       for (Region &r : op->getRegions())
         for (BlockArgument v : r.getArguments())
-          if (isa<BaseMemRefType>(v.getType()) &&
+          if (isa<BufferLikeType>(v.getType()) &&
               bufferViewFlowOp.mayBeTerminalBuffer(v))
             this->terminals.insert(v);
       return WalkResult::advance();
@@ -201,7 +202,7 @@ void BufferViewFlowAnalysis::build(Operation *op) {
 }
 
 bool BufferViewFlowAnalysis::mayBeTerminalBuffer(Value value) const {
-  assert(isa<BaseMemRefType>(value.getType()) && "expected memref");
+  assert(isa<BufferLikeType>(value.getType()) && "expected memref");
   return terminals.contains(value);
 }
 
@@ -240,8 +241,8 @@ static Value getViewBase(Value value) {
 BufferOriginAnalysis::BufferOriginAnalysis(Operation *op) : analysis(op) {}
 
 std::optional<bool> BufferOriginAnalysis::isSameAllocation(Value v1, Value v2) {
-  assert(isa<BaseMemRefType>(v1.getType()) && "expected buffer");
-  assert(isa<BaseMemRefType>(v2.getType()) && "expected buffer");
+  assert(isa<BufferLikeType>(v1.getType()) && "expected buffer");
+  assert(isa<BufferLikeType>(v2.getType()) && "expected buffer");
 
   // Skip over all view-like ops.
   v1 = getViewBase(v1);
@@ -275,7 +276,7 @@ std::optional<bool> BufferOriginAnalysis::isSameAllocation(Value v1, Value v2) {
                                       bool &allAllocs,
                                       bool &allAllocsOrFuncEntryArgs) {
     for (Value v : origin) {
-      if (isa<BaseMemRefType>(v.getType()) && analysis.mayBeTerminalBuffer(v)) {
+      if (isa<BufferLikeType>(v.getType()) && analysis.mayBeTerminalBuffer(v)) {
         terminal.insert(v);
         allAllocs &= hasAllocateSideEffect(v);
         allAllocsOrFuncEntryArgs &=
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
index 0b60c44ece5fd..a296b617024d8 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
@@ -80,14 +80,14 @@ struct OneShotBufferizePass
 
       if (mustInferMemorySpace) {
         opt.defaultMemorySpaceFn =
-            [](TensorType t) -> std::optional<Attribute> {
+            [](TensorLikeType t) -> std::optional<Attribute> {
           return std::nullopt;
         };
       }
 
       if (useEncodingForMemorySpace) {
         opt.defaultMemorySpaceFn =
-            [](TensorType t) -> std::optional<Attribute> {
+            [](TensorLikeType t) -> std::optional<Attribute> {
           if (auto rtt = dyn_cast<RankedTensorType>(t))
             return rtt.getEncoding();
           return std::nullopt;
@@ -113,13 +113,15 @@ struct OneShotBufferizePass
                                        const BufferizationOptions &options) {
         auto tensorType = cast<TensorType>(value.getType());
         if (unknownTypeConversionOption == LayoutMapOption::IdentityLayoutMap)
-          return bufferization::getMemRefTypeWithStaticIdentityLayout(
-              tensorType, memorySpace);
+          return mlir::cast<BufferLikeType>(
+              bufferization::getMemRefTypeWithStaticIdentityLayout(
+                  tensorType, memorySpace));
         assert(unknownTypeConversionOption ==
                    LayoutMapOption::FullyDynamicLayoutMap &&
                "invalid layout map option");
-        return bufferization::getMemRefTypeWithFullyDynamicLayout(tensorType,
-                                                                  memorySpace);
+        return mlir::cast<BufferLikeType>(
+            bufferization::getMemRefTypeWithFullyDynamicLayout(tensorType,
+                                                               memorySpace));
       };
 
       // Configure op filter.
@@ -407,7 +409,7 @@ bufferization::bufferizeBlockSignature(Block *block, RewriterBase &rewriter,
       continue;
     }
 
-    FailureOr<BaseMemRefType> memrefType =
+    FailureOr<BufferLikeType> memrefType =
         bufferization::getBufferType(bbArg, options);
     if (failed(memrefType))
       return failure();
@@ -458,7 +460,7 @@ bufferization::bufferizeBlockSignature(Block *block, RewriterBase &rewriter,
         newOperands.push_back(operand);
         continue;
       }
-      FailureOr<BaseMemRefType> operandBufferType =
+      FailureOr<BufferLikeType> operandBufferType =
           bufferization::getBufferType(operand, options);
       if (failed(operandBufferType))
         return failure();
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
index c45678f1e4b4d..4d39d9b795bed 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
@@ -53,14 +53,14 @@ void FuncAnalysisState::startFunctionAnalysis(FuncOp funcOp) {
 /// Return the index-th bufferized function argument type. This assumes that the
 /// specified argument is a tensor. If the tensor is ranked, a layout map may be
 /// specified by the user (as per `options.functionArgTypeConverterFn`).
-static BaseMemRefType
+static BufferLikeType
 getBufferizedFunctionArgType(FuncOp funcOp, int64_t index,
                              const BufferizationOptions &options) {
   auto tensorType =
-      dyn_cast<TensorType>(funcOp.getFunctionType().getInput(index));
-  assert(tensorType && "expected TensorType");
+      dyn_cast<TensorLikeType>(funcOp.getFunctionType().getInput(index));
+  assert(tensorType && "expected TensorLikeType");
 
-  BaseMemRefType memrefType = options.functionArgTypeConverterFn(
+  BufferLikeType memrefType = options.functionArgTypeConverterFn(
       tensorType, *options.defaultMemorySpaceFn(tensorType), funcOp, options);
 
   auto layoutAttr = funcOp.getArgAttrOfType<AffineMapAttr>(
@@ -70,9 +70,9 @@ getBufferizedFunctionArgType(FuncOp funcOp, int64_t index,
 
   auto rankedMemrefType = dyn_cast<MemRefType>(memrefType);
   assert(rankedMemrefType && "buffer layout not supported on unranked tensors");
-  return MemRefType::get(
+  return mlir::cast<BufferLikeType>(MemRefType::get(
       rankedMemrefType.getShape(), rankedMemrefType.getElementType(),
-      layoutAttr.getValue(), rankedMemrefType.getMemorySpace());
+      layoutAttr.getValue(), rankedMemrefType.getMemorySpace()));
 }
 
 /// Return the FuncOp called by `callOp`.
@@ -195,7 +195,7 @@ struct CallOpInterface
     return result;
   }
 
-  FailureOr<BaseMemRefType>
+  FailureOr<BufferLikeType>
   getBufferType(Operation *op, Value value, const BufferizationOptions &options,
                 SmallVector<Value> &invocationStack) const {
     auto callOp = cast<func::CallOp>(op);
@@ -207,11 +207,11 @@ struct CallOpInterface
     FunctionType funcType = funcOp.getFunctionType();
     Type resultType =
         funcType.getResult(cast<OpResult>(value).getResultNumber());
-    if (auto bufferizedType = dyn_cast<BaseMemRefType>(resultType))
+    if (auto bufferizedType = dyn_cast<BufferLikeType>(resultType))
       return bufferizedType;
 
     // Otherwise, call the type converter to compute the bufferized type.
-    auto tensorType = cast<TensorType>(resultType);
+    auto tensorType = cast<TensorLikeType>(resultType);
     return options.functionArgTypeConverterFn(
         tensorType, *options.defaultMemorySpaceFn(tensorType), funcOp, options);
   }
@@ -233,7 +233,7 @@ struct CallOpInterface
       }
 
       // Returning a memref.
-      FailureOr<BaseMemRefType> resultType =
+      FailureOr<BufferLikeType> resultType =
           bufferization::getBufferType(result, options);
       if (failed(resultType))
         return failure();
@@ -263,11 +263,11 @@ struct CallOpInterface
 
       // Caller / callee type mismatch is handled with castOrReallocMemRefValue.
       auto memRefType = funcType.getInput(opOperand.getOperandNumber());
-      if (!isa<BaseMemRefType>(memRefType)) {
+      if (!isa<BufferLikeType>(memRefType)) {
         // The called function was not bufferized yet. This can happen when
         // there cycles in the function call graph. Compute the bufferized
         // result type.
-        FailureOr<BaseMemRefType> maybeMemRefType =
+        FailureOr<BufferLikeType> maybeMemRefType =
             bufferization::getBufferType(
                 funcOp.getArgument(opOperand.getOperandNumber()), options);
         if (failed(maybeMemRefType))
@@ -371,7 +371,7 @@ struct FuncOpInterface
     return getAliasingBranchOpOperands(op, cast<BlockArgument>(value), state);
   }
 
-  FailureOr<BaseMemRefType>
+  FailureOr<BufferLikeType>
   getBufferType(Operation *op, Value value, const BufferizationOptions &options,
                 SmallVector<Value> &invocationStack) const {
     auto funcOp = cast<FuncOp>(op);
@@ -413,8 +413,8 @@ struct FuncOpInterface
     // Compute the result types.
     SmallVector<Type> retTypes;
     for (Type resultType : funcType.getResults()) {
-      if (auto tensorType = dyn_cast<TensorType>(resultType)) {
-        BaseMemRefType resultType = options.functionArgTypeConverterFn(
+      if (auto tensorType = dyn_cast<TensorLikeType>(resultType)) {
+        BufferLikeType resultType = options.functionArgTypeConverterFn(
             tensorType, *options.defaultMemorySpaceFn(tensorType), funcOp,
             options);
         retTypes.push_back(resultType);
diff --git a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
index cf62ee8bc45b5..523ee48be2003 100644
--- a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -102,11 +102,11 @@ struct ConditionOpInterface
     SmallVector<Value> newArgs;
     for (const auto &it : llvm::enumerate(conditionOp.getArgs())) {
       Value value = it.value();
-      if (isa<TensorType>(value.getType())) {
+      if (isa<bufferization::TensorLikeType>(value.getType())) {
         FailureOr<Value> maybeBuffer = getBuffer(rewriter, value, options);
         if (failed(maybeBuffer))
           return failure();
-        FailureOr<BaseMemRefType> resultType = bufferization::getBufferType(
+        auto resultType = bufferization::getBufferType(
             whileOp.getAfterArguments()[it.index()], options);
         if (failed(resultType))
           return failure();
@@ -201,7 +201,7 @@ struct ExecuteRegionOpInterface
     rewriter.setInsertionPointAfter(newOp);
     SmallVector<Value> newResults;
     for (const auto &it : llvm::enumerate(executeRegionOp->getResultTypes())) {
-      if (isa<TensorType>(it.value())) {
+      if (isa<bufferization::TensorLikeType>(it.value())) {
         newResults.push_back(rewriter.create<bufferization::ToTensorOp>(
             executeRegionOp.getLoc(), it.value(),
             newOp->getResult(it.index())));
@@ -244,7 +244,7 @@ struct IfOpInterface
     // Compute bufferized result types.
     SmallVector<Type> newTypes;
     for (Value result : ifOp.getResults()) {
-      if (!isa<TensorType>(result.getType())) {
+      if (!isa<bufferization::TensorLikeType>(result.getType())) {
         newTypes.push_back(result.getType());
         continue;
       }
@@ -270,7 +270,7 @@ struct IfOpInterface
     return success();
   }
 
-  FailureOr<BaseMemRefType>
+  FailureOr<bufferization::BufferLikeType>
   getBufferType(Operation *op, Value value, const BufferizationOptions &options,
                 SmallVector<Value> &invocationStack) const {
     auto ifOp = cast<scf::IfOp>(op);
@@ -282,10 +282,10 @@ struct IfOpInterface
     auto opResult = cast<OpResult>(value);
     auto thenValue = thenYieldOp.getOperand(opResult.getResultNumber());
     auto elseValue = elseYieldOp.getOperand(opResult.getResultNumber());
-    BaseMemRefType thenBufferType, elseBufferType;
-    if (isa<BaseMemRefType>(thenValue.getType())) {
+    bufferization::BufferLikeType thenBufferType, elseBufferType;
+    if (isa<bufferization::BufferLikeType>(thenValue.getType())) {
       // True branch was already bufferized.
-      thenBufferType = cast<BaseMemRefType>(thenValue.getType());
+      thenBufferType = cast<bufferization::BufferLikeType>(thenValue.getType());
     } else {
       auto maybeBufferType =
           bufferization::getBufferType(thenValue, options, invocationStack);
@@ -293,9 +293,9 @@ struct IfOpInterface
         return failure();
       thenBufferType = *maybeBufferType;
     }
-    if (isa<BaseMemRefType>(elseValue.getType())) {
+    if (isa<bufferization::BufferLikeType>(elseValue.getType())) {
       // False branch was already bufferized.
-      elseBufferType = cast<BaseMemRefType>(elseValue.getType());
+      elseBufferType = cast<bufferization::BufferLikeType>(elseValue.getType());
     } else {
       auto maybeBufferType =
           bufferization::getBufferType(elseValue, options, invocationStack);
@@ -313,8 +313,10 @@ struct IfOpInterface
       return op->emitError("inconsistent memory space on then/else branches");
 
     // Layout maps are different: Promote to fully dynamic layout map.
-    return getMemRefTypeWithFullyDynamicLayout(
-        cast<TensorType>(opResult.getType()), thenBufferType.getMemorySpace());
+    return mlir::cast<bufferization::BufferLikeType>(
+        getMemRefTypeWithFullyDynamicLayout(
+            cast<TensorType>(opResult.getType()),
+            thenBufferType.getMemorySpace()));
   }
 };
 
@@ -354,7 +356,7 @@ struct IndexSwitchOpInterface
     // Compute bufferized result types.
     SmallVector<Type> newTypes;
     for (Value result : switchOp.getResults()) {
-      if (!isa<TensorType>(result.getType())) {
+      if (!isa<bufferization::TensorLikeType>(result.getType())) {
         newTypes.push_back(result.getType());
         continue;
       }
@@ -384,7 +386,7 @@ struct IndexSwitchOpInterface
     return success();
   }
 
-  FailureOr<BaseMemRefType>
+  FailureOr<bufferization::BufferLikeType>
   getBufferType(Operation *op, Value value, const BufferizationOptions &options,
                 SmallVector<Value> &invocationStack) const {
     auto switchOp = cast<scf::IndexSwitchOp>(op);
@@ -392,11 +394,13 @@ struct IndexSwitchOpInterface
     int64_t resultNum = cast<OpResult>(value).getResultNumber();
 
     // Helper function to get buffer type of a case.
-    SmallVector<BaseMemRefType> yieldedTypes;
-    auto getYieldedBufferType = [&](Block &b) -> FailureOr<BaseMemRefType> {
+    SmallVector<bufferization::BufferLikeType> yieldedTypes;
+    auto getYieldedBufferType =
+        [&](Block &b) -> FailureOr<bufferization::BufferLikeType> {
       auto yieldOp = cast<scf::YieldOp>(b.getTerminator());
       Value yieldedValue = yieldOp->getOperand(resultNum);
-      if (auto bufferType = dyn_cast<BaseMemRefType>(yieldedValue.getType()))
+      if (auto bufferType =
+              dyn_cast<bufferization::BufferLikeType>(yieldedValue.getType()))
         return bufferType;
       auto maybeBufferType =
           bufferization::getBufferType(yieldedValue, options, invocationStack);
@@ -409,7 +413,7 @@ struct IndexSwitchOpInterface
     auto maybeBufferType = getYieldedBufferType(switchOp.getDefaultBlock());
     if (failed(maybeBufferType))
       return failure();
-    BaseMemRefType bufferType = *maybeBufferType;
+    auto bufferType = *maybeBufferType;
 
     // Compute buffer types of all other cases.
     for (int64_t i = 0, numCases = switchOp.getNumCases(); i < numCases; ++i) {
@@ -426,8 +430,9 @@ struct IndexSwitchOpInterface
         return op->emitError("inconsistent memory space on switch cases");
 
       // Layout maps are different: Promote to fully dynamic layout map.
-      bufferType = getMemRefTypeWithFullyDynamicLayout(
-          cast<TensorType>(value.getType()), bufferType.getMemorySpace());
+      bufferType = mlir::cast<bufferization::BufferLikeType>(
+          getMemRefTypeWithFullyDynamicLayout(cast<TensorType>(value.getType()),
+                                              bufferType.getMemorySpace()));
     }
 
     return bufferType;
@@ -439,7 +444,7 @@ struct IndexSwitchOpInterface
 static DenseSet<int64_t> getTensorIndices(ValueRange values) {
   DenseSet<int64_t> result;
   for (const auto &it : llvm::enumerate(values))
-    if (isa<TensorType>(it.value().getType()))
+    if (isa<bufferization::TensorLikeType>(it.value().getType()))
       result.insert(it.index());
   return result;
 }
@@ -452,8 +457,8 @@ DenseSet<int64_t> getEquivalentBuffers(Block::BlockArgListType bbArgs,
   unsigned int minSize = std::min(bbArgs.size(), yieldedValues.size());
   DenseSet<int64_t> result;
   for (unsigned int i = 0; i < minSize; ++i) {
-    if (!isa<TensorType>(bbArgs[i].getType()) ||
-        !isa<TensorType>(yieldedValues[i].getType()))
+    if (!isa<bufferization::TensorLikeType>(bbArgs[i].getType()) ||
+        !isa<bufferization::TensorLikeType>(yieldedValues[i].getType()))
       continue;
     if (state.areEquivalentBufferizedValues(bbArgs[i], yieldedValues[i]))
       result.insert(i);
@@ -468,7 +473,7 @@ getBuffers(RewriterBase &rewriter, const MutableOperandRange &operands,
            const BufferizationOptions &options) {
   SmallVector<Value> result;
   for (OpOperand &opOperand : operands) {
-    if (isa<TensorType>(opOperand.get().getType())) {
+    if (isa<bufferization::TensorLikeType>(opOperand.get().getType())) {
       FailureOr<Value> resultBuffer =
           getBuffer(rewriter, opOperand.get(), options);
       if (failed(resultBuffer))
@@ -516,9 +521,11 @@ getBbArgReplacements(RewriterBase &rewriter, Block::BlockArgListType bbArgs,
 /// If both buffer types are equal, no casts are needed the computed buffer type
 /// can be used directly. Otherwise, the buffer types can only differ in their
 /// layout map and a cast must be inserted.
-static FailureOr<BaseMemRefType> computeLoopRegionIterArgBufferType(
-    Operation *loopOp, BlockArgument iterArg, Value initArg, Value yieldedValue,
-    const BufferizationOptions &options, SmallVector<Value> &invocationStack) {
+static FailureOr<bufferization::BufferLikeType>
+computeLoopRegionIterArgBufferType(Operation *loopOp, BlockArgument iterArg,
+                                   Value initArg, Value yieldedValue,
+                                   const BufferizationOptions &options,
+                                   SmallVector<Value> &invocationStack) {
   // Determine the buffer type of the init_arg.
   auto initArgBufferType =
       bufferization::getBufferType(initArg, options, invocationStack);
@@ -540,10 +547,11 @@ static FailureOr<BaseMemRefType> computeLoopRegionIterArgBufferType(
   }
 
   // Compute the buffer type of the yielded value.
-  BaseMemRefType yieldedValueBufferType;
-  if (isa<BaseMemRefType>(yieldedValue.getType())) {
+  bufferization::BufferLikeType yieldedValueBufferType;
+  if (isa<bufferization::BufferLikeType>(yieldedValue.getType())) {
     // scf.yield was already bufferized.
-    yieldedValueBufferType = cast<BaseMemRefType>(yieldedValue.getType());
+    yieldedValueBufferType =
+        cast<bufferization::BufferLikeType>(yieldedValue.getType());
   } else {
     // Note: This typically triggers a recursive call for the buffer type of
     // the iter_arg.
@@ -576,8 +584,9 @@ static FailureOr<BaseMemRefType> computeLoopRegionIterArgBufferType(
         "expected same shape");
   }
 #endif // NDEBUG
-  return getMemRefTypeWithFullyDynamicLayout(
-      iterTensorType, yieldedBufferType.getMemorySpace());
+  return mlir::cast<bufferization::BufferLikeType>(
+      getMemRefTypeWithFullyDynamicLayout(iterTensorType,
+                                          yieldedBufferType.getMemorySpace()));
 }
 
 /// Return `true` if the given loop may have 0 iterations.
@@ -696,12 +705,13 @@ struct ForOpInterface
     return success();
   }
 
-  FailureOr<BaseMemRefType>
+  FailureOr<bufferization::BufferLikeType>
   getBufferType(Operation *op, Value value, const BufferizationOptions &options,
                 SmallVector<Value> &invocationStack) const {
     auto forOp = cast<scf::ForOp>(op);
     assert(getOwnerOfValue(value) == op && "invalid value");
-    assert(isa<TensorType>(value.getType()) && "expected tensor type");
+    assert(isa<bufferization::TensorLikeType>(value.getType()) &&
+           "expected tensor type");
 
     if (auto opResult = dyn_cast<OpResult>(value)) {
       // The type of an OpResult must match the corresponding iter_arg type.
@@ -744,7 +754,7 @@ struct ForOpInterface
       Value initArg = it.value();
       Value result = forOp->getResult(it.index());
       // If the type is not a tensor, bufferization doesn't need to touch it.
-      if (!isa<TensorType>(result.getType())) {
+      if (!isa<bufferization::TensorLikeType>(result.getType())) {
         castedInitArgs.push_back(initArg);
         continue;
       }
@@ -795,7 +805,7 @@ struct ForOpInterface
     auto forOp = cast<scf::ForOp>(op);
     auto yieldOp = cast<scf::YieldOp>(forOp.getBody()->getTerminator());
     for (OpResult opResult : op->getOpResults()) {
-      if (!isa<TensorType>(opResult.getType()))
+      if (!isa<bufferization::TensorLikeType>(opResult.getType()))
         continue;
 
       // Note: This is overly strict. We should check for aliasing bufferized
@@ -920,7 +930,7 @@ struct WhileOpInterface
     for (int64_t idx = 0;
          idx < static_cast<int64_t>(conditionOp.getArgs().size()); ++idx) {
       Value value = conditionOp.getArgs()[idx];
-      if (!isa<TensorType>(value.getType()) ||
+      if (!isa<bufferization::TensorLikeType>(value.getType()) ||
           (equivalentYieldsAfter.contains(idx) &&
            equivalentYieldsBefore.contains(idx))) {
         beforeYieldValues.push_back(value);
@@ -962,7 +972,7 @@ struct WhileOpInterface
       Value initArg = it.value();
       Value beforeArg = whileOp.getBeforeArguments()[it.index()];
       // If the type is not a tensor, bufferization doesn't need to touch it.
-      if (!isa<TensorType>(beforeArg.getType())) {
+      if (!isa<bufferization::TensorLikeType>(beforeArg.getType())) {
         castedInitArgs.push_back(initArg);
         continue;
       }
@@ -975,7 +985,7 @@ struct WhileOpInterface
     // The result types of a WhileOp are the same as the "after" bbArg types.
     SmallVector<Type> argsTypesAfter = llvm::to_vector(
         llvm::map_range(whileOp.getAfterArguments(), [&](BlockArgument bbArg) {
-          if (!isa<TensorType>(bbArg.getType()))
+          if (!isa<bufferization::TensorLikeType>(bbArg.getType()))
             return bbArg.getType();
           // TODO: error handling
           return llvm::cast<Type>(
@@ -1022,12 +1032,13 @@ struct WhileOpInterface
     return success();
   }
 
-  FailureOr<BaseMemRefType>
+  FailureOr<bufferization::BufferLikeType>
   getBufferType(Operation *op, Value value, const BufferizationOptions &options,
                 SmallVector<Value> &invocationStack) const {
     auto whileOp = cast<scf::WhileOp>(op);
     assert(getOwnerOfValue(value) == op && "invalid value");
-    assert(isa<TensorType>(value.getType()) && "expected tensor type");
+    assert(isa<bufferization::TensorLikeType>(value.getType()) &&
+           "expected tensor type");
 
     // Case 1: Block argument of the "before" region.
     if (auto bbArg = dyn_cast<BlockArgument>(value)) {
@@ -1053,9 +1064,9 @@ struct WhileOpInterface
       llvm_unreachable("invalid value");
     }
     Value conditionYieldedVal = whileOp.getConditionOp().getArgs()[resultNum];
-    if (!isa<TensorType>(conditionYieldedVal.getType())) {
+    if (!isa<bufferization::TensorLikeType>(conditionYieldedVal.getType())) {
       // scf.condition was already bufferized.
-      return cast<BaseMemRefType>(conditionYieldedVal.getType());
+      return cast<bufferization::BufferLikeType>(conditionYieldedVal.getType());
     }
     return bufferization::getBufferType(conditionYieldedVal, options,
                                         invocationStack);
@@ -1082,7 +1093,7 @@ struct WhileOpInterface
     auto conditionOp = whileOp.getConditionOp();
     for (const auto &it : llvm::enumerate(conditionOp.getArgs())) {
       Block *block = conditionOp->getBlock();
-      if (!isa<TensorType>(it.value().getType()))
+      if (!isa<bufferization::TensorLikeType>(it.value().getType()))
         continue;
       if (it.index() >= block->getNumArguments() ||
           !state.areEquivalentBufferizedValues(it.value(),
@@ -1095,7 +1106,7 @@ struct WhileOpInterface
     auto yieldOp = whileOp.getYieldOp();
     for (const auto &it : llvm::enumerate(yieldOp.getResults())) {
       Block *block = yieldOp->getBlock();
-      if (!isa<TensorType>(it.value().getType()))
+      if (!isa<bufferization::TensorLikeType>(it.value().getType()))
         continue;
       if (it.index() >= block->getNumArguments() ||
           !state.areEquivalentBufferizedValues(it.value(),
@@ -1154,7 +1165,7 @@ struct YieldOpInterface
     SmallVector<Value> newResults;
     for (const auto &it : llvm::enumerate(yieldOp.getResults())) {
       Value value = it.value();
-      if (isa<TensorType>(value.getType())) {
+      if (isa<bufferization::TensorLikeType>(value.getType())) {
         FailureOr<Value> maybeBuffer = getBuffer(rewriter, value, options);
         if (failed(maybeBuffer))
           return failure();
@@ -1162,14 +1173,14 @@ struct YieldOpInterface
         // We may have to cast the value before yielding it.
         if (isa<scf::ForOp, scf::IfOp, scf::IndexSwitchOp>(
                 yieldOp->getParentOp())) {
-          FailureOr<BaseMemRefType> resultType = bufferization::getBufferType(
+          auto resultType = bufferization::getBufferType(
               yieldOp->getParentOp()->getResult(it.index()), options);
           if (failed(resultType))
             return failure();
           buffer = castBuffer(rewriter, buffer, *resultType);
         } else if (auto whileOp =
                        dyn_cast<scf::WhileOp>(yieldOp->getParentOp())) {
-          FailureOr<BaseMemRefType> resultType = bufferization::getBufferType(
+          auto resultType = bufferization::getBufferType(
               whileOp.getBeforeArguments()[it.index()], options);
           if (failed(resultType))
             return failure();
@@ -1274,7 +1285,7 @@ struct ForallOpInterface
     return success();
   }
 
-  FailureOr<BaseMemRefType>
+  FailureOr<bufferization::BufferLikeType>
   getBufferType(Operation *op, Value value, const BufferizationOptions &options,
                 SmallVector<Value> &invocationStack) const {
     auto forallOp = cast<ForallOp>(op);
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp
index 6e882a8d0ff30..068c248c1bcd7 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp
@@ -220,8 +220,8 @@ mlir::getBufferizationOptionsForSparsification(bool analysisOnly) {
   options.setFunctionBoundaryTypeConversion(LayoutMapOption::IdentityLayoutMap);
   options.unknownTypeConverterFn = [](Value value, Attribute memorySpace,
                                       const BufferizationOptions &options) {
-    return getMemRefTypeWithStaticIdentityLayout(
-        cast<TensorType>(value.getType()), memorySpace);
+    return llvm::cast<BufferLikeType>(getMemRefTypeWithStaticIdentityLayout(
+        cast<TensorType>(value.getType()), memorySpace));
   };
   if (analysisOnly) {
     options.testAnalysisOnly = true;
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenUtils.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenUtils.cpp
index f92382472b478..742a92566a31e 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenUtils.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenUtils.cpp
@@ -550,8 +550,8 @@ TypedValue<BaseMemRefType>
 sparse_tensor::genToMemref(OpBuilder &builder, Location loc, Value tensor) {
   auto tTp = llvm::cast<TensorType>(tensor.getType());
   auto mTp = MemRefType::get(tTp.getShape(), tTp.getElementType());
-  return builder.create<bufferization::ToMemrefOp>(loc, mTp, tensor)
-      .getResult();
+  return llvm::cast<TypedValue<BaseMemRefType>>(
+      builder.create<bufferization::ToMemrefOp>(loc, mTp, tensor).getResult());
 }
 
 Value sparse_tensor::createOrFoldSliceOffsetOp(OpBuilder &builder, Location loc,
diff --git a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
index 31014172a9555..fb0dd151a4448 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -487,8 +487,7 @@ struct FromElementsOpInterface
         /*copy=*/false);
     if (failed(tensorAlloc))
       return failure();
-    FailureOr<BaseMemRefType> memrefType =
-        bufferization::getBufferType(*tensorAlloc, options);
+    auto memrefType = bufferization::getBufferType(*tensorAlloc, options);
     if (failed(memrefType))
       return failure();
     Value buffer = rewriter.create<bufferization::ToMemrefOp>(
@@ -592,7 +591,8 @@ struct GenerateOpInterface
     auto type = generateOp.getResult().getType();
 
     // TODO: Implement memory space for this op.
-    if (options.defaultMemorySpaceFn(type) != Attribute())
+    if (options.defaultMemorySpaceFn(llvm::cast<TensorLikeType>(type)) !=
+        Attribute())
       return op->emitError("memory space not implemented yet");
 
     // Allocate memory.
@@ -1031,7 +1031,8 @@ struct SplatOpInterface
     auto tensorType = cast<RankedTensorType>(tensorAlloc->getType());
 
     // TODO: Implement memory space for this op.
-    if (options.defaultMemorySpaceFn(tensorType) != Attribute())
+    if (options.defaultMemorySpaceFn(llvm::cast<TensorLikeType>(tensorType)) !=
+        Attribute())
       return op->emitError("memory space not implemented yet");
 
     auto linalgOp =
diff --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize.mlir
index e65c5b92949f6..6fb421675fab6 100644
--- a/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize.mlir
+++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize.mlir
@@ -268,4 +268,23 @@ func.func @materialize_in_dest_raw(%f: f32, %f2: f32, %idx: index) -> (tensor<5x
   %r = tensor.extract %dest_filled[%idx] : tensor<5xf32>
 
   return %0, %r : tensor<5xf32>, f32
-}
\ No newline at end of file
+}
+
+// -----
+
+// CHECK-LABEL: func.func @test_dialect_op(
+// CHECK-SAME:    %[[ARG:.*]]: !test.test_tensor<[32, 64], f64>
+// CHECK-SAME:  ) -> !test.test_tensor<[32, 128], f64> {
+func.func @test_dialect_op(%arg: !test.test_tensor<[32, 64], f64>)
+    -> !test.test_tensor<[32, 128], f64> {
+  // CHECK: %[[MEMREF:.*]] = bufferization.to_memref %[[ARG]]
+  // CHECK: %[[DUMMY:.*]] = "test.dummy_memref_op"(%[[MEMREF]])
+  // CHECK-SAME: : (!test.test_memref<[32, 64], f64>)
+  // CHECK-SAME: -> !test.test_memref<[32, 128], f64>
+  // CHECK: %[[OUT:.*]] = bufferization.to_tensor %[[DUMMY]]
+  %out = "test.dummy_tensor_op"(%arg) : (!test.test_tensor<[32, 64], f64>)
+    -> !test.test_tensor<[32, 128], f64>
+
+  // CHECK: return %[[OUT]]
+  return %out : !test.test_tensor<[32, 128], f64>
+}
diff --git a/mlir/test/Dialect/Bufferization/invalid.mlir b/mlir/test/Dialect/Bufferization/invalid.mlir
index 2c8807b66de74..86b541d95924b 100644
--- a/mlir/test/Dialect/Bufferization/invalid.mlir
+++ b/mlir/test/Dialect/Bufferization/invalid.mlir
@@ -58,14 +58,14 @@ func.func @invalid_materialize_in_destination(%arg0: tensor<5x5xf32>, %arg1: ten
 // -----
 
 func.func @invalid_materialize_in_destination_dest_type(%arg0: tensor<5xf32>, %arg1: vector<5xf32>) {
-  // expected-error @below{{'dest' must be a tensor or a memref}}
+  // expected-error @below{{'dest' must be a tensor or a buffer}}
   bufferization.materialize_in_destination %arg0 in %arg1 : (tensor<5xf32>, vector<5xf32>) -> ()
 }
 
 // -----
 
 func.func @invalid_materialize_in_destination_result(%arg0: tensor<?xf32>, %arg1: memref<?xf32>) {
-  // expected-error @below{{memref 'dest' implies zero results}}
+  // expected-error @below{{buffer 'dest' implies zero results}}
   bufferization.materialize_in_destination %arg0 in restrict %arg1 : (tensor<?xf32>, memref<?xf32>) -> (tensor<?xf32>)
 }
 
@@ -79,14 +79,14 @@ func.func @invalid_materialize_in_destination_result_missing(%arg0: tensor<?xf32
 // -----
 
 func.func @invalid_materialize_in_destination_restrict(%arg0: tensor<?xf32>, %arg1: tensor<?xf32>) {
-  // expected-error @below{{'restrict' is valid only for memref destinations}}
+  // expected-error @below{{'restrict' is valid only for buffer destinations}}
   bufferization.materialize_in_destination %arg0 in restrict %arg1 : (tensor<?xf32>, tensor<?xf32>) -> (tensor<?xf32>)
 }
 
 // -----
 
 func.func @invalid_materialize_in_destination_restrict(%arg0: tensor<?xf32>, %arg1: tensor<?xf32>) {
-  // expected-error @below{{'writable' must be specified if and only if the destination is of memref type}}
+  // expected-error @below{{'writable' must be specified if and only if the destination is of buffer type}}
   bufferization.materialize_in_destination %arg0 in writable %arg1 : (tensor<?xf32>, tensor<?xf32>) -> (tensor<?xf32>)
 }
 
diff --git a/mlir/test/lib/Dialect/Bufferization/TestTensorCopyInsertion.cpp b/mlir/test/lib/Dialect/Bufferization/TestTensorCopyInsertion.cpp
index 2991a3c165ee2..95d6158d7c67f 100644
--- a/mlir/test/lib/Dialect/Bufferization/TestTensorCopyInsertion.cpp
+++ b/mlir/test/lib/Dialect/Bufferization/TestTensorCopyInsertion.cpp
@@ -46,7 +46,9 @@ struct TestTensorCopyInsertionPass
     options.bufferizeFunctionBoundaries = bufferizeFunctionBoundaries;
     if (mustInferMemorySpace) {
       options.defaultMemorySpaceFn =
-          [](TensorType t) -> std::optional<Attribute> { return std::nullopt; };
+          [](bufferization::TensorLikeType t) -> std::optional<Attribute> {
+        return std::nullopt;
+      };
     }
     if (failed(bufferization::insertTensorCopies(getOperation(), options)))
       signalPassFailure();
diff --git a/mlir/test/lib/Dialect/Test/TestOpDefs.cpp b/mlir/test/lib/Dialect/Test/TestOpDefs.cpp
index 454a12bac9ab3..df7586976280c 100644
--- a/mlir/test/lib/Dialect/Test/TestOpDefs.cpp
+++ b/mlir/test/lib/Dialect/Test/TestOpDefs.cpp
@@ -8,6 +8,7 @@
 
 #include "TestDialect.h"
 #include "TestOps.h"
+#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
 #include "mlir/IR/Verifier.h"
 #include "mlir/Interfaces/FunctionImplementation.h"
@@ -1386,3 +1387,26 @@ TestMultiSlotAlloca::handleDestructuringComplete(
     const DestructurableMemorySlot &slot, OpBuilder &builder) {
   return createNewMultiAllocaWithoutSlot(slot, builder, *this);
 }
+
+::mlir::LogicalResult test::TestDummyTensorOp::bufferize(
+    ::mlir::RewriterBase &rewriter,
+    const ::mlir::bufferization::BufferizationOptions &options) {
+  const auto inType = getInput().getType();
+  const auto bufferizedInType = test::TestMemrefType::get(
+      getContext(), inType.getShape(), inType.getElementType(), nullptr);
+  const auto outType = getOutput().getType();
+  const auto bufferizedOutType = test::TestMemrefType::get(
+      getContext(), outType.getShape(), outType.getElementType(), nullptr);
+
+  // replace op with memref analogy, preserve correct types at the boundaries
+  auto toMemref = rewriter.create<::mlir::bufferization::ToMemrefOp>(
+      getLoc(), bufferizedInType, getInput());
+  auto dummyMemrefOp = rewriter.create<test::TestDummyMemrefOp>(
+      getLoc(), bufferizedOutType, toMemref.getResult());
+  auto toTensor = rewriter.create<::mlir::bufferization::ToTensorOp>(
+      getLoc(), outType, dummyMemrefOp.getOutput());
+
+  rewriter.replaceOp(*this, toTensor);
+
+  return mlir::success();
+}
diff --git a/mlir/test/lib/Dialect/Test/TestOps.h b/mlir/test/lib/Dialect/Test/TestOps.h
index f070c3bedd92c..ea8867e3fc41d 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.h
+++ b/mlir/test/lib/Dialect/Test/TestOps.h
@@ -13,6 +13,7 @@
 #include "TestInterfaces.h"
 #include "TestTypes.h"
 #include "mlir/Bytecode/BytecodeImplementation.h"
+#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
 #include "mlir/Dialect/DLTI/DLTI.h"
 #include "mlir/Dialect/DLTI/Traits.h"
 #include "mlir/Dialect/Func/IR/FuncOps.h"
diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index 85a49e05d4c73..976b4963a29f7 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -30,7 +30,7 @@ include "mlir/Interfaces/InferTypeOpInterface.td"
 include "mlir/Interfaces/LoopLikeInterface.td"
 include "mlir/Interfaces/MemorySlotInterfaces.td"
 include "mlir/Interfaces/SideEffectInterfaces.td"
-
+include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td"
 
 // Include the attribute definitions.
 include "TestAttrDefs.td"
@@ -3499,4 +3499,57 @@ def TestAllocWithMultipleResults : TEST_Op<"alloc_with_multiple_results"> {
   }];
 }
 
+//===----------------------------------------------------------------------===//
+// Test Ops bufferization
+//===----------------------------------------------------------------------===//
+
+def TestDummyTensorOp : TEST_Op<"dummy_tensor_op", [BufferizableOpInterface]> {
+  let arguments = (ins
+    Arg<TestTensorType>:$input
+  );
+  let results = (outs
+    Arg<TestTensorType>:$output
+  );
+  let extraClassDeclaration = [{
+    // BufferizableOpInterface
+    bool bufferizesToMemoryRead(mlir::OpOperand&,
+      const mlir::bufferization::AnalysisState&);
+
+    bool bufferizesToMemoryWrite(mlir::OpOperand&,
+      const mlir::bufferization::AnalysisState&);
+
+    mlir::bufferization::AliasingValueList getAliasingValues(mlir::OpOperand&,
+      const mlir::bufferization::AnalysisState&);
+
+    mlir::LogicalResult bufferize(
+      mlir::RewriterBase& rewriter,
+      const mlir::bufferization::BufferizationOptions& options);
+  }];
+
+  let extraClassDefinition = [{
+    bool test::TestDummyTensorOp::bufferizesToMemoryRead(::mlir::OpOperand&,
+        const ::mlir::bufferization::AnalysisState&) {
+      return true;
+    }
+    bool test::TestDummyTensorOp::bufferizesToMemoryWrite(::mlir::OpOperand&,
+        const ::mlir::bufferization::AnalysisState&) {
+      return true;
+    }
+    ::mlir::bufferization::AliasingValueList
+    test::TestDummyTensorOp::getAliasingValues(::mlir::OpOperand&,
+        const ::mlir::bufferization::AnalysisState&) {
+      return {};
+    }
+  }];
+}
+
+def TestDummyMemrefOp : TEST_Op<"dummy_memref_op", []> {
+  let arguments = (ins
+    Arg<TestMemrefType>:$input
+  );
+  let results = (outs
+    Arg<TestMemrefType>:$output
+  );
+}
+
 #endif // TEST_OPS
diff --git a/mlir/test/lib/Dialect/Test/TestTypeDefs.td b/mlir/test/lib/Dialect/Test/TestTypeDefs.td
index e9785594d3332..cee6888a7196c 100644
--- a/mlir/test/lib/Dialect/Test/TestTypeDefs.td
+++ b/mlir/test/lib/Dialect/Test/TestTypeDefs.td
@@ -446,6 +446,9 @@ def TestMemrefType : Test_Type<"TestMemref",
       return test::TestMemrefType::get(
         getContext(), shape.value_or(getShape()), elementType, getMemSpace());
     }
+
+    // BufferLikeTypeInterface:
+    ::mlir::Attribute getMemorySpace() const { return getMemSpace(); }
   }];
 }
 



More information about the Mlir-commits mailing list