[Mlir-commits] [mlir] 2f5539e - [mlir][linalg][bufferize][NFC] Move `getResultBuffer` to op interface

Matthias Springer llvmlistbot at llvm.org
Wed Nov 10 21:43:39 PST 2021


Author: Matthias Springer
Date: 2021-11-11T14:38:18+09:00
New Revision: 2f5539e300774c3828c08c8871325c0542ddb605

URL: https://github.com/llvm/llvm-project/commit/2f5539e300774c3828c08c8871325c0542ddb605
DIFF: https://github.com/llvm/llvm-project/commit/2f5539e300774c3828c08c8871325c0542ddb605.diff

LOG: [mlir][linalg][bufferize][NFC] Move `getResultBuffer` to op interface

This is in preparation of decoupling Comprehensive Bufferize from the various dialects.

Differential Revision: https://reviews.llvm.org/D113387

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h
    mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp
    mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h
index 1b2b2950776e..42908bc6c5c2 100644
--- a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h
+++ b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h
@@ -197,24 +197,47 @@ findValueInReverseUseDefChain(Value value,
 /// is returned regardless of whether it is a memory write or not.
 Value findLastPrecedingWrite(Value value);
 
-/// Callback functions that are used by the comprehensive bufferization pass to
-/// allocate/deallocate memory. The `deallocationFn` is gauranteed to recieve
-/// the `Value` returned by the `allocationFn`.
+/// Callback functions that are used to allocate/deallocate/copy memory buffers.
+/// Comprehensive Bufferize provides default implementations of these functions.
+// TODO: Could be replaced with a "bufferization strategy" object with virtual
+// functions in the future.
 struct AllocationCallbacks {
   using AllocationFn = std::function<Optional<Value>(
       OpBuilder &, Location, MemRefType, const SmallVector<Value> &)>;
   using DeallocationFn = std::function<void(OpBuilder &, Location, Value)>;
   using MemCpyFn = std::function<void(OpBuilder &, Location, Value, Value)>;
+  using CreateAllocDeallocFn =
+      std::function<Value(OpBuilder &, Location, Value,
+                          BufferizationAliasInfo &, AllocationCallbacks &)>;
 
   AllocationCallbacks(AllocationFn allocFn, DeallocationFn deallocFn,
-                      MemCpyFn copyFn)
-      : allocationFn(allocFn), deallocationFn(deallocFn), memCpyFn(copyFn) {}
+                      MemCpyFn copyFn, CreateAllocDeallocFn allocDeallocFn)
+      : allocationFn(allocFn), deallocationFn(deallocFn), memCpyFn(copyFn),
+        createAllocDeallocFn(allocDeallocFn) {}
 
+  /// A function that allocates memory.
   AllocationFn allocationFn;
+
+  /// A function that deallocated memory. Must be allocated by `allocationFn`.
   DeallocationFn deallocationFn;
+
+  /// A function that copies memory between two allocations.
   MemCpyFn memCpyFn;
+
+  /// A function that creates an alloc-dealloc pair. This function may perform
+  /// additional optimizations such as buffer allocation hoisting. This function
+  /// calls `allocationFn` and `deallocationFn` to create (de)allocations.
+  CreateAllocDeallocFn createAllocDeallocFn;
 };
 
+/// Return the result buffer (memref) for a given OpResult (tensor). Allocate
+/// a new buffer and copy over data from the existing buffer if out-of-place
+/// bufferization is necessary.
+Value getResultBuffer(OpBuilder &b, OpResult result,
+                      const BlockAndValueMapping &bvm,
+                      BufferizationAliasInfo &aliasInfo,
+                      AllocationCallbacks allocationFns);
+
 } // namespace comprehensive_bufferize
 } // namespace linalg
 } // namespace mlir

diff  --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp
index 1bc2b55f314d..3c5c95e87a04 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp
@@ -7,6 +7,7 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h"
+#include "mlir/IR/BlockAndValueMapping.h"
 #include "mlir/IR/Operation.h"
 #include "llvm/Support/Debug.h"
 
@@ -313,3 +314,70 @@ Value mlir::linalg::comprehensive_bufferize::findLastPrecedingWrite(
   assert(result.size() == 1 && "expected exactly one result");
   return result.front();
 }
+
+/// Return the result buffer (memref) for a given OpResult (tensor). Allocate
+/// a new buffer and copy over data from the existing buffer if out-of-place
+/// bufferization is necessary.
+Value mlir::linalg::comprehensive_bufferize::getResultBuffer(
+    OpBuilder &b, OpResult result, const BlockAndValueMapping &bvm,
+    BufferizationAliasInfo &aliasInfo, AllocationCallbacks allocationFns) {
+  OpBuilder::InsertionGuard guard(b);
+  Operation *op = result.getOwner();
+  SmallVector<OpOperand *> aliasingOperands = getAliasingOpOperand(result);
+  assert(!aliasingOperands.empty() && "could not get aliasing OpOperand");
+  OpOperand *opOperand = aliasingOperands.front();
+  Value operand = opOperand->get();
+  Value operandBuffer = bvm.lookupOrNull(operand);
+  assert(operandBuffer && "operand buffer not found");
+  // Make sure that all OpOperands are the same buffer. If this is not the case,
+  // we would have to materialize a memref value.
+  // TODO: Should be looking for checking for "equivalent buffers" instead of
+  // operator== here, but equivalent buffers for scf.if yield values are not
+  // set up yet.
+  if (!llvm::all_of(aliasingOperands, [&](OpOperand *o) {
+        return bvm.lookup(o->get()) == operandBuffer;
+      })) {
+    op->emitError("result buffer is ambiguous");
+    return Value();
+  }
+
+  // If bufferizing out-of-place, allocate a new buffer.
+  if (!aliasInfo.isInPlace(result)) {
+    // Ops with multiple aliasing operands can currently not bufferize
+    // out-of-place.
+    assert(
+        aliasingOperands.size() == 1 &&
+        "ops with multiple aliasing OpOperands cannot bufferize out-of-place");
+    Location loc = op->getLoc();
+    // Allocate the result buffer.
+    Value resultBuffer = allocationFns.createAllocDeallocFn(
+        b, loc, operand, aliasInfo, allocationFns);
+    bool skipCopy = false;
+    // Do not copy if the last preceding write of `operand` is an op that does
+    // not write (skipping ops that merely create aliases). E.g., InitTensorOp.
+    // Note: If `findLastPrecedingWrite` reaches the end of the reverse SSA
+    // use-def chain, it returns that value, regardless of whether it is a
+    // memory write or not.
+    Value lastWrite = findLastPrecedingWrite(operand);
+    if (auto bufferizableOp =
+            lastWrite.getDefiningOp<BufferizableOpInterface>())
+      if (!bufferizableOp.isMemoryWrite(lastWrite.cast<OpResult>()))
+        skipCopy = true;
+    // Do not copy if the copied data is never read.
+    if (!isValueRead(result))
+      skipCopy = true;
+    // Do not copy if this op does not read the data, but writes it.
+    if (bufferizesToMemoryWrite(*opOperand) &&
+        !bufferizesToMemoryRead(*opOperand))
+      skipCopy = true;
+    if (!skipCopy) {
+      // Set insertion point now that potential alloc/dealloc are introduced.
+      b.setInsertionPoint(op);
+      allocationFns.memCpyFn(b, loc, operandBuffer, resultBuffer);
+    }
+    return resultBuffer;
+  }
+
+  // Bufferizing in-place. No need to allocate a new buffer.
+  return operandBuffer;
+}

diff  --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp
index ba94049bdfb8..233c896648b9 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp
@@ -911,74 +911,6 @@ static Value createNewAllocDeallocPairForShapedValue(
 // Bufferization as simple BlockAndValueMapping rewrites.
 //===----------------------------------------------------------------------===//
 
-/// Return the result buffer (memref) for a given OpResult (tensor). Allocate
-/// a new buffer and copy over data from the existing buffer if out-of-place
-/// bufferization is necessary.
-static Value getResultBuffer(OpBuilder &b, OpResult result,
-                             const BlockAndValueMapping &bvm,
-                             BufferizationAliasInfo &aliasInfo,
-                             AllocationCallbacks allocationFns) {
-  OpBuilder::InsertionGuard guard(b);
-  Operation *op = result.getOwner();
-  SmallVector<OpOperand *> aliasingOperands = getAliasingOpOperand(result);
-  assert(!aliasingOperands.empty() && "could not get aliasing OpOperand");
-  OpOperand *opOperand = aliasingOperands.front();
-  Value operand = opOperand->get();
-  Value operandBuffer = lookup(bvm, operand);
-  assert(operandBuffer && "operand buffer not found");
-  // Make sure that all OpOperands are the same buffer. If this is not the case,
-  // we would have to materialize a memref value.
-  // TODO: Should be looking for checking for "equivalent buffers" instead of
-  // operator== here, but equivalent buffers for scf.if yield values are not
-  // set up yet.
-  if (!llvm::all_of(aliasingOperands, [&](OpOperand *o) {
-        return lookup(bvm, o->get()) == operandBuffer;
-      })) {
-    op->emitError("result buffer is ambiguous");
-    return Value();
-  }
-
-  // If bufferizing out-of-place, allocate a new buffer.
-  if (!aliasInfo.isInPlace(result)) {
-    // Ops with multiple aliasing operands can currently not bufferize
-    // out-of-place.
-    assert(
-        aliasingOperands.size() == 1 &&
-        "ops with multiple aliasing OpOperands cannot bufferize out-of-place");
-    Location loc = op->getLoc();
-    // Allocate the result buffer.
-    Value resultBuffer = createNewAllocDeallocPairForShapedValue(
-        b, loc, operand, aliasInfo, allocationFns);
-    bool skipCopy = false;
-    // Do not copy if the last preceding write of `operand` is an op that does
-    // not write (skipping ops that merely create aliases). E.g., InitTensorOp.
-    // Note: If `findLastPrecedingWrite` reaches the end of the reverse SSA
-    // use-def chain, it returns that value, regardless of whether it is a
-    // memory write or not.
-    Value lastWrite = findLastPrecedingWrite(operand);
-    if (auto bufferizableOp =
-            lastWrite.getDefiningOp<BufferizableOpInterface>())
-      if (!bufferizableOp.isMemoryWrite(lastWrite.cast<OpResult>()))
-        skipCopy = true;
-    // Do not copy if the copied data is never read.
-    if (!isValueRead(result))
-      skipCopy = true;
-    // Do not copy if this op does not read the data, but writes it.
-    if (bufferizesToMemoryWrite(*opOperand) &&
-        !bufferizesToMemoryRead(*opOperand))
-      skipCopy = true;
-    if (!skipCopy) {
-      // Set insertion point now that potential alloc/dealloc are introduced.
-      b.setInsertionPoint(op);
-      allocationFns.memCpyFn(b, loc, operandBuffer, resultBuffer);
-    }
-    return resultBuffer;
-  }
-
-  // Bufferizing in-place. No need to allocate a new buffer.
-  return operandBuffer;
-}
-
 /// In a first approximation, all the function arguments of a FuncOp are marked
 /// inplaceable. For now, it is the responsibility of the `callOp` bufferization
 /// to allow FuncOp that are inplaceable to write inPlace.
@@ -1964,7 +1896,8 @@ LogicalResult mlir::linalg::comprehensive_bufferize::runComprehensiveBufferize(
 std::unique_ptr<AllocationCallbacks>
 mlir::linalg::comprehensive_bufferize::defaultAllocationCallbacks() {
   return std::make_unique<AllocationCallbacks>(
-      defaultAllocationFn, defaultDeallocationFn, defaultMemCpyFn);
+      defaultAllocationFn, defaultDeallocationFn, defaultMemCpyFn,
+      createNewAllocDeallocPairForShapedValue);
 }
 
 // Default constructor for BufferizationOptions that sets all allocation


        


More information about the Mlir-commits mailing list