[Mlir-commits] [mlir] mlir::mesh::shardingOp adding shard-size control (PR #98145)
Frank Schlimbach
llvmlistbot at llvm.org
Fri Aug 2 06:51:44 PDT 2024
================
@@ -105,22 +105,221 @@ def Mesh_MeshShapeOp : Mesh_Op<"mesh_shape", [
];
}
+def Mesh_ProcessMultiIndexOp : Mesh_Op<"process_multi_index", [
+ Pure,
+ DeclareOpInterfaceMethods<SymbolUserOpInterface>,
+ DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>
+]> {
+ let summary = "Get the multi index of current device along specified mesh axes.";
+ let description = [{
+ It is used in the SPMD format of IR.
+ The `axes` mush be non-negative and less than the total number of mesh axes.
+ If the axes are empty then get the index along all axes.
+ }];
+ let arguments = (ins
+ FlatSymbolRefAttr:$mesh,
+ DefaultValuedAttr<Mesh_MeshAxesAttr, "{}">:$axes
+ );
+ let results = (outs
+ Variadic<Index>:$result
+ );
+ let assemblyFormat = [{
+ `on` $mesh (`axes` `=` $axes^)?
+ attr-dict `:` type($result)
+ }];
+ let builders = [
+ OpBuilder<(ins "::mlir::mesh::MeshOp":$mesh)>,
+ OpBuilder<(ins "StringRef":$mesh, "ArrayRef<MeshAxis>":$axes)>
+ ];
+}
+
+def Mesh_ProcessLinearIndexOp : Mesh_Op<"process_linear_index", [
+ Pure,
+ DeclareOpInterfaceMethods<SymbolUserOpInterface>,
+ DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>
+]> {
+ let summary = "Get the linear index of the current device.";
+ let description = [{
+ Example:
+ ```
+ %idx = mesh.process_linear_index on @mesh : index
+ ```
+ if `@mesh` has shape `(10, 20, 30)`, a device with multi
+ index `(1, 2, 3)` will have linear index `3 + 30*2 + 20*30*1`.
+ }];
+ let arguments = (ins FlatSymbolRefAttr:$mesh);
+ let results = (outs Index:$result);
+ let assemblyFormat = "`on` $mesh attr-dict `:` type($result)";
+ let builders = [
+ OpBuilder<(ins "::mlir::mesh::MeshOp":$mesh)>
+ ];
+}
+
+//===----------------------------------------------------------------------===//
+// Sharding operations.
+//===----------------------------------------------------------------------===//
+
+def Mesh_ShardingOp : Mesh_Op<"sharding", [
+ Pure,
+ AttrSizedOperandSegments,
+ DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>
+ ]> {
+ let summary = "Define a sharding of a tensor.";
+ let description = [{
+ The MeshSharding specifies how a tensor is sharded and distributed across the
+ process mesh. It is typically used in a `mesh.shard` operation.
+ The operation has the follwing attributes and operands:
+
+ 1. `mesh`: this attribute is a FlatSymbolRefAttr that refers to the device
+ mesh where the distributed tensor is placed. The symbol must resolve to a
+ `mesh.mesh` operation.
+
+ 2. `split_axes`: is an array composed of int64_t sub-arrays. The outer array's
+ maximum size is the `rank` of the related tensor. For the i-th sub-array, if
+ its value is [x, y], it indicates that the tensor's i-th dimension is splitted
+ along the x and y axes of the device mesh.
+
+ 3. [Optional] `partial_axes`: if not empty, this signifies that the tensor is partial
+ one along the specified mesh axes. An all-reduce should be applied to obtain
+ the complete tensor, with reduction type being specified by `partial_type`.
+
+ 4. [Optional] `partial_type`: indicates the reduction type of the possible all-reduce
+ op. It has 4 possible values:
+ `generic`: is not an allowed value inside a shard attribute.
+
+ 5. [Optional] Sizes of halos to be added for each sharded tensor dimension.
+ `halo_sizes`is provided as a flattened 1d array of i64s, 2 values for each sharded dimension.
+ `halo_sizes` = [1, 2] means that the first sharded dimension gets an additional
+ halo of size 1 at the start of the first dimension and a halo size is 2 at its end.
+ `halo_sizes` = [1, 2, 2, 3] defines halos for the first 2 sharded dimensions
+ e.g. the first sharded dimension gets [1,2] halos and the seconds gets [2,3] halos.
+ `?` indicates dynamic halo sizes.
+
+ 6. [Optional] Sizes of sharded dimensions of each shard.
+ `sharded_dims_sizes`is provided as a flattened 1d array of i64s: for each device of the
+ device-mesh one value for each sharded tensor dimension.
+ Assuming a 3d-tensor of shape 32x32x32 with the first 2 dimensions being sharded,
+ `sharded_dims_sizes` = [16, 8, 16, 24] means that the first device of
+ the device-mesh will get a shard of shape 16x8x32 and the second device will get a
+ shard of shape 16x24x32.
+ `?` indicates dynamic shard dimensions.
----------------
fschlimb wrote:
Good question. It is thinkable to write a dynamic check at least for some basic cases.
For now, I guess it would be a good idea to disallow this case. What do you think?
https://github.com/llvm/llvm-project/pull/98145
More information about the Mlir-commits
mailing list