[Mlir-commits] [mlir] [MLIR] Fix Markdown docs of `shard.neighbors_linear_indices` (PR #163409)
Shenghang Tsai
llvmlistbot at llvm.org
Tue Oct 14 07:48:27 PDT 2025
https://github.com/jackalcooper created https://github.com/llvm/llvm-project/pull/163409
<img width="889" height="138" alt="image" src="https://github.com/user-attachments/assets/5a907cd7-beed-4a03-9d5e-43ec594d1e04" />
>From 383716317dce376206154b47f8100dbd849d87b2 Mon Sep 17 00:00:00 2001
From: Shenghang Tsai <jackalcooper at gmail.com>
Date: Tue, 14 Oct 2025 15:56:38 +0800
Subject: [PATCH 1/2] Closing unclosed backquotes
---
.../include/mlir/Dialect/Shard/IR/ShardOps.td | 30 +++++++++----------
1 file changed, 15 insertions(+), 15 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Shard/IR/ShardOps.td b/mlir/include/mlir/Dialect/Shard/IR/ShardOps.td
index 29b384f401876..6f9a213cc04e5 100644
--- a/mlir/include/mlir/Dialect/Shard/IR/ShardOps.td
+++ b/mlir/include/mlir/Dialect/Shard/IR/ShardOps.td
@@ -49,7 +49,7 @@ def Shard_GridOp : Shard_Op<"grid", [Symbol, Pure]> {
Example:
```
// A device grid with 3 axes, the total device number is 4 * 8 * 12
- // The dimension sizes are 4, 8, 12
+ // The dimension sizes are 4, 8, 12
shard.grid @grid0(shape = 4x8x12)
// A device grid with 2 axes, the total device number is unknown
@@ -173,8 +173,8 @@ def Shard_NeighborsLinearIndicesOp : Shard_Op<"neighbors_linear_indices", [
%idx = shard.neighbors_linear_indices on @grid[%c1, %c2, %c3] split_axes = [1] : index
```
The above returns two indices, `633` and `693`, which correspond to the
- index of the previous process `(1, 1, 3)`, and the next process
- `(1, 3, 3) along the split axis `1`.
+ index of the previous process `(1, 1, 3)`, and the next process
+ `(1, 3, 3)` along the split axis `1`.
A negative value is returned if there is no neighbor in the respective
direction along the given `split_axes`.
@@ -222,7 +222,7 @@ def Shard_ShardingOp : Shard_Op<"sharding", [
size is 2 at its end. `halo_sizes = [1, 2, 2, 3]` defines halos for the first 2
sharded dimensions e.g. the first sharded dimension gets `[1,2]` halos and the
seconds gets `[2,3]` halos. `?` indicates dynamic halo sizes.
-
+
4. [Optional] Offsets for each shard and sharded tensor dimension.
`sharded_dims_offsets` is provided as a flattened 1d array of i64s. For each
sharded tensor dimension the offsets (starting index) of all shards in that
@@ -230,12 +230,12 @@ def Shard_ShardingOp : Shard_Op<"sharding", [
For a 1d sharding this means that position `i` has the exclusive prefix sum for
shard `i`, and since only contiguous sharding is supported, its inclusive prefix
sum is at position 'i+1'.
-
+
Assuming a 3d-tensor of shape 32x32x32 with the first 2 dimensions being sharded,
`sharded_dims_offsets` = [0, 24, 32, 0, 20, 32] means that the first device of
the device-grid will get a shard of shape 24x20x32 and the second device will get
a shard of shape 8x12x32. `?` indicates dynamic shard dimensions.
-
+
`halo_sizes` and `sharded_dims_offsets` are mutually exclusive.
Examples:
@@ -259,7 +259,7 @@ def Shard_ShardingOp : Shard_Op<"sharding", [
// and it has halo-sizes of 1 and 2 on the sharded dim.
%halo_sharding = shard.sharding @grid0 split_axes = [[0]] halo_sizes = [1, 2]
%sharded1 = shard.shard %arg0 to %halo_sharding : tensor<4x8xf32>
-
+
// The tensor is sharded on its second dimension along axis 0 of @grid1d_4
// and it has pre-defined shard sizes. The shards of the devices will have
// the following shapes: [4x2, 4x3, 4x4, 4x5]
@@ -267,7 +267,7 @@ def Shard_ShardingOp : Shard_Op<"sharding", [
%sharded2 = shard.shard %arg0 to %sharding4 : tensor<4x14xf32>
```
}];
-
+
let arguments = (ins
FlatSymbolRefAttr:$grid,
Shard_GridAxesArrayAttr:$split_axes,
@@ -389,7 +389,7 @@ def Shard_ShardOp : Shard_Op<"shard", [
%0 = shard.shard %arg0 to %sharding annotate_for_users : tensor<4x8xf32>
...
}
-
+
func.func @two_operands_annotated(%arg0 : tensor<4x8xf32>, %arg1 : tensor<16x8xf32>) -> () {
%sharding = shard.sharding @grid0 split_axes = [[0]] : !shard.sharding
%0 = shard.shard %arg0 to %sharding annotate_for_users : tensor<4x8xf32>
@@ -589,7 +589,7 @@ def Shard_AllSliceOp : Shard_CollectiveCommunicationOpBase<"all_slice", [
This operation can be thought of as the inverse of all-gather.
Technically, it is not required that all processes have the same input tensor.
Each process will slice a piece of its local tensor based on its in-group device index.
- The operation does not communicate data between devices.
+ The operation does not communicate data between devices.
Example:
```mlir
@@ -706,7 +706,7 @@ def Shard_BroadcastOp : Shard_CollectiveCommunicationOpBase<"broadcast", [
The operation broadcasts along grid axes `grid_axes`.
The `root` device specifies the in-group multi-index that is broadcast to
all other devices in the group.
-
+
Example:
```
shard.grid @grid0(shape = 2x2)
@@ -716,13 +716,13 @@ def Shard_BroadcastOp : Shard_CollectiveCommunicationOpBase<"broadcast", [
root = [0]
: (tensor<2xi8>) -> tensor<2xi8>
```
-
+
Input:
```
+-------+-------+ | broadcast
device (0, 0) -> | 1 2 | 3 4 | <- device (0, 1) | along axis 0
+-------+-------+ ↓
- device (1, 0) -> | | | <- device (1, 1)
+ device (1, 0) -> | | | <- device (1, 1)
+-------+-------+
```
@@ -978,7 +978,7 @@ def Shard_ScatterOp : Shard_CollectiveCommunicationOpBase<"scatter", [
device
(1, 1)
```
-
+
Result:
```
device
@@ -986,7 +986,7 @@ def Shard_ScatterOp : Shard_CollectiveCommunicationOpBase<"scatter", [
↓
+-------+-------+
device (0, 0) -> | 1 2 | 5 6 |
- +-------+-------+
+ +-------+-------+
device (1, 0) -> | 3 4 | 7 8 |
+-------+-------+
↑
>From 07e67c8dad57dbcb22186272bb8c455accab4515 Mon Sep 17 00:00:00 2001
From: Shenghang Tsai <jackalcooper at gmail.com>
Date: Tue, 14 Oct 2025 22:47:28 +0800
Subject: [PATCH 2/2] restore whitespaces
---
.../include/mlir/Dialect/Shard/IR/ShardOps.td | 28 +++++++++----------
1 file changed, 14 insertions(+), 14 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Shard/IR/ShardOps.td b/mlir/include/mlir/Dialect/Shard/IR/ShardOps.td
index 6f9a213cc04e5..b9d7163ea4c1e 100644
--- a/mlir/include/mlir/Dialect/Shard/IR/ShardOps.td
+++ b/mlir/include/mlir/Dialect/Shard/IR/ShardOps.td
@@ -49,7 +49,7 @@ def Shard_GridOp : Shard_Op<"grid", [Symbol, Pure]> {
Example:
```
// A device grid with 3 axes, the total device number is 4 * 8 * 12
- // The dimension sizes are 4, 8, 12
+ // The dimension sizes are 4, 8, 12
shard.grid @grid0(shape = 4x8x12)
// A device grid with 2 axes, the total device number is unknown
@@ -173,7 +173,7 @@ def Shard_NeighborsLinearIndicesOp : Shard_Op<"neighbors_linear_indices", [
%idx = shard.neighbors_linear_indices on @grid[%c1, %c2, %c3] split_axes = [1] : index
```
The above returns two indices, `633` and `693`, which correspond to the
- index of the previous process `(1, 1, 3)`, and the next process
+ index of the previous process `(1, 1, 3)`, and the next process
`(1, 3, 3)` along the split axis `1`.
A negative value is returned if there is no neighbor in the respective
@@ -222,7 +222,7 @@ def Shard_ShardingOp : Shard_Op<"sharding", [
size is 2 at its end. `halo_sizes = [1, 2, 2, 3]` defines halos for the first 2
sharded dimensions e.g. the first sharded dimension gets `[1,2]` halos and the
seconds gets `[2,3]` halos. `?` indicates dynamic halo sizes.
-
+
4. [Optional] Offsets for each shard and sharded tensor dimension.
`sharded_dims_offsets` is provided as a flattened 1d array of i64s. For each
sharded tensor dimension the offsets (starting index) of all shards in that
@@ -230,12 +230,12 @@ def Shard_ShardingOp : Shard_Op<"sharding", [
For a 1d sharding this means that position `i` has the exclusive prefix sum for
shard `i`, and since only contiguous sharding is supported, its inclusive prefix
sum is at position 'i+1'.
-
+
Assuming a 3d-tensor of shape 32x32x32 with the first 2 dimensions being sharded,
`sharded_dims_offsets` = [0, 24, 32, 0, 20, 32] means that the first device of
the device-grid will get a shard of shape 24x20x32 and the second device will get
a shard of shape 8x12x32. `?` indicates dynamic shard dimensions.
-
+
`halo_sizes` and `sharded_dims_offsets` are mutually exclusive.
Examples:
@@ -259,7 +259,7 @@ def Shard_ShardingOp : Shard_Op<"sharding", [
// and it has halo-sizes of 1 and 2 on the sharded dim.
%halo_sharding = shard.sharding @grid0 split_axes = [[0]] halo_sizes = [1, 2]
%sharded1 = shard.shard %arg0 to %halo_sharding : tensor<4x8xf32>
-
+
// The tensor is sharded on its second dimension along axis 0 of @grid1d_4
// and it has pre-defined shard sizes. The shards of the devices will have
// the following shapes: [4x2, 4x3, 4x4, 4x5]
@@ -267,7 +267,7 @@ def Shard_ShardingOp : Shard_Op<"sharding", [
%sharded2 = shard.shard %arg0 to %sharding4 : tensor<4x14xf32>
```
}];
-
+
let arguments = (ins
FlatSymbolRefAttr:$grid,
Shard_GridAxesArrayAttr:$split_axes,
@@ -389,7 +389,7 @@ def Shard_ShardOp : Shard_Op<"shard", [
%0 = shard.shard %arg0 to %sharding annotate_for_users : tensor<4x8xf32>
...
}
-
+
func.func @two_operands_annotated(%arg0 : tensor<4x8xf32>, %arg1 : tensor<16x8xf32>) -> () {
%sharding = shard.sharding @grid0 split_axes = [[0]] : !shard.sharding
%0 = shard.shard %arg0 to %sharding annotate_for_users : tensor<4x8xf32>
@@ -589,7 +589,7 @@ def Shard_AllSliceOp : Shard_CollectiveCommunicationOpBase<"all_slice", [
This operation can be thought of as the inverse of all-gather.
Technically, it is not required that all processes have the same input tensor.
Each process will slice a piece of its local tensor based on its in-group device index.
- The operation does not communicate data between devices.
+ The operation does not communicate data between devices.
Example:
```mlir
@@ -706,7 +706,7 @@ def Shard_BroadcastOp : Shard_CollectiveCommunicationOpBase<"broadcast", [
The operation broadcasts along grid axes `grid_axes`.
The `root` device specifies the in-group multi-index that is broadcast to
all other devices in the group.
-
+
Example:
```
shard.grid @grid0(shape = 2x2)
@@ -716,13 +716,13 @@ def Shard_BroadcastOp : Shard_CollectiveCommunicationOpBase<"broadcast", [
root = [0]
: (tensor<2xi8>) -> tensor<2xi8>
```
-
+
Input:
```
+-------+-------+ | broadcast
device (0, 0) -> | 1 2 | 3 4 | <- device (0, 1) | along axis 0
+-------+-------+ ↓
- device (1, 0) -> | | | <- device (1, 1)
+ device (1, 0) -> | | | <- device (1, 1)
+-------+-------+
```
@@ -978,7 +978,7 @@ def Shard_ScatterOp : Shard_CollectiveCommunicationOpBase<"scatter", [
device
(1, 1)
```
-
+
Result:
```
device
@@ -986,7 +986,7 @@ def Shard_ScatterOp : Shard_CollectiveCommunicationOpBase<"scatter", [
↓
+-------+-------+
device (0, 0) -> | 1 2 | 5 6 |
- +-------+-------+
+ +-------+-------+
device (1, 0) -> | 3 4 | 7 8 |
+-------+-------+
↑
More information about the Mlir-commits
mailing list