[Mlir-commits] [mlir] [MLIR][Transform][Tune] Introduce `transform.tune.alternatives` op (PR #160724)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Sep 25 08:13:51 PDT 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Rolf Morel (rolfmorel)
<details>
<summary>Changes</summary>
This op enables expressing uncertainty regarding what should be happening at particular places in transform-dialect schedules. In particular, it enables representing a choice among alternative regions. This choice is resolved through providing a `selected_region` argument. When this argument is provided, the semantics are such that it is valid to rewrite the op through substituting in the selected region -- with its interpreted semantics corresponding to exactly this.
This op represents another piece of the puzzle w.r.t. a toolkit for expressing autotuning problems with the transform dialect. Note that this goes beyond tuning knobs _on_ transforms -- going further by making it tunable which (sequences of) transforms are to be applied.
---
Patch is 30.94 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/160724.diff
7 Files Affected:
- (modified) mlir/include/mlir/Dialect/Transform/TuneExtension/TuneExtensionOps.h (+1)
- (modified) mlir/include/mlir/Dialect/Transform/TuneExtension/TuneExtensionOps.td (+54)
- (modified) mlir/lib/Dialect/Transform/TuneExtension/TuneExtensionOps.cpp (+185)
- (modified) mlir/python/mlir/dialects/transform/tune.py (+63-3)
- (modified) mlir/test/Dialect/Transform/test-tune-extension-invalid.mlir (+85)
- (modified) mlir/test/Dialect/Transform/test-tune-extension.mlir (+99)
- (modified) mlir/test/python/dialects/transform_tune_ext.py (+73-14)
``````````diff
diff --git a/mlir/include/mlir/Dialect/Transform/TuneExtension/TuneExtensionOps.h b/mlir/include/mlir/Dialect/Transform/TuneExtension/TuneExtensionOps.h
index 74e1d28ffac82..ba11259790676 100644
--- a/mlir/include/mlir/Dialect/Transform/TuneExtension/TuneExtensionOps.h
+++ b/mlir/include/mlir/Dialect/Transform/TuneExtension/TuneExtensionOps.h
@@ -9,6 +9,7 @@
#ifndef MLIR_DIALECT_TRANSFORM_TUNEEXTENSION_TUNEEXTENSIONOPS_H
#define MLIR_DIALECT_TRANSFORM_TUNEEXTENSION_TUNEEXTENSIONOPS_H
+#include "mlir/Dialect/Transform/IR/TransformOps.h"
#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/OpDefinition.h"
diff --git a/mlir/include/mlir/Dialect/Transform/TuneExtension/TuneExtensionOps.td b/mlir/include/mlir/Dialect/Transform/TuneExtension/TuneExtensionOps.td
index d68d451afac40..d095659fc4838 100644
--- a/mlir/include/mlir/Dialect/Transform/TuneExtension/TuneExtensionOps.td
+++ b/mlir/include/mlir/Dialect/Transform/TuneExtension/TuneExtensionOps.td
@@ -11,10 +11,15 @@
include "mlir/Dialect/Transform/IR/TransformDialect.td"
include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.td"
+include "mlir/Interfaces/ControlFlowInterfaces.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
include "mlir/IR/BuiltinAttributes.td"
include "mlir/IR/CommonAttrConstraints.td"
+//===----------------------------------------------------------------------===//
+// KnobOp
+//===----------------------------------------------------------------------===//
+
def KnobOp : Op<Transform_Dialect, "tune.knob", [
DeclareOpInterfaceMethods<TransformOpInterface>,
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
@@ -52,4 +57,53 @@ def KnobOp : Op<Transform_Dialect, "tune.knob", [
"`<` $name `>` (`=` $selected^ `from`)? `options` `=` $options attr-dict `->` type(results)";
}
+//===----------------------------------------------------------------------===//
+// AlternativesOp
+//===----------------------------------------------------------------------===//
+
+def AlternativesOp : Op<Transform_Dialect, "tune.alternatives", [
+ DeclareOpInterfaceMethods<RegionBranchOpInterface,
+ ["getEntrySuccessorOperands", "getSuccessorRegions",
+ "getRegionInvocationBounds"]>,
+ DeclareOpInterfaceMethods<TransformOpInterface>,
+ DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
+ SingleBlockImplicitTerminator<"::mlir::transform::YieldOp">,
+ NoRegionArguments
+]> {
+ let summary = "Represents a choice among its regions, i.e. sub-schedules";
+
+ let description = [{
+ This op represents a choice over which of its regions is to be used.
+
+ When `selected_region` is provided, the semantics are that this op is to be
+ substituted for by the selected region, meaning the region's results become
+ the results of this op. Without a provided `selected_region`, the semantics
+ are that this non-deterministic choice is yet to be resolved -- which in
+ terms of the op's interpreted semantics is a failure.
+
+ The `selected_region` argument is either an `IntegerAttr` or a param holding
+ an `IntegerAttr`, which should provide a valid zero-based index with respect
+ to the number of alternatives, i.e. regions.
+ }];
+ let cppNamespace = [{ mlir::transform::tune }];
+
+ let arguments = (ins Builtin_StringAttr:$name,
+ OptionalAttr<APIntAttr>:$selected_region_attr,
+ Optional<TransformParamTypeInterface>:$selected_region_param);
+ let results = (outs Variadic<Transform_AnyHandleOrParamType>:$results);
+ let regions = (region VariadicRegion<SizedRegion<1>>:$alternatives);
+
+ let assemblyFormat = [{
+ `<` $name `>`
+ (`selected_region` `=` custom<AlternativesOpSelectedRegion>(
+ $selected_region_attr, $selected_region_param)^)?
+ attr-dict-with-keyword
+ (`:` type($selected_region_param)^)?
+ (`->` type($results)^)?
+ regions
+ }];
+
+ let hasVerifier = 1;
+}
+
#endif // MLIR_DIALECT_TRANSFORM_TUNEEXTENSION_TUNEEXTENSIONOPS
diff --git a/mlir/lib/Dialect/Transform/TuneExtension/TuneExtensionOps.cpp b/mlir/lib/Dialect/Transform/TuneExtension/TuneExtensionOps.cpp
index 842e880ca9150..dad63586bc8d4 100644
--- a/mlir/lib/Dialect/Transform/TuneExtension/TuneExtensionOps.cpp
+++ b/mlir/lib/Dialect/Transform/TuneExtension/TuneExtensionOps.cpp
@@ -6,13 +6,25 @@
//
//===----------------------------------------------------------------------===//
+#include "mlir/Dialect/Transform/IR/TransformOps.h"
#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
+#include "mlir/IR/OpImplementation.h"
#include "llvm/Support/Debug.h"
+#include <cstddef>
#include "mlir/Dialect/Transform/TuneExtension/TuneExtensionOps.h"
using namespace mlir;
+static ParseResult parseAlternativesOpSelectedRegion(
+ OpAsmParser &parser, IntegerAttr &selectedRegionAttr,
+ std::optional<OpAsmParser::UnresolvedOperand> &selectedRegionParam);
+
+static void printAlternativesOpSelectedRegion(OpAsmPrinter &printer,
+ Operation *op,
+ IntegerAttr selectedRegionAttr,
+ Value selectedRegionParam);
+
#define GET_OP_CLASSES
#include "mlir/Dialect/Transform/TuneExtension/TuneExtensionOps.cpp.inc"
@@ -57,3 +69,176 @@ LogicalResult transform::tune::KnobOp::verify() {
return success();
}
+
+//===----------------------------------------------------------------------===//
+// AlternativesOp
+//===----------------------------------------------------------------------===//
+
+static ParseResult parseAlternativesOpSelectedRegion(
+ OpAsmParser &parser, IntegerAttr &selectedRegionAttr,
+ std::optional<OpAsmParser::UnresolvedOperand> &selectedRegionParam) {
+ size_t selectedRegionIdx;
+ OptionalParseResult attrParseRes =
+ parser.parseOptionalInteger(selectedRegionIdx);
+ if (attrParseRes.has_value()) {
+ if (failed(*attrParseRes))
+ return failure();
+
+ selectedRegionAttr = parser.getBuilder().getIndexAttr(selectedRegionIdx);
+ return success();
+ }
+
+ OpAsmParser::UnresolvedOperand param;
+ auto paramParseRes = parser.parseOptionalOperand(param);
+ if (paramParseRes.has_value()) {
+ if (failed(*paramParseRes))
+ return failure();
+
+ selectedRegionParam = param;
+ return success();
+ }
+
+ return parser.emitError(parser.getCurrentLocation())
+ << "expected either an integer attribute or a transform.param operand";
+}
+
+static void printAlternativesOpSelectedRegion(OpAsmPrinter &printer,
+ Operation *op,
+ IntegerAttr selectedRegionAttr,
+ Value selectedRegionParam) {
+ if (selectedRegionAttr)
+ printer << selectedRegionAttr.getValue();
+ if (selectedRegionParam)
+ printer << selectedRegionParam;
+}
+
+OperandRange transform::tune::AlternativesOp::getEntrySuccessorOperands(
+ RegionBranchPoint point) {
+ // No operands will be forwarded to the region(s).
+ return getOperands().slice(0, 0);
+}
+
+void transform::tune::AlternativesOp::getSuccessorRegions(
+ RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> ®ions) {
+ if (point.isParent())
+ if (auto selectedRegionIdx = getSelectedRegionAttr())
+ regions.emplace_back(
+ &getAlternatives()[selectedRegionIdx->getSExtValue()],
+ Block::BlockArgListType());
+ else
+ for (Region &alternative : getAlternatives())
+ regions.emplace_back(&alternative, Block::BlockArgListType());
+ else
+ regions.emplace_back(getOperation()->getResults());
+}
+
+void transform::tune::AlternativesOp::getRegionInvocationBounds(
+ ArrayRef<Attribute> operands, SmallVectorImpl<InvocationBounds> &bounds) {
+ (void)operands;
+ bounds.reserve(getNumRegions());
+
+ if (auto selectedRegionIdx = getSelectedRegionAttr()) {
+ bounds.resize(getNumRegions(), InvocationBounds(0, 0));
+ bounds[selectedRegionIdx->getSExtValue()] = InvocationBounds(1, 1);
+ } else {
+ bounds.resize(getNumRegions(), InvocationBounds(0, 1));
+ }
+}
+
+void transform::tune::AlternativesOp::getEffects(
+ SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
+ onlyReadsHandle(getSelectedRegionParamMutable(), effects);
+ producesHandle(getOperation()->getOpResults(), effects);
+ // TODO: should effects from regions be forwarded?
+}
+
+DiagnosedSilenceableFailure
+transform::tune::AlternativesOp::apply(transform::TransformRewriter &rewriter,
+ transform::TransformResults &results,
+ transform::TransformState &state) {
+ std::optional<size_t> selectedRegionIdx;
+
+ if (auto selectedRegionAttr = getSelectedRegionAttr())
+ selectedRegionIdx = selectedRegionAttr->getSExtValue();
+
+ if (Value selectedRegionParam = getSelectedRegionParam()) {
+ ArrayRef<Attribute> associatedAttrs = state.getParams(selectedRegionParam);
+ IntegerAttr selectedRegionAttr;
+ if (associatedAttrs.size() != 1 ||
+ !(selectedRegionAttr = dyn_cast<IntegerAttr>(associatedAttrs[0])))
+ return emitDefiniteFailure()
+ << "param should hold exactly one integer attribute, got: "
+ << associatedAttrs[0];
+ selectedRegionIdx = selectedRegionAttr.getValue().getSExtValue();
+ }
+
+ if (!selectedRegionIdx)
+ return emitDefiniteFailure() << "non-deterministic choice " << getName()
+ << " is only resolved through providing a "
+ "`selected_region` attr/param";
+
+ if (*selectedRegionIdx < 0 || *selectedRegionIdx >= getNumRegions())
+ return emitDefiniteFailure()
+ << "'selected_region' attribute/param specifies region at index "
+ << *selectedRegionIdx << " while op has only " << getNumRegions()
+ << " regions";
+
+ Region &selectedRegion = getRegion(*selectedRegionIdx);
+ auto scope = state.make_region_scope(selectedRegion);
+ Block &block = selectedRegion.front();
+ // Apply the region's ops one by one.
+ for (Operation &transform : block.without_terminator()) {
+ DiagnosedSilenceableFailure result =
+ state.applyTransform(cast<transform::TransformOpInterface>(transform));
+ if (result.isDefiniteFailure())
+ return result;
+
+ if (result.isSilenceableFailure()) {
+ for (const auto &res : getResults())
+ results.set(res, {});
+ return result;
+ }
+ }
+ // Forward the operation mapping for values yielded from the region to the
+ // values produced by the alternatives op.
+ transform::detail::forwardTerminatorOperands(&block, state, results);
+ return DiagnosedSilenceableFailure::success();
+}
+
+LogicalResult transform::tune::AlternativesOp::verify() {
+ for (auto *region : getRegions()) {
+ auto yieldTerminator =
+ llvm::dyn_cast_if_present<transform::YieldOp>(region->front().back());
+ if (!yieldTerminator)
+ return emitOpError() << "expected '"
+ << transform::YieldOp::getOperationName()
+ << "' as terminator";
+
+ if (yieldTerminator->getNumOperands() != getNumResults())
+ return yieldTerminator.emitOpError()
+ << "expected terminator to have as many operands as the parent op "
+ "has results";
+
+ for (auto [i, operandType, resultType] : llvm::zip_equal(
+ llvm::seq<unsigned>(0, yieldTerminator->getNumOperands()),
+ yieldTerminator->getOperands().getType(), getResultTypes())) {
+ if (operandType == resultType)
+ continue;
+ return yieldTerminator.emitOpError()
+ << "the type of the terminator operand #" << i
+ << " must match the type of the corresponding parent op result ("
+ << operandType << " vs " << resultType << ")";
+ }
+ }
+
+ if (auto selectedRegionAttr = getSelectedRegionAttr()) {
+ size_t regionIdx = selectedRegionAttr->getSExtValue();
+ if (regionIdx < 0 || regionIdx >= getNumRegions())
+ return emitOpError()
+ << "'selected_region' attribute specifies region at index "
+ << regionIdx << " while op has only " << getNumRegions()
+ << " regions";
+ }
+
+ return success();
+}
diff --git a/mlir/python/mlir/dialects/transform/tune.py b/mlir/python/mlir/dialects/transform/tune.py
index f63f88a382422..b3bfa8015c4d8 100644
--- a/mlir/python/mlir/dialects/transform/tune.py
+++ b/mlir/python/mlir/dialects/transform/tune.py
@@ -6,6 +6,9 @@
from ...ir import (
Type,
+ Value,
+ Operation,
+ OpView,
Attribute,
ArrayAttr,
StringAttr,
@@ -19,7 +22,10 @@
from .._transform_tune_extension_ops_gen import _Dialect
try:
- from .._ods_common import _cext as _ods_cext
+ from .._ods_common import (
+ get_op_result_or_value as _get_op_result_or_value,
+ _cext as _ods_cext,
+ )
except ImportError as e:
raise RuntimeError("Error loading imports from extension module") from e
@@ -36,7 +42,7 @@ def __init__(
ArrayAttr, Sequence[Union[Attribute, bool, int, float, str]], Attribute
],
*,
- selected: Optional[Attribute] = None,
+ selected: Optional[Union[Attribute, bool, int, float, str]] = None,
loc=None,
ip=None,
):
@@ -75,8 +81,62 @@ def knob(
ArrayAttr, Sequence[Union[Attribute, bool, int, float, str]], Attribute
],
*,
- selected: Optional[Attribute] = None,
+ selected: Optional[Union[Attribute, bool, int, float, str]] = None,
loc=None,
ip=None,
):
return KnobOp(result, name, options, selected=selected, loc=loc, ip=ip)
+
+
+ at _ods_cext.register_operation(_Dialect, replace=True)
+class AlternativesOp(AlternativesOp):
+ def __init__(
+ self,
+ results: Sequence[Type],
+ name: Union[StringAttr, str],
+ num_alternatives: int,
+ *,
+ selected_region: Optional[
+ Union[int, IntegerAttr, Value, Operation, OpView]
+ ] = None,
+ loc=None,
+ ip=None,
+ ):
+ if isinstance(name, str):
+ name = StringAttr.get(name)
+
+ selected_region_attr = selected_region_param = None
+ if isinstance(selected_region, IntegerAttr):
+ selected_region_attr = selected_region
+ elif isinstance(selected_region, int):
+ selected_region_attr = IntegerAttr.get(
+ IntegerType.get_signless(32), selected_region
+ )
+ elif isinstance(selected_region, (Value, Operation, OpView)):
+ selected_region_param = _get_op_result_or_value(selected_region)
+
+ super().__init__(
+ results,
+ name,
+ num_alternatives,
+ selected_region_attr=selected_region_attr,
+ selected_region_param=selected_region_param,
+ loc=loc,
+ ip=ip,
+ )
+ for region in self.regions:
+ region.blocks.append()
+
+
+def alternatives(
+ results: Sequence[Type],
+ name: Union[StringAttr, str],
+ num_alternatives: int,
+ *,
+ selected_region: Optional[Union[int, IntegerAttr, Value, Operation, OpView]] = None,
+ loc=None,
+ ip=None,
+):
+ return AlternativesOp(
+ results, name, num_alternatives, selected_region=selected_region, loc=loc, ip=ip
+ )
diff --git a/mlir/test/Dialect/Transform/test-tune-extension-invalid.mlir b/mlir/test/Dialect/Transform/test-tune-extension-invalid.mlir
index 2e5f433abeb71..efc3890288456 100644
--- a/mlir/test/Dialect/Transform/test-tune-extension-invalid.mlir
+++ b/mlir/test/Dialect/Transform/test-tune-extension-invalid.mlir
@@ -19,3 +19,88 @@ module attributes {transform.with_named_sequence} {
transform.yield
}
}
+
+// -----
+
+func.func private @f()
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+ // expected-error at below {{'selected_region' attribute specifies region at index 2 while op has only 2 regions}}
+ transform.tune.alternatives<"bifurcation"> selected_region = 2 {
+ transform.yield
+ }, {
+ transform.yield
+ }
+ transform.yield
+ }
+}
+
+// -----
+
+func.func private @f()
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+ %singleton_of_c0 = transform.param.constant [0] -> !transform.any_param
+ // expected-error at below {{param should hold exactly one integer attribute, got: [0]}}
+ transform.tune.alternatives<"bifurcation"> selected_region = %singleton_of_c0 : !transform.any_param {
+ transform.yield
+ }, {
+ transform.yield
+ }
+ transform.yield
+ }
+}
+
+// -----
+
+func.func private @f()
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+ %c0 = transform.param.constant 0 -> !transform.any_param
+ %c1 = transform.param.constant 1 -> !transform.any_param
+ %c0_and_c1 = transform.merge_handles %c0, %c1 : !transform.any_param
+ // expected-error at below {{param should hold exactly one integer attribute}}
+ transform.tune.alternatives<"bifurcation"> selected_region = %c0_and_c1 : !transform.any_param {
+ transform.yield
+ }, {
+ transform.yield
+ }
+ transform.yield
+ }
+}
+
+// -----
+
+func.func private @f()
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+ %c2 = transform.param.constant 2 -> !transform.any_param
+ // expected-error at below {{'selected_region' attribute/param specifies region at index 2 while op has only 2 regions}}
+ transform.tune.alternatives<"bifurcation"> selected_region = %c2 : !transform.any_param {
+ transform.yield
+ }, {
+ transform.yield
+ }
+ transform.yield
+ }
+}
+
+// -----
+
+func.func private @f()
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+ // expected-error at below {{non-deterministic choice "bifurcation" is only resolved through providing a `selected_region` attr/param}}
+ transform.tune.alternatives<"bifurcation"> {
+ transform.yield
+ }, {
+ transform.yield
+ }
+ transform.yield
+ }
+}
diff --git a/mlir/test/Dialect/Transform/test-tune-extension.mlir b/mlir/test/Dialect/Transform/test-tune-extension.mlir
index 0a253c6d5f837..80b7525136b33 100644
--- a/mlir/test/Dialect/Transform/test-tune-extension.mlir
+++ b/mlir/test/Dialect/Transform/test-tune-extension.mlir
@@ -59,3 +59,102 @@ module attributes {transform.with_named_sequence} {
transform.yield
}
}
+
+
+// -----
+
+// CHECK-LABEL: schedule_with_two_independent_choices_already_made
+func.func @schedule_with_two_independent_choices_already_made(
+ %arg0: tensor<128x128xf32>, %arg1: tensor<128x128xf32>, %arg2: tensor<128x128xf32>)
+ -> tensor<128x128xf32> {
+// CHECK-NOT: scf.forall
+// CHECK: scf.for
+// CHECK-NOT: scf.for
+// CHECK: scf.forall
+// CHECK-NOT: scf.for
+// CHECK: tensor.extract_slice
+// CHECK: tensor.extract_slice
+// CHECK: tensor.extract_slice
+// CHECK: linalg.matmul
+// CHECK: scf.forall.in_parallel
+// CHECK: tensor.parallel_insert_slice
+// CHECK: tensor.insert_slice
+// CHECK: scf.yield
+ %0 = linalg.matmul ins(%arg0, %arg1: tensor<128x128xf32>, tensor<128x128xf32>)
+ outs(%arg2: tensor<128x128xf32>) -> tensor<128x128xf32>
+ return %0 : tensor<128x128xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+ %matmul = transform.structured.match ops{["linalg.matmul"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+
+ %tiled_matmul = transform.tune.alternatives<"outer_par_or_seq_tiling"> selected_region = 0 -> !transform.any_op
+ { // First alternative/region, with index = 0
+...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/160724
More information about the Mlir-commits
mailing list