[Mlir-commits] [mlir] [mlir][python] C++ API demo (PR #71133)
Maksim Levental
llvmlistbot at llvm.org
Thu Nov 2 19:43:03 PDT 2023
https://github.com/makslevental updated https://github.com/llvm/llvm-project/pull/71133
>From 87f8596229eb97195176b676764597568f269dd4 Mon Sep 17 00:00:00 2001
From: max <maksim.levental at gmail.com>
Date: Thu, 2 Nov 2023 21:41:54 -0500
Subject: [PATCH] [mlir][python] C++ API demo
---
mlir/python/CMakeLists.txt | 2 +
mlir/python/mlir/dialects/python_test.py | 6 +++
mlir/test/python/dialects/python_test.py | 35 +++++++++++++++
mlir/test/python/lib/CMakeLists.txt | 2 +-
mlir/test/python/lib/PythonTestModule.cpp | 6 +++
mlir/test/python/lib/PythonTestPass.cpp | 53 +++++++++++++++++++++++
mlir/test/python/lib/PythonTestPass.h | 16 +++++++
7 files changed, 119 insertions(+), 1 deletion(-)
create mode 100644 mlir/test/python/lib/PythonTestPass.cpp
create mode 100644 mlir/test/python/lib/PythonTestPass.h
diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt
index 88e6e13602d291a..e45f55e0381063d 100644
--- a/mlir/python/CMakeLists.txt
+++ b/mlir/python/CMakeLists.txt
@@ -607,11 +607,13 @@ if(MLIR_INCLUDE_TESTS)
ROOT_DIR "${MLIR_SOURCE_DIR}/test/python/lib"
SOURCES
PythonTestModule.cpp
+ PythonTestPass.cpp
PRIVATE_LINK_LIBS
LLVMSupport
EMBED_CAPI_LINK_LIBS
MLIRCAPIPythonTestDialect
)
+ set_source_files_properties(${MLIR_SOURCE_DIR}/test/python/lib/PythonTestPass.cpp PROPERTIES COMPILE_FLAGS -fno-rtti)
endif()
################################################################################
diff --git a/mlir/python/mlir/dialects/python_test.py b/mlir/python/mlir/dialects/python_test.py
index 6579e02d8549efa..401ac260139e0a1 100644
--- a/mlir/python/mlir/dialects/python_test.py
+++ b/mlir/python/mlir/dialects/python_test.py
@@ -15,3 +15,9 @@ def register_python_test_dialect(context, load=True):
from .._mlir_libs import _mlirPythonTest
_mlirPythonTest.register_python_test_dialect(context, load)
+
+
+def register_python_test_pass_demo_pass(func):
+ from .._mlir_libs import _mlirPythonTest
+
+ _mlirPythonTest.register_python_test_pass_demo_pass(func)
diff --git a/mlir/test/python/dialects/python_test.py b/mlir/test/python/dialects/python_test.py
index 472db7e5124dbed..39e3a2dc8ee45ca 100644
--- a/mlir/test/python/dialects/python_test.py
+++ b/mlir/test/python/dialects/python_test.py
@@ -5,6 +5,7 @@
import mlir.dialects.python_test as test
import mlir.dialects.tensor as tensor
import mlir.dialects.arith as arith
+from mlir.passmanager import PassManager
def run(f):
@@ -551,3 +552,37 @@ def testInferTypeOpInterface():
two_operands = test.InferResultsVariadicInputsOp(single=zero, doubled=zero)
# CHECK: f32
print(two_operands.result.type)
+
+
+# CHECK-LABEL: testPythonPassDemo
+ at run
+def testPythonPassDemo():
+ def print_ops(op):
+ print(op.name)
+
+ module = """
+ module {
+ func.func @main() {
+ %memref = memref.alloca() : memref<1xi64>
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : i64
+ memref.store %c1, %memref[%c0] : memref<1xi64>
+ %u_memref = memref.cast %memref : memref<1xi64> to memref<*xi64>
+ return
+ }
+ }
+ """
+
+ # CHECK: memref.alloca
+ # CHECK: arith.constant
+ # CHECK: arith.constant
+ # CHECK: memref.store
+ # CHECK: memref.cast
+ # CHECK: func.return
+ # CHECK: func.func
+ # CHECK: builtin.module
+ with Context() as ctx, Location.unknown():
+ test.register_python_test_dialect(ctx)
+ test.register_python_test_pass_demo_pass(print_ops)
+ mlir_module = Module.parse(module)
+ PassManager.parse("builtin.module(python-pass-demo)").run(mlir_module.operation)
diff --git a/mlir/test/python/lib/CMakeLists.txt b/mlir/test/python/lib/CMakeLists.txt
index d7cbbfbc214772b..8354a08e7b7139f 100644
--- a/mlir/test/python/lib/CMakeLists.txt
+++ b/mlir/test/python/lib/CMakeLists.txt
@@ -2,6 +2,7 @@ set(LLVM_OPTIONAL_SOURCES
PythonTestCAPI.cpp
PythonTestDialect.cpp
PythonTestModule.cpp
+ PythonTestPass.cpp
)
add_mlir_library(MLIRPythonTestDialect
@@ -29,4 +30,3 @@ add_mlir_public_c_api_library(MLIRCAPIPythonTestDialect
MLIRCAPIIR
MLIRPythonTestDialect
)
-
diff --git a/mlir/test/python/lib/PythonTestModule.cpp b/mlir/test/python/lib/PythonTestModule.cpp
index f533082a0a147c0..5be16e37abc338c 100644
--- a/mlir/test/python/lib/PythonTestModule.cpp
+++ b/mlir/test/python/lib/PythonTestModule.cpp
@@ -7,6 +7,8 @@
//===----------------------------------------------------------------------===//
#include "PythonTestCAPI.h"
+#include "PythonTestPass.h"
+
#include "mlir-c/BuiltinAttributes.h"
#include "mlir-c/BuiltinTypes.h"
#include "mlir-c/IR.h"
@@ -34,6 +36,10 @@ PYBIND11_MODULE(_mlirPythonTest, m) {
},
py::arg("context"), py::arg("load") = true);
+ m.def("register_python_test_pass_demo_pass", [](py::function func) {
+ registerPythonTestPassDemoPassWithFunc(func.ptr());
+ });
+
mlir_attribute_subclass(m, "TestAttr",
mlirAttributeIsAPythonTestTestAttribute)
.def_classmethod(
diff --git a/mlir/test/python/lib/PythonTestPass.cpp b/mlir/test/python/lib/PythonTestPass.cpp
new file mode 100644
index 000000000000000..ce2b9f34c3c5c3d
--- /dev/null
+++ b/mlir/test/python/lib/PythonTestPass.cpp
@@ -0,0 +1,53 @@
+//===- PythonTestPassDemo.cpp -----------------------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "PythonTestPass.h"
+
+#include "mlir-c/Bindings/Python/Interop.h"
+#include "mlir/CAPI/IR.h"
+#include "mlir/IR/BuiltinDialect.h"
+#include "mlir/Pass/Pass.h"
+
+using namespace mlir;
+
+namespace {
+
+struct PythonTestPassDemo
+ : public PassWrapper<PythonTestPassDemo, OperationPass<ModuleOp>> {
+ MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(PythonTestPassDemo)
+
+ PythonTestPassDemo(PyObject *func) : func(func) {}
+ StringRef getArgument() const final { return "python-pass-demo"; }
+
+ void runOnOperation() override {
+ this->getOperation()->walk([this](Operation *op) {
+ PyObject *mlirModule =
+ PyImport_ImportModule(MAKE_MLIR_PYTHON_QUALNAME("ir"));
+ PyObject *cAPIFactory = PyObject_GetAttrString(
+ PyObject_GetAttrString(mlirModule, "Operation"),
+ MLIR_PYTHON_CAPI_FACTORY_ATTR);
+ PyObject *opApiObject = PyObject_CallFunction(
+ cAPIFactory, "(O)", mlirPythonOperationToCapsule(wrap(op)));
+ (void)PyObject_CallFunction(func, "(O)", opApiObject);
+ Py_DECREF(opApiObject);
+ });
+ }
+
+ PyObject *func;
+};
+
+std::unique_ptr<OperationPass<ModuleOp>>
+createPythonTestPassDemoPassWithFunc(PyObject *func) {
+ return std::make_unique<PythonTestPassDemo>(func);
+}
+
+} // namespace
+
+void registerPythonTestPassDemoPassWithFunc(PyObject *func) {
+ registerPass([func]() { return createPythonTestPassDemoPassWithFunc(func); });
+}
diff --git a/mlir/test/python/lib/PythonTestPass.h b/mlir/test/python/lib/PythonTestPass.h
new file mode 100644
index 000000000000000..4df4f965857eda6
--- /dev/null
+++ b/mlir/test/python/lib/PythonTestPass.h
@@ -0,0 +1,16 @@
+//===- PythonTestPassDemo.h ---------------------------------------*- C -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_TEST_PYTHON_PASS_PYTHONTESTCAPI_H
+#define MLIR_TEST_PYTHON_PASS_PYTHONTESTCAPI_H
+
+#include <Python.h>
+
+void registerPythonTestPassDemoPassWithFunc(PyObject *func);
+
+#endif
More information about the Mlir-commits
mailing list