[Mlir-commits] [mlir] Enable custom alloc-like ops in `promoteBufferResultsToOutParams` (PR #120288)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Dec 20 11:40:53 PST 2024


https://github.com/srcarroll updated https://github.com/llvm/llvm-project/pull/120288

>From 1db8e794fc718ba3720789659ebba47c9e7e05a8 Mon Sep 17 00:00:00 2001
From: Sam <srcarroll314 at gmail.com>
Date: Tue, 17 Dec 2024 12:57:55 -0600
Subject: [PATCH 1/3] Enable any `AllocationOpInterface` with
 `hoistStaticAllocs` option

---
 .../Bufferization/Transforms/BufferResultsToOutParams.cpp     | 4 +++-
 1 file changed, 3 insertions(+), 1 deletion(-)

diff --git a/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp b/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp
index b7755b2be8483b..b4d2d6b0c5da8f 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp
@@ -6,6 +6,7 @@
 //
 //===----------------------------------------------------------------------===//
 
+#include "mlir/Dialect/Bufferization/IR/AllocationOpInterface.h"
 #include "mlir/Dialect/Bufferization/Transforms/Passes.h"
 
 #include "mlir/Dialect/Func/IR/FuncOps.h"
@@ -121,7 +122,8 @@ static LogicalResult updateReturnOps(func::FuncOp func,
     OpBuilder builder(op);
     for (auto [orig, arg] : llvm::zip(copyIntoOutParams, appendedEntryArgs)) {
       if (hoistStaticAllocs &&
-          isa_and_nonnull<memref::AllocOp>(orig.getDefiningOp()) &&
+          isa_and_nonnull<bufferization::AllocationOpInterface>(
+              orig.getDefiningOp()) &&
           mlir::cast<MemRefType>(orig.getType()).hasStaticShape()) {
         orig.replaceAllUsesWith(arg);
         orig.getDefiningOp()->erase();

>From bf145cfeb45263d879eb998ce47cf6dd93aa2b06 Mon Sep 17 00:00:00 2001
From: Sam <srcarroll314 at gmail.com>
Date: Thu, 19 Dec 2024 11:46:37 -0600
Subject: [PATCH 2/3] add and use custom allocation function

---
 .../Dialect/Bufferization/Transforms/Passes.h | 13 +++++++++-
 .../Transforms/BufferResultsToOutParams.cpp   | 24 ++++++++++++++-----
 2 files changed, 30 insertions(+), 7 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h
index fe43a05c81fdc3..966438956fc6cd 100644
--- a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h
@@ -6,6 +6,7 @@
 
 namespace mlir {
 class FunctionOpInterface;
+class MemRefType;
 class ModuleOp;
 class RewritePatternSet;
 class OpBuilder;
@@ -38,7 +39,7 @@ std::unique_ptr<Pass> createOwnershipBasedBufferDeallocationPass(
     DeallocationOptions options = DeallocationOptions());
 
 /// Creates a pass that finds all temporary allocations
-/// and attempts to move the deallocation after the last user/dependency 
+/// and attempts to move the deallocation after the last user/dependency
 /// of the allocation, thereby optimizing allocation liveness.
 std::unique_ptr<Pass> createOptimizeAllocationLivenessPass();
 
@@ -157,6 +158,12 @@ std::unique_ptr<Pass> createBufferLoopHoistingPass();
 // Options struct for BufferResultsToOutParams pass.
 // Note: defined only here, not in tablegen.
 struct BufferResultsToOutParamsOpts {
+  /// Allocator function: Generate a memref allocation with the given type.
+  /// Since `promoteBufferResultsToOutParams` doesn't allow dynamically shaped
+  /// results, we don't allow passing a range of values for dynamic dims.
+  using AllocationFn =
+      std::function<FailureOr<Value>(OpBuilder &, Location, MemRefType)>;
+
   /// Memcpy function: Generate a memcpy between two memrefs.
   using MemCpyFn =
       std::function<LogicalResult(OpBuilder &, Location, Value, Value)>;
@@ -167,6 +174,10 @@ struct BufferResultsToOutParamsOpts {
     return true;
   };
 
+  /// Allocation function; used to allocate a memref.
+  /// If this is empty, memref.alloc is used
+  std::optional<AllocationFn> allocationFn;
+
   /// Memcpy function; used to create a copy between two memrefs.
   /// If this is empty, memref.copy is used.
   std::optional<MemCpyFn> memCpyFn;
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp b/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp
index b4d2d6b0c5da8f..545b6ca009c03c 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp
@@ -22,6 +22,7 @@ namespace bufferization {
 } // namespace mlir
 
 using namespace mlir;
+using AllocationFn = bufferization::BufferResultsToOutParamsOpts::AllocationFn;
 using MemCpyFn = bufferization::BufferResultsToOutParamsOpts::MemCpyFn;
 
 /// Return `true` if the given MemRef type has a fully dynamic layout.
@@ -141,9 +142,8 @@ static LogicalResult updateReturnOps(func::FuncOp func,
 
 // Updates all CallOps in the scope of the given ModuleOp by allocating
 // temporary buffers for newly introduced out params.
-static LogicalResult
-updateCalls(ModuleOp module,
-            const bufferization::BufferResultsToOutParamsOpts &options) {
+static LogicalResult updateCalls(ModuleOp module, AllocationFn allocationFn,
+                                 std::function<bool(func::FuncOp *)> filterFn) {
   bool didFail = false;
   SymbolTable symtab(module);
   module.walk([&](func::CallOp op) {
@@ -154,7 +154,7 @@ updateCalls(ModuleOp module,
       didFail = true;
       return;
     }
-    if (!options.filterFn(&callee))
+    if (!filterFn(&callee))
       return;
     SmallVector<Value, 6> replaceWithNewCallResults;
     SmallVector<Value, 6> replaceWithOutParams;
@@ -177,7 +177,13 @@ updateCalls(ModuleOp module,
       auto allocType =
           MemRefType::get(memrefType.getShape(), memrefType.getElementType(),
                           AffineMap(), memrefType.getMemorySpace());
-      Value outParam = builder.create<memref::AllocOp>(op.getLoc(), allocType);
+      auto maybeOutParam = allocationFn(builder, op.getLoc(), allocType);
+      if (failed(maybeOutParam)) {
+        op.emitError() << "failed to create allocation op";
+        didFail = true;
+        return;
+      }
+      Value outParam = maybeOutParam.value();
       if (!hasStaticIdentityLayout(memrefType)) {
         // Layout maps are already checked in `updateFuncOp`.
         assert(hasFullyDynamicLayoutMap(memrefType) &&
@@ -226,7 +232,13 @@ LogicalResult mlir::bufferization::promoteBufferResultsToOutParams(
       return failure();
     }
   }
-  if (failed(updateCalls(module, options)))
+  auto defaultAllocationFn = [](OpBuilder &builder, Location loc,
+                                MemRefType type) {
+    return builder.create<memref::AllocOp>(loc, type).getResult();
+  };
+  if (failed(updateCalls(module,
+                         options.allocationFn.value_or(defaultAllocationFn),
+                         options.filterFn)))
     return failure();
   return success();
 }

>From cc788e0b16ac112c0ea214c04a1de943d8101f32 Mon Sep 17 00:00:00 2001
From: Sam <srcarroll314 at gmail.com>
Date: Fri, 20 Dec 2024 10:14:42 -0600
Subject: [PATCH 3/3] Move default alloc/copy fns to
 `BufferResultsToOutParamsOpts` struct

---
 .../Dialect/Bufferization/Transforms/Passes.h | 16 ++++++--
 .../Transforms/BufferResultsToOutParams.cpp   | 38 +++++++------------
 2 files changed, 25 insertions(+), 29 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h
index 966438956fc6cd..c8e456a1d7e380 100644
--- a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h
@@ -2,6 +2,7 @@
 #define MLIR_DIALECT_BUFFERIZATION_TRANSFORMS_PASSES_H
 
 #include "mlir/Dialect/Bufferization/IR/BufferDeallocationOpInterface.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Pass/Pass.h"
 
 namespace mlir {
@@ -175,12 +176,19 @@ struct BufferResultsToOutParamsOpts {
   };
 
   /// Allocation function; used to allocate a memref.
-  /// If this is empty, memref.alloc is used
-  std::optional<AllocationFn> allocationFn;
+  /// Default memref.alloc is used
+  AllocationFn allocationFn = [](OpBuilder &builder, Location loc,
+                                 MemRefType type) {
+    return builder.create<memref::AllocOp>(loc, type).getResult();
+  };
 
   /// Memcpy function; used to create a copy between two memrefs.
-  /// If this is empty, memref.copy is used.
-  std::optional<MemCpyFn> memCpyFn;
+  /// Default memref.copy is used.
+  MemCpyFn memCpyFn = [](OpBuilder &builder, Location loc, Value from,
+                         Value to) {
+    builder.create<memref::CopyOp>(loc, from, to);
+    return success();
+  };
 
   /// If true, the pass adds a "bufferize.result" attribute to each output
   /// parameter.
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp b/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp
index 545b6ca009c03c..2502744cb3f580 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp
@@ -107,10 +107,9 @@ updateFuncOp(func::FuncOp func,
 // Updates all ReturnOps in the scope of the given func::FuncOp by either
 // keeping them as return values or copying the associated buffer contents into
 // the given out-params.
-static LogicalResult updateReturnOps(func::FuncOp func,
-                                     ArrayRef<BlockArgument> appendedEntryArgs,
-                                     MemCpyFn memCpyFn,
-                                     bool hoistStaticAllocs) {
+static LogicalResult
+updateReturnOps(func::FuncOp func, ArrayRef<BlockArgument> appendedEntryArgs,
+                const bufferization::BufferResultsToOutParamsOpts &options) {
   auto res = func.walk([&](func::ReturnOp op) {
     SmallVector<Value, 6> copyIntoOutParams;
     SmallVector<Value, 6> keepAsReturnOperands;
@@ -122,14 +121,14 @@ static LogicalResult updateReturnOps(func::FuncOp func,
     }
     OpBuilder builder(op);
     for (auto [orig, arg] : llvm::zip(copyIntoOutParams, appendedEntryArgs)) {
-      if (hoistStaticAllocs &&
+      if (options.hoistStaticAllocs &&
           isa_and_nonnull<bufferization::AllocationOpInterface>(
               orig.getDefiningOp()) &&
           mlir::cast<MemRefType>(orig.getType()).hasStaticShape()) {
         orig.replaceAllUsesWith(arg);
         orig.getDefiningOp()->erase();
       } else {
-        if (failed(memCpyFn(builder, op.getLoc(), orig, arg)))
+        if (failed(options.memCpyFn(builder, op.getLoc(), orig, arg)))
           return WalkResult::interrupt();
       }
     }
@@ -142,8 +141,9 @@ static LogicalResult updateReturnOps(func::FuncOp func,
 
 // Updates all CallOps in the scope of the given ModuleOp by allocating
 // temporary buffers for newly introduced out params.
-static LogicalResult updateCalls(ModuleOp module, AllocationFn allocationFn,
-                                 std::function<bool(func::FuncOp *)> filterFn) {
+static LogicalResult
+updateCalls(ModuleOp module,
+            const bufferization::BufferResultsToOutParamsOpts &options) {
   bool didFail = false;
   SymbolTable symtab(module);
   module.walk([&](func::CallOp op) {
@@ -154,7 +154,7 @@ static LogicalResult updateCalls(ModuleOp module, AllocationFn allocationFn,
       didFail = true;
       return;
     }
-    if (!filterFn(&callee))
+    if (!options.filterFn(&callee))
       return;
     SmallVector<Value, 6> replaceWithNewCallResults;
     SmallVector<Value, 6> replaceWithOutParams;
@@ -177,7 +177,8 @@ static LogicalResult updateCalls(ModuleOp module, AllocationFn allocationFn,
       auto allocType =
           MemRefType::get(memrefType.getShape(), memrefType.getElementType(),
                           AffineMap(), memrefType.getMemorySpace());
-      auto maybeOutParam = allocationFn(builder, op.getLoc(), allocType);
+      auto maybeOutParam =
+          options.allocationFn(builder, op.getLoc(), allocType);
       if (failed(maybeOutParam)) {
         op.emitError() << "failed to create allocation op";
         didFail = true;
@@ -221,24 +222,11 @@ LogicalResult mlir::bufferization::promoteBufferResultsToOutParams(
       return failure();
     if (func.isExternal())
       continue;
-    auto defaultMemCpyFn = [](OpBuilder &builder, Location loc, Value from,
-                              Value to) {
-      builder.create<memref::CopyOp>(loc, from, to);
-      return success();
-    };
-    if (failed(updateReturnOps(func, appendedEntryArgs,
-                               options.memCpyFn.value_or(defaultMemCpyFn),
-                               options.hoistStaticAllocs))) {
+    if (failed(updateReturnOps(func, appendedEntryArgs, options))) {
       return failure();
     }
   }
-  auto defaultAllocationFn = [](OpBuilder &builder, Location loc,
-                                MemRefType type) {
-    return builder.create<memref::AllocOp>(loc, type).getResult();
-  };
-  if (failed(updateCalls(module,
-                         options.allocationFn.value_or(defaultAllocationFn),
-                         options.filterFn)))
+  if (failed(updateCalls(module, options)))
     return failure();
   return success();
 }



More information about the Mlir-commits mailing list