[Mlir-commits] [mlir] [MLIR][Transform] apply_registered_pass op's options as a dict (PR #143159)
Rolf Morel
llvmlistbot at llvm.org
Wed Jun 11 06:55:13 PDT 2025
https://github.com/rolfmorel updated https://github.com/llvm/llvm-project/pull/143159
>From 461d7cfaf359ee07f34c9b3eb91f402b97afe312 Mon Sep 17 00:00:00 2001
From: Rolf Morel <rolf.morel at intel.com>
Date: Fri, 6 Jun 2025 08:13:44 -0700
Subject: [PATCH 1/7] [MLIR][Transform] friendlier Python-bindings
apply_registered_pass op
In particular, use similar syntax for providing options as in the
(pretty-)printed IR.
---
.../mlir/dialects/transform/__init__.py | 35 ++++++++++++++++++
mlir/test/python/dialects/transform.py | 36 +++++++++++++++++++
2 files changed, 71 insertions(+)
diff --git a/mlir/python/mlir/dialects/transform/__init__.py b/mlir/python/mlir/dialects/transform/__init__.py
index 5b158ec6b65fd..cdcdeadd54cd3 100644
--- a/mlir/python/mlir/dialects/transform/__init__.py
+++ b/mlir/python/mlir/dialects/transform/__init__.py
@@ -214,6 +214,41 @@ def __init__(
super().__init__(_get_op_results_or_values(operands), loc=loc, ip=ip)
+ at _ods_cext.register_operation(_Dialect, replace=True)
+class ApplyRegisteredPassOp(ApplyRegisteredPassOp):
+ def __init__(
+ self,
+ result: Type,
+ pass_name: Union[str, StringAttr],
+ target: Value,
+ *,
+ options: Sequence[Union[str, StringAttr, Value, Operation]] = [],
+ loc=None,
+ ip=None,
+ ):
+ static_options = []
+ dynamic_options = []
+ for opt in options:
+ if isinstance(opt, str):
+ static_options.append(StringAttr.get(opt))
+ elif isinstance(opt, StringAttr):
+ static_options.append(opt)
+ elif isinstance(opt, Value):
+ static_options.append(UnitAttr.get())
+ dynamic_options.append(_get_op_result_or_value(opt))
+ else:
+ raise TypeError(f"Unsupported option type: {type(opt)}")
+ super().__init__(
+ result,
+ pass_name,
+ dynamic_options,
+ target=_get_op_result_or_value(target),
+ options=static_options,
+ loc=loc,
+ ip=ip,
+ )
+
+
AnyOpTypeT = NewType("AnyOpType", AnyOpType)
diff --git a/mlir/test/python/dialects/transform.py b/mlir/test/python/dialects/transform.py
index 6ed4818fc9d2f..dc0987e769a09 100644
--- a/mlir/test/python/dialects/transform.py
+++ b/mlir/test/python/dialects/transform.py
@@ -254,3 +254,39 @@ def testReplicateOp(module: Module):
# CHECK: %[[FIRST:.+]] = pdl_match
# CHECK: %[[SECOND:.+]] = pdl_match
# CHECK: %{{.*}} = replicate num(%[[FIRST]]) %[[SECOND]]
+
+
+ at run
+def testApplyRegisteredPassOp(module: Module):
+ sequence = transform.SequenceOp(
+ transform.FailurePropagationMode.Propagate, [], transform.AnyOpType.get()
+ )
+ with InsertionPoint(sequence.body):
+ mod = transform.ApplyRegisteredPassOp(
+ transform.AnyOpType.get(), "canonicalize", sequence.bodyTarget
+ )
+ mod = transform.ApplyRegisteredPassOp(
+ transform.AnyOpType.get(), "canonicalize", mod, options=("top-down=false",)
+ )
+ max_iter = transform.param_constant(
+ transform.AnyParamType.get(), StringAttr.get("max-iterations=10")
+ )
+ max_rewrites = transform.param_constant(
+ transform.AnyParamType.get(), StringAttr.get("max-num-rewrites=1")
+ )
+ transform.ApplyRegisteredPassOp(
+ transform.AnyOpType.get(),
+ "canonicalize",
+ mod,
+ options=("top-down=false", max_iter, "test-convergence=true", max_rewrites),
+ )
+ transform.YieldOp()
+ # CHECK-LABEL: TEST: testApplyRegisteredPassOp
+ # CHECK: transform.sequence
+ # CHECK: %{{.*}} = apply_registered_pass "canonicalize" to {{.*}} : (!transform.any_op) -> !transform.any_op
+ # CHECK: %{{.*}} = apply_registered_pass "canonicalize" with options = "top-down=false" to {{.*}} : (!transform.any_op) -> !transform.any_op
+ # CHECK: %[[MAX_ITER:.+]] = transform.param.constant
+ # CHECK: %[[MAX_REWRITE:.+]] = transform.param.constant
+ # CHECK: %{{.*}} = apply_registered_pass "canonicalize"
+ # CHECK-SAME: with options = "top-down=false" %[[MAX_ITER]]
+ # CHECK-SAME: "test-convergence=true" %[[MAX_REWRITE]] to %{{.*}} : (!transform.any_param, !transform.any_param, !transform.any_op) -> !transform.any_op
>From 18360e7f5279bc89d35cef81be22b579faf0fb28 Mon Sep 17 00:00:00 2001
From: Rolf Morel <rolf.morel at intel.com>
Date: Fri, 6 Jun 2025 08:26:39 -0700
Subject: [PATCH 2/7] snake_case_helper
---
mlir/python/mlir/dialects/transform/__init__.py | 4 ++++
mlir/test/python/dialects/transform.py | 7 +++++--
2 files changed, 9 insertions(+), 2 deletions(-)
diff --git a/mlir/python/mlir/dialects/transform/__init__.py b/mlir/python/mlir/dialects/transform/__init__.py
index cdcdeadd54cd3..90282df49fb7d 100644
--- a/mlir/python/mlir/dialects/transform/__init__.py
+++ b/mlir/python/mlir/dialects/transform/__init__.py
@@ -249,6 +249,10 @@ def __init__(
)
+def apply_registered_pass(result, pass_name, target, *, options=[], loc=None, ip=None) -> Value:
+ return ApplyRegisteredPassOp(result=result, pass_name=pass_name, target=target, options=options, loc=loc, ip=ip).result
+
+
AnyOpTypeT = NewType("AnyOpType", AnyOpType)
diff --git a/mlir/test/python/dialects/transform.py b/mlir/test/python/dialects/transform.py
index dc0987e769a09..6492b58570814 100644
--- a/mlir/test/python/dialects/transform.py
+++ b/mlir/test/python/dialects/transform.py
@@ -266,7 +266,10 @@ def testApplyRegisteredPassOp(module: Module):
transform.AnyOpType.get(), "canonicalize", sequence.bodyTarget
)
mod = transform.ApplyRegisteredPassOp(
- transform.AnyOpType.get(), "canonicalize", mod, options=("top-down=false",)
+ transform.AnyOpType.get(),
+ "canonicalize",
+ mod.result,
+ options=("top-down=false",),
)
max_iter = transform.param_constant(
transform.AnyParamType.get(), StringAttr.get("max-iterations=10")
@@ -274,7 +277,7 @@ def testApplyRegisteredPassOp(module: Module):
max_rewrites = transform.param_constant(
transform.AnyParamType.get(), StringAttr.get("max-num-rewrites=1")
)
- transform.ApplyRegisteredPassOp(
+ transform.apply_registered_pass(
transform.AnyOpType.get(),
"canonicalize",
mod,
>From a96f40d3a7eff737892265112bd8378ba7a434bd Mon Sep 17 00:00:00 2001
From: Rolf Morel <rolf.morel at intel.com>
Date: Fri, 6 Jun 2025 08:31:07 -0700
Subject: [PATCH 3/7] Formatting
---
mlir/python/mlir/dialects/transform/__init__.py | 13 +++++++++++--
1 file changed, 11 insertions(+), 2 deletions(-)
diff --git a/mlir/python/mlir/dialects/transform/__init__.py b/mlir/python/mlir/dialects/transform/__init__.py
index 90282df49fb7d..176361ca32b35 100644
--- a/mlir/python/mlir/dialects/transform/__init__.py
+++ b/mlir/python/mlir/dialects/transform/__init__.py
@@ -249,8 +249,17 @@ def __init__(
)
-def apply_registered_pass(result, pass_name, target, *, options=[], loc=None, ip=None) -> Value:
- return ApplyRegisteredPassOp(result=result, pass_name=pass_name, target=target, options=options, loc=loc, ip=ip).result
+def apply_registered_pass(
+ result, pass_name, target, *, options=[], loc=None, ip=None
+) -> Value:
+ return ApplyRegisteredPassOp(
+ result=result,
+ pass_name=pass_name,
+ target=target,
+ options=options,
+ loc=loc,
+ ip=ip,
+ ).result
AnyOpTypeT = NewType("AnyOpType", AnyOpType)
>From 9890aee49a1305be82e414bac90f1468afa13bb8 Mon Sep 17 00:00:00 2001
From: Rolf Morel <rolf.morel at intel.com>
Date: Sat, 7 Jun 2025 13:16:35 -0700
Subject: [PATCH 4/7] Allow passing a dictionary of options, including params
as values
---
.../mlir/Dialect/Transform/IR/CMakeLists.txt | 4 +
.../Dialect/Transform/IR/TransformAttrs.h | 3 +
.../Dialect/Transform/IR/TransformAttrs.td | 21 ++
.../Dialect/Transform/IR/TransformDialect.td | 1 +
.../mlir/Dialect/Transform/IR/TransformOps.td | 23 +-
.../Dialect/Transform/IR/TransformDialect.cpp | 9 +
.../lib/Dialect/Transform/IR/TransformOps.cpp | 218 +++++++++++-------
.../mlir/dialects/transform/__init__.py | 64 +++--
.../Transform/test-pass-application.mlir | 152 ++++++++++--
mlir/test/python/dialects/transform.py | 27 ++-
10 files changed, 385 insertions(+), 137 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Transform/IR/CMakeLists.txt b/mlir/include/mlir/Dialect/Transform/IR/CMakeLists.txt
index df5af7ae710da..9acab9228f100 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/Transform/IR/CMakeLists.txt
@@ -20,6 +20,10 @@ mlir_tablegen(TransformDialectEnums.h.inc -gen-enum-decls)
mlir_tablegen(TransformDialectEnums.cpp.inc -gen-enum-defs)
add_public_tablegen_target(MLIRTransformDialectEnumIncGen)
add_dependencies(mlir-headers MLIRTransformDialectEnumIncGen)
+mlir_tablegen(TransformAttrs.h.inc -gen-attrdef-decls)
+mlir_tablegen(TransformAttrs.cpp.inc -gen-attrdef-defs)
+add_public_tablegen_target(MLIRTransformDialectAttributesIncGen)
+add_dependencies(mlir-headers MLIRTransformDialectAttributesIncGen)
add_mlir_dialect(TransformOps transform)
add_mlir_doc(TransformOps TransformOps Dialects/ -gen-op-doc -dialect=transform)
diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformAttrs.h b/mlir/include/mlir/Dialect/Transform/IR/TransformAttrs.h
index 3cb935003b4c4..379af932ca484 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformAttrs.h
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformAttrs.h
@@ -17,4 +17,7 @@
#include "mlir/Dialect/Transform/IR/TransformDialectEnums.h.inc"
+#define GET_ATTRDEF_CLASSES
+#include "mlir/Dialect/Transform/IR/TransformAttrs.h.inc"
+
#endif // MLIR_DIALECT_TRANSFORM_IR_TRANSFORMATTRS_H
diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformAttrs.td b/mlir/include/mlir/Dialect/Transform/IR/TransformAttrs.td
index ebad2994880e7..c25b472c10e75 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformAttrs.td
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformAttrs.td
@@ -10,6 +10,14 @@
#define MLIR_DIALECT_TRANSFORM_IR_TRANSFORMATTRS
include "mlir/IR/EnumAttr.td"
+include "mlir/Dialect/Transform/IR/TransformDialect.td"
+
+class Transform_Attr<string name, string attrMnemonic,
+ list<Trait> traits = [],
+ string baseCppClass = "::mlir::Attribute">
+ : AttrDef<Transform_Dialect, name, traits, baseCppClass> {
+ let mnemonic = attrMnemonic;
+}
def PropagateFailuresCase : I32EnumAttrCase<"Propagate", 1, "propagate">;
def SuppressFailuresCase : I32EnumAttrCase<"Suppress", 2, "suppress">;
@@ -33,4 +41,17 @@ def MatchCmpIPredicateAttr : I32EnumAttr<
let cppNamespace = "::mlir::transform";
}
+def ParamOperandIndexAttr : Transform_Attr<"ParamOperandIndex",
+ "param_operand_index" > {
+ let mnemonic = "param_operand_index";
+ let description = [{
+ Used to refer to a specific param-operand (via its index) from within an
+ attribute on a transform operation.
+ }];
+ let parameters = (ins
+ "IntegerAttr":$index
+ );
+ let assemblyFormat = "`<` $index `>`";
+}
+
#endif // MLIR_DIALECT_TRANSFORM_IR_TRANSFORMATTRS
diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformDialect.td b/mlir/include/mlir/Dialect/Transform/IR/TransformDialect.td
index d03049e186f94..c7ea5ade72ace 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformDialect.td
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformDialect.td
@@ -19,6 +19,7 @@ def Transform_Dialect : Dialect {
let cppNamespace = "::mlir::transform";
let hasOperationAttrVerify = 1;
+ let useDefaultAttributePrinterParser = 1;
let extraClassDeclaration = [{
/// Symbol name for the default entry point "named sequence".
constexpr const static ::llvm::StringLiteral
diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
index e864a65f8ceac..f75ba27e58e76 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
@@ -405,10 +405,23 @@ def ApplyRegisteredPassOp : TransformDialectOp<"apply_registered_pass",
let description = [{
This transform applies the specified pass or pass pipeline to the targeted
ops. The name of the pass/pipeline is specified as a string attribute, as
- set during pass/pipeline registration. Optionally, pass options may be
- specified as (space-separated) string attributes with the option to pass
- these attributes via params. The pass options syntax is identical to the one
- used with "mlir-opt".
+ set during pass/pipeline registration.
+
+ Optionally, pass options may be specified via a DictionaryAttr. This
+ dictionary is converted to a string -- formatted `key=value ...` -- which
+ is expected to be in the exact format used by the pass on the commandline.
+ Values are either attributes or (SSA-values of) Transform Dialect params.
+ For example:
+
+ ```mlir
+ transform.apply_registered_pass "canonicalize"
+ with options = { "top-down" = false,
+ "max-iterations" = %max_iter,
+ "test-convergence" = true,
+ "max-num-rewrites" = %max_rewrites }
+ to %module
+ : (!transform.any_param, !transform.any_param, !transform.any_op) -> !transform.any_op
+ ```
This op first looks for a pass pipeline with the specified name. If no such
pipeline exists, it looks for a pass with the specified name. If no such
@@ -422,7 +435,7 @@ def ApplyRegisteredPassOp : TransformDialectOp<"apply_registered_pass",
}];
let arguments = (ins StrAttr:$pass_name,
- DefaultValuedAttr<ArrayAttr, "{}">:$options,
+ DefaultValuedAttr<DictionaryAttr, "{}">:$options,
Variadic<TransformParamTypeInterface>:$dynamic_options,
TransformHandleTypeInterface:$target);
let results = (outs TransformHandleTypeInterface:$result);
diff --git a/mlir/lib/Dialect/Transform/IR/TransformDialect.cpp b/mlir/lib/Dialect/Transform/IR/TransformDialect.cpp
index 497ceb19f1a21..4a95fe7459e8c 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformDialect.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformDialect.cpp
@@ -8,17 +8,22 @@
#include "mlir/Dialect/Transform/IR/TransformDialect.h"
#include "mlir/Analysis/CallGraph.h"
+#include "mlir/Dialect/Transform/IR/TransformAttrs.h"
#include "mlir/Dialect/Transform/IR/TransformOps.h"
#include "mlir/Dialect/Transform/IR/TransformTypes.h"
#include "mlir/Dialect/Transform/IR/Utils.h"
#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
#include "mlir/IR/DialectImplementation.h"
#include "llvm/ADT/SCCIterator.h"
+#include "llvm/ADT/TypeSwitch.h"
using namespace mlir;
#include "mlir/Dialect/Transform/IR/TransformDialect.cpp.inc"
+#define GET_ATTRDEF_CLASSES
+#include "mlir/Dialect/Transform/IR/TransformAttrs.cpp.inc"
+
#ifndef NDEBUG
void transform::detail::checkImplementsTransformOpInterface(
StringRef name, MLIRContext *context) {
@@ -66,6 +71,10 @@ void transform::TransformDialect::initialize() {
#include "mlir/Dialect/Transform/IR/TransformOps.cpp.inc"
>();
initializeTypes();
+ addAttributes<
+#define GET_ATTRDEF_LIST
+#include "mlir/Dialect/Transform/IR/TransformAttrs.cpp.inc"
+ >();
initializeLibraryModule();
}
diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
index a0f9518e3d12f..322c89ec7d7f6 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
@@ -54,10 +54,11 @@
using namespace mlir;
static ParseResult parseApplyRegisteredPassOptions(
- OpAsmParser &parser, ArrayAttr &options,
+ OpAsmParser &parser, DictionaryAttr &options,
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &dynamicOptions);
static void printApplyRegisteredPassOptions(OpAsmPrinter &printer,
- Operation *op, ArrayAttr options,
+ Operation *op,
+ DictionaryAttr options,
ValueRange dynamicOptions);
static ParseResult parseSequenceOpOperands(
OpAsmParser &parser, std::optional<OpAsmParser::UnresolvedOperand> &root,
@@ -784,41 +785,50 @@ DiagnosedSilenceableFailure
transform::ApplyRegisteredPassOp::apply(transform::TransformRewriter &rewriter,
transform::TransformResults &results,
transform::TransformState &state) {
- // Obtain a single options-string from options passed statically as
- // string attributes as well as "dynamically" through params.
+ // Obtain a single options-string to pass to the pass(-pipeline) from options
+ // passed in as a dictionary of keys mapping to values which are either
+ // attributes or param-operands pointing to attributes.
+
std::string options;
+ llvm::raw_string_ostream optionsStream(options); // For "printing" attrs.
+
OperandRange dynamicOptions = getDynamicOptions();
- size_t dynamicOptionsIdx = 0;
- for (auto [idx, optionAttr] : llvm::enumerate(getOptions())) {
+ for (auto [idx, namedAttribute] : llvm::enumerate(getOptions())) {
if (idx > 0)
- options += " "; // Interleave options seperator.
-
- if (auto strAttr = dyn_cast<StringAttr>(optionAttr)) {
- options += strAttr.getValue();
- } else if (isa<UnitAttr>(optionAttr)) {
- assert(dynamicOptionsIdx < dynamicOptions.size() &&
+ optionsStream << " "; // Interleave options separator.
+ optionsStream << namedAttribute.getName().str(); // Append the key.
+ optionsStream << "="; // And the key-value separator.
+
+ Attribute valueAttrToAppend;
+ if (auto paramOperandIndex = dyn_cast<transform::ParamOperandIndexAttr>(
+ namedAttribute.getValue())) {
+ // The corresponding value attribute is passed in via a param.
+ // Obtain the param-operand via its specified index.
+ size_t dynamicOptionIdx = paramOperandIndex.getIndex().getInt();
+ assert(dynamicOptionIdx < dynamicOptions.size() &&
"number of dynamic option markers (UnitAttr) in options ArrayAttr "
"should be the same as the number of options passed as params");
ArrayRef<Attribute> dynamicOption =
- state.getParams(dynamicOptions[dynamicOptionsIdx++]);
+ state.getParams(dynamicOptions[dynamicOptionIdx]);
if (dynamicOption.size() != 1)
- return emitSilenceableError() << "options passed as a param must have "
- "a single value associated, param "
- << dynamicOptionsIdx - 1 << " associates "
- << dynamicOption.size();
-
- if (auto dynamicOptionStr = dyn_cast<StringAttr>(dynamicOption[0])) {
- options += dynamicOptionStr.getValue();
- } else {
return emitSilenceableError()
- << "options passed as a param must be a string, got "
- << dynamicOption[0];
- }
+ << "options passed as a param must have "
+ "a single value associated, param "
+ << dynamicOptionIdx << " associates " << dynamicOption.size();
+ valueAttrToAppend = dynamicOption[0];
+ } else {
+ // Value is a static attribute.
+ valueAttrToAppend = namedAttribute.getValue();
+ }
+
+ // Append string representation of value attribute.
+ if (auto strAttr = dyn_cast<StringAttr>(valueAttrToAppend)) {
+ optionsStream << strAttr.getValue().str();
} else {
- llvm_unreachable(
- "expected options element to be either StringAttr or UnitAttr");
+ valueAttrToAppend.print(optionsStream, /*elideType=*/true);
}
}
+ optionsStream.flush();
// Get pass or pass pipeline from registry.
const PassRegistryEntry *info = PassPipelineInfo::lookup(getPassName());
@@ -864,84 +874,116 @@ transform::ApplyRegisteredPassOp::apply(transform::TransformRewriter &rewriter,
}
static ParseResult parseApplyRegisteredPassOptions(
- OpAsmParser &parser, ArrayAttr &options,
+ OpAsmParser &parser, DictionaryAttr &options,
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &dynamicOptions) {
- auto dynamicOptionMarker = UnitAttr::get(parser.getContext());
- SmallVector<Attribute> optionsArray;
-
- auto parseOperandOrString = [&]() -> OptionalParseResult {
- OpAsmParser::UnresolvedOperand operand;
- OptionalParseResult parsedOperand = parser.parseOptionalOperand(operand);
- if (parsedOperand.has_value()) {
- if (failed(parsedOperand.value()))
- return failure();
-
- dynamicOptions.push_back(operand);
- optionsArray.push_back(
- dynamicOptionMarker); // Placeholder for knowing where to
- // inject the dynamic option-as-param.
- return success();
- }
+ // Construct the options DictionaryAttr per a `{ key = value, ... }` syntax.
+ SmallVector<NamedAttribute> keyValuePairs;
- StringAttr stringAttr;
- OptionalParseResult parsedStringAttr =
- parser.parseOptionalAttribute(stringAttr);
- if (parsedStringAttr.has_value()) {
- if (failed(parsedStringAttr.value()))
- return failure();
- optionsArray.push_back(stringAttr);
- return success();
- }
+ size_t dynamicOptionsIdx = 0;
+ auto parseKeyValuePair = [&]() -> ParseResult {
+ // Parse items of the form `key = value` where `key` is a bare identifier or
+ // a string and `value` is either an attribute or an operand.
+
+ std::string key;
+ Attribute valueAttr;
+ if (parser.parseOptionalKeywordOrString(&key))
+ return parser.emitError(parser.getCurrentLocation())
+ << "expected key to either be an identifier or a string";
+ if (key.empty())
+ return failure();
- return std::nullopt;
+ if (parser.parseEqual())
+ return parser.emitError(parser.getCurrentLocation())
+ << "expected '=' after key in key-value pair";
+
+ // Parse the value, which can be either an attribute or an operand.
+ OptionalParseResult parsedValueAttr =
+ parser.parseOptionalAttribute(valueAttr);
+ if (!parsedValueAttr.has_value()) {
+ OpAsmParser::UnresolvedOperand operand;
+ ParseResult parsedOperand = parser.parseOperand(operand);
+ if (failed(parsedOperand))
+ return parser.emitError(parser.getCurrentLocation())
+ << "expected a valid attribute or operand as value associated "
+ << "to key '" << key << "'";
+ dynamicOptions.push_back(operand);
+ auto wrappedIndex = IntegerAttr::get(
+ IntegerType::get(parser.getContext(), 64), dynamicOptionsIdx++);
+ valueAttr = transform::ParamOperandIndexAttr::get(parser.getContext(),
+ wrappedIndex);
+ } else if (failed(parsedValueAttr.value())) {
+ return failure(); // NB: Attempted parse should have output error message.
+ } else if (isa<transform::ParamOperandIndexAttr>(valueAttr)) {
+ return parser.emitError(parser.getCurrentLocation())
+ << "the param_operand_index attribute is a marker reserved for "
+ << "indicating a value will be passed via params and is only used "
+ << "in the generic print format";
+ }
+
+ keyValuePairs.push_back(NamedAttribute(key, valueAttr));
+ return success();
};
- OptionalParseResult parsedOptionsElement = parseOperandOrString();
- while (parsedOptionsElement.has_value()) {
- if (failed(parsedOptionsElement.value()))
- return failure();
- parsedOptionsElement = parseOperandOrString();
- }
+ if (parser.parseCommaSeparatedList(AsmParser::Delimiter::Braces,
+ parseKeyValuePair,
+ " in options dictionary"))
+ return failure(); // NB: Attempted parse should have output error message.
- if (optionsArray.empty()) {
+ if (DictionaryAttr::findDuplicate(
+ keyValuePairs, /*isSorted=*/false) // Also sorts the keyValuePairs.
+ .has_value())
return parser.emitError(parser.getCurrentLocation())
- << "expected at least one option (either a string or a param)";
- }
- options = parser.getBuilder().getArrayAttr(optionsArray);
+ << "duplicate keys found in options dictionary";
+
+ options = DictionaryAttr::getWithSorted(parser.getContext(), keyValuePairs);
+
return success();
}
static void printApplyRegisteredPassOptions(OpAsmPrinter &printer,
- Operation *op, ArrayAttr options,
+ Operation *op,
+ DictionaryAttr options,
ValueRange dynamicOptions) {
- size_t currentDynamicOptionIdx = 0;
- for (auto [idx, optionAttr] : llvm::enumerate(options)) {
- if (idx > 0)
- printer << " "; // Interleave options separator.
+ if (options.empty())
+ return;
- if (isa<UnitAttr>(optionAttr))
- printer.printOperand(dynamicOptions[currentDynamicOptionIdx++]);
- else if (auto strAttr = dyn_cast<StringAttr>(optionAttr))
- printer.printAttribute(strAttr);
- else
- llvm_unreachable("each option should be either a StringAttr or UnitAttr");
- }
+ printer << "{";
+ llvm::interleaveComma(options, printer, [&](NamedAttribute namedAttribute) {
+ printer << namedAttribute.getName() << " = ";
+ Attribute value = namedAttribute.getValue();
+ if (auto indexAttr = dyn_cast<transform::ParamOperandIndexAttr>(value)) {
+ printer.printOperand(dynamicOptions[indexAttr.getIndex().getInt()]);
+ } else {
+ printer.printAttribute(value);
+ }
+ });
+ printer << "}";
}
LogicalResult transform::ApplyRegisteredPassOp::verify() {
- size_t numUnitsInOptions = 0;
- for (Attribute optionsElement : getOptions()) {
- if (isa<UnitAttr>(optionsElement))
- numUnitsInOptions++;
- else if (!isa<StringAttr>(optionsElement))
- return emitOpError() << "expected each option to be either a StringAttr "
- << "or a UnitAttr, got " << optionsElement;
- }
-
- if (getDynamicOptions().size() != numUnitsInOptions)
- return emitOpError()
- << "expected the same number of options passed as params as "
- << "UnitAttr elements in options ArrayAttr";
+ // Check that there is a one-to-one correspondence between param operands
+ // and references to dynamic options in the options dictionary.
+
+ auto dynamicOptions = SmallVector<Value>(getDynamicOptions());
+ for (NamedAttribute namedAttr : getOptions())
+ if (auto paramOperandIndex =
+ dyn_cast<transform::ParamOperandIndexAttr>(namedAttr.getValue())) {
+ size_t dynamicOptionIdx = paramOperandIndex.getIndex().getInt();
+ if (dynamicOptionIdx < 0 || dynamicOptionIdx >= dynamicOptions.size())
+ return emitOpError()
+ << "dynamic option index " << dynamicOptionIdx
+ << " is out of bounds for the number of dynamic options: "
+ << dynamicOptions.size();
+ if (dynamicOptions[dynamicOptionIdx] == nullptr)
+ return emitOpError() << "dynamic option index " << dynamicOptionIdx
+ << " is already used in options";
+ dynamicOptions[dynamicOptionIdx] = nullptr; // Mark this option as used.
+ }
+
+ for (Value dynamicOption : dynamicOptions)
+ if (dynamicOption)
+ return emitOpError() << "a param operand does not have a corresponding "
+ << "param_operand_index attr in the options dict";
return success();
}
diff --git a/mlir/python/mlir/dialects/transform/__init__.py b/mlir/python/mlir/dialects/transform/__init__.py
index 176361ca32b35..4a97dcebd3cdb 100644
--- a/mlir/python/mlir/dialects/transform/__init__.py
+++ b/mlir/python/mlir/dialects/transform/__init__.py
@@ -18,7 +18,7 @@
except ImportError as e:
raise RuntimeError("Error loading imports from extension module") from e
-from typing import Optional, Sequence, Union, NewType
+from typing import Dict, Optional, Sequence, Union, NewType
@_ods_cext.register_operation(_Dialect, replace=True)
@@ -214,43 +214,77 @@ def __init__(
super().__init__(_get_op_results_or_values(operands), loc=loc, ip=ip)
+ at register_attribute_builder("ParamOperandIndexAttr")
+def _paramOperandIndexAttr(x: int, context) -> Attribute:
+ return Attribute.parse(f"#transform.param_operand_index<{x}>", context=context)
+
+
@_ods_cext.register_operation(_Dialect, replace=True)
class ApplyRegisteredPassOp(ApplyRegisteredPassOp):
def __init__(
self,
result: Type,
pass_name: Union[str, StringAttr],
- target: Value,
+ target: Union[Operation, Value, OpView],
*,
- options: Sequence[Union[str, StringAttr, Value, Operation]] = [],
+ options: Dict[
+ Union[str, StringAttr],
+ Union[Attribute, Value, Operation, OpView],
+ ] = {},
loc=None,
ip=None,
):
- static_options = []
+ options_dict = {}
dynamic_options = []
- for opt in options:
- if isinstance(opt, str):
- static_options.append(StringAttr.get(opt))
- elif isinstance(opt, StringAttr):
- static_options.append(opt)
- elif isinstance(opt, Value):
- static_options.append(UnitAttr.get())
- dynamic_options.append(_get_op_result_or_value(opt))
+
+ ParamOperandIndexAttr = AttrBuilder.get("ParamOperandIndexAttr")
+ context = (loc and loc.context) or Context.current
+
+ cur_param_operand_idx = 0
+ for key, value in options.items():
+ if isinstance(key, StringAttr):
+ key = key.value
+
+ if isinstance(value, (Value, Operation, OpView)):
+ value = _get_op_result_or_value(value)
+ # v = Attribute.parse(
+ # f"#transform.param_operand_index<{cur_param_operand_idx}>",
+ # context=context,
+ # )
+ v = _paramOperandIndexAttr(cur_param_operand_idx, context)
+ options_dict[key] = v
+ cur_param_operand_idx += 1
+ dynamic_options.append(value)
+ elif isinstance(value, Attribute):
+ options_dict[key] = value
+ elif isinstance(value, str):
+ options_dict[key] = StringAttr.get(value)
else:
- raise TypeError(f"Unsupported option type: {type(opt)}")
+ raise TypeError(f"Unsupported option type: {type(value)}")
+ if len(options_dict) > 0:
+ print(options_dict, cur_param_operand_idx)
super().__init__(
result,
pass_name,
dynamic_options,
target=_get_op_result_or_value(target),
- options=static_options,
+ options=DictAttr.get(options_dict),
loc=loc,
ip=ip,
)
def apply_registered_pass(
- result, pass_name, target, *, options=[], loc=None, ip=None
+ result: Type,
+ pass_name: Union[str, StringAttr],
+ target: Union[Operation, Value, OpView],
+ *,
+ options: Dict[
+ Union[str, StringAttr],
+ Union[Attribute, Value, Operation, OpView],
+ ] = {},
+ loc=None,
+ ip=None,
) -> Value:
return ApplyRegisteredPassOp(
result=result,
diff --git a/mlir/test/Dialect/Transform/test-pass-application.mlir b/mlir/test/Dialect/Transform/test-pass-application.mlir
index 463fd98afa65c..ea1dd4abd4db2 100644
--- a/mlir/test/Dialect/Transform/test-pass-application.mlir
+++ b/mlir/test/Dialect/Transform/test-pass-application.mlir
@@ -80,7 +80,7 @@ module attributes {transform.with_named_sequence} {
// expected-error @below {{failed to add pass or pass pipeline to pipeline: canonicalize}}
// expected-error @below {{<Pass-Options-Parser>: no such option invalid-option}}
transform.apply_registered_pass "canonicalize"
- with options = "invalid-option=1" to %1
+ with options = { "invalid-option" = 1 } to %1
: (!transform.any_op) -> !transform.any_op
transform.yield
}
@@ -97,7 +97,7 @@ module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op) {
%1 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
transform.apply_registered_pass "canonicalize"
- with options = "top-down=false" to %1
+ with options = { "top-down" = false } to %1
: (!transform.any_op) -> !transform.any_op
transform.yield
}
@@ -115,7 +115,7 @@ module attributes {transform.with_named_sequence} {
%1 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
//transform.apply_registered_pass "canonicalize" with options = "top-down=false,max-iterations=10" to %1 : (!transform.any_op) -> !transform.any_op
transform.apply_registered_pass "canonicalize"
- with options = "top-down=false test-convergence=true" to %1
+ with options = { "top-down" = false, "test-convergence" =true } to %1
: (!transform.any_op) -> !transform.any_op
transform.yield
}
@@ -132,7 +132,7 @@ module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op) {
%1 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
transform.apply_registered_pass "canonicalize"
- with options = "top-down=false" "max-iterations=0" to %1
+ with options = { "top-down" = false, "max-iterations" = 0 } to %1
: (!transform.any_op) -> !transform.any_op
transform.yield
}
@@ -148,10 +148,15 @@ func.func @valid_dynamic_pass_options() {
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op) {
%1 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
- %max_iter = transform.param.constant "max-iterations=10" -> !transform.any_param
- %max_rewrites = transform.param.constant "max-num-rewrites=1" -> !transform.any_param
- %2 = transform.apply_registered_pass "canonicalize"
- with options = "top-down=false" %max_iter "test-convergence=true" %max_rewrites to %1
+ %max_iter = transform.param.constant 10 -> !transform.any_param
+ %max_rewrites = transform.param.constant 1 -> !transform.any_param
+ %2 = transform.apply_registered_pass
+ "canonicalize"
+ with options = { "top-down" = false,
+ "max-iterations" = %max_iter,
+ "test-convergence" = true,
+ "max-num-rewrites" = %max_rewrites }
+ to %1
: (!transform.any_param, !transform.any_param, !transform.any_op) -> !transform.any_op
transform.yield
}
@@ -159,7 +164,7 @@ module attributes {transform.with_named_sequence} {
// -----
-func.func @invalid_dynamic_options_as_array() {
+func.func @invalid_options_as_str() {
return
}
@@ -167,34 +172,80 @@ module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op) {
%1 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
%max_iter = transform.param.constant "max-iterations=10" -> !transform.any_param
- // expected-error @+2 {{expected at least one option (either a string or a param)}}
+ // expected-error @+2 {{expected '{' in options dictionary}}
%2 = transform.apply_registered_pass "canonicalize"
- with options = ["top-down=false" %max_iter] to %1
- : (!transform.any_param, !transform.any_param, !transform.any_op) -> !transform.any_op
+ with options = "top-down=false" to %1 : (!transform.any_op) -> !transform.any_op
transform.yield
}
}
// -----
-func.func @invalid_options_as_pairs() {
+func.func @invalid_options_as_pairs_without_braces() {
return
}
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op) {
%1 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
- // expected-error @+2 {{expected 'to'}}
+ // expected-error @+2 {{expected '{' in options dictionary}}
%2 = transform.apply_registered_pass "canonicalize"
- with options = "top-down=" false to %1
- : (!transform.any_param, !transform.any_param, !transform.any_op) -> !transform.any_op
+ with options = "top-down"=false to %1 : (!transform.any_op) -> !transform.any_op
transform.yield
}
}
// -----
-func.func @invalid_pass_option_param() {
+func.func @invalid_options_due_to_reserved_attr() {
+ return
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op) {
+ %1 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ // expected-error @+2 {{the param_operand_index attribute is a marker reserved for indicating a value will be passed via params and is only used in the generic print format}}
+ %2 = transform.apply_registered_pass "canonicalize"
+ with options = { "top-down" = #transform.param_operand_index<0> } to %1 : (!transform.any_op) -> !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+func.func @invalid_options_due_duplicated_key() {
+ return
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op) {
+ %1 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ // expected-error @+2 {{duplicate keys found in options dictionary}}
+ %2 = transform.apply_registered_pass "canonicalize"
+ with options = {"top-down"=false,"top-down"=true} to %1 : (!transform.any_op) -> !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+func.func @invalid_options_due_invalid_key() {
+ return
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op) {
+ %1 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ // expected-error @+2 {{expected key to either be an identifier or a string}}
+ %2 = transform.apply_registered_pass "canonicalize"
+ with options = { @label = 0 } to %1 : (!transform.any_op) -> !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+func.func @invalid_pass_option_bare_param() {
return
}
@@ -202,7 +253,7 @@ module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op) {
%1 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
%pass_options = transform.param.constant 42 -> !transform.any_param
- // expected-error @below {{options passed as a param must be a string, got 42}}
+ // expected-error @+2 {{expected '{' in options dictionary}}
transform.apply_registered_pass "canonicalize"
with options = %pass_options to %1
: (!transform.any_param, !transform.any_op) -> !transform.any_op
@@ -219,12 +270,12 @@ func.func @too_many_pass_option_params() {
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op) {
%1 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
- %x = transform.param.constant "x" -> !transform.any_param
- %y = transform.param.constant "y" -> !transform.any_param
- %pass_options = transform.merge_handles %x, %y : !transform.any_param
+ %x = transform.param.constant true -> !transform.any_param
+ %y = transform.param.constant false -> !transform.any_param
+ %topdown_options = transform.merge_handles %x, %y : !transform.any_param
// expected-error @below {{options passed as a param must have a single value associated, param 0 associates 2}}
transform.apply_registered_pass "canonicalize"
- with options = %pass_options to %1
+ with options = { "top-down" = %topdown_options } to %1
: (!transform.any_param, !transform.any_op) -> !transform.any_op
transform.yield
}
@@ -248,3 +299,60 @@ module attributes {transform.with_named_sequence} {
transform.yield
}
}
+
+// -----
+
+"builtin.module"() ({
+ "transform.named_sequence"() <{function_type = (!transform.any_op) -> (), sym_name = "__transform_main"}> ({
+ ^bb0(%arg0: !transform.any_op):
+ %0 = "transform.structured.match"(%arg0) <{ops = ["func.func"]}> : (!transform.any_op) -> !transform.any_op
+ %1 = "transform.param.constant"() <{value = 10 : i64}> : () -> !transform.any_param
+ // expected-error @below {{dynamic option index 1 is out of bounds for the number of dynamic options: 1}}
+ %2 = "transform.apply_registered_pass"(%1, %0) <{
+ options = {"max-iterations" = #transform.param_operand_index<1 : i64>,
+ "test-convergence" = true,
+ "top-down" = false},
+ pass_name = "canonicalize"}>
+ : (!transform.any_param, !transform.any_op) -> !transform.any_op
+ "transform.yield"() : () -> ()
+ }) : () -> ()
+}) {transform.with_named_sequence} : () -> ()
+
+// -----
+
+"builtin.module"() ({
+ "transform.named_sequence"() <{function_type = (!transform.any_op) -> (), sym_name = "__transform_main"}> ({
+ ^bb0(%arg0: !transform.any_op):
+ %0 = "transform.structured.match"(%arg0) <{ops = ["func.func"]}> : (!transform.any_op) -> !transform.any_op
+ %1 = "transform.param.constant"() <{value = 10 : i64}> : () -> !transform.any_param
+ %2 = "transform.param.constant"() <{value = 1 : i64}> : () -> !transform.any_param
+ // expected-error @below {{dynamic option index 0 is already used in options}}
+ %3 = "transform.apply_registered_pass"(%1, %2, %0) <{
+ options = {"max-iterations" = #transform.param_operand_index<0 : i64>,
+ "max-num-rewrites" = #transform.param_operand_index<0 : i64>,
+ "test-convergence" = true,
+ "top-down" = false},
+ pass_name = "canonicalize"}>
+ : (!transform.any_param, !transform.any_param, !transform.any_op) -> !transform.any_op
+ "transform.yield"() : () -> ()
+ }) : () -> ()
+}) {transform.with_named_sequence} : () -> ()
+
+// -----
+
+"builtin.module"() ({
+ "transform.named_sequence"() <{function_type = (!transform.any_op) -> (), sym_name = "__transform_main"}> ({
+ ^bb0(%arg0: !transform.any_op):
+ %0 = "transform.structured.match"(%arg0) <{ops = ["func.func"]}> : (!transform.any_op) -> !transform.any_op
+ %1 = "transform.param.constant"() <{value = 10 : i64}> : () -> !transform.any_param
+ %2 = "transform.param.constant"() <{value = 1 : i64}> : () -> !transform.any_param
+ // expected-error @below {{a param operand does not have a corresponding param_operand_index attr in the options dict}}
+ %3 = "transform.apply_registered_pass"(%1, %2, %0) <{
+ options = {"max-iterations" = #transform.param_operand_index<0 : i64>,
+ "test-convergence" = true,
+ "top-down" = false},
+ pass_name = "canonicalize"}>
+ : (!transform.any_param, !transform.any_param, !transform.any_op) -> !transform.any_op
+ "transform.yield"() : () -> ()
+ }) : () -> ()
+}) {transform.with_named_sequence} : () -> ()
diff --git a/mlir/test/python/dialects/transform.py b/mlir/test/python/dialects/transform.py
index 6492b58570814..48bc9bad37a1e 100644
--- a/mlir/test/python/dialects/transform.py
+++ b/mlir/test/python/dialects/transform.py
@@ -269,27 +269,40 @@ def testApplyRegisteredPassOp(module: Module):
transform.AnyOpType.get(),
"canonicalize",
mod.result,
- options=("top-down=false",),
+ options={"top-down": BoolAttr.get(False)},
)
max_iter = transform.param_constant(
- transform.AnyParamType.get(), StringAttr.get("max-iterations=10")
+ transform.AnyParamType.get(),
+ IntegerAttr.get(IntegerType.get_signless(64), 10),
)
max_rewrites = transform.param_constant(
- transform.AnyParamType.get(), StringAttr.get("max-num-rewrites=1")
+ transform.AnyParamType.get(),
+ IntegerAttr.get(IntegerType.get_signless(64), 1),
)
transform.apply_registered_pass(
transform.AnyOpType.get(),
"canonicalize",
mod,
- options=("top-down=false", max_iter, "test-convergence=true", max_rewrites),
+ options={
+ "top-down": BoolAttr.get(False),
+ "max-iterations": max_iter,
+ "test-convergence": BoolAttr.get(True),
+ "max-rewrites": max_rewrites,
+ },
)
transform.YieldOp()
# CHECK-LABEL: TEST: testApplyRegisteredPassOp
# CHECK: transform.sequence
# CHECK: %{{.*}} = apply_registered_pass "canonicalize" to {{.*}} : (!transform.any_op) -> !transform.any_op
- # CHECK: %{{.*}} = apply_registered_pass "canonicalize" with options = "top-down=false" to {{.*}} : (!transform.any_op) -> !transform.any_op
+ # CHECK: %{{.*}} = apply_registered_pass "canonicalize"
+ # CHECK-SAME: with options = {"top-down" = false}
+ # CHECK-SAME: to {{.*}} : (!transform.any_op) -> !transform.any_op
# CHECK: %[[MAX_ITER:.+]] = transform.param.constant
# CHECK: %[[MAX_REWRITE:.+]] = transform.param.constant
# CHECK: %{{.*}} = apply_registered_pass "canonicalize"
- # CHECK-SAME: with options = "top-down=false" %[[MAX_ITER]]
- # CHECK-SAME: "test-convergence=true" %[[MAX_REWRITE]] to %{{.*}} : (!transform.any_param, !transform.any_param, !transform.any_op) -> !transform.any_op
+ # NB: MLIR has sorted the dict lexicographically by key:
+ # CHECK-SAME: with options = {"max-iterations" = %[[MAX_ITER]],
+ # CHECK-SAME: "max-rewrites" = %[[MAX_REWRITE]],
+ # CHECK-SAME: "test-convergence" = true,
+ # CHECK-SAME: "top-down" = false}
+ # CHECK-SAME: to %{{.*}} : (!transform.any_param, !transform.any_param, !transform.any_op) -> !transform.any_op
>From e1a7803942373b84b20c6edf4538e4d9c0302aeb Mon Sep 17 00:00:00 2001
From: Rolf Morel <rolf.morel at intel.com>
Date: Sun, 8 Jun 2025 13:10:49 -0700
Subject: [PATCH 5/7] Fix Python signature
---
.../mlir/dialects/transform/__init__.py | 44 +++++++++----------
1 file changed, 22 insertions(+), 22 deletions(-)
diff --git a/mlir/python/mlir/dialects/transform/__init__.py b/mlir/python/mlir/dialects/transform/__init__.py
index 4a97dcebd3cdb..6d4abe8882e9d 100644
--- a/mlir/python/mlir/dialects/transform/__init__.py
+++ b/mlir/python/mlir/dialects/transform/__init__.py
@@ -21,6 +21,11 @@
from typing import Dict, Optional, Sequence, Union, NewType
+ at register_attribute_builder("ParamOperandIndexAttr")
+def _paramOperandIndexAttr(x: int, context) -> Attribute:
+ return Attribute.parse(f"#transform.param_operand_index<{x}>", context=context)
+
+
@_ods_cext.register_operation(_Dialect, replace=True)
class CastOp(CastOp):
def __init__(
@@ -214,11 +219,6 @@ def __init__(
super().__init__(_get_op_results_or_values(operands), loc=loc, ip=ip)
- at register_attribute_builder("ParamOperandIndexAttr")
-def _paramOperandIndexAttr(x: int, context) -> Attribute:
- return Attribute.parse(f"#transform.param_operand_index<{x}>", context=context)
-
-
@_ods_cext.register_operation(_Dialect, replace=True)
class ApplyRegisteredPassOp(ApplyRegisteredPassOp):
def __init__(
@@ -227,10 +227,12 @@ def __init__(
pass_name: Union[str, StringAttr],
target: Union[Operation, Value, OpView],
*,
- options: Dict[
- Union[str, StringAttr],
- Union[Attribute, Value, Operation, OpView],
- ] = {},
+ options: Optional[
+ Dict[
+ Union[str, StringAttr],
+ Union[Attribute, Value, Operation, OpView],
+ ]
+ ] = None,
loc=None,
ip=None,
):
@@ -241,20 +243,16 @@ def __init__(
context = (loc and loc.context) or Context.current
cur_param_operand_idx = 0
- for key, value in options.items():
+ for key, value in options.items() if options is not None else {}:
if isinstance(key, StringAttr):
key = key.value
if isinstance(value, (Value, Operation, OpView)):
- value = _get_op_result_or_value(value)
- # v = Attribute.parse(
- # f"#transform.param_operand_index<{cur_param_operand_idx}>",
- # context=context,
- # )
- v = _paramOperandIndexAttr(cur_param_operand_idx, context)
- options_dict[key] = v
+ dynamic_options.append(_get_op_result_or_value(value))
+ options_dict[key] = ParamOperandIndexAttr(
+ cur_param_operand_idx, context
+ )
cur_param_operand_idx += 1
- dynamic_options.append(value)
elif isinstance(value, Attribute):
options_dict[key] = value
elif isinstance(value, str):
@@ -279,10 +277,12 @@ def apply_registered_pass(
pass_name: Union[str, StringAttr],
target: Union[Operation, Value, OpView],
*,
- options: Dict[
- Union[str, StringAttr],
- Union[Attribute, Value, Operation, OpView],
- ] = {},
+ options: Optional[
+ Dict[
+ Union[str, StringAttr],
+ Union[Attribute, Value, Operation, OpView],
+ ]
+ ] = None,
loc=None,
ip=None,
) -> Value:
>From 079b3dbd75d791221155081e07061f6910360cf1 Mon Sep 17 00:00:00 2001
From: Rolf Morel <rolf.morel at intel.com>
Date: Mon, 9 Jun 2025 16:22:10 -0700
Subject: [PATCH 6/7] Add comments to generic format tests
---
.../Transform/test-pass-application.mlir | 17 +++++++++++++++++
1 file changed, 17 insertions(+)
diff --git a/mlir/test/Dialect/Transform/test-pass-application.mlir b/mlir/test/Dialect/Transform/test-pass-application.mlir
index ea1dd4abd4db2..86615f05e4a01 100644
--- a/mlir/test/Dialect/Transform/test-pass-application.mlir
+++ b/mlir/test/Dialect/Transform/test-pass-application.mlir
@@ -302,6 +302,14 @@ module attributes {transform.with_named_sequence} {
// -----
+/////////////////////////////////////////////////////////////////////
+// Check that the following cases are caugh in the generic format. //
+/////////////////////////////////////////////////////////////////////
+
+// Invalid due to param_operand_index occurences in options dict not being
+// one-to-one with the dynamic options provided as params:
+// param_operand_index out of bounds w.r.t. the number of options provided via params.
+
"builtin.module"() ({
"transform.named_sequence"() <{function_type = (!transform.any_op) -> (), sym_name = "__transform_main"}> ({
^bb0(%arg0: !transform.any_op):
@@ -320,6 +328,11 @@ module attributes {transform.with_named_sequence} {
// -----
+// Invalid due to param_operand_index occurences in options dict not being
+// one-to-one with the dynamic options provided as params:
+// the first option-param is referred to twice and the second one not at all.
+// (The pretty-printed format supports this by passing in the same param twice.)
+
"builtin.module"() ({
"transform.named_sequence"() <{function_type = (!transform.any_op) -> (), sym_name = "__transform_main"}> ({
^bb0(%arg0: !transform.any_op):
@@ -340,6 +353,10 @@ module attributes {transform.with_named_sequence} {
// -----
+// Invalid due to param_operand_index occurences in options dict not being
+// one-to-one with the dynamic options provided as params:
+// two option-params are provide though only the first one is referred to from the options-dict.
+
"builtin.module"() ({
"transform.named_sequence"() <{function_type = (!transform.any_op) -> (), sym_name = "__transform_main"}> ({
^bb0(%arg0: !transform.any_op):
>From a8d0b3355ead35f9f5810fc53fbc2a3d9eea8480 Mon Sep 17 00:00:00 2001
From: Rolf Morel <rolf.morel at intel.com>
Date: Wed, 11 Jun 2025 06:54:04 -0700
Subject: [PATCH 7/7] Address @adam-smnk's review
---
.../Dialect/Transform/IR/TransformAttrs.td | 6 ++---
.../lib/Dialect/Transform/IR/TransformOps.cpp | 27 +++++++++++--------
.../mlir/dialects/transform/__init__.py | 10 +++----
.../Transform/test-pass-application.mlir | 22 +++++++--------
4 files changed, 34 insertions(+), 31 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformAttrs.td b/mlir/include/mlir/Dialect/Transform/IR/TransformAttrs.td
index c25b472c10e75..e67a9444c24a8 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformAttrs.td
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformAttrs.td
@@ -41,9 +41,7 @@ def MatchCmpIPredicateAttr : I32EnumAttr<
let cppNamespace = "::mlir::transform";
}
-def ParamOperandIndexAttr : Transform_Attr<"ParamOperandIndex",
- "param_operand_index" > {
- let mnemonic = "param_operand_index";
+def ParamOperandAttr : Transform_Attr<"ParamOperand", "param_operand"> {
let description = [{
Used to refer to a specific param-operand (via its index) from within an
attribute on a transform operation.
@@ -51,7 +49,7 @@ def ParamOperandIndexAttr : Transform_Attr<"ParamOperandIndex",
let parameters = (ins
"IntegerAttr":$index
);
- let assemblyFormat = "`<` $index `>`";
+ let assemblyFormat = "`<` `index` `=` $index `>`";
}
#endif // MLIR_DIALECT_TRANSFORM_IR_TRANSFORMATTRS
diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
index 322c89ec7d7f6..582d082153bef 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
@@ -800,8 +800,8 @@ transform::ApplyRegisteredPassOp::apply(transform::TransformRewriter &rewriter,
optionsStream << "="; // And the key-value separator.
Attribute valueAttrToAppend;
- if (auto paramOperandIndex = dyn_cast<transform::ParamOperandIndexAttr>(
- namedAttribute.getValue())) {
+ if (auto paramOperandIndex =
+ dyn_cast<transform::ParamOperandAttr>(namedAttribute.getValue())) {
// The corresponding value attribute is passed in via a param.
// Obtain the param-operand via its specified index.
size_t dynamicOptionIdx = paramOperandIndex.getIndex().getInt();
@@ -906,16 +906,20 @@ static ParseResult parseApplyRegisteredPassOptions(
return parser.emitError(parser.getCurrentLocation())
<< "expected a valid attribute or operand as value associated "
<< "to key '" << key << "'";
+ // To make use of the operand, we need to store it in the options dict.
+ // As SSA-values cannot occur in attributes, what we do instead is store
+ // an attribute in its place that contains the index of the param-operand,
+ // so that an attr-value associated to the param can be resolved later on.
dynamicOptions.push_back(operand);
auto wrappedIndex = IntegerAttr::get(
IntegerType::get(parser.getContext(), 64), dynamicOptionsIdx++);
- valueAttr = transform::ParamOperandIndexAttr::get(parser.getContext(),
- wrappedIndex);
+ valueAttr =
+ transform::ParamOperandAttr::get(parser.getContext(), wrappedIndex);
} else if (failed(parsedValueAttr.value())) {
return failure(); // NB: Attempted parse should have output error message.
- } else if (isa<transform::ParamOperandIndexAttr>(valueAttr)) {
+ } else if (isa<transform::ParamOperandAttr>(valueAttr)) {
return parser.emitError(parser.getCurrentLocation())
- << "the param_operand_index attribute is a marker reserved for "
+ << "the param_operand attribute is a marker reserved for "
<< "indicating a value will be passed via params and is only used "
<< "in the generic print format";
}
@@ -951,7 +955,8 @@ static void printApplyRegisteredPassOptions(OpAsmPrinter &printer,
llvm::interleaveComma(options, printer, [&](NamedAttribute namedAttribute) {
printer << namedAttribute.getName() << " = ";
Attribute value = namedAttribute.getValue();
- if (auto indexAttr = dyn_cast<transform::ParamOperandIndexAttr>(value)) {
+ if (auto indexAttr = dyn_cast<transform::ParamOperandAttr>(value)) {
+ // Resolve index of param-operand to its actual SSA-value and print that.
printer.printOperand(dynamicOptions[indexAttr.getIndex().getInt()]);
} else {
printer.printAttribute(value);
@@ -966,9 +971,9 @@ LogicalResult transform::ApplyRegisteredPassOp::verify() {
auto dynamicOptions = SmallVector<Value>(getDynamicOptions());
for (NamedAttribute namedAttr : getOptions())
- if (auto paramOperandIndex =
- dyn_cast<transform::ParamOperandIndexAttr>(namedAttr.getValue())) {
- size_t dynamicOptionIdx = paramOperandIndex.getIndex().getInt();
+ if (auto paramOperand =
+ dyn_cast<transform::ParamOperandAttr>(namedAttr.getValue())) {
+ size_t dynamicOptionIdx = paramOperand.getIndex().getInt();
if (dynamicOptionIdx < 0 || dynamicOptionIdx >= dynamicOptions.size())
return emitOpError()
<< "dynamic option index " << dynamicOptionIdx
@@ -983,7 +988,7 @@ LogicalResult transform::ApplyRegisteredPassOp::verify() {
for (Value dynamicOption : dynamicOptions)
if (dynamicOption)
return emitOpError() << "a param operand does not have a corresponding "
- << "param_operand_index attr in the options dict";
+ << "param_operand attr in the options dict";
return success();
}
diff --git a/mlir/python/mlir/dialects/transform/__init__.py b/mlir/python/mlir/dialects/transform/__init__.py
index 6d4abe8882e9d..ed53c6af5086a 100644
--- a/mlir/python/mlir/dialects/transform/__init__.py
+++ b/mlir/python/mlir/dialects/transform/__init__.py
@@ -21,9 +21,9 @@
from typing import Dict, Optional, Sequence, Union, NewType
- at register_attribute_builder("ParamOperandIndexAttr")
-def _paramOperandIndexAttr(x: int, context) -> Attribute:
- return Attribute.parse(f"#transform.param_operand_index<{x}>", context=context)
+ at register_attribute_builder("ParamOperandAttr")
+def _paramOperandAttr(x: int, context) -> Attribute:
+ return Attribute.parse(f"#transform.param_operand<index={x}>", context=context)
@_ods_cext.register_operation(_Dialect, replace=True)
@@ -239,7 +239,7 @@ def __init__(
options_dict = {}
dynamic_options = []
- ParamOperandIndexAttr = AttrBuilder.get("ParamOperandIndexAttr")
+ ParamOperandAttr = AttrBuilder.get("ParamOperandAttr")
context = (loc and loc.context) or Context.current
cur_param_operand_idx = 0
@@ -249,7 +249,7 @@ def __init__(
if isinstance(value, (Value, Operation, OpView)):
dynamic_options.append(_get_op_result_or_value(value))
- options_dict[key] = ParamOperandIndexAttr(
+ options_dict[key] = ParamOperandAttr(
cur_param_operand_idx, context
)
cur_param_operand_idx += 1
diff --git a/mlir/test/Dialect/Transform/test-pass-application.mlir b/mlir/test/Dialect/Transform/test-pass-application.mlir
index 86615f05e4a01..6e6d4eb7e249f 100644
--- a/mlir/test/Dialect/Transform/test-pass-application.mlir
+++ b/mlir/test/Dialect/Transform/test-pass-application.mlir
@@ -204,9 +204,9 @@ func.func @invalid_options_due_to_reserved_attr() {
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op) {
%1 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
- // expected-error @+2 {{the param_operand_index attribute is a marker reserved for indicating a value will be passed via params and is only used in the generic print format}}
+ // expected-error @+2 {{the param_operand attribute is a marker reserved for indicating a value will be passed via params and is only used in the generic print format}}
%2 = transform.apply_registered_pass "canonicalize"
- with options = { "top-down" = #transform.param_operand_index<0> } to %1 : (!transform.any_op) -> !transform.any_op
+ with options = { "top-down" = #transform.param_operand<index=0> } to %1 : (!transform.any_op) -> !transform.any_op
transform.yield
}
}
@@ -306,7 +306,7 @@ module attributes {transform.with_named_sequence} {
// Check that the following cases are caugh in the generic format. //
/////////////////////////////////////////////////////////////////////
-// Invalid due to param_operand_index occurences in options dict not being
+// Invalid due to param_operand occurences in options dict not being
// one-to-one with the dynamic options provided as params:
// param_operand_index out of bounds w.r.t. the number of options provided via params.
@@ -317,7 +317,7 @@ module attributes {transform.with_named_sequence} {
%1 = "transform.param.constant"() <{value = 10 : i64}> : () -> !transform.any_param
// expected-error @below {{dynamic option index 1 is out of bounds for the number of dynamic options: 1}}
%2 = "transform.apply_registered_pass"(%1, %0) <{
- options = {"max-iterations" = #transform.param_operand_index<1 : i64>,
+ options = {"max-iterations" = #transform.param_operand<index=1 : i64>,
"test-convergence" = true,
"top-down" = false},
pass_name = "canonicalize"}>
@@ -328,10 +328,10 @@ module attributes {transform.with_named_sequence} {
// -----
-// Invalid due to param_operand_index occurences in options dict not being
+// Invalid due to param_operand occurences in options dict not being
// one-to-one with the dynamic options provided as params:
// the first option-param is referred to twice and the second one not at all.
-// (The pretty-printed format supports this by passing in the same param twice.)
+// (In the pretty-printed format, if you want to refer to a param SSA-value twice, it counts as two param arguments.)
"builtin.module"() ({
"transform.named_sequence"() <{function_type = (!transform.any_op) -> (), sym_name = "__transform_main"}> ({
@@ -341,8 +341,8 @@ module attributes {transform.with_named_sequence} {
%2 = "transform.param.constant"() <{value = 1 : i64}> : () -> !transform.any_param
// expected-error @below {{dynamic option index 0 is already used in options}}
%3 = "transform.apply_registered_pass"(%1, %2, %0) <{
- options = {"max-iterations" = #transform.param_operand_index<0 : i64>,
- "max-num-rewrites" = #transform.param_operand_index<0 : i64>,
+ options = {"max-iterations" = #transform.param_operand<index=0 : i64>,
+ "max-num-rewrites" = #transform.param_operand<index=0 : i64>,
"test-convergence" = true,
"top-down" = false},
pass_name = "canonicalize"}>
@@ -353,7 +353,7 @@ module attributes {transform.with_named_sequence} {
// -----
-// Invalid due to param_operand_index occurences in options dict not being
+// Invalid due to param_operand occurences in options dict not being
// one-to-one with the dynamic options provided as params:
// two option-params are provide though only the first one is referred to from the options-dict.
@@ -363,9 +363,9 @@ module attributes {transform.with_named_sequence} {
%0 = "transform.structured.match"(%arg0) <{ops = ["func.func"]}> : (!transform.any_op) -> !transform.any_op
%1 = "transform.param.constant"() <{value = 10 : i64}> : () -> !transform.any_param
%2 = "transform.param.constant"() <{value = 1 : i64}> : () -> !transform.any_param
- // expected-error @below {{a param operand does not have a corresponding param_operand_index attr in the options dict}}
+ // expected-error @below {{a param operand does not have a corresponding param_operand attr in the options dict}}
%3 = "transform.apply_registered_pass"(%1, %2, %0) <{
- options = {"max-iterations" = #transform.param_operand_index<0 : i64>,
+ options = {"max-iterations" = #transform.param_operand<index=0 : i64>,
"test-convergence" = true,
"top-down" = false},
pass_name = "canonicalize"}>
More information about the Mlir-commits
mailing list