[Mlir-commits] [mlir] [mlir][bufferization] Support bufferization of external functions (PR #113999)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Oct 28 21:17:46 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-bufferization

Author: Matthias Springer (matthias-springer)

<details>
<summary>Changes</summary>

This commit adds support for bufferizing external functions that have no body. Such functions were previously rejected by One-Shot Bufferize if they returned a tensor value.

This commit is in preparation of removing the deprecated `func-bufferize` pass. That pass can bufferize external functions.

Also update a few comments.

---
Full diff: https://github.com/llvm/llvm-project/pull/113999.diff


4 Files Affected:

- (modified) mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h (+6-5) 
- (modified) mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp (+30-26) 
- (modified) mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-invalid.mlir (-17) 
- (modified) mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize.mlir (+15) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
index aceb9d059b95f3..4866e31b19d5de 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
@@ -60,7 +60,8 @@ struct AliasingValue {
   bool isDefinite;
 };
 
-template <typename T> class AliasList {
+template <typename T>
+class AliasList {
 public:
   /// Create an empty list of aliases.
   AliasList() = default;
@@ -259,7 +260,7 @@ struct BufferizationOptions {
   /// Initializer function for analysis state.
   using AnalysisStateInitFn = std::function<void(AnalysisState &)>;
   /// Tensor -> MemRef type converter.
-  /// Parameters: Value, memory space, func op, bufferization options
+  /// Parameters: tensor type, memory space, func op, bufferization options
   using FunctionArgTypeConverterFn =
       std::function<BaseMemRefType(TensorType, Attribute memorySpace,
                                    func::FuncOp, const BufferizationOptions &)>;
@@ -344,9 +345,9 @@ struct BufferizationOptions {
   void setFunctionBoundaryTypeConversion(LayoutMapOption layoutMapOption);
 
   /// Type converter from tensors to memrefs. This type converter is used to
-  /// determine bufferized function argument types. By default, a type
-  /// converter that returns a memref type with a fully dynamic layout map is
-  /// used.
+  /// determine bufferized function argument and result types. By default, a
+  /// type converter that returns a memref type with a fully dynamic layout map
+  /// is used.
   ///
   /// If `bufferizeFunctionBoundaries` is not set, this function isn't used.
   FunctionArgTypeConverterFn functionArgTypeConverterFn = nullptr;
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
index 9fbe574ec392dc..a372e87d8335f1 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
@@ -82,7 +82,8 @@ getBufferizedFunctionArgType(FuncOp funcOp, int64_t index,
 
 /// Return the FuncOp called by `callOp`.
 static FuncOp getCalledFunction(CallOpInterface callOp) {
-  SymbolRefAttr sym = llvm::dyn_cast_if_present<SymbolRefAttr>(callOp.getCallableForCallee());
+  SymbolRefAttr sym =
+      llvm::dyn_cast_if_present<SymbolRefAttr>(callOp.getCallableForCallee());
   if (!sym)
     return nullptr;
   return dyn_cast_or_null<FuncOp>(
@@ -392,11 +393,11 @@ struct FuncOpInterface
     auto funcOp = cast<FuncOp>(op);
     FunctionType funcType = funcOp.getFunctionType();
 
-    // Construct the bufferized function type.
+    // Construct the bufferized function type. Compute the argument types.
     SmallVector<Type> argTypes;
     for (const auto &it : llvm::enumerate(funcType.getInputs())) {
       Type argType = it.value();
-      if (dyn_cast<TensorType>(argType)) {
+      if (isa<TensorType>(argType)) {
         argTypes.push_back(
             getBufferizedFunctionArgType(funcOp, it.index(), options));
         continue;
@@ -404,24 +405,33 @@ struct FuncOpInterface
       argTypes.push_back(argType);
     }
 
-    // Bodiless functions are assumed opaque and we cannot know the
-    // bufferization contract they want to enforce. As a consequence, only
-    // support functions that don't return any tensors atm.
-    if (funcOp.isExternal()) {
-      SmallVector<Type> retTypes;
-      for (Type resultType : funcType.getResults()) {
-        if (isa<TensorType>(resultType))
-          return funcOp->emitError() << "cannot bufferize bodiless function "
-                                     << "that returns a tensor";
+    // Compute the result types.
+    SmallVector<Type> retTypes;
+    for (Type resultType : funcType.getResults()) {
+      if (auto tensorType = dyn_cast<TensorType>(resultType)) {
+        BaseMemRefType resultType = options.functionArgTypeConverterFn(
+            tensorType, *options.defaultMemorySpaceFn(tensorType), funcOp,
+            options);
         retTypes.push_back(resultType);
+        continue;
       }
-      funcOp.setType(FunctionType::get(op->getContext(), argTypes, retTypes));
+      retTypes.push_back(resultType);
+    }
+
+    // Compute the new function type.
+    auto newFuncType = FunctionType::get(op->getContext(), argTypes, retTypes);
+
+    // If the function has no body, set the new function type and we are done.
+    if (funcOp.isExternal()) {
+      funcOp.setType(newFuncType);
       return success();
     }
 
     // TODO: Support functions with multiple returns.
     func::ReturnOp returnOp = getAssumedUniqueReturnOp(funcOp);
     assert(returnOp && "expected func with single return op");
+    assert(returnOp->getNumOperands() == retTypes.size() &&
+           "incorrect number of return values");
     Location loc = returnOp.getLoc();
 
     // 1. Bufferize every block.
@@ -430,10 +440,10 @@ struct FuncOpInterface
                                                         options)))
         return failure();
 
-    // 2. For each result, keep track of which inplace argument it reuses.
+    // 2. Bufferize all operands of the return op.
     SmallVector<Value> returnValues;
-    for (OpOperand &returnOperand : returnOp->getOpOperands()) {
-      Value returnVal = returnOperand.get();
+    for (auto [returnVal, bufferizedType] :
+         llvm::zip_equal(returnOp->getOperands(), retTypes)) {
       auto tensorType = dyn_cast<TensorType>(returnVal.getType());
       rewriter.setInsertionPoint(returnOp);
 
@@ -443,23 +453,17 @@ struct FuncOpInterface
         continue;
       }
 
-      // Note: If `inferFunctionResultLayout = true`, cast are later folded
+      // Note: If `inferFunctionResultLayout = true`, casts are later folded
       // away.
-      BaseMemRefType resultType = options.functionArgTypeConverterFn(
-          tensorType, *options.defaultMemorySpaceFn(tensorType), funcOp,
-          options);
       Value toMemrefOp = rewriter.create<bufferization::ToMemrefOp>(
-          loc, resultType, returnVal);
+          loc, bufferizedType, returnVal);
       returnValues.push_back(toMemrefOp);
     }
 
-    // 3. Rewrite the terminator without the in-place bufferizable values.
     returnOp.getOperandsMutable().assign(returnValues);
 
-    // 4. Rewrite the FuncOp type to buffer form.
-    funcOp.setType(FunctionType::get(op->getContext(), argTypes,
-                                     ValueRange(returnValues).getTypes()));
-
+    // 3. Set the new function type.
+    funcOp.setType(newFuncType);
     return success();
   }
 
diff --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-invalid.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-invalid.mlir
index ee0f71f668dc74..2829eafb7c1c59 100644
--- a/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-invalid.mlir
+++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-invalid.mlir
@@ -1,11 +1,5 @@
 // RUN: mlir-opt %s -allow-unregistered-dialect -one-shot-bufferize="bufferize-function-boundaries=1" -split-input-file -verify-diagnostics
 
-// expected-error @+2 {{cannot bufferize bodiless function that returns a tensor}}
-// expected-error @+1 {{failed to bufferize op}}
-func.func private @foo() -> tensor<?xf32>
-
-// -----
-
 // expected-error @+1 {{cannot bufferize a FuncOp with tensors and without a unique ReturnOp}}
 func.func @swappy(%cond1 : i1, %cond2 : i1, %t1 : tensor<f32>, %t2 : tensor<f32>)
     -> (tensor<f32>, tensor<f32>)
@@ -123,17 +117,6 @@ func.func @to_tensor_op_unsupported(%m: memref<?xf32>, %idx: index) -> (f32) {
 
 // -----
 
-// expected-error @+2 {{failed to bufferize op}}
-// expected-error @+1 {{cannot bufferize bodiless function that returns a tensor}}
-func.func private @foo(%t : tensor<?xf32>) -> (f32, tensor<?xf32>, f32)
-
-func.func @call_to_unknown_tensor_returning_func(%t : tensor<?xf32>) {
-  call @foo(%t) : (tensor<?xf32>) -> (f32, tensor<?xf32>, f32)
-  return
-}
-
-// -----
-
 func.func @yield_alloc_dominance_test_2(%cst : f32, %idx : index,
                                         %idx2 : index) -> f32 {
   %1 = bufferization.alloc_tensor(%idx) : tensor<?xf32>
diff --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize.mlir
index 0d5224514e3a02..d31b43477beb9f 100644
--- a/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize.mlir
+++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize.mlir
@@ -42,6 +42,21 @@ func.func private @external_func_with_return_val(tensor<4xi32>) -> f32
 
 // -----
 
+// Bufferization of bodiless function that returns a tensor.
+
+// CHECK: func.func private @foo(memref<?xf32, strided<[?], offset: ?>>) -> (f32, memref<?xf32, strided<[?], offset: ?>>, f32)
+func.func private @foo(%t : tensor<?xf32>) -> (f32, tensor<?xf32>, f32)
+
+// CHECK: func.func @call_to_unknown_tensor_returning_func(
+// CHECK-SAME: %[[arg0:.*]]: memref<?xf32, strided<[?], offset: ?>>) {
+func.func @call_to_unknown_tensor_returning_func(%t : tensor<?xf32>) {
+  // CHECK: call @foo(%[[arg0]]) : (memref<?xf32, strided<[?], offset: ?>>) -> (f32, memref<?xf32, strided<[?], offset: ?>>, f32)
+  call @foo(%t) : (tensor<?xf32>) -> (f32, tensor<?xf32>, f32)
+  return
+}
+
+// -----
+
 // A function that returns a non-equivalent tensor with layout map.
 
 // CHECK-LABEL: func @return_extract_slice(%{{.*}}) -> memref<2x?xf32, strided<[10, 1], offset: ?>>

``````````

</details>


https://github.com/llvm/llvm-project/pull/113999


More information about the Mlir-commits mailing list