[Mlir-commits] [mlir] [mlir][mesh] Add TableGen deffinitions of more collective ops (PR #73842)

Chengji Yao llvmlistbot at llvm.org
Fri Dec 1 13:49:22 PST 2023


================
@@ -339,6 +339,185 @@ def Mesh_AllToAllOp : Mesh_CollectiveCommunicationOpBase<"all_to_all", [
   let hasCanonicalizer = 1;
 }
 
+def Mesh_BroadcastOp : Mesh_CollectiveCommunicationOpBase<"broadcast", [
+    AllShapesMatch<["input", "result"]>,
+    AllElementTypesMatch<["input", "result"]>
+  ]> {
+  let summary = "Broadcast over a device mesh.";
+  let description = [{
+    Broadcast the tensor on `root` to all devices in each respective group.
+    The operation broadcasts along mesh axes `mesh_axes`.
+    The `root` device specifies the in-group multi-index that is broadcast to
+    all other devices in the group.
+    
+    Example:
+    ```
+    mesh.cluster @mesh0(rank = 2, dim_sizes = [2, 2])
+
+    %1 = mesh.broadcast %0 on @mesh0
+      mesh_axes = [0]
+      root = [0]
+      : (tensor<2xi8>) -> tensor<2xi8>
+    ```
+    
+    Input:
+    ```
+                     +-------+-------+                   | broadcast
+    device (0, 0) -> |  1  2 |  3  4 | <- device (0, 1)  | along axis 0
+                     +-------+-------+                   ↓
+    device (1, 0) -> |       |       | <- device (1, 1) 
+                     +-------+-------+
+    ```
+
+    Output:
+    ```
+                     +-------+-------+
+    device (0, 0) -> |  1  2 |  3  4 | <- device (0, 1)
+                     +-------+-------+
+    device (1, 0) -> |  1  2 |  3  4 | <- device (1, 1)
+                     +-------+-------+
+    ```
+  }];
+  let arguments = !con(commonArgs, (ins
+    AnyRankedTensor:$input,
+    DenseI64ArrayAttr:$root,
+    Variadic<Index>:$root_dynamic
+  ));
+  let results = (outs
+    AnyRankedTensor:$result
+  );
+  let assemblyFormat = [{
+    $input `on` $mesh (`mesh_axes` `=` $mesh_axes^)?
+    `root` `=` custom<DynamicIndexList>($root_dynamic, $root)
+    attr-dict `:` functional-type(operands, results)
+  }];
+}
+
+def Mesh_GatherOp : Mesh_CollectiveCommunicationOpBase<"gather", [
+    AllRanksMatch<["input", "result"]>,
+    AllElementTypesMatch<["input", "result"]>
+  ]> {
+  let summary = "Gather over a device mesh.";
+  let description = [{
+    Gathers on device `root` along the `gather_axis` tensor axis.
+    `root` specifies the coordinates of a device along `mesh_axes`.
+    It uniquely identifies the root device for each device group.
+    The result tensor on non-root devices is undefined.
+    Using it will result in undefined behavior.
+
+    Example:
+    ```mlir
+    mesh.cluster @mesh0(rank = 2, dim_sizes = [2, 2])
+    ...
+    %1 = mesh.gather %0 on @mesh0 mesh_axes = [1]
+      gather_axis = 1 root = [1]
+      : (tensor<2x2xi8>) -> tensor<2x4xi8>
----------------
yaochengji wrote:

Is the output still `tensor<2x4xi8>` if it's not the root?

https://github.com/llvm/llvm-project/pull/73842


More information about the Mlir-commits mailing list