[Mlir-commits] [mlir] 5fb46a9 - Revert "[mlir][Linalg] Allow comprehensive bufferization to use callbacks for alloc/dealloc."

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Oct 25 08:58:00 PDT 2021


Author: MaheshRavishankar
Date: 2021-10-25T08:57:53-07:00
New Revision: 5fb46a9fa3aeb72ed12a830441d0bc736780c23b

URL: https://github.com/llvm/llvm-project/commit/5fb46a9fa3aeb72ed12a830441d0bc736780c23b
DIFF: https://github.com/llvm/llvm-project/commit/5fb46a9fa3aeb72ed12a830441d0bc736780c23b.diff

LOG: Revert "[mlir][Linalg] Allow comprehensive bufferization to use callbacks for alloc/dealloc."

This reverts commit c86f218fe4ca661a4348d20b66210324224870e8.

Revert because it causes build failure.

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Linalg/Passes.td
    mlir/include/mlir/Dialect/Linalg/Transforms/ComprehensiveBufferize.h
    mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
    mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp
    utils/bazel/llvm-project-overlay/mlir/BUILD.bazel

Removed: 
    mlir/test/Dialect/Linalg/comprehensive-module-bufferize-alloca.mlir


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/Passes.td b/mlir/include/mlir/Dialect/Linalg/Passes.td
index d75437ed6077..d0744a891328 100644
--- a/mlir/include/mlir/Dialect/Linalg/Passes.td
+++ b/mlir/include/mlir/Dialect/Linalg/Passes.td
@@ -40,10 +40,7 @@ def LinalgComprehensiveModuleBufferize :
            "Only runs inplaceability analysis (for testing purposes only)">,
     Option<"allowReturnMemref", "allow-return-memref", "bool",
             /*default=*/"false",
-           "Allows the return of memrefs (for testing purposes only)">,
-    Option<"useAlloca", "use-alloca", "bool",
-           /*default=*/"false",
-           "Use stack allocations for memrefs (for testing purposes only)">
+           "Allows the return of memrefs (for testing purposes only)">
   ];
   let constructor = "mlir::createLinalgComprehensiveModuleBufferizePass()";
 }

diff  --git a/mlir/include/mlir/Dialect/Linalg/Transforms/ComprehensiveBufferize.h b/mlir/include/mlir/Dialect/Linalg/Transforms/ComprehensiveBufferize.h
index da5505504999..5226ab394cd7 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/ComprehensiveBufferize.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/ComprehensiveBufferize.h
@@ -175,36 +175,14 @@ LogicalResult inPlaceAnalysis(SmallVector<Operation *> &ops,
                               BufferizationAliasInfo &aliasInfo,
                               const DominanceInfo &domInfo);
 
-/// Default allocation function that is used by the comprehensive bufferization
-/// pass. The default currently creates a ranked memref using `memref.alloc`.
-Optional<Value> defaultAllocationFn(OpBuilder &b, Location loc,
-                                    Value shapedValue);
-
-/// Default deallocation function that is used by the comprehensive
-/// bufferization pass. It expects to recieve back the value called from the
-/// `defaultAllocationFn`.
-void defaultDeallocationFn(OpBuilder &b, Location loc, Value allocatedBuffer);
-
-/// Callback functions that are used by the comprehensive bufferization pass to
-/// allocate/deallocate memory. These default to use the
-/// `defaultAllocationFn`/`defaultDeallocationFn`, but can be overridden by the
-/// caller. The `deallocationFn` is gauranteed to recieve the `Value` returned
-/// by the `allocationFn`.
-struct AllocationCallbacks {
-  std::function<Optional<Value>(OpBuilder &b, Location loc, Value shapedValue)>
-      allocationFn = defaultAllocationFn;
-  std::function<void(OpBuilder &b, Location loc, Value v)> deallocationFn =
-      defaultDeallocationFn;
-};
-
 /// Bufferize one particular op.
 /// `bufferizedFunctionTypes` (resp. `globalCreator`) are expected to be
 /// non-null if `op` is a CallOpInterface (resp. GlobalCreator).
 LogicalResult
 bufferizeOp(Operation *op, BlockAndValueMapping &bvm,
             BufferizationAliasInfo &aliasInfo,
-            AllocationCallbacks allocationFns,
-            DenseMap<FuncOp, FunctionType> *bufferizedFunctionTypes = nullptr);
+            DenseMap<FuncOp, FunctionType> *bufferizedFunctionTypes = nullptr,
+            GlobalCreator *globalCreator = nullptr);
 
 } // namespace linalg
 } // namespace mlir

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
index 840978bf5648..4b1de5f2b471 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
@@ -33,7 +33,6 @@ add_mlir_dialect_library(MLIRLinalgTransforms
   MLIRAnalysis
   MLIRArithmetic
   MLIRComplex
-  MLIRInferTypeOpInterface
   MLIRIR
   MLIRMemRef
   MLIRLinalgAnalysis

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp
index ac01bd80ccc8..9e970f3e5e14 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp
@@ -118,12 +118,12 @@
 #include "mlir/Dialect/Vector/VectorOps.h"
 #include "mlir/IR/AsmState.h"
 #include "mlir/IR/Operation.h"
-#include "mlir/Interfaces/InferTypeOpInterface.h"
 #include "mlir/Pass/Pass.h"
 #include "mlir/Pass/PassManager.h"
 #include "mlir/Transforms/BufferUtils.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 #include "mlir/Transforms/Passes.h"
+
 #include "llvm/ADT/DenseSet.h"
 #include "llvm/ADT/ScopeExit.h"
 #include "llvm/ADT/SetVector.h"
@@ -983,6 +983,7 @@ bool BufferizationAliasInfo::hasReadAfterWriteInterference(
     const DenseSet<OpOperand *> &usesRead,
     const DenseSet<OpOperand *> &usesWrite,
     const DominanceInfo &domInfo) const {
+
   for (OpOperand *uRead : usesRead) {
     Operation *readingOp = uRead->getOwner();
 
@@ -1414,27 +1415,66 @@ Operation *getFirstParentOfType(Value v) {
 /// Create an Allocop/DeAllocOp pair, where the AllocOp is after
 /// `shapedValue.getDefiningOp` (or at the top of the block in case of a
 /// bbArg) and the DeallocOp is at the end of the block.
-static Value createNewAllocDeallocPairForShapedValue(
-    OpBuilder &b, Location loc, Value shapedValue,
-    BufferizationAliasInfo &aliasInfo, AllocationCallbacks &allocationFns) {
+static Value
+createNewAllocDeallocPairForShapedValue(OpBuilder &b, Location loc,
+                                        Value shapedValue,
+                                        BufferizationAliasInfo &aliasInfo) {
   // Take a guard before anything else.
   OpBuilder::InsertionGuard g(b);
 
+  // TODO: non-zero address space.
+  // TODO: layout information if relevant.
+  // Cannot allocate an unranked memref so just always go for the contiguous
+  // form.
+  MemRefType allocMemRefType =
+      getContiguousMemRefType(shapedValue.getType().cast<ShapedType>());
   assert(shapedValue.getType().isa<ShapedType>());
   MemRefType memRefType = shapedValue.getType().dyn_cast<MemRefType>();
+  memRefType = memRefType ? memRefType : allocMemRefType;
+
+  if (auto bbArg = shapedValue.dyn_cast<BlockArgument>()) {
+    b.setInsertionPointToStart(bbArg.getOwner());
+    loc = bbArg.getOwner()->getParentOp()->getLoc();
+  } else {
+    b.setInsertionPoint(shapedValue.getDefiningOp());
+    loc = shapedValue.getDefiningOp()->getLoc();
+  }
+
+  // Compute the dynamic part of the shape.
+  SmallVector<Value> dynShape;
+  for (auto dim : enumerate(memRefType.getShape()))
+    if (dim.value() == ShapedType::kDynamicSize)
+      dynShape.push_back(createOrFoldDimOp(b, loc, shapedValue, dim.index()));
 
-  Optional<Value> allocated = allocationFns.allocationFn(b, loc, shapedValue);
-  // TODO: For now just assert the value is returned. Eventually need to
-  // error-propagate.
-  assert(allocated && "allocation failed");
-  Value casted = allocated.getValue();
-  MemRefType allocMemRefType = allocated->getType().cast<MemRefType>();
-  if (memRefType && memRefType != allocMemRefType) {
-    casted = b.create<memref::CastOp>(loc, memRefType, allocated.getValue());
-    aliasInfo.insertNewBufferEquivalence(casted, allocated.getValue());
+  // If the buffer is statically shaped, try to hoist it to the first enclosing
+  // parallel region.
+  // TODO: this concept of parallel region and threadlocal needs interfaces.
+  // TODO: also hoist in the dynamic case. For now this relies on subsequent
+  // calls to LICM and buffer hoisting which will most likely not succeed.
+  // TODO: when packing, allocate a static bounding box which will enable more
+  // hoisting.
+  Value allocated;
+  { // Guarded insertion point to potentially hoist the AllocOp.
+    OpBuilder::InsertionGuard g(b);
+    if (dynShape.empty()) {
+      Operation *parent =
+          getFirstParentOfType<FuncOp, TiledLoopOp, scf::ParallelOp,
+                               AffineParallelOp>(shapedValue);
+      if (parent)
+        b.setInsertionPointToStart(&(parent->getRegion(0).front()));
+    }
+    allocated = b.create<memref::AllocOp>(
+        loc, allocMemRefType, dynShape, b.getI64IntegerAttr(kBufferAlignments));
+    aliasInfo.createAliasInfoEntry(allocated);
   }
+  Value casted = allocated;
+  if (memRefType != allocMemRefType) {
+    casted = b.create<memref::CastOp>(loc, memRefType, allocated);
+    aliasInfo.insertNewBufferEquivalence(casted, allocated);
+  }
+  b.setInsertionPoint(allocated.getParentBlock()->getTerminator());
+  b.create<memref::DeallocOp>(loc, allocated);
 
-  allocationFns.deallocationFn(b, loc, allocated.getValue());
   return casted;
 }
 
@@ -1448,7 +1488,6 @@ static Value createNewAllocDeallocPairForShapedValue(
 static Value getResultBuffer(OpBuilder &b, OpResult result,
                              const BlockAndValueMapping &bvm,
                              BufferizationAliasInfo &aliasInfo,
-                             AllocationCallbacks allocationFns,
                              bool skipCopy = false) {
   OpBuilder::InsertionGuard guard(b);
   Operation *op = result.getOwner();
@@ -1476,8 +1515,8 @@ static Value getResultBuffer(OpBuilder &b, OpResult result,
         "ops with multiple aliasing OpOperands cannot bufferize out-of-place");
     Location loc = op->getLoc();
     // Allocate the result buffer.
-    Value resultBuffer = createNewAllocDeallocPairForShapedValue(
-        b, loc, operand, aliasInfo, allocationFns);
+    Value resultBuffer =
+        createNewAllocDeallocPairForShapedValue(b, loc, operand, aliasInfo);
     // Do not copy the result of an InitTensorOp.
     if (isInitTensorOp(operand))
       skipCopy = true;
@@ -1499,10 +1538,11 @@ static Value getResultBuffer(OpBuilder &b, OpResult result,
 /// Helper function for LinalgOp bufferization.
 /// When allocating a new buffer, analyze whether `op` wants to read form that
 /// buffer. Only in that case, a copy of the result buffer may be needed.
-static LogicalResult allocateBuffersForResults(
-    OpBuilder &b, Location loc, LinalgOp op,
-    SmallVectorImpl<Value> &resultBuffers, BlockAndValueMapping &bvm,
-    BufferizationAliasInfo &aliasInfo, AllocationCallbacks &allocationFns) {
+static LogicalResult
+allocateBuffersForResults(OpBuilder &b, Location loc, LinalgOp op,
+                          SmallVectorImpl<Value> &resultBuffers,
+                          BlockAndValueMapping &bvm,
+                          BufferizationAliasInfo &aliasInfo) {
   // Take a guard before anything else.
   OpBuilder::InsertionGuard g(b);
   b.setInsertionPoint(op);
@@ -1513,8 +1553,7 @@ static LogicalResult allocateBuffersForResults(
     OpResult opResult = getInplaceableOpResult(*opOperand);
     assert(opResult && "could not find correspond OpResult");
     bool skipCopy = !op.payloadUsesValueFromOperand(opOperand);
-    Value resultBuffer =
-        getResultBuffer(b, opResult, bvm, aliasInfo, allocationFns, skipCopy);
+    Value resultBuffer = getResultBuffer(b, opResult, bvm, aliasInfo, skipCopy);
     if (!resultBuffer)
       return failure();
     resultBuffers.push_back(resultBuffer);
@@ -1529,8 +1568,7 @@ static LogicalResult allocateBuffersForResults(
 /// Generic conversion for any LinalgOp on tensors.
 static LogicalResult bufferize(OpBuilder &b, LinalgOp op,
                                BlockAndValueMapping &bvm,
-                               BufferizationAliasInfo &aliasInfo,
-                               AllocationCallbacks &allocationFns) {
+                               BufferizationAliasInfo &aliasInfo) {
   // Take a guard before anything else.
   OpBuilder::InsertionGuard g(b);
 
@@ -1553,7 +1591,7 @@ static LogicalResult bufferize(OpBuilder &b, LinalgOp op,
   SmallVector<Value> newOutputBuffers;
   // Try to allocate new buffers depending on op's inplace semantics.
   if (failed(allocateBuffersForResults(b, loc, op, newOutputBuffers, bvm,
-                                       aliasInfo, allocationFns)))
+                                       aliasInfo)))
     return failure();
 
   // Clone the newly bufferized op.
@@ -1578,7 +1616,7 @@ static LogicalResult bufferize(OpBuilder &b, LinalgOp op,
 /// to allow FuncOp that are inplaceable to write inPlace.
 static LogicalResult
 bufferize(OpBuilder &b, CallOpInterface callOp, BlockAndValueMapping &bvm,
-          BufferizationAliasInfo &aliasInfo, AllocationCallbacks &allocationFns,
+          BufferizationAliasInfo &aliasInfo,
           DenseMap<FuncOp, FunctionType> &bufferizedFunctionTypes) {
   FuncOp funcOp = getCalledFunction(callOp);
   assert(isa<CallOp>(callOp.getOperation()) && funcOp &&
@@ -1717,14 +1755,12 @@ bufferize(OpBuilder &b, CallOpInterface callOp, BlockAndValueMapping &bvm,
 /// tensor::CastOp bufferizes to memref::CastOp.
 static LogicalResult bufferize(OpBuilder &b, tensor::CastOp castOp,
                                BlockAndValueMapping &bvm,
-                               BufferizationAliasInfo &aliasInfo,
-                               AllocationCallbacks &allocationFn) {
+                               BufferizationAliasInfo &aliasInfo) {
   // Take a guard before anything else.
   OpBuilder::InsertionGuard g(b);
   b.setInsertionPoint(castOp);
 
-  Value resultBuffer =
-      getResultBuffer(b, castOp->getResult(0), bvm, aliasInfo, allocationFn);
+  Value resultBuffer = getResultBuffer(b, castOp->getResult(0), bvm, aliasInfo);
   if (!resultBuffer)
     return failure();
   Type sourceType = resultBuffer.getType();
@@ -1750,15 +1786,10 @@ static LogicalResult bufferize(OpBuilder &b, tensor::CastOp castOp,
 
 static LogicalResult bufferize(OpBuilder &b, arith::ConstantOp constantOp,
                                BlockAndValueMapping &bvm,
-                               BufferizationAliasInfo &aliasInfo) {
+                               BufferizationAliasInfo &aliasInfo,
+                               GlobalCreator &globalCreator) {
   assert(constantOp.getType().dyn_cast<RankedTensorType>() &&
          "not a constant ranked tensor");
-  auto moduleOp = constantOp->getParentOfType<ModuleOp>();
-  if (!moduleOp) {
-    return constantOp.emitError(
-        "cannot bufferize constants not within builtin.module op");
-  }
-  GlobalCreator globalCreator(moduleOp);
 
   // Take a guard before anything else.
   OpBuilder::InsertionGuard g(b);
@@ -1793,8 +1824,7 @@ static LogicalResult bufferize(OpBuilder &b, tensor::DimOp dimOp,
 
 static LogicalResult bufferize(OpBuilder &b, scf::ForOp forOp,
                                BlockAndValueMapping &bvm,
-                               BufferizationAliasInfo &aliasInfo,
-                               AllocationCallbacks &allocationFn) {
+                               BufferizationAliasInfo &aliasInfo) {
   // Take a guard before anything else.
   OpBuilder::InsertionGuard g(b);
 
@@ -1807,8 +1837,7 @@ static LogicalResult bufferize(OpBuilder &b, scf::ForOp forOp,
            "unsupported unranked tensor");
 
     // TODO: More general: Matching bbArg does not bufferize to a read.
-    Value resultBuffer =
-        getResultBuffer(b, opResult, bvm, aliasInfo, allocationFn);
+    Value resultBuffer = getResultBuffer(b, opResult, bvm, aliasInfo);
     if (!resultBuffer)
       return failure();
 
@@ -1851,8 +1880,7 @@ static LogicalResult bufferize(OpBuilder &b, scf::IfOp ifOp,
 /// FuncOp always creates TensorToMemRef ops.
 static LogicalResult bufferize(OpBuilder &b, FuncOp funcOp,
                                BlockAndValueMapping &bvm,
-                               BufferizationAliasInfo &aliasInfo,
-                               AllocationCallbacks &allocationFn) {
+                               BufferizationAliasInfo &aliasInfo) {
   // Take a guard before anything else.
   OpBuilder::InsertionGuard g(b);
   b.setInsertionPointToStart(&funcOp.body().front());
@@ -1878,8 +1906,7 @@ static LogicalResult bufferize(OpBuilder &b, FuncOp funcOp,
 /// TODO: consider hoisting across function boundaries prior to bufferization.
 static LogicalResult bufferize(OpBuilder &b, InitTensorOp initTensorOp,
                                BlockAndValueMapping &bvm,
-                               BufferizationAliasInfo &aliasInfo,
-                               AllocationCallbacks &allocationFn) {
+                               BufferizationAliasInfo &aliasInfo) {
   // The InitTensorOp may have been eliminated.
   if (initTensorOp->getUses().empty())
     return success();
@@ -1889,8 +1916,7 @@ static LogicalResult bufferize(OpBuilder &b, InitTensorOp initTensorOp,
   b.setInsertionPoint(initTensorOp);
 
   Value alloc = createNewAllocDeallocPairForShapedValue(
-      b, initTensorOp->getLoc(), initTensorOp.result(), aliasInfo,
-      allocationFn);
+      b, initTensorOp->getLoc(), initTensorOp.result(), aliasInfo);
   map(bvm, initTensorOp.result(), alloc);
   return success();
 }
@@ -1923,8 +1949,7 @@ static LogicalResult bufferize(OpBuilder &b, ReturnOp returnOp,
 /// Bufferization for TiledLoopOp..
 static LogicalResult bufferize(OpBuilder &b, TiledLoopOp tiledLoopOp,
                                BlockAndValueMapping &bvm,
-                               BufferizationAliasInfo &aliasInfo,
-                               AllocationCallbacks &allocationFn) {
+                               BufferizationAliasInfo &aliasInfo) {
   // Take a guard before anything else.
   OpBuilder::InsertionGuard g(b);
 
@@ -1964,8 +1989,7 @@ static LogicalResult bufferize(OpBuilder &b, TiledLoopOp tiledLoopOp,
 
     const OpResult &opResult = tiledLoopOp->getResult(resultIndex);
     OpOperand &yieldOperand = yieldOp->getOpOperand(resultIndex);
-    Value resultBuffer =
-        getResultBuffer(b, opResult, bvm, aliasInfo, allocationFn);
+    Value resultBuffer = getResultBuffer(b, opResult, bvm, aliasInfo);
     if (!resultBuffer)
       return failure();
 
@@ -2049,8 +2073,7 @@ static LogicalResult bufferize(OpBuilder &b, TiledLoopOp tiledLoopOp,
 /// isolation.
 static LogicalResult bufferize(OpBuilder &b, ExtractSliceOp extractSliceOp,
                                BlockAndValueMapping &bvm,
-                               BufferizationAliasInfo &aliasInfo,
-                               AllocationCallbacks &allocationFn) {
+                               BufferizationAliasInfo &aliasInfo) {
   // Take a guard before anything else.
   OpBuilder::InsertionGuard g(b);
 
@@ -2070,7 +2093,7 @@ static LogicalResult bufferize(OpBuilder &b, ExtractSliceOp extractSliceOp,
   auto inPlace = getInPlace(extractSliceOp->getResult(0));
   if (inPlace != InPlaceSpec::True)
     alloc = createNewAllocDeallocPairForShapedValue(
-        b, loc, extractSliceOp.result(), aliasInfo, allocationFn);
+        b, loc, extractSliceOp.result(), aliasInfo);
 
   // Set insertion point now that potential alloc/dealloc are introduced.
   b.setInsertionPoint(extractSliceOp);
@@ -2102,8 +2125,7 @@ static LogicalResult bufferize(OpBuilder &b, ExtractSliceOp extractSliceOp,
 
 static LogicalResult bufferize(OpBuilder &b, InsertSliceOp insertSliceOp,
                                BlockAndValueMapping &bvm,
-                               BufferizationAliasInfo &aliasInfo,
-                               AllocationCallbacks &allocationFn) {
+                               BufferizationAliasInfo &aliasInfo) {
   // Take a guard before anything else.
   OpBuilder::InsertionGuard g(b);
   b.setInsertionPoint(insertSliceOp);
@@ -2118,8 +2140,8 @@ static LogicalResult bufferize(OpBuilder &b, InsertSliceOp insertSliceOp,
   // TODO: be very loud about it or even consider failing the pass.
   // Alloc a copy for `insertSliceOp.dest()`, it will become the result
   // buffer.
-  Value dstMemref = getResultBuffer(b, insertSliceOp->getResult(0), bvm,
-                                    aliasInfo, allocationFn);
+  Value dstMemref =
+      getResultBuffer(b, insertSliceOp->getResult(0), bvm, aliasInfo);
   if (!dstMemref)
     return failure();
   auto dstMemrefType = dstMemref.getType().cast<MemRefType>();
@@ -2162,8 +2184,7 @@ static LogicalResult bufferize(OpBuilder &b, InsertSliceOp insertSliceOp,
 
 static LogicalResult bufferize(OpBuilder &b, VectorTransferOpInterface op,
                                BlockAndValueMapping &bvm,
-                               BufferizationAliasInfo &aliasInfo,
-                               AllocationCallbacks &allocationFn) {
+                               BufferizationAliasInfo &aliasInfo) {
   // Take a guard before anything else.
   OpBuilder::InsertionGuard g(b);
   b.setInsertionPoint(op);
@@ -2184,8 +2205,7 @@ static LogicalResult bufferize(OpBuilder &b, VectorTransferOpInterface op,
   // Leave the previous transfer_write to dead code as it still has uses at
   // this point.
   auto writeOp = cast<vector::TransferWriteOp>(op.getOperation());
-  Value resultBuffer =
-      getResultBuffer(b, op->getResult(0), bvm, aliasInfo, allocationFn);
+  Value resultBuffer = getResultBuffer(b, op->getResult(0), bvm, aliasInfo);
   if (!resultBuffer)
     return failure();
   b.create<vector::TransferWriteOp>(
@@ -2416,107 +2436,18 @@ inPlaceAnalysisFuncOpBody(FuncOp funcOp, BufferizationAliasInfo &aliasInfo,
 // Bufferization entry-point for functions.
 //===----------------------------------------------------------------------===//
 
-/// Compute the type of the `memref` to use for allocating the buffer for
-/// `shapedValue`. Also returns (by reference in `dynShape`), the value for the
-/// dynamic dimensions in the returned `memref` type. The function also sets the
-/// insertion point of the builder `b` to the position where the allocation is
-/// to be inserted.
-static MemRefType getAllocationTypeAndShape(OpBuilder &b, Location loc,
-                                            Value shapedValue,
-                                            SmallVectorImpl<Value> &dynShape) {
-  MemRefType allocMemRefType =
-      getContiguousMemRefType(shapedValue.getType().cast<ShapedType>());
-  if (auto bbArg = shapedValue.dyn_cast<BlockArgument>()) {
-    b.setInsertionPointToStart(bbArg.getOwner());
-    loc = bbArg.getOwner()->getParentOp()->getLoc();
-  } else {
-    b.setInsertionPoint(shapedValue.getDefiningOp());
-    loc = shapedValue.getDefiningOp()->getLoc();
-  }
-
-  // Compute the dynamic part of the shape.
-  bool foundDynamicShapes = false;
-  if (auto rankedOp = dyn_cast_or_null<ReifyRankedShapedTypeOpInterface>(
-          shapedValue.getDefiningOp())) {
-    ReifiedRankedShapedTypeDims resultDims;
-    if (succeeded(rankedOp.reifyResultShapes(b, resultDims))) {
-      foundDynamicShapes = true;
-      OpResult resultValue = shapedValue.dyn_cast<OpResult>();
-      auto &shape = resultDims[resultValue.getResultNumber()];
-      for (auto dim : enumerate(allocMemRefType.getShape()))
-        if (dim.value() == ShapedType::kDynamicSize)
-          dynShape.push_back(shape[dim.index()]);
-    }
-  }
-  if (!foundDynamicShapes) {
-    for (auto dim : enumerate(allocMemRefType.getShape()))
-      if (dim.value() == ShapedType::kDynamicSize)
-        dynShape.push_back(createOrFoldDimOp(b, loc, shapedValue, dim.index()));
-  }
-
-  // If the buffer is statically shaped, try to hoist it to the first enclosing
-  // parallel region.
-  // TODO: this concept of parallel region and threadlocal needs interfaces.
-  // TODO: also hoist in the dynamic case. For now this relies on subsequent
-  // calls to LICM and buffer hoisting which will most likely not succeed.
-  // TODO: when packing, allocate a static bounding box which will enable more
-  // hoisting.
-  if (dynShape.empty()) {
-    Operation *parent =
-        getFirstParentOfType<FuncOp, TiledLoopOp, scf::ParallelOp,
-                             AffineParallelOp>(shapedValue);
-    if (parent)
-      b.setInsertionPointToStart(&(parent->getRegion(0).front()));
-  }
-  return allocMemRefType;
-}
-
-Optional<Value> mlir::linalg::defaultAllocationFn(OpBuilder &b, Location loc,
-                                                  Value shapedValue) {
-  // Take a guard before anything else.
-  OpBuilder::InsertionGuard g(b);
-  SmallVector<Value> dynShape;
-  MemRefType allocMemRefType =
-      getAllocationTypeAndShape(b, loc, shapedValue, dynShape);
-  Value allocated = b.create<memref::AllocOp>(
-      loc, allocMemRefType, dynShape, b.getI64IntegerAttr(kBufferAlignments));
-  return allocated;
-}
-
-static Optional<Value> allocationFnUsingAlloca(OpBuilder &b, Location loc,
-                                               Value shapedValue) {
-  OpBuilder::InsertionGuard g(b);
-  SmallVector<Value> dynShape;
-  MemRefType allocMemRefType =
-      getAllocationTypeAndShape(b, loc, shapedValue, dynShape);
-  Value allocated = b.create<memref::AllocaOp>(
-      loc, allocMemRefType, dynShape, b.getI64IntegerAttr(kBufferAlignments));
-  return allocated;
-}
-
-void mlir::linalg::defaultDeallocationFn(OpBuilder &b, Location loc,
-                                         Value allocatedBuffer) {
-  OpBuilder::InsertionGuard g(b);
-  b.setInsertionPoint(allocatedBuffer.getParentBlock()->getTerminator());
-  b.create<memref::DeallocOp>(loc, allocatedBuffer);
-}
-
 LogicalResult mlir::linalg::bufferizeOp(
     Operation *op, BlockAndValueMapping &bvm, BufferizationAliasInfo &aliasInfo,
-    AllocationCallbacks allocationFns,
-    DenseMap<FuncOp, FunctionType> *bufferizedFunctionTypes) {
+    DenseMap<FuncOp, FunctionType> *bufferizedFunctionTypes,
+    GlobalCreator *globalCreator) {
   OpBuilder b(op->getContext());
   return TypeSwitch<Operation *, LogicalResult>(op)
       // Skip BufferCast and TensorLoad ops.
       .Case<memref::BufferCastOp, memref::TensorLoadOp>(
           [&](auto) { return success(); })
-      .Case<ExtractSliceOp, InitTensorOp, InsertSliceOp, LinalgOp, scf::ForOp,
-            tensor::CastOp, TiledLoopOp, VectorTransferOpInterface>(
-          [&](auto op) {
-            LDBG("Begin bufferize:\n" << op << '\n');
-            return bufferize(b, op, bvm, aliasInfo, allocationFns);
-          })
-      .Case<tensor::DimOp, tensor::ExtractOp, ReturnOp, linalg::YieldOp,
+      .Case<tensor::CastOp, tensor::DimOp, ExtractSliceOp, scf::ForOp,
+            InitTensorOp, InsertSliceOp, tensor::ExtractOp, LinalgOp, ReturnOp,
+            TiledLoopOp, VectorTransferOpInterface, linalg::YieldOp,
             scf::YieldOp>([&](auto op) {
         LDBG("Begin bufferize:\n" << op << '\n');
         return bufferize(b, op, bvm, aliasInfo);
@@ -2533,14 +2464,15 @@ LogicalResult mlir::linalg::bufferizeOp(
         if (!bufferizedFunctionTypes)
           llvm_unreachable(
               "null bufferizedFunctionTypes when bufferizing CallOpInterface");
-        return bufferize(b, op, bvm, aliasInfo, allocationFns,
-                         *bufferizedFunctionTypes);
+        return bufferize(b, op, bvm, aliasInfo, *bufferizedFunctionTypes);
       })
       .Case([&](arith::ConstantOp op) {
         if (!isaTensor(op.getResult().getType()))
           return success();
         LDBG("Begin bufferize:\n" << op << '\n');
-        return bufferize(b, op, bvm, aliasInfo);
+        if (!globalCreator)
+          llvm_unreachable("null globalCreator when bufferizing ConstantOp");
+        return bufferize(b, op, bvm, aliasInfo, *globalCreator);
       })
       .Default([&](Operation *op) -> LogicalResult {
         auto isaTensor = [](Type t) { return t.isa<TensorType>(); };
@@ -2553,13 +2485,15 @@ LogicalResult mlir::linalg::bufferizeOp(
 
 static LogicalResult bufferizeFuncOpInternals(
     FuncOp funcOp, BlockAndValueMapping &bvm, BufferizationAliasInfo &aliasInfo,
-    AllocationCallbacks &allocationFns,
-    DenseMap<FuncOp, FunctionType> &bufferizedFunctionTypes) {
+    DenseMap<FuncOp, FunctionType> &bufferizedFunctionTypes,
+    GlobalCreator &globalCreator) {
+
   LLVM_DEBUG(llvm::dbgs() << "\n\n");
   LDBG("Begin BufferizeFuncOpInternals:\n" << funcOp << '\n');
   OpBuilder b(funcOp->getContext());
-  /// Start by bufferizing `funcOp` arguments.
-  if (failed(bufferize(b, funcOp, bvm, aliasInfo, allocationFns)))
+
+  // Start by bufferizing `funcOp` arguments.
+  if (failed(bufferize(b, funcOp, bvm, aliasInfo)))
     return failure();
 
   // Cannot erase ops during the traversal. Do that afterwards.
@@ -2582,13 +2516,13 @@ static LogicalResult bufferizeFuncOpInternals(
     }
 
     for (Operation *op : llvm::reverse(preorderBufferize))
-      if (failed(bufferizeOp(op, bvm, aliasInfo, allocationFns,
-                             &bufferizedFunctionTypes)))
+      if (failed(bufferizeOp(op, bvm, aliasInfo, &bufferizedFunctionTypes,
+                             &globalCreator)))
         return failure();
 
     if (!bufferizedOps.contains(op) &&
-        failed(bufferizeOp(op, bvm, aliasInfo, allocationFns,
-                           &bufferizedFunctionTypes)))
+        failed(bufferizeOp(op, bvm, aliasInfo, &bufferizedFunctionTypes,
+                           &globalCreator)))
       return failure();
 
     // Register post-walk erasure, if necessary.
@@ -2859,19 +2793,12 @@ namespace {
 struct LinalgComprehensiveModuleBufferize
     : public LinalgComprehensiveModuleBufferizeBase<
           LinalgComprehensiveModuleBufferize> {
-  LinalgComprehensiveModuleBufferize() {}
-
-  LinalgComprehensiveModuleBufferize(
-      const LinalgComprehensiveModuleBufferize &p) {}
 
   void runOnOperation() override;
 
   void getDependentDialects(DialectRegistry &registry) const override {
     registry.insert<linalg::LinalgDialect, memref::MemRefDialect>();
   }
-
-private:
-  std::unique_ptr<AllocationCallbacks> allocationFns;
 };
 } // end namespace
 
@@ -3056,22 +2983,6 @@ static LogicalResult runInitTensorElimination(FuncOp funcOp,
 }
 
 void LinalgComprehensiveModuleBufferize::runOnOperation() {
-  if (!allocationFns) {
-    // The allocation functions to use needs to be set here. The flag for the
-    // pass and flag for the use of alloca map to LLVM command line
-    // options. These being static global objects have no set order in which
-    // they are defined. So ideally this should be in the constructor, but the
-    // constructor might be called before the flag is initialized using the
-    // command line option. So this is set up at the start of the pass.
-    if (useAlloca) {
-      AllocationCallbacks allocaAllocationFns = {
-          allocationFnUsingAlloca, [](OpBuilder &b, Location loc, Value v) {}};
-      allocationFns =
-          std::make_unique<AllocationCallbacks>(std::move(allocaAllocationFns));
-    } else {
-      allocationFns = std::make_unique<AllocationCallbacks>();
-    }
-  }
   ModuleOp moduleOp = getOperation();
   applyEnablingTransformations(moduleOp);
 
@@ -3081,6 +2992,7 @@ void LinalgComprehensiveModuleBufferize::runOnOperation() {
   if (failed(getFuncOpsOrderedByCalls(moduleOp, orderedFuncOps, callerMap)))
     return signalPassFailure();
 
+  GlobalCreator globalCreator(moduleOp);
   DominanceInfo domInfo(moduleOp);
   BufferizationAliasInfo aliasInfo(moduleOp);
   // Interestingly, all function args that are not visible outside of a module
@@ -3120,8 +3032,8 @@ void LinalgComprehensiveModuleBufferize::runOnOperation() {
     if (!testAnalysisOnly) {
       BlockAndValueMapping tensorToBufferMap;
       if (failed(bufferizeFuncOpInternals(funcOp, tensorToBufferMap, aliasInfo,
-                                          *allocationFns,
-                                          bufferizedFunctionTypes))) {
+                                          bufferizedFunctionTypes,
+                                          globalCreator))) {
         signalPassFailure();
         return;
       }

diff  --git a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-alloca.mlir b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-alloca.mlir
deleted file mode 100644
index 71d631c85e0d..000000000000
--- a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-alloca.mlir
+++ /dev/null
@@ -1,65 +0,0 @@
-// RUN: mlir-opt %s -pass-pipeline="linalg-comprehensive-module-bufferize{allow-return-memref use-alloca}" -split-input-file | FileCheck %s
-
-//  CHECK-DAG: #[[$DYN_0D_MAP:.*]] = affine_map<()[s0] -> (s0)>
-//  CHECK-DAG: #[[$DYN_1D_MAP:.*]] = affine_map<(d0)[s0, s1] -> (d0 * s1 + s0)>
-
-//      CHECK:  func @init_and_dot(
-// CHECK-SAME:    %[[A:[a-zA-Z0-9]*]]: memref<64xf32, #[[$DYN_1D_MAP]]>
-// CHECK-SAME:    %[[B:[a-zA-Z0-9]*]]: memref<64xf32, #[[$DYN_1D_MAP]]>
-// CHECK-SAME:    %[[C:[a-zA-Z0-9]*]]: memref<f32, #[[$DYN_0D_MAP]]>
-func @init_and_dot(%a: tensor<64xf32>, %b: tensor<64xf32>, %c: tensor<f32>) -> tensor<f32> {
-  // CHECK-NEXT:   %[[C0:.*]] = arith.constant 0{{.*}} : f32
-  %v0 = arith.constant 0.0 : f32
-
-  // CHECK-NEXT:   linalg.fill(%[[C0]], %[[C]]) : f32, memref<f32, #[[$DYN_0D_MAP]]>
-  %d = linalg.fill(%v0, %c) : f32, tensor<f32> -> tensor<f32>
-
-  // CHECK-NEXT:   linalg.dot ins(%[[A]], %[[B]] : memref<64xf32, #[[$DYN_1D_MAP]]>, memref<64xf32, #[[$DYN_1D_MAP]]>) outs(%[[C]] : memref<f32, #[[$DYN_0D_MAP]]>)
-  %e = linalg.dot ins(%a, %b : tensor<64xf32>,tensor<64xf32>)
-    outs(%d: tensor<f32>) -> tensor<f32>
-
-  // CHECK-NEXT:   return
-  return %e : tensor<f32>
-}
-
-//      CHECK:  func @main()
-func @main() {
-  //  CHECK-DAG:   %[[C0:.*]] = arith.constant 0{{.*}} : f32
-  //  CHECK-DAG:   %[[C1:.*]] = arith.constant 1{{.*}} : f32
-  //  CHECK-DAG:   %[[C2:.*]] = arith.constant 2{{.*}} : f32
-  %v0 = arith.constant 0.0 : f32
-  %v1 = arith.constant 1.0 : f32
-  %v2 = arith.constant 2.0 : f32
-
-  // CHECK-NEXT:   %[[C:.*]] = memref.alloca() {alignment = 128 : i64} : memref<f32>
-  // CHECK-NEXT:   %[[B:.*]] = memref.alloca() {alignment = 128 : i64} : memref<64xf32>
-  // CHECK-NEXT:   %[[A:.*]] = memref.alloca() {alignment = 128 : i64} : memref<64xf32>
-  %A = linalg.init_tensor [64] : tensor<64xf32>
-  %B = linalg.init_tensor [64] : tensor<64xf32>
-  %C = linalg.init_tensor [] : tensor<f32>
-
-  // CHECK-NEXT:   linalg.fill(%[[C1]], %[[A]]) : f32, memref<64xf32>
-  // CHECK-NEXT:   linalg.fill(%[[C2]], %[[B]]) : f32, memref<64xf32>
-  // CHECK-NEXT:   linalg.fill(%[[C0]], %[[C]]) : f32, memref<f32>
-  %AA = linalg.fill(%v1, %A) : f32, tensor<64xf32> -> tensor<64xf32>
-  %BB = linalg.fill(%v2, %B) : f32, tensor<64xf32> -> tensor<64xf32>
-  %CC = linalg.fill(%v0, %C) : f32, tensor<f32> -> tensor<f32>
-
-  // CHECK-NEXT:   %[[cA:.*]] = memref.cast %[[A]] : memref<64xf32> to memref<64xf32, #[[$DYN_1D_MAP]]>
-  // CHECK-NEXT:   %[[cB:.*]] = memref.cast %[[B]] : memref<64xf32> to memref<64xf32, #[[$DYN_1D_MAP]]>
-  // CHECK-NEXT:   %[[cC:.*]] = memref.cast %[[C]] : memref<f32> to memref<f32, #[[$DYN_0D_MAP]]>
-  // CHECK-NEXT:   call @init_and_dot(%[[cA]], %[[cB]], %[[cC]])
-  %res = call @init_and_dot(%AA, %BB, %CC) :
-    (tensor<64xf32>, tensor<64xf32>, tensor<f32>) -> tensor<f32>
-
-  // CHECK-NEXT:   %[[dC:.*]] = memref.cast %[[C]] : memref<f32> to memref<*xf32>
-  %res2 = tensor.cast %res: tensor<f32> to tensor<*xf32>
-
-  // CHECK-NEXT:   call @print_memref_f32(%[[dC]]) : (memref<*xf32>) -> ()
-  call @print_memref_f32(%res2) : (tensor<*xf32>) -> ()
-
-  return
-}
-
-//     CHECK:   func private @print_memref_f32(memref<*xf32>)
-func private @print_memref_f32(tensor<*xf32>)

diff  --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
index 39c258ade3fe..a0339919063f 100644
--- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
@@ -6314,7 +6314,6 @@ cc_library(
         ":ComplexDialect",
         ":DialectUtils",
         ":IR",
-        ":InferTypeOpInterface",
         ":LinalgOps",
         ":LinalgPassIncGen",
         ":LinalgStructuredOpsIncGen",


        


More information about the Mlir-commits mailing list