[Mlir-commits] [mlir] [MLIR][Transform] Introduce `transform.tune.knob` op (PR #146732)
Rolf Morel
llvmlistbot at llvm.org
Tue Jul 8 02:53:07 PDT 2025
https://github.com/rolfmorel updated https://github.com/llvm/llvm-project/pull/146732
>From 25ab701d44721727a56ab5de8ca13021abad0db3 Mon Sep 17 00:00:00 2001
From: Rolf Morel <rolf.morel at intel.com>
Date: Wed, 2 Jul 2025 08:35:02 -0700
Subject: [PATCH 1/7] [MLIR][Transform] Introduce `transform.tune.select` op
A new op to represent that an attribute is to be chosen from a set of
alternatives and that this choice is made available as a
`!transform.param`. When a `selected` argument is provided, the op's
`apply()` semantics is that of just making this selected attribute
available as the result. When `selected` is not provided, `apply()`
complains that nothing has resolved the non-determinism that the op is
representing.
---
.../mlir/Dialect/Transform/CMakeLists.txt | 1 +
.../Transform/TuneExtension/CMakeLists.txt | 6 ++
.../Transform/TuneExtension/TuneExtension.h | 21 +++++++
.../TuneExtension/TuneExtensionOps.h | 22 +++++++
.../TuneExtension/TuneExtensionOps.td | 36 +++++++++++
mlir/lib/Dialect/Transform/CMakeLists.txt | 1 +
.../Transform/TuneExtension/CMakeLists.txt | 12 ++++
.../Transform/TuneExtension/TuneExtension.cpp | 34 +++++++++++
.../TuneExtension/TuneExtensionOps.cpp | 61 +++++++++++++++++++
9 files changed, 194 insertions(+)
create mode 100644 mlir/include/mlir/Dialect/Transform/TuneExtension/CMakeLists.txt
create mode 100644 mlir/include/mlir/Dialect/Transform/TuneExtension/TuneExtension.h
create mode 100644 mlir/include/mlir/Dialect/Transform/TuneExtension/TuneExtensionOps.h
create mode 100644 mlir/include/mlir/Dialect/Transform/TuneExtension/TuneExtensionOps.td
create mode 100644 mlir/lib/Dialect/Transform/TuneExtension/CMakeLists.txt
create mode 100644 mlir/lib/Dialect/Transform/TuneExtension/TuneExtension.cpp
create mode 100644 mlir/lib/Dialect/Transform/TuneExtension/TuneExtensionOps.cpp
diff --git a/mlir/include/mlir/Dialect/Transform/CMakeLists.txt b/mlir/include/mlir/Dialect/Transform/CMakeLists.txt
index b6155b5f573f1..e70479b2a39f2 100644
--- a/mlir/include/mlir/Dialect/Transform/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/Transform/CMakeLists.txt
@@ -5,3 +5,4 @@ add_subdirectory(IRDLExtension)
add_subdirectory(LoopExtension)
add_subdirectory(PDLExtension)
add_subdirectory(Transforms)
+add_subdirectory(TuneExtension)
diff --git a/mlir/include/mlir/Dialect/Transform/TuneExtension/CMakeLists.txt b/mlir/include/mlir/Dialect/Transform/TuneExtension/CMakeLists.txt
new file mode 100644
index 0000000000000..9afca813afda6
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Transform/TuneExtension/CMakeLists.txt
@@ -0,0 +1,6 @@
+set(LLVM_TARGET_DEFINITIONS TuneExtensionOps.td)
+mlir_tablegen(TuneExtensionOps.h.inc -gen-op-decls)
+mlir_tablegen(TuneExtensionOps.cpp.inc -gen-op-defs)
+add_public_tablegen_target(MLIRTransformDialectTuneExtensionOpsIncGen)
+
+add_mlir_doc(TuneExtensionOps TuneExtensionOps Dialects/ -gen-op-doc)
diff --git a/mlir/include/mlir/Dialect/Transform/TuneExtension/TuneExtension.h b/mlir/include/mlir/Dialect/Transform/TuneExtension/TuneExtension.h
new file mode 100644
index 0000000000000..1453d1754297f
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Transform/TuneExtension/TuneExtension.h
@@ -0,0 +1,21 @@
+//===- TuneExtension.h - Tune extension for Transform dialect ---*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_TRANSFORM_TUNEEXTENSION_TUNEEXTENSION_H
+#define MLIR_DIALECT_TRANSFORM_TUNEEXTENSION_TUNEEXTENSION_H
+
+namespace mlir {
+class DialectRegistry;
+
+namespace transform {
+/// Registers the tune extension of the Transform dialect in the given registry.
+void registerTuneExtension(DialectRegistry &dialectRegistry);
+} // namespace transform
+} // namespace mlir
+
+#endif // MLIR_DIALECT_TRANSFORM_TUNEEXTENSION_TUNEEXTENSION_H
diff --git a/mlir/include/mlir/Dialect/Transform/TuneExtension/TuneExtensionOps.h b/mlir/include/mlir/Dialect/Transform/TuneExtension/TuneExtensionOps.h
new file mode 100644
index 0000000000000..de5bbc61919e9
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Transform/TuneExtension/TuneExtensionOps.h
@@ -0,0 +1,22 @@
+//===- TuneExtensionOps.h - Tune ext. for Transform dialect -----*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_TRANSFORM_TUNEEXTENSION_TUNEEXTENSIONOPS_H
+#define MLIR_DIALECT_TRANSFORM_TUNEEXTENSION_TUNEEXTENSIONOPS_H
+
+#include "mlir/Bytecode/BytecodeOpInterface.h"
+#include "mlir/Dialect/Transform/IR/TransformDialect.h"
+#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
+#include "mlir/IR/OpDefinition.h"
+#include "mlir/IR/OpImplementation.h"
+#include "mlir/Interfaces/SideEffectInterfaces.h"
+
+#define GET_OP_CLASSES
+#include "mlir/Dialect/Transform/TuneExtension/TuneExtensionOps.h.inc"
+
+#endif // MLIR_DIALECT_TRANSFORM_TUNEEXTENSION_TUNEEXTENSIONOPS_H
diff --git a/mlir/include/mlir/Dialect/Transform/TuneExtension/TuneExtensionOps.td b/mlir/include/mlir/Dialect/Transform/TuneExtension/TuneExtensionOps.td
new file mode 100644
index 0000000000000..9366ab0ddd240
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Transform/TuneExtension/TuneExtensionOps.td
@@ -0,0 +1,36 @@
+//===- TuneExtensionOps.td - Transform dialect operations --*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_TRANSFORM_TUNEEXTENSION_TUNEEXTENSIONOPS
+#define MLIR_DIALECT_TRANSFORM_TUNEEXTENSION_TUNEEXTENSIONOPS
+
+include "mlir/Dialect/Transform/IR/TransformDialect.td"
+include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.td"
+include "mlir/Interfaces/SideEffectInterfaces.td"
+include "mlir/IR/CommonAttrConstraints.td"
+
+def SelectOp : Op<Transform_Dialect, "tune.select", [
+ DeclareOpInterfaceMethods<TransformOpInterface>,
+ DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
+]> {
+ let summary = "Non-deterministically select a value from a set of values";
+ let description = [{
+ TODO
+ }];
+ let cppNamespace = [{ mlir::transform::tune }];
+ let hasVerifier = 1;
+
+ let arguments = (ins SymbolRefAttr:$name,
+ AnyAttr:$options,
+ OptionalAttr<AnyAttr>:$selected);
+ let results = (outs TransformParamTypeInterface:$result);
+ let assemblyFormat =
+ "$name (`=` $selected^ `selected`)? `from` $options attr-dict `->` type(results)";
+}
+
+#endif // MLIR_DIALECT_TRANSFORM_TUNEEXTENSION_TUNEEXTENSIONOPS
diff --git a/mlir/lib/Dialect/Transform/CMakeLists.txt b/mlir/lib/Dialect/Transform/CMakeLists.txt
index 0c0d5ebe0c212..6e628353258d6 100644
--- a/mlir/lib/Dialect/Transform/CMakeLists.txt
+++ b/mlir/lib/Dialect/Transform/CMakeLists.txt
@@ -5,4 +5,5 @@ add_subdirectory(IRDLExtension)
add_subdirectory(LoopExtension)
add_subdirectory(PDLExtension)
add_subdirectory(Transforms)
+add_subdirectory(TuneExtension)
add_subdirectory(Utils)
diff --git a/mlir/lib/Dialect/Transform/TuneExtension/CMakeLists.txt b/mlir/lib/Dialect/Transform/TuneExtension/CMakeLists.txt
new file mode 100644
index 0000000000000..ff01d25e57f68
--- /dev/null
+++ b/mlir/lib/Dialect/Transform/TuneExtension/CMakeLists.txt
@@ -0,0 +1,12 @@
+add_mlir_dialect_library(MLIRTransformTuneExtension
+ TuneExtension.cpp
+ TuneExtensionOps.cpp
+
+ DEPENDS
+ MLIRTransformDialectTuneExtensionOpsIncGen
+
+ LINK_LIBS PUBLIC
+ MLIRIR
+ MLIRTransformDialect
+ MLIRTransforms
+)
diff --git a/mlir/lib/Dialect/Transform/TuneExtension/TuneExtension.cpp b/mlir/lib/Dialect/Transform/TuneExtension/TuneExtension.cpp
new file mode 100644
index 0000000000000..c4581db83bb05
--- /dev/null
+++ b/mlir/lib/Dialect/Transform/TuneExtension/TuneExtension.cpp
@@ -0,0 +1,34 @@
+//===- TuneExtension.cpp - Tune extension for the Transform dialect -------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Transform/TuneExtension/TuneExtension.h"
+
+#include "mlir/Dialect/Transform/IR/TransformDialect.h"
+#include "mlir/Dialect/Transform/TuneExtension/TuneExtensionOps.h"
+#include "mlir/IR/DialectRegistry.h"
+
+using namespace mlir;
+
+/// Tune extension of the Transform dialect. This provides "core" transform
+/// operations for loop-like ops.
+class TuneExtension
+ : public transform::TransformDialectExtension<TuneExtension> {
+public:
+ MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TuneExtension)
+
+ void init() {
+ registerTransformOps<
+#define GET_OP_LIST
+#include "mlir/Dialect/Transform/TuneExtension/TuneExtensionOps.cpp.inc"
+ >();
+ }
+};
+
+void mlir::transform::registerTuneExtension(DialectRegistry &dialectRegistry) {
+ dialectRegistry.addExtensions<TuneExtension>();
+}
diff --git a/mlir/lib/Dialect/Transform/TuneExtension/TuneExtensionOps.cpp b/mlir/lib/Dialect/Transform/TuneExtension/TuneExtensionOps.cpp
new file mode 100644
index 0000000000000..401f09eb4b6dc
--- /dev/null
+++ b/mlir/lib/Dialect/Transform/TuneExtension/TuneExtensionOps.cpp
@@ -0,0 +1,61 @@
+//===- TuneExtensionOps.cpp - Tune extension for the Transform dialect ----===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Transform/TuneExtension/TuneExtensionOps.h"
+#include "mlir/Dialect/Transform/IR/TransformOps.h"
+#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
+
+#include "mlir/IR/OpImplementation.h"
+#include "mlir/IR/PatternMatch.h"
+#include "llvm/Support/Debug.h"
+
+using namespace mlir;
+
+#define GET_OP_CLASSES
+#include "mlir/Dialect/Transform/TuneExtension/TuneExtensionOps.cpp.inc"
+
+#define DEBUG_TYPE "transform-tune"
+#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "] ")
+
+//===----------------------------------------------------------------------===//
+// SelectOp
+//===----------------------------------------------------------------------===//
+
+void transform::tune::SelectOp::getEffects(
+ SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
+ producesHandle(getOperation()->getOpResults(), effects);
+ onlyReadsPayload(effects);
+}
+
+DiagnosedSilenceableFailure
+transform::tune::SelectOp::apply(transform::TransformRewriter &rewriter,
+ transform::TransformResults &results,
+ transform::TransformState &state) {
+ if (getSelected()) {
+ results.setParams(getOperation()->getOpResults()[0], *getSelected());
+ return DiagnosedSilenceableFailure::success();
+ }
+
+ return emitDefiniteFailure() << "non-deterministic choice is only resolved "
+ "through providing a `selected` attr!";
+}
+
+LogicalResult transform::tune::SelectOp::verify() {
+ if (auto selected = getSelected()) {
+ if (auto optionsArray = dyn_cast<ArrayAttr>(getOptions())) {
+ if (!llvm::is_contained(optionsArray, selected))
+ return emitOpError("provided `selected` attribute is not an element of "
+ "`options` array of attributes");
+ } else
+ LLVM_DEBUG(DBGS() << "cannot verify `selected` attribute " << selected
+ << " is an element of `options` attribute "
+ << getOptions());
+ }
+
+ return success();
+}
>From 020562047c425177f6307c0fcf884736371ae2eb Mon Sep 17 00:00:00 2001
From: Rolf Morel <rolf.morel at intel.com>
Date: Thu, 3 Jul 2025 06:45:47 -0700
Subject: [PATCH 2/7] Move to new name, KnobOp, and new syntax and add test
cases and docs
---
.../TuneExtension/TuneExtensionOps.td | 28 +++++++++++++++----
mlir/include/mlir/InitAllExtensions.h | 2 ++
.../Transform/TuneExtension/TuneExtension.cpp | 2 --
.../TuneExtension/TuneExtensionOps.cpp | 17 +++++------
mlir/python/CMakeLists.txt | 9 ++++++
5 files changed, 43 insertions(+), 15 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Transform/TuneExtension/TuneExtensionOps.td b/mlir/include/mlir/Dialect/Transform/TuneExtension/TuneExtensionOps.td
index 9366ab0ddd240..afb67e8fef250 100644
--- a/mlir/include/mlir/Dialect/Transform/TuneExtension/TuneExtensionOps.td
+++ b/mlir/include/mlir/Dialect/Transform/TuneExtension/TuneExtensionOps.td
@@ -12,25 +12,43 @@
include "mlir/Dialect/Transform/IR/TransformDialect.td"
include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
+include "mlir/IR/BuiltinAttributes.td"
include "mlir/IR/CommonAttrConstraints.td"
-def SelectOp : Op<Transform_Dialect, "tune.select", [
+def KnobOp : Op<Transform_Dialect, "tune.knob", [
DeclareOpInterfaceMethods<TransformOpInterface>,
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
]> {
- let summary = "Non-deterministically select a value from a set of values";
+ let summary = "Represents a tunable parameter with a set of options";
+
let description = [{
- TODO
+ Provides a representation for "tunables" within schedules.
+
+ Each op represents a single tunable, which has a `name` and a set
+ of valid `options` described by an attribute. Without a specified
+ `selected` option, this op represents a non-deterministic choice
+ that has yet to be resolved -- as such, the interpreter runtime
+ semantics is to raise a failure.
+
+ The non-deterministic choice is resolved through providing a
+ `selected` attribute. When provided, the interpreter runtime
+ semantics are to return the `selected` attribute as a param through
+ the op's result.
+
+ -----
+
+ In case the `options` attribute is an `ArrayAttr`, the verifier checks that the provided `selected` attribute occurs in `options`.
}];
let cppNamespace = [{ mlir::transform::tune }];
let hasVerifier = 1;
- let arguments = (ins SymbolRefAttr:$name,
+ let arguments = (ins Builtin_StringAttr:$name,
AnyAttr:$options,
OptionalAttr<AnyAttr>:$selected);
let results = (outs TransformParamTypeInterface:$result);
+
let assemblyFormat =
- "$name (`=` $selected^ `selected`)? `from` $options attr-dict `->` type(results)";
+ "`<` $name `>` (`=` $selected^ `from`)? `options` `=` $options attr-dict `->` type(results)";
}
#endif // MLIR_DIALECT_TRANSFORM_TUNEEXTENSION_TUNEEXTENSIONOPS
diff --git a/mlir/include/mlir/InitAllExtensions.h b/mlir/include/mlir/InitAllExtensions.h
index f356b91b1b6c0..0f2d0e45008cc 100644
--- a/mlir/include/mlir/InitAllExtensions.h
+++ b/mlir/include/mlir/InitAllExtensions.h
@@ -52,6 +52,7 @@
#include "mlir/Dialect/Transform/IRDLExtension/IRDLExtension.h"
#include "mlir/Dialect/Transform/LoopExtension/LoopExtension.h"
#include "mlir/Dialect/Transform/PDLExtension/PDLExtension.h"
+#include "mlir/Dialect/Transform/TuneExtension/TuneExtension.h"
#include "mlir/Dialect/Vector/TransformOps/VectorTransformOps.h"
#include "mlir/Target/LLVMIR/Dialect/Builtin/BuiltinToLLVMIRTranslation.h"
#include "mlir/Target/LLVMIR/Dialect/GPU/GPUToLLVMIRTranslation.h"
@@ -107,6 +108,7 @@ inline void registerAllExtensions(DialectRegistry ®istry) {
transform::registerIRDLExtension(registry);
transform::registerLoopExtension(registry);
transform::registerPDLExtension(registry);
+ transform::registerTuneExtension(registry);
vector::registerTransformDialectExtension(registry);
arm_neon::registerTransformDialectExtension(registry);
arm_sve::registerTransformDialectExtension(registry);
diff --git a/mlir/lib/Dialect/Transform/TuneExtension/TuneExtension.cpp b/mlir/lib/Dialect/Transform/TuneExtension/TuneExtension.cpp
index c4581db83bb05..e18f1e2748540 100644
--- a/mlir/lib/Dialect/Transform/TuneExtension/TuneExtension.cpp
+++ b/mlir/lib/Dialect/Transform/TuneExtension/TuneExtension.cpp
@@ -14,8 +14,6 @@
using namespace mlir;
-/// Tune extension of the Transform dialect. This provides "core" transform
-/// operations for loop-like ops.
class TuneExtension
: public transform::TransformDialectExtension<TuneExtension> {
public:
diff --git a/mlir/lib/Dialect/Transform/TuneExtension/TuneExtensionOps.cpp b/mlir/lib/Dialect/Transform/TuneExtension/TuneExtensionOps.cpp
index 401f09eb4b6dc..6ee0753b3d556 100644
--- a/mlir/lib/Dialect/Transform/TuneExtension/TuneExtensionOps.cpp
+++ b/mlir/lib/Dialect/Transform/TuneExtension/TuneExtensionOps.cpp
@@ -23,29 +23,30 @@ using namespace mlir;
#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "] ")
//===----------------------------------------------------------------------===//
-// SelectOp
+// KnobOp
//===----------------------------------------------------------------------===//
-void transform::tune::SelectOp::getEffects(
+void transform::tune::KnobOp::getEffects(
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
producesHandle(getOperation()->getOpResults(), effects);
onlyReadsPayload(effects);
}
DiagnosedSilenceableFailure
-transform::tune::SelectOp::apply(transform::TransformRewriter &rewriter,
- transform::TransformResults &results,
- transform::TransformState &state) {
+transform::tune::KnobOp::apply(transform::TransformRewriter &rewriter,
+ transform::TransformResults &results,
+ transform::TransformState &state) {
if (getSelected()) {
results.setParams(getOperation()->getOpResults()[0], *getSelected());
return DiagnosedSilenceableFailure::success();
}
- return emitDefiniteFailure() << "non-deterministic choice is only resolved "
- "through providing a `selected` attr!";
+ return emitDefiniteFailure()
+ << "non-deterministic choice " << getName()
+ << " is only resolved through providing a `selected` attr";
}
-LogicalResult transform::tune::SelectOp::verify() {
+LogicalResult transform::tune::KnobOp::verify() {
if (auto selected = getSelected()) {
if (auto optionsArray = dyn_cast<ArrayAttr>(getOptions())) {
if (!llvm::is_contained(optionsArray, selected))
diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt
index b2daabb2a5957..7a0c95ebb8200 100644
--- a/mlir/python/CMakeLists.txt
+++ b/mlir/python/CMakeLists.txt
@@ -180,6 +180,15 @@ ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
DIALECT_NAME transform
EXTENSION_NAME transform_debug_extension)
+declare_mlir_dialect_extension_python_bindings(
+ADD_TO_PARENT MLIRPythonSources.Dialects
+ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
+ TD_FILE dialects/TransformTuneExtensionOps.td
+ SOURCES
+ dialects/transform/tune.py
+ DIALECT_NAME transform
+ EXTENSION_NAME transform_tune_extension)
+
declare_mlir_dialect_python_bindings(
ADD_TO_PARENT MLIRPythonSources.Dialects
ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
>From 8e0a4d475b81da1490fe93b972cb9f554d28dc99 Mon Sep 17 00:00:00 2001
From: Rolf Morel <rolf.morel at intel.com>
Date: Thu, 3 Jul 2025 07:01:17 -0700
Subject: [PATCH 3/7] Fix up header includes
---
.../TuneExtension/TuneExtensionOps.h | 5 +-
.../TuneExtension/TuneExtensionOps.cpp | 4 +-
.../dialects/TransformTuneExtensionOps.td | 18 ++++
mlir/python/mlir/dialects/transform/tune.py | 82 +++++++++++++++++++
.../test-tune-extension-invalid.mlir | 21 +++++
.../Transform/test-tune-extension.mlir | 61 ++++++++++++++
.../python/dialects/transform_tune_ext.py | 56 +++++++++++++
7 files changed, 241 insertions(+), 6 deletions(-)
create mode 100644 mlir/python/mlir/dialects/TransformTuneExtensionOps.td
create mode 100644 mlir/python/mlir/dialects/transform/tune.py
create mode 100644 mlir/test/Dialect/Transform/test-tune-extension-invalid.mlir
create mode 100644 mlir/test/Dialect/Transform/test-tune-extension.mlir
create mode 100644 mlir/test/python/dialects/transform_tune_ext.py
diff --git a/mlir/include/mlir/Dialect/Transform/TuneExtension/TuneExtensionOps.h b/mlir/include/mlir/Dialect/Transform/TuneExtension/TuneExtensionOps.h
index de5bbc61919e9..74e1d28ffac82 100644
--- a/mlir/include/mlir/Dialect/Transform/TuneExtension/TuneExtensionOps.h
+++ b/mlir/include/mlir/Dialect/Transform/TuneExtension/TuneExtensionOps.h
@@ -9,12 +9,9 @@
#ifndef MLIR_DIALECT_TRANSFORM_TUNEEXTENSION_TUNEEXTENSIONOPS_H
#define MLIR_DIALECT_TRANSFORM_TUNEEXTENSION_TUNEEXTENSIONOPS_H
-#include "mlir/Bytecode/BytecodeOpInterface.h"
-#include "mlir/Dialect/Transform/IR/TransformDialect.h"
#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
+#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/OpDefinition.h"
-#include "mlir/IR/OpImplementation.h"
-#include "mlir/Interfaces/SideEffectInterfaces.h"
#define GET_OP_CLASSES
#include "mlir/Dialect/Transform/TuneExtension/TuneExtensionOps.h.inc"
diff --git a/mlir/lib/Dialect/Transform/TuneExtension/TuneExtensionOps.cpp b/mlir/lib/Dialect/Transform/TuneExtension/TuneExtensionOps.cpp
index 6ee0753b3d556..0c77dbb0f05dd 100644
--- a/mlir/lib/Dialect/Transform/TuneExtension/TuneExtensionOps.cpp
+++ b/mlir/lib/Dialect/Transform/TuneExtension/TuneExtensionOps.cpp
@@ -6,14 +6,14 @@
//
//===----------------------------------------------------------------------===//
-#include "mlir/Dialect/Transform/TuneExtension/TuneExtensionOps.h"
#include "mlir/Dialect/Transform/IR/TransformOps.h"
#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
-
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/PatternMatch.h"
#include "llvm/Support/Debug.h"
+#include "mlir/Dialect/Transform/TuneExtension/TuneExtensionOps.h"
+
using namespace mlir;
#define GET_OP_CLASSES
diff --git a/mlir/python/mlir/dialects/TransformTuneExtensionOps.td b/mlir/python/mlir/dialects/TransformTuneExtensionOps.td
new file mode 100644
index 0000000000000..60ed95d110762
--- /dev/null
+++ b/mlir/python/mlir/dialects/TransformTuneExtensionOps.td
@@ -0,0 +1,18 @@
+//===-- TransformTuneExtensionOps.td - Binding entry point -*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Entry point of the generated Python bindings for the Tune extension of the
+// Transform dialect.
+//===----------------------------------------------------------------------===//
+
+#ifndef PYTHON_BINDINGS_TRANSFORM_DEBUG_EXTENSION_OPS
+#define PYTHON_BINDINGS_TRANSFORM_DEBUG_EXTENSION_OPS
+
+include "mlir/Dialect/Transform/TuneExtension/TuneExtensionOps.td"
+
+#endif // PYTHON_BINDINGS_TRANSFORM_DEBUG_EXTENSION_OPS
diff --git a/mlir/python/mlir/dialects/transform/tune.py b/mlir/python/mlir/dialects/transform/tune.py
new file mode 100644
index 0000000000000..15c43aba795eb
--- /dev/null
+++ b/mlir/python/mlir/dialects/transform/tune.py
@@ -0,0 +1,82 @@
+# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+from typing import Optional, Sequence
+
+from ...ir import (
+ Type,
+ Attribute,
+ ArrayAttr,
+ StringAttr,
+ F64Type,
+ IntegerType,
+ IntegerAttr,
+ FloatAttr,
+ BoolAttr,
+)
+from .._transform_tune_extension_ops_gen import *
+from .._transform_tune_extension_ops_gen import _Dialect
+
+try:
+ from .._ods_common import _cext as _ods_cext
+except ImportError as e:
+ raise RuntimeError("Error loading imports from extension module") from e
+
+from typing import Union
+
+
+ at _ods_cext.register_operation(_Dialect, replace=True)
+class KnobOp(KnobOp):
+ def __init__(
+ self,
+ result: Type, # !transform.any_param or !transform.param<Type>
+ name: Union[StringAttr, str],
+ options: Union[
+ ArrayAttr, Sequence[Union[Attribute, bool, int, float, str]], Attribute
+ ],
+ *,
+ selected: Optional[Attribute] = None,
+ loc=None,
+ ip=None,
+ ):
+ if isinstance(name, str):
+ name = StringAttr.get(name)
+
+ def map_to_attr(value):
+ if isinstance(value, bool):
+ return BoolAttr.get(value)
+ if isinstance(value, int):
+ return IntegerAttr.get(IntegerType.get_signless(64), value)
+ if isinstance(value, float):
+ return FloatAttr.get(F64Type.get(), value)
+ if isinstance(value, str):
+ return StringAttr.get(value)
+ assert isinstance(value, Attribute)
+ return value
+
+ if isinstance(options, Sequence) and not isinstance(options, ArrayAttr):
+ options = ArrayAttr.get([map_to_attr(opt) for opt in options])
+
+ super().__init__(
+ result,
+ name,
+ options,
+ selected=selected and map_to_attr(selected),
+ loc=loc,
+ ip=ip,
+ )
+
+
+def knob(
+ result: Type, # !transform.any_param or !transform.param<Type>
+ name: Union[StringAttr, str],
+ options: Union[
+ ArrayAttr, Sequence[Union[Attribute, bool, int, float, str]], Attribute
+ ],
+ *,
+ selected: Optional[Attribute] = None,
+ loc=None,
+ ip=None,
+):
+ return KnobOp(result, name, options, selected=selected, loc=loc, ip=ip)
diff --git a/mlir/test/Dialect/Transform/test-tune-extension-invalid.mlir b/mlir/test/Dialect/Transform/test-tune-extension-invalid.mlir
new file mode 100644
index 0000000000000..2e5f433abeb71
--- /dev/null
+++ b/mlir/test/Dialect/Transform/test-tune-extension-invalid.mlir
@@ -0,0 +1,21 @@
+// RUN: mlir-opt %s --transform-interpreter --split-input-file --verify-diagnostics
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+ // expected-error at below {{provided `selected` attribute is not an element of `options` array of attributes}}
+ %heads_or_tails = transform.tune.knob<"coin"> = 1 from options = [true, false] -> !transform.any_param
+ transform.yield
+ }
+}
+
+// -----
+
+func.func private @f()
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+ // expected-error at below {{non-deterministic choice "coin" is only resolved through providing a `selected` attr}}
+ %heads_or_tails = transform.tune.knob<"coin"> options = [true, false] -> !transform.any_param
+ transform.yield
+ }
+}
diff --git a/mlir/test/Dialect/Transform/test-tune-extension.mlir b/mlir/test/Dialect/Transform/test-tune-extension.mlir
new file mode 100644
index 0000000000000..4e3ad0c8d18d9
--- /dev/null
+++ b/mlir/test/Dialect/Transform/test-tune-extension.mlir
@@ -0,0 +1,61 @@
+// RUN: mlir-opt %s --transform-interpreter --split-input-file \
+// RUN: --verify-diagnostics | FileCheck %s
+
+// CHECK-LABEL: @schedule_with_nondet_knobs
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @schedule_with_nondet_knobs(%arg0: !transform.any_op {transform.readonly}) {
+ // CHECK: %[[HEADS_OR_TAILS:.*]] = transform.tune.knob<"coin"> options = [true, false] -> !transform.any_param
+ %heads_or_tails = transform.tune.knob<"coin"> options = [true, false] -> !transform.any_param
+ // CHECK: transform.tune.knob<"animal"> options = ["cat", "dog", unit] -> !transform.any_param
+ %chosen_category = transform.tune.knob<"animal"> options = ["cat", "dog", unit] -> !transform.any_param
+ // CHECK: transform.tune.knob<"tile_size"> options = [2, 4, 8, 16, 24, 32] -> !transform.any_param
+ %chosen_tile_size = transform.tune.knob<"tile_size"> options = [2, 4, 8, 16, 24, 32] -> !transform.any_param
+ // CHECK: transform.tune.knob<"magic_value"> options = [2.000000e+00 : f32, 2.250000e+00 : f32, 2.500000e+00 : f32, 2.750000e+00 : f32, 3.000000e+00 : f32] -> !transform.any_param
+ %chosen_constant = transform.tune.knob<"magic_value"> options = [2.0 : f32, 2.25 : f32, 2.5 : f32, 2.75 : f32, 3.0 : f32] -> !transform.any_param
+ // CHECK: transform.debug.emit_param_as_remark %[[HEADS_OR_TAILS]]
+ transform.debug.emit_param_as_remark %heads_or_tails : !transform.any_param
+ transform.yield
+ }
+ transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+ // Dummy sequence to appease -transform-interpreter invocation
+ transform.yield
+ }
+}
+
+// -----
+
+// Schedule where non-determinism on knobs has been resolved by selecting a valid option.
+
+// CHECK-LABEL: payload_for_schedule_with_selected_knobs
+func.func private @payload_for_schedule_with_selected_knobs()
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+ // CHECK: %[[HEADS_OR_TAILS:.*]] = transform.tune.knob<"coin"> = true from options = [true, false] -> !transform.any_param
+ %heads_or_tails = transform.tune.knob<"coin"> = true from options = [true, false] -> !transform.any_param
+ // expected-remark at below {{true}}
+ transform.debug.emit_param_as_remark %heads_or_tails : !transform.any_param
+
+ // CHECK: transform.tune.knob<"animal"> = "dog" from options = ["cat", "dog", unit] -> !transform.any_param
+ %chosen_category = transform.tune.knob<"animal"> = "dog" from options = ["cat", "dog", unit] -> !transform.any_param
+ // CHECK: transform.tune.knob<"tile_size"> = 8 : i64 from options = [2, 4, 8, 16, 24, 32] -> !transform.any_param
+ %chosen_tile_size = transform.tune.knob<"tile_size"> = 8 from options = [2, 4, 8, 16, 24, 32] -> !transform.any_param
+ // CHECK: transform.tune.knob<"magic_value"> = 2.500000e+00 : f32 from options = [2.000000e+00 : f32, 2.250000e+00 : f32, 2.500000e+00 : f32, 2.750000e+00 : f32, 3.000000e+00 : f32] -> !transform.any_param
+ %chosen_constant = transform.tune.knob<"magic_value"> = 2.5 : f32 from options = [2.0 : f32, 2.25 : f32, 2.5 : f32, 2.75 : f32, 3.0 : f32] -> !transform.any_param
+ transform.yield
+ }
+}
+
+// -----
+
+// CHECK: #[[AFFINE_SET:.*]] = affine_set<(d0) : (d0 - 2 >= 0)>
+// CHECK: payload_for_schedule_where_selected_knob_being_a_member_of_options_is_unverified
+func.func private @payload_for_schedule_where_selected_knob_being_a_member_of_options_is_unverified()
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+ // CHECK: transform.tune.knob<"bounded"> = 4242 : i64 from options = #[[AFFINE_SET]] -> !transform.any_param
+ %value_in_half_range = transform.tune.knob<"bounded"> = 4242 from options = affine_set<(d0) : (d0 - 2 >= 0)> -> !transform.any_param
+ transform.yield
+ }
+}
\ No newline at end of file
diff --git a/mlir/test/python/dialects/transform_tune_ext.py b/mlir/test/python/dialects/transform_tune_ext.py
new file mode 100644
index 0000000000000..0065479328ec3
--- /dev/null
+++ b/mlir/test/python/dialects/transform_tune_ext.py
@@ -0,0 +1,56 @@
+# RUN: %PYTHON %s | FileCheck %s
+
+from mlir.ir import *
+from mlir.dialects import transform
+from mlir.dialects.transform import tune, debug
+
+
+def run(f):
+ print("\nTEST:", f.__name__)
+ with Context(), Location.unknown():
+ module = Module.create()
+ with InsertionPoint(module.body):
+ sequence = transform.SequenceOp(
+ transform.FailurePropagationMode.Propagate,
+ [],
+ transform.AnyOpType.get(),
+ )
+ with InsertionPoint(sequence.body):
+ f(sequence.bodyTarget)
+ transform.YieldOp()
+ print(module)
+ return f
+
+
+# CHECK-LABEL: TEST: testKnobOp
+ at run
+def testKnobOp(target):
+ any_param = transform.AnyParamType.get()
+
+ # CHECK: %[[HEADS_OR_TAILS:.*]] = transform.tune.knob<"coin"> options = [true, false] -> !transform.any_param
+ heads_or_tails = tune.KnobOp(
+ result=any_param, name=StringAttr.get("coin"), options=[True, False]
+ )
+ # CHECK: transform.tune.knob<"animal"> options = ["cat", "dog", unit] -> !transform.any_param
+ tune.KnobOp(any_param, name="animal", options=["cat", "dog", UnitAttr.get()])
+ # CHECK: transform.tune.knob<"tile_size"> options = [2, 4, 8, 16, 24, 32] -> !transform.any_param
+ tune.KnobOp(any_param, "tile_size", [2, 4, 8, 16, 24, 32])
+ # CHECK: transform.tune.knob<"magic_value"> options = [2.000000e+00, 2.250000e+00, 2.500000e+00, 2.750000e+00, 3.000000e+00] -> !transform.any_param
+ tune.knob(any_param, "magic_value", [2.0, 2.25, 2.5, 2.75, 3.0])
+
+ # CHECK: transform.debug.emit_param_as_remark %[[HEADS_OR_TAILS]]
+ debug.emit_param_as_remark(heads_or_tails)
+
+ # CHECK: %[[HEADS:.*]] = transform.tune.knob<"coin"> = true from options = [true, false] -> !transform.any_param
+ heads = tune.KnobOp(any_param, "coin", options=[True, False], selected=True)
+ # CHECK: transform.tune.knob<"animal"> = "dog" from options = ["cat", "dog", unit] -> !transform.any_param
+ tune.KnobOp(
+ any_param, name="animal", options=["cat", "dog", UnitAttr.get()], selected="dog"
+ )
+ # CHECK: transform.tune.knob<"tile_size"> = 8 : i64 from options = [2, 4, 8, 16, 24, 32] -> !transform.any_param
+ tune.KnobOp(any_param, "tile_size", [2, 4, 8, 16, 24, 32], selected=8)
+ # CHECK: transform.tune.knob<"magic_value"> = 2.500000e+00 : f64 from options = [2.000000e+00, 2.250000e+00, 2.500000e+00, 2.750000e+00, 3.000000e+00] -> !transform.any_param
+ tune.knob(any_param, "magic_value", [2.0, 2.25, 2.5, 2.75, 3.0], selected=2.5)
+
+ # CHECK: transform.debug.emit_param_as_remark %[[HEADS]]
+ debug.emit_param_as_remark(heads)
>From 91cc2f3fddcc55f541d76145cf471cf5772a8f32 Mon Sep 17 00:00:00 2001
From: Rolf Morel <rolf.morel at intel.com>
Date: Thu, 3 Jul 2025 07:02:19 -0700
Subject: [PATCH 4/7] \n
---
mlir/test/Dialect/Transform/test-tune-extension.mlir | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/mlir/test/Dialect/Transform/test-tune-extension.mlir b/mlir/test/Dialect/Transform/test-tune-extension.mlir
index 4e3ad0c8d18d9..0a253c6d5f837 100644
--- a/mlir/test/Dialect/Transform/test-tune-extension.mlir
+++ b/mlir/test/Dialect/Transform/test-tune-extension.mlir
@@ -58,4 +58,4 @@ module attributes {transform.with_named_sequence} {
%value_in_half_range = transform.tune.knob<"bounded"> = 4242 from options = affine_set<(d0) : (d0 - 2 >= 0)> -> !transform.any_param
transform.yield
}
-}
\ No newline at end of file
+}
>From 2f24acfc1556f1ff09ff53b74bc551e4f2c7cb42 Mon Sep 17 00:00:00 2001
From: Rolf Morel <rolf.morel at intel.com>
Date: Thu, 3 Jul 2025 07:07:46 -0700
Subject: [PATCH 5/7] Python formatting
---
mlir/python/mlir/dialects/transform/tune.py | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
diff --git a/mlir/python/mlir/dialects/transform/tune.py b/mlir/python/mlir/dialects/transform/tune.py
index 15c43aba795eb..f63f88a382422 100644
--- a/mlir/python/mlir/dialects/transform/tune.py
+++ b/mlir/python/mlir/dialects/transform/tune.py
@@ -30,7 +30,7 @@
class KnobOp(KnobOp):
def __init__(
self,
- result: Type, # !transform.any_param or !transform.param<Type>
+ result: Type, # !transform.any_param or !transform.param<Type>
name: Union[StringAttr, str],
options: Union[
ArrayAttr, Sequence[Union[Attribute, bool, int, float, str]], Attribute
@@ -69,7 +69,7 @@ def map_to_attr(value):
def knob(
- result: Type, # !transform.any_param or !transform.param<Type>
+ result: Type, # !transform.any_param or !transform.param<Type>
name: Union[StringAttr, str],
options: Union[
ArrayAttr, Sequence[Union[Attribute, bool, int, float, str]], Attribute
>From de0d81e33726088891df7bbbca6a868e36f627b0 Mon Sep 17 00:00:00 2001
From: Rolf Morel <rolf.morel at intel.com>
Date: Fri, 4 Jul 2025 09:45:27 -0700
Subject: [PATCH 6/7] Demonstrate options-spec as a dict attr from Python
---
mlir/test/python/dialects/transform_tune_ext.py | 16 ++++++++++++++++
1 file changed, 16 insertions(+)
diff --git a/mlir/test/python/dialects/transform_tune_ext.py b/mlir/test/python/dialects/transform_tune_ext.py
index 0065479328ec3..dfb93594bca52 100644
--- a/mlir/test/python/dialects/transform_tune_ext.py
+++ b/mlir/test/python/dialects/transform_tune_ext.py
@@ -54,3 +54,19 @@ def testKnobOp(target):
# CHECK: transform.debug.emit_param_as_remark %[[HEADS]]
debug.emit_param_as_remark(heads)
+
+ # CHECK: transform.tune.knob<"range_as_a_dict"> = 4 : i64 from options = {start = 2 : i64, step = 2 : i64, stop = 16 : i64} -> !transform.any_param
+ # NB: Membership of `selected` in non-ArrayAttr `options` is _not_ verified.
+ i64 = IntegerType.get_signless(64)
+ tune.knob(
+ any_param,
+ "range_as_a_dict",
+ DictAttr.get(
+ {
+ "start": IntegerAttr.get(i64, 2),
+ "stop": IntegerAttr.get(i64, 16),
+ "step": IntegerAttr.get(i64, 2),
+ }
+ ),
+ selected=4,
+ )
>From 8afb10a79ad6a89f2853fff26821c35b85352c27 Mon Sep 17 00:00:00 2001
From: Rolf Morel <rolf.morel at intel.com>
Date: Tue, 8 Jul 2025 02:51:58 -0700
Subject: [PATCH 7/7] Minor fixes
---
.../mlir/Dialect/Transform/TuneExtension/TuneExtensionOps.td | 3 ++-
mlir/lib/Dialect/Transform/TuneExtension/CMakeLists.txt | 1 -
mlir/lib/Dialect/Transform/TuneExtension/TuneExtensionOps.cpp | 2 +-
mlir/python/mlir/dialects/TransformTuneExtensionOps.td | 1 +
4 files changed, 4 insertions(+), 3 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Transform/TuneExtension/TuneExtensionOps.td b/mlir/include/mlir/Dialect/Transform/TuneExtension/TuneExtensionOps.td
index afb67e8fef250..d68d451afac40 100644
--- a/mlir/include/mlir/Dialect/Transform/TuneExtension/TuneExtensionOps.td
+++ b/mlir/include/mlir/Dialect/Transform/TuneExtension/TuneExtensionOps.td
@@ -37,7 +37,8 @@ def KnobOp : Op<Transform_Dialect, "tune.knob", [
-----
- In case the `options` attribute is an `ArrayAttr`, the verifier checks that the provided `selected` attribute occurs in `options`.
+ In case the `options` attribute is an `ArrayAttr`, the verifier
+ checks that the provided `selected` attribute occurs in `options`.
}];
let cppNamespace = [{ mlir::transform::tune }];
let hasVerifier = 1;
diff --git a/mlir/lib/Dialect/Transform/TuneExtension/CMakeLists.txt b/mlir/lib/Dialect/Transform/TuneExtension/CMakeLists.txt
index ff01d25e57f68..56b90a9a04edf 100644
--- a/mlir/lib/Dialect/Transform/TuneExtension/CMakeLists.txt
+++ b/mlir/lib/Dialect/Transform/TuneExtension/CMakeLists.txt
@@ -8,5 +8,4 @@ add_mlir_dialect_library(MLIRTransformTuneExtension
LINK_LIBS PUBLIC
MLIRIR
MLIRTransformDialect
- MLIRTransforms
)
diff --git a/mlir/lib/Dialect/Transform/TuneExtension/TuneExtensionOps.cpp b/mlir/lib/Dialect/Transform/TuneExtension/TuneExtensionOps.cpp
index 0c77dbb0f05dd..75c1cc53e2606 100644
--- a/mlir/lib/Dialect/Transform/TuneExtension/TuneExtensionOps.cpp
+++ b/mlir/lib/Dialect/Transform/TuneExtension/TuneExtensionOps.cpp
@@ -37,7 +37,7 @@ transform::tune::KnobOp::apply(transform::TransformRewriter &rewriter,
transform::TransformResults &results,
transform::TransformState &state) {
if (getSelected()) {
- results.setParams(getOperation()->getOpResults()[0], *getSelected());
+ results.setParams(llvm::cast<OpResult>(getResult()), *getSelected());
return DiagnosedSilenceableFailure::success();
}
diff --git a/mlir/python/mlir/dialects/TransformTuneExtensionOps.td b/mlir/python/mlir/dialects/TransformTuneExtensionOps.td
index 60ed95d110762..ff3047592ab12 100644
--- a/mlir/python/mlir/dialects/TransformTuneExtensionOps.td
+++ b/mlir/python/mlir/dialects/TransformTuneExtensionOps.td
@@ -8,6 +8,7 @@
//
// Entry point of the generated Python bindings for the Tune extension of the
// Transform dialect.
+//
//===----------------------------------------------------------------------===//
#ifndef PYTHON_BINDINGS_TRANSFORM_DEBUG_EXTENSION_OPS
More information about the Mlir-commits
mailing list