[Mlir-commits] [mlir] [mlir][mesh] Add collective communication operations (PR #71960)
Chengji Yao
llvmlistbot at llvm.org
Sun Nov 12 22:37:59 PST 2023
================
@@ -171,4 +182,209 @@ def Mesh_ShardOp : Mesh_Op<"shard", [Pure, SameOperandsAndResultType]> {
}];
}
+//===----------------------------------------------------------------------===//
+// collective communication ops
+//===----------------------------------------------------------------------===//
+
+class Mesh_CollectiveCommunicationOpBase<
+ string mnemonic, list<Trait> traits = []> :
+ Mesh_Op<mnemonic,
+ !listconcat(traits,
+ [SymbolUserOpInterface])> {
+ let assemblyFormat = "$input attr-dict `:` type($input) `->` type($result)";
+ code extraClassDeclarationBase = [{
+ ::mlir::LogicalResult verifySymbolUses(
+ ::mlir::SymbolTableCollection &symbolTable);
+ }];
+}
+
+def Mesh_AllGatherOp : Mesh_CollectiveCommunicationOpBase<"all_gather", [
+ SameOperandsAndResultElementType,
+ SameOperandsAndResultRank
+ ]> {
+ let summary = "All-gather over a device mesh.";
+ let description = [{
+ Gathers along the `gather_axis` tensor axis.
+ The order of input tensors in the resulting tensor is the same as the
+ order of the corresponding devices' multi-index in the mesh.
+
+ Example:
+ ```mlir
+ mesh.cluster @mesh0(rank = 2, dim_sizes = [2, 2])
+ ...
+ %1 = mesh.all_gather %0 {
+ mesh = @mesh0, mesh_axes = array<i16: 1>, gather_axis = 1 : index
+ } : tensor<2x2xi8> -> tensor<2x4xi8>
+ ```
+ Input:
+ ```
+ +-------+-------+
+ device (0, 0) -> | 1 2 | 5 6 | <- device (0, 1)
+ | 3 4 | 7 8 |
+ +-------+-------+
+ device (1, 0) -> | 9 10 | 13 14 | <- device (1, 1)
+ | 11 12 | 15 16 |
+ +-------+-------+
+ ```
+ Result:
+ ```
+ +-------------+
+ | 1 2 5 6 | <- devices (0, 0) and (0, 1)
+ | 3 4 7 8 |
+ +-------------+
+ | 9 10 13 14 | <- devices (1, 0) and (1, 1)
+ | 11 12 15 16 |
+ +-------------+
+ ```
+ }];
+ let arguments = (ins
+ AnyNon0RankedTensor:$input,
+ FlatSymbolRefAttr:$mesh,
+ DefaultValuedOptionalAttr<DenseI16ArrayAttr, "{}">:$mesh_axes,
+ APIntAttr:$gather_axis
+ );
+ let results = (outs
+ AnyNon0RankedTensor:$result
+ );
+ let hasCanonicalizer = 1;
+ let hasVerifier = 1;
+ let extraClassDeclaration = extraClassDeclarationBase;
+}
+
+def Mesh_AllReduceOp : Mesh_CollectiveCommunicationOpBase<"all_reduce", [
+ SameOperandsAndResultShape]> {
+ let summary = "All-reduce over a device mesh.";
+ let description = [{
+ The accumulation element type is specified by the result type and
+ it does not need to match the input element type.
+ The input element is converted to the result element type before
+ performing the reduction.
+
+ Attributes:
+ `reduction`: Indicates the reduction method.
+
+ Example:
+ ```
+ %1 = mesh.all_reduce %0 {
+ mesh = @mesh0, mesh_axes = array<i16: 1, 0>, reduction = #mesh.partial<max>
+ } : tensor<3x4xf32> -> tensor<3x4xf64>
+ ```
+ }];
+ let arguments = (ins
+ AnyRankedTensor:$input,
+ FlatSymbolRefAttr:$mesh,
+ DefaultValuedOptionalAttr<DenseI16ArrayAttr, "{}">:$mesh_axes,
+ DefaultValuedOptionalAttr<Mesh_PartialAttr, "::mlir::mesh::Partial::Sum">:$reduction
+ );
+ let results = (outs
+ AnyRankedTensor:$result
+ );
+ let hasCanonicalizer = 1;
+ let extraClassDeclaration = extraClassDeclarationBase;
+}
+
+def Mesh_AllToAllOp : Mesh_CollectiveCommunicationOpBase<"all_to_all", [
+ SameOperandsAndResultElementType,
+ SameOperandsAndResultRank]> {
+ let summary = "All-to-all over a device mesh.";
+ let description = [{
+ Performs an all-to-all on tensor pieces split along `split_axis`.
+ The resulting pieces are concatenated along `concat_axis` on ech device.
+ Example:
+ ```
+ mesh.cluster @mesh0(rank = 1, dim_sizes = [3])
+ ...
+ %1 = mesh.all_to_all %0 {
+ mesh = @mesh0, mesh_axes = array<i16: 0>, split_axis = 0, concat_axis = 0
+ } : tensor<3x6xi8> -> tensor<3x6xi8>
+ ```
+ Input:
+ ```
+ device device device
+ (0) (1) (2)
+ +-------+-------+-------+
----------------
yaochengji wrote:
Could you add tensor axis annotation here to enhance the readability? Maybe all_gather and reduce_scatter also need it.
https://github.com/llvm/llvm-project/pull/71960
More information about the Mlir-commits
mailing list