[libcxx-commits] [flang] [compiler-rt] [libunwind] [libc] [libcxx] [lld] [libclc] [llvm] [mlir] [clang] [lldb] [clang-tools-extra] [mlir][python] Fix generation of python bindings for async dialect (PR #75960)

Abhishek Kulkarni via libcxx-commits libcxx-commits at lists.llvm.org
Tue Jan 9 11:21:11 PST 2024


https://github.com/adk9 updated https://github.com/llvm/llvm-project/pull/75960

>From a43ef7289cd7f5353fc4b365566011b93879e8f6 Mon Sep 17 00:00:00 2001
From: Abhishek Kulkarni <abkulkarni at microsoft.com>
Date: Tue, 19 Dec 2023 10:50:26 -0800
Subject: [PATCH] Fix generation of python bindings for async dialect

---
 .../mlir/Dialect/Async/IR/CMakeLists.txt      |  1 +
 mlir/python/CMakeLists.txt                    |  9 ++--
 .../mlir/dialects/async_dialect/__init__.py   |  2 +-
 mlir/test/python/dialects/async_dialect.py    | 16 ++++++-
 .../mlir/python/BUILD.bazel                   | 47 +++++++++++++++++++
 5 files changed, 68 insertions(+), 7 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Async/IR/CMakeLists.txt b/mlir/include/mlir/Dialect/Async/IR/CMakeLists.txt
index ebbf2df760faa4..2525eee2a34ec9 100644
--- a/mlir/include/mlir/Dialect/Async/IR/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/Async/IR/CMakeLists.txt
@@ -1,2 +1,3 @@
+set(LLVM_TARGET_DEFINITIONS AsyncOps.td)
 add_mlir_dialect(AsyncOps async)
 add_mlir_doc(AsyncOps AsyncDialect Dialects/ -gen-dialect-doc)
diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt
index 41d91cf6778338..550465f6e37467 100644
--- a/mlir/python/CMakeLists.txt
+++ b/mlir/python/CMakeLists.txt
@@ -72,7 +72,7 @@ declare_mlir_dialect_python_bindings(
   ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
   TD_FILE dialects/AsyncOps.td
   SOURCES_GLOB dialects/async_dialect/*.py
-  DIALECT_NAME async_dialect)
+  DIALECT_NAME async)
 
 declare_mlir_dialect_python_bindings(
   ADD_TO_PARENT MLIRPythonSources.Dialects
@@ -522,7 +522,7 @@ declare_mlir_python_extension(MLIRPythonExtension.Dialects.Transform.Pybind
 
 declare_mlir_python_extension(MLIRPythonExtension.AsyncDialectPasses
   MODULE_NAME _mlirAsyncPasses
-  ADD_TO_PARENT MLIRPythonSources.Dialects.async_dialect
+  ADD_TO_PARENT MLIRPythonSources.Dialects.async
   ROOT_DIR "${PYTHON_SOURCE_DIR}"
   SOURCES
     AsyncPasses.cpp
@@ -664,11 +664,10 @@ if(NOT LLVM_ENABLE_IDE)
     COMPONENT mlir-python-sources
   )
 endif()
-
-################################################################################
+# ###############################################################################
 # The fully assembled package of modules.
 # This must come last.
-################################################################################
+# ###############################################################################
 
 add_mlir_python_modules(MLIRPythonModules
   ROOT_PREFIX "${MLIR_BINARY_DIR}/python_packages/mlir_core/mlir"
diff --git a/mlir/python/mlir/dialects/async_dialect/__init__.py b/mlir/python/mlir/dialects/async_dialect/__init__.py
index dcf9d6cb2638f6..6a5ecfc20956cf 100644
--- a/mlir/python/mlir/dialects/async_dialect/__init__.py
+++ b/mlir/python/mlir/dialects/async_dialect/__init__.py
@@ -2,4 +2,4 @@
 #  See https://llvm.org/LICENSE.txt for license information.
 #  SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
 
-from .._async_dialect_ops_gen import *
+from .._async_ops_gen import *
diff --git a/mlir/test/python/dialects/async_dialect.py b/mlir/test/python/dialects/async_dialect.py
index f6181cc76118ed..13e3c42e57c21e 100644
--- a/mlir/test/python/dialects/async_dialect.py
+++ b/mlir/test/python/dialects/async_dialect.py
@@ -1,7 +1,8 @@
 # RUN: %PYTHON %s | FileCheck %s
 
 from mlir.ir import *
-import mlir.dialects.async_dialect
+from mlir.dialects import arith
+import mlir.dialects.async_dialect as async_dialect
 import mlir.dialects.async_dialect.passes
 from mlir.passmanager import *
 
@@ -11,6 +12,19 @@ def run(f):
     f()
 
 
+# CHECK-LABEL: TEST: testCreateGroupOp
+ at run
+def testCreateGroupOp():
+    with Context() as ctx, Location.unknown():
+        module = Module.create()
+        with InsertionPoint(module.body):
+            i32 = IntegerType.get_signless(32)
+            group_size = arith.ConstantOp(i32, 4)
+            async_dialect.create_group(group_size)
+        # CHECK:         %0 = "arith.constant"() <{value = 4 : i32}> : () -> i32
+        # CHECK:         %1 = "async.create_group"(%0) : (i32) -> !async.group
+        print(module)
+
 def testAsyncPass():
     with Context() as context:
         PassManager.parse("any(async-to-async-runtime)")
diff --git a/utils/bazel/llvm-project-overlay/mlir/python/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/python/BUILD.bazel
index 049098b158f294..18e84ac7b68103 100644
--- a/utils/bazel/llvm-project-overlay/mlir/python/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/python/BUILD.bazel
@@ -331,6 +331,53 @@ filegroup(
     ],
 )
 
+##---------------------------------------------------------------------------##
+# Async dialect.
+##---------------------------------------------------------------------------##
+
+gentbl_filegroup(
+    name = "AsyncOpsPyGen",
+    tbl_outs = [
+        (
+            [
+                "-gen-python-enum-bindings",
+                "-bind-dialect=async",
+            ],
+            "mlir/dialects/_async_enum_gen.py",
+        ),
+        (
+            [
+                "-gen-python-op-bindings",
+                "-bind-dialect=async",
+            ],
+            "mlir/dialects/_async_ops_gen.py",
+        ),
+    ],
+    tblgen = "//mlir:mlir-tblgen",
+    td_file = "mlir/dialects/AsyncOps.td",
+    deps = [
+        "//mlir:AsyncOpsTdFiles",
+        "//mlir:OpBaseTdFiles",
+    ],
+)
+
+filegroup(
+    name = "AsyncOpsPyFiles",
+    srcs = [
+        ":AsyncOpsPyGen",
+    ],
+)
+
+filegroup(
+    name = "AsyncOpsPackagePyFiles",
+    srcs = glob(["mlir/dialects/async_dialect/*.py"]),
+)
+
+filegroup(
+    name = "AsyncOpsPackagePassesPyFiles",
+    srcs = glob(["mlir/dialects/async_dialect/passes/*.py"]),
+)
+
 ##---------------------------------------------------------------------------##
 # Arith dialect.
 ##---------------------------------------------------------------------------##



More information about the libcxx-commits mailing list