[llvm] [mlir] [mlir][python] Fix generation of python bindings for async dialect (PR #75960)

via llvm-commits llvm-commits at lists.llvm.org
Tue Dec 19 10:54:40 PST 2023


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-async

@llvm/pr-subscribers-mlir

Author: Abhishek Kulkarni (adk9)

<details>
<summary>Changes</summary>

There were no Python bindings being generated for mlir "async" dialect. This PR fixes the issues with generation of Python bindings for "async" dialect and adds a test case to use them.

---
Full diff: https://github.com/llvm/llvm-project/pull/75960.diff


5 Files Affected:

- (modified) mlir/include/mlir/Dialect/Async/IR/CMakeLists.txt (+1) 
- (modified) mlir/python/CMakeLists.txt (+4-5) 
- (modified) mlir/python/mlir/dialects/async_dialect/__init__.py (+1-1) 
- (modified) mlir/test/python/dialects/async_dialect.py (+15-1) 
- (modified) utils/bazel/llvm-project-overlay/mlir/python/BUILD.bazel (+47) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/Async/IR/CMakeLists.txt b/mlir/include/mlir/Dialect/Async/IR/CMakeLists.txt
index ebbf2df760faa4..2525eee2a34ec9 100644
--- a/mlir/include/mlir/Dialect/Async/IR/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/Async/IR/CMakeLists.txt
@@ -1,2 +1,3 @@
+set(LLVM_TARGET_DEFINITIONS AsyncOps.td)
 add_mlir_dialect(AsyncOps async)
 add_mlir_doc(AsyncOps AsyncDialect Dialects/ -gen-dialect-doc)
diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt
index 41d91cf6778338..550465f6e37467 100644
--- a/mlir/python/CMakeLists.txt
+++ b/mlir/python/CMakeLists.txt
@@ -72,7 +72,7 @@ declare_mlir_dialect_python_bindings(
   ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
   TD_FILE dialects/AsyncOps.td
   SOURCES_GLOB dialects/async_dialect/*.py
-  DIALECT_NAME async_dialect)
+  DIALECT_NAME async)
 
 declare_mlir_dialect_python_bindings(
   ADD_TO_PARENT MLIRPythonSources.Dialects
@@ -522,7 +522,7 @@ declare_mlir_python_extension(MLIRPythonExtension.Dialects.Transform.Pybind
 
 declare_mlir_python_extension(MLIRPythonExtension.AsyncDialectPasses
   MODULE_NAME _mlirAsyncPasses
-  ADD_TO_PARENT MLIRPythonSources.Dialects.async_dialect
+  ADD_TO_PARENT MLIRPythonSources.Dialects.async
   ROOT_DIR "${PYTHON_SOURCE_DIR}"
   SOURCES
     AsyncPasses.cpp
@@ -664,11 +664,10 @@ if(NOT LLVM_ENABLE_IDE)
     COMPONENT mlir-python-sources
   )
 endif()
-
-################################################################################
+# ###############################################################################
 # The fully assembled package of modules.
 # This must come last.
-################################################################################
+# ###############################################################################
 
 add_mlir_python_modules(MLIRPythonModules
   ROOT_PREFIX "${MLIR_BINARY_DIR}/python_packages/mlir_core/mlir"
diff --git a/mlir/python/mlir/dialects/async_dialect/__init__.py b/mlir/python/mlir/dialects/async_dialect/__init__.py
index dcf9d6cb2638f6..6a5ecfc20956cf 100644
--- a/mlir/python/mlir/dialects/async_dialect/__init__.py
+++ b/mlir/python/mlir/dialects/async_dialect/__init__.py
@@ -2,4 +2,4 @@
 #  See https://llvm.org/LICENSE.txt for license information.
 #  SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
 
-from .._async_dialect_ops_gen import *
+from .._async_ops_gen import *
diff --git a/mlir/test/python/dialects/async_dialect.py b/mlir/test/python/dialects/async_dialect.py
index f6181cc76118ed..13e3c42e57c21e 100644
--- a/mlir/test/python/dialects/async_dialect.py
+++ b/mlir/test/python/dialects/async_dialect.py
@@ -1,7 +1,8 @@
 # RUN: %PYTHON %s | FileCheck %s
 
 from mlir.ir import *
-import mlir.dialects.async_dialect
+from mlir.dialects import arith
+import mlir.dialects.async_dialect as async_dialect
 import mlir.dialects.async_dialect.passes
 from mlir.passmanager import *
 
@@ -11,6 +12,19 @@ def run(f):
     f()
 
 
+# CHECK-LABEL: TEST: testCreateGroupOp
+ at run
+def testCreateGroupOp():
+    with Context() as ctx, Location.unknown():
+        module = Module.create()
+        with InsertionPoint(module.body):
+            i32 = IntegerType.get_signless(32)
+            group_size = arith.ConstantOp(i32, 4)
+            async_dialect.create_group(group_size)
+        # CHECK:         %0 = "arith.constant"() <{value = 4 : i32}> : () -> i32
+        # CHECK:         %1 = "async.create_group"(%0) : (i32) -> !async.group
+        print(module)
+
 def testAsyncPass():
     with Context() as context:
         PassManager.parse("any(async-to-async-runtime)")
diff --git a/utils/bazel/llvm-project-overlay/mlir/python/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/python/BUILD.bazel
index 049098b158f294..18e84ac7b68103 100644
--- a/utils/bazel/llvm-project-overlay/mlir/python/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/python/BUILD.bazel
@@ -331,6 +331,53 @@ filegroup(
     ],
 )
 
+##---------------------------------------------------------------------------##
+# Async dialect.
+##---------------------------------------------------------------------------##
+
+gentbl_filegroup(
+    name = "AsyncOpsPyGen",
+    tbl_outs = [
+        (
+            [
+                "-gen-python-enum-bindings",
+                "-bind-dialect=async",
+            ],
+            "mlir/dialects/_async_enum_gen.py",
+        ),
+        (
+            [
+                "-gen-python-op-bindings",
+                "-bind-dialect=async",
+            ],
+            "mlir/dialects/_async_ops_gen.py",
+        ),
+    ],
+    tblgen = "//mlir:mlir-tblgen",
+    td_file = "mlir/dialects/AsyncOps.td",
+    deps = [
+        "//mlir:AsyncOpsTdFiles",
+        "//mlir:OpBaseTdFiles",
+    ],
+)
+
+filegroup(
+    name = "AsyncOpsPyFiles",
+    srcs = [
+        ":AsyncOpsPyGen",
+    ],
+)
+
+filegroup(
+    name = "AsyncOpsPackagePyFiles",
+    srcs = glob(["mlir/dialects/async_dialect/*.py"]),
+)
+
+filegroup(
+    name = "AsyncOpsPackagePassesPyFiles",
+    srcs = glob(["mlir/dialects/async_dialect/passes/*.py"]),
+)
+
 ##---------------------------------------------------------------------------##
 # Arith dialect.
 ##---------------------------------------------------------------------------##

``````````

</details>


https://github.com/llvm/llvm-project/pull/75960


More information about the llvm-commits mailing list