[Mlir-commits] [mlir] Transform loading separation (PR #67686)

Ingo Müller llvmlistbot at llvm.org
Thu Sep 28 07:27:27 PDT 2023


https://github.com/ingomueller-net created https://github.com/llvm/llvm-project/pull/67686

None

>From 14ac86570446888a4c50957911ab1a4d27fa255f Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Ingo=20M=C3=BCller?= <ingomueller at google.com>
Date: Thu, 28 Sep 2023 09:20:59 +0000
Subject: [PATCH 1/2] Make variable names for shared and library modules
 consistent. (NFC)

This commit renames the arguments of several static implementation
functions of the transform interpreter base class to match the names of
the corresponding member variables in order to clarify their intent.
Similarly, it renames some local variables to reflect their relationship
with corresponding member variables. Finally, this commit also asserts
in `interpreterBaseRunOnOperationImpl` that at most one of shared and
library module are set (which the initialization function guarantees)
and simplifies some related `if` conditions.
---
 .../TransformInterpreterPassBase.cpp          | 56 +++++++++++--------
 1 file changed, 34 insertions(+), 22 deletions(-)

diff --git a/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp b/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp
index d5c65b23e3a2134..671f6902fecdb16 100644
--- a/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp
+++ b/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp
@@ -379,7 +379,7 @@ static LogicalResult defineDeclaredSymbols(Block &block, ModuleOp definitions) {
 LogicalResult transform::detail::interpreterBaseRunOnOperationImpl(
     Operation *target, StringRef passName,
     const std::shared_ptr<OwningOpRef<ModuleOp>> &sharedTransformModule,
-    const std::shared_ptr<OwningOpRef<ModuleOp>> &libraryModule,
+    const std::shared_ptr<OwningOpRef<ModuleOp>> &transformLibraryModule,
     const RaggedArray<MappedValue> &extraMappings,
     const TransformOptions &options,
     const Pass::Option<std::string> &transformFileName,
@@ -387,6 +387,16 @@ LogicalResult transform::detail::interpreterBaseRunOnOperationImpl(
     const Pass::Option<std::string> &debugPayloadRootTag,
     const Pass::Option<std::string> &debugTransformRootTag,
     StringRef binaryName) {
+  bool hasSharedTransformModule =
+      sharedTransformModule && *sharedTransformModule;
+  bool hasTransformLibraryModule =
+      transformLibraryModule && *transformLibraryModule;
+  assert((!hasSharedTransformModule || !hasTransformLibraryModule) &&
+         "at most one of shared or library transform module can be set");
+
+  // Step 0
+  // ------
+  // If debugPayloadRootTag or debugTransformRootTag was passed, then we are in)
 
   // Step 1
   // ------
@@ -407,9 +417,8 @@ LogicalResult transform::detail::interpreterBaseRunOnOperationImpl(
   // transform is embedded in the payload IR. If debugTransformRootTag was
   // passed, then we are in user-specified selection of the transforming IR.
   // This corresponds to REPL debug mode.
-  bool sharedTransform = (sharedTransformModule && *sharedTransformModule);
   Operation *transformContainer =
-      sharedTransform ? sharedTransformModule->get() : target;
+      hasSharedTransformModule ? sharedTransformModule->get() : target;
   Operation *transformRoot =
       debugTransformRootTag.empty()
           ? findTopLevelTransform(transformContainer,
@@ -430,7 +439,8 @@ LogicalResult transform::detail::interpreterBaseRunOnOperationImpl(
   // Copy external defintions for symbols if provided. Be aware of potential
   // concurrent execution (normally, the error shouldn't be triggered unless the
   // transform IR modifies itself in a pass, which is also forbidden elsewhere).
-  if (!sharedTransform && libraryModule && *libraryModule) {
+  if (hasTransformLibraryModule) {
+    assert(!hasSharedTransformModule);
     if (!target->isProperAncestor(transformRoot)) {
       InFlightDiagnostic diag =
           transformRoot->emitError()
@@ -439,7 +449,7 @@ LogicalResult transform::detail::interpreterBaseRunOnOperationImpl(
       return diag;
     }
     if (failed(defineDeclaredSymbols(*transformRoot->getBlock(),
-                                     libraryModule->get())))
+                                     transformLibraryModule->get())))
       return failure();
   }
 
@@ -461,25 +471,27 @@ LogicalResult transform::detail::interpreterBaseRunOnOperationImpl(
 LogicalResult transform::detail::interpreterBaseInitializeImpl(
     MLIRContext *context, StringRef transformFileName,
     StringRef transformLibraryFileName,
-    std::shared_ptr<OwningOpRef<ModuleOp>> &module,
-    std::shared_ptr<OwningOpRef<ModuleOp>> &libraryModule,
+    std::shared_ptr<OwningOpRef<ModuleOp>> &sharedTransformModule,
+    std::shared_ptr<OwningOpRef<ModuleOp>> &transformLibraryModule,
     function_ref<std::optional<LogicalResult>(OpBuilder &, Location)>
         moduleBuilder) {
-  OwningOpRef<ModuleOp> parsed;
-  if (failed(parseTransformModuleFromFile(context, transformFileName, parsed)))
+  OwningOpRef<ModuleOp> parsedTransformModule;
+  if (failed(parseTransformModuleFromFile(context, transformFileName,
+                                          parsedTransformModule)))
     return failure();
-  if (parsed && failed(mlir::verify(*parsed)))
+  if (parsedTransformModule && failed(mlir::verify(*parsedTransformModule)))
     return failure();
 
-  OwningOpRef<ModuleOp> parsedLibrary;
+  OwningOpRef<ModuleOp> parsedLibraryModule;
   if (failed(parseTransformModuleFromFile(context, transformLibraryFileName,
-                                          parsedLibrary)))
+                                          parsedLibraryModule)))
     return failure();
-  if (parsedLibrary && failed(mlir::verify(*parsedLibrary)))
+  if (parsedLibraryModule && failed(mlir::verify(*parsedLibraryModule)))
     return failure();
 
-  if (parsed) {
-    module = std::make_shared<OwningOpRef<ModuleOp>>(std::move(parsed));
+  if (parsedTransformModule) {
+    sharedTransformModule = std::make_shared<OwningOpRef<ModuleOp>>(
+        std::move(parsedTransformModule));
   } else if (moduleBuilder) {
     // TODO: better location story.
     auto location = UnknownLoc::get(context);
@@ -491,20 +503,20 @@ LogicalResult transform::detail::interpreterBaseInitializeImpl(
     if (std::optional<LogicalResult> result = moduleBuilder(b, location)) {
       if (failed(*result))
         return failure();
-      module = std::move(localModule);
+      sharedTransformModule = std::move(localModule);
     }
   }
 
-  if (!parsedLibrary || !*parsedLibrary)
+  if (!parsedLibraryModule || !*parsedLibraryModule)
     return success();
 
-  if (module && *module) {
-    if (failed(defineDeclaredSymbols(*module->get().getBody(),
-                                     parsedLibrary.get())))
+  if (sharedTransformModule && *sharedTransformModule) {
+    if (failed(defineDeclaredSymbols(*sharedTransformModule->get().getBody(),
+                                     parsedLibraryModule.get())))
       return failure();
   } else {
-    libraryModule =
-        std::make_shared<OwningOpRef<ModuleOp>>(std::move(parsedLibrary));
+    transformLibraryModule =
+        std::make_shared<OwningOpRef<ModuleOp>>(std::move(parsedLibraryModule));
   }
   return success();
 }

>From 54f083ae9343cd362412b695691cee9bf5f5cf6a Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Ingo=20M=C3=BCller?= <ingomueller at google.com>
Date: Thu, 28 Sep 2023 13:32:25 +0000
Subject: [PATCH 2/2] Do not modify the target when loading transform libary
 module.

Until now, if the transform script was embedded into the input IR, the
transform dialect interpreter injected the externally resolved symbols
into that IR, which then became part of the output. This is not always
desirable.

This commit is a first step to separate the logic of loading/resolution/
injection from the interpreter. The modification consists of cloning the
IR that contains the main transform script if necessary (i.e., if we
actually need to load it and it is part of the input op of the pass).
The next step will be to introduce a dedicated pass for loading and
injecting transform script and or library.
---
 .../Transforms/TransformInterpreterPassBase.h |  4 ++--
 .../TransformInterpreterPassBase.cpp          | 23 +++++++++++++++----
 ...test-interpreter-external-symbol-decl.mlir | 16 ++++++-------
 .../TestTransformDialectInterpreter.cpp       |  4 ++--
 4 files changed, 30 insertions(+), 17 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Transform/Transforms/TransformInterpreterPassBase.h b/mlir/include/mlir/Dialect/Transform/Transforms/TransformInterpreterPassBase.h
index 91903e254b0d5b3..9c67f3af61cc12d 100644
--- a/mlir/include/mlir/Dialect/Transform/Transforms/TransformInterpreterPassBase.h
+++ b/mlir/include/mlir/Dialect/Transform/Transforms/TransformInterpreterPassBase.h
@@ -64,7 +64,7 @@ LogicalResult interpreterBaseRunOnOperationImpl(
 ///     will be interpreted.
 ///   - transformLibraryFileName: if non-empty, the name of the file containing
 ///     definitions of external symbols referenced in the transform script.
-///     These definitions will be used to replace declarations.
+///     These definitions will be used to resolve declarations.
 ///   - debugPayloadRootTag: if non-empty, the value of the attribute named
 ///     `kTransformDialectTagAttrName` indicating the single op that is
 ///     considered the payload root of the transform interpreter; otherwise, the
@@ -85,7 +85,7 @@ LogicalResult interpreterBaseRunOnOperationImpl(
 /// as template arguments. They are *not* expected to to implement `initialize`
 /// or `runOnOperation`. They *are* expected to call the copy constructor of
 /// this class in their copy constructors, short of which the file-based
-/// transform dialect script injection facility will become nonoperational.
+/// transform dialect script resolution facility will become non-operational.
 ///
 /// Concrete passes may implement the `runBeforeInterpreter` and
 /// `runAfterInterpreter` to customize the behavior of the pass.
diff --git a/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp b/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp
index 671f6902fecdb16..9c245fda7f567e7 100644
--- a/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp
+++ b/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp
@@ -417,8 +417,24 @@ LogicalResult transform::detail::interpreterBaseRunOnOperationImpl(
   // transform is embedded in the payload IR. If debugTransformRootTag was
   // passed, then we are in user-specified selection of the transforming IR.
   // This corresponds to REPL debug mode.
-  Operation *transformContainer =
-      hasSharedTransformModule ? sharedTransformModule->get() : target;
+
+  OwningOpRef<Operation *> transformContainerClone;
+  Operation *transformContainer;
+  if (hasTransformLibraryModule) {
+    // If we have a library module, then the transform script is embedded in the
+    // target, which we don't want to modify when loading the library. We thus
+    // clone the target and use that as transform container.
+    assert(!hasSharedTransformModule);
+    transformContainerClone = target->clone();
+    transformContainer = transformContainerClone.get();
+  } else {
+    // If we have a shared library, which is private to us, we can modify it
+    // when loading the library, so we use that. Otherwise, we don't have any
+    // library to load, so we can use the target and won't modify it.
+    transformContainer =
+        hasSharedTransformModule ? sharedTransformModule->get() : target;
+  }
+
   Operation *transformRoot =
       debugTransformRootTag.empty()
           ? findTopLevelTransform(transformContainer,
@@ -440,8 +456,7 @@ LogicalResult transform::detail::interpreterBaseRunOnOperationImpl(
   // concurrent execution (normally, the error shouldn't be triggered unless the
   // transform IR modifies itself in a pass, which is also forbidden elsewhere).
   if (hasTransformLibraryModule) {
-    assert(!hasSharedTransformModule);
-    if (!target->isProperAncestor(transformRoot)) {
+    if (!transformContainer->isProperAncestor(transformRoot)) {
       InFlightDiagnostic diag =
           transformRoot->emitError()
           << "cannot inject transform definitions next to pass anchor op";
diff --git a/mlir/test/Dialect/Transform/test-interpreter-external-symbol-decl.mlir b/mlir/test/Dialect/Transform/test-interpreter-external-symbol-decl.mlir
index 04b6c5a02e0adf1..076a2171094808a 100644
--- a/mlir/test/Dialect/Transform/test-interpreter-external-symbol-decl.mlir
+++ b/mlir/test/Dialect/Transform/test-interpreter-external-symbol-decl.mlir
@@ -1,27 +1,25 @@
 // RUN: mlir-opt %s --pass-pipeline="builtin.module(test-transform-dialect-interpreter{transform-library-file-name=%p/test-interpreter-external-symbol-def.mlir})" \
 // RUN:             --verify-diagnostics --split-input-file | FileCheck %s
 
-// RUN: mlir-opt %s --pass-pipeline="builtin.module(test-transform-dialect-interpreter{transform-library-file-name=%p/test-interpreter-external-symbol-def.mlir}, test-transform-dialect-interpreter)" \
-// RUN:             --verify-diagnostics --split-input-file | FileCheck %s
-
 // RUN: mlir-opt %s --pass-pipeline="builtin.module(test-transform-dialect-interpreter{transform-library-file-name=%p/test-interpreter-external-symbol-def.mlir}, test-transform-dialect-interpreter{transform-library-file-name=%p/test-interpreter-external-symbol-def.mlir})" \
 // RUN:             --verify-diagnostics --split-input-file | FileCheck %s
 
 // The definition of the @foo named sequence is provided in another file. It
-// will be included because of the pass option. Repeated application of the
-// same pass, with or without the library option, should not be a problem.
+// will be available because of the pass option but not included in the output.
+// Repeated application of the same pass works, but only if the library is
+// provided in both.
 // Note that the same diagnostic produced twice at the same location only
 // needs to be matched once.
 
 // expected-remark @below {{message}}
 // expected-remark @below {{unannotated}}
 module attributes {transform.with_named_sequence} {
-  // CHECK: transform.named_sequence @foo
-  // CHECK: test_print_remark_at_operand %{{.*}}, "message"
+  // CHECK: transform.named_sequence private @foo
+  // CHECK-NOT: test_print_remark_at_operand
   transform.named_sequence private @foo(!transform.any_op {transform.readonly})
 
-  // CHECK: transform.named_sequence @unannotated
-  // CHECK: test_print_remark_at_operand %{{.*}}, "unannotated"
+  // CHECK: transform.named_sequence private @unannotated
+  // CHECK-NOT: test_print_remark_at_operand
   transform.named_sequence private @unannotated(!transform.any_op {transform.readonly})
 
   transform.sequence failures(propagate) {
diff --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectInterpreter.cpp b/mlir/test/lib/Dialect/Transform/TestTransformDialectInterpreter.cpp
index f73deef9d5fd48c..675b5ecd50346fa 100644
--- a/mlir/test/lib/Dialect/Transform/TestTransformDialectInterpreter.cpp
+++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectInterpreter.cpp
@@ -219,8 +219,8 @@ class TestTransformDialectInterpreterPass
   Option<std::string> transformLibraryFileName{
       *this, "transform-library-file-name", llvm::cl::init(""),
       llvm::cl::desc(
-          "Optional name of the file containing transform dialect symbol "
-          "definitions to be injected into the transform module.")};
+          "Optional name of the file providing transform dialect definitions "
+          "from which declarations in the transform module can be resolved.")};
 
   Option<bool> testModuleGeneration{
       *this, "test-module-generation", llvm::cl::init(false),



More information about the Mlir-commits mailing list