[llvm] [mlir] Add shard Dialect Python Bindings (PR #162578)

Siavash Nazari via llvm-commits llvm-commits at lists.llvm.org
Wed Oct 8 21:49:38 PDT 2025


https://github.com/Svoch updated https://github.com/llvm/llvm-project/pull/162578

>From c99fa28a22b4040ddb32e84ecd27de2f3e2f6c27 Mon Sep 17 00:00:00 2001
From: Siavash Nazari <sinazari at lusll2vwgjhwq.teslamotors.com>
Date: Wed, 8 Oct 2025 18:20:51 -0700
Subject: [PATCH] Add shard Dialect Python Bindings

---
 mlir/python/CMakeLists.txt                    |  9 +++
 mlir/python/mlir/dialects/ShardOps.td         | 14 ++++
 mlir/python/mlir/dialects/shard.py            |  6 ++
 mlir/test/python/dialects/shard.py            | 68 +++++++++++++++++++
 .../mlir/python/BUILD.bazel                   | 32 +++++++++
 5 files changed, 129 insertions(+)
 create mode 100644 mlir/python/mlir/dialects/ShardOps.td
 create mode 100644 mlir/python/mlir/dialects/shard.py
 create mode 100644 mlir/test/python/dialects/shard.py

diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt
index 9f5246de6bda0..20f07440df2c3 100644
--- a/mlir/python/CMakeLists.txt
+++ b/mlir/python/CMakeLists.txt
@@ -336,6 +336,15 @@ declare_mlir_dialect_python_bindings(
     dialects/memref.py
   DIALECT_NAME memref)
 
+declare_mlir_dialect_python_bindings(
+  ADD_TO_PARENT MLIRPythonSources.Dialects
+  ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
+  TD_FILE dialects/ShardOps.td
+  SOURCES
+    dialects/shard.py
+  DIALECT_NAME shard
+  GEN_ENUM_BINDINGS)
+
 declare_mlir_dialect_python_bindings(
   ADD_TO_PARENT MLIRPythonSources.Dialects
   ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
diff --git a/mlir/python/mlir/dialects/ShardOps.td b/mlir/python/mlir/dialects/ShardOps.td
new file mode 100644
index 0000000000000..f8527664df67b
--- /dev/null
+++ b/mlir/python/mlir/dialects/ShardOps.td
@@ -0,0 +1,14 @@
+//===-- ShardOps.td - Entry point for ShardOps bindings ---------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef PYTHON_BINDINGS_SHARD_OPS
+#define PYTHON_BINDINGS_SHARD_OPS
+
+include "mlir/Dialect/Shard/IR/ShardOps.td"
+
+#endif
diff --git a/mlir/python/mlir/dialects/shard.py b/mlir/python/mlir/dialects/shard.py
new file mode 100644
index 0000000000000..8d69f17954290
--- /dev/null
+++ b/mlir/python/mlir/dialects/shard.py
@@ -0,0 +1,6 @@
+#  Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+#  See https://llvm.org/LICENSE.txt for license information.
+#  SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+from ._shard_ops_gen import *
+from ._shard_enum_gen import *
diff --git a/mlir/test/python/dialects/shard.py b/mlir/test/python/dialects/shard.py
new file mode 100644
index 0000000000000..59ece43583b1f
--- /dev/null
+++ b/mlir/test/python/dialects/shard.py
@@ -0,0 +1,68 @@
+# RUN: %PYTHON %s | FileCheck %s
+
+from mlir.ir import *
+from mlir.dialects.shard import *
+from mlir.dialects import func
+
+
+def constructAndPrintInModule(f):
+    print("\nTEST:", f.__name__)
+    with Context(), Location.unknown():
+        module = Module.create()
+        with InsertionPoint(module.body):
+            f()
+        print(module)
+    return f
+
+
+# CHECK-LABEL: TEST: testShardGrid
+ at constructAndPrintInModule
+def testShardGrid():
+    # Test creating shard grids with different shapes
+    grid2d = shard.GridOp("grid_2d", [2, 2])
+    grid1d = shard.GridOp("grid_1d", [4])
+    grid_dynamic = shard.GridOp("grid_dynamic", [2, -1])  # -1 for dynamic dimension
+
+    # CHECK: shard.grid @grid_2d(shape = 2x2)
+    # CHECK: shard.grid @grid_1d(shape = 4)
+    # CHECK: shard.grid @grid_dynamic(shape = 2x?)
+
+
+# CHECK-LABEL: TEST: testCollectiveOperations
+ at constructAndPrintInModule
+def testCollectiveOperations():
+    # Create grid and types
+    grid = shard.GridOp("grid_2x2", [2, 2])
+    i32 = IntegerType.get_signless(32)
+    input_type = RankedTensorType.get([4, 2], i32)
+    gather_result_type = RankedTensorType.get([4, 4], i32)
+    
+    # Create a function to hold the operations
+    func_type = FunctionType.get([input_type], [input_type])
+    test_func = func.FuncOp("test_collectives", func_type)
+    
+    with InsertionPoint(test_func.add_entry_block()):
+        arg = test_func.entry_block.arguments[0]
+        
+        gather_op = shard.AllGatherOp(
+            input=arg,
+            grid=FlatSymbolRefAttr.get("grid_2x2"),
+            grid_axes=ArrayAttr.get([IntegerAttr.get(i32, 1)]),
+            gather_axis=IntegerAttr.get(i32, 1),
+            result=gather_result_type
+        )
+
+        reduce_op = shard.AllReduceOp(
+            input=arg,
+            grid=FlatSymbolRefAttr.get("grid_2x2"),
+            reduction=ReductionKind.SUM,
+            result=input_type
+        )
+
+        func.ReturnOp([reduce_op])
+
+    # CHECK: shard.grid @grid_2x2(shape = 2x2)
+    # CHECK: func @test_collectives(%{{.*}}: tensor<4x2xi32>) -> tensor<4x2xi32>
+    # CHECK: %{{.*}} = shard.all_gather %{{.*}} on @grid_2x2 grid_axes = [1] gather_axis = 1 : tensor<4x2xi32> -> tensor<4x4xi32>
+    # CHECK: %{{.*}} = shard.all_reduce %{{.*}} on @grid_2x2 reduction = sum : tensor<4x2xi32> -> tensor<4x2xi32>
+
diff --git a/utils/bazel/llvm-project-overlay/mlir/python/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/python/BUILD.bazel
index 102c4161eb74c..72af4f08bde57 100644
--- a/utils/bazel/llvm-project-overlay/mlir/python/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/python/BUILD.bazel
@@ -981,6 +981,38 @@ filegroup(
     ],
 )
 
+##---------------------------------------------------------------------------##
+# Shard dialect.
+##---------------------------------------------------------------------------##
+
+gentbl_filegroup(
+    name = "ShardOpsPyGen",
+    tbl_outs = {
+        "mlir/dialects/_shard_enum_gen.py": [
+            "-gen-python-enum-bindings",
+            "-bind-dialect=shard",
+        ],
+        "mlir/dialects/_shard_ops_gen.py": [
+            "-gen-python-op-bindings",
+            "-bind-dialect=shard",
+        ],
+    },
+    tblgen = "//mlir:mlir-tblgen",
+    td_file = "mlir/dialects/ShardOps.td",
+    deps = [
+        "//mlir:OpBaseTdFiles",
+        "//mlir:ShardTdFiles",
+    ],
+)
+
+filegroup(
+    name = "ShardOpsPyFiles",
+    srcs = [
+        "mlir/dialects/shard.py",
+        ":ShardOpsPyGen",
+    ],
+)
+
 ##---------------------------------------------------------------------------##
 # Shape dialect.
 ##---------------------------------------------------------------------------##



More information about the llvm-commits mailing list