[Mlir-commits] [mlir] [MLIR][Transform] apply_registered_pass: support ListOptions (PR #144026)
Rolf Morel
llvmlistbot at llvm.org
Mon Jun 16 03:47:35 PDT 2025
https://github.com/rolfmorel updated https://github.com/llvm/llvm-project/pull/144026
>From ec61920910979e53579ec144e92b677504c0254b Mon Sep 17 00:00:00 2001
From: Rolf Morel <rolf.morel at intel.com>
Date: Fri, 13 Jun 2025 00:13:42 -0700
Subject: [PATCH 1/7] [MLIR][Transform] apply_registered_pass: support
ListOptions as params
Interpret the multiple values associated to a param as a
comma-separated list, i.e. as the analog of a ListOption on a pass.
---
.../lib/Dialect/Transform/IR/TransformOps.cpp | 28 +++++-----
.../Transform/test-pass-application.mlir | 52 ++++++++++++-------
2 files changed, 45 insertions(+), 35 deletions(-)
diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
index 582d082153bef..bfe6416987629 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
@@ -791,6 +791,12 @@ transform::ApplyRegisteredPassOp::apply(transform::TransformRewriter &rewriter,
std::string options;
llvm::raw_string_ostream optionsStream(options); // For "printing" attrs.
+ auto appendValueAttr = [&](Attribute valueAttr) {
+ if (auto strAttr = dyn_cast<StringAttr>(valueAttr))
+ optionsStream << strAttr.getValue().str();
+ else
+ valueAttr.print(optionsStream, /*elideType=*/true);
+ };
OperandRange dynamicOptions = getDynamicOptions();
for (auto [idx, namedAttribute] : llvm::enumerate(getOptions())) {
@@ -799,7 +805,6 @@ transform::ApplyRegisteredPassOp::apply(transform::TransformRewriter &rewriter,
optionsStream << namedAttribute.getName().str(); // Append the key.
optionsStream << "="; // And the key-value separator.
- Attribute valueAttrToAppend;
if (auto paramOperandIndex =
dyn_cast<transform::ParamOperandAttr>(namedAttribute.getValue())) {
// The corresponding value attribute is passed in via a param.
@@ -810,22 +815,15 @@ transform::ApplyRegisteredPassOp::apply(transform::TransformRewriter &rewriter,
"should be the same as the number of options passed as params");
ArrayRef<Attribute> dynamicOption =
state.getParams(dynamicOptions[dynamicOptionIdx]);
- if (dynamicOption.size() != 1)
- return emitSilenceableError()
- << "options passed as a param must have "
- "a single value associated, param "
- << dynamicOptionIdx << " associates " << dynamicOption.size();
- valueAttrToAppend = dynamicOption[0];
+ // Append all attributes associated to the param, separated by commas.
+ for (auto [idx, associatedAttr] : llvm::enumerate(dynamicOption)) {
+ if (idx > 0)
+ optionsStream << ",";
+ appendValueAttr(associatedAttr);
+ }
} else {
// Value is a static attribute.
- valueAttrToAppend = namedAttribute.getValue();
- }
-
- // Append string representation of value attribute.
- if (auto strAttr = dyn_cast<StringAttr>(valueAttrToAppend)) {
- optionsStream << strAttr.getValue().str();
- } else {
- valueAttrToAppend.print(optionsStream, /*elideType=*/true);
+ appendValueAttr(namedAttribute.getValue());
}
}
optionsStream.flush();
diff --git a/mlir/test/Dialect/Transform/test-pass-application.mlir b/mlir/test/Dialect/Transform/test-pass-application.mlir
index 1d1be9eda3496..407dfa3823436 100644
--- a/mlir/test/Dialect/Transform/test-pass-application.mlir
+++ b/mlir/test/Dialect/Transform/test-pass-application.mlir
@@ -164,6 +164,38 @@ module attributes {transform.with_named_sequence} {
// -----
+// CHECK-LABEL: func private @valid_dynamic_pass_list_option()
+module {
+ func.func @valid_dynamic_pass_list_option() {
+ return
+ }
+
+ // CHECK: func @a()
+ func.func @a() {
+ return
+ }
+ // CHECK: func @b()
+ func.func @b() {
+ return
+ }
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op) {
+ %1 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %2 = transform.get_parent_op %1 { deduplicate } : (!transform.any_op) -> !transform.any_op
+ %symbol_a = transform.param.constant "a" -> !transform.any_param
+ %symbol_b = transform.param.constant "b" -> !transform.any_param
+ %multiple_symbol_names = transform.merge_handles %symbol_a, %symbol_b : !transform.any_param
+ transform.apply_registered_pass "symbol-privatize"
+ with options = { exclude = %multiple_symbol_names } to %2
+ : (!transform.any_op, !transform.any_param) -> !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
func.func @invalid_options_as_str() {
return
}
@@ -262,26 +294,6 @@ module attributes {transform.with_named_sequence} {
// -----
-func.func @too_many_pass_option_params() {
- return
-}
-
-module attributes {transform.with_named_sequence} {
- transform.named_sequence @__transform_main(%arg1: !transform.any_op) {
- %1 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
- %x = transform.param.constant true -> !transform.any_param
- %y = transform.param.constant false -> !transform.any_param
- %topdown_options = transform.merge_handles %x, %y : !transform.any_param
- // expected-error @below {{options passed as a param must have a single value associated, param 0 associates 2}}
- transform.apply_registered_pass "canonicalize"
- with options = { "top-down" = %topdown_options } to %1
- : (!transform.any_op, !transform.any_param) -> !transform.any_op
- transform.yield
- }
-}
-
-// -----
-
module attributes {transform.with_named_sequence} {
// expected-error @below {{trying to schedule a pass on an unsupported operation}}
// expected-note @below {{target op}}
>From ff632049e28e0d2ae63dabe21283876d1c0467e7 Mon Sep 17 00:00:00 2001
From: Rolf Morel <rolf.morel at intel.com>
Date: Fri, 13 Jun 2025 00:45:10 -0700
Subject: [PATCH 2/7] Also support passing in an ArrayAttr as a ListOption,
also through params
---
.../lib/Dialect/Transform/IR/TransformOps.cpp | 12 +++-
.../Transform/test-pass-application.mlir | 63 ++++++++++++++++++-
2 files changed, 70 insertions(+), 5 deletions(-)
diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
index bfe6416987629..0538faf5b3ba8 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
@@ -791,11 +791,17 @@ transform::ApplyRegisteredPassOp::apply(transform::TransformRewriter &rewriter,
std::string options;
llvm::raw_string_ostream optionsStream(options); // For "printing" attrs.
- auto appendValueAttr = [&](Attribute valueAttr) {
- if (auto strAttr = dyn_cast<StringAttr>(valueAttr))
+ std::function<void(Attribute)> appendValueAttr = [&](Attribute valueAttr) {
+ if (auto arrayAttr = dyn_cast<ArrayAttr>(valueAttr)) {
+ for (auto [idx, eltAttr] : llvm::enumerate(arrayAttr)) {
+ appendValueAttr(eltAttr);
+ optionsStream << ",";
+ }
+ } else if (auto strAttr = dyn_cast<StringAttr>(valueAttr)) {
optionsStream << strAttr.getValue().str();
- else
+ } else {
valueAttr.print(optionsStream, /*elideType=*/true);
+ }
};
OperandRange dynamicOptions = getDynamicOptions();
diff --git a/mlir/test/Dialect/Transform/test-pass-application.mlir b/mlir/test/Dialect/Transform/test-pass-application.mlir
index 407dfa3823436..f7909f4c035d9 100644
--- a/mlir/test/Dialect/Transform/test-pass-application.mlir
+++ b/mlir/test/Dialect/Transform/test-pass-application.mlir
@@ -164,9 +164,9 @@ module attributes {transform.with_named_sequence} {
// -----
-// CHECK-LABEL: func private @valid_dynamic_pass_list_option()
+// CHECK-LABEL: func private @valid_multiple_params_as_list_option()
module {
- func.func @valid_dynamic_pass_list_option() {
+ func.func @valid_multiple_params_as_list_option() {
return
}
@@ -196,6 +196,65 @@ module attributes {transform.with_named_sequence} {
// -----
+// CHECK-LABEL: func private @valid_array_attr_as_list_option()
+module {
+ func.func @valid_array_attr_param_as_list_option() {
+ return
+ }
+
+ // CHECK: func @a()
+ func.func @a() {
+ return
+ }
+ // CHECK: func @b()
+ func.func @b() {
+ return
+ }
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op) {
+ %1 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %2 = transform.get_parent_op %1 { deduplicate } : (!transform.any_op) -> !transform.any_op
+ transform.apply_registered_pass "symbol-privatize"
+ with options = { exclude = ["a", "b"] } to %2
+ : (!transform.any_op) -> !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+// CHECK-LABEL: func private @valid_array_attr_param_as_list_option()
+module {
+ func.func @valid_array_attr_param_as_list_option() {
+ return
+ }
+
+ // CHECK: func @a()
+ func.func @a() {
+ return
+ }
+ // CHECK: func @b()
+ func.func @b() {
+ return
+ }
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op) {
+ %1 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %2 = transform.get_parent_op %1 { deduplicate } : (!transform.any_op) -> !transform.any_op
+ %multiple_symbol_names = transform.param.constant ["a","b"] -> !transform.any_param
+ transform.apply_registered_pass "symbol-privatize"
+ with options = { exclude = %multiple_symbol_names } to %2
+ : (!transform.any_op, !transform.any_param) -> !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
func.func @invalid_options_as_str() {
return
}
>From 7767260b9f81d664d064c1579acd7d41cb32b665 Mon Sep 17 00:00:00 2001
From: Rolf Morel <rolf.morel at intel.com>
Date: Fri, 13 Jun 2025 01:02:21 -0700
Subject: [PATCH 3/7] Minor clean-up
---
mlir/lib/Dialect/Transform/IR/TransformOps.cpp | 18 +++++-------------
.../Transform/test-pass-application.mlir | 2 +-
2 files changed, 6 insertions(+), 14 deletions(-)
diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
index 0538faf5b3ba8..651462ee6ad03 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
@@ -792,16 +792,12 @@ transform::ApplyRegisteredPassOp::apply(transform::TransformRewriter &rewriter,
std::string options;
llvm::raw_string_ostream optionsStream(options); // For "printing" attrs.
std::function<void(Attribute)> appendValueAttr = [&](Attribute valueAttr) {
- if (auto arrayAttr = dyn_cast<ArrayAttr>(valueAttr)) {
- for (auto [idx, eltAttr] : llvm::enumerate(arrayAttr)) {
- appendValueAttr(eltAttr);
- optionsStream << ",";
- }
- } else if (auto strAttr = dyn_cast<StringAttr>(valueAttr)) {
+ if (auto arrayAttr = dyn_cast<ArrayAttr>(valueAttr))
+ llvm::interleave(arrayAttr, optionsStream, appendValueAttr, ",");
+ else if (auto strAttr = dyn_cast<StringAttr>(valueAttr))
optionsStream << strAttr.getValue().str();
- } else {
+ else
valueAttr.print(optionsStream, /*elideType=*/true);
- }
};
OperandRange dynamicOptions = getDynamicOptions();
@@ -822,11 +818,7 @@ transform::ApplyRegisteredPassOp::apply(transform::TransformRewriter &rewriter,
ArrayRef<Attribute> dynamicOption =
state.getParams(dynamicOptions[dynamicOptionIdx]);
// Append all attributes associated to the param, separated by commas.
- for (auto [idx, associatedAttr] : llvm::enumerate(dynamicOption)) {
- if (idx > 0)
- optionsStream << ",";
- appendValueAttr(associatedAttr);
- }
+ llvm::interleave(dynamicOption, optionsStream, appendValueAttr, ",");
} else {
// Value is a static attribute.
appendValueAttr(namedAttribute.getValue());
diff --git a/mlir/test/Dialect/Transform/test-pass-application.mlir b/mlir/test/Dialect/Transform/test-pass-application.mlir
index f7909f4c035d9..7262a8fe9faee 100644
--- a/mlir/test/Dialect/Transform/test-pass-application.mlir
+++ b/mlir/test/Dialect/Transform/test-pass-application.mlir
@@ -198,7 +198,7 @@ module attributes {transform.with_named_sequence} {
// CHECK-LABEL: func private @valid_array_attr_as_list_option()
module {
- func.func @valid_array_attr_param_as_list_option() {
+ func.func @valid_array_attr_as_list_option() {
return
}
>From fda3f63c7166d14b7505895ab25d9733d9bcec31 Mon Sep 17 00:00:00 2001
From: Rolf Morel <rolf.morel at intel.com>
Date: Fri, 13 Jun 2025 07:17:32 -0700
Subject: [PATCH 4/7] Allow ArrayAttrs options to contain param-operands
---
.../mlir/Dialect/Transform/IR/TransformOps.td | 2 +-
.../lib/Dialect/Transform/IR/TransformOps.cpp | 164 ++++++++++++------
.../Transform/test-pass-application.mlir | 39 ++++-
3 files changed, 148 insertions(+), 57 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
index 0aa750e625436..140c9c66f3918 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
@@ -418,7 +418,7 @@ def ApplyRegisteredPassOp : TransformDialectOp<"apply_registered_pass",
with options = { "top-down" = false,
"max-iterations" = %max_iter,
"test-convergence" = true,
- "max-num-rewrites" = %max_rewrites }
+ "max-num-rewrites" = %max_rewrites }
to %module
: (!transform.any_param, !transform.any_param, !transform.any_op) -> !transform.any_op
```
diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
index 651462ee6ad03..bb9bdd70625e4 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
@@ -788,42 +788,47 @@ transform::ApplyRegisteredPassOp::apply(transform::TransformRewriter &rewriter,
// Obtain a single options-string to pass to the pass(-pipeline) from options
// passed in as a dictionary of keys mapping to values which are either
// attributes or param-operands pointing to attributes.
+ OperandRange dynamicOptions = getDynamicOptions();
std::string options;
llvm::raw_string_ostream optionsStream(options); // For "printing" attrs.
- std::function<void(Attribute)> appendValueAttr = [&](Attribute valueAttr) {
- if (auto arrayAttr = dyn_cast<ArrayAttr>(valueAttr))
- llvm::interleave(arrayAttr, optionsStream, appendValueAttr, ",");
- else if (auto strAttr = dyn_cast<StringAttr>(valueAttr))
- optionsStream << strAttr.getValue().str();
- else
- valueAttr.print(optionsStream, /*elideType=*/true);
- };
- OperandRange dynamicOptions = getDynamicOptions();
- for (auto [idx, namedAttribute] : llvm::enumerate(getOptions())) {
- if (idx > 0)
- optionsStream << " "; // Interleave options separator.
- optionsStream << namedAttribute.getName().str(); // Append the key.
- optionsStream << "="; // And the key-value separator.
-
- if (auto paramOperandIndex =
- dyn_cast<transform::ParamOperandAttr>(namedAttribute.getValue())) {
- // The corresponding value attribute is passed in via a param.
+ // A helper to convert an option's attribute value into a corresponding
+ // string representation, with the ability to obtain the attr(s) from a param.
+ std::function<void(Attribute)> appendValueAttr = [&](Attribute valueAttr) {
+ if (auto paramOperand = dyn_cast<transform::ParamOperandAttr>(valueAttr)) {
+ // The corresponding value attribute(s) is/are passed in via a param.
// Obtain the param-operand via its specified index.
- size_t dynamicOptionIdx = paramOperandIndex.getIndex().getInt();
+ size_t dynamicOptionIdx = paramOperand.getIndex().getInt();
assert(dynamicOptionIdx < dynamicOptions.size() &&
- "number of dynamic option markers (UnitAttr) in options ArrayAttr "
+ "the number of ParamOperandAttrs in the options DictionaryAttr"
"should be the same as the number of options passed as params");
- ArrayRef<Attribute> dynamicOption =
+ ArrayRef<Attribute> attrsAssociatedToParam =
state.getParams(dynamicOptions[dynamicOptionIdx]);
- // Append all attributes associated to the param, separated by commas.
- llvm::interleave(dynamicOption, optionsStream, appendValueAttr, ",");
+ // Recursive so as to append all attrs associated to the param.
+ llvm::interleave(attrsAssociatedToParam, optionsStream, appendValueAttr,
+ ",");
+ } else if (auto arrayAttr = dyn_cast<ArrayAttr>(valueAttr)) {
+ // Recursive so as to append all nested attrs of the array.
+ llvm::interleave(arrayAttr, optionsStream, appendValueAttr, ",");
+ } else if (auto strAttr = dyn_cast<StringAttr>(valueAttr)) {
+ // Convert to unquoted string.
+ optionsStream << strAttr.getValue().str();
} else {
- // Value is a static attribute.
- appendValueAttr(namedAttribute.getValue());
+ // For all other attributes, ask the attr to print itself (without type).
+ valueAttr.print(optionsStream, /*elideType=*/true);
}
- }
+ };
+
+ // Convert the options DictionaryAttr into a single string.
+ llvm::interleave(
+ getOptions(), optionsStream,
+ [&](auto namedAttribute) {
+ optionsStream << namedAttribute.getName().str(); // Append the key.
+ optionsStream << "="; // And the key-value separator.
+ appendValueAttr(namedAttribute.getValue()); // And the attr's str repr.
+ },
+ " ");
optionsStream.flush();
// Get pass or pass pipeline from registry.
@@ -874,23 +879,30 @@ static ParseResult parseApplyRegisteredPassOptions(
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &dynamicOptions) {
// Construct the options DictionaryAttr per a `{ key = value, ... }` syntax.
SmallVector<NamedAttribute> keyValuePairs;
-
size_t dynamicOptionsIdx = 0;
- auto parseKeyValuePair = [&]() -> ParseResult {
- // Parse items of the form `key = value` where `key` is a bare identifier or
- // a string and `value` is either an attribute or an operand.
- std::string key;
- Attribute valueAttr;
- if (parser.parseOptionalKeywordOrString(&key))
- return parser.emitError(parser.getCurrentLocation())
- << "expected key to either be an identifier or a string";
- if (key.empty())
- return failure();
+ // Helper for allowing parsing of option values which can be of the form:
+ // - a normal attribute
+ // - an operand (which would be converted to an attr referring to the operand)
+ // - ArrayAttrs containing the foregoing (in correspondence with ListOptions)
+ std::function<ParseResult(Attribute &)> parseValue =
+ [&](Attribute &valueAttr) -> ParseResult {
+ // Allow for array syntax, e.g. `[0 : i64, %param, true, %other_param]`:
+ if (succeeded(parser.parseOptionalLSquare())) {
+ SmallVector<Attribute> attrs;
- if (parser.parseEqual())
- return parser.emitError(parser.getCurrentLocation())
- << "expected '=' after key in key-value pair";
+ // Recursively parse the array's elements, which might be operands.
+ if (parser.parseCommaSeparatedList(
+ AsmParser::Delimiter::None,
+ [&]() -> ParseResult { return parseValue(attrs.emplace_back()); },
+ " in options dictionary") ||
+ parser.parseRSquare())
+ return failure(); // NB: Attempted parse should've output error message.
+
+ valueAttr = ArrayAttr::get(parser.getContext(), attrs);
+
+ return success();
+ }
// Parse the value, which can be either an attribute or an operand.
OptionalParseResult parsedValueAttr =
@@ -899,9 +911,7 @@ static ParseResult parseApplyRegisteredPassOptions(
OpAsmParser::UnresolvedOperand operand;
ParseResult parsedOperand = parser.parseOperand(operand);
if (failed(parsedOperand))
- return parser.emitError(parser.getCurrentLocation())
- << "expected a valid attribute or operand as value associated "
- << "to key '" << key << "'";
+ return failure(); // NB: Attempted parse should've output error message.
// To make use of the operand, we need to store it in the options dict.
// As SSA-values cannot occur in attributes, what we do instead is store
// an attribute in its place that contains the index of the param-operand,
@@ -920,7 +930,30 @@ static ParseResult parseApplyRegisteredPassOptions(
<< "in the generic print format";
}
+ return success();
+ };
+
+ // Helper for `key = value`-pair parsing where `key` is a bare identifier or a
+ // string and `value` looks like either an attribute or an operand-in-an-attr.
+ std::function<ParseResult()> parseKeyValuePair = [&]() -> ParseResult {
+ std::string key;
+ Attribute valueAttr;
+
+ if (failed(parser.parseOptionalKeywordOrString(&key)) || key.empty())
+ return parser.emitError(parser.getCurrentLocation())
+ << "expected key to either be an identifier or a string";
+
+ if (failed(parser.parseEqual()))
+ return parser.emitError(parser.getCurrentLocation())
+ << "expected '=' after key in key-value pair";
+
+ if (failed(parseValue(valueAttr)))
+ return parser.emitError(parser.getCurrentLocation())
+ << "expected a valid attribute or operand as value associated "
+ << "to key '" << key << "'";
+
keyValuePairs.push_back(NamedAttribute(key, valueAttr));
+
return success();
};
@@ -947,16 +980,27 @@ static void printApplyRegisteredPassOptions(OpAsmPrinter &printer,
if (options.empty())
return;
- printer << "{";
- llvm::interleaveComma(options, printer, [&](NamedAttribute namedAttribute) {
- printer << namedAttribute.getName() << " = ";
- Attribute value = namedAttribute.getValue();
- if (auto indexAttr = dyn_cast<transform::ParamOperandAttr>(value)) {
+ std::function<void(Attribute)> printOptionValue = [&](Attribute valueAttr) {
+ if (auto paramOperandAttr =
+ dyn_cast<transform::ParamOperandAttr>(valueAttr)) {
// Resolve index of param-operand to its actual SSA-value and print that.
- printer.printOperand(dynamicOptions[indexAttr.getIndex().getInt()]);
+ printer.printOperand(
+ dynamicOptions[paramOperandAttr.getIndex().getInt()]);
+ } else if (auto arrayAttr = dyn_cast<ArrayAttr>(valueAttr)) {
+ // This case is so that ArrayAttr-contained operands are pretty-printed.
+ printer << "[";
+ llvm::interleaveComma(arrayAttr, printer, printOptionValue);
+ printer << "]";
} else {
- printer.printAttribute(value);
+ printer.printAttribute(valueAttr);
}
+ };
+
+ printer << "{";
+ llvm::interleaveComma(options, printer, [&](NamedAttribute namedAttribute) {
+ printer << namedAttribute.getName();
+ printer << " = ";
+ printOptionValue(namedAttribute.getValue());
});
printer << "}";
}
@@ -966,9 +1010,11 @@ LogicalResult transform::ApplyRegisteredPassOp::verify() {
// and references to dynamic options in the options dictionary.
auto dynamicOptions = SmallVector<Value>(getDynamicOptions());
- for (NamedAttribute namedAttr : getOptions())
- if (auto paramOperand =
- dyn_cast<transform::ParamOperandAttr>(namedAttr.getValue())) {
+
+ // Helper for option values to mark seen operands as having been seen (once).
+ std::function<LogicalResult(Attribute)> checkOptionValue =
+ [&](Attribute valueAttr) -> LogicalResult {
+ if (auto paramOperand = dyn_cast<transform::ParamOperandAttr>(valueAttr)) {
size_t dynamicOptionIdx = paramOperand.getIndex().getInt();
if (dynamicOptionIdx < 0 || dynamicOptionIdx >= dynamicOptions.size())
return emitOpError()
@@ -979,8 +1025,20 @@ LogicalResult transform::ApplyRegisteredPassOp::verify() {
return emitOpError() << "dynamic option index " << dynamicOptionIdx
<< " is already used in options";
dynamicOptions[dynamicOptionIdx] = nullptr; // Mark this option as used.
+ } else if (auto arrayAttr = dyn_cast<ArrayAttr>(valueAttr)) {
+ // Recurse into ArrayAttrs as they may contain references to operands.
+ for (auto eltAttr : arrayAttr)
+ if (failed(checkOptionValue(eltAttr)))
+ return failure();
}
+ return success();
+ };
+
+ for (NamedAttribute namedAttr : getOptions())
+ if (failed(checkOptionValue(namedAttr.getValue())))
+ return failure();
+ // All dynamicOptions-params seen in the dict will have been set to null.
for (Value dynamicOption : dynamicOptions)
if (dynamicOption)
return emitOpError() << "a param operand does not have a corresponding "
diff --git a/mlir/test/Dialect/Transform/test-pass-application.mlir b/mlir/test/Dialect/Transform/test-pass-application.mlir
index 7262a8fe9faee..e21e750011ce7 100644
--- a/mlir/test/Dialect/Transform/test-pass-application.mlir
+++ b/mlir/test/Dialect/Transform/test-pass-application.mlir
@@ -164,9 +164,9 @@ module attributes {transform.with_named_sequence} {
// -----
-// CHECK-LABEL: func private @valid_multiple_params_as_list_option()
+// CHECK-LABEL: func private @valid_multiple_values_as_list_option_single_param()
module {
- func.func @valid_multiple_params_as_list_option() {
+ func.func @valid_multiple_values_as_list_option_single_param() {
return
}
@@ -253,6 +253,38 @@ module attributes {transform.with_named_sequence} {
}
}
+// -----
+
+// CHECK-LABEL: func private @valid_multiple_params_as_single_list_option()
+module {
+ func.func @valid_multiple_params_as_single_list_option() {
+ return
+ }
+
+ // CHECK: func @a()
+ func.func @a() {
+ return
+ }
+ // CHECK: func @b()
+ func.func @b() {
+ return
+ }
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op) {
+ %1 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %2 = transform.get_parent_op %1 { deduplicate } : (!transform.any_op) -> !transform.any_op
+ %symbol_a = transform.param.constant "a" -> !transform.any_param
+ %symbol_b = transform.param.constant "b" -> !transform.any_param
+ transform.apply_registered_pass "symbol-privatize"
+ with options = { exclude = [%symbol_a, %symbol_b] } to %2
+ : (!transform.any_op, !transform.any_param, !transform.any_param) -> !transform.any_op
+ transform.yield
+ }
+}
+
+
// -----
func.func @invalid_options_as_str() {
@@ -294,7 +326,8 @@ func.func @invalid_options_due_to_reserved_attr() {
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op) {
%1 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
- // expected-error @+2 {{the param_operand attribute is a marker reserved for indicating a value will be passed via params and is only used in the generic print format}}
+ // expected-error @+3 {{the param_operand attribute is a marker reserved for indicating a value will be passed via params and is only used in the generic print format}}
+ // expected-error @+2 {{expected a valid attribute or operand as value associated to key 'top-down'}}
%2 = transform.apply_registered_pass "canonicalize"
with options = { "top-down" = #transform.param_operand<index=0> } to %1 : (!transform.any_op) -> !transform.any_op
transform.yield
>From 01ee8bbd9706432ac47bdf630aefcd472b96e93c Mon Sep 17 00:00:00 2001
From: Rolf Morel <rolf.morel at intel.com>
Date: Fri, 13 Jun 2025 07:59:32 -0700
Subject: [PATCH 5/7] Update Python-bindings
---
.../mlir/dialects/transform/__init__.py | 41 ++++++-------
.../Transform/test-pass-application.mlir | 1 -
mlir/test/python/dialects/transform.py | 60 ++++++++++++++-----
3 files changed, 65 insertions(+), 37 deletions(-)
diff --git a/mlir/python/mlir/dialects/transform/__init__.py b/mlir/python/mlir/dialects/transform/__init__.py
index bfe96b1b3e5d4..b075919d1ef0f 100644
--- a/mlir/python/mlir/dialects/transform/__init__.py
+++ b/mlir/python/mlir/dialects/transform/__init__.py
@@ -219,6 +219,11 @@ def __init__(
super().__init__(_get_op_results_or_values(operands), loc=loc, ip=ip)
+OptionValueTypes = Union[
+ Sequence["OptionValueTypes"], Attribute, Value, Operation, OpView, str, int, bool
+]
+
+
@_ods_cext.register_operation(_Dialect, replace=True)
class ApplyRegisteredPassOp(ApplyRegisteredPassOp):
def __init__(
@@ -227,12 +232,7 @@ def __init__(
target: Union[Operation, Value, OpView],
pass_name: Union[str, StringAttr],
*,
- options: Optional[
- Dict[
- Union[str, StringAttr],
- Union[Attribute, Value, Operation, OpView, str, int, bool],
- ]
- ] = None,
+ options: Optional[Dict[Union[str, StringAttr], OptionValueTypes]] = None,
loc=None,
ip=None,
):
@@ -243,26 +243,32 @@ def __init__(
context = (loc and loc.context) or Context.current
cur_param_operand_idx = 0
- for key, value in options.items() if options is not None else {}:
- if isinstance(key, StringAttr):
- key = key.value
+ def option_value_to_attr(value):
+ nonlocal cur_param_operand_idx
if isinstance(value, (Value, Operation, OpView)):
dynamic_options.append(_get_op_result_or_value(value))
- options_dict[key] = ParamOperandAttr(cur_param_operand_idx, context)
cur_param_operand_idx += 1
+ return ParamOperandAttr(cur_param_operand_idx - 1, context)
elif isinstance(value, Attribute):
- options_dict[key] = value
+ return value
# The following cases auto-convert Python values to attributes.
elif isinstance(value, bool):
- options_dict[key] = BoolAttr.get(value)
+ return BoolAttr.get(value)
elif isinstance(value, int):
default_int_type = IntegerType.get_signless(64, context)
- options_dict[key] = IntegerAttr.get(default_int_type, value)
+ return IntegerAttr.get(default_int_type, value)
elif isinstance(value, str):
- options_dict[key] = StringAttr.get(value)
+ return StringAttr.get(value)
+ elif isinstance(value, Sequence):
+ return ArrayAttr.get([option_value_to_attr(elt) for elt in value])
else:
raise TypeError(f"Unsupported option type: {type(value)}")
+
+ for key, value in options.items() if options is not None else {}:
+ if isinstance(key, StringAttr):
+ key = key.value
+ options_dict[key] = option_value_to_attr(value)
super().__init__(
result,
_get_op_result_or_value(target),
@@ -279,12 +285,7 @@ def apply_registered_pass(
target: Union[Operation, Value, OpView],
pass_name: Union[str, StringAttr],
*,
- options: Optional[
- Dict[
- Union[str, StringAttr],
- Union[Attribute, Value, Operation, OpView, str, int, bool],
- ]
- ] = None,
+ options: Optional[Dict[Union[str, StringAttr], OptionValueTypes]] = None,
loc=None,
ip=None,
) -> Value:
diff --git a/mlir/test/Dialect/Transform/test-pass-application.mlir b/mlir/test/Dialect/Transform/test-pass-application.mlir
index e21e750011ce7..ce8f69c58701d 100644
--- a/mlir/test/Dialect/Transform/test-pass-application.mlir
+++ b/mlir/test/Dialect/Transform/test-pass-application.mlir
@@ -284,7 +284,6 @@ module attributes {transform.with_named_sequence} {
}
}
-
// -----
func.func @invalid_options_as_str() {
diff --git a/mlir/test/python/dialects/transform.py b/mlir/test/python/dialects/transform.py
index eeb95605d7a9a..aeadfcb596526 100644
--- a/mlir/test/python/dialects/transform.py
+++ b/mlir/test/python/dialects/transform.py
@@ -256,30 +256,45 @@ def testReplicateOp(module: Module):
# CHECK: %{{.*}} = replicate num(%[[FIRST]]) %[[SECOND]]
+# CHECK-LABEL: TEST: testApplyRegisteredPassOp
@run
def testApplyRegisteredPassOp(module: Module):
+ # CHECK: transform.sequence
sequence = transform.SequenceOp(
transform.FailurePropagationMode.Propagate, [], transform.AnyOpType.get()
)
with InsertionPoint(sequence.body):
+ # CHECK: %{{.*}} = apply_registered_pass "canonicalize" to {{.*}} : (!transform.any_op) -> !transform.any_op
mod = transform.ApplyRegisteredPassOp(
transform.AnyOpType.get(), sequence.bodyTarget, "canonicalize"
)
+ # CHECK: %{{.*}} = apply_registered_pass "canonicalize"
+ # CHECK-SAME: with options = {"top-down" = false}
+ # CHECK-SAME: to {{.*}} : (!transform.any_op) -> !transform.any_op
mod = transform.ApplyRegisteredPassOp(
transform.AnyOpType.get(),
mod.result,
"canonicalize",
options={"top-down": BoolAttr.get(False)},
)
+ # CHECK: %[[MAX_ITER:.+]] = transform.param.constant
max_iter = transform.param_constant(
transform.AnyParamType.get(),
IntegerAttr.get(IntegerType.get_signless(64), 10),
)
+ # CHECK: %[[MAX_REWRITE:.+]] = transform.param.constant
max_rewrites = transform.param_constant(
transform.AnyParamType.get(),
IntegerAttr.get(IntegerType.get_signless(64), 1),
)
- transform.apply_registered_pass(
+ # CHECK: %{{.*}} = apply_registered_pass "canonicalize"
+ # NB: MLIR has sorted the dict lexicographically by key:
+ # CHECK-SAME: with options = {"max-iterations" = %[[MAX_ITER]],
+ # CHECK-SAME: "max-rewrites" = %[[MAX_REWRITE]],
+ # CHECK-SAME: "test-convergence" = true,
+ # CHECK-SAME: "top-down" = false}
+ # CHECK-SAME: to %{{.*}} : (!transform.any_op, !transform.any_param, !transform.any_param) -> !transform.any_op
+ mod = transform.apply_registered_pass(
transform.AnyOpType.get(),
mod,
"canonicalize",
@@ -290,19 +305,32 @@ def testApplyRegisteredPassOp(module: Module):
"max-rewrites": max_rewrites,
},
)
+ # CHECK: %{{.*}} = apply_registered_pass "symbol-privatize"
+ # CHECK-SAME: with options = {"exclude" = ["a", "b"]}
+ # CHECK-SAME: to %{{.*}} : (!transform.any_op) -> !transform.any_op
+ mod = transform.apply_registered_pass(
+ transform.AnyOpType.get(),
+ mod,
+ "symbol-privatize",
+ options={ "exclude": ("a", "b") },
+ )
+ # CHECK: %[[SYMBOL_A:.+]] = transform.param.constant
+ symbol_a = transform.param_constant(
+ transform.AnyParamType.get(),
+ StringAttr.get("a")
+ )
+ # CHECK: %[[SYMBOL_B:.+]] = transform.param.constant
+ symbol_b = transform.param_constant(
+ transform.AnyParamType.get(),
+ StringAttr.get("b")
+ )
+ # CHECK: %{{.*}} = apply_registered_pass "symbol-privatize"
+ # CHECK-SAME: with options = {"exclude" = [%[[SYMBOL_A]], %[[SYMBOL_B]]]}
+ # CHECK-SAME: to %{{.*}} : (!transform.any_op, !transform.any_param, !transform.any_param) -> !transform.any_op
+ mod = transform.apply_registered_pass(
+ transform.AnyOpType.get(),
+ mod,
+ "symbol-privatize",
+ options={ "exclude": (symbol_a, symbol_b) },
+ )
transform.YieldOp()
- # CHECK-LABEL: TEST: testApplyRegisteredPassOp
- # CHECK: transform.sequence
- # CHECK: %{{.*}} = apply_registered_pass "canonicalize" to {{.*}} : (!transform.any_op) -> !transform.any_op
- # CHECK: %{{.*}} = apply_registered_pass "canonicalize"
- # CHECK-SAME: with options = {"top-down" = false}
- # CHECK-SAME: to {{.*}} : (!transform.any_op) -> !transform.any_op
- # CHECK: %[[MAX_ITER:.+]] = transform.param.constant
- # CHECK: %[[MAX_REWRITE:.+]] = transform.param.constant
- # CHECK: %{{.*}} = apply_registered_pass "canonicalize"
- # NB: MLIR has sorted the dict lexicographically by key:
- # CHECK-SAME: with options = {"max-iterations" = %[[MAX_ITER]],
- # CHECK-SAME: "max-rewrites" = %[[MAX_REWRITE]],
- # CHECK-SAME: "test-convergence" = true,
- # CHECK-SAME: "top-down" = false}
- # CHECK-SAME: to %{{.*}} : (!transform.any_op, !transform.any_param, !transform.any_param) -> !transform.any_op
>From abbac927d003953da75b6997d076b0f730873ab1 Mon Sep 17 00:00:00 2001
From: Rolf Morel <rolf.morel at intel.com>
Date: Fri, 13 Jun 2025 08:15:03 -0700
Subject: [PATCH 6/7] Fix Python formatting
---
mlir/test/python/dialects/transform.py | 10 ++++------
1 file changed, 4 insertions(+), 6 deletions(-)
diff --git a/mlir/test/python/dialects/transform.py b/mlir/test/python/dialects/transform.py
index aeadfcb596526..6c5e4e5505b1c 100644
--- a/mlir/test/python/dialects/transform.py
+++ b/mlir/test/python/dialects/transform.py
@@ -312,17 +312,15 @@ def testApplyRegisteredPassOp(module: Module):
transform.AnyOpType.get(),
mod,
"symbol-privatize",
- options={ "exclude": ("a", "b") },
+ options={"exclude": ("a", "b")},
)
# CHECK: %[[SYMBOL_A:.+]] = transform.param.constant
symbol_a = transform.param_constant(
- transform.AnyParamType.get(),
- StringAttr.get("a")
+ transform.AnyParamType.get(), StringAttr.get("a")
)
# CHECK: %[[SYMBOL_B:.+]] = transform.param.constant
symbol_b = transform.param_constant(
- transform.AnyParamType.get(),
- StringAttr.get("b")
+ transform.AnyParamType.get(), StringAttr.get("b")
)
# CHECK: %{{.*}} = apply_registered_pass "symbol-privatize"
# CHECK-SAME: with options = {"exclude" = [%[[SYMBOL_A]], %[[SYMBOL_B]]]}
@@ -331,6 +329,6 @@ def testApplyRegisteredPassOp(module: Module):
transform.AnyOpType.get(),
mod,
"symbol-privatize",
- options={ "exclude": (symbol_a, symbol_b) },
+ options={"exclude": (symbol_a, symbol_b)},
)
transform.YieldOp()
>From bcc104af4abbd5a2cb5e69f6e60c556d95011664 Mon Sep 17 00:00:00 2001
From: Rolf Morel <rolf.morel at intel.com>
Date: Mon, 16 Jun 2025 03:46:20 -0700
Subject: [PATCH 7/7] Update docs
---
mlir/include/mlir/Dialect/Transform/IR/TransformOps.td | 3 +++
1 file changed, 3 insertions(+)
diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
index 140c9c66f3918..62e66b3dabee8 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
@@ -423,6 +423,9 @@ def ApplyRegisteredPassOp : TransformDialectOp<"apply_registered_pass",
: (!transform.any_param, !transform.any_param, !transform.any_op) -> !transform.any_op
```
+ Options' values which are `ArrayAttr`s are converted to comma-separated
+ lists of options. Likewise for params which associate multiple values.
+
This op first looks for a pass pipeline with the specified name. If no such
pipeline exists, it looks for a pass with the specified name. If no such
pass exists either, this op fails definitely.
More information about the Mlir-commits
mailing list