[Mlir-commits] [mlir] [mlir][bufferization] Support custom types at function boundaries (PR #159766)
Andrei Golubev
llvmlistbot at llvm.org
Mon Sep 22 04:32:28 PDT 2025
https://github.com/andrey-golubev updated https://github.com/llvm/llvm-project/pull/159766
>From 9ba18926c8c9d79069dbc86f6eee937f2e0bbc53 Mon Sep 17 00:00:00 2001
From: "Golubev, Andrey" <andrey.golubev at intel.com>
Date: Wed, 9 Jul 2025 09:56:39 +0000
Subject: [PATCH 1/2] [mlir][bufferization] Support custom types at function
boundaries
Support custom types (3/N): allow custom tensor and buffer types in
function signatures and at call-sites. This is one of the major building
blocks to move in the direction of module-level one-shot-bufferization
support.
In order to enable this, TensorLikeType is extended with a new interface
method that is invoked solely within the function boundary
bufferization.
---
.../IR/BufferizationTypeInterfaces.h | 1 +
.../IR/BufferizationTypeInterfaces.td | 12 +++
.../Bufferization/IR/BufferizationDialect.cpp | 13 +++
.../Bufferization/Transforms/Bufferize.cpp | 2 +-
.../FuncBufferizableOpInterfaceImpl.cpp | 90 ++++++++++---------
.../Transforms/one-shot-module-bufferize.mlir | 56 ++++++++++++
mlir/test/lib/Dialect/Test/TestTypeDefs.td | 5 ++
mlir/test/lib/Dialect/Test/TestTypes.cpp | 8 ++
8 files changed, 146 insertions(+), 41 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.h
index a2bfcb7ed2b75..9b052b8bb7e14 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.h
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.h
@@ -13,6 +13,7 @@
// Bufferization Type Interfaces
//===----------------------------------------------------------------------===//
+#include "mlir/Dialect/Func/IR/FuncOps.h" // to access mlir::func::FuncOp
#include "mlir/IR/Diagnostics.h"
#include "mlir/IR/Types.h"
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.td b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.td
index fb6fc4f5ad964..c4235cd067999 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.td
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.td
@@ -43,6 +43,18 @@ def Bufferization_TensorLikeTypeInterface
/*args=*/(ins
"::mlir::bufferization::BufferLikeType":$bufferType,
"::llvm::function_ref<::mlir::InFlightDiagnostic()>":$emitError)
+ >,
+ InterfaceMethod<[{
+ Returns a BufferLike type for this TensorLike type in the context of
+ this type being function argument or result.
+ }],
+ /*retTy=*/"::mlir::FailureOr<::mlir::bufferization::BufferLikeType>",
+ /*methodName=*/"getBufferTypeAtFunctionBoundary",
+ /*args=*/(ins
+ "::mlir::func::FuncOp":$funcOp,
+ "const ::mlir::bufferization::BufferizationOptions &":$options,
+ "::llvm::function_ref<::mlir::InFlightDiagnostic()>":$emitError
+ )
>
];
}
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizationDialect.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizationDialect.cpp
index 6c08cdfb669f3..9b907922a24c4 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizationDialect.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizationDialect.cpp
@@ -87,6 +87,19 @@ struct BuiltinTensorExternalModel
return mlir::success();
}
+
+ llvm::FailureOr<BufferLikeType> getBufferTypeAtFunctionBoundary(
+ mlir::Type tensor, mlir::func::FuncOp funcOp,
+ const BufferizationOptions &options,
+ llvm::function_ref<mlir::InFlightDiagnostic()> emitError) const {
+ auto tensorType = cast<TensorType>(tensor);
+ auto memSpace = options.defaultMemorySpaceFn(tensorType);
+ if (!memSpace.has_value())
+ return emitError() << "could not infer memory space";
+
+ return cast<BufferLikeType>(options.functionArgTypeConverterFn(
+ tensorType, *memSpace, funcOp, options));
+ }
};
template <typename MemRef>
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
index 68ef51992efee..701ab52a491a8 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
@@ -401,7 +401,7 @@ bufferization::bufferizeBlockSignature(Block *block, RewriterBase &rewriter,
// Compute the new signature.
SmallVector<Type> newTypes;
for (BlockArgument &bbArg : block->getArguments()) {
- auto tensorType = dyn_cast<TensorType>(bbArg.getType());
+ auto tensorType = dyn_cast<TensorLikeType>(bbArg.getType());
if (!tensorType) {
newTypes.push_back(bbArg.getType());
continue;
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
index f69efd1b3fa8c..b7bac9f4623f1 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
@@ -52,26 +52,35 @@ void FuncAnalysisState::startFunctionAnalysis(FuncOp funcOp) {
/// Return the index-th bufferized function argument type. This assumes that the
/// specified argument is a tensor. If the tensor is ranked, a layout map may be
/// specified by the user (as per `options.functionArgTypeConverterFn`).
-static BaseMemRefType
+static BufferLikeType
getBufferizedFunctionArgType(FuncOp funcOp, int64_t index,
const BufferizationOptions &options) {
auto tensorType =
- dyn_cast<TensorType>(funcOp.getFunctionType().getInput(index));
- assert(tensorType && "expected TensorType");
-
- BaseMemRefType memrefType = options.functionArgTypeConverterFn(
- tensorType, *options.defaultMemorySpaceFn(tensorType), funcOp, options);
-
- auto layoutAttr = funcOp.getArgAttrOfType<MemRefLayoutAttrInterface>(
- index, BufferizationDialect::kBufferLayoutAttrName);
- if (!layoutAttr)
- return memrefType;
-
- auto rankedMemrefType = dyn_cast<MemRefType>(memrefType);
- assert(rankedMemrefType && "buffer layout not supported on unranked tensors");
- return MemRefType::get(rankedMemrefType.getShape(),
- rankedMemrefType.getElementType(), layoutAttr,
- rankedMemrefType.getMemorySpace());
+ dyn_cast<TensorLikeType>(funcOp.getFunctionType().getInput(index));
+ assert(tensorType && "expected TensorLikeType");
+ auto maybeBufferType = tensorType.getBufferTypeAtFunctionBoundary(
+ funcOp, options, [&]() { return funcOp->emitError(); });
+ assert(mlir::succeeded(maybeBufferType) &&
+ "a valid buffer is always expected");
+
+ auto bufferType = *maybeBufferType;
+
+ // Note: For builtin tensors there is additional logic related to layout.
+ if (isa<TensorType>(tensorType)) {
+ auto layoutAttr = funcOp.getArgAttrOfType<MemRefLayoutAttrInterface>(
+ index, BufferizationDialect::kBufferLayoutAttrName);
+ if (!layoutAttr)
+ return bufferType;
+
+ auto rankedMemrefType = dyn_cast<MemRefType>(bufferType);
+ assert(rankedMemrefType &&
+ "buffer layout not supported on unranked tensors");
+ return cast<BufferLikeType>(MemRefType::get(
+ rankedMemrefType.getShape(), rankedMemrefType.getElementType(),
+ layoutAttr, rankedMemrefType.getMemorySpace()));
+ }
+
+ return bufferType;
}
/// Return the FuncOp called by `callOp`.
@@ -227,14 +236,13 @@ struct CallOpInterface
FunctionType funcType = funcOp.getFunctionType();
Type resultType =
funcType.getResult(cast<OpResult>(value).getResultNumber());
- if (auto bufferizedType = dyn_cast<BaseMemRefType>(resultType))
- return cast<BufferLikeType>(bufferizedType);
+ if (auto bufferizedType = dyn_cast<BufferLikeType>(resultType))
+ return bufferizedType;
// Otherwise, call the type converter to compute the bufferized type.
- auto tensorType = cast<TensorType>(resultType);
- return cast<BufferLikeType>(options.functionArgTypeConverterFn(
- tensorType, *options.defaultMemorySpaceFn(tensorType), funcOp,
- options));
+ auto tensorType = cast<TensorLikeType>(resultType);
+ return tensorType.getBufferTypeAtFunctionBoundary(
+ funcOp, options, [&]() { return funcOp->emitError(); });
}
/// All function arguments are writable. It is the responsibility of the
@@ -248,7 +256,7 @@ struct CallOpInterface
SmallVector<Type> resultTypes;
for (Value result : callOp.getResults()) {
Type returnType = result.getType();
- if (!isa<TensorType>(returnType)) {
+ if (!isa<TensorLikeType>(returnType)) {
// Non-tensor values are returned.
resultTypes.push_back(returnType);
continue;
@@ -272,7 +280,7 @@ struct CallOpInterface
for (OpOperand &opOperand : callOp->getOpOperands()) {
// Non-tensor operands are just copied.
- if (!isa<TensorType>(opOperand.get().getType())) {
+ if (!isa<TensorLikeType>(opOperand.get().getType())) {
newOperands.push_back(opOperand.get());
continue;
}
@@ -285,8 +293,8 @@ struct CallOpInterface
Value buffer = *maybeBuffer;
// Caller / callee type mismatch is handled with castOrReallocMemRefValue.
- auto memRefType = funcType.getInput(opOperand.getOperandNumber());
- if (!isa<BaseMemRefType>(memRefType)) {
+ auto bufferType = funcType.getInput(opOperand.getOperandNumber());
+ if (!isa<BufferLikeType>(bufferType)) {
// The called function was not bufferized yet. This can happen when
// there cycles in the function call graph. Compute the bufferized
// result type.
@@ -296,7 +304,7 @@ struct CallOpInterface
state);
if (failed(maybeBufferType))
return failure();
- memRefType = *maybeBufferType;
+ bufferType = *maybeBufferType;
}
// Since we don't yet have a clear layout story, to_buffer may
@@ -305,8 +313,8 @@ struct CallOpInterface
// that will either canonicalize away or fail compilation until we can do
// something better. Insert a reallocation + copy if it cannot be
// statically guaranteed that a direct cast would be valid.
- if (buffer.getType() != memRefType) {
- auto memrefDstType = dyn_cast<MemRefType>(memRefType);
+ if (buffer.getType() != bufferType) {
+ auto memrefDstType = dyn_cast<MemRefType>(bufferType);
assert(memrefDstType &&
"buffer layout not supported on unranked tensors");
FailureOr<Value> replacement = bufferization::castOrReallocMemRefValue(
@@ -370,7 +378,7 @@ struct FuncOpInterface
static bool supportsUnstructuredControlFlow() { return true; }
bool hasTensorSemantics(Operation *op) const {
- auto isaTensor = llvm::IsaPred<TensorType>;
+ auto isaTensor = llvm::IsaPred<TensorLikeType>;
// A function has tensor semantics if it has tensor arguments/results.
auto funcOp = cast<FuncOp>(op);
@@ -406,8 +414,8 @@ struct FuncOpInterface
// Function arguments are special.
if (bbArg.getOwner() == &funcOp.getBody().front())
- return cast<BufferLikeType>(
- getBufferizedFunctionArgType(funcOp, bbArg.getArgNumber(), options));
+ return getBufferizedFunctionArgType(funcOp, bbArg.getArgNumber(),
+ options);
return OpWithUnstructuredControlFlowBufferizableOpInterfaceExternalModel::
getBufferType(op, value, options, state, invocationStack);
@@ -430,7 +438,7 @@ struct FuncOpInterface
SmallVector<Type> argTypes;
for (const auto &it : llvm::enumerate(funcType.getInputs())) {
Type argType = it.value();
- if (isa<TensorType>(argType)) {
+ if (isa<TensorLikeType>(argType)) {
argTypes.push_back(
getBufferizedFunctionArgType(funcOp, it.index(), options));
continue;
@@ -441,11 +449,13 @@ struct FuncOpInterface
// Compute the result types.
SmallVector<Type> retTypes;
for (Type resultType : funcType.getResults()) {
- if (auto tensorType = dyn_cast<TensorType>(resultType)) {
- BaseMemRefType resultType = options.functionArgTypeConverterFn(
- tensorType, *options.defaultMemorySpaceFn(tensorType), funcOp,
- options);
- retTypes.push_back(resultType);
+ if (auto tensorType = dyn_cast<TensorLikeType>(resultType)) {
+ FailureOr<BufferLikeType> resultType =
+ tensorType.getBufferTypeAtFunctionBoundary(
+ funcOp, options, [&]() { return funcOp->emitError(); });
+ assert(mlir::succeeded(resultType) &&
+ "a valid buffer is always expected");
+ retTypes.push_back(*resultType);
continue;
}
retTypes.push_back(resultType);
@@ -473,7 +483,7 @@ struct FuncOpInterface
SmallVector<Value> returnValues;
for (auto [returnVal, bufferizedType] :
llvm::zip_equal(returnOp->getOperands(), retTypes)) {
- auto tensorType = dyn_cast<TensorType>(returnVal.getType());
+ auto tensorType = dyn_cast<TensorLikeType>(returnVal.getType());
rewriter.setInsertionPoint(returnOp);
// If not a tensor type just forward it.
diff --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize.mlir
index 2efb5893c8511..eb0093106dc11 100644
--- a/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize.mlir
+++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize.mlir
@@ -810,3 +810,59 @@ module @inner_module {
return %t : tensor<5xf32>
}
}
+
+// -----
+
+// CHECK: func.func @custom_types(
+// CHECK-SAME: %[[arg:.*]]: !test.test_memref<[4, 4], f64>
+// CHECK-SAME: ) -> (!test.test_memref<[4, 8], f64>,
+// CHECK-SAME: !test.test_memref<[4, 8], f64>)
+func.func @custom_types(%arg: !test.test_tensor<[4, 4], f64>)
+ -> (!test.test_tensor<[4, 8], f64>, !test.test_tensor<[4, 8], f64>) {
+ // CHECK: %[[out1:.*]] = "test.dummy_memref_op"(%[[arg]]) :
+ // CHECK-SAME: (!test.test_memref<[4, 4], f64>) -> !test.test_memref<[4, 8], f64>
+ %out1 = "test.dummy_tensor_op"(%arg) : (!test.test_tensor<[4, 4], f64>)
+ -> !test.test_tensor<[4, 8], f64>
+
+ // CHECK: %[[alloc:.*]] = "test.create_memref_op"
+ // CHECK: %[[out2:.*]] = "test.dummy_memref_op"(%[[alloc]])
+ // CHECK-SAME: (!test.test_memref<[4, 4], f64>) -> !test.test_memref<[4, 8], f64>
+ %alloc = "test.create_tensor_op"() : () -> !test.test_tensor<[4, 4], f64>
+ %out2 = "test.dummy_tensor_op"(%alloc) : (!test.test_tensor<[4, 4], f64>)
+ -> !test.test_tensor<[4, 8], f64>
+
+ // CHECK: return %[[out1]], %[[out2]]
+ return %out1, %out2 :
+ !test.test_tensor<[4, 8], f64>, !test.test_tensor<[4, 8], f64>
+}
+
+// -----
+
+// CHECK: func.func @custom_types_foo(
+// CHECK-SAME: %[[arg:.*]]: !test.test_memref<[4, 4], f64>
+// CHECK-SAME: ) -> !test.test_memref<[4, 4], f64>
+func.func @custom_types_foo(%arg: !test.test_tensor<[4, 4], f64>)
+ -> !test.test_tensor<[4, 4], f64> {
+ // CHECK: %[[out:.*]] = "test.dummy_memref_op"(%[[arg]])
+ %out = "test.dummy_tensor_op"(%arg) : (!test.test_tensor<[4, 4], f64>)
+ -> !test.test_tensor<[4, 4], f64>
+ // CHECK: return %[[out]]
+ return %out : !test.test_tensor<[4, 4], f64>
+}
+
+// CHECK: func.func @custom_types_bar(
+// CHECK-SAME: %[[arg:.*]]: !test.test_memref<[4, 4], f64>
+// CHECK-SAME: ) -> !test.test_memref<[4, 8], f64>
+func.func @custom_types_bar(%arg: !test.test_tensor<[4, 4], f64>)
+ -> !test.test_tensor<[4, 8], f64> {
+ // CHECK: %[[call:.*]] = call @custom_types_foo(%[[arg]])
+ %call = func.call @custom_types_foo(%arg) : (!test.test_tensor<[4, 4], f64>)
+ -> !test.test_tensor<[4, 4], f64>
+
+ // CHECK: %[[out:.*]] = "test.dummy_memref_op"(%[[call]])
+ %out = "test.dummy_tensor_op"(%call) : (!test.test_tensor<[4, 4], f64>)
+ -> !test.test_tensor<[4, 8], f64>
+
+ // CHECK: return %[[out]]
+ return %out : !test.test_tensor<[4, 8], f64>
+}
diff --git a/mlir/test/lib/Dialect/Test/TestTypeDefs.td b/mlir/test/lib/Dialect/Test/TestTypeDefs.td
index ea20597231d58..562fc66acea2a 100644
--- a/mlir/test/lib/Dialect/Test/TestTypeDefs.td
+++ b/mlir/test/lib/Dialect/Test/TestTypeDefs.td
@@ -444,6 +444,11 @@ def TestTensorType : Test_Type<"TestTensor",
::mlir::LogicalResult verifyCompatibleBufferType(
::mlir::bufferization::BufferLikeType bufferType,
::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError);
+
+ ::mlir::FailureOr<::mlir::bufferization::BufferLikeType>
+ getBufferTypeAtFunctionBoundary(mlir::func::FuncOp funcOp,
+ const ::mlir::bufferization::BufferizationOptions& options,
+ ::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError);
}];
}
diff --git a/mlir/test/lib/Dialect/Test/TestTypes.cpp b/mlir/test/lib/Dialect/Test/TestTypes.cpp
index bea043f56fe21..3c92fb94aebee 100644
--- a/mlir/test/lib/Dialect/Test/TestTypes.cpp
+++ b/mlir/test/lib/Dialect/Test/TestTypes.cpp
@@ -573,3 +573,11 @@ ::mlir::LogicalResult TestTensorType::verifyCompatibleBufferType(
getElementType() == testMemref.getElementType();
return mlir::success(valid);
}
+
+::mlir::FailureOr<::mlir::bufferization::BufferLikeType>
+TestTensorType::getBufferTypeAtFunctionBoundary(
+ mlir::func::FuncOp,
+ const ::mlir::bufferization::BufferizationOptions &options,
+ ::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError) {
+ return getBufferType(options, emitError);
+}
>From d7373f3e12a764e4e0966b8f2cdbf8b0eec3cca5 Mon Sep 17 00:00:00 2001
From: "Golubev, Andrey" <andrey.golubev at intel.com>
Date: Mon, 22 Sep 2025 11:22:53 +0000
Subject: [PATCH 2/2] Rely on BufferizationOptions::FunctionArgTypeConverterFn
Transform FunctionArgTypeConverterFn into a tensor-like -> buffer-like
converter so that it could be used as a generic function boundary
conversion utility.
---
.../IR/BufferizableOpInterface.h | 16 ++++---
.../IR/BufferizationTypeInterfaces.h | 1 -
.../IR/BufferizationTypeInterfaces.td | 12 -----
.../IR/BufferizableOpInterface.cpp | 39 +++++++++++----
.../Bufferization/IR/BufferizationDialect.cpp | 13 -----
.../FuncBufferizableOpInterfaceImpl.cpp | 48 +++++++++++--------
mlir/test/lib/Dialect/Test/TestTypeDefs.td | 5 --
mlir/test/lib/Dialect/Test/TestTypes.cpp | 8 ----
8 files changed, 67 insertions(+), 75 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
index f3b34f9fded7f..5bf3916630158 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
@@ -260,10 +260,10 @@ struct BufferizationOptions {
std::function<LogicalResult(OpBuilder &, Location, Value, Value)>;
/// Initializer function for analysis state.
using AnalysisStateInitFn = std::function<void(AnalysisState &)>;
- /// Tensor -> MemRef type converter.
- /// Parameters: tensor type, memory space, func op, bufferization options
+ /// Tensor-like -> Buffer-like type converter.
+ /// Parameters: tensor-like type, memory space, func op, bufferization options
using FunctionArgTypeConverterFn =
- std::function<BaseMemRefType(TensorType, Attribute memorySpace,
+ std::function<BufferLikeType(TensorLikeType, Attribute memorySpace,
func::FuncOp, const BufferizationOptions &)>;
/// Tensor -> MemRef type converter.
/// Parameters: tensor type, memory space, bufferization options
@@ -335,10 +335,12 @@ struct BufferizationOptions {
/// predictable.
void setFunctionBoundaryTypeConversion(LayoutMapOption layoutMapOption);
- /// Type converter from tensors to memrefs. This type converter is used to
- /// determine bufferized function argument and result types. By default, a
- /// type converter that returns a memref type with a fully dynamic layout map
- /// is used.
+ /// Type converter from tensors to buffers. This type converter is used to
+ /// determine bufferized function argument and result types.
+ ///
+ /// By default, if tensor is a (builtin) tensor type, a type converter that
+ /// returns a memref type with a fully dynamic layout map is used; if tensor
+ /// is a (generic) tensor-like type, TensorLikeType::getBufferType() is used.
///
/// If `bufferizeFunctionBoundaries` is not set, this function isn't used.
FunctionArgTypeConverterFn functionArgTypeConverterFn = nullptr;
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.h
index 9b052b8bb7e14..a2bfcb7ed2b75 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.h
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.h
@@ -13,7 +13,6 @@
// Bufferization Type Interfaces
//===----------------------------------------------------------------------===//
-#include "mlir/Dialect/Func/IR/FuncOps.h" // to access mlir::func::FuncOp
#include "mlir/IR/Diagnostics.h"
#include "mlir/IR/Types.h"
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.td b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.td
index c4235cd067999..fb6fc4f5ad964 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.td
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.td
@@ -43,18 +43,6 @@ def Bufferization_TensorLikeTypeInterface
/*args=*/(ins
"::mlir::bufferization::BufferLikeType":$bufferType,
"::llvm::function_ref<::mlir::InFlightDiagnostic()>":$emitError)
- >,
- InterfaceMethod<[{
- Returns a BufferLike type for this TensorLike type in the context of
- this type being function argument or result.
- }],
- /*retTy=*/"::mlir::FailureOr<::mlir::bufferization::BufferLikeType>",
- /*methodName=*/"getBufferTypeAtFunctionBoundary",
- /*args=*/(ins
- "::mlir::func::FuncOp":$funcOp,
- "const ::mlir::bufferization::BufferizationOptions &":$options,
- "::llvm::function_ref<::mlir::InFlightDiagnostic()>":$emitError
- )
>
];
}
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
index f7b0b87085f3d..fae1df69ed3e3 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
@@ -338,11 +338,21 @@ bool OpFilter::isOpAllowed(Operation *op) const {
namespace {
/// Default function arg type converter: Use a fully dynamic layout map.
-BaseMemRefType
-defaultFunctionArgTypeConverter(TensorType type, Attribute memorySpace,
+BufferLikeType
+defaultFunctionArgTypeConverter(TensorLikeType type, Attribute memorySpace,
func::FuncOp funcOp,
const BufferizationOptions &options) {
- return getMemRefTypeWithFullyDynamicLayout(type, memorySpace);
+ if (auto tensorType = mlir::dyn_cast<TensorType>(type)) {
+ return cast<BufferLikeType>(
+ getMemRefTypeWithFullyDynamicLayout(tensorType, memorySpace));
+ }
+
+ // If not builtin, fallback to TensorLikeType::getBufferType()
+ auto bufferType =
+ type.getBufferType(options, [&]() { return funcOp->emitError(); });
+ assert(mlir::succeeded(bufferType) &&
+ "a valid buffer is always expected at function boundary");
+ return *bufferType;
}
/// Default unknown type converter: Use a fully dynamic layout map.
BaseMemRefType
@@ -385,14 +395,25 @@ BufferizationOptions::dynCastBufferizableOp(Value value) const {
void BufferizationOptions::setFunctionBoundaryTypeConversion(
LayoutMapOption layoutMapOption) {
- functionArgTypeConverterFn = [=](TensorType tensorType, Attribute memorySpace,
+ functionArgTypeConverterFn = [=](TensorLikeType type, Attribute memorySpace,
func::FuncOp funcOp,
const BufferizationOptions &options) {
- if (layoutMapOption == LayoutMapOption::IdentityLayoutMap)
- return bufferization::getMemRefTypeWithStaticIdentityLayout(tensorType,
- memorySpace);
- return bufferization::getMemRefTypeWithFullyDynamicLayout(tensorType,
- memorySpace);
+ if (auto tensorType = mlir::dyn_cast<TensorType>(type)) {
+ if (layoutMapOption == LayoutMapOption::IdentityLayoutMap)
+ return cast<BufferLikeType>(
+ bufferization::getMemRefTypeWithStaticIdentityLayout(tensorType,
+ memorySpace));
+ return cast<BufferLikeType>(
+ bufferization::getMemRefTypeWithFullyDynamicLayout(tensorType,
+ memorySpace));
+ }
+
+ // If not builtin, fallback to TensorLikeType::getBufferType()
+ auto bufferType =
+ type.getBufferType(options, [&]() { return funcOp->emitError(); });
+ assert(mlir::succeeded(bufferType) &&
+ "a valid buffer is always expected at function boundary");
+ return *bufferType;
};
inferFunctionResultLayout =
layoutMapOption == LayoutMapOption::InferLayoutMap;
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizationDialect.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizationDialect.cpp
index 9b907922a24c4..6c08cdfb669f3 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizationDialect.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizationDialect.cpp
@@ -87,19 +87,6 @@ struct BuiltinTensorExternalModel
return mlir::success();
}
-
- llvm::FailureOr<BufferLikeType> getBufferTypeAtFunctionBoundary(
- mlir::Type tensor, mlir::func::FuncOp funcOp,
- const BufferizationOptions &options,
- llvm::function_ref<mlir::InFlightDiagnostic()> emitError) const {
- auto tensorType = cast<TensorType>(tensor);
- auto memSpace = options.defaultMemorySpaceFn(tensorType);
- if (!memSpace.has_value())
- return emitError() << "could not infer memory space";
-
- return cast<BufferLikeType>(options.functionArgTypeConverterFn(
- tensorType, *memSpace, funcOp, options));
- }
};
template <typename MemRef>
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
index b7bac9f4623f1..d9d69342e42a8 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
@@ -49,30 +49,38 @@ void FuncAnalysisState::startFunctionAnalysis(FuncOp funcOp) {
#endif // NDEBUG
}
+// Note: this is a local adaptor to unify TensorType and TensorLikeType code
+// paths that both work with BufferizationOptions.
+static mlir::Attribute
+getDefaultMemorySpace(const BufferizationOptions &options,
+ TensorLikeType type) {
+ if (auto tensorType = dyn_cast<TensorType>(type)) {
+ return *options.defaultMemorySpaceFn(tensorType);
+ }
+ return nullptr;
+}
+
/// Return the index-th bufferized function argument type. This assumes that the
/// specified argument is a tensor. If the tensor is ranked, a layout map may be
/// specified by the user (as per `options.functionArgTypeConverterFn`).
static BufferLikeType
getBufferizedFunctionArgType(FuncOp funcOp, int64_t index,
const BufferizationOptions &options) {
- auto tensorType =
+ auto type =
dyn_cast<TensorLikeType>(funcOp.getFunctionType().getInput(index));
- assert(tensorType && "expected TensorLikeType");
- auto maybeBufferType = tensorType.getBufferTypeAtFunctionBoundary(
- funcOp, options, [&]() { return funcOp->emitError(); });
- assert(mlir::succeeded(maybeBufferType) &&
- "a valid buffer is always expected");
-
- auto bufferType = *maybeBufferType;
+ assert(type && "expected TensorLikeType");
// Note: For builtin tensors there is additional logic related to layout.
- if (isa<TensorType>(tensorType)) {
+ if (auto tensorType = dyn_cast<TensorType>(type)) {
+ BufferLikeType memrefType = options.functionArgTypeConverterFn(
+ type, *options.defaultMemorySpaceFn(tensorType), funcOp, options);
+
auto layoutAttr = funcOp.getArgAttrOfType<MemRefLayoutAttrInterface>(
index, BufferizationDialect::kBufferLayoutAttrName);
if (!layoutAttr)
- return bufferType;
+ return memrefType;
- auto rankedMemrefType = dyn_cast<MemRefType>(bufferType);
+ auto rankedMemrefType = dyn_cast<MemRefType>(memrefType);
assert(rankedMemrefType &&
"buffer layout not supported on unranked tensors");
return cast<BufferLikeType>(MemRefType::get(
@@ -80,7 +88,8 @@ getBufferizedFunctionArgType(FuncOp funcOp, int64_t index,
layoutAttr, rankedMemrefType.getMemorySpace()));
}
- return bufferType;
+ return options.functionArgTypeConverterFn(type, /*memSpace=*/nullptr, funcOp,
+ options);
}
/// Return the FuncOp called by `callOp`.
@@ -241,8 +250,9 @@ struct CallOpInterface
// Otherwise, call the type converter to compute the bufferized type.
auto tensorType = cast<TensorLikeType>(resultType);
- return tensorType.getBufferTypeAtFunctionBoundary(
- funcOp, options, [&]() { return funcOp->emitError(); });
+ return cast<BufferLikeType>(options.functionArgTypeConverterFn(
+ tensorType, getDefaultMemorySpace(options, tensorType), funcOp,
+ options));
}
/// All function arguments are writable. It is the responsibility of the
@@ -450,12 +460,10 @@ struct FuncOpInterface
SmallVector<Type> retTypes;
for (Type resultType : funcType.getResults()) {
if (auto tensorType = dyn_cast<TensorLikeType>(resultType)) {
- FailureOr<BufferLikeType> resultType =
- tensorType.getBufferTypeAtFunctionBoundary(
- funcOp, options, [&]() { return funcOp->emitError(); });
- assert(mlir::succeeded(resultType) &&
- "a valid buffer is always expected");
- retTypes.push_back(*resultType);
+ BufferLikeType resultType = options.functionArgTypeConverterFn(
+ tensorType, getDefaultMemorySpace(options, tensorType), funcOp,
+ options);
+ retTypes.push_back(resultType);
continue;
}
retTypes.push_back(resultType);
diff --git a/mlir/test/lib/Dialect/Test/TestTypeDefs.td b/mlir/test/lib/Dialect/Test/TestTypeDefs.td
index 562fc66acea2a..ea20597231d58 100644
--- a/mlir/test/lib/Dialect/Test/TestTypeDefs.td
+++ b/mlir/test/lib/Dialect/Test/TestTypeDefs.td
@@ -444,11 +444,6 @@ def TestTensorType : Test_Type<"TestTensor",
::mlir::LogicalResult verifyCompatibleBufferType(
::mlir::bufferization::BufferLikeType bufferType,
::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError);
-
- ::mlir::FailureOr<::mlir::bufferization::BufferLikeType>
- getBufferTypeAtFunctionBoundary(mlir::func::FuncOp funcOp,
- const ::mlir::bufferization::BufferizationOptions& options,
- ::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError);
}];
}
diff --git a/mlir/test/lib/Dialect/Test/TestTypes.cpp b/mlir/test/lib/Dialect/Test/TestTypes.cpp
index 3c92fb94aebee..bea043f56fe21 100644
--- a/mlir/test/lib/Dialect/Test/TestTypes.cpp
+++ b/mlir/test/lib/Dialect/Test/TestTypes.cpp
@@ -573,11 +573,3 @@ ::mlir::LogicalResult TestTensorType::verifyCompatibleBufferType(
getElementType() == testMemref.getElementType();
return mlir::success(valid);
}
-
-::mlir::FailureOr<::mlir::bufferization::BufferLikeType>
-TestTensorType::getBufferTypeAtFunctionBoundary(
- mlir::func::FuncOp,
- const ::mlir::bufferization::BufferizationOptions &options,
- ::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError) {
- return getBufferType(options, emitError);
-}
More information about the Mlir-commits
mailing list