[Mlir-commits] [mlir] [mlir][mesh] Add collective communication operations (PR #71960)

Mehdi Amini llvmlistbot at llvm.org
Wed Nov 15 01:30:47 PST 2023


================
@@ -171,4 +182,194 @@ def Mesh_ShardOp : Mesh_Op<"shard", [Pure, SameOperandsAndResultType]> {
   }];
 }
 
+//===----------------------------------------------------------------------===//
+// collective communication ops
+//===----------------------------------------------------------------------===//
+
+class Mesh_CollectiveCommunicationOpBase<
+    string mnemonic, list<Trait> traits = []> :
+    Mesh_Op<mnemonic,
+      !listconcat(traits,
+      [DeclareOpInterfaceMethods<SymbolUserOpInterface>])> {
+  let assemblyFormat = "$input attr-dict `:` type($input) `->` type($result)";
+  dag commonArgs = (ins
+    FlatSymbolRefAttr:$mesh,
+    DefaultValuedOptionalAttr<DenseI16ArrayAttr, "{}">:$mesh_axes
+  );
+}
+
+def Mesh_AllGatherOp : Mesh_CollectiveCommunicationOpBase<"all_gather", [
+    SameOperandsAndResultElementType,
+    SameOperandsAndResultRank
+  ]> {
+  let summary = "All-gather over a device mesh.";
+  let description = [{
+    Gathers along the `gather_axis` tensor axis.
+    The order of input tensors in the resulting tensor is the same as the
+    order of the corresponding devices' multi-index in the mesh.
+
+    Example:
+    ```mlir
+    mesh.cluster @mesh0(rank = 2, dim_sizes = [2, 2])
+    ...
+    %1 = mesh.all_gather %0 {
+        mesh = @mesh0, mesh_axes = array<i16: 1>, gather_axis = 1 : index
+      } : tensor<2x2xi8> -> tensor<2x4xi8>
+    ```
+    Input:
+    ```
+                     +-------+-------+
+    device (0, 0) -> |  1  2 |  5  6 | <- device (0, 1)
+                     |  3  4 |  7  8 |
+                     +-------+-------+
+    device (1, 0) -> |  9 10 | 13 14 | <- device (1, 1)
+                     | 11 12 | 15 16 |
+                     +-------+-------+
+    ```
+    Result:
+    ```
+    +-------------+
+    |  1  2  5  6 | <- devices (0, 0) and (0, 1)
+    |  3  4  7  8 |
+    +-------------+
+    |  9 10 13 14 | <- devices (1, 0) and (1, 1)
+    | 11 12 15 16 |
+    +-------------+
+    ```
+  }];
+  let arguments = !con(commonArgs, (ins
+    AnyNon0RankedTensor:$input,
+    APIntAttr:$gather_axis
----------------
joker-eph wrote:

I32Attr?

https://github.com/llvm/llvm-project/pull/71960


More information about the Mlir-commits mailing list