[Mlir-commits] [mlir] [MLIR][Transform] Allow ApplyRegisteredPassOp to take options as a param (PR #142683)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Jun 3 15:26:55 PDT 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Rolf Morel (rolfmorel)
<details>
<summary>Changes</summary>
Makes it possible to pass around the options to a pass inside a schedule.
The refactoring also makes it so that the pass manager and pass are only
constructed once per `apply()` of the transform op versus for each target
payload given to the op's `apply()`.
---
Full diff: https://github.com/llvm/llvm-project/pull/142683.diff
3 Files Affected:
- (modified) mlir/include/mlir/Dialect/Transform/IR/TransformOps.td (+10-15)
- (modified) mlir/lib/Dialect/Transform/IR/TransformOps.cpp (+99-18)
- (modified) mlir/test/Dialect/Transform/test-pass-application.mlir (+51-2)
``````````diff
diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
index e4eb67c8e14ce..b042f5e436185 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
@@ -399,15 +399,15 @@ def ApplyLoopInvariantCodeMotionOp : TransformDialectOp<"apply_licm",
}
def ApplyRegisteredPassOp : TransformDialectOp<"apply_registered_pass",
- [TransformOpInterface, TransformEachOpTrait,
- FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface]> {
+ [DeclareOpInterfaceMethods<TransformOpInterface>,
+ DeclareOpInterfaceMethods<MemoryEffectsOpInterface>]> {
let summary = "Applies the specified registered pass or pass pipeline";
let description = [{
This transform applies the specified pass or pass pipeline to the targeted
ops. The name of the pass/pipeline is specified as a string attribute, as
set during pass/pipeline registration. Optionally, pass options may be
- specified as a string attribute. The pass options syntax is identical to the
- one used with "mlir-opt".
+ specified as a string attribute with the option to pass the attribute as a
+ param. The pass options syntax is identical to the one used with "mlir-opt".
This op first looks for a pass pipeline with the specified name. If no such
pipeline exists, it looks for a pass with the specified name. If no such
@@ -420,20 +420,15 @@ def ApplyRegisteredPassOp : TransformDialectOp<"apply_registered_pass",
of targeted ops.
}];
- let arguments = (ins TransformHandleTypeInterface:$target,
+ let arguments = (ins Optional<TransformParamTypeInterface>:$dynamic_options,
+ TransformHandleTypeInterface:$target,
StrAttr:$pass_name,
- DefaultValuedAttr<StrAttr, "\"\"">:$options);
+ DefaultValuedAttr<StrAttr, "\"\"">:$static_options);
let results = (outs TransformHandleTypeInterface:$result);
let assemblyFormat = [{
- $pass_name `to` $target attr-dict `:` functional-type(operands, results)
- }];
-
- let extraClassDeclaration = [{
- ::mlir::DiagnosedSilenceableFailure applyToOne(
- ::mlir::transform::TransformRewriter &rewriter,
- ::mlir::Operation *target,
- ::mlir::transform::ApplyToEachResultList &results,
- ::mlir::transform::TransformState &state);
+ $pass_name (`with` `options` `=`
+ custom<ApplyRegisteredPassOptions>($dynamic_options, $static_options)^)?
+ `to` $target attr-dict `:` functional-type(operands, results)
}];
}
diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
index 673743f22249a..536c3e14fe5c0 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
@@ -53,6 +53,13 @@
using namespace mlir;
+static ParseResult parseApplyRegisteredPassOptions(
+ OpAsmParser &parser,
+ std::optional<OpAsmParser::UnresolvedOperand> &dynamicOptions,
+ StringAttr &staticOptions);
+static void printApplyRegisteredPassOptions(OpAsmPrinter &printer,
+ Operation *op, Value dynamicOptions,
+ StringAttr staticOptions);
static ParseResult parseSequenceOpOperands(
OpAsmParser &parser, std::optional<OpAsmParser::UnresolvedOperand> &root,
Type &rootType,
@@ -766,17 +773,38 @@ void transform::ApplyLoopInvariantCodeMotionOp::getEffects(
// ApplyRegisteredPassOp
//===----------------------------------------------------------------------===//
-DiagnosedSilenceableFailure transform::ApplyRegisteredPassOp::applyToOne(
- transform::TransformRewriter &rewriter, Operation *target,
- ApplyToEachResultList &results, transform::TransformState &state) {
- // Make sure that this transform is not applied to itself. Modifying the
- // transform IR while it is being interpreted is generally dangerous. Even
- // more so when applying passes because they may perform a wide range of IR
- // modifications.
- DiagnosedSilenceableFailure payloadCheck =
- ensurePayloadIsSeparateFromTransform(*this, target);
- if (!payloadCheck.succeeded())
- return payloadCheck;
+void transform::ApplyRegisteredPassOp::getEffects(
+ SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
+ consumesHandle(getTargetMutable(), effects);
+ onlyReadsHandle(getDynamicOptionsMutable(), effects);
+ producesHandle(getOperation()->getOpResults(), effects);
+ modifiesPayload(effects);
+}
+
+DiagnosedSilenceableFailure
+transform::ApplyRegisteredPassOp::apply(transform::TransformRewriter &rewriter,
+ transform::TransformResults &results,
+ transform::TransformState &state) {
+ // Check whether pass options are specified, either as a dynamic param or
+ // a static attribute. In either case, options are passed as a single string.
+ StringRef options;
+ if (auto dynamicOptions = getDynamicOptions()) {
+ ArrayRef<Attribute> dynamicOptionsParam = state.getParams(dynamicOptions);
+ if (dynamicOptionsParam.size() != 1) {
+ return emitSilenceableError()
+ << "options passed as a param must be a single value, got "
+ << dynamicOptionsParam.size();
+ }
+ if (auto optionsStrAttr = dyn_cast<StringAttr>(dynamicOptionsParam[0])) {
+ options = optionsStrAttr.getValue();
+ } else {
+ return emitSilenceableError()
+ << "options passed as a param must be a string, got "
+ << dynamicOptionsParam[0];
+ }
+ } else {
+ options = getStaticOptions();
+ }
// Get pass or pass pipeline from registry.
const PassRegistryEntry *info = PassPipelineInfo::lookup(getPassName());
@@ -786,9 +814,9 @@ DiagnosedSilenceableFailure transform::ApplyRegisteredPassOp::applyToOne(
return emitDefiniteFailure()
<< "unknown pass or pass pipeline: " << getPassName();
- // Create pass manager and run the pass or pass pipeline.
+ // Create pass manager and add the pass or pass pipeline.
PassManager pm(getContext());
- if (failed(info->addToPipeline(pm, getOptions(), [&](const Twine &msg) {
+ if (failed(info->addToPipeline(pm, options, [&](const Twine &msg) {
emitError(msg);
return failure();
}))) {
@@ -796,16 +824,69 @@ DiagnosedSilenceableFailure transform::ApplyRegisteredPassOp::applyToOne(
<< "failed to add pass or pass pipeline to pipeline: "
<< getPassName();
}
- if (failed(pm.run(target))) {
- auto diag = emitSilenceableError() << "pass pipeline failed";
- diag.attachNote(target->getLoc()) << "target op";
- return diag;
+
+ auto targets = SmallVector<Operation *>(state.getPayloadOps(getTarget()));
+ for (Operation *target : targets) {
+ // Make sure that this transform is not applied to itself. Modifying the
+ // transform IR while it is being interpreted is generally dangerous. Even
+ // more so when applying passes because they may perform a wide range of IR
+ // modifications.
+ DiagnosedSilenceableFailure payloadCheck =
+ ensurePayloadIsSeparateFromTransform(*this, target);
+ if (!payloadCheck.succeeded())
+ return payloadCheck;
+
+ // Run the pass or pass pipeline on the current target operation.
+ if (failed(pm.run(target))) {
+ auto diag = emitSilenceableError() << "pass pipeline failed";
+ diag.attachNote(target->getLoc()) << "target op";
+ return diag;
+ }
}
- results.push_back(target);
+ // The applied pass will have directly modified the payload IR(s).
+ results.set(llvm::cast<OpResult>(getResult()), targets);
return DiagnosedSilenceableFailure::success();
}
+static ParseResult parseApplyRegisteredPassOptions(
+ OpAsmParser &parser,
+ std::optional<OpAsmParser::UnresolvedOperand> &dynamicOptions,
+ StringAttr &staticOptions) {
+ dynamicOptions = std::nullopt;
+ OpAsmParser::UnresolvedOperand dynamicOptionsOperand;
+ OptionalParseResult hasDynamicOptions =
+ parser.parseOptionalOperand(dynamicOptionsOperand);
+
+ if (hasDynamicOptions.has_value()) {
+ if (failed(hasDynamicOptions.value()))
+ return failure();
+
+ dynamicOptions = dynamicOptionsOperand;
+ return success();
+ }
+
+ OptionalParseResult hasStaticOptions =
+ parser.parseOptionalAttribute(staticOptions);
+ if (hasStaticOptions.has_value()) {
+ if (failed(hasStaticOptions.value()))
+ return failure();
+ return success();
+ }
+
+ return success();
+}
+
+static void printApplyRegisteredPassOptions(OpAsmPrinter &printer,
+ Operation *op, Value dynamicOptions,
+ StringAttr staticOptions) {
+ if (dynamicOptions) {
+ printer.printOperand(dynamicOptions);
+ } else if (!staticOptions.getValue().empty()) {
+ printer.printAttribute(staticOptions);
+ }
+}
+
//===----------------------------------------------------------------------===//
// CastOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Transform/test-pass-application.mlir b/mlir/test/Dialect/Transform/test-pass-application.mlir
index 3a40b462b8270..e8e0f63b28096 100644
--- a/mlir/test/Dialect/Transform/test-pass-application.mlir
+++ b/mlir/test/Dialect/Transform/test-pass-application.mlir
@@ -79,7 +79,7 @@ module attributes {transform.with_named_sequence} {
%1 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
// expected-error @below {{failed to add pass or pass pipeline to pipeline: canonicalize}}
// expected-error @below {{<Pass-Options-Parser>: no such option invalid-option}}
- transform.apply_registered_pass "canonicalize" to %1 {options = "invalid-option=1"} : (!transform.any_op) -> !transform.any_op
+ transform.apply_registered_pass "canonicalize" with options = "invalid-option=1" to %1 : (!transform.any_op) -> !transform.any_op
transform.yield
}
}
@@ -94,7 +94,56 @@ func.func @valid_pass_option() {
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op) {
%1 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
- transform.apply_registered_pass "canonicalize" to %1 {options = "top-down=false"} : (!transform.any_op) -> !transform.any_op
+ transform.apply_registered_pass "canonicalize" with options = "top-down=false" to %1 : (!transform.any_op) -> !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+// CHECK-LABEL: func @valid_dynamic_pass_option()
+func.func @valid_dynamic_pass_option() {
+ return
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op) {
+ %1 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %pass_options = transform.param.constant "top-down=false" -> !transform.any_param
+ transform.apply_registered_pass "canonicalize" with options = %pass_options to %1 : (!transform.any_param, !transform.any_op) -> !transform.any_op
+ transform.yield
+ }
+}
+// -----
+
+func.func @invalid_pass_option_param() {
+ return
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op) {
+ %1 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %pass_options = transform.param.constant 42 -> !transform.any_param
+ // expected-error @below {{options passed as a param must be a string, got 42}}
+ transform.apply_registered_pass "canonicalize" with options = %pass_options to %1 : (!transform.any_param, !transform.any_op) -> !transform.any_op
+ transform.apply_registered_pass "canonicalize" with options = "invalid-option=1" to %1 : (!transform.any_op) -> !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+func.func @too_many_pass_option_params() {
+ return
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op) {
+ %1 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %x = transform.param.constant "x" -> !transform.any_param
+ %pass_options = transform.merge_handles %x, %x : !transform.any_param
+ // expected-error @below {{options passed as a param must be a single value, got 2}}
+ transform.apply_registered_pass "canonicalize" with options = %pass_options to %1 : (!transform.any_param, !transform.any_op) -> !transform.any_op
transform.yield
}
}
``````````
</details>
https://github.com/llvm/llvm-project/pull/142683
More information about the Mlir-commits
mailing list