[Mlir-commits] [mlir] b2499bf - [mlir][bufferize][NFC] Refactor createAlloc function signature

Matthias Springer llvmlistbot at llvm.org
Mon Jan 24 03:25:59 PST 2022


Author: Matthias Springer
Date: 2022-01-24T20:25:35+09:00
New Revision: b2499bf3e851c67ef623766b922de520de9235d5

URL: https://github.com/llvm/llvm-project/commit/b2499bf3e851c67ef623766b922de520de9235d5
DIFF: https://github.com/llvm/llvm-project/commit/b2499bf3e851c67ef623766b922de520de9235d5.diff

LOG: [mlir][bufferize][NFC] Refactor createAlloc function signature

Pass a ValueRange instead of an ArrayRef<Value> for better compatibility. Also provide an additional function overload that automatically deallocates the buffer if specified.

Differential Revision: https://reviews.llvm.org/D118025

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
    mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
    mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
index f679a22fa7a6c..bbac6e59aeeb2 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
@@ -36,8 +36,8 @@ class BufferizationState;
 
 /// Options for ComprehensiveBufferize.
 struct BufferizationOptions {
-  using AllocationFn = std::function<FailureOr<Value>(
-      OpBuilder &, Location, MemRefType, ArrayRef<Value>)>;
+  using AllocationFn = std::function<FailureOr<Value>(OpBuilder &, Location,
+                                                      MemRefType, ValueRange)>;
   using DeallocationFn =
       std::function<LogicalResult(OpBuilder &, Location, Value)>;
   using MemCpyFn =
@@ -298,15 +298,23 @@ UnrankedMemRefType getUnrankedMemRefType(Type elementType,
 MemRefType getDynamicMemRefType(RankedTensorType tensorType,
                                 unsigned addressSpace = 0);
 
-/// Creates a memref allocation.
+/// Creates a memref allocation with the given type and dynamic extents.
 FailureOr<Value> createAlloc(OpBuilder &b, Location loc, MemRefType type,
-                             ArrayRef<Value> dynShape,
+                             ValueRange dynShape,
+                             const BufferizationOptions &options);
+
+/// Creates a memref allocation with the given type and dynamic extents. If
+/// `createDealloc`, a deallocation op is inserted at the point where the
+/// allocation goes out of scope.
+FailureOr<Value> createAlloc(OpBuilder &b, Location loc, MemRefType type,
+                             ValueRange dynShape, bool deallocMemref,
                              const BufferizationOptions &options);
 
 /// Creates a memref allocation for the given shaped value. This function may
 /// perform additional optimizations such as buffer allocation hoisting. If
 /// `createDealloc`, a deallocation op is inserted at the point where the
 /// allocation goes out of scope.
+// TODO: Allocation hoisting should be a cleanup pass.
 FailureOr<Value> createAlloc(OpBuilder &b, Location loc, Value shapedValue,
                              bool deallocMemref,
                              const BufferizationOptions &options);

diff  --git a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
index fb081d3d6c3cd..e565f41a39d5a 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
@@ -433,10 +433,10 @@ bufferization::createAlloc(OpBuilder &b, Location loc, Value shapedValue,
   return casted;
 }
 
-/// Create a memref allocation.
+/// Create a memref allocation with the given type and dynamic extents.
 FailureOr<Value>
 bufferization::createAlloc(OpBuilder &b, Location loc, MemRefType type,
-                           ArrayRef<Value> dynShape,
+                           ValueRange dynShape,
                            const BufferizationOptions &options) {
   if (options.allocationFn)
     return (*options.allocationFn)(b, loc, type, dynShape);
@@ -447,6 +447,28 @@ bufferization::createAlloc(OpBuilder &b, Location loc, MemRefType type,
   return allocated;
 }
 
+/// Create a memref allocation with the given type and dynamic extents. May also
+/// deallocate the memref again.
+FailureOr<Value>
+bufferization::createAlloc(OpBuilder &b, Location loc, MemRefType type,
+                           ValueRange dynShape, bool deallocMemref,
+                           const BufferizationOptions &options) {
+  OpBuilder::InsertionGuard g(b);
+
+  FailureOr<Value> alloc = createAlloc(b, loc, type, dynShape, options);
+  if (failed(alloc))
+    return failure();
+
+  if (deallocMemref) {
+    // Dealloc at the end of the block.
+    b.setInsertionPoint(alloc.getValue().getParentBlock()->getTerminator());
+    if (failed(createDealloc(b, loc, *alloc, options)))
+      return failure();
+  }
+
+  return alloc;
+}
+
 /// Create a memref deallocation.
 LogicalResult
 bufferization::createDealloc(OpBuilder &b, Location loc, Value allocatedBuffer,

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp
index 3c8b9c9606952..9409492e12dba 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp
@@ -73,7 +73,7 @@ static void applyEnablingTransformations(ModuleOp moduleOp) {
 
 static FailureOr<Value> allocationFnUsingAlloca(OpBuilder &b, Location loc,
                                                 MemRefType type,
-                                                ArrayRef<Value> dynShape) {
+                                                ValueRange dynShape) {
   Value allocated = b.create<memref::AllocaOp>(
       loc, type, dynShape, b.getI64IntegerAttr(kBufferAlignments));
   return allocated;


        


More information about the Mlir-commits mailing list