[Mlir-commits] [mlir] [mlir][bufferization] Allow cyclic function graphs without tensors (PR #68632)
Matthias Springer
llvmlistbot at llvm.org
Mon Oct 9 14:30:11 PDT 2023
https://github.com/matthias-springer created https://github.com/llvm/llvm-project/pull/68632
Cyclic function call graphs are generally not supported by One-Shot Bufferize. However, they can be allowed when a function does not have tensor arguments or results. This is because it is then no longer necessary that the callee will be bufferized before the caller.
>From e7eadebadd007d3562387a3ac0f1d0733bb30830 Mon Sep 17 00:00:00 2001
From: Matthias Springer <me at m-sp.org>
Date: Mon, 9 Oct 2023 23:28:24 +0200
Subject: [PATCH] [mlir][bufferization] Allow cyclic function graphs without
tensors
Cyclic function call graphs are generally not supported by One-Shot Bufferize. However, they can be allowed when a function does not have tensor arguments or results. This is because it is then no longer necessary that the callee will be bufferized before the caller.
---
.../Transforms/OneShotModuleBufferize.cpp | 15 ++++++++++++-
.../one-shot-module-bufferize-invalid.mlir | 12 +++++------
.../Transforms/one-shot-module-bufferize.mlir | 21 +++++++++++++++++++
3 files changed, 41 insertions(+), 7 deletions(-)
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp
index 417f457c8910ca9..786ebb23b457d52 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp
@@ -274,6 +274,13 @@ static void equivalenceAnalysis(func::FuncOp funcOp,
});
}
+/// Return "true" if the given function signature has tensor semantics.
+static bool hasTensorSignature(func::FuncOp funcOp) {
+ auto isaTensor = [](Type t) { return isa<TensorType>(t); };
+ return llvm::any_of(funcOp.getFunctionType().getInputs(), isaTensor) ||
+ llvm::any_of(funcOp.getFunctionType().getResults(), isaTensor);
+}
+
/// Store all functions of the `moduleOp` in `orderedFuncOps`, sorted by
/// callee-caller order (i.e. callees without callers first).
/// Store the map of FuncOp to all its callers in `callerMap`.
@@ -297,10 +304,16 @@ getFuncOpsOrderedByCalls(ModuleOp moduleOp,
"without a unique ReturnOp";
}
+ // Collect function calls and populate the caller map.
numberCallOpsContainedInFuncOp[funcOp] = 0;
return funcOp.walk([&](func::CallOp callOp) -> WalkResult {
func::FuncOp calledFunction = getCalledFunction(callOp);
assert(calledFunction && "could not retrieved called func::FuncOp");
+ // If the called function does not have any tensors in its signature, then
+ // it is not necessary to bufferize the callee before the caller.
+ if (!hasTensorSignature(calledFunction))
+ return WalkResult::skip();
+
callerMap[calledFunction].insert(callOp);
if (calledBy[calledFunction].insert(funcOp).second) {
numberCallOpsContainedInFuncOp[funcOp]++;
@@ -310,7 +323,7 @@ getFuncOpsOrderedByCalls(ModuleOp moduleOp,
});
if (res.wasInterrupted())
return failure();
- // Iteratively remove function operation that do not call any of the
+ // Iteratively remove function operations that do not call any of the
// functions remaining in the callCounter map and add them to the worklist.
while (!numberCallOpsContainedInFuncOp.empty()) {
auto it = llvm::find_if(numberCallOpsContainedInFuncOp,
diff --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-invalid.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-invalid.mlir
index fd74ae0b60dbbb8..ee0f71f668dc741 100644
--- a/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-invalid.mlir
+++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-invalid.mlir
@@ -27,14 +27,14 @@ func.func @swappy(%cond1 : i1, %cond2 : i1, %t1 : tensor<f32>, %t2 : tensor<f32>
// expected-error @-3 {{expected callgraph to be free of circular dependencies}}
-func.func @foo() {
- call @bar() : () -> ()
- return
+func.func @foo(%t: tensor<5xf32>) -> tensor<5xf32> {
+ %0 = call @bar(%t) : (tensor<5xf32>) -> (tensor<5xf32>)
+ return %0 : tensor<5xf32>
}
-func.func @bar() {
- call @foo() : () -> ()
- return
+func.func @bar(%t: tensor<5xf32>) -> tensor<5xf32>{
+ %0 = call @foo(%t) : (tensor<5xf32>) -> (tensor<5xf32>)
+ return %0 : tensor<5xf32>
}
// -----
diff --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize.mlir
index b9de4ba34e0e6d3..39f4835b28ffeb2 100644
--- a/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize.mlir
+++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize.mlir
@@ -662,3 +662,24 @@ func.func @br_in_func(%t: tensor<5xf32>) -> tensor<5xf32> {
^bb1(%arg1 : tensor<5xf32>):
func.return %arg1 : tensor<5xf32>
}
+
+// -----
+
+// Cyclic call graphs with tensors are not supported by One-Shot Bufferize.
+// However, if a function signature does not have any tensor arguments or
+// results, calls to that function are not seen as an "edge" in the fuction
+// call graph.
+
+// CHECK-LABEL: func.func @foo(%{{.*}}: memref<5xf32>) -> memref<5xf32>
+func.func @foo(%m: memref<5xf32>) -> memref<5xf32> {
+ %0 = tensor.empty() : tensor<5xf32>
+ %1 = func.call @bar(%0, %m)
+ : (tensor<5xf32>, memref<5xf32>) -> (memref<5xf32>)
+ return %1 : memref<5xf32>
+}
+
+// CHECK: func.func @bar(%{{.*}}: memref<5xf32, strided<[?], offset: ?>>, %arg1: memref<5xf32>) -> memref<5xf32>
+func.func @bar(%t: tensor<5xf32>, %m: memref<5xf32>) -> memref<5xf32> {
+ %0 = func.call @foo(%m) : (memref<5xf32>) -> (memref<5xf32>)
+ return %0 : memref<5xf32>
+}
More information about the Mlir-commits
mailing list