[Mlir-commits] [mlir] 2f8690b - [mlir][transform] ApplyRegisteredPassOp: Support pass pipelines

Matthias Springer llvmlistbot at llvm.org
Mon Sep 4 06:11:40 PDT 2023


Author: Matthias Springer
Date: 2023-09-04T15:11:24+02:00
New Revision: 2f8690b1e2bf09174660cd66ef8a27066848555e

URL: https://github.com/llvm/llvm-project/commit/2f8690b1e2bf09174660cd66ef8a27066848555e
DIFF: https://github.com/llvm/llvm-project/commit/2f8690b1e2bf09174660cd66ef8a27066848555e.diff

LOG: [mlir][transform] ApplyRegisteredPassOp: Support pass pipelines

The same transform op can now be used to apply registered pass pipelines.

This revision also adds a helper function for querying `PassPipelineInfo` objects and moves the corresponding `lookup` function for `PassInfo` objects to the `PassInfo` class.

Differential Revision: https://reviews.llvm.org/D159211

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
    mlir/include/mlir/Pass/Pass.h
    mlir/include/mlir/Pass/PassRegistry.h
    mlir/lib/Dialect/Transform/IR/TransformOps.cpp
    mlir/lib/Pass/PassRegistry.cpp
    mlir/test/Dialect/Transform/test-pass-application.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
index 125c7ec3c66070..ca5c915ef8c2ca 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
@@ -385,13 +385,17 @@ def ApplyLoopInvariantCodeMotionOp : TransformDialectOp<"apply_licm",
 def ApplyRegisteredPassOp : TransformDialectOp<"apply_registered_pass",
     [TransformOpInterface, TransformEachOpTrait,
      FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface]> {
-  let summary = "Applies the specified registered pass";
+  let summary = "Applies the specified registered pass or pass pipeline";
   let description = [{
-    This transform applies the specified pass to the targeted ops. The name of
-    the pass is specified as a string attribute, as set during pass
-    registration. Optionally, pass options may be specified as a string
-    attribute. The pass options syntax is identical to the one used with
-    "mlir-opt".
+    This transform applies the specified pass or pass pipeline to the targeted
+    ops. The name of the pass/pipeline is specified as a string attribute, as
+    set during pass/pipeline registration. Optionally, pass options may be
+    specified as a string attribute. The pass options syntax is identical to the
+    one used with "mlir-opt".
+
+    This op first looks for a pass pipeline with the specified name. If no such
+    pipeline exists, it looks for a pass with the specified name. If no such
+    pass exists either, this op fails definitely.
 
     This transform consumes the target handle and produces a new handle that is
     mapped to the same op. Passes are not allowed to remove/modify the operation

diff  --git a/mlir/include/mlir/Pass/Pass.h b/mlir/include/mlir/Pass/Pass.h
index 1562351e623144..5a4df4324ecd1e 100644
--- a/mlir/include/mlir/Pass/Pass.h
+++ b/mlir/include/mlir/Pass/Pass.h
@@ -55,12 +55,9 @@ class Pass {
   /// Returns the unique identifier that corresponds to this pass.
   TypeID getTypeID() const { return passID; }
 
-  /// Returns the pass info for the specified pass class or null if unknown.
-  static const PassInfo *lookupPassInfo(StringRef passArg);
-
   /// Returns the pass info for this pass, or null if unknown.
   const PassInfo *lookupPassInfo() const {
-    return lookupPassInfo(getArgument());
+    return PassInfo::lookup(getArgument());
   }
 
   /// Returns the derived pass name.

diff  --git a/mlir/include/mlir/Pass/PassRegistry.h b/mlir/include/mlir/Pass/PassRegistry.h
index db6d9dcfed672a..08874f0121991f 100644
--- a/mlir/include/mlir/Pass/PassRegistry.h
+++ b/mlir/include/mlir/Pass/PassRegistry.h
@@ -105,6 +105,10 @@ class PassPipelineInfo : public PassRegistryEntry {
       std::function<void(function_ref<void(const detail::PassOptions &)>)>
           optHandler)
       : PassRegistryEntry(arg, description, builder, std::move(optHandler)) {}
+
+  /// Returns the pass pipeline info for the specified pass pipeline or null if
+  /// unknown.
+  static const PassPipelineInfo *lookup(StringRef pipelineArg);
 };
 
 /// A structure to represent the information for a derived pass class.
@@ -114,6 +118,9 @@ class PassInfo : public PassRegistryEntry {
   /// PassRegistration or registerPass.
   PassInfo(StringRef arg, StringRef description,
            const PassAllocatorFunction &allocator);
+
+  /// Returns the pass info for the specified pass class or null if unknown.
+  static const PassInfo *lookup(StringRef passArg);
 };
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
index 54217ffbf9eb77..3c563c0a36c3bb 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
@@ -720,20 +720,23 @@ DiagnosedSilenceableFailure transform::ApplyRegisteredPassOp::applyToOne(
   if (!payloadCheck.succeeded())
     return payloadCheck;
 
-  // Get pass from registry.
-  const PassInfo *passInfo = Pass::lookupPassInfo(getPassName());
-  if (!passInfo) {
-    return emitDefiniteFailure() << "unknown pass: " << getPassName();
-  }
+  // Get pass or pass pipeline from registry.
+  const PassRegistryEntry *info = PassPipelineInfo::lookup(getPassName());
+  if (!info)
+    info = PassInfo::lookup(getPassName());
+  if (!info)
+    return emitDefiniteFailure()
+           << "unknown pass or pass pipeline: " << getPassName();
 
-  // Create pass manager with a single pass and run it.
+  // Create pass manager and run the pass or pass pipeline.
   PassManager pm(getContext());
-  if (failed(passInfo->addToPipeline(pm, getOptions(), [&](const Twine &msg) {
+  if (failed(info->addToPipeline(pm, getOptions(), [&](const Twine &msg) {
         emitError(msg);
         return failure();
       }))) {
     return emitDefiniteFailure()
-           << "failed to add pass to pipeline: " << getPassName();
+           << "failed to add pass or pass pipeline to pipeline: "
+           << getPassName();
   }
   if (failed(pm.run(target))) {
     auto diag = emitSilenceableError() << "pass pipeline failed";

diff  --git a/mlir/lib/Pass/PassRegistry.cpp b/mlir/lib/Pass/PassRegistry.cpp
index 42d65418344eea..e3ea5704f61d16 100644
--- a/mlir/lib/Pass/PassRegistry.cpp
+++ b/mlir/lib/Pass/PassRegistry.cpp
@@ -139,11 +139,18 @@ void mlir::registerPass(const PassAllocatorFunction &function) {
 }
 
 /// Returns the pass info for the specified pass argument or null if unknown.
-const PassInfo *mlir::Pass::lookupPassInfo(StringRef passArg) {
+const PassInfo *mlir::PassInfo::lookup(StringRef passArg) {
   auto it = passRegistry->find(passArg);
   return it == passRegistry->end() ? nullptr : &it->second;
 }
 
+/// Returns the pass pipeline info for the specified pass pipeline argument or
+/// null if unknown.
+const PassPipelineInfo *mlir::PassPipelineInfo::lookup(StringRef pipelineArg) {
+  auto it = passPipelineRegistry->find(pipelineArg);
+  return it == passPipelineRegistry->end() ? nullptr : &it->second;
+}
+
 //===----------------------------------------------------------------------===//
 // PassOptions
 //===----------------------------------------------------------------------===//
@@ -653,16 +660,14 @@ TextualPipeline::resolvePipelineElement(PipelineElement &element,
   // pipeline.
   if (!element.innerPipeline.empty())
     return resolvePipelineElements(element.innerPipeline, errorHandler);
+
   // Otherwise, this must be a pass or pass pipeline.
   // Check to see if a pipeline was registered with this name.
-  auto pipelineRegistryIt = passPipelineRegistry->find(element.name);
-  if (pipelineRegistryIt != passPipelineRegistry->end()) {
-    element.registryEntry = &pipelineRegistryIt->second;
+  if ((element.registryEntry = PassPipelineInfo::lookup(element.name)))
     return success();
-  }
 
   // If not, then this must be a specific pass name.
-  if ((element.registryEntry = Pass::lookupPassInfo(element.name)))
+  if ((element.registryEntry = PassInfo::lookup(element.name)))
     return success();
 
   // Emit an error for the unknown pass.

diff  --git a/mlir/test/Dialect/Transform/test-pass-application.mlir b/mlir/test/Dialect/Transform/test-pass-application.mlir
index 6342657b612801..65625457c86898 100644
--- a/mlir/test/Dialect/Transform/test-pass-application.mlir
+++ b/mlir/test/Dialect/Transform/test-pass-application.mlir
@@ -17,6 +17,21 @@ transform.sequence failures(propagate) {
 
 // -----
 
+// CHECK-LABEL: func @pass_pipeline(
+func.func @pass_pipeline() {
+  return
+}
+
+transform.sequence failures(propagate) {
+^bb1(%arg1: !transform.any_op):
+  %1 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+  // This pipeline does not do anything. Just make sure that the pipeline is
+  // found and no error is produced.
+  transform.apply_registered_pass "test-options-pass-pipeline" to %1 : (!transform.any_op) -> !transform.any_op
+}
+
+// -----
+
 func.func @invalid_pass_name() {
   return
 }
@@ -24,7 +39,7 @@ func.func @invalid_pass_name() {
 transform.sequence failures(propagate) {
 ^bb1(%arg1: !transform.any_op):
   %1 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
-  // expected-error @below {{unknown pass: non-existing-pass}}
+  // expected-error @below {{unknown pass or pass pipeline: non-existing-pass}}
   transform.apply_registered_pass "non-existing-pass" to %1 : (!transform.any_op) -> !transform.any_op
 }
 
@@ -54,7 +69,7 @@ func.func @invalid_pass_option() {
 transform.sequence failures(propagate) {
 ^bb1(%arg1: !transform.any_op):
   %1 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
-  // expected-error @below {{failed to add pass to pipeline: canonicalize}}
+  // expected-error @below {{failed to add pass or pass pipeline to pipeline: canonicalize}}
   transform.apply_registered_pass "canonicalize" to %1 {options = "invalid-option=1"} : (!transform.any_op) -> !transform.any_op
 }
 


        


More information about the Mlir-commits mailing list