[Mlir-commits] [mlir] [mlir][bufferization] Support custom types (1/N) (PR #142986)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Jun 5 08:11:20 PDT 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-tensor
@llvm/pr-subscribers-mlir-bufferization
Author: Andrei Golubev (andrey-golubev)
<details>
<summary>Changes</summary>
Following the introduction of TensorLike and BufferLike type interfaces (see 00eaff3e9c897c263a879416d0f151d7ca7eeaff), introduce minimal changes required to bufferize a custom tensor operation into a custom buffer operation.
To achieve this, a new conversion dialect interface is added that abstracts away the differences between existing (tensor -> memref) and custom conversions.
The scope of the changes is intentionally limited (for example, BufferizableOpInterface is untouched) in order to first understand the basics and reach consensus design-wise.
---
Patch is 49.87 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/142986.diff
19 Files Affected:
- (modified) mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h (+15-2)
- (added) mlir/include/mlir/Dialect/Bufferization/IR/BufferizationConversionInterface.h (+72)
- (modified) mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td (+28-20)
- (modified) mlir/include/mlir/Dialect/Bufferization/IR/UnstructuredControlFlow.h (+3-2)
- (modified) mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp (+8-6)
- (modified) mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp (+51-25)
- (added) mlir/lib/Dialect/Bufferization/IR/BufferizationConversionInterface.cpp (+67)
- (modified) mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp (+11-10)
- (modified) mlir/lib/Dialect/Bufferization/IR/CMakeLists.txt (+1)
- (modified) mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp (+4-4)
- (modified) mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp (+4-4)
- (modified) mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp (+28-23)
- (modified) mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenUtils.cpp (+2-2)
- (modified) mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp (+8-6)
- (modified) mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize.mlir (+20-1)
- (modified) mlir/test/lib/Dialect/Test/TestDialect.cpp (+49)
- (modified) mlir/test/lib/Dialect/Test/TestOpDefs.cpp (+23)
- (modified) mlir/test/lib/Dialect/Test/TestOps.h (+1)
- (modified) mlir/test/lib/Dialect/Test/TestOps.td (+56-2)
``````````diff
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
index adccbef754ec5..8390da956444d 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
@@ -17,6 +17,7 @@
#include <optional>
#include "mlir/Dialect/Bufferization/IR/BufferizationEnums.h.inc"
+#include "mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.h"
namespace mlir {
class OpBuilder;
@@ -615,7 +616,7 @@ FailureOr<Value> getBuffer(RewriterBase &rewriter, Value value,
/// IR, this function can be used.
///
/// This function is a wrapper around BufferizableOpInterface::getBufferType.
-FailureOr<BaseMemRefType> getBufferType(Value value,
+FailureOr<BufferLikeType> getBufferType(Value value,
const BufferizationOptions &options,
const BufferizationState &state);
@@ -629,7 +630,7 @@ FailureOr<BaseMemRefType> getBufferType(Value value,
/// IR, this function can be used.
///
/// This function is a wrapper around `BufferizableOpInterface::getBufferType`.
-FailureOr<BaseMemRefType> getBufferType(Value value,
+FailureOr<BufferLikeType> getBufferType(Value value,
const BufferizationOptions &options,
const BufferizationState &state,
SmallVector<Value> &invocationStack);
@@ -738,6 +739,18 @@ AliasingValueList unknownGetAliasingValues(OpOperand &opOperand);
/// This is the default implementation of
/// BufferizableOpInterface::hasTensorSemantics
bool defaultHasTensorSemantics(Operation *op);
+
+/// This is a helper function used when buffer type is guaranteed to be memref.
+FailureOr<BaseMemRefType> castToMemRef(FailureOr<BufferLikeType> bufferType);
+
+/// This function is a free-standing helper that relies on
+/// bufferization::ConversionInterface to verify the types in tensor and buffer
+/// worlds match.
+bool typesMatchAfterBufferization(Operation &op, Value tensor, Value buffer);
+
+/// This function is a free-standing helper that relies on
+/// bufferization::ConversionInterface to perform the conversion.
+Type getTensorFromBuffer(Type buffer);
} // namespace detail
} // namespace bufferization
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationConversionInterface.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationConversionInterface.h
new file mode 100644
index 0000000000000..4164d1dcb9ea6
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationConversionInterface.h
@@ -0,0 +1,72 @@
+//===- BufferizationConversionInterface.h - Dialect Interface ---*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_BUFFERIZATION_IR_BUFFERIZATIONCONVERSIONINTERFACE_H_
+#define MLIR_DIALECT_BUFFERIZATION_IR_BUFFERIZATIONCONVERSIONINTERFACE_H_
+
+#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
+#include "mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.h"
+#include "mlir/IR/DialectInterface.h"
+
+namespace mlir {
+namespace bufferization {
+
+/// This class defines a virtual interface for conversions between tensor-like
+/// and buffer-like types.
+struct ConversionDialectInterface
+ : DialectInterface::Base<ConversionDialectInterface> {
+ using Base::Base;
+
+ /// Hook to customize tensor-like -> buffer-like conversion within a given
+ /// dialect. Returns a buffer-like type for the specific tensor-like type.
+ virtual FailureOr<BufferLikeType> getBufferType(
+ Value value, const BufferizationOptions &options,
+ const BufferizationState &state,
+ function_ref<InFlightDiagnostic(const Twine &)> emitError) const = 0;
+
+ /// Hook to customize type checking between tensor-like and buffer-like types.
+ /// Given tensor `T` and buffer `B = getBufferType(T, ...)`, the call to
+ /// `typesMatch(T, B)` must return true.
+ virtual LogicalResult typesMatch(
+ TensorLikeType tensor, BufferLikeType buffer,
+ function_ref<InFlightDiagnostic(const Twine &)> emitError) const = 0;
+
+ /// Hook to customize buffer-like -> tensor-like conversion, which is the
+ /// opposite of bufferization.
+ virtual TensorLikeType getTensorFromBuffer(BufferLikeType buffer) const = 0;
+};
+
+/// Interface collection for conversion between tensor-like and buffer-like
+/// types, dispatches to a concrete interface implementation based on the
+/// dialect to which the given type belongs.
+struct ConversionInterface
+ : DialectInterfaceCollection<ConversionDialectInterface> {
+ using Base::Base;
+
+ /// Dispatches to ConversionDialectInterface::getBufferType() of the dialect
+ /// associated with the value type.
+ FailureOr<BufferLikeType> getBufferType(
+ Value value, const BufferizationOptions &options,
+ const BufferizationState &state,
+ function_ref<InFlightDiagnostic(const Twine &)> emitError) const;
+
+ /// Dispatches to ConversionDialectInterface::typesMatch() of the dialect
+ /// associated with the value type.
+ LogicalResult
+ typesMatch(TensorLikeType tensor, BufferLikeType buffer,
+ function_ref<InFlightDiagnostic(const Twine &)> emitError) const;
+
+ /// Dispatches to ConversionDialectInterface::getTensorFromBuffer() of the
+ /// dialect associated with the value type.
+ TensorLikeType getTensorFromBuffer(BufferLikeType buffer) const;
+};
+
+} // namespace bufferization
+} // namespace mlir
+
+#endif // MLIR_DIALECT_BUFFERIZATION_IR_BUFFERIZATIONTYPEINTERFACES_H_
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
index 3d4dcdee2663b..277d56bc3f647 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
@@ -12,6 +12,7 @@
include "mlir/Dialect/Bufferization/IR/AllocationOpInterface.td"
include "mlir/Dialect/Bufferization/IR/BufferViewFlowOpInterface.td"
include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td"
+include "mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.td"
include "mlir/Dialect/Bufferization/IR/BufferizationBase.td"
include "mlir/Interfaces/DestinationStyleOpInterface.td"
include "mlir/Interfaces/InferTypeOpInterface.td"
@@ -387,20 +388,28 @@ def Bufferization_DeallocTensorOp : Bufferization_Op<"dealloc_tensor",
// ToTensorOp
//===----------------------------------------------------------------------===//
+class Bufferization_TensorAndBufferMatch<string tensor, string buffer> : PredOpTrait<
+ "specified tensor and buffer types match",
+ CPred<
+ "::mlir::bufferization::detail::typesMatchAfterBufferization("
+ "$_op, $" # tensor # ", $" # buffer #")"
+ >
+>;
+
def Bufferization_ToTensorOp : Bufferization_Op<"to_tensor", [
BufferizableOpInterface,
SameOperandsAndResultShape,
SameOperandsAndResultElementType,
- AllElementTypesMatch<["memref", "result"]>
+ Bufferization_TensorAndBufferMatch<"result", "buffer">
]> {
- let summary = "create a tensor from a `memref`";
+ let summary = "create a buffer-like type from a tensor-like type";
let description = [{
- An operation that creates a tensor from a `memref`. The result value is a
- tensor whose shape and element type match the memref operand.
+ An operation that creates a tensor from a buffer. The result value is a
+ tensor-like type whose shape and element type match the buffer-like operand.
The opposite of this op is `to_buffer`. Together, these two ops are
useful for source/target materializations when doing type conversions
- involving tensors and memrefs.
+ involving tensors and buffers.
Example:
@@ -442,11 +451,11 @@ def Bufferization_ToTensorOp : Bufferization_Op<"to_tensor", [
away. However, such IR is no longer bufferizable with One-Shot Bufferize.
}];
- let arguments = (ins Arg<AnyRankedOrUnrankedMemRef,
+ let arguments = (ins Arg<Bufferization_BufferLikeTypeInterface,
"the reference to load from",
- [MemReadAt<0, FullEffect>]>:$memref,
+ [MemReadAt<0, FullEffect>]>:$buffer,
UnitAttr:$restrict, UnitAttr:$writable);
- let results = (outs AnyTensor:$result);
+ let results = (outs Bufferization_TensorLikeTypeInterface:$result);
let extraClassDeclaration = [{
/// The result of a to_tensor is always a tensor.
@@ -473,19 +482,19 @@ def Bufferization_ToTensorOp : Bufferization_Op<"to_tensor", [
FailureOr<BaseMemRefType> getBufferType(
Value value, const BufferizationOptions &options,
const BufferizationState &state, SmallVector<Value> &invocationStack) {
- return ::llvm::cast<BaseMemRefType>(getMemref().getType());
+ return ::llvm::cast<BaseMemRefType>(getBuffer().getType());
}
}];
let assemblyFormat = [{
- $memref (`restrict` $restrict^)? (`writable` $writable^)? attr-dict
- `:` type($memref) `to` type($result)
+ $buffer (`restrict` $restrict^)? (`writable` $writable^)? attr-dict
+ `:` type($buffer) `to` type($result)
}];
let builders = [
- OpBuilder<(ins "Value":$memref, CArg<"bool", "false">:$restrict, CArg<"bool", "false">:$writeable), [{
- auto rtt = memref::getTensorTypeFromMemRefType(memref.getType());
- build($_builder, $_state, rtt, memref, restrict, writeable);
+ OpBuilder<(ins "Value":$buffer, CArg<"bool", "false">:$restrict, CArg<"bool", "false">:$writeable), [{
+ auto rtt = bufferization::detail::getTensorFromBuffer(buffer.getType());
+ build($_builder, $_state, rtt, buffer, restrict, writeable);
}]>
];
@@ -503,10 +512,9 @@ def Bufferization_ToBufferOp : Bufferization_Op<"to_buffer", [
SameOperandsAndResultShape,
SameOperandsAndResultElementType,
Pure,
- AllShapesMatch<["memref", "tensor"]>,
- AllElementTypesMatch<["memref", "tensor"]>
+ Bufferization_TensorAndBufferMatch<"tensor", "buffer">
]> {
- let summary = "cast a tensor to memref";
+ let summary = "cast a tensor-like type to buffer-like type";
let description = [{
An operation that returns the future buffer of a `tensor`.
@@ -524,8 +532,8 @@ def Bufferization_ToBufferOp : Bufferization_Op<"to_buffer", [
the returned buffer) will not be written to.
}];
- let arguments = (ins AnyTensor:$tensor, UnitAttr:$read_only);
- let results = (outs AnyRankedOrUnrankedMemRef:$memref);
+ let arguments = (ins Bufferization_TensorLikeTypeInterface:$tensor, UnitAttr:$read_only);
+ let results = (outs Bufferization_BufferLikeTypeInterface:$buffer);
let extraClassDeclaration = [{
//===------------------------------------------------------------------===//
@@ -560,7 +568,7 @@ def Bufferization_ToBufferOp : Bufferization_Op<"to_buffer", [
}];
let assemblyFormat = [{
- $tensor (`read_only` $read_only^)? attr-dict `:` type($tensor) `to` type($memref)
+ $tensor (`read_only` $read_only^)? attr-dict `:` type($tensor) `to` type($buffer)
}];
let hasFolder = 1;
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/UnstructuredControlFlow.h b/mlir/include/mlir/Dialect/Bufferization/IR/UnstructuredControlFlow.h
index a441b8b66659e..f56c10555f02c 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/UnstructuredControlFlow.h
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/UnstructuredControlFlow.h
@@ -65,12 +65,13 @@ struct OpWithUnstructuredControlFlowBufferizableOpInterfaceExternalModel
// The operand was already bufferized. Take its type directly.
callerType = memrefType;
} else {
- FailureOr<BaseMemRefType> maybeCallerType =
+ FailureOr<BufferLikeType> maybeCallerType =
bufferization::getBufferType(opOperand->get(), options, state,
invocationStack);
if (failed(maybeCallerType))
return failure();
- callerType = *maybeCallerType;
+ assert(isa<BaseMemRefType>(*maybeCallerType) && "expected memref type");
+ callerType = cast<BaseMemRefType>(*maybeCallerType);
}
if (!bufferType) {
diff --git a/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp
index a57d58ab28d28..021a557f68b4b 100644
--- a/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -164,8 +164,8 @@ struct SelectOpInterface
// buffers have different types, they differ only in their layout map. Cast
// both of them to the most dynamic MemRef type.
if (trueBuffer.getType() != falseBuffer.getType()) {
- auto targetType =
- bufferization::getBufferType(selectOp.getResult(), options, state);
+ auto targetType = bufferization::detail::castToMemRef(
+ bufferization::getBufferType(selectOp.getResult(), options, state));
if (failed(targetType))
return failure();
if (trueBuffer.getType() != *targetType)
@@ -187,10 +187,12 @@ struct SelectOpInterface
SmallVector<Value> &invocationStack) const {
auto selectOp = cast<arith::SelectOp>(op);
assert(value == selectOp.getResult() && "invalid value");
- auto trueType = bufferization::getBufferType(
- selectOp.getTrueValue(), options, state, invocationStack);
- auto falseType = bufferization::getBufferType(
- selectOp.getFalseValue(), options, state, invocationStack);
+ auto trueType =
+ bufferization::detail::castToMemRef(bufferization::getBufferType(
+ selectOp.getTrueValue(), options, state, invocationStack));
+ auto falseType =
+ bufferization::detail::castToMemRef(bufferization::getBufferType(
+ selectOp.getFalseValue(), options, state, invocationStack));
if (failed(trueType) || failed(falseType))
return failure();
if (*trueType == *falseType)
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
index 1d6e1bdaf80f5..d00605a7b9a17 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
@@ -8,6 +8,7 @@
#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
+#include "mlir/Dialect/Bufferization/IR/BufferizationConversionInterface.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
@@ -211,8 +212,8 @@ FailureOr<Value> bufferization::allocateTensorForShapedValue(
// Add 'memory_space' attribute. Not needed if 'copy' operand is specified.
if (copy)
return allocTensorOp.getResult();
- FailureOr<BaseMemRefType> copyBufferType =
- getBufferType(tensor, options, state);
+ auto copyBufferType =
+ detail::castToMemRef(getBufferType(tensor, options, state));
if (failed(copyBufferType))
return failure();
std::optional<Attribute> memorySpace = copyBufferType->getMemorySpace();
@@ -673,28 +674,28 @@ FailureOr<Value> bufferization::getBuffer(RewriterBase &rewriter, Value value,
const BufferizationOptions &options,
const BufferizationState &state) {
#ifndef NDEBUG
- auto tensorType = llvm::dyn_cast<TensorType>(value.getType());
+ auto tensorType = llvm::dyn_cast<TensorLikeType>(value.getType());
assert(tensorType && "unexpected non-tensor type");
#endif // NDEBUG
// Replace "%t = to_tensor %m" with %m.
if (auto toTensorOp = value.getDefiningOp<bufferization::ToTensorOp>())
- return toTensorOp.getMemref();
+ return toTensorOp.getBuffer();
// Insert to_buffer op.
OpBuilder::InsertionGuard g(rewriter);
setInsertionPointAfter(rewriter, value);
- FailureOr<BaseMemRefType> memrefType = getBufferType(value, options, state);
- if (failed(memrefType))
+ FailureOr<BufferLikeType> bufferType = getBufferType(value, options, state);
+ if (failed(bufferType))
return failure();
- ensureToBufferOpIsValid(value, *memrefType);
+ ensureToBufferOpIsValid(value, *bufferType);
return rewriter
- .create<bufferization::ToBufferOp>(value.getLoc(), *memrefType, value)
+ .create<bufferization::ToBufferOp>(value.getLoc(), *bufferType, value)
.getResult();
}
/// Return the buffer type for a given Value (tensor) after bufferization.
-FailureOr<BaseMemRefType>
+FailureOr<BufferLikeType>
bufferization::getBufferType(Value value, const BufferizationOptions &options,
const BufferizationState &state) {
SmallVector<Value> invocationStack;
@@ -702,11 +703,11 @@ bufferization::getBufferType(Value value, const BufferizationOptions &options,
}
/// Return the buffer type for a given Value (tensor) after bufferization.
-FailureOr<BaseMemRefType>
+FailureOr<BufferLikeType>
bufferization::getBufferType(Value value, const BufferizationOptions &options,
const BufferizationState &state,
SmallVector<Value> &invocationStack) {
- assert(llvm::isa<TensorType>(value.getType()) &&
+ assert(llvm::isa<TensorLikeType>(value.getType()) &&
"unexpected non-tensor type");
invocationStack.push_back(value);
auto popFromStack =
@@ -718,13 +719,11 @@ bufferization::getBufferType(Value value, const BufferizationOptions &options,
if (bufferizableOp)
return bufferizableOp.getBufferType(value, options, state, invocationStack);
- // Op is not bufferizable.
- auto memSpace =
- options.defaultMemorySpaceFn(cast<TensorType>(value.getType()));
- if (!memSpace.has_value())
- return op->emitError("could not infer memory space");
-
- return getMemRefType(value, options, /*layout=*/{}, *memSpace);
+ // Op is not bufferizable, use conversion interface.
+ bufferization::ConversionInterface iface(value.getContext());
+ return iface.getBufferType(value, options, state, [&](const Twine &message) {
+ return op->emitError(message);
+ });
}
bool bufferization::hasTensorSemantics(Operation *op) {
@@ -744,12 +743,11 @@ void bufferization::replaceOpWithBufferizedValues(RewriterBase &rewriter,
SmallVector<Value> replacements;
for (OpResult opResult : op->getOpResults()) {
Value replacement = values[opResult.getResultNumber()];
- if (llvm::isa<TensorType>(opResult.getType())) {
+ if (llvm::isa<TensorLikeType>(opResult.getType())) {
// The OpResult is a tensor. Such values are replaced with memrefs during
// bufferization.
- assert((llvm::isa<MemRefType>(replacement.getType()) ||
- llvm::isa<UnrankedMemRefType>(replacement.getType())) &&
- "tensor op result should be replaced with a memref value");
+ assert(llvm::isa<BufferLikeType>(replacement.getType()) &&
+ "tensor op result should be replaced with a buffer value");
// The existing uses of the OpResult still expect a tensor. Insert a
// ToTensorOp. Throughout bufferization, this ToTensorOp will gradually
// loose all of its users and eventually DCE away.
@@ -970,8 +968,8 @@ FailureOr<BaseMemRefType> bufferization::detail::defaultGetBufferType(
// If the OpResult has an equivalent OpOperand, both OpResult and
// OpOperand bufferize to the exact same buffer type.
Value equivalentOperand = aliases.getAliases().front().opOperand->get();
- return getBufferType(equivalentOperand, options, bufferizationState,
- invocationStack);
+ return castToMemRef(getBufferType(equivalentOperand, options,
+ bufferizationState, invocationStack));
}
// If we do not know the memory space and there is no default memory space,
@@ -1031,7 +1029,7 @@ bufferization::detail::unknownGe...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/142986
More information about the Mlir-commits
mailing list