[Mlir-commits] [mlir] 09cd4a7 - Introduced AllocationOpInterface to create deallocation operations on-the-fly that are compatible with the allocation operation implementing this interface.

Julian Gross llvmlistbot at llvm.org
Wed Sep 29 06:55:11 PDT 2021


Author: Marcel Koester
Date: 2021-09-29T15:54:21+02:00
New Revision: 09cd4a71ed1ecf531dd1582718c4a424d0e3048a

URL: https://github.com/llvm/llvm-project/commit/09cd4a71ed1ecf531dd1582718c4a424d0e3048a
DIFF: https://github.com/llvm/llvm-project/commit/09cd4a71ed1ecf531dd1582718c4a424d0e3048a.diff

LOG: Introduced AllocationOpInterface to create deallocation operations on-the-fly that are compatible with the allocation operation implementing this interface.
Added interface implementations for AllocOp and CloneOp defined in the MemRef diallect.
Adapted the BufferDeallocation pass to be compatible with the interface introduced in this CL.

Differential Revision: https://reviews.llvm.org/D109350

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
    mlir/include/mlir/Interfaces/SideEffectInterfaces.td
    mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
    mlir/lib/Transforms/BufferDeallocation.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
index 630ab8621df2b..f5b8486c74d2b 100644
--- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
+++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
@@ -119,7 +119,10 @@ def AssumeAlignmentOp : MemRef_Op<"assume_alignment"> {
 // AllocOp
 //===----------------------------------------------------------------------===//
 
-def MemRef_AllocOp : AllocLikeOp<"alloc", DefaultResource> {
+def MemRef_AllocOp : AllocLikeOp<"alloc", DefaultResource, [
+    DeclareOpInterfaceMethods<AllocationOpInterface,
+      ["buildDealloc", "buildClone"]>]
+  > {
   let summary = "memory allocation operation";
   let description = [{
     The `alloc` operation allocates a region of memory, as specified by its
@@ -413,7 +416,9 @@ def MemRef_CastOp : MemRef_Op<"cast", [
 
 def CloneOp : MemRef_Op<"clone", [
     CopyOpInterface,
-    DeclareOpInterfaceMethods<MemoryEffectsOpInterface>
+    DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
+    DeclareOpInterfaceMethods<AllocationOpInterface,
+      ["buildDealloc", "buildClone"]>
   ]> {
   let builders = [
     OpBuilder<(ins "Value":$value), [{

diff  --git a/mlir/include/mlir/Interfaces/SideEffectInterfaces.td b/mlir/include/mlir/Interfaces/SideEffectInterfaces.td
index 1c12a5a56e7f7..09cb5c0cdc6a1 100644
--- a/mlir/include/mlir/Interfaces/SideEffectInterfaces.td
+++ b/mlir/include/mlir/Interfaces/SideEffectInterfaces.td
@@ -16,6 +16,45 @@
 
 include "mlir/Interfaces/SideEffectInterfaceBase.td"
 
+//===----------------------------------------------------------------------===//
+// AllocationOpInterface
+//===----------------------------------------------------------------------===//
+
+def AllocationOpInterface : OpInterface<"AllocationOpInterface"> {
+  let description = [{
+    This interface provides general allocation-related methods that are
+    designed for allocation operations. For example, it offers the ability to
+    construct associated deallocation and clone operations that are compatible
+    with the current allocation operation.
+  }];
+  let cppNamespace = "::mlir";
+
+  let methods = [
+    StaticInterfaceMethod<[{
+        Builds a deallocation operation using the provided builder and the
+        current allocation value (which refers to the current Op implementing
+        this interface). The allocation value is a result of the current
+        operation implementing this interface. If there is no compatible
+        deallocation operation, this method can return ::llvm::None.
+      }],
+      "::mlir::Optional<::mlir::Operation*>", "buildDealloc",
+      (ins "::mlir::OpBuilder&":$opBuilder, "::mlir::Value":$alloc), [{}],
+      /*defaultImplementation=*/[{ return llvm::None; }]
+    >,
+    StaticInterfaceMethod<[{
+        Builds a clone operation using the provided builder and the current
+        allocation value (which refers to the current Op implementing this
+        interface). The allocation value is a result of the current operation
+        implementing this interface. If there is no compatible clone operation,
+        this method can return ::llvm::None.
+      }],
+      "::mlir::Optional<::mlir::Value>", "buildClone",
+      (ins "::mlir::OpBuilder&":$opBuilder, "::mlir::Value":$alloc), [{}],
+      /*defaultImplementation=*/[{ return llvm::None; }]
+    >
+  ];
+}
+
 //===----------------------------------------------------------------------===//
 // MemoryEffects
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 412f49232cadd..5c588c3533e48 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -190,6 +190,15 @@ struct SimplifyDeadAlloc : public OpRewritePattern<T> {
 };
 } // end anonymous namespace.
 
+Optional<Operation *> AllocOp::buildDealloc(OpBuilder &builder, Value alloc) {
+  return builder.create<memref::DeallocOp>(alloc.getLoc(), alloc)
+      .getOperation();
+}
+
+Optional<Value> AllocOp::buildClone(OpBuilder &builder, Value alloc) {
+  return builder.create<memref::CloneOp>(alloc.getLoc(), alloc).getResult();
+}
+
 void AllocOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                           MLIRContext *context) {
   results.add<SimplifyAllocConst<AllocOp>, SimplifyDeadAlloc<AllocOp>>(context);
@@ -638,6 +647,15 @@ OpFoldResult CloneOp::fold(ArrayRef<Attribute> operands) {
   return succeeded(foldMemRefCast(*this)) ? getResult() : Value();
 }
 
+Optional<Operation *> CloneOp::buildDealloc(OpBuilder &builder, Value alloc) {
+  return builder.create<memref::DeallocOp>(alloc.getLoc(), alloc)
+      .getOperation();
+}
+
+Optional<Value> CloneOp::buildClone(OpBuilder &builder, Value alloc) {
+  return builder.create<memref::CloneOp>(alloc.getLoc(), alloc).getResult();
+}
+
 //===----------------------------------------------------------------------===//
 // DeallocOp
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Transforms/BufferDeallocation.cpp b/mlir/lib/Transforms/BufferDeallocation.cpp
index 65a54931460d2..9e1f2123382e3 100644
--- a/mlir/lib/Transforms/BufferDeallocation.cpp
+++ b/mlir/lib/Transforms/BufferDeallocation.cpp
@@ -64,14 +64,18 @@
 using namespace mlir;
 
 /// Walks over all immediate return-like terminators in the given region.
-static void walkReturnOperations(Region *region,
-                                 std::function<void(Operation *)> func) {
+static LogicalResult
+walkReturnOperations(Region *region,
+                     std::function<LogicalResult(Operation *)> func) {
   for (Block &block : *region) {
     Operation *terminator = block.getTerminator();
     // Skip non region-return-like terminators.
-    if (isRegionReturnLike(terminator))
-      func(terminator);
+    if (isRegionReturnLike(terminator)) {
+      if (failed(func(terminator)))
+        return failure();
+    }
   }
+  return success();
 }
 
 /// Checks if all operations in a given region that have at least one attached
@@ -187,24 +191,60 @@ class Backedges {
 /// The buffer deallocation transformation which ensures that all allocs in the
 /// program have a corresponding de-allocation. As a side-effect, it might also
 /// introduce clones that in turn leads to additional deallocations.
-class BufferDeallocation : BufferPlacementTransformationBase {
+class BufferDeallocation : public BufferPlacementTransformationBase {
 public:
+  using AliasAllocationMapT = llvm::DenseMap<Value, AllocationOpInterface>;
+
   BufferDeallocation(Operation *op)
       : BufferPlacementTransformationBase(op), dominators(op),
         postDominators(op) {}
 
+  /// Checks if all allocation operations either provide an already existing
+  /// deallocation operation or implement the AllocationOpInterface. In
+  /// addition, this method initializes the internal alias to
+  /// AllocationOpInterface mapping in order to get compatible
+  /// AllocationOpInterface implementations for aliases.
+  LogicalResult prepare() {
+    for (const BufferPlacementAllocs::AllocEntry &entry : allocs) {
+      // Get the defining allocation operation.
+      Value alloc = std::get<0>(entry);
+      auto allocationInterface = alloc.getDefiningOp<AllocationOpInterface>();
+      // If there is no existing deallocation operation and no implementation of
+      // the AllocationOpInterface, we cannot apply the BufferDeallocation pass.
+      if (!std::get<1>(entry) && !allocationInterface) {
+        return alloc.getDefiningOp()->emitError(
+            "Allocation is not deallocated explicitly nor does the operation "
+            "implement the AllocationOpInterface.");
+      }
+
+      // Register the current allocation interface implementation.
+      aliasToAllocations[alloc] = allocationInterface;
+
+      // Get the alias information for the current allocation node.
+      llvm::for_each(aliases.resolve(alloc), [&](Value alias) {
+        // TODO: check for incompatible implementations of the
+        // AllocationOpInterface. This could be realized by promoting the
+        // AllocationOpInterface to a DialectInterface.
+        aliasToAllocations[alias] = allocationInterface;
+      });
+    }
+    return success();
+  }
+
   /// Performs the actual placement/creation of all temporary clone and dealloc
   /// nodes.
-  void deallocate() {
+  LogicalResult deallocate() {
     // Add additional clones that are required.
-    introduceClones();
+    if (failed(introduceClones()))
+      return failure();
+
     // Place deallocations for all allocation entries.
-    placeDeallocs();
+    return placeDeallocs();
   }
 
 private:
   /// Introduces required clone operations to avoid memory leaks.
-  void introduceClones() {
+  LogicalResult introduceClones() {
     // Initialize the set of values that require a dedicated memory free
     // operation since their operands cannot be safely deallocated in a post
     // dominator.
@@ -256,21 +296,22 @@ class BufferDeallocation : BufferPlacementTransformationBase {
 
     // Add new allocs and additional clone operations.
     for (Value value : valuesToFree) {
-      if (auto blockArg = value.dyn_cast<BlockArgument>())
-        introduceBlockArgCopy(blockArg);
-      else
-        introduceValueCopyForRegionResult(value);
+      if (failed(value.isa<BlockArgument>()
+                     ? introduceBlockArgCopy(value.cast<BlockArgument>())
+                     : introduceValueCopyForRegionResult(value)))
+        return failure();
 
       // Register the value to require a final dealloc. Note that we do not have
       // to assign a block here since we do not want to move the allocation node
       // to another location.
       allocs.registerAlloc(std::make_tuple(value, nullptr));
     }
+    return success();
   }
 
   /// Introduces temporary clones in all predecessors and copies the source
   /// values into the newly allocated buffers.
-  void introduceBlockArgCopy(BlockArgument blockArg) {
+  LogicalResult introduceBlockArgCopy(BlockArgument blockArg) {
     // Allocate a buffer for the current block argument in the block of
     // the associated value (which will be a predecessor block by
     // definition).
@@ -284,18 +325,21 @@ class BufferDeallocation : BufferPlacementTransformationBase {
       Value sourceValue =
           branchInterface.getSuccessorOperands(it.getSuccessorIndex())
               .getValue()[blockArg.getArgNumber()];
-      // Create a new clone at the current location of the terminator.
-      Value clone = introduceCloneBuffers(sourceValue, terminator);
       // Wire new clone and successor operand.
       auto mutableOperands =
           branchInterface.getMutableSuccessorOperands(it.getSuccessorIndex());
-      if (!mutableOperands.hasValue())
+      if (!mutableOperands) {
         terminator->emitError() << "terminators with immutable successor "
                                    "operands are not supported";
-      else
-        mutableOperands.getValue()
-            .slice(blockArg.getArgNumber(), 1)
-            .assign(clone);
+        continue;
+      }
+      // Create a new clone at the current location of the terminator.
+      auto clone = introduceCloneBuffers(sourceValue, terminator);
+      if (failed(clone))
+        return failure();
+      mutableOperands.getValue()
+          .slice(blockArg.getArgNumber(), 1)
+          .assign(*clone);
     }
 
     // Check whether the block argument has implicitly defined predecessors via
@@ -307,14 +351,15 @@ class BufferDeallocation : BufferPlacementTransformationBase {
     RegionBranchOpInterface regionInterface;
     if (!argRegion || &argRegion->front() != block ||
         !(regionInterface = dyn_cast<RegionBranchOpInterface>(parentOp)))
-      return;
+      return success();
 
-    introduceClonesForRegionSuccessors(
-        regionInterface, argRegion->getParentOp()->getRegions(), blockArg,
-        [&](RegionSuccessor &successorRegion) {
-          // Find a predecessor of our argRegion.
-          return successorRegion.getSuccessor() == argRegion;
-        });
+    if (failed(introduceClonesForRegionSuccessors(
+            regionInterface, argRegion->getParentOp()->getRegions(), blockArg,
+            [&](RegionSuccessor &successorRegion) {
+              // Find a predecessor of our argRegion.
+              return successorRegion.getSuccessor() == argRegion;
+            })))
+      return failure();
 
     // Check whether the block argument belongs to an entry region of the
     // parent operation. In this case, we have to introduce an additional clone
@@ -326,24 +371,27 @@ class BufferDeallocation : BufferPlacementTransformationBase {
           return successorRegion.getSuccessor() == argRegion;
         });
     if (it == successorRegions.end())
-      return;
+      return success();
 
     // Determine the actual operand to introduce a clone for and rewire the
     // operand to point to the clone instead.
     Value operand =
         regionInterface.getSuccessorEntryOperands(argRegion->getRegionNumber())
             [llvm::find(it->getSuccessorInputs(), blockArg).getIndex()];
-    Value clone = introduceCloneBuffers(operand, parentOp);
+    auto clone = introduceCloneBuffers(operand, parentOp);
+    if (failed(clone))
+      return failure();
 
     auto op = llvm::find(parentOp->getOperands(), operand);
     assert(op != parentOp->getOperands().end() &&
            "parentOp does not contain operand");
-    parentOp->setOperand(op.getIndex(), clone);
+    parentOp->setOperand(op.getIndex(), *clone);
+    return success();
   }
 
   /// Introduces temporary clones in front of all associated nested-region
   /// terminators and copies the source values into the newly allocated buffers.
-  void introduceValueCopyForRegionResult(Value value) {
+  LogicalResult introduceValueCopyForRegionResult(Value value) {
     // Get the actual result index in the scope of the parent terminator.
     Operation *operation = value.getDefiningOp();
     auto regionInterface = cast<RegionBranchOpInterface>(operation);
@@ -358,15 +406,15 @@ class BufferDeallocation : BufferPlacementTransformationBase {
     // been considered critical. Therefore, the algorithm assumes that a clone
     // of a previously allocated buffer is returned by the operation (like in
     // the case of a block argument).
-    introduceClonesForRegionSuccessors(regionInterface, operation->getRegions(),
-                                       value, regionPredicate);
+    return introduceClonesForRegionSuccessors(
+        regionInterface, operation->getRegions(), value, regionPredicate);
   }
 
   /// Introduces buffer clones for all terminators in the given regions. The
   /// regionPredicate is applied to every successor region in order to restrict
   /// the clones to specific regions.
   template <typename TPredicate>
-  void introduceClonesForRegionSuccessors(
+  LogicalResult introduceClonesForRegionSuccessors(
       RegionBranchOpInterface regionInterface, MutableArrayRef<Region> regions,
       Value argValue, const TPredicate &regionPredicate) {
     for (Region &region : regions) {
@@ -389,27 +437,33 @@ class BufferDeallocation : BufferPlacementTransformationBase {
       // Iterate over all immediate terminator operations to introduce
       // new buffer allocations. Thereby, the appropriate terminator operand
       // will be adjusted to point to the newly allocated buffer instead.
-      walkReturnOperations(&region, [&](Operation *terminator) {
-        // Get the actual mutable operands for this terminator op.
-        auto terminatorOperands = *getMutableRegionBranchSuccessorOperands(
-            terminator, region.getRegionNumber());
-        // Extract the source value from the current terminator.
-        // This conversion needs to exist on a separate line due to a bug in
-        // GCC conversion analysis.
-        OperandRange immutableTerminatorOperands = terminatorOperands;
-        Value sourceValue = immutableTerminatorOperands[operandIndex];
-        // Create a new clone at the current location of the terminator.
-        Value clone = introduceCloneBuffers(sourceValue, terminator);
-        // Wire clone and terminator operand.
-        terminatorOperands.slice(operandIndex, 1).assign(clone);
-      });
+      if (failed(walkReturnOperations(&region, [&](Operation *terminator) {
+            // Get the actual mutable operands for this terminator op.
+            auto terminatorOperands = *getMutableRegionBranchSuccessorOperands(
+                terminator, region.getRegionNumber());
+            // Extract the source value from the current terminator.
+            // This conversion needs to exist on a separate line due to a bug in
+            // GCC conversion analysis.
+            OperandRange immutableTerminatorOperands = terminatorOperands;
+            Value sourceValue = immutableTerminatorOperands[operandIndex];
+            // Create a new clone at the current location of the terminator.
+            auto clone = introduceCloneBuffers(sourceValue, terminator);
+            if (failed(clone))
+              return failure();
+            // Wire clone and terminator operand.
+            terminatorOperands.slice(operandIndex, 1).assign(*clone);
+            return success();
+          })))
+        return failure();
     }
+    return success();
   }
 
   /// Creates a new memory allocation for the given source value and clones
   /// its content into the newly allocated buffer. The terminator operation is
   /// used to insert the clone operation at the right place.
-  Value introduceCloneBuffers(Value sourceValue, Operation *terminator) {
+  FailureOr<Value> introduceCloneBuffers(Value sourceValue,
+                                         Operation *terminator) {
     // Avoid multiple clones of the same source value. This can happen in the
     // presence of loops when a branch acts as a backedge while also having
     // another successor that returns to its parent operation. Note: that
@@ -422,19 +476,18 @@ class BufferDeallocation : BufferPlacementTransformationBase {
       return sourceValue;
     // Create a new clone operation that copies the contents of the old
     // buffer to the new one.
-    OpBuilder builder(terminator);
-    auto cloneOp =
-        builder.create<memref::CloneOp>(terminator->getLoc(), sourceValue);
-
-    // Remember the clone of original source value.
-    clonedValues.insert(cloneOp);
-    return cloneOp;
+    auto clone = buildClone(terminator, sourceValue);
+    if (succeeded(clone)) {
+      // Remember the clone of original source value.
+      clonedValues.insert(*clone);
+    }
+    return clone;
   }
 
   /// Finds correct dealloc positions according to the algorithm described at
   /// the top of the file for all alloc nodes and block arguments that can be
   /// handled by this analysis.
-  void placeDeallocs() const {
+  LogicalResult placeDeallocs() {
     // Move or insert deallocs using the previously computed information.
     // These deallocations will be linked to their associated allocation nodes
     // since they don't have any aliases that can (potentially) increase their
@@ -492,10 +545,54 @@ class BufferDeallocation : BufferPlacementTransformationBase {
         if (!nextOp)
           continue;
         // If there is no dealloc node, insert one in the right place.
-        OpBuilder builder(nextOp);
-        builder.create<memref::DeallocOp>(alloc.getLoc(), alloc);
+        if (failed(buildDealloc(nextOp, alloc)))
+          return failure();
       }
     }
+    return success();
+  }
+
+  /// Builds a deallocation operation compatible with the given allocation
+  /// value. If there is no registered AllocationOpInterface implementation for
+  /// the given value (e.g. in the case of a function parameter), this method
+  /// builds a memref::DeallocOp.
+  LogicalResult buildDealloc(Operation *op, Value alloc) {
+    OpBuilder builder(op);
+    auto it = aliasToAllocations.find(alloc);
+    if (it != aliasToAllocations.end()) {
+      // Call the allocation op interface to build a supported and
+      // compatible deallocation operation.
+      auto dealloc = it->second.buildDealloc(builder, alloc);
+      if (!dealloc)
+        return op->emitError()
+               << "allocations without compatible deallocations are "
+                  "not supported";
+    } else {
+      // Build a "default" DeallocOp for unknown allocation sources.
+      builder.create<memref::DeallocOp>(alloc.getLoc(), alloc);
+    }
+    return success();
+  }
+
+  /// Builds a clone operation compatible with the given allocation value. If
+  /// there is no registered AllocationOpInterface implementation for the given
+  /// value (e.g. in the case of a function parameter), this method builds a
+  /// memref::CloneOp.
+  FailureOr<Value> buildClone(Operation *op, Value alloc) {
+    OpBuilder builder(op);
+    auto it = aliasToAllocations.find(alloc);
+    if (it != aliasToAllocations.end()) {
+      // Call the allocation op interface to build a supported and
+      // compatible clone operation.
+      auto clone = it->second.buildClone(builder, alloc);
+      if (clone)
+        return *clone;
+      return (LogicalResult)(op->emitError()
+                             << "allocations without compatible clone ops "
+                                "are not supported");
+    }
+    // Build a "default" CloneOp for unknown allocation sources.
+    return builder.create<memref::CloneOp>(alloc.getLoc(), alloc).getResult();
   }
 
   /// The dominator info to find the appropriate start operation to move the
@@ -508,6 +605,9 @@ class BufferDeallocation : BufferPlacementTransformationBase {
 
   /// Stores already cloned buffers to avoid additional clones of clones.
   ValueSetT clonedValues;
+
+  /// Maps aliases to their source allocation interfaces (inverse mapping).
+  AliasAllocationMapT aliasToAllocations;
 };
 
 //===----------------------------------------------------------------------===//
@@ -529,13 +629,20 @@ struct BufferDeallocationPass : BufferDeallocationBase<BufferDeallocationPass> {
     }
 
     // Check that the control flow structures are supported.
-    if (!validateSupportedControlFlow(func.getRegion())) {
+    if (!validateSupportedControlFlow(func.getRegion()))
       return signalPassFailure();
-    }
 
-    // Place all required temporary clone and dealloc nodes.
+    // Gather all required allocation nodes and prepare the deallocation phase.
     BufferDeallocation deallocation(func);
-    deallocation.deallocate();
+
+    // Check for supported AllocationOpInterface implementations and prepare the
+    // internal deallocation pass.
+    if (failed(deallocation.prepare()))
+      return signalPassFailure();
+
+    // Place all required temporary clone and dealloc nodes.
+    if (failed(deallocation.deallocate()))
+      return signalPassFailure();
   }
 };
 


        


More information about the Mlir-commits mailing list