[Mlir-commits] [mlir] mlir::mesh::shardingOp adding shard-size control (PR #98145)
Boian Petkantchin
llvmlistbot at llvm.org
Tue Jul 9 18:55:36 PDT 2024
================
@@ -105,22 +105,221 @@ def Mesh_MeshShapeOp : Mesh_Op<"mesh_shape", [
];
}
+def Mesh_ProcessMultiIndexOp : Mesh_Op<"process_multi_index", [
+ Pure,
+ DeclareOpInterfaceMethods<SymbolUserOpInterface>,
+ DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>
+]> {
+ let summary = "Get the multi index of current device along specified mesh axes.";
+ let description = [{
+ It is used in the SPMD format of IR.
+ The `axes` mush be non-negative and less than the total number of mesh axes.
+ If the axes are empty then get the index along all axes.
+ }];
+ let arguments = (ins
+ FlatSymbolRefAttr:$mesh,
+ DefaultValuedAttr<Mesh_MeshAxesAttr, "{}">:$axes
+ );
+ let results = (outs
+ Variadic<Index>:$result
+ );
+ let assemblyFormat = [{
+ `on` $mesh (`axes` `=` $axes^)?
+ attr-dict `:` type($result)
+ }];
+ let builders = [
+ OpBuilder<(ins "::mlir::mesh::MeshOp":$mesh)>,
+ OpBuilder<(ins "StringRef":$mesh, "ArrayRef<MeshAxis>":$axes)>
+ ];
+}
+
+def Mesh_ProcessLinearIndexOp : Mesh_Op<"process_linear_index", [
+ Pure,
+ DeclareOpInterfaceMethods<SymbolUserOpInterface>,
+ DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>
+]> {
+ let summary = "Get the linear index of the current device.";
+ let description = [{
+ Example:
+ ```
+ %idx = mesh.process_linear_index on @mesh : index
+ ```
+ if `@mesh` has shape `(10, 20, 30)`, a device with multi
+ index `(1, 2, 3)` will have linear index `3 + 30*2 + 20*30*1`.
+ }];
+ let arguments = (ins FlatSymbolRefAttr:$mesh);
+ let results = (outs Index:$result);
+ let assemblyFormat = "`on` $mesh attr-dict `:` type($result)";
+ let builders = [
+ OpBuilder<(ins "::mlir::mesh::MeshOp":$mesh)>
+ ];
+}
+
+//===----------------------------------------------------------------------===//
+// Sharding operations.
+//===----------------------------------------------------------------------===//
+
+def Mesh_ShardingOp : Mesh_Op<"sharding", [
+ Pure,
+ AttrSizedOperandSegments,
+ DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>
+ ]> {
+ let summary = "Define a sharding of a tensor.";
+ let description = [{
+ The MeshSharding specifies how a tensor is sharded and distributed across the
+ process mesh. It is typically used in a `mesh.shard` operation.
+ The operation has the follwing attributes and operands:
+
+ 1. `mesh`: this attribute is a FlatSymbolRefAttr that refers to the device
+ mesh where the distributed tensor is placed. The symbol must resolve to a
+ `mesh.mesh` operation.
+
+ 2. `split_axes`: is an array composed of int64_t sub-arrays. The outer array's
+ maximum size is the `rank` of the related tensor. For the i-th sub-array, if
+ its value is [x, y], it indicates that the tensor's i-th dimension is splitted
+ along the x and y axes of the device mesh.
+
+ 3. [Optional] `partial_axes`: if not empty, this signifies that the tensor is partial
+ one along the specified mesh axes. An all-reduce should be applied to obtain
+ the complete tensor, with reduction type being specified by `partial_type`.
+
+ 4. [Optional] `partial_type`: indicates the reduction type of the possible all-reduce
+ op. It has 4 possible values:
+ `generic`: is not an allowed value inside a shard attribute.
+
+ 5. [Optional] Sizes of halos to be added for each sharded tensor dimension.
+ `halo_sizes`is provided as a flattened 1d array of i64s, 2 values for each sharded dimension.
+ `halo_sizes` = [1, 2] means that the first sharded dimension gets an additional
+ halo of size 1 at the start of the first dimension and a halo size is 2 at its end.
+ `halo_sizes` = [1, 2, 2, 3] defines halos for the first 2 sharded dimensions
+ e.g. the first sharded dimension gets [1,2] halos and the seconds gets [2,3] halos.
+ `?` indicates dynamic halo sizes.
+
+ 6. [Optional] Sizes of sharded dimensions of each shard.
+ `sharded_dims_sizes`is provided as a flattened 1d array of i64s: for each device of the
+ device-mesh one value for each sharded tensor dimension.
+ Assuming a 3d-tensor of shape 32x32x32 with the first 2 dimensions being sharded,
+ `sharded_dims_sizes` = [16, 8, 16, 24] means that the first device of
+ the device-mesh will get a shard of shape 16x8x32 and the second device will get a
+ shard of shape 16x24x32.
+ `?` indicates dynamic shard dimensions.
+
+ `halo_sizes` and `sharded_dims_sizes` are mutually exclusive.
+
+ Examples:
+
+ ```
+ mesh.mesh @mesh0(shape = 2x2x4)
+ mesh.mesh @mesh1d_4(shape = 4)
+
+ // The tensor is fully replicated on @mesh0.
+ // Currently, there must be at least one sub-array present in axes, even
+ // if it's empty. Otherwise, a parsing error will occur.
+ %sharding0 = mesh.sharding @mesh0, [[]]
+
+ // The tensor is sharded on the first dimension along axis 0 of @mesh0
+ %sharding1 = mesh.sharding @mesh0, [[0]]
+
+ // The tensor is sharded on its first dimension along axis 0 of @mesh0 and
+ // it is also a partial_sum along mesh axis 1.
+ %sharding2 = mesh.sharding @mesh0, [[0], []] partial = sum[1]
+
+ // The tensor is sharded on its first dimension along axis 0 of @mesh0 and
+ // it is also a partial_max along mesh axis 1.
+ %sharding3 = mesh.sharding @mesh0, [[0]] partial = max[1]
+
+ // Could be used for a mesh.shard op
+ %sharded0 = mesh.shard %arg0 to %sharding3 : tensor<4x8xf32>
+
+ // The tensor is sharded on its first dimension along axis 0 of @mesh0 and
+ // and it has halo-sizes of 1 and 2 on the sharded dim.
+ %halo_sharding = mesh.sharding @mesh0, [[0]] halo_sizes = [1, 2]
+ %sharded1 = mesh.shard %arg0 to %halo_sharding : tensor<4x8xf32>
+
+ // The tensor is sharded on its second dimension along axis 0 of @mesh1d_4
+ // and it has pre-defined shard sizes. The shards of the devices will have
+ // the following shapes: [4x2, 4x3, 4x4, 4x5]
+ %sharding4 = mesh.sharding @mesh1d_4, [[], [0]] sharded_dims_sizes = [2, 3, 4, 5]
+ %sharded2 = mesh.shard %arg0 to %sharding4 : tensor<4x14xf32>
+ ```
+ }];
+
+ let arguments = (ins
+ FlatSymbolRefAttr:$mesh,
+ Mesh_MeshAxesArrayAttr:$split_axes,
+ OptionalAttr<Mesh_MeshAxesAttr>:$partial_axes,
+ OptionalAttr<Mesh_ReductionKindAttr>:$partial_type,
+ DefaultValuedAttr<DenseI64ArrayAttr, "{}">:$static_sharded_dims_sizes,
+ Variadic<I64>:$dynamic_sharded_dims_sizes,
+ DefaultValuedAttr<DenseI64ArrayAttr, "{}">:$static_halo_sizes,
+ Variadic<I64>:$dynamic_halo_sizes
+ );
+ let results = (outs
+ Mesh_Sharding:$result
+ );
+ let assemblyFormat = [{
+ $mesh `,` $split_axes
----------------
sogartar wrote:
Maybe the assembly format should be the same as the reset. It will definitely be more clear for outsiders.
```
(`split_axes` `=` $split_axes)
```
Just mentioning it. Does not have to be in this PR.
https://github.com/llvm/llvm-project/pull/98145
More information about the Mlir-commits
mailing list