[Mlir-commits] [mlir] [mlir][bufferization] Add support for recursive function calls (PR #114003)
Matthias Springer
llvmlistbot at llvm.org
Mon Oct 28 22:33:34 PDT 2024
https://github.com/matthias-springer updated https://github.com/llvm/llvm-project/pull/114003
>From 89f8a746e6b08e89c0b9a53a4a26d8b8ec5677d8 Mon Sep 17 00:00:00 2001
From: Matthias Springer <mspringer at nvidia.com>
Date: Tue, 29 Oct 2024 06:26:33 +0100
Subject: [PATCH] [mlir][bufferization] Add support for recursive function
calls
---
.../FuncBufferizableOpInterfaceImpl.cpp | 28 ++++++--
.../Transforms/OneShotModuleBufferize.cpp | 67 ++++++++++++++-----
.../one-shot-module-bufferize-analysis.mlir | 12 ++++
.../one-shot-module-bufferize-invalid.mlir | 14 ----
.../Transforms/one-shot-module-bufferize.mlir | 49 ++++++++++++++
5 files changed, 137 insertions(+), 33 deletions(-)
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
index 9fbe574ec392dc..48d61fc223f4c7 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
@@ -82,7 +82,8 @@ getBufferizedFunctionArgType(FuncOp funcOp, int64_t index,
/// Return the FuncOp called by `callOp`.
static FuncOp getCalledFunction(CallOpInterface callOp) {
- SymbolRefAttr sym = llvm::dyn_cast_if_present<SymbolRefAttr>(callOp.getCallableForCallee());
+ SymbolRefAttr sym =
+ llvm::dyn_cast_if_present<SymbolRefAttr>(callOp.getCallableForCallee());
if (!sym)
return nullptr;
return dyn_cast_or_null<FuncOp>(
@@ -206,11 +207,18 @@ struct CallOpInterface
FuncOp funcOp = getCalledFunction(callOp);
assert(funcOp && "expected CallOp to a FuncOp");
- // The callee was already bufferized, so we can directly take the type from
+ // If the callee was already bufferized, we can directly take the type from
// its signature.
FunctionType funcType = funcOp.getFunctionType();
- return cast<BaseMemRefType>(
- funcType.getResult(cast<OpResult>(value).getResultNumber()));
+ Type resultType =
+ funcType.getResult(cast<OpResult>(value).getResultNumber());
+ if (auto bufferizedType = dyn_cast<BaseMemRefType>(resultType))
+ return bufferizedType;
+
+ // Otherwise, call the type converter to compute the bufferized type.
+ auto tensorType = cast<TensorType>(resultType);
+ return options.functionArgTypeConverterFn(
+ tensorType, *options.defaultMemorySpaceFn(tensorType), funcOp, options);
}
/// All function arguments are writable. It is the responsibility of the
@@ -260,6 +268,18 @@ struct CallOpInterface
// Caller / callee type mismatch is handled with castOrReallocMemRefValue.
auto memRefType = funcType.getInput(opOperand.getOperandNumber());
+ if (!isa<BaseMemRefType>(memRefType)) {
+ // The called function was not bufferized yet. This can happen when
+ // there cycles in the function call graph. Compute the bufferized
+ // result type.
+ FailureOr<BaseMemRefType> maybeMemRefType =
+ bufferization::getBufferType(
+ funcOp.getArgument(opOperand.getOperandNumber()), options);
+ if (failed(maybeMemRefType))
+ return failure();
+ memRefType = *maybeMemRefType;
+ }
+
// Since we don't yet have a clear layout story, to_memref may
// conservatively turn tensors into more dynamic memref than necessary.
// If the memref type of the callee fails, introduce an extra memref.cast
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp
index 0a4072605c265f..4d0d232e6afae9 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp
@@ -285,14 +285,17 @@ static bool hasTensorSignature(func::FuncOp funcOp) {
}
/// Store all functions of the `moduleOp` in `orderedFuncOps`, sorted by
-/// callee-caller order (i.e. callees without callers first).
+/// callee-caller order (i.e., callees without callers first). Store all
+/// remaining functions (i.e., the ones that call each other recursively) in
+/// `remainingFuncOps`.
+///
/// Store the map of FuncOp to all its callers in `callerMap`.
-/// Return `failure()` if a cycle of calls is detected or if we are unable to
-/// retrieve the called FuncOp from any func::CallOp.
-static LogicalResult
-getFuncOpsOrderedByCalls(ModuleOp moduleOp,
- SmallVectorImpl<func::FuncOp> &orderedFuncOps,
- FuncCallerMap &callerMap) {
+///
+/// Return `failure()` if we are unable to retrieve the called FuncOp from
+/// any func::CallOp.
+static LogicalResult getFuncOpsOrderedByCalls(
+ ModuleOp moduleOp, SmallVectorImpl<func::FuncOp> &orderedFuncOps,
+ SmallVectorImpl<func::FuncOp> &remainingFuncOps, FuncCallerMap &callerMap) {
// For each FuncOp, the set of functions called by it (i.e. the union of
// symbols of all nested func::CallOp).
DenseMap<func::FuncOp, DenseSet<func::FuncOp>> calledBy;
@@ -326,19 +329,25 @@ getFuncOpsOrderedByCalls(ModuleOp moduleOp,
});
if (res.wasInterrupted())
return failure();
+
// Iteratively remove function operations that do not call any of the
- // functions remaining in the callCounter map and add them to the worklist.
+ // functions remaining in the callCounter map and add them to ordered list.
while (!numberCallOpsContainedInFuncOp.empty()) {
auto it = llvm::find_if(numberCallOpsContainedInFuncOp,
[](auto entry) { return entry.getSecond() == 0; });
if (it == numberCallOpsContainedInFuncOp.end())
- return moduleOp.emitOpError(
- "expected callgraph to be free of circular dependencies.");
+ break;
orderedFuncOps.push_back(it->getFirst());
for (auto callee : calledBy[it->getFirst()])
numberCallOpsContainedInFuncOp[callee]--;
numberCallOpsContainedInFuncOp.erase(it);
}
+
+ // Put all other functions in the list of remaining functions. These are
+ // functions that call each each circularly.
+ for (auto it : numberCallOpsContainedInFuncOp)
+ remainingFuncOps.push_back(it.first);
+
return success();
}
@@ -379,15 +388,17 @@ mlir::bufferization::analyzeModuleOp(ModuleOp moduleOp,
FuncAnalysisState &funcState = getOrCreateFuncAnalysisState(state);
// A list of functions in the order in which they are analyzed + bufferized.
- SmallVector<func::FuncOp> orderedFuncOps;
+ SmallVector<func::FuncOp> orderedFuncOps, remainingFuncOps;
// A mapping of FuncOps to their callers.
FuncCallerMap callerMap;
- if (failed(getFuncOpsOrderedByCalls(moduleOp, orderedFuncOps, callerMap)))
+ if (failed(getFuncOpsOrderedByCalls(moduleOp, orderedFuncOps,
+ remainingFuncOps, callerMap)))
return failure();
- // Analyze ops.
+ // Analyze ops in order. Starting with functions that are not calling any
+ // other functions.
for (func::FuncOp funcOp : orderedFuncOps) {
if (!state.getOptions().isOpAllowed(funcOp))
continue;
@@ -411,6 +422,25 @@ mlir::bufferization::analyzeModuleOp(ModuleOp moduleOp,
funcState.analyzedFuncOps[funcOp] = FuncOpAnalysisState::Analyzed;
}
+ // Analyze all other ops.
+ for (func::FuncOp funcOp : remainingFuncOps) {
+ if (!state.getOptions().isOpAllowed(funcOp))
+ continue;
+
+ // Gather equivalence info for CallOps.
+ equivalenceAnalysis(funcOp, state, funcState);
+
+ // Analyze funcOp.
+ if (failed(analyzeOp(funcOp, state, statistics)))
+ return failure();
+
+ // TODO: We currently skip all function argument analyses for functions
+ // that call each other circularly. These analyses do not support recursive
+ // calls yet. The `BufferizableOpInterface` implementations of `func`
+ // dialect ops return conservative results in the absence of analysis
+ // information.
+ }
+
return success();
}
@@ -430,13 +460,20 @@ LogicalResult mlir::bufferization::bufferizeModuleOp(
IRRewriter rewriter(moduleOp.getContext());
// A list of functions in the order in which they are analyzed + bufferized.
- SmallVector<func::FuncOp> orderedFuncOps;
+ SmallVector<func::FuncOp> orderedFuncOps, remainingFuncOps;
// A mapping of FuncOps to their callers.
FuncCallerMap callerMap;
- if (failed(getFuncOpsOrderedByCalls(moduleOp, orderedFuncOps, callerMap)))
+ // Try to bufferize functions in calling order. I.e., first bufferize
+ // functions that do not call other functions. This allows us to infer
+ // accurate buffer types for function return values. Functions that call
+ // each other recursively are bufferized in an unspecified order at the end.
+ // We may use unnecessarily "complex" (in terms of layout map) buffer types.
+ if (failed(getFuncOpsOrderedByCalls(moduleOp, orderedFuncOps,
+ remainingFuncOps, callerMap)))
return failure();
+ llvm::append_range(orderedFuncOps, remainingFuncOps);
// Bufferize functions.
for (func::FuncOp funcOp : orderedFuncOps) {
diff --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-analysis.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-analysis.mlir
index 42d9cc00d3ff5a..3f6d182b57c031 100644
--- a/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-analysis.mlir
+++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-analysis.mlir
@@ -1348,3 +1348,15 @@ func.func @private_func_aliasing(%t: tensor<?xf32>) -> f32 {
%2 = tensor.extract %1[%c0] : tensor<6xf32>
return %2 : f32
}
+
+// -----
+
+// CHECK-LABEL: func @recursive_function
+func.func @recursive_function(%a: tensor<?xf32>, %b: tensor<?xf32>) -> (tensor<?xf32>, tensor<?xf32>) {
+ // The analysis does not support recursive function calls and is conservative
+ // around them.
+ // CHECK: call @recursive_function
+ // CHECK-SAME: {__inplace_operands_attr__ = ["false", "false"]}
+ %0:2 = call @recursive_function(%a, %b) : (tensor<?xf32>, tensor<?xf32>) -> (tensor<?xf32>, tensor<?xf32>)
+ return %0#0, %0#1 : tensor<?xf32>, tensor<?xf32>
+}
diff --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-invalid.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-invalid.mlir
index ee0f71f668dc74..888310169f3043 100644
--- a/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-invalid.mlir
+++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-invalid.mlir
@@ -25,20 +25,6 @@ func.func @swappy(%cond1 : i1, %cond2 : i1, %t1 : tensor<f32>, %t2 : tensor<f32>
// -----
-// expected-error @-3 {{expected callgraph to be free of circular dependencies}}
-
-func.func @foo(%t: tensor<5xf32>) -> tensor<5xf32> {
- %0 = call @bar(%t) : (tensor<5xf32>) -> (tensor<5xf32>)
- return %0 : tensor<5xf32>
-}
-
-func.func @bar(%t: tensor<5xf32>) -> tensor<5xf32>{
- %0 = call @foo(%t) : (tensor<5xf32>) -> (tensor<5xf32>)
- return %0 : tensor<5xf32>
-}
-
-// -----
-
func.func @scf_for(%A : tensor<?xf32>,
%B : tensor<?xf32> {bufferization.writable = true},
%C : tensor<4xf32>,
diff --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize.mlir
index 0d5224514e3a02..da145291960680 100644
--- a/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize.mlir
+++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize.mlir
@@ -707,3 +707,52 @@ func.func @bar(%t: tensor<5xf32>, %m: memref<5xf32>) -> memref<5xf32> {
%0 = func.call @foo(%m) : (memref<5xf32>) -> (memref<5xf32>)
return %0 : memref<5xf32>
}
+
+// -----
+
+// A recursive function.
+
+// CHECK-LABEL: func.func @foo(
+// CHECK-SAME: %[[arg0:.*]]: memref<5xf32, strided<[?], offset: ?>>) -> memref<5xf32, strided<[?], offset: ?>> {
+func.func @foo(%t: tensor<5xf32>) -> tensor<5xf32> {
+ // We are conservative around recursive functions. The analysis cannot handle
+ // them, so we have to assume the op operand of the call op bufferizes to a
+ // memory read and write. This causes a copy in this test case.
+ // CHECK: %[[copy:.*]] = memref.alloc() {alignment = 64 : i64} : memref<5xf32>
+ // CHECK: memref.copy %[[arg0]], %[[copy]]
+ // CHECK: %[[cast:.*]] = memref.cast %[[copy]] : memref<5xf32> to memref<5xf32, strided<[?], offset: ?>>
+ // CHECK: %[[call:.*]] = call @foo(%[[cast]])
+ %0 = call @foo(%t) : (tensor<5xf32>) -> (tensor<5xf32>)
+
+ // CHECK: memref.load %[[arg0]]
+ %c0 = arith.constant 0 : index
+ %extr = tensor.extract %t[%c0] : tensor<5xf32>
+ vector.print %extr : f32
+
+ // CHECK: return %[[call]]
+ return %0 : tensor<5xf32>
+}
+
+// -----
+
+// Two functions calling each other recursively.
+
+// CHECK-LABEL: func.func @foo(
+// CHECK-SAME: %[[arg0:.*]]: memref<5xf32, strided<[?], offset: ?>>) -> memref<5xf32, strided<[?], offset: ?>> {
+// CHECK: %[[call:.*]] = call @bar(%[[arg0]]) : (memref<5xf32, strided<[?], offset: ?>>) -> memref<5xf32, strided<[?], offset: ?>>
+// CHECK: return %[[call]]
+// CHECK: }
+func.func @foo(%t: tensor<5xf32>) -> tensor<5xf32> {
+ %0 = call @bar(%t) : (tensor<5xf32>) -> (tensor<5xf32>)
+ return %0 : tensor<5xf32>
+}
+
+// CHECK-LABEL: func.func @bar(
+// CHECK-SAME: %[[arg0:.*]]: memref<5xf32, strided<[?], offset: ?>>) -> memref<5xf32, strided<[?], offset: ?>> {
+// CHECK: %[[call:.*]] = call @foo(%[[arg0]]) : (memref<5xf32, strided<[?], offset: ?>>) -> memref<5xf32, strided<[?], offset: ?>>
+// CHECK: return %[[call]]
+// CHECK: }
+func.func @bar(%t: tensor<5xf32>) -> tensor<5xf32>{
+ %0 = call @foo(%t) : (tensor<5xf32>) -> (tensor<5xf32>)
+ return %0 : tensor<5xf32>
+}
More information about the Mlir-commits
mailing list