[Mlir-commits] [mlir] [mlir][mesh] Add collective communication operations (PR #71960)
Mehdi Amini
llvmlistbot at llvm.org
Wed Nov 15 01:30:47 PST 2023
================
@@ -171,4 +182,194 @@ def Mesh_ShardOp : Mesh_Op<"shard", [Pure, SameOperandsAndResultType]> {
}];
}
+//===----------------------------------------------------------------------===//
+// collective communication ops
+//===----------------------------------------------------------------------===//
+
+class Mesh_CollectiveCommunicationOpBase<
+ string mnemonic, list<Trait> traits = []> :
+ Mesh_Op<mnemonic,
+ !listconcat(traits,
+ [DeclareOpInterfaceMethods<SymbolUserOpInterface>])> {
+ let assemblyFormat = "$input attr-dict `:` type($input) `->` type($result)";
+ dag commonArgs = (ins
+ FlatSymbolRefAttr:$mesh,
+ DefaultValuedOptionalAttr<DenseI16ArrayAttr, "{}">:$mesh_axes
+ );
+}
+
+def Mesh_AllGatherOp : Mesh_CollectiveCommunicationOpBase<"all_gather", [
+ SameOperandsAndResultElementType,
+ SameOperandsAndResultRank
+ ]> {
+ let summary = "All-gather over a device mesh.";
+ let description = [{
+ Gathers along the `gather_axis` tensor axis.
+ The order of input tensors in the resulting tensor is the same as the
+ order of the corresponding devices' multi-index in the mesh.
+
+ Example:
+ ```mlir
+ mesh.cluster @mesh0(rank = 2, dim_sizes = [2, 2])
+ ...
+ %1 = mesh.all_gather %0 {
+ mesh = @mesh0, mesh_axes = array<i16: 1>, gather_axis = 1 : index
+ } : tensor<2x2xi8> -> tensor<2x4xi8>
+ ```
+ Input:
+ ```
+ +-------+-------+
+ device (0, 0) -> | 1 2 | 5 6 | <- device (0, 1)
+ | 3 4 | 7 8 |
+ +-------+-------+
+ device (1, 0) -> | 9 10 | 13 14 | <- device (1, 1)
+ | 11 12 | 15 16 |
+ +-------+-------+
+ ```
+ Result:
+ ```
+ +-------------+
+ | 1 2 5 6 | <- devices (0, 0) and (0, 1)
+ | 3 4 7 8 |
+ +-------------+
+ | 9 10 13 14 | <- devices (1, 0) and (1, 1)
+ | 11 12 15 16 |
+ +-------------+
+ ```
+ }];
+ let arguments = !con(commonArgs, (ins
+ AnyNon0RankedTensor:$input,
+ APIntAttr:$gather_axis
----------------
joker-eph wrote:
I32Attr?
https://github.com/llvm/llvm-project/pull/71960
More information about the Mlir-commits
mailing list