[Mlir-commits] [mlir] [mlir][transform] Don't modify the target in interpreter when loading library. (PR #67686)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Sep 28 07:41:01 PDT 2023


Ingo =?utf-8?q?Müller?= <ingomueller at google.com>
Message-ID:
In-Reply-To: <llvm/llvm-project/pull/67686/mlir at github.com>


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

<details>
<summary>Changes</summary>

Until now, if the transform script was embedded into the input IR, the transform dialect interpreter injected the externally resolved symbols into that IR, which then became part of the output. This is not always desirable.

This PR is a first step to separate the logic of loading/resolution/injection from the interpreter. The modification consists of cloning the IR that contains the main transform script if necessary (i.e., if we actually need to load it and it is part of the input op of the pass). The next step will be to introduce a dedicated pass for loading and injecting transform script and or library.

The PR also improves some variable names and related if-conditions, which are currently an independent NFC commit and could be factored out into a dedicated PR.

---
Full diff: https://github.com/llvm/llvm-project/pull/67686.diff


4 Files Affected:

- (modified) mlir/include/mlir/Dialect/Transform/Transforms/TransformInterpreterPassBase.h (+2-2) 
- (modified) mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp (+51-24) 
- (modified) mlir/test/Dialect/Transform/test-interpreter-external-symbol-decl.mlir (+7-9) 
- (modified) mlir/test/lib/Dialect/Transform/TestTransformDialectInterpreter.cpp (+2-2) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/Transform/Transforms/TransformInterpreterPassBase.h b/mlir/include/mlir/Dialect/Transform/Transforms/TransformInterpreterPassBase.h
index 91903e254b0d5b3..9c67f3af61cc12d 100644
--- a/mlir/include/mlir/Dialect/Transform/Transforms/TransformInterpreterPassBase.h
+++ b/mlir/include/mlir/Dialect/Transform/Transforms/TransformInterpreterPassBase.h
@@ -64,7 +64,7 @@ LogicalResult interpreterBaseRunOnOperationImpl(
 ///     will be interpreted.
 ///   - transformLibraryFileName: if non-empty, the name of the file containing
 ///     definitions of external symbols referenced in the transform script.
-///     These definitions will be used to replace declarations.
+///     These definitions will be used to resolve declarations.
 ///   - debugPayloadRootTag: if non-empty, the value of the attribute named
 ///     `kTransformDialectTagAttrName` indicating the single op that is
 ///     considered the payload root of the transform interpreter; otherwise, the
@@ -85,7 +85,7 @@ LogicalResult interpreterBaseRunOnOperationImpl(
 /// as template arguments. They are *not* expected to to implement `initialize`
 /// or `runOnOperation`. They *are* expected to call the copy constructor of
 /// this class in their copy constructors, short of which the file-based
-/// transform dialect script injection facility will become nonoperational.
+/// transform dialect script resolution facility will become non-operational.
 ///
 /// Concrete passes may implement the `runBeforeInterpreter` and
 /// `runAfterInterpreter` to customize the behavior of the pass.
diff --git a/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp b/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp
index d5c65b23e3a2134..9c245fda7f567e7 100644
--- a/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp
+++ b/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp
@@ -379,7 +379,7 @@ static LogicalResult defineDeclaredSymbols(Block &block, ModuleOp definitions) {
 LogicalResult transform::detail::interpreterBaseRunOnOperationImpl(
     Operation *target, StringRef passName,
     const std::shared_ptr<OwningOpRef<ModuleOp>> &sharedTransformModule,
-    const std::shared_ptr<OwningOpRef<ModuleOp>> &libraryModule,
+    const std::shared_ptr<OwningOpRef<ModuleOp>> &transformLibraryModule,
     const RaggedArray<MappedValue> &extraMappings,
     const TransformOptions &options,
     const Pass::Option<std::string> &transformFileName,
@@ -387,6 +387,16 @@ LogicalResult transform::detail::interpreterBaseRunOnOperationImpl(
     const Pass::Option<std::string> &debugPayloadRootTag,
     const Pass::Option<std::string> &debugTransformRootTag,
     StringRef binaryName) {
+  bool hasSharedTransformModule =
+      sharedTransformModule && *sharedTransformModule;
+  bool hasTransformLibraryModule =
+      transformLibraryModule && *transformLibraryModule;
+  assert((!hasSharedTransformModule || !hasTransformLibraryModule) &&
+         "at most one of shared or library transform module can be set");
+
+  // Step 0
+  // ------
+  // If debugPayloadRootTag or debugTransformRootTag was passed, then we are in)
 
   // Step 1
   // ------
@@ -407,9 +417,24 @@ LogicalResult transform::detail::interpreterBaseRunOnOperationImpl(
   // transform is embedded in the payload IR. If debugTransformRootTag was
   // passed, then we are in user-specified selection of the transforming IR.
   // This corresponds to REPL debug mode.
-  bool sharedTransform = (sharedTransformModule && *sharedTransformModule);
-  Operation *transformContainer =
-      sharedTransform ? sharedTransformModule->get() : target;
+
+  OwningOpRef<Operation *> transformContainerClone;
+  Operation *transformContainer;
+  if (hasTransformLibraryModule) {
+    // If we have a library module, then the transform script is embedded in the
+    // target, which we don't want to modify when loading the library. We thus
+    // clone the target and use that as transform container.
+    assert(!hasSharedTransformModule);
+    transformContainerClone = target->clone();
+    transformContainer = transformContainerClone.get();
+  } else {
+    // If we have a shared library, which is private to us, we can modify it
+    // when loading the library, so we use that. Otherwise, we don't have any
+    // library to load, so we can use the target and won't modify it.
+    transformContainer =
+        hasSharedTransformModule ? sharedTransformModule->get() : target;
+  }
+
   Operation *transformRoot =
       debugTransformRootTag.empty()
           ? findTopLevelTransform(transformContainer,
@@ -430,8 +455,8 @@ LogicalResult transform::detail::interpreterBaseRunOnOperationImpl(
   // Copy external defintions for symbols if provided. Be aware of potential
   // concurrent execution (normally, the error shouldn't be triggered unless the
   // transform IR modifies itself in a pass, which is also forbidden elsewhere).
-  if (!sharedTransform && libraryModule && *libraryModule) {
-    if (!target->isProperAncestor(transformRoot)) {
+  if (hasTransformLibraryModule) {
+    if (!transformContainer->isProperAncestor(transformRoot)) {
       InFlightDiagnostic diag =
           transformRoot->emitError()
           << "cannot inject transform definitions next to pass anchor op";
@@ -439,7 +464,7 @@ LogicalResult transform::detail::interpreterBaseRunOnOperationImpl(
       return diag;
     }
     if (failed(defineDeclaredSymbols(*transformRoot->getBlock(),
-                                     libraryModule->get())))
+                                     transformLibraryModule->get())))
       return failure();
   }
 
@@ -461,25 +486,27 @@ LogicalResult transform::detail::interpreterBaseRunOnOperationImpl(
 LogicalResult transform::detail::interpreterBaseInitializeImpl(
     MLIRContext *context, StringRef transformFileName,
     StringRef transformLibraryFileName,
-    std::shared_ptr<OwningOpRef<ModuleOp>> &module,
-    std::shared_ptr<OwningOpRef<ModuleOp>> &libraryModule,
+    std::shared_ptr<OwningOpRef<ModuleOp>> &sharedTransformModule,
+    std::shared_ptr<OwningOpRef<ModuleOp>> &transformLibraryModule,
     function_ref<std::optional<LogicalResult>(OpBuilder &, Location)>
         moduleBuilder) {
-  OwningOpRef<ModuleOp> parsed;
-  if (failed(parseTransformModuleFromFile(context, transformFileName, parsed)))
+  OwningOpRef<ModuleOp> parsedTransformModule;
+  if (failed(parseTransformModuleFromFile(context, transformFileName,
+                                          parsedTransformModule)))
     return failure();
-  if (parsed && failed(mlir::verify(*parsed)))
+  if (parsedTransformModule && failed(mlir::verify(*parsedTransformModule)))
     return failure();
 
-  OwningOpRef<ModuleOp> parsedLibrary;
+  OwningOpRef<ModuleOp> parsedLibraryModule;
   if (failed(parseTransformModuleFromFile(context, transformLibraryFileName,
-                                          parsedLibrary)))
+                                          parsedLibraryModule)))
     return failure();
-  if (parsedLibrary && failed(mlir::verify(*parsedLibrary)))
+  if (parsedLibraryModule && failed(mlir::verify(*parsedLibraryModule)))
     return failure();
 
-  if (parsed) {
-    module = std::make_shared<OwningOpRef<ModuleOp>>(std::move(parsed));
+  if (parsedTransformModule) {
+    sharedTransformModule = std::make_shared<OwningOpRef<ModuleOp>>(
+        std::move(parsedTransformModule));
   } else if (moduleBuilder) {
     // TODO: better location story.
     auto location = UnknownLoc::get(context);
@@ -491,20 +518,20 @@ LogicalResult transform::detail::interpreterBaseInitializeImpl(
     if (std::optional<LogicalResult> result = moduleBuilder(b, location)) {
       if (failed(*result))
         return failure();
-      module = std::move(localModule);
+      sharedTransformModule = std::move(localModule);
     }
   }
 
-  if (!parsedLibrary || !*parsedLibrary)
+  if (!parsedLibraryModule || !*parsedLibraryModule)
     return success();
 
-  if (module && *module) {
-    if (failed(defineDeclaredSymbols(*module->get().getBody(),
-                                     parsedLibrary.get())))
+  if (sharedTransformModule && *sharedTransformModule) {
+    if (failed(defineDeclaredSymbols(*sharedTransformModule->get().getBody(),
+                                     parsedLibraryModule.get())))
       return failure();
   } else {
-    libraryModule =
-        std::make_shared<OwningOpRef<ModuleOp>>(std::move(parsedLibrary));
+    transformLibraryModule =
+        std::make_shared<OwningOpRef<ModuleOp>>(std::move(parsedLibraryModule));
   }
   return success();
 }
diff --git a/mlir/test/Dialect/Transform/test-interpreter-external-symbol-decl.mlir b/mlir/test/Dialect/Transform/test-interpreter-external-symbol-decl.mlir
index 04b6c5a02e0adf1..076a2171094808a 100644
--- a/mlir/test/Dialect/Transform/test-interpreter-external-symbol-decl.mlir
+++ b/mlir/test/Dialect/Transform/test-interpreter-external-symbol-decl.mlir
@@ -1,27 +1,25 @@
 // RUN: mlir-opt %s --pass-pipeline="builtin.module(test-transform-dialect-interpreter{transform-library-file-name=%p/test-interpreter-external-symbol-def.mlir})" \
 // RUN:             --verify-diagnostics --split-input-file | FileCheck %s
 
-// RUN: mlir-opt %s --pass-pipeline="builtin.module(test-transform-dialect-interpreter{transform-library-file-name=%p/test-interpreter-external-symbol-def.mlir}, test-transform-dialect-interpreter)" \
-// RUN:             --verify-diagnostics --split-input-file | FileCheck %s
-
 // RUN: mlir-opt %s --pass-pipeline="builtin.module(test-transform-dialect-interpreter{transform-library-file-name=%p/test-interpreter-external-symbol-def.mlir}, test-transform-dialect-interpreter{transform-library-file-name=%p/test-interpreter-external-symbol-def.mlir})" \
 // RUN:             --verify-diagnostics --split-input-file | FileCheck %s
 
 // The definition of the @foo named sequence is provided in another file. It
-// will be included because of the pass option. Repeated application of the
-// same pass, with or without the library option, should not be a problem.
+// will be available because of the pass option but not included in the output.
+// Repeated application of the same pass works, but only if the library is
+// provided in both.
 // Note that the same diagnostic produced twice at the same location only
 // needs to be matched once.
 
 // expected-remark @below {{message}}
 // expected-remark @below {{unannotated}}
 module attributes {transform.with_named_sequence} {
-  // CHECK: transform.named_sequence @foo
-  // CHECK: test_print_remark_at_operand %{{.*}}, "message"
+  // CHECK: transform.named_sequence private @foo
+  // CHECK-NOT: test_print_remark_at_operand
   transform.named_sequence private @foo(!transform.any_op {transform.readonly})
 
-  // CHECK: transform.named_sequence @unannotated
-  // CHECK: test_print_remark_at_operand %{{.*}}, "unannotated"
+  // CHECK: transform.named_sequence private @unannotated
+  // CHECK-NOT: test_print_remark_at_operand
   transform.named_sequence private @unannotated(!transform.any_op {transform.readonly})
 
   transform.sequence failures(propagate) {
diff --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectInterpreter.cpp b/mlir/test/lib/Dialect/Transform/TestTransformDialectInterpreter.cpp
index f73deef9d5fd48c..675b5ecd50346fa 100644
--- a/mlir/test/lib/Dialect/Transform/TestTransformDialectInterpreter.cpp
+++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectInterpreter.cpp
@@ -219,8 +219,8 @@ class TestTransformDialectInterpreterPass
   Option<std::string> transformLibraryFileName{
       *this, "transform-library-file-name", llvm::cl::init(""),
       llvm::cl::desc(
-          "Optional name of the file containing transform dialect symbol "
-          "definitions to be injected into the transform module.")};
+          "Optional name of the file providing transform dialect definitions "
+          "from which declarations in the transform module can be resolved.")};
 
   Option<bool> testModuleGeneration{
       *this, "test-module-generation", llvm::cl::init(false),

``````````

</details>


https://github.com/llvm/llvm-project/pull/67686


More information about the Mlir-commits mailing list