[Mlir-commits] [mlir] [mlir] expose transform interpreter to Python (PR #82365)
Oleksandr Alex Zinenko
llvmlistbot at llvm.org
Wed Feb 21 01:31:31 PST 2024
https://github.com/ftynse updated https://github.com/llvm/llvm-project/pull/82365
>From 3dd816b6fa19764cd98963fe10b74d51a9dfcfc9 Mon Sep 17 00:00:00 2001
From: Alex Zinenko <zinenko at google.com>
Date: Tue, 20 Feb 2024 14:55:31 +0000
Subject: [PATCH] [mlir] expose transform interpreter to Python
Transform interpreter functionality can be used standalone without going
through the interpreter pass, make it available in Python.
---
.../mlir-c/Dialect/Transform/Interpreter.h | 77 ++++++++++++++++
.../mlir/Bindings/Python/PybindAdaptors.h | 36 ++++++++
mlir/lib/Bindings/Python/DialectLLVM.cpp | 31 -------
mlir/lib/Bindings/Python/IRCore.cpp | 7 ++
mlir/lib/Bindings/Python/IRModule.h | 1 +
.../Bindings/Python/TransformInterpreter.cpp | 90 +++++++++++++++++++
mlir/lib/CAPI/Dialect/CMakeLists.txt | 9 ++
.../lib/CAPI/Dialect/TransformInterpreter.cpp | 74 +++++++++++++++
mlir/python/CMakeLists.txt | 19 ++++
.../transform/interpreter/__init__.py | 33 +++++++
mlir/test/CAPI/CMakeLists.txt | 9 ++
mlir/test/CAPI/transform_interpreter.c | 69 ++++++++++++++
mlir/test/CMakeLists.txt | 1 +
mlir/test/lit.cfg.py | 1 +
.../python/dialects/transform_interpreter.py | 56 ++++++++++++
15 files changed, 482 insertions(+), 31 deletions(-)
create mode 100644 mlir/include/mlir-c/Dialect/Transform/Interpreter.h
create mode 100644 mlir/lib/Bindings/Python/TransformInterpreter.cpp
create mode 100644 mlir/lib/CAPI/Dialect/TransformInterpreter.cpp
create mode 100644 mlir/python/mlir/dialects/transform/interpreter/__init__.py
create mode 100644 mlir/test/CAPI/transform_interpreter.c
create mode 100644 mlir/test/python/dialects/transform_interpreter.py
diff --git a/mlir/include/mlir-c/Dialect/Transform/Interpreter.h b/mlir/include/mlir-c/Dialect/Transform/Interpreter.h
new file mode 100644
index 00000000000000..00095d5040a0e5
--- /dev/null
+++ b/mlir/include/mlir-c/Dialect/Transform/Interpreter.h
@@ -0,0 +1,77 @@
+//===-- mlir-c/Dialect/Transform/Interpreter.h --------------------*- C -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM
+// Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// C interface to the transform dialect interpreter.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir-c/IR.h"
+#include "mlir-c/Support.h"
+
+#ifdef __cplusplus
+extern "C" {
+#endif
+
+#define DEFINE_C_API_STRUCT(name, storage) \
+ struct name { \
+ storage *ptr; \
+ }; \
+ typedef struct name name
+
+DEFINE_C_API_STRUCT(MlirTransformOptions, void);
+
+#undef DEFINE_C_API_STRUCT
+
+//----------------------------------------------------------------------------//
+// MlirTransformOptions
+//----------------------------------------------------------------------------//
+
+/// Creates a default-initialized transform options object.
+MLIR_CAPI_EXPORTED MlirTransformOptions mlirTransformOptionsCreate(void);
+
+/// Enables or disables expensive checks in transform options.
+MLIR_CAPI_EXPORTED void
+mlirTransformOptionsEnableExpensiveChecks(MlirTransformOptions transformOptions,
+ bool enable);
+
+/// Returns true if expensive checks are enabled in transform options.
+MLIR_CAPI_EXPORTED bool mlirTransformOptionsGetExpensiveChecksEnabled(
+ MlirTransformOptions transformOptions);
+
+/// Enables or disables the enforcement of the top-level transform op being
+/// single in transform options.
+MLIR_CAPI_EXPORTED void mlirTransformOptionsEnforceSingleTopLevelTransformOp(
+ MlirTransformOptions transformOptions, bool enable);
+
+/// Returns true if the enforcement of the top-level transform op being single
+/// is enabled in transform options.
+MLIR_CAPI_EXPORTED bool mlirTransformOptionsGetEnforceSingleTopLevelTransformOp(
+ MlirTransformOptions transformOptions);
+
+/// Destroys a transform options object previously created by
+/// mlirTransformOptionsCreate.
+MLIR_CAPI_EXPORTED void
+mlirTransformOptionsDestroy(MlirTransformOptions transformOptions);
+
+//----------------------------------------------------------------------------//
+// Transform interpreter.
+//----------------------------------------------------------------------------//
+
+/// Applies the transformation script starting at the given transform root
+/// operation to the given payload operation. The module containing the
+/// transform root as well as the transform options should be provided. The
+/// transform operation must implement TransformOpInterface and the module must
+/// be a ModuleOp. Returns the status of the application.
+MLIR_CAPI_EXPORTED MlirLogicalResult mlirTransformApplyNamedSequence(
+ MlirOperation payload, MlirOperation transformRoot,
+ MlirOperation transformModule, MlirTransformOptions transformOptions);
+
+#ifdef __cplusplus
+}
+#endif
diff --git a/mlir/include/mlir/Bindings/Python/PybindAdaptors.h b/mlir/include/mlir/Bindings/Python/PybindAdaptors.h
index 66cf20e1c136f9..52f6321251919e 100644
--- a/mlir/include/mlir/Bindings/Python/PybindAdaptors.h
+++ b/mlir/include/mlir/Bindings/Python/PybindAdaptors.h
@@ -23,6 +23,7 @@
#include <pybind11/stl.h>
#include "mlir-c/Bindings/Python/Interop.h"
+#include "mlir-c/Diagnostics.h"
#include "mlir-c/IR.h"
#include "llvm/ADT/Twine.h"
@@ -569,6 +570,41 @@ class mlir_value_subclass : public pure_subclass {
};
} // namespace adaptors
+
+/// RAII scope intercepting all diagnostics into a string. The message must be
+/// checked before this goes out of scope.
+class CollectDiagnosticsToStringScope {
+public:
+ explicit CollectDiagnosticsToStringScope(MlirContext ctx) : context(ctx) {
+ handlerID = mlirContextAttachDiagnosticHandler(ctx, &handler, &errorMessage,
+ /*deleteUserData=*/nullptr);
+ }
+ ~CollectDiagnosticsToStringScope() {
+ assert(errorMessage.empty() && "unchecked error message");
+ mlirContextDetachDiagnosticHandler(context, handlerID);
+ }
+
+ [[nodiscard]] std::string takeMessage() { return std::move(errorMessage); }
+
+private:
+ static MlirLogicalResult handler(MlirDiagnostic diag, void *data) {
+ auto printer = +[](MlirStringRef message, void *data) {
+ *static_cast<std::string *>(data) +=
+ llvm::StringRef(message.data, message.length);
+ };
+ MlirLocation loc = mlirDiagnosticGetLocation(diag);
+ *static_cast<std::string *>(data) += "at ";
+ mlirLocationPrint(loc, printer, data);
+ *static_cast<std::string *>(data) += ": ";
+ mlirDiagnosticPrint(diag, printer, data);
+ return mlirLogicalResultSuccess();
+ }
+
+ MlirContext context;
+ MlirDiagnosticHandlerID handlerID;
+ std::string errorMessage = "";
+};
+
} // namespace python
} // namespace mlir
diff --git a/mlir/lib/Bindings/Python/DialectLLVM.cpp b/mlir/lib/Bindings/Python/DialectLLVM.cpp
index 780f5eacf0b8e5..843707751dd849 100644
--- a/mlir/lib/Bindings/Python/DialectLLVM.cpp
+++ b/mlir/lib/Bindings/Python/DialectLLVM.cpp
@@ -6,7 +6,6 @@
//
//===----------------------------------------------------------------------===//
-#include "mlir-c/Diagnostics.h"
#include "mlir-c/Dialect/LLVM.h"
#include "mlir-c/IR.h"
#include "mlir-c/Support.h"
@@ -19,36 +18,6 @@ using namespace mlir;
using namespace mlir::python;
using namespace mlir::python::adaptors;
-/// RAII scope intercepting all diagnostics into a string. The message must be
-/// checked before this goes out of scope.
-class CollectDiagnosticsToStringScope {
-public:
- explicit CollectDiagnosticsToStringScope(MlirContext ctx) : context(ctx) {
- handlerID = mlirContextAttachDiagnosticHandler(ctx, &handler, &errorMessage,
- /*deleteUserData=*/nullptr);
- }
- ~CollectDiagnosticsToStringScope() {
- assert(errorMessage.empty() && "unchecked error message");
- mlirContextDetachDiagnosticHandler(context, handlerID);
- }
-
- [[nodiscard]] std::string takeMessage() { return std::move(errorMessage); }
-
-private:
- static MlirLogicalResult handler(MlirDiagnostic diag, void *data) {
- auto printer = +[](MlirStringRef message, void *data) {
- *static_cast<std::string *>(data) +=
- StringRef(message.data, message.length);
- };
- mlirDiagnosticPrint(diag, printer, data);
- return mlirLogicalResultSuccess();
- }
-
- MlirContext context;
- MlirDiagnosticHandlerID handlerID;
- std::string errorMessage = "";
-};
-
void populateDialectLLVMSubmodule(const pybind11::module &m) {
auto llvmStructType =
mlir_type_subclass(m, "StructType", mlirTypeIsALLVMStructType);
diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index 8a7951dc29fe5f..734f2f7f3f94cf 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -678,6 +678,10 @@ void PyMlirContext::clearOperationsInside(PyOperationBase &op) {
mlirOperationWalk(op.getOperation(), invalidatingCallback,
static_cast<void *>(&data), MlirWalkPreOrder);
}
+void PyMlirContext::clearOperationsInside(MlirOperation op) {
+ PyOperationRef opRef = PyOperation::forOperation(getRef(), op);
+ clearOperationsInside(opRef->getOperation());
+}
size_t PyMlirContext::getLiveModuleCount() { return liveModules.size(); }
@@ -2556,6 +2560,9 @@ void mlir::python::populateIRCore(py::module &m) {
.def("_get_live_operation_objects",
&PyMlirContext::getLiveOperationObjects)
.def("_clear_live_operations", &PyMlirContext::clearLiveOperations)
+ .def("_clear_live_operations_inside",
+ py::overload_cast<MlirOperation>(
+ &PyMlirContext::clearOperationsInside))
.def("_get_live_module_count", &PyMlirContext::getLiveModuleCount)
.def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
&PyMlirContext::getCapsule)
diff --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h
index 48f39c939340d7..9acfdde25ae047 100644
--- a/mlir/lib/Bindings/Python/IRModule.h
+++ b/mlir/lib/Bindings/Python/IRModule.h
@@ -223,6 +223,7 @@ class PyMlirContext {
/// Clears all operations nested inside the given op using
/// `clearOperation(MlirOperation)`.
void clearOperationsInside(PyOperationBase &op);
+ void clearOperationsInside(MlirOperation op);
/// Gets the count of live modules associated with this context.
/// Used for testing.
diff --git a/mlir/lib/Bindings/Python/TransformInterpreter.cpp b/mlir/lib/Bindings/Python/TransformInterpreter.cpp
new file mode 100644
index 00000000000000..6517f8c39dfadd
--- /dev/null
+++ b/mlir/lib/Bindings/Python/TransformInterpreter.cpp
@@ -0,0 +1,90 @@
+//===- TransformInterpreter.cpp -------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Pybind classes for the transform dialect interpreter.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir-c/Dialect/Transform/Interpreter.h"
+#include "mlir-c/IR.h"
+#include "mlir-c/Support.h"
+#include "mlir/Bindings/Python/PybindAdaptors.h"
+
+#include <pybind11/detail/common.h>
+#include <pybind11/pybind11.h>
+
+namespace py = pybind11;
+
+namespace {
+struct PyMlirTransformOptions {
+ PyMlirTransformOptions() { options = mlirTransformOptionsCreate(); };
+ PyMlirTransformOptions(PyMlirTransformOptions &&other) {
+ options = other.options;
+ other.options.ptr = nullptr;
+ }
+ PyMlirTransformOptions(const PyMlirTransformOptions &) = delete;
+
+ ~PyMlirTransformOptions() { mlirTransformOptionsDestroy(options); }
+
+ MlirTransformOptions options;
+};
+} // namespace
+
+static void populateTransformInterpreterSubmodule(py::module &m) {
+ py::class_<PyMlirTransformOptions>(m, "TransformOptions", py::module_local())
+ .def(py::init())
+ .def_property(
+ "expensive_checks",
+ [](const PyMlirTransformOptions &self) {
+ return mlirTransformOptionsGetExpensiveChecksEnabled(self.options);
+ },
+ [](PyMlirTransformOptions &self, bool value) {
+ mlirTransformOptionsEnableExpensiveChecks(self.options, value);
+ })
+ .def_property(
+ "enforce_single_top_level_transform_op",
+ [](const PyMlirTransformOptions &self) {
+ return mlirTransformOptionsGetEnforceSingleTopLevelTransformOp(
+ self.options);
+ },
+ [](PyMlirTransformOptions &self, bool value) {
+ mlirTransformOptionsEnforceSingleTopLevelTransformOp(self.options,
+ value);
+ });
+
+ m.def(
+ "apply_named_sequence",
+ [](MlirOperation payloadRoot, MlirOperation transformRoot,
+ MlirOperation transformModule, const PyMlirTransformOptions &options) {
+ mlir::python::CollectDiagnosticsToStringScope scope(
+ mlirOperationGetContext(transformRoot));
+
+ // Calling back into Python to invalidate everything under the payload
+ // root. This is awkward, but we don't have access to PyMlirContext
+ // object here otherwise.
+ py::object obj = py::cast(payloadRoot);
+ obj.attr("context").attr("_clear_live_operations_inside")(payloadRoot);
+
+ MlirLogicalResult result = mlirTransformApplyNamedSequence(
+ payloadRoot, transformRoot, transformModule, options.options);
+ if (mlirLogicalResultIsSuccess(result))
+ return;
+
+ throw py::value_error(
+ "Failed to apply named transform sequence.\nDiagnostic message " +
+ scope.takeMessage());
+ },
+ py::arg("payload_root"), py::arg("transform_root"),
+ py::arg("transform_module"),
+ py::arg("transform_options") = PyMlirTransformOptions());
+}
+
+PYBIND11_MODULE(_mlirTransformInterpreter, m) {
+ m.doc() = "MLIR Transform dialect interpreter functionality.";
+ populateTransformInterpreterSubmodule(m);
+}
diff --git a/mlir/lib/CAPI/Dialect/CMakeLists.txt b/mlir/lib/CAPI/Dialect/CMakeLists.txt
index b2952da17a41c0..58b8739043f9df 100644
--- a/mlir/lib/CAPI/Dialect/CMakeLists.txt
+++ b/mlir/lib/CAPI/Dialect/CMakeLists.txt
@@ -198,6 +198,15 @@ add_mlir_upstream_c_api_library(MLIRCAPITransformDialect
MLIRTransformDialect
)
+add_mlir_upstream_c_api_library(MLIRCAPITransformDialectTransforms
+ TransformInterpreter.cpp
+
+ PARTIAL_SOURCES_INTENDED
+ LINK_LIBS PUBLIC
+ MLIRCAPIIR
+ MLIRTransformDialectTransforms
+)
+
add_mlir_upstream_c_api_library(MLIRCAPIQuant
Quant.cpp
diff --git a/mlir/lib/CAPI/Dialect/TransformInterpreter.cpp b/mlir/lib/CAPI/Dialect/TransformInterpreter.cpp
new file mode 100644
index 00000000000000..6a2cfb235fcfd4
--- /dev/null
+++ b/mlir/lib/CAPI/Dialect/TransformInterpreter.cpp
@@ -0,0 +1,74 @@
+//===- TransformTransforms.cpp - C Interface for Transform dialect --------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// C interface to transforms for the transform dialect.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir-c/Dialect/Transform/Interpreter.h"
+#include "mlir-c/Support.h"
+#include "mlir/CAPI/IR.h"
+#include "mlir/CAPI/Support.h"
+#include "mlir/CAPI/Wrap.h"
+#include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
+#include "mlir/Dialect/Transform/Transforms/TransformInterpreterUtils.h"
+
+using namespace mlir;
+
+DEFINE_C_API_PTR_METHODS(MlirTransformOptions, transform::TransformOptions)
+
+extern "C" {
+
+MlirTransformOptions mlirTransformOptionsCreate() {
+ return wrap(new transform::TransformOptions);
+}
+
+void mlirTransformOptionsEnableExpensiveChecks(
+ MlirTransformOptions transformOptions, bool enable) {
+ unwrap(transformOptions)->enableExpensiveChecks(enable);
+}
+
+bool mlirTransformOptionsGetExpensiveChecksEnabled(
+ MlirTransformOptions transformOptions) {
+ return unwrap(transformOptions)->getExpensiveChecksEnabled();
+}
+
+void mlirTransformOptionsEnforceSingleTopLevelTransformOp(
+ MlirTransformOptions transformOptions, bool enable) {
+ unwrap(transformOptions)->enableEnforceSingleToplevelTransformOp(enable);
+}
+
+bool mlirTransformOptionsGetEnforceSingleTopLevelTransformOp(
+ MlirTransformOptions transformOptions) {
+ return unwrap(transformOptions)->getEnforceSingleToplevelTransformOp();
+}
+
+void mlirTransformOptionsDestroy(MlirTransformOptions transformOptions) {
+ delete unwrap(transformOptions);
+}
+
+MlirLogicalResult mlirTransformApplyNamedSequence(
+ MlirOperation payload, MlirOperation transformRoot,
+ MlirOperation transformModule, MlirTransformOptions transformOptions) {
+ Operation *transformRootOp = unwrap(transformRoot);
+ Operation *transformModuleOp = unwrap(transformModule);
+ if (!isa<transform::TransformOpInterface>(transformRootOp)) {
+ transformRootOp->emitError()
+ << "must implement TransformOpInterface to be used as transform root";
+ return mlirLogicalResultFailure();
+ }
+ if (!isa<ModuleOp>(transformModuleOp)) {
+ transformModuleOp->emitError()
+ << "must be a " << ModuleOp::getOperationName();
+ return mlirLogicalResultFailure();
+ }
+ return wrap(transform::applyTransformNamedSequence(
+ unwrap(payload), unwrap(transformRoot),
+ cast<ModuleOp>(unwrap(transformModule)), *unwrap(transformOptions)));
+}
+}
diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt
index ed167afeb69a62..563d035f155267 100644
--- a/mlir/python/CMakeLists.txt
+++ b/mlir/python/CMakeLists.txt
@@ -181,6 +181,13 @@ declare_mlir_python_sources(
SOURCES
dialects/transform/extras/__init__.py)
+declare_mlir_python_sources(
+ MLIRPythonSources.Dialects.transform.interpreter
+ ADD_TO_PARENT MLIRPythonSources.Dialects
+ ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
+ SOURCES
+ dialects/transform/interpreter/__init__.py)
+
declare_mlir_dialect_extension_python_bindings(
ADD_TO_PARENT MLIRPythonSources.Dialects
ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
@@ -609,6 +616,18 @@ declare_mlir_python_extension(MLIRPythonExtension.SparseTensorDialectPasses
MLIRCAPISparseTensor
)
+declare_mlir_python_extension(MLIRPythonExtension.TransformInterpreter
+ MODULE_NAME _mlirTransformInterpreter
+ ADD_TO_PARENT MLIRPythonSources.Dialects.transform
+ ROOT_DIR "${PYTHON_SOURCE_DIR}"
+ SOURCES
+ TransformInterpreter.cpp
+ PRIVATE_LINK_LIBS
+ LLVMSupport
+ EMBED_CAPI_LINK_LIBS
+ MLIRCAPITransformDialectTransforms
+)
+
# TODO: Figure out how to put this in the test tree.
# This should not be included in the main Python extension. However,
# putting it into MLIRPythonTestSources along with the dialect declaration
diff --git a/mlir/python/mlir/dialects/transform/interpreter/__init__.py b/mlir/python/mlir/dialects/transform/interpreter/__init__.py
new file mode 100644
index 00000000000000..6145b99224eb54
--- /dev/null
+++ b/mlir/python/mlir/dialects/transform/interpreter/__init__.py
@@ -0,0 +1,33 @@
+# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+from ....ir import Operation
+from ...._mlir_libs import _mlirTransformInterpreter as _cextTransformInterpreter
+
+
+TransformOptions = _cextTransformInterpreter.TransformOptions
+
+
+def _unpack_operation(op):
+ if isinstance(op, Operation):
+ return op
+ return op.operation
+
+
+def apply_named_sequence(
+ payload_root, transform_root, transform_module, transform_options=None
+):
+ """Applies the transformation script starting at the given transform root
+ operation to the given payload operation. The module containing the
+ transform root as well as the transform options should be provided.
+ The transform operation must implement TransformOpInterface and the module
+ must be a ModuleOp."""
+
+ args = tuple(
+ map(_unpack_operation, (payload_root, transform_root, transform_module))
+ )
+ if transform_options is None:
+ _cextTransformInterpreter.apply_named_sequence(*args)
+ else:
+ _cextTransformInterpreter(*args, transform_options)
diff --git a/mlir/test/CAPI/CMakeLists.txt b/mlir/test/CAPI/CMakeLists.txt
index 1096a3b0806648..79b61fdef38b49 100644
--- a/mlir/test/CAPI/CMakeLists.txt
+++ b/mlir/test/CAPI/CMakeLists.txt
@@ -86,6 +86,15 @@ _add_capi_test_executable(mlir-capi-transform-test
MLIRCAPITransformDialect
)
+_add_capi_test_executable(mlir-capi-transform-interpreter-test
+ transform_interpreter.c
+ LINK_LIBS PRIVATE
+ MLIRCAPIIR
+ MLIRCAPIRegisterEverything
+ MLIRCAPITransformDialect
+ MLIRCAPITransformDialectTransforms
+)
+
_add_capi_test_executable(mlir-capi-translation-test
translation.c
LINK_LIBS PRIVATE
diff --git a/mlir/test/CAPI/transform_interpreter.c b/mlir/test/CAPI/transform_interpreter.c
new file mode 100644
index 00000000000000..8fe37b47b7f874
--- /dev/null
+++ b/mlir/test/CAPI/transform_interpreter.c
@@ -0,0 +1,69 @@
+//===- transform_interpreter.c - Test of the Transform interpreter C API --===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM
+// Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+// RUN: mlir-capi-transform-interpreter-test 2>&1 | FileCheck %s
+
+#include "mlir-c/Dialect/Transform.h"
+#include "mlir-c/Dialect/Transform/Interpreter.h"
+#include "mlir-c/IR.h"
+#include "mlir-c/Support.h"
+
+#include <stdio.h>
+#include <stdlib.h>
+
+int testApplyNamedSequence(MlirContext ctx) {
+ fprintf(stderr, "%s\n", __FUNCTION__);
+
+ const char module[] =
+ "module attributes {transform.with_named_sequence} {"
+ " transform.named_sequence @__transform_main(%root: !transform.any_op) {"
+ " transform.print %root { name = \"from interpreter\" }: "
+ "!transform.any_op"
+ " transform.yield"
+ " }"
+ "}";
+
+ MlirStringRef moduleStringRef = mlirStringRefCreateFromCString(module);
+ MlirStringRef nameStringRef = mlirStringRefCreateFromCString("inline-module");
+
+ MlirOperation root =
+ mlirOperationCreateParse(ctx, moduleStringRef, nameStringRef);
+ if (mlirOperationIsNull(root))
+ return 1;
+ MlirBlock body = mlirRegionGetFirstBlock(mlirOperationGetRegion(root, 0));
+ MlirOperation entry = mlirBlockGetFirstOperation(body);
+
+ MlirTransformOptions options = mlirTransformOptionsCreate();
+ mlirTransformOptionsEnableExpensiveChecks(options, true);
+ mlirTransformOptionsEnforceSingleTopLevelTransformOp(options, true);
+
+ MlirLogicalResult result =
+ mlirTransformApplyNamedSequence(root, entry, root, options);
+ mlirTransformOptionsDestroy(options);
+ if (mlirLogicalResultIsFailure(result))
+ return 2;
+
+ return 0;
+}
+// CHECK-LABEL: testApplyNamedSequence
+// CHECK: from interpreter
+// CHECK: transform.named_sequence @__transform_main
+// CHECK: transform.print %arg0
+// CHECK: transform.yield
+
+int main(void) {
+ MlirContext ctx = mlirContextCreate();
+ mlirDialectHandleRegisterDialect(mlirGetDialectHandle__transform__(), ctx);
+ int result = testApplyNamedSequence(ctx);
+ mlirContextDestroy(ctx);
+ if (result)
+ return result;
+
+ return EXIT_SUCCESS;
+}
diff --git a/mlir/test/CMakeLists.txt b/mlir/test/CMakeLists.txt
index 6724dd4bdd1bcd..74921544c55578 100644
--- a/mlir/test/CMakeLists.txt
+++ b/mlir/test/CMakeLists.txt
@@ -100,6 +100,7 @@ set(MLIR_TEST_DEPENDS
mlir-capi-quant-test
mlir-capi-sparse-tensor-test
mlir-capi-transform-test
+ mlir-capi-transform-interpreter-test
mlir-capi-translation-test
mlir-linalg-ods-yaml-gen
mlir-lsp-server
diff --git a/mlir/test/lit.cfg.py b/mlir/test/lit.cfg.py
index 38e65e4549c559..904dfb680a0404 100644
--- a/mlir/test/lit.cfg.py
+++ b/mlir/test/lit.cfg.py
@@ -106,6 +106,7 @@ def add_runtime(name):
"mlir-capi-quant-test",
"mlir-capi-sparse-tensor-test",
"mlir-capi-transform-test",
+ "mlir-capi-transform-interpreter-test",
"mlir-capi-translation-test",
"mlir-cpu-runner",
add_runtime("mlir_runner_utils"),
diff --git a/mlir/test/python/dialects/transform_interpreter.py b/mlir/test/python/dialects/transform_interpreter.py
new file mode 100644
index 00000000000000..740c49f76a26c4
--- /dev/null
+++ b/mlir/test/python/dialects/transform_interpreter.py
@@ -0,0 +1,56 @@
+# RUN: %PYTHON %s | FileCheck %s
+
+from mlir import ir
+from mlir.dialects.transform import interpreter as interp
+
+
+def test_in_context(f):
+ with ir.Context(), ir.Location.unknown():
+ f()
+ return f
+
+
+print_root_module = """
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%root: !transform.any_op) {
+ transform.print %root { name = \"from interpreter\" }: !transform.any_op
+ transform.yield
+ }
+}"""
+
+
+ at test_in_context
+def print_self():
+ m = ir.Module.parse(print_root_module.replace("from interpreter", "print_self"))
+ interp.apply_named_sequence(m, m.body.operations[0], m)
+
+
+# CHECK-LABEL: print_self
+# CHECK: transform.named_sequence @__transform_main
+# CHECK: transform.print
+# CHECK: transform.yield
+
+
+ at test_in_context
+def print_other():
+ transform = ir.Module.parse(
+ print_root_module.replace("from interpreter", "print_other")
+ )
+ payload = ir.Module.parse("module attributes { this.is.payload } {}")
+ interp.apply_named_sequence(payload, transform.body.operations[0], transform)
+
+
+# CHECK-LABEL: print_other
+# CHECK-NOT: transform
+# CHECK: this.is.payload
+
+
+ at test_in_context
+def failed():
+ payload = ir.Module.parse("module attributes { this.is.payload } {}")
+ try:
+ interp.apply_named_sequence(payload, payload, payload)
+ except ValueError as e:
+ assert (
+ "must implement TransformOpInterface to be used as transform root" in str(e)
+ )
More information about the Mlir-commits
mailing list