[Mlir-commits] [mlir] [mlir][bufferization] Support bufferization of external functions (PR #113999)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Oct 28 21:17:46 PDT 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-bufferization
Author: Matthias Springer (matthias-springer)
<details>
<summary>Changes</summary>
This commit adds support for bufferizing external functions that have no body. Such functions were previously rejected by One-Shot Bufferize if they returned a tensor value.
This commit is in preparation of removing the deprecated `func-bufferize` pass. That pass can bufferize external functions.
Also update a few comments.
---
Full diff: https://github.com/llvm/llvm-project/pull/113999.diff
4 Files Affected:
- (modified) mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h (+6-5)
- (modified) mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp (+30-26)
- (modified) mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-invalid.mlir (-17)
- (modified) mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize.mlir (+15)
``````````diff
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
index aceb9d059b95f3..4866e31b19d5de 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
@@ -60,7 +60,8 @@ struct AliasingValue {
bool isDefinite;
};
-template <typename T> class AliasList {
+template <typename T>
+class AliasList {
public:
/// Create an empty list of aliases.
AliasList() = default;
@@ -259,7 +260,7 @@ struct BufferizationOptions {
/// Initializer function for analysis state.
using AnalysisStateInitFn = std::function<void(AnalysisState &)>;
/// Tensor -> MemRef type converter.
- /// Parameters: Value, memory space, func op, bufferization options
+ /// Parameters: tensor type, memory space, func op, bufferization options
using FunctionArgTypeConverterFn =
std::function<BaseMemRefType(TensorType, Attribute memorySpace,
func::FuncOp, const BufferizationOptions &)>;
@@ -344,9 +345,9 @@ struct BufferizationOptions {
void setFunctionBoundaryTypeConversion(LayoutMapOption layoutMapOption);
/// Type converter from tensors to memrefs. This type converter is used to
- /// determine bufferized function argument types. By default, a type
- /// converter that returns a memref type with a fully dynamic layout map is
- /// used.
+ /// determine bufferized function argument and result types. By default, a
+ /// type converter that returns a memref type with a fully dynamic layout map
+ /// is used.
///
/// If `bufferizeFunctionBoundaries` is not set, this function isn't used.
FunctionArgTypeConverterFn functionArgTypeConverterFn = nullptr;
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
index 9fbe574ec392dc..a372e87d8335f1 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
@@ -82,7 +82,8 @@ getBufferizedFunctionArgType(FuncOp funcOp, int64_t index,
/// Return the FuncOp called by `callOp`.
static FuncOp getCalledFunction(CallOpInterface callOp) {
- SymbolRefAttr sym = llvm::dyn_cast_if_present<SymbolRefAttr>(callOp.getCallableForCallee());
+ SymbolRefAttr sym =
+ llvm::dyn_cast_if_present<SymbolRefAttr>(callOp.getCallableForCallee());
if (!sym)
return nullptr;
return dyn_cast_or_null<FuncOp>(
@@ -392,11 +393,11 @@ struct FuncOpInterface
auto funcOp = cast<FuncOp>(op);
FunctionType funcType = funcOp.getFunctionType();
- // Construct the bufferized function type.
+ // Construct the bufferized function type. Compute the argument types.
SmallVector<Type> argTypes;
for (const auto &it : llvm::enumerate(funcType.getInputs())) {
Type argType = it.value();
- if (dyn_cast<TensorType>(argType)) {
+ if (isa<TensorType>(argType)) {
argTypes.push_back(
getBufferizedFunctionArgType(funcOp, it.index(), options));
continue;
@@ -404,24 +405,33 @@ struct FuncOpInterface
argTypes.push_back(argType);
}
- // Bodiless functions are assumed opaque and we cannot know the
- // bufferization contract they want to enforce. As a consequence, only
- // support functions that don't return any tensors atm.
- if (funcOp.isExternal()) {
- SmallVector<Type> retTypes;
- for (Type resultType : funcType.getResults()) {
- if (isa<TensorType>(resultType))
- return funcOp->emitError() << "cannot bufferize bodiless function "
- << "that returns a tensor";
+ // Compute the result types.
+ SmallVector<Type> retTypes;
+ for (Type resultType : funcType.getResults()) {
+ if (auto tensorType = dyn_cast<TensorType>(resultType)) {
+ BaseMemRefType resultType = options.functionArgTypeConverterFn(
+ tensorType, *options.defaultMemorySpaceFn(tensorType), funcOp,
+ options);
retTypes.push_back(resultType);
+ continue;
}
- funcOp.setType(FunctionType::get(op->getContext(), argTypes, retTypes));
+ retTypes.push_back(resultType);
+ }
+
+ // Compute the new function type.
+ auto newFuncType = FunctionType::get(op->getContext(), argTypes, retTypes);
+
+ // If the function has no body, set the new function type and we are done.
+ if (funcOp.isExternal()) {
+ funcOp.setType(newFuncType);
return success();
}
// TODO: Support functions with multiple returns.
func::ReturnOp returnOp = getAssumedUniqueReturnOp(funcOp);
assert(returnOp && "expected func with single return op");
+ assert(returnOp->getNumOperands() == retTypes.size() &&
+ "incorrect number of return values");
Location loc = returnOp.getLoc();
// 1. Bufferize every block.
@@ -430,10 +440,10 @@ struct FuncOpInterface
options)))
return failure();
- // 2. For each result, keep track of which inplace argument it reuses.
+ // 2. Bufferize all operands of the return op.
SmallVector<Value> returnValues;
- for (OpOperand &returnOperand : returnOp->getOpOperands()) {
- Value returnVal = returnOperand.get();
+ for (auto [returnVal, bufferizedType] :
+ llvm::zip_equal(returnOp->getOperands(), retTypes)) {
auto tensorType = dyn_cast<TensorType>(returnVal.getType());
rewriter.setInsertionPoint(returnOp);
@@ -443,23 +453,17 @@ struct FuncOpInterface
continue;
}
- // Note: If `inferFunctionResultLayout = true`, cast are later folded
+ // Note: If `inferFunctionResultLayout = true`, casts are later folded
// away.
- BaseMemRefType resultType = options.functionArgTypeConverterFn(
- tensorType, *options.defaultMemorySpaceFn(tensorType), funcOp,
- options);
Value toMemrefOp = rewriter.create<bufferization::ToMemrefOp>(
- loc, resultType, returnVal);
+ loc, bufferizedType, returnVal);
returnValues.push_back(toMemrefOp);
}
- // 3. Rewrite the terminator without the in-place bufferizable values.
returnOp.getOperandsMutable().assign(returnValues);
- // 4. Rewrite the FuncOp type to buffer form.
- funcOp.setType(FunctionType::get(op->getContext(), argTypes,
- ValueRange(returnValues).getTypes()));
-
+ // 3. Set the new function type.
+ funcOp.setType(newFuncType);
return success();
}
diff --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-invalid.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-invalid.mlir
index ee0f71f668dc74..2829eafb7c1c59 100644
--- a/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-invalid.mlir
+++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-invalid.mlir
@@ -1,11 +1,5 @@
// RUN: mlir-opt %s -allow-unregistered-dialect -one-shot-bufferize="bufferize-function-boundaries=1" -split-input-file -verify-diagnostics
-// expected-error @+2 {{cannot bufferize bodiless function that returns a tensor}}
-// expected-error @+1 {{failed to bufferize op}}
-func.func private @foo() -> tensor<?xf32>
-
-// -----
-
// expected-error @+1 {{cannot bufferize a FuncOp with tensors and without a unique ReturnOp}}
func.func @swappy(%cond1 : i1, %cond2 : i1, %t1 : tensor<f32>, %t2 : tensor<f32>)
-> (tensor<f32>, tensor<f32>)
@@ -123,17 +117,6 @@ func.func @to_tensor_op_unsupported(%m: memref<?xf32>, %idx: index) -> (f32) {
// -----
-// expected-error @+2 {{failed to bufferize op}}
-// expected-error @+1 {{cannot bufferize bodiless function that returns a tensor}}
-func.func private @foo(%t : tensor<?xf32>) -> (f32, tensor<?xf32>, f32)
-
-func.func @call_to_unknown_tensor_returning_func(%t : tensor<?xf32>) {
- call @foo(%t) : (tensor<?xf32>) -> (f32, tensor<?xf32>, f32)
- return
-}
-
-// -----
-
func.func @yield_alloc_dominance_test_2(%cst : f32, %idx : index,
%idx2 : index) -> f32 {
%1 = bufferization.alloc_tensor(%idx) : tensor<?xf32>
diff --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize.mlir
index 0d5224514e3a02..d31b43477beb9f 100644
--- a/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize.mlir
+++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize.mlir
@@ -42,6 +42,21 @@ func.func private @external_func_with_return_val(tensor<4xi32>) -> f32
// -----
+// Bufferization of bodiless function that returns a tensor.
+
+// CHECK: func.func private @foo(memref<?xf32, strided<[?], offset: ?>>) -> (f32, memref<?xf32, strided<[?], offset: ?>>, f32)
+func.func private @foo(%t : tensor<?xf32>) -> (f32, tensor<?xf32>, f32)
+
+// CHECK: func.func @call_to_unknown_tensor_returning_func(
+// CHECK-SAME: %[[arg0:.*]]: memref<?xf32, strided<[?], offset: ?>>) {
+func.func @call_to_unknown_tensor_returning_func(%t : tensor<?xf32>) {
+ // CHECK: call @foo(%[[arg0]]) : (memref<?xf32, strided<[?], offset: ?>>) -> (f32, memref<?xf32, strided<[?], offset: ?>>, f32)
+ call @foo(%t) : (tensor<?xf32>) -> (f32, tensor<?xf32>, f32)
+ return
+}
+
+// -----
+
// A function that returns a non-equivalent tensor with layout map.
// CHECK-LABEL: func @return_extract_slice(%{{.*}}) -> memref<2x?xf32, strided<[10, 1], offset: ?>>
``````````
</details>
https://github.com/llvm/llvm-project/pull/113999
More information about the Mlir-commits
mailing list