[Mlir-commits] [mlir] [mlir][bufferization] Add support for recursive function calls (PR #114003)
Matthias Springer
llvmlistbot at llvm.org
Mon Oct 28 22:30:08 PDT 2024
https://github.com/matthias-springer created https://github.com/llvm/llvm-project/pull/114003
This commit adds support for recursive function calls to One-Shot Bufferize.
The analysis does not support recursive function calls. The function body itself can be analyzed, but we cannot make any assumptions about the aliasing relation between function result and function arguments. Similarly, when looking at a `call` op, we do not know whether the operands will bufferize to a memory read/write. In the absence of such information, we have to conservatively assume that they do.
This commit is in preparation of removing the deprecated `func-bufferize` pass. That pass can bufferize recursive functions.
>From c4310981673ddf90f8006cfa35576531a0f44dbb Mon Sep 17 00:00:00 2001
From: Matthias Springer <mspringer at nvidia.com>
Date: Tue, 29 Oct 2024 06:26:33 +0100
Subject: [PATCH] [mlir][bufferization] Add support for recursive function
calls
---
.../FuncBufferizableOpInterfaceImpl.cpp | 28 +++++++--
.../Transforms/OneShotModuleBufferize.cpp | 59 ++++++++++++++-----
.../one-shot-module-bufferize-analysis.mlir | 12 ++++
.../one-shot-module-bufferize-invalid.mlir | 14 -----
.../Transforms/one-shot-module-bufferize.mlir | 49 +++++++++++++++
5 files changed, 130 insertions(+), 32 deletions(-)
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
index 9fbe574ec392dc..48d61fc223f4c7 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
@@ -82,7 +82,8 @@ getBufferizedFunctionArgType(FuncOp funcOp, int64_t index,
/// Return the FuncOp called by `callOp`.
static FuncOp getCalledFunction(CallOpInterface callOp) {
- SymbolRefAttr sym = llvm::dyn_cast_if_present<SymbolRefAttr>(callOp.getCallableForCallee());
+ SymbolRefAttr sym =
+ llvm::dyn_cast_if_present<SymbolRefAttr>(callOp.getCallableForCallee());
if (!sym)
return nullptr;
return dyn_cast_or_null<FuncOp>(
@@ -206,11 +207,18 @@ struct CallOpInterface
FuncOp funcOp = getCalledFunction(callOp);
assert(funcOp && "expected CallOp to a FuncOp");
- // The callee was already bufferized, so we can directly take the type from
+ // If the callee was already bufferized, we can directly take the type from
// its signature.
FunctionType funcType = funcOp.getFunctionType();
- return cast<BaseMemRefType>(
- funcType.getResult(cast<OpResult>(value).getResultNumber()));
+ Type resultType =
+ funcType.getResult(cast<OpResult>(value).getResultNumber());
+ if (auto bufferizedType = dyn_cast<BaseMemRefType>(resultType))
+ return bufferizedType;
+
+ // Otherwise, call the type converter to compute the bufferized type.
+ auto tensorType = cast<TensorType>(resultType);
+ return options.functionArgTypeConverterFn(
+ tensorType, *options.defaultMemorySpaceFn(tensorType), funcOp, options);
}
/// All function arguments are writable. It is the responsibility of the
@@ -260,6 +268,18 @@ struct CallOpInterface
// Caller / callee type mismatch is handled with castOrReallocMemRefValue.
auto memRefType = funcType.getInput(opOperand.getOperandNumber());
+ if (!isa<BaseMemRefType>(memRefType)) {
+ // The called function was not bufferized yet. This can happen when
+ // there cycles in the function call graph. Compute the bufferized
+ // result type.
+ FailureOr<BaseMemRefType> maybeMemRefType =
+ bufferization::getBufferType(
+ funcOp.getArgument(opOperand.getOperandNumber()), options);
+ if (failed(maybeMemRefType))
+ return failure();
+ memRefType = *maybeMemRefType;
+ }
+
// Since we don't yet have a clear layout story, to_memref may
// conservatively turn tensors into more dynamic memref than necessary.
// If the memref type of the callee fails, introduce an extra memref.cast
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp
index 0a4072605c265f..fde8d92fe610ce 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp
@@ -287,12 +287,11 @@ static bool hasTensorSignature(func::FuncOp funcOp) {
/// Store all functions of the `moduleOp` in `orderedFuncOps`, sorted by
/// callee-caller order (i.e. callees without callers first).
/// Store the map of FuncOp to all its callers in `callerMap`.
-/// Return `failure()` if a cycle of calls is detected or if we are unable to
-/// retrieve the called FuncOp from any func::CallOp.
-static LogicalResult
-getFuncOpsOrderedByCalls(ModuleOp moduleOp,
- SmallVectorImpl<func::FuncOp> &orderedFuncOps,
- FuncCallerMap &callerMap) {
+/// Return `failure()` if we are unable to retrieve the called FuncOp from
+/// any func::CallOp.
+static LogicalResult getFuncOpsOrderedByCalls(
+ ModuleOp moduleOp, SmallVectorImpl<func::FuncOp> &orderedFuncOps,
+ SmallVectorImpl<func::FuncOp> &remainingFuncOps, FuncCallerMap &callerMap) {
// For each FuncOp, the set of functions called by it (i.e. the union of
// symbols of all nested func::CallOp).
DenseMap<func::FuncOp, DenseSet<func::FuncOp>> calledBy;
@@ -326,19 +325,25 @@ getFuncOpsOrderedByCalls(ModuleOp moduleOp,
});
if (res.wasInterrupted())
return failure();
+
// Iteratively remove function operations that do not call any of the
- // functions remaining in the callCounter map and add them to the worklist.
+ // functions remaining in the callCounter map and add them to ordered list.
while (!numberCallOpsContainedInFuncOp.empty()) {
auto it = llvm::find_if(numberCallOpsContainedInFuncOp,
[](auto entry) { return entry.getSecond() == 0; });
if (it == numberCallOpsContainedInFuncOp.end())
- return moduleOp.emitOpError(
- "expected callgraph to be free of circular dependencies.");
+ break;
orderedFuncOps.push_back(it->getFirst());
for (auto callee : calledBy[it->getFirst()])
numberCallOpsContainedInFuncOp[callee]--;
numberCallOpsContainedInFuncOp.erase(it);
}
+
+ // Put all other functions in the list of remaining functions. These are
+ // functions that call each each circularly.
+ for (auto it : numberCallOpsContainedInFuncOp)
+ remainingFuncOps.push_back(it.first);
+
return success();
}
@@ -379,15 +384,17 @@ mlir::bufferization::analyzeModuleOp(ModuleOp moduleOp,
FuncAnalysisState &funcState = getOrCreateFuncAnalysisState(state);
// A list of functions in the order in which they are analyzed + bufferized.
- SmallVector<func::FuncOp> orderedFuncOps;
+ SmallVector<func::FuncOp> orderedFuncOps, remainingFuncOps;
// A mapping of FuncOps to their callers.
FuncCallerMap callerMap;
- if (failed(getFuncOpsOrderedByCalls(moduleOp, orderedFuncOps, callerMap)))
+ if (failed(getFuncOpsOrderedByCalls(moduleOp, orderedFuncOps,
+ remainingFuncOps, callerMap)))
return failure();
- // Analyze ops.
+ // Analyze ops in order. Starting with functions that are not calling any
+ // other functions.
for (func::FuncOp funcOp : orderedFuncOps) {
if (!state.getOptions().isOpAllowed(funcOp))
continue;
@@ -411,6 +418,23 @@ mlir::bufferization::analyzeModuleOp(ModuleOp moduleOp,
funcState.analyzedFuncOps[funcOp] = FuncOpAnalysisState::Analyzed;
}
+ // Analyze all other ops.
+ for (func::FuncOp funcOp : remainingFuncOps) {
+ if (!state.getOptions().isOpAllowed(funcOp))
+ continue;
+
+ // Gather equivalence info for CallOps.
+ equivalenceAnalysis(funcOp, state, funcState);
+
+ // Analyze funcOp.
+ if (failed(analyzeOp(funcOp, state, statistics)))
+ return failure();
+
+ // TODO: We currently skip all function argument analyses for functions
+ // that call each other circularly. These analyses do not support recursive
+ // calls yet.
+ }
+
return success();
}
@@ -430,13 +454,20 @@ LogicalResult mlir::bufferization::bufferizeModuleOp(
IRRewriter rewriter(moduleOp.getContext());
// A list of functions in the order in which they are analyzed + bufferized.
- SmallVector<func::FuncOp> orderedFuncOps;
+ SmallVector<func::FuncOp> orderedFuncOps, remainingFuncOps;
// A mapping of FuncOps to their callers.
FuncCallerMap callerMap;
- if (failed(getFuncOpsOrderedByCalls(moduleOp, orderedFuncOps, callerMap)))
+ // Try to bufferize functions in calling order. I.e., first bufferize
+ // functions that do not call other functions. This allows us to infer
+ // accurate buffer types for function return values. Functions that call
+ // each other recursively are bufferized in an unspecified order at the end.
+ // We may use unnecessarily "complex" (in terms of layout map) buffer types.
+ if (failed(getFuncOpsOrderedByCalls(moduleOp, orderedFuncOps,
+ remainingFuncOps, callerMap)))
return failure();
+ llvm::append_range(orderedFuncOps, remainingFuncOps);
// Bufferize functions.
for (func::FuncOp funcOp : orderedFuncOps) {
diff --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-analysis.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-analysis.mlir
index 42d9cc00d3ff5a..3f6d182b57c031 100644
--- a/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-analysis.mlir
+++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-analysis.mlir
@@ -1348,3 +1348,15 @@ func.func @private_func_aliasing(%t: tensor<?xf32>) -> f32 {
%2 = tensor.extract %1[%c0] : tensor<6xf32>
return %2 : f32
}
+
+// -----
+
+// CHECK-LABEL: func @recursive_function
+func.func @recursive_function(%a: tensor<?xf32>, %b: tensor<?xf32>) -> (tensor<?xf32>, tensor<?xf32>) {
+ // The analysis does not support recursive function calls and is conservative
+ // around them.
+ // CHECK: call @recursive_function
+ // CHECK-SAME: {__inplace_operands_attr__ = ["false", "false"]}
+ %0:2 = call @recursive_function(%a, %b) : (tensor<?xf32>, tensor<?xf32>) -> (tensor<?xf32>, tensor<?xf32>)
+ return %0#0, %0#1 : tensor<?xf32>, tensor<?xf32>
+}
diff --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-invalid.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-invalid.mlir
index ee0f71f668dc74..888310169f3043 100644
--- a/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-invalid.mlir
+++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-invalid.mlir
@@ -25,20 +25,6 @@ func.func @swappy(%cond1 : i1, %cond2 : i1, %t1 : tensor<f32>, %t2 : tensor<f32>
// -----
-// expected-error @-3 {{expected callgraph to be free of circular dependencies}}
-
-func.func @foo(%t: tensor<5xf32>) -> tensor<5xf32> {
- %0 = call @bar(%t) : (tensor<5xf32>) -> (tensor<5xf32>)
- return %0 : tensor<5xf32>
-}
-
-func.func @bar(%t: tensor<5xf32>) -> tensor<5xf32>{
- %0 = call @foo(%t) : (tensor<5xf32>) -> (tensor<5xf32>)
- return %0 : tensor<5xf32>
-}
-
-// -----
-
func.func @scf_for(%A : tensor<?xf32>,
%B : tensor<?xf32> {bufferization.writable = true},
%C : tensor<4xf32>,
diff --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize.mlir
index 0d5224514e3a02..da145291960680 100644
--- a/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize.mlir
+++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize.mlir
@@ -707,3 +707,52 @@ func.func @bar(%t: tensor<5xf32>, %m: memref<5xf32>) -> memref<5xf32> {
%0 = func.call @foo(%m) : (memref<5xf32>) -> (memref<5xf32>)
return %0 : memref<5xf32>
}
+
+// -----
+
+// A recursive function.
+
+// CHECK-LABEL: func.func @foo(
+// CHECK-SAME: %[[arg0:.*]]: memref<5xf32, strided<[?], offset: ?>>) -> memref<5xf32, strided<[?], offset: ?>> {
+func.func @foo(%t: tensor<5xf32>) -> tensor<5xf32> {
+ // We are conservative around recursive functions. The analysis cannot handle
+ // them, so we have to assume the op operand of the call op bufferizes to a
+ // memory read and write. This causes a copy in this test case.
+ // CHECK: %[[copy:.*]] = memref.alloc() {alignment = 64 : i64} : memref<5xf32>
+ // CHECK: memref.copy %[[arg0]], %[[copy]]
+ // CHECK: %[[cast:.*]] = memref.cast %[[copy]] : memref<5xf32> to memref<5xf32, strided<[?], offset: ?>>
+ // CHECK: %[[call:.*]] = call @foo(%[[cast]])
+ %0 = call @foo(%t) : (tensor<5xf32>) -> (tensor<5xf32>)
+
+ // CHECK: memref.load %[[arg0]]
+ %c0 = arith.constant 0 : index
+ %extr = tensor.extract %t[%c0] : tensor<5xf32>
+ vector.print %extr : f32
+
+ // CHECK: return %[[call]]
+ return %0 : tensor<5xf32>
+}
+
+// -----
+
+// Two functions calling each other recursively.
+
+// CHECK-LABEL: func.func @foo(
+// CHECK-SAME: %[[arg0:.*]]: memref<5xf32, strided<[?], offset: ?>>) -> memref<5xf32, strided<[?], offset: ?>> {
+// CHECK: %[[call:.*]] = call @bar(%[[arg0]]) : (memref<5xf32, strided<[?], offset: ?>>) -> memref<5xf32, strided<[?], offset: ?>>
+// CHECK: return %[[call]]
+// CHECK: }
+func.func @foo(%t: tensor<5xf32>) -> tensor<5xf32> {
+ %0 = call @bar(%t) : (tensor<5xf32>) -> (tensor<5xf32>)
+ return %0 : tensor<5xf32>
+}
+
+// CHECK-LABEL: func.func @bar(
+// CHECK-SAME: %[[arg0:.*]]: memref<5xf32, strided<[?], offset: ?>>) -> memref<5xf32, strided<[?], offset: ?>> {
+// CHECK: %[[call:.*]] = call @foo(%[[arg0]]) : (memref<5xf32, strided<[?], offset: ?>>) -> memref<5xf32, strided<[?], offset: ?>>
+// CHECK: return %[[call]]
+// CHECK: }
+func.func @bar(%t: tensor<5xf32>) -> tensor<5xf32>{
+ %0 = call @foo(%t) : (tensor<5xf32>) -> (tensor<5xf32>)
+ return %0 : tensor<5xf32>
+}
More information about the Mlir-commits
mailing list