[Mlir-commits] [mlir] [MLIR][Transform] Allow ApplyRegisteredPassOp to take options as a param (PR #142683)

Rolf Morel llvmlistbot at llvm.org
Tue Jun 3 15:26:16 PDT 2025


https://github.com/rolfmorel created https://github.com/llvm/llvm-project/pull/142683

Makes it possible to pass around the options to a pass inside a schedule.

The refactoring also makes it so that the pass manager and pass are only
constructed once per `apply()` of the transform op versus for each target
payload given to the op's `apply()`.

>From 897730627e603729dfaee5aa67eb6db7f488074f Mon Sep 17 00:00:00 2001
From: Rolf Morel <rolf.morel at intel.com>
Date: Tue, 3 Jun 2025 15:08:55 -0700
Subject: [PATCH] [MLIR][Transform] Allow ApplyRegisteredPassOp to take options
 as a param

Makes it possible to pass around the options to a pass inside a schedule.

The refactoring also makes it so that the pass manager and pass are only
constructed once per apply of the transform op versus for each target
payload given to the op.
---
 .../mlir/Dialect/Transform/IR/TransformOps.td |  25 ++--
 .../lib/Dialect/Transform/IR/TransformOps.cpp | 117 +++++++++++++++---
 .../Transform/test-pass-application.mlir      |  53 +++++++-
 3 files changed, 160 insertions(+), 35 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
index e4eb67c8e14ce..b042f5e436185 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
@@ -399,15 +399,15 @@ def ApplyLoopInvariantCodeMotionOp : TransformDialectOp<"apply_licm",
 }
 
 def ApplyRegisteredPassOp : TransformDialectOp<"apply_registered_pass",
-    [TransformOpInterface, TransformEachOpTrait,
-     FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface]> {
+    [DeclareOpInterfaceMethods<TransformOpInterface>,
+     DeclareOpInterfaceMethods<MemoryEffectsOpInterface>]> {
   let summary = "Applies the specified registered pass or pass pipeline";
   let description = [{
     This transform applies the specified pass or pass pipeline to the targeted
     ops. The name of the pass/pipeline is specified as a string attribute, as
     set during pass/pipeline registration. Optionally, pass options may be
-    specified as a string attribute. The pass options syntax is identical to the
-    one used with "mlir-opt".
+    specified as a string attribute with the option to pass the attribute as a
+    param. The pass options syntax is identical to the one used with "mlir-opt".
 
     This op first looks for a pass pipeline with the specified name. If no such
     pipeline exists, it looks for a pass with the specified name. If no such
@@ -420,20 +420,15 @@ def ApplyRegisteredPassOp : TransformDialectOp<"apply_registered_pass",
     of targeted ops.
   }];
 
-  let arguments = (ins TransformHandleTypeInterface:$target,
+  let arguments = (ins Optional<TransformParamTypeInterface>:$dynamic_options,
+                       TransformHandleTypeInterface:$target,
                        StrAttr:$pass_name,
-                       DefaultValuedAttr<StrAttr, "\"\"">:$options);
+                       DefaultValuedAttr<StrAttr, "\"\"">:$static_options);
   let results = (outs TransformHandleTypeInterface:$result);
   let assemblyFormat = [{
-    $pass_name `to` $target attr-dict `:` functional-type(operands, results)
-  }];
-
-  let extraClassDeclaration = [{
-    ::mlir::DiagnosedSilenceableFailure applyToOne(
-      ::mlir::transform::TransformRewriter &rewriter,
-      ::mlir::Operation *target,
-      ::mlir::transform::ApplyToEachResultList &results,
-      ::mlir::transform::TransformState &state);
+    $pass_name (`with` `options` `=`
+      custom<ApplyRegisteredPassOptions>($dynamic_options, $static_options)^)?
+      `to` $target attr-dict `:` functional-type(operands, results)
   }];
 }
 
diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
index 673743f22249a..536c3e14fe5c0 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
@@ -53,6 +53,13 @@
 
 using namespace mlir;
 
+static ParseResult parseApplyRegisteredPassOptions(
+    OpAsmParser &parser,
+    std::optional<OpAsmParser::UnresolvedOperand> &dynamicOptions,
+    StringAttr &staticOptions);
+static void printApplyRegisteredPassOptions(OpAsmPrinter &printer,
+                                            Operation *op, Value dynamicOptions,
+                                            StringAttr staticOptions);
 static ParseResult parseSequenceOpOperands(
     OpAsmParser &parser, std::optional<OpAsmParser::UnresolvedOperand> &root,
     Type &rootType,
@@ -766,17 +773,38 @@ void transform::ApplyLoopInvariantCodeMotionOp::getEffects(
 // ApplyRegisteredPassOp
 //===----------------------------------------------------------------------===//
 
-DiagnosedSilenceableFailure transform::ApplyRegisteredPassOp::applyToOne(
-    transform::TransformRewriter &rewriter, Operation *target,
-    ApplyToEachResultList &results, transform::TransformState &state) {
-  // Make sure that this transform is not applied to itself. Modifying the
-  // transform IR while it is being interpreted is generally dangerous. Even
-  // more so when applying passes because they may perform a wide range of IR
-  // modifications.
-  DiagnosedSilenceableFailure payloadCheck =
-      ensurePayloadIsSeparateFromTransform(*this, target);
-  if (!payloadCheck.succeeded())
-    return payloadCheck;
+void transform::ApplyRegisteredPassOp::getEffects(
+    SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
+  consumesHandle(getTargetMutable(), effects);
+  onlyReadsHandle(getDynamicOptionsMutable(), effects);
+  producesHandle(getOperation()->getOpResults(), effects);
+  modifiesPayload(effects);
+}
+
+DiagnosedSilenceableFailure
+transform::ApplyRegisteredPassOp::apply(transform::TransformRewriter &rewriter,
+                                        transform::TransformResults &results,
+                                        transform::TransformState &state) {
+  // Check whether pass options are specified, either as a dynamic param or
+  // a static attribute. In either case, options are passed as a single string.
+  StringRef options;
+  if (auto dynamicOptions = getDynamicOptions()) {
+    ArrayRef<Attribute> dynamicOptionsParam = state.getParams(dynamicOptions);
+    if (dynamicOptionsParam.size() != 1) {
+      return emitSilenceableError()
+             << "options passed as a param must be a single value, got "
+             << dynamicOptionsParam.size();
+    }
+    if (auto optionsStrAttr = dyn_cast<StringAttr>(dynamicOptionsParam[0])) {
+      options = optionsStrAttr.getValue();
+    } else {
+      return emitSilenceableError()
+             << "options passed as a param must be a string, got "
+             << dynamicOptionsParam[0];
+    }
+  } else {
+    options = getStaticOptions();
+  }
 
   // Get pass or pass pipeline from registry.
   const PassRegistryEntry *info = PassPipelineInfo::lookup(getPassName());
@@ -786,9 +814,9 @@ DiagnosedSilenceableFailure transform::ApplyRegisteredPassOp::applyToOne(
     return emitDefiniteFailure()
            << "unknown pass or pass pipeline: " << getPassName();
 
-  // Create pass manager and run the pass or pass pipeline.
+  // Create pass manager and add the pass or pass pipeline.
   PassManager pm(getContext());
-  if (failed(info->addToPipeline(pm, getOptions(), [&](const Twine &msg) {
+  if (failed(info->addToPipeline(pm, options, [&](const Twine &msg) {
         emitError(msg);
         return failure();
       }))) {
@@ -796,16 +824,69 @@ DiagnosedSilenceableFailure transform::ApplyRegisteredPassOp::applyToOne(
            << "failed to add pass or pass pipeline to pipeline: "
            << getPassName();
   }
-  if (failed(pm.run(target))) {
-    auto diag = emitSilenceableError() << "pass pipeline failed";
-    diag.attachNote(target->getLoc()) << "target op";
-    return diag;
+
+  auto targets = SmallVector<Operation *>(state.getPayloadOps(getTarget()));
+  for (Operation *target : targets) {
+    // Make sure that this transform is not applied to itself. Modifying the
+    // transform IR while it is being interpreted is generally dangerous. Even
+    // more so when applying passes because they may perform a wide range of IR
+    // modifications.
+    DiagnosedSilenceableFailure payloadCheck =
+        ensurePayloadIsSeparateFromTransform(*this, target);
+    if (!payloadCheck.succeeded())
+      return payloadCheck;
+
+    // Run the pass or pass pipeline on the current target operation.
+    if (failed(pm.run(target))) {
+      auto diag = emitSilenceableError() << "pass pipeline failed";
+      diag.attachNote(target->getLoc()) << "target op";
+      return diag;
+    }
   }
 
-  results.push_back(target);
+  // The applied pass will have directly modified the payload IR(s).
+  results.set(llvm::cast<OpResult>(getResult()), targets);
   return DiagnosedSilenceableFailure::success();
 }
 
+static ParseResult parseApplyRegisteredPassOptions(
+    OpAsmParser &parser,
+    std::optional<OpAsmParser::UnresolvedOperand> &dynamicOptions,
+    StringAttr &staticOptions) {
+  dynamicOptions = std::nullopt;
+  OpAsmParser::UnresolvedOperand dynamicOptionsOperand;
+  OptionalParseResult hasDynamicOptions =
+      parser.parseOptionalOperand(dynamicOptionsOperand);
+
+  if (hasDynamicOptions.has_value()) {
+    if (failed(hasDynamicOptions.value()))
+      return failure();
+
+    dynamicOptions = dynamicOptionsOperand;
+    return success();
+  }
+
+  OptionalParseResult hasStaticOptions =
+      parser.parseOptionalAttribute(staticOptions);
+  if (hasStaticOptions.has_value()) {
+    if (failed(hasStaticOptions.value()))
+      return failure();
+    return success();
+  }
+
+  return success();
+}
+
+static void printApplyRegisteredPassOptions(OpAsmPrinter &printer,
+                                            Operation *op, Value dynamicOptions,
+                                            StringAttr staticOptions) {
+  if (dynamicOptions) {
+    printer.printOperand(dynamicOptions);
+  } else if (!staticOptions.getValue().empty()) {
+    printer.printAttribute(staticOptions);
+  }
+}
+
 //===----------------------------------------------------------------------===//
 // CastOp
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Transform/test-pass-application.mlir b/mlir/test/Dialect/Transform/test-pass-application.mlir
index 3a40b462b8270..e8e0f63b28096 100644
--- a/mlir/test/Dialect/Transform/test-pass-application.mlir
+++ b/mlir/test/Dialect/Transform/test-pass-application.mlir
@@ -79,7 +79,7 @@ module attributes {transform.with_named_sequence} {
     %1 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
     // expected-error @below {{failed to add pass or pass pipeline to pipeline: canonicalize}}
     // expected-error @below {{<Pass-Options-Parser>: no such option invalid-option}}
-    transform.apply_registered_pass "canonicalize" to %1 {options = "invalid-option=1"} : (!transform.any_op) -> !transform.any_op
+    transform.apply_registered_pass "canonicalize" with options = "invalid-option=1" to %1 : (!transform.any_op) -> !transform.any_op
     transform.yield
   }
 }
@@ -94,7 +94,56 @@ func.func @valid_pass_option() {
 module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%arg1: !transform.any_op) {
     %1 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
-    transform.apply_registered_pass "canonicalize" to %1 {options = "top-down=false"} : (!transform.any_op) -> !transform.any_op
+    transform.apply_registered_pass "canonicalize" with options = "top-down=false" to %1 : (!transform.any_op) -> !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+// CHECK-LABEL: func @valid_dynamic_pass_option()
+func.func @valid_dynamic_pass_option() {
+  return
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op) {
+    %1 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    %pass_options = transform.param.constant "top-down=false" -> !transform.any_param
+    transform.apply_registered_pass "canonicalize" with options = %pass_options to %1 : (!transform.any_param, !transform.any_op) -> !transform.any_op
+    transform.yield
+  }
+}
+// -----
+
+func.func @invalid_pass_option_param() {
+  return
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op) {
+    %1 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    %pass_options = transform.param.constant 42 -> !transform.any_param
+    // expected-error @below {{options passed as a param must be a string, got 42}}
+    transform.apply_registered_pass "canonicalize" with options = %pass_options to %1 : (!transform.any_param, !transform.any_op) -> !transform.any_op
+    transform.apply_registered_pass "canonicalize" with options = "invalid-option=1" to %1 : (!transform.any_op) -> !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+func.func @too_many_pass_option_params() {
+  return
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op) {
+    %1 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    %x = transform.param.constant "x" -> !transform.any_param
+    %pass_options = transform.merge_handles %x, %x : !transform.any_param
+    // expected-error @below {{options passed as a param must be a single value, got 2}}
+    transform.apply_registered_pass "canonicalize" with options = %pass_options to %1 : (!transform.any_param, !transform.any_op) -> !transform.any_op
     transform.yield
   }
 }



More information about the Mlir-commits mailing list