[Mlir-commits] [mlir] [MLIR][Transform] apply_registered_pass op's options as a dict (PR #143159)

Rolf Morel llvmlistbot at llvm.org
Wed Jun 11 06:55:13 PDT 2025


https://github.com/rolfmorel updated https://github.com/llvm/llvm-project/pull/143159

>From 461d7cfaf359ee07f34c9b3eb91f402b97afe312 Mon Sep 17 00:00:00 2001
From: Rolf Morel <rolf.morel at intel.com>
Date: Fri, 6 Jun 2025 08:13:44 -0700
Subject: [PATCH 1/7] [MLIR][Transform] friendlier Python-bindings
 apply_registered_pass op

In particular, use similar syntax for providing options as in the
(pretty-)printed IR.
---
 .../mlir/dialects/transform/__init__.py       | 35 ++++++++++++++++++
 mlir/test/python/dialects/transform.py        | 36 +++++++++++++++++++
 2 files changed, 71 insertions(+)

diff --git a/mlir/python/mlir/dialects/transform/__init__.py b/mlir/python/mlir/dialects/transform/__init__.py
index 5b158ec6b65fd..cdcdeadd54cd3 100644
--- a/mlir/python/mlir/dialects/transform/__init__.py
+++ b/mlir/python/mlir/dialects/transform/__init__.py
@@ -214,6 +214,41 @@ def __init__(
         super().__init__(_get_op_results_or_values(operands), loc=loc, ip=ip)
 
 
+ at _ods_cext.register_operation(_Dialect, replace=True)
+class ApplyRegisteredPassOp(ApplyRegisteredPassOp):
+    def __init__(
+        self,
+        result: Type,
+        pass_name: Union[str, StringAttr],
+        target: Value,
+        *,
+        options: Sequence[Union[str, StringAttr, Value, Operation]] = [],
+        loc=None,
+        ip=None,
+    ):
+        static_options = []
+        dynamic_options = []
+        for opt in options:
+            if isinstance(opt, str):
+                static_options.append(StringAttr.get(opt))
+            elif isinstance(opt, StringAttr):
+                static_options.append(opt)
+            elif isinstance(opt, Value):
+                static_options.append(UnitAttr.get())
+                dynamic_options.append(_get_op_result_or_value(opt))
+            else:
+                raise TypeError(f"Unsupported option type: {type(opt)}")
+        super().__init__(
+            result,
+            pass_name,
+            dynamic_options,
+            target=_get_op_result_or_value(target),
+            options=static_options,
+            loc=loc,
+            ip=ip,
+        )
+
+
 AnyOpTypeT = NewType("AnyOpType", AnyOpType)
 
 
diff --git a/mlir/test/python/dialects/transform.py b/mlir/test/python/dialects/transform.py
index 6ed4818fc9d2f..dc0987e769a09 100644
--- a/mlir/test/python/dialects/transform.py
+++ b/mlir/test/python/dialects/transform.py
@@ -254,3 +254,39 @@ def testReplicateOp(module: Module):
     # CHECK: %[[FIRST:.+]] = pdl_match
     # CHECK: %[[SECOND:.+]] = pdl_match
     # CHECK: %{{.*}} = replicate num(%[[FIRST]]) %[[SECOND]]
+
+
+ at run
+def testApplyRegisteredPassOp(module: Module):
+    sequence = transform.SequenceOp(
+        transform.FailurePropagationMode.Propagate, [], transform.AnyOpType.get()
+    )
+    with InsertionPoint(sequence.body):
+        mod = transform.ApplyRegisteredPassOp(
+            transform.AnyOpType.get(), "canonicalize", sequence.bodyTarget
+        )
+        mod = transform.ApplyRegisteredPassOp(
+            transform.AnyOpType.get(), "canonicalize", mod, options=("top-down=false",)
+        )
+        max_iter = transform.param_constant(
+            transform.AnyParamType.get(), StringAttr.get("max-iterations=10")
+        )
+        max_rewrites = transform.param_constant(
+            transform.AnyParamType.get(), StringAttr.get("max-num-rewrites=1")
+        )
+        transform.ApplyRegisteredPassOp(
+            transform.AnyOpType.get(),
+            "canonicalize",
+            mod,
+            options=("top-down=false", max_iter, "test-convergence=true", max_rewrites),
+        )
+        transform.YieldOp()
+    # CHECK-LABEL: TEST: testApplyRegisteredPassOp
+    # CHECK: transform.sequence
+    # CHECK:   %{{.*}} = apply_registered_pass "canonicalize" to {{.*}} : (!transform.any_op) -> !transform.any_op
+    # CHECK:   %{{.*}} = apply_registered_pass "canonicalize" with options = "top-down=false" to {{.*}} : (!transform.any_op) -> !transform.any_op
+    # CHECK:   %[[MAX_ITER:.+]] = transform.param.constant
+    # CHECK:   %[[MAX_REWRITE:.+]] = transform.param.constant
+    # CHECK:   %{{.*}} = apply_registered_pass "canonicalize"
+    # CHECK-SAME:    with options = "top-down=false" %[[MAX_ITER]]
+    # CHECK-SAME:   "test-convergence=true" %[[MAX_REWRITE]] to %{{.*}} : (!transform.any_param, !transform.any_param, !transform.any_op) -> !transform.any_op

>From 18360e7f5279bc89d35cef81be22b579faf0fb28 Mon Sep 17 00:00:00 2001
From: Rolf Morel <rolf.morel at intel.com>
Date: Fri, 6 Jun 2025 08:26:39 -0700
Subject: [PATCH 2/7] snake_case_helper

---
 mlir/python/mlir/dialects/transform/__init__.py | 4 ++++
 mlir/test/python/dialects/transform.py          | 7 +++++--
 2 files changed, 9 insertions(+), 2 deletions(-)

diff --git a/mlir/python/mlir/dialects/transform/__init__.py b/mlir/python/mlir/dialects/transform/__init__.py
index cdcdeadd54cd3..90282df49fb7d 100644
--- a/mlir/python/mlir/dialects/transform/__init__.py
+++ b/mlir/python/mlir/dialects/transform/__init__.py
@@ -249,6 +249,10 @@ def __init__(
         )
 
 
+def apply_registered_pass(result, pass_name, target, *, options=[], loc=None, ip=None) -> Value:
+  return ApplyRegisteredPassOp(result=result, pass_name=pass_name, target=target, options=options, loc=loc, ip=ip).result
+
+
 AnyOpTypeT = NewType("AnyOpType", AnyOpType)
 
 
diff --git a/mlir/test/python/dialects/transform.py b/mlir/test/python/dialects/transform.py
index dc0987e769a09..6492b58570814 100644
--- a/mlir/test/python/dialects/transform.py
+++ b/mlir/test/python/dialects/transform.py
@@ -266,7 +266,10 @@ def testApplyRegisteredPassOp(module: Module):
             transform.AnyOpType.get(), "canonicalize", sequence.bodyTarget
         )
         mod = transform.ApplyRegisteredPassOp(
-            transform.AnyOpType.get(), "canonicalize", mod, options=("top-down=false",)
+            transform.AnyOpType.get(),
+            "canonicalize",
+            mod.result,
+            options=("top-down=false",),
         )
         max_iter = transform.param_constant(
             transform.AnyParamType.get(), StringAttr.get("max-iterations=10")
@@ -274,7 +277,7 @@ def testApplyRegisteredPassOp(module: Module):
         max_rewrites = transform.param_constant(
             transform.AnyParamType.get(), StringAttr.get("max-num-rewrites=1")
         )
-        transform.ApplyRegisteredPassOp(
+        transform.apply_registered_pass(
             transform.AnyOpType.get(),
             "canonicalize",
             mod,

>From a96f40d3a7eff737892265112bd8378ba7a434bd Mon Sep 17 00:00:00 2001
From: Rolf Morel <rolf.morel at intel.com>
Date: Fri, 6 Jun 2025 08:31:07 -0700
Subject: [PATCH 3/7] Formatting

---
 mlir/python/mlir/dialects/transform/__init__.py | 13 +++++++++++--
 1 file changed, 11 insertions(+), 2 deletions(-)

diff --git a/mlir/python/mlir/dialects/transform/__init__.py b/mlir/python/mlir/dialects/transform/__init__.py
index 90282df49fb7d..176361ca32b35 100644
--- a/mlir/python/mlir/dialects/transform/__init__.py
+++ b/mlir/python/mlir/dialects/transform/__init__.py
@@ -249,8 +249,17 @@ def __init__(
         )
 
 
-def apply_registered_pass(result, pass_name, target, *, options=[], loc=None, ip=None) -> Value:
-  return ApplyRegisteredPassOp(result=result, pass_name=pass_name, target=target, options=options, loc=loc, ip=ip).result
+def apply_registered_pass(
+    result, pass_name, target, *, options=[], loc=None, ip=None
+) -> Value:
+    return ApplyRegisteredPassOp(
+        result=result,
+        pass_name=pass_name,
+        target=target,
+        options=options,
+        loc=loc,
+        ip=ip,
+    ).result
 
 
 AnyOpTypeT = NewType("AnyOpType", AnyOpType)

>From 9890aee49a1305be82e414bac90f1468afa13bb8 Mon Sep 17 00:00:00 2001
From: Rolf Morel <rolf.morel at intel.com>
Date: Sat, 7 Jun 2025 13:16:35 -0700
Subject: [PATCH 4/7] Allow passing a dictionary of options, including params
 as values

---
 .../mlir/Dialect/Transform/IR/CMakeLists.txt  |   4 +
 .../Dialect/Transform/IR/TransformAttrs.h     |   3 +
 .../Dialect/Transform/IR/TransformAttrs.td    |  21 ++
 .../Dialect/Transform/IR/TransformDialect.td  |   1 +
 .../mlir/Dialect/Transform/IR/TransformOps.td |  23 +-
 .../Dialect/Transform/IR/TransformDialect.cpp |   9 +
 .../lib/Dialect/Transform/IR/TransformOps.cpp | 218 +++++++++++-------
 .../mlir/dialects/transform/__init__.py       |  64 +++--
 .../Transform/test-pass-application.mlir      | 152 ++++++++++--
 mlir/test/python/dialects/transform.py        |  27 ++-
 10 files changed, 385 insertions(+), 137 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Transform/IR/CMakeLists.txt b/mlir/include/mlir/Dialect/Transform/IR/CMakeLists.txt
index df5af7ae710da..9acab9228f100 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/Transform/IR/CMakeLists.txt
@@ -20,6 +20,10 @@ mlir_tablegen(TransformDialectEnums.h.inc -gen-enum-decls)
 mlir_tablegen(TransformDialectEnums.cpp.inc -gen-enum-defs)
 add_public_tablegen_target(MLIRTransformDialectEnumIncGen)
 add_dependencies(mlir-headers MLIRTransformDialectEnumIncGen)
+mlir_tablegen(TransformAttrs.h.inc -gen-attrdef-decls)
+mlir_tablegen(TransformAttrs.cpp.inc -gen-attrdef-defs)
+add_public_tablegen_target(MLIRTransformDialectAttributesIncGen)
+add_dependencies(mlir-headers MLIRTransformDialectAttributesIncGen)
 
 add_mlir_dialect(TransformOps transform)
 add_mlir_doc(TransformOps TransformOps Dialects/ -gen-op-doc -dialect=transform)
diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformAttrs.h b/mlir/include/mlir/Dialect/Transform/IR/TransformAttrs.h
index 3cb935003b4c4..379af932ca484 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformAttrs.h
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformAttrs.h
@@ -17,4 +17,7 @@
 
 #include "mlir/Dialect/Transform/IR/TransformDialectEnums.h.inc"
 
+#define GET_ATTRDEF_CLASSES
+#include "mlir/Dialect/Transform/IR/TransformAttrs.h.inc"
+
 #endif // MLIR_DIALECT_TRANSFORM_IR_TRANSFORMATTRS_H
diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformAttrs.td b/mlir/include/mlir/Dialect/Transform/IR/TransformAttrs.td
index ebad2994880e7..c25b472c10e75 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformAttrs.td
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformAttrs.td
@@ -10,6 +10,14 @@
 #define MLIR_DIALECT_TRANSFORM_IR_TRANSFORMATTRS
 
 include "mlir/IR/EnumAttr.td"
+include "mlir/Dialect/Transform/IR/TransformDialect.td"
+
+class Transform_Attr<string name, string attrMnemonic,
+                     list<Trait> traits = [],
+                     string baseCppClass = "::mlir::Attribute">
+    : AttrDef<Transform_Dialect, name, traits, baseCppClass> {
+  let mnemonic = attrMnemonic;
+}
 
 def PropagateFailuresCase : I32EnumAttrCase<"Propagate", 1, "propagate">;
 def SuppressFailuresCase : I32EnumAttrCase<"Suppress", 2, "suppress">;
@@ -33,4 +41,17 @@ def MatchCmpIPredicateAttr : I32EnumAttr<
   let cppNamespace = "::mlir::transform";
 }
 
+def ParamOperandIndexAttr : Transform_Attr<"ParamOperandIndex",
+                                           "param_operand_index" > {
+  let mnemonic = "param_operand_index";
+  let description = [{
+    Used to refer to a specific param-operand (via its index) from within an
+    attribute on a transform operation.
+  }];
+  let parameters = (ins
+    "IntegerAttr":$index
+  );
+  let assemblyFormat = "`<` $index `>`";
+}
+
 #endif  // MLIR_DIALECT_TRANSFORM_IR_TRANSFORMATTRS
diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformDialect.td b/mlir/include/mlir/Dialect/Transform/IR/TransformDialect.td
index d03049e186f94..c7ea5ade72ace 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformDialect.td
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformDialect.td
@@ -19,6 +19,7 @@ def Transform_Dialect : Dialect {
   let cppNamespace = "::mlir::transform";
 
   let hasOperationAttrVerify = 1;
+  let useDefaultAttributePrinterParser = 1;
   let extraClassDeclaration = [{
     /// Symbol name for the default entry point "named sequence".
     constexpr const static ::llvm::StringLiteral
diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
index e864a65f8ceac..f75ba27e58e76 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
@@ -405,10 +405,23 @@ def ApplyRegisteredPassOp : TransformDialectOp<"apply_registered_pass",
   let description = [{
     This transform applies the specified pass or pass pipeline to the targeted
     ops. The name of the pass/pipeline is specified as a string attribute, as
-    set during pass/pipeline registration. Optionally, pass options may be
-    specified as (space-separated) string attributes with the option to pass
-    these attributes via params. The pass options syntax is identical to the one
-    used with "mlir-opt".
+    set during pass/pipeline registration.
+
+    Optionally, pass options may be specified via a DictionaryAttr. This
+    dictionary is converted to a string -- formatted `key=value ...` -- which
+    is expected to be in the exact format used by the pass on the commandline.
+    Values are either attributes or (SSA-values of) Transform Dialect params.
+    For example:
+
+    ```mlir
+    transform.apply_registered_pass "canonicalize"
+        with options = { "top-down" = false,
+                         "max-iterations" = %max_iter,
+                         "test-convergence" = true,
+                         "max-num-rewrites" =  %max_rewrites }
+        to %module
+    : (!transform.any_param, !transform.any_param, !transform.any_op) -> !transform.any_op
+    ```
 
     This op first looks for a pass pipeline with the specified name. If no such
     pipeline exists, it looks for a pass with the specified name. If no such
@@ -422,7 +435,7 @@ def ApplyRegisteredPassOp : TransformDialectOp<"apply_registered_pass",
   }];
 
   let arguments = (ins StrAttr:$pass_name,
-                       DefaultValuedAttr<ArrayAttr, "{}">:$options,
+                       DefaultValuedAttr<DictionaryAttr, "{}">:$options,
                        Variadic<TransformParamTypeInterface>:$dynamic_options,
                        TransformHandleTypeInterface:$target);
   let results = (outs TransformHandleTypeInterface:$result);
diff --git a/mlir/lib/Dialect/Transform/IR/TransformDialect.cpp b/mlir/lib/Dialect/Transform/IR/TransformDialect.cpp
index 497ceb19f1a21..4a95fe7459e8c 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformDialect.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformDialect.cpp
@@ -8,17 +8,22 @@
 
 #include "mlir/Dialect/Transform/IR/TransformDialect.h"
 #include "mlir/Analysis/CallGraph.h"
+#include "mlir/Dialect/Transform/IR/TransformAttrs.h"
 #include "mlir/Dialect/Transform/IR/TransformOps.h"
 #include "mlir/Dialect/Transform/IR/TransformTypes.h"
 #include "mlir/Dialect/Transform/IR/Utils.h"
 #include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
 #include "mlir/IR/DialectImplementation.h"
 #include "llvm/ADT/SCCIterator.h"
+#include "llvm/ADT/TypeSwitch.h"
 
 using namespace mlir;
 
 #include "mlir/Dialect/Transform/IR/TransformDialect.cpp.inc"
 
+#define GET_ATTRDEF_CLASSES
+#include "mlir/Dialect/Transform/IR/TransformAttrs.cpp.inc"
+
 #ifndef NDEBUG
 void transform::detail::checkImplementsTransformOpInterface(
     StringRef name, MLIRContext *context) {
@@ -66,6 +71,10 @@ void transform::TransformDialect::initialize() {
 #include "mlir/Dialect/Transform/IR/TransformOps.cpp.inc"
       >();
   initializeTypes();
+  addAttributes<
+#define GET_ATTRDEF_LIST
+#include "mlir/Dialect/Transform/IR/TransformAttrs.cpp.inc"
+      >();
   initializeLibraryModule();
 }
 
diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
index a0f9518e3d12f..322c89ec7d7f6 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
@@ -54,10 +54,11 @@
 using namespace mlir;
 
 static ParseResult parseApplyRegisteredPassOptions(
-    OpAsmParser &parser, ArrayAttr &options,
+    OpAsmParser &parser, DictionaryAttr &options,
     SmallVectorImpl<OpAsmParser::UnresolvedOperand> &dynamicOptions);
 static void printApplyRegisteredPassOptions(OpAsmPrinter &printer,
-                                            Operation *op, ArrayAttr options,
+                                            Operation *op,
+                                            DictionaryAttr options,
                                             ValueRange dynamicOptions);
 static ParseResult parseSequenceOpOperands(
     OpAsmParser &parser, std::optional<OpAsmParser::UnresolvedOperand> &root,
@@ -784,41 +785,50 @@ DiagnosedSilenceableFailure
 transform::ApplyRegisteredPassOp::apply(transform::TransformRewriter &rewriter,
                                         transform::TransformResults &results,
                                         transform::TransformState &state) {
-  // Obtain a single options-string from options passed statically as
-  // string attributes as well as "dynamically" through params.
+  // Obtain a single options-string to pass to the pass(-pipeline) from options
+  // passed in as a dictionary of keys mapping to values which are either
+  // attributes or param-operands pointing to attributes.
+
   std::string options;
+  llvm::raw_string_ostream optionsStream(options); // For "printing" attrs.
+
   OperandRange dynamicOptions = getDynamicOptions();
-  size_t dynamicOptionsIdx = 0;
-  for (auto [idx, optionAttr] : llvm::enumerate(getOptions())) {
+  for (auto [idx, namedAttribute] : llvm::enumerate(getOptions())) {
     if (idx > 0)
-      options += " "; // Interleave options seperator.
-
-    if (auto strAttr = dyn_cast<StringAttr>(optionAttr)) {
-      options += strAttr.getValue();
-    } else if (isa<UnitAttr>(optionAttr)) {
-      assert(dynamicOptionsIdx < dynamicOptions.size() &&
+      optionsStream << " "; // Interleave options separator.
+    optionsStream << namedAttribute.getName().str(); // Append the key.
+    optionsStream << "="; // And the key-value separator.
+
+    Attribute valueAttrToAppend;
+    if (auto paramOperandIndex = dyn_cast<transform::ParamOperandIndexAttr>(
+            namedAttribute.getValue())) {
+      // The corresponding value attribute is passed in via a param.
+      // Obtain the param-operand via its specified index.
+      size_t dynamicOptionIdx = paramOperandIndex.getIndex().getInt();
+      assert(dynamicOptionIdx < dynamicOptions.size() &&
              "number of dynamic option markers (UnitAttr) in options ArrayAttr "
              "should be the same as the number of options passed as params");
       ArrayRef<Attribute> dynamicOption =
-          state.getParams(dynamicOptions[dynamicOptionsIdx++]);
+          state.getParams(dynamicOptions[dynamicOptionIdx]);
       if (dynamicOption.size() != 1)
-        return emitSilenceableError() << "options passed as a param must have "
-                                         "a single value associated, param "
-                                      << dynamicOptionsIdx - 1 << " associates "
-                                      << dynamicOption.size();
-
-      if (auto dynamicOptionStr = dyn_cast<StringAttr>(dynamicOption[0])) {
-        options += dynamicOptionStr.getValue();
-      } else {
         return emitSilenceableError()
-               << "options passed as a param must be a string, got "
-               << dynamicOption[0];
-      }
+               << "options passed as a param must have "
+                  "a single value associated, param "
+               << dynamicOptionIdx << " associates " << dynamicOption.size();
+      valueAttrToAppend = dynamicOption[0];
+    } else {
+      // Value is a static attribute.
+      valueAttrToAppend = namedAttribute.getValue();
+    }
+
+    // Append string representation of value attribute.
+    if (auto strAttr = dyn_cast<StringAttr>(valueAttrToAppend)) {
+      optionsStream << strAttr.getValue().str();
     } else {
-      llvm_unreachable(
-          "expected options element to be either StringAttr or UnitAttr");
+      valueAttrToAppend.print(optionsStream, /*elideType=*/true);
     }
   }
+  optionsStream.flush();
 
   // Get pass or pass pipeline from registry.
   const PassRegistryEntry *info = PassPipelineInfo::lookup(getPassName());
@@ -864,84 +874,116 @@ transform::ApplyRegisteredPassOp::apply(transform::TransformRewriter &rewriter,
 }
 
 static ParseResult parseApplyRegisteredPassOptions(
-    OpAsmParser &parser, ArrayAttr &options,
+    OpAsmParser &parser, DictionaryAttr &options,
     SmallVectorImpl<OpAsmParser::UnresolvedOperand> &dynamicOptions) {
-  auto dynamicOptionMarker = UnitAttr::get(parser.getContext());
-  SmallVector<Attribute> optionsArray;
-
-  auto parseOperandOrString = [&]() -> OptionalParseResult {
-    OpAsmParser::UnresolvedOperand operand;
-    OptionalParseResult parsedOperand = parser.parseOptionalOperand(operand);
-    if (parsedOperand.has_value()) {
-      if (failed(parsedOperand.value()))
-        return failure();
-
-      dynamicOptions.push_back(operand);
-      optionsArray.push_back(
-          dynamicOptionMarker); // Placeholder for knowing where to
-                                // inject the dynamic option-as-param.
-      return success();
-    }
+  // Construct the options DictionaryAttr per a `{ key = value, ... }` syntax.
+  SmallVector<NamedAttribute> keyValuePairs;
 
-    StringAttr stringAttr;
-    OptionalParseResult parsedStringAttr =
-        parser.parseOptionalAttribute(stringAttr);
-    if (parsedStringAttr.has_value()) {
-      if (failed(parsedStringAttr.value()))
-        return failure();
-      optionsArray.push_back(stringAttr);
-      return success();
-    }
+  size_t dynamicOptionsIdx = 0;
+  auto parseKeyValuePair = [&]() -> ParseResult {
+    // Parse items of the form `key = value` where `key` is a bare identifier or
+    // a string and `value` is either an attribute or an operand.
+
+    std::string key;
+    Attribute valueAttr;
+    if (parser.parseOptionalKeywordOrString(&key))
+      return parser.emitError(parser.getCurrentLocation())
+             << "expected key to either be an identifier or a string";
+    if (key.empty())
+      return failure();
 
-    return std::nullopt;
+    if (parser.parseEqual())
+      return parser.emitError(parser.getCurrentLocation())
+             << "expected '=' after key in key-value pair";
+
+    // Parse the value, which can be either an attribute or an operand.
+    OptionalParseResult parsedValueAttr =
+        parser.parseOptionalAttribute(valueAttr);
+    if (!parsedValueAttr.has_value()) {
+      OpAsmParser::UnresolvedOperand operand;
+      ParseResult parsedOperand = parser.parseOperand(operand);
+      if (failed(parsedOperand))
+        return parser.emitError(parser.getCurrentLocation())
+               << "expected a valid attribute or operand as value associated "
+               << "to key '" << key << "'";
+      dynamicOptions.push_back(operand);
+      auto wrappedIndex = IntegerAttr::get(
+          IntegerType::get(parser.getContext(), 64), dynamicOptionsIdx++);
+      valueAttr = transform::ParamOperandIndexAttr::get(parser.getContext(),
+                                                        wrappedIndex);
+    } else if (failed(parsedValueAttr.value())) {
+      return failure(); // NB: Attempted parse should have output error message.
+    } else if (isa<transform::ParamOperandIndexAttr>(valueAttr)) {
+      return parser.emitError(parser.getCurrentLocation())
+             << "the param_operand_index attribute is a marker reserved for "
+             << "indicating a value will be passed via params and is only used "
+             << "in the generic print format";
+    }
+
+    keyValuePairs.push_back(NamedAttribute(key, valueAttr));
+    return success();
   };
 
-  OptionalParseResult parsedOptionsElement = parseOperandOrString();
-  while (parsedOptionsElement.has_value()) {
-    if (failed(parsedOptionsElement.value()))
-      return failure();
-    parsedOptionsElement = parseOperandOrString();
-  }
+  if (parser.parseCommaSeparatedList(AsmParser::Delimiter::Braces,
+                                     parseKeyValuePair,
+                                     " in options dictionary"))
+    return failure(); // NB: Attempted parse should have output error message.
 
-  if (optionsArray.empty()) {
+  if (DictionaryAttr::findDuplicate(
+          keyValuePairs, /*isSorted=*/false) // Also sorts the keyValuePairs.
+          .has_value())
     return parser.emitError(parser.getCurrentLocation())
-           << "expected at least one option (either a string or a param)";
-  }
-  options = parser.getBuilder().getArrayAttr(optionsArray);
+           << "duplicate keys found in options dictionary";
+
+  options = DictionaryAttr::getWithSorted(parser.getContext(), keyValuePairs);
+
   return success();
 }
 
 static void printApplyRegisteredPassOptions(OpAsmPrinter &printer,
-                                            Operation *op, ArrayAttr options,
+                                            Operation *op,
+                                            DictionaryAttr options,
                                             ValueRange dynamicOptions) {
-  size_t currentDynamicOptionIdx = 0;
-  for (auto [idx, optionAttr] : llvm::enumerate(options)) {
-    if (idx > 0)
-      printer << " "; // Interleave options separator.
+  if (options.empty())
+    return;
 
-    if (isa<UnitAttr>(optionAttr))
-      printer.printOperand(dynamicOptions[currentDynamicOptionIdx++]);
-    else if (auto strAttr = dyn_cast<StringAttr>(optionAttr))
-      printer.printAttribute(strAttr);
-    else
-      llvm_unreachable("each option should be either a StringAttr or UnitAttr");
-  }
+  printer << "{";
+  llvm::interleaveComma(options, printer, [&](NamedAttribute namedAttribute) {
+    printer << namedAttribute.getName() << " = ";
+    Attribute value = namedAttribute.getValue();
+    if (auto indexAttr = dyn_cast<transform::ParamOperandIndexAttr>(value)) {
+      printer.printOperand(dynamicOptions[indexAttr.getIndex().getInt()]);
+    } else {
+      printer.printAttribute(value);
+    }
+  });
+  printer << "}";
 }
 
 LogicalResult transform::ApplyRegisteredPassOp::verify() {
-  size_t numUnitsInOptions = 0;
-  for (Attribute optionsElement : getOptions()) {
-    if (isa<UnitAttr>(optionsElement))
-      numUnitsInOptions++;
-    else if (!isa<StringAttr>(optionsElement))
-      return emitOpError() << "expected each option to be either a StringAttr "
-                           << "or a UnitAttr, got " << optionsElement;
-  }
-
-  if (getDynamicOptions().size() != numUnitsInOptions)
-    return emitOpError()
-           << "expected the same number of options passed as params as "
-           << "UnitAttr elements in options ArrayAttr";
+  // Check that there is a one-to-one correspondence between param operands
+  // and references to dynamic options in the options dictionary.
+
+  auto dynamicOptions = SmallVector<Value>(getDynamicOptions());
+  for (NamedAttribute namedAttr : getOptions())
+    if (auto paramOperandIndex =
+            dyn_cast<transform::ParamOperandIndexAttr>(namedAttr.getValue())) {
+      size_t dynamicOptionIdx = paramOperandIndex.getIndex().getInt();
+      if (dynamicOptionIdx < 0 || dynamicOptionIdx >= dynamicOptions.size())
+        return emitOpError()
+               << "dynamic option index " << dynamicOptionIdx
+               << " is out of bounds for the number of dynamic options: "
+               << dynamicOptions.size();
+      if (dynamicOptions[dynamicOptionIdx] == nullptr)
+        return emitOpError() << "dynamic option index " << dynamicOptionIdx
+                             << " is already used in options";
+      dynamicOptions[dynamicOptionIdx] = nullptr; // Mark this option as used.
+    }
+
+  for (Value dynamicOption : dynamicOptions)
+    if (dynamicOption)
+      return emitOpError() << "a param operand does not have a corresponding "
+                           << "param_operand_index attr in the options dict";
 
   return success();
 }
diff --git a/mlir/python/mlir/dialects/transform/__init__.py b/mlir/python/mlir/dialects/transform/__init__.py
index 176361ca32b35..4a97dcebd3cdb 100644
--- a/mlir/python/mlir/dialects/transform/__init__.py
+++ b/mlir/python/mlir/dialects/transform/__init__.py
@@ -18,7 +18,7 @@
 except ImportError as e:
     raise RuntimeError("Error loading imports from extension module") from e
 
-from typing import Optional, Sequence, Union, NewType
+from typing import Dict, Optional, Sequence, Union, NewType
 
 
 @_ods_cext.register_operation(_Dialect, replace=True)
@@ -214,43 +214,77 @@ def __init__(
         super().__init__(_get_op_results_or_values(operands), loc=loc, ip=ip)
 
 
+ at register_attribute_builder("ParamOperandIndexAttr")
+def _paramOperandIndexAttr(x: int, context) -> Attribute:
+    return Attribute.parse(f"#transform.param_operand_index<{x}>", context=context)
+
+
 @_ods_cext.register_operation(_Dialect, replace=True)
 class ApplyRegisteredPassOp(ApplyRegisteredPassOp):
     def __init__(
         self,
         result: Type,
         pass_name: Union[str, StringAttr],
-        target: Value,
+        target: Union[Operation, Value, OpView],
         *,
-        options: Sequence[Union[str, StringAttr, Value, Operation]] = [],
+        options: Dict[
+            Union[str, StringAttr],
+            Union[Attribute, Value, Operation, OpView],
+        ] = {},
         loc=None,
         ip=None,
     ):
-        static_options = []
+        options_dict = {}
         dynamic_options = []
-        for opt in options:
-            if isinstance(opt, str):
-                static_options.append(StringAttr.get(opt))
-            elif isinstance(opt, StringAttr):
-                static_options.append(opt)
-            elif isinstance(opt, Value):
-                static_options.append(UnitAttr.get())
-                dynamic_options.append(_get_op_result_or_value(opt))
+
+        ParamOperandIndexAttr = AttrBuilder.get("ParamOperandIndexAttr")
+        context = (loc and loc.context) or Context.current
+
+        cur_param_operand_idx = 0
+        for key, value in options.items():
+            if isinstance(key, StringAttr):
+                key = key.value
+
+            if isinstance(value, (Value, Operation, OpView)):
+                value = _get_op_result_or_value(value)
+                # v = Attribute.parse(
+                #    f"#transform.param_operand_index<{cur_param_operand_idx}>",
+                #    context=context,
+                # )
+                v = _paramOperandIndexAttr(cur_param_operand_idx, context)
+                options_dict[key] = v
+                cur_param_operand_idx += 1
+                dynamic_options.append(value)
+            elif isinstance(value, Attribute):
+                options_dict[key] = value
+            elif isinstance(value, str):
+                options_dict[key] = StringAttr.get(value)
             else:
-                raise TypeError(f"Unsupported option type: {type(opt)}")
+                raise TypeError(f"Unsupported option type: {type(value)}")
+        if len(options_dict) > 0:
+            print(options_dict, cur_param_operand_idx)
         super().__init__(
             result,
             pass_name,
             dynamic_options,
             target=_get_op_result_or_value(target),
-            options=static_options,
+            options=DictAttr.get(options_dict),
             loc=loc,
             ip=ip,
         )
 
 
 def apply_registered_pass(
-    result, pass_name, target, *, options=[], loc=None, ip=None
+    result: Type,
+    pass_name: Union[str, StringAttr],
+    target: Union[Operation, Value, OpView],
+    *,
+    options: Dict[
+        Union[str, StringAttr],
+        Union[Attribute, Value, Operation, OpView],
+    ] = {},
+    loc=None,
+    ip=None,
 ) -> Value:
     return ApplyRegisteredPassOp(
         result=result,
diff --git a/mlir/test/Dialect/Transform/test-pass-application.mlir b/mlir/test/Dialect/Transform/test-pass-application.mlir
index 463fd98afa65c..ea1dd4abd4db2 100644
--- a/mlir/test/Dialect/Transform/test-pass-application.mlir
+++ b/mlir/test/Dialect/Transform/test-pass-application.mlir
@@ -80,7 +80,7 @@ module attributes {transform.with_named_sequence} {
     // expected-error @below {{failed to add pass or pass pipeline to pipeline: canonicalize}}
     // expected-error @below {{<Pass-Options-Parser>: no such option invalid-option}}
     transform.apply_registered_pass "canonicalize"
-        with options = "invalid-option=1" to %1
+        with options = { "invalid-option" = 1 } to %1
         : (!transform.any_op) -> !transform.any_op
     transform.yield
   }
@@ -97,7 +97,7 @@ module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%arg1: !transform.any_op) {
     %1 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
     transform.apply_registered_pass "canonicalize"
-        with options = "top-down=false" to %1
+        with options = { "top-down" = false } to %1
         : (!transform.any_op) -> !transform.any_op
     transform.yield
   }
@@ -115,7 +115,7 @@ module attributes {transform.with_named_sequence} {
     %1 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
     //transform.apply_registered_pass "canonicalize" with options = "top-down=false,max-iterations=10" to %1 : (!transform.any_op) -> !transform.any_op
     transform.apply_registered_pass "canonicalize"
-        with options = "top-down=false test-convergence=true" to %1
+        with options = { "top-down" = false, "test-convergence" =true } to %1
         : (!transform.any_op) -> !transform.any_op
     transform.yield
   }
@@ -132,7 +132,7 @@ module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%arg1: !transform.any_op) {
     %1 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
     transform.apply_registered_pass "canonicalize"
-        with options = "top-down=false" "max-iterations=0" to %1
+        with options = { "top-down" = false, "max-iterations" = 0 } to %1
         : (!transform.any_op) -> !transform.any_op
     transform.yield
   }
@@ -148,10 +148,15 @@ func.func @valid_dynamic_pass_options() {
 module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%arg1: !transform.any_op) {
     %1 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
-    %max_iter = transform.param.constant "max-iterations=10" -> !transform.any_param
-    %max_rewrites = transform.param.constant "max-num-rewrites=1" -> !transform.any_param
-    %2 = transform.apply_registered_pass "canonicalize"
-        with options = "top-down=false" %max_iter "test-convergence=true" %max_rewrites to %1
+    %max_iter = transform.param.constant 10 -> !transform.any_param
+    %max_rewrites = transform.param.constant 1 -> !transform.any_param
+    %2 = transform.apply_registered_pass
+        "canonicalize"
+        with options = { "top-down" = false,
+                         "max-iterations" = %max_iter,
+                         "test-convergence" = true,
+                         "max-num-rewrites" =  %max_rewrites }
+        to %1
         : (!transform.any_param, !transform.any_param, !transform.any_op) -> !transform.any_op
     transform.yield
   }
@@ -159,7 +164,7 @@ module attributes {transform.with_named_sequence} {
 
 // -----
 
-func.func @invalid_dynamic_options_as_array() {
+func.func @invalid_options_as_str() {
   return
 }
 
@@ -167,34 +172,80 @@ module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%arg1: !transform.any_op) {
     %1 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
     %max_iter = transform.param.constant "max-iterations=10" -> !transform.any_param
-    // expected-error @+2 {{expected at least one option (either a string or a param)}}
+    // expected-error @+2 {{expected '{' in options dictionary}}
     %2 = transform.apply_registered_pass "canonicalize"
-        with options = ["top-down=false" %max_iter] to %1
-        : (!transform.any_param, !transform.any_param, !transform.any_op) -> !transform.any_op
+        with options = "top-down=false" to %1 : (!transform.any_op) -> !transform.any_op
     transform.yield
   }
 }
 
 // -----
 
-func.func @invalid_options_as_pairs() {
+func.func @invalid_options_as_pairs_without_braces() {
   return
 }
 
 module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%arg1: !transform.any_op) {
     %1 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
-    // expected-error @+2 {{expected 'to'}}
+    // expected-error @+2 {{expected '{' in options dictionary}}
     %2 = transform.apply_registered_pass "canonicalize"
-        with options = "top-down=" false to %1
-        : (!transform.any_param, !transform.any_param, !transform.any_op) -> !transform.any_op
+        with options = "top-down"=false to %1 : (!transform.any_op) -> !transform.any_op
     transform.yield
   }
 }
 
 // -----
 
-func.func @invalid_pass_option_param() {
+func.func @invalid_options_due_to_reserved_attr() {
+  return
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op) {
+    %1 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    // expected-error @+2 {{the param_operand_index attribute is a marker reserved for indicating a value will be passed via params and is only used in the generic print format}}
+    %2 = transform.apply_registered_pass "canonicalize"
+        with options = { "top-down" = #transform.param_operand_index<0> } to %1 : (!transform.any_op) -> !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+func.func @invalid_options_due_duplicated_key() {
+  return
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op) {
+    %1 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    // expected-error @+2 {{duplicate keys found in options dictionary}}
+    %2 = transform.apply_registered_pass "canonicalize"
+        with options = {"top-down"=false,"top-down"=true} to %1 : (!transform.any_op) -> !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+func.func @invalid_options_due_invalid_key() {
+  return
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op) {
+    %1 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    // expected-error @+2 {{expected key to either be an identifier or a string}}
+    %2 = transform.apply_registered_pass "canonicalize"
+        with options = { @label = 0 } to %1 : (!transform.any_op) -> !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+func.func @invalid_pass_option_bare_param() {
   return
 }
 
@@ -202,7 +253,7 @@ module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%arg1: !transform.any_op) {
     %1 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
     %pass_options = transform.param.constant 42 -> !transform.any_param
-    // expected-error @below {{options passed as a param must be a string, got 42}}
+    // expected-error @+2 {{expected '{' in options dictionary}}
     transform.apply_registered_pass "canonicalize"
         with options = %pass_options to %1
         : (!transform.any_param, !transform.any_op) -> !transform.any_op
@@ -219,12 +270,12 @@ func.func @too_many_pass_option_params() {
 module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%arg1: !transform.any_op) {
     %1 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
-    %x = transform.param.constant "x" -> !transform.any_param
-    %y = transform.param.constant "y" -> !transform.any_param
-    %pass_options = transform.merge_handles %x, %y : !transform.any_param
+    %x = transform.param.constant true -> !transform.any_param
+    %y = transform.param.constant false -> !transform.any_param
+    %topdown_options = transform.merge_handles %x, %y : !transform.any_param
     // expected-error @below {{options passed as a param must have a single value associated, param 0 associates 2}}
     transform.apply_registered_pass "canonicalize"
-        with options = %pass_options to %1
+        with options = { "top-down" = %topdown_options } to %1
         : (!transform.any_param, !transform.any_op) -> !transform.any_op
     transform.yield
   }
@@ -248,3 +299,60 @@ module attributes {transform.with_named_sequence} {
     transform.yield
   }
 }
+
+// -----
+
+"builtin.module"() ({
+  "transform.named_sequence"() <{function_type = (!transform.any_op) -> (), sym_name = "__transform_main"}> ({
+  ^bb0(%arg0: !transform.any_op):
+    %0 = "transform.structured.match"(%arg0) <{ops = ["func.func"]}> : (!transform.any_op) -> !transform.any_op
+    %1 = "transform.param.constant"() <{value = 10 : i64}> : () -> !transform.any_param
+    // expected-error @below {{dynamic option index 1 is out of bounds for the number of dynamic options: 1}}
+    %2 = "transform.apply_registered_pass"(%1, %0) <{
+      options = {"max-iterations" = #transform.param_operand_index<1 : i64>,
+                 "test-convergence" = true,
+                 "top-down" = false},
+      pass_name = "canonicalize"}>
+    : (!transform.any_param, !transform.any_op) -> !transform.any_op
+    "transform.yield"() : () -> ()
+  }) : () -> ()
+}) {transform.with_named_sequence} : () -> ()
+
+// -----
+
+"builtin.module"() ({
+  "transform.named_sequence"() <{function_type = (!transform.any_op) -> (), sym_name = "__transform_main"}> ({
+  ^bb0(%arg0: !transform.any_op):
+    %0 = "transform.structured.match"(%arg0) <{ops = ["func.func"]}> : (!transform.any_op) -> !transform.any_op
+    %1 = "transform.param.constant"() <{value = 10 : i64}> : () -> !transform.any_param
+    %2 = "transform.param.constant"() <{value = 1 : i64}> : () -> !transform.any_param
+    // expected-error @below {{dynamic option index 0 is already used in options}}
+    %3 = "transform.apply_registered_pass"(%1, %2, %0) <{
+      options = {"max-iterations" = #transform.param_operand_index<0 : i64>,
+                 "max-num-rewrites" = #transform.param_operand_index<0 : i64>,
+                 "test-convergence" = true,
+                 "top-down" = false},
+      pass_name = "canonicalize"}>
+    : (!transform.any_param, !transform.any_param, !transform.any_op) -> !transform.any_op
+    "transform.yield"() : () -> ()
+  }) : () -> ()
+}) {transform.with_named_sequence} : () -> ()
+
+// -----
+
+"builtin.module"() ({
+  "transform.named_sequence"() <{function_type = (!transform.any_op) -> (), sym_name = "__transform_main"}> ({
+  ^bb0(%arg0: !transform.any_op):
+    %0 = "transform.structured.match"(%arg0) <{ops = ["func.func"]}> : (!transform.any_op) -> !transform.any_op
+    %1 = "transform.param.constant"() <{value = 10 : i64}> : () -> !transform.any_param
+    %2 = "transform.param.constant"() <{value = 1 : i64}> : () -> !transform.any_param
+    // expected-error @below {{a param operand does not have a corresponding param_operand_index attr in the options dict}}
+    %3 = "transform.apply_registered_pass"(%1, %2, %0) <{
+      options = {"max-iterations" = #transform.param_operand_index<0 : i64>,
+                 "test-convergence" = true,
+                 "top-down" = false},
+      pass_name = "canonicalize"}>
+    : (!transform.any_param, !transform.any_param, !transform.any_op) -> !transform.any_op
+    "transform.yield"() : () -> ()
+  }) : () -> ()
+}) {transform.with_named_sequence} : () -> ()
diff --git a/mlir/test/python/dialects/transform.py b/mlir/test/python/dialects/transform.py
index 6492b58570814..48bc9bad37a1e 100644
--- a/mlir/test/python/dialects/transform.py
+++ b/mlir/test/python/dialects/transform.py
@@ -269,27 +269,40 @@ def testApplyRegisteredPassOp(module: Module):
             transform.AnyOpType.get(),
             "canonicalize",
             mod.result,
-            options=("top-down=false",),
+            options={"top-down": BoolAttr.get(False)},
         )
         max_iter = transform.param_constant(
-            transform.AnyParamType.get(), StringAttr.get("max-iterations=10")
+            transform.AnyParamType.get(),
+            IntegerAttr.get(IntegerType.get_signless(64), 10),
         )
         max_rewrites = transform.param_constant(
-            transform.AnyParamType.get(), StringAttr.get("max-num-rewrites=1")
+            transform.AnyParamType.get(),
+            IntegerAttr.get(IntegerType.get_signless(64), 1),
         )
         transform.apply_registered_pass(
             transform.AnyOpType.get(),
             "canonicalize",
             mod,
-            options=("top-down=false", max_iter, "test-convergence=true", max_rewrites),
+            options={
+                "top-down": BoolAttr.get(False),
+                "max-iterations": max_iter,
+                "test-convergence": BoolAttr.get(True),
+                "max-rewrites": max_rewrites,
+            },
         )
         transform.YieldOp()
     # CHECK-LABEL: TEST: testApplyRegisteredPassOp
     # CHECK: transform.sequence
     # CHECK:   %{{.*}} = apply_registered_pass "canonicalize" to {{.*}} : (!transform.any_op) -> !transform.any_op
-    # CHECK:   %{{.*}} = apply_registered_pass "canonicalize" with options = "top-down=false" to {{.*}} : (!transform.any_op) -> !transform.any_op
+    # CHECK:   %{{.*}} = apply_registered_pass "canonicalize"
+    # CHECK-SAME:    with options = {"top-down" = false}
+    # CHECK-SAME:    to {{.*}} : (!transform.any_op) -> !transform.any_op
     # CHECK:   %[[MAX_ITER:.+]] = transform.param.constant
     # CHECK:   %[[MAX_REWRITE:.+]] = transform.param.constant
     # CHECK:   %{{.*}} = apply_registered_pass "canonicalize"
-    # CHECK-SAME:    with options = "top-down=false" %[[MAX_ITER]]
-    # CHECK-SAME:   "test-convergence=true" %[[MAX_REWRITE]] to %{{.*}} : (!transform.any_param, !transform.any_param, !transform.any_op) -> !transform.any_op
+    # NB: MLIR has sorted the dict lexicographically by key:
+    # CHECK-SAME:    with options = {"max-iterations" = %[[MAX_ITER]],
+    # CHECK-SAME:                    "max-rewrites" =  %[[MAX_REWRITE]],
+    # CHECK-SAME:                    "test-convergence" = true,
+    # CHECK-SAME:                    "top-down" = false}
+    # CHECK-SAME:    to %{{.*}} : (!transform.any_param, !transform.any_param, !transform.any_op) -> !transform.any_op

>From e1a7803942373b84b20c6edf4538e4d9c0302aeb Mon Sep 17 00:00:00 2001
From: Rolf Morel <rolf.morel at intel.com>
Date: Sun, 8 Jun 2025 13:10:49 -0700
Subject: [PATCH 5/7] Fix Python signature

---
 .../mlir/dialects/transform/__init__.py       | 44 +++++++++----------
 1 file changed, 22 insertions(+), 22 deletions(-)

diff --git a/mlir/python/mlir/dialects/transform/__init__.py b/mlir/python/mlir/dialects/transform/__init__.py
index 4a97dcebd3cdb..6d4abe8882e9d 100644
--- a/mlir/python/mlir/dialects/transform/__init__.py
+++ b/mlir/python/mlir/dialects/transform/__init__.py
@@ -21,6 +21,11 @@
 from typing import Dict, Optional, Sequence, Union, NewType
 
 
+ at register_attribute_builder("ParamOperandIndexAttr")
+def _paramOperandIndexAttr(x: int, context) -> Attribute:
+    return Attribute.parse(f"#transform.param_operand_index<{x}>", context=context)
+
+
 @_ods_cext.register_operation(_Dialect, replace=True)
 class CastOp(CastOp):
     def __init__(
@@ -214,11 +219,6 @@ def __init__(
         super().__init__(_get_op_results_or_values(operands), loc=loc, ip=ip)
 
 
- at register_attribute_builder("ParamOperandIndexAttr")
-def _paramOperandIndexAttr(x: int, context) -> Attribute:
-    return Attribute.parse(f"#transform.param_operand_index<{x}>", context=context)
-
-
 @_ods_cext.register_operation(_Dialect, replace=True)
 class ApplyRegisteredPassOp(ApplyRegisteredPassOp):
     def __init__(
@@ -227,10 +227,12 @@ def __init__(
         pass_name: Union[str, StringAttr],
         target: Union[Operation, Value, OpView],
         *,
-        options: Dict[
-            Union[str, StringAttr],
-            Union[Attribute, Value, Operation, OpView],
-        ] = {},
+        options: Optional[
+            Dict[
+                Union[str, StringAttr],
+                Union[Attribute, Value, Operation, OpView],
+            ]
+        ] = None,
         loc=None,
         ip=None,
     ):
@@ -241,20 +243,16 @@ def __init__(
         context = (loc and loc.context) or Context.current
 
         cur_param_operand_idx = 0
-        for key, value in options.items():
+        for key, value in options.items() if options is not None else {}:
             if isinstance(key, StringAttr):
                 key = key.value
 
             if isinstance(value, (Value, Operation, OpView)):
-                value = _get_op_result_or_value(value)
-                # v = Attribute.parse(
-                #    f"#transform.param_operand_index<{cur_param_operand_idx}>",
-                #    context=context,
-                # )
-                v = _paramOperandIndexAttr(cur_param_operand_idx, context)
-                options_dict[key] = v
+                dynamic_options.append(_get_op_result_or_value(value))
+                options_dict[key] = ParamOperandIndexAttr(
+                    cur_param_operand_idx, context
+                )
                 cur_param_operand_idx += 1
-                dynamic_options.append(value)
             elif isinstance(value, Attribute):
                 options_dict[key] = value
             elif isinstance(value, str):
@@ -279,10 +277,12 @@ def apply_registered_pass(
     pass_name: Union[str, StringAttr],
     target: Union[Operation, Value, OpView],
     *,
-    options: Dict[
-        Union[str, StringAttr],
-        Union[Attribute, Value, Operation, OpView],
-    ] = {},
+    options: Optional[
+        Dict[
+            Union[str, StringAttr],
+            Union[Attribute, Value, Operation, OpView],
+        ]
+    ] = None,
     loc=None,
     ip=None,
 ) -> Value:

>From 079b3dbd75d791221155081e07061f6910360cf1 Mon Sep 17 00:00:00 2001
From: Rolf Morel <rolf.morel at intel.com>
Date: Mon, 9 Jun 2025 16:22:10 -0700
Subject: [PATCH 6/7] Add comments to generic format tests

---
 .../Transform/test-pass-application.mlir        | 17 +++++++++++++++++
 1 file changed, 17 insertions(+)

diff --git a/mlir/test/Dialect/Transform/test-pass-application.mlir b/mlir/test/Dialect/Transform/test-pass-application.mlir
index ea1dd4abd4db2..86615f05e4a01 100644
--- a/mlir/test/Dialect/Transform/test-pass-application.mlir
+++ b/mlir/test/Dialect/Transform/test-pass-application.mlir
@@ -302,6 +302,14 @@ module attributes {transform.with_named_sequence} {
 
 // -----
 
+/////////////////////////////////////////////////////////////////////
+// Check that the following cases are caugh in the generic format. //
+/////////////////////////////////////////////////////////////////////
+
+// Invalid due to param_operand_index occurences in options dict not being
+// one-to-one with the dynamic options provided as params:
+//   param_operand_index out of bounds w.r.t. the number of options provided via params.
+
 "builtin.module"() ({
   "transform.named_sequence"() <{function_type = (!transform.any_op) -> (), sym_name = "__transform_main"}> ({
   ^bb0(%arg0: !transform.any_op):
@@ -320,6 +328,11 @@ module attributes {transform.with_named_sequence} {
 
 // -----
 
+// Invalid due to param_operand_index occurences in options dict not being
+// one-to-one with the dynamic options provided as params:
+//   the first option-param is referred to twice and the second one not at all.
+// (The pretty-printed format supports this by passing in the same param twice.)
+
 "builtin.module"() ({
   "transform.named_sequence"() <{function_type = (!transform.any_op) -> (), sym_name = "__transform_main"}> ({
   ^bb0(%arg0: !transform.any_op):
@@ -340,6 +353,10 @@ module attributes {transform.with_named_sequence} {
 
 // -----
 
+// Invalid due to param_operand_index occurences in options dict not being
+// one-to-one with the dynamic options provided as params:
+//   two option-params are provide though only the first one is referred to from the options-dict.
+
 "builtin.module"() ({
   "transform.named_sequence"() <{function_type = (!transform.any_op) -> (), sym_name = "__transform_main"}> ({
   ^bb0(%arg0: !transform.any_op):

>From a8d0b3355ead35f9f5810fc53fbc2a3d9eea8480 Mon Sep 17 00:00:00 2001
From: Rolf Morel <rolf.morel at intel.com>
Date: Wed, 11 Jun 2025 06:54:04 -0700
Subject: [PATCH 7/7] Address @adam-smnk's review

---
 .../Dialect/Transform/IR/TransformAttrs.td    |  6 ++---
 .../lib/Dialect/Transform/IR/TransformOps.cpp | 27 +++++++++++--------
 .../mlir/dialects/transform/__init__.py       | 10 +++----
 .../Transform/test-pass-application.mlir      | 22 +++++++--------
 4 files changed, 34 insertions(+), 31 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformAttrs.td b/mlir/include/mlir/Dialect/Transform/IR/TransformAttrs.td
index c25b472c10e75..e67a9444c24a8 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformAttrs.td
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformAttrs.td
@@ -41,9 +41,7 @@ def MatchCmpIPredicateAttr : I32EnumAttr<
   let cppNamespace = "::mlir::transform";
 }
 
-def ParamOperandIndexAttr : Transform_Attr<"ParamOperandIndex",
-                                           "param_operand_index" > {
-  let mnemonic = "param_operand_index";
+def ParamOperandAttr : Transform_Attr<"ParamOperand", "param_operand"> {
   let description = [{
     Used to refer to a specific param-operand (via its index) from within an
     attribute on a transform operation.
@@ -51,7 +49,7 @@ def ParamOperandIndexAttr : Transform_Attr<"ParamOperandIndex",
   let parameters = (ins
     "IntegerAttr":$index
   );
-  let assemblyFormat = "`<` $index `>`";
+  let assemblyFormat = "`<` `index` `=` $index `>`";
 }
 
 #endif  // MLIR_DIALECT_TRANSFORM_IR_TRANSFORMATTRS
diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
index 322c89ec7d7f6..582d082153bef 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
@@ -800,8 +800,8 @@ transform::ApplyRegisteredPassOp::apply(transform::TransformRewriter &rewriter,
     optionsStream << "="; // And the key-value separator.
 
     Attribute valueAttrToAppend;
-    if (auto paramOperandIndex = dyn_cast<transform::ParamOperandIndexAttr>(
-            namedAttribute.getValue())) {
+    if (auto paramOperandIndex =
+            dyn_cast<transform::ParamOperandAttr>(namedAttribute.getValue())) {
       // The corresponding value attribute is passed in via a param.
       // Obtain the param-operand via its specified index.
       size_t dynamicOptionIdx = paramOperandIndex.getIndex().getInt();
@@ -906,16 +906,20 @@ static ParseResult parseApplyRegisteredPassOptions(
         return parser.emitError(parser.getCurrentLocation())
                << "expected a valid attribute or operand as value associated "
                << "to key '" << key << "'";
+      // To make use of the operand, we need to store it in the options dict.
+      // As SSA-values cannot occur in attributes, what we do instead is store
+      // an attribute in its place that contains the index of the param-operand,
+      // so that an attr-value associated to the param can be resolved later on.
       dynamicOptions.push_back(operand);
       auto wrappedIndex = IntegerAttr::get(
           IntegerType::get(parser.getContext(), 64), dynamicOptionsIdx++);
-      valueAttr = transform::ParamOperandIndexAttr::get(parser.getContext(),
-                                                        wrappedIndex);
+      valueAttr =
+          transform::ParamOperandAttr::get(parser.getContext(), wrappedIndex);
     } else if (failed(parsedValueAttr.value())) {
       return failure(); // NB: Attempted parse should have output error message.
-    } else if (isa<transform::ParamOperandIndexAttr>(valueAttr)) {
+    } else if (isa<transform::ParamOperandAttr>(valueAttr)) {
       return parser.emitError(parser.getCurrentLocation())
-             << "the param_operand_index attribute is a marker reserved for "
+             << "the param_operand attribute is a marker reserved for "
              << "indicating a value will be passed via params and is only used "
              << "in the generic print format";
     }
@@ -951,7 +955,8 @@ static void printApplyRegisteredPassOptions(OpAsmPrinter &printer,
   llvm::interleaveComma(options, printer, [&](NamedAttribute namedAttribute) {
     printer << namedAttribute.getName() << " = ";
     Attribute value = namedAttribute.getValue();
-    if (auto indexAttr = dyn_cast<transform::ParamOperandIndexAttr>(value)) {
+    if (auto indexAttr = dyn_cast<transform::ParamOperandAttr>(value)) {
+      // Resolve index of param-operand to its actual SSA-value and print that.
       printer.printOperand(dynamicOptions[indexAttr.getIndex().getInt()]);
     } else {
       printer.printAttribute(value);
@@ -966,9 +971,9 @@ LogicalResult transform::ApplyRegisteredPassOp::verify() {
 
   auto dynamicOptions = SmallVector<Value>(getDynamicOptions());
   for (NamedAttribute namedAttr : getOptions())
-    if (auto paramOperandIndex =
-            dyn_cast<transform::ParamOperandIndexAttr>(namedAttr.getValue())) {
-      size_t dynamicOptionIdx = paramOperandIndex.getIndex().getInt();
+    if (auto paramOperand =
+            dyn_cast<transform::ParamOperandAttr>(namedAttr.getValue())) {
+      size_t dynamicOptionIdx = paramOperand.getIndex().getInt();
       if (dynamicOptionIdx < 0 || dynamicOptionIdx >= dynamicOptions.size())
         return emitOpError()
                << "dynamic option index " << dynamicOptionIdx
@@ -983,7 +988,7 @@ LogicalResult transform::ApplyRegisteredPassOp::verify() {
   for (Value dynamicOption : dynamicOptions)
     if (dynamicOption)
       return emitOpError() << "a param operand does not have a corresponding "
-                           << "param_operand_index attr in the options dict";
+                           << "param_operand attr in the options dict";
 
   return success();
 }
diff --git a/mlir/python/mlir/dialects/transform/__init__.py b/mlir/python/mlir/dialects/transform/__init__.py
index 6d4abe8882e9d..ed53c6af5086a 100644
--- a/mlir/python/mlir/dialects/transform/__init__.py
+++ b/mlir/python/mlir/dialects/transform/__init__.py
@@ -21,9 +21,9 @@
 from typing import Dict, Optional, Sequence, Union, NewType
 
 
- at register_attribute_builder("ParamOperandIndexAttr")
-def _paramOperandIndexAttr(x: int, context) -> Attribute:
-    return Attribute.parse(f"#transform.param_operand_index<{x}>", context=context)
+ at register_attribute_builder("ParamOperandAttr")
+def _paramOperandAttr(x: int, context) -> Attribute:
+    return Attribute.parse(f"#transform.param_operand<index={x}>", context=context)
 
 
 @_ods_cext.register_operation(_Dialect, replace=True)
@@ -239,7 +239,7 @@ def __init__(
         options_dict = {}
         dynamic_options = []
 
-        ParamOperandIndexAttr = AttrBuilder.get("ParamOperandIndexAttr")
+        ParamOperandAttr = AttrBuilder.get("ParamOperandAttr")
         context = (loc and loc.context) or Context.current
 
         cur_param_operand_idx = 0
@@ -249,7 +249,7 @@ def __init__(
 
             if isinstance(value, (Value, Operation, OpView)):
                 dynamic_options.append(_get_op_result_or_value(value))
-                options_dict[key] = ParamOperandIndexAttr(
+                options_dict[key] = ParamOperandAttr(
                     cur_param_operand_idx, context
                 )
                 cur_param_operand_idx += 1
diff --git a/mlir/test/Dialect/Transform/test-pass-application.mlir b/mlir/test/Dialect/Transform/test-pass-application.mlir
index 86615f05e4a01..6e6d4eb7e249f 100644
--- a/mlir/test/Dialect/Transform/test-pass-application.mlir
+++ b/mlir/test/Dialect/Transform/test-pass-application.mlir
@@ -204,9 +204,9 @@ func.func @invalid_options_due_to_reserved_attr() {
 module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%arg1: !transform.any_op) {
     %1 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
-    // expected-error @+2 {{the param_operand_index attribute is a marker reserved for indicating a value will be passed via params and is only used in the generic print format}}
+    // expected-error @+2 {{the param_operand attribute is a marker reserved for indicating a value will be passed via params and is only used in the generic print format}}
     %2 = transform.apply_registered_pass "canonicalize"
-        with options = { "top-down" = #transform.param_operand_index<0> } to %1 : (!transform.any_op) -> !transform.any_op
+        with options = { "top-down" = #transform.param_operand<index=0> } to %1 : (!transform.any_op) -> !transform.any_op
     transform.yield
   }
 }
@@ -306,7 +306,7 @@ module attributes {transform.with_named_sequence} {
 // Check that the following cases are caugh in the generic format. //
 /////////////////////////////////////////////////////////////////////
 
-// Invalid due to param_operand_index occurences in options dict not being
+// Invalid due to param_operand occurences in options dict not being
 // one-to-one with the dynamic options provided as params:
 //   param_operand_index out of bounds w.r.t. the number of options provided via params.
 
@@ -317,7 +317,7 @@ module attributes {transform.with_named_sequence} {
     %1 = "transform.param.constant"() <{value = 10 : i64}> : () -> !transform.any_param
     // expected-error @below {{dynamic option index 1 is out of bounds for the number of dynamic options: 1}}
     %2 = "transform.apply_registered_pass"(%1, %0) <{
-      options = {"max-iterations" = #transform.param_operand_index<1 : i64>,
+      options = {"max-iterations" = #transform.param_operand<index=1 : i64>,
                  "test-convergence" = true,
                  "top-down" = false},
       pass_name = "canonicalize"}>
@@ -328,10 +328,10 @@ module attributes {transform.with_named_sequence} {
 
 // -----
 
-// Invalid due to param_operand_index occurences in options dict not being
+// Invalid due to param_operand occurences in options dict not being
 // one-to-one with the dynamic options provided as params:
 //   the first option-param is referred to twice and the second one not at all.
-// (The pretty-printed format supports this by passing in the same param twice.)
+// (In the pretty-printed format, if you want to refer to a param SSA-value twice, it counts as two param arguments.)
 
 "builtin.module"() ({
   "transform.named_sequence"() <{function_type = (!transform.any_op) -> (), sym_name = "__transform_main"}> ({
@@ -341,8 +341,8 @@ module attributes {transform.with_named_sequence} {
     %2 = "transform.param.constant"() <{value = 1 : i64}> : () -> !transform.any_param
     // expected-error @below {{dynamic option index 0 is already used in options}}
     %3 = "transform.apply_registered_pass"(%1, %2, %0) <{
-      options = {"max-iterations" = #transform.param_operand_index<0 : i64>,
-                 "max-num-rewrites" = #transform.param_operand_index<0 : i64>,
+      options = {"max-iterations" = #transform.param_operand<index=0 : i64>,
+                 "max-num-rewrites" = #transform.param_operand<index=0 : i64>,
                  "test-convergence" = true,
                  "top-down" = false},
       pass_name = "canonicalize"}>
@@ -353,7 +353,7 @@ module attributes {transform.with_named_sequence} {
 
 // -----
 
-// Invalid due to param_operand_index occurences in options dict not being
+// Invalid due to param_operand occurences in options dict not being
 // one-to-one with the dynamic options provided as params:
 //   two option-params are provide though only the first one is referred to from the options-dict.
 
@@ -363,9 +363,9 @@ module attributes {transform.with_named_sequence} {
     %0 = "transform.structured.match"(%arg0) <{ops = ["func.func"]}> : (!transform.any_op) -> !transform.any_op
     %1 = "transform.param.constant"() <{value = 10 : i64}> : () -> !transform.any_param
     %2 = "transform.param.constant"() <{value = 1 : i64}> : () -> !transform.any_param
-    // expected-error @below {{a param operand does not have a corresponding param_operand_index attr in the options dict}}
+    // expected-error @below {{a param operand does not have a corresponding param_operand attr in the options dict}}
     %3 = "transform.apply_registered_pass"(%1, %2, %0) <{
-      options = {"max-iterations" = #transform.param_operand_index<0 : i64>,
+      options = {"max-iterations" = #transform.param_operand<index=0 : i64>,
                  "test-convergence" = true,
                  "top-down" = false},
       pass_name = "canonicalize"}>



More information about the Mlir-commits mailing list