[Mlir-commits] [mlir] mlir::mesh::shardingOp adding shard-size control (PR #98145)

Boian Petkantchin llvmlistbot at llvm.org
Tue Jul 9 18:55:36 PDT 2024


================
@@ -105,22 +105,221 @@ def Mesh_MeshShapeOp : Mesh_Op<"mesh_shape", [
   ];
 }
 
+def Mesh_ProcessMultiIndexOp : Mesh_Op<"process_multi_index", [
+  Pure,
+  DeclareOpInterfaceMethods<SymbolUserOpInterface>,
+  DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>
+]> {
+  let summary = "Get the multi index of current device along specified mesh axes.";
+  let description = [{
+    It is used in the SPMD format of IR.
+    The `axes` mush be non-negative and less than the total number of mesh axes.
+    If the axes are empty then get the index along all axes.
+  }];
+  let arguments = (ins
+    FlatSymbolRefAttr:$mesh,
+    DefaultValuedAttr<Mesh_MeshAxesAttr, "{}">:$axes
+  );
+  let results = (outs
+    Variadic<Index>:$result
+  );
+  let assemblyFormat = [{
+    `on` $mesh (`axes` `=` $axes^)?
+    attr-dict `:` type($result)
+  }];
+  let builders = [
+    OpBuilder<(ins "::mlir::mesh::MeshOp":$mesh)>,
+    OpBuilder<(ins "StringRef":$mesh, "ArrayRef<MeshAxis>":$axes)>
+  ];
+}
+
+def Mesh_ProcessLinearIndexOp : Mesh_Op<"process_linear_index", [
+  Pure,
+  DeclareOpInterfaceMethods<SymbolUserOpInterface>,
+  DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>
+]> {
+  let summary = "Get the linear index of the current device.";
+  let description = [{
+    Example:
+    ```
+    %idx = mesh.process_linear_index on @mesh : index
+    ```
+    if `@mesh` has shape `(10, 20, 30)`, a device with multi
+    index `(1, 2, 3)` will have linear index `3 + 30*2 + 20*30*1`.
+  }];
+  let arguments = (ins FlatSymbolRefAttr:$mesh);
+  let results = (outs Index:$result);
+  let assemblyFormat = "`on` $mesh attr-dict `:` type($result)";
+  let builders = [
+    OpBuilder<(ins "::mlir::mesh::MeshOp":$mesh)>
+  ];
+}
+
+//===----------------------------------------------------------------------===//
+// Sharding operations.
+//===----------------------------------------------------------------------===//
+
+def Mesh_ShardingOp : Mesh_Op<"sharding", [
+    Pure,
+    AttrSizedOperandSegments,
+    DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>
+  ]> {
+  let summary = "Define a sharding of a tensor.";
+  let description = [{
+    The MeshSharding specifies how a tensor is sharded and distributed across the
+    process mesh. It is typically used in a `mesh.shard` operation.
+    The operation has the follwing attributes and operands:
+
+    1. `mesh`: this attribute is a FlatSymbolRefAttr that refers to the device
+    mesh where the distributed tensor is placed. The symbol must resolve to a
+    `mesh.mesh` operation.
+
+    2. `split_axes`: is an array composed of int64_t sub-arrays. The outer array's
+    maximum size is the `rank` of the related tensor. For the i-th sub-array, if
+    its value is [x, y], it indicates that the tensor's i-th dimension is splitted
+    along the x and y axes of the device mesh.
+
+    3. [Optional] `partial_axes`: if not empty, this signifies that the tensor is partial
+    one along the specified mesh axes. An all-reduce should be applied to obtain
+    the complete tensor, with reduction type being specified by `partial_type`.
+
+    4. [Optional] `partial_type`: indicates the reduction type of the possible all-reduce
+    op. It has 4 possible values:
+    `generic`: is not an allowed value inside a shard attribute.
+
+    5. [Optional] Sizes of halos to be added for each sharded tensor dimension.
+    `halo_sizes`is provided as a flattened 1d array of i64s, 2 values for each sharded dimension.
+    `halo_sizes` = [1, 2] means that the first sharded dimension gets an additional
+    halo of size 1 at the start of the first dimension and a halo size is 2 at its end.
+    `halo_sizes` = [1, 2, 2, 3] defines halos for the first 2 sharded dimensions
+    e.g. the first sharded dimension gets [1,2] halos and the seconds gets [2,3] halos.
+    `?` indicates dynamic halo sizes.
+    
+    6. [Optional] Sizes of sharded dimensions of each shard.
+    `sharded_dims_sizes`is provided as a flattened 1d array of i64s: for each device of the
+    device-mesh one value for each sharded tensor dimension.
+    Assuming a 3d-tensor of shape 32x32x32 with the first 2 dimensions being sharded,
+    `sharded_dims_sizes` = [16, 8, 16, 24] means that the first device of
+    the device-mesh will get a shard of shape 16x8x32 and the second device will get a
+    shard of shape 16x24x32.
+    `?` indicates dynamic shard dimensions.
+    
+    `halo_sizes` and `sharded_dims_sizes` are mutually exclusive.
+
+    Examples:
+
+    ```
+    mesh.mesh @mesh0(shape = 2x2x4)
+    mesh.mesh @mesh1d_4(shape = 4)
+
+    // The tensor is fully replicated on @mesh0.
+    // Currently, there must be at least one sub-array present in axes, even
+    // if it's empty. Otherwise, a parsing error will occur.
+    %sharding0 = mesh.sharding @mesh0, [[]]
+
+    // The tensor is sharded on the first dimension along axis 0 of @mesh0
+    %sharding1 = mesh.sharding @mesh0, [[0]]
+
+    // The tensor is sharded on its first dimension along axis 0 of @mesh0 and
+    // it is also a partial_sum along mesh axis 1.
+    %sharding2 = mesh.sharding @mesh0, [[0], []] partial = sum[1]
+
+    // The tensor is sharded on its first dimension along axis 0 of @mesh0 and
+    // it is also a partial_max along mesh axis 1.
+    %sharding3 = mesh.sharding @mesh0, [[0]] partial = max[1]
+
+    // Could be used for a mesh.shard op
+    %sharded0 = mesh.shard %arg0 to %sharding3 : tensor<4x8xf32>
+
+    // The tensor is sharded on its first dimension along axis 0 of @mesh0 and
+    // and it has halo-sizes of 1 and 2 on the sharded dim.
+    %halo_sharding = mesh.sharding @mesh0, [[0]] halo_sizes = [1, 2]
+    %sharded1 = mesh.shard %arg0 to %halo_sharding : tensor<4x8xf32>
+    
+    // The tensor is sharded on its second dimension along axis 0 of @mesh1d_4
+    // and it has pre-defined shard sizes. The shards of the devices will have
+    // the following shapes: [4x2, 4x3, 4x4, 4x5]
+    %sharding4 = mesh.sharding @mesh1d_4, [[], [0]] sharded_dims_sizes = [2, 3, 4, 5]
+    %sharded2 = mesh.shard %arg0 to %sharding4 : tensor<4x14xf32>
+    ```
+  }];
+    
+  let arguments = (ins
+    FlatSymbolRefAttr:$mesh,
+    Mesh_MeshAxesArrayAttr:$split_axes,
+    OptionalAttr<Mesh_MeshAxesAttr>:$partial_axes,
+    OptionalAttr<Mesh_ReductionKindAttr>:$partial_type,
+    DefaultValuedAttr<DenseI64ArrayAttr, "{}">:$static_sharded_dims_sizes,
+    Variadic<I64>:$dynamic_sharded_dims_sizes,
+    DefaultValuedAttr<DenseI64ArrayAttr, "{}">:$static_halo_sizes,
+    Variadic<I64>:$dynamic_halo_sizes
+  );
+  let results = (outs
+    Mesh_Sharding:$result
+  );
+  let assemblyFormat = [{
+    $mesh `,` $split_axes
----------------
sogartar wrote:

Maybe the assembly format should be the same as the reset. It will definitely be more clear for outsiders.
```
(`split_axes` `=` $split_axes)
```

Just mentioning it. Does not have to be in this PR.

https://github.com/llvm/llvm-project/pull/98145


More information about the Mlir-commits mailing list