[Mlir-commits] [mlir] [MLIR][Transform] apply_registered_pass: support ListOptions (PR #144026)

Rolf Morel llvmlistbot at llvm.org
Mon Jun 16 03:47:35 PDT 2025


https://github.com/rolfmorel updated https://github.com/llvm/llvm-project/pull/144026

>From ec61920910979e53579ec144e92b677504c0254b Mon Sep 17 00:00:00 2001
From: Rolf Morel <rolf.morel at intel.com>
Date: Fri, 13 Jun 2025 00:13:42 -0700
Subject: [PATCH 1/7] [MLIR][Transform] apply_registered_pass: support
 ListOptions as params

Interpret the multiple values associated to a param as a
comma-separated list, i.e. as the analog of a ListOption on a pass.
---
 .../lib/Dialect/Transform/IR/TransformOps.cpp | 28 +++++-----
 .../Transform/test-pass-application.mlir      | 52 ++++++++++++-------
 2 files changed, 45 insertions(+), 35 deletions(-)

diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
index 582d082153bef..bfe6416987629 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
@@ -791,6 +791,12 @@ transform::ApplyRegisteredPassOp::apply(transform::TransformRewriter &rewriter,
 
   std::string options;
   llvm::raw_string_ostream optionsStream(options); // For "printing" attrs.
+  auto appendValueAttr = [&](Attribute valueAttr) {
+    if (auto strAttr = dyn_cast<StringAttr>(valueAttr))
+      optionsStream << strAttr.getValue().str();
+    else
+      valueAttr.print(optionsStream, /*elideType=*/true);
+  };
 
   OperandRange dynamicOptions = getDynamicOptions();
   for (auto [idx, namedAttribute] : llvm::enumerate(getOptions())) {
@@ -799,7 +805,6 @@ transform::ApplyRegisteredPassOp::apply(transform::TransformRewriter &rewriter,
     optionsStream << namedAttribute.getName().str(); // Append the key.
     optionsStream << "="; // And the key-value separator.
 
-    Attribute valueAttrToAppend;
     if (auto paramOperandIndex =
             dyn_cast<transform::ParamOperandAttr>(namedAttribute.getValue())) {
       // The corresponding value attribute is passed in via a param.
@@ -810,22 +815,15 @@ transform::ApplyRegisteredPassOp::apply(transform::TransformRewriter &rewriter,
              "should be the same as the number of options passed as params");
       ArrayRef<Attribute> dynamicOption =
           state.getParams(dynamicOptions[dynamicOptionIdx]);
-      if (dynamicOption.size() != 1)
-        return emitSilenceableError()
-               << "options passed as a param must have "
-                  "a single value associated, param "
-               << dynamicOptionIdx << " associates " << dynamicOption.size();
-      valueAttrToAppend = dynamicOption[0];
+      // Append all attributes associated to the param, separated by commas.
+      for (auto [idx, associatedAttr] : llvm::enumerate(dynamicOption)) {
+        if (idx > 0)
+          optionsStream << ",";
+        appendValueAttr(associatedAttr);
+      }
     } else {
       // Value is a static attribute.
-      valueAttrToAppend = namedAttribute.getValue();
-    }
-
-    // Append string representation of value attribute.
-    if (auto strAttr = dyn_cast<StringAttr>(valueAttrToAppend)) {
-      optionsStream << strAttr.getValue().str();
-    } else {
-      valueAttrToAppend.print(optionsStream, /*elideType=*/true);
+      appendValueAttr(namedAttribute.getValue());
     }
   }
   optionsStream.flush();
diff --git a/mlir/test/Dialect/Transform/test-pass-application.mlir b/mlir/test/Dialect/Transform/test-pass-application.mlir
index 1d1be9eda3496..407dfa3823436 100644
--- a/mlir/test/Dialect/Transform/test-pass-application.mlir
+++ b/mlir/test/Dialect/Transform/test-pass-application.mlir
@@ -164,6 +164,38 @@ module attributes {transform.with_named_sequence} {
 
 // -----
 
+// CHECK-LABEL: func private @valid_dynamic_pass_list_option()
+module {
+  func.func @valid_dynamic_pass_list_option() {
+    return
+  }
+
+  // CHECK: func @a()
+  func.func @a() {
+    return
+  }
+  // CHECK: func @b()
+  func.func @b() {
+    return
+  }
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op) {
+    %1 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    %2 = transform.get_parent_op %1 { deduplicate } : (!transform.any_op) -> !transform.any_op
+    %symbol_a = transform.param.constant "a" -> !transform.any_param
+    %symbol_b = transform.param.constant "b" -> !transform.any_param
+    %multiple_symbol_names = transform.merge_handles %symbol_a, %symbol_b : !transform.any_param
+    transform.apply_registered_pass "symbol-privatize"
+        with options = { exclude = %multiple_symbol_names } to %2
+        : (!transform.any_op, !transform.any_param) -> !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
 func.func @invalid_options_as_str() {
   return
 }
@@ -262,26 +294,6 @@ module attributes {transform.with_named_sequence} {
 
 // -----
 
-func.func @too_many_pass_option_params() {
-  return
-}
-
-module attributes {transform.with_named_sequence} {
-  transform.named_sequence @__transform_main(%arg1: !transform.any_op) {
-    %1 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
-    %x = transform.param.constant true -> !transform.any_param
-    %y = transform.param.constant false -> !transform.any_param
-    %topdown_options = transform.merge_handles %x, %y : !transform.any_param
-    // expected-error @below {{options passed as a param must have a single value associated, param 0 associates 2}}
-    transform.apply_registered_pass "canonicalize"
-        with options = { "top-down" = %topdown_options } to %1
-        : (!transform.any_op, !transform.any_param) -> !transform.any_op
-    transform.yield
-  }
-}
-
-// -----
-
 module attributes {transform.with_named_sequence} {
   // expected-error @below {{trying to schedule a pass on an unsupported operation}}
   // expected-note @below {{target op}}

>From ff632049e28e0d2ae63dabe21283876d1c0467e7 Mon Sep 17 00:00:00 2001
From: Rolf Morel <rolf.morel at intel.com>
Date: Fri, 13 Jun 2025 00:45:10 -0700
Subject: [PATCH 2/7] Also support passing in an ArrayAttr as a ListOption,
 also through params

---
 .../lib/Dialect/Transform/IR/TransformOps.cpp | 12 +++-
 .../Transform/test-pass-application.mlir      | 63 ++++++++++++++++++-
 2 files changed, 70 insertions(+), 5 deletions(-)

diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
index bfe6416987629..0538faf5b3ba8 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
@@ -791,11 +791,17 @@ transform::ApplyRegisteredPassOp::apply(transform::TransformRewriter &rewriter,
 
   std::string options;
   llvm::raw_string_ostream optionsStream(options); // For "printing" attrs.
-  auto appendValueAttr = [&](Attribute valueAttr) {
-    if (auto strAttr = dyn_cast<StringAttr>(valueAttr))
+  std::function<void(Attribute)> appendValueAttr = [&](Attribute valueAttr) {
+    if (auto arrayAttr = dyn_cast<ArrayAttr>(valueAttr)) {
+      for (auto [idx, eltAttr] : llvm::enumerate(arrayAttr)) {
+        appendValueAttr(eltAttr);
+        optionsStream << ",";
+      }
+    } else if (auto strAttr = dyn_cast<StringAttr>(valueAttr)) {
       optionsStream << strAttr.getValue().str();
-    else
+    } else {
       valueAttr.print(optionsStream, /*elideType=*/true);
+    }
   };
 
   OperandRange dynamicOptions = getDynamicOptions();
diff --git a/mlir/test/Dialect/Transform/test-pass-application.mlir b/mlir/test/Dialect/Transform/test-pass-application.mlir
index 407dfa3823436..f7909f4c035d9 100644
--- a/mlir/test/Dialect/Transform/test-pass-application.mlir
+++ b/mlir/test/Dialect/Transform/test-pass-application.mlir
@@ -164,9 +164,9 @@ module attributes {transform.with_named_sequence} {
 
 // -----
 
-// CHECK-LABEL: func private @valid_dynamic_pass_list_option()
+// CHECK-LABEL: func private @valid_multiple_params_as_list_option()
 module {
-  func.func @valid_dynamic_pass_list_option() {
+  func.func @valid_multiple_params_as_list_option() {
     return
   }
 
@@ -196,6 +196,65 @@ module attributes {transform.with_named_sequence} {
 
 // -----
 
+// CHECK-LABEL: func private @valid_array_attr_as_list_option()
+module {
+  func.func @valid_array_attr_param_as_list_option() {
+    return
+  }
+
+  // CHECK: func @a()
+  func.func @a() {
+    return
+  }
+  // CHECK: func @b()
+  func.func @b() {
+    return
+  }
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op) {
+    %1 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    %2 = transform.get_parent_op %1 { deduplicate } : (!transform.any_op) -> !transform.any_op
+    transform.apply_registered_pass "symbol-privatize"
+        with options = { exclude = ["a", "b"] } to %2
+        : (!transform.any_op) -> !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+// CHECK-LABEL: func private @valid_array_attr_param_as_list_option()
+module {
+  func.func @valid_array_attr_param_as_list_option() {
+    return
+  }
+
+  // CHECK: func @a()
+  func.func @a() {
+    return
+  }
+  // CHECK: func @b()
+  func.func @b() {
+    return
+  }
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op) {
+    %1 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    %2 = transform.get_parent_op %1 { deduplicate } : (!transform.any_op) -> !transform.any_op
+    %multiple_symbol_names = transform.param.constant ["a","b"] -> !transform.any_param
+    transform.apply_registered_pass "symbol-privatize"
+        with options = { exclude = %multiple_symbol_names } to %2
+        : (!transform.any_op, !transform.any_param) -> !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
 func.func @invalid_options_as_str() {
   return
 }

>From 7767260b9f81d664d064c1579acd7d41cb32b665 Mon Sep 17 00:00:00 2001
From: Rolf Morel <rolf.morel at intel.com>
Date: Fri, 13 Jun 2025 01:02:21 -0700
Subject: [PATCH 3/7] Minor clean-up

---
 mlir/lib/Dialect/Transform/IR/TransformOps.cpp | 18 +++++-------------
 .../Transform/test-pass-application.mlir       |  2 +-
 2 files changed, 6 insertions(+), 14 deletions(-)

diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
index 0538faf5b3ba8..651462ee6ad03 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
@@ -792,16 +792,12 @@ transform::ApplyRegisteredPassOp::apply(transform::TransformRewriter &rewriter,
   std::string options;
   llvm::raw_string_ostream optionsStream(options); // For "printing" attrs.
   std::function<void(Attribute)> appendValueAttr = [&](Attribute valueAttr) {
-    if (auto arrayAttr = dyn_cast<ArrayAttr>(valueAttr)) {
-      for (auto [idx, eltAttr] : llvm::enumerate(arrayAttr)) {
-        appendValueAttr(eltAttr);
-        optionsStream << ",";
-      }
-    } else if (auto strAttr = dyn_cast<StringAttr>(valueAttr)) {
+    if (auto arrayAttr = dyn_cast<ArrayAttr>(valueAttr))
+      llvm::interleave(arrayAttr, optionsStream, appendValueAttr, ",");
+    else if (auto strAttr = dyn_cast<StringAttr>(valueAttr))
       optionsStream << strAttr.getValue().str();
-    } else {
+    else
       valueAttr.print(optionsStream, /*elideType=*/true);
-    }
   };
 
   OperandRange dynamicOptions = getDynamicOptions();
@@ -822,11 +818,7 @@ transform::ApplyRegisteredPassOp::apply(transform::TransformRewriter &rewriter,
       ArrayRef<Attribute> dynamicOption =
           state.getParams(dynamicOptions[dynamicOptionIdx]);
       // Append all attributes associated to the param, separated by commas.
-      for (auto [idx, associatedAttr] : llvm::enumerate(dynamicOption)) {
-        if (idx > 0)
-          optionsStream << ",";
-        appendValueAttr(associatedAttr);
-      }
+      llvm::interleave(dynamicOption, optionsStream, appendValueAttr, ",");
     } else {
       // Value is a static attribute.
       appendValueAttr(namedAttribute.getValue());
diff --git a/mlir/test/Dialect/Transform/test-pass-application.mlir b/mlir/test/Dialect/Transform/test-pass-application.mlir
index f7909f4c035d9..7262a8fe9faee 100644
--- a/mlir/test/Dialect/Transform/test-pass-application.mlir
+++ b/mlir/test/Dialect/Transform/test-pass-application.mlir
@@ -198,7 +198,7 @@ module attributes {transform.with_named_sequence} {
 
 // CHECK-LABEL: func private @valid_array_attr_as_list_option()
 module {
-  func.func @valid_array_attr_param_as_list_option() {
+  func.func @valid_array_attr_as_list_option() {
     return
   }
 

>From fda3f63c7166d14b7505895ab25d9733d9bcec31 Mon Sep 17 00:00:00 2001
From: Rolf Morel <rolf.morel at intel.com>
Date: Fri, 13 Jun 2025 07:17:32 -0700
Subject: [PATCH 4/7] Allow ArrayAttrs options to contain param-operands

---
 .../mlir/Dialect/Transform/IR/TransformOps.td |   2 +-
 .../lib/Dialect/Transform/IR/TransformOps.cpp | 164 ++++++++++++------
 .../Transform/test-pass-application.mlir      |  39 ++++-
 3 files changed, 148 insertions(+), 57 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
index 0aa750e625436..140c9c66f3918 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
@@ -418,7 +418,7 @@ def ApplyRegisteredPassOp : TransformDialectOp<"apply_registered_pass",
         with options = { "top-down" = false,
                          "max-iterations" = %max_iter,
                          "test-convergence" = true,
-                         "max-num-rewrites" =  %max_rewrites }
+                         "max-num-rewrites" = %max_rewrites }
         to %module
     : (!transform.any_param, !transform.any_param, !transform.any_op) -> !transform.any_op
     ```
diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
index 651462ee6ad03..bb9bdd70625e4 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
@@ -788,42 +788,47 @@ transform::ApplyRegisteredPassOp::apply(transform::TransformRewriter &rewriter,
   // Obtain a single options-string to pass to the pass(-pipeline) from options
   // passed in as a dictionary of keys mapping to values which are either
   // attributes or param-operands pointing to attributes.
+  OperandRange dynamicOptions = getDynamicOptions();
 
   std::string options;
   llvm::raw_string_ostream optionsStream(options); // For "printing" attrs.
-  std::function<void(Attribute)> appendValueAttr = [&](Attribute valueAttr) {
-    if (auto arrayAttr = dyn_cast<ArrayAttr>(valueAttr))
-      llvm::interleave(arrayAttr, optionsStream, appendValueAttr, ",");
-    else if (auto strAttr = dyn_cast<StringAttr>(valueAttr))
-      optionsStream << strAttr.getValue().str();
-    else
-      valueAttr.print(optionsStream, /*elideType=*/true);
-  };
 
-  OperandRange dynamicOptions = getDynamicOptions();
-  for (auto [idx, namedAttribute] : llvm::enumerate(getOptions())) {
-    if (idx > 0)
-      optionsStream << " "; // Interleave options separator.
-    optionsStream << namedAttribute.getName().str(); // Append the key.
-    optionsStream << "="; // And the key-value separator.
-
-    if (auto paramOperandIndex =
-            dyn_cast<transform::ParamOperandAttr>(namedAttribute.getValue())) {
-      // The corresponding value attribute is passed in via a param.
+  // A helper to convert an option's attribute value into a corresponding
+  // string representation, with the ability to obtain the attr(s) from a param.
+  std::function<void(Attribute)> appendValueAttr = [&](Attribute valueAttr) {
+    if (auto paramOperand = dyn_cast<transform::ParamOperandAttr>(valueAttr)) {
+      // The corresponding value attribute(s) is/are passed in via a param.
       // Obtain the param-operand via its specified index.
-      size_t dynamicOptionIdx = paramOperandIndex.getIndex().getInt();
+      size_t dynamicOptionIdx = paramOperand.getIndex().getInt();
       assert(dynamicOptionIdx < dynamicOptions.size() &&
-             "number of dynamic option markers (UnitAttr) in options ArrayAttr "
+             "the number of ParamOperandAttrs in the options DictionaryAttr"
              "should be the same as the number of options passed as params");
-      ArrayRef<Attribute> dynamicOption =
+      ArrayRef<Attribute> attrsAssociatedToParam =
           state.getParams(dynamicOptions[dynamicOptionIdx]);
-      // Append all attributes associated to the param, separated by commas.
-      llvm::interleave(dynamicOption, optionsStream, appendValueAttr, ",");
+      // Recursive so as to append all attrs associated to the param.
+      llvm::interleave(attrsAssociatedToParam, optionsStream, appendValueAttr,
+                       ",");
+    } else if (auto arrayAttr = dyn_cast<ArrayAttr>(valueAttr)) {
+      // Recursive so as to append all nested attrs of the array.
+      llvm::interleave(arrayAttr, optionsStream, appendValueAttr, ",");
+    } else if (auto strAttr = dyn_cast<StringAttr>(valueAttr)) {
+      // Convert to unquoted string.
+      optionsStream << strAttr.getValue().str();
     } else {
-      // Value is a static attribute.
-      appendValueAttr(namedAttribute.getValue());
+      // For all other attributes, ask the attr to print itself (without type).
+      valueAttr.print(optionsStream, /*elideType=*/true);
     }
-  }
+  };
+
+  // Convert the options DictionaryAttr into a single string.
+  llvm::interleave(
+      getOptions(), optionsStream,
+      [&](auto namedAttribute) {
+        optionsStream << namedAttribute.getName().str(); // Append the key.
+        optionsStream << "="; // And the key-value separator.
+        appendValueAttr(namedAttribute.getValue()); // And the attr's str repr.
+      },
+      " ");
   optionsStream.flush();
 
   // Get pass or pass pipeline from registry.
@@ -874,23 +879,30 @@ static ParseResult parseApplyRegisteredPassOptions(
     SmallVectorImpl<OpAsmParser::UnresolvedOperand> &dynamicOptions) {
   // Construct the options DictionaryAttr per a `{ key = value, ... }` syntax.
   SmallVector<NamedAttribute> keyValuePairs;
-
   size_t dynamicOptionsIdx = 0;
-  auto parseKeyValuePair = [&]() -> ParseResult {
-    // Parse items of the form `key = value` where `key` is a bare identifier or
-    // a string and `value` is either an attribute or an operand.
 
-    std::string key;
-    Attribute valueAttr;
-    if (parser.parseOptionalKeywordOrString(&key))
-      return parser.emitError(parser.getCurrentLocation())
-             << "expected key to either be an identifier or a string";
-    if (key.empty())
-      return failure();
+  // Helper for allowing parsing of option values which can be of the form:
+  // - a normal attribute
+  // - an operand (which would be converted to an attr referring to the operand)
+  // - ArrayAttrs containing the foregoing (in correspondence with ListOptions)
+  std::function<ParseResult(Attribute &)> parseValue =
+      [&](Attribute &valueAttr) -> ParseResult {
+    // Allow for array syntax, e.g. `[0 : i64, %param, true, %other_param]`:
+    if (succeeded(parser.parseOptionalLSquare())) {
+      SmallVector<Attribute> attrs;
 
-    if (parser.parseEqual())
-      return parser.emitError(parser.getCurrentLocation())
-             << "expected '=' after key in key-value pair";
+      // Recursively parse the array's elements, which might be operands.
+      if (parser.parseCommaSeparatedList(
+              AsmParser::Delimiter::None,
+              [&]() -> ParseResult { return parseValue(attrs.emplace_back()); },
+              " in options dictionary") ||
+          parser.parseRSquare())
+        return failure(); // NB: Attempted parse should've output error message.
+
+      valueAttr = ArrayAttr::get(parser.getContext(), attrs);
+
+      return success();
+    }
 
     // Parse the value, which can be either an attribute or an operand.
     OptionalParseResult parsedValueAttr =
@@ -899,9 +911,7 @@ static ParseResult parseApplyRegisteredPassOptions(
       OpAsmParser::UnresolvedOperand operand;
       ParseResult parsedOperand = parser.parseOperand(operand);
       if (failed(parsedOperand))
-        return parser.emitError(parser.getCurrentLocation())
-               << "expected a valid attribute or operand as value associated "
-               << "to key '" << key << "'";
+        return failure(); // NB: Attempted parse should've output error message.
       // To make use of the operand, we need to store it in the options dict.
       // As SSA-values cannot occur in attributes, what we do instead is store
       // an attribute in its place that contains the index of the param-operand,
@@ -920,7 +930,30 @@ static ParseResult parseApplyRegisteredPassOptions(
              << "in the generic print format";
     }
 
+    return success();
+  };
+
+  // Helper for `key = value`-pair parsing where `key` is a bare identifier or a
+  // string and `value` looks like either an attribute or an operand-in-an-attr.
+  std::function<ParseResult()> parseKeyValuePair = [&]() -> ParseResult {
+    std::string key;
+    Attribute valueAttr;
+
+    if (failed(parser.parseOptionalKeywordOrString(&key)) || key.empty())
+      return parser.emitError(parser.getCurrentLocation())
+             << "expected key to either be an identifier or a string";
+
+    if (failed(parser.parseEqual()))
+      return parser.emitError(parser.getCurrentLocation())
+             << "expected '=' after key in key-value pair";
+
+    if (failed(parseValue(valueAttr)))
+      return parser.emitError(parser.getCurrentLocation())
+             << "expected a valid attribute or operand as value associated "
+             << "to key '" << key << "'";
+
     keyValuePairs.push_back(NamedAttribute(key, valueAttr));
+
     return success();
   };
 
@@ -947,16 +980,27 @@ static void printApplyRegisteredPassOptions(OpAsmPrinter &printer,
   if (options.empty())
     return;
 
-  printer << "{";
-  llvm::interleaveComma(options, printer, [&](NamedAttribute namedAttribute) {
-    printer << namedAttribute.getName() << " = ";
-    Attribute value = namedAttribute.getValue();
-    if (auto indexAttr = dyn_cast<transform::ParamOperandAttr>(value)) {
+  std::function<void(Attribute)> printOptionValue = [&](Attribute valueAttr) {
+    if (auto paramOperandAttr =
+            dyn_cast<transform::ParamOperandAttr>(valueAttr)) {
       // Resolve index of param-operand to its actual SSA-value and print that.
-      printer.printOperand(dynamicOptions[indexAttr.getIndex().getInt()]);
+      printer.printOperand(
+          dynamicOptions[paramOperandAttr.getIndex().getInt()]);
+    } else if (auto arrayAttr = dyn_cast<ArrayAttr>(valueAttr)) {
+      // This case is so that ArrayAttr-contained operands are pretty-printed.
+      printer << "[";
+      llvm::interleaveComma(arrayAttr, printer, printOptionValue);
+      printer << "]";
     } else {
-      printer.printAttribute(value);
+      printer.printAttribute(valueAttr);
     }
+  };
+
+  printer << "{";
+  llvm::interleaveComma(options, printer, [&](NamedAttribute namedAttribute) {
+    printer << namedAttribute.getName();
+    printer << " = ";
+    printOptionValue(namedAttribute.getValue());
   });
   printer << "}";
 }
@@ -966,9 +1010,11 @@ LogicalResult transform::ApplyRegisteredPassOp::verify() {
   // and references to dynamic options in the options dictionary.
 
   auto dynamicOptions = SmallVector<Value>(getDynamicOptions());
-  for (NamedAttribute namedAttr : getOptions())
-    if (auto paramOperand =
-            dyn_cast<transform::ParamOperandAttr>(namedAttr.getValue())) {
+
+  // Helper for option values to mark seen operands as having been seen (once).
+  std::function<LogicalResult(Attribute)> checkOptionValue =
+      [&](Attribute valueAttr) -> LogicalResult {
+    if (auto paramOperand = dyn_cast<transform::ParamOperandAttr>(valueAttr)) {
       size_t dynamicOptionIdx = paramOperand.getIndex().getInt();
       if (dynamicOptionIdx < 0 || dynamicOptionIdx >= dynamicOptions.size())
         return emitOpError()
@@ -979,8 +1025,20 @@ LogicalResult transform::ApplyRegisteredPassOp::verify() {
         return emitOpError() << "dynamic option index " << dynamicOptionIdx
                              << " is already used in options";
       dynamicOptions[dynamicOptionIdx] = nullptr; // Mark this option as used.
+    } else if (auto arrayAttr = dyn_cast<ArrayAttr>(valueAttr)) {
+      // Recurse into ArrayAttrs as they may contain references to operands.
+      for (auto eltAttr : arrayAttr)
+        if (failed(checkOptionValue(eltAttr)))
+          return failure();
     }
+    return success();
+  };
+
+  for (NamedAttribute namedAttr : getOptions())
+    if (failed(checkOptionValue(namedAttr.getValue())))
+      return failure();
 
+  // All dynamicOptions-params seen in the dict will have been set to null.
   for (Value dynamicOption : dynamicOptions)
     if (dynamicOption)
       return emitOpError() << "a param operand does not have a corresponding "
diff --git a/mlir/test/Dialect/Transform/test-pass-application.mlir b/mlir/test/Dialect/Transform/test-pass-application.mlir
index 7262a8fe9faee..e21e750011ce7 100644
--- a/mlir/test/Dialect/Transform/test-pass-application.mlir
+++ b/mlir/test/Dialect/Transform/test-pass-application.mlir
@@ -164,9 +164,9 @@ module attributes {transform.with_named_sequence} {
 
 // -----
 
-// CHECK-LABEL: func private @valid_multiple_params_as_list_option()
+// CHECK-LABEL: func private @valid_multiple_values_as_list_option_single_param()
 module {
-  func.func @valid_multiple_params_as_list_option() {
+  func.func @valid_multiple_values_as_list_option_single_param() {
     return
   }
 
@@ -253,6 +253,38 @@ module attributes {transform.with_named_sequence} {
   }
 }
 
+// -----
+
+// CHECK-LABEL: func private @valid_multiple_params_as_single_list_option()
+module {
+  func.func @valid_multiple_params_as_single_list_option() {
+    return
+  }
+
+  // CHECK: func @a()
+  func.func @a() {
+    return
+  }
+  // CHECK: func @b()
+  func.func @b() {
+    return
+  }
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op) {
+    %1 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    %2 = transform.get_parent_op %1 { deduplicate } : (!transform.any_op) -> !transform.any_op
+    %symbol_a = transform.param.constant "a" -> !transform.any_param
+    %symbol_b = transform.param.constant "b" -> !transform.any_param
+    transform.apply_registered_pass "symbol-privatize"
+        with options = { exclude = [%symbol_a, %symbol_b] } to %2
+        : (!transform.any_op, !transform.any_param, !transform.any_param) -> !transform.any_op
+    transform.yield
+  }
+}
+
+
 // -----
 
 func.func @invalid_options_as_str() {
@@ -294,7 +326,8 @@ func.func @invalid_options_due_to_reserved_attr() {
 module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%arg1: !transform.any_op) {
     %1 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
-    // expected-error @+2 {{the param_operand attribute is a marker reserved for indicating a value will be passed via params and is only used in the generic print format}}
+    // expected-error @+3 {{the param_operand attribute is a marker reserved for indicating a value will be passed via params and is only used in the generic print format}}
+    // expected-error @+2 {{expected a valid attribute or operand as value associated to key 'top-down'}}
     %2 = transform.apply_registered_pass "canonicalize"
         with options = { "top-down" = #transform.param_operand<index=0> } to %1 : (!transform.any_op) -> !transform.any_op
     transform.yield

>From 01ee8bbd9706432ac47bdf630aefcd472b96e93c Mon Sep 17 00:00:00 2001
From: Rolf Morel <rolf.morel at intel.com>
Date: Fri, 13 Jun 2025 07:59:32 -0700
Subject: [PATCH 5/7] Update Python-bindings

---
 .../mlir/dialects/transform/__init__.py       | 41 ++++++-------
 .../Transform/test-pass-application.mlir      |  1 -
 mlir/test/python/dialects/transform.py        | 60 ++++++++++++++-----
 3 files changed, 65 insertions(+), 37 deletions(-)

diff --git a/mlir/python/mlir/dialects/transform/__init__.py b/mlir/python/mlir/dialects/transform/__init__.py
index bfe96b1b3e5d4..b075919d1ef0f 100644
--- a/mlir/python/mlir/dialects/transform/__init__.py
+++ b/mlir/python/mlir/dialects/transform/__init__.py
@@ -219,6 +219,11 @@ def __init__(
         super().__init__(_get_op_results_or_values(operands), loc=loc, ip=ip)
 
 
+OptionValueTypes = Union[
+    Sequence["OptionValueTypes"], Attribute, Value, Operation, OpView, str, int, bool
+]
+
+
 @_ods_cext.register_operation(_Dialect, replace=True)
 class ApplyRegisteredPassOp(ApplyRegisteredPassOp):
     def __init__(
@@ -227,12 +232,7 @@ def __init__(
         target: Union[Operation, Value, OpView],
         pass_name: Union[str, StringAttr],
         *,
-        options: Optional[
-            Dict[
-                Union[str, StringAttr],
-                Union[Attribute, Value, Operation, OpView, str, int, bool],
-            ]
-        ] = None,
+        options: Optional[Dict[Union[str, StringAttr], OptionValueTypes]] = None,
         loc=None,
         ip=None,
     ):
@@ -243,26 +243,32 @@ def __init__(
         context = (loc and loc.context) or Context.current
 
         cur_param_operand_idx = 0
-        for key, value in options.items() if options is not None else {}:
-            if isinstance(key, StringAttr):
-                key = key.value
 
+        def option_value_to_attr(value):
+            nonlocal cur_param_operand_idx
             if isinstance(value, (Value, Operation, OpView)):
                 dynamic_options.append(_get_op_result_or_value(value))
-                options_dict[key] = ParamOperandAttr(cur_param_operand_idx, context)
                 cur_param_operand_idx += 1
+                return ParamOperandAttr(cur_param_operand_idx - 1, context)
             elif isinstance(value, Attribute):
-                options_dict[key] = value
+                return value
             # The following cases auto-convert Python values to attributes.
             elif isinstance(value, bool):
-                options_dict[key] = BoolAttr.get(value)
+                return BoolAttr.get(value)
             elif isinstance(value, int):
                 default_int_type = IntegerType.get_signless(64, context)
-                options_dict[key] = IntegerAttr.get(default_int_type, value)
+                return IntegerAttr.get(default_int_type, value)
             elif isinstance(value, str):
-                options_dict[key] = StringAttr.get(value)
+                return StringAttr.get(value)
+            elif isinstance(value, Sequence):
+                return ArrayAttr.get([option_value_to_attr(elt) for elt in value])
             else:
                 raise TypeError(f"Unsupported option type: {type(value)}")
+
+        for key, value in options.items() if options is not None else {}:
+            if isinstance(key, StringAttr):
+                key = key.value
+            options_dict[key] = option_value_to_attr(value)
         super().__init__(
             result,
             _get_op_result_or_value(target),
@@ -279,12 +285,7 @@ def apply_registered_pass(
     target: Union[Operation, Value, OpView],
     pass_name: Union[str, StringAttr],
     *,
-    options: Optional[
-        Dict[
-            Union[str, StringAttr],
-            Union[Attribute, Value, Operation, OpView, str, int, bool],
-        ]
-    ] = None,
+    options: Optional[Dict[Union[str, StringAttr], OptionValueTypes]] = None,
     loc=None,
     ip=None,
 ) -> Value:
diff --git a/mlir/test/Dialect/Transform/test-pass-application.mlir b/mlir/test/Dialect/Transform/test-pass-application.mlir
index e21e750011ce7..ce8f69c58701d 100644
--- a/mlir/test/Dialect/Transform/test-pass-application.mlir
+++ b/mlir/test/Dialect/Transform/test-pass-application.mlir
@@ -284,7 +284,6 @@ module attributes {transform.with_named_sequence} {
   }
 }
 
-
 // -----
 
 func.func @invalid_options_as_str() {
diff --git a/mlir/test/python/dialects/transform.py b/mlir/test/python/dialects/transform.py
index eeb95605d7a9a..aeadfcb596526 100644
--- a/mlir/test/python/dialects/transform.py
+++ b/mlir/test/python/dialects/transform.py
@@ -256,30 +256,45 @@ def testReplicateOp(module: Module):
     # CHECK: %{{.*}} = replicate num(%[[FIRST]]) %[[SECOND]]
 
 
+# CHECK-LABEL: TEST: testApplyRegisteredPassOp
 @run
 def testApplyRegisteredPassOp(module: Module):
+    # CHECK: transform.sequence
     sequence = transform.SequenceOp(
         transform.FailurePropagationMode.Propagate, [], transform.AnyOpType.get()
     )
     with InsertionPoint(sequence.body):
+        # CHECK:   %{{.*}} = apply_registered_pass "canonicalize" to {{.*}} : (!transform.any_op) -> !transform.any_op
         mod = transform.ApplyRegisteredPassOp(
             transform.AnyOpType.get(), sequence.bodyTarget, "canonicalize"
         )
+        # CHECK:   %{{.*}} = apply_registered_pass "canonicalize"
+        # CHECK-SAME:    with options = {"top-down" = false}
+        # CHECK-SAME:    to {{.*}} : (!transform.any_op) -> !transform.any_op
         mod = transform.ApplyRegisteredPassOp(
             transform.AnyOpType.get(),
             mod.result,
             "canonicalize",
             options={"top-down": BoolAttr.get(False)},
         )
+        # CHECK:   %[[MAX_ITER:.+]] = transform.param.constant
         max_iter = transform.param_constant(
             transform.AnyParamType.get(),
             IntegerAttr.get(IntegerType.get_signless(64), 10),
         )
+        # CHECK:   %[[MAX_REWRITE:.+]] = transform.param.constant
         max_rewrites = transform.param_constant(
             transform.AnyParamType.get(),
             IntegerAttr.get(IntegerType.get_signless(64), 1),
         )
-        transform.apply_registered_pass(
+        # CHECK:   %{{.*}} = apply_registered_pass "canonicalize"
+        # NB: MLIR has sorted the dict lexicographically by key:
+        # CHECK-SAME:    with options = {"max-iterations" = %[[MAX_ITER]],
+        # CHECK-SAME:                    "max-rewrites" =  %[[MAX_REWRITE]],
+        # CHECK-SAME:                    "test-convergence" = true,
+        # CHECK-SAME:                    "top-down" = false}
+        # CHECK-SAME:    to %{{.*}} : (!transform.any_op, !transform.any_param, !transform.any_param) -> !transform.any_op
+        mod = transform.apply_registered_pass(
             transform.AnyOpType.get(),
             mod,
             "canonicalize",
@@ -290,19 +305,32 @@ def testApplyRegisteredPassOp(module: Module):
                 "max-rewrites": max_rewrites,
             },
         )
+        # CHECK:   %{{.*}} = apply_registered_pass "symbol-privatize"
+        # CHECK-SAME:    with options = {"exclude" = ["a", "b"]}
+        # CHECK-SAME:    to %{{.*}} : (!transform.any_op) -> !transform.any_op
+        mod = transform.apply_registered_pass(
+            transform.AnyOpType.get(),
+            mod,
+            "symbol-privatize",
+            options={ "exclude": ("a", "b") },
+        )
+        # CHECK:   %[[SYMBOL_A:.+]] = transform.param.constant
+        symbol_a = transform.param_constant(
+            transform.AnyParamType.get(),
+            StringAttr.get("a")
+        )
+        # CHECK:   %[[SYMBOL_B:.+]] = transform.param.constant
+        symbol_b = transform.param_constant(
+            transform.AnyParamType.get(),
+            StringAttr.get("b")
+        )
+        # CHECK:   %{{.*}} = apply_registered_pass "symbol-privatize"
+        # CHECK-SAME:    with options = {"exclude" = [%[[SYMBOL_A]], %[[SYMBOL_B]]]}
+        # CHECK-SAME:    to %{{.*}} : (!transform.any_op, !transform.any_param, !transform.any_param) -> !transform.any_op
+        mod = transform.apply_registered_pass(
+            transform.AnyOpType.get(),
+            mod,
+            "symbol-privatize",
+            options={ "exclude": (symbol_a, symbol_b) },
+        )
         transform.YieldOp()
-    # CHECK-LABEL: TEST: testApplyRegisteredPassOp
-    # CHECK: transform.sequence
-    # CHECK:   %{{.*}} = apply_registered_pass "canonicalize" to {{.*}} : (!transform.any_op) -> !transform.any_op
-    # CHECK:   %{{.*}} = apply_registered_pass "canonicalize"
-    # CHECK-SAME:    with options = {"top-down" = false}
-    # CHECK-SAME:    to {{.*}} : (!transform.any_op) -> !transform.any_op
-    # CHECK:   %[[MAX_ITER:.+]] = transform.param.constant
-    # CHECK:   %[[MAX_REWRITE:.+]] = transform.param.constant
-    # CHECK:   %{{.*}} = apply_registered_pass "canonicalize"
-    # NB: MLIR has sorted the dict lexicographically by key:
-    # CHECK-SAME:    with options = {"max-iterations" = %[[MAX_ITER]],
-    # CHECK-SAME:                    "max-rewrites" =  %[[MAX_REWRITE]],
-    # CHECK-SAME:                    "test-convergence" = true,
-    # CHECK-SAME:                    "top-down" = false}
-    # CHECK-SAME:    to %{{.*}} : (!transform.any_op, !transform.any_param, !transform.any_param) -> !transform.any_op

>From abbac927d003953da75b6997d076b0f730873ab1 Mon Sep 17 00:00:00 2001
From: Rolf Morel <rolf.morel at intel.com>
Date: Fri, 13 Jun 2025 08:15:03 -0700
Subject: [PATCH 6/7] Fix Python formatting

---
 mlir/test/python/dialects/transform.py | 10 ++++------
 1 file changed, 4 insertions(+), 6 deletions(-)

diff --git a/mlir/test/python/dialects/transform.py b/mlir/test/python/dialects/transform.py
index aeadfcb596526..6c5e4e5505b1c 100644
--- a/mlir/test/python/dialects/transform.py
+++ b/mlir/test/python/dialects/transform.py
@@ -312,17 +312,15 @@ def testApplyRegisteredPassOp(module: Module):
             transform.AnyOpType.get(),
             mod,
             "symbol-privatize",
-            options={ "exclude": ("a", "b") },
+            options={"exclude": ("a", "b")},
         )
         # CHECK:   %[[SYMBOL_A:.+]] = transform.param.constant
         symbol_a = transform.param_constant(
-            transform.AnyParamType.get(),
-            StringAttr.get("a")
+            transform.AnyParamType.get(), StringAttr.get("a")
         )
         # CHECK:   %[[SYMBOL_B:.+]] = transform.param.constant
         symbol_b = transform.param_constant(
-            transform.AnyParamType.get(),
-            StringAttr.get("b")
+            transform.AnyParamType.get(), StringAttr.get("b")
         )
         # CHECK:   %{{.*}} = apply_registered_pass "symbol-privatize"
         # CHECK-SAME:    with options = {"exclude" = [%[[SYMBOL_A]], %[[SYMBOL_B]]]}
@@ -331,6 +329,6 @@ def testApplyRegisteredPassOp(module: Module):
             transform.AnyOpType.get(),
             mod,
             "symbol-privatize",
-            options={ "exclude": (symbol_a, symbol_b) },
+            options={"exclude": (symbol_a, symbol_b)},
         )
         transform.YieldOp()

>From bcc104af4abbd5a2cb5e69f6e60c556d95011664 Mon Sep 17 00:00:00 2001
From: Rolf Morel <rolf.morel at intel.com>
Date: Mon, 16 Jun 2025 03:46:20 -0700
Subject: [PATCH 7/7] Update docs

---
 mlir/include/mlir/Dialect/Transform/IR/TransformOps.td | 3 +++
 1 file changed, 3 insertions(+)

diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
index 140c9c66f3918..62e66b3dabee8 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
@@ -423,6 +423,9 @@ def ApplyRegisteredPassOp : TransformDialectOp<"apply_registered_pass",
     : (!transform.any_param, !transform.any_param, !transform.any_op) -> !transform.any_op
     ```
 
+    Options' values which are `ArrayAttr`s are converted to comma-separated
+    lists of options. Likewise for params which associate multiple values.
+
     This op first looks for a pass pipeline with the specified name. If no such
     pipeline exists, it looks for a pass with the specified name. If no such
     pass exists either, this op fails definitely.



More information about the Mlir-commits mailing list