[Mlir-commits] [mlir] 9d30c6a - [mlir][transform] generate transform module on-the-fly

Alex Zinenko llvmlistbot at llvm.org
Tue Jun 6 02:35:02 PDT 2023


Author: Alex Zinenko
Date: 2023-06-06T09:34:54Z
New Revision: 9d30c6a721edf75d0726e07fb82cc5538fb95c16

URL: https://github.com/llvm/llvm-project/commit/9d30c6a721edf75d0726e07fb82cc5538fb95c16
DIFF: https://github.com/llvm/llvm-project/commit/9d30c6a721edf75d0726e07fb82cc5538fb95c16.diff

LOG: [mlir][transform] generate transform module on-the-fly

Add a TransformInterpreterPassBase capability to generate the (shared)
module containing the transform script during the pass initialization.
This is helpful to programmatically generate the script as opposed to
parsing it from the textual module.

Reviewed By: springerm

Differential Revision: https://reviews.llvm.org/D152185

Added: 
    mlir/test/Dialect/Transform/test-interpreter-module-generation.mlir

Modified: 
    mlir/include/mlir/Dialect/Transform/Transforms/TransformInterpreterPassBase.h
    mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp
    mlir/test/lib/Dialect/Transform/TestTransformDialectInterpreter.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Transform/Transforms/TransformInterpreterPassBase.h b/mlir/include/mlir/Dialect/Transform/Transforms/TransformInterpreterPassBase.h
index d7c61ef9cf9cf..91903e254b0d5 100644
--- a/mlir/include/mlir/Dialect/Transform/Transforms/TransformInterpreterPassBase.h
+++ b/mlir/include/mlir/Dialect/Transform/Transforms/TransformInterpreterPassBase.h
@@ -35,7 +35,9 @@ LogicalResult interpreterBaseInitializeImpl(
     MLIRContext *context, StringRef transformFileName,
     StringRef transformLibraryFileName,
     std::shared_ptr<OwningOpRef<ModuleOp>> &module,
-    std::shared_ptr<OwningOpRef<ModuleOp>> &libraryModule);
+    std::shared_ptr<OwningOpRef<ModuleOp>> &libraryModule,
+    function_ref<std::optional<LogicalResult>(OpBuilder &, Location)>
+        moduleBuilder = nullptr);
 
 /// Template-free implementation of
 /// TransformInterpreterPassBase::runOnOperation.
@@ -123,7 +125,11 @@ class TransformInterpreterPassBase : public GeneratedBase<Concrete> {
         static_cast<Concrete *>(this)->transformLibraryFileName;
     return detail::interpreterBaseInitializeImpl(
         context, transformFileName, transformLibraryFileName,
-        sharedTransformModule, transformLibraryModule);
+        sharedTransformModule, transformLibraryModule,
+        [this](OpBuilder &builder, Location loc) {
+          return static_cast<Concrete *>(this)->constructTransformModule(
+              builder, loc);
+        });
   }
 
   /// Hook for passes to run additional logic in the pass before the
@@ -136,6 +142,14 @@ class TransformInterpreterPassBase : public GeneratedBase<Concrete> {
   /// fails.
   LogicalResult runAfterInterpreter(Operation *) { return success(); }
 
+  /// Hook for passes to run custom logic to construct the transform module.
+  /// This will run during initialization. If the external script is provided,
+  /// it overrides the construction, which will not be called.
+  std::optional<LogicalResult> constructTransformModule(OpBuilder &builder,
+                                                        Location loc) {
+    return std::nullopt;
+  }
+
   void runOnOperation() override {
     auto *pass = static_cast<Concrete *>(this);
     Operation *op = pass->getOperation();

diff  --git a/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp b/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp
index 9dc91612fa434..b9380f52072f3 100644
--- a/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp
+++ b/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp
@@ -462,7 +462,9 @@ LogicalResult transform::detail::interpreterBaseInitializeImpl(
     MLIRContext *context, StringRef transformFileName,
     StringRef transformLibraryFileName,
     std::shared_ptr<OwningOpRef<ModuleOp>> &module,
-    std::shared_ptr<OwningOpRef<ModuleOp>> &libraryModule) {
+    std::shared_ptr<OwningOpRef<ModuleOp>> &libraryModule,
+    function_ref<std::optional<LogicalResult>(OpBuilder &, Location)>
+        moduleBuilder) {
   OwningOpRef<ModuleOp> parsed;
   if (failed(parseTransformModuleFromFile(context, transformFileName, parsed)))
     return failure();
@@ -476,7 +478,23 @@ LogicalResult transform::detail::interpreterBaseInitializeImpl(
   if (parsedLibrary && failed(mlir::verify(*parsedLibrary)))
     return failure();
 
-  module = std::make_shared<OwningOpRef<ModuleOp>>(std::move(parsed));
+  if (parsed) {
+    module = std::make_shared<OwningOpRef<ModuleOp>>(std::move(parsed));
+  } else if (moduleBuilder) {
+    // TODO: better location story.
+    auto location = UnknownLoc::get(context);
+    auto localModule = std::make_shared<OwningOpRef<ModuleOp>>(
+        ModuleOp::create(location, "__transform"));
+
+    OpBuilder b(context);
+    b.setInsertionPointToEnd(localModule->get().getBody());
+    if (std::optional<LogicalResult> result = moduleBuilder(b, location)) {
+      if (failed(*result))
+        return failure();
+      module = std::move(localModule);
+    }
+  }
+
   if (!parsedLibrary || !*parsedLibrary)
     return success();
 

diff  --git a/mlir/test/Dialect/Transform/test-interpreter-module-generation.mlir b/mlir/test/Dialect/Transform/test-interpreter-module-generation.mlir
new file mode 100644
index 0000000000000..159aed720964d
--- /dev/null
+++ b/mlir/test/Dialect/Transform/test-interpreter-module-generation.mlir
@@ -0,0 +1,4 @@
+// RUN: mlir-opt %s --test-transform-dialect-interpreter=test-module-generation=1 --verify-diagnostics
+
+// expected-remark @below {{remark from generated}}
+module {}

diff  --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectInterpreter.cpp b/mlir/test/lib/Dialect/Transform/TestTransformDialectInterpreter.cpp
index b7e9d08effb7e..f73deef9d5fd4 100644
--- a/mlir/test/lib/Dialect/Transform/TestTransformDialectInterpreter.cpp
+++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectInterpreter.cpp
@@ -11,7 +11,9 @@
 //
 //===----------------------------------------------------------------------===//
 
+#include "TestTransformDialectExtension.h"
 #include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
+#include "mlir/Dialect/Transform/IR/TransformOps.h"
 #include "mlir/Dialect/Transform/Transforms/TransformInterpreterPassBase.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/BuiltinOps.h"
@@ -46,6 +48,10 @@ class TestTransformDialectInterpreterPass
     return "apply transform dialect operations one by one";
   }
 
+  void getDependentDialects(DialectRegistry &registry) const override {
+    registry.insert<transform::TransformDialect>();
+  }
+
   void findOperationsByName(Operation *root, StringRef name,
                             SmallVectorImpl<Operation *> &operations) {
     root->walk([&](Operation *op) {
@@ -86,6 +92,22 @@ class TestTransformDialectInterpreterPass
     return numSetValues;
   }
 
+  std::optional<LogicalResult> constructTransformModule(OpBuilder &builder,
+                                                        Location loc) {
+    if (!testModuleGeneration)
+      return std::nullopt;
+
+    builder.create<transform::SequenceOp>(
+        loc, TypeRange(), transform::FailurePropagationMode::Propagate,
+        builder.getType<transform::AnyOpType>(),
+        [](OpBuilder &b, Location nested, Value rootH) {
+          b.create<mlir::test::TestPrintRemarkAtOperandOp>(
+              nested, rootH, "remark from generated");
+          b.create<transform::YieldOp>(nested, ValueRange());
+        });
+    return success();
+  }
+
   void runOnOperation() override {
     unsigned firstSetOptions =
         numberOfSetOptions(bindFirstExtraToOps, bindFirstExtraToParams,
@@ -199,6 +221,11 @@ class TestTransformDialectInterpreterPass
       llvm::cl::desc(
           "Optional name of the file containing transform dialect symbol "
           "definitions to be injected into the transform module.")};
+
+  Option<bool> testModuleGeneration{
+      *this, "test-module-generation", llvm::cl::init(false),
+      llvm::cl::desc("test the generation of the transform module during pass "
+                     "initialization, overridden by parsing")};
 };
 
 struct TestTransformDialectEraseSchedulePass


        


More information about the Mlir-commits mailing list