[Mlir-commits] [mlir] [mlir][mesh] Add TableGen deffinitions of more collective ops (PR #73842)
Chengji Yao
llvmlistbot at llvm.org
Fri Dec 1 13:49:22 PST 2023
================
@@ -339,6 +339,185 @@ def Mesh_AllToAllOp : Mesh_CollectiveCommunicationOpBase<"all_to_all", [
let hasCanonicalizer = 1;
}
+def Mesh_BroadcastOp : Mesh_CollectiveCommunicationOpBase<"broadcast", [
+ AllShapesMatch<["input", "result"]>,
+ AllElementTypesMatch<["input", "result"]>
+ ]> {
+ let summary = "Broadcast over a device mesh.";
+ let description = [{
+ Broadcast the tensor on `root` to all devices in each respective group.
+ The operation broadcasts along mesh axes `mesh_axes`.
+ The `root` device specifies the in-group multi-index that is broadcast to
+ all other devices in the group.
+
+ Example:
+ ```
+ mesh.cluster @mesh0(rank = 2, dim_sizes = [2, 2])
+
+ %1 = mesh.broadcast %0 on @mesh0
+ mesh_axes = [0]
+ root = [0]
+ : (tensor<2xi8>) -> tensor<2xi8>
+ ```
+
+ Input:
+ ```
+ +-------+-------+ | broadcast
+ device (0, 0) -> | 1 2 | 3 4 | <- device (0, 1) | along axis 0
+ +-------+-------+ ↓
+ device (1, 0) -> | | | <- device (1, 1)
+ +-------+-------+
+ ```
+
+ Output:
+ ```
+ +-------+-------+
+ device (0, 0) -> | 1 2 | 3 4 | <- device (0, 1)
+ +-------+-------+
+ device (1, 0) -> | 1 2 | 3 4 | <- device (1, 1)
+ +-------+-------+
+ ```
+ }];
+ let arguments = !con(commonArgs, (ins
+ AnyRankedTensor:$input,
+ DenseI64ArrayAttr:$root,
+ Variadic<Index>:$root_dynamic
+ ));
+ let results = (outs
+ AnyRankedTensor:$result
+ );
+ let assemblyFormat = [{
+ $input `on` $mesh (`mesh_axes` `=` $mesh_axes^)?
+ `root` `=` custom<DynamicIndexList>($root_dynamic, $root)
+ attr-dict `:` functional-type(operands, results)
+ }];
+}
+
+def Mesh_GatherOp : Mesh_CollectiveCommunicationOpBase<"gather", [
+ AllRanksMatch<["input", "result"]>,
+ AllElementTypesMatch<["input", "result"]>
+ ]> {
+ let summary = "Gather over a device mesh.";
+ let description = [{
+ Gathers on device `root` along the `gather_axis` tensor axis.
+ `root` specifies the coordinates of a device along `mesh_axes`.
+ It uniquely identifies the root device for each device group.
+ The result tensor on non-root devices is undefined.
+ Using it will result in undefined behavior.
+
+ Example:
+ ```mlir
+ mesh.cluster @mesh0(rank = 2, dim_sizes = [2, 2])
+ ...
+ %1 = mesh.gather %0 on @mesh0 mesh_axes = [1]
+ gather_axis = 1 root = [1]
+ : (tensor<2x2xi8>) -> tensor<2x4xi8>
----------------
yaochengji wrote:
Is the output still `tensor<2x4xi8>` if it's not the root?
https://github.com/llvm/llvm-project/pull/73842
More information about the Mlir-commits
mailing list