[llvm] [mlir][bufferization] Move memref specific implementation of AllocationOpInterface to memref dialect directory (PR #66637)

Martin Erhart via llvm-commits llvm-commits at lists.llvm.org
Mon Sep 18 05:33:19 PDT 2023


https://github.com/maerhart created https://github.com/llvm/llvm-project/pull/66637

Follow-up on #65578

>From de0e61c655e950ed46b8f5869595053d36163ab0 Mon Sep 17 00:00:00 2001
From: Martin Erhart <merhart at google.com>
Date: Mon, 18 Sep 2023 12:28:00 +0000
Subject: [PATCH] [mlir][bufferization] Move memref specific implementation of
 AllocationOpInterface to memref dialect directory

---
 .../Dialect/Bufferization/Transforms/Passes.h |  3 -
 .../Transforms/AllocationOpInterfaceImpl.h    | 20 ++++++
 mlir/include/mlir/InitAllDialects.h           |  2 +
 .../BufferizationTransformOps.cpp             |  1 -
 .../Transforms/BufferDeallocation.cpp         |  1 -
 .../Bufferization/Transforms/Bufferize.cpp    | 57 ---------------
 .../Transforms/AllocationOpInterfaceImpl.cpp  | 69 +++++++++++++++++++
 .../Dialect/MemRef/Transforms/CMakeLists.txt  |  1 +
 .../llvm-project-overlay/mlir/BUILD.bazel     |  1 +
 9 files changed, 93 insertions(+), 62 deletions(-)
 create mode 100644 mlir/include/mlir/Dialect/MemRef/Transforms/AllocationOpInterfaceImpl.h
 create mode 100644 mlir/lib/Dialect/MemRef/Transforms/AllocationOpInterfaceImpl.cpp

diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h
index 92520eb13da6875..a6f668b26aa10e4 100644
--- a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h
@@ -211,9 +211,6 @@ std::unique_ptr<Pass> createBufferizationBufferizePass();
 // Registration
 //===----------------------------------------------------------------------===//
 
-/// Register external models for AllocationOpInterface.
-void registerAllocationOpInterfaceExternalModels(DialectRegistry &registry);
-
 /// Generate the code for registering passes.
 #define GEN_PASS_REGISTRATION
 #include "mlir/Dialect/Bufferization/Transforms/Passes.h.inc"
diff --git a/mlir/include/mlir/Dialect/MemRef/Transforms/AllocationOpInterfaceImpl.h b/mlir/include/mlir/Dialect/MemRef/Transforms/AllocationOpInterfaceImpl.h
new file mode 100644
index 000000000000000..aea05821fd1167c
--- /dev/null
+++ b/mlir/include/mlir/Dialect/MemRef/Transforms/AllocationOpInterfaceImpl.h
@@ -0,0 +1,20 @@
+//===- AllocationOpInterfaceImpl.h - Impl. of AllocationOpInterface -------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_MEMREF_ALLOCATIONOPINTERFACEIMPL_H
+#define MLIR_DIALECT_MEMREF_ALLOCATIONOPINTERFACEIMPL_H
+
+namespace mlir {
+class DialectRegistry;
+
+namespace memref {
+void registerAllocationOpInterfaceExternalModels(DialectRegistry &registry);
+} // namespace memref
+} // namespace mlir
+
+#endif // MLIR_DIALECT_MEMREF_ALLOCATIONOPINTERFACEIMPL_H
diff --git a/mlir/include/mlir/InitAllDialects.h b/mlir/include/mlir/InitAllDialects.h
index 5b2b1ed24d5173d..f36b79e86832171 100644
--- a/mlir/include/mlir/InitAllDialects.h
+++ b/mlir/include/mlir/InitAllDialects.h
@@ -50,6 +50,7 @@
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/MemRef/IR/MemRefMemorySlot.h"
 #include "mlir/Dialect/MemRef/IR/ValueBoundsOpInterfaceImpl.h"
+#include "mlir/Dialect/MemRef/Transforms/AllocationOpInterfaceImpl.h"
 #include "mlir/Dialect/MemRef/Transforms/BufferizableOpInterfaceImpl.h"
 #include "mlir/Dialect/MemRef/Transforms/RuntimeOpVerification.h"
 #include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h"
@@ -147,6 +148,7 @@ inline void registerAllDialects(DialectRegistry &registry) {
   linalg::registerBufferizableOpInterfaceExternalModels(registry);
   linalg::registerTilingInterfaceExternalModels(registry);
   linalg::registerValueBoundsOpInterfaceExternalModels(registry);
+  memref::registerAllocationOpInterfaceExternalModels(registry);
   memref::registerBufferizableOpInterfaceExternalModels(registry);
   memref::registerRuntimeVerifiableOpInterfaceExternalModels(registry);
   memref::registerValueBoundsOpInterfaceExternalModels(registry);
diff --git a/mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp b/mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp
index b84cc452d0141cd..7a6d1858489d1e6 100644
--- a/mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp
+++ b/mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp
@@ -175,5 +175,4 @@ class BufferizationTransformDialectExtension
 void mlir::bufferization::registerTransformDialectExtension(
     DialectRegistry &registry) {
   registry.addExtensions<BufferizationTransformDialectExtension>();
-  bufferization::registerAllocationOpInterfaceExternalModels(registry);
 }
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocation.cpp b/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocation.cpp
index f74c6255c196ba5..a0a81d4add71210 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocation.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocation.cpp
@@ -634,7 +634,6 @@ struct BufferDeallocationPass
   void getDependentDialects(DialectRegistry &registry) const override {
     registry.insert<bufferization::BufferizationDialect>();
     registry.insert<memref::MemRefDialect>();
-    registerAllocationOpInterfaceExternalModels(registry);
   }
 
   void runOnOperation() override {
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
index 7358d0d465d3e3d..2edb27da98fe910 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
@@ -196,7 +196,6 @@ struct OneShotBufferizePass
   void getDependentDialects(DialectRegistry &registry) const override {
     registry
         .insert<bufferization::BufferizationDialect, memref::MemRefDialect>();
-    registerAllocationOpInterfaceExternalModels(registry);
   }
 
   void runOnOperation() override {
@@ -682,59 +681,3 @@ BufferizationOptions bufferization::getPartialBufferizationOptions() {
   options.opFilter.allowDialect<BufferizationDialect>();
   return options;
 }
-
-//===----------------------------------------------------------------------===//
-// Default AllocationOpInterface implementation and registration
-//===----------------------------------------------------------------------===//
-
-namespace {
-struct DefaultAllocationInterface
-    : public bufferization::AllocationOpInterface::ExternalModel<
-          DefaultAllocationInterface, memref::AllocOp> {
-  static std::optional<Operation *> buildDealloc(OpBuilder &builder,
-                                                 Value alloc) {
-    return builder.create<memref::DeallocOp>(alloc.getLoc(), alloc)
-        .getOperation();
-  }
-  static std::optional<Value> buildClone(OpBuilder &builder, Value alloc) {
-    return builder.create<bufferization::CloneOp>(alloc.getLoc(), alloc)
-        .getResult();
-  }
-  static ::mlir::HoistingKind getHoistingKind() {
-    return HoistingKind::Loop | HoistingKind::Block;
-  }
-  static ::std::optional<::mlir::Operation *>
-  buildPromotedAlloc(OpBuilder &builder, Value alloc) {
-    Operation *definingOp = alloc.getDefiningOp();
-    return builder.create<memref::AllocaOp>(
-        definingOp->getLoc(), cast<MemRefType>(definingOp->getResultTypes()[0]),
-        definingOp->getOperands(), definingOp->getAttrs());
-  }
-};
-
-struct DefaultAutomaticAllocationHoistingInterface
-    : public bufferization::AllocationOpInterface::ExternalModel<
-          DefaultAutomaticAllocationHoistingInterface, memref::AllocaOp> {
-  static ::mlir::HoistingKind getHoistingKind() { return HoistingKind::Loop; }
-};
-
-struct DefaultReallocationInterface
-    : public bufferization::AllocationOpInterface::ExternalModel<
-          DefaultAllocationInterface, memref::ReallocOp> {
-  static std::optional<Operation *> buildDealloc(OpBuilder &builder,
-                                                 Value realloc) {
-    return builder.create<memref::DeallocOp>(realloc.getLoc(), realloc)
-        .getOperation();
-  }
-};
-} // namespace
-
-void bufferization::registerAllocationOpInterfaceExternalModels(
-    DialectRegistry &registry) {
-  registry.addExtension(+[](MLIRContext *ctx, memref::MemRefDialect *dialect) {
-    memref::AllocOp::attachInterface<DefaultAllocationInterface>(*ctx);
-    memref::AllocaOp::attachInterface<
-        DefaultAutomaticAllocationHoistingInterface>(*ctx);
-    memref::ReallocOp::attachInterface<DefaultReallocationInterface>(*ctx);
-  });
-}
diff --git a/mlir/lib/Dialect/MemRef/Transforms/AllocationOpInterfaceImpl.cpp b/mlir/lib/Dialect/MemRef/Transforms/AllocationOpInterfaceImpl.cpp
new file mode 100644
index 000000000000000..c4334159443236e
--- /dev/null
+++ b/mlir/lib/Dialect/MemRef/Transforms/AllocationOpInterfaceImpl.cpp
@@ -0,0 +1,69 @@
+//===- AllocationOpInterfaceImpl.cpp - Impl. of AllocationOpInterface -----===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/MemRef/Transforms/AllocationOpInterfaceImpl.h"
+
+#include "mlir/Dialect/Bufferization/IR/AllocationOpInterface.h"
+#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/IR/Dialect.h"
+#include "mlir/IR/Operation.h"
+
+using namespace mlir;
+
+namespace {
+struct DefaultAllocationInterface
+    : public bufferization::AllocationOpInterface::ExternalModel<
+          DefaultAllocationInterface, memref::AllocOp> {
+  static std::optional<Operation *> buildDealloc(OpBuilder &builder,
+                                                 Value alloc) {
+    return builder.create<memref::DeallocOp>(alloc.getLoc(), alloc)
+        .getOperation();
+  }
+  static std::optional<Value> buildClone(OpBuilder &builder, Value alloc) {
+    return builder.create<bufferization::CloneOp>(alloc.getLoc(), alloc)
+        .getResult();
+  }
+  static ::mlir::HoistingKind getHoistingKind() {
+    return HoistingKind::Loop | HoistingKind::Block;
+  }
+  static ::std::optional<::mlir::Operation *>
+  buildPromotedAlloc(OpBuilder &builder, Value alloc) {
+    Operation *definingOp = alloc.getDefiningOp();
+    return builder.create<memref::AllocaOp>(
+        definingOp->getLoc(), cast<MemRefType>(definingOp->getResultTypes()[0]),
+        definingOp->getOperands(), definingOp->getAttrs());
+  }
+};
+
+struct DefaultAutomaticAllocationHoistingInterface
+    : public bufferization::AllocationOpInterface::ExternalModel<
+          DefaultAutomaticAllocationHoistingInterface, memref::AllocaOp> {
+  static ::mlir::HoistingKind getHoistingKind() { return HoistingKind::Loop; }
+};
+
+struct DefaultReallocationInterface
+    : public bufferization::AllocationOpInterface::ExternalModel<
+          DefaultAllocationInterface, memref::ReallocOp> {
+  static std::optional<Operation *> buildDealloc(OpBuilder &builder,
+                                                 Value realloc) {
+    return builder.create<memref::DeallocOp>(realloc.getLoc(), realloc)
+        .getOperation();
+  }
+};
+} // namespace
+
+void mlir::memref::registerAllocationOpInterfaceExternalModels(
+    DialectRegistry &registry) {
+  registry.addExtension(+[](MLIRContext *ctx, memref::MemRefDialect *dialect) {
+    memref::AllocOp::attachInterface<DefaultAllocationInterface>(*ctx);
+    memref::AllocaOp::attachInterface<
+        DefaultAutomaticAllocationHoistingInterface>(*ctx);
+    memref::ReallocOp::attachInterface<DefaultReallocationInterface>(*ctx);
+  });
+}
diff --git a/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt b/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt
index ddd674c37c4e536..b16c281c93640ea 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt
@@ -1,4 +1,5 @@
 add_mlir_dialect_library(MLIRMemRefTransforms
+  AllocationOpInterfaceImpl.cpp
   BufferizableOpInterfaceImpl.cpp
   ComposeSubView.cpp
   ExpandOps.cpp
diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
index 9bea555f701757c..3449a9a1bbcabe0 100644
--- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
@@ -11722,6 +11722,7 @@ cc_library(
         ":AffineDialect",
         ":AffineTransforms",
         ":AffineUtils",
+        ":AllocationOpInterface",
         ":ArithDialect",
         ":ArithTransforms",
         ":ArithUtils",



More information about the llvm-commits mailing list