[Mlir-commits] [mlir] 542a8cf - [mlir][linalg][bufferize] Fix insertion point of result buffers

Matthias Springer llvmlistbot at llvm.org
Mon Nov 15 02:33:22 PST 2021


Author: Matthias Springer
Date: 2021-11-15T19:27:33+09:00
New Revision: 542a8cfba7fb5cc62aec4442b7d2f13d72da37fb

URL: https://github.com/llvm/llvm-project/commit/542a8cfba7fb5cc62aec4442b7d2f13d72da37fb
DIFF: https://github.com/llvm/llvm-project/commit/542a8cfba7fb5cc62aec4442b7d2f13d72da37fb.diff

LOG: [mlir][linalg][bufferize] Fix insertion point of result buffers

Differential Revision: https://reviews.llvm.org/D113723

Added: 
    

Modified: 
    mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp
    mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp
    mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp
index bada67c7c1b79..0778f1dbd3b6b 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp
@@ -350,9 +350,16 @@ Value mlir::linalg::comprehensive_bufferize::getResultBuffer(
         aliasingOperands.size() == 1 &&
         "ops with multiple aliasing OpOperands cannot bufferize out-of-place");
     Location loc = op->getLoc();
+    // Move insertion point right after `operandBuffer`. That is where the
+    // allocation should be inserted (in the absence of allocation hoisting).
+    if (auto bbArg = operandBuffer.dyn_cast<BlockArgument>()) {
+      b.setInsertionPointToStart(bbArg.getOwner());
+    } else {
+      b.setInsertionPointAfter(operandBuffer.getDefiningOp());
+    }
     // Allocate the result buffer.
     Value resultBuffer =
-        state.allocationFns.createAllocDeallocFn(b, loc, operand, state);
+        state.allocationFns.createAllocDeallocFn(b, loc, operandBuffer, state);
     bool skipCopy = false;
     // Do not copy if the last preceding write of `operand` is an op that does
     // not write (skipping ops that merely create aliases). E.g., InitTensorOp.
@@ -372,7 +379,7 @@ Value mlir::linalg::comprehensive_bufferize::getResultBuffer(
         !bufferizesToMemoryRead(*opOperand))
       skipCopy = true;
     if (!skipCopy) {
-      // Set insertion point now that potential alloc/dealloc are introduced.
+      // The copy happens right before the op that is bufferized.
       b.setInsertionPoint(op);
       state.allocationFns.memCpyFn(b, loc, operandBuffer, resultBuffer);
     }

diff  --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp
index c447e0d33fcea..f0024214909b3 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp
@@ -756,53 +756,68 @@ static FunctionType getOrCreateBufferizedFunctionType(
 // Bufferization-specific scoped alloc/dealloc insertion support.
 //===----------------------------------------------------------------------===//
 
-/// Helper function that creates a memref::DimOp or tensor::DimOp depending on
-/// the type of `source`.
-static Value createOrFoldDimOp(OpBuilder &b, Location loc, Value source,
-                               int64_t dim) {
-  if (source.getType().isa<UnrankedMemRefType, MemRefType>())
-    return b.createOrFold<memref::DimOp>(loc, source, dim);
-  if (source.getType().isa<UnrankedTensorType, RankedTensorType>())
-    return b.createOrFold<tensor::DimOp>(loc, source, dim);
-  llvm_unreachable("Expected MemRefType or TensorType");
+/// Move the insertion point of the given builder to the beginning of a
+/// surrounding block as much as possible, while not crossing any allocation
+/// hoisting barriers.
+static void moveInsertionPointToAllocationHoistingBarrier(OpBuilder &b) {
+  Operation *op = b.getInsertionBlock()->getParentOp();
+  while (op) {
+    if (auto bufferizableOp = dyn_cast<BufferizableOpInterface>(op))
+      if (bufferizableOp.isAllocationHoistingBarrier())
+        break;
+    op = op->getParentOp();
+  }
+
+  // FuncOp is an allocation hoisting barrier, so the above loop should never
+  // run out of parents.
+  assert(
+      (op && cast<BufferizableOpInterface>(op).isAllocationHoistingBarrier()) &&
+      "expected traversal to end at allocation hoisting barrier");
+
+  // TODO: Handle cases where allocation hoisting barrier has more than one
+  // region or block.
+  assert(op->getNumRegions() == 1 &&
+         "allocation hoisting barriers with >1 regions not supported");
+  assert(op->getRegion(0).getBlocks().size() == 1 &&
+         "allocation hoisting barriers with >1 blocks not supported");
+  b.setInsertionPointToStart(&(op->getRegion(0).front()));
 }
 
 /// Compute the type of the `memref` to use for allocating the buffer for
 /// `shapedValue`. Also returns (by reference in `dynShape`), the value for the
-/// dynamic dimensions in the returned `memref` type. The function also sets the
-/// insertion point of the builder `b` to the position where the allocation is
-/// to be inserted.
+/// dynamic dimensions in the returned `memref` type. The function may also set
+/// the insertion point to an earlier location, where the allocation should
+/// happen ("allocation hoisting").
 static MemRefType getAllocationTypeAndShape(OpBuilder &b, Location loc,
                                             Value shapedValue,
                                             SmallVectorImpl<Value> &dynShape) {
   MemRefType allocMemRefType =
       getContiguousMemRefType(shapedValue.getType().cast<ShapedType>());
-  if (auto bbArg = shapedValue.dyn_cast<BlockArgument>()) {
-    b.setInsertionPointToStart(bbArg.getOwner());
-    loc = bbArg.getOwner()->getParentOp()->getLoc();
-  } else {
-    b.setInsertionPoint(shapedValue.getDefiningOp());
-    loc = shapedValue.getDefiningOp()->getLoc();
-  }
 
   // Compute the dynamic part of the shape.
-  bool foundDynamicShapes = false;
+  bool reifiedShapes = false;
   if (auto rankedOp = dyn_cast_or_null<ReifyRankedShapedTypeOpInterface>(
           shapedValue.getDefiningOp())) {
     ReifiedRankedShapedTypeDims resultDims;
     if (succeeded(rankedOp.reifyResultShapes(b, resultDims))) {
-      foundDynamicShapes = true;
+      reifiedShapes = true;
       OpResult resultValue = shapedValue.dyn_cast<OpResult>();
       auto &shape = resultDims[resultValue.getResultNumber()];
       for (auto dim : enumerate(allocMemRefType.getShape()))
-        if (dim.value() == ShapedType::kDynamicSize)
+        if (ShapedType::isDynamic(dim.value()))
           dynShape.push_back(shape[dim.index()]);
     }
   }
-  if (!foundDynamicShapes) {
+
+  if (!reifiedShapes) {
     for (auto dim : enumerate(allocMemRefType.getShape()))
-      if (dim.value() == ShapedType::kDynamicSize)
-        dynShape.push_back(createOrFoldDimOp(b, loc, shapedValue, dim.index()));
+      if (ShapedType::isDynamic(dim.value())) {
+        assert((shapedValue.getType().isa<UnrankedMemRefType>() ||
+                shapedValue.getType().isa<MemRefType>()) &&
+               "expected MemRef type");
+        dynShape.push_back(
+            b.create<memref::DimOp>(loc, shapedValue, dim.index()));
+      }
   }
 
   // If the buffer is statically shaped, try to hoist it to the first enclosing
@@ -811,28 +826,9 @@ static MemRefType getAllocationTypeAndShape(OpBuilder &b, Location loc,
   // calls to LICM and buffer hoisting which will most likely not succeed.
   // TODO: when packing, allocate a static bounding box which will enable more
   // hoisting.
-  if (dynShape.empty()) {
-    Operation *parent;
-    if (auto bbArg = shapedValue.dyn_cast<BlockArgument>())
-      parent = bbArg.getOwner()->getParentOp();
-    else
-      parent = shapedValue.getDefiningOp()->getParentOp();
-    while (parent) {
-      if (auto bufferizableOp = dyn_cast<BufferizableOpInterface>(parent))
-        if (bufferizableOp.isAllocationHoistingBarrier())
-          break;
-      parent = parent->getParentOp();
-    }
+  if (dynShape.empty())
+    moveInsertionPointToAllocationHoistingBarrier(b);
 
-    // FuncOp is an allocation hoisting barrier, so the above loop should never
-    // run out of parents.
-    assert(
-        (parent &&
-         cast<BufferizableOpInterface>(parent).isAllocationHoistingBarrier()) &&
-        "expected traversal to end at allocation hoisting barrier");
-
-    b.setInsertionPointToStart(&(parent->getRegion(0).front()));
-  }
   return allocMemRefType;
 }
 
@@ -2247,6 +2243,7 @@ struct ExtractSliceOpInterface
 
     // Take a guard before anything else.
     OpBuilder::InsertionGuard g(b);
+    b.setInsertionPoint(extractSliceOp);
 
     LDBG("bufferize: " << *extractSliceOp << '\n');
 
@@ -2263,9 +2260,6 @@ struct ExtractSliceOpInterface
       alloc = createNewAllocDeallocPairForShapedValue(
           b, loc, extractSliceOp.result(), state);
 
-    // Set insertion point now that potential alloc/dealloc are introduced.
-    b.setInsertionPoint(extractSliceOp);
-
     // Bufferize to subview.
     auto subviewMemRefType =
         memref::SubViewOp::inferRankReducedResultType(

diff  --git a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir
index dcd409da32d06..583c58ed12fd9 100644
--- a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir
+++ b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir
@@ -168,9 +168,9 @@ func @insert_slice_fun(%A0 : tensor<?xf32>,
   ->  (tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>)
 {
   // Hoisted allocs.
-  //      CHECK: %[[REALLOC_A1:.*]] = memref.alloc
   //      CHECK: %[[REALLOC_A0_2:.*]] = memref.alloc
   //      CHECK: %[[REALLOC_A0:.*]] = memref.alloc
+  //      CHECK: %[[REALLOC_A1:.*]] = memref.alloc
 
   // Alloc and copy the whole result tensor. Copy the tensor.extract_slice.
   //      CHECK: linalg.copy(%[[A0]], %[[REALLOC_A0]]


        


More information about the Mlir-commits mailing list