[Mlir-commits] [mlir] [mlir] expose transform dialect symbol merge to python (PR #87690)
Oleksandr Alex Zinenko
llvmlistbot at llvm.org
Tue Apr 16 03:06:18 PDT 2024
https://github.com/ftynse updated https://github.com/llvm/llvm-project/pull/87690
>From 6acb09fc7e84ae6f1406eb60d763610003eee648 Mon Sep 17 00:00:00 2001
From: Alex Zinenko <zinenko at google.com>
Date: Thu, 4 Apr 2024 19:44:13 +0000
Subject: [PATCH] [mlir] expose transform dialect symbol merge to python
This functionality is available in C++, make it available in Python
directly to operate on transform modules.
---
.../mlir-c/Dialect/Transform/Interpreter.h | 12 ++-
.../Bindings/Python/TransformInterpreter.cpp | 15 ++++
.../lib/CAPI/Dialect/TransformInterpreter.cpp | 9 +++
.../transform/interpreter/__init__.py | 8 +-
.../python/dialects/transform_interpreter.py | 76 +++++++++++++++++++
5 files changed, 118 insertions(+), 2 deletions(-)
diff --git a/mlir/include/mlir-c/Dialect/Transform/Interpreter.h b/mlir/include/mlir-c/Dialect/Transform/Interpreter.h
index 00095d5040a0e5..fa320324234e8d 100644
--- a/mlir/include/mlir-c/Dialect/Transform/Interpreter.h
+++ b/mlir/include/mlir-c/Dialect/Transform/Interpreter.h
@@ -60,7 +60,7 @@ MLIR_CAPI_EXPORTED void
mlirTransformOptionsDestroy(MlirTransformOptions transformOptions);
//----------------------------------------------------------------------------//
-// Transform interpreter.
+// Transform interpreter and utilities.
//----------------------------------------------------------------------------//
/// Applies the transformation script starting at the given transform root
@@ -72,6 +72,16 @@ MLIR_CAPI_EXPORTED MlirLogicalResult mlirTransformApplyNamedSequence(
MlirOperation payload, MlirOperation transformRoot,
MlirOperation transformModule, MlirTransformOptions transformOptions);
+/// Merge the symbols from `other` into `target`, potentially renaming them to
+/// avoid conflicts. Private symbols may be renamed during the merge, public
+/// symbols must have at most one declaration. A name conflict in public symbols
+/// is reported as an error before returning a failure.
+///
+/// Note that this clones the `other` operation unlike the C++ counterpart that
+/// takes ownership.
+MLIR_CAPI_EXPORTED MlirLogicalResult
+mlirMergeSymbolsIntoFromClone(MlirOperation target, MlirOperation other);
+
#ifdef __cplusplus
}
#endif
diff --git a/mlir/lib/Bindings/Python/TransformInterpreter.cpp b/mlir/lib/Bindings/Python/TransformInterpreter.cpp
index 6517f8c39dfadd..6448ae433b5c3f 100644
--- a/mlir/lib/Bindings/Python/TransformInterpreter.cpp
+++ b/mlir/lib/Bindings/Python/TransformInterpreter.cpp
@@ -82,6 +82,21 @@ static void populateTransformInterpreterSubmodule(py::module &m) {
py::arg("payload_root"), py::arg("transform_root"),
py::arg("transform_module"),
py::arg("transform_options") = PyMlirTransformOptions());
+
+ m.def(
+ "merge_symbols_into",
+ [](MlirOperation target, MlirOperation other) {
+ mlir::python::CollectDiagnosticsToStringScope scope(
+ mlirOperationGetContext(target));
+
+ MlirLogicalResult result = mlirMergeSymbolsIntoFromClone(target, other);
+ if (mlirLogicalResultIsSuccess(result))
+ return;
+
+ throw py::value_error("Failed to merge symbols.\nDiagnostic message " +
+ scope.takeMessage());
+ },
+ py::arg("target"), py::arg("other"));
}
PYBIND11_MODULE(_mlirTransformInterpreter, m) {
diff --git a/mlir/lib/CAPI/Dialect/TransformInterpreter.cpp b/mlir/lib/CAPI/Dialect/TransformInterpreter.cpp
index eb6951dc5584d6..145455e1c1b3d2 100644
--- a/mlir/lib/CAPI/Dialect/TransformInterpreter.cpp
+++ b/mlir/lib/CAPI/Dialect/TransformInterpreter.cpp
@@ -15,6 +15,7 @@
#include "mlir/CAPI/IR.h"
#include "mlir/CAPI/Support.h"
#include "mlir/CAPI/Wrap.h"
+#include "mlir/Dialect/Transform/IR/Utils.h"
#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
#include "mlir/Dialect/Transform/Transforms/TransformInterpreterUtils.h"
@@ -71,4 +72,12 @@ MlirLogicalResult mlirTransformApplyNamedSequence(
unwrap(payload), unwrap(transformRoot),
cast<ModuleOp>(unwrap(transformModule)), *unwrap(transformOptions)));
}
+
+MlirLogicalResult mlirMergeSymbolsIntoFromClone(MlirOperation target,
+ MlirOperation other) {
+ OwningOpRef<Operation *> otherOwning(unwrap(other)->clone());
+ LogicalResult result = transform::detail::mergeSymbolsInto(
+ unwrap(target), std::move(otherOwning));
+ return wrap(result);
+}
}
diff --git a/mlir/python/mlir/dialects/transform/interpreter/__init__.py b/mlir/python/mlir/dialects/transform/interpreter/__init__.py
index 6145b99224eb54..4de827257174ab 100644
--- a/mlir/python/mlir/dialects/transform/interpreter/__init__.py
+++ b/mlir/python/mlir/dialects/transform/interpreter/__init__.py
@@ -5,7 +5,6 @@
from ....ir import Operation
from ...._mlir_libs import _mlirTransformInterpreter as _cextTransformInterpreter
-
TransformOptions = _cextTransformInterpreter.TransformOptions
@@ -31,3 +30,10 @@ def apply_named_sequence(
_cextTransformInterpreter.apply_named_sequence(*args)
else:
_cextTransformInterpreter(*args, transform_options)
+
+
+def merge_symbols_into(target, other):
+ """Copies symbols from other into target, renaming private symbols to avoid duplicates. Raises an error if copying would lead to duplicate public symbols."""
+ _cextTransformInterpreter.merge_symbols_into(
+ _unpack_operation(target), _unpack_operation(other)
+ )
diff --git a/mlir/test/python/dialects/transform_interpreter.py b/mlir/test/python/dialects/transform_interpreter.py
index 740c49f76a26c4..d3ada7f32d8d59 100644
--- a/mlir/test/python/dialects/transform_interpreter.py
+++ b/mlir/test/python/dialects/transform_interpreter.py
@@ -54,3 +54,79 @@ def failed():
assert (
"must implement TransformOpInterface to be used as transform root" in str(e)
)
+
+
+print_root_via_include_module = """
+module @print_root_via_include_module attributes {transform.with_named_sequence} {
+ transform.named_sequence private @callee1(%root: !transform.any_op {transform.readonly})
+ transform.named_sequence private @callee2(%root: !transform.any_op {transform.readonly})
+ transform.named_sequence @__transform_main(%root: !transform.any_op) {
+ transform.include @callee2 failures(propagate)
+ (%root) : (!transform.any_op) -> ()
+ transform.yield
+ }
+}"""
+
+callee2_definition = """
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence private @callee1(%root: !transform.any_op {transform.readonly})
+ transform.named_sequence @callee2(%root: !transform.any_op {transform.readonly}) {
+ transform.include @callee1 failures(propagate)
+ (%root) : (!transform.any_op) -> ()
+ transform.yield
+ }
+}
+"""
+
+callee1_definition = """
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @callee1(%root: !transform.any_op {transform.readonly}) {
+ transform.print %root { name = \"from interpreter\" }: !transform.any_op
+ transform.yield
+ }
+}
+"""
+
+
+ at test_in_context
+def include():
+ main = ir.Module.parse(print_root_via_include_module)
+ callee1 = ir.Module.parse(callee1_definition)
+ callee2 = ir.Module.parse(callee2_definition)
+ interp.merge_symbols_into(main, callee1)
+ interp.merge_symbols_into(main, callee2)
+
+ # CHECK: @print_root_via_include_module
+ # CHECK: transform.named_sequence @__transform_main
+ # CHECK: transform.include @callee2
+ #
+ # CHECK: transform.named_sequence @callee1
+ # CHECK: transform.print
+ #
+ # CHECK: transform.named_sequence @callee2
+ # CHECK: transform.include @callee1
+ interp.apply_named_sequence(main, main.body.operations[0], main)
+
+
+ at test_in_context
+def partial_include():
+ main = ir.Module.parse(print_root_via_include_module)
+ callee2 = ir.Module.parse(callee2_definition)
+ interp.merge_symbols_into(main, callee2)
+
+ try:
+ interp.apply_named_sequence(main, main.body.operations[0], main)
+ except ValueError as e:
+ assert "Failed to apply" in str(e)
+
+
+ at test_in_context
+def repeated_include():
+ main = ir.Module.parse(print_root_via_include_module)
+ callee2 = ir.Module.parse(callee2_definition)
+ interp.merge_symbols_into(main, callee2)
+
+ try:
+ interp.merge_symbols_into(main, callee2)
+ except ValueError as e:
+ assert "doubly defined symbol @callee2" in str(e)
More information about the Mlir-commits
mailing list