[Mlir-commits] [mlir] [MLIR][Transform][Tune] Introduce `transform.tune.alternatives` op (PR #160724)

Rolf Morel llvmlistbot at llvm.org
Wed Oct 1 06:36:08 PDT 2025


https://github.com/rolfmorel updated https://github.com/llvm/llvm-project/pull/160724

>From 95b5a10c62f14f9d5b261b0a2d8883f03ca8ac64 Mon Sep 17 00:00:00 2001
From: Rolf Morel <rolf.morel at intel.com>
Date: Thu, 25 Sep 2025 08:02:37 -0700
Subject: [PATCH 1/7] [MLIR][Transform][Tune] Introduce
 `transform.tune.alternatives` op

This op enables expressing uncertainty regarding what should be at
particular places in the transform-dialect schedules. In particular, it
enables representing a choice amond alternative region. A choice
resolved through providing a `selected_region` argument. When this
argument is provided, the semantics are such that it is valid to rewrite
the op through substituting in the selected region -- with its
interpreted semantics corresponding to exactly this.

This op represents another piece of the puzzle w.r.t. a toolkit for
expressing autotuning problems with the transform dialect. Note that
this goes beyond tuning knobs _on_ transforms, going further by making
it tunable which (sequences of) transforms are to be applied.
---
 .../TuneExtension/TuneExtensionOps.h          |   1 +
 .../TuneExtension/TuneExtensionOps.td         |  54 +++++
 .../TuneExtension/TuneExtensionOps.cpp        | 185 ++++++++++++++++++
 mlir/python/mlir/dialects/transform/tune.py   |  66 ++++++-
 .../test-tune-extension-invalid.mlir          |  85 ++++++++
 .../Transform/test-tune-extension.mlir        |  99 ++++++++++
 .../python/dialects/transform_tune_ext.py     |  87 ++++++--
 7 files changed, 560 insertions(+), 17 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Transform/TuneExtension/TuneExtensionOps.h b/mlir/include/mlir/Dialect/Transform/TuneExtension/TuneExtensionOps.h
index 74e1d28ffac82..ba11259790676 100644
--- a/mlir/include/mlir/Dialect/Transform/TuneExtension/TuneExtensionOps.h
+++ b/mlir/include/mlir/Dialect/Transform/TuneExtension/TuneExtensionOps.h
@@ -9,6 +9,7 @@
 #ifndef MLIR_DIALECT_TRANSFORM_TUNEEXTENSION_TUNEEXTENSIONOPS_H
 #define MLIR_DIALECT_TRANSFORM_TUNEEXTENSION_TUNEEXTENSIONOPS_H
 
+#include "mlir/Dialect/Transform/IR/TransformOps.h"
 #include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
 #include "mlir/IR/BuiltinAttributes.h"
 #include "mlir/IR/OpDefinition.h"
diff --git a/mlir/include/mlir/Dialect/Transform/TuneExtension/TuneExtensionOps.td b/mlir/include/mlir/Dialect/Transform/TuneExtension/TuneExtensionOps.td
index d68d451afac40..d095659fc4838 100644
--- a/mlir/include/mlir/Dialect/Transform/TuneExtension/TuneExtensionOps.td
+++ b/mlir/include/mlir/Dialect/Transform/TuneExtension/TuneExtensionOps.td
@@ -11,10 +11,15 @@
 
 include "mlir/Dialect/Transform/IR/TransformDialect.td"
 include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.td"
+include "mlir/Interfaces/ControlFlowInterfaces.td"
 include "mlir/Interfaces/SideEffectInterfaces.td"
 include "mlir/IR/BuiltinAttributes.td"
 include "mlir/IR/CommonAttrConstraints.td"
 
+//===----------------------------------------------------------------------===//
+// KnobOp
+//===----------------------------------------------------------------------===//
+
 def KnobOp : Op<Transform_Dialect, "tune.knob", [
   DeclareOpInterfaceMethods<TransformOpInterface>,
   DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
@@ -52,4 +57,53 @@ def KnobOp : Op<Transform_Dialect, "tune.knob", [
       "`<` $name `>` (`=` $selected^ `from`)? `options` `=` $options attr-dict `->` type(results)";
 }
 
+//===----------------------------------------------------------------------===//
+// AlternativesOp
+//===----------------------------------------------------------------------===//
+
+def AlternativesOp : Op<Transform_Dialect, "tune.alternatives", [
+  DeclareOpInterfaceMethods<RegionBranchOpInterface,
+        ["getEntrySuccessorOperands", "getSuccessorRegions",
+         "getRegionInvocationBounds"]>,
+  DeclareOpInterfaceMethods<TransformOpInterface>,
+  DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
+  SingleBlockImplicitTerminator<"::mlir::transform::YieldOp">,
+  NoRegionArguments
+]> {
+  let summary = "Represents a choice among its regions, i.e. sub-schedules";
+
+  let description = [{
+    This op represents a choice over which of its regions is to be used.
+
+    When `selected_region` is provided, the semantics are that this op is to be
+    substituted for by the selected region, meaning the region's results become
+    the results of this op. Without a provided `selected_region`, the semantics
+    are that this non-deterministic choice is yet to be resolved -- which in
+    terms of the op's interpreted semantics is a failure.
+
+    The `selected_region` argument is either an `IntegerAttr` or a param holding
+    an `IntegerAttr`, which should provide a valid zero-based index with respect
+    to the number of alternatives, i.e. regions.
+  }];
+  let cppNamespace = [{ mlir::transform::tune }];
+
+  let arguments = (ins Builtin_StringAttr:$name,
+                       OptionalAttr<APIntAttr>:$selected_region_attr,
+                       Optional<TransformParamTypeInterface>:$selected_region_param);
+  let results = (outs Variadic<Transform_AnyHandleOrParamType>:$results);
+  let regions = (region VariadicRegion<SizedRegion<1>>:$alternatives);
+
+  let assemblyFormat = [{
+    `<` $name `>`
+    (`selected_region` `=` custom<AlternativesOpSelectedRegion>(
+        $selected_region_attr, $selected_region_param)^)?
+    attr-dict-with-keyword
+    (`:` type($selected_region_param)^)?
+    (`->` type($results)^)?
+    regions
+  }];
+
+  let hasVerifier = 1;
+}
+
 #endif // MLIR_DIALECT_TRANSFORM_TUNEEXTENSION_TUNEEXTENSIONOPS
diff --git a/mlir/lib/Dialect/Transform/TuneExtension/TuneExtensionOps.cpp b/mlir/lib/Dialect/Transform/TuneExtension/TuneExtensionOps.cpp
index 842e880ca9150..dad63586bc8d4 100644
--- a/mlir/lib/Dialect/Transform/TuneExtension/TuneExtensionOps.cpp
+++ b/mlir/lib/Dialect/Transform/TuneExtension/TuneExtensionOps.cpp
@@ -6,13 +6,25 @@
 //
 //===----------------------------------------------------------------------===//
 
+#include "mlir/Dialect/Transform/IR/TransformOps.h"
 #include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
+#include "mlir/IR/OpImplementation.h"
 #include "llvm/Support/Debug.h"
+#include <cstddef>
 
 #include "mlir/Dialect/Transform/TuneExtension/TuneExtensionOps.h"
 
 using namespace mlir;
 
+static ParseResult parseAlternativesOpSelectedRegion(
+    OpAsmParser &parser, IntegerAttr &selectedRegionAttr,
+    std::optional<OpAsmParser::UnresolvedOperand> &selectedRegionParam);
+
+static void printAlternativesOpSelectedRegion(OpAsmPrinter &printer,
+                                              Operation *op,
+                                              IntegerAttr selectedRegionAttr,
+                                              Value selectedRegionParam);
+
 #define GET_OP_CLASSES
 #include "mlir/Dialect/Transform/TuneExtension/TuneExtensionOps.cpp.inc"
 
@@ -57,3 +69,176 @@ LogicalResult transform::tune::KnobOp::verify() {
 
   return success();
 }
+
+//===----------------------------------------------------------------------===//
+// AlternativesOp
+//===----------------------------------------------------------------------===//
+
+static ParseResult parseAlternativesOpSelectedRegion(
+    OpAsmParser &parser, IntegerAttr &selectedRegionAttr,
+    std::optional<OpAsmParser::UnresolvedOperand> &selectedRegionParam) {
+  size_t selectedRegionIdx;
+  OptionalParseResult attrParseRes =
+      parser.parseOptionalInteger(selectedRegionIdx);
+  if (attrParseRes.has_value()) {
+    if (failed(*attrParseRes))
+      return failure();
+
+    selectedRegionAttr = parser.getBuilder().getIndexAttr(selectedRegionIdx);
+    return success();
+  }
+
+  OpAsmParser::UnresolvedOperand param;
+  auto paramParseRes = parser.parseOptionalOperand(param);
+  if (paramParseRes.has_value()) {
+    if (failed(*paramParseRes))
+      return failure();
+
+    selectedRegionParam = param;
+    return success();
+  }
+
+  return parser.emitError(parser.getCurrentLocation())
+         << "expected either an integer attribute or a transform.param operand";
+}
+
+static void printAlternativesOpSelectedRegion(OpAsmPrinter &printer,
+                                              Operation *op,
+                                              IntegerAttr selectedRegionAttr,
+                                              Value selectedRegionParam) {
+  if (selectedRegionAttr)
+    printer << selectedRegionAttr.getValue();
+  if (selectedRegionParam)
+    printer << selectedRegionParam;
+}
+
+OperandRange transform::tune::AlternativesOp::getEntrySuccessorOperands(
+    RegionBranchPoint point) {
+  // No operands will be forwarded to the region(s).
+  return getOperands().slice(0, 0);
+}
+
+void transform::tune::AlternativesOp::getSuccessorRegions(
+    RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {
+  if (point.isParent())
+    if (auto selectedRegionIdx = getSelectedRegionAttr())
+      regions.emplace_back(
+          &getAlternatives()[selectedRegionIdx->getSExtValue()],
+          Block::BlockArgListType());
+    else
+      for (Region &alternative : getAlternatives())
+        regions.emplace_back(&alternative, Block::BlockArgListType());
+  else
+    regions.emplace_back(getOperation()->getResults());
+}
+
+void transform::tune::AlternativesOp::getRegionInvocationBounds(
+    ArrayRef<Attribute> operands, SmallVectorImpl<InvocationBounds> &bounds) {
+  (void)operands;
+  bounds.reserve(getNumRegions());
+
+  if (auto selectedRegionIdx = getSelectedRegionAttr()) {
+    bounds.resize(getNumRegions(), InvocationBounds(0, 0));
+    bounds[selectedRegionIdx->getSExtValue()] = InvocationBounds(1, 1);
+  } else {
+    bounds.resize(getNumRegions(), InvocationBounds(0, 1));
+  }
+}
+
+void transform::tune::AlternativesOp::getEffects(
+    SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
+  onlyReadsHandle(getSelectedRegionParamMutable(), effects);
+  producesHandle(getOperation()->getOpResults(), effects);
+  // TODO: should effects from regions be forwarded?
+}
+
+DiagnosedSilenceableFailure
+transform::tune::AlternativesOp::apply(transform::TransformRewriter &rewriter,
+                                       transform::TransformResults &results,
+                                       transform::TransformState &state) {
+  std::optional<size_t> selectedRegionIdx;
+
+  if (auto selectedRegionAttr = getSelectedRegionAttr())
+    selectedRegionIdx = selectedRegionAttr->getSExtValue();
+
+  if (Value selectedRegionParam = getSelectedRegionParam()) {
+    ArrayRef<Attribute> associatedAttrs = state.getParams(selectedRegionParam);
+    IntegerAttr selectedRegionAttr;
+    if (associatedAttrs.size() != 1 ||
+        !(selectedRegionAttr = dyn_cast<IntegerAttr>(associatedAttrs[0])))
+      return emitDefiniteFailure()
+             << "param should hold exactly one integer attribute, got: "
+             << associatedAttrs[0];
+    selectedRegionIdx = selectedRegionAttr.getValue().getSExtValue();
+  }
+
+  if (!selectedRegionIdx)
+    return emitDefiniteFailure() << "non-deterministic choice " << getName()
+                                 << " is only resolved through providing a "
+                                    "`selected_region` attr/param";
+
+  if (*selectedRegionIdx < 0 || *selectedRegionIdx >= getNumRegions())
+    return emitDefiniteFailure()
+           << "'selected_region' attribute/param specifies region at index "
+           << *selectedRegionIdx << " while op has only " << getNumRegions()
+           << " regions";
+
+  Region &selectedRegion = getRegion(*selectedRegionIdx);
+  auto scope = state.make_region_scope(selectedRegion);
+  Block &block = selectedRegion.front();
+  // Apply the region's ops one by one.
+  for (Operation &transform : block.without_terminator()) {
+    DiagnosedSilenceableFailure result =
+        state.applyTransform(cast<transform::TransformOpInterface>(transform));
+    if (result.isDefiniteFailure())
+      return result;
+
+    if (result.isSilenceableFailure()) {
+      for (const auto &res : getResults())
+        results.set(res, {});
+      return result;
+    }
+  }
+  // Forward the operation mapping for values yielded from the region to the
+  // values produced by the alternatives op.
+  transform::detail::forwardTerminatorOperands(&block, state, results);
+  return DiagnosedSilenceableFailure::success();
+}
+
+LogicalResult transform::tune::AlternativesOp::verify() {
+  for (auto *region : getRegions()) {
+    auto yieldTerminator =
+        llvm::dyn_cast_if_present<transform::YieldOp>(region->front().back());
+    if (!yieldTerminator)
+      return emitOpError() << "expected '"
+                           << transform::YieldOp::getOperationName()
+                           << "' as terminator";
+
+    if (yieldTerminator->getNumOperands() != getNumResults())
+      return yieldTerminator.emitOpError()
+             << "expected terminator to have as many operands as the parent op "
+                "has results";
+
+    for (auto [i, operandType, resultType] : llvm::zip_equal(
+             llvm::seq<unsigned>(0, yieldTerminator->getNumOperands()),
+             yieldTerminator->getOperands().getType(), getResultTypes())) {
+      if (operandType == resultType)
+        continue;
+      return yieldTerminator.emitOpError()
+             << "the type of the terminator operand #" << i
+             << " must match the type of the corresponding parent op result ("
+             << operandType << " vs " << resultType << ")";
+    }
+  }
+
+  if (auto selectedRegionAttr = getSelectedRegionAttr()) {
+    size_t regionIdx = selectedRegionAttr->getSExtValue();
+    if (regionIdx < 0 || regionIdx >= getNumRegions())
+      return emitOpError()
+             << "'selected_region' attribute specifies region at index "
+             << regionIdx << " while op has only " << getNumRegions()
+             << " regions";
+  }
+
+  return success();
+}
diff --git a/mlir/python/mlir/dialects/transform/tune.py b/mlir/python/mlir/dialects/transform/tune.py
index f63f88a382422..b3bfa8015c4d8 100644
--- a/mlir/python/mlir/dialects/transform/tune.py
+++ b/mlir/python/mlir/dialects/transform/tune.py
@@ -6,6 +6,9 @@
 
 from ...ir import (
     Type,
+    Value,
+    Operation,
+    OpView,
     Attribute,
     ArrayAttr,
     StringAttr,
@@ -19,7 +22,10 @@
 from .._transform_tune_extension_ops_gen import _Dialect
 
 try:
-    from .._ods_common import _cext as _ods_cext
+    from .._ods_common import (
+        get_op_result_or_value as _get_op_result_or_value,
+        _cext as _ods_cext,
+    )
 except ImportError as e:
     raise RuntimeError("Error loading imports from extension module") from e
 
@@ -36,7 +42,7 @@ def __init__(
             ArrayAttr, Sequence[Union[Attribute, bool, int, float, str]], Attribute
         ],
         *,
-        selected: Optional[Attribute] = None,
+        selected: Optional[Union[Attribute, bool, int, float, str]] = None,
         loc=None,
         ip=None,
     ):
@@ -75,8 +81,62 @@ def knob(
         ArrayAttr, Sequence[Union[Attribute, bool, int, float, str]], Attribute
     ],
     *,
-    selected: Optional[Attribute] = None,
+    selected: Optional[Union[Attribute, bool, int, float, str]] = None,
     loc=None,
     ip=None,
 ):
     return KnobOp(result, name, options, selected=selected, loc=loc, ip=ip)
+
+
+ at _ods_cext.register_operation(_Dialect, replace=True)
+class AlternativesOp(AlternativesOp):
+    def __init__(
+        self,
+        results: Sequence[Type],
+        name: Union[StringAttr, str],
+        num_alternatives: int,
+        *,
+        selected_region: Optional[
+            Union[int, IntegerAttr, Value, Operation, OpView]
+        ] = None,
+        loc=None,
+        ip=None,
+    ):
+        if isinstance(name, str):
+            name = StringAttr.get(name)
+
+        selected_region_attr = selected_region_param = None
+        if isinstance(selected_region, IntegerAttr):
+            selected_region_attr = selected_region
+        elif isinstance(selected_region, int):
+            selected_region_attr = IntegerAttr.get(
+                IntegerType.get_signless(32), selected_region
+            )
+        elif isinstance(selected_region, (Value, Operation, OpView)):
+            selected_region_param = _get_op_result_or_value(selected_region)
+
+        super().__init__(
+            results,
+            name,
+            num_alternatives,
+            selected_region_attr=selected_region_attr,
+            selected_region_param=selected_region_param,
+            loc=loc,
+            ip=ip,
+        )
+        for region in self.regions:
+            region.blocks.append()
+
+
+def alternatives(
+    results: Sequence[Type],
+    name: Union[StringAttr, str],
+    num_alternatives: int,
+    *,
+    selected_region: Optional[Union[int, IntegerAttr, Value, Operation, OpView]] = None,
+    loc=None,
+    ip=None,
+):
+    return AlternativesOp(
+        results, name, num_alternatives, selected_region=selected_region, loc=loc, ip=ip
+    )
diff --git a/mlir/test/Dialect/Transform/test-tune-extension-invalid.mlir b/mlir/test/Dialect/Transform/test-tune-extension-invalid.mlir
index 2e5f433abeb71..efc3890288456 100644
--- a/mlir/test/Dialect/Transform/test-tune-extension-invalid.mlir
+++ b/mlir/test/Dialect/Transform/test-tune-extension-invalid.mlir
@@ -19,3 +19,88 @@ module attributes {transform.with_named_sequence} {
     transform.yield
   }
 }
+
+// -----
+
+func.func private @f()
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+    // expected-error at below {{'selected_region' attribute specifies region at index 2 while op has only 2 regions}}
+    transform.tune.alternatives<"bifurcation"> selected_region = 2 {
+      transform.yield
+    }, {
+      transform.yield
+    }
+    transform.yield
+  }
+}
+
+// -----
+
+func.func private @f()
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+    %singleton_of_c0 = transform.param.constant [0] -> !transform.any_param
+    // expected-error at below {{param should hold exactly one integer attribute, got: [0]}}
+    transform.tune.alternatives<"bifurcation"> selected_region = %singleton_of_c0 : !transform.any_param {
+      transform.yield
+    }, {
+      transform.yield
+    }
+    transform.yield
+  }
+}
+
+// -----
+
+func.func private @f()
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+    %c0 = transform.param.constant 0 -> !transform.any_param
+    %c1 = transform.param.constant 1 -> !transform.any_param
+    %c0_and_c1 = transform.merge_handles %c0, %c1 : !transform.any_param
+    // expected-error at below {{param should hold exactly one integer attribute}}
+    transform.tune.alternatives<"bifurcation"> selected_region = %c0_and_c1 : !transform.any_param {
+      transform.yield
+    }, {
+      transform.yield
+    }
+    transform.yield
+  }
+}
+
+// -----
+
+func.func private @f()
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+    %c2 = transform.param.constant 2 -> !transform.any_param
+    // expected-error at below {{'selected_region' attribute/param specifies region at index 2 while op has only 2 regions}}
+    transform.tune.alternatives<"bifurcation"> selected_region = %c2 : !transform.any_param {
+      transform.yield
+    }, {
+      transform.yield
+    }
+    transform.yield
+  }
+}
+
+// -----
+
+func.func private @f()
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+    // expected-error at below {{non-deterministic choice "bifurcation" is only resolved through providing a `selected_region` attr/param}}
+    transform.tune.alternatives<"bifurcation"> {
+      transform.yield
+    }, {
+      transform.yield
+    }
+    transform.yield
+  }
+}
diff --git a/mlir/test/Dialect/Transform/test-tune-extension.mlir b/mlir/test/Dialect/Transform/test-tune-extension.mlir
index 0a253c6d5f837..80b7525136b33 100644
--- a/mlir/test/Dialect/Transform/test-tune-extension.mlir
+++ b/mlir/test/Dialect/Transform/test-tune-extension.mlir
@@ -59,3 +59,102 @@ module attributes {transform.with_named_sequence} {
     transform.yield
   }
 }
+
+
+// -----
+
+// CHECK-LABEL: schedule_with_two_independent_choices_already_made
+func.func @schedule_with_two_independent_choices_already_made(
+  %arg0: tensor<128x128xf32>, %arg1: tensor<128x128xf32>, %arg2: tensor<128x128xf32>)
+    -> tensor<128x128xf32> {
+//      CHECK-NOT: scf.forall
+//      CHECK:     scf.for
+//      CHECK-NOT:   scf.for
+//      CHECK:       scf.forall
+//      CHECK-NOT:   scf.for
+//      CHECK:         tensor.extract_slice
+//      CHECK:         tensor.extract_slice
+//      CHECK:         tensor.extract_slice
+//      CHECK:         linalg.matmul
+//      CHECK:         scf.forall.in_parallel
+//      CHECK:           tensor.parallel_insert_slice
+//      CHECK:       tensor.insert_slice
+//      CHECK:       scf.yield
+  %0 = linalg.matmul  ins(%arg0, %arg1: tensor<128x128xf32>, tensor<128x128xf32>)
+                     outs(%arg2: tensor<128x128xf32>) -> tensor<128x128xf32>
+  return %0 : tensor<128x128xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+    %matmul = transform.structured.match ops{["linalg.matmul"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+
+    %tiled_matmul = transform.tune.alternatives<"outer_par_or_seq_tiling"> selected_region = 0 -> !transform.any_op
+    { // First alternative/region, with index = 0
+      %contained_matmul, %loop = transform.structured.tile_using_for %matmul tile_sizes [8] : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+      transform.yield %contained_matmul : !transform.any_op
+    }, { // Second alternative/region, with index = 1
+      %contained_matmul, %loop = transform.structured.tile_using_forall %matmul tile_sizes [8] : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+      transform.yield %contained_matmul : !transform.any_op
+    }
+
+    transform.tune.alternatives<"inner_par_or_seq_tiling"> selected_region = 1 -> !transform.any_op {
+      %contained_matmul, %loop = transform.structured.tile_using_for %tiled_matmul tile_sizes [0, 16] : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+      transform.yield %contained_matmul : !transform.any_op
+    }, {
+      %contained_matmul, %loop = transform.structured.tile_using_forall %tiled_matmul tile_sizes [0, 16] : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+      transform.yield %contained_matmul : !transform.any_op
+    }
+
+    transform.yield
+  }
+}
+
+// -----
+
+// CHECK-LABEL: subschedule_with_choice_resolved_in_main_schedule
+func.func @subschedule_with_choice_resolved_in_main_schedule(
+  %arg0: tensor<128x128xf32>, %arg1: tensor<128x128xf32>, %arg2: tensor<128x128xf32>)
+    -> tensor<128x128xf32> {
+//      CHECK-NOT: scf.for
+//      CHECK:     scf.forall
+//      CHECK-NOT:   scf.forall
+//      CHECK:       scf.for
+//      CHECK-NOT:   scf.forall
+//      CHECK:         tensor.extract_slice
+//      CHECK:         tensor.extract_slice
+//      CHECK:         tensor.extract_slice
+//      CHECK:         linalg.matmul
+//      CHECK:         tensor.insert_slice
+//      CHECK:         scf.yield
+//      CHECK:       scf.forall.in_parallel
+//      CHECK:         tensor.parallel_insert_slice
+  %0 = linalg.matmul  ins(%arg0, %arg1: tensor<128x128xf32>, tensor<128x128xf32>)
+                     outs(%arg2: tensor<128x128xf32>) -> tensor<128x128xf32>
+  return %0 : tensor<128x128xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @subschedule_with_embedded_choice(%matmul: !transform.any_op {transform.readonly},
+                                                             %par_or_seq: !transform.param<i64> {transform.readonly},
+                                                             %tile_size: !transform.param<i64> {transform.readonly}) -> !transform.any_op {
+    %tiled_matmul = transform.tune.alternatives<"par_or_seq_tiling"> selected_region = %par_or_seq : !transform.param<i64> -> !transform.any_op {
+      %contained_matmul, %loop = transform.structured.tile_using_for %matmul tile_sizes [%tile_size] : (!transform.any_op, !transform.param<i64>) -> (!transform.any_op, !transform.any_op)
+      transform.yield %contained_matmul : !transform.any_op
+    }, {
+      %contained_matmul, %loop = transform.structured.tile_using_forall %matmul tile_sizes [%tile_size] : (!transform.any_op, !transform.param<i64>) -> (!transform.any_op, !transform.any_op)
+      transform.yield %contained_matmul : !transform.any_op
+    }
+    transform.yield %tiled_matmul : !transform.any_op
+  }
+  transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+    %matmul = transform.structured.match ops{["linalg.matmul"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+    %outer_par = transform.param.constant 1 -> !transform.param<i64>
+    %outer_tile_size = transform.param.constant 32 -> !transform.param<i64>
+    %inner_seq = transform.tune.knob<"inner_par_or_seq"> = 0 from options = [0, 1] -> !transform.param<i64>
+    %inner_tile_size = transform.param.constant 8 -> !transform.param<i64>
+    %tiled_matmul = transform.include @subschedule_with_embedded_choice failures(propagate) (%matmul, %outer_par, %outer_tile_size) : (!transform.any_op, !transform.param<i64>, !transform.param<i64>) -> !transform.any_op
+    %tiled_tiled_matmul = transform.include @subschedule_with_embedded_choice failures(propagate) (%tiled_matmul, %inner_seq, %inner_tile_size) : (!transform.any_op, !transform.param<i64>, !transform.param<i64>) -> !transform.any_op
+    transform.yield
+  }
+}
diff --git a/mlir/test/python/dialects/transform_tune_ext.py b/mlir/test/python/dialects/transform_tune_ext.py
index dfb93594bca52..3245e65b861de 100644
--- a/mlir/test/python/dialects/transform_tune_ext.py
+++ b/mlir/test/python/dialects/transform_tune_ext.py
@@ -1,21 +1,21 @@
 # RUN: %PYTHON %s | FileCheck %s
 
-from mlir.ir import *
+from mlir import ir
 from mlir.dialects import transform
 from mlir.dialects.transform import tune, debug
 
 
 def run(f):
-    print("\nTEST:", f.__name__)
-    with Context(), Location.unknown():
-        module = Module.create()
-        with InsertionPoint(module.body):
+    print("\n// TEST:", f.__name__)
+    with ir.Context(), ir.Location.unknown():
+        module = ir.Module.create()
+        with ir.InsertionPoint(module.body):
             sequence = transform.SequenceOp(
                 transform.FailurePropagationMode.Propagate,
                 [],
                 transform.AnyOpType.get(),
             )
-            with InsertionPoint(sequence.body):
+            with ir.InsertionPoint(sequence.body):
                 f(sequence.bodyTarget)
                 transform.YieldOp()
         print(module)
@@ -29,10 +29,10 @@ def testKnobOp(target):
 
     # CHECK: %[[HEADS_OR_TAILS:.*]] = transform.tune.knob<"coin"> options = [true, false] -> !transform.any_param
     heads_or_tails = tune.KnobOp(
-        result=any_param, name=StringAttr.get("coin"), options=[True, False]
+        result=any_param, name=ir.StringAttr.get("coin"), options=[True, False]
     )
     # CHECK: transform.tune.knob<"animal"> options = ["cat", "dog", unit] -> !transform.any_param
-    tune.KnobOp(any_param, name="animal", options=["cat", "dog", UnitAttr.get()])
+    tune.KnobOp(any_param, name="animal", options=["cat", "dog", ir.UnitAttr.get()])
     # CHECK: transform.tune.knob<"tile_size"> options = [2, 4, 8, 16, 24, 32] -> !transform.any_param
     tune.KnobOp(any_param, "tile_size", [2, 4, 8, 16, 24, 32])
     # CHECK: transform.tune.knob<"magic_value"> options = [2.000000e+00, 2.250000e+00, 2.500000e+00, 2.750000e+00, 3.000000e+00] -> !transform.any_param
@@ -45,7 +45,7 @@ def testKnobOp(target):
     heads = tune.KnobOp(any_param, "coin", options=[True, False], selected=True)
     # CHECK: transform.tune.knob<"animal"> = "dog" from options = ["cat", "dog", unit] -> !transform.any_param
     tune.KnobOp(
-        any_param, name="animal", options=["cat", "dog", UnitAttr.get()], selected="dog"
+        any_param, name="animal", options=["cat", "dog", ir.UnitAttr.get()], selected="dog"
     )
     # CHECK: transform.tune.knob<"tile_size"> = 8 : i64 from options = [2, 4, 8, 16, 24, 32] -> !transform.any_param
     tune.KnobOp(any_param, "tile_size", [2, 4, 8, 16, 24, 32], selected=8)
@@ -57,16 +57,75 @@ def testKnobOp(target):
 
     # CHECK: transform.tune.knob<"range_as_a_dict"> = 4 : i64 from options = {start = 2 : i64, step = 2 : i64, stop = 16 : i64} -> !transform.any_param
     # NB: Membership of `selected` in non-ArrayAttr `options` is _not_ verified.
-    i64 = IntegerType.get_signless(64)
+    i64 = ir.IntegerType.get_signless(64)
     tune.knob(
         any_param,
         "range_as_a_dict",
-        DictAttr.get(
+        ir.DictAttr.get(
             {
-                "start": IntegerAttr.get(i64, 2),
-                "stop": IntegerAttr.get(i64, 16),
-                "step": IntegerAttr.get(i64, 2),
+                "start": ir.IntegerAttr.get(i64, 2),
+                "stop": ir.IntegerAttr.get(i64, 16),
+                "step": ir.IntegerAttr.get(i64, 2),
             }
         ),
         selected=4,
     )
+
+# CHECK-LABEL: TEST: testAlternativesOp
+ at run
+def testAlternativesOp(target):
+    any_param = transform.AnyParamType.get()
+
+    # CHECK: %[[LEFT_OR_RIGHT_OUTCOME:.*]] = transform.tune.alternatives<"left_or_right"> -> !transform.any_param {
+    left_or_right = tune.AlternativesOp([transform.AnyParamType.get()], "left_or_right", 2)
+    with ir.InsertionPoint(left_or_right.alternatives[_left := 0].blocks[0]):
+        # CHECK: %[[C0:.*]] = transform.param.constant 0
+        i32_0 = ir.IntegerAttr.get(ir.IntegerType.get_signless(32), 0)
+        c0 = transform.ParamConstantOp(transform.AnyParamType.get(), i32_0)
+        # CHECK: transform.yield %[[C0]]
+        transform.yield_(c0)
+    # CHECK-NEXT: }, {
+    with ir.InsertionPoint(left_or_right.alternatives[_right := 1].blocks[0]):
+        # CHECK: %[[C1:.*]] = transform.param.constant 1
+        i32_1 = ir.IntegerAttr.get(ir.IntegerType.get_signless(32), 1)
+        c1 = transform.ParamConstantOp(transform.AnyParamType.get(), i32_1)
+        # CHECK: transform.yield %[[C1]]
+        transform.yield_(c1)
+    # CHECK-NEXT: }
+    outcome_of_left_or_right_decision = left_or_right.results[0]
+
+    # CHECK: transform.tune.alternatives<"fork_in_the_road"> selected_region = 0 -> !transform.any_param {
+    fork_in_the_road = tune.AlternativesOp([transform.AnyParamType.get()], "fork_in_the_road", 2, selected_region=0)
+    with ir.InsertionPoint(fork_in_the_road.alternatives[_left := 0].blocks[0]):
+        # CHECK: %[[C0:.*]] = transform.param.constant 0
+        i32_0 = ir.IntegerAttr.get(ir.IntegerType.get_signless(32), 0)
+        c0 = transform.ParamConstantOp(transform.AnyParamType.get(), i32_0)
+        # CHECK: transform.yield %[[C0]]
+        transform.yield_(c0)
+    # CHECK-NEXT: }, {
+    with ir.InsertionPoint(fork_in_the_road.alternatives[_right := 1].blocks[0]):
+        # CHECK: %[[C1:.*]] = transform.param.constant 1
+        i32_1 = ir.IntegerAttr.get(ir.IntegerType.get_signless(32), 1)
+        c1 = transform.ParamConstantOp(transform.AnyParamType.get(), i32_1)
+        # CHECK: transform.yield %[[C1]]
+        transform.yield_(c1)
+    # CHECK-NEXT: }
+
+    # CHECK: transform.tune.alternatives<"left_or_right_as_before"> selected_region = %[[LEFT_OR_RIGHT_OUTCOME]] : !transform.any_param {
+    left_or_right_as_before = tune.AlternativesOp([], "left_or_right_as_before", 2, selected_region=outcome_of_left_or_right_decision)
+    with ir.InsertionPoint(left_or_right_as_before.alternatives[_left := 0].blocks[0]):
+        # CHECK: transform.param.constant 1337
+        i32_1337 = ir.IntegerAttr.get(ir.IntegerType.get_signless(32), 1337)
+        c1337 = transform.ParamConstantOp(transform.AnyParamType.get(), i32_1337)
+        # CHECK: transform.debug.emit_param_as_remark
+        debug.emit_param_as_remark(c1337)
+        transform.yield_([])
+    # CHECK-NEXT: }, {
+    with ir.InsertionPoint(left_or_right_as_before.alternatives[_right := 1].blocks[0]):
+        # CHECK: transform.param.constant 42
+        i32_42 = ir.IntegerAttr.get(ir.IntegerType.get_signless(32), 42)
+        c42 = transform.ParamConstantOp(transform.AnyParamType.get(), i32_42)
+        # CHECK: transform.debug.emit_param_as_remark
+        debug.emit_param_as_remark(c42)
+        transform.yield_([])
+    # CHECK-NEXT: }

>From 940a7197597435c106f287a364f34b37b34f6dbc Mon Sep 17 00:00:00 2001
From: Rolf Morel <rolf.morel at intel.com>
Date: Thu, 25 Sep 2025 08:44:48 -0700
Subject: [PATCH 2/7] Fix python formatting

---
 .../python/dialects/transform_tune_ext.py     | 21 +++++++++++++++----
 1 file changed, 17 insertions(+), 4 deletions(-)

diff --git a/mlir/test/python/dialects/transform_tune_ext.py b/mlir/test/python/dialects/transform_tune_ext.py
index 3245e65b861de..afbe84fe2f8d8 100644
--- a/mlir/test/python/dialects/transform_tune_ext.py
+++ b/mlir/test/python/dialects/transform_tune_ext.py
@@ -45,7 +45,10 @@ def testKnobOp(target):
     heads = tune.KnobOp(any_param, "coin", options=[True, False], selected=True)
     # CHECK: transform.tune.knob<"animal"> = "dog" from options = ["cat", "dog", unit] -> !transform.any_param
     tune.KnobOp(
-        any_param, name="animal", options=["cat", "dog", ir.UnitAttr.get()], selected="dog"
+        any_param,
+        name="animal",
+        options=["cat", "dog", ir.UnitAttr.get()],
+        selected="dog",
     )
     # CHECK: transform.tune.knob<"tile_size"> = 8 : i64 from options = [2, 4, 8, 16, 24, 32] -> !transform.any_param
     tune.KnobOp(any_param, "tile_size", [2, 4, 8, 16, 24, 32], selected=8)
@@ -71,13 +74,16 @@ def testKnobOp(target):
         selected=4,
     )
 
+
 # CHECK-LABEL: TEST: testAlternativesOp
 @run
 def testAlternativesOp(target):
     any_param = transform.AnyParamType.get()
 
     # CHECK: %[[LEFT_OR_RIGHT_OUTCOME:.*]] = transform.tune.alternatives<"left_or_right"> -> !transform.any_param {
-    left_or_right = tune.AlternativesOp([transform.AnyParamType.get()], "left_or_right", 2)
+    left_or_right = tune.AlternativesOp(
+        [transform.AnyParamType.get()], "left_or_right", 2
+    )
     with ir.InsertionPoint(left_or_right.alternatives[_left := 0].blocks[0]):
         # CHECK: %[[C0:.*]] = transform.param.constant 0
         i32_0 = ir.IntegerAttr.get(ir.IntegerType.get_signless(32), 0)
@@ -95,7 +101,9 @@ def testAlternativesOp(target):
     outcome_of_left_or_right_decision = left_or_right.results[0]
 
     # CHECK: transform.tune.alternatives<"fork_in_the_road"> selected_region = 0 -> !transform.any_param {
-    fork_in_the_road = tune.AlternativesOp([transform.AnyParamType.get()], "fork_in_the_road", 2, selected_region=0)
+    fork_in_the_road = tune.AlternativesOp(
+        [transform.AnyParamType.get()], "fork_in_the_road", 2, selected_region=0
+    )
     with ir.InsertionPoint(fork_in_the_road.alternatives[_left := 0].blocks[0]):
         # CHECK: %[[C0:.*]] = transform.param.constant 0
         i32_0 = ir.IntegerAttr.get(ir.IntegerType.get_signless(32), 0)
@@ -112,7 +120,12 @@ def testAlternativesOp(target):
     # CHECK-NEXT: }
 
     # CHECK: transform.tune.alternatives<"left_or_right_as_before"> selected_region = %[[LEFT_OR_RIGHT_OUTCOME]] : !transform.any_param {
-    left_or_right_as_before = tune.AlternativesOp([], "left_or_right_as_before", 2, selected_region=outcome_of_left_or_right_decision)
+    left_or_right_as_before = tune.AlternativesOp(
+        [],
+        "left_or_right_as_before",
+        2,
+        selected_region=outcome_of_left_or_right_decision,
+    )
     with ir.InsertionPoint(left_or_right_as_before.alternatives[_left := 0].blocks[0]):
         # CHECK: transform.param.constant 1337
         i32_1337 = ir.IntegerAttr.get(ir.IntegerType.get_signless(32), 1337)

>From c82592c81165b2b572d2dcb21941cf8831b7706c Mon Sep 17 00:00:00 2001
From: Rolf Morel <rolf.morel at intel.com>
Date: Thu, 25 Sep 2025 09:10:56 -0700
Subject: [PATCH 3/7] Fix py formatting -- 2 out of n

---
 mlir/test/python/dialects/transform_tune_ext.py | 12 ++++++------
 1 file changed, 6 insertions(+), 6 deletions(-)

diff --git a/mlir/test/python/dialects/transform_tune_ext.py b/mlir/test/python/dialects/transform_tune_ext.py
index afbe84fe2f8d8..01e20726f84ac 100644
--- a/mlir/test/python/dialects/transform_tune_ext.py
+++ b/mlir/test/python/dialects/transform_tune_ext.py
@@ -84,14 +84,14 @@ def testAlternativesOp(target):
     left_or_right = tune.AlternativesOp(
         [transform.AnyParamType.get()], "left_or_right", 2
     )
-    with ir.InsertionPoint(left_or_right.alternatives[_left := 0].blocks[0]):
+    with ir.InsertionPoint(left_or_right.alternatives[_left:=0].blocks[0]):
         # CHECK: %[[C0:.*]] = transform.param.constant 0
         i32_0 = ir.IntegerAttr.get(ir.IntegerType.get_signless(32), 0)
         c0 = transform.ParamConstantOp(transform.AnyParamType.get(), i32_0)
         # CHECK: transform.yield %[[C0]]
         transform.yield_(c0)
     # CHECK-NEXT: }, {
-    with ir.InsertionPoint(left_or_right.alternatives[_right := 1].blocks[0]):
+    with ir.InsertionPoint(left_or_right.alternatives[_right:=1].blocks[0]):
         # CHECK: %[[C1:.*]] = transform.param.constant 1
         i32_1 = ir.IntegerAttr.get(ir.IntegerType.get_signless(32), 1)
         c1 = transform.ParamConstantOp(transform.AnyParamType.get(), i32_1)
@@ -104,14 +104,14 @@ def testAlternativesOp(target):
     fork_in_the_road = tune.AlternativesOp(
         [transform.AnyParamType.get()], "fork_in_the_road", 2, selected_region=0
     )
-    with ir.InsertionPoint(fork_in_the_road.alternatives[_left := 0].blocks[0]):
+    with ir.InsertionPoint(fork_in_the_road.alternatives[_left:=0].blocks[0]):
         # CHECK: %[[C0:.*]] = transform.param.constant 0
         i32_0 = ir.IntegerAttr.get(ir.IntegerType.get_signless(32), 0)
         c0 = transform.ParamConstantOp(transform.AnyParamType.get(), i32_0)
         # CHECK: transform.yield %[[C0]]
         transform.yield_(c0)
     # CHECK-NEXT: }, {
-    with ir.InsertionPoint(fork_in_the_road.alternatives[_right := 1].blocks[0]):
+    with ir.InsertionPoint(fork_in_the_road.alternatives[_right:=1].blocks[0]):
         # CHECK: %[[C1:.*]] = transform.param.constant 1
         i32_1 = ir.IntegerAttr.get(ir.IntegerType.get_signless(32), 1)
         c1 = transform.ParamConstantOp(transform.AnyParamType.get(), i32_1)
@@ -126,7 +126,7 @@ def testAlternativesOp(target):
         2,
         selected_region=outcome_of_left_or_right_decision,
     )
-    with ir.InsertionPoint(left_or_right_as_before.alternatives[_left := 0].blocks[0]):
+    with ir.InsertionPoint(left_or_right_as_before.alternatives[_left:=0].blocks[0]):
         # CHECK: transform.param.constant 1337
         i32_1337 = ir.IntegerAttr.get(ir.IntegerType.get_signless(32), 1337)
         c1337 = transform.ParamConstantOp(transform.AnyParamType.get(), i32_1337)
@@ -134,7 +134,7 @@ def testAlternativesOp(target):
         debug.emit_param_as_remark(c1337)
         transform.yield_([])
     # CHECK-NEXT: }, {
-    with ir.InsertionPoint(left_or_right_as_before.alternatives[_right := 1].blocks[0]):
+    with ir.InsertionPoint(left_or_right_as_before.alternatives[_right:=1].blocks[0]):
         # CHECK: transform.param.constant 42
         i32_42 = ir.IntegerAttr.get(ir.IntegerType.get_signless(32), 42)
         c42 = transform.ParamConstantOp(transform.AnyParamType.get(), i32_42)

>From e7b217dcc2ad442cd408b5215a29a11b18237bb2 Mon Sep 17 00:00:00 2001
From: Rolf Morel <rolf.morel at intel.com>
Date: Mon, 29 Sep 2025 14:12:26 -0700
Subject: [PATCH 4/7] Add 4-way test case

---
 .../Transform/test-tune-extension.mlir        | 27 +++++++++++++++++++
 1 file changed, 27 insertions(+)

diff --git a/mlir/test/Dialect/Transform/test-tune-extension.mlir b/mlir/test/Dialect/Transform/test-tune-extension.mlir
index 80b7525136b33..5da48a2218ec6 100644
--- a/mlir/test/Dialect/Transform/test-tune-extension.mlir
+++ b/mlir/test/Dialect/Transform/test-tune-extension.mlir
@@ -158,3 +158,30 @@ module attributes {transform.with_named_sequence} {
     transform.yield
   }
 }
+
+// -----
+
+// CHECK-LABEL: eeny_meeny_miny_moe
+func.func private @eeny_meeny_miny_moe()
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+    %matmul = transform.structured.match ops{["linalg.matmul"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+
+    %tiled_matmul = transform.tune.alternatives<"4way"> selected_region = 3 -> !transform.any_param
+    { // First alternative/region, with index = 0
+      %out = transform.param.constant "eeny" -> !transform.any_param
+      transform.yield %out : !transform.any_param
+    }, { // Second alternative/region, with index = 1
+      %out = transform.param.constant "meeny" -> !transform.any_param
+      transform.yield %out : !transform.any_param
+    }, { // Third alternative/region, with index = 2
+      %out = transform.param.constant "miny" -> !transform.any_param
+      transform.yield %out : !transform.any_param
+    }, { // Fourth alternative/region, with index = 3
+      %out = transform.param.constant "moe" -> !transform.any_param
+      transform.yield %out : !transform.any_param
+    }
+    transform.yield
+  }
+}
\ No newline at end of file

>From 085bdf33444d81cc5456c07228124a1280a16c61 Mon Sep 17 00:00:00 2001
From: Rolf Morel <rolf.morel at intel.com>
Date: Mon, 29 Sep 2025 14:18:01 -0700
Subject: [PATCH 5/7] Remove unused include

---
 mlir/lib/Dialect/Transform/TuneExtension/TuneExtensionOps.cpp | 1 -
 1 file changed, 1 deletion(-)

diff --git a/mlir/lib/Dialect/Transform/TuneExtension/TuneExtensionOps.cpp b/mlir/lib/Dialect/Transform/TuneExtension/TuneExtensionOps.cpp
index dad63586bc8d4..c627158e999ed 100644
--- a/mlir/lib/Dialect/Transform/TuneExtension/TuneExtensionOps.cpp
+++ b/mlir/lib/Dialect/Transform/TuneExtension/TuneExtensionOps.cpp
@@ -10,7 +10,6 @@
 #include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
 #include "mlir/IR/OpImplementation.h"
 #include "llvm/Support/Debug.h"
-#include <cstddef>
 
 #include "mlir/Dialect/Transform/TuneExtension/TuneExtensionOps.h"
 

>From cbca7b72bccb8f8375a2cbcf115b71c7db1c6763 Mon Sep 17 00:00:00 2001
From: Rolf Morel <rolf.morel at intel.com>
Date: Wed, 1 Oct 2025 06:32:48 -0700
Subject: [PATCH 6/7] Fix python syntax error that only occurs on Windows

I wish I was kidding...
---
 mlir/test/python/dialects/transform_tune_ext.py | 9 +++++----
 1 file changed, 5 insertions(+), 4 deletions(-)

diff --git a/mlir/test/python/dialects/transform_tune_ext.py b/mlir/test/python/dialects/transform_tune_ext.py
index 01e20726f84ac..77a8794197758 100644
--- a/mlir/test/python/dialects/transform_tune_ext.py
+++ b/mlir/test/python/dialects/transform_tune_ext.py
@@ -84,14 +84,15 @@ def testAlternativesOp(target):
     left_or_right = tune.AlternativesOp(
         [transform.AnyParamType.get()], "left_or_right", 2
     )
-    with ir.InsertionPoint(left_or_right.alternatives[_left:=0].blocks[0]):
+    idx_for_left, idx_for_right = 0, 1
+    with ir.InsertionPoint(left_or_right.alternatives[idx_for_left].blocks[0]):
         # CHECK: %[[C0:.*]] = transform.param.constant 0
         i32_0 = ir.IntegerAttr.get(ir.IntegerType.get_signless(32), 0)
         c0 = transform.ParamConstantOp(transform.AnyParamType.get(), i32_0)
         # CHECK: transform.yield %[[C0]]
         transform.yield_(c0)
     # CHECK-NEXT: }, {
-    with ir.InsertionPoint(left_or_right.alternatives[_right:=1].blocks[0]):
+    with ir.InsertionPoint(left_or_right.alternatives[idx_for_right].blocks[0]):
         # CHECK: %[[C1:.*]] = transform.param.constant 1
         i32_1 = ir.IntegerAttr.get(ir.IntegerType.get_signless(32), 1)
         c1 = transform.ParamConstantOp(transform.AnyParamType.get(), i32_1)
@@ -104,14 +105,14 @@ def testAlternativesOp(target):
     fork_in_the_road = tune.AlternativesOp(
         [transform.AnyParamType.get()], "fork_in_the_road", 2, selected_region=0
     )
-    with ir.InsertionPoint(fork_in_the_road.alternatives[_left:=0].blocks[0]):
+    with ir.InsertionPoint(fork_in_the_road.alternatives[idx_for_left].blocks[0]):
         # CHECK: %[[C0:.*]] = transform.param.constant 0
         i32_0 = ir.IntegerAttr.get(ir.IntegerType.get_signless(32), 0)
         c0 = transform.ParamConstantOp(transform.AnyParamType.get(), i32_0)
         # CHECK: transform.yield %[[C0]]
         transform.yield_(c0)
     # CHECK-NEXT: }, {
-    with ir.InsertionPoint(fork_in_the_road.alternatives[_right:=1].blocks[0]):
+    with ir.InsertionPoint(fork_in_the_road.alternatives[idx_for_right].blocks[0]):
         # CHECK: %[[C1:.*]] = transform.param.constant 1
         i32_1 = ir.IntegerAttr.get(ir.IntegerType.get_signless(32), 1)
         c1 = transform.ParamConstantOp(transform.AnyParamType.get(), i32_1)

>From d1da9f11bef86860ad45728b2c48a592bb1069cf Mon Sep 17 00:00:00 2001
From: Rolf Morel <rolf.morel at intel.com>
Date: Wed, 1 Oct 2025 06:35:33 -0700
Subject: [PATCH 7/7] Py syntax fix -- 4 of n

---
 mlir/test/python/dialects/transform_tune_ext.py | 8 ++++++--
 1 file changed, 6 insertions(+), 2 deletions(-)

diff --git a/mlir/test/python/dialects/transform_tune_ext.py b/mlir/test/python/dialects/transform_tune_ext.py
index 77a8794197758..eb2a083211ef7 100644
--- a/mlir/test/python/dialects/transform_tune_ext.py
+++ b/mlir/test/python/dialects/transform_tune_ext.py
@@ -127,7 +127,9 @@ def testAlternativesOp(target):
         2,
         selected_region=outcome_of_left_or_right_decision,
     )
-    with ir.InsertionPoint(left_or_right_as_before.alternatives[_left:=0].blocks[0]):
+    with ir.InsertionPoint(
+        left_or_right_as_before.alternatives[idx_for_left].blocks[0]
+    ):
         # CHECK: transform.param.constant 1337
         i32_1337 = ir.IntegerAttr.get(ir.IntegerType.get_signless(32), 1337)
         c1337 = transform.ParamConstantOp(transform.AnyParamType.get(), i32_1337)
@@ -135,7 +137,9 @@ def testAlternativesOp(target):
         debug.emit_param_as_remark(c1337)
         transform.yield_([])
     # CHECK-NEXT: }, {
-    with ir.InsertionPoint(left_or_right_as_before.alternatives[_right:=1].blocks[0]):
+    with ir.InsertionPoint(
+        left_or_right_as_before.alternatives[idx_for_right].blocks[0]
+    ):
         # CHECK: transform.param.constant 42
         i32_42 = ir.IntegerAttr.get(ir.IntegerType.get_signless(32), 42)
         c42 = transform.ParamConstantOp(transform.AnyParamType.get(), i32_42)



More information about the Mlir-commits mailing list