[Mlir-commits] [mlir] 73140da - [mlir] expose transform dialect symbol merge to python (#87690)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Apr 17 06:02:04 PDT 2024
Author: Oleksandr "Alex" Zinenko
Date: 2024-04-17T15:01:59+02:00
New Revision: 73140daebbf522dbb14dc4b2f3c67dc0aa1a62dd
URL: https://github.com/llvm/llvm-project/commit/73140daebbf522dbb14dc4b2f3c67dc0aa1a62dd
DIFF: https://github.com/llvm/llvm-project/commit/73140daebbf522dbb14dc4b2f3c67dc0aa1a62dd.diff
LOG: [mlir] expose transform dialect symbol merge to python (#87690)
This functionality is available in C++, make it available in Python
directly to operate on transform modules.
Added:
Modified:
mlir/include/mlir-c/Dialect/Transform/Interpreter.h
mlir/lib/Bindings/Python/TransformInterpreter.cpp
mlir/lib/CAPI/Dialect/TransformInterpreter.cpp
mlir/python/mlir/dialects/transform/interpreter/__init__.py
mlir/test/python/dialects/transform_interpreter.py
Removed:
################################################################################
diff --git a/mlir/include/mlir-c/Dialect/Transform/Interpreter.h b/mlir/include/mlir-c/Dialect/Transform/Interpreter.h
index 00095d5040a0e5..fa320324234e8d 100644
--- a/mlir/include/mlir-c/Dialect/Transform/Interpreter.h
+++ b/mlir/include/mlir-c/Dialect/Transform/Interpreter.h
@@ -60,7 +60,7 @@ MLIR_CAPI_EXPORTED void
mlirTransformOptionsDestroy(MlirTransformOptions transformOptions);
//----------------------------------------------------------------------------//
-// Transform interpreter.
+// Transform interpreter and utilities.
//----------------------------------------------------------------------------//
/// Applies the transformation script starting at the given transform root
@@ -72,6 +72,16 @@ MLIR_CAPI_EXPORTED MlirLogicalResult mlirTransformApplyNamedSequence(
MlirOperation payload, MlirOperation transformRoot,
MlirOperation transformModule, MlirTransformOptions transformOptions);
+/// Merge the symbols from `other` into `target`, potentially renaming them to
+/// avoid conflicts. Private symbols may be renamed during the merge, public
+/// symbols must have at most one declaration. A name conflict in public symbols
+/// is reported as an error before returning a failure.
+///
+/// Note that this clones the `other` operation unlike the C++ counterpart that
+/// takes ownership.
+MLIR_CAPI_EXPORTED MlirLogicalResult
+mlirMergeSymbolsIntoFromClone(MlirOperation target, MlirOperation other);
+
#ifdef __cplusplus
}
#endif
diff --git a/mlir/lib/Bindings/Python/TransformInterpreter.cpp b/mlir/lib/Bindings/Python/TransformInterpreter.cpp
index 6517f8c39dfadd..f6b4532b1b6be4 100644
--- a/mlir/lib/Bindings/Python/TransformInterpreter.cpp
+++ b/mlir/lib/Bindings/Python/TransformInterpreter.cpp
@@ -82,6 +82,21 @@ static void populateTransformInterpreterSubmodule(py::module &m) {
py::arg("payload_root"), py::arg("transform_root"),
py::arg("transform_module"),
py::arg("transform_options") = PyMlirTransformOptions());
+
+ m.def(
+ "copy_symbols_and_merge_into",
+ [](MlirOperation target, MlirOperation other) {
+ mlir::python::CollectDiagnosticsToStringScope scope(
+ mlirOperationGetContext(target));
+
+ MlirLogicalResult result = mlirMergeSymbolsIntoFromClone(target, other);
+ if (mlirLogicalResultIsFailure(result)) {
+ throw py::value_error(
+ "Failed to merge symbols.\nDiagnostic message " +
+ scope.takeMessage());
+ }
+ },
+ py::arg("target"), py::arg("other"));
}
PYBIND11_MODULE(_mlirTransformInterpreter, m) {
diff --git a/mlir/lib/CAPI/Dialect/TransformInterpreter.cpp b/mlir/lib/CAPI/Dialect/TransformInterpreter.cpp
index eb6951dc5584d6..145455e1c1b3d2 100644
--- a/mlir/lib/CAPI/Dialect/TransformInterpreter.cpp
+++ b/mlir/lib/CAPI/Dialect/TransformInterpreter.cpp
@@ -15,6 +15,7 @@
#include "mlir/CAPI/IR.h"
#include "mlir/CAPI/Support.h"
#include "mlir/CAPI/Wrap.h"
+#include "mlir/Dialect/Transform/IR/Utils.h"
#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
#include "mlir/Dialect/Transform/Transforms/TransformInterpreterUtils.h"
@@ -71,4 +72,12 @@ MlirLogicalResult mlirTransformApplyNamedSequence(
unwrap(payload), unwrap(transformRoot),
cast<ModuleOp>(unwrap(transformModule)), *unwrap(transformOptions)));
}
+
+MlirLogicalResult mlirMergeSymbolsIntoFromClone(MlirOperation target,
+ MlirOperation other) {
+ OwningOpRef<Operation *> otherOwning(unwrap(other)->clone());
+ LogicalResult result = transform::detail::mergeSymbolsInto(
+ unwrap(target), std::move(otherOwning));
+ return wrap(result);
+}
}
diff --git a/mlir/python/mlir/dialects/transform/interpreter/__init__.py b/mlir/python/mlir/dialects/transform/interpreter/__init__.py
index 6145b99224eb54..34cdc43cb617fd 100644
--- a/mlir/python/mlir/dialects/transform/interpreter/__init__.py
+++ b/mlir/python/mlir/dialects/transform/interpreter/__init__.py
@@ -5,7 +5,6 @@
from ....ir import Operation
from ...._mlir_libs import _mlirTransformInterpreter as _cextTransformInterpreter
-
TransformOptions = _cextTransformInterpreter.TransformOptions
@@ -31,3 +30,12 @@ def apply_named_sequence(
_cextTransformInterpreter.apply_named_sequence(*args)
else:
_cextTransformInterpreter(*args, transform_options)
+
+
+def copy_symbols_and_merge_into(target, other):
+ """Copies symbols from other into target, renaming private symbols to avoid
+ duplicates. Raises an error if copying would lead to duplicate public
+ symbols."""
+ _cextTransformInterpreter.copy_symbols_and_merge_into(
+ _unpack_operation(target), _unpack_operation(other)
+ )
diff --git a/mlir/test/python/dialects/transform_interpreter.py b/mlir/test/python/dialects/transform_interpreter.py
index 740c49f76a26c4..807a98c4932797 100644
--- a/mlir/test/python/dialects/transform_interpreter.py
+++ b/mlir/test/python/dialects/transform_interpreter.py
@@ -54,3 +54,79 @@ def failed():
assert (
"must implement TransformOpInterface to be used as transform root" in str(e)
)
+
+
+print_root_via_include_module = """
+module @print_root_via_include_module attributes {transform.with_named_sequence} {
+ transform.named_sequence private @callee1(%root: !transform.any_op {transform.readonly})
+ transform.named_sequence private @callee2(%root: !transform.any_op {transform.readonly})
+ transform.named_sequence @__transform_main(%root: !transform.any_op) {
+ transform.include @callee2 failures(propagate)
+ (%root) : (!transform.any_op) -> ()
+ transform.yield
+ }
+}"""
+
+callee2_definition = """
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence private @callee1(%root: !transform.any_op {transform.readonly})
+ transform.named_sequence @callee2(%root: !transform.any_op {transform.readonly}) {
+ transform.include @callee1 failures(propagate)
+ (%root) : (!transform.any_op) -> ()
+ transform.yield
+ }
+}
+"""
+
+callee1_definition = """
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @callee1(%root: !transform.any_op {transform.readonly}) {
+ transform.print %root { name = \"from interpreter\" }: !transform.any_op
+ transform.yield
+ }
+}
+"""
+
+
+ at test_in_context
+def include():
+ main = ir.Module.parse(print_root_via_include_module)
+ callee1 = ir.Module.parse(callee1_definition)
+ callee2 = ir.Module.parse(callee2_definition)
+ interp.copy_symbols_and_merge_into(main, callee1)
+ interp.copy_symbols_and_merge_into(main, callee2)
+
+ # CHECK: @print_root_via_include_module
+ # CHECK: transform.named_sequence @__transform_main
+ # CHECK: transform.include @callee2
+ #
+ # CHECK: transform.named_sequence @callee1
+ # CHECK: transform.print
+ #
+ # CHECK: transform.named_sequence @callee2
+ # CHECK: transform.include @callee1
+ interp.apply_named_sequence(main, main.body.operations[0], main)
+
+
+ at test_in_context
+def partial_include():
+ main = ir.Module.parse(print_root_via_include_module)
+ callee2 = ir.Module.parse(callee2_definition)
+ interp.copy_symbols_and_merge_into(main, callee2)
+
+ try:
+ interp.apply_named_sequence(main, main.body.operations[0], main)
+ except ValueError as e:
+ assert "Failed to apply" in str(e)
+
+
+ at test_in_context
+def repeated_include():
+ main = ir.Module.parse(print_root_via_include_module)
+ callee2 = ir.Module.parse(callee2_definition)
+ interp.copy_symbols_and_merge_into(main, callee2)
+
+ try:
+ interp.copy_symbols_and_merge_into(main, callee2)
+ except ValueError as e:
+ assert "doubly defined symbol @callee2" in str(e)
More information about the Mlir-commits
mailing list