[llvm] [mlir] [mlir] expose transform interpreter to Python (PR #82365)

Oleksandr Alex Zinenko via llvm-commits llvm-commits at lists.llvm.org
Wed Feb 21 02:00:59 PST 2024


https://github.com/ftynse updated https://github.com/llvm/llvm-project/pull/82365

>From 3dd816b6fa19764cd98963fe10b74d51a9dfcfc9 Mon Sep 17 00:00:00 2001
From: Alex Zinenko <zinenko at google.com>
Date: Tue, 20 Feb 2024 14:55:31 +0000
Subject: [PATCH 1/2] [mlir] expose transform interpreter to Python

Transform interpreter functionality can be used standalone without going
through the interpreter pass, make it available in Python.
---
 .../mlir-c/Dialect/Transform/Interpreter.h    | 77 ++++++++++++++++
 .../mlir/Bindings/Python/PybindAdaptors.h     | 36 ++++++++
 mlir/lib/Bindings/Python/DialectLLVM.cpp      | 31 -------
 mlir/lib/Bindings/Python/IRCore.cpp           |  7 ++
 mlir/lib/Bindings/Python/IRModule.h           |  1 +
 .../Bindings/Python/TransformInterpreter.cpp  | 90 +++++++++++++++++++
 mlir/lib/CAPI/Dialect/CMakeLists.txt          |  9 ++
 .../lib/CAPI/Dialect/TransformInterpreter.cpp | 74 +++++++++++++++
 mlir/python/CMakeLists.txt                    | 19 ++++
 .../transform/interpreter/__init__.py         | 33 +++++++
 mlir/test/CAPI/CMakeLists.txt                 |  9 ++
 mlir/test/CAPI/transform_interpreter.c        | 69 ++++++++++++++
 mlir/test/CMakeLists.txt                      |  1 +
 mlir/test/lit.cfg.py                          |  1 +
 .../python/dialects/transform_interpreter.py  | 56 ++++++++++++
 15 files changed, 482 insertions(+), 31 deletions(-)
 create mode 100644 mlir/include/mlir-c/Dialect/Transform/Interpreter.h
 create mode 100644 mlir/lib/Bindings/Python/TransformInterpreter.cpp
 create mode 100644 mlir/lib/CAPI/Dialect/TransformInterpreter.cpp
 create mode 100644 mlir/python/mlir/dialects/transform/interpreter/__init__.py
 create mode 100644 mlir/test/CAPI/transform_interpreter.c
 create mode 100644 mlir/test/python/dialects/transform_interpreter.py

diff --git a/mlir/include/mlir-c/Dialect/Transform/Interpreter.h b/mlir/include/mlir-c/Dialect/Transform/Interpreter.h
new file mode 100644
index 00000000000000..00095d5040a0e5
--- /dev/null
+++ b/mlir/include/mlir-c/Dialect/Transform/Interpreter.h
@@ -0,0 +1,77 @@
+//===-- mlir-c/Dialect/Transform/Interpreter.h --------------------*- C -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM
+// Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// C interface to the transform dialect interpreter.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir-c/IR.h"
+#include "mlir-c/Support.h"
+
+#ifdef __cplusplus
+extern "C" {
+#endif
+
+#define DEFINE_C_API_STRUCT(name, storage)                                     \
+  struct name {                                                                \
+    storage *ptr;                                                              \
+  };                                                                           \
+  typedef struct name name
+
+DEFINE_C_API_STRUCT(MlirTransformOptions, void);
+
+#undef DEFINE_C_API_STRUCT
+
+//----------------------------------------------------------------------------//
+// MlirTransformOptions
+//----------------------------------------------------------------------------//
+
+/// Creates a default-initialized transform options object.
+MLIR_CAPI_EXPORTED MlirTransformOptions mlirTransformOptionsCreate(void);
+
+/// Enables or disables expensive checks in transform options.
+MLIR_CAPI_EXPORTED void
+mlirTransformOptionsEnableExpensiveChecks(MlirTransformOptions transformOptions,
+                                          bool enable);
+
+/// Returns true if expensive checks are enabled in transform options.
+MLIR_CAPI_EXPORTED bool mlirTransformOptionsGetExpensiveChecksEnabled(
+    MlirTransformOptions transformOptions);
+
+/// Enables or disables the enforcement of the top-level transform op being
+/// single in transform options.
+MLIR_CAPI_EXPORTED void mlirTransformOptionsEnforceSingleTopLevelTransformOp(
+    MlirTransformOptions transformOptions, bool enable);
+
+/// Returns true if the enforcement of the top-level transform op being single
+/// is enabled in transform options.
+MLIR_CAPI_EXPORTED bool mlirTransformOptionsGetEnforceSingleTopLevelTransformOp(
+    MlirTransformOptions transformOptions);
+
+/// Destroys a transform options object previously created by
+/// mlirTransformOptionsCreate.
+MLIR_CAPI_EXPORTED void
+mlirTransformOptionsDestroy(MlirTransformOptions transformOptions);
+
+//----------------------------------------------------------------------------//
+// Transform interpreter.
+//----------------------------------------------------------------------------//
+
+/// Applies the transformation script starting at the given transform root
+/// operation to the given payload operation. The module containing the
+/// transform root as well as the transform options should be provided. The
+/// transform operation must implement TransformOpInterface and the module must
+/// be a ModuleOp. Returns the status of the application.
+MLIR_CAPI_EXPORTED MlirLogicalResult mlirTransformApplyNamedSequence(
+    MlirOperation payload, MlirOperation transformRoot,
+    MlirOperation transformModule, MlirTransformOptions transformOptions);
+
+#ifdef __cplusplus
+}
+#endif
diff --git a/mlir/include/mlir/Bindings/Python/PybindAdaptors.h b/mlir/include/mlir/Bindings/Python/PybindAdaptors.h
index 66cf20e1c136f9..52f6321251919e 100644
--- a/mlir/include/mlir/Bindings/Python/PybindAdaptors.h
+++ b/mlir/include/mlir/Bindings/Python/PybindAdaptors.h
@@ -23,6 +23,7 @@
 #include <pybind11/stl.h>
 
 #include "mlir-c/Bindings/Python/Interop.h"
+#include "mlir-c/Diagnostics.h"
 #include "mlir-c/IR.h"
 
 #include "llvm/ADT/Twine.h"
@@ -569,6 +570,41 @@ class mlir_value_subclass : public pure_subclass {
 };
 
 } // namespace adaptors
+
+/// RAII scope intercepting all diagnostics into a string. The message must be
+/// checked before this goes out of scope.
+class CollectDiagnosticsToStringScope {
+public:
+  explicit CollectDiagnosticsToStringScope(MlirContext ctx) : context(ctx) {
+    handlerID = mlirContextAttachDiagnosticHandler(ctx, &handler, &errorMessage,
+                                                   /*deleteUserData=*/nullptr);
+  }
+  ~CollectDiagnosticsToStringScope() {
+    assert(errorMessage.empty() && "unchecked error message");
+    mlirContextDetachDiagnosticHandler(context, handlerID);
+  }
+
+  [[nodiscard]] std::string takeMessage() { return std::move(errorMessage); }
+
+private:
+  static MlirLogicalResult handler(MlirDiagnostic diag, void *data) {
+    auto printer = +[](MlirStringRef message, void *data) {
+      *static_cast<std::string *>(data) +=
+          llvm::StringRef(message.data, message.length);
+    };
+    MlirLocation loc = mlirDiagnosticGetLocation(diag);
+    *static_cast<std::string *>(data) += "at ";
+    mlirLocationPrint(loc, printer, data);
+    *static_cast<std::string *>(data) += ": ";
+    mlirDiagnosticPrint(diag, printer, data);
+    return mlirLogicalResultSuccess();
+  }
+
+  MlirContext context;
+  MlirDiagnosticHandlerID handlerID;
+  std::string errorMessage = "";
+};
+
 } // namespace python
 } // namespace mlir
 
diff --git a/mlir/lib/Bindings/Python/DialectLLVM.cpp b/mlir/lib/Bindings/Python/DialectLLVM.cpp
index 780f5eacf0b8e5..843707751dd849 100644
--- a/mlir/lib/Bindings/Python/DialectLLVM.cpp
+++ b/mlir/lib/Bindings/Python/DialectLLVM.cpp
@@ -6,7 +6,6 @@
 //
 //===----------------------------------------------------------------------===//
 
-#include "mlir-c/Diagnostics.h"
 #include "mlir-c/Dialect/LLVM.h"
 #include "mlir-c/IR.h"
 #include "mlir-c/Support.h"
@@ -19,36 +18,6 @@ using namespace mlir;
 using namespace mlir::python;
 using namespace mlir::python::adaptors;
 
-/// RAII scope intercepting all diagnostics into a string. The message must be
-/// checked before this goes out of scope.
-class CollectDiagnosticsToStringScope {
-public:
-  explicit CollectDiagnosticsToStringScope(MlirContext ctx) : context(ctx) {
-    handlerID = mlirContextAttachDiagnosticHandler(ctx, &handler, &errorMessage,
-                                                   /*deleteUserData=*/nullptr);
-  }
-  ~CollectDiagnosticsToStringScope() {
-    assert(errorMessage.empty() && "unchecked error message");
-    mlirContextDetachDiagnosticHandler(context, handlerID);
-  }
-
-  [[nodiscard]] std::string takeMessage() { return std::move(errorMessage); }
-
-private:
-  static MlirLogicalResult handler(MlirDiagnostic diag, void *data) {
-    auto printer = +[](MlirStringRef message, void *data) {
-      *static_cast<std::string *>(data) +=
-          StringRef(message.data, message.length);
-    };
-    mlirDiagnosticPrint(diag, printer, data);
-    return mlirLogicalResultSuccess();
-  }
-
-  MlirContext context;
-  MlirDiagnosticHandlerID handlerID;
-  std::string errorMessage = "";
-};
-
 void populateDialectLLVMSubmodule(const pybind11::module &m) {
   auto llvmStructType =
       mlir_type_subclass(m, "StructType", mlirTypeIsALLVMStructType);
diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index 8a7951dc29fe5f..734f2f7f3f94cf 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -678,6 +678,10 @@ void PyMlirContext::clearOperationsInside(PyOperationBase &op) {
   mlirOperationWalk(op.getOperation(), invalidatingCallback,
                     static_cast<void *>(&data), MlirWalkPreOrder);
 }
+void PyMlirContext::clearOperationsInside(MlirOperation op) {
+  PyOperationRef opRef = PyOperation::forOperation(getRef(), op);
+  clearOperationsInside(opRef->getOperation());
+}
 
 size_t PyMlirContext::getLiveModuleCount() { return liveModules.size(); }
 
@@ -2556,6 +2560,9 @@ void mlir::python::populateIRCore(py::module &m) {
       .def("_get_live_operation_objects",
            &PyMlirContext::getLiveOperationObjects)
       .def("_clear_live_operations", &PyMlirContext::clearLiveOperations)
+      .def("_clear_live_operations_inside",
+           py::overload_cast<MlirOperation>(
+               &PyMlirContext::clearOperationsInside))
       .def("_get_live_module_count", &PyMlirContext::getLiveModuleCount)
       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
                              &PyMlirContext::getCapsule)
diff --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h
index 48f39c939340d7..9acfdde25ae047 100644
--- a/mlir/lib/Bindings/Python/IRModule.h
+++ b/mlir/lib/Bindings/Python/IRModule.h
@@ -223,6 +223,7 @@ class PyMlirContext {
   /// Clears all operations nested inside the given op using
   /// `clearOperation(MlirOperation)`.
   void clearOperationsInside(PyOperationBase &op);
+  void clearOperationsInside(MlirOperation op);
 
   /// Gets the count of live modules associated with this context.
   /// Used for testing.
diff --git a/mlir/lib/Bindings/Python/TransformInterpreter.cpp b/mlir/lib/Bindings/Python/TransformInterpreter.cpp
new file mode 100644
index 00000000000000..6517f8c39dfadd
--- /dev/null
+++ b/mlir/lib/Bindings/Python/TransformInterpreter.cpp
@@ -0,0 +1,90 @@
+//===- TransformInterpreter.cpp -------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Pybind classes for the transform dialect interpreter.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir-c/Dialect/Transform/Interpreter.h"
+#include "mlir-c/IR.h"
+#include "mlir-c/Support.h"
+#include "mlir/Bindings/Python/PybindAdaptors.h"
+
+#include <pybind11/detail/common.h>
+#include <pybind11/pybind11.h>
+
+namespace py = pybind11;
+
+namespace {
+struct PyMlirTransformOptions {
+  PyMlirTransformOptions() { options = mlirTransformOptionsCreate(); };
+  PyMlirTransformOptions(PyMlirTransformOptions &&other) {
+    options = other.options;
+    other.options.ptr = nullptr;
+  }
+  PyMlirTransformOptions(const PyMlirTransformOptions &) = delete;
+
+  ~PyMlirTransformOptions() { mlirTransformOptionsDestroy(options); }
+
+  MlirTransformOptions options;
+};
+} // namespace
+
+static void populateTransformInterpreterSubmodule(py::module &m) {
+  py::class_<PyMlirTransformOptions>(m, "TransformOptions", py::module_local())
+      .def(py::init())
+      .def_property(
+          "expensive_checks",
+          [](const PyMlirTransformOptions &self) {
+            return mlirTransformOptionsGetExpensiveChecksEnabled(self.options);
+          },
+          [](PyMlirTransformOptions &self, bool value) {
+            mlirTransformOptionsEnableExpensiveChecks(self.options, value);
+          })
+      .def_property(
+          "enforce_single_top_level_transform_op",
+          [](const PyMlirTransformOptions &self) {
+            return mlirTransformOptionsGetEnforceSingleTopLevelTransformOp(
+                self.options);
+          },
+          [](PyMlirTransformOptions &self, bool value) {
+            mlirTransformOptionsEnforceSingleTopLevelTransformOp(self.options,
+                                                                 value);
+          });
+
+  m.def(
+      "apply_named_sequence",
+      [](MlirOperation payloadRoot, MlirOperation transformRoot,
+         MlirOperation transformModule, const PyMlirTransformOptions &options) {
+        mlir::python::CollectDiagnosticsToStringScope scope(
+            mlirOperationGetContext(transformRoot));
+
+        // Calling back into Python to invalidate everything under the payload
+        // root. This is awkward, but we don't have access to PyMlirContext
+        // object here otherwise.
+        py::object obj = py::cast(payloadRoot);
+        obj.attr("context").attr("_clear_live_operations_inside")(payloadRoot);
+
+        MlirLogicalResult result = mlirTransformApplyNamedSequence(
+            payloadRoot, transformRoot, transformModule, options.options);
+        if (mlirLogicalResultIsSuccess(result))
+          return;
+
+        throw py::value_error(
+            "Failed to apply named transform sequence.\nDiagnostic message " +
+            scope.takeMessage());
+      },
+      py::arg("payload_root"), py::arg("transform_root"),
+      py::arg("transform_module"),
+      py::arg("transform_options") = PyMlirTransformOptions());
+}
+
+PYBIND11_MODULE(_mlirTransformInterpreter, m) {
+  m.doc() = "MLIR Transform dialect interpreter functionality.";
+  populateTransformInterpreterSubmodule(m);
+}
diff --git a/mlir/lib/CAPI/Dialect/CMakeLists.txt b/mlir/lib/CAPI/Dialect/CMakeLists.txt
index b2952da17a41c0..58b8739043f9df 100644
--- a/mlir/lib/CAPI/Dialect/CMakeLists.txt
+++ b/mlir/lib/CAPI/Dialect/CMakeLists.txt
@@ -198,6 +198,15 @@ add_mlir_upstream_c_api_library(MLIRCAPITransformDialect
   MLIRTransformDialect
 )
 
+add_mlir_upstream_c_api_library(MLIRCAPITransformDialectTransforms
+  TransformInterpreter.cpp
+
+  PARTIAL_SOURCES_INTENDED
+  LINK_LIBS PUBLIC
+  MLIRCAPIIR
+  MLIRTransformDialectTransforms
+)
+
 add_mlir_upstream_c_api_library(MLIRCAPIQuant
   Quant.cpp
 
diff --git a/mlir/lib/CAPI/Dialect/TransformInterpreter.cpp b/mlir/lib/CAPI/Dialect/TransformInterpreter.cpp
new file mode 100644
index 00000000000000..6a2cfb235fcfd4
--- /dev/null
+++ b/mlir/lib/CAPI/Dialect/TransformInterpreter.cpp
@@ -0,0 +1,74 @@
+//===- TransformTransforms.cpp - C Interface for Transform dialect --------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// C interface to transforms for the transform dialect.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir-c/Dialect/Transform/Interpreter.h"
+#include "mlir-c/Support.h"
+#include "mlir/CAPI/IR.h"
+#include "mlir/CAPI/Support.h"
+#include "mlir/CAPI/Wrap.h"
+#include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
+#include "mlir/Dialect/Transform/Transforms/TransformInterpreterUtils.h"
+
+using namespace mlir;
+
+DEFINE_C_API_PTR_METHODS(MlirTransformOptions, transform::TransformOptions)
+
+extern "C" {
+
+MlirTransformOptions mlirTransformOptionsCreate() {
+  return wrap(new transform::TransformOptions);
+}
+
+void mlirTransformOptionsEnableExpensiveChecks(
+    MlirTransformOptions transformOptions, bool enable) {
+  unwrap(transformOptions)->enableExpensiveChecks(enable);
+}
+
+bool mlirTransformOptionsGetExpensiveChecksEnabled(
+    MlirTransformOptions transformOptions) {
+  return unwrap(transformOptions)->getExpensiveChecksEnabled();
+}
+
+void mlirTransformOptionsEnforceSingleTopLevelTransformOp(
+    MlirTransformOptions transformOptions, bool enable) {
+  unwrap(transformOptions)->enableEnforceSingleToplevelTransformOp(enable);
+}
+
+bool mlirTransformOptionsGetEnforceSingleTopLevelTransformOp(
+    MlirTransformOptions transformOptions) {
+  return unwrap(transformOptions)->getEnforceSingleToplevelTransformOp();
+}
+
+void mlirTransformOptionsDestroy(MlirTransformOptions transformOptions) {
+  delete unwrap(transformOptions);
+}
+
+MlirLogicalResult mlirTransformApplyNamedSequence(
+    MlirOperation payload, MlirOperation transformRoot,
+    MlirOperation transformModule, MlirTransformOptions transformOptions) {
+  Operation *transformRootOp = unwrap(transformRoot);
+  Operation *transformModuleOp = unwrap(transformModule);
+  if (!isa<transform::TransformOpInterface>(transformRootOp)) {
+    transformRootOp->emitError()
+        << "must implement TransformOpInterface to be used as transform root";
+    return mlirLogicalResultFailure();
+  }
+  if (!isa<ModuleOp>(transformModuleOp)) {
+    transformModuleOp->emitError()
+        << "must be a " << ModuleOp::getOperationName();
+    return mlirLogicalResultFailure();
+  }
+  return wrap(transform::applyTransformNamedSequence(
+      unwrap(payload), unwrap(transformRoot),
+      cast<ModuleOp>(unwrap(transformModule)), *unwrap(transformOptions)));
+}
+}
diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt
index ed167afeb69a62..563d035f155267 100644
--- a/mlir/python/CMakeLists.txt
+++ b/mlir/python/CMakeLists.txt
@@ -181,6 +181,13 @@ declare_mlir_python_sources(
   SOURCES
     dialects/transform/extras/__init__.py)
 
+declare_mlir_python_sources(
+  MLIRPythonSources.Dialects.transform.interpreter
+  ADD_TO_PARENT MLIRPythonSources.Dialects
+  ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
+  SOURCES
+    dialects/transform/interpreter/__init__.py)
+
 declare_mlir_dialect_extension_python_bindings(
   ADD_TO_PARENT MLIRPythonSources.Dialects
   ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
@@ -609,6 +616,18 @@ declare_mlir_python_extension(MLIRPythonExtension.SparseTensorDialectPasses
     MLIRCAPISparseTensor
 )
 
+declare_mlir_python_extension(MLIRPythonExtension.TransformInterpreter
+  MODULE_NAME _mlirTransformInterpreter
+  ADD_TO_PARENT MLIRPythonSources.Dialects.transform
+  ROOT_DIR "${PYTHON_SOURCE_DIR}"
+  SOURCES
+    TransformInterpreter.cpp
+  PRIVATE_LINK_LIBS
+    LLVMSupport
+  EMBED_CAPI_LINK_LIBS
+    MLIRCAPITransformDialectTransforms
+)
+
 # TODO: Figure out how to put this in the test tree.
 # This should not be included in the main Python extension. However,
 # putting it into MLIRPythonTestSources along with the dialect declaration
diff --git a/mlir/python/mlir/dialects/transform/interpreter/__init__.py b/mlir/python/mlir/dialects/transform/interpreter/__init__.py
new file mode 100644
index 00000000000000..6145b99224eb54
--- /dev/null
+++ b/mlir/python/mlir/dialects/transform/interpreter/__init__.py
@@ -0,0 +1,33 @@
+#  Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+#  See https://llvm.org/LICENSE.txt for license information.
+#  SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+from ....ir import Operation
+from ...._mlir_libs import _mlirTransformInterpreter as _cextTransformInterpreter
+
+
+TransformOptions = _cextTransformInterpreter.TransformOptions
+
+
+def _unpack_operation(op):
+    if isinstance(op, Operation):
+        return op
+    return op.operation
+
+
+def apply_named_sequence(
+    payload_root, transform_root, transform_module, transform_options=None
+):
+    """Applies the transformation script starting at the given transform root
+    operation to the given payload operation. The module containing the
+    transform root as well as the transform options should be provided.
+    The transform operation must implement TransformOpInterface and the module
+    must be a ModuleOp."""
+
+    args = tuple(
+        map(_unpack_operation, (payload_root, transform_root, transform_module))
+    )
+    if transform_options is None:
+        _cextTransformInterpreter.apply_named_sequence(*args)
+    else:
+        _cextTransformInterpreter(*args, transform_options)
diff --git a/mlir/test/CAPI/CMakeLists.txt b/mlir/test/CAPI/CMakeLists.txt
index 1096a3b0806648..79b61fdef38b49 100644
--- a/mlir/test/CAPI/CMakeLists.txt
+++ b/mlir/test/CAPI/CMakeLists.txt
@@ -86,6 +86,15 @@ _add_capi_test_executable(mlir-capi-transform-test
     MLIRCAPITransformDialect
 )
 
+_add_capi_test_executable(mlir-capi-transform-interpreter-test
+  transform_interpreter.c
+  LINK_LIBS PRIVATE
+    MLIRCAPIIR
+    MLIRCAPIRegisterEverything
+    MLIRCAPITransformDialect
+    MLIRCAPITransformDialectTransforms
+)
+
 _add_capi_test_executable(mlir-capi-translation-test
   translation.c
   LINK_LIBS PRIVATE
diff --git a/mlir/test/CAPI/transform_interpreter.c b/mlir/test/CAPI/transform_interpreter.c
new file mode 100644
index 00000000000000..8fe37b47b7f874
--- /dev/null
+++ b/mlir/test/CAPI/transform_interpreter.c
@@ -0,0 +1,69 @@
+//===- transform_interpreter.c - Test of the Transform interpreter C API --===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM
+// Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+// RUN: mlir-capi-transform-interpreter-test 2>&1 | FileCheck %s
+
+#include "mlir-c/Dialect/Transform.h"
+#include "mlir-c/Dialect/Transform/Interpreter.h"
+#include "mlir-c/IR.h"
+#include "mlir-c/Support.h"
+
+#include <stdio.h>
+#include <stdlib.h>
+
+int testApplyNamedSequence(MlirContext ctx) {
+  fprintf(stderr, "%s\n", __FUNCTION__);
+
+  const char module[] =
+      "module attributes {transform.with_named_sequence} {"
+      "  transform.named_sequence @__transform_main(%root: !transform.any_op) {"
+      "    transform.print %root { name = \"from interpreter\" }: "
+      "!transform.any_op"
+      "    transform.yield"
+      "  }"
+      "}";
+
+  MlirStringRef moduleStringRef = mlirStringRefCreateFromCString(module);
+  MlirStringRef nameStringRef = mlirStringRefCreateFromCString("inline-module");
+
+  MlirOperation root =
+      mlirOperationCreateParse(ctx, moduleStringRef, nameStringRef);
+  if (mlirOperationIsNull(root))
+    return 1;
+  MlirBlock body = mlirRegionGetFirstBlock(mlirOperationGetRegion(root, 0));
+  MlirOperation entry = mlirBlockGetFirstOperation(body);
+
+  MlirTransformOptions options = mlirTransformOptionsCreate();
+  mlirTransformOptionsEnableExpensiveChecks(options, true);
+  mlirTransformOptionsEnforceSingleTopLevelTransformOp(options, true);
+
+  MlirLogicalResult result =
+      mlirTransformApplyNamedSequence(root, entry, root, options);
+  mlirTransformOptionsDestroy(options);
+  if (mlirLogicalResultIsFailure(result))
+    return 2;
+
+  return 0;
+}
+// CHECK-LABEL: testApplyNamedSequence
+// CHECK: from interpreter
+// CHECK: transform.named_sequence @__transform_main
+// CHECK:   transform.print %arg0
+// CHECK:   transform.yield
+
+int main(void) {
+  MlirContext ctx = mlirContextCreate();
+  mlirDialectHandleRegisterDialect(mlirGetDialectHandle__transform__(), ctx);
+  int result = testApplyNamedSequence(ctx);
+  mlirContextDestroy(ctx);
+  if (result)
+    return result;
+
+  return EXIT_SUCCESS;
+}
diff --git a/mlir/test/CMakeLists.txt b/mlir/test/CMakeLists.txt
index 6724dd4bdd1bcd..74921544c55578 100644
--- a/mlir/test/CMakeLists.txt
+++ b/mlir/test/CMakeLists.txt
@@ -100,6 +100,7 @@ set(MLIR_TEST_DEPENDS
   mlir-capi-quant-test
   mlir-capi-sparse-tensor-test
   mlir-capi-transform-test
+  mlir-capi-transform-interpreter-test
   mlir-capi-translation-test
   mlir-linalg-ods-yaml-gen
   mlir-lsp-server
diff --git a/mlir/test/lit.cfg.py b/mlir/test/lit.cfg.py
index 38e65e4549c559..904dfb680a0404 100644
--- a/mlir/test/lit.cfg.py
+++ b/mlir/test/lit.cfg.py
@@ -106,6 +106,7 @@ def add_runtime(name):
     "mlir-capi-quant-test",
     "mlir-capi-sparse-tensor-test",
     "mlir-capi-transform-test",
+    "mlir-capi-transform-interpreter-test",
     "mlir-capi-translation-test",
     "mlir-cpu-runner",
     add_runtime("mlir_runner_utils"),
diff --git a/mlir/test/python/dialects/transform_interpreter.py b/mlir/test/python/dialects/transform_interpreter.py
new file mode 100644
index 00000000000000..740c49f76a26c4
--- /dev/null
+++ b/mlir/test/python/dialects/transform_interpreter.py
@@ -0,0 +1,56 @@
+# RUN: %PYTHON %s | FileCheck %s
+
+from mlir import ir
+from mlir.dialects.transform import interpreter as interp
+
+
+def test_in_context(f):
+    with ir.Context(), ir.Location.unknown():
+        f()
+    return f
+
+
+print_root_module = """
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%root: !transform.any_op) {
+    transform.print %root { name = \"from interpreter\" }: !transform.any_op
+    transform.yield
+  }
+}"""
+
+
+ at test_in_context
+def print_self():
+    m = ir.Module.parse(print_root_module.replace("from interpreter", "print_self"))
+    interp.apply_named_sequence(m, m.body.operations[0], m)
+
+
+# CHECK-LABEL: print_self
+# CHECK: transform.named_sequence @__transform_main
+# CHECK: transform.print
+# CHECK: transform.yield
+
+
+ at test_in_context
+def print_other():
+    transform = ir.Module.parse(
+        print_root_module.replace("from interpreter", "print_other")
+    )
+    payload = ir.Module.parse("module attributes { this.is.payload } {}")
+    interp.apply_named_sequence(payload, transform.body.operations[0], transform)
+
+
+# CHECK-LABEL: print_other
+# CHECK-NOT: transform
+# CHECK: this.is.payload
+
+
+ at test_in_context
+def failed():
+    payload = ir.Module.parse("module attributes { this.is.payload } {}")
+    try:
+        interp.apply_named_sequence(payload, payload, payload)
+    except ValueError as e:
+        assert (
+            "must implement TransformOpInterface to be used as transform root" in str(e)
+        )

>From bb5f234484322dcca1f63d672d00dceaae31feab Mon Sep 17 00:00:00 2001
From: Alex Zinenko <zinenko at google.com>
Date: Wed, 21 Feb 2024 10:00:33 +0000
Subject: [PATCH 2/2] add bazel

---
 .../llvm-project-overlay/mlir/BUILD.bazel     | 19 +++++++++++++++++++
 .../mlir/python/BUILD.bazel                   |  5 +++++
 2 files changed, 24 insertions(+)

diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
index 9ad33aeb8b1e77..2edc656745fdab 100644
--- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
@@ -739,6 +739,25 @@ mlir_c_api_cc_library(
     ],
 )
 
+mlir_c_api_cc_library(
+    name = "CAPITransformDialectTransforms",
+    srcs = [
+        "lib/CAPI/Dialect/TransformInterpreter.cpp",
+    ],
+    hdrs = [
+        "include/mlir-c/Dialect/Transform/Interpreter.h",
+    ],
+    capi_deps = [
+        ":CAPIIR",
+        ":CAPITransformDialect",
+    ],
+    includes = ["include"],
+    deps = [
+        ":TransformDialect",
+        ":TransformDialectTransforms",
+    ],
+)
+
 mlir_c_api_cc_library(
     name = "CAPIMLProgram",
     srcs = [
diff --git a/utils/bazel/llvm-project-overlay/mlir/python/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/python/BUILD.bazel
index f19c2336e6bcb4..0c3ed22e736018 100644
--- a/utils/bazel/llvm-project-overlay/mlir/python/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/python/BUILD.bazel
@@ -1483,6 +1483,11 @@ filegroup(
     srcs = glob(["mlir/dialects/transform/extras/*.py"]),
 )
 
+filegroup(
+    name = "TransformInterpreterPackagePyFiles",
+    srcs = glob(["mlir/dialects/transform/interpreter/*.py"]),
+)
+
 ##---------------------------------------------------------------------------##
 # Vector dialect.
 ##---------------------------------------------------------------------------##



More information about the llvm-commits mailing list