[Mlir-commits] [mlir] [mlir][bufferization] Add support for recursive function calls (PR #114003)

Matthias Springer llvmlistbot at llvm.org
Mon Oct 28 22:33:34 PDT 2024


https://github.com/matthias-springer updated https://github.com/llvm/llvm-project/pull/114003

>From 89f8a746e6b08e89c0b9a53a4a26d8b8ec5677d8 Mon Sep 17 00:00:00 2001
From: Matthias Springer <mspringer at nvidia.com>
Date: Tue, 29 Oct 2024 06:26:33 +0100
Subject: [PATCH] [mlir][bufferization] Add support for recursive function
 calls

---
 .../FuncBufferizableOpInterfaceImpl.cpp       | 28 ++++++--
 .../Transforms/OneShotModuleBufferize.cpp     | 67 ++++++++++++++-----
 .../one-shot-module-bufferize-analysis.mlir   | 12 ++++
 .../one-shot-module-bufferize-invalid.mlir    | 14 ----
 .../Transforms/one-shot-module-bufferize.mlir | 49 ++++++++++++++
 5 files changed, 137 insertions(+), 33 deletions(-)

diff --git a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
index 9fbe574ec392dc..48d61fc223f4c7 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
@@ -82,7 +82,8 @@ getBufferizedFunctionArgType(FuncOp funcOp, int64_t index,
 
 /// Return the FuncOp called by `callOp`.
 static FuncOp getCalledFunction(CallOpInterface callOp) {
-  SymbolRefAttr sym = llvm::dyn_cast_if_present<SymbolRefAttr>(callOp.getCallableForCallee());
+  SymbolRefAttr sym =
+      llvm::dyn_cast_if_present<SymbolRefAttr>(callOp.getCallableForCallee());
   if (!sym)
     return nullptr;
   return dyn_cast_or_null<FuncOp>(
@@ -206,11 +207,18 @@ struct CallOpInterface
     FuncOp funcOp = getCalledFunction(callOp);
     assert(funcOp && "expected CallOp to a FuncOp");
 
-    // The callee was already bufferized, so we can directly take the type from
+    // If the callee was already bufferized, we can directly take the type from
     // its signature.
     FunctionType funcType = funcOp.getFunctionType();
-    return cast<BaseMemRefType>(
-        funcType.getResult(cast<OpResult>(value).getResultNumber()));
+    Type resultType =
+        funcType.getResult(cast<OpResult>(value).getResultNumber());
+    if (auto bufferizedType = dyn_cast<BaseMemRefType>(resultType))
+      return bufferizedType;
+
+    // Otherwise, call the type converter to compute the bufferized type.
+    auto tensorType = cast<TensorType>(resultType);
+    return options.functionArgTypeConverterFn(
+        tensorType, *options.defaultMemorySpaceFn(tensorType), funcOp, options);
   }
 
   /// All function arguments are writable. It is the responsibility of the
@@ -260,6 +268,18 @@ struct CallOpInterface
 
       // Caller / callee type mismatch is handled with castOrReallocMemRefValue.
       auto memRefType = funcType.getInput(opOperand.getOperandNumber());
+      if (!isa<BaseMemRefType>(memRefType)) {
+        // The called function was not bufferized yet. This can happen when
+        // there cycles in the function call graph. Compute the bufferized
+        // result type.
+        FailureOr<BaseMemRefType> maybeMemRefType =
+            bufferization::getBufferType(
+                funcOp.getArgument(opOperand.getOperandNumber()), options);
+        if (failed(maybeMemRefType))
+          return failure();
+        memRefType = *maybeMemRefType;
+      }
+
       // Since we don't yet have a clear layout story, to_memref may
       // conservatively turn tensors into more dynamic memref than necessary.
       // If the memref type of the callee fails, introduce an extra memref.cast
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp
index 0a4072605c265f..4d0d232e6afae9 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp
@@ -285,14 +285,17 @@ static bool hasTensorSignature(func::FuncOp funcOp) {
 }
 
 /// Store all functions of the `moduleOp` in `orderedFuncOps`, sorted by
-/// callee-caller order (i.e. callees without callers first).
+/// callee-caller order (i.e., callees without callers first). Store all
+/// remaining functions (i.e., the ones that call each other recursively) in
+/// `remainingFuncOps`.
+///
 /// Store the map of FuncOp to all its callers in `callerMap`.
-/// Return `failure()` if a cycle of calls is detected or if we are unable to
-/// retrieve the called FuncOp from any func::CallOp.
-static LogicalResult
-getFuncOpsOrderedByCalls(ModuleOp moduleOp,
-                         SmallVectorImpl<func::FuncOp> &orderedFuncOps,
-                         FuncCallerMap &callerMap) {
+///
+/// Return `failure()` if we are unable to retrieve the called FuncOp from
+/// any func::CallOp.
+static LogicalResult getFuncOpsOrderedByCalls(
+    ModuleOp moduleOp, SmallVectorImpl<func::FuncOp> &orderedFuncOps,
+    SmallVectorImpl<func::FuncOp> &remainingFuncOps, FuncCallerMap &callerMap) {
   // For each FuncOp, the set of functions called by it (i.e. the union of
   // symbols of all nested func::CallOp).
   DenseMap<func::FuncOp, DenseSet<func::FuncOp>> calledBy;
@@ -326,19 +329,25 @@ getFuncOpsOrderedByCalls(ModuleOp moduleOp,
   });
   if (res.wasInterrupted())
     return failure();
+
   // Iteratively remove function operations that do not call any of the
-  // functions remaining in the callCounter map and add them to the worklist.
+  // functions remaining in the callCounter map and add them to ordered list.
   while (!numberCallOpsContainedInFuncOp.empty()) {
     auto it = llvm::find_if(numberCallOpsContainedInFuncOp,
                             [](auto entry) { return entry.getSecond() == 0; });
     if (it == numberCallOpsContainedInFuncOp.end())
-      return moduleOp.emitOpError(
-          "expected callgraph to be free of circular dependencies.");
+      break;
     orderedFuncOps.push_back(it->getFirst());
     for (auto callee : calledBy[it->getFirst()])
       numberCallOpsContainedInFuncOp[callee]--;
     numberCallOpsContainedInFuncOp.erase(it);
   }
+
+  // Put all other functions in the list of remaining functions. These are
+  // functions that call each each circularly.
+  for (auto it : numberCallOpsContainedInFuncOp)
+    remainingFuncOps.push_back(it.first);
+
   return success();
 }
 
@@ -379,15 +388,17 @@ mlir::bufferization::analyzeModuleOp(ModuleOp moduleOp,
   FuncAnalysisState &funcState = getOrCreateFuncAnalysisState(state);
 
   // A list of functions in the order in which they are analyzed + bufferized.
-  SmallVector<func::FuncOp> orderedFuncOps;
+  SmallVector<func::FuncOp> orderedFuncOps, remainingFuncOps;
 
   // A mapping of FuncOps to their callers.
   FuncCallerMap callerMap;
 
-  if (failed(getFuncOpsOrderedByCalls(moduleOp, orderedFuncOps, callerMap)))
+  if (failed(getFuncOpsOrderedByCalls(moduleOp, orderedFuncOps,
+                                      remainingFuncOps, callerMap)))
     return failure();
 
-  // Analyze ops.
+  // Analyze ops in order. Starting with functions that are not calling any
+  // other functions.
   for (func::FuncOp funcOp : orderedFuncOps) {
     if (!state.getOptions().isOpAllowed(funcOp))
       continue;
@@ -411,6 +422,25 @@ mlir::bufferization::analyzeModuleOp(ModuleOp moduleOp,
     funcState.analyzedFuncOps[funcOp] = FuncOpAnalysisState::Analyzed;
   }
 
+  // Analyze all other ops.
+  for (func::FuncOp funcOp : remainingFuncOps) {
+    if (!state.getOptions().isOpAllowed(funcOp))
+      continue;
+
+    // Gather equivalence info for CallOps.
+    equivalenceAnalysis(funcOp, state, funcState);
+
+    // Analyze funcOp.
+    if (failed(analyzeOp(funcOp, state, statistics)))
+      return failure();
+
+    // TODO: We currently skip all function argument analyses for functions
+    // that call each other circularly. These analyses do not support recursive
+    // calls yet. The `BufferizableOpInterface` implementations of `func`
+    // dialect ops return conservative results in the absence of analysis
+    // information.
+  }
+
   return success();
 }
 
@@ -430,13 +460,20 @@ LogicalResult mlir::bufferization::bufferizeModuleOp(
   IRRewriter rewriter(moduleOp.getContext());
 
   // A list of functions in the order in which they are analyzed + bufferized.
-  SmallVector<func::FuncOp> orderedFuncOps;
+  SmallVector<func::FuncOp> orderedFuncOps, remainingFuncOps;
 
   // A mapping of FuncOps to their callers.
   FuncCallerMap callerMap;
 
-  if (failed(getFuncOpsOrderedByCalls(moduleOp, orderedFuncOps, callerMap)))
+  // Try to bufferize functions in calling order. I.e., first bufferize
+  // functions that do not call other functions. This allows us to infer
+  // accurate buffer types for function return values. Functions that call
+  // each other recursively are bufferized in an unspecified order at the end.
+  // We may use unnecessarily "complex" (in terms of layout map) buffer types.
+  if (failed(getFuncOpsOrderedByCalls(moduleOp, orderedFuncOps,
+                                      remainingFuncOps, callerMap)))
     return failure();
+  llvm::append_range(orderedFuncOps, remainingFuncOps);
 
   // Bufferize functions.
   for (func::FuncOp funcOp : orderedFuncOps) {
diff --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-analysis.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-analysis.mlir
index 42d9cc00d3ff5a..3f6d182b57c031 100644
--- a/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-analysis.mlir
+++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-analysis.mlir
@@ -1348,3 +1348,15 @@ func.func @private_func_aliasing(%t: tensor<?xf32>) -> f32 {
   %2 = tensor.extract %1[%c0] : tensor<6xf32>
   return %2 : f32
 }
+
+// -----
+
+// CHECK-LABEL: func @recursive_function
+func.func @recursive_function(%a: tensor<?xf32>, %b: tensor<?xf32>) -> (tensor<?xf32>, tensor<?xf32>) {
+  // The analysis does not support recursive function calls and is conservative
+  // around them.
+  // CHECK: call @recursive_function
+  // CHECK-SAME: {__inplace_operands_attr__ = ["false", "false"]}
+  %0:2 = call @recursive_function(%a, %b) : (tensor<?xf32>, tensor<?xf32>) -> (tensor<?xf32>, tensor<?xf32>)
+  return %0#0, %0#1 : tensor<?xf32>, tensor<?xf32>
+}
diff --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-invalid.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-invalid.mlir
index ee0f71f668dc74..888310169f3043 100644
--- a/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-invalid.mlir
+++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-invalid.mlir
@@ -25,20 +25,6 @@ func.func @swappy(%cond1 : i1, %cond2 : i1, %t1 : tensor<f32>, %t2 : tensor<f32>
 
 // -----
 
-// expected-error @-3 {{expected callgraph to be free of circular dependencies}}
-
-func.func @foo(%t: tensor<5xf32>) -> tensor<5xf32> {
-  %0 = call @bar(%t) : (tensor<5xf32>) -> (tensor<5xf32>)
-  return %0 : tensor<5xf32>
-}
-
-func.func @bar(%t: tensor<5xf32>) -> tensor<5xf32>{
-  %0 = call @foo(%t) : (tensor<5xf32>) -> (tensor<5xf32>)
-  return %0 : tensor<5xf32>
-}
-
-// -----
-
 func.func @scf_for(%A : tensor<?xf32>,
               %B : tensor<?xf32> {bufferization.writable = true},
               %C : tensor<4xf32>,
diff --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize.mlir
index 0d5224514e3a02..da145291960680 100644
--- a/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize.mlir
+++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize.mlir
@@ -707,3 +707,52 @@ func.func @bar(%t: tensor<5xf32>, %m: memref<5xf32>) -> memref<5xf32> {
   %0 = func.call @foo(%m) : (memref<5xf32>) -> (memref<5xf32>)
   return %0 : memref<5xf32>
 }
+
+// -----
+
+// A recursive function.
+
+// CHECK-LABEL: func.func @foo(
+//  CHECK-SAME:     %[[arg0:.*]]: memref<5xf32, strided<[?], offset: ?>>) -> memref<5xf32, strided<[?], offset: ?>> {
+func.func @foo(%t: tensor<5xf32>) -> tensor<5xf32> {
+  // We are conservative around recursive functions. The analysis cannot handle
+  // them, so we have to assume the op operand of the call op bufferizes to a
+  // memory read and write. This causes a copy in this test case.
+  // CHECK: %[[copy:.*]] = memref.alloc() {alignment = 64 : i64} : memref<5xf32>
+  // CHECK: memref.copy %[[arg0]], %[[copy]]
+  // CHECK: %[[cast:.*]] = memref.cast %[[copy]] : memref<5xf32> to memref<5xf32, strided<[?], offset: ?>>
+  // CHECK: %[[call:.*]] = call @foo(%[[cast]])
+  %0 = call @foo(%t) : (tensor<5xf32>) -> (tensor<5xf32>)
+
+  // CHECK: memref.load %[[arg0]]
+  %c0 = arith.constant 0 : index
+  %extr = tensor.extract %t[%c0] : tensor<5xf32>
+  vector.print %extr : f32
+
+  // CHECK: return %[[call]]
+  return %0 : tensor<5xf32>
+}
+
+// -----
+
+// Two functions calling each other recursively.
+
+// CHECK-LABEL: func.func @foo(
+//  CHECK-SAME:     %[[arg0:.*]]: memref<5xf32, strided<[?], offset: ?>>) -> memref<5xf32, strided<[?], offset: ?>> {
+//       CHECK:   %[[call:.*]] = call @bar(%[[arg0]]) : (memref<5xf32, strided<[?], offset: ?>>) -> memref<5xf32, strided<[?], offset: ?>>
+//       CHECK:   return %[[call]]
+//       CHECK: }
+func.func @foo(%t: tensor<5xf32>) -> tensor<5xf32> {
+  %0 = call @bar(%t) : (tensor<5xf32>) -> (tensor<5xf32>)
+  return %0 : tensor<5xf32>
+}
+
+// CHECK-LABEL: func.func @bar(
+//  CHECK-SAME:     %[[arg0:.*]]: memref<5xf32, strided<[?], offset: ?>>) -> memref<5xf32, strided<[?], offset: ?>> {
+//       CHECK:   %[[call:.*]] = call @foo(%[[arg0]]) : (memref<5xf32, strided<[?], offset: ?>>) -> memref<5xf32, strided<[?], offset: ?>>
+//       CHECK:   return %[[call]]
+//       CHECK: }
+func.func @bar(%t: tensor<5xf32>) -> tensor<5xf32>{
+  %0 = call @foo(%t) : (tensor<5xf32>) -> (tensor<5xf32>)
+  return %0 : tensor<5xf32>
+}



More information about the Mlir-commits mailing list