[llvm] [mlir] [MLIR][Python] Add shard Dialect Python Bindings (PR #162578)
Siavash Nazari via llvm-commits
llvm-commits at lists.llvm.org
Wed Oct 8 22:56:25 PDT 2025
https://github.com/Svoch updated https://github.com/llvm/llvm-project/pull/162578
>From c618791277e78944f8e1c6edd6ea251833403d27 Mon Sep 17 00:00:00 2001
From: Siavash Nazari <sinazari at lusll2vwgjhwq.teslamotors.com>
Date: Wed, 8 Oct 2025 22:07:33 -0700
Subject: [PATCH] [MLIR][Python] Add shard Dialect Python Bindings
---
mlir/python/CMakeLists.txt | 9 +++
mlir/python/mlir/dialects/ShardOps.td | 14 ++++
mlir/python/mlir/dialects/shard.py | 6 ++
mlir/test/python/dialects/shard.py | 67 +++++++++++++++++++
.../mlir/python/BUILD.bazel | 32 +++++++++
5 files changed, 128 insertions(+)
create mode 100644 mlir/python/mlir/dialects/ShardOps.td
create mode 100644 mlir/python/mlir/dialects/shard.py
create mode 100644 mlir/test/python/dialects/shard.py
diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt
index 9f5246de6bda0..20f07440df2c3 100644
--- a/mlir/python/CMakeLists.txt
+++ b/mlir/python/CMakeLists.txt
@@ -336,6 +336,15 @@ declare_mlir_dialect_python_bindings(
dialects/memref.py
DIALECT_NAME memref)
+declare_mlir_dialect_python_bindings(
+ ADD_TO_PARENT MLIRPythonSources.Dialects
+ ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
+ TD_FILE dialects/ShardOps.td
+ SOURCES
+ dialects/shard.py
+ DIALECT_NAME shard
+ GEN_ENUM_BINDINGS)
+
declare_mlir_dialect_python_bindings(
ADD_TO_PARENT MLIRPythonSources.Dialects
ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
diff --git a/mlir/python/mlir/dialects/ShardOps.td b/mlir/python/mlir/dialects/ShardOps.td
new file mode 100644
index 0000000000000..f8527664df67b
--- /dev/null
+++ b/mlir/python/mlir/dialects/ShardOps.td
@@ -0,0 +1,14 @@
+//===-- ShardOps.td - Entry point for ShardOps bindings ---------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef PYTHON_BINDINGS_SHARD_OPS
+#define PYTHON_BINDINGS_SHARD_OPS
+
+include "mlir/Dialect/Shard/IR/ShardOps.td"
+
+#endif
diff --git a/mlir/python/mlir/dialects/shard.py b/mlir/python/mlir/dialects/shard.py
new file mode 100644
index 0000000000000..8d69f17954290
--- /dev/null
+++ b/mlir/python/mlir/dialects/shard.py
@@ -0,0 +1,6 @@
+# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+from ._shard_ops_gen import *
+from ._shard_enum_gen import *
diff --git a/mlir/test/python/dialects/shard.py b/mlir/test/python/dialects/shard.py
new file mode 100644
index 0000000000000..cfe31e7d1e930
--- /dev/null
+++ b/mlir/test/python/dialects/shard.py
@@ -0,0 +1,67 @@
+# RUN: %PYTHON %s | FileCheck %s
+
+from mlir.ir import *
+from mlir.dialects import shard
+from mlir.dialects import func
+
+
+def constructAndPrintInModule(f):
+ print("\nTEST:", f.__name__)
+ with Context(), Location.unknown():
+ module = Module.create()
+ with InsertionPoint(module.body):
+ f()
+ print(module)
+ return f
+
+
+# CHECK-LABEL: TEST: testShardGrid
+ at constructAndPrintInModule
+def testShardGrid():
+ # Test creating shard grids with different shapes
+ grid2d = shard.GridOp("grid_2d", [2, 2])
+ grid1d = shard.GridOp("grid_1d", [4])
+ grid_dynamic = shard.GridOp("grid_dynamic", [2, -1]) # -1 for dynamic dimension
+
+ # CHECK: shard.grid @grid_2d(shape = 2x2)
+ # CHECK: shard.grid @grid_1d(shape = 4)
+ # CHECK: shard.grid @grid_dynamic(shape = 2x?)
+
+
+# CHECK-LABEL: TEST: testCollectiveOperations
+ at constructAndPrintInModule
+def testCollectiveOperations():
+ # Create grid and types
+ grid = shard.GridOp("grid_2x2", [2, 2])
+ i32 = IntegerType.get_signless(32)
+ input_type = RankedTensorType.get([4, 2], i32)
+ gather_result_type = RankedTensorType.get([4, 4], i32)
+
+ # Create a function to hold the operations
+ func_type = FunctionType.get([input_type], [input_type])
+ test_func = func.FuncOp("test_collectives", func_type)
+
+ with InsertionPoint(test_func.add_entry_block()):
+ arg = test_func.entry_block.arguments[0]
+
+ gather_op = shard.AllGatherOp(
+ input=arg,
+ grid=FlatSymbolRefAttr.get("grid_2x2"),
+ grid_axes=ArrayAttr.get([IntegerAttr.get(i32, 1)]),
+ gather_axis=IntegerAttr.get(i32, 1),
+ result=gather_result_type,
+ )
+
+ reduce_op = shard.AllReduceOp(
+ input=arg,
+ grid=FlatSymbolRefAttr.get("grid_2x2"),
+ reduction=shard.ReductionKind.Sum,
+ result=input_type,
+ )
+
+ func.ReturnOp([reduce_op])
+
+ # CHECK: shard.grid @grid_2x2(shape = 2x2)
+ # CHECK: func @test_collectives(%{{.*}}: tensor<4x2xi32>) -> tensor<4x2xi32>
+ # CHECK: %{{.*}} = shard.all_gather %{{.*}} on @grid_2x2 grid_axes = [1] gather_axis = 1 : tensor<4x2xi32> -> tensor<4x4xi32>
+ # CHECK: %{{.*}} = shard.all_reduce %{{.*}} on @grid_2x2 reduction = sum : tensor<4x2xi32> -> tensor<4x2xi32>
diff --git a/utils/bazel/llvm-project-overlay/mlir/python/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/python/BUILD.bazel
index 102c4161eb74c..72af4f08bde57 100644
--- a/utils/bazel/llvm-project-overlay/mlir/python/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/python/BUILD.bazel
@@ -981,6 +981,38 @@ filegroup(
],
)
+##---------------------------------------------------------------------------##
+# Shard dialect.
+##---------------------------------------------------------------------------##
+
+gentbl_filegroup(
+ name = "ShardOpsPyGen",
+ tbl_outs = {
+ "mlir/dialects/_shard_enum_gen.py": [
+ "-gen-python-enum-bindings",
+ "-bind-dialect=shard",
+ ],
+ "mlir/dialects/_shard_ops_gen.py": [
+ "-gen-python-op-bindings",
+ "-bind-dialect=shard",
+ ],
+ },
+ tblgen = "//mlir:mlir-tblgen",
+ td_file = "mlir/dialects/ShardOps.td",
+ deps = [
+ "//mlir:OpBaseTdFiles",
+ "//mlir:ShardTdFiles",
+ ],
+)
+
+filegroup(
+ name = "ShardOpsPyFiles",
+ srcs = [
+ "mlir/dialects/shard.py",
+ ":ShardOpsPyGen",
+ ],
+)
+
##---------------------------------------------------------------------------##
# Shape dialect.
##---------------------------------------------------------------------------##
More information about the llvm-commits
mailing list