[Mlir-commits] [mlir] [MLIR][Transform] Allow ApplyRegisteredPassOp to take options as a param (PR #142683)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Jun 3 15:26:55 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Rolf Morel (rolfmorel)

<details>
<summary>Changes</summary>

Makes it possible to pass around the options to a pass inside a schedule.

The refactoring also makes it so that the pass manager and pass are only
constructed once per `apply()` of the transform op versus for each target
payload given to the op's `apply()`.

---
Full diff: https://github.com/llvm/llvm-project/pull/142683.diff


3 Files Affected:

- (modified) mlir/include/mlir/Dialect/Transform/IR/TransformOps.td (+10-15) 
- (modified) mlir/lib/Dialect/Transform/IR/TransformOps.cpp (+99-18) 
- (modified) mlir/test/Dialect/Transform/test-pass-application.mlir (+51-2) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
index e4eb67c8e14ce..b042f5e436185 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
@@ -399,15 +399,15 @@ def ApplyLoopInvariantCodeMotionOp : TransformDialectOp<"apply_licm",
 }
 
 def ApplyRegisteredPassOp : TransformDialectOp<"apply_registered_pass",
-    [TransformOpInterface, TransformEachOpTrait,
-     FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface]> {
+    [DeclareOpInterfaceMethods<TransformOpInterface>,
+     DeclareOpInterfaceMethods<MemoryEffectsOpInterface>]> {
   let summary = "Applies the specified registered pass or pass pipeline";
   let description = [{
     This transform applies the specified pass or pass pipeline to the targeted
     ops. The name of the pass/pipeline is specified as a string attribute, as
     set during pass/pipeline registration. Optionally, pass options may be
-    specified as a string attribute. The pass options syntax is identical to the
-    one used with "mlir-opt".
+    specified as a string attribute with the option to pass the attribute as a
+    param. The pass options syntax is identical to the one used with "mlir-opt".
 
     This op first looks for a pass pipeline with the specified name. If no such
     pipeline exists, it looks for a pass with the specified name. If no such
@@ -420,20 +420,15 @@ def ApplyRegisteredPassOp : TransformDialectOp<"apply_registered_pass",
     of targeted ops.
   }];
 
-  let arguments = (ins TransformHandleTypeInterface:$target,
+  let arguments = (ins Optional<TransformParamTypeInterface>:$dynamic_options,
+                       TransformHandleTypeInterface:$target,
                        StrAttr:$pass_name,
-                       DefaultValuedAttr<StrAttr, "\"\"">:$options);
+                       DefaultValuedAttr<StrAttr, "\"\"">:$static_options);
   let results = (outs TransformHandleTypeInterface:$result);
   let assemblyFormat = [{
-    $pass_name `to` $target attr-dict `:` functional-type(operands, results)
-  }];
-
-  let extraClassDeclaration = [{
-    ::mlir::DiagnosedSilenceableFailure applyToOne(
-      ::mlir::transform::TransformRewriter &rewriter,
-      ::mlir::Operation *target,
-      ::mlir::transform::ApplyToEachResultList &results,
-      ::mlir::transform::TransformState &state);
+    $pass_name (`with` `options` `=`
+      custom<ApplyRegisteredPassOptions>($dynamic_options, $static_options)^)?
+      `to` $target attr-dict `:` functional-type(operands, results)
   }];
 }
 
diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
index 673743f22249a..536c3e14fe5c0 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
@@ -53,6 +53,13 @@
 
 using namespace mlir;
 
+static ParseResult parseApplyRegisteredPassOptions(
+    OpAsmParser &parser,
+    std::optional<OpAsmParser::UnresolvedOperand> &dynamicOptions,
+    StringAttr &staticOptions);
+static void printApplyRegisteredPassOptions(OpAsmPrinter &printer,
+                                            Operation *op, Value dynamicOptions,
+                                            StringAttr staticOptions);
 static ParseResult parseSequenceOpOperands(
     OpAsmParser &parser, std::optional<OpAsmParser::UnresolvedOperand> &root,
     Type &rootType,
@@ -766,17 +773,38 @@ void transform::ApplyLoopInvariantCodeMotionOp::getEffects(
 // ApplyRegisteredPassOp
 //===----------------------------------------------------------------------===//
 
-DiagnosedSilenceableFailure transform::ApplyRegisteredPassOp::applyToOne(
-    transform::TransformRewriter &rewriter, Operation *target,
-    ApplyToEachResultList &results, transform::TransformState &state) {
-  // Make sure that this transform is not applied to itself. Modifying the
-  // transform IR while it is being interpreted is generally dangerous. Even
-  // more so when applying passes because they may perform a wide range of IR
-  // modifications.
-  DiagnosedSilenceableFailure payloadCheck =
-      ensurePayloadIsSeparateFromTransform(*this, target);
-  if (!payloadCheck.succeeded())
-    return payloadCheck;
+void transform::ApplyRegisteredPassOp::getEffects(
+    SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
+  consumesHandle(getTargetMutable(), effects);
+  onlyReadsHandle(getDynamicOptionsMutable(), effects);
+  producesHandle(getOperation()->getOpResults(), effects);
+  modifiesPayload(effects);
+}
+
+DiagnosedSilenceableFailure
+transform::ApplyRegisteredPassOp::apply(transform::TransformRewriter &rewriter,
+                                        transform::TransformResults &results,
+                                        transform::TransformState &state) {
+  // Check whether pass options are specified, either as a dynamic param or
+  // a static attribute. In either case, options are passed as a single string.
+  StringRef options;
+  if (auto dynamicOptions = getDynamicOptions()) {
+    ArrayRef<Attribute> dynamicOptionsParam = state.getParams(dynamicOptions);
+    if (dynamicOptionsParam.size() != 1) {
+      return emitSilenceableError()
+             << "options passed as a param must be a single value, got "
+             << dynamicOptionsParam.size();
+    }
+    if (auto optionsStrAttr = dyn_cast<StringAttr>(dynamicOptionsParam[0])) {
+      options = optionsStrAttr.getValue();
+    } else {
+      return emitSilenceableError()
+             << "options passed as a param must be a string, got "
+             << dynamicOptionsParam[0];
+    }
+  } else {
+    options = getStaticOptions();
+  }
 
   // Get pass or pass pipeline from registry.
   const PassRegistryEntry *info = PassPipelineInfo::lookup(getPassName());
@@ -786,9 +814,9 @@ DiagnosedSilenceableFailure transform::ApplyRegisteredPassOp::applyToOne(
     return emitDefiniteFailure()
            << "unknown pass or pass pipeline: " << getPassName();
 
-  // Create pass manager and run the pass or pass pipeline.
+  // Create pass manager and add the pass or pass pipeline.
   PassManager pm(getContext());
-  if (failed(info->addToPipeline(pm, getOptions(), [&](const Twine &msg) {
+  if (failed(info->addToPipeline(pm, options, [&](const Twine &msg) {
         emitError(msg);
         return failure();
       }))) {
@@ -796,16 +824,69 @@ DiagnosedSilenceableFailure transform::ApplyRegisteredPassOp::applyToOne(
            << "failed to add pass or pass pipeline to pipeline: "
            << getPassName();
   }
-  if (failed(pm.run(target))) {
-    auto diag = emitSilenceableError() << "pass pipeline failed";
-    diag.attachNote(target->getLoc()) << "target op";
-    return diag;
+
+  auto targets = SmallVector<Operation *>(state.getPayloadOps(getTarget()));
+  for (Operation *target : targets) {
+    // Make sure that this transform is not applied to itself. Modifying the
+    // transform IR while it is being interpreted is generally dangerous. Even
+    // more so when applying passes because they may perform a wide range of IR
+    // modifications.
+    DiagnosedSilenceableFailure payloadCheck =
+        ensurePayloadIsSeparateFromTransform(*this, target);
+    if (!payloadCheck.succeeded())
+      return payloadCheck;
+
+    // Run the pass or pass pipeline on the current target operation.
+    if (failed(pm.run(target))) {
+      auto diag = emitSilenceableError() << "pass pipeline failed";
+      diag.attachNote(target->getLoc()) << "target op";
+      return diag;
+    }
   }
 
-  results.push_back(target);
+  // The applied pass will have directly modified the payload IR(s).
+  results.set(llvm::cast<OpResult>(getResult()), targets);
   return DiagnosedSilenceableFailure::success();
 }
 
+static ParseResult parseApplyRegisteredPassOptions(
+    OpAsmParser &parser,
+    std::optional<OpAsmParser::UnresolvedOperand> &dynamicOptions,
+    StringAttr &staticOptions) {
+  dynamicOptions = std::nullopt;
+  OpAsmParser::UnresolvedOperand dynamicOptionsOperand;
+  OptionalParseResult hasDynamicOptions =
+      parser.parseOptionalOperand(dynamicOptionsOperand);
+
+  if (hasDynamicOptions.has_value()) {
+    if (failed(hasDynamicOptions.value()))
+      return failure();
+
+    dynamicOptions = dynamicOptionsOperand;
+    return success();
+  }
+
+  OptionalParseResult hasStaticOptions =
+      parser.parseOptionalAttribute(staticOptions);
+  if (hasStaticOptions.has_value()) {
+    if (failed(hasStaticOptions.value()))
+      return failure();
+    return success();
+  }
+
+  return success();
+}
+
+static void printApplyRegisteredPassOptions(OpAsmPrinter &printer,
+                                            Operation *op, Value dynamicOptions,
+                                            StringAttr staticOptions) {
+  if (dynamicOptions) {
+    printer.printOperand(dynamicOptions);
+  } else if (!staticOptions.getValue().empty()) {
+    printer.printAttribute(staticOptions);
+  }
+}
+
 //===----------------------------------------------------------------------===//
 // CastOp
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Transform/test-pass-application.mlir b/mlir/test/Dialect/Transform/test-pass-application.mlir
index 3a40b462b8270..e8e0f63b28096 100644
--- a/mlir/test/Dialect/Transform/test-pass-application.mlir
+++ b/mlir/test/Dialect/Transform/test-pass-application.mlir
@@ -79,7 +79,7 @@ module attributes {transform.with_named_sequence} {
     %1 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
     // expected-error @below {{failed to add pass or pass pipeline to pipeline: canonicalize}}
     // expected-error @below {{<Pass-Options-Parser>: no such option invalid-option}}
-    transform.apply_registered_pass "canonicalize" to %1 {options = "invalid-option=1"} : (!transform.any_op) -> !transform.any_op
+    transform.apply_registered_pass "canonicalize" with options = "invalid-option=1" to %1 : (!transform.any_op) -> !transform.any_op
     transform.yield
   }
 }
@@ -94,7 +94,56 @@ func.func @valid_pass_option() {
 module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%arg1: !transform.any_op) {
     %1 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
-    transform.apply_registered_pass "canonicalize" to %1 {options = "top-down=false"} : (!transform.any_op) -> !transform.any_op
+    transform.apply_registered_pass "canonicalize" with options = "top-down=false" to %1 : (!transform.any_op) -> !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+// CHECK-LABEL: func @valid_dynamic_pass_option()
+func.func @valid_dynamic_pass_option() {
+  return
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op) {
+    %1 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    %pass_options = transform.param.constant "top-down=false" -> !transform.any_param
+    transform.apply_registered_pass "canonicalize" with options = %pass_options to %1 : (!transform.any_param, !transform.any_op) -> !transform.any_op
+    transform.yield
+  }
+}
+// -----
+
+func.func @invalid_pass_option_param() {
+  return
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op) {
+    %1 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    %pass_options = transform.param.constant 42 -> !transform.any_param
+    // expected-error @below {{options passed as a param must be a string, got 42}}
+    transform.apply_registered_pass "canonicalize" with options = %pass_options to %1 : (!transform.any_param, !transform.any_op) -> !transform.any_op
+    transform.apply_registered_pass "canonicalize" with options = "invalid-option=1" to %1 : (!transform.any_op) -> !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+func.func @too_many_pass_option_params() {
+  return
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op) {
+    %1 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    %x = transform.param.constant "x" -> !transform.any_param
+    %pass_options = transform.merge_handles %x, %x : !transform.any_param
+    // expected-error @below {{options passed as a param must be a single value, got 2}}
+    transform.apply_registered_pass "canonicalize" with options = %pass_options to %1 : (!transform.any_param, !transform.any_op) -> !transform.any_op
     transform.yield
   }
 }

``````````

</details>


https://github.com/llvm/llvm-project/pull/142683


More information about the Mlir-commits mailing list