[Mlir-commits] [mlir] 8906b7b - Enable custom alloc-like ops in `promoteBufferResultsToOutParams` (#120288)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Dec 26 09:32:55 PST 2024


Author: srcarroll
Date: 2024-12-26T11:32:51-06:00
New Revision: 8906b7be918be653d3c5f2ef3dbd923561603969

URL: https://github.com/llvm/llvm-project/commit/8906b7be918be653d3c5f2ef3dbd923561603969
DIFF: https://github.com/llvm/llvm-project/commit/8906b7be918be653d3c5f2ef3dbd923561603969.diff

LOG: Enable custom alloc-like ops in `promoteBufferResultsToOutParams` (#120288)

In `buffer-results-to-out-params`, when `hoist-static-allocs` option is
enabled the pass was looking for `memref.alloc`s in order to attempt to
avoid copies when it can. Which makes it not extensible to external ops
that have allocation like properties. This patch simply changes
`memref::AllocOp` to `AllocationOpInterface` in the check to enable for
any allocation op.
Moreover, for function call updates, we enable setting an allocation
function callback in `BufferResultsToOutParamsOpts` to allow users to
emit their own alloc-like op.

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h
    mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h
index fe43a05c81fdc3..c8e456a1d7e380 100644
--- a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h
@@ -2,10 +2,12 @@
 #define MLIR_DIALECT_BUFFERIZATION_TRANSFORMS_PASSES_H
 
 #include "mlir/Dialect/Bufferization/IR/BufferDeallocationOpInterface.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Pass/Pass.h"
 
 namespace mlir {
 class FunctionOpInterface;
+class MemRefType;
 class ModuleOp;
 class RewritePatternSet;
 class OpBuilder;
@@ -38,7 +40,7 @@ std::unique_ptr<Pass> createOwnershipBasedBufferDeallocationPass(
     DeallocationOptions options = DeallocationOptions());
 
 /// Creates a pass that finds all temporary allocations
-/// and attempts to move the deallocation after the last user/dependency 
+/// and attempts to move the deallocation after the last user/dependency
 /// of the allocation, thereby optimizing allocation liveness.
 std::unique_ptr<Pass> createOptimizeAllocationLivenessPass();
 
@@ -157,6 +159,12 @@ std::unique_ptr<Pass> createBufferLoopHoistingPass();
 // Options struct for BufferResultsToOutParams pass.
 // Note: defined only here, not in tablegen.
 struct BufferResultsToOutParamsOpts {
+  /// Allocator function: Generate a memref allocation with the given type.
+  /// Since `promoteBufferResultsToOutParams` doesn't allow dynamically shaped
+  /// results, we don't allow passing a range of values for dynamic dims.
+  using AllocationFn =
+      std::function<FailureOr<Value>(OpBuilder &, Location, MemRefType)>;
+
   /// Memcpy function: Generate a memcpy between two memrefs.
   using MemCpyFn =
       std::function<LogicalResult(OpBuilder &, Location, Value, Value)>;
@@ -167,9 +175,20 @@ struct BufferResultsToOutParamsOpts {
     return true;
   };
 
+  /// Allocation function; used to allocate a memref.
+  /// Default memref.alloc is used
+  AllocationFn allocationFn = [](OpBuilder &builder, Location loc,
+                                 MemRefType type) {
+    return builder.create<memref::AllocOp>(loc, type).getResult();
+  };
+
   /// Memcpy function; used to create a copy between two memrefs.
-  /// If this is empty, memref.copy is used.
-  std::optional<MemCpyFn> memCpyFn;
+  /// Default memref.copy is used.
+  MemCpyFn memCpyFn = [](OpBuilder &builder, Location loc, Value from,
+                         Value to) {
+    builder.create<memref::CopyOp>(loc, from, to);
+    return success();
+  };
 
   /// If true, the pass adds a "bufferize.result" attribute to each output
   /// parameter.

diff  --git a/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp b/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp
index b7755b2be8483b..2502744cb3f580 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp
@@ -6,6 +6,7 @@
 //
 //===----------------------------------------------------------------------===//
 
+#include "mlir/Dialect/Bufferization/IR/AllocationOpInterface.h"
 #include "mlir/Dialect/Bufferization/Transforms/Passes.h"
 
 #include "mlir/Dialect/Func/IR/FuncOps.h"
@@ -21,6 +22,7 @@ namespace bufferization {
 } // namespace mlir
 
 using namespace mlir;
+using AllocationFn = bufferization::BufferResultsToOutParamsOpts::AllocationFn;
 using MemCpyFn = bufferization::BufferResultsToOutParamsOpts::MemCpyFn;
 
 /// Return `true` if the given MemRef type has a fully dynamic layout.
@@ -105,10 +107,9 @@ updateFuncOp(func::FuncOp func,
 // Updates all ReturnOps in the scope of the given func::FuncOp by either
 // keeping them as return values or copying the associated buffer contents into
 // the given out-params.
-static LogicalResult updateReturnOps(func::FuncOp func,
-                                     ArrayRef<BlockArgument> appendedEntryArgs,
-                                     MemCpyFn memCpyFn,
-                                     bool hoistStaticAllocs) {
+static LogicalResult
+updateReturnOps(func::FuncOp func, ArrayRef<BlockArgument> appendedEntryArgs,
+                const bufferization::BufferResultsToOutParamsOpts &options) {
   auto res = func.walk([&](func::ReturnOp op) {
     SmallVector<Value, 6> copyIntoOutParams;
     SmallVector<Value, 6> keepAsReturnOperands;
@@ -120,13 +121,14 @@ static LogicalResult updateReturnOps(func::FuncOp func,
     }
     OpBuilder builder(op);
     for (auto [orig, arg] : llvm::zip(copyIntoOutParams, appendedEntryArgs)) {
-      if (hoistStaticAllocs &&
-          isa_and_nonnull<memref::AllocOp>(orig.getDefiningOp()) &&
+      if (options.hoistStaticAllocs &&
+          isa_and_nonnull<bufferization::AllocationOpInterface>(
+              orig.getDefiningOp()) &&
           mlir::cast<MemRefType>(orig.getType()).hasStaticShape()) {
         orig.replaceAllUsesWith(arg);
         orig.getDefiningOp()->erase();
       } else {
-        if (failed(memCpyFn(builder, op.getLoc(), orig, arg)))
+        if (failed(options.memCpyFn(builder, op.getLoc(), orig, arg)))
           return WalkResult::interrupt();
       }
     }
@@ -175,7 +177,14 @@ updateCalls(ModuleOp module,
       auto allocType =
           MemRefType::get(memrefType.getShape(), memrefType.getElementType(),
                           AffineMap(), memrefType.getMemorySpace());
-      Value outParam = builder.create<memref::AllocOp>(op.getLoc(), allocType);
+      auto maybeOutParam =
+          options.allocationFn(builder, op.getLoc(), allocType);
+      if (failed(maybeOutParam)) {
+        op.emitError() << "failed to create allocation op";
+        didFail = true;
+        return;
+      }
+      Value outParam = maybeOutParam.value();
       if (!hasStaticIdentityLayout(memrefType)) {
         // Layout maps are already checked in `updateFuncOp`.
         assert(hasFullyDynamicLayoutMap(memrefType) &&
@@ -213,14 +222,7 @@ LogicalResult mlir::bufferization::promoteBufferResultsToOutParams(
       return failure();
     if (func.isExternal())
       continue;
-    auto defaultMemCpyFn = [](OpBuilder &builder, Location loc, Value from,
-                              Value to) {
-      builder.create<memref::CopyOp>(loc, from, to);
-      return success();
-    };
-    if (failed(updateReturnOps(func, appendedEntryArgs,
-                               options.memCpyFn.value_or(defaultMemCpyFn),
-                               options.hoistStaticAllocs))) {
+    if (failed(updateReturnOps(func, appendedEntryArgs, options))) {
       return failure();
     }
   }


        


More information about the Mlir-commits mailing list