[Mlir-commits] [mlir] mlir::mesh::shardingOp adding shard-size control (PR #98145)

Frank Schlimbach llvmlistbot at llvm.org
Wed Jul 10 01:11:34 PDT 2024


================
@@ -105,22 +105,221 @@ def Mesh_MeshShapeOp : Mesh_Op<"mesh_shape", [
   ];
 }
 
+def Mesh_ProcessMultiIndexOp : Mesh_Op<"process_multi_index", [
+  Pure,
+  DeclareOpInterfaceMethods<SymbolUserOpInterface>,
+  DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>
+]> {
+  let summary = "Get the multi index of current device along specified mesh axes.";
+  let description = [{
+    It is used in the SPMD format of IR.
+    The `axes` mush be non-negative and less than the total number of mesh axes.
+    If the axes are empty then get the index along all axes.
+  }];
+  let arguments = (ins
+    FlatSymbolRefAttr:$mesh,
+    DefaultValuedAttr<Mesh_MeshAxesAttr, "{}">:$axes
+  );
+  let results = (outs
+    Variadic<Index>:$result
+  );
+  let assemblyFormat = [{
+    `on` $mesh (`axes` `=` $axes^)?
+    attr-dict `:` type($result)
+  }];
+  let builders = [
+    OpBuilder<(ins "::mlir::mesh::MeshOp":$mesh)>,
+    OpBuilder<(ins "StringRef":$mesh, "ArrayRef<MeshAxis>":$axes)>
+  ];
+}
+
+def Mesh_ProcessLinearIndexOp : Mesh_Op<"process_linear_index", [
+  Pure,
+  DeclareOpInterfaceMethods<SymbolUserOpInterface>,
+  DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>
+]> {
+  let summary = "Get the linear index of the current device.";
+  let description = [{
+    Example:
+    ```
+    %idx = mesh.process_linear_index on @mesh : index
+    ```
+    if `@mesh` has shape `(10, 20, 30)`, a device with multi
+    index `(1, 2, 3)` will have linear index `3 + 30*2 + 20*30*1`.
+  }];
+  let arguments = (ins FlatSymbolRefAttr:$mesh);
+  let results = (outs Index:$result);
+  let assemblyFormat = "`on` $mesh attr-dict `:` type($result)";
+  let builders = [
+    OpBuilder<(ins "::mlir::mesh::MeshOp":$mesh)>
+  ];
+}
+
+//===----------------------------------------------------------------------===//
+// Sharding operations.
+//===----------------------------------------------------------------------===//
+
+def Mesh_ShardingOp : Mesh_Op<"sharding", [
+    Pure,
+    AttrSizedOperandSegments,
+    DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>
+  ]> {
+  let summary = "Define a sharding of a tensor.";
+  let description = [{
+    The MeshSharding specifies how a tensor is sharded and distributed across the
+    process mesh. It is typically used in a `mesh.shard` operation.
+    The operation has the follwing attributes and operands:
+
+    1. `mesh`: this attribute is a FlatSymbolRefAttr that refers to the device
+    mesh where the distributed tensor is placed. The symbol must resolve to a
+    `mesh.mesh` operation.
+
+    2. `split_axes`: is an array composed of int64_t sub-arrays. The outer array's
+    maximum size is the `rank` of the related tensor. For the i-th sub-array, if
+    its value is [x, y], it indicates that the tensor's i-th dimension is splitted
+    along the x and y axes of the device mesh.
+
+    3. [Optional] `partial_axes`: if not empty, this signifies that the tensor is partial
+    one along the specified mesh axes. An all-reduce should be applied to obtain
+    the complete tensor, with reduction type being specified by `partial_type`.
+
+    4. [Optional] `partial_type`: indicates the reduction type of the possible all-reduce
+    op. It has 4 possible values:
+    `generic`: is not an allowed value inside a shard attribute.
+
+    5. [Optional] Sizes of halos to be added for each sharded tensor dimension.
+    `halo_sizes`is provided as a flattened 1d array of i64s, 2 values for each sharded dimension.
+    `halo_sizes` = [1, 2] means that the first sharded dimension gets an additional
+    halo of size 1 at the start of the first dimension and a halo size is 2 at its end.
+    `halo_sizes` = [1, 2, 2, 3] defines halos for the first 2 sharded dimensions
+    e.g. the first sharded dimension gets [1,2] halos and the seconds gets [2,3] halos.
+    `?` indicates dynamic halo sizes.
+    
+    6. [Optional] Sizes of sharded dimensions of each shard.
+    `sharded_dims_sizes`is provided as a flattened 1d array of i64s: for each device of the
+    device-mesh one value for each sharded tensor dimension.
+    Assuming a 3d-tensor of shape 32x32x32 with the first 2 dimensions being sharded,
+    `sharded_dims_sizes` = [16, 8, 16, 24] means that the first device of
+    the device-mesh will get a shard of shape 16x8x32 and the second device will get a
+    shard of shape 16x24x32.
+    `?` indicates dynamic shard dimensions.
----------------
fschlimb wrote:

I guess this could be solved with a custom parser so that it accepts nested lists like in the split-axis attribute but using the dynamic/static syntax. I hoped we can either defer or maybe someone has done something like this before and can help/give a pointer. I don't see how a struct would make this easier - can you elaborate?

https://github.com/llvm/llvm-project/pull/98145


More information about the Mlir-commits mailing list