[libcxx-commits] [lldb] [lld] [clang-tools-extra] [mlir] [compiler-rt] [clang] [flang] [libclc] [libunwind] [llvm] [libcxx] [libc] [mlir][python] Fix generation of python bindings for async dialect (PR #75960)
Abhishek Kulkarni via libcxx-commits
libcxx-commits at lists.llvm.org
Thu Jan 25 18:50:51 PST 2024
https://github.com/adk9 updated https://github.com/llvm/llvm-project/pull/75960
>From a43ef7289cd7f5353fc4b365566011b93879e8f6 Mon Sep 17 00:00:00 2001
From: Abhishek Kulkarni <abkulkarni at microsoft.com>
Date: Tue, 19 Dec 2023 10:50:26 -0800
Subject: [PATCH] Fix generation of python bindings for async dialect
---
.../mlir/Dialect/Async/IR/CMakeLists.txt | 1 +
mlir/python/CMakeLists.txt | 9 ++--
.../mlir/dialects/async_dialect/__init__.py | 2 +-
mlir/test/python/dialects/async_dialect.py | 16 ++++++-
.../mlir/python/BUILD.bazel | 47 +++++++++++++++++++
5 files changed, 68 insertions(+), 7 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Async/IR/CMakeLists.txt b/mlir/include/mlir/Dialect/Async/IR/CMakeLists.txt
index ebbf2df760faa4..2525eee2a34ec9 100644
--- a/mlir/include/mlir/Dialect/Async/IR/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/Async/IR/CMakeLists.txt
@@ -1,2 +1,3 @@
+set(LLVM_TARGET_DEFINITIONS AsyncOps.td)
add_mlir_dialect(AsyncOps async)
add_mlir_doc(AsyncOps AsyncDialect Dialects/ -gen-dialect-doc)
diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt
index 41d91cf6778338..550465f6e37467 100644
--- a/mlir/python/CMakeLists.txt
+++ b/mlir/python/CMakeLists.txt
@@ -72,7 +72,7 @@ declare_mlir_dialect_python_bindings(
ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
TD_FILE dialects/AsyncOps.td
SOURCES_GLOB dialects/async_dialect/*.py
- DIALECT_NAME async_dialect)
+ DIALECT_NAME async)
declare_mlir_dialect_python_bindings(
ADD_TO_PARENT MLIRPythonSources.Dialects
@@ -522,7 +522,7 @@ declare_mlir_python_extension(MLIRPythonExtension.Dialects.Transform.Pybind
declare_mlir_python_extension(MLIRPythonExtension.AsyncDialectPasses
MODULE_NAME _mlirAsyncPasses
- ADD_TO_PARENT MLIRPythonSources.Dialects.async_dialect
+ ADD_TO_PARENT MLIRPythonSources.Dialects.async
ROOT_DIR "${PYTHON_SOURCE_DIR}"
SOURCES
AsyncPasses.cpp
@@ -664,11 +664,10 @@ if(NOT LLVM_ENABLE_IDE)
COMPONENT mlir-python-sources
)
endif()
-
-################################################################################
+# ###############################################################################
# The fully assembled package of modules.
# This must come last.
-################################################################################
+# ###############################################################################
add_mlir_python_modules(MLIRPythonModules
ROOT_PREFIX "${MLIR_BINARY_DIR}/python_packages/mlir_core/mlir"
diff --git a/mlir/python/mlir/dialects/async_dialect/__init__.py b/mlir/python/mlir/dialects/async_dialect/__init__.py
index dcf9d6cb2638f6..6a5ecfc20956cf 100644
--- a/mlir/python/mlir/dialects/async_dialect/__init__.py
+++ b/mlir/python/mlir/dialects/async_dialect/__init__.py
@@ -2,4 +2,4 @@
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-from .._async_dialect_ops_gen import *
+from .._async_ops_gen import *
diff --git a/mlir/test/python/dialects/async_dialect.py b/mlir/test/python/dialects/async_dialect.py
index f6181cc76118ed..13e3c42e57c21e 100644
--- a/mlir/test/python/dialects/async_dialect.py
+++ b/mlir/test/python/dialects/async_dialect.py
@@ -1,7 +1,8 @@
# RUN: %PYTHON %s | FileCheck %s
from mlir.ir import *
-import mlir.dialects.async_dialect
+from mlir.dialects import arith
+import mlir.dialects.async_dialect as async_dialect
import mlir.dialects.async_dialect.passes
from mlir.passmanager import *
@@ -11,6 +12,19 @@ def run(f):
f()
+# CHECK-LABEL: TEST: testCreateGroupOp
+ at run
+def testCreateGroupOp():
+ with Context() as ctx, Location.unknown():
+ module = Module.create()
+ with InsertionPoint(module.body):
+ i32 = IntegerType.get_signless(32)
+ group_size = arith.ConstantOp(i32, 4)
+ async_dialect.create_group(group_size)
+ # CHECK: %0 = "arith.constant"() <{value = 4 : i32}> : () -> i32
+ # CHECK: %1 = "async.create_group"(%0) : (i32) -> !async.group
+ print(module)
+
def testAsyncPass():
with Context() as context:
PassManager.parse("any(async-to-async-runtime)")
diff --git a/utils/bazel/llvm-project-overlay/mlir/python/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/python/BUILD.bazel
index 049098b158f294..18e84ac7b68103 100644
--- a/utils/bazel/llvm-project-overlay/mlir/python/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/python/BUILD.bazel
@@ -331,6 +331,53 @@ filegroup(
],
)
+##---------------------------------------------------------------------------##
+# Async dialect.
+##---------------------------------------------------------------------------##
+
+gentbl_filegroup(
+ name = "AsyncOpsPyGen",
+ tbl_outs = [
+ (
+ [
+ "-gen-python-enum-bindings",
+ "-bind-dialect=async",
+ ],
+ "mlir/dialects/_async_enum_gen.py",
+ ),
+ (
+ [
+ "-gen-python-op-bindings",
+ "-bind-dialect=async",
+ ],
+ "mlir/dialects/_async_ops_gen.py",
+ ),
+ ],
+ tblgen = "//mlir:mlir-tblgen",
+ td_file = "mlir/dialects/AsyncOps.td",
+ deps = [
+ "//mlir:AsyncOpsTdFiles",
+ "//mlir:OpBaseTdFiles",
+ ],
+)
+
+filegroup(
+ name = "AsyncOpsPyFiles",
+ srcs = [
+ ":AsyncOpsPyGen",
+ ],
+)
+
+filegroup(
+ name = "AsyncOpsPackagePyFiles",
+ srcs = glob(["mlir/dialects/async_dialect/*.py"]),
+)
+
+filegroup(
+ name = "AsyncOpsPackagePassesPyFiles",
+ srcs = glob(["mlir/dialects/async_dialect/passes/*.py"]),
+)
+
##---------------------------------------------------------------------------##
# Arith dialect.
##---------------------------------------------------------------------------##
More information about the libcxx-commits
mailing list