[Mlir-commits] [mlir] 73140da - [mlir] expose transform dialect symbol merge to python (#87690)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Apr 17 06:02:04 PDT 2024


Author: Oleksandr "Alex" Zinenko
Date: 2024-04-17T15:01:59+02:00
New Revision: 73140daebbf522dbb14dc4b2f3c67dc0aa1a62dd

URL: https://github.com/llvm/llvm-project/commit/73140daebbf522dbb14dc4b2f3c67dc0aa1a62dd
DIFF: https://github.com/llvm/llvm-project/commit/73140daebbf522dbb14dc4b2f3c67dc0aa1a62dd.diff

LOG: [mlir] expose transform dialect symbol merge to python (#87690)

This functionality is available in C++, make it available in Python
directly to operate on transform modules.

Added: 
    

Modified: 
    mlir/include/mlir-c/Dialect/Transform/Interpreter.h
    mlir/lib/Bindings/Python/TransformInterpreter.cpp
    mlir/lib/CAPI/Dialect/TransformInterpreter.cpp
    mlir/python/mlir/dialects/transform/interpreter/__init__.py
    mlir/test/python/dialects/transform_interpreter.py

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir-c/Dialect/Transform/Interpreter.h b/mlir/include/mlir-c/Dialect/Transform/Interpreter.h
index 00095d5040a0e5..fa320324234e8d 100644
--- a/mlir/include/mlir-c/Dialect/Transform/Interpreter.h
+++ b/mlir/include/mlir-c/Dialect/Transform/Interpreter.h
@@ -60,7 +60,7 @@ MLIR_CAPI_EXPORTED void
 mlirTransformOptionsDestroy(MlirTransformOptions transformOptions);
 
 //----------------------------------------------------------------------------//
-// Transform interpreter.
+// Transform interpreter and utilities.
 //----------------------------------------------------------------------------//
 
 /// Applies the transformation script starting at the given transform root
@@ -72,6 +72,16 @@ MLIR_CAPI_EXPORTED MlirLogicalResult mlirTransformApplyNamedSequence(
     MlirOperation payload, MlirOperation transformRoot,
     MlirOperation transformModule, MlirTransformOptions transformOptions);
 
+/// Merge the symbols from `other` into `target`, potentially renaming them to
+/// avoid conflicts. Private symbols may be renamed during the merge, public
+/// symbols must have at most one declaration. A name conflict in public symbols
+/// is reported as an error before returning a failure.
+///
+/// Note that this clones the `other` operation unlike the C++ counterpart that
+/// takes ownership.
+MLIR_CAPI_EXPORTED MlirLogicalResult
+mlirMergeSymbolsIntoFromClone(MlirOperation target, MlirOperation other);
+
 #ifdef __cplusplus
 }
 #endif

diff  --git a/mlir/lib/Bindings/Python/TransformInterpreter.cpp b/mlir/lib/Bindings/Python/TransformInterpreter.cpp
index 6517f8c39dfadd..f6b4532b1b6be4 100644
--- a/mlir/lib/Bindings/Python/TransformInterpreter.cpp
+++ b/mlir/lib/Bindings/Python/TransformInterpreter.cpp
@@ -82,6 +82,21 @@ static void populateTransformInterpreterSubmodule(py::module &m) {
       py::arg("payload_root"), py::arg("transform_root"),
       py::arg("transform_module"),
       py::arg("transform_options") = PyMlirTransformOptions());
+
+  m.def(
+      "copy_symbols_and_merge_into",
+      [](MlirOperation target, MlirOperation other) {
+        mlir::python::CollectDiagnosticsToStringScope scope(
+            mlirOperationGetContext(target));
+
+        MlirLogicalResult result = mlirMergeSymbolsIntoFromClone(target, other);
+        if (mlirLogicalResultIsFailure(result)) {
+          throw py::value_error(
+              "Failed to merge symbols.\nDiagnostic message " +
+              scope.takeMessage());
+        }
+      },
+      py::arg("target"), py::arg("other"));
 }
 
 PYBIND11_MODULE(_mlirTransformInterpreter, m) {

diff  --git a/mlir/lib/CAPI/Dialect/TransformInterpreter.cpp b/mlir/lib/CAPI/Dialect/TransformInterpreter.cpp
index eb6951dc5584d6..145455e1c1b3d2 100644
--- a/mlir/lib/CAPI/Dialect/TransformInterpreter.cpp
+++ b/mlir/lib/CAPI/Dialect/TransformInterpreter.cpp
@@ -15,6 +15,7 @@
 #include "mlir/CAPI/IR.h"
 #include "mlir/CAPI/Support.h"
 #include "mlir/CAPI/Wrap.h"
+#include "mlir/Dialect/Transform/IR/Utils.h"
 #include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
 #include "mlir/Dialect/Transform/Transforms/TransformInterpreterUtils.h"
 
@@ -71,4 +72,12 @@ MlirLogicalResult mlirTransformApplyNamedSequence(
       unwrap(payload), unwrap(transformRoot),
       cast<ModuleOp>(unwrap(transformModule)), *unwrap(transformOptions)));
 }
+
+MlirLogicalResult mlirMergeSymbolsIntoFromClone(MlirOperation target,
+                                                MlirOperation other) {
+  OwningOpRef<Operation *> otherOwning(unwrap(other)->clone());
+  LogicalResult result = transform::detail::mergeSymbolsInto(
+      unwrap(target), std::move(otherOwning));
+  return wrap(result);
+}
 }

diff  --git a/mlir/python/mlir/dialects/transform/interpreter/__init__.py b/mlir/python/mlir/dialects/transform/interpreter/__init__.py
index 6145b99224eb54..34cdc43cb617fd 100644
--- a/mlir/python/mlir/dialects/transform/interpreter/__init__.py
+++ b/mlir/python/mlir/dialects/transform/interpreter/__init__.py
@@ -5,7 +5,6 @@
 from ....ir import Operation
 from ...._mlir_libs import _mlirTransformInterpreter as _cextTransformInterpreter
 
-
 TransformOptions = _cextTransformInterpreter.TransformOptions
 
 
@@ -31,3 +30,12 @@ def apply_named_sequence(
         _cextTransformInterpreter.apply_named_sequence(*args)
     else:
         _cextTransformInterpreter(*args, transform_options)
+
+
+def copy_symbols_and_merge_into(target, other):
+    """Copies symbols from other into target, renaming private symbols to avoid
+    duplicates. Raises an error if copying would lead to duplicate public
+    symbols."""
+    _cextTransformInterpreter.copy_symbols_and_merge_into(
+        _unpack_operation(target), _unpack_operation(other)
+    )

diff  --git a/mlir/test/python/dialects/transform_interpreter.py b/mlir/test/python/dialects/transform_interpreter.py
index 740c49f76a26c4..807a98c4932797 100644
--- a/mlir/test/python/dialects/transform_interpreter.py
+++ b/mlir/test/python/dialects/transform_interpreter.py
@@ -54,3 +54,79 @@ def failed():
         assert (
             "must implement TransformOpInterface to be used as transform root" in str(e)
         )
+
+
+print_root_via_include_module = """
+module @print_root_via_include_module attributes {transform.with_named_sequence} {
+  transform.named_sequence private @callee1(%root: !transform.any_op {transform.readonly})
+  transform.named_sequence private @callee2(%root: !transform.any_op {transform.readonly})
+  transform.named_sequence @__transform_main(%root: !transform.any_op) {
+    transform.include @callee2 failures(propagate)
+        (%root) : (!transform.any_op) -> ()
+    transform.yield
+  }
+}"""
+
+callee2_definition = """
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence private @callee1(%root: !transform.any_op {transform.readonly})
+  transform.named_sequence @callee2(%root: !transform.any_op {transform.readonly}) {
+    transform.include @callee1 failures(propagate)
+        (%root) : (!transform.any_op) -> ()
+    transform.yield
+  }
+}
+"""
+
+callee1_definition = """
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @callee1(%root: !transform.any_op {transform.readonly}) {
+    transform.print %root { name = \"from interpreter\" }: !transform.any_op
+    transform.yield
+  }
+}
+"""
+
+
+ at test_in_context
+def include():
+    main = ir.Module.parse(print_root_via_include_module)
+    callee1 = ir.Module.parse(callee1_definition)
+    callee2 = ir.Module.parse(callee2_definition)
+    interp.copy_symbols_and_merge_into(main, callee1)
+    interp.copy_symbols_and_merge_into(main, callee2)
+
+    # CHECK: @print_root_via_include_module
+    # CHECK: transform.named_sequence @__transform_main
+    # CHECK: transform.include @callee2
+    #
+    # CHECK: transform.named_sequence @callee1
+    # CHECK: transform.print
+    #
+    # CHECK: transform.named_sequence @callee2
+    # CHECK: transform.include @callee1
+    interp.apply_named_sequence(main, main.body.operations[0], main)
+
+
+ at test_in_context
+def partial_include():
+    main = ir.Module.parse(print_root_via_include_module)
+    callee2 = ir.Module.parse(callee2_definition)
+    interp.copy_symbols_and_merge_into(main, callee2)
+
+    try:
+        interp.apply_named_sequence(main, main.body.operations[0], main)
+    except ValueError as e:
+        assert "Failed to apply" in str(e)
+
+
+ at test_in_context
+def repeated_include():
+    main = ir.Module.parse(print_root_via_include_module)
+    callee2 = ir.Module.parse(callee2_definition)
+    interp.copy_symbols_and_merge_into(main, callee2)
+
+    try:
+        interp.copy_symbols_and_merge_into(main, callee2)
+    except ValueError as e:
+        assert "doubly defined symbol @callee2" in str(e)


        


More information about the Mlir-commits mailing list