[Mlir-commits] [mlir] [mlir][bufferization] Allow cyclic function graphs without tensors (PR #68632)

Matthias Springer llvmlistbot at llvm.org
Mon Oct 9 14:30:11 PDT 2023


https://github.com/matthias-springer created https://github.com/llvm/llvm-project/pull/68632

Cyclic function call graphs are generally not supported by One-Shot Bufferize. However, they can be allowed when a function does not have tensor arguments or results. This is because it is then no longer necessary that the callee will be bufferized before the caller.

>From e7eadebadd007d3562387a3ac0f1d0733bb30830 Mon Sep 17 00:00:00 2001
From: Matthias Springer <me at m-sp.org>
Date: Mon, 9 Oct 2023 23:28:24 +0200
Subject: [PATCH] [mlir][bufferization] Allow cyclic function graphs without
 tensors

Cyclic function call graphs are generally not supported by One-Shot Bufferize. However, they can be allowed when a function does not have tensor arguments or results. This is because it is then no longer necessary that the callee will be bufferized before the caller.
---
 .../Transforms/OneShotModuleBufferize.cpp     | 15 ++++++++++++-
 .../one-shot-module-bufferize-invalid.mlir    | 12 +++++------
 .../Transforms/one-shot-module-bufferize.mlir | 21 +++++++++++++++++++
 3 files changed, 41 insertions(+), 7 deletions(-)

diff --git a/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp
index 417f457c8910ca9..786ebb23b457d52 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp
@@ -274,6 +274,13 @@ static void equivalenceAnalysis(func::FuncOp funcOp,
   });
 }
 
+/// Return "true" if the given function signature has tensor semantics.
+static bool hasTensorSignature(func::FuncOp funcOp) {
+  auto isaTensor = [](Type t) { return isa<TensorType>(t); };
+  return llvm::any_of(funcOp.getFunctionType().getInputs(), isaTensor) ||
+         llvm::any_of(funcOp.getFunctionType().getResults(), isaTensor);
+}
+
 /// Store all functions of the `moduleOp` in `orderedFuncOps`, sorted by
 /// callee-caller order (i.e. callees without callers first).
 /// Store the map of FuncOp to all its callers in `callerMap`.
@@ -297,10 +304,16 @@ getFuncOpsOrderedByCalls(ModuleOp moduleOp,
                   "without a unique ReturnOp";
     }
 
+    // Collect function calls and populate the caller map.
     numberCallOpsContainedInFuncOp[funcOp] = 0;
     return funcOp.walk([&](func::CallOp callOp) -> WalkResult {
       func::FuncOp calledFunction = getCalledFunction(callOp);
       assert(calledFunction && "could not retrieved called func::FuncOp");
+      // If the called function does not have any tensors in its signature, then
+      // it is not necessary to bufferize the callee before the caller.
+      if (!hasTensorSignature(calledFunction))
+        return WalkResult::skip();
+
       callerMap[calledFunction].insert(callOp);
       if (calledBy[calledFunction].insert(funcOp).second) {
         numberCallOpsContainedInFuncOp[funcOp]++;
@@ -310,7 +323,7 @@ getFuncOpsOrderedByCalls(ModuleOp moduleOp,
   });
   if (res.wasInterrupted())
     return failure();
-  // Iteratively remove function operation that do not call any of the
+  // Iteratively remove function operations that do not call any of the
   // functions remaining in the callCounter map and add them to the worklist.
   while (!numberCallOpsContainedInFuncOp.empty()) {
     auto it = llvm::find_if(numberCallOpsContainedInFuncOp,
diff --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-invalid.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-invalid.mlir
index fd74ae0b60dbbb8..ee0f71f668dc741 100644
--- a/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-invalid.mlir
+++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-invalid.mlir
@@ -27,14 +27,14 @@ func.func @swappy(%cond1 : i1, %cond2 : i1, %t1 : tensor<f32>, %t2 : tensor<f32>
 
 // expected-error @-3 {{expected callgraph to be free of circular dependencies}}
 
-func.func @foo() {
-  call @bar() : () -> ()
-  return
+func.func @foo(%t: tensor<5xf32>) -> tensor<5xf32> {
+  %0 = call @bar(%t) : (tensor<5xf32>) -> (tensor<5xf32>)
+  return %0 : tensor<5xf32>
 }
 
-func.func @bar() {
-  call @foo() : () -> ()
-  return
+func.func @bar(%t: tensor<5xf32>) -> tensor<5xf32>{
+  %0 = call @foo(%t) : (tensor<5xf32>) -> (tensor<5xf32>)
+  return %0 : tensor<5xf32>
 }
 
 // -----
diff --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize.mlir
index b9de4ba34e0e6d3..39f4835b28ffeb2 100644
--- a/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize.mlir
+++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize.mlir
@@ -662,3 +662,24 @@ func.func @br_in_func(%t: tensor<5xf32>) -> tensor<5xf32> {
 ^bb1(%arg1 : tensor<5xf32>):
   func.return %arg1 : tensor<5xf32>
 }
+
+// -----
+
+// Cyclic call graphs with tensors are not supported by One-Shot Bufferize.
+// However, if a function signature does not have any tensor arguments or
+// results, calls to that function are not seen as an "edge" in the fuction
+// call graph.
+
+// CHECK-LABEL: func.func @foo(%{{.*}}: memref<5xf32>) -> memref<5xf32>
+func.func @foo(%m: memref<5xf32>) -> memref<5xf32> {
+  %0 = tensor.empty() : tensor<5xf32>
+  %1 = func.call @bar(%0, %m)
+      : (tensor<5xf32>, memref<5xf32>) -> (memref<5xf32>)
+  return %1 : memref<5xf32>
+}
+
+// CHECK: func.func @bar(%{{.*}}: memref<5xf32, strided<[?], offset: ?>>, %arg1: memref<5xf32>) -> memref<5xf32>
+func.func @bar(%t: tensor<5xf32>, %m: memref<5xf32>) -> memref<5xf32> {
+  %0 = func.call @foo(%m) : (memref<5xf32>) -> (memref<5xf32>)
+  return %0 : memref<5xf32>
+}



More information about the Mlir-commits mailing list