[Mlir-commits] [mlir] [mlir] expose transform interpreter to Python (PR #82365)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Feb 20 07:01:56 PST 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Oleksandr "Alex" Zinenko (ftynse)
<details>
<summary>Changes</summary>
Transform interpreter functionality can be used standalone without going through the interpreter pass, make it available in Python.
---
Patch is 23.85 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/82365.diff
13 Files Affected:
- (added) mlir/include/mlir-c/Dialect/Transform/Interpreter.h (+77)
- (modified) mlir/include/mlir/Bindings/Python/PybindAdaptors.h (+36)
- (modified) mlir/lib/Bindings/Python/DialectLLVM.cpp (-31)
- (modified) mlir/lib/Bindings/Python/IRCore.cpp (+7)
- (modified) mlir/lib/Bindings/Python/IRModule.h (+1)
- (added) mlir/lib/Bindings/Python/TransformInterpreter.cpp (+90)
- (modified) mlir/lib/CAPI/Dialect/CMakeLists.txt (+9)
- (added) mlir/lib/CAPI/Dialect/TransformInterpreter.cpp (+74)
- (modified) mlir/python/CMakeLists.txt (+19)
- (added) mlir/python/mlir/dialects/transform/interpreter/__init__.py (+33)
- (modified) mlir/test/CAPI/CMakeLists.txt (+9)
- (added) mlir/test/CAPI/transform_interpreter.c (+69)
- (added) mlir/test/python/dialects/transform_interpreter.py (+55)
``````````diff
diff --git a/mlir/include/mlir-c/Dialect/Transform/Interpreter.h b/mlir/include/mlir-c/Dialect/Transform/Interpreter.h
new file mode 100644
index 00000000000000..00095d5040a0e5
--- /dev/null
+++ b/mlir/include/mlir-c/Dialect/Transform/Interpreter.h
@@ -0,0 +1,77 @@
+//===-- mlir-c/Dialect/Transform/Interpreter.h --------------------*- C -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM
+// Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// C interface to the transform dialect interpreter.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir-c/IR.h"
+#include "mlir-c/Support.h"
+
+#ifdef __cplusplus
+extern "C" {
+#endif
+
+#define DEFINE_C_API_STRUCT(name, storage) \
+ struct name { \
+ storage *ptr; \
+ }; \
+ typedef struct name name
+
+DEFINE_C_API_STRUCT(MlirTransformOptions, void);
+
+#undef DEFINE_C_API_STRUCT
+
+//----------------------------------------------------------------------------//
+// MlirTransformOptions
+//----------------------------------------------------------------------------//
+
+/// Creates a default-initialized transform options object.
+MLIR_CAPI_EXPORTED MlirTransformOptions mlirTransformOptionsCreate(void);
+
+/// Enables or disables expensive checks in transform options.
+MLIR_CAPI_EXPORTED void
+mlirTransformOptionsEnableExpensiveChecks(MlirTransformOptions transformOptions,
+ bool enable);
+
+/// Returns true if expensive checks are enabled in transform options.
+MLIR_CAPI_EXPORTED bool mlirTransformOptionsGetExpensiveChecksEnabled(
+ MlirTransformOptions transformOptions);
+
+/// Enables or disables the enforcement of the top-level transform op being
+/// single in transform options.
+MLIR_CAPI_EXPORTED void mlirTransformOptionsEnforceSingleTopLevelTransformOp(
+ MlirTransformOptions transformOptions, bool enable);
+
+/// Returns true if the enforcement of the top-level transform op being single
+/// is enabled in transform options.
+MLIR_CAPI_EXPORTED bool mlirTransformOptionsGetEnforceSingleTopLevelTransformOp(
+ MlirTransformOptions transformOptions);
+
+/// Destroys a transform options object previously created by
+/// mlirTransformOptionsCreate.
+MLIR_CAPI_EXPORTED void
+mlirTransformOptionsDestroy(MlirTransformOptions transformOptions);
+
+//----------------------------------------------------------------------------//
+// Transform interpreter.
+//----------------------------------------------------------------------------//
+
+/// Applies the transformation script starting at the given transform root
+/// operation to the given payload operation. The module containing the
+/// transform root as well as the transform options should be provided. The
+/// transform operation must implement TransformOpInterface and the module must
+/// be a ModuleOp. Returns the status of the application.
+MLIR_CAPI_EXPORTED MlirLogicalResult mlirTransformApplyNamedSequence(
+ MlirOperation payload, MlirOperation transformRoot,
+ MlirOperation transformModule, MlirTransformOptions transformOptions);
+
+#ifdef __cplusplus
+}
+#endif
diff --git a/mlir/include/mlir/Bindings/Python/PybindAdaptors.h b/mlir/include/mlir/Bindings/Python/PybindAdaptors.h
index 66cf20e1c136f9..52f6321251919e 100644
--- a/mlir/include/mlir/Bindings/Python/PybindAdaptors.h
+++ b/mlir/include/mlir/Bindings/Python/PybindAdaptors.h
@@ -23,6 +23,7 @@
#include <pybind11/stl.h>
#include "mlir-c/Bindings/Python/Interop.h"
+#include "mlir-c/Diagnostics.h"
#include "mlir-c/IR.h"
#include "llvm/ADT/Twine.h"
@@ -569,6 +570,41 @@ class mlir_value_subclass : public pure_subclass {
};
} // namespace adaptors
+
+/// RAII scope intercepting all diagnostics into a string. The message must be
+/// checked before this goes out of scope.
+class CollectDiagnosticsToStringScope {
+public:
+ explicit CollectDiagnosticsToStringScope(MlirContext ctx) : context(ctx) {
+ handlerID = mlirContextAttachDiagnosticHandler(ctx, &handler, &errorMessage,
+ /*deleteUserData=*/nullptr);
+ }
+ ~CollectDiagnosticsToStringScope() {
+ assert(errorMessage.empty() && "unchecked error message");
+ mlirContextDetachDiagnosticHandler(context, handlerID);
+ }
+
+ [[nodiscard]] std::string takeMessage() { return std::move(errorMessage); }
+
+private:
+ static MlirLogicalResult handler(MlirDiagnostic diag, void *data) {
+ auto printer = +[](MlirStringRef message, void *data) {
+ *static_cast<std::string *>(data) +=
+ llvm::StringRef(message.data, message.length);
+ };
+ MlirLocation loc = mlirDiagnosticGetLocation(diag);
+ *static_cast<std::string *>(data) += "at ";
+ mlirLocationPrint(loc, printer, data);
+ *static_cast<std::string *>(data) += ": ";
+ mlirDiagnosticPrint(diag, printer, data);
+ return mlirLogicalResultSuccess();
+ }
+
+ MlirContext context;
+ MlirDiagnosticHandlerID handlerID;
+ std::string errorMessage = "";
+};
+
} // namespace python
} // namespace mlir
diff --git a/mlir/lib/Bindings/Python/DialectLLVM.cpp b/mlir/lib/Bindings/Python/DialectLLVM.cpp
index 780f5eacf0b8e5..843707751dd849 100644
--- a/mlir/lib/Bindings/Python/DialectLLVM.cpp
+++ b/mlir/lib/Bindings/Python/DialectLLVM.cpp
@@ -6,7 +6,6 @@
//
//===----------------------------------------------------------------------===//
-#include "mlir-c/Diagnostics.h"
#include "mlir-c/Dialect/LLVM.h"
#include "mlir-c/IR.h"
#include "mlir-c/Support.h"
@@ -19,36 +18,6 @@ using namespace mlir;
using namespace mlir::python;
using namespace mlir::python::adaptors;
-/// RAII scope intercepting all diagnostics into a string. The message must be
-/// checked before this goes out of scope.
-class CollectDiagnosticsToStringScope {
-public:
- explicit CollectDiagnosticsToStringScope(MlirContext ctx) : context(ctx) {
- handlerID = mlirContextAttachDiagnosticHandler(ctx, &handler, &errorMessage,
- /*deleteUserData=*/nullptr);
- }
- ~CollectDiagnosticsToStringScope() {
- assert(errorMessage.empty() && "unchecked error message");
- mlirContextDetachDiagnosticHandler(context, handlerID);
- }
-
- [[nodiscard]] std::string takeMessage() { return std::move(errorMessage); }
-
-private:
- static MlirLogicalResult handler(MlirDiagnostic diag, void *data) {
- auto printer = +[](MlirStringRef message, void *data) {
- *static_cast<std::string *>(data) +=
- StringRef(message.data, message.length);
- };
- mlirDiagnosticPrint(diag, printer, data);
- return mlirLogicalResultSuccess();
- }
-
- MlirContext context;
- MlirDiagnosticHandlerID handlerID;
- std::string errorMessage = "";
-};
-
void populateDialectLLVMSubmodule(const pybind11::module &m) {
auto llvmStructType =
mlir_type_subclass(m, "StructType", mlirTypeIsALLVMStructType);
diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index 8a7951dc29fe5f..734f2f7f3f94cf 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -678,6 +678,10 @@ void PyMlirContext::clearOperationsInside(PyOperationBase &op) {
mlirOperationWalk(op.getOperation(), invalidatingCallback,
static_cast<void *>(&data), MlirWalkPreOrder);
}
+void PyMlirContext::clearOperationsInside(MlirOperation op) {
+ PyOperationRef opRef = PyOperation::forOperation(getRef(), op);
+ clearOperationsInside(opRef->getOperation());
+}
size_t PyMlirContext::getLiveModuleCount() { return liveModules.size(); }
@@ -2556,6 +2560,9 @@ void mlir::python::populateIRCore(py::module &m) {
.def("_get_live_operation_objects",
&PyMlirContext::getLiveOperationObjects)
.def("_clear_live_operations", &PyMlirContext::clearLiveOperations)
+ .def("_clear_live_operations_inside",
+ py::overload_cast<MlirOperation>(
+ &PyMlirContext::clearOperationsInside))
.def("_get_live_module_count", &PyMlirContext::getLiveModuleCount)
.def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
&PyMlirContext::getCapsule)
diff --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h
index 48f39c939340d7..9acfdde25ae047 100644
--- a/mlir/lib/Bindings/Python/IRModule.h
+++ b/mlir/lib/Bindings/Python/IRModule.h
@@ -223,6 +223,7 @@ class PyMlirContext {
/// Clears all operations nested inside the given op using
/// `clearOperation(MlirOperation)`.
void clearOperationsInside(PyOperationBase &op);
+ void clearOperationsInside(MlirOperation op);
/// Gets the count of live modules associated with this context.
/// Used for testing.
diff --git a/mlir/lib/Bindings/Python/TransformInterpreter.cpp b/mlir/lib/Bindings/Python/TransformInterpreter.cpp
new file mode 100644
index 00000000000000..6517f8c39dfadd
--- /dev/null
+++ b/mlir/lib/Bindings/Python/TransformInterpreter.cpp
@@ -0,0 +1,90 @@
+//===- TransformInterpreter.cpp -------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Pybind classes for the transform dialect interpreter.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir-c/Dialect/Transform/Interpreter.h"
+#include "mlir-c/IR.h"
+#include "mlir-c/Support.h"
+#include "mlir/Bindings/Python/PybindAdaptors.h"
+
+#include <pybind11/detail/common.h>
+#include <pybind11/pybind11.h>
+
+namespace py = pybind11;
+
+namespace {
+struct PyMlirTransformOptions {
+ PyMlirTransformOptions() { options = mlirTransformOptionsCreate(); };
+ PyMlirTransformOptions(PyMlirTransformOptions &&other) {
+ options = other.options;
+ other.options.ptr = nullptr;
+ }
+ PyMlirTransformOptions(const PyMlirTransformOptions &) = delete;
+
+ ~PyMlirTransformOptions() { mlirTransformOptionsDestroy(options); }
+
+ MlirTransformOptions options;
+};
+} // namespace
+
+static void populateTransformInterpreterSubmodule(py::module &m) {
+ py::class_<PyMlirTransformOptions>(m, "TransformOptions", py::module_local())
+ .def(py::init())
+ .def_property(
+ "expensive_checks",
+ [](const PyMlirTransformOptions &self) {
+ return mlirTransformOptionsGetExpensiveChecksEnabled(self.options);
+ },
+ [](PyMlirTransformOptions &self, bool value) {
+ mlirTransformOptionsEnableExpensiveChecks(self.options, value);
+ })
+ .def_property(
+ "enforce_single_top_level_transform_op",
+ [](const PyMlirTransformOptions &self) {
+ return mlirTransformOptionsGetEnforceSingleTopLevelTransformOp(
+ self.options);
+ },
+ [](PyMlirTransformOptions &self, bool value) {
+ mlirTransformOptionsEnforceSingleTopLevelTransformOp(self.options,
+ value);
+ });
+
+ m.def(
+ "apply_named_sequence",
+ [](MlirOperation payloadRoot, MlirOperation transformRoot,
+ MlirOperation transformModule, const PyMlirTransformOptions &options) {
+ mlir::python::CollectDiagnosticsToStringScope scope(
+ mlirOperationGetContext(transformRoot));
+
+ // Calling back into Python to invalidate everything under the payload
+ // root. This is awkward, but we don't have access to PyMlirContext
+ // object here otherwise.
+ py::object obj = py::cast(payloadRoot);
+ obj.attr("context").attr("_clear_live_operations_inside")(payloadRoot);
+
+ MlirLogicalResult result = mlirTransformApplyNamedSequence(
+ payloadRoot, transformRoot, transformModule, options.options);
+ if (mlirLogicalResultIsSuccess(result))
+ return;
+
+ throw py::value_error(
+ "Failed to apply named transform sequence.\nDiagnostic message " +
+ scope.takeMessage());
+ },
+ py::arg("payload_root"), py::arg("transform_root"),
+ py::arg("transform_module"),
+ py::arg("transform_options") = PyMlirTransformOptions());
+}
+
+PYBIND11_MODULE(_mlirTransformInterpreter, m) {
+ m.doc() = "MLIR Transform dialect interpreter functionality.";
+ populateTransformInterpreterSubmodule(m);
+}
diff --git a/mlir/lib/CAPI/Dialect/CMakeLists.txt b/mlir/lib/CAPI/Dialect/CMakeLists.txt
index b2952da17a41c0..58b8739043f9df 100644
--- a/mlir/lib/CAPI/Dialect/CMakeLists.txt
+++ b/mlir/lib/CAPI/Dialect/CMakeLists.txt
@@ -198,6 +198,15 @@ add_mlir_upstream_c_api_library(MLIRCAPITransformDialect
MLIRTransformDialect
)
+add_mlir_upstream_c_api_library(MLIRCAPITransformDialectTransforms
+ TransformInterpreter.cpp
+
+ PARTIAL_SOURCES_INTENDED
+ LINK_LIBS PUBLIC
+ MLIRCAPIIR
+ MLIRTransformDialectTransforms
+)
+
add_mlir_upstream_c_api_library(MLIRCAPIQuant
Quant.cpp
diff --git a/mlir/lib/CAPI/Dialect/TransformInterpreter.cpp b/mlir/lib/CAPI/Dialect/TransformInterpreter.cpp
new file mode 100644
index 00000000000000..6a2cfb235fcfd4
--- /dev/null
+++ b/mlir/lib/CAPI/Dialect/TransformInterpreter.cpp
@@ -0,0 +1,74 @@
+//===- TransformTransforms.cpp - C Interface for Transform dialect --------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// C interface to transforms for the transform dialect.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir-c/Dialect/Transform/Interpreter.h"
+#include "mlir-c/Support.h"
+#include "mlir/CAPI/IR.h"
+#include "mlir/CAPI/Support.h"
+#include "mlir/CAPI/Wrap.h"
+#include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
+#include "mlir/Dialect/Transform/Transforms/TransformInterpreterUtils.h"
+
+using namespace mlir;
+
+DEFINE_C_API_PTR_METHODS(MlirTransformOptions, transform::TransformOptions)
+
+extern "C" {
+
+MlirTransformOptions mlirTransformOptionsCreate() {
+ return wrap(new transform::TransformOptions);
+}
+
+void mlirTransformOptionsEnableExpensiveChecks(
+ MlirTransformOptions transformOptions, bool enable) {
+ unwrap(transformOptions)->enableExpensiveChecks(enable);
+}
+
+bool mlirTransformOptionsGetExpensiveChecksEnabled(
+ MlirTransformOptions transformOptions) {
+ return unwrap(transformOptions)->getExpensiveChecksEnabled();
+}
+
+void mlirTransformOptionsEnforceSingleTopLevelTransformOp(
+ MlirTransformOptions transformOptions, bool enable) {
+ unwrap(transformOptions)->enableEnforceSingleToplevelTransformOp(enable);
+}
+
+bool mlirTransformOptionsGetEnforceSingleTopLevelTransformOp(
+ MlirTransformOptions transformOptions) {
+ return unwrap(transformOptions)->getEnforceSingleToplevelTransformOp();
+}
+
+void mlirTransformOptionsDestroy(MlirTransformOptions transformOptions) {
+ delete unwrap(transformOptions);
+}
+
+MlirLogicalResult mlirTransformApplyNamedSequence(
+ MlirOperation payload, MlirOperation transformRoot,
+ MlirOperation transformModule, MlirTransformOptions transformOptions) {
+ Operation *transformRootOp = unwrap(transformRoot);
+ Operation *transformModuleOp = unwrap(transformModule);
+ if (!isa<transform::TransformOpInterface>(transformRootOp)) {
+ transformRootOp->emitError()
+ << "must implement TransformOpInterface to be used as transform root";
+ return mlirLogicalResultFailure();
+ }
+ if (!isa<ModuleOp>(transformModuleOp)) {
+ transformModuleOp->emitError()
+ << "must be a " << ModuleOp::getOperationName();
+ return mlirLogicalResultFailure();
+ }
+ return wrap(transform::applyTransformNamedSequence(
+ unwrap(payload), unwrap(transformRoot),
+ cast<ModuleOp>(unwrap(transformModule)), *unwrap(transformOptions)));
+}
+}
diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt
index ed167afeb69a62..563d035f155267 100644
--- a/mlir/python/CMakeLists.txt
+++ b/mlir/python/CMakeLists.txt
@@ -181,6 +181,13 @@ declare_mlir_python_sources(
SOURCES
dialects/transform/extras/__init__.py)
+declare_mlir_python_sources(
+ MLIRPythonSources.Dialects.transform.interpreter
+ ADD_TO_PARENT MLIRPythonSources.Dialects
+ ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
+ SOURCES
+ dialects/transform/interpreter/__init__.py)
+
declare_mlir_dialect_extension_python_bindings(
ADD_TO_PARENT MLIRPythonSources.Dialects
ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
@@ -609,6 +616,18 @@ declare_mlir_python_extension(MLIRPythonExtension.SparseTensorDialectPasses
MLIRCAPISparseTensor
)
+declare_mlir_python_extension(MLIRPythonExtension.TransformInterpreter
+ MODULE_NAME _mlirTransformInterpreter
+ ADD_TO_PARENT MLIRPythonSources.Dialects.transform
+ ROOT_DIR "${PYTHON_SOURCE_DIR}"
+ SOURCES
+ TransformInterpreter.cpp
+ PRIVATE_LINK_LIBS
+ LLVMSupport
+ EMBED_CAPI_LINK_LIBS
+ MLIRCAPITransformDialectTransforms
+)
+
# TODO: Figure out how to put this in the test tree.
# This should not be included in the main Python extension. However,
# putting it into MLIRPythonTestSources along with the dialect declaration
diff --git a/mlir/python/mlir/dialects/transform/interpreter/__init__.py b/mlir/python/mlir/dialects/transform/interpreter/__init__.py
new file mode 100644
index 00000000000000..6145b99224eb54
--- /dev/null
+++ b/mlir/python/mlir/dialects/transform/interpreter/__init__.py
@@ -0,0 +1,33 @@
+# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+from ....ir import Operation
+from ...._mlir_libs import _mlirTransformInterpreter as _cextTransformInterpreter
+
+
+TransformOptions = _cextTransformInterpreter.TransformOptions
+
+
+def _unpack_operation(op):
+ if isinstance(op, Operation):
+ return op
+ return op.operation
+
+
+def apply_named_sequence(
+ payload_root, transform_root, transform_module, transform_options=None
+):
+ """Applies the transformation script starting at the given transform root
+ operation to the given payload operation. The module containing the
+ transform root as well as the transform options should be provided.
+ The transform operation must implement TransformOpInterface and the module
+ must be a ModuleOp."""
+
+ args = tuple(
+ map(_unpack_operation, (payload_root, transform_root, transform_module))
+ )
+ if transform_options is None:
+ _cextTransformInterpreter.apply_named_sequence(*args)
+ else:
+ _cextTransformInterpreter(*args, transform_options)
diff --git a/mlir/test/CAPI/CMakeLists.txt b/mlir/test/CAPI/CMakeLists.txt
index 1096a3b0806648..79b61fdef38b49 100644
--- a/mlir/test/CAPI/CMakeLists.txt
+++ b/mlir/test/CAPI/CMakeLists.txt
@@ -86,6 +86,15 @@ _add_capi_test_executable(mlir-capi-transform-test
MLIRCAPITransformDialect
)
+_add_capi_test_executable(mlir-capi-transform-interpreter-test
+ transform_interpreter.c
+ LINK_LIBS PRIVATE
+ MLIRCAPIIR
+ MLIRCAPIRegisterEverything
+ MLIRCAPITransformDialect
+ MLIRCAPITransformDialectTransforms
+)
+
_add_capi_test_executable(mlir-capi-translation-test
translation.c
LINK_LIBS PRIVATE
diff --git a/mlir/test/CAPI/transform_interpreter.c b/mlir/test/CAPI/transform_interpreter.c
new file mode 100644
index 00000000000000..8fe37b47b7f874
--- /dev/null
+++ b/mlir/test/CAPI/transform_interpreter.c
@@ -0,0 +1,69 @@
+//===- transform_interpreter.c - Test of the Transform interpreter C API --===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM
+/...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/82365
More information about the Mlir-commits
mailing list