[Mlir-commits] [mlir] [MLIR][Transform][Python] expose transform.debug extension in Python (PR #145550)
Rolf Morel
llvmlistbot at llvm.org
Wed Jun 25 08:10:16 PDT 2025
https://github.com/rolfmorel updated https://github.com/llvm/llvm-project/pull/145550
>From d502b3599db08c42cdad8bb5bd03b9c08525bb0a Mon Sep 17 00:00:00 2001
From: Rolf Morel <rolf.morel at intel.com>
Date: Tue, 24 Jun 2025 05:19:39 -0700
Subject: [PATCH 1/5] [MLIR][Transform][Python] expose transform.debug
extension in Python
---
mlir/python/CMakeLists.txt | 9 ++
.../dialects/TransformDebugExtensionOps.td | 19 ++++
mlir/python/mlir/dialects/transform/debug.py | 86 +++++++++++++++++++
.../python/dialects/transform_debug_ext.py | 47 ++++++++++
4 files changed, 161 insertions(+)
create mode 100644 mlir/python/mlir/dialects/TransformDebugExtensionOps.td
create mode 100644 mlir/python/mlir/dialects/transform/debug.py
create mode 100644 mlir/test/python/dialects/transform_debug_ext.py
diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt
index ee07081246fc7..b2daabb2a5957 100644
--- a/mlir/python/CMakeLists.txt
+++ b/mlir/python/CMakeLists.txt
@@ -171,6 +171,15 @@ ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
DIALECT_NAME transform
EXTENSION_NAME transform_pdl_extension)
+declare_mlir_dialect_extension_python_bindings(
+ADD_TO_PARENT MLIRPythonSources.Dialects
+ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
+ TD_FILE dialects/TransformDebugExtensionOps.td
+ SOURCES
+ dialects/transform/debug.py
+ DIALECT_NAME transform
+ EXTENSION_NAME transform_debug_extension)
+
declare_mlir_dialect_python_bindings(
ADD_TO_PARENT MLIRPythonSources.Dialects
ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
diff --git a/mlir/python/mlir/dialects/TransformDebugExtensionOps.td b/mlir/python/mlir/dialects/TransformDebugExtensionOps.td
new file mode 100644
index 0000000000000..22a85d2366994
--- /dev/null
+++ b/mlir/python/mlir/dialects/TransformDebugExtensionOps.td
@@ -0,0 +1,19 @@
+//===-- TransformDebugExtensionOps.td - Binding entry point *- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Entry point of the generated Python bindings for the Debug extension of the
+// Transform dialect.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef PYTHON_BINDINGS_TRANSFORM_DEBUG_EXTENSION_OPS
+#define PYTHON_BINDINGS_TRANSFORM_DEBUG_EXTENSION_OPS
+
+include "mlir/Dialect/Transform/DebugExtension/DebugExtensionOps.td"
+
+#endif // PYTHON_BINDINGS_TRANSFORM_DEBUG_EXTENSION_OPS
diff --git a/mlir/python/mlir/dialects/transform/debug.py b/mlir/python/mlir/dialects/transform/debug.py
new file mode 100644
index 0000000000000..738c556b1d362
--- /dev/null
+++ b/mlir/python/mlir/dialects/transform/debug.py
@@ -0,0 +1,86 @@
+# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+from typing import Optional
+
+from ...ir import Attribute, Operation, Value, StringAttr
+from .._transform_debug_extension_ops_gen import *
+from .._transform_pdl_extension_ops_gen import _Dialect
+
+try:
+ from .._ods_common import _cext as _ods_cext
+except ImportError as e:
+ raise RuntimeError("Error loading imports from extension module") from e
+
+from typing import Union
+
+
+ at _ods_cext.register_operation(_Dialect, replace=True)
+class DebugEmitParamAsRemarkOp(DebugEmitParamAsRemarkOp):
+ def __init__(
+ self,
+ param: Attribute,
+ *,
+ anchor: Optional[Operation] = None,
+ message: Optional[Union[StringAttr, str]] = None,
+ loc=None,
+ ip=None,
+ ):
+ if isinstance(message, str):
+ message = StringAttr.get(message)
+
+ super().__init__(
+ param,
+ anchor=anchor,
+ message=message,
+ loc=loc,
+ ip=ip,
+ )
+
+
+def emit_param_as_remark(
+ param: Attribute,
+ *,
+ anchor: Optional[Operation] = None,
+ message: Optional[Union[StringAttr, str]] = None,
+ loc=None,
+ ip=None,
+):
+ return DebugEmitParamAsRemarkOp(
+ param, anchor=anchor, message=message, loc=loc, ip=ip
+ )
+
+del debug_emit_param_as_remark
+
+ at _ods_cext.register_operation(_Dialect, replace=True)
+class DebugEmitRemarkAtOp(DebugEmitRemarkAtOp):
+ def __init__(
+ self,
+ at: Union[Operation, Value],
+ message: Optional[Union[StringAttr, str]] = None,
+ *,
+ loc=None,
+ ip=None,
+ ):
+ if isinstance(message, str):
+ message = StringAttr.get(message)
+
+ super().__init__(
+ at,
+ message,
+ loc=loc,
+ ip=ip,
+ )
+
+
+def emit_remark_at(
+ at: Union[Operation, Value],
+ message: Optional[Union[StringAttr, str]] = None,
+ *,
+ loc=None,
+ ip=None,
+):
+ return DebugEmitRemarkAtOp(at, message, loc=loc, ip=ip)
+
+del debug_emit_remark_at
diff --git a/mlir/test/python/dialects/transform_debug_ext.py b/mlir/test/python/dialects/transform_debug_ext.py
new file mode 100644
index 0000000000000..c96e7e66e03d1
--- /dev/null
+++ b/mlir/test/python/dialects/transform_debug_ext.py
@@ -0,0 +1,47 @@
+# RUN: %PYTHON %s | FileCheck %s
+
+from mlir.ir import *
+from mlir.dialects import transform
+from mlir.dialects.transform import debug
+
+
+def run(f):
+ print("\nTEST:", f.__name__)
+ with Context(), Location.unknown():
+ module = Module.create()
+ with InsertionPoint(module.body):
+ sequence = transform.SequenceOp(
+ transform.FailurePropagationMode.Propagate,
+ [],
+ transform.AnyOpType.get(),
+ )
+ with InsertionPoint(sequence.body):
+ f(sequence.bodyTarget)
+ transform.YieldOp()
+ print(module)
+ return f
+
+
+ at run
+def testDebugEmitParamAsRemark(target):
+ i0 = IntegerAttr.get(IntegerType.get_signless(32), 0)
+ i0_param = transform.ParamConstantOp(transform.AnyParamType.get(), i0)
+ debug.emit_param_as_remark(i0_param)
+ debug.emit_param_as_remark(i0_param, anchor=target, message="some text")
+ # CHECK-LABEL: TEST: testDebugEmitParamAsRemark
+ # CHECK: ^{{.*}}(%[[ARG0:.+]]: !transform.any_op):
+ # CHECK: %[[PARAM:.*]] = transform.param.constant
+ # CHECK: transform.debug.emit_param_as_remark %[[PARAM]]
+ # CHECK: transform.debug.emit_param_as_remark %[[PARAM]]
+ # CHECK-SAME: "some text"
+ # CHECK-SAME: at %[[ARG0]]
+
+
+ at run
+def testDebugEmitRemarkAtOp(target):
+ i0 = IntegerAttr.get(IntegerType.get_signless(32), 0)
+ i0_param = transform.ParamConstantOp(transform.AnyParamType.get(), i0)
+ debug.emit_remark_at(target, "some text")
+ # CHECK-LABEL: TEST: testDebugEmitRemarkAtOp
+ # CHECK: ^{{.*}}(%[[ARG0:.+]]: !transform.any_op):
+ # CHECK: transform.debug.emit_remark_at %[[ARG0]], "some text"
>From 3cdc54067d7acf53fa63042b35f46996886dff0a Mon Sep 17 00:00:00 2001
From: Rolf Morel <rolf.morel at intel.com>
Date: Tue, 24 Jun 2025 10:03:57 -0700
Subject: [PATCH 2/5] Formatting fix
---
mlir/python/mlir/dialects/transform/debug.py | 3 +++
1 file changed, 3 insertions(+)
diff --git a/mlir/python/mlir/dialects/transform/debug.py b/mlir/python/mlir/dialects/transform/debug.py
index 738c556b1d362..cb7d12dcec250 100644
--- a/mlir/python/mlir/dialects/transform/debug.py
+++ b/mlir/python/mlir/dialects/transform/debug.py
@@ -51,8 +51,10 @@ def emit_param_as_remark(
param, anchor=anchor, message=message, loc=loc, ip=ip
)
+
del debug_emit_param_as_remark
+
@_ods_cext.register_operation(_Dialect, replace=True)
class DebugEmitRemarkAtOp(DebugEmitRemarkAtOp):
def __init__(
@@ -83,4 +85,5 @@ def emit_remark_at(
):
return DebugEmitRemarkAtOp(at, message, loc=loc, ip=ip)
+
del debug_emit_remark_at
>From 45eaa96f0babe6a302b66f1a58e58a375093ecc1 Mon Sep 17 00:00:00 2001
From: Rolf Morel <rolf.morel at intel.com>
Date: Wed, 25 Jun 2025 03:56:43 -0700
Subject: [PATCH 3/5] Remove Debug... prefix from ops
---
.../DebugExtension/DebugExtensionOps.td | 4 ++--
.../DebugExtension/DebugExtensionOps.cpp | 4 ++--
mlir/python/mlir/dialects/transform/debug.py | 16 ++++------------
3 files changed, 8 insertions(+), 16 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Transform/DebugExtension/DebugExtensionOps.td b/mlir/include/mlir/Dialect/Transform/DebugExtension/DebugExtensionOps.td
index 0275f241fda35..4a6898e36d343 100644
--- a/mlir/include/mlir/Dialect/Transform/DebugExtension/DebugExtensionOps.td
+++ b/mlir/include/mlir/Dialect/Transform/DebugExtension/DebugExtensionOps.td
@@ -20,7 +20,7 @@ include "mlir/Dialect/Transform/Interfaces/MatchInterfaces.td"
include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.td"
include "mlir/Dialect/Transform/IR/TransformDialect.td"
-def DebugEmitRemarkAtOp : TransformDialectOp<"debug.emit_remark_at",
+def EmitRemarkAtOp : TransformDialectOp<"debug.emit_remark_at",
[MatchOpInterface,
DeclareOpInterfaceMethods<TransformOpInterface>,
MemoryEffectsOpInterface, NavigationTransformOpTrait]> {
@@ -39,7 +39,7 @@ def DebugEmitRemarkAtOp : TransformDialectOp<"debug.emit_remark_at",
let assemblyFormat = "$at `,` $message attr-dict `:` type($at)";
}
-def DebugEmitParamAsRemarkOp
+def EmitParamAsRemarkOp
: TransformDialectOp<"debug.emit_param_as_remark",
[MatchOpInterface,
DeclareOpInterfaceMethods<TransformOpInterface>,
diff --git a/mlir/lib/Dialect/Transform/DebugExtension/DebugExtensionOps.cpp b/mlir/lib/Dialect/Transform/DebugExtension/DebugExtensionOps.cpp
index 7a9f8f4b1b528..94f102ea123da 100644
--- a/mlir/lib/Dialect/Transform/DebugExtension/DebugExtensionOps.cpp
+++ b/mlir/lib/Dialect/Transform/DebugExtension/DebugExtensionOps.cpp
@@ -19,7 +19,7 @@ using namespace mlir;
#include "mlir/Dialect/Transform/DebugExtension/DebugExtensionOps.cpp.inc"
DiagnosedSilenceableFailure
-transform::DebugEmitRemarkAtOp::apply(transform::TransformRewriter &rewriter,
+transform::EmitRemarkAtOp::apply(transform::TransformRewriter &rewriter,
transform::TransformResults &results,
transform::TransformState &state) {
if (isa<TransformHandleTypeInterface>(getAt().getType())) {
@@ -52,7 +52,7 @@ transform::DebugEmitRemarkAtOp::apply(transform::TransformRewriter &rewriter,
return DiagnosedSilenceableFailure::success();
}
-DiagnosedSilenceableFailure transform::DebugEmitParamAsRemarkOp::apply(
+DiagnosedSilenceableFailure transform::EmitParamAsRemarkOp::apply(
transform::TransformRewriter &rewriter,
transform::TransformResults &results, transform::TransformState &state) {
std::string str;
diff --git a/mlir/python/mlir/dialects/transform/debug.py b/mlir/python/mlir/dialects/transform/debug.py
index cb7d12dcec250..f7c04268dc03d 100644
--- a/mlir/python/mlir/dialects/transform/debug.py
+++ b/mlir/python/mlir/dialects/transform/debug.py
@@ -17,7 +17,7 @@
@_ods_cext.register_operation(_Dialect, replace=True)
-class DebugEmitParamAsRemarkOp(DebugEmitParamAsRemarkOp):
+class EmitParamAsRemarkOp(EmitParamAsRemarkOp):
def __init__(
self,
param: Attribute,
@@ -47,16 +47,11 @@ def emit_param_as_remark(
loc=None,
ip=None,
):
- return DebugEmitParamAsRemarkOp(
- param, anchor=anchor, message=message, loc=loc, ip=ip
- )
-
-
-del debug_emit_param_as_remark
+ return EmitParamAsRemarkOp(param, anchor=anchor, message=message, loc=loc, ip=ip)
@_ods_cext.register_operation(_Dialect, replace=True)
-class DebugEmitRemarkAtOp(DebugEmitRemarkAtOp):
+class EmitRemarkAtOp(EmitRemarkAtOp):
def __init__(
self,
at: Union[Operation, Value],
@@ -83,7 +78,4 @@ def emit_remark_at(
loc=None,
ip=None,
):
- return DebugEmitRemarkAtOp(at, message, loc=loc, ip=ip)
-
-
-del debug_emit_remark_at
+ return EmitRemarkAtOp(at, message, loc=loc, ip=ip)
>From 25edc5c6473a3ba5bef1a483eacbcabf832c5680 Mon Sep 17 00:00:00 2001
From: Rolf Morel <rolf.morel at intel.com>
Date: Wed, 25 Jun 2025 07:56:22 -0700
Subject: [PATCH 4/5] Remove extraneous lines, per Alex's review
---
mlir/test/python/dialects/transform_debug_ext.py | 2 --
1 file changed, 2 deletions(-)
diff --git a/mlir/test/python/dialects/transform_debug_ext.py b/mlir/test/python/dialects/transform_debug_ext.py
index c96e7e66e03d1..2dfdaed343865 100644
--- a/mlir/test/python/dialects/transform_debug_ext.py
+++ b/mlir/test/python/dialects/transform_debug_ext.py
@@ -39,8 +39,6 @@ def testDebugEmitParamAsRemark(target):
@run
def testDebugEmitRemarkAtOp(target):
- i0 = IntegerAttr.get(IntegerType.get_signless(32), 0)
- i0_param = transform.ParamConstantOp(transform.AnyParamType.get(), i0)
debug.emit_remark_at(target, "some text")
# CHECK-LABEL: TEST: testDebugEmitRemarkAtOp
# CHECK: ^{{.*}}(%[[ARG0:.+]]: !transform.any_op):
>From 1c87e8d6da4c6d53daf8983ff1f9dc49000a7c0a Mon Sep 17 00:00:00 2001
From: Rolf Morel <rolf.morel at intel.com>
Date: Wed, 25 Jun 2025 08:09:19 -0700
Subject: [PATCH 5/5] Formatting fix
---
.../Transform/DebugExtension/DebugExtensionOps.cpp | 11 ++++++-----
1 file changed, 6 insertions(+), 5 deletions(-)
diff --git a/mlir/lib/Dialect/Transform/DebugExtension/DebugExtensionOps.cpp b/mlir/lib/Dialect/Transform/DebugExtension/DebugExtensionOps.cpp
index 94f102ea123da..12257da878a40 100644
--- a/mlir/lib/Dialect/Transform/DebugExtension/DebugExtensionOps.cpp
+++ b/mlir/lib/Dialect/Transform/DebugExtension/DebugExtensionOps.cpp
@@ -20,8 +20,8 @@ using namespace mlir;
DiagnosedSilenceableFailure
transform::EmitRemarkAtOp::apply(transform::TransformRewriter &rewriter,
- transform::TransformResults &results,
- transform::TransformState &state) {
+ transform::TransformResults &results,
+ transform::TransformState &state) {
if (isa<TransformHandleTypeInterface>(getAt().getType())) {
auto payload = state.getPayloadOps(getAt());
for (Operation *op : payload)
@@ -52,9 +52,10 @@ transform::EmitRemarkAtOp::apply(transform::TransformRewriter &rewriter,
return DiagnosedSilenceableFailure::success();
}
-DiagnosedSilenceableFailure transform::EmitParamAsRemarkOp::apply(
- transform::TransformRewriter &rewriter,
- transform::TransformResults &results, transform::TransformState &state) {
+DiagnosedSilenceableFailure
+transform::EmitParamAsRemarkOp::apply(transform::TransformRewriter &rewriter,
+ transform::TransformResults &results,
+ transform::TransformState &state) {
std::string str;
llvm::raw_string_ostream os(str);
if (getMessage())
More information about the Mlir-commits
mailing list