[Mlir-commits] [mlir] [mlir] expose transform dialect symbol merge to python (PR #87690)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Apr 4 12:46:56 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Oleksandr "Alex" Zinenko (ftynse)

<details>
<summary>Changes</summary>

This functionality is available in C++, make it available in Python directly to operate on transform modules.

---
Full diff: https://github.com/llvm/llvm-project/pull/87690.diff


5 Files Affected:

- (modified) mlir/include/mlir-c/Dialect/Transform/Interpreter.h (+11-1) 
- (modified) mlir/lib/Bindings/Python/TransformInterpreter.cpp (+15) 
- (modified) mlir/lib/CAPI/Dialect/TransformInterpreter.cpp (+9) 
- (modified) mlir/python/mlir/dialects/transform/interpreter/__init__.py (+7-1) 
- (modified) mlir/test/python/dialects/transform_interpreter.py (+76) 


``````````diff
diff --git a/mlir/include/mlir-c/Dialect/Transform/Interpreter.h b/mlir/include/mlir-c/Dialect/Transform/Interpreter.h
index 00095d5040a0e5..fa320324234e8d 100644
--- a/mlir/include/mlir-c/Dialect/Transform/Interpreter.h
+++ b/mlir/include/mlir-c/Dialect/Transform/Interpreter.h
@@ -60,7 +60,7 @@ MLIR_CAPI_EXPORTED void
 mlirTransformOptionsDestroy(MlirTransformOptions transformOptions);
 
 //----------------------------------------------------------------------------//
-// Transform interpreter.
+// Transform interpreter and utilities.
 //----------------------------------------------------------------------------//
 
 /// Applies the transformation script starting at the given transform root
@@ -72,6 +72,16 @@ MLIR_CAPI_EXPORTED MlirLogicalResult mlirTransformApplyNamedSequence(
     MlirOperation payload, MlirOperation transformRoot,
     MlirOperation transformModule, MlirTransformOptions transformOptions);
 
+/// Merge the symbols from `other` into `target`, potentially renaming them to
+/// avoid conflicts. Private symbols may be renamed during the merge, public
+/// symbols must have at most one declaration. A name conflict in public symbols
+/// is reported as an error before returning a failure.
+///
+/// Note that this clones the `other` operation unlike the C++ counterpart that
+/// takes ownership.
+MLIR_CAPI_EXPORTED MlirLogicalResult
+mlirMergeSymbolsIntoFromClone(MlirOperation target, MlirOperation other);
+
 #ifdef __cplusplus
 }
 #endif
diff --git a/mlir/lib/Bindings/Python/TransformInterpreter.cpp b/mlir/lib/Bindings/Python/TransformInterpreter.cpp
index 6517f8c39dfadd..6448ae433b5c3f 100644
--- a/mlir/lib/Bindings/Python/TransformInterpreter.cpp
+++ b/mlir/lib/Bindings/Python/TransformInterpreter.cpp
@@ -82,6 +82,21 @@ static void populateTransformInterpreterSubmodule(py::module &m) {
       py::arg("payload_root"), py::arg("transform_root"),
       py::arg("transform_module"),
       py::arg("transform_options") = PyMlirTransformOptions());
+
+  m.def(
+      "merge_symbols_into",
+      [](MlirOperation target, MlirOperation other) {
+        mlir::python::CollectDiagnosticsToStringScope scope(
+            mlirOperationGetContext(target));
+
+        MlirLogicalResult result = mlirMergeSymbolsIntoFromClone(target, other);
+        if (mlirLogicalResultIsSuccess(result))
+          return;
+
+        throw py::value_error("Failed to merge symbols.\nDiagnostic message " +
+                              scope.takeMessage());
+      },
+      py::arg("target"), py::arg("other"));
 }
 
 PYBIND11_MODULE(_mlirTransformInterpreter, m) {
diff --git a/mlir/lib/CAPI/Dialect/TransformInterpreter.cpp b/mlir/lib/CAPI/Dialect/TransformInterpreter.cpp
index eb6951dc5584d6..145455e1c1b3d2 100644
--- a/mlir/lib/CAPI/Dialect/TransformInterpreter.cpp
+++ b/mlir/lib/CAPI/Dialect/TransformInterpreter.cpp
@@ -15,6 +15,7 @@
 #include "mlir/CAPI/IR.h"
 #include "mlir/CAPI/Support.h"
 #include "mlir/CAPI/Wrap.h"
+#include "mlir/Dialect/Transform/IR/Utils.h"
 #include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
 #include "mlir/Dialect/Transform/Transforms/TransformInterpreterUtils.h"
 
@@ -71,4 +72,12 @@ MlirLogicalResult mlirTransformApplyNamedSequence(
       unwrap(payload), unwrap(transformRoot),
       cast<ModuleOp>(unwrap(transformModule)), *unwrap(transformOptions)));
 }
+
+MlirLogicalResult mlirMergeSymbolsIntoFromClone(MlirOperation target,
+                                                MlirOperation other) {
+  OwningOpRef<Operation *> otherOwning(unwrap(other)->clone());
+  LogicalResult result = transform::detail::mergeSymbolsInto(
+      unwrap(target), std::move(otherOwning));
+  return wrap(result);
+}
 }
diff --git a/mlir/python/mlir/dialects/transform/interpreter/__init__.py b/mlir/python/mlir/dialects/transform/interpreter/__init__.py
index 6145b99224eb54..4de827257174ab 100644
--- a/mlir/python/mlir/dialects/transform/interpreter/__init__.py
+++ b/mlir/python/mlir/dialects/transform/interpreter/__init__.py
@@ -5,7 +5,6 @@
 from ....ir import Operation
 from ...._mlir_libs import _mlirTransformInterpreter as _cextTransformInterpreter
 
-
 TransformOptions = _cextTransformInterpreter.TransformOptions
 
 
@@ -31,3 +30,10 @@ def apply_named_sequence(
         _cextTransformInterpreter.apply_named_sequence(*args)
     else:
         _cextTransformInterpreter(*args, transform_options)
+
+
+def merge_symbols_into(target, other):
+    """Copies symbols from other into target, renaming private symbols to avoid duplicates. Raises an error if copying would lead to duplicate public symbols."""
+    _cextTransformInterpreter.merge_symbols_into(
+        _unpack_operation(target), _unpack_operation(other)
+    )
diff --git a/mlir/test/python/dialects/transform_interpreter.py b/mlir/test/python/dialects/transform_interpreter.py
index 740c49f76a26c4..d3ada7f32d8d59 100644
--- a/mlir/test/python/dialects/transform_interpreter.py
+++ b/mlir/test/python/dialects/transform_interpreter.py
@@ -54,3 +54,79 @@ def failed():
         assert (
             "must implement TransformOpInterface to be used as transform root" in str(e)
         )
+
+
+print_root_via_include_module = """
+module @print_root_via_include_module attributes {transform.with_named_sequence} {
+  transform.named_sequence private @callee1(%root: !transform.any_op {transform.readonly})
+  transform.named_sequence private @callee2(%root: !transform.any_op {transform.readonly})
+  transform.named_sequence @__transform_main(%root: !transform.any_op) {
+    transform.include @callee2 failures(propagate)
+        (%root) : (!transform.any_op) -> ()
+    transform.yield
+  }
+}"""
+
+callee2_definition = """
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence private @callee1(%root: !transform.any_op {transform.readonly})
+  transform.named_sequence @callee2(%root: !transform.any_op {transform.readonly}) {
+    transform.include @callee1 failures(propagate)
+        (%root) : (!transform.any_op) -> ()
+    transform.yield
+  }
+}
+"""
+
+callee1_definition = """
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @callee1(%root: !transform.any_op {transform.readonly}) {
+    transform.print %root { name = \"from interpreter\" }: !transform.any_op
+    transform.yield
+  }
+}
+"""
+
+
+ at test_in_context
+def include():
+    main = ir.Module.parse(print_root_via_include_module)
+    callee1 = ir.Module.parse(callee1_definition)
+    callee2 = ir.Module.parse(callee2_definition)
+    interp.merge_symbols_into(main, callee1)
+    interp.merge_symbols_into(main, callee2)
+
+    # CHECK: @print_root_via_include_module
+    # CHECK: transform.named_sequence @__transform_main
+    # CHECK: transform.include @callee2
+    #
+    # CHECK: transform.named_sequence @callee1
+    # CHECK: transform.print
+    #
+    # CHECK: transform.named_sequence @callee2
+    # CHECK: transform.include @callee1
+    interp.apply_named_sequence(main, main.body.operations[0], main)
+
+
+ at test_in_context
+def partial_include():
+    main = ir.Module.parse(print_root_via_include_module)
+    callee2 = ir.Module.parse(callee2_definition)
+    interp.merge_symbols_into(main, callee2)
+
+    try:
+        interp.apply_named_sequence(main, main.body.operations[0], main)
+    except ValueError as e:
+        assert "Failed to apply" in str(e)
+
+
+ at test_in_context
+def repeated_include():
+    main = ir.Module.parse(print_root_via_include_module)
+    callee2 = ir.Module.parse(callee2_definition)
+    interp.merge_symbols_into(main, callee2)
+
+    try:
+        interp.merge_symbols_into(main, callee2)
+    except ValueError as e:
+        assert "doubly defined symbol @callee2" in str(e)

``````````

</details>


https://github.com/llvm/llvm-project/pull/87690


More information about the Mlir-commits mailing list