[Mlir-commits] [mlir] c86f218 - [mlir][Linalg] Allow comprehensive bufferization to use callbacks for alloc/dealloc.

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Oct 25 08:50:39 PDT 2021


Author: MaheshRavishankar
Date: 2021-10-25T08:50:25-07:00
New Revision: c86f218fe4ca661a4348d20b66210324224870e8

URL: https://github.com/llvm/llvm-project/commit/c86f218fe4ca661a4348d20b66210324224870e8
DIFF: https://github.com/llvm/llvm-project/commit/c86f218fe4ca661a4348d20b66210324224870e8.diff

LOG: [mlir][Linalg] Allow comprehensive bufferization to use callbacks for alloc/dealloc.

Using callbacks for allocation/deallocation allows users to override
the default.
Also add an option to comprehensive bufferization pass to use `alloca`
instead of `alloc`s. Note that this option is just for testing. The
option to use `alloca` does not work well with the option to allow for
returning memrefs.

Differential Revision: https://reviews.llvm.org/D112166

Added: 
    mlir/test/Dialect/Linalg/comprehensive-module-bufferize-alloca.mlir

Modified: 
    mlir/include/mlir/Dialect/Linalg/Passes.td
    mlir/include/mlir/Dialect/Linalg/Transforms/ComprehensiveBufferize.h
    mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
    mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp
    utils/bazel/llvm-project-overlay/mlir/BUILD.bazel

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/Passes.td b/mlir/include/mlir/Dialect/Linalg/Passes.td
index d0744a891328..d75437ed6077 100644
--- a/mlir/include/mlir/Dialect/Linalg/Passes.td
+++ b/mlir/include/mlir/Dialect/Linalg/Passes.td
@@ -40,7 +40,10 @@ def LinalgComprehensiveModuleBufferize :
            "Only runs inplaceability analysis (for testing purposes only)">,
     Option<"allowReturnMemref", "allow-return-memref", "bool",
             /*default=*/"false",
-           "Allows the return of memrefs (for testing purposes only)">
+           "Allows the return of memrefs (for testing purposes only)">,
+    Option<"useAlloca", "use-alloca", "bool",
+           /*default=*/"false",
+           "Use stack allocations for memrefs (for testing purposes only)">
   ];
   let constructor = "mlir::createLinalgComprehensiveModuleBufferizePass()";
 }

diff  --git a/mlir/include/mlir/Dialect/Linalg/Transforms/ComprehensiveBufferize.h b/mlir/include/mlir/Dialect/Linalg/Transforms/ComprehensiveBufferize.h
index 5226ab394cd7..da5505504999 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/ComprehensiveBufferize.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/ComprehensiveBufferize.h
@@ -175,14 +175,36 @@ LogicalResult inPlaceAnalysis(SmallVector<Operation *> &ops,
                               BufferizationAliasInfo &aliasInfo,
                               const DominanceInfo &domInfo);
 
+/// Default allocation function that is used by the comprehensive bufferization
+/// pass. The default currently creates a ranked memref using `memref.alloc`.
+Optional<Value> defaultAllocationFn(OpBuilder &b, Location loc,
+                                    Value shapedValue);
+
+/// Default deallocation function that is used by the comprehensive
+/// bufferization pass. It expects to recieve back the value called from the
+/// `defaultAllocationFn`.
+void defaultDeallocationFn(OpBuilder &b, Location loc, Value allocatedBuffer);
+
+/// Callback functions that are used by the comprehensive bufferization pass to
+/// allocate/deallocate memory. These default to use the
+/// `defaultAllocationFn`/`defaultDeallocationFn`, but can be overridden by the
+/// caller. The `deallocationFn` is gauranteed to recieve the `Value` returned
+/// by the `allocationFn`.
+struct AllocationCallbacks {
+  std::function<Optional<Value>(OpBuilder &b, Location loc, Value shapedValue)>
+      allocationFn = defaultAllocationFn;
+  std::function<void(OpBuilder &b, Location loc, Value v)> deallocationFn =
+      defaultDeallocationFn;
+};
+
 /// Bufferize one particular op.
 /// `bufferizedFunctionTypes` (resp. `globalCreator`) are expected to be
 /// non-null if `op` is a CallOpInterface (resp. GlobalCreator).
 LogicalResult
 bufferizeOp(Operation *op, BlockAndValueMapping &bvm,
             BufferizationAliasInfo &aliasInfo,
-            DenseMap<FuncOp, FunctionType> *bufferizedFunctionTypes = nullptr,
-            GlobalCreator *globalCreator = nullptr);
+            AllocationCallbacks allocationFns,
+            DenseMap<FuncOp, FunctionType> *bufferizedFunctionTypes = nullptr);
 
 } // namespace linalg
 } // namespace mlir

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
index 4b1de5f2b471..840978bf5648 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
@@ -33,6 +33,7 @@ add_mlir_dialect_library(MLIRLinalgTransforms
   MLIRAnalysis
   MLIRArithmetic
   MLIRComplex
+  MLIRInferTypeOpInterface
   MLIRIR
   MLIRMemRef
   MLIRLinalgAnalysis

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp
index 9e970f3e5e14..ac01bd80ccc8 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp
@@ -118,12 +118,12 @@
 #include "mlir/Dialect/Vector/VectorOps.h"
 #include "mlir/IR/AsmState.h"
 #include "mlir/IR/Operation.h"
+#include "mlir/Interfaces/InferTypeOpInterface.h"
 #include "mlir/Pass/Pass.h"
 #include "mlir/Pass/PassManager.h"
 #include "mlir/Transforms/BufferUtils.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 #include "mlir/Transforms/Passes.h"
-
 #include "llvm/ADT/DenseSet.h"
 #include "llvm/ADT/ScopeExit.h"
 #include "llvm/ADT/SetVector.h"
@@ -983,7 +983,6 @@ bool BufferizationAliasInfo::hasReadAfterWriteInterference(
     const DenseSet<OpOperand *> &usesRead,
     const DenseSet<OpOperand *> &usesWrite,
     const DominanceInfo &domInfo) const {
-
   for (OpOperand *uRead : usesRead) {
     Operation *readingOp = uRead->getOwner();
 
@@ -1415,66 +1414,27 @@ Operation *getFirstParentOfType(Value v) {
 /// Create an Allocop/DeAllocOp pair, where the AllocOp is after
 /// `shapedValue.getDefiningOp` (or at the top of the block in case of a
 /// bbArg) and the DeallocOp is at the end of the block.
-static Value
-createNewAllocDeallocPairForShapedValue(OpBuilder &b, Location loc,
-                                        Value shapedValue,
-                                        BufferizationAliasInfo &aliasInfo) {
+static Value createNewAllocDeallocPairForShapedValue(
+    OpBuilder &b, Location loc, Value shapedValue,
+    BufferizationAliasInfo &aliasInfo, AllocationCallbacks &allocationFns) {
   // Take a guard before anything else.
   OpBuilder::InsertionGuard g(b);
 
-  // TODO: non-zero address space.
-  // TODO: layout information if relevant.
-  // Cannot allocate an unranked memref so just always go for the contiguous
-  // form.
-  MemRefType allocMemRefType =
-      getContiguousMemRefType(shapedValue.getType().cast<ShapedType>());
   assert(shapedValue.getType().isa<ShapedType>());
   MemRefType memRefType = shapedValue.getType().dyn_cast<MemRefType>();
-  memRefType = memRefType ? memRefType : allocMemRefType;
-
-  if (auto bbArg = shapedValue.dyn_cast<BlockArgument>()) {
-    b.setInsertionPointToStart(bbArg.getOwner());
-    loc = bbArg.getOwner()->getParentOp()->getLoc();
-  } else {
-    b.setInsertionPoint(shapedValue.getDefiningOp());
-    loc = shapedValue.getDefiningOp()->getLoc();
-  }
-
-  // Compute the dynamic part of the shape.
-  SmallVector<Value> dynShape;
-  for (auto dim : enumerate(memRefType.getShape()))
-    if (dim.value() == ShapedType::kDynamicSize)
-      dynShape.push_back(createOrFoldDimOp(b, loc, shapedValue, dim.index()));
 
-  // If the buffer is statically shaped, try to hoist it to the first enclosing
-  // parallel region.
-  // TODO: this concept of parallel region and threadlocal needs interfaces.
-  // TODO: also hoist in the dynamic case. For now this relies on subsequent
-  // calls to LICM and buffer hoisting which will most likely not succeed.
-  // TODO: when packing, allocate a static bounding box which will enable more
-  // hoisting.
-  Value allocated;
-  { // Guarded insertion point to potentially hoist the AllocOp.
-    OpBuilder::InsertionGuard g(b);
-    if (dynShape.empty()) {
-      Operation *parent =
-          getFirstParentOfType<FuncOp, TiledLoopOp, scf::ParallelOp,
-                               AffineParallelOp>(shapedValue);
-      if (parent)
-        b.setInsertionPointToStart(&(parent->getRegion(0).front()));
-    }
-    allocated = b.create<memref::AllocOp>(
-        loc, allocMemRefType, dynShape, b.getI64IntegerAttr(kBufferAlignments));
-    aliasInfo.createAliasInfoEntry(allocated);
+  Optional<Value> allocated = allocationFns.allocationFn(b, loc, shapedValue);
+  // TODO: For now just assert the value is returned. Eventually need to
+  // error-propagate.
+  assert(allocated && "allocation failed");
+  Value casted = allocated.getValue();
+  MemRefType allocMemRefType = allocated->getType().cast<MemRefType>();
+  if (memRefType && memRefType != allocMemRefType) {
+    casted = b.create<memref::CastOp>(loc, memRefType, allocated.getValue());
+    aliasInfo.insertNewBufferEquivalence(casted, allocated.getValue());
   }
-  Value casted = allocated;
-  if (memRefType != allocMemRefType) {
-    casted = b.create<memref::CastOp>(loc, memRefType, allocated);
-    aliasInfo.insertNewBufferEquivalence(casted, allocated);
-  }
-  b.setInsertionPoint(allocated.getParentBlock()->getTerminator());
-  b.create<memref::DeallocOp>(loc, allocated);
 
+  allocationFns.deallocationFn(b, loc, allocated.getValue());
   return casted;
 }
 
@@ -1488,6 +1448,7 @@ createNewAllocDeallocPairForShapedValue(OpBuilder &b, Location loc,
 static Value getResultBuffer(OpBuilder &b, OpResult result,
                              const BlockAndValueMapping &bvm,
                              BufferizationAliasInfo &aliasInfo,
+                             AllocationCallbacks allocationFns,
                              bool skipCopy = false) {
   OpBuilder::InsertionGuard guard(b);
   Operation *op = result.getOwner();
@@ -1515,8 +1476,8 @@ static Value getResultBuffer(OpBuilder &b, OpResult result,
         "ops with multiple aliasing OpOperands cannot bufferize out-of-place");
     Location loc = op->getLoc();
     // Allocate the result buffer.
-    Value resultBuffer =
-        createNewAllocDeallocPairForShapedValue(b, loc, operand, aliasInfo);
+    Value resultBuffer = createNewAllocDeallocPairForShapedValue(
+        b, loc, operand, aliasInfo, allocationFns);
     // Do not copy the result of an InitTensorOp.
     if (isInitTensorOp(operand))
       skipCopy = true;
@@ -1538,11 +1499,10 @@ static Value getResultBuffer(OpBuilder &b, OpResult result,
 /// Helper function for LinalgOp bufferization.
 /// When allocating a new buffer, analyze whether `op` wants to read form that
 /// buffer. Only in that case, a copy of the result buffer may be needed.
-static LogicalResult
-allocateBuffersForResults(OpBuilder &b, Location loc, LinalgOp op,
-                          SmallVectorImpl<Value> &resultBuffers,
-                          BlockAndValueMapping &bvm,
-                          BufferizationAliasInfo &aliasInfo) {
+static LogicalResult allocateBuffersForResults(
+    OpBuilder &b, Location loc, LinalgOp op,
+    SmallVectorImpl<Value> &resultBuffers, BlockAndValueMapping &bvm,
+    BufferizationAliasInfo &aliasInfo, AllocationCallbacks &allocationFns) {
   // Take a guard before anything else.
   OpBuilder::InsertionGuard g(b);
   b.setInsertionPoint(op);
@@ -1553,7 +1513,8 @@ allocateBuffersForResults(OpBuilder &b, Location loc, LinalgOp op,
     OpResult opResult = getInplaceableOpResult(*opOperand);
     assert(opResult && "could not find correspond OpResult");
     bool skipCopy = !op.payloadUsesValueFromOperand(opOperand);
-    Value resultBuffer = getResultBuffer(b, opResult, bvm, aliasInfo, skipCopy);
+    Value resultBuffer =
+        getResultBuffer(b, opResult, bvm, aliasInfo, allocationFns, skipCopy);
     if (!resultBuffer)
       return failure();
     resultBuffers.push_back(resultBuffer);
@@ -1568,7 +1529,8 @@ allocateBuffersForResults(OpBuilder &b, Location loc, LinalgOp op,
 /// Generic conversion for any LinalgOp on tensors.
 static LogicalResult bufferize(OpBuilder &b, LinalgOp op,
                                BlockAndValueMapping &bvm,
-                               BufferizationAliasInfo &aliasInfo) {
+                               BufferizationAliasInfo &aliasInfo,
+                               AllocationCallbacks &allocationFns) {
   // Take a guard before anything else.
   OpBuilder::InsertionGuard g(b);
 
@@ -1591,7 +1553,7 @@ static LogicalResult bufferize(OpBuilder &b, LinalgOp op,
   SmallVector<Value> newOutputBuffers;
   // Try to allocate new buffers depending on op's inplace semantics.
   if (failed(allocateBuffersForResults(b, loc, op, newOutputBuffers, bvm,
-                                       aliasInfo)))
+                                       aliasInfo, allocationFns)))
     return failure();
 
   // Clone the newly bufferized op.
@@ -1616,7 +1578,7 @@ static LogicalResult bufferize(OpBuilder &b, LinalgOp op,
 /// to allow FuncOp that are inplaceable to write inPlace.
 static LogicalResult
 bufferize(OpBuilder &b, CallOpInterface callOp, BlockAndValueMapping &bvm,
-          BufferizationAliasInfo &aliasInfo,
+          BufferizationAliasInfo &aliasInfo, AllocationCallbacks &allocationFns,
           DenseMap<FuncOp, FunctionType> &bufferizedFunctionTypes) {
   FuncOp funcOp = getCalledFunction(callOp);
   assert(isa<CallOp>(callOp.getOperation()) && funcOp &&
@@ -1755,12 +1717,14 @@ bufferize(OpBuilder &b, CallOpInterface callOp, BlockAndValueMapping &bvm,
 /// tensor::CastOp bufferizes to memref::CastOp.
 static LogicalResult bufferize(OpBuilder &b, tensor::CastOp castOp,
                                BlockAndValueMapping &bvm,
-                               BufferizationAliasInfo &aliasInfo) {
+                               BufferizationAliasInfo &aliasInfo,
+                               AllocationCallbacks &allocationFn) {
   // Take a guard before anything else.
   OpBuilder::InsertionGuard g(b);
   b.setInsertionPoint(castOp);
 
-  Value resultBuffer = getResultBuffer(b, castOp->getResult(0), bvm, aliasInfo);
+  Value resultBuffer =
+      getResultBuffer(b, castOp->getResult(0), bvm, aliasInfo, allocationFn);
   if (!resultBuffer)
     return failure();
   Type sourceType = resultBuffer.getType();
@@ -1786,10 +1750,15 @@ static LogicalResult bufferize(OpBuilder &b, tensor::CastOp castOp,
 
 static LogicalResult bufferize(OpBuilder &b, arith::ConstantOp constantOp,
                                BlockAndValueMapping &bvm,
-                               BufferizationAliasInfo &aliasInfo,
-                               GlobalCreator &globalCreator) {
+                               BufferizationAliasInfo &aliasInfo) {
   assert(constantOp.getType().dyn_cast<RankedTensorType>() &&
          "not a constant ranked tensor");
+  auto moduleOp = constantOp->getParentOfType<ModuleOp>();
+  if (!moduleOp) {
+    return constantOp.emitError(
+        "cannot bufferize constants not within builtin.module op");
+  }
+  GlobalCreator globalCreator(moduleOp);
 
   // Take a guard before anything else.
   OpBuilder::InsertionGuard g(b);
@@ -1824,7 +1793,8 @@ static LogicalResult bufferize(OpBuilder &b, tensor::DimOp dimOp,
 
 static LogicalResult bufferize(OpBuilder &b, scf::ForOp forOp,
                                BlockAndValueMapping &bvm,
-                               BufferizationAliasInfo &aliasInfo) {
+                               BufferizationAliasInfo &aliasInfo,
+                               AllocationCallbacks &allocationFn) {
   // Take a guard before anything else.
   OpBuilder::InsertionGuard g(b);
 
@@ -1837,7 +1807,8 @@ static LogicalResult bufferize(OpBuilder &b, scf::ForOp forOp,
            "unsupported unranked tensor");
 
     // TODO: More general: Matching bbArg does not bufferize to a read.
-    Value resultBuffer = getResultBuffer(b, opResult, bvm, aliasInfo);
+    Value resultBuffer =
+        getResultBuffer(b, opResult, bvm, aliasInfo, allocationFn);
     if (!resultBuffer)
       return failure();
 
@@ -1880,7 +1851,8 @@ static LogicalResult bufferize(OpBuilder &b, scf::IfOp ifOp,
 /// FuncOp always creates TensorToMemRef ops.
 static LogicalResult bufferize(OpBuilder &b, FuncOp funcOp,
                                BlockAndValueMapping &bvm,
-                               BufferizationAliasInfo &aliasInfo) {
+                               BufferizationAliasInfo &aliasInfo,
+                               AllocationCallbacks &allocationFn) {
   // Take a guard before anything else.
   OpBuilder::InsertionGuard g(b);
   b.setInsertionPointToStart(&funcOp.body().front());
@@ -1906,7 +1878,8 @@ static LogicalResult bufferize(OpBuilder &b, FuncOp funcOp,
 /// TODO: consider hoisting across function boundaries prior to bufferization.
 static LogicalResult bufferize(OpBuilder &b, InitTensorOp initTensorOp,
                                BlockAndValueMapping &bvm,
-                               BufferizationAliasInfo &aliasInfo) {
+                               BufferizationAliasInfo &aliasInfo,
+                               AllocationCallbacks &allocationFn) {
   // The InitTensorOp may have been eliminated.
   if (initTensorOp->getUses().empty())
     return success();
@@ -1916,7 +1889,8 @@ static LogicalResult bufferize(OpBuilder &b, InitTensorOp initTensorOp,
   b.setInsertionPoint(initTensorOp);
 
   Value alloc = createNewAllocDeallocPairForShapedValue(
-      b, initTensorOp->getLoc(), initTensorOp.result(), aliasInfo);
+      b, initTensorOp->getLoc(), initTensorOp.result(), aliasInfo,
+      allocationFn);
   map(bvm, initTensorOp.result(), alloc);
   return success();
 }
@@ -1949,7 +1923,8 @@ static LogicalResult bufferize(OpBuilder &b, ReturnOp returnOp,
 /// Bufferization for TiledLoopOp..
 static LogicalResult bufferize(OpBuilder &b, TiledLoopOp tiledLoopOp,
                                BlockAndValueMapping &bvm,
-                               BufferizationAliasInfo &aliasInfo) {
+                               BufferizationAliasInfo &aliasInfo,
+                               AllocationCallbacks &allocationFn) {
   // Take a guard before anything else.
   OpBuilder::InsertionGuard g(b);
 
@@ -1989,7 +1964,8 @@ static LogicalResult bufferize(OpBuilder &b, TiledLoopOp tiledLoopOp,
 
     const OpResult &opResult = tiledLoopOp->getResult(resultIndex);
     OpOperand &yieldOperand = yieldOp->getOpOperand(resultIndex);
-    Value resultBuffer = getResultBuffer(b, opResult, bvm, aliasInfo);
+    Value resultBuffer =
+        getResultBuffer(b, opResult, bvm, aliasInfo, allocationFn);
     if (!resultBuffer)
       return failure();
 
@@ -2073,7 +2049,8 @@ static LogicalResult bufferize(OpBuilder &b, TiledLoopOp tiledLoopOp,
 /// isolation.
 static LogicalResult bufferize(OpBuilder &b, ExtractSliceOp extractSliceOp,
                                BlockAndValueMapping &bvm,
-                               BufferizationAliasInfo &aliasInfo) {
+                               BufferizationAliasInfo &aliasInfo,
+                               AllocationCallbacks &allocationFn) {
   // Take a guard before anything else.
   OpBuilder::InsertionGuard g(b);
 
@@ -2093,7 +2070,7 @@ static LogicalResult bufferize(OpBuilder &b, ExtractSliceOp extractSliceOp,
   auto inPlace = getInPlace(extractSliceOp->getResult(0));
   if (inPlace != InPlaceSpec::True)
     alloc = createNewAllocDeallocPairForShapedValue(
-        b, loc, extractSliceOp.result(), aliasInfo);
+        b, loc, extractSliceOp.result(), aliasInfo, allocationFn);
 
   // Set insertion point now that potential alloc/dealloc are introduced.
   b.setInsertionPoint(extractSliceOp);
@@ -2125,7 +2102,8 @@ static LogicalResult bufferize(OpBuilder &b, ExtractSliceOp extractSliceOp,
 
 static LogicalResult bufferize(OpBuilder &b, InsertSliceOp insertSliceOp,
                                BlockAndValueMapping &bvm,
-                               BufferizationAliasInfo &aliasInfo) {
+                               BufferizationAliasInfo &aliasInfo,
+                               AllocationCallbacks &allocationFn) {
   // Take a guard before anything else.
   OpBuilder::InsertionGuard g(b);
   b.setInsertionPoint(insertSliceOp);
@@ -2140,8 +2118,8 @@ static LogicalResult bufferize(OpBuilder &b, InsertSliceOp insertSliceOp,
   // TODO: be very loud about it or even consider failing the pass.
   // Alloc a copy for `insertSliceOp.dest()`, it will become the result
   // buffer.
-  Value dstMemref =
-      getResultBuffer(b, insertSliceOp->getResult(0), bvm, aliasInfo);
+  Value dstMemref = getResultBuffer(b, insertSliceOp->getResult(0), bvm,
+                                    aliasInfo, allocationFn);
   if (!dstMemref)
     return failure();
   auto dstMemrefType = dstMemref.getType().cast<MemRefType>();
@@ -2184,7 +2162,8 @@ static LogicalResult bufferize(OpBuilder &b, InsertSliceOp insertSliceOp,
 
 static LogicalResult bufferize(OpBuilder &b, VectorTransferOpInterface op,
                                BlockAndValueMapping &bvm,
-                               BufferizationAliasInfo &aliasInfo) {
+                               BufferizationAliasInfo &aliasInfo,
+                               AllocationCallbacks &allocationFn) {
   // Take a guard before anything else.
   OpBuilder::InsertionGuard g(b);
   b.setInsertionPoint(op);
@@ -2205,7 +2184,8 @@ static LogicalResult bufferize(OpBuilder &b, VectorTransferOpInterface op,
   // Leave the previous transfer_write to dead code as it still has uses at
   // this point.
   auto writeOp = cast<vector::TransferWriteOp>(op.getOperation());
-  Value resultBuffer = getResultBuffer(b, op->getResult(0), bvm, aliasInfo);
+  Value resultBuffer =
+      getResultBuffer(b, op->getResult(0), bvm, aliasInfo, allocationFn);
   if (!resultBuffer)
     return failure();
   b.create<vector::TransferWriteOp>(
@@ -2436,18 +2416,107 @@ inPlaceAnalysisFuncOpBody(FuncOp funcOp, BufferizationAliasInfo &aliasInfo,
 // Bufferization entry-point for functions.
 //===----------------------------------------------------------------------===//
 
+/// Compute the type of the `memref` to use for allocating the buffer for
+/// `shapedValue`. Also returns (by reference in `dynShape`), the value for the
+/// dynamic dimensions in the returned `memref` type. The function also sets the
+/// insertion point of the builder `b` to the position where the allocation is
+/// to be inserted.
+static MemRefType getAllocationTypeAndShape(OpBuilder &b, Location loc,
+                                            Value shapedValue,
+                                            SmallVectorImpl<Value> &dynShape) {
+  MemRefType allocMemRefType =
+      getContiguousMemRefType(shapedValue.getType().cast<ShapedType>());
+  if (auto bbArg = shapedValue.dyn_cast<BlockArgument>()) {
+    b.setInsertionPointToStart(bbArg.getOwner());
+    loc = bbArg.getOwner()->getParentOp()->getLoc();
+  } else {
+    b.setInsertionPoint(shapedValue.getDefiningOp());
+    loc = shapedValue.getDefiningOp()->getLoc();
+  }
+
+  // Compute the dynamic part of the shape.
+  bool foundDynamicShapes = false;
+  if (auto rankedOp = dyn_cast_or_null<ReifyRankedShapedTypeOpInterface>(
+          shapedValue.getDefiningOp())) {
+    ReifiedRankedShapedTypeDims resultDims;
+    if (succeeded(rankedOp.reifyResultShapes(b, resultDims))) {
+      foundDynamicShapes = true;
+      OpResult resultValue = shapedValue.dyn_cast<OpResult>();
+      auto &shape = resultDims[resultValue.getResultNumber()];
+      for (auto dim : enumerate(allocMemRefType.getShape()))
+        if (dim.value() == ShapedType::kDynamicSize)
+          dynShape.push_back(shape[dim.index()]);
+    }
+  }
+  if (!foundDynamicShapes) {
+    for (auto dim : enumerate(allocMemRefType.getShape()))
+      if (dim.value() == ShapedType::kDynamicSize)
+        dynShape.push_back(createOrFoldDimOp(b, loc, shapedValue, dim.index()));
+  }
+
+  // If the buffer is statically shaped, try to hoist it to the first enclosing
+  // parallel region.
+  // TODO: this concept of parallel region and threadlocal needs interfaces.
+  // TODO: also hoist in the dynamic case. For now this relies on subsequent
+  // calls to LICM and buffer hoisting which will most likely not succeed.
+  // TODO: when packing, allocate a static bounding box which will enable more
+  // hoisting.
+  if (dynShape.empty()) {
+    Operation *parent =
+        getFirstParentOfType<FuncOp, TiledLoopOp, scf::ParallelOp,
+                             AffineParallelOp>(shapedValue);
+    if (parent)
+      b.setInsertionPointToStart(&(parent->getRegion(0).front()));
+  }
+  return allocMemRefType;
+}
+
+Optional<Value> mlir::linalg::defaultAllocationFn(OpBuilder &b, Location loc,
+                                                  Value shapedValue) {
+  // Take a guard before anything else.
+  OpBuilder::InsertionGuard g(b);
+  SmallVector<Value> dynShape;
+  MemRefType allocMemRefType =
+      getAllocationTypeAndShape(b, loc, shapedValue, dynShape);
+  Value allocated = b.create<memref::AllocOp>(
+      loc, allocMemRefType, dynShape, b.getI64IntegerAttr(kBufferAlignments));
+  return allocated;
+}
+
+static Optional<Value> allocationFnUsingAlloca(OpBuilder &b, Location loc,
+                                               Value shapedValue) {
+  OpBuilder::InsertionGuard g(b);
+  SmallVector<Value> dynShape;
+  MemRefType allocMemRefType =
+      getAllocationTypeAndShape(b, loc, shapedValue, dynShape);
+  Value allocated = b.create<memref::AllocaOp>(
+      loc, allocMemRefType, dynShape, b.getI64IntegerAttr(kBufferAlignments));
+  return allocated;
+}
+
+void mlir::linalg::defaultDeallocationFn(OpBuilder &b, Location loc,
+                                         Value allocatedBuffer) {
+  OpBuilder::InsertionGuard g(b);
+  b.setInsertionPoint(allocatedBuffer.getParentBlock()->getTerminator());
+  b.create<memref::DeallocOp>(loc, allocatedBuffer);
+}
+
 LogicalResult mlir::linalg::bufferizeOp(
     Operation *op, BlockAndValueMapping &bvm, BufferizationAliasInfo &aliasInfo,
-    DenseMap<FuncOp, FunctionType> *bufferizedFunctionTypes,
-    GlobalCreator *globalCreator) {
+    AllocationCallbacks allocationFns,
+    DenseMap<FuncOp, FunctionType> *bufferizedFunctionTypes) {
   OpBuilder b(op->getContext());
   return TypeSwitch<Operation *, LogicalResult>(op)
       // Skip BufferCast and TensorLoad ops.
       .Case<memref::BufferCastOp, memref::TensorLoadOp>(
           [&](auto) { return success(); })
-      .Case<tensor::CastOp, tensor::DimOp, ExtractSliceOp, scf::ForOp,
-            InitTensorOp, InsertSliceOp, tensor::ExtractOp, LinalgOp, ReturnOp,
-            TiledLoopOp, VectorTransferOpInterface, linalg::YieldOp,
+      .Case<ExtractSliceOp, InitTensorOp, InsertSliceOp, LinalgOp, scf::ForOp,
+            tensor::CastOp, TiledLoopOp, VectorTransferOpInterface>(
+          [&](auto op) {
+            LDBG("Begin bufferize:\n" << op << '\n');
+            return bufferize(b, op, bvm, aliasInfo, allocationFns);
+          })
+      .Case<tensor::DimOp, tensor::ExtractOp, ReturnOp, linalg::YieldOp,
             scf::YieldOp>([&](auto op) {
         LDBG("Begin bufferize:\n" << op << '\n');
         return bufferize(b, op, bvm, aliasInfo);
@@ -2464,15 +2533,14 @@ LogicalResult mlir::linalg::bufferizeOp(
         if (!bufferizedFunctionTypes)
           llvm_unreachable(
               "null bufferizedFunctionTypes when bufferizing CallOpInterface");
-        return bufferize(b, op, bvm, aliasInfo, *bufferizedFunctionTypes);
+        return bufferize(b, op, bvm, aliasInfo, allocationFns,
+                         *bufferizedFunctionTypes);
       })
       .Case([&](arith::ConstantOp op) {
         if (!isaTensor(op.getResult().getType()))
           return success();
         LDBG("Begin bufferize:\n" << op << '\n');
-        if (!globalCreator)
-          llvm_unreachable("null globalCreator when bufferizing ConstantOp");
-        return bufferize(b, op, bvm, aliasInfo, *globalCreator);
+        return bufferize(b, op, bvm, aliasInfo);
       })
       .Default([&](Operation *op) -> LogicalResult {
         auto isaTensor = [](Type t) { return t.isa<TensorType>(); };
@@ -2485,15 +2553,13 @@ LogicalResult mlir::linalg::bufferizeOp(
 
 static LogicalResult bufferizeFuncOpInternals(
     FuncOp funcOp, BlockAndValueMapping &bvm, BufferizationAliasInfo &aliasInfo,
-    DenseMap<FuncOp, FunctionType> &bufferizedFunctionTypes,
-    GlobalCreator &globalCreator) {
-
+    AllocationCallbacks &allocationFns,
+    DenseMap<FuncOp, FunctionType> &bufferizedFunctionTypes) {
   LLVM_DEBUG(llvm::dbgs() << "\n\n");
   LDBG("Begin BufferizeFuncOpInternals:\n" << funcOp << '\n');
   OpBuilder b(funcOp->getContext());
-
-  // Start by bufferizing `funcOp` arguments.
-  if (failed(bufferize(b, funcOp, bvm, aliasInfo)))
+  /// Start by bufferizing `funcOp` arguments.
+  if (failed(bufferize(b, funcOp, bvm, aliasInfo, allocationFns)))
     return failure();
 
   // Cannot erase ops during the traversal. Do that afterwards.
@@ -2516,13 +2582,13 @@ static LogicalResult bufferizeFuncOpInternals(
     }
 
     for (Operation *op : llvm::reverse(preorderBufferize))
-      if (failed(bufferizeOp(op, bvm, aliasInfo, &bufferizedFunctionTypes,
-                             &globalCreator)))
+      if (failed(bufferizeOp(op, bvm, aliasInfo, allocationFns,
+                             &bufferizedFunctionTypes)))
         return failure();
 
     if (!bufferizedOps.contains(op) &&
-        failed(bufferizeOp(op, bvm, aliasInfo, &bufferizedFunctionTypes,
-                           &globalCreator)))
+        failed(bufferizeOp(op, bvm, aliasInfo, allocationFns,
+                           &bufferizedFunctionTypes)))
       return failure();
 
     // Register post-walk erasure, if necessary.
@@ -2793,12 +2859,19 @@ namespace {
 struct LinalgComprehensiveModuleBufferize
     : public LinalgComprehensiveModuleBufferizeBase<
           LinalgComprehensiveModuleBufferize> {
+  LinalgComprehensiveModuleBufferize() {}
+
+  LinalgComprehensiveModuleBufferize(
+      const LinalgComprehensiveModuleBufferize &p) {}
 
   void runOnOperation() override;
 
   void getDependentDialects(DialectRegistry &registry) const override {
     registry.insert<linalg::LinalgDialect, memref::MemRefDialect>();
   }
+
+private:
+  std::unique_ptr<AllocationCallbacks> allocationFns;
 };
 } // end namespace
 
@@ -2983,6 +3056,22 @@ static LogicalResult runInitTensorElimination(FuncOp funcOp,
 }
 
 void LinalgComprehensiveModuleBufferize::runOnOperation() {
+  if (!allocationFns) {
+    // The allocation functions to use needs to be set here. The flag for the
+    // pass and flag for the use of alloca map to LLVM command line
+    // options. These being static global objects have no set order in which
+    // they are defined. So ideally this should be in the constructor, but the
+    // constructor might be called before the flag is initialized using the
+    // command line option. So this is set up at the start of the pass.
+    if (useAlloca) {
+      AllocationCallbacks allocaAllocationFns = {
+          allocationFnUsingAlloca, [](OpBuilder &b, Location loc, Value v) {}};
+      allocationFns =
+          std::make_unique<AllocationCallbacks>(std::move(allocaAllocationFns));
+    } else {
+      allocationFns = std::make_unique<AllocationCallbacks>();
+    }
+  }
   ModuleOp moduleOp = getOperation();
   applyEnablingTransformations(moduleOp);
 
@@ -2992,7 +3081,6 @@ void LinalgComprehensiveModuleBufferize::runOnOperation() {
   if (failed(getFuncOpsOrderedByCalls(moduleOp, orderedFuncOps, callerMap)))
     return signalPassFailure();
 
-  GlobalCreator globalCreator(moduleOp);
   DominanceInfo domInfo(moduleOp);
   BufferizationAliasInfo aliasInfo(moduleOp);
   // Interestingly, all function args that are not visible outside of a module
@@ -3032,8 +3120,8 @@ void LinalgComprehensiveModuleBufferize::runOnOperation() {
     if (!testAnalysisOnly) {
       BlockAndValueMapping tensorToBufferMap;
       if (failed(bufferizeFuncOpInternals(funcOp, tensorToBufferMap, aliasInfo,
-                                          bufferizedFunctionTypes,
-                                          globalCreator))) {
+                                          *allocationFns,
+                                          bufferizedFunctionTypes))) {
         signalPassFailure();
         return;
       }

diff  --git a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-alloca.mlir b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-alloca.mlir
new file mode 100644
index 000000000000..71d631c85e0d
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-alloca.mlir
@@ -0,0 +1,65 @@
+// RUN: mlir-opt %s -pass-pipeline="linalg-comprehensive-module-bufferize{allow-return-memref use-alloca}" -split-input-file | FileCheck %s
+
+//  CHECK-DAG: #[[$DYN_0D_MAP:.*]] = affine_map<()[s0] -> (s0)>
+//  CHECK-DAG: #[[$DYN_1D_MAP:.*]] = affine_map<(d0)[s0, s1] -> (d0 * s1 + s0)>
+
+//      CHECK:  func @init_and_dot(
+// CHECK-SAME:    %[[A:[a-zA-Z0-9]*]]: memref<64xf32, #[[$DYN_1D_MAP]]>
+// CHECK-SAME:    %[[B:[a-zA-Z0-9]*]]: memref<64xf32, #[[$DYN_1D_MAP]]>
+// CHECK-SAME:    %[[C:[a-zA-Z0-9]*]]: memref<f32, #[[$DYN_0D_MAP]]>
+func @init_and_dot(%a: tensor<64xf32>, %b: tensor<64xf32>, %c: tensor<f32>) -> tensor<f32> {
+  // CHECK-NEXT:   %[[C0:.*]] = arith.constant 0{{.*}} : f32
+  %v0 = arith.constant 0.0 : f32
+
+  // CHECK-NEXT:   linalg.fill(%[[C0]], %[[C]]) : f32, memref<f32, #[[$DYN_0D_MAP]]>
+  %d = linalg.fill(%v0, %c) : f32, tensor<f32> -> tensor<f32>
+
+  // CHECK-NEXT:   linalg.dot ins(%[[A]], %[[B]] : memref<64xf32, #[[$DYN_1D_MAP]]>, memref<64xf32, #[[$DYN_1D_MAP]]>) outs(%[[C]] : memref<f32, #[[$DYN_0D_MAP]]>)
+  %e = linalg.dot ins(%a, %b : tensor<64xf32>,tensor<64xf32>)
+    outs(%d: tensor<f32>) -> tensor<f32>
+
+  // CHECK-NEXT:   return
+  return %e : tensor<f32>
+}
+
+//      CHECK:  func @main()
+func @main() {
+  //  CHECK-DAG:   %[[C0:.*]] = arith.constant 0{{.*}} : f32
+  //  CHECK-DAG:   %[[C1:.*]] = arith.constant 1{{.*}} : f32
+  //  CHECK-DAG:   %[[C2:.*]] = arith.constant 2{{.*}} : f32
+  %v0 = arith.constant 0.0 : f32
+  %v1 = arith.constant 1.0 : f32
+  %v2 = arith.constant 2.0 : f32
+
+  // CHECK-NEXT:   %[[C:.*]] = memref.alloca() {alignment = 128 : i64} : memref<f32>
+  // CHECK-NEXT:   %[[B:.*]] = memref.alloca() {alignment = 128 : i64} : memref<64xf32>
+  // CHECK-NEXT:   %[[A:.*]] = memref.alloca() {alignment = 128 : i64} : memref<64xf32>
+  %A = linalg.init_tensor [64] : tensor<64xf32>
+  %B = linalg.init_tensor [64] : tensor<64xf32>
+  %C = linalg.init_tensor [] : tensor<f32>
+
+  // CHECK-NEXT:   linalg.fill(%[[C1]], %[[A]]) : f32, memref<64xf32>
+  // CHECK-NEXT:   linalg.fill(%[[C2]], %[[B]]) : f32, memref<64xf32>
+  // CHECK-NEXT:   linalg.fill(%[[C0]], %[[C]]) : f32, memref<f32>
+  %AA = linalg.fill(%v1, %A) : f32, tensor<64xf32> -> tensor<64xf32>
+  %BB = linalg.fill(%v2, %B) : f32, tensor<64xf32> -> tensor<64xf32>
+  %CC = linalg.fill(%v0, %C) : f32, tensor<f32> -> tensor<f32>
+
+  // CHECK-NEXT:   %[[cA:.*]] = memref.cast %[[A]] : memref<64xf32> to memref<64xf32, #[[$DYN_1D_MAP]]>
+  // CHECK-NEXT:   %[[cB:.*]] = memref.cast %[[B]] : memref<64xf32> to memref<64xf32, #[[$DYN_1D_MAP]]>
+  // CHECK-NEXT:   %[[cC:.*]] = memref.cast %[[C]] : memref<f32> to memref<f32, #[[$DYN_0D_MAP]]>
+  // CHECK-NEXT:   call @init_and_dot(%[[cA]], %[[cB]], %[[cC]])
+  %res = call @init_and_dot(%AA, %BB, %CC) :
+    (tensor<64xf32>, tensor<64xf32>, tensor<f32>) -> tensor<f32>
+
+  // CHECK-NEXT:   %[[dC:.*]] = memref.cast %[[C]] : memref<f32> to memref<*xf32>
+  %res2 = tensor.cast %res: tensor<f32> to tensor<*xf32>
+
+  // CHECK-NEXT:   call @print_memref_f32(%[[dC]]) : (memref<*xf32>) -> ()
+  call @print_memref_f32(%res2) : (tensor<*xf32>) -> ()
+
+  return
+}
+
+//     CHECK:   func private @print_memref_f32(memref<*xf32>)
+func private @print_memref_f32(tensor<*xf32>)

diff  --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
index a0339919063f..39c258ade3fe 100644
--- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
@@ -6314,6 +6314,7 @@ cc_library(
         ":ComplexDialect",
         ":DialectUtils",
         ":IR",
+        ":InferTypeOpInterface",
         ":LinalgOps",
         ":LinalgPassIncGen",
         ":LinalgStructuredOpsIncGen",


        


More information about the Mlir-commits mailing list