[Mlir-commits] [mlir] [mlir][bufferization] Use TensorLike, BufferLike type interfaces (PR #136736)
Andrei Golubev
llvmlistbot at llvm.org
Tue Apr 22 11:02:40 PDT 2025
https://github.com/andrey-golubev created https://github.com/llvm/llvm-project/pull/136736
The general idea is to replace most of the places that rely on builtin's TensorType / BaseMemRefType with the newly added type interfaces.
Thus far, do the bare minimum: refactor (almost) "blindly" the API of the dialect and options, leaving most of the logic "as is". The exceptions are the bufferization.{to_tensor, to_memref} ops that act as "glue" when bufferizing neighbouring operations and the enclosing functions.
>From fe90c52e99e4655eeabf7985944953e66dda6565 Mon Sep 17 00:00:00 2001
From: "Golubev, Andrey" <andrey.golubev at intel.com>
Date: Thu, 17 Apr 2025 15:36:01 +0000
Subject: [PATCH] [mlir][bufferization] Use TensorLike, BufferLike type
interfaces
The general idea is to replace most of the places that rely on builtin's
TensorType / BaseMemRefType with the newly added type interfaces.
Thus far, do the bare minimum: refactor (almost) "blindly" the API of
the dialect and options, leaving most of the logic "as is". The
exceptions are the bufferization.{to_tensor, to_memref} ops that act as
"glue" when bufferizing neighbouring operations and the enclosing
functions.
---
.../IR/BufferizableOpInterface.h | 21 ++--
.../IR/BufferizableOpInterface.td | 2 +-
.../Bufferization/IR/BufferizationOps.td | 17 +--
.../IR/BufferizationTypeInterfaces.h | 1 +
.../IR/BufferizationTypeInterfaces.td | 13 ++-
.../IR/UnstructuredControlFlow.h | 35 +++---
.../BufferizableOpInterfaceImpl.cpp | 13 ++-
.../IR/BufferizableOpInterface.cpp | 94 ++++++++-------
.../Bufferization/IR/BufferizationDialect.cpp | 6 +-
.../Bufferization/IR/BufferizationOps.cpp | 30 ++---
.../IR/BufferizationTypeInterfaces.cpp | 21 ++++
.../Dialect/Bufferization/IR/CMakeLists.txt | 1 +
.../Transforms/BufferViewFlowAnalysis.cpp | 17 +--
.../Bufferization/Transforms/Bufferize.cpp | 18 +--
.../FuncBufferizableOpInterfaceImpl.cpp | 30 ++---
.../BufferizableOpInterfaceImpl.cpp | 109 ++++++++++--------
.../SparsificationAndBufferizationPass.cpp | 4 +-
.../Transforms/Utils/CodegenUtils.cpp | 4 +-
.../BufferizableOpInterfaceImpl.cpp | 9 +-
.../Transforms/one-shot-bufferize.mlir | 21 +++-
mlir/test/Dialect/Bufferization/invalid.mlir | 8 +-
.../Bufferization/TestTensorCopyInsertion.cpp | 4 +-
mlir/test/lib/Dialect/Test/TestOpDefs.cpp | 24 ++++
mlir/test/lib/Dialect/Test/TestOps.h | 1 +
mlir/test/lib/Dialect/Test/TestOps.td | 55 ++++++++-
mlir/test/lib/Dialect/Test/TestTypeDefs.td | 3 +
26 files changed, 370 insertions(+), 191 deletions(-)
create mode 100644 mlir/lib/Dialect/Bufferization/IR/BufferizationTypeInterfaces.cpp
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
index ada9539e87121..70092908d961f 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
@@ -17,6 +17,7 @@
#include <optional>
#include "mlir/Dialect/Bufferization/IR/BufferizationEnums.h.inc"
+#include "mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.h"
namespace mlir {
class OpBuilder;
@@ -259,18 +260,18 @@ struct BufferizationOptions {
std::function<LogicalResult(OpBuilder &, Location, Value, Value)>;
/// Initializer function for analysis state.
using AnalysisStateInitFn = std::function<void(AnalysisState &)>;
- /// Tensor -> MemRef type converter.
- /// Parameters: tensor type, memory space, func op, bufferization options
+ /// TensorLike -> BufferLike type converter.
+ /// Parameters: tensor like type, memory space, func op, bufferization options
using FunctionArgTypeConverterFn =
- std::function<BaseMemRefType(TensorType, Attribute memorySpace,
+ std::function<BufferLikeType(TensorLikeType, Attribute memorySpace,
func::FuncOp, const BufferizationOptions &)>;
- /// Tensor -> MemRef type converter.
+ /// TensorLike -> BufferLike type converter.
/// Parameters: Value, memory space, bufferization options
- using UnknownTypeConverterFn = std::function<BaseMemRefType(
+ using UnknownTypeConverterFn = std::function<BufferLikeType(
Value, Attribute memorySpace, const BufferizationOptions &)>;
// Produce a MemorySpace attribute from a tensor type
using DefaultMemorySpaceFn =
- std::function<std::optional<Attribute>(TensorType t)>;
+ std::function<std::optional<Attribute>(TensorLikeType t)>;
BufferizationOptions();
@@ -360,7 +361,7 @@ struct BufferizationOptions {
// Returning std::nullopt will cause bufferization to fail (useful to indicate
// failure to determine memory space for a tensor type).
DefaultMemorySpaceFn defaultMemorySpaceFn =
- [](TensorType t) -> std::optional<Attribute> { return Attribute(); };
+ [](TensorLikeType t) -> std::optional<Attribute> { return Attribute(); };
/// If set to `true`, the analysis is skipped. A buffer is copied before every
/// write. This flag cannot be used together with `testAnalysisOnly = true`.
@@ -600,7 +601,7 @@ FailureOr<Value> getBuffer(RewriterBase &rewriter, Value value,
/// IR, this function can be used.
///
/// This function is a wrapper around BufferizableOpInterface::getBufferType.
-FailureOr<BaseMemRefType> getBufferType(Value value,
+FailureOr<BufferLikeType> getBufferType(Value value,
const BufferizationOptions &options);
/// Return the buffer type for a given Value (tensor) after bufferization
@@ -613,7 +614,7 @@ FailureOr<BaseMemRefType> getBufferType(Value value,
/// IR, this function can be used.
///
/// This function is a wrapper around `BufferizableOpInterface::getBufferType`.
-FailureOr<BaseMemRefType> getBufferType(Value value,
+FailureOr<BufferLikeType> getBufferType(Value value,
const BufferizationOptions &options,
SmallVector<Value> &invocationStack);
@@ -693,7 +694,7 @@ AliasingOpOperandList defaultGetAliasingOpOperands(Value value,
/// This is the default implementation of
/// BufferizableOpInterface::getBufferType. Should not be called from other
/// places.
-FailureOr<BaseMemRefType>
+FailureOr<BufferLikeType>
defaultGetBufferType(Value value, const BufferizationOptions &options,
SmallVector<Value> &invocationStack);
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td
index 95022d7d665d2..1de1742fab81a 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td
@@ -518,7 +518,7 @@ def BufferizableOpInterface : OpInterface<"BufferizableOpInterface"> {
Note: This interface method should never be called directly from user
code. Always use `bufferization::getBufferType`.
}],
- /*retType=*/"::mlir::FailureOr<::mlir::BaseMemRefType>",
+ /*retType=*/"::mlir::FailureOr<::mlir::bufferization::BufferLikeType>",
/*methodName=*/"getBufferType",
/*args=*/(ins "::mlir::Value":$value,
"const ::mlir::bufferization::BufferizationOptions &":$options,
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
index fad78a63444b9..81ce0f3fb650b 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
@@ -13,6 +13,7 @@ include "mlir/Dialect/Bufferization/IR/AllocationOpInterface.td"
include "mlir/Dialect/Bufferization/IR/BufferViewFlowOpInterface.td"
include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td"
include "mlir/Dialect/Bufferization/IR/BufferizationBase.td"
+include "mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.td"
include "mlir/Interfaces/DestinationStyleOpInterface.td"
include "mlir/Interfaces/InferTypeOpInterface.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
@@ -109,7 +110,7 @@ def Bufferization_AllocTensorOp : Bufferization_Op<"alloc_tensor",
AliasingValueList getAliasingValues(
OpOperand &opOperand, const AnalysisState &state);
- FailureOr<BaseMemRefType> getBufferType(
+ FailureOr<BufferLikeType> getBufferType(
Value value, const BufferizationOptions &options,
SmallVector<Value> &invocationStack);
@@ -438,11 +439,11 @@ def Bufferization_ToTensorOp : Bufferization_Op<"to_tensor", [
away. However, such IR is no longer bufferizable with One-Shot Bufferize.
}];
- let arguments = (ins Arg<AnyRankedOrUnrankedMemRef,
+ let arguments = (ins Arg<Bufferization_BufferLikeTypeInterface,
"the reference to load from",
[MemReadAt<0, FullEffect>]>:$memref,
UnitAttr:$restrict, UnitAttr:$writable);
- let results = (outs AnyTensor:$result);
+ let results = (outs Bufferization_TensorLikeTypeInterface:$result);
let extraClassDeclaration = [{
/// The result of a to_tensor is always a tensor.
@@ -465,10 +466,10 @@ def Bufferization_ToTensorOp : Bufferization_Op<"to_tensor", [
bool isWritable(Value value, const AnalysisState &state);
- FailureOr<BaseMemRefType> getBufferType(
+ FailureOr<BufferLikeType> getBufferType(
Value value, const BufferizationOptions &options,
SmallVector<Value> &invocationStack) {
- return ::llvm::cast<BaseMemRefType>(getMemref().getType());
+ return ::llvm::cast<BufferLikeType>(getMemref().getType());
}
}];
@@ -493,6 +494,7 @@ def Bufferization_ToTensorOp : Bufferization_Op<"to_tensor", [
// ToMemrefOp
//===----------------------------------------------------------------------===//
+// TODO: rename to "to_buffer"
def Bufferization_ToMemrefOp : Bufferization_Op<"to_memref", [
BufferizableOpInterface,
SameOperandsAndResultShape,
@@ -519,8 +521,9 @@ def Bufferization_ToMemrefOp : Bufferization_Op<"to_memref", [
the returned buffer) will not be written to.
}];
- let arguments = (ins AnyTensor:$tensor, UnitAttr:$read_only);
- let results = (outs AnyRankedOrUnrankedMemRef:$memref);
+ let arguments = (ins Bufferization_TensorLikeTypeInterface:$tensor,
+ UnitAttr:$read_only);
+ let results = (outs Bufferization_BufferLikeTypeInterface:$memref);
let extraClassDeclaration = [{
//===------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.h
index 5faa1479ee542..290f1298f2501 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.h
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.h
@@ -13,6 +13,7 @@
// Bufferization Type Interfaces
//===----------------------------------------------------------------------===//
+#include "mlir/IR/Attributes.h" // mlir::Attribute
#include "mlir/IR/Types.h"
#include "mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.h.inc"
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.td b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.td
index f19224a295648..c053a6bdc1a91 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.td
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.td
@@ -33,10 +33,17 @@ def Bufferization_BufferLikeTypeInterface
let description = [{
Indicates that this type is a buffer type (similarly to a MLIR builtin
memref) for bufferization purposes.
-
- The interface currently has no methods as it is used by types to opt into
- being supported by the bufferization procedures.
}];
+
+ let methods = [
+ InterfaceMethod<
+ /*desc=*/[{
+ Returns the memory space in which data referred to by this buffer resides.
+ }],
+ /*retType=*/"::mlir::Attribute",
+ /*methodName=*/"getMemorySpace"
+ >,
+ ];
}
#endif // BUFFERIZATION_TYPE_INTERFACES
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/UnstructuredControlFlow.h b/mlir/include/mlir/Dialect/Bufferization/IR/UnstructuredControlFlow.h
index 78109770efab7..89eb65c4a0942 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/UnstructuredControlFlow.h
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/UnstructuredControlFlow.h
@@ -32,7 +32,7 @@ template <typename ConcreteModel, typename ConcreteOp>
struct OpWithUnstructuredControlFlowBufferizableOpInterfaceExternalModel
: public BufferizableOpInterface::ExternalModel<ConcreteModel, ConcreteOp> {
- FailureOr<BaseMemRefType>
+ FailureOr<BufferLikeType>
getBufferType(Operation *op, Value value, const BufferizationOptions &options,
SmallVector<Value> &invocationStack) const {
// Note: The user may want to override this function for OpResults in
@@ -46,7 +46,7 @@ struct OpWithUnstructuredControlFlowBufferizableOpInterfaceExternalModel
// operand types of all forwarded values. If these are all the same type,
// take that type. Otherwise, take only the memory space and fall back to a
// buffer type with a fully dynamic layout map.
- BaseMemRefType bufferType;
+ BufferLikeType bufferType;
auto tensorType = cast<TensorType>(value.getType());
for (OpOperand *opOperand :
detail::getCallerOpOperands(cast<BlockArgument>(value))) {
@@ -59,13 +59,13 @@ struct OpWithUnstructuredControlFlowBufferizableOpInterfaceExternalModel
continue;
// Compute the bufferized type of the forwarded operand.
- BaseMemRefType callerType;
- if (auto memrefType =
- dyn_cast<BaseMemRefType>(opOperand->get().getType())) {
+ BufferLikeType callerType;
+ if (auto bufferLikeType =
+ dyn_cast<BufferLikeType>(opOperand->get().getType())) {
// The operand was already bufferized. Take its type directly.
- callerType = memrefType;
+ callerType = bufferLikeType;
} else {
- FailureOr<BaseMemRefType> maybeCallerType =
+ FailureOr<BufferLikeType> maybeCallerType =
bufferization::getBufferType(opOperand->get(), options,
invocationStack);
if (failed(maybeCallerType))
@@ -86,14 +86,20 @@ struct OpWithUnstructuredControlFlowBufferizableOpInterfaceExternalModel
// of the earlier forwarded operands, fall back to a buffer type with a
// fully dynamic layout map.
#ifndef NDEBUG
+ assert(mlir::isa<BaseMemRefType>(bufferType) &&
+ mlir::isa<BaseMemRefType>(callerType) && "expected memrefs");
+ auto memrefType = mlir::cast<BaseMemRefType>(bufferType);
+ auto callerMemrefType = mlir::cast<BaseMemRefType>(callerType);
+
if (auto rankedTensorType = dyn_cast<RankedTensorType>(tensorType)) {
- assert(bufferType.hasRank() && callerType.hasRank() &&
+ assert(memrefType.hasRank() && callerMemrefType.hasRank() &&
"expected ranked memrefs");
- assert(llvm::all_equal({bufferType.getShape(), callerType.getShape(),
- rankedTensorType.getShape()}) &&
- "expected same shape");
+ assert(
+ llvm::all_equal({memrefType.getShape(), callerMemrefType.getShape(),
+ rankedTensorType.getShape()}) &&
+ "expected same shape");
} else {
- assert(!bufferType.hasRank() && !callerType.hasRank() &&
+ assert(!memrefType.hasRank() && !callerMemrefType.hasRank() &&
"expected unranked memrefs");
}
#endif // NDEBUG
@@ -102,8 +108,9 @@ struct OpWithUnstructuredControlFlowBufferizableOpInterfaceExternalModel
return op->emitOpError("incoming operands of block argument have "
"inconsistent memory spaces");
- bufferType = getMemRefTypeWithFullyDynamicLayout(
- tensorType, bufferType.getMemorySpace());
+ bufferType =
+ mlir::cast<BufferLikeType>(getMemRefTypeWithFullyDynamicLayout(
+ tensorType, bufferType.getMemorySpace()));
}
if (!bufferType)
diff --git a/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp
index 5e69a98db8f1e..433757192bfd1 100644
--- a/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -26,7 +26,7 @@ struct ConstantOpInterface
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
const BufferizationOptions &options) const {
auto constantOp = cast<arith::ConstantOp>(op);
- auto type = dyn_cast<RankedTensorType>(constantOp.getType());
+ auto type = dyn_cast<TensorLikeType>(constantOp.getType());
// Only ranked tensors are supported.
if (!type)
@@ -176,7 +176,7 @@ struct SelectOpInterface
return success();
}
- FailureOr<BaseMemRefType>
+ FailureOr<bufferization::BufferLikeType>
getBufferType(Operation *op, Value value, const BufferizationOptions &options,
SmallVector<Value> &invocationStack) const {
auto selectOp = cast<arith::SelectOp>(op);
@@ -195,10 +195,11 @@ struct SelectOpInterface
// If the buffers have different types, they differ only in their layout
// map.
auto memrefType = llvm::cast<MemRefType>(*trueType);
- return getMemRefTypeWithFullyDynamicLayout(
- RankedTensorType::get(memrefType.getShape(),
- memrefType.getElementType()),
- memrefType.getMemorySpace());
+ return mlir::cast<bufferization::BufferLikeType>(
+ getMemRefTypeWithFullyDynamicLayout(
+ RankedTensorType::get(memrefType.getShape(),
+ memrefType.getElementType()),
+ memrefType.getMemorySpace()));
}
};
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
index 99ffa62c41a4d..82ff1bdfe5fd7 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
@@ -206,12 +206,13 @@ FailureOr<Value> bufferization::allocateTensorForShapedValue(
// Add 'memory_space' attribute. Not needed if 'copy' operand is specified.
if (copy)
return allocTensorOp.getResult();
- FailureOr<BaseMemRefType> copyBufferType = getBufferType(tensor, options);
+ FailureOr<BufferLikeType> copyBufferType = getBufferType(tensor, options);
if (failed(copyBufferType))
return failure();
std::optional<Attribute> memorySpace = copyBufferType->getMemorySpace();
if (!memorySpace)
- memorySpace = options.defaultMemorySpaceFn(tensorType);
+ memorySpace =
+ options.defaultMemorySpaceFn(mlir::cast<TensorLikeType>(tensorType));
if (memorySpace.has_value())
allocTensorOp.setMemorySpaceAttr(memorySpace.value());
return allocTensorOp.getResult();
@@ -229,6 +230,7 @@ LogicalResult BufferizableOpInterface::resolveTensorOpOperandConflicts(
// Find all out-of-place OpOperands.
for (OpOperand &opOperand : op->getOpOperands()) {
Type operandType = opOperand.get().getType();
+ // Note: can only copy TensorType (any other TensorLikeType is rejected)
if (!llvm::isa<TensorType>(operandType))
continue;
if (state.isInPlace(opOperand))
@@ -328,18 +330,21 @@ bool OpFilter::isOpAllowed(Operation *op) const {
namespace {
/// Default function arg type converter: Use a fully dynamic layout map.
-BaseMemRefType
-defaultFunctionArgTypeConverter(TensorType type, Attribute memorySpace,
- func::FuncOp funcOp,
+bufferization::BufferLikeType
+defaultFunctionArgTypeConverter(bufferization::TensorLikeType type,
+ Attribute memorySpace, func::FuncOp funcOp,
const BufferizationOptions &options) {
- return getMemRefTypeWithFullyDynamicLayout(type, memorySpace);
+ return mlir::cast<bufferization::BufferLikeType>(
+ getMemRefTypeWithFullyDynamicLayout(mlir::cast<TensorType>(type),
+ memorySpace));
}
/// Default unknown type converter: Use a fully dynamic layout map.
-BaseMemRefType
+BufferLikeType
defaultUnknownTypeConverter(Value value, Attribute memorySpace,
const BufferizationOptions &options) {
- return getMemRefTypeWithFullyDynamicLayout(
- llvm::cast<TensorType>(value.getType()), memorySpace);
+ return mlir::cast<bufferization::BufferLikeType>(
+ getMemRefTypeWithFullyDynamicLayout(
+ llvm::cast<TensorType>(value.getType()), memorySpace));
}
} // namespace
@@ -376,14 +381,16 @@ BufferizationOptions::dynCastBufferizableOp(Value value) const {
void BufferizationOptions::setFunctionBoundaryTypeConversion(
LayoutMapOption layoutMapOption) {
- functionArgTypeConverterFn = [=](TensorType tensorType, Attribute memorySpace,
- func::FuncOp funcOp,
+ functionArgTypeConverterFn = [=](TensorLikeType tensorType,
+ Attribute memorySpace, func::FuncOp funcOp,
const BufferizationOptions &options) {
if (layoutMapOption == LayoutMapOption::IdentityLayoutMap)
- return bufferization::getMemRefTypeWithStaticIdentityLayout(tensorType,
- memorySpace);
- return bufferization::getMemRefTypeWithFullyDynamicLayout(tensorType,
- memorySpace);
+ return mlir::cast<bufferization::BufferLikeType>(
+ bufferization::getMemRefTypeWithStaticIdentityLayout(
+ mlir::cast<TensorType>(tensorType), memorySpace));
+ return mlir::cast<bufferization::BufferLikeType>(
+ bufferization::getMemRefTypeWithFullyDynamicLayout(
+ mlir::cast<TensorType>(tensorType), memorySpace));
};
inferFunctionResultLayout =
layoutMapOption == LayoutMapOption::InferLayoutMap;
@@ -473,7 +480,8 @@ bool AnalysisState::bufferizesToMemoryWrite(Value value) const {
/// read. Also takes into account ops that create an alias but do not read by
/// themselves (e.g., ExtractSliceOp).
bool AnalysisState::isValueRead(Value value) const {
- assert(llvm::isa<TensorType>(value.getType()) && "expected TensorType");
+ assert(llvm::isa<bufferization::TensorLikeType>(value.getType()) &&
+ "expected TensorLikeType");
SmallVector<OpOperand *> workingSet;
DenseSet<OpOperand *> visited;
for (OpOperand &use : value.getUses())
@@ -663,7 +671,8 @@ static void ensureToMemrefOpIsValid(Value tensor, Type memrefType) {
FailureOr<Value> bufferization::getBuffer(RewriterBase &rewriter, Value value,
const BufferizationOptions &options) {
#ifndef NDEBUG
- auto tensorType = llvm::dyn_cast<TensorType>(value.getType());
+ auto tensorType =
+ llvm::dyn_cast<bufferization::TensorLikeType>(value.getType());
assert(tensorType && "unexpected non-tensor type");
#endif // NDEBUG
@@ -674,7 +683,7 @@ FailureOr<Value> bufferization::getBuffer(RewriterBase &rewriter, Value value,
// Insert to_memref op.
OpBuilder::InsertionGuard g(rewriter);
setInsertionPointAfter(rewriter, value);
- FailureOr<BaseMemRefType> memrefType = getBufferType(value, options);
+ FailureOr<BufferLikeType> memrefType = getBufferType(value, options);
if (failed(memrefType))
return failure();
ensureToMemrefOpIsValid(value, *memrefType);
@@ -684,18 +693,18 @@ FailureOr<Value> bufferization::getBuffer(RewriterBase &rewriter, Value value,
}
/// Return the buffer type for a given Value (tensor) after bufferization.
-FailureOr<BaseMemRefType>
+FailureOr<BufferLikeType>
bufferization::getBufferType(Value value, const BufferizationOptions &options) {
SmallVector<Value> invocationStack;
return getBufferType(value, options, invocationStack);
}
/// Return the buffer type for a given Value (tensor) after bufferization.
-FailureOr<BaseMemRefType>
+FailureOr<BufferLikeType>
bufferization::getBufferType(Value value, const BufferizationOptions &options,
SmallVector<Value> &invocationStack) {
- assert(llvm::isa<TensorType>(value.getType()) &&
- "unexpected non-tensor type");
+ assert(llvm::isa<TensorLikeType>(value.getType()) &&
+ "unexpected non-tensor-like type");
invocationStack.push_back(value);
auto popFromStack =
llvm::make_scope_exit([&]() { invocationStack.pop_back(); });
@@ -708,11 +717,12 @@ bufferization::getBufferType(Value value, const BufferizationOptions &options,
// Op is not bufferizable.
auto memSpace =
- options.defaultMemorySpaceFn(cast<TensorType>(value.getType()));
+ options.defaultMemorySpaceFn(cast<TensorLikeType>(value.getType()));
if (!memSpace.has_value())
return op->emitError("could not infer memory space");
- return getMemRefType(value, options, /*layout=*/{}, *memSpace);
+ return mlir::cast<BufferLikeType>(
+ getMemRefType(value, options, /*layout=*/{}, *memSpace));
}
bool bufferization::hasTensorSemantics(Operation *op) {
@@ -732,12 +742,11 @@ void bufferization::replaceOpWithBufferizedValues(RewriterBase &rewriter,
SmallVector<Value> replacements;
for (OpResult opResult : op->getOpResults()) {
Value replacement = values[opResult.getResultNumber()];
- if (llvm::isa<TensorType>(opResult.getType())) {
- // The OpResult is a tensor. Such values are replaced with memrefs during
+ if (llvm::isa<bufferization::TensorLikeType>(opResult.getType())) {
+ // The OpResult is a tensor. Such values are replaced with buffers during
// bufferization.
- assert((llvm::isa<MemRefType>(replacement.getType()) ||
- llvm::isa<UnrankedMemRefType>(replacement.getType())) &&
- "tensor op result should be replaced with a memref value");
+ assert(llvm::isa<bufferization::BufferLikeType>(replacement.getType()) &&
+ "tensor op result should be replaced with a buffer value");
// The existing uses of the OpResult still expect a tensor. Insert a
// ToTensorOp. Throughout bufferization, this ToTensorOp will gradually
// loose all of its users and eventually DCE away.
@@ -789,6 +798,8 @@ BaseMemRefType bufferization::getMemRefType(Value value,
const BufferizationOptions &options,
MemRefLayoutAttrInterface layout,
Attribute memorySpace) {
+ assert(mlir::isa<TensorType>(value.getType()) &&
+ "expected tensor type in tensor -> memref conversion");
auto tensorType = llvm::cast<TensorType>(value.getType());
// Case 1: Unranked memref type.
@@ -807,7 +818,8 @@ BaseMemRefType bufferization::getMemRefType(Value value,
memorySpace);
}
- return options.unknownTypeConverterFn(value, memorySpace, options);
+ return mlir::cast<BaseMemRefType>(
+ options.unknownTypeConverterFn(value, memorySpace, options));
}
BaseMemRefType
@@ -928,7 +940,7 @@ AliasingOpOperandList bufferization::detail::defaultGetAliasingOpOperands(
Operation *op = getOwnerOfValue(value);
SmallVector<AliasingOpOperand> result;
for (OpOperand &opOperand : op->getOpOperands()) {
- if (!llvm::isa<TensorType>(opOperand.get().getType()))
+ if (!llvm::isa<bufferization::TensorLikeType>(opOperand.get().getType()))
continue;
AliasingValueList aliasingValues = state.getAliasingValues(opOperand);
for (const auto &it : aliasingValues)
@@ -938,14 +950,15 @@ AliasingOpOperandList bufferization::detail::defaultGetAliasingOpOperands(
return AliasingOpOperandList(std::move(result));
}
-FailureOr<BaseMemRefType> bufferization::detail::defaultGetBufferType(
+FailureOr<BufferLikeType> bufferization::detail::defaultGetBufferType(
Value value, const BufferizationOptions &options,
SmallVector<Value> &invocationStack) {
- assert(llvm::isa<TensorType>(value.getType()) && "expected tensor type");
+ assert(llvm::isa<TensorLikeType>(value.getType()) && "expected tensor type");
// No further analysis is possible for a block argument.
if (llvm::isa<BlockArgument>(value))
- return bufferization::getMemRefType(value, options);
+ return mlir::cast<BufferLikeType>(
+ bufferization::getMemRefType(value, options));
// Value is an OpResult.
Operation *op = getOwnerOfValue(value);
@@ -963,11 +976,12 @@ FailureOr<BaseMemRefType> bufferization::detail::defaultGetBufferType(
// If we do not know the memory space and there is no default memory space,
// report a failure.
auto memSpace =
- options.defaultMemorySpaceFn(cast<TensorType>(value.getType()));
+ options.defaultMemorySpaceFn(cast<TensorLikeType>(value.getType()));
if (!memSpace.has_value())
return op->emitError("could not infer memory space");
- return getMemRefType(value, options, /*layout=*/{}, *memSpace);
+ return mlir::cast<BufferLikeType>(
+ getMemRefType(value, options, /*layout=*/{}, *memSpace));
}
bool bufferization::detail::defaultIsRepetitiveRegion(
@@ -993,7 +1007,7 @@ bufferization::detail::unknownGetAliasingOpOperands(Value value) {
// with every OpOperand.
AliasingOpOperandList r;
for (OpOperand &operand : value.getDefiningOp()->getOpOperands())
- if (isa<TensorType>(operand.get().getType()))
+ if (isa<bufferization::TensorLikeType>(operand.get().getType()))
r.addAlias({&operand, BufferRelation::Unknown, /*isDefinite=*/false});
return r;
}
@@ -1006,18 +1020,18 @@ bufferization::detail::unknownGetAliasingValues(OpOperand &opOperand) {
// with every OpOperand.
AliasingValueList r;
for (OpResult result : opOperand.getOwner()->getOpResults())
- if (llvm::isa<TensorType>(result.getType()))
+ if (llvm::isa<bufferization::TensorLikeType>(result.getType()))
r.addAlias({result, BufferRelation::Unknown, /*isDefinite=*/false});
for (Region ®ion : opOperand.getOwner()->getRegions())
if (!region.getBlocks().empty())
for (BlockArgument bbArg : region.getBlocks().front().getArguments())
- if (isa<TensorType>(bbArg.getType()))
+ if (isa<bufferization::TensorLikeType>(bbArg.getType()))
r.addAlias({bbArg, BufferRelation::Unknown, /*isDefinite=*/false});
return r;
}
bool bufferization::detail::defaultHasTensorSemantics(Operation *op) {
- auto isaTensor = [](Type t) { return isa<TensorType>(t); };
+ auto isaTensor = [](Type t) { return isa<bufferization::TensorLikeType>(t); };
bool hasTensorBlockArgument = any_of(op->getRegions(), [&](Region &r) {
return any_of(r.getBlocks(), [&](Block &b) {
return any_of(b.getArguments(), [&](BlockArgument bbArg) {
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizationDialect.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizationDialect.cpp
index 6b9253a5d71da..02f9252dcb088 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizationDialect.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizationDialect.cpp
@@ -62,7 +62,11 @@ struct BuiltinTensorExternalModel
template <typename MemRef>
struct BuiltinMemRefExternalModel
: BufferLikeType::ExternalModel<BuiltinMemRefExternalModel<MemRef>,
- MemRef> {};
+ MemRef> {
+ mlir::Attribute getMemorySpace(mlir::Type type) const {
+ return mlir::cast<MemRef>(type).getMemorySpace();
+ }
+};
} // namespace
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
index 4fce9be390bd6..2ceb6795899c9 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
@@ -220,7 +220,7 @@ AliasingValueList AllocTensorOp::getAliasingValues(OpOperand &opOperand,
return {};
}
-FailureOr<BaseMemRefType>
+FailureOr<bufferization::BufferLikeType>
AllocTensorOp::getBufferType(Value value, const BufferizationOptions &options,
SmallVector<Value> &invocationStack) {
assert(value == getResult() && "invalid value");
@@ -235,13 +235,15 @@ AllocTensorOp::getBufferType(Value value, const BufferizationOptions &options,
if (failed(copyBufferType))
return failure();
memorySpace = copyBufferType->getMemorySpace();
- } else if (auto ms = options.defaultMemorySpaceFn(getType())) {
+ } else if (auto ms = options.defaultMemorySpaceFn(
+ mlir::cast<TensorLikeType>(getType()))) {
memorySpace = *ms;
} else {
return getOperation()->emitError("could not infer memory space");
}
- return getMemRefTypeWithStaticIdentityLayout(getType(), memorySpace);
+ return mlir::cast<BufferLikeType>(
+ getMemRefTypeWithStaticIdentityLayout(getType(), memorySpace));
}
LogicalResult AllocTensorOp::verify() {
@@ -585,7 +587,7 @@ MaterializeInDestinationOp::bufferize(RewriterBase &rewriter,
return failure();
buffer = *maybeBuffer;
} else {
- assert(isa<BaseMemRefType>(getDest().getType()) && "expected memref type");
+ assert(isa<BufferLikeType>(getDest().getType()) && "expected buffer type");
buffer = getDest();
}
auto srcBuffer = getBuffer(rewriter, getSource(), options);
@@ -632,7 +634,7 @@ Value MaterializeInDestinationOp::buildSubsetExtraction(OpBuilder &builder,
return {};
// Build a bufferization.to_tensor op.
- assert(isa<BaseMemRefType>(getDest().getType()) && "expected memref type");
+ assert(isa<BufferLikeType>(getDest().getType()) && "expected buffer type");
assert(getRestrict() &&
"expected that ops with memrefs dest have 'restrict'");
setRestrict(false);
@@ -667,22 +669,22 @@ bool MaterializeInDestinationOp::operatesOnDisjointSubset(
}
LogicalResult MaterializeInDestinationOp::verify() {
- if (!isa<TensorType, BaseMemRefType>(getDest().getType()))
- return emitOpError("'dest' must be a tensor or a memref");
+ if (!isa<TensorType, BufferLikeType>(getDest().getType()))
+ return emitOpError("'dest' must be a tensor or a buffer");
if (auto destType = dyn_cast<TensorType>(getDest().getType())) {
if (getOperation()->getNumResults() != 1)
return emitOpError("tensor 'dest' implies exactly one tensor result");
if (destType != getResult().getType())
return emitOpError("result and 'dest' types must match");
}
- if (isa<BaseMemRefType>(getDest().getType()) &&
+ if (isa<BufferLikeType>(getDest().getType()) &&
getOperation()->getNumResults() != 0)
- return emitOpError("memref 'dest' implies zero results");
- if (getRestrict() && !isa<BaseMemRefType>(getDest().getType()))
- return emitOpError("'restrict' is valid only for memref destinations");
- if (getWritable() != isa<BaseMemRefType>(getDest().getType()))
+ return emitOpError("buffer 'dest' implies zero results");
+ if (getRestrict() && !isa<BufferLikeType>(getDest().getType()))
+ return emitOpError("'restrict' is valid only for buffer destinations");
+ if (getWritable() != isa<BufferLikeType>(getDest().getType()))
return emitOpError("'writable' must be specified if and only if the "
- "destination is of memref type");
+ "destination is of buffer type");
TensorType srcType = getSource().getType();
ShapedType destType = cast<ShapedType>(getDest().getType());
if (srcType.hasRank() != destType.hasRank())
@@ -724,7 +726,7 @@ MutableOperandRange MaterializeInDestinationOp::getDpsInitsMutable() {
void MaterializeInDestinationOp::getEffects(
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
&effects) {
- if (isa<BaseMemRefType>(getDest().getType()))
+ if (isa<BufferLikeType>(getDest().getType()))
effects.emplace_back(MemoryEffects::Write::get(), &getDestMutable(),
SideEffects::DefaultResource::get());
}
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizationTypeInterfaces.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizationTypeInterfaces.cpp
new file mode 100644
index 0000000000000..0e973915c6fc9
--- /dev/null
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizationTypeInterfaces.cpp
@@ -0,0 +1,21 @@
+//===- BufferizationTypeInterfaces.cpp - Type Interfaces --------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.h"
+
+//===----------------------------------------------------------------------===//
+// Bufferization Type Interfaces
+//===----------------------------------------------------------------------===//
+
+namespace mlir {
+namespace bufferization {
+
+#include "mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.cpp.inc"
+
+} // namespace bufferization
+} // namespace mlir
diff --git a/mlir/lib/Dialect/Bufferization/IR/CMakeLists.txt b/mlir/lib/Dialect/Bufferization/IR/CMakeLists.txt
index 63dcc1eb233e9..5d8f0060f2c3f 100644
--- a/mlir/lib/Dialect/Bufferization/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/Bufferization/IR/CMakeLists.txt
@@ -6,6 +6,7 @@ add_mlir_dialect_library(MLIRBufferizationDialect
BufferizationDialect.cpp
BufferViewFlowOpInterface.cpp
UnstructuredControlFlow.cpp
+ BufferizationTypeInterfaces.cpp
ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Bufferization
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.cpp b/mlir/lib/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.cpp
index 72f47b8b468ea..cb9db1288039a 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.cpp
@@ -9,6 +9,7 @@
#include "mlir/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.h"
#include "mlir/Dialect/Bufferization/IR/BufferViewFlowOpInterface.h"
+#include "mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.h"
#include "mlir/Interfaces/CallInterfaces.h"
#include "mlir/Interfaces/ControlFlowInterfaces.h"
#include "mlir/Interfaces/FunctionInterfaces.h"
@@ -93,11 +94,11 @@ void BufferViewFlowAnalysis::build(Operation *op) {
// given op as terminals.
auto populateTerminalValues = [&](Operation *op) {
for (Value v : op->getResults())
- if (isa<BaseMemRefType>(v.getType()))
+ if (isa<BufferLikeType>(v.getType()))
this->terminals.insert(v);
for (Region &r : op->getRegions())
for (BlockArgument v : r.getArguments())
- if (isa<BaseMemRefType>(v.getType()))
+ if (isa<BufferLikeType>(v.getType()))
this->terminals.insert(v);
};
@@ -108,12 +109,12 @@ void BufferViewFlowAnalysis::build(Operation *op) {
if (auto bufferViewFlowOp = dyn_cast<BufferViewFlowOpInterface>(op)) {
bufferViewFlowOp.populateDependencies(registerDependencies);
for (Value v : op->getResults())
- if (isa<BaseMemRefType>(v.getType()) &&
+ if (isa<BufferLikeType>(v.getType()) &&
bufferViewFlowOp.mayBeTerminalBuffer(v))
this->terminals.insert(v);
for (Region &r : op->getRegions())
for (BlockArgument v : r.getArguments())
- if (isa<BaseMemRefType>(v.getType()) &&
+ if (isa<BufferLikeType>(v.getType()) &&
bufferViewFlowOp.mayBeTerminalBuffer(v))
this->terminals.insert(v);
return WalkResult::advance();
@@ -201,7 +202,7 @@ void BufferViewFlowAnalysis::build(Operation *op) {
}
bool BufferViewFlowAnalysis::mayBeTerminalBuffer(Value value) const {
- assert(isa<BaseMemRefType>(value.getType()) && "expected memref");
+ assert(isa<BufferLikeType>(value.getType()) && "expected memref");
return terminals.contains(value);
}
@@ -240,8 +241,8 @@ static Value getViewBase(Value value) {
BufferOriginAnalysis::BufferOriginAnalysis(Operation *op) : analysis(op) {}
std::optional<bool> BufferOriginAnalysis::isSameAllocation(Value v1, Value v2) {
- assert(isa<BaseMemRefType>(v1.getType()) && "expected buffer");
- assert(isa<BaseMemRefType>(v2.getType()) && "expected buffer");
+ assert(isa<BufferLikeType>(v1.getType()) && "expected buffer");
+ assert(isa<BufferLikeType>(v2.getType()) && "expected buffer");
// Skip over all view-like ops.
v1 = getViewBase(v1);
@@ -275,7 +276,7 @@ std::optional<bool> BufferOriginAnalysis::isSameAllocation(Value v1, Value v2) {
bool &allAllocs,
bool &allAllocsOrFuncEntryArgs) {
for (Value v : origin) {
- if (isa<BaseMemRefType>(v.getType()) && analysis.mayBeTerminalBuffer(v)) {
+ if (isa<BufferLikeType>(v.getType()) && analysis.mayBeTerminalBuffer(v)) {
terminal.insert(v);
allAllocs &= hasAllocateSideEffect(v);
allAllocsOrFuncEntryArgs &=
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
index 0b60c44ece5fd..a296b617024d8 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
@@ -80,14 +80,14 @@ struct OneShotBufferizePass
if (mustInferMemorySpace) {
opt.defaultMemorySpaceFn =
- [](TensorType t) -> std::optional<Attribute> {
+ [](TensorLikeType t) -> std::optional<Attribute> {
return std::nullopt;
};
}
if (useEncodingForMemorySpace) {
opt.defaultMemorySpaceFn =
- [](TensorType t) -> std::optional<Attribute> {
+ [](TensorLikeType t) -> std::optional<Attribute> {
if (auto rtt = dyn_cast<RankedTensorType>(t))
return rtt.getEncoding();
return std::nullopt;
@@ -113,13 +113,15 @@ struct OneShotBufferizePass
const BufferizationOptions &options) {
auto tensorType = cast<TensorType>(value.getType());
if (unknownTypeConversionOption == LayoutMapOption::IdentityLayoutMap)
- return bufferization::getMemRefTypeWithStaticIdentityLayout(
- tensorType, memorySpace);
+ return mlir::cast<BufferLikeType>(
+ bufferization::getMemRefTypeWithStaticIdentityLayout(
+ tensorType, memorySpace));
assert(unknownTypeConversionOption ==
LayoutMapOption::FullyDynamicLayoutMap &&
"invalid layout map option");
- return bufferization::getMemRefTypeWithFullyDynamicLayout(tensorType,
- memorySpace);
+ return mlir::cast<BufferLikeType>(
+ bufferization::getMemRefTypeWithFullyDynamicLayout(tensorType,
+ memorySpace));
};
// Configure op filter.
@@ -407,7 +409,7 @@ bufferization::bufferizeBlockSignature(Block *block, RewriterBase &rewriter,
continue;
}
- FailureOr<BaseMemRefType> memrefType =
+ FailureOr<BufferLikeType> memrefType =
bufferization::getBufferType(bbArg, options);
if (failed(memrefType))
return failure();
@@ -458,7 +460,7 @@ bufferization::bufferizeBlockSignature(Block *block, RewriterBase &rewriter,
newOperands.push_back(operand);
continue;
}
- FailureOr<BaseMemRefType> operandBufferType =
+ FailureOr<BufferLikeType> operandBufferType =
bufferization::getBufferType(operand, options);
if (failed(operandBufferType))
return failure();
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
index c45678f1e4b4d..4d39d9b795bed 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
@@ -53,14 +53,14 @@ void FuncAnalysisState::startFunctionAnalysis(FuncOp funcOp) {
/// Return the index-th bufferized function argument type. This assumes that the
/// specified argument is a tensor. If the tensor is ranked, a layout map may be
/// specified by the user (as per `options.functionArgTypeConverterFn`).
-static BaseMemRefType
+static BufferLikeType
getBufferizedFunctionArgType(FuncOp funcOp, int64_t index,
const BufferizationOptions &options) {
auto tensorType =
- dyn_cast<TensorType>(funcOp.getFunctionType().getInput(index));
- assert(tensorType && "expected TensorType");
+ dyn_cast<TensorLikeType>(funcOp.getFunctionType().getInput(index));
+ assert(tensorType && "expected TensorLikeType");
- BaseMemRefType memrefType = options.functionArgTypeConverterFn(
+ BufferLikeType memrefType = options.functionArgTypeConverterFn(
tensorType, *options.defaultMemorySpaceFn(tensorType), funcOp, options);
auto layoutAttr = funcOp.getArgAttrOfType<AffineMapAttr>(
@@ -70,9 +70,9 @@ getBufferizedFunctionArgType(FuncOp funcOp, int64_t index,
auto rankedMemrefType = dyn_cast<MemRefType>(memrefType);
assert(rankedMemrefType && "buffer layout not supported on unranked tensors");
- return MemRefType::get(
+ return mlir::cast<BufferLikeType>(MemRefType::get(
rankedMemrefType.getShape(), rankedMemrefType.getElementType(),
- layoutAttr.getValue(), rankedMemrefType.getMemorySpace());
+ layoutAttr.getValue(), rankedMemrefType.getMemorySpace()));
}
/// Return the FuncOp called by `callOp`.
@@ -195,7 +195,7 @@ struct CallOpInterface
return result;
}
- FailureOr<BaseMemRefType>
+ FailureOr<BufferLikeType>
getBufferType(Operation *op, Value value, const BufferizationOptions &options,
SmallVector<Value> &invocationStack) const {
auto callOp = cast<func::CallOp>(op);
@@ -207,11 +207,11 @@ struct CallOpInterface
FunctionType funcType = funcOp.getFunctionType();
Type resultType =
funcType.getResult(cast<OpResult>(value).getResultNumber());
- if (auto bufferizedType = dyn_cast<BaseMemRefType>(resultType))
+ if (auto bufferizedType = dyn_cast<BufferLikeType>(resultType))
return bufferizedType;
// Otherwise, call the type converter to compute the bufferized type.
- auto tensorType = cast<TensorType>(resultType);
+ auto tensorType = cast<TensorLikeType>(resultType);
return options.functionArgTypeConverterFn(
tensorType, *options.defaultMemorySpaceFn(tensorType), funcOp, options);
}
@@ -233,7 +233,7 @@ struct CallOpInterface
}
// Returning a memref.
- FailureOr<BaseMemRefType> resultType =
+ FailureOr<BufferLikeType> resultType =
bufferization::getBufferType(result, options);
if (failed(resultType))
return failure();
@@ -263,11 +263,11 @@ struct CallOpInterface
// Caller / callee type mismatch is handled with castOrReallocMemRefValue.
auto memRefType = funcType.getInput(opOperand.getOperandNumber());
- if (!isa<BaseMemRefType>(memRefType)) {
+ if (!isa<BufferLikeType>(memRefType)) {
// The called function was not bufferized yet. This can happen when
// there cycles in the function call graph. Compute the bufferized
// result type.
- FailureOr<BaseMemRefType> maybeMemRefType =
+ FailureOr<BufferLikeType> maybeMemRefType =
bufferization::getBufferType(
funcOp.getArgument(opOperand.getOperandNumber()), options);
if (failed(maybeMemRefType))
@@ -371,7 +371,7 @@ struct FuncOpInterface
return getAliasingBranchOpOperands(op, cast<BlockArgument>(value), state);
}
- FailureOr<BaseMemRefType>
+ FailureOr<BufferLikeType>
getBufferType(Operation *op, Value value, const BufferizationOptions &options,
SmallVector<Value> &invocationStack) const {
auto funcOp = cast<FuncOp>(op);
@@ -413,8 +413,8 @@ struct FuncOpInterface
// Compute the result types.
SmallVector<Type> retTypes;
for (Type resultType : funcType.getResults()) {
- if (auto tensorType = dyn_cast<TensorType>(resultType)) {
- BaseMemRefType resultType = options.functionArgTypeConverterFn(
+ if (auto tensorType = dyn_cast<TensorLikeType>(resultType)) {
+ BufferLikeType resultType = options.functionArgTypeConverterFn(
tensorType, *options.defaultMemorySpaceFn(tensorType), funcOp,
options);
retTypes.push_back(resultType);
diff --git a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
index cf62ee8bc45b5..523ee48be2003 100644
--- a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -102,11 +102,11 @@ struct ConditionOpInterface
SmallVector<Value> newArgs;
for (const auto &it : llvm::enumerate(conditionOp.getArgs())) {
Value value = it.value();
- if (isa<TensorType>(value.getType())) {
+ if (isa<bufferization::TensorLikeType>(value.getType())) {
FailureOr<Value> maybeBuffer = getBuffer(rewriter, value, options);
if (failed(maybeBuffer))
return failure();
- FailureOr<BaseMemRefType> resultType = bufferization::getBufferType(
+ auto resultType = bufferization::getBufferType(
whileOp.getAfterArguments()[it.index()], options);
if (failed(resultType))
return failure();
@@ -201,7 +201,7 @@ struct ExecuteRegionOpInterface
rewriter.setInsertionPointAfter(newOp);
SmallVector<Value> newResults;
for (const auto &it : llvm::enumerate(executeRegionOp->getResultTypes())) {
- if (isa<TensorType>(it.value())) {
+ if (isa<bufferization::TensorLikeType>(it.value())) {
newResults.push_back(rewriter.create<bufferization::ToTensorOp>(
executeRegionOp.getLoc(), it.value(),
newOp->getResult(it.index())));
@@ -244,7 +244,7 @@ struct IfOpInterface
// Compute bufferized result types.
SmallVector<Type> newTypes;
for (Value result : ifOp.getResults()) {
- if (!isa<TensorType>(result.getType())) {
+ if (!isa<bufferization::TensorLikeType>(result.getType())) {
newTypes.push_back(result.getType());
continue;
}
@@ -270,7 +270,7 @@ struct IfOpInterface
return success();
}
- FailureOr<BaseMemRefType>
+ FailureOr<bufferization::BufferLikeType>
getBufferType(Operation *op, Value value, const BufferizationOptions &options,
SmallVector<Value> &invocationStack) const {
auto ifOp = cast<scf::IfOp>(op);
@@ -282,10 +282,10 @@ struct IfOpInterface
auto opResult = cast<OpResult>(value);
auto thenValue = thenYieldOp.getOperand(opResult.getResultNumber());
auto elseValue = elseYieldOp.getOperand(opResult.getResultNumber());
- BaseMemRefType thenBufferType, elseBufferType;
- if (isa<BaseMemRefType>(thenValue.getType())) {
+ bufferization::BufferLikeType thenBufferType, elseBufferType;
+ if (isa<bufferization::BufferLikeType>(thenValue.getType())) {
// True branch was already bufferized.
- thenBufferType = cast<BaseMemRefType>(thenValue.getType());
+ thenBufferType = cast<bufferization::BufferLikeType>(thenValue.getType());
} else {
auto maybeBufferType =
bufferization::getBufferType(thenValue, options, invocationStack);
@@ -293,9 +293,9 @@ struct IfOpInterface
return failure();
thenBufferType = *maybeBufferType;
}
- if (isa<BaseMemRefType>(elseValue.getType())) {
+ if (isa<bufferization::BufferLikeType>(elseValue.getType())) {
// False branch was already bufferized.
- elseBufferType = cast<BaseMemRefType>(elseValue.getType());
+ elseBufferType = cast<bufferization::BufferLikeType>(elseValue.getType());
} else {
auto maybeBufferType =
bufferization::getBufferType(elseValue, options, invocationStack);
@@ -313,8 +313,10 @@ struct IfOpInterface
return op->emitError("inconsistent memory space on then/else branches");
// Layout maps are different: Promote to fully dynamic layout map.
- return getMemRefTypeWithFullyDynamicLayout(
- cast<TensorType>(opResult.getType()), thenBufferType.getMemorySpace());
+ return mlir::cast<bufferization::BufferLikeType>(
+ getMemRefTypeWithFullyDynamicLayout(
+ cast<TensorType>(opResult.getType()),
+ thenBufferType.getMemorySpace()));
}
};
@@ -354,7 +356,7 @@ struct IndexSwitchOpInterface
// Compute bufferized result types.
SmallVector<Type> newTypes;
for (Value result : switchOp.getResults()) {
- if (!isa<TensorType>(result.getType())) {
+ if (!isa<bufferization::TensorLikeType>(result.getType())) {
newTypes.push_back(result.getType());
continue;
}
@@ -384,7 +386,7 @@ struct IndexSwitchOpInterface
return success();
}
- FailureOr<BaseMemRefType>
+ FailureOr<bufferization::BufferLikeType>
getBufferType(Operation *op, Value value, const BufferizationOptions &options,
SmallVector<Value> &invocationStack) const {
auto switchOp = cast<scf::IndexSwitchOp>(op);
@@ -392,11 +394,13 @@ struct IndexSwitchOpInterface
int64_t resultNum = cast<OpResult>(value).getResultNumber();
// Helper function to get buffer type of a case.
- SmallVector<BaseMemRefType> yieldedTypes;
- auto getYieldedBufferType = [&](Block &b) -> FailureOr<BaseMemRefType> {
+ SmallVector<bufferization::BufferLikeType> yieldedTypes;
+ auto getYieldedBufferType =
+ [&](Block &b) -> FailureOr<bufferization::BufferLikeType> {
auto yieldOp = cast<scf::YieldOp>(b.getTerminator());
Value yieldedValue = yieldOp->getOperand(resultNum);
- if (auto bufferType = dyn_cast<BaseMemRefType>(yieldedValue.getType()))
+ if (auto bufferType =
+ dyn_cast<bufferization::BufferLikeType>(yieldedValue.getType()))
return bufferType;
auto maybeBufferType =
bufferization::getBufferType(yieldedValue, options, invocationStack);
@@ -409,7 +413,7 @@ struct IndexSwitchOpInterface
auto maybeBufferType = getYieldedBufferType(switchOp.getDefaultBlock());
if (failed(maybeBufferType))
return failure();
- BaseMemRefType bufferType = *maybeBufferType;
+ auto bufferType = *maybeBufferType;
// Compute buffer types of all other cases.
for (int64_t i = 0, numCases = switchOp.getNumCases(); i < numCases; ++i) {
@@ -426,8 +430,9 @@ struct IndexSwitchOpInterface
return op->emitError("inconsistent memory space on switch cases");
// Layout maps are different: Promote to fully dynamic layout map.
- bufferType = getMemRefTypeWithFullyDynamicLayout(
- cast<TensorType>(value.getType()), bufferType.getMemorySpace());
+ bufferType = mlir::cast<bufferization::BufferLikeType>(
+ getMemRefTypeWithFullyDynamicLayout(cast<TensorType>(value.getType()),
+ bufferType.getMemorySpace()));
}
return bufferType;
@@ -439,7 +444,7 @@ struct IndexSwitchOpInterface
static DenseSet<int64_t> getTensorIndices(ValueRange values) {
DenseSet<int64_t> result;
for (const auto &it : llvm::enumerate(values))
- if (isa<TensorType>(it.value().getType()))
+ if (isa<bufferization::TensorLikeType>(it.value().getType()))
result.insert(it.index());
return result;
}
@@ -452,8 +457,8 @@ DenseSet<int64_t> getEquivalentBuffers(Block::BlockArgListType bbArgs,
unsigned int minSize = std::min(bbArgs.size(), yieldedValues.size());
DenseSet<int64_t> result;
for (unsigned int i = 0; i < minSize; ++i) {
- if (!isa<TensorType>(bbArgs[i].getType()) ||
- !isa<TensorType>(yieldedValues[i].getType()))
+ if (!isa<bufferization::TensorLikeType>(bbArgs[i].getType()) ||
+ !isa<bufferization::TensorLikeType>(yieldedValues[i].getType()))
continue;
if (state.areEquivalentBufferizedValues(bbArgs[i], yieldedValues[i]))
result.insert(i);
@@ -468,7 +473,7 @@ getBuffers(RewriterBase &rewriter, const MutableOperandRange &operands,
const BufferizationOptions &options) {
SmallVector<Value> result;
for (OpOperand &opOperand : operands) {
- if (isa<TensorType>(opOperand.get().getType())) {
+ if (isa<bufferization::TensorLikeType>(opOperand.get().getType())) {
FailureOr<Value> resultBuffer =
getBuffer(rewriter, opOperand.get(), options);
if (failed(resultBuffer))
@@ -516,9 +521,11 @@ getBbArgReplacements(RewriterBase &rewriter, Block::BlockArgListType bbArgs,
/// If both buffer types are equal, no casts are needed the computed buffer type
/// can be used directly. Otherwise, the buffer types can only differ in their
/// layout map and a cast must be inserted.
-static FailureOr<BaseMemRefType> computeLoopRegionIterArgBufferType(
- Operation *loopOp, BlockArgument iterArg, Value initArg, Value yieldedValue,
- const BufferizationOptions &options, SmallVector<Value> &invocationStack) {
+static FailureOr<bufferization::BufferLikeType>
+computeLoopRegionIterArgBufferType(Operation *loopOp, BlockArgument iterArg,
+ Value initArg, Value yieldedValue,
+ const BufferizationOptions &options,
+ SmallVector<Value> &invocationStack) {
// Determine the buffer type of the init_arg.
auto initArgBufferType =
bufferization::getBufferType(initArg, options, invocationStack);
@@ -540,10 +547,11 @@ static FailureOr<BaseMemRefType> computeLoopRegionIterArgBufferType(
}
// Compute the buffer type of the yielded value.
- BaseMemRefType yieldedValueBufferType;
- if (isa<BaseMemRefType>(yieldedValue.getType())) {
+ bufferization::BufferLikeType yieldedValueBufferType;
+ if (isa<bufferization::BufferLikeType>(yieldedValue.getType())) {
// scf.yield was already bufferized.
- yieldedValueBufferType = cast<BaseMemRefType>(yieldedValue.getType());
+ yieldedValueBufferType =
+ cast<bufferization::BufferLikeType>(yieldedValue.getType());
} else {
// Note: This typically triggers a recursive call for the buffer type of
// the iter_arg.
@@ -576,8 +584,9 @@ static FailureOr<BaseMemRefType> computeLoopRegionIterArgBufferType(
"expected same shape");
}
#endif // NDEBUG
- return getMemRefTypeWithFullyDynamicLayout(
- iterTensorType, yieldedBufferType.getMemorySpace());
+ return mlir::cast<bufferization::BufferLikeType>(
+ getMemRefTypeWithFullyDynamicLayout(iterTensorType,
+ yieldedBufferType.getMemorySpace()));
}
/// Return `true` if the given loop may have 0 iterations.
@@ -696,12 +705,13 @@ struct ForOpInterface
return success();
}
- FailureOr<BaseMemRefType>
+ FailureOr<bufferization::BufferLikeType>
getBufferType(Operation *op, Value value, const BufferizationOptions &options,
SmallVector<Value> &invocationStack) const {
auto forOp = cast<scf::ForOp>(op);
assert(getOwnerOfValue(value) == op && "invalid value");
- assert(isa<TensorType>(value.getType()) && "expected tensor type");
+ assert(isa<bufferization::TensorLikeType>(value.getType()) &&
+ "expected tensor type");
if (auto opResult = dyn_cast<OpResult>(value)) {
// The type of an OpResult must match the corresponding iter_arg type.
@@ -744,7 +754,7 @@ struct ForOpInterface
Value initArg = it.value();
Value result = forOp->getResult(it.index());
// If the type is not a tensor, bufferization doesn't need to touch it.
- if (!isa<TensorType>(result.getType())) {
+ if (!isa<bufferization::TensorLikeType>(result.getType())) {
castedInitArgs.push_back(initArg);
continue;
}
@@ -795,7 +805,7 @@ struct ForOpInterface
auto forOp = cast<scf::ForOp>(op);
auto yieldOp = cast<scf::YieldOp>(forOp.getBody()->getTerminator());
for (OpResult opResult : op->getOpResults()) {
- if (!isa<TensorType>(opResult.getType()))
+ if (!isa<bufferization::TensorLikeType>(opResult.getType()))
continue;
// Note: This is overly strict. We should check for aliasing bufferized
@@ -920,7 +930,7 @@ struct WhileOpInterface
for (int64_t idx = 0;
idx < static_cast<int64_t>(conditionOp.getArgs().size()); ++idx) {
Value value = conditionOp.getArgs()[idx];
- if (!isa<TensorType>(value.getType()) ||
+ if (!isa<bufferization::TensorLikeType>(value.getType()) ||
(equivalentYieldsAfter.contains(idx) &&
equivalentYieldsBefore.contains(idx))) {
beforeYieldValues.push_back(value);
@@ -962,7 +972,7 @@ struct WhileOpInterface
Value initArg = it.value();
Value beforeArg = whileOp.getBeforeArguments()[it.index()];
// If the type is not a tensor, bufferization doesn't need to touch it.
- if (!isa<TensorType>(beforeArg.getType())) {
+ if (!isa<bufferization::TensorLikeType>(beforeArg.getType())) {
castedInitArgs.push_back(initArg);
continue;
}
@@ -975,7 +985,7 @@ struct WhileOpInterface
// The result types of a WhileOp are the same as the "after" bbArg types.
SmallVector<Type> argsTypesAfter = llvm::to_vector(
llvm::map_range(whileOp.getAfterArguments(), [&](BlockArgument bbArg) {
- if (!isa<TensorType>(bbArg.getType()))
+ if (!isa<bufferization::TensorLikeType>(bbArg.getType()))
return bbArg.getType();
// TODO: error handling
return llvm::cast<Type>(
@@ -1022,12 +1032,13 @@ struct WhileOpInterface
return success();
}
- FailureOr<BaseMemRefType>
+ FailureOr<bufferization::BufferLikeType>
getBufferType(Operation *op, Value value, const BufferizationOptions &options,
SmallVector<Value> &invocationStack) const {
auto whileOp = cast<scf::WhileOp>(op);
assert(getOwnerOfValue(value) == op && "invalid value");
- assert(isa<TensorType>(value.getType()) && "expected tensor type");
+ assert(isa<bufferization::TensorLikeType>(value.getType()) &&
+ "expected tensor type");
// Case 1: Block argument of the "before" region.
if (auto bbArg = dyn_cast<BlockArgument>(value)) {
@@ -1053,9 +1064,9 @@ struct WhileOpInterface
llvm_unreachable("invalid value");
}
Value conditionYieldedVal = whileOp.getConditionOp().getArgs()[resultNum];
- if (!isa<TensorType>(conditionYieldedVal.getType())) {
+ if (!isa<bufferization::TensorLikeType>(conditionYieldedVal.getType())) {
// scf.condition was already bufferized.
- return cast<BaseMemRefType>(conditionYieldedVal.getType());
+ return cast<bufferization::BufferLikeType>(conditionYieldedVal.getType());
}
return bufferization::getBufferType(conditionYieldedVal, options,
invocationStack);
@@ -1082,7 +1093,7 @@ struct WhileOpInterface
auto conditionOp = whileOp.getConditionOp();
for (const auto &it : llvm::enumerate(conditionOp.getArgs())) {
Block *block = conditionOp->getBlock();
- if (!isa<TensorType>(it.value().getType()))
+ if (!isa<bufferization::TensorLikeType>(it.value().getType()))
continue;
if (it.index() >= block->getNumArguments() ||
!state.areEquivalentBufferizedValues(it.value(),
@@ -1095,7 +1106,7 @@ struct WhileOpInterface
auto yieldOp = whileOp.getYieldOp();
for (const auto &it : llvm::enumerate(yieldOp.getResults())) {
Block *block = yieldOp->getBlock();
- if (!isa<TensorType>(it.value().getType()))
+ if (!isa<bufferization::TensorLikeType>(it.value().getType()))
continue;
if (it.index() >= block->getNumArguments() ||
!state.areEquivalentBufferizedValues(it.value(),
@@ -1154,7 +1165,7 @@ struct YieldOpInterface
SmallVector<Value> newResults;
for (const auto &it : llvm::enumerate(yieldOp.getResults())) {
Value value = it.value();
- if (isa<TensorType>(value.getType())) {
+ if (isa<bufferization::TensorLikeType>(value.getType())) {
FailureOr<Value> maybeBuffer = getBuffer(rewriter, value, options);
if (failed(maybeBuffer))
return failure();
@@ -1162,14 +1173,14 @@ struct YieldOpInterface
// We may have to cast the value before yielding it.
if (isa<scf::ForOp, scf::IfOp, scf::IndexSwitchOp>(
yieldOp->getParentOp())) {
- FailureOr<BaseMemRefType> resultType = bufferization::getBufferType(
+ auto resultType = bufferization::getBufferType(
yieldOp->getParentOp()->getResult(it.index()), options);
if (failed(resultType))
return failure();
buffer = castBuffer(rewriter, buffer, *resultType);
} else if (auto whileOp =
dyn_cast<scf::WhileOp>(yieldOp->getParentOp())) {
- FailureOr<BaseMemRefType> resultType = bufferization::getBufferType(
+ auto resultType = bufferization::getBufferType(
whileOp.getBeforeArguments()[it.index()], options);
if (failed(resultType))
return failure();
@@ -1274,7 +1285,7 @@ struct ForallOpInterface
return success();
}
- FailureOr<BaseMemRefType>
+ FailureOr<bufferization::BufferLikeType>
getBufferType(Operation *op, Value value, const BufferizationOptions &options,
SmallVector<Value> &invocationStack) const {
auto forallOp = cast<ForallOp>(op);
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp
index 6e882a8d0ff30..068c248c1bcd7 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp
@@ -220,8 +220,8 @@ mlir::getBufferizationOptionsForSparsification(bool analysisOnly) {
options.setFunctionBoundaryTypeConversion(LayoutMapOption::IdentityLayoutMap);
options.unknownTypeConverterFn = [](Value value, Attribute memorySpace,
const BufferizationOptions &options) {
- return getMemRefTypeWithStaticIdentityLayout(
- cast<TensorType>(value.getType()), memorySpace);
+ return llvm::cast<BufferLikeType>(getMemRefTypeWithStaticIdentityLayout(
+ cast<TensorType>(value.getType()), memorySpace));
};
if (analysisOnly) {
options.testAnalysisOnly = true;
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenUtils.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenUtils.cpp
index f92382472b478..742a92566a31e 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenUtils.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenUtils.cpp
@@ -550,8 +550,8 @@ TypedValue<BaseMemRefType>
sparse_tensor::genToMemref(OpBuilder &builder, Location loc, Value tensor) {
auto tTp = llvm::cast<TensorType>(tensor.getType());
auto mTp = MemRefType::get(tTp.getShape(), tTp.getElementType());
- return builder.create<bufferization::ToMemrefOp>(loc, mTp, tensor)
- .getResult();
+ return llvm::cast<TypedValue<BaseMemRefType>>(
+ builder.create<bufferization::ToMemrefOp>(loc, mTp, tensor).getResult());
}
Value sparse_tensor::createOrFoldSliceOffsetOp(OpBuilder &builder, Location loc,
diff --git a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
index 31014172a9555..fb0dd151a4448 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -487,8 +487,7 @@ struct FromElementsOpInterface
/*copy=*/false);
if (failed(tensorAlloc))
return failure();
- FailureOr<BaseMemRefType> memrefType =
- bufferization::getBufferType(*tensorAlloc, options);
+ auto memrefType = bufferization::getBufferType(*tensorAlloc, options);
if (failed(memrefType))
return failure();
Value buffer = rewriter.create<bufferization::ToMemrefOp>(
@@ -592,7 +591,8 @@ struct GenerateOpInterface
auto type = generateOp.getResult().getType();
// TODO: Implement memory space for this op.
- if (options.defaultMemorySpaceFn(type) != Attribute())
+ if (options.defaultMemorySpaceFn(llvm::cast<TensorLikeType>(type)) !=
+ Attribute())
return op->emitError("memory space not implemented yet");
// Allocate memory.
@@ -1031,7 +1031,8 @@ struct SplatOpInterface
auto tensorType = cast<RankedTensorType>(tensorAlloc->getType());
// TODO: Implement memory space for this op.
- if (options.defaultMemorySpaceFn(tensorType) != Attribute())
+ if (options.defaultMemorySpaceFn(llvm::cast<TensorLikeType>(tensorType)) !=
+ Attribute())
return op->emitError("memory space not implemented yet");
auto linalgOp =
diff --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize.mlir
index e65c5b92949f6..6fb421675fab6 100644
--- a/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize.mlir
+++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize.mlir
@@ -268,4 +268,23 @@ func.func @materialize_in_dest_raw(%f: f32, %f2: f32, %idx: index) -> (tensor<5x
%r = tensor.extract %dest_filled[%idx] : tensor<5xf32>
return %0, %r : tensor<5xf32>, f32
-}
\ No newline at end of file
+}
+
+// -----
+
+// CHECK-LABEL: func.func @test_dialect_op(
+// CHECK-SAME: %[[ARG:.*]]: !test.test_tensor<[32, 64], f64>
+// CHECK-SAME: ) -> !test.test_tensor<[32, 128], f64> {
+func.func @test_dialect_op(%arg: !test.test_tensor<[32, 64], f64>)
+ -> !test.test_tensor<[32, 128], f64> {
+ // CHECK: %[[MEMREF:.*]] = bufferization.to_memref %[[ARG]]
+ // CHECK: %[[DUMMY:.*]] = "test.dummy_memref_op"(%[[MEMREF]])
+ // CHECK-SAME: : (!test.test_memref<[32, 64], f64>)
+ // CHECK-SAME: -> !test.test_memref<[32, 128], f64>
+ // CHECK: %[[OUT:.*]] = bufferization.to_tensor %[[DUMMY]]
+ %out = "test.dummy_tensor_op"(%arg) : (!test.test_tensor<[32, 64], f64>)
+ -> !test.test_tensor<[32, 128], f64>
+
+ // CHECK: return %[[OUT]]
+ return %out : !test.test_tensor<[32, 128], f64>
+}
diff --git a/mlir/test/Dialect/Bufferization/invalid.mlir b/mlir/test/Dialect/Bufferization/invalid.mlir
index 2c8807b66de74..86b541d95924b 100644
--- a/mlir/test/Dialect/Bufferization/invalid.mlir
+++ b/mlir/test/Dialect/Bufferization/invalid.mlir
@@ -58,14 +58,14 @@ func.func @invalid_materialize_in_destination(%arg0: tensor<5x5xf32>, %arg1: ten
// -----
func.func @invalid_materialize_in_destination_dest_type(%arg0: tensor<5xf32>, %arg1: vector<5xf32>) {
- // expected-error @below{{'dest' must be a tensor or a memref}}
+ // expected-error @below{{'dest' must be a tensor or a buffer}}
bufferization.materialize_in_destination %arg0 in %arg1 : (tensor<5xf32>, vector<5xf32>) -> ()
}
// -----
func.func @invalid_materialize_in_destination_result(%arg0: tensor<?xf32>, %arg1: memref<?xf32>) {
- // expected-error @below{{memref 'dest' implies zero results}}
+ // expected-error @below{{buffer 'dest' implies zero results}}
bufferization.materialize_in_destination %arg0 in restrict %arg1 : (tensor<?xf32>, memref<?xf32>) -> (tensor<?xf32>)
}
@@ -79,14 +79,14 @@ func.func @invalid_materialize_in_destination_result_missing(%arg0: tensor<?xf32
// -----
func.func @invalid_materialize_in_destination_restrict(%arg0: tensor<?xf32>, %arg1: tensor<?xf32>) {
- // expected-error @below{{'restrict' is valid only for memref destinations}}
+ // expected-error @below{{'restrict' is valid only for buffer destinations}}
bufferization.materialize_in_destination %arg0 in restrict %arg1 : (tensor<?xf32>, tensor<?xf32>) -> (tensor<?xf32>)
}
// -----
func.func @invalid_materialize_in_destination_restrict(%arg0: tensor<?xf32>, %arg1: tensor<?xf32>) {
- // expected-error @below{{'writable' must be specified if and only if the destination is of memref type}}
+ // expected-error @below{{'writable' must be specified if and only if the destination is of buffer type}}
bufferization.materialize_in_destination %arg0 in writable %arg1 : (tensor<?xf32>, tensor<?xf32>) -> (tensor<?xf32>)
}
diff --git a/mlir/test/lib/Dialect/Bufferization/TestTensorCopyInsertion.cpp b/mlir/test/lib/Dialect/Bufferization/TestTensorCopyInsertion.cpp
index 2991a3c165ee2..95d6158d7c67f 100644
--- a/mlir/test/lib/Dialect/Bufferization/TestTensorCopyInsertion.cpp
+++ b/mlir/test/lib/Dialect/Bufferization/TestTensorCopyInsertion.cpp
@@ -46,7 +46,9 @@ struct TestTensorCopyInsertionPass
options.bufferizeFunctionBoundaries = bufferizeFunctionBoundaries;
if (mustInferMemorySpace) {
options.defaultMemorySpaceFn =
- [](TensorType t) -> std::optional<Attribute> { return std::nullopt; };
+ [](bufferization::TensorLikeType t) -> std::optional<Attribute> {
+ return std::nullopt;
+ };
}
if (failed(bufferization::insertTensorCopies(getOperation(), options)))
signalPassFailure();
diff --git a/mlir/test/lib/Dialect/Test/TestOpDefs.cpp b/mlir/test/lib/Dialect/Test/TestOpDefs.cpp
index 454a12bac9ab3..df7586976280c 100644
--- a/mlir/test/lib/Dialect/Test/TestOpDefs.cpp
+++ b/mlir/test/lib/Dialect/Test/TestOpDefs.cpp
@@ -8,6 +8,7 @@
#include "TestDialect.h"
#include "TestOps.h"
+#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/Verifier.h"
#include "mlir/Interfaces/FunctionImplementation.h"
@@ -1386,3 +1387,26 @@ TestMultiSlotAlloca::handleDestructuringComplete(
const DestructurableMemorySlot &slot, OpBuilder &builder) {
return createNewMultiAllocaWithoutSlot(slot, builder, *this);
}
+
+::mlir::LogicalResult test::TestDummyTensorOp::bufferize(
+ ::mlir::RewriterBase &rewriter,
+ const ::mlir::bufferization::BufferizationOptions &options) {
+ const auto inType = getInput().getType();
+ const auto bufferizedInType = test::TestMemrefType::get(
+ getContext(), inType.getShape(), inType.getElementType(), nullptr);
+ const auto outType = getOutput().getType();
+ const auto bufferizedOutType = test::TestMemrefType::get(
+ getContext(), outType.getShape(), outType.getElementType(), nullptr);
+
+ // replace op with memref analogy, preserve correct types at the boundaries
+ auto toMemref = rewriter.create<::mlir::bufferization::ToMemrefOp>(
+ getLoc(), bufferizedInType, getInput());
+ auto dummyMemrefOp = rewriter.create<test::TestDummyMemrefOp>(
+ getLoc(), bufferizedOutType, toMemref.getResult());
+ auto toTensor = rewriter.create<::mlir::bufferization::ToTensorOp>(
+ getLoc(), outType, dummyMemrefOp.getOutput());
+
+ rewriter.replaceOp(*this, toTensor);
+
+ return mlir::success();
+}
diff --git a/mlir/test/lib/Dialect/Test/TestOps.h b/mlir/test/lib/Dialect/Test/TestOps.h
index f070c3bedd92c..ea8867e3fc41d 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.h
+++ b/mlir/test/lib/Dialect/Test/TestOps.h
@@ -13,6 +13,7 @@
#include "TestInterfaces.h"
#include "TestTypes.h"
#include "mlir/Bytecode/BytecodeImplementation.h"
+#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
#include "mlir/Dialect/DLTI/DLTI.h"
#include "mlir/Dialect/DLTI/Traits.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index 85a49e05d4c73..976b4963a29f7 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -30,7 +30,7 @@ include "mlir/Interfaces/InferTypeOpInterface.td"
include "mlir/Interfaces/LoopLikeInterface.td"
include "mlir/Interfaces/MemorySlotInterfaces.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
-
+include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td"
// Include the attribute definitions.
include "TestAttrDefs.td"
@@ -3499,4 +3499,57 @@ def TestAllocWithMultipleResults : TEST_Op<"alloc_with_multiple_results"> {
}];
}
+//===----------------------------------------------------------------------===//
+// Test Ops bufferization
+//===----------------------------------------------------------------------===//
+
+def TestDummyTensorOp : TEST_Op<"dummy_tensor_op", [BufferizableOpInterface]> {
+ let arguments = (ins
+ Arg<TestTensorType>:$input
+ );
+ let results = (outs
+ Arg<TestTensorType>:$output
+ );
+ let extraClassDeclaration = [{
+ // BufferizableOpInterface
+ bool bufferizesToMemoryRead(mlir::OpOperand&,
+ const mlir::bufferization::AnalysisState&);
+
+ bool bufferizesToMemoryWrite(mlir::OpOperand&,
+ const mlir::bufferization::AnalysisState&);
+
+ mlir::bufferization::AliasingValueList getAliasingValues(mlir::OpOperand&,
+ const mlir::bufferization::AnalysisState&);
+
+ mlir::LogicalResult bufferize(
+ mlir::RewriterBase& rewriter,
+ const mlir::bufferization::BufferizationOptions& options);
+ }];
+
+ let extraClassDefinition = [{
+ bool test::TestDummyTensorOp::bufferizesToMemoryRead(::mlir::OpOperand&,
+ const ::mlir::bufferization::AnalysisState&) {
+ return true;
+ }
+ bool test::TestDummyTensorOp::bufferizesToMemoryWrite(::mlir::OpOperand&,
+ const ::mlir::bufferization::AnalysisState&) {
+ return true;
+ }
+ ::mlir::bufferization::AliasingValueList
+ test::TestDummyTensorOp::getAliasingValues(::mlir::OpOperand&,
+ const ::mlir::bufferization::AnalysisState&) {
+ return {};
+ }
+ }];
+}
+
+def TestDummyMemrefOp : TEST_Op<"dummy_memref_op", []> {
+ let arguments = (ins
+ Arg<TestMemrefType>:$input
+ );
+ let results = (outs
+ Arg<TestMemrefType>:$output
+ );
+}
+
#endif // TEST_OPS
diff --git a/mlir/test/lib/Dialect/Test/TestTypeDefs.td b/mlir/test/lib/Dialect/Test/TestTypeDefs.td
index e9785594d3332..cee6888a7196c 100644
--- a/mlir/test/lib/Dialect/Test/TestTypeDefs.td
+++ b/mlir/test/lib/Dialect/Test/TestTypeDefs.td
@@ -446,6 +446,9 @@ def TestMemrefType : Test_Type<"TestMemref",
return test::TestMemrefType::get(
getContext(), shape.value_or(getShape()), elementType, getMemSpace());
}
+
+ // BufferLikeTypeInterface:
+ ::mlir::Attribute getMemorySpace() const { return getMemSpace(); }
}];
}
More information about the Mlir-commits
mailing list