[Mlir-commits] [mlir] [mlir][bufferization] Add support for recursive function calls (PR #114003)
Matthias Springer
llvmlistbot at llvm.org
Mon Nov 4 17:05:31 PST 2024
https://github.com/matthias-springer updated https://github.com/llvm/llvm-project/pull/114003
>From 9f95eb791b8e0882b338cf7acb0a43ba30e94bc6 Mon Sep 17 00:00:00 2001
From: Matthias Springer <mspringer at nvidia.com>
Date: Tue, 29 Oct 2024 06:26:33 +0100
Subject: [PATCH 1/2] [mlir][bufferization] Add support for recursive function
calls
---
.../FuncBufferizableOpInterfaceImpl.cpp | 25 ++++++-
.../Transforms/OneShotModuleBufferize.cpp | 67 ++++++++++++++-----
.../one-shot-module-bufferize-analysis.mlir | 12 ++++
.../one-shot-module-bufferize-invalid.mlir | 14 ----
.../Transforms/one-shot-module-bufferize.mlir | 49 ++++++++++++++
5 files changed, 135 insertions(+), 32 deletions(-)
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
index 6e91d3b89a7c79..11ed434f774a87 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
@@ -207,11 +207,18 @@ struct CallOpInterface
FuncOp funcOp = getCalledFunction(callOp);
assert(funcOp && "expected CallOp to a FuncOp");
- // The callee was already bufferized, so we can directly take the type from
+ // If the callee was already bufferized, we can directly take the type from
// its signature.
FunctionType funcType = funcOp.getFunctionType();
- return cast<BaseMemRefType>(
- funcType.getResult(cast<OpResult>(value).getResultNumber()));
+ Type resultType =
+ funcType.getResult(cast<OpResult>(value).getResultNumber());
+ if (auto bufferizedType = dyn_cast<BaseMemRefType>(resultType))
+ return bufferizedType;
+
+ // Otherwise, call the type converter to compute the bufferized type.
+ auto tensorType = cast<TensorType>(resultType);
+ return options.functionArgTypeConverterFn(
+ tensorType, *options.defaultMemorySpaceFn(tensorType), funcOp, options);
}
/// All function arguments are writable. It is the responsibility of the
@@ -261,6 +268,18 @@ struct CallOpInterface
// Caller / callee type mismatch is handled with castOrReallocMemRefValue.
auto memRefType = funcType.getInput(opOperand.getOperandNumber());
+ if (!isa<BaseMemRefType>(memRefType)) {
+ // The called function was not bufferized yet. This can happen when
+ // there cycles in the function call graph. Compute the bufferized
+ // result type.
+ FailureOr<BaseMemRefType> maybeMemRefType =
+ bufferization::getBufferType(
+ funcOp.getArgument(opOperand.getOperandNumber()), options);
+ if (failed(maybeMemRefType))
+ return failure();
+ memRefType = *maybeMemRefType;
+ }
+
// Since we don't yet have a clear layout story, to_memref may
// conservatively turn tensors into more dynamic memref than necessary.
// If the memref type of the callee fails, introduce an extra memref.cast
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp
index 0a4072605c265f..4d0d232e6afae9 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp
@@ -285,14 +285,17 @@ static bool hasTensorSignature(func::FuncOp funcOp) {
}
/// Store all functions of the `moduleOp` in `orderedFuncOps`, sorted by
-/// callee-caller order (i.e. callees without callers first).
+/// callee-caller order (i.e., callees without callers first). Store all
+/// remaining functions (i.e., the ones that call each other recursively) in
+/// `remainingFuncOps`.
+///
/// Store the map of FuncOp to all its callers in `callerMap`.
-/// Return `failure()` if a cycle of calls is detected or if we are unable to
-/// retrieve the called FuncOp from any func::CallOp.
-static LogicalResult
-getFuncOpsOrderedByCalls(ModuleOp moduleOp,
- SmallVectorImpl<func::FuncOp> &orderedFuncOps,
- FuncCallerMap &callerMap) {
+///
+/// Return `failure()` if we are unable to retrieve the called FuncOp from
+/// any func::CallOp.
+static LogicalResult getFuncOpsOrderedByCalls(
+ ModuleOp moduleOp, SmallVectorImpl<func::FuncOp> &orderedFuncOps,
+ SmallVectorImpl<func::FuncOp> &remainingFuncOps, FuncCallerMap &callerMap) {
// For each FuncOp, the set of functions called by it (i.e. the union of
// symbols of all nested func::CallOp).
DenseMap<func::FuncOp, DenseSet<func::FuncOp>> calledBy;
@@ -326,19 +329,25 @@ getFuncOpsOrderedByCalls(ModuleOp moduleOp,
});
if (res.wasInterrupted())
return failure();
+
// Iteratively remove function operations that do not call any of the
- // functions remaining in the callCounter map and add them to the worklist.
+ // functions remaining in the callCounter map and add them to ordered list.
while (!numberCallOpsContainedInFuncOp.empty()) {
auto it = llvm::find_if(numberCallOpsContainedInFuncOp,
[](auto entry) { return entry.getSecond() == 0; });
if (it == numberCallOpsContainedInFuncOp.end())
- return moduleOp.emitOpError(
- "expected callgraph to be free of circular dependencies.");
+ break;
orderedFuncOps.push_back(it->getFirst());
for (auto callee : calledBy[it->getFirst()])
numberCallOpsContainedInFuncOp[callee]--;
numberCallOpsContainedInFuncOp.erase(it);
}
+
+ // Put all other functions in the list of remaining functions. These are
+ // functions that call each each circularly.
+ for (auto it : numberCallOpsContainedInFuncOp)
+ remainingFuncOps.push_back(it.first);
+
return success();
}
@@ -379,15 +388,17 @@ mlir::bufferization::analyzeModuleOp(ModuleOp moduleOp,
FuncAnalysisState &funcState = getOrCreateFuncAnalysisState(state);
// A list of functions in the order in which they are analyzed + bufferized.
- SmallVector<func::FuncOp> orderedFuncOps;
+ SmallVector<func::FuncOp> orderedFuncOps, remainingFuncOps;
// A mapping of FuncOps to their callers.
FuncCallerMap callerMap;
- if (failed(getFuncOpsOrderedByCalls(moduleOp, orderedFuncOps, callerMap)))
+ if (failed(getFuncOpsOrderedByCalls(moduleOp, orderedFuncOps,
+ remainingFuncOps, callerMap)))
return failure();
- // Analyze ops.
+ // Analyze ops in order. Starting with functions that are not calling any
+ // other functions.
for (func::FuncOp funcOp : orderedFuncOps) {
if (!state.getOptions().isOpAllowed(funcOp))
continue;
@@ -411,6 +422,25 @@ mlir::bufferization::analyzeModuleOp(ModuleOp moduleOp,
funcState.analyzedFuncOps[funcOp] = FuncOpAnalysisState::Analyzed;
}
+ // Analyze all other ops.
+ for (func::FuncOp funcOp : remainingFuncOps) {
+ if (!state.getOptions().isOpAllowed(funcOp))
+ continue;
+
+ // Gather equivalence info for CallOps.
+ equivalenceAnalysis(funcOp, state, funcState);
+
+ // Analyze funcOp.
+ if (failed(analyzeOp(funcOp, state, statistics)))
+ return failure();
+
+ // TODO: We currently skip all function argument analyses for functions
+ // that call each other circularly. These analyses do not support recursive
+ // calls yet. The `BufferizableOpInterface` implementations of `func`
+ // dialect ops return conservative results in the absence of analysis
+ // information.
+ }
+
return success();
}
@@ -430,13 +460,20 @@ LogicalResult mlir::bufferization::bufferizeModuleOp(
IRRewriter rewriter(moduleOp.getContext());
// A list of functions in the order in which they are analyzed + bufferized.
- SmallVector<func::FuncOp> orderedFuncOps;
+ SmallVector<func::FuncOp> orderedFuncOps, remainingFuncOps;
// A mapping of FuncOps to their callers.
FuncCallerMap callerMap;
- if (failed(getFuncOpsOrderedByCalls(moduleOp, orderedFuncOps, callerMap)))
+ // Try to bufferize functions in calling order. I.e., first bufferize
+ // functions that do not call other functions. This allows us to infer
+ // accurate buffer types for function return values. Functions that call
+ // each other recursively are bufferized in an unspecified order at the end.
+ // We may use unnecessarily "complex" (in terms of layout map) buffer types.
+ if (failed(getFuncOpsOrderedByCalls(moduleOp, orderedFuncOps,
+ remainingFuncOps, callerMap)))
return failure();
+ llvm::append_range(orderedFuncOps, remainingFuncOps);
// Bufferize functions.
for (func::FuncOp funcOp : orderedFuncOps) {
diff --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-analysis.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-analysis.mlir
index 42d9cc00d3ff5a..3f6d182b57c031 100644
--- a/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-analysis.mlir
+++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-analysis.mlir
@@ -1348,3 +1348,15 @@ func.func @private_func_aliasing(%t: tensor<?xf32>) -> f32 {
%2 = tensor.extract %1[%c0] : tensor<6xf32>
return %2 : f32
}
+
+// -----
+
+// CHECK-LABEL: func @recursive_function
+func.func @recursive_function(%a: tensor<?xf32>, %b: tensor<?xf32>) -> (tensor<?xf32>, tensor<?xf32>) {
+ // The analysis does not support recursive function calls and is conservative
+ // around them.
+ // CHECK: call @recursive_function
+ // CHECK-SAME: {__inplace_operands_attr__ = ["false", "false"]}
+ %0:2 = call @recursive_function(%a, %b) : (tensor<?xf32>, tensor<?xf32>) -> (tensor<?xf32>, tensor<?xf32>)
+ return %0#0, %0#1 : tensor<?xf32>, tensor<?xf32>
+}
diff --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-invalid.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-invalid.mlir
index 2829eafb7c1c59..28ce0735e47b74 100644
--- a/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-invalid.mlir
+++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-invalid.mlir
@@ -19,20 +19,6 @@ func.func @swappy(%cond1 : i1, %cond2 : i1, %t1 : tensor<f32>, %t2 : tensor<f32>
// -----
-// expected-error @-3 {{expected callgraph to be free of circular dependencies}}
-
-func.func @foo(%t: tensor<5xf32>) -> tensor<5xf32> {
- %0 = call @bar(%t) : (tensor<5xf32>) -> (tensor<5xf32>)
- return %0 : tensor<5xf32>
-}
-
-func.func @bar(%t: tensor<5xf32>) -> tensor<5xf32>{
- %0 = call @foo(%t) : (tensor<5xf32>) -> (tensor<5xf32>)
- return %0 : tensor<5xf32>
-}
-
-// -----
-
func.func @scf_for(%A : tensor<?xf32>,
%B : tensor<?xf32> {bufferization.writable = true},
%C : tensor<4xf32>,
diff --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize.mlir
index d31b43477beb9f..2b5b8631436705 100644
--- a/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize.mlir
+++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize.mlir
@@ -722,3 +722,52 @@ func.func @bar(%t: tensor<5xf32>, %m: memref<5xf32>) -> memref<5xf32> {
%0 = func.call @foo(%m) : (memref<5xf32>) -> (memref<5xf32>)
return %0 : memref<5xf32>
}
+
+// -----
+
+// A recursive function.
+
+// CHECK-LABEL: func.func @foo(
+// CHECK-SAME: %[[arg0:.*]]: memref<5xf32, strided<[?], offset: ?>>) -> memref<5xf32, strided<[?], offset: ?>> {
+func.func @foo(%t: tensor<5xf32>) -> tensor<5xf32> {
+ // We are conservative around recursive functions. The analysis cannot handle
+ // them, so we have to assume the op operand of the call op bufferizes to a
+ // memory read and write. This causes a copy in this test case.
+ // CHECK: %[[copy:.*]] = memref.alloc() {alignment = 64 : i64} : memref<5xf32>
+ // CHECK: memref.copy %[[arg0]], %[[copy]]
+ // CHECK: %[[cast:.*]] = memref.cast %[[copy]] : memref<5xf32> to memref<5xf32, strided<[?], offset: ?>>
+ // CHECK: %[[call:.*]] = call @foo(%[[cast]])
+ %0 = call @foo(%t) : (tensor<5xf32>) -> (tensor<5xf32>)
+
+ // CHECK: memref.load %[[arg0]]
+ %c0 = arith.constant 0 : index
+ %extr = tensor.extract %t[%c0] : tensor<5xf32>
+ vector.print %extr : f32
+
+ // CHECK: return %[[call]]
+ return %0 : tensor<5xf32>
+}
+
+// -----
+
+// Two functions calling each other recursively.
+
+// CHECK-LABEL: func.func @foo(
+// CHECK-SAME: %[[arg0:.*]]: memref<5xf32, strided<[?], offset: ?>>) -> memref<5xf32, strided<[?], offset: ?>> {
+// CHECK: %[[call:.*]] = call @bar(%[[arg0]]) : (memref<5xf32, strided<[?], offset: ?>>) -> memref<5xf32, strided<[?], offset: ?>>
+// CHECK: return %[[call]]
+// CHECK: }
+func.func @foo(%t: tensor<5xf32>) -> tensor<5xf32> {
+ %0 = call @bar(%t) : (tensor<5xf32>) -> (tensor<5xf32>)
+ return %0 : tensor<5xf32>
+}
+
+// CHECK-LABEL: func.func @bar(
+// CHECK-SAME: %[[arg0:.*]]: memref<5xf32, strided<[?], offset: ?>>) -> memref<5xf32, strided<[?], offset: ?>> {
+// CHECK: %[[call:.*]] = call @foo(%[[arg0]]) : (memref<5xf32, strided<[?], offset: ?>>) -> memref<5xf32, strided<[?], offset: ?>>
+// CHECK: return %[[call]]
+// CHECK: }
+func.func @bar(%t: tensor<5xf32>) -> tensor<5xf32>{
+ %0 = call @foo(%t) : (tensor<5xf32>) -> (tensor<5xf32>)
+ return %0 : tensor<5xf32>
+}
>From c4d579be8d1c874c1cb51ace8d197bc7d47e2e9a Mon Sep 17 00:00:00 2001
From: Matthias Springer <mspringer at nvidia.com>
Date: Tue, 5 Nov 2024 02:03:55 +0100
Subject: [PATCH 2/2] address comments
---
.../Transforms/OneShotModuleBufferize.cpp | 26 +++++++++++++------
1 file changed, 18 insertions(+), 8 deletions(-)
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp
index 4d0d232e6afae9..a492bcdd0f3e38 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp
@@ -344,7 +344,7 @@ static LogicalResult getFuncOpsOrderedByCalls(
}
// Put all other functions in the list of remaining functions. These are
- // functions that call each each circularly.
+ // functions that call each other circularly.
for (auto it : numberCallOpsContainedInFuncOp)
remainingFuncOps.push_back(it.first);
@@ -387,8 +387,13 @@ mlir::bufferization::analyzeModuleOp(ModuleOp moduleOp,
"expected that function boundary bufferization is activated");
FuncAnalysisState &funcState = getOrCreateFuncAnalysisState(state);
- // A list of functions in the order in which they are analyzed + bufferized.
- SmallVector<func::FuncOp> orderedFuncOps, remainingFuncOps;
+ // A list of non-circular functions in the order in which they are analyzed
+ // and bufferized.
+ SmallVector<func::FuncOp> orderedFuncOps;
+ // A list of all other functions. I.e., functions that call each other
+ // recursively. For these, we analyze the function body but not the function
+ // boundary.
+ SmallVector<func::FuncOp> remainingFuncOps;
// A mapping of FuncOps to their callers.
FuncCallerMap callerMap;
@@ -397,8 +402,8 @@ mlir::bufferization::analyzeModuleOp(ModuleOp moduleOp,
remainingFuncOps, callerMap)))
return failure();
- // Analyze ops in order. Starting with functions that are not calling any
- // other functions.
+ // Analyze functions in order. Starting with functions that are not calling
+ // any other functions.
for (func::FuncOp funcOp : orderedFuncOps) {
if (!state.getOptions().isOpAllowed(funcOp))
continue;
@@ -422,7 +427,7 @@ mlir::bufferization::analyzeModuleOp(ModuleOp moduleOp,
funcState.analyzedFuncOps[funcOp] = FuncOpAnalysisState::Analyzed;
}
- // Analyze all other ops.
+ // Analyze all other functions. All function boundary analyses are skipped.
for (func::FuncOp funcOp : remainingFuncOps) {
if (!state.getOptions().isOpAllowed(funcOp))
continue;
@@ -459,8 +464,13 @@ LogicalResult mlir::bufferization::bufferizeModuleOp(
"expected that function boundary bufferization is activated");
IRRewriter rewriter(moduleOp.getContext());
- // A list of functions in the order in which they are analyzed + bufferized.
- SmallVector<func::FuncOp> orderedFuncOps, remainingFuncOps;
+ // A list of non-circular functions in the order in which they are analyzed
+ // and bufferized.
+ SmallVector<func::FuncOp> orderedFuncOps;
+ // A list of all other functions. I.e., functions that call each other
+ // recursively. For these, we analyze the function body but not the function
+ // boundary.
+ SmallVector<func::FuncOp> remainingFuncOps;
// A mapping of FuncOps to their callers.
FuncCallerMap callerMap;
More information about the Mlir-commits
mailing list