[Mlir-commits] [mlir] ed1cbeb - [mlir][linalg][bufferize][NFC] Simplify AllocationCallbacks

Matthias Springer llvmlistbot at llvm.org
Thu Nov 4 20:01:02 PDT 2021


Author: Matthias Springer
Date: 2021-11-05T11:56:06+09:00
New Revision: ed1cbebafa84cabbfeae762d6062eb33e7eb72a4

URL: https://github.com/llvm/llvm-project/commit/ed1cbebafa84cabbfeae762d6062eb33e7eb72a4
DIFF: https://github.com/llvm/llvm-project/commit/ed1cbebafa84cabbfeae762d6062eb33e7eb72a4.diff

LOG: [mlir][linalg][bufferize][NFC] Simplify AllocationCallbacks

AllocationCallbacks functions allocate/deallocate only. They no longer set the insertion point.

This is in preparation of decoupling ComprehensiveBufferize from the Linalg dialect.

Differential Revision: https://reviews.llvm.org/D112991

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Linalg/Transforms/ComprehensiveBufferize.h
    mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/Transforms/ComprehensiveBufferize.h b/mlir/include/mlir/Dialect/Linalg/Transforms/ComprehensiveBufferize.h
index 4decaee4238b..794ac4cea3d5 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/ComprehensiveBufferize.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/ComprehensiveBufferize.h
@@ -122,8 +122,8 @@ LogicalResult inPlaceAnalysis(SmallVector<Operation *> &ops,
 
 /// Default allocation function that is used by the comprehensive bufferization
 /// pass. The default currently creates a ranked memref using `memref.alloc`.
-Optional<Value> defaultAllocationFn(OpBuilder &b, Location loc,
-                                    Value shapedValue);
+Optional<Value> defaultAllocationFn(OpBuilder &b, Location loc, MemRefType type,
+                                    const SmallVector<Value> &dynShape);
 
 /// Default deallocation function that is used by the comprehensive
 /// bufferization pass. It expects to recieve back the value called from the
@@ -140,8 +140,8 @@ void defaultMemCpyFn(OpBuilder &b, Location loc, Value from, Value to);
 /// caller. The `deallocationFn` is gauranteed to recieve the `Value` returned
 /// by the `allocationFn`.
 struct AllocationCallbacks {
-  using AllocationFn =
-      std::function<Optional<Value>(OpBuilder &, Location, Value)>;
+  using AllocationFn = std::function<Optional<Value>(
+      OpBuilder &, Location, MemRefType, const SmallVector<Value> &)>;
   using DeallocationFn = std::function<void(OpBuilder &, Location, Value)>;
   using MemCpyFn = std::function<void(OpBuilder &, Location, Value, Value)>;
 

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp
index 28e432a88ad3..b85a5df2f507 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp
@@ -1208,6 +1208,61 @@ Operation *getFirstParentOfType(Value v) {
   return nullptr;
 }
 
+/// Compute the type of the `memref` to use for allocating the buffer for
+/// `shapedValue`. Also returns (by reference in `dynShape`), the value for the
+/// dynamic dimensions in the returned `memref` type. The function also sets the
+/// insertion point of the builder `b` to the position where the allocation is
+/// to be inserted.
+static MemRefType getAllocationTypeAndShape(OpBuilder &b, Location loc,
+                                            Value shapedValue,
+                                            SmallVectorImpl<Value> &dynShape) {
+  MemRefType allocMemRefType =
+      getContiguousMemRefType(shapedValue.getType().cast<ShapedType>());
+  if (auto bbArg = shapedValue.dyn_cast<BlockArgument>()) {
+    b.setInsertionPointToStart(bbArg.getOwner());
+    loc = bbArg.getOwner()->getParentOp()->getLoc();
+  } else {
+    b.setInsertionPoint(shapedValue.getDefiningOp());
+    loc = shapedValue.getDefiningOp()->getLoc();
+  }
+
+  // Compute the dynamic part of the shape.
+  bool foundDynamicShapes = false;
+  if (auto rankedOp = dyn_cast_or_null<ReifyRankedShapedTypeOpInterface>(
+          shapedValue.getDefiningOp())) {
+    ReifiedRankedShapedTypeDims resultDims;
+    if (succeeded(rankedOp.reifyResultShapes(b, resultDims))) {
+      foundDynamicShapes = true;
+      OpResult resultValue = shapedValue.dyn_cast<OpResult>();
+      auto &shape = resultDims[resultValue.getResultNumber()];
+      for (auto dim : enumerate(allocMemRefType.getShape()))
+        if (dim.value() == ShapedType::kDynamicSize)
+          dynShape.push_back(shape[dim.index()]);
+    }
+  }
+  if (!foundDynamicShapes) {
+    for (auto dim : enumerate(allocMemRefType.getShape()))
+      if (dim.value() == ShapedType::kDynamicSize)
+        dynShape.push_back(createOrFoldDimOp(b, loc, shapedValue, dim.index()));
+  }
+
+  // If the buffer is statically shaped, try to hoist it to the first enclosing
+  // parallel region.
+  // TODO: this concept of parallel region and threadlocal needs interfaces.
+  // TODO: also hoist in the dynamic case. For now this relies on subsequent
+  // calls to LICM and buffer hoisting which will most likely not succeed.
+  // TODO: when packing, allocate a static bounding box which will enable more
+  // hoisting.
+  if (dynShape.empty()) {
+    Operation *parent =
+        getFirstParentOfType<FuncOp, TiledLoopOp, scf::ParallelOp,
+                             AffineParallelOp>(shapedValue);
+    if (parent)
+      b.setInsertionPointToStart(&(parent->getRegion(0).front()));
+  }
+  return allocMemRefType;
+}
+
 /// Create an Allocop/DeAllocOp pair, where the AllocOp is after
 /// `shapedValue.getDefiningOp` (or at the top of the block in case of a
 /// bbArg) and the DeallocOp is at the end of the block.
@@ -1217,20 +1272,26 @@ static Value createNewAllocDeallocPairForShapedValue(
   // Take a guard before anything else.
   OpBuilder::InsertionGuard g(b);
 
+  // 1. Create memory allocation.
   assert(shapedValue.getType().isa<ShapedType>());
   MemRefType memRefType = shapedValue.getType().dyn_cast<MemRefType>();
-
-  Optional<Value> allocated = allocationFns.allocationFn(b, loc, shapedValue);
+  SmallVector<Value> dynShape;
+  // Note: getAllocationTypeAndShape also sets the insertion point.
+  MemRefType allocMemRefType =
+      getAllocationTypeAndShape(b, loc, shapedValue, dynShape);
+  Optional<Value> allocated =
+      allocationFns.allocationFn(b, loc, allocMemRefType, dynShape);
   // TODO: For now just assert the value is returned. Eventually need to
   // error-propagate.
   assert(allocated && "allocation failed");
   Value casted = allocated.getValue();
-  MemRefType allocMemRefType = allocated->getType().cast<MemRefType>();
   if (memRefType && memRefType != allocMemRefType) {
     casted = b.create<memref::CastOp>(loc, memRefType, allocated.getValue());
     aliasInfo.insertNewBufferEquivalence(casted, allocated.getValue());
   }
 
+  // 2. Create memory deallocation.
+  b.setInsertionPoint(allocated.getValue().getParentBlock()->getTerminator());
   allocationFns.deallocationFn(b, loc, allocated.getValue());
   return casted;
 }
@@ -1595,88 +1656,24 @@ inPlaceAnalysisFuncOpBody(FuncOp funcOp, BufferizationAliasInfo &aliasInfo,
 // Bufferization entry-point for functions.
 //===----------------------------------------------------------------------===//
 
-/// Compute the type of the `memref` to use for allocating the buffer for
-/// `shapedValue`. Also returns (by reference in `dynShape`), the value for the
-/// dynamic dimensions in the returned `memref` type. The function also sets the
-/// insertion point of the builder `b` to the position where the allocation is
-/// to be inserted.
-static MemRefType getAllocationTypeAndShape(OpBuilder &b, Location loc,
-                                            Value shapedValue,
-                                            SmallVectorImpl<Value> &dynShape) {
-  MemRefType allocMemRefType =
-      getContiguousMemRefType(shapedValue.getType().cast<ShapedType>());
-  if (auto bbArg = shapedValue.dyn_cast<BlockArgument>()) {
-    b.setInsertionPointToStart(bbArg.getOwner());
-    loc = bbArg.getOwner()->getParentOp()->getLoc();
-  } else {
-    b.setInsertionPoint(shapedValue.getDefiningOp());
-    loc = shapedValue.getDefiningOp()->getLoc();
-  }
-
-  // Compute the dynamic part of the shape.
-  bool foundDynamicShapes = false;
-  if (auto rankedOp = dyn_cast_or_null<ReifyRankedShapedTypeOpInterface>(
-          shapedValue.getDefiningOp())) {
-    ReifiedRankedShapedTypeDims resultDims;
-    if (succeeded(rankedOp.reifyResultShapes(b, resultDims))) {
-      foundDynamicShapes = true;
-      OpResult resultValue = shapedValue.dyn_cast<OpResult>();
-      auto &shape = resultDims[resultValue.getResultNumber()];
-      for (auto dim : enumerate(allocMemRefType.getShape()))
-        if (dim.value() == ShapedType::kDynamicSize)
-          dynShape.push_back(shape[dim.index()]);
-    }
-  }
-  if (!foundDynamicShapes) {
-    for (auto dim : enumerate(allocMemRefType.getShape()))
-      if (dim.value() == ShapedType::kDynamicSize)
-        dynShape.push_back(createOrFoldDimOp(b, loc, shapedValue, dim.index()));
-  }
-
-  // If the buffer is statically shaped, try to hoist it to the first enclosing
-  // parallel region.
-  // TODO: this concept of parallel region and threadlocal needs interfaces.
-  // TODO: also hoist in the dynamic case. For now this relies on subsequent
-  // calls to LICM and buffer hoisting which will most likely not succeed.
-  // TODO: when packing, allocate a static bounding box which will enable more
-  // hoisting.
-  if (dynShape.empty()) {
-    Operation *parent =
-        getFirstParentOfType<FuncOp, TiledLoopOp, scf::ParallelOp,
-                             AffineParallelOp>(shapedValue);
-    if (parent)
-      b.setInsertionPointToStart(&(parent->getRegion(0).front()));
-  }
-  return allocMemRefType;
-}
-
-Optional<Value> mlir::linalg::defaultAllocationFn(OpBuilder &b, Location loc,
-                                                  Value shapedValue) {
-  // Take a guard before anything else.
-  OpBuilder::InsertionGuard g(b);
-  SmallVector<Value> dynShape;
-  MemRefType allocMemRefType =
-      getAllocationTypeAndShape(b, loc, shapedValue, dynShape);
+Optional<Value>
+mlir::linalg::defaultAllocationFn(OpBuilder &b, Location loc, MemRefType type,
+                                  const SmallVector<Value> &dynShape) {
   Value allocated = b.create<memref::AllocOp>(
-      loc, allocMemRefType, dynShape, b.getI64IntegerAttr(kBufferAlignments));
+      loc, type, dynShape, b.getI64IntegerAttr(kBufferAlignments));
   return allocated;
 }
 
-static Optional<Value> allocationFnUsingAlloca(OpBuilder &b, Location loc,
-                                               Value shapedValue) {
-  OpBuilder::InsertionGuard g(b);
-  SmallVector<Value> dynShape;
-  MemRefType allocMemRefType =
-      getAllocationTypeAndShape(b, loc, shapedValue, dynShape);
+static Optional<Value>
+allocationFnUsingAlloca(OpBuilder &b, Location loc, MemRefType type,
+                        const SmallVector<Value> &dynShape) {
   Value allocated = b.create<memref::AllocaOp>(
-      loc, allocMemRefType, dynShape, b.getI64IntegerAttr(kBufferAlignments));
+      loc, type, dynShape, b.getI64IntegerAttr(kBufferAlignments));
   return allocated;
 }
 
 void mlir::linalg::defaultDeallocationFn(OpBuilder &b, Location loc,
                                          Value allocatedBuffer) {
-  OpBuilder::InsertionGuard g(b);
-  b.setInsertionPoint(allocatedBuffer.getParentBlock()->getTerminator());
   b.create<memref::DeallocOp>(loc, allocatedBuffer);
 }
 


        


More information about the Mlir-commits mailing list