[Mlir-commits] [mlir] [mlir][bufferization] Support custom types (1/N) (PR #142986)

Andrei Golubev llvmlistbot at llvm.org
Thu Jun 5 08:10:43 PDT 2025


https://github.com/andrey-golubev created https://github.com/llvm/llvm-project/pull/142986

Following the introduction of TensorLike and BufferLike type interfaces (see 00eaff3e9c897c263a879416d0f151d7ca7eeaff), introduce minimal changes required to bufferize a custom tensor operation into a custom buffer operation.

To achieve this, a new conversion dialect interface is added that abstracts away the differences between existing (tensor -> memref) and custom conversions.

The scope of the changes is intentionally limited (for example, BufferizableOpInterface is untouched) in order to first understand the basics and reach consensus design-wise.

>From 6599da157f41246174509faecd727b9ed8682264 Mon Sep 17 00:00:00 2001
From: "Golubev, Andrey" <andrey.golubev at intel.com>
Date: Wed, 4 Jun 2025 15:03:26 +0000
Subject: [PATCH] [mlir][bufferization] Support custom types (1/N)

Following the introduction of TensorLike and BufferLike type interfaces
(see 00eaff3e9c897c263a879416d0f151d7ca7eeaff), introduce minimal
changes required to bufferize a custom tensor operation into a custom
buffer operation.

To achieve this, a new conversion dialect interface is added that
abstracts away the differences between existing (tensor -> memref) and
custom conversions.

The scope of the changes is intentionally limited (for example,
BufferizableOpInterface is untouched) in order to first understand the
basics and reach consensus design-wise.
---
 .../IR/BufferizableOpInterface.h              | 17 ++++-
 .../IR/BufferizationConversionInterface.h     | 72 ++++++++++++++++++
 .../Bufferization/IR/BufferizationOps.td      | 48 +++++++-----
 .../IR/UnstructuredControlFlow.h              |  5 +-
 .../BufferizableOpInterfaceImpl.cpp           | 14 ++--
 .../IR/BufferizableOpInterface.cpp            | 76 +++++++++++++------
 .../IR/BufferizationConversionInterface.cpp   | 67 ++++++++++++++++
 .../Bufferization/IR/BufferizationOps.cpp     | 21 ++---
 .../Dialect/Bufferization/IR/CMakeLists.txt   |  1 +
 .../Bufferization/Transforms/Bufferize.cpp    |  8 +-
 .../FuncBufferizableOpInterfaceImpl.cpp       |  8 +-
 .../BufferizableOpInterfaceImpl.cpp           | 51 +++++++------
 .../Transforms/Utils/CodegenUtils.cpp         |  4 +-
 .../BufferizableOpInterfaceImpl.cpp           | 14 ++--
 .../Transforms/one-shot-bufferize.mlir        | 21 ++++-
 mlir/test/lib/Dialect/Test/TestDialect.cpp    | 49 ++++++++++++
 mlir/test/lib/Dialect/Test/TestOpDefs.cpp     | 23 ++++++
 mlir/test/lib/Dialect/Test/TestOps.h          |  1 +
 mlir/test/lib/Dialect/Test/TestOps.td         | 58 +++++++++++++-
 19 files changed, 451 insertions(+), 107 deletions(-)
 create mode 100644 mlir/include/mlir/Dialect/Bufferization/IR/BufferizationConversionInterface.h
 create mode 100644 mlir/lib/Dialect/Bufferization/IR/BufferizationConversionInterface.cpp

diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
index adccbef754ec5..8390da956444d 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
@@ -17,6 +17,7 @@
 #include <optional>
 
 #include "mlir/Dialect/Bufferization/IR/BufferizationEnums.h.inc"
+#include "mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.h"
 
 namespace mlir {
 class OpBuilder;
@@ -615,7 +616,7 @@ FailureOr<Value> getBuffer(RewriterBase &rewriter, Value value,
 /// IR, this function can be used.
 ///
 /// This function is a wrapper around BufferizableOpInterface::getBufferType.
-FailureOr<BaseMemRefType> getBufferType(Value value,
+FailureOr<BufferLikeType> getBufferType(Value value,
                                         const BufferizationOptions &options,
                                         const BufferizationState &state);
 
@@ -629,7 +630,7 @@ FailureOr<BaseMemRefType> getBufferType(Value value,
 /// IR, this function can be used.
 ///
 /// This function is a wrapper around `BufferizableOpInterface::getBufferType`.
-FailureOr<BaseMemRefType> getBufferType(Value value,
+FailureOr<BufferLikeType> getBufferType(Value value,
                                         const BufferizationOptions &options,
                                         const BufferizationState &state,
                                         SmallVector<Value> &invocationStack);
@@ -738,6 +739,18 @@ AliasingValueList unknownGetAliasingValues(OpOperand &opOperand);
 /// This is the default implementation of
 /// BufferizableOpInterface::hasTensorSemantics
 bool defaultHasTensorSemantics(Operation *op);
+
+/// This is a helper function used when buffer type is guaranteed to be memref.
+FailureOr<BaseMemRefType> castToMemRef(FailureOr<BufferLikeType> bufferType);
+
+/// This function is a free-standing helper that relies on
+/// bufferization::ConversionInterface to verify the types in tensor and buffer
+/// worlds match.
+bool typesMatchAfterBufferization(Operation &op, Value tensor, Value buffer);
+
+/// This function is a free-standing helper that relies on
+/// bufferization::ConversionInterface to perform the conversion.
+Type getTensorFromBuffer(Type buffer);
 } // namespace detail
 
 } // namespace bufferization
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationConversionInterface.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationConversionInterface.h
new file mode 100644
index 0000000000000..4164d1dcb9ea6
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationConversionInterface.h
@@ -0,0 +1,72 @@
+//===- BufferizationConversionInterface.h - Dialect Interface ---*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_BUFFERIZATION_IR_BUFFERIZATIONCONVERSIONINTERFACE_H_
+#define MLIR_DIALECT_BUFFERIZATION_IR_BUFFERIZATIONCONVERSIONINTERFACE_H_
+
+#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
+#include "mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.h"
+#include "mlir/IR/DialectInterface.h"
+
+namespace mlir {
+namespace bufferization {
+
+/// This class defines a virtual interface for conversions between tensor-like
+/// and buffer-like types.
+struct ConversionDialectInterface
+    : DialectInterface::Base<ConversionDialectInterface> {
+  using Base::Base;
+
+  /// Hook to customize tensor-like -> buffer-like conversion within a given
+  /// dialect. Returns a buffer-like type for the specific tensor-like type.
+  virtual FailureOr<BufferLikeType> getBufferType(
+      Value value, const BufferizationOptions &options,
+      const BufferizationState &state,
+      function_ref<InFlightDiagnostic(const Twine &)> emitError) const = 0;
+
+  /// Hook to customize type checking between tensor-like and buffer-like types.
+  /// Given tensor `T` and buffer `B = getBufferType(T, ...)`, the call to
+  /// `typesMatch(T, B)` must return true.
+  virtual LogicalResult typesMatch(
+      TensorLikeType tensor, BufferLikeType buffer,
+      function_ref<InFlightDiagnostic(const Twine &)> emitError) const = 0;
+
+  /// Hook to customize buffer-like -> tensor-like conversion, which is the
+  /// opposite of bufferization.
+  virtual TensorLikeType getTensorFromBuffer(BufferLikeType buffer) const = 0;
+};
+
+/// Interface collection for conversion between tensor-like and buffer-like
+/// types, dispatches to a concrete interface implementation based on the
+/// dialect to which the given type belongs.
+struct ConversionInterface
+    : DialectInterfaceCollection<ConversionDialectInterface> {
+  using Base::Base;
+
+  /// Dispatches to ConversionDialectInterface::getBufferType() of the dialect
+  /// associated with the value type.
+  FailureOr<BufferLikeType> getBufferType(
+      Value value, const BufferizationOptions &options,
+      const BufferizationState &state,
+      function_ref<InFlightDiagnostic(const Twine &)> emitError) const;
+
+  /// Dispatches to ConversionDialectInterface::typesMatch() of the dialect
+  /// associated with the value type.
+  LogicalResult
+  typesMatch(TensorLikeType tensor, BufferLikeType buffer,
+             function_ref<InFlightDiagnostic(const Twine &)> emitError) const;
+
+  /// Dispatches to ConversionDialectInterface::getTensorFromBuffer() of the
+  /// dialect associated with the value type.
+  TensorLikeType getTensorFromBuffer(BufferLikeType buffer) const;
+};
+
+} // namespace bufferization
+} // namespace mlir
+
+#endif // MLIR_DIALECT_BUFFERIZATION_IR_BUFFERIZATIONTYPEINTERFACES_H_
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
index 3d4dcdee2663b..277d56bc3f647 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
@@ -12,6 +12,7 @@
 include "mlir/Dialect/Bufferization/IR/AllocationOpInterface.td"
 include "mlir/Dialect/Bufferization/IR/BufferViewFlowOpInterface.td"
 include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td"
+include "mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.td"
 include "mlir/Dialect/Bufferization/IR/BufferizationBase.td"
 include "mlir/Interfaces/DestinationStyleOpInterface.td"
 include "mlir/Interfaces/InferTypeOpInterface.td"
@@ -387,20 +388,28 @@ def Bufferization_DeallocTensorOp : Bufferization_Op<"dealloc_tensor",
 // ToTensorOp
 //===----------------------------------------------------------------------===//
 
+class Bufferization_TensorAndBufferMatch<string tensor, string buffer> : PredOpTrait<
+  "specified tensor and buffer types match",
+  CPred<
+    "::mlir::bufferization::detail::typesMatchAfterBufferization("
+        "$_op, $" # tensor # ", $" # buffer #")"
+  >
+>;
+
 def Bufferization_ToTensorOp : Bufferization_Op<"to_tensor", [
     BufferizableOpInterface,
     SameOperandsAndResultShape,
     SameOperandsAndResultElementType,
-    AllElementTypesMatch<["memref", "result"]>
+    Bufferization_TensorAndBufferMatch<"result", "buffer">
   ]> {
-  let summary = "create a tensor from a `memref`";
+  let summary = "create a buffer-like type from a tensor-like type";
   let description = [{
-    An operation that creates a tensor from a `memref`. The result value is a
-    tensor whose shape and element type match the memref operand.
+    An operation that creates a tensor from a buffer. The result value is a
+    tensor-like type whose shape and element type match the buffer-like operand.
 
     The opposite of this op is `to_buffer`. Together, these two ops are
     useful for source/target materializations when doing type conversions
-    involving tensors and memrefs.
+    involving tensors and buffers.
 
     Example:
 
@@ -442,11 +451,11 @@ def Bufferization_ToTensorOp : Bufferization_Op<"to_tensor", [
     away. However, such IR is no longer bufferizable with One-Shot Bufferize.
   }];
 
-  let arguments = (ins Arg<AnyRankedOrUnrankedMemRef,
+  let arguments = (ins Arg<Bufferization_BufferLikeTypeInterface,
                            "the reference to load from",
-                           [MemReadAt<0, FullEffect>]>:$memref,
+                           [MemReadAt<0, FullEffect>]>:$buffer,
                        UnitAttr:$restrict, UnitAttr:$writable);
-  let results = (outs AnyTensor:$result);
+  let results = (outs Bufferization_TensorLikeTypeInterface:$result);
 
   let extraClassDeclaration = [{
     /// The result of a to_tensor is always a tensor.
@@ -473,19 +482,19 @@ def Bufferization_ToTensorOp : Bufferization_Op<"to_tensor", [
     FailureOr<BaseMemRefType> getBufferType(
         Value value, const BufferizationOptions &options,
         const BufferizationState &state, SmallVector<Value> &invocationStack) {
-      return ::llvm::cast<BaseMemRefType>(getMemref().getType());
+      return ::llvm::cast<BaseMemRefType>(getBuffer().getType());
     }
   }];
 
   let assemblyFormat = [{
-    $memref (`restrict` $restrict^)? (`writable` $writable^)? attr-dict
-      `:` type($memref) `to` type($result)
+    $buffer (`restrict` $restrict^)? (`writable` $writable^)? attr-dict
+      `:` type($buffer) `to` type($result)
   }];
 
   let builders = [
-    OpBuilder<(ins "Value":$memref, CArg<"bool", "false">:$restrict, CArg<"bool", "false">:$writeable), [{
-      auto rtt = memref::getTensorTypeFromMemRefType(memref.getType());
-      build($_builder, $_state, rtt, memref, restrict, writeable);
+    OpBuilder<(ins "Value":$buffer, CArg<"bool", "false">:$restrict, CArg<"bool", "false">:$writeable), [{
+      auto rtt = bufferization::detail::getTensorFromBuffer(buffer.getType());
+      build($_builder, $_state, rtt, buffer, restrict, writeable);
     }]>
   ];
 
@@ -503,10 +512,9 @@ def Bufferization_ToBufferOp : Bufferization_Op<"to_buffer", [
     SameOperandsAndResultShape,
     SameOperandsAndResultElementType,
     Pure,
-    AllShapesMatch<["memref", "tensor"]>,
-    AllElementTypesMatch<["memref", "tensor"]>
+    Bufferization_TensorAndBufferMatch<"tensor", "buffer">
   ]> {
-  let summary = "cast a tensor to memref";
+  let summary = "cast a tensor-like type to buffer-like type";
   let description = [{
     An operation that returns the future buffer of a `tensor`.
 
@@ -524,8 +532,8 @@ def Bufferization_ToBufferOp : Bufferization_Op<"to_buffer", [
     the returned buffer) will not be written to.
   }];
 
-  let arguments = (ins AnyTensor:$tensor, UnitAttr:$read_only);
-  let results = (outs AnyRankedOrUnrankedMemRef:$memref);
+  let arguments = (ins Bufferization_TensorLikeTypeInterface:$tensor, UnitAttr:$read_only);
+  let results = (outs Bufferization_BufferLikeTypeInterface:$buffer);
 
   let extraClassDeclaration = [{
     //===------------------------------------------------------------------===//
@@ -560,7 +568,7 @@ def Bufferization_ToBufferOp : Bufferization_Op<"to_buffer", [
   }];
 
   let assemblyFormat = [{
-    $tensor (`read_only` $read_only^)? attr-dict `:` type($tensor) `to` type($memref)
+    $tensor (`read_only` $read_only^)? attr-dict `:` type($tensor) `to` type($buffer)
   }];
 
   let hasFolder = 1;
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/UnstructuredControlFlow.h b/mlir/include/mlir/Dialect/Bufferization/IR/UnstructuredControlFlow.h
index a441b8b66659e..f56c10555f02c 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/UnstructuredControlFlow.h
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/UnstructuredControlFlow.h
@@ -65,12 +65,13 @@ struct OpWithUnstructuredControlFlowBufferizableOpInterfaceExternalModel
         // The operand was already bufferized. Take its type directly.
         callerType = memrefType;
       } else {
-        FailureOr<BaseMemRefType> maybeCallerType =
+        FailureOr<BufferLikeType> maybeCallerType =
             bufferization::getBufferType(opOperand->get(), options, state,
                                          invocationStack);
         if (failed(maybeCallerType))
           return failure();
-        callerType = *maybeCallerType;
+        assert(isa<BaseMemRefType>(*maybeCallerType) && "expected memref type");
+        callerType = cast<BaseMemRefType>(*maybeCallerType);
       }
 
       if (!bufferType) {
diff --git a/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp
index a57d58ab28d28..021a557f68b4b 100644
--- a/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -164,8 +164,8 @@ struct SelectOpInterface
     // buffers have different types, they differ only in their layout map. Cast
     // both of them to the most dynamic MemRef type.
     if (trueBuffer.getType() != falseBuffer.getType()) {
-      auto targetType =
-          bufferization::getBufferType(selectOp.getResult(), options, state);
+      auto targetType = bufferization::detail::castToMemRef(
+          bufferization::getBufferType(selectOp.getResult(), options, state));
       if (failed(targetType))
         return failure();
       if (trueBuffer.getType() != *targetType)
@@ -187,10 +187,12 @@ struct SelectOpInterface
                 SmallVector<Value> &invocationStack) const {
     auto selectOp = cast<arith::SelectOp>(op);
     assert(value == selectOp.getResult() && "invalid value");
-    auto trueType = bufferization::getBufferType(
-        selectOp.getTrueValue(), options, state, invocationStack);
-    auto falseType = bufferization::getBufferType(
-        selectOp.getFalseValue(), options, state, invocationStack);
+    auto trueType =
+        bufferization::detail::castToMemRef(bufferization::getBufferType(
+            selectOp.getTrueValue(), options, state, invocationStack));
+    auto falseType =
+        bufferization::detail::castToMemRef(bufferization::getBufferType(
+            selectOp.getFalseValue(), options, state, invocationStack));
     if (failed(trueType) || failed(falseType))
       return failure();
     if (*trueType == *falseType)
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
index 1d6e1bdaf80f5..d00605a7b9a17 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
@@ -8,6 +8,7 @@
 
 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
+#include "mlir/Dialect/Bufferization/IR/BufferizationConversionInterface.h"
 #include "mlir/Dialect/Func/IR/FuncOps.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
@@ -211,8 +212,8 @@ FailureOr<Value> bufferization::allocateTensorForShapedValue(
   // Add 'memory_space' attribute. Not needed if 'copy' operand is specified.
   if (copy)
     return allocTensorOp.getResult();
-  FailureOr<BaseMemRefType> copyBufferType =
-      getBufferType(tensor, options, state);
+  auto copyBufferType =
+      detail::castToMemRef(getBufferType(tensor, options, state));
   if (failed(copyBufferType))
     return failure();
   std::optional<Attribute> memorySpace = copyBufferType->getMemorySpace();
@@ -673,28 +674,28 @@ FailureOr<Value> bufferization::getBuffer(RewriterBase &rewriter, Value value,
                                           const BufferizationOptions &options,
                                           const BufferizationState &state) {
 #ifndef NDEBUG
-  auto tensorType = llvm::dyn_cast<TensorType>(value.getType());
+  auto tensorType = llvm::dyn_cast<TensorLikeType>(value.getType());
   assert(tensorType && "unexpected non-tensor type");
 #endif // NDEBUG
 
   // Replace "%t = to_tensor %m" with %m.
   if (auto toTensorOp = value.getDefiningOp<bufferization::ToTensorOp>())
-    return toTensorOp.getMemref();
+    return toTensorOp.getBuffer();
 
   // Insert to_buffer op.
   OpBuilder::InsertionGuard g(rewriter);
   setInsertionPointAfter(rewriter, value);
-  FailureOr<BaseMemRefType> memrefType = getBufferType(value, options, state);
-  if (failed(memrefType))
+  FailureOr<BufferLikeType> bufferType = getBufferType(value, options, state);
+  if (failed(bufferType))
     return failure();
-  ensureToBufferOpIsValid(value, *memrefType);
+  ensureToBufferOpIsValid(value, *bufferType);
   return rewriter
-      .create<bufferization::ToBufferOp>(value.getLoc(), *memrefType, value)
+      .create<bufferization::ToBufferOp>(value.getLoc(), *bufferType, value)
       .getResult();
 }
 
 /// Return the buffer type for a given Value (tensor) after bufferization.
-FailureOr<BaseMemRefType>
+FailureOr<BufferLikeType>
 bufferization::getBufferType(Value value, const BufferizationOptions &options,
                              const BufferizationState &state) {
   SmallVector<Value> invocationStack;
@@ -702,11 +703,11 @@ bufferization::getBufferType(Value value, const BufferizationOptions &options,
 }
 
 /// Return the buffer type for a given Value (tensor) after bufferization.
-FailureOr<BaseMemRefType>
+FailureOr<BufferLikeType>
 bufferization::getBufferType(Value value, const BufferizationOptions &options,
                              const BufferizationState &state,
                              SmallVector<Value> &invocationStack) {
-  assert(llvm::isa<TensorType>(value.getType()) &&
+  assert(llvm::isa<TensorLikeType>(value.getType()) &&
          "unexpected non-tensor type");
   invocationStack.push_back(value);
   auto popFromStack =
@@ -718,13 +719,11 @@ bufferization::getBufferType(Value value, const BufferizationOptions &options,
   if (bufferizableOp)
     return bufferizableOp.getBufferType(value, options, state, invocationStack);
 
-  // Op is not bufferizable.
-  auto memSpace =
-      options.defaultMemorySpaceFn(cast<TensorType>(value.getType()));
-  if (!memSpace.has_value())
-    return op->emitError("could not infer memory space");
-
-  return getMemRefType(value, options, /*layout=*/{}, *memSpace);
+  // Op is not bufferizable, use conversion interface.
+  bufferization::ConversionInterface iface(value.getContext());
+  return iface.getBufferType(value, options, state, [&](const Twine &message) {
+    return op->emitError(message);
+  });
 }
 
 bool bufferization::hasTensorSemantics(Operation *op) {
@@ -744,12 +743,11 @@ void bufferization::replaceOpWithBufferizedValues(RewriterBase &rewriter,
   SmallVector<Value> replacements;
   for (OpResult opResult : op->getOpResults()) {
     Value replacement = values[opResult.getResultNumber()];
-    if (llvm::isa<TensorType>(opResult.getType())) {
+    if (llvm::isa<TensorLikeType>(opResult.getType())) {
       // The OpResult is a tensor. Such values are replaced with memrefs during
       // bufferization.
-      assert((llvm::isa<MemRefType>(replacement.getType()) ||
-              llvm::isa<UnrankedMemRefType>(replacement.getType())) &&
-             "tensor op result should be replaced with a memref value");
+      assert(llvm::isa<BufferLikeType>(replacement.getType()) &&
+             "tensor op result should be replaced with a buffer value");
       // The existing uses of the OpResult still expect a tensor. Insert a
       // ToTensorOp. Throughout bufferization, this ToTensorOp will gradually
       // loose all of its users and eventually DCE away.
@@ -970,8 +968,8 @@ FailureOr<BaseMemRefType> bufferization::detail::defaultGetBufferType(
     // If the OpResult has an equivalent OpOperand, both OpResult and
     // OpOperand bufferize to the exact same buffer type.
     Value equivalentOperand = aliases.getAliases().front().opOperand->get();
-    return getBufferType(equivalentOperand, options, bufferizationState,
-                         invocationStack);
+    return castToMemRef(getBufferType(equivalentOperand, options,
+                                      bufferizationState, invocationStack));
   }
 
   // If we do not know the memory space and there is no default memory space,
@@ -1031,7 +1029,7 @@ bufferization::detail::unknownGetAliasingValues(OpOperand &opOperand) {
 }
 
 bool bufferization::detail::defaultHasTensorSemantics(Operation *op) {
-  auto isaTensor = [](Type t) { return isa<TensorType>(t); };
+  auto isaTensor = [](Type t) { return isa<TensorLikeType>(t); };
   bool hasTensorBlockArgument = any_of(op->getRegions(), [&](Region &r) {
     return any_of(r.getBlocks(), [&](Block &b) {
       return any_of(b.getArguments(), [&](BlockArgument bbArg) {
@@ -1046,3 +1044,31 @@ bool bufferization::detail::defaultHasTensorSemantics(Operation *op) {
     return true;
   return any_of(op->getOperandTypes(), isaTensor);
 }
+
+FailureOr<BaseMemRefType>
+bufferization::detail::castToMemRef(FailureOr<BufferLikeType> bufferType) {
+  if (failed(bufferType))
+    return failure();
+  assert(isa<BaseMemRefType>(*bufferType) && "expected memref type");
+  return cast<BaseMemRefType>(*bufferType);
+}
+
+bool bufferization::detail::typesMatchAfterBufferization(Operation &op,
+                                                         Value tensor,
+                                                         Value buffer) {
+  assert(isa<TensorLikeType>(tensor.getType()) && "expected TensorLikeType");
+  assert(isa<BufferLikeType>(buffer.getType()) && "expected BufferLikeType");
+
+  // Op is not bufferizable, use conversion interface.
+  bufferization::ConversionInterface iface(op.getContext());
+  return succeeded(iface.typesMatch(
+      cast<TensorLikeType>(tensor.getType()),
+      cast<BufferLikeType>(buffer.getType()),
+      [&](const Twine &message) { return op.emitError(message); }));
+}
+
+Type bufferization::detail::getTensorFromBuffer(Type buffer) {
+  assert(isa<BufferLikeType>(buffer) && "expected BufferLikeType");
+  bufferization::ConversionInterface iface(buffer.getContext());
+  return iface.getTensorFromBuffer(cast<BufferLikeType>(buffer));
+}
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizationConversionInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizationConversionInterface.cpp
new file mode 100644
index 0000000000000..287e9bf85002f
--- /dev/null
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizationConversionInterface.cpp
@@ -0,0 +1,67 @@
+//===- BufferizationConversionInterface.cpp - Dialect Interface  ---=------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Bufferization/IR/BufferizationConversionInterface.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h" // getTensorTypeFromMemRefType
+
+namespace mlir {
+namespace bufferization {
+
+FailureOr<BufferLikeType> ConversionInterface::getBufferType(
+    Value value, const BufferizationOptions &options,
+    const BufferizationState &state,
+    function_ref<InFlightDiagnostic(const Twine &)> emitError) const {
+  Dialect *dialect = &value.getType().getDialect();
+  if (const ConversionDialectInterface *iface = getInterfaceFor(dialect))
+    return iface->getBufferType(value, options, state, emitError);
+
+  // Fall back to tensor -> memref conversion.
+  auto memSpace =
+      options.defaultMemorySpaceFn(cast<TensorType>(value.getType()));
+  if (!memSpace.has_value())
+    return emitError("could not infer memory space");
+
+  return cast<BufferLikeType>(
+      getMemRefType(value, options, /*layout=*/{}, *memSpace));
+}
+
+LogicalResult ConversionInterface::typesMatch(
+    TensorLikeType tensor, BufferLikeType buffer,
+    function_ref<InFlightDiagnostic(const Twine &)> emitError) const {
+  Dialect *dialect = &tensor.getDialect();
+  if (const ConversionDialectInterface *iface = getInterfaceFor(dialect))
+    return iface->typesMatch(tensor, buffer, emitError);
+
+  // Fall back to tensor, memref checking.
+  assert(isa<TensorType>(tensor) && "expected tensor type");
+  assert(isa<BaseMemRefType>(buffer) && "expected memref type");
+
+  if (cast<ShapedType>(tensor).getShape() !=
+      cast<ShapedType>(buffer).getShape()) {
+    return emitError("shapes do not match");
+  }
+
+  if (cast<ShapedType>(tensor).getElementType() !=
+      cast<ShapedType>(buffer).getElementType()) {
+    return emitError("element types do not match");
+  }
+
+  return success();
+}
+
+TensorLikeType
+ConversionInterface::getTensorFromBuffer(BufferLikeType buffer) const {
+  Dialect *dialect = &buffer.getDialect();
+  if (const ConversionDialectInterface *iface = getInterfaceFor(dialect))
+    return iface->getTensorFromBuffer(buffer);
+
+  return cast<TensorLikeType>(memref::getTensorTypeFromMemRefType(buffer));
+}
+
+} // namespace bufferization
+} // namespace mlir
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
index dc54ac94aed32..79af1e8fee79f 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
@@ -90,12 +90,12 @@ LogicalResult mlir::bufferization::foldToBufferToTensorPair(
   if (!bufferToTensor)
     return failure();
 
-  Type srcType = bufferToTensor.getMemref().getType();
+  Type srcType = bufferToTensor.getBuffer().getType();
   Type destType = toBuffer.getType();
 
   // Directly rewrite if the type did not change.
   if (srcType == destType) {
-    rewriter.replaceOp(toBuffer, bufferToTensor.getMemref());
+    rewriter.replaceOp(toBuffer, bufferToTensor.getBuffer());
     return success();
   }
 
@@ -106,7 +106,7 @@ LogicalResult mlir::bufferization::foldToBufferToTensorPair(
   // Ranked memref -> Ranked memref cast.
   if (rankedSrcType && rankedDestType) {
     FailureOr<Value> replacement = castOrReallocMemRefValue(
-        rewriter, bufferToTensor.getMemref(), rankedDestType, options);
+        rewriter, bufferToTensor.getBuffer(), rankedDestType, options);
     if (failed(replacement))
       return failure();
 
@@ -124,7 +124,7 @@ LogicalResult mlir::bufferization::foldToBufferToTensorPair(
   assert(memref::CastOp::areCastCompatible(srcType, destType) &&
          "expected that types are cast compatible");
   rewriter.replaceOpWithNewOp<memref::CastOp>(toBuffer, destType,
-                                              bufferToTensor.getMemref());
+                                              bufferToTensor.getBuffer());
   return success();
 }
 
@@ -233,8 +233,9 @@ AllocTensorOp::getBufferType(Value value, const BufferizationOptions &options,
   if (getMemorySpace().has_value()) {
     memorySpace = *getMemorySpace();
   } else if (getCopy()) {
-    auto copyBufferType = bufferization::getBufferType(getCopy(), options,
-                                                       state, invocationStack);
+    auto copyBufferType =
+        bufferization::detail::castToMemRef(bufferization::getBufferType(
+            getCopy(), options, state, invocationStack));
     if (failed(copyBufferType))
       return failure();
     memorySpace = copyBufferType->getMemorySpace();
@@ -744,7 +745,7 @@ bool ToTensorOp::isWritable(Value value, const AnalysisState &state) {
 }
 
 OpFoldResult ToTensorOp::fold(FoldAdaptor) {
-  if (auto toBuffer = getMemref().getDefiningOp<ToBufferOp>())
+  if (auto toBuffer = getBuffer().getDefiningOp<ToBufferOp>())
     // Approximate alias analysis by conservatively folding only when no there
     // is no interleaved operation.
     if (toBuffer->getBlock() == this->getOperation()->getBlock() &&
@@ -764,7 +765,7 @@ struct DimOfToTensorFolder : public OpRewritePattern<tensor::DimOp> {
       return failure();
 
     rewriter.replaceOpWithNewOp<memref::DimOp>(
-        dimOp, memrefToTensorOp.getMemref(), dimOp.getIndex());
+        dimOp, memrefToTensorOp.getBuffer(), dimOp.getIndex());
     return success();
   }
 };
@@ -781,8 +782,8 @@ void ToTensorOp::getCanonicalizationPatterns(RewritePatternSet &results,
 
 OpFoldResult ToBufferOp::fold(FoldAdaptor) {
   if (auto memrefToTensor = getTensor().getDefiningOp<ToTensorOp>())
-    if (memrefToTensor.getMemref().getType() == getType())
-      return memrefToTensor.getMemref();
+    if (memrefToTensor.getBuffer().getType() == getType())
+      return memrefToTensor.getBuffer();
   return {};
 }
 
diff --git a/mlir/lib/Dialect/Bufferization/IR/CMakeLists.txt b/mlir/lib/Dialect/Bufferization/IR/CMakeLists.txt
index 63dcc1eb233e9..a47c1569e4c33 100644
--- a/mlir/lib/Dialect/Bufferization/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/Bufferization/IR/CMakeLists.txt
@@ -6,6 +6,7 @@ add_mlir_dialect_library(MLIRBufferizationDialect
   BufferizationDialect.cpp
   BufferViewFlowOpInterface.cpp
   UnstructuredControlFlow.cpp
+  BufferizationConversionInterface.cpp
 
   ADDITIONAL_HEADER_DIRS
   ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Bufferization
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
index c7681d309a4af..e3ffa2125af70 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
@@ -412,11 +412,11 @@ bufferization::bufferizeBlockSignature(Block *block, RewriterBase &rewriter,
       continue;
     }
 
-    FailureOr<BaseMemRefType> memrefType =
+    FailureOr<BufferLikeType> bufferType =
         bufferization::getBufferType(bbArg, options, state);
-    if (failed(memrefType))
+    if (failed(bufferType))
       return failure();
-    newTypes.push_back(*memrefType);
+    newTypes.push_back(*bufferType);
   }
 
   // Change the type of all block arguments.
@@ -463,7 +463,7 @@ bufferization::bufferizeBlockSignature(Block *block, RewriterBase &rewriter,
         newOperands.push_back(operand);
         continue;
       }
-      FailureOr<BaseMemRefType> operandBufferType =
+      FailureOr<BufferLikeType> operandBufferType =
           bufferization::getBufferType(operand, options, state);
       if (failed(operandBufferType))
         return failure();
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
index a0168da44b7b3..453ed43bcadd2 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
@@ -255,7 +255,7 @@ struct CallOpInterface
       }
 
       // Returning a memref.
-      FailureOr<BaseMemRefType> resultType =
+      FailureOr<BufferLikeType> resultType =
           bufferization::getBufferType(result, options, state);
       if (failed(resultType))
         return failure();
@@ -290,13 +290,13 @@ struct CallOpInterface
         // The called function was not bufferized yet. This can happen when
         // there cycles in the function call graph. Compute the bufferized
         // result type.
-        FailureOr<BaseMemRefType> maybeMemRefType =
+        FailureOr<BufferLikeType> maybeBufferType =
             bufferization::getBufferType(
                 funcOp.getArgument(opOperand.getOperandNumber()), options,
                 state);
-        if (failed(maybeMemRefType))
+        if (failed(maybeBufferType))
           return failure();
-        memRefType = *maybeMemRefType;
+        memRefType = *maybeBufferType;
       }
 
       // Since we don't yet have a clear layout story, to_buffer may
diff --git a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
index 46fa77a7dc4e6..efa9fc1a070aa 100644
--- a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -108,7 +108,7 @@ struct ConditionOpInterface
             getBuffer(rewriter, value, options, state);
         if (failed(maybeBuffer))
           return failure();
-        FailureOr<BaseMemRefType> resultType = bufferization::getBufferType(
+        FailureOr<BufferLikeType> resultType = bufferization::getBufferType(
             whileOp.getAfterArguments()[it.index()], options, state);
         if (failed(resultType))
           return failure();
@@ -292,8 +292,9 @@ struct IfOpInterface
       // True branch was already bufferized.
       thenBufferType = cast<BaseMemRefType>(thenValue.getType());
     } else {
-      auto maybeBufferType = bufferization::getBufferType(
-          thenValue, options, state, invocationStack);
+      auto maybeBufferType =
+          bufferization::detail::castToMemRef(bufferization::getBufferType(
+              thenValue, options, state, invocationStack));
       if (failed(maybeBufferType))
         return failure();
       thenBufferType = *maybeBufferType;
@@ -302,8 +303,9 @@ struct IfOpInterface
       // False branch was already bufferized.
       elseBufferType = cast<BaseMemRefType>(elseValue.getType());
     } else {
-      auto maybeBufferType = bufferization::getBufferType(
-          elseValue, options, state, invocationStack);
+      auto maybeBufferType =
+          bufferization::detail::castToMemRef(bufferization::getBufferType(
+              elseValue, options, state, invocationStack));
       if (failed(maybeBufferType))
         return failure();
       elseBufferType = *maybeBufferType;
@@ -406,9 +408,7 @@ struct IndexSwitchOpInterface
         return bufferType;
       auto maybeBufferType = bufferization::getBufferType(
           yieldedValue, options, state, invocationStack);
-      if (failed(maybeBufferType))
-        return failure();
-      return maybeBufferType;
+      return bufferization::detail::castToMemRef(maybeBufferType);
     };
 
     // Compute buffer type of the default case.
@@ -527,8 +527,8 @@ static FailureOr<BaseMemRefType> computeLoopRegionIterArgBufferType(
     const BufferizationOptions &options, const BufferizationState &state,
     SmallVector<Value> &invocationStack) {
   // Determine the buffer type of the init_arg.
-  auto initArgBufferType =
-      bufferization::getBufferType(initArg, options, state, invocationStack);
+  auto initArgBufferType = bufferization::detail::castToMemRef(
+      bufferization::getBufferType(initArg, options, state, invocationStack));
   if (failed(initArgBufferType))
     return failure();
 
@@ -554,8 +554,9 @@ static FailureOr<BaseMemRefType> computeLoopRegionIterArgBufferType(
   } else {
     // Note: This typically triggers a recursive call for the buffer type of
     // the iter_arg.
-    auto maybeBufferType = bufferization::getBufferType(yieldedValue, options,
-                                                        state, invocationStack);
+    auto maybeBufferType =
+        bufferization::detail::castToMemRef(bufferization::getBufferType(
+            yieldedValue, options, state, invocationStack));
     if (failed(maybeBufferType))
       return failure();
     yieldedValueBufferType = *maybeBufferType;
@@ -718,8 +719,12 @@ struct ForOpInterface
     if (auto opResult = dyn_cast<OpResult>(value)) {
       // The type of an OpResult must match the corresponding iter_arg type.
       BlockArgument bbArg = forOp.getTiedLoopRegionIterArg(opResult);
-      return bufferization::getBufferType(bbArg, options, state,
-                                          invocationStack);
+      auto bufferType =
+          bufferization::getBufferType(bbArg, options, state, invocationStack);
+      if (failed(bufferType))
+        return failure();
+      assert(isa<BaseMemRefType>(*bufferType) && "expected memref type");
+      return cast<BaseMemRefType>(*bufferType);
     }
 
     // Compute result/argument number.
@@ -1078,8 +1083,8 @@ struct WhileOpInterface
       // scf.condition was already bufferized.
       return cast<BaseMemRefType>(conditionYieldedVal.getType());
     }
-    return bufferization::getBufferType(conditionYieldedVal, options, state,
-                                        invocationStack);
+    return bufferization::detail::castToMemRef(bufferization::getBufferType(
+        conditionYieldedVal, options, state, invocationStack));
   }
 
   /// Assert that yielded values of an scf.while op are equivalent to their
@@ -1185,14 +1190,14 @@ struct YieldOpInterface
         // We may have to cast the value before yielding it.
         if (isa<scf::ForOp, scf::IfOp, scf::IndexSwitchOp>(
                 yieldOp->getParentOp())) {
-          FailureOr<BaseMemRefType> resultType = bufferization::getBufferType(
+          FailureOr<BufferLikeType> resultType = bufferization::getBufferType(
               yieldOp->getParentOp()->getResult(it.index()), options, state);
           if (failed(resultType))
             return failure();
           buffer = castBuffer(rewriter, buffer, *resultType);
         } else if (auto whileOp =
                        dyn_cast<scf::WhileOp>(yieldOp->getParentOp())) {
-          FailureOr<BaseMemRefType> resultType = bufferization::getBufferType(
+          FailureOr<BufferLikeType> resultType = bufferization::getBufferType(
               whileOp.getBeforeArguments()[it.index()], options, state);
           if (failed(resultType))
             return failure();
@@ -1307,15 +1312,15 @@ struct ForallOpInterface
     if (auto bbArg = dyn_cast<BlockArgument>(value))
       // A tensor block argument has the same bufferized type as the
       // corresponding output operand.
-      return bufferization::getBufferType(
-          forallOp.getTiedOpOperand(bbArg)->get(), options, state,
-          invocationStack);
+      return bufferization::detail::castToMemRef(
+          bufferization::getBufferType(forallOp.getTiedOpOperand(bbArg)->get(),
+                                       options, state, invocationStack));
 
     // The bufferized result type is the same as the bufferized type of the
     // corresponding output operand.
-    return bufferization::getBufferType(
+    return bufferization::detail::castToMemRef(bufferization::getBufferType(
         forallOp.getOutputs()[cast<OpResult>(value).getResultNumber()], options,
-        state, invocationStack);
+        state, invocationStack));
   }
 
   bool isRepetitiveRegion(Operation *op, unsigned index) const {
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenUtils.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenUtils.cpp
index 57291064eba22..1bd9563b3db07 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenUtils.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenUtils.cpp
@@ -549,8 +549,8 @@ TypedValue<BaseMemRefType>
 sparse_tensor::genToMemref(OpBuilder &builder, Location loc, Value tensor) {
   auto tTp = llvm::cast<TensorType>(tensor.getType());
   auto mTp = MemRefType::get(tTp.getShape(), tTp.getElementType());
-  return builder.create<bufferization::ToBufferOp>(loc, mTp, tensor)
-      .getResult();
+  return cast<TypedValue<BaseMemRefType>>(
+      builder.create<bufferization::ToBufferOp>(loc, mTp, tensor).getResult());
 }
 
 Value sparse_tensor::createOrFoldSliceOffsetOp(OpBuilder &builder, Location loc,
diff --git a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
index 4b778b768d136..40b710f17fe44 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -54,8 +54,9 @@ struct CastOpInterface
                 const BufferizationState &state,
                 SmallVector<Value> &invocationStack) const {
     auto castOp = cast<tensor::CastOp>(op);
-    auto maybeSrcBufferType = bufferization::getBufferType(
-        castOp.getSource(), options, state, invocationStack);
+    auto maybeSrcBufferType =
+        bufferization::detail::castToMemRef(bufferization::getBufferType(
+            castOp.getSource(), options, state, invocationStack));
     if (failed(maybeSrcBufferType))
       return failure();
     Attribute memorySpace = maybeSrcBufferType->getMemorySpace();
@@ -500,8 +501,8 @@ struct FromElementsOpInterface
         /*copy=*/false);
     if (failed(tensorAlloc))
       return failure();
-    FailureOr<BaseMemRefType> memrefType =
-        bufferization::getBufferType(*tensorAlloc, options, state);
+    FailureOr<BaseMemRefType> memrefType = bufferization::detail::castToMemRef(
+        bufferization::getBufferType(*tensorAlloc, options, state));
     if (failed(memrefType))
       return failure();
     Value buffer = rewriter.create<bufferization::ToBufferOp>(
@@ -758,8 +759,9 @@ struct PadOpInterface
                 SmallVector<Value> &invocationStack) const {
     // Infer memory space from the source tensor.
     auto padOp = cast<tensor::PadOp>(op);
-    auto maybeSrcBufferType = bufferization::getBufferType(
-        padOp.getSource(), options, state, invocationStack);
+    auto maybeSrcBufferType =
+        bufferization::detail::castToMemRef(bufferization::getBufferType(
+            padOp.getSource(), options, state, invocationStack));
     if (failed(maybeSrcBufferType))
       return failure();
     MemRefLayoutAttrInterface layout;
diff --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize.mlir
index cd19e3a5e82aa..da3c26ce36ba5 100644
--- a/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize.mlir
+++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize.mlir
@@ -268,4 +268,23 @@ func.func @materialize_in_dest_raw(%f: f32, %f2: f32, %idx: index) -> (tensor<5x
   %r = tensor.extract %dest_filled[%idx] : tensor<5xf32>
 
   return %0, %r : tensor<5xf32>, f32
-}
\ No newline at end of file
+}
+
+// -----
+
+// CHECK-LABEL: func.func @test_dialect_op(
+// CHECK-SAME:    %[[ARG:.*]]: !test.test_tensor<[32, 64], f64>
+// CHECK-SAME:  ) -> !test.test_tensor<[32, 128], f64> {
+func.func @test_dialect_op(%arg: !test.test_tensor<[32, 64], f64>)
+    -> !test.test_tensor<[32, 128], f64> {
+  // CHECK: %[[MEMREF:.*]] = bufferization.to_buffer %[[ARG]]
+  // CHECK: %[[DUMMY:.*]] = "test.dummy_memref_op"(%[[MEMREF]])
+  // CHECK-SAME: : (!test.test_memref<[32, 64], f64>)
+  // CHECK-SAME: -> !test.test_memref<[32, 128], f64>
+  // CHECK: %[[OUT:.*]] = bufferization.to_tensor %[[DUMMY]]
+  %out = "test.dummy_tensor_op"(%arg) : (!test.test_tensor<[32, 64], f64>)
+    -> !test.test_tensor<[32, 128], f64>
+
+  // CHECK: return %[[OUT]]
+  return %out : !test.test_tensor<[32, 128], f64>
+}
diff --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp
index 1bbf2cc7481d9..03985874f910d 100644
--- a/mlir/test/lib/Dialect/Test/TestDialect.cpp
+++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp
@@ -11,6 +11,7 @@
 #include "TestTypes.h"
 #include "mlir/Bytecode/BytecodeImplementation.h"
 #include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Bufferization/IR/BufferizationConversionInterface.h"
 #include "mlir/Dialect/Func/IR/FuncOps.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
 #include "mlir/IR/AsmState.h"
@@ -284,6 +285,53 @@ getDynamicCustomParserPrinterOp(TestDialect *dialect) {
                                   verifier, regionVerifier, parser, printer);
 }
 
+namespace {
+
+struct TestConverter : bufferization::ConversionDialectInterface {
+  TestConverter(Dialect *dialect)
+      : bufferization::ConversionDialectInterface(dialect) {}
+
+  FailureOr<bufferization::BufferLikeType>
+  getBufferType(Value value, const bufferization::BufferizationOptions &options,
+                const bufferization::BufferizationState &state,
+                function_ref<InFlightDiagnostic(const Twine &)> emitError)
+      const override {
+    auto testTensor = dyn_cast<TestTensorType>(value.getType());
+    if (!testTensor)
+      return emitError("expected TestTensorType");
+
+    return cast<bufferization::BufferLikeType>(
+        TestMemrefType::get(value.getContext(), testTensor.getShape(),
+                            testTensor.getElementType(), nullptr));
+  }
+
+  LogicalResult typesMatch(bufferization::TensorLikeType tensor,
+                           bufferization::BufferLikeType buffer,
+                           function_ref<InFlightDiagnostic(const Twine &)>
+                               emitError) const override {
+    auto testTensor = dyn_cast<TestTensorType>(tensor);
+    auto testMemref = dyn_cast<TestMemrefType>(buffer);
+    if (!testTensor || !testMemref)
+      return emitError("expected TestTensorType and TestMemrefType");
+
+    const bool valid =
+        testTensor.getShape() == testMemref.getShape() &&
+        testTensor.getElementType() == testMemref.getElementType();
+    return success(valid);
+  }
+
+  bufferization::TensorLikeType
+  getTensorFromBuffer(bufferization::BufferLikeType buffer) const override {
+    auto testMemref = dyn_cast<TestMemrefType>(buffer);
+    assert(testMemref && "expected TestMemrefType");
+    return cast<bufferization::TensorLikeType>(
+        TestTensorType::get(testMemref.getContext(), testMemref.getShape(),
+                            testMemref.getElementType()));
+  }
+};
+
+} // namespace
+
 //===----------------------------------------------------------------------===//
 // TestDialect
 //===----------------------------------------------------------------------===//
@@ -333,6 +381,7 @@ void TestDialect::initialize() {
   registerDynamicOp(getDynamicCustomParserPrinterOp(this));
   registerInterfaces();
   allowUnknownOperations();
+  addInterface<TestConverter>();
 
   // Instantiate our fallback op interface that we'll use on specific
   // unregistered op.
diff --git a/mlir/test/lib/Dialect/Test/TestOpDefs.cpp b/mlir/test/lib/Dialect/Test/TestOpDefs.cpp
index b5a8bd10d6b68..78e44c6ec7a9b 100644
--- a/mlir/test/lib/Dialect/Test/TestOpDefs.cpp
+++ b/mlir/test/lib/Dialect/Test/TestOpDefs.cpp
@@ -8,6 +8,7 @@
 
 #include "TestDialect.h"
 #include "TestOps.h"
+#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
 #include "mlir/IR/Verifier.h"
 #include "mlir/Interfaces/FunctionImplementation.h"
@@ -1387,3 +1388,25 @@ TestMultiSlotAlloca::handleDestructuringComplete(
     const DestructurableMemorySlot &slot, OpBuilder &builder) {
   return createNewMultiAllocaWithoutSlot(slot, builder, *this);
 }
+
+::mlir::LogicalResult test::TestDummyTensorOp::bufferize(
+    ::mlir::RewriterBase &rewriter,
+    const ::mlir::bufferization::BufferizationOptions &options,
+    ::mlir::bufferization::BufferizationState &state) {
+  auto buffer =
+      mlir::bufferization::getBuffer(rewriter, getInput(), options, state);
+  if (mlir::failed(buffer))
+    return failure();
+
+  const auto outType = getOutput().getType();
+  const auto bufferizedOutType = test::TestMemrefType::get(
+      getContext(), outType.getShape(), outType.getElementType(), nullptr);
+  // replace op with memref analogy
+  auto dummyMemrefOp = rewriter.create<test::TestDummyMemrefOp>(
+      getLoc(), bufferizedOutType, *buffer);
+
+  mlir::bufferization::replaceOpWithBufferizedValues(rewriter, getOperation(),
+                                                     dummyMemrefOp.getResult());
+
+  return mlir::success();
+}
diff --git a/mlir/test/lib/Dialect/Test/TestOps.h b/mlir/test/lib/Dialect/Test/TestOps.h
index c2ee5f9ab9a57..b414b47c87425 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.h
+++ b/mlir/test/lib/Dialect/Test/TestOps.h
@@ -13,6 +13,7 @@
 #include "TestInterfaces.h"
 #include "TestTypes.h"
 #include "mlir/Bytecode/BytecodeImplementation.h"
+#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
 #include "mlir/Dialect/DLTI/DLTI.h"
 #include "mlir/Dialect/DLTI/Traits.h"
 #include "mlir/Dialect/Func/IR/FuncOps.h"
diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index 59330fdb1bb2c..79bcd9c2e0a9a 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -31,7 +31,7 @@ include "mlir/Interfaces/InferTypeOpInterface.td"
 include "mlir/Interfaces/LoopLikeInterface.td"
 include "mlir/Interfaces/MemorySlotInterfaces.td"
 include "mlir/Interfaces/SideEffectInterfaces.td"
-
+include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td"
 
 // Include the attribute definitions.
 include "TestAttrDefs.td"
@@ -2825,7 +2825,7 @@ def TestNVVMRequiresSMArchCondOp :
   let assemblyFormat = "attr-dict";
 }
 
-def TestNVVMRequirestSMArchCondMultiOp : 
+def TestNVVMRequirestSMArchCondMultiOp :
     TEST_Op<"nvvm_requires_sm_90a_or_sm_100a", [NVVMRequiresSMa<[90, 100]>]> {
   let arguments = (ins );
   let assemblyFormat = "attr-dict";
@@ -3552,4 +3552,58 @@ def TestAllocWithMultipleResults : TEST_Op<"alloc_with_multiple_results"> {
   }];
 }
 
+//===----------------------------------------------------------------------===//
+// Test Ops bufferization
+//===----------------------------------------------------------------------===//
+
+def TestDummyTensorOp : TEST_Op<"dummy_tensor_op", [BufferizableOpInterface]> {
+  let arguments = (ins
+    Arg<TestTensorType>:$input
+  );
+  let results = (outs
+    Arg<TestTensorType>:$output
+  );
+  let extraClassDeclaration = [{
+    // BufferizableOpInterface
+    bool bufferizesToMemoryRead(mlir::OpOperand&,
+      const mlir::bufferization::AnalysisState&);
+
+    bool bufferizesToMemoryWrite(mlir::OpOperand&,
+      const mlir::bufferization::AnalysisState&);
+
+    mlir::bufferization::AliasingValueList getAliasingValues(mlir::OpOperand&,
+      const mlir::bufferization::AnalysisState&);
+
+    mlir::LogicalResult bufferize(
+      mlir::RewriterBase& rewriter,
+      const mlir::bufferization::BufferizationOptions& options,
+      mlir::bufferization::BufferizationState &state);
+  }];
+
+  let extraClassDefinition = [{
+    bool test::TestDummyTensorOp::bufferizesToMemoryRead(::mlir::OpOperand&,
+        const ::mlir::bufferization::AnalysisState&) {
+      return true;
+    }
+    bool test::TestDummyTensorOp::bufferizesToMemoryWrite(::mlir::OpOperand&,
+        const ::mlir::bufferization::AnalysisState&) {
+      return true;
+    }
+    ::mlir::bufferization::AliasingValueList
+    test::TestDummyTensorOp::getAliasingValues(::mlir::OpOperand&,
+        const ::mlir::bufferization::AnalysisState&) {
+      return {};
+    }
+  }];
+}
+
+def TestDummyMemrefOp : TEST_Op<"dummy_memref_op", []> {
+  let arguments = (ins
+    Arg<TestMemrefType>:$input
+  );
+  let results = (outs
+    Arg<TestMemrefType>:$output
+  );
+}
+
 #endif // TEST_OPS



More information about the Mlir-commits mailing list