[Mlir-commits] [mlir] 0a58982 - [mlir][Linalg] Remove alloc/dealloc pair as a callback.

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Nov 24 10:36:45 PST 2021


Author: MaheshRavishankar
Date: 2021-11-24T10:36:34-08:00
New Revision: 0a58982b082df3b0b995a19f2cbed34228fda73d

URL: https://github.com/llvm/llvm-project/commit/0a58982b082df3b0b995a19f2cbed34228fda73d
DIFF: https://github.com/llvm/llvm-project/commit/0a58982b082df3b0b995a19f2cbed34228fda73d.diff

LOG: [mlir][Linalg] Remove alloc/dealloc pair as a callback.

The alloc dealloc pair generation callback is really central to the
bufferization algorithm, it modifies the state in a way that affects
correctness. This is not really a configurable option. Moving it to
BufferizationState removes what was probably the reason it was added
as a callback.

Differential Revision: https://reviews.llvm.org/D114417

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h
    mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp
    mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp
    mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp
    mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp
    mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h
index 5ce675101c5a..ee01a99017a0 100644
--- a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h
+++ b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h
@@ -212,16 +212,13 @@ struct BufferizationState;
 // functions in the future.
 struct AllocationCallbacks {
   using AllocationFn = std::function<Optional<Value>(
-      OpBuilder &, Location, MemRefType, const SmallVector<Value> &)>;
+      OpBuilder &, Location, MemRefType, ArrayRef<Value>)>;
   using DeallocationFn = std::function<void(OpBuilder &, Location, Value)>;
   using MemCpyFn = std::function<void(OpBuilder &, Location, Value, Value)>;
-  using CreateAllocDeallocFn =
-      std::function<Value(OpBuilder &, Location, Value, BufferizationState &)>;
 
   AllocationCallbacks(AllocationFn allocFn, DeallocationFn deallocFn,
-                      MemCpyFn copyFn, CreateAllocDeallocFn allocDeallocFn)
-      : allocationFn(allocFn), deallocationFn(deallocFn), memCpyFn(copyFn),
-        createAllocDeallocFn(allocDeallocFn) {}
+                      MemCpyFn copyFn)
+      : allocationFn(allocFn), deallocationFn(deallocFn), memCpyFn(copyFn) {}
 
   /// A function that allocates memory.
   AllocationFn allocationFn;
@@ -231,11 +228,6 @@ struct AllocationCallbacks {
 
   /// A function that copies memory between two allocations.
   MemCpyFn memCpyFn;
-
-  /// A function that creates an alloc-dealloc pair. This function may perform
-  /// additional optimizations such as buffer allocation hoisting. This function
-  /// calls `allocationFn` and `deallocationFn` to create (de)allocations.
-  CreateAllocDeallocFn createAllocDeallocFn;
 };
 
 /// BufferizationState keeps track of bufferization state and provides access to
@@ -247,6 +239,12 @@ struct BufferizationState {
   // BufferizationState should be passed as a reference.
   BufferizationState(const BufferizationState &) = delete;
 
+  /// A function that creates an alloc-dealloc pair. This function may perform
+  /// additional optimizations such as buffer allocation hoisting. This function
+  /// calls `allocationFn` and `deallocationFn` to create (de)allocations.
+  Value createAllocDeallocFn(OpBuilder &builder, Location loc,
+                             Value shapedValue);
+
   /// Map tensor values to memref buffers.
   void mapBuffer(ValueRange tensors, ValueRange buffers);
 

diff  --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp
index a530139f1e50..46c881d15c06 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp
@@ -7,6 +7,7 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h"
+
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/IR/AsmState.h"
 #include "mlir/IR/BlockAndValueMapping.h"
@@ -360,8 +361,7 @@ Value mlir::linalg::comprehensive_bufferize::getResultBuffer(
       b.setInsertionPointAfter(operandBuffer.getDefiningOp());
     }
     // Allocate the result buffer.
-    Value resultBuffer =
-        state.allocationFns.createAllocDeallocFn(b, loc, operandBuffer, state);
+    Value resultBuffer = state.createAllocDeallocFn(b, loc, operandBuffer);
     bool skipCopy = false;
     // Do not copy if the last preceding write of `operand` is an op that does
     // not write (skipping ops that merely create aliases). E.g., InitTensorOp.
@@ -442,6 +442,118 @@ mlir::linalg::comprehensive_bufferize::bufferize(Operation *op,
   return op->emitError() << "unsupported op with tensors";
 }
 
+//===----------------------------------------------------------------------===//
+// Bufferization-specific scoped alloc/dealloc insertion support.
+//===----------------------------------------------------------------------===//
+
+/// Move the insertion point of the given builder to the beginning of a
+/// surrounding block as much as possible, while not crossing any allocation
+/// hoisting barriers.
+static void moveInsertionPointToAllocationHoistingBarrier(OpBuilder &b) {
+  Operation *op = b.getInsertionBlock()->getParentOp();
+  while (op) {
+    if (auto bufferizableOp = dyn_cast<BufferizableOpInterface>(op))
+      if (bufferizableOp.isAllocationHoistingBarrier())
+        break;
+    op = op->getParentOp();
+  }
+
+  // FuncOp is an allocation hoisting barrier, so the above loop should never
+  // run out of parents.
+  assert(
+      (op && cast<BufferizableOpInterface>(op).isAllocationHoistingBarrier()) &&
+      "expected traversal to end at allocation hoisting barrier");
+
+  // TODO: Handle cases where allocation hoisting barrier has more than one
+  // region or block.
+  assert(op->getNumRegions() == 1 &&
+         "allocation hoisting barriers with >1 regions not supported");
+  assert(op->getRegion(0).getBlocks().size() == 1 &&
+         "allocation hoisting barriers with >1 blocks not supported");
+  b.setInsertionPointToStart(&(op->getRegion(0).front()));
+}
+
+/// Compute the type of the `memref` to use for allocating the buffer for
+/// `shapedValue`. Also returns (by reference in `dynShape`), the value for the
+/// dynamic dimensions in the returned `memref` type. The function may also set
+/// the insertion point to an earlier location, where the allocation should
+/// happen ("allocation hoisting").
+static MemRefType getAllocationTypeAndShape(OpBuilder &b, Location loc,
+                                            Value shapedValue,
+                                            SmallVectorImpl<Value> &dynShape) {
+  MemRefType allocMemRefType =
+      getContiguousMemRefType(shapedValue.getType().cast<ShapedType>());
+
+  // Compute the dynamic part of the shape.
+  bool reifiedShapes = false;
+  if (auto rankedOp = dyn_cast_or_null<ReifyRankedShapedTypeOpInterface>(
+          shapedValue.getDefiningOp())) {
+    ReifiedRankedShapedTypeDims resultDims;
+    if (succeeded(rankedOp.reifyResultShapes(b, resultDims))) {
+      reifiedShapes = true;
+      OpResult resultValue = shapedValue.dyn_cast<OpResult>();
+      auto &shape = resultDims[resultValue.getResultNumber()];
+      for (auto dim : enumerate(allocMemRefType.getShape()))
+        if (ShapedType::isDynamic(dim.value()))
+          dynShape.push_back(shape[dim.index()]);
+    }
+  }
+
+  if (!reifiedShapes) {
+    for (auto dim : enumerate(allocMemRefType.getShape()))
+      if (ShapedType::isDynamic(dim.value())) {
+        assert((shapedValue.getType().isa<UnrankedMemRefType>() ||
+                shapedValue.getType().isa<MemRefType>()) &&
+               "expected MemRef type");
+        dynShape.push_back(
+            b.create<memref::DimOp>(loc, shapedValue, dim.index()));
+      }
+  }
+
+  // If the buffer is statically shaped, try to hoist it to the first enclosing
+  // parallel region.
+  // TODO: also hoist in the dynamic case. For now this relies on subsequent
+  // calls to LICM and buffer hoisting which will most likely not succeed.
+  // TODO: when packing, allocate a static bounding box which will enable more
+  // hoisting.
+  if (dynShape.empty())
+    moveInsertionPointToAllocationHoistingBarrier(b);
+
+  return allocMemRefType;
+}
+
+/// Create an Allocop/DeAllocOp pair, where the AllocOp is after
+/// `shapedValue.getDefiningOp` (or at the top of the block in case of a
+/// bbArg) and the DeallocOp is at the end of the block.
+Value mlir::linalg::comprehensive_bufferize::BufferizationState::
+    createAllocDeallocFn(OpBuilder &b, Location loc, Value shapedValue) {
+  // Take a guard before anything else.
+  OpBuilder::InsertionGuard g(b);
+
+  // 1. Create memory allocation.
+  assert(shapedValue.getType().isa<ShapedType>());
+  MemRefType memRefType = shapedValue.getType().dyn_cast<MemRefType>();
+  SmallVector<Value> dynShape;
+  // Note: getAllocationTypeAndShape also sets the insertion point.
+  MemRefType allocMemRefType =
+      getAllocationTypeAndShape(b, loc, shapedValue, dynShape);
+  Optional<Value> allocated =
+      allocationFns.allocationFn(b, loc, allocMemRefType, dynShape);
+  // TODO: For now just assert the value is returned. Eventually need to
+  // error-propagate.
+  assert(allocated && "allocation failed");
+  Value casted = allocated.getValue();
+  if (memRefType && memRefType != allocMemRefType) {
+    casted = b.create<memref::CastOp>(loc, memRefType, allocated.getValue());
+    aliasInfo.insertNewBufferEquivalence(casted, allocated.getValue());
+  }
+
+  // 2. Create memory deallocation.
+  b.setInsertionPoint(allocated.getValue().getParentBlock()->getTerminator());
+  allocationFns.deallocationFn(b, loc, allocated.getValue());
+  return casted;
+}
+
 //===----------------------------------------------------------------------===//
 // Bufferization-specific BlockAndValueMapping support with debugging.
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp
index 0dd0fa41d7ce..22f5bf3b06e5 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp
@@ -627,118 +627,6 @@ static FunctionType getOrCreateBufferizedFunctionType(
   return it2.first->second;
 }
 
-//===----------------------------------------------------------------------===//
-// Bufferization-specific scoped alloc/dealloc insertion support.
-//===----------------------------------------------------------------------===//
-
-/// Move the insertion point of the given builder to the beginning of a
-/// surrounding block as much as possible, while not crossing any allocation
-/// hoisting barriers.
-static void moveInsertionPointToAllocationHoistingBarrier(OpBuilder &b) {
-  Operation *op = b.getInsertionBlock()->getParentOp();
-  while (op) {
-    if (auto bufferizableOp = dyn_cast<BufferizableOpInterface>(op))
-      if (bufferizableOp.isAllocationHoistingBarrier())
-        break;
-    op = op->getParentOp();
-  }
-
-  // FuncOp is an allocation hoisting barrier, so the above loop should never
-  // run out of parents.
-  assert(
-      (op && cast<BufferizableOpInterface>(op).isAllocationHoistingBarrier()) &&
-      "expected traversal to end at allocation hoisting barrier");
-
-  // TODO: Handle cases where allocation hoisting barrier has more than one
-  // region or block.
-  assert(op->getNumRegions() == 1 &&
-         "allocation hoisting barriers with >1 regions not supported");
-  assert(op->getRegion(0).getBlocks().size() == 1 &&
-         "allocation hoisting barriers with >1 blocks not supported");
-  b.setInsertionPointToStart(&(op->getRegion(0).front()));
-}
-
-/// Compute the type of the `memref` to use for allocating the buffer for
-/// `shapedValue`. Also returns (by reference in `dynShape`), the value for the
-/// dynamic dimensions in the returned `memref` type. The function may also set
-/// the insertion point to an earlier location, where the allocation should
-/// happen ("allocation hoisting").
-static MemRefType getAllocationTypeAndShape(OpBuilder &b, Location loc,
-                                            Value shapedValue,
-                                            SmallVectorImpl<Value> &dynShape) {
-  MemRefType allocMemRefType =
-      getContiguousMemRefType(shapedValue.getType().cast<ShapedType>());
-
-  // Compute the dynamic part of the shape.
-  bool reifiedShapes = false;
-  if (auto rankedOp = dyn_cast_or_null<ReifyRankedShapedTypeOpInterface>(
-          shapedValue.getDefiningOp())) {
-    ReifiedRankedShapedTypeDims resultDims;
-    if (succeeded(rankedOp.reifyResultShapes(b, resultDims))) {
-      reifiedShapes = true;
-      OpResult resultValue = shapedValue.dyn_cast<OpResult>();
-      auto &shape = resultDims[resultValue.getResultNumber()];
-      for (auto dim : enumerate(allocMemRefType.getShape()))
-        if (ShapedType::isDynamic(dim.value()))
-          dynShape.push_back(shape[dim.index()]);
-    }
-  }
-
-  if (!reifiedShapes) {
-    for (auto dim : enumerate(allocMemRefType.getShape()))
-      if (ShapedType::isDynamic(dim.value())) {
-        assert((shapedValue.getType().isa<UnrankedMemRefType>() ||
-                shapedValue.getType().isa<MemRefType>()) &&
-               "expected MemRef type");
-        dynShape.push_back(
-            b.create<memref::DimOp>(loc, shapedValue, dim.index()));
-      }
-  }
-
-  // If the buffer is statically shaped, try to hoist it to the first enclosing
-  // parallel region.
-  // TODO: also hoist in the dynamic case. For now this relies on subsequent
-  // calls to LICM and buffer hoisting which will most likely not succeed.
-  // TODO: when packing, allocate a static bounding box which will enable more
-  // hoisting.
-  if (dynShape.empty())
-    moveInsertionPointToAllocationHoistingBarrier(b);
-
-  return allocMemRefType;
-}
-
-/// Create an Allocop/DeAllocOp pair, where the AllocOp is after
-/// `shapedValue.getDefiningOp` (or at the top of the block in case of a
-/// bbArg) and the DeallocOp is at the end of the block.
-static Value createNewAllocDeallocPairForShapedValue(
-    OpBuilder &b, Location loc, Value shapedValue, BufferizationState &state) {
-  // Take a guard before anything else.
-  OpBuilder::InsertionGuard g(b);
-
-  // 1. Create memory allocation.
-  assert(shapedValue.getType().isa<ShapedType>());
-  MemRefType memRefType = shapedValue.getType().dyn_cast<MemRefType>();
-  SmallVector<Value> dynShape;
-  // Note: getAllocationTypeAndShape also sets the insertion point.
-  MemRefType allocMemRefType =
-      getAllocationTypeAndShape(b, loc, shapedValue, dynShape);
-  Optional<Value> allocated =
-      state.allocationFns.allocationFn(b, loc, allocMemRefType, dynShape);
-  // TODO: For now just assert the value is returned. Eventually need to
-  // error-propagate.
-  assert(allocated && "allocation failed");
-  Value casted = allocated.getValue();
-  if (memRefType && memRefType != allocMemRefType) {
-    casted = b.create<memref::CastOp>(loc, memRefType, allocated.getValue());
-    state.aliasInfo.insertNewBufferEquivalence(casted, allocated.getValue());
-  }
-
-  // 2. Create memory deallocation.
-  b.setInsertionPoint(allocated.getValue().getParentBlock()->getTerminator());
-  state.allocationFns.deallocationFn(b, loc, allocated.getValue());
-  return casted;
-}
-
 //===----------------------------------------------------------------------===//
 // Bufferization as simple BlockAndValueMapping rewrites.
 //===----------------------------------------------------------------------===//
@@ -1358,7 +1246,7 @@ LogicalResult mlir::linalg::comprehensive_bufferize::runComprehensiveBufferize(
 /// pass. The default currently creates a ranked memref using `memref.alloc`.
 static Optional<Value> defaultAllocationFn(OpBuilder &b, Location loc,
                                            MemRefType type,
-                                           const SmallVector<Value> &dynShape) {
+                                           ArrayRef<Value> dynShape) {
   Value allocated = b.create<memref::AllocOp>(
       loc, type, dynShape, b.getI64IntegerAttr(kBufferAlignments));
   return allocated;
@@ -1381,8 +1269,7 @@ static void defaultMemCpyFn(OpBuilder &b, Location loc, Value from, Value to) {
 std::unique_ptr<AllocationCallbacks>
 mlir::linalg::comprehensive_bufferize::defaultAllocationCallbacks() {
   return std::make_unique<AllocationCallbacks>(
-      defaultAllocationFn, defaultDeallocationFn, defaultMemCpyFn,
-      createNewAllocDeallocPairForShapedValue);
+      defaultAllocationFn, defaultDeallocationFn, defaultMemCpyFn);
 }
 
 // Default constructor for BufferizationOptions that sets all allocation

diff  --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp
index 52fed305d06c..7bea450c343f 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp
@@ -167,8 +167,8 @@ struct InitTensorOpInterface
     OpBuilder::InsertionGuard g(b);
     b.setInsertionPoint(initTensorOp);
 
-    Value alloc = state.allocationFns.createAllocDeallocFn(
-        b, initTensorOp->getLoc(), initTensorOp.result(), state);
+    Value alloc = state.createAllocDeallocFn(b, initTensorOp->getLoc(),
+                                             initTensorOp.result());
     state.mapBuffer(initTensorOp.result(), alloc);
     return success();
   }

diff  --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp
index f72a23d7ca81..13cc7d7ce4f3 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp
@@ -154,8 +154,7 @@ struct ExtractSliceOpInterface
     bool inplace = state.aliasInfo.isInPlace(extractSliceOp->getResult(0));
     Value alloc;
     if (!inplace)
-      alloc = state.allocationFns.createAllocDeallocFn(
-          b, loc, extractSliceOp.result(), state);
+      alloc = state.createAllocDeallocFn(b, loc, extractSliceOp.result());
 
     // Bufferize to subview.
     auto subviewMemRefType =

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp
index 70e6dd7088c5..9acb2c1d5fc3 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp
@@ -52,9 +52,9 @@ static void applyEnablingTransformations(ModuleOp moduleOp) {
   (void)applyPatternsAndFoldGreedily(moduleOp, std::move(patterns));
 }
 
-static Optional<Value>
-allocationFnUsingAlloca(OpBuilder &b, Location loc, MemRefType type,
-                        const SmallVector<Value> &dynShape) {
+static Optional<Value> allocationFnUsingAlloca(OpBuilder &b, Location loc,
+                                               MemRefType type,
+                                               ArrayRef<Value> dynShape) {
   Value allocated = b.create<memref::AllocaOp>(
       loc, type, dynShape, b.getI64IntegerAttr(kBufferAlignments));
   return allocated;


        


More information about the Mlir-commits mailing list