[Mlir-commits] [mlir] 55e3857 - Make buffer hoisting/promotion passes use AllocationOpInterface

Mehdi Amini llvmlistbot at llvm.org
Tue Aug 22 16:51:20 PDT 2023


Author: Xiaolei Shi
Date: 2023-08-22T16:51:04-07:00
New Revision: 55e3857931d1e187af574b322e026ced379ee1f2

URL: https://github.com/llvm/llvm-project/commit/55e3857931d1e187af574b322e026ced379ee1f2
DIFF: https://github.com/llvm/llvm-project/commit/55e3857931d1e187af574b322e026ced379ee1f2.diff

LOG: Make buffer hoisting/promotion passes use AllocationOpInterface

This update implements the usage of AllocationOpInterface in the buffer hoisting/promotion passes. Two interface methods, namely `getHoistingKind` and `buildPromotedAlloc`, have been added. The former indicates which kind of hoisting (loop, block) an allocation operation supports, while the latter builds a stack allocation operation for promotable allocations used by the promote-buffers-to-stack pass.

This update makes these passes be functional for user customized allocation operation.

Reviewed By: springerm

Differential Revision: https://reviews.llvm.org/D158398

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Bufferization/IR/AllocationOpInterface.h
    mlir/include/mlir/Dialect/Bufferization/IR/AllocationOpInterface.td
    mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp
    mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocation.cpp
    mlir/lib/Dialect/Bufferization/Transforms/BufferOptimizations.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Bufferization/IR/AllocationOpInterface.h b/mlir/include/mlir/Dialect/Bufferization/IR/AllocationOpInterface.h
index bfb0762211d31d..f785fd398f3923 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/AllocationOpInterface.h
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/AllocationOpInterface.h
@@ -15,6 +15,16 @@
 
 #include "mlir/IR/Builders.h"
 
+namespace mlir {
+// Enum class representing 
diff erent hoisting kinds for the allocation
+// operation
+enum class HoistingKind : uint8_t {
+  None = 0,      // No hoisting kind selected
+  Loop = 1 << 0, // Indicates loop hoisting kind
+  Block = 1 << 1 // Indicates dominated block hoisting kind
+};
+} // namespace mlir
+
 #include "mlir/Dialect/Bufferization/IR/AllocationOpInterface.h.inc"
 
 #endif // MLIR_DIALECT_BUFFERIZATION_IR_ALLOCATIONOPINTERFACE_H_

diff  --git a/mlir/include/mlir/Dialect/Bufferization/IR/AllocationOpInterface.td b/mlir/include/mlir/Dialect/Bufferization/IR/AllocationOpInterface.td
index e3fbee7bf96c51..f82be7bb582f27 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/AllocationOpInterface.td
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/AllocationOpInterface.td
@@ -51,6 +51,25 @@ def AllocationOpInterface : OpInterface<"AllocationOpInterface"> {
       "::std::optional<::mlir::Value>", "buildClone",
       (ins "::mlir::OpBuilder&":$builder, "::mlir::Value":$alloc), [{}],
       /*defaultImplementation=*/[{ return std::nullopt; }]
+    >,
+    StaticInterfaceMethod<[{
+        Returns the kind of hoisting supported for the buffer allocated by this
+        operation.
+      }],
+      "::mlir::HoistingKind", "getHoistingKind",
+      (ins), [{}],
+      /*defaultImplementation=*/[{ return HoistingKind::None; }]
+    >,
+    StaticInterfaceMethod<[{
+        Builds a stack allocation operation using the provided builder and the
+        current allocation value (which refers to the current Op implementing this
+        interface). The allocation value is a result of the current
+        operation implementing this interface. If there is no compatible
+        stack allocation operation, this method can return ::std::nullopt.
+      }],
+      "::std::optional<::mlir::Operation*>", "buildPromotedAlloc",
+      (ins "::mlir::OpBuilder&":$builder, "::mlir::Value":$alloc), [{}],
+      /*defaultImplementation=*/[{ return std::nullopt; }]
     >
   ];
 }

diff  --git a/mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp b/mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp
index f1233739524d28..53504edde40edf 100644
--- a/mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp
+++ b/mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp
@@ -11,6 +11,7 @@
 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
 #include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h"
 #include "mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h"
+#include "mlir/Dialect/Bufferization/Transforms/Passes.h"
 #include "mlir/Dialect/Bufferization/Transforms/Transforms.h"
 #include "mlir/Dialect/Linalg/IR/Linalg.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
@@ -175,4 +176,5 @@ class BufferizationTransformDialectExtension
 void mlir::bufferization::registerTransformDialectExtension(
     DialectRegistry &registry) {
   registry.addExtensions<BufferizationTransformDialectExtension>();
+  bufferization::registerAllocationOpInterfaceExternalModels(registry);
 }

diff  --git a/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocation.cpp b/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocation.cpp
index 98184718611a54..35d48b87bd7016 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocation.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocation.cpp
@@ -644,6 +644,23 @@ struct DefaultAllocationInterface
     return builder.create<bufferization::CloneOp>(alloc.getLoc(), alloc)
         .getResult();
   }
+  static ::mlir::HoistingKind getHoistingKind() {
+    return static_cast<HoistingKind>(static_cast<uint8_t>(HoistingKind::Loop) |
+                                     static_cast<uint8_t>(HoistingKind::Block));
+  }
+  static ::std::optional<::mlir::Operation *>
+  buildPromotedAlloc(OpBuilder &builder, Value alloc) {
+    Operation *definingOp = alloc.getDefiningOp();
+    return builder.create<memref::AllocaOp>(
+        definingOp->getLoc(), cast<MemRefType>(definingOp->getResultTypes()[0]),
+        definingOp->getOperands(), definingOp->getAttrs());
+  }
+};
+
+struct DefaultAutomaticAllocationHoistingInterface
+    : public bufferization::AllocationOpInterface::ExternalModel<
+          DefaultAutomaticAllocationHoistingInterface, memref::AllocaOp> {
+  static ::mlir::HoistingKind getHoistingKind() { return HoistingKind::Loop; }
 };
 
 struct DefaultReallocationInterface
@@ -720,6 +737,8 @@ void bufferization::registerAllocationOpInterfaceExternalModels(
     DialectRegistry &registry) {
   registry.addExtension(+[](MLIRContext *ctx, memref::MemRefDialect *dialect) {
     memref::AllocOp::attachInterface<DefaultAllocationInterface>(*ctx);
+    memref::AllocaOp::attachInterface<
+        DefaultAutomaticAllocationHoistingInterface>(*ctx);
     memref::ReallocOp::attachInterface<DefaultReallocationInterface>(*ctx);
   });
 }

diff  --git a/mlir/lib/Dialect/Bufferization/Transforms/BufferOptimizations.cpp b/mlir/lib/Dialect/Bufferization/Transforms/BufferOptimizations.cpp
index b4f6a5f61fba20..90ea890d65e778 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/BufferOptimizations.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/BufferOptimizations.cpp
@@ -13,6 +13,7 @@
 
 #include "mlir/Dialect/Bufferization/Transforms/Passes.h"
 
+#include "mlir/Dialect/Bufferization/IR/AllocationOpInterface.h"
 #include "mlir/Dialect/Bufferization/Transforms/BufferUtils.h"
 #include "mlir/Dialect/Bufferization/Transforms/Transforms.h"
 #include "mlir/Dialect/Func/IR/FuncOps.h"
@@ -39,6 +40,22 @@ static bool isKnownControlFlowInterface(Operation *op) {
   return isa<LoopLikeOpInterface, RegionBranchOpInterface>(op);
 }
 
+/// Returns true if the given operation implements the AllocationOpInterface
+/// and it supports the dominate block hoisting.
+static bool allowAllocDominateBlockHoisting(Operation *op) {
+  auto allocOp = dyn_cast<AllocationOpInterface>(op);
+  return allocOp && (static_cast<uint8_t>(allocOp.getHoistingKind()) &
+                     static_cast<uint8_t>(HoistingKind::Block));
+}
+
+/// Returns true if the given operation implements the AllocationOpInterface
+/// and it supports the loop hoisting.
+static bool allowAllocLoopHoisting(Operation *op) {
+  auto allocOp = dyn_cast<AllocationOpInterface>(op);
+  return allocOp && (static_cast<uint8_t>(allocOp.getHoistingKind()) &
+                     static_cast<uint8_t>(HoistingKind::Loop));
+}
+
 /// Check if the size of the allocation is less than the given size. The
 /// transformation is only applied to small buffers since large buffers could
 /// exceed the stack space.
@@ -279,7 +296,7 @@ struct BufferAllocationHoistingState : BufferAllocationHoistingStateBase {
 
   /// Returns true if the given operation should be considered for hoisting.
   static bool shouldHoistOpType(Operation *op) {
-    return llvm::isa<memref::AllocOp>(op);
+    return allowAllocDominateBlockHoisting(op);
   }
 
   /// Sets the current placement block to the given block.
@@ -316,7 +333,7 @@ struct BufferAllocationLoopHoistingState : BufferAllocationHoistingStateBase {
 
   /// Returns true if the given operation should be considered for hoisting.
   static bool shouldHoistOpType(Operation *op) {
-    return llvm::isa<memref::AllocOp, memref::AllocaOp>(op);
+    return allowAllocLoopHoisting(op);
   }
 
   /// Does not change the internal placement block, as we want to move
@@ -356,13 +373,15 @@ class BufferPlacementPromotion : BufferPlacementTransformationBase {
       // `AutomaticAllocationScope` determined during the initialization phase.
       OpBuilder builder(startOperation);
       Operation *allocOp = alloc.getDefiningOp();
-      Operation *alloca = builder.create<memref::AllocaOp>(
-          alloc.getLoc(), cast<MemRefType>(alloc.getType()),
-          allocOp->getOperands(), allocOp->getAttrs());
-
-      // Replace the original alloc by a newly created alloca.
-      allocOp->replaceAllUsesWith(alloca);
-      allocOp->erase();
+      if (auto allocInterface = dyn_cast<AllocationOpInterface>(allocOp)) {
+        Operation *alloca =
+            allocInterface.buildPromotedAlloc(builder, alloc).value();
+        if (!alloca)
+          continue;
+        // Replace the original alloc by a newly created alloca.
+        allocOp->replaceAllUsesWith(alloca);
+        allocOp->erase();
+      }
     }
   }
 };


        


More information about the Mlir-commits mailing list