[Mlir-commits] [mlir] [MLIR] Fix Markdown docs of `shard.neighbors_linear_indices` (PR #163409)

Shenghang Tsai llvmlistbot at llvm.org
Tue Oct 14 07:48:27 PDT 2025


https://github.com/jackalcooper created https://github.com/llvm/llvm-project/pull/163409

<img width="889" height="138" alt="image" src="https://github.com/user-attachments/assets/5a907cd7-beed-4a03-9d5e-43ec594d1e04" />


>From 383716317dce376206154b47f8100dbd849d87b2 Mon Sep 17 00:00:00 2001
From: Shenghang Tsai <jackalcooper at gmail.com>
Date: Tue, 14 Oct 2025 15:56:38 +0800
Subject: [PATCH 1/2] Closing unclosed backquotes

---
 .../include/mlir/Dialect/Shard/IR/ShardOps.td | 30 +++++++++----------
 1 file changed, 15 insertions(+), 15 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Shard/IR/ShardOps.td b/mlir/include/mlir/Dialect/Shard/IR/ShardOps.td
index 29b384f401876..6f9a213cc04e5 100644
--- a/mlir/include/mlir/Dialect/Shard/IR/ShardOps.td
+++ b/mlir/include/mlir/Dialect/Shard/IR/ShardOps.td
@@ -49,7 +49,7 @@ def Shard_GridOp : Shard_Op<"grid", [Symbol, Pure]> {
     Example:
     ```
     // A device grid with 3 axes, the total device number is 4 * 8 * 12
-    // The dimension sizes are 4, 8, 12 
+    // The dimension sizes are 4, 8, 12
     shard.grid @grid0(shape = 4x8x12)
 
     // A device grid with 2 axes, the total device number is unknown
@@ -173,8 +173,8 @@ def Shard_NeighborsLinearIndicesOp : Shard_Op<"neighbors_linear_indices", [
     %idx = shard.neighbors_linear_indices on @grid[%c1, %c2, %c3] split_axes = [1] : index
     ```
     The above returns two indices, `633` and `693`, which correspond to the
-    index of the previous process `(1, 1, 3)`, and the next process 
-    `(1, 3, 3) along the split axis `1`.
+    index of the previous process `(1, 1, 3)`, and the next process
+    `(1, 3, 3)` along the split axis `1`.
 
     A negative value is returned if there is no neighbor in the respective
     direction along the given `split_axes`.
@@ -222,7 +222,7 @@ def Shard_ShardingOp : Shard_Op<"sharding", [
     size is 2 at its end. `halo_sizes = [1, 2, 2, 3]` defines halos for the first 2
     sharded dimensions e.g. the first sharded dimension gets `[1,2]` halos and the
     seconds gets `[2,3]` halos. `?` indicates dynamic halo sizes.
-    
+
     4. [Optional] Offsets for each shard and sharded tensor dimension.
     `sharded_dims_offsets` is provided as a flattened 1d array of i64s. For each
     sharded tensor dimension the offsets (starting index) of all shards in that
@@ -230,12 +230,12 @@ def Shard_ShardingOp : Shard_Op<"sharding", [
     For a 1d sharding this means that position `i` has the exclusive prefix sum for
     shard `i`, and since only contiguous sharding is supported, its inclusive prefix
     sum is at position 'i+1'.
-    
+
     Assuming a 3d-tensor of shape 32x32x32 with the first 2 dimensions being sharded,
     `sharded_dims_offsets` = [0, 24, 32, 0, 20, 32] means that the first device of
     the device-grid will get a shard of shape 24x20x32 and the second device will get
     a shard of shape 8x12x32. `?` indicates dynamic shard dimensions.
-    
+
     `halo_sizes` and `sharded_dims_offsets` are mutually exclusive.
 
     Examples:
@@ -259,7 +259,7 @@ def Shard_ShardingOp : Shard_Op<"sharding", [
     // and it has halo-sizes of 1 and 2 on the sharded dim.
     %halo_sharding = shard.sharding @grid0 split_axes = [[0]] halo_sizes = [1, 2]
     %sharded1 = shard.shard %arg0 to %halo_sharding : tensor<4x8xf32>
-    
+
     // The tensor is sharded on its second dimension along axis 0 of @grid1d_4
     // and it has pre-defined shard sizes. The shards of the devices will have
     // the following shapes: [4x2, 4x3, 4x4, 4x5]
@@ -267,7 +267,7 @@ def Shard_ShardingOp : Shard_Op<"sharding", [
     %sharded2 = shard.shard %arg0 to %sharding4 : tensor<4x14xf32>
     ```
   }];
-    
+
   let arguments = (ins
     FlatSymbolRefAttr:$grid,
     Shard_GridAxesArrayAttr:$split_axes,
@@ -389,7 +389,7 @@ def Shard_ShardOp : Shard_Op<"shard", [
       %0 = shard.shard %arg0 to %sharding annotate_for_users : tensor<4x8xf32>
       ...
     }
-    
+
     func.func @two_operands_annotated(%arg0 : tensor<4x8xf32>, %arg1 : tensor<16x8xf32>) -> () {
       %sharding = shard.sharding @grid0 split_axes = [[0]] : !shard.sharding
       %0 = shard.shard %arg0 to %sharding annotate_for_users : tensor<4x8xf32>
@@ -589,7 +589,7 @@ def Shard_AllSliceOp : Shard_CollectiveCommunicationOpBase<"all_slice", [
     This operation can be thought of as the inverse of all-gather.
     Technically, it is not required that all processes have the same input tensor.
     Each process will slice a piece of its local tensor based on its in-group device index.
-    The operation does not communicate data between devices. 
+    The operation does not communicate data between devices.
 
     Example:
     ```mlir
@@ -706,7 +706,7 @@ def Shard_BroadcastOp : Shard_CollectiveCommunicationOpBase<"broadcast", [
     The operation broadcasts along grid axes `grid_axes`.
     The `root` device specifies the in-group multi-index that is broadcast to
     all other devices in the group.
-    
+
     Example:
     ```
     shard.grid @grid0(shape = 2x2)
@@ -716,13 +716,13 @@ def Shard_BroadcastOp : Shard_CollectiveCommunicationOpBase<"broadcast", [
       root = [0]
       : (tensor<2xi8>) -> tensor<2xi8>
     ```
-    
+
     Input:
     ```
                      +-------+-------+                   | broadcast
     device (0, 0) -> |  1  2 |  3  4 | <- device (0, 1)  | along axis 0
                      +-------+-------+                   ↓
-    device (1, 0) -> |       |       | <- device (1, 1) 
+    device (1, 0) -> |       |       | <- device (1, 1)
                      +-------+-------+
     ```
 
@@ -978,7 +978,7 @@ def Shard_ScatterOp : Shard_CollectiveCommunicationOpBase<"scatter", [
                               device
                               (1, 1)
     ```
-    
+
     Result:
     ```
                               device
@@ -986,7 +986,7 @@ def Shard_ScatterOp : Shard_CollectiveCommunicationOpBase<"scatter", [
                                  ↓
                      +-------+-------+
     device (0, 0) -> |  1  2 |  5  6 |
-                     +-------+-------+ 
+                     +-------+-------+
     device (1, 0) -> |  3  4 |  7  8 |
                      +-------+-------+
                                 ↑

>From 07e67c8dad57dbcb22186272bb8c455accab4515 Mon Sep 17 00:00:00 2001
From: Shenghang Tsai <jackalcooper at gmail.com>
Date: Tue, 14 Oct 2025 22:47:28 +0800
Subject: [PATCH 2/2] restore whitespaces

---
 .../include/mlir/Dialect/Shard/IR/ShardOps.td | 28 +++++++++----------
 1 file changed, 14 insertions(+), 14 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Shard/IR/ShardOps.td b/mlir/include/mlir/Dialect/Shard/IR/ShardOps.td
index 6f9a213cc04e5..b9d7163ea4c1e 100644
--- a/mlir/include/mlir/Dialect/Shard/IR/ShardOps.td
+++ b/mlir/include/mlir/Dialect/Shard/IR/ShardOps.td
@@ -49,7 +49,7 @@ def Shard_GridOp : Shard_Op<"grid", [Symbol, Pure]> {
     Example:
     ```
     // A device grid with 3 axes, the total device number is 4 * 8 * 12
-    // The dimension sizes are 4, 8, 12
+    // The dimension sizes are 4, 8, 12 
     shard.grid @grid0(shape = 4x8x12)
 
     // A device grid with 2 axes, the total device number is unknown
@@ -173,7 +173,7 @@ def Shard_NeighborsLinearIndicesOp : Shard_Op<"neighbors_linear_indices", [
     %idx = shard.neighbors_linear_indices on @grid[%c1, %c2, %c3] split_axes = [1] : index
     ```
     The above returns two indices, `633` and `693`, which correspond to the
-    index of the previous process `(1, 1, 3)`, and the next process
+    index of the previous process `(1, 1, 3)`, and the next process 
     `(1, 3, 3)` along the split axis `1`.
 
     A negative value is returned if there is no neighbor in the respective
@@ -222,7 +222,7 @@ def Shard_ShardingOp : Shard_Op<"sharding", [
     size is 2 at its end. `halo_sizes = [1, 2, 2, 3]` defines halos for the first 2
     sharded dimensions e.g. the first sharded dimension gets `[1,2]` halos and the
     seconds gets `[2,3]` halos. `?` indicates dynamic halo sizes.
-
+    
     4. [Optional] Offsets for each shard and sharded tensor dimension.
     `sharded_dims_offsets` is provided as a flattened 1d array of i64s. For each
     sharded tensor dimension the offsets (starting index) of all shards in that
@@ -230,12 +230,12 @@ def Shard_ShardingOp : Shard_Op<"sharding", [
     For a 1d sharding this means that position `i` has the exclusive prefix sum for
     shard `i`, and since only contiguous sharding is supported, its inclusive prefix
     sum is at position 'i+1'.
-
+    
     Assuming a 3d-tensor of shape 32x32x32 with the first 2 dimensions being sharded,
     `sharded_dims_offsets` = [0, 24, 32, 0, 20, 32] means that the first device of
     the device-grid will get a shard of shape 24x20x32 and the second device will get
     a shard of shape 8x12x32. `?` indicates dynamic shard dimensions.
-
+    
     `halo_sizes` and `sharded_dims_offsets` are mutually exclusive.
 
     Examples:
@@ -259,7 +259,7 @@ def Shard_ShardingOp : Shard_Op<"sharding", [
     // and it has halo-sizes of 1 and 2 on the sharded dim.
     %halo_sharding = shard.sharding @grid0 split_axes = [[0]] halo_sizes = [1, 2]
     %sharded1 = shard.shard %arg0 to %halo_sharding : tensor<4x8xf32>
-
+    
     // The tensor is sharded on its second dimension along axis 0 of @grid1d_4
     // and it has pre-defined shard sizes. The shards of the devices will have
     // the following shapes: [4x2, 4x3, 4x4, 4x5]
@@ -267,7 +267,7 @@ def Shard_ShardingOp : Shard_Op<"sharding", [
     %sharded2 = shard.shard %arg0 to %sharding4 : tensor<4x14xf32>
     ```
   }];
-
+    
   let arguments = (ins
     FlatSymbolRefAttr:$grid,
     Shard_GridAxesArrayAttr:$split_axes,
@@ -389,7 +389,7 @@ def Shard_ShardOp : Shard_Op<"shard", [
       %0 = shard.shard %arg0 to %sharding annotate_for_users : tensor<4x8xf32>
       ...
     }
-
+    
     func.func @two_operands_annotated(%arg0 : tensor<4x8xf32>, %arg1 : tensor<16x8xf32>) -> () {
       %sharding = shard.sharding @grid0 split_axes = [[0]] : !shard.sharding
       %0 = shard.shard %arg0 to %sharding annotate_for_users : tensor<4x8xf32>
@@ -589,7 +589,7 @@ def Shard_AllSliceOp : Shard_CollectiveCommunicationOpBase<"all_slice", [
     This operation can be thought of as the inverse of all-gather.
     Technically, it is not required that all processes have the same input tensor.
     Each process will slice a piece of its local tensor based on its in-group device index.
-    The operation does not communicate data between devices.
+    The operation does not communicate data between devices. 
 
     Example:
     ```mlir
@@ -706,7 +706,7 @@ def Shard_BroadcastOp : Shard_CollectiveCommunicationOpBase<"broadcast", [
     The operation broadcasts along grid axes `grid_axes`.
     The `root` device specifies the in-group multi-index that is broadcast to
     all other devices in the group.
-
+    
     Example:
     ```
     shard.grid @grid0(shape = 2x2)
@@ -716,13 +716,13 @@ def Shard_BroadcastOp : Shard_CollectiveCommunicationOpBase<"broadcast", [
       root = [0]
       : (tensor<2xi8>) -> tensor<2xi8>
     ```
-
+    
     Input:
     ```
                      +-------+-------+                   | broadcast
     device (0, 0) -> |  1  2 |  3  4 | <- device (0, 1)  | along axis 0
                      +-------+-------+                   ↓
-    device (1, 0) -> |       |       | <- device (1, 1)
+    device (1, 0) -> |       |       | <- device (1, 1) 
                      +-------+-------+
     ```
 
@@ -978,7 +978,7 @@ def Shard_ScatterOp : Shard_CollectiveCommunicationOpBase<"scatter", [
                               device
                               (1, 1)
     ```
-
+    
     Result:
     ```
                               device
@@ -986,7 +986,7 @@ def Shard_ScatterOp : Shard_CollectiveCommunicationOpBase<"scatter", [
                                  ↓
                      +-------+-------+
     device (0, 0) -> |  1  2 |  5  6 |
-                     +-------+-------+
+                     +-------+-------+ 
     device (1, 0) -> |  3  4 |  7  8 |
                      +-------+-------+
                                 ↑



More information about the Mlir-commits mailing list