[Mlir-commits] [mlir] 333ee21 - [mlir] Transform dialect: separate dependent and generated dialects

Alex Zinenko llvmlistbot at llvm.org
Mon Jul 25 02:59:58 PDT 2022


Author: Alex Zinenko
Date: 2022-07-25T09:59:53Z
New Revision: 333ee218ce9b50ef9ded1dbafbb02667e56d18ab

URL: https://github.com/llvm/llvm-project/commit/333ee218ce9b50ef9ded1dbafbb02667e56d18ab
DIFF: https://github.com/llvm/llvm-project/commit/333ee218ce9b50ef9ded1dbafbb02667e56d18ab.diff

LOG: [mlir] Transform dialect: separate dependent and generated dialects

In the Transform dialect extensions, provide the separate mechanism to
declare dependent dialects (the dialects the transform IR depends on)
and the generated dialects (the dialects the payload IR may be
transformed into). This allows the Transform dialect clients that are
only constructing the transform IR to avoid loading the dialects
relevant for the payload IR along with the Transform dialect itself,
thus decreasing the build/link time.

Reviewed By: springerm

Differential Revision: https://reviews.llvm.org/D130289

Added: 
    mlir/unittests/Dialect/Transform/BuildOnlyExtensionTest.cpp
    mlir/unittests/Dialect/Transform/CMakeLists.txt

Modified: 
    mlir/include/mlir/Dialect/Transform/IR/TransformDialect.h
    mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp
    mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
    mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
    mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp
    mlir/unittests/Dialect/CMakeLists.txt
    utils/bazel/llvm-project-overlay/mlir/unittests/BUILD.bazel

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Transform/IR/TransformDialect.h b/mlir/include/mlir/Dialect/Transform/IR/TransformDialect.h
index b897618e368f1..1f32a595d4c87 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformDialect.h
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformDialect.h
@@ -52,6 +52,26 @@ namespace transform {
 /// expected to derive this class and register operations in the constructor.
 /// They can be registered with the DialectRegistry and automatically applied
 /// to the Transform dialect when it is loaded.
+///
+/// Derived classes are expected to define a `void init()` function in which
+/// they can call various protected methods of the base class to register
+/// extension operations and declare their dependencies.
+///
+/// By default, the extension is configured both for construction of the
+/// Transform IR and for its application to some payload. If only the
+/// construction is desired, the extension can be switched to "build-only" mode
+/// that avoids loading the dialects that are only necessary for transforming
+/// the payload. To perform the switch, the extension must be wrapped into the
+/// `BuildOnly` class template (see below) when it is registered, as in:
+///
+///    dialectRegistry.addExtension<BuildOnly<MyTransformDialectExt>>();
+///
+/// instead of:
+///
+///    dialectRegistry.addExtension<MyTransformDialectExt>();
+///
+/// Derived classes must reexport the constructor of this class or otherwise
+/// forward its boolean argument to support this behavior.
 template <typename DerivedTy, typename... ExtraDialects>
 class TransformDialectExtension
     : public DialectExtension<DerivedTy, TransformDialect, ExtraDialects...> {
@@ -65,12 +85,31 @@ class TransformDialectExtension
              ExtraDialects *...) const final {
     for (const DialectLoader &loader : dialectLoaders)
       loader(context);
+
+    // Only load generated dialects if the user intends to apply
+    // transformations specified by the extension.
+    if (!buildOnly)
+      for (const DialectLoader &loader : generatedDialectLoaders)
+        loader(context);
+
     for (const Initializer &init : opInitializers)
       init(transformDialect);
     transformDialect->mergeInPDLMatchHooks(std::move(pdlMatchConstraintFns));
   }
 
 protected:
+  using Base = TransformDialectExtension<DerivedTy, ExtraDialects...>;
+
+  /// Extension constructor. The argument indicates whether to skip generated
+  /// dialects when applying the extension.
+  explicit TransformDialectExtension(bool buildOnly = false)
+      : buildOnly(buildOnly) {
+    static_cast<DerivedTy *>(this)->init();
+  }
+
+  /// Hook for derived classes to inject constructor behavior.
+  void init() {}
+
   /// Injects the operations into the Transform dialect. The operations must
   /// implement the TransformOpInterface and MemoryEffectsOpInterface, and the
   /// implementations must be already available when the operation is injected.
@@ -85,13 +124,28 @@ class TransformDialectExtension
   /// provided as template parameter. When the Transform dialect is loaded,
   /// dependent dialects will be loaded as well. This is intended for dialects
   /// that contain attributes and types used in creation and canonicalization of
-  /// the injected operations.
+  /// the injected operations, similarly to how the dialect definition may list
+  /// dependent dialects. This is *not* intended for dialects entities from
+  /// which may be produced when applying the transformations specified by ops
+  /// registered by this extension.
   template <typename DialectTy>
   void declareDependentDialect() {
     dialectLoaders.push_back(
         [](MLIRContext *context) { context->loadDialect<DialectTy>(); });
   }
 
+  /// Declares that the transformations associated with the operations
+  /// registered by this dialect extension may produce operations from the
+  /// dialect provided as template parameter while processing payload IR that
+  /// does not contain the operations from said dialect. This is similar to
+  /// dependent dialects of a pass. These dialects will be loaded along with the
+  /// transform dialect unless the extension is in the build-only mode.
+  template <typename DialectTy>
+  void declareGeneratedDialect() {
+    generatedDialectLoaders.push_back(
+        [](MLIRContext *context) { context->loadDialect<DialectTy>(); });
+  }
+
   /// Injects the named constraint to make it available for use with the
   /// PDLMatchOp in the transform dialect.
   void registerPDLMatchConstraintFn(StringRef name,
@@ -108,14 +162,32 @@ class TransformDialectExtension
 
 private:
   SmallVector<Initializer> opInitializers;
+
+  /// Callbacks loading the dependent dialects, i.e. the dialect needed for the
+  /// extension ops.
   SmallVector<DialectLoader> dialectLoaders;
 
-  /// A list of constraints that should be made availble to PDL patterns
+  /// Callbacks loading the generated dialects, i.e. the dialects produced when
+  /// applying the transformations.
+  SmallVector<DialectLoader> generatedDialectLoaders;
+
+  /// A list of constraints that should be made available to PDL patterns
   /// processed by PDLMatchOp in the Transform dialect.
   ///
   /// Declared as mutable so its contents can be moved in the `apply` const
   /// method, which is only called once.
   mutable llvm::StringMap<PDLConstraintFunction> pdlMatchConstraintFns;
+
+  /// Indicates that the extension is in build-only mode.
+  bool buildOnly;
+};
+
+/// A wrapper for transform dialect extensions that forces them to be
+/// constructed in the build-only mode.
+template <typename DerivedTy>
+class BuildOnly : public DerivedTy {
+public:
+  BuildOnly() : DerivedTy(/*buildOnly=*/true) {}
 };
 
 } // namespace transform

diff  --git a/mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp b/mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp
index 66144cf073b4b..fc3c386d74a4a 100644
--- a/mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp
+++ b/mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp
@@ -75,10 +75,14 @@ class BufferizationTransformDialectExtension
     : public transform::TransformDialectExtension<
           BufferizationTransformDialectExtension> {
 public:
-  BufferizationTransformDialectExtension() {
-    declareDependentDialect<bufferization::BufferizationDialect>();
+  using Base::Base;
+
+  void init() {
     declareDependentDialect<pdl::PDLDialect>();
-    declareDependentDialect<memref::MemRefDialect>();
+
+    declareGeneratedDialect<bufferization::BufferizationDialect>();
+    declareGeneratedDialect<memref::MemRefDialect>();
+
     registerTransformOps<
 #define GET_OP_LIST
 #include "mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp.inc"

diff  --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 5cc900cb035a9..58aa32047b15a 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -1069,12 +1069,16 @@ class LinalgTransformDialectExtension
     : public transform::TransformDialectExtension<
           LinalgTransformDialectExtension> {
 public:
-  LinalgTransformDialectExtension() {
-    declareDependentDialect<AffineDialect>();
-    declareDependentDialect<arith::ArithmeticDialect>();
+  using Base::Base;
+
+  void init() {
     declareDependentDialect<pdl::PDLDialect>();
-    declareDependentDialect<scf::SCFDialect>();
-    declareDependentDialect<vector::VectorDialect>();
+
+    declareGeneratedDialect<AffineDialect>();
+    declareGeneratedDialect<arith::ArithmeticDialect>();
+    declareGeneratedDialect<scf::SCFDialect>();
+    declareGeneratedDialect<vector::VectorDialect>();
+
     registerTransformOps<
 #define GET_OP_LIST
 #include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp.inc"

diff  --git a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
index 0a8af59c68b05..5f32e78826c49 100644
--- a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
+++ b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
@@ -9,6 +9,7 @@
 #include "mlir/Dialect/SCF/TransformOps/SCFTransformOps.h"
 #include "mlir/Dialect/Affine/IR/AffineOps.h"
 #include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/PDL/IR/PDL.h"
 #include "mlir/Dialect/SCF/IR/SCF.h"
 #include "mlir/Dialect/SCF/Transforms/Patterns.h"
 #include "mlir/Dialect/SCF/Transforms/Transforms.h"
@@ -233,9 +234,14 @@ class SCFTransformDialectExtension
     : public transform::TransformDialectExtension<
           SCFTransformDialectExtension> {
 public:
-  SCFTransformDialectExtension() {
-    declareDependentDialect<AffineDialect>();
-    declareDependentDialect<func::FuncDialect>();
+  using Base::Base;
+
+  void init() {
+    declareDependentDialect<pdl::PDLDialect>();
+
+    declareGeneratedDialect<AffineDialect>();
+    declareGeneratedDialect<func::FuncDialect>();
+
     registerTransformOps<
 #define GET_OP_LIST
 #include "mlir/Dialect/SCF/TransformOps/SCFTransformOps.cpp.inc"

diff  --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp
index 5b7a0b88752ee..3893508ff19f4 100644
--- a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp
+++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp
@@ -295,7 +295,9 @@ class TestTransformDialectExtension
     : public transform::TransformDialectExtension<
           TestTransformDialectExtension> {
 public:
-  TestTransformDialectExtension() {
+  using Base::Base;
+
+  void init() {
     declareDependentDialect<pdl::PDLDialect>();
     registerTransformOps<TestTransformOp,
                          TestTransformUnrestrictedOpNoInterface,

diff  --git a/mlir/unittests/Dialect/CMakeLists.txt b/mlir/unittests/Dialect/CMakeLists.txt
index f8e5e46e5ac3b..cdff10654635b 100644
--- a/mlir/unittests/Dialect/CMakeLists.txt
+++ b/mlir/unittests/Dialect/CMakeLists.txt
@@ -12,4 +12,5 @@ add_subdirectory(MemRef)
 add_subdirectory(Quant)
 add_subdirectory(SparseTensor)
 add_subdirectory(SPIRV)
+add_subdirectory(Transform)
 add_subdirectory(Utils)

diff  --git a/mlir/unittests/Dialect/Transform/BuildOnlyExtensionTest.cpp b/mlir/unittests/Dialect/Transform/BuildOnlyExtensionTest.cpp
new file mode 100644
index 0000000000000..40fb752ffd6eb
--- /dev/null
+++ b/mlir/unittests/Dialect/Transform/BuildOnlyExtensionTest.cpp
@@ -0,0 +1,45 @@
+//===- BuildOnlyExtensionTest.cpp - unit test for transform extensions ----===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/Transform/IR/TransformDialect.h"
+#include "mlir/IR/DialectRegistry.h"
+#include "mlir/IR/MLIRContext.h"
+#include "gtest/gtest.h"
+
+using namespace mlir;
+using namespace mlir::transform;
+
+namespace {
+class Extension : public TransformDialectExtension<Extension> {
+public:
+  using Base::Base;
+  void init() { declareGeneratedDialect<func::FuncDialect>(); }
+};
+} // end namespace
+
+TEST(BuildOnlyExtensionTest, buildOnlyExtension) {
+  // Register the build-only version of the transform dialect extension. The
+  // func dialect is declared as generated so it should not be loaded along with
+  // the transform dialect.
+  DialectRegistry registry;
+  registry.addExtensions<BuildOnly<Extension>>();
+  MLIRContext ctx(registry);
+  ctx.getOrLoadDialect<TransformDialect>();
+  ASSERT_FALSE(ctx.getLoadedDialect<func::FuncDialect>());
+}
+
+TEST(BuildOnlyExtensionTest, buildAndApplyExtension) {
+  // Register the full version of the transform dialect extension. The func
+  // dialect should be loaded along with the transform dialect.
+  DialectRegistry registry;
+  registry.addExtensions<Extension>();
+  MLIRContext ctx(registry);
+  ctx.getOrLoadDialect<TransformDialect>();
+  ASSERT_TRUE(ctx.getLoadedDialect<func::FuncDialect>());
+}

diff  --git a/mlir/unittests/Dialect/Transform/CMakeLists.txt b/mlir/unittests/Dialect/Transform/CMakeLists.txt
new file mode 100644
index 0000000000000..1fecd21221c91
--- /dev/null
+++ b/mlir/unittests/Dialect/Transform/CMakeLists.txt
@@ -0,0 +1,8 @@
+add_mlir_unittest(MLIRTransformDialectTests
+  BuildOnlyExtensionTest.cpp
+)
+target_link_libraries(MLIRTransformDialectTests
+  PRIVATE
+  MLIRFuncDialect
+  MLIRTransformDialect
+)

diff  --git a/utils/bazel/llvm-project-overlay/mlir/unittests/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/unittests/BUILD.bazel
index d335e55694d42..ff520048803a8 100644
--- a/utils/bazel/llvm-project-overlay/mlir/unittests/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/unittests/BUILD.bazel
@@ -179,6 +179,21 @@ cc_test(
     ],
 )
 
+cc_test(
+    name = "transform_dialect_tests",
+    size = "small",
+    srcs = glob([
+        "Dialect/Transform/*.cpp",
+        "Dialect/Transform/*.h",
+    ]),
+    deps = [
+        "//llvm:TestingSupport",
+        "//llvm:gtest_main",
+        "//mlir:FuncDialect",
+        "//mlir:TransformDialect",
+    ],
+)
+
 cc_test(
     name = "dialect_utils_tests",
     size = "small",


        


More information about the Mlir-commits mailing list