[Mlir-commits] [mlir] [mlir][mesh] Add TableGen deffinitions of more collective ops (PR #73842)

Boian Petkantchin llvmlistbot at llvm.org
Fri Dec 1 17:11:45 PST 2023


================
@@ -339,6 +339,185 @@ def Mesh_AllToAllOp : Mesh_CollectiveCommunicationOpBase<"all_to_all", [
   let hasCanonicalizer = 1;
 }
 
+def Mesh_BroadcastOp : Mesh_CollectiveCommunicationOpBase<"broadcast", [
+    AllShapesMatch<["input", "result"]>,
+    AllElementTypesMatch<["input", "result"]>
+  ]> {
+  let summary = "Broadcast over a device mesh.";
+  let description = [{
+    Broadcast the tensor on `root` to all devices in each respective group.
+    The operation broadcasts along mesh axes `mesh_axes`.
+    The `root` device specifies the in-group multi-index that is broadcast to
+    all other devices in the group.
+    
+    Example:
+    ```
+    mesh.cluster @mesh0(rank = 2, dim_sizes = [2, 2])
+
+    %1 = mesh.broadcast %0 on @mesh0
+      mesh_axes = [0]
+      root = [0]
+      : (tensor<2xi8>) -> tensor<2xi8>
+    ```
+    
+    Input:
+    ```
+                     +-------+-------+                   | broadcast
+    device (0, 0) -> |  1  2 |  3  4 | <- device (0, 1)  | along axis 0
+                     +-------+-------+                   ↓
+    device (1, 0) -> |       |       | <- device (1, 1) 
+                     +-------+-------+
+    ```
+
+    Output:
+    ```
+                     +-------+-------+
+    device (0, 0) -> |  1  2 |  3  4 | <- device (0, 1)
+                     +-------+-------+
+    device (1, 0) -> |  1  2 |  3  4 | <- device (1, 1)
+                     +-------+-------+
+    ```
+  }];
+  let arguments = !con(commonArgs, (ins
+    AnyRankedTensor:$input,
+    DenseI64ArrayAttr:$root,
+    Variadic<Index>:$root_dynamic
+  ));
+  let results = (outs
+    AnyRankedTensor:$result
+  );
+  let assemblyFormat = [{
+    $input `on` $mesh (`mesh_axes` `=` $mesh_axes^)?
+    `root` `=` custom<DynamicIndexList>($root_dynamic, $root)
+    attr-dict `:` functional-type(operands, results)
+  }];
+}
+
+def Mesh_GatherOp : Mesh_CollectiveCommunicationOpBase<"gather", [
+    AllRanksMatch<["input", "result"]>,
+    AllElementTypesMatch<["input", "result"]>
+  ]> {
+  let summary = "Gather over a device mesh.";
+  let description = [{
+    Gathers on device `root` along the `gather_axis` tensor axis.
+    `root` specifies the coordinates of a device along `mesh_axes`.
+    It uniquely identifies the root device for each device group.
+    The result tensor on non-root devices is undefined.
+    Using it will result in undefined behavior.
+
+    Example:
+    ```mlir
+    mesh.cluster @mesh0(rank = 2, dim_sizes = [2, 2])
+    ...
+    %1 = mesh.gather %0 on @mesh0 mesh_axes = [1]
+      gather_axis = 1 root = [1]
+      : (tensor<2x2xi8>) -> tensor<2x4xi8>
+    ```
+    Input:
+    ```
+                      gather tensor
+                      axis 1
+                      ------------>
+                     +-------+-------+
+    device (0, 0) -> |  1  2 |  5  6 | <- device (0, 1)
+                     |  3  4 |  7  8 |
+                     +-------+-------+
+    device (1, 0) -> |  9 10 | 13 14 | <- device (1, 1)
+                     | 11 12 | 15 16 |
+                     +-------+-------+
+    ```
+    Result:
+    ```
+    +-------------+
+    |  1  2  5  6 | <- devices (0, 1)
+    |  3  4  7  8 |
+    +-------------+
+    |  9 10 13 14 | <- devices (1, 1)
+    | 11 12 15 16 |
+    +-------------+
+    ```
+    Devices `(0, 0)` and `(1, 0)` have undefined result.
+  }];
+  let arguments = !con(commonArgs, (ins
+    AnyNon0RankedTensor:$input,
+    IndexAttr:$gather_axis,
+    DenseI64ArrayAttr:$root,
+    Variadic<Index>:$root_dynamic
+  ));
+  let results = (outs
+    AnyNon0RankedTensor:$result
+  );
+  let assemblyFormat = [{
+    $input `on` $mesh (`mesh_axes` `=` $mesh_axes^)?
+    `gather_axis` `=` $gather_axis
+    `root` `=` custom<DynamicIndexList>($root_dynamic, $root)
+    attr-dict `:` functional-type(operands, results)
+  }];
+}
+
+def Mesh_ReceiveOp : Mesh_CollectiveCommunicationOpBase<"receive", [
----------------
sogartar wrote:

Done.

https://github.com/llvm/llvm-project/pull/73842


More information about the Mlir-commits mailing list