[Mlir-commits] [mlir] [mlir][python] C++ API demo (PR #71133)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Nov 2 19:46:51 PDT 2023


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Maksim Levental (makslevental)

<details>
<summary>Changes</summary>

This PR demonstrates how to call C++ APIs directly without going through the C API and without RTTI.

In fact it demonstrates something more exciting: tunneling python callbacks all the way into MLIR passes:

```python
def print_ops(op):
    print(op.name)

test.register_python_test_pass_demo_pass(print_ops)

module = """
module {
  func.func @<!-- -->main() {
    %memref = memref.alloca() : memref<1xi64>
    %c0 = arith.constant 0 : index
    %c1 = arith.constant 1 : i64
    memref.store %c1, %memref[%c0] : memref<1xi64>
    %u_memref = memref.cast %memref : memref<1xi64> to memref<*xi64>
    return
  }
}
"""

mlir_module = Module.parse(module)
PassManager.parse("builtin.module(python-pass-demo)").run(mlir_module.operation)
```

This will print

```
memref.alloca
arith.constant
arith.constant
memref.store
memref.cast
func.return
func.func
builtin.module
```

by calling `print_ops` _from inside the pass_ (where there's a `this->getOperation()->walk([this](Operation *op) { ... });`).

---
Full diff: https://github.com/llvm/llvm-project/pull/71133.diff


7 Files Affected:

- (modified) mlir/python/CMakeLists.txt (+2) 
- (modified) mlir/python/mlir/dialects/python_test.py (+6) 
- (modified) mlir/test/python/dialects/python_test.py (+35) 
- (modified) mlir/test/python/lib/CMakeLists.txt (+1-1) 
- (modified) mlir/test/python/lib/PythonTestModule.cpp (+6) 
- (added) mlir/test/python/lib/PythonTestPass.cpp (+53) 
- (added) mlir/test/python/lib/PythonTestPass.h (+16) 


``````````diff
diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt
index 88e6e13602d291a..e45f55e0381063d 100644
--- a/mlir/python/CMakeLists.txt
+++ b/mlir/python/CMakeLists.txt
@@ -607,11 +607,13 @@ if(MLIR_INCLUDE_TESTS)
     ROOT_DIR "${MLIR_SOURCE_DIR}/test/python/lib"
     SOURCES
       PythonTestModule.cpp
+      PythonTestPass.cpp
     PRIVATE_LINK_LIBS
       LLVMSupport
     EMBED_CAPI_LINK_LIBS
       MLIRCAPIPythonTestDialect
   )
+  set_source_files_properties(${MLIR_SOURCE_DIR}/test/python/lib/PythonTestPass.cpp PROPERTIES COMPILE_FLAGS -fno-rtti)
 endif()
 
 ################################################################################
diff --git a/mlir/python/mlir/dialects/python_test.py b/mlir/python/mlir/dialects/python_test.py
index 6579e02d8549efa..401ac260139e0a1 100644
--- a/mlir/python/mlir/dialects/python_test.py
+++ b/mlir/python/mlir/dialects/python_test.py
@@ -15,3 +15,9 @@ def register_python_test_dialect(context, load=True):
     from .._mlir_libs import _mlirPythonTest
 
     _mlirPythonTest.register_python_test_dialect(context, load)
+
+
+def register_python_test_pass_demo_pass(func):
+    from .._mlir_libs import _mlirPythonTest
+
+    _mlirPythonTest.register_python_test_pass_demo_pass(func)
diff --git a/mlir/test/python/dialects/python_test.py b/mlir/test/python/dialects/python_test.py
index 472db7e5124dbed..39e3a2dc8ee45ca 100644
--- a/mlir/test/python/dialects/python_test.py
+++ b/mlir/test/python/dialects/python_test.py
@@ -5,6 +5,7 @@
 import mlir.dialects.python_test as test
 import mlir.dialects.tensor as tensor
 import mlir.dialects.arith as arith
+from mlir.passmanager import PassManager
 
 
 def run(f):
@@ -551,3 +552,37 @@ def testInferTypeOpInterface():
             two_operands = test.InferResultsVariadicInputsOp(single=zero, doubled=zero)
             # CHECK: f32
             print(two_operands.result.type)
+
+
+# CHECK-LABEL: testPythonPassDemo
+ at run
+def testPythonPassDemo():
+    def print_ops(op):
+        print(op.name)
+
+    module = """
+    module {
+      func.func @main() {
+        %memref = memref.alloca() : memref<1xi64>
+        %c0 = arith.constant 0 : index
+        %c1 = arith.constant 1 : i64
+        memref.store %c1, %memref[%c0] : memref<1xi64>
+        %u_memref = memref.cast %memref : memref<1xi64> to memref<*xi64>
+        return
+      }
+    }
+    """
+
+    # CHECK: memref.alloca
+    # CHECK: arith.constant
+    # CHECK: arith.constant
+    # CHECK: memref.store
+    # CHECK: memref.cast
+    # CHECK: func.return
+    # CHECK: func.func
+    # CHECK: builtin.module
+    with Context() as ctx, Location.unknown():
+        test.register_python_test_dialect(ctx)
+        test.register_python_test_pass_demo_pass(print_ops)
+        mlir_module = Module.parse(module)
+        PassManager.parse("builtin.module(python-pass-demo)").run(mlir_module.operation)
diff --git a/mlir/test/python/lib/CMakeLists.txt b/mlir/test/python/lib/CMakeLists.txt
index d7cbbfbc214772b..8354a08e7b7139f 100644
--- a/mlir/test/python/lib/CMakeLists.txt
+++ b/mlir/test/python/lib/CMakeLists.txt
@@ -2,6 +2,7 @@ set(LLVM_OPTIONAL_SOURCES
   PythonTestCAPI.cpp
   PythonTestDialect.cpp
   PythonTestModule.cpp
+  PythonTestPass.cpp
 )
 
 add_mlir_library(MLIRPythonTestDialect
@@ -29,4 +30,3 @@ add_mlir_public_c_api_library(MLIRCAPIPythonTestDialect
   MLIRCAPIIR
   MLIRPythonTestDialect
 )
-
diff --git a/mlir/test/python/lib/PythonTestModule.cpp b/mlir/test/python/lib/PythonTestModule.cpp
index f533082a0a147c0..5be16e37abc338c 100644
--- a/mlir/test/python/lib/PythonTestModule.cpp
+++ b/mlir/test/python/lib/PythonTestModule.cpp
@@ -7,6 +7,8 @@
 //===----------------------------------------------------------------------===//
 
 #include "PythonTestCAPI.h"
+#include "PythonTestPass.h"
+
 #include "mlir-c/BuiltinAttributes.h"
 #include "mlir-c/BuiltinTypes.h"
 #include "mlir-c/IR.h"
@@ -34,6 +36,10 @@ PYBIND11_MODULE(_mlirPythonTest, m) {
       },
       py::arg("context"), py::arg("load") = true);
 
+  m.def("register_python_test_pass_demo_pass", [](py::function func) {
+    registerPythonTestPassDemoPassWithFunc(func.ptr());
+  });
+
   mlir_attribute_subclass(m, "TestAttr",
                           mlirAttributeIsAPythonTestTestAttribute)
       .def_classmethod(
diff --git a/mlir/test/python/lib/PythonTestPass.cpp b/mlir/test/python/lib/PythonTestPass.cpp
new file mode 100644
index 000000000000000..ce2b9f34c3c5c3d
--- /dev/null
+++ b/mlir/test/python/lib/PythonTestPass.cpp
@@ -0,0 +1,53 @@
+//===- PythonTestPassDemo.cpp -----------------------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "PythonTestPass.h"
+
+#include "mlir-c/Bindings/Python/Interop.h"
+#include "mlir/CAPI/IR.h"
+#include "mlir/IR/BuiltinDialect.h"
+#include "mlir/Pass/Pass.h"
+
+using namespace mlir;
+
+namespace {
+
+struct PythonTestPassDemo
+    : public PassWrapper<PythonTestPassDemo, OperationPass<ModuleOp>> {
+  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(PythonTestPassDemo)
+
+  PythonTestPassDemo(PyObject *func) : func(func) {}
+  StringRef getArgument() const final { return "python-pass-demo"; }
+
+  void runOnOperation() override {
+    this->getOperation()->walk([this](Operation *op) {
+      PyObject *mlirModule =
+          PyImport_ImportModule(MAKE_MLIR_PYTHON_QUALNAME("ir"));
+      PyObject *cAPIFactory = PyObject_GetAttrString(
+          PyObject_GetAttrString(mlirModule, "Operation"),
+          MLIR_PYTHON_CAPI_FACTORY_ATTR);
+      PyObject *opApiObject = PyObject_CallFunction(
+          cAPIFactory, "(O)", mlirPythonOperationToCapsule(wrap(op)));
+      (void)PyObject_CallFunction(func, "(O)", opApiObject);
+      Py_DECREF(opApiObject);
+    });
+  }
+
+  PyObject *func;
+};
+
+std::unique_ptr<OperationPass<ModuleOp>>
+createPythonTestPassDemoPassWithFunc(PyObject *func) {
+  return std::make_unique<PythonTestPassDemo>(func);
+}
+
+} // namespace
+
+void registerPythonTestPassDemoPassWithFunc(PyObject *func) {
+  registerPass([func]() { return createPythonTestPassDemoPassWithFunc(func); });
+}
diff --git a/mlir/test/python/lib/PythonTestPass.h b/mlir/test/python/lib/PythonTestPass.h
new file mode 100644
index 000000000000000..4df4f965857eda6
--- /dev/null
+++ b/mlir/test/python/lib/PythonTestPass.h
@@ -0,0 +1,16 @@
+//===- PythonTestPassDemo.h ---------------------------------------*- C -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_TEST_PYTHON_PASS_PYTHONTESTCAPI_H
+#define MLIR_TEST_PYTHON_PASS_PYTHONTESTCAPI_H
+
+#include <Python.h>
+
+void registerPythonTestPassDemoPassWithFunc(PyObject *func);
+
+#endif

``````````

</details>


https://github.com/llvm/llvm-project/pull/71133


More information about the Mlir-commits mailing list