[Mlir-commits] [mlir] [mlir][bufferization] Support custom types (1/N) (PR #142986)
Andrei Golubev
llvmlistbot at llvm.org
Thu Jun 5 08:10:43 PDT 2025
https://github.com/andrey-golubev created https://github.com/llvm/llvm-project/pull/142986
Following the introduction of TensorLike and BufferLike type interfaces (see 00eaff3e9c897c263a879416d0f151d7ca7eeaff), introduce minimal changes required to bufferize a custom tensor operation into a custom buffer operation.
To achieve this, a new conversion dialect interface is added that abstracts away the differences between existing (tensor -> memref) and custom conversions.
The scope of the changes is intentionally limited (for example, BufferizableOpInterface is untouched) in order to first understand the basics and reach consensus design-wise.
>From 6599da157f41246174509faecd727b9ed8682264 Mon Sep 17 00:00:00 2001
From: "Golubev, Andrey" <andrey.golubev at intel.com>
Date: Wed, 4 Jun 2025 15:03:26 +0000
Subject: [PATCH] [mlir][bufferization] Support custom types (1/N)
Following the introduction of TensorLike and BufferLike type interfaces
(see 00eaff3e9c897c263a879416d0f151d7ca7eeaff), introduce minimal
changes required to bufferize a custom tensor operation into a custom
buffer operation.
To achieve this, a new conversion dialect interface is added that
abstracts away the differences between existing (tensor -> memref) and
custom conversions.
The scope of the changes is intentionally limited (for example,
BufferizableOpInterface is untouched) in order to first understand the
basics and reach consensus design-wise.
---
.../IR/BufferizableOpInterface.h | 17 ++++-
.../IR/BufferizationConversionInterface.h | 72 ++++++++++++++++++
.../Bufferization/IR/BufferizationOps.td | 48 +++++++-----
.../IR/UnstructuredControlFlow.h | 5 +-
.../BufferizableOpInterfaceImpl.cpp | 14 ++--
.../IR/BufferizableOpInterface.cpp | 76 +++++++++++++------
.../IR/BufferizationConversionInterface.cpp | 67 ++++++++++++++++
.../Bufferization/IR/BufferizationOps.cpp | 21 ++---
.../Dialect/Bufferization/IR/CMakeLists.txt | 1 +
.../Bufferization/Transforms/Bufferize.cpp | 8 +-
.../FuncBufferizableOpInterfaceImpl.cpp | 8 +-
.../BufferizableOpInterfaceImpl.cpp | 51 +++++++------
.../Transforms/Utils/CodegenUtils.cpp | 4 +-
.../BufferizableOpInterfaceImpl.cpp | 14 ++--
.../Transforms/one-shot-bufferize.mlir | 21 ++++-
mlir/test/lib/Dialect/Test/TestDialect.cpp | 49 ++++++++++++
mlir/test/lib/Dialect/Test/TestOpDefs.cpp | 23 ++++++
mlir/test/lib/Dialect/Test/TestOps.h | 1 +
mlir/test/lib/Dialect/Test/TestOps.td | 58 +++++++++++++-
19 files changed, 451 insertions(+), 107 deletions(-)
create mode 100644 mlir/include/mlir/Dialect/Bufferization/IR/BufferizationConversionInterface.h
create mode 100644 mlir/lib/Dialect/Bufferization/IR/BufferizationConversionInterface.cpp
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
index adccbef754ec5..8390da956444d 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
@@ -17,6 +17,7 @@
#include <optional>
#include "mlir/Dialect/Bufferization/IR/BufferizationEnums.h.inc"
+#include "mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.h"
namespace mlir {
class OpBuilder;
@@ -615,7 +616,7 @@ FailureOr<Value> getBuffer(RewriterBase &rewriter, Value value,
/// IR, this function can be used.
///
/// This function is a wrapper around BufferizableOpInterface::getBufferType.
-FailureOr<BaseMemRefType> getBufferType(Value value,
+FailureOr<BufferLikeType> getBufferType(Value value,
const BufferizationOptions &options,
const BufferizationState &state);
@@ -629,7 +630,7 @@ FailureOr<BaseMemRefType> getBufferType(Value value,
/// IR, this function can be used.
///
/// This function is a wrapper around `BufferizableOpInterface::getBufferType`.
-FailureOr<BaseMemRefType> getBufferType(Value value,
+FailureOr<BufferLikeType> getBufferType(Value value,
const BufferizationOptions &options,
const BufferizationState &state,
SmallVector<Value> &invocationStack);
@@ -738,6 +739,18 @@ AliasingValueList unknownGetAliasingValues(OpOperand &opOperand);
/// This is the default implementation of
/// BufferizableOpInterface::hasTensorSemantics
bool defaultHasTensorSemantics(Operation *op);
+
+/// This is a helper function used when buffer type is guaranteed to be memref.
+FailureOr<BaseMemRefType> castToMemRef(FailureOr<BufferLikeType> bufferType);
+
+/// This function is a free-standing helper that relies on
+/// bufferization::ConversionInterface to verify the types in tensor and buffer
+/// worlds match.
+bool typesMatchAfterBufferization(Operation &op, Value tensor, Value buffer);
+
+/// This function is a free-standing helper that relies on
+/// bufferization::ConversionInterface to perform the conversion.
+Type getTensorFromBuffer(Type buffer);
} // namespace detail
} // namespace bufferization
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationConversionInterface.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationConversionInterface.h
new file mode 100644
index 0000000000000..4164d1dcb9ea6
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationConversionInterface.h
@@ -0,0 +1,72 @@
+//===- BufferizationConversionInterface.h - Dialect Interface ---*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_BUFFERIZATION_IR_BUFFERIZATIONCONVERSIONINTERFACE_H_
+#define MLIR_DIALECT_BUFFERIZATION_IR_BUFFERIZATIONCONVERSIONINTERFACE_H_
+
+#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
+#include "mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.h"
+#include "mlir/IR/DialectInterface.h"
+
+namespace mlir {
+namespace bufferization {
+
+/// This class defines a virtual interface for conversions between tensor-like
+/// and buffer-like types.
+struct ConversionDialectInterface
+ : DialectInterface::Base<ConversionDialectInterface> {
+ using Base::Base;
+
+ /// Hook to customize tensor-like -> buffer-like conversion within a given
+ /// dialect. Returns a buffer-like type for the specific tensor-like type.
+ virtual FailureOr<BufferLikeType> getBufferType(
+ Value value, const BufferizationOptions &options,
+ const BufferizationState &state,
+ function_ref<InFlightDiagnostic(const Twine &)> emitError) const = 0;
+
+ /// Hook to customize type checking between tensor-like and buffer-like types.
+ /// Given tensor `T` and buffer `B = getBufferType(T, ...)`, the call to
+ /// `typesMatch(T, B)` must return true.
+ virtual LogicalResult typesMatch(
+ TensorLikeType tensor, BufferLikeType buffer,
+ function_ref<InFlightDiagnostic(const Twine &)> emitError) const = 0;
+
+ /// Hook to customize buffer-like -> tensor-like conversion, which is the
+ /// opposite of bufferization.
+ virtual TensorLikeType getTensorFromBuffer(BufferLikeType buffer) const = 0;
+};
+
+/// Interface collection for conversion between tensor-like and buffer-like
+/// types, dispatches to a concrete interface implementation based on the
+/// dialect to which the given type belongs.
+struct ConversionInterface
+ : DialectInterfaceCollection<ConversionDialectInterface> {
+ using Base::Base;
+
+ /// Dispatches to ConversionDialectInterface::getBufferType() of the dialect
+ /// associated with the value type.
+ FailureOr<BufferLikeType> getBufferType(
+ Value value, const BufferizationOptions &options,
+ const BufferizationState &state,
+ function_ref<InFlightDiagnostic(const Twine &)> emitError) const;
+
+ /// Dispatches to ConversionDialectInterface::typesMatch() of the dialect
+ /// associated with the value type.
+ LogicalResult
+ typesMatch(TensorLikeType tensor, BufferLikeType buffer,
+ function_ref<InFlightDiagnostic(const Twine &)> emitError) const;
+
+ /// Dispatches to ConversionDialectInterface::getTensorFromBuffer() of the
+ /// dialect associated with the value type.
+ TensorLikeType getTensorFromBuffer(BufferLikeType buffer) const;
+};
+
+} // namespace bufferization
+} // namespace mlir
+
+#endif // MLIR_DIALECT_BUFFERIZATION_IR_BUFFERIZATIONTYPEINTERFACES_H_
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
index 3d4dcdee2663b..277d56bc3f647 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
@@ -12,6 +12,7 @@
include "mlir/Dialect/Bufferization/IR/AllocationOpInterface.td"
include "mlir/Dialect/Bufferization/IR/BufferViewFlowOpInterface.td"
include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td"
+include "mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.td"
include "mlir/Dialect/Bufferization/IR/BufferizationBase.td"
include "mlir/Interfaces/DestinationStyleOpInterface.td"
include "mlir/Interfaces/InferTypeOpInterface.td"
@@ -387,20 +388,28 @@ def Bufferization_DeallocTensorOp : Bufferization_Op<"dealloc_tensor",
// ToTensorOp
//===----------------------------------------------------------------------===//
+class Bufferization_TensorAndBufferMatch<string tensor, string buffer> : PredOpTrait<
+ "specified tensor and buffer types match",
+ CPred<
+ "::mlir::bufferization::detail::typesMatchAfterBufferization("
+ "$_op, $" # tensor # ", $" # buffer #")"
+ >
+>;
+
def Bufferization_ToTensorOp : Bufferization_Op<"to_tensor", [
BufferizableOpInterface,
SameOperandsAndResultShape,
SameOperandsAndResultElementType,
- AllElementTypesMatch<["memref", "result"]>
+ Bufferization_TensorAndBufferMatch<"result", "buffer">
]> {
- let summary = "create a tensor from a `memref`";
+ let summary = "create a buffer-like type from a tensor-like type";
let description = [{
- An operation that creates a tensor from a `memref`. The result value is a
- tensor whose shape and element type match the memref operand.
+ An operation that creates a tensor from a buffer. The result value is a
+ tensor-like type whose shape and element type match the buffer-like operand.
The opposite of this op is `to_buffer`. Together, these two ops are
useful for source/target materializations when doing type conversions
- involving tensors and memrefs.
+ involving tensors and buffers.
Example:
@@ -442,11 +451,11 @@ def Bufferization_ToTensorOp : Bufferization_Op<"to_tensor", [
away. However, such IR is no longer bufferizable with One-Shot Bufferize.
}];
- let arguments = (ins Arg<AnyRankedOrUnrankedMemRef,
+ let arguments = (ins Arg<Bufferization_BufferLikeTypeInterface,
"the reference to load from",
- [MemReadAt<0, FullEffect>]>:$memref,
+ [MemReadAt<0, FullEffect>]>:$buffer,
UnitAttr:$restrict, UnitAttr:$writable);
- let results = (outs AnyTensor:$result);
+ let results = (outs Bufferization_TensorLikeTypeInterface:$result);
let extraClassDeclaration = [{
/// The result of a to_tensor is always a tensor.
@@ -473,19 +482,19 @@ def Bufferization_ToTensorOp : Bufferization_Op<"to_tensor", [
FailureOr<BaseMemRefType> getBufferType(
Value value, const BufferizationOptions &options,
const BufferizationState &state, SmallVector<Value> &invocationStack) {
- return ::llvm::cast<BaseMemRefType>(getMemref().getType());
+ return ::llvm::cast<BaseMemRefType>(getBuffer().getType());
}
}];
let assemblyFormat = [{
- $memref (`restrict` $restrict^)? (`writable` $writable^)? attr-dict
- `:` type($memref) `to` type($result)
+ $buffer (`restrict` $restrict^)? (`writable` $writable^)? attr-dict
+ `:` type($buffer) `to` type($result)
}];
let builders = [
- OpBuilder<(ins "Value":$memref, CArg<"bool", "false">:$restrict, CArg<"bool", "false">:$writeable), [{
- auto rtt = memref::getTensorTypeFromMemRefType(memref.getType());
- build($_builder, $_state, rtt, memref, restrict, writeable);
+ OpBuilder<(ins "Value":$buffer, CArg<"bool", "false">:$restrict, CArg<"bool", "false">:$writeable), [{
+ auto rtt = bufferization::detail::getTensorFromBuffer(buffer.getType());
+ build($_builder, $_state, rtt, buffer, restrict, writeable);
}]>
];
@@ -503,10 +512,9 @@ def Bufferization_ToBufferOp : Bufferization_Op<"to_buffer", [
SameOperandsAndResultShape,
SameOperandsAndResultElementType,
Pure,
- AllShapesMatch<["memref", "tensor"]>,
- AllElementTypesMatch<["memref", "tensor"]>
+ Bufferization_TensorAndBufferMatch<"tensor", "buffer">
]> {
- let summary = "cast a tensor to memref";
+ let summary = "cast a tensor-like type to buffer-like type";
let description = [{
An operation that returns the future buffer of a `tensor`.
@@ -524,8 +532,8 @@ def Bufferization_ToBufferOp : Bufferization_Op<"to_buffer", [
the returned buffer) will not be written to.
}];
- let arguments = (ins AnyTensor:$tensor, UnitAttr:$read_only);
- let results = (outs AnyRankedOrUnrankedMemRef:$memref);
+ let arguments = (ins Bufferization_TensorLikeTypeInterface:$tensor, UnitAttr:$read_only);
+ let results = (outs Bufferization_BufferLikeTypeInterface:$buffer);
let extraClassDeclaration = [{
//===------------------------------------------------------------------===//
@@ -560,7 +568,7 @@ def Bufferization_ToBufferOp : Bufferization_Op<"to_buffer", [
}];
let assemblyFormat = [{
- $tensor (`read_only` $read_only^)? attr-dict `:` type($tensor) `to` type($memref)
+ $tensor (`read_only` $read_only^)? attr-dict `:` type($tensor) `to` type($buffer)
}];
let hasFolder = 1;
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/UnstructuredControlFlow.h b/mlir/include/mlir/Dialect/Bufferization/IR/UnstructuredControlFlow.h
index a441b8b66659e..f56c10555f02c 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/UnstructuredControlFlow.h
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/UnstructuredControlFlow.h
@@ -65,12 +65,13 @@ struct OpWithUnstructuredControlFlowBufferizableOpInterfaceExternalModel
// The operand was already bufferized. Take its type directly.
callerType = memrefType;
} else {
- FailureOr<BaseMemRefType> maybeCallerType =
+ FailureOr<BufferLikeType> maybeCallerType =
bufferization::getBufferType(opOperand->get(), options, state,
invocationStack);
if (failed(maybeCallerType))
return failure();
- callerType = *maybeCallerType;
+ assert(isa<BaseMemRefType>(*maybeCallerType) && "expected memref type");
+ callerType = cast<BaseMemRefType>(*maybeCallerType);
}
if (!bufferType) {
diff --git a/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp
index a57d58ab28d28..021a557f68b4b 100644
--- a/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -164,8 +164,8 @@ struct SelectOpInterface
// buffers have different types, they differ only in their layout map. Cast
// both of them to the most dynamic MemRef type.
if (trueBuffer.getType() != falseBuffer.getType()) {
- auto targetType =
- bufferization::getBufferType(selectOp.getResult(), options, state);
+ auto targetType = bufferization::detail::castToMemRef(
+ bufferization::getBufferType(selectOp.getResult(), options, state));
if (failed(targetType))
return failure();
if (trueBuffer.getType() != *targetType)
@@ -187,10 +187,12 @@ struct SelectOpInterface
SmallVector<Value> &invocationStack) const {
auto selectOp = cast<arith::SelectOp>(op);
assert(value == selectOp.getResult() && "invalid value");
- auto trueType = bufferization::getBufferType(
- selectOp.getTrueValue(), options, state, invocationStack);
- auto falseType = bufferization::getBufferType(
- selectOp.getFalseValue(), options, state, invocationStack);
+ auto trueType =
+ bufferization::detail::castToMemRef(bufferization::getBufferType(
+ selectOp.getTrueValue(), options, state, invocationStack));
+ auto falseType =
+ bufferization::detail::castToMemRef(bufferization::getBufferType(
+ selectOp.getFalseValue(), options, state, invocationStack));
if (failed(trueType) || failed(falseType))
return failure();
if (*trueType == *falseType)
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
index 1d6e1bdaf80f5..d00605a7b9a17 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
@@ -8,6 +8,7 @@
#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
+#include "mlir/Dialect/Bufferization/IR/BufferizationConversionInterface.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
@@ -211,8 +212,8 @@ FailureOr<Value> bufferization::allocateTensorForShapedValue(
// Add 'memory_space' attribute. Not needed if 'copy' operand is specified.
if (copy)
return allocTensorOp.getResult();
- FailureOr<BaseMemRefType> copyBufferType =
- getBufferType(tensor, options, state);
+ auto copyBufferType =
+ detail::castToMemRef(getBufferType(tensor, options, state));
if (failed(copyBufferType))
return failure();
std::optional<Attribute> memorySpace = copyBufferType->getMemorySpace();
@@ -673,28 +674,28 @@ FailureOr<Value> bufferization::getBuffer(RewriterBase &rewriter, Value value,
const BufferizationOptions &options,
const BufferizationState &state) {
#ifndef NDEBUG
- auto tensorType = llvm::dyn_cast<TensorType>(value.getType());
+ auto tensorType = llvm::dyn_cast<TensorLikeType>(value.getType());
assert(tensorType && "unexpected non-tensor type");
#endif // NDEBUG
// Replace "%t = to_tensor %m" with %m.
if (auto toTensorOp = value.getDefiningOp<bufferization::ToTensorOp>())
- return toTensorOp.getMemref();
+ return toTensorOp.getBuffer();
// Insert to_buffer op.
OpBuilder::InsertionGuard g(rewriter);
setInsertionPointAfter(rewriter, value);
- FailureOr<BaseMemRefType> memrefType = getBufferType(value, options, state);
- if (failed(memrefType))
+ FailureOr<BufferLikeType> bufferType = getBufferType(value, options, state);
+ if (failed(bufferType))
return failure();
- ensureToBufferOpIsValid(value, *memrefType);
+ ensureToBufferOpIsValid(value, *bufferType);
return rewriter
- .create<bufferization::ToBufferOp>(value.getLoc(), *memrefType, value)
+ .create<bufferization::ToBufferOp>(value.getLoc(), *bufferType, value)
.getResult();
}
/// Return the buffer type for a given Value (tensor) after bufferization.
-FailureOr<BaseMemRefType>
+FailureOr<BufferLikeType>
bufferization::getBufferType(Value value, const BufferizationOptions &options,
const BufferizationState &state) {
SmallVector<Value> invocationStack;
@@ -702,11 +703,11 @@ bufferization::getBufferType(Value value, const BufferizationOptions &options,
}
/// Return the buffer type for a given Value (tensor) after bufferization.
-FailureOr<BaseMemRefType>
+FailureOr<BufferLikeType>
bufferization::getBufferType(Value value, const BufferizationOptions &options,
const BufferizationState &state,
SmallVector<Value> &invocationStack) {
- assert(llvm::isa<TensorType>(value.getType()) &&
+ assert(llvm::isa<TensorLikeType>(value.getType()) &&
"unexpected non-tensor type");
invocationStack.push_back(value);
auto popFromStack =
@@ -718,13 +719,11 @@ bufferization::getBufferType(Value value, const BufferizationOptions &options,
if (bufferizableOp)
return bufferizableOp.getBufferType(value, options, state, invocationStack);
- // Op is not bufferizable.
- auto memSpace =
- options.defaultMemorySpaceFn(cast<TensorType>(value.getType()));
- if (!memSpace.has_value())
- return op->emitError("could not infer memory space");
-
- return getMemRefType(value, options, /*layout=*/{}, *memSpace);
+ // Op is not bufferizable, use conversion interface.
+ bufferization::ConversionInterface iface(value.getContext());
+ return iface.getBufferType(value, options, state, [&](const Twine &message) {
+ return op->emitError(message);
+ });
}
bool bufferization::hasTensorSemantics(Operation *op) {
@@ -744,12 +743,11 @@ void bufferization::replaceOpWithBufferizedValues(RewriterBase &rewriter,
SmallVector<Value> replacements;
for (OpResult opResult : op->getOpResults()) {
Value replacement = values[opResult.getResultNumber()];
- if (llvm::isa<TensorType>(opResult.getType())) {
+ if (llvm::isa<TensorLikeType>(opResult.getType())) {
// The OpResult is a tensor. Such values are replaced with memrefs during
// bufferization.
- assert((llvm::isa<MemRefType>(replacement.getType()) ||
- llvm::isa<UnrankedMemRefType>(replacement.getType())) &&
- "tensor op result should be replaced with a memref value");
+ assert(llvm::isa<BufferLikeType>(replacement.getType()) &&
+ "tensor op result should be replaced with a buffer value");
// The existing uses of the OpResult still expect a tensor. Insert a
// ToTensorOp. Throughout bufferization, this ToTensorOp will gradually
// loose all of its users and eventually DCE away.
@@ -970,8 +968,8 @@ FailureOr<BaseMemRefType> bufferization::detail::defaultGetBufferType(
// If the OpResult has an equivalent OpOperand, both OpResult and
// OpOperand bufferize to the exact same buffer type.
Value equivalentOperand = aliases.getAliases().front().opOperand->get();
- return getBufferType(equivalentOperand, options, bufferizationState,
- invocationStack);
+ return castToMemRef(getBufferType(equivalentOperand, options,
+ bufferizationState, invocationStack));
}
// If we do not know the memory space and there is no default memory space,
@@ -1031,7 +1029,7 @@ bufferization::detail::unknownGetAliasingValues(OpOperand &opOperand) {
}
bool bufferization::detail::defaultHasTensorSemantics(Operation *op) {
- auto isaTensor = [](Type t) { return isa<TensorType>(t); };
+ auto isaTensor = [](Type t) { return isa<TensorLikeType>(t); };
bool hasTensorBlockArgument = any_of(op->getRegions(), [&](Region &r) {
return any_of(r.getBlocks(), [&](Block &b) {
return any_of(b.getArguments(), [&](BlockArgument bbArg) {
@@ -1046,3 +1044,31 @@ bool bufferization::detail::defaultHasTensorSemantics(Operation *op) {
return true;
return any_of(op->getOperandTypes(), isaTensor);
}
+
+FailureOr<BaseMemRefType>
+bufferization::detail::castToMemRef(FailureOr<BufferLikeType> bufferType) {
+ if (failed(bufferType))
+ return failure();
+ assert(isa<BaseMemRefType>(*bufferType) && "expected memref type");
+ return cast<BaseMemRefType>(*bufferType);
+}
+
+bool bufferization::detail::typesMatchAfterBufferization(Operation &op,
+ Value tensor,
+ Value buffer) {
+ assert(isa<TensorLikeType>(tensor.getType()) && "expected TensorLikeType");
+ assert(isa<BufferLikeType>(buffer.getType()) && "expected BufferLikeType");
+
+ // Op is not bufferizable, use conversion interface.
+ bufferization::ConversionInterface iface(op.getContext());
+ return succeeded(iface.typesMatch(
+ cast<TensorLikeType>(tensor.getType()),
+ cast<BufferLikeType>(buffer.getType()),
+ [&](const Twine &message) { return op.emitError(message); }));
+}
+
+Type bufferization::detail::getTensorFromBuffer(Type buffer) {
+ assert(isa<BufferLikeType>(buffer) && "expected BufferLikeType");
+ bufferization::ConversionInterface iface(buffer.getContext());
+ return iface.getTensorFromBuffer(cast<BufferLikeType>(buffer));
+}
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizationConversionInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizationConversionInterface.cpp
new file mode 100644
index 0000000000000..287e9bf85002f
--- /dev/null
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizationConversionInterface.cpp
@@ -0,0 +1,67 @@
+//===- BufferizationConversionInterface.cpp - Dialect Interface ---=------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Bufferization/IR/BufferizationConversionInterface.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h" // getTensorTypeFromMemRefType
+
+namespace mlir {
+namespace bufferization {
+
+FailureOr<BufferLikeType> ConversionInterface::getBufferType(
+ Value value, const BufferizationOptions &options,
+ const BufferizationState &state,
+ function_ref<InFlightDiagnostic(const Twine &)> emitError) const {
+ Dialect *dialect = &value.getType().getDialect();
+ if (const ConversionDialectInterface *iface = getInterfaceFor(dialect))
+ return iface->getBufferType(value, options, state, emitError);
+
+ // Fall back to tensor -> memref conversion.
+ auto memSpace =
+ options.defaultMemorySpaceFn(cast<TensorType>(value.getType()));
+ if (!memSpace.has_value())
+ return emitError("could not infer memory space");
+
+ return cast<BufferLikeType>(
+ getMemRefType(value, options, /*layout=*/{}, *memSpace));
+}
+
+LogicalResult ConversionInterface::typesMatch(
+ TensorLikeType tensor, BufferLikeType buffer,
+ function_ref<InFlightDiagnostic(const Twine &)> emitError) const {
+ Dialect *dialect = &tensor.getDialect();
+ if (const ConversionDialectInterface *iface = getInterfaceFor(dialect))
+ return iface->typesMatch(tensor, buffer, emitError);
+
+ // Fall back to tensor, memref checking.
+ assert(isa<TensorType>(tensor) && "expected tensor type");
+ assert(isa<BaseMemRefType>(buffer) && "expected memref type");
+
+ if (cast<ShapedType>(tensor).getShape() !=
+ cast<ShapedType>(buffer).getShape()) {
+ return emitError("shapes do not match");
+ }
+
+ if (cast<ShapedType>(tensor).getElementType() !=
+ cast<ShapedType>(buffer).getElementType()) {
+ return emitError("element types do not match");
+ }
+
+ return success();
+}
+
+TensorLikeType
+ConversionInterface::getTensorFromBuffer(BufferLikeType buffer) const {
+ Dialect *dialect = &buffer.getDialect();
+ if (const ConversionDialectInterface *iface = getInterfaceFor(dialect))
+ return iface->getTensorFromBuffer(buffer);
+
+ return cast<TensorLikeType>(memref::getTensorTypeFromMemRefType(buffer));
+}
+
+} // namespace bufferization
+} // namespace mlir
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
index dc54ac94aed32..79af1e8fee79f 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
@@ -90,12 +90,12 @@ LogicalResult mlir::bufferization::foldToBufferToTensorPair(
if (!bufferToTensor)
return failure();
- Type srcType = bufferToTensor.getMemref().getType();
+ Type srcType = bufferToTensor.getBuffer().getType();
Type destType = toBuffer.getType();
// Directly rewrite if the type did not change.
if (srcType == destType) {
- rewriter.replaceOp(toBuffer, bufferToTensor.getMemref());
+ rewriter.replaceOp(toBuffer, bufferToTensor.getBuffer());
return success();
}
@@ -106,7 +106,7 @@ LogicalResult mlir::bufferization::foldToBufferToTensorPair(
// Ranked memref -> Ranked memref cast.
if (rankedSrcType && rankedDestType) {
FailureOr<Value> replacement = castOrReallocMemRefValue(
- rewriter, bufferToTensor.getMemref(), rankedDestType, options);
+ rewriter, bufferToTensor.getBuffer(), rankedDestType, options);
if (failed(replacement))
return failure();
@@ -124,7 +124,7 @@ LogicalResult mlir::bufferization::foldToBufferToTensorPair(
assert(memref::CastOp::areCastCompatible(srcType, destType) &&
"expected that types are cast compatible");
rewriter.replaceOpWithNewOp<memref::CastOp>(toBuffer, destType,
- bufferToTensor.getMemref());
+ bufferToTensor.getBuffer());
return success();
}
@@ -233,8 +233,9 @@ AllocTensorOp::getBufferType(Value value, const BufferizationOptions &options,
if (getMemorySpace().has_value()) {
memorySpace = *getMemorySpace();
} else if (getCopy()) {
- auto copyBufferType = bufferization::getBufferType(getCopy(), options,
- state, invocationStack);
+ auto copyBufferType =
+ bufferization::detail::castToMemRef(bufferization::getBufferType(
+ getCopy(), options, state, invocationStack));
if (failed(copyBufferType))
return failure();
memorySpace = copyBufferType->getMemorySpace();
@@ -744,7 +745,7 @@ bool ToTensorOp::isWritable(Value value, const AnalysisState &state) {
}
OpFoldResult ToTensorOp::fold(FoldAdaptor) {
- if (auto toBuffer = getMemref().getDefiningOp<ToBufferOp>())
+ if (auto toBuffer = getBuffer().getDefiningOp<ToBufferOp>())
// Approximate alias analysis by conservatively folding only when no there
// is no interleaved operation.
if (toBuffer->getBlock() == this->getOperation()->getBlock() &&
@@ -764,7 +765,7 @@ struct DimOfToTensorFolder : public OpRewritePattern<tensor::DimOp> {
return failure();
rewriter.replaceOpWithNewOp<memref::DimOp>(
- dimOp, memrefToTensorOp.getMemref(), dimOp.getIndex());
+ dimOp, memrefToTensorOp.getBuffer(), dimOp.getIndex());
return success();
}
};
@@ -781,8 +782,8 @@ void ToTensorOp::getCanonicalizationPatterns(RewritePatternSet &results,
OpFoldResult ToBufferOp::fold(FoldAdaptor) {
if (auto memrefToTensor = getTensor().getDefiningOp<ToTensorOp>())
- if (memrefToTensor.getMemref().getType() == getType())
- return memrefToTensor.getMemref();
+ if (memrefToTensor.getBuffer().getType() == getType())
+ return memrefToTensor.getBuffer();
return {};
}
diff --git a/mlir/lib/Dialect/Bufferization/IR/CMakeLists.txt b/mlir/lib/Dialect/Bufferization/IR/CMakeLists.txt
index 63dcc1eb233e9..a47c1569e4c33 100644
--- a/mlir/lib/Dialect/Bufferization/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/Bufferization/IR/CMakeLists.txt
@@ -6,6 +6,7 @@ add_mlir_dialect_library(MLIRBufferizationDialect
BufferizationDialect.cpp
BufferViewFlowOpInterface.cpp
UnstructuredControlFlow.cpp
+ BufferizationConversionInterface.cpp
ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Bufferization
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
index c7681d309a4af..e3ffa2125af70 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
@@ -412,11 +412,11 @@ bufferization::bufferizeBlockSignature(Block *block, RewriterBase &rewriter,
continue;
}
- FailureOr<BaseMemRefType> memrefType =
+ FailureOr<BufferLikeType> bufferType =
bufferization::getBufferType(bbArg, options, state);
- if (failed(memrefType))
+ if (failed(bufferType))
return failure();
- newTypes.push_back(*memrefType);
+ newTypes.push_back(*bufferType);
}
// Change the type of all block arguments.
@@ -463,7 +463,7 @@ bufferization::bufferizeBlockSignature(Block *block, RewriterBase &rewriter,
newOperands.push_back(operand);
continue;
}
- FailureOr<BaseMemRefType> operandBufferType =
+ FailureOr<BufferLikeType> operandBufferType =
bufferization::getBufferType(operand, options, state);
if (failed(operandBufferType))
return failure();
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
index a0168da44b7b3..453ed43bcadd2 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
@@ -255,7 +255,7 @@ struct CallOpInterface
}
// Returning a memref.
- FailureOr<BaseMemRefType> resultType =
+ FailureOr<BufferLikeType> resultType =
bufferization::getBufferType(result, options, state);
if (failed(resultType))
return failure();
@@ -290,13 +290,13 @@ struct CallOpInterface
// The called function was not bufferized yet. This can happen when
// there cycles in the function call graph. Compute the bufferized
// result type.
- FailureOr<BaseMemRefType> maybeMemRefType =
+ FailureOr<BufferLikeType> maybeBufferType =
bufferization::getBufferType(
funcOp.getArgument(opOperand.getOperandNumber()), options,
state);
- if (failed(maybeMemRefType))
+ if (failed(maybeBufferType))
return failure();
- memRefType = *maybeMemRefType;
+ memRefType = *maybeBufferType;
}
// Since we don't yet have a clear layout story, to_buffer may
diff --git a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
index 46fa77a7dc4e6..efa9fc1a070aa 100644
--- a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -108,7 +108,7 @@ struct ConditionOpInterface
getBuffer(rewriter, value, options, state);
if (failed(maybeBuffer))
return failure();
- FailureOr<BaseMemRefType> resultType = bufferization::getBufferType(
+ FailureOr<BufferLikeType> resultType = bufferization::getBufferType(
whileOp.getAfterArguments()[it.index()], options, state);
if (failed(resultType))
return failure();
@@ -292,8 +292,9 @@ struct IfOpInterface
// True branch was already bufferized.
thenBufferType = cast<BaseMemRefType>(thenValue.getType());
} else {
- auto maybeBufferType = bufferization::getBufferType(
- thenValue, options, state, invocationStack);
+ auto maybeBufferType =
+ bufferization::detail::castToMemRef(bufferization::getBufferType(
+ thenValue, options, state, invocationStack));
if (failed(maybeBufferType))
return failure();
thenBufferType = *maybeBufferType;
@@ -302,8 +303,9 @@ struct IfOpInterface
// False branch was already bufferized.
elseBufferType = cast<BaseMemRefType>(elseValue.getType());
} else {
- auto maybeBufferType = bufferization::getBufferType(
- elseValue, options, state, invocationStack);
+ auto maybeBufferType =
+ bufferization::detail::castToMemRef(bufferization::getBufferType(
+ elseValue, options, state, invocationStack));
if (failed(maybeBufferType))
return failure();
elseBufferType = *maybeBufferType;
@@ -406,9 +408,7 @@ struct IndexSwitchOpInterface
return bufferType;
auto maybeBufferType = bufferization::getBufferType(
yieldedValue, options, state, invocationStack);
- if (failed(maybeBufferType))
- return failure();
- return maybeBufferType;
+ return bufferization::detail::castToMemRef(maybeBufferType);
};
// Compute buffer type of the default case.
@@ -527,8 +527,8 @@ static FailureOr<BaseMemRefType> computeLoopRegionIterArgBufferType(
const BufferizationOptions &options, const BufferizationState &state,
SmallVector<Value> &invocationStack) {
// Determine the buffer type of the init_arg.
- auto initArgBufferType =
- bufferization::getBufferType(initArg, options, state, invocationStack);
+ auto initArgBufferType = bufferization::detail::castToMemRef(
+ bufferization::getBufferType(initArg, options, state, invocationStack));
if (failed(initArgBufferType))
return failure();
@@ -554,8 +554,9 @@ static FailureOr<BaseMemRefType> computeLoopRegionIterArgBufferType(
} else {
// Note: This typically triggers a recursive call for the buffer type of
// the iter_arg.
- auto maybeBufferType = bufferization::getBufferType(yieldedValue, options,
- state, invocationStack);
+ auto maybeBufferType =
+ bufferization::detail::castToMemRef(bufferization::getBufferType(
+ yieldedValue, options, state, invocationStack));
if (failed(maybeBufferType))
return failure();
yieldedValueBufferType = *maybeBufferType;
@@ -718,8 +719,12 @@ struct ForOpInterface
if (auto opResult = dyn_cast<OpResult>(value)) {
// The type of an OpResult must match the corresponding iter_arg type.
BlockArgument bbArg = forOp.getTiedLoopRegionIterArg(opResult);
- return bufferization::getBufferType(bbArg, options, state,
- invocationStack);
+ auto bufferType =
+ bufferization::getBufferType(bbArg, options, state, invocationStack);
+ if (failed(bufferType))
+ return failure();
+ assert(isa<BaseMemRefType>(*bufferType) && "expected memref type");
+ return cast<BaseMemRefType>(*bufferType);
}
// Compute result/argument number.
@@ -1078,8 +1083,8 @@ struct WhileOpInterface
// scf.condition was already bufferized.
return cast<BaseMemRefType>(conditionYieldedVal.getType());
}
- return bufferization::getBufferType(conditionYieldedVal, options, state,
- invocationStack);
+ return bufferization::detail::castToMemRef(bufferization::getBufferType(
+ conditionYieldedVal, options, state, invocationStack));
}
/// Assert that yielded values of an scf.while op are equivalent to their
@@ -1185,14 +1190,14 @@ struct YieldOpInterface
// We may have to cast the value before yielding it.
if (isa<scf::ForOp, scf::IfOp, scf::IndexSwitchOp>(
yieldOp->getParentOp())) {
- FailureOr<BaseMemRefType> resultType = bufferization::getBufferType(
+ FailureOr<BufferLikeType> resultType = bufferization::getBufferType(
yieldOp->getParentOp()->getResult(it.index()), options, state);
if (failed(resultType))
return failure();
buffer = castBuffer(rewriter, buffer, *resultType);
} else if (auto whileOp =
dyn_cast<scf::WhileOp>(yieldOp->getParentOp())) {
- FailureOr<BaseMemRefType> resultType = bufferization::getBufferType(
+ FailureOr<BufferLikeType> resultType = bufferization::getBufferType(
whileOp.getBeforeArguments()[it.index()], options, state);
if (failed(resultType))
return failure();
@@ -1307,15 +1312,15 @@ struct ForallOpInterface
if (auto bbArg = dyn_cast<BlockArgument>(value))
// A tensor block argument has the same bufferized type as the
// corresponding output operand.
- return bufferization::getBufferType(
- forallOp.getTiedOpOperand(bbArg)->get(), options, state,
- invocationStack);
+ return bufferization::detail::castToMemRef(
+ bufferization::getBufferType(forallOp.getTiedOpOperand(bbArg)->get(),
+ options, state, invocationStack));
// The bufferized result type is the same as the bufferized type of the
// corresponding output operand.
- return bufferization::getBufferType(
+ return bufferization::detail::castToMemRef(bufferization::getBufferType(
forallOp.getOutputs()[cast<OpResult>(value).getResultNumber()], options,
- state, invocationStack);
+ state, invocationStack));
}
bool isRepetitiveRegion(Operation *op, unsigned index) const {
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenUtils.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenUtils.cpp
index 57291064eba22..1bd9563b3db07 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenUtils.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenUtils.cpp
@@ -549,8 +549,8 @@ TypedValue<BaseMemRefType>
sparse_tensor::genToMemref(OpBuilder &builder, Location loc, Value tensor) {
auto tTp = llvm::cast<TensorType>(tensor.getType());
auto mTp = MemRefType::get(tTp.getShape(), tTp.getElementType());
- return builder.create<bufferization::ToBufferOp>(loc, mTp, tensor)
- .getResult();
+ return cast<TypedValue<BaseMemRefType>>(
+ builder.create<bufferization::ToBufferOp>(loc, mTp, tensor).getResult());
}
Value sparse_tensor::createOrFoldSliceOffsetOp(OpBuilder &builder, Location loc,
diff --git a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
index 4b778b768d136..40b710f17fe44 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -54,8 +54,9 @@ struct CastOpInterface
const BufferizationState &state,
SmallVector<Value> &invocationStack) const {
auto castOp = cast<tensor::CastOp>(op);
- auto maybeSrcBufferType = bufferization::getBufferType(
- castOp.getSource(), options, state, invocationStack);
+ auto maybeSrcBufferType =
+ bufferization::detail::castToMemRef(bufferization::getBufferType(
+ castOp.getSource(), options, state, invocationStack));
if (failed(maybeSrcBufferType))
return failure();
Attribute memorySpace = maybeSrcBufferType->getMemorySpace();
@@ -500,8 +501,8 @@ struct FromElementsOpInterface
/*copy=*/false);
if (failed(tensorAlloc))
return failure();
- FailureOr<BaseMemRefType> memrefType =
- bufferization::getBufferType(*tensorAlloc, options, state);
+ FailureOr<BaseMemRefType> memrefType = bufferization::detail::castToMemRef(
+ bufferization::getBufferType(*tensorAlloc, options, state));
if (failed(memrefType))
return failure();
Value buffer = rewriter.create<bufferization::ToBufferOp>(
@@ -758,8 +759,9 @@ struct PadOpInterface
SmallVector<Value> &invocationStack) const {
// Infer memory space from the source tensor.
auto padOp = cast<tensor::PadOp>(op);
- auto maybeSrcBufferType = bufferization::getBufferType(
- padOp.getSource(), options, state, invocationStack);
+ auto maybeSrcBufferType =
+ bufferization::detail::castToMemRef(bufferization::getBufferType(
+ padOp.getSource(), options, state, invocationStack));
if (failed(maybeSrcBufferType))
return failure();
MemRefLayoutAttrInterface layout;
diff --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize.mlir
index cd19e3a5e82aa..da3c26ce36ba5 100644
--- a/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize.mlir
+++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize.mlir
@@ -268,4 +268,23 @@ func.func @materialize_in_dest_raw(%f: f32, %f2: f32, %idx: index) -> (tensor<5x
%r = tensor.extract %dest_filled[%idx] : tensor<5xf32>
return %0, %r : tensor<5xf32>, f32
-}
\ No newline at end of file
+}
+
+// -----
+
+// CHECK-LABEL: func.func @test_dialect_op(
+// CHECK-SAME: %[[ARG:.*]]: !test.test_tensor<[32, 64], f64>
+// CHECK-SAME: ) -> !test.test_tensor<[32, 128], f64> {
+func.func @test_dialect_op(%arg: !test.test_tensor<[32, 64], f64>)
+ -> !test.test_tensor<[32, 128], f64> {
+ // CHECK: %[[MEMREF:.*]] = bufferization.to_buffer %[[ARG]]
+ // CHECK: %[[DUMMY:.*]] = "test.dummy_memref_op"(%[[MEMREF]])
+ // CHECK-SAME: : (!test.test_memref<[32, 64], f64>)
+ // CHECK-SAME: -> !test.test_memref<[32, 128], f64>
+ // CHECK: %[[OUT:.*]] = bufferization.to_tensor %[[DUMMY]]
+ %out = "test.dummy_tensor_op"(%arg) : (!test.test_tensor<[32, 64], f64>)
+ -> !test.test_tensor<[32, 128], f64>
+
+ // CHECK: return %[[OUT]]
+ return %out : !test.test_tensor<[32, 128], f64>
+}
diff --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp
index 1bbf2cc7481d9..03985874f910d 100644
--- a/mlir/test/lib/Dialect/Test/TestDialect.cpp
+++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp
@@ -11,6 +11,7 @@
#include "TestTypes.h"
#include "mlir/Bytecode/BytecodeImplementation.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Bufferization/IR/BufferizationConversionInterface.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/AsmState.h"
@@ -284,6 +285,53 @@ getDynamicCustomParserPrinterOp(TestDialect *dialect) {
verifier, regionVerifier, parser, printer);
}
+namespace {
+
+struct TestConverter : bufferization::ConversionDialectInterface {
+ TestConverter(Dialect *dialect)
+ : bufferization::ConversionDialectInterface(dialect) {}
+
+ FailureOr<bufferization::BufferLikeType>
+ getBufferType(Value value, const bufferization::BufferizationOptions &options,
+ const bufferization::BufferizationState &state,
+ function_ref<InFlightDiagnostic(const Twine &)> emitError)
+ const override {
+ auto testTensor = dyn_cast<TestTensorType>(value.getType());
+ if (!testTensor)
+ return emitError("expected TestTensorType");
+
+ return cast<bufferization::BufferLikeType>(
+ TestMemrefType::get(value.getContext(), testTensor.getShape(),
+ testTensor.getElementType(), nullptr));
+ }
+
+ LogicalResult typesMatch(bufferization::TensorLikeType tensor,
+ bufferization::BufferLikeType buffer,
+ function_ref<InFlightDiagnostic(const Twine &)>
+ emitError) const override {
+ auto testTensor = dyn_cast<TestTensorType>(tensor);
+ auto testMemref = dyn_cast<TestMemrefType>(buffer);
+ if (!testTensor || !testMemref)
+ return emitError("expected TestTensorType and TestMemrefType");
+
+ const bool valid =
+ testTensor.getShape() == testMemref.getShape() &&
+ testTensor.getElementType() == testMemref.getElementType();
+ return success(valid);
+ }
+
+ bufferization::TensorLikeType
+ getTensorFromBuffer(bufferization::BufferLikeType buffer) const override {
+ auto testMemref = dyn_cast<TestMemrefType>(buffer);
+ assert(testMemref && "expected TestMemrefType");
+ return cast<bufferization::TensorLikeType>(
+ TestTensorType::get(testMemref.getContext(), testMemref.getShape(),
+ testMemref.getElementType()));
+ }
+};
+
+} // namespace
+
//===----------------------------------------------------------------------===//
// TestDialect
//===----------------------------------------------------------------------===//
@@ -333,6 +381,7 @@ void TestDialect::initialize() {
registerDynamicOp(getDynamicCustomParserPrinterOp(this));
registerInterfaces();
allowUnknownOperations();
+ addInterface<TestConverter>();
// Instantiate our fallback op interface that we'll use on specific
// unregistered op.
diff --git a/mlir/test/lib/Dialect/Test/TestOpDefs.cpp b/mlir/test/lib/Dialect/Test/TestOpDefs.cpp
index b5a8bd10d6b68..78e44c6ec7a9b 100644
--- a/mlir/test/lib/Dialect/Test/TestOpDefs.cpp
+++ b/mlir/test/lib/Dialect/Test/TestOpDefs.cpp
@@ -8,6 +8,7 @@
#include "TestDialect.h"
#include "TestOps.h"
+#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/Verifier.h"
#include "mlir/Interfaces/FunctionImplementation.h"
@@ -1387,3 +1388,25 @@ TestMultiSlotAlloca::handleDestructuringComplete(
const DestructurableMemorySlot &slot, OpBuilder &builder) {
return createNewMultiAllocaWithoutSlot(slot, builder, *this);
}
+
+::mlir::LogicalResult test::TestDummyTensorOp::bufferize(
+ ::mlir::RewriterBase &rewriter,
+ const ::mlir::bufferization::BufferizationOptions &options,
+ ::mlir::bufferization::BufferizationState &state) {
+ auto buffer =
+ mlir::bufferization::getBuffer(rewriter, getInput(), options, state);
+ if (mlir::failed(buffer))
+ return failure();
+
+ const auto outType = getOutput().getType();
+ const auto bufferizedOutType = test::TestMemrefType::get(
+ getContext(), outType.getShape(), outType.getElementType(), nullptr);
+ // replace op with memref analogy
+ auto dummyMemrefOp = rewriter.create<test::TestDummyMemrefOp>(
+ getLoc(), bufferizedOutType, *buffer);
+
+ mlir::bufferization::replaceOpWithBufferizedValues(rewriter, getOperation(),
+ dummyMemrefOp.getResult());
+
+ return mlir::success();
+}
diff --git a/mlir/test/lib/Dialect/Test/TestOps.h b/mlir/test/lib/Dialect/Test/TestOps.h
index c2ee5f9ab9a57..b414b47c87425 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.h
+++ b/mlir/test/lib/Dialect/Test/TestOps.h
@@ -13,6 +13,7 @@
#include "TestInterfaces.h"
#include "TestTypes.h"
#include "mlir/Bytecode/BytecodeImplementation.h"
+#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
#include "mlir/Dialect/DLTI/DLTI.h"
#include "mlir/Dialect/DLTI/Traits.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index 59330fdb1bb2c..79bcd9c2e0a9a 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -31,7 +31,7 @@ include "mlir/Interfaces/InferTypeOpInterface.td"
include "mlir/Interfaces/LoopLikeInterface.td"
include "mlir/Interfaces/MemorySlotInterfaces.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
-
+include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td"
// Include the attribute definitions.
include "TestAttrDefs.td"
@@ -2825,7 +2825,7 @@ def TestNVVMRequiresSMArchCondOp :
let assemblyFormat = "attr-dict";
}
-def TestNVVMRequirestSMArchCondMultiOp :
+def TestNVVMRequirestSMArchCondMultiOp :
TEST_Op<"nvvm_requires_sm_90a_or_sm_100a", [NVVMRequiresSMa<[90, 100]>]> {
let arguments = (ins );
let assemblyFormat = "attr-dict";
@@ -3552,4 +3552,58 @@ def TestAllocWithMultipleResults : TEST_Op<"alloc_with_multiple_results"> {
}];
}
+//===----------------------------------------------------------------------===//
+// Test Ops bufferization
+//===----------------------------------------------------------------------===//
+
+def TestDummyTensorOp : TEST_Op<"dummy_tensor_op", [BufferizableOpInterface]> {
+ let arguments = (ins
+ Arg<TestTensorType>:$input
+ );
+ let results = (outs
+ Arg<TestTensorType>:$output
+ );
+ let extraClassDeclaration = [{
+ // BufferizableOpInterface
+ bool bufferizesToMemoryRead(mlir::OpOperand&,
+ const mlir::bufferization::AnalysisState&);
+
+ bool bufferizesToMemoryWrite(mlir::OpOperand&,
+ const mlir::bufferization::AnalysisState&);
+
+ mlir::bufferization::AliasingValueList getAliasingValues(mlir::OpOperand&,
+ const mlir::bufferization::AnalysisState&);
+
+ mlir::LogicalResult bufferize(
+ mlir::RewriterBase& rewriter,
+ const mlir::bufferization::BufferizationOptions& options,
+ mlir::bufferization::BufferizationState &state);
+ }];
+
+ let extraClassDefinition = [{
+ bool test::TestDummyTensorOp::bufferizesToMemoryRead(::mlir::OpOperand&,
+ const ::mlir::bufferization::AnalysisState&) {
+ return true;
+ }
+ bool test::TestDummyTensorOp::bufferizesToMemoryWrite(::mlir::OpOperand&,
+ const ::mlir::bufferization::AnalysisState&) {
+ return true;
+ }
+ ::mlir::bufferization::AliasingValueList
+ test::TestDummyTensorOp::getAliasingValues(::mlir::OpOperand&,
+ const ::mlir::bufferization::AnalysisState&) {
+ return {};
+ }
+ }];
+}
+
+def TestDummyMemrefOp : TEST_Op<"dummy_memref_op", []> {
+ let arguments = (ins
+ Arg<TestMemrefType>:$input
+ );
+ let results = (outs
+ Arg<TestMemrefType>:$output
+ );
+}
+
#endif // TEST_OPS
More information about the Mlir-commits
mailing list