[Mlir-commits] [mlir] [mlir][mesh] Add TableGen deffinitions of more collective ops (PR #73842)

Boian Petkantchin llvmlistbot at llvm.org
Fri Dec 1 17:11:16 PST 2023


https://github.com/sogartar updated https://github.com/llvm/llvm-project/pull/73842

>From baf29d52d58868728cd85f2f0de62791d78b7a88 Mon Sep 17 00:00:00 2001
From: Boian Petkantchin <boian at nod-labs.com>
Date: Mon, 20 Nov 2023 17:41:21 -0800
Subject: [PATCH 1/4] [mlir][mesh] Add TableGen definitions of more collective
 ops

Add deffinitions for
broadcast, gather, receive, reduce, scatter, send and shift.
---
 mlir/docs/Dialects/Mesh.md                   |  14 +
 mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td | 324 +++++++++++++++++++
 mlir/lib/Dialect/Mesh/IR/MeshOps.cpp         |  65 ++++
 3 files changed, 403 insertions(+)

diff --git a/mlir/docs/Dialects/Mesh.md b/mlir/docs/Dialects/Mesh.md
index 03877f1a6544817..77da2f10d8902c9 100644
--- a/mlir/docs/Dialects/Mesh.md
+++ b/mlir/docs/Dialects/Mesh.md
@@ -15,14 +15,17 @@ explanation.
 The main addition is that the collectives in this dialect have mesh
 semantics.
 
+### Device groups
 The operation attributes `mesh` and `mesh_axes` specifies a list of device mesh
 axes that partition the devices into disjoint groups.
 The collective operation is performed between devices in the same group.
 Devices that have the same coordinates outside of axes `mesh_axes` are in the
 same group.
+A group is described by its multi-index along the axes outside of `mesh_axes`.
 For example if we have a device mesh of size `2x3x4x5` and the partition mesh
 axes list is `[0, 1]` then devices are partitioned into the groups
 `{ { (i, j, k, m) | 0<=i<2, 0<=j<3 } | 0<=k<4, 0<=m<5 }`.
+The device groups would be `{ (k, m) | 0<=k<4, 0<=m<5 }`.
 Devices (1, 0, 2, 3) and (1, 1, 2, 3) will be in the same group.
 Device (1, 0, 2, 4) will be in another group.
 Some collective operations like all-to-all and all-gather care about the
@@ -33,6 +36,17 @@ The axes are ordered from outer to inner.
 If we have an axis list `[3, 1]` then device `(i, 1, k, 0)` will precede
 both devices `(i, 0, k, 1)` and `(i, 2, k, 0)`.
 
+### In-group Device
+Some operations like `broadcast`, `scatter` and `send` specify devices in each
+device-group.
+These devices are represented with their multi-index over the mesh axes that
+are not constant within a device group.
+These are the axes specified by `mesh_axes` attribute.
+
+For Example on a 3D mesh an operation with `mesh_axes = [0, 2]` would specify
+an in-group device with `(i, j)`. Then for each group with index `g` on the
+second axis, the in-group device would be `(i, g, j)`.
+
 
 ## Operations
 
diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
index 5cce15dd1015ecc..1f63b7824133fab 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
@@ -339,6 +339,182 @@ def Mesh_AllToAllOp : Mesh_CollectiveCommunicationOpBase<"all_to_all", [
   let hasCanonicalizer = 1;
 }
 
+def Mesh_BroadcastOp : Mesh_CollectiveCommunicationOpBase<"broadcast", [
+    SameOperandsAndResultShape
+  ]> {
+  let summary = "Broadcast over a device mesh.";
+  let description = [{
+    Broadcast the tensor on `root` to all devices in each respective group.
+    The operation broadcasts along mesh axes `mesh_axes`.
+    The `root` device specifies the in-group multi-index that is broadcast to
+    all other devices in the group.
+    
+    Example:
+    ```
+    mesh.cluster @mesh0(rank = 2, dim_sizes = [2, 2])
+
+    %1 = mesh.broadcast %0 on @mesh0
+      mesh_axes = [0]
+      root = [0]
+      : tensor <2xi8> -> tensor <2xi8>
+    ```
+    
+    Input:
+    ```
+                     +-------+-------+                   | broadcast
+    device (0, 0) -> |  1  2 |  3  4 | <- device (0, 1)  | along axis 0
+                     +-------+-------+                   ↓
+    device (1, 0) -> |       |       | <- device (1, 1) 
+                     +-------+-------+
+    ```
+
+    Output:
+    ```
+                     +-------+-------+
+    device (0, 0) -> |  1  2 |  3  4 | <- device (0, 1)
+                     +-------+-------+
+    device (1, 0) -> |  1  2 |  3  4 | <- device (1, 1)
+                     +-------+-------+
+    ```
+  }];
+  let arguments = !con(commonArgs, (ins
+    AnyRankedTensor:$input,
+    DenseI64ArrayAttr:$root,
+    Variadic<Index>:$root_dynamic
+  ));
+  let results = (outs
+    AnyRankedTensor:$result
+  );
+  let assemblyFormat = [{
+    $input `on` $mesh (`mesh_axes` `=` $mesh_axes^)?
+    `root` `=` custom<DynamicIndexList>($root_dynamic, $root)
+    attr-dict `:` type($input) `->` type($result)
+  }];
+}
+
+def Mesh_GatherOp : Mesh_CollectiveCommunicationOpBase<"gather", [
+    SameOperandsAndResultElementType,
+    SameOperandsAndResultRank
+  ]> {
+  let summary = "Gather over a device mesh.";
+  let description = [{
+    Gathers on device `root` along the `gather_axis` tensor axis.
+    `root` specifies the coordinates of a device along `mesh_axes`.
+    It uniquely identifies the root device for each device group.
+    The result tensor on non-root devices is undefined.
+    Using it will result in undefined behavior.
+
+    Example:
+    ```mlir
+    mesh.cluster @mesh0(rank = 2, dim_sizes = [2, 2])
+    ...
+    %1 = mesh.gather %0 on @mesh0 mesh_axes = [1]
+      gather_axis = 1 root = [1]
+      : tensor<2x2xi8> -> tensor<2x4xi8>
+    ```
+    Input:
+    ```
+                      gather tensor
+                      axis 1
+                      ------------>
+                     +-------+-------+
+    device (0, 0) -> |  1  2 |  5  6 | <- device (0, 1)
+                     |  3  4 |  7  8 |
+                     +-------+-------+
+    device (1, 0) -> |  9 10 | 13 14 | <- device (1, 1)
+                     | 11 12 | 15 16 |
+                     +-------+-------+
+    ```
+    Result:
+    ```
+    +-------------+
+    |  1  2  5  6 | <- devices (0, 1)
+    |  3  4  7  8 |
+    +-------------+
+    |  9 10 13 14 | <- devices (1, 1)
+    | 11 12 15 16 |
+    +-------------+
+    ```
+    Devices `(0, 0)` and `(1, 0)` have undefined result.
+  }];
+  let arguments = !con(commonArgs, (ins
+    AnyNon0RankedTensor:$input,
+    IndexAttr:$gather_axis,
+    DenseI64ArrayAttr:$root,
+    Variadic<Index>:$root_dynamic
+  ));
+  let results = (outs
+    AnyNon0RankedTensor:$result
+  );
+  let assemblyFormat = [{
+    $input `on` $mesh (`mesh_axes` `=` $mesh_axes^)?
+    `gather_axis` `=` $gather_axis
+    `root` `=` custom<DynamicIndexList>($root_dynamic, $root)
+    attr-dict `:` type($input) `->` type($result)
+  }];
+}
+
+def Mesh_ReceiveOp : Mesh_CollectiveCommunicationOpBase<"receive", [
+    SameOperandsAndResultElementType,
+    SameOperandsAndResultShape]> {
+  let summary = "Send over a device mesh.";
+  let description = [{
+    Receive from a device within a device group.
+  }];
+  let arguments = !con(commonArgs, (ins
+    AnyNon0RankedTensor:$input,
+    OptionalAttr<DenseI64ArrayAttr>:$source,
+    Variadic<Index>:$source_dynamic
+  ));
+  let results = (outs
+    AnyRankedTensor:$result
+  );
+  let assemblyFormat = [{
+    $input `on` $mesh (`mesh_axes` `=` $mesh_axes^)?
+    (`source` `=` custom<DynamicIndexList>($source_dynamic, $source)^)?
+    attr-dict `:` type($input) `->` type($result)
+  }];
+}
+
+def Mesh_ReduceOp : Mesh_CollectiveCommunicationOpBase<"reduce", [
+    SameOperandsAndResultShape]> {
+  let summary = "Reduce over a device mesh.";
+  let description = [{
+    Reduces on device `root` within each device group.
+    `root` specifies the coordinates of a device along `mesh_axes`.
+    It uniquely identifies the root device within its device group.
+    The accumulation element type is specified by the result type and
+    it does not need to match the input element type.
+    The input element is converted to the result element type before
+    performing the reduction.
+
+    Attributes:
+    `reduction`: Indicates the reduction method.
+
+    Example:
+    ```
+    %1 = mesh.reduce %0 on @mesh0 mesh_axes = [1, 0]
+      reduction = <max> root = [2, 3]
+      : tensor<3x4xf32> -> tensor<3x4xf64>
+    ```
+  }];
+  let arguments = !con(commonArgs, (ins
+    AnyRankedTensor:$input,
+    DefaultValuedAttr<Mesh_PartialAttr, "::mlir::mesh::Partial::Sum">:$reduction,
+    DenseI64ArrayAttr:$root,
+    Variadic<Index>:$root_dynamic
+  ));
+  let results = (outs
+    AnyRankedTensor:$result
+  );
+  let assemblyFormat = [{
+    $input `on` $mesh (`mesh_axes` `=` $mesh_axes^)?
+    (`reduction` `=` $reduction^)?
+    `root` `=` custom<DynamicIndexList>($root_dynamic, $root)
+    attr-dict `:` type($input) `->` type($result)
+  }];
+}
+
 def Mesh_ReduceScatterOp : Mesh_CollectiveCommunicationOpBase<"reduce_scatter", [
     SameOperandsAndResultRank]> {
   let summary = "Reduce-scatter over a device mesh.";
@@ -400,4 +576,152 @@ def Mesh_ReduceScatterOp : Mesh_CollectiveCommunicationOpBase<"reduce_scatter",
   let hasCanonicalizer = 1;
 }
 
+def Mesh_ScatterOp : Mesh_CollectiveCommunicationOpBase<"scatter", [
+    SameOperandsAndResultElementType,
+    SameOperandsAndResultRank]> {
+  let summary = "Scatter over a device mesh.";
+  let description = [{
+    For each device group split the input tensor on the `root` device along
+    axis `scatter_axis` and scatter the parts across the group devices.
+
+    Example:
+    ```
+    mesh.cluster @mesh0(rank = 2, dim_sizes = [2, 2])
+    %1 = mesh.scatter %0 on @mesh0 mesh_axes = [0]
+      scatter_axis = 0
+      root = [1]
+      : tensor<2x2xi8> -> tensor<1x2xi8>
+    ```
+
+    Input:
+    ```
+                              device
+                              (0, 1)
+                                 ↓
+                     +-------+-------+  | scatter tensor
+    device (0, 0) -> |       |       |  | axis 0
+                     |       |       |  ↓
+                     +-------+-------+
+    device (1, 0) -> |  1  2 |  5  6 |
+                     |  3  4 |  7  8 |
+                     +-------+-------+
+                                ↑
+                              device
+                              (1, 1)
+    ```
+    
+    Result:
+    ```
+                              device
+                              (0, 1)
+                                 ↓
+                     +-------+-------+
+    device (0, 0) -> |  1  2 |  5  6 |
+                     +-------+-------+ 
+    device (1, 0) -> |  3  4 |  7  8 |
+                     +-------+-------+
+                                ↑
+                              device
+                              (1, 1)
+    ```
+  }];
+  let arguments = !con(commonArgs, (ins
+    AnyNon0RankedTensor:$input,
+    IndexAttr:$scatter_axis,
+    DenseI64ArrayAttr:$root,
+    Variadic<Index>:$root_dynamic
+  ));
+  let results = (outs
+    AnyRankedTensor:$result
+  );
+  let assemblyFormat = [{
+    $input `on` $mesh (`mesh_axes` `=` $mesh_axes^)?
+    `scatter_axis` `=` $scatter_axis
+    `root` `=` custom<DynamicIndexList>($root_dynamic, $root)
+    attr-dict `:` type($input) `->` type($result)
+  }];
+}
+
+def Mesh_SendOp : Mesh_CollectiveCommunicationOpBase<"send", [
+    SameOperandsAndResultElementType,
+    SameOperandsAndResultShape]> {
+  let summary = "Send over a device mesh.";
+  let description = [{
+    Send from one device to another within a device group.
+  }];
+  let arguments = !con(commonArgs, (ins
+    AnyNon0RankedTensor:$input,
+    DenseI64ArrayAttr:$destination,
+    Variadic<Index>:$destination_dynamic
+  ));
+  let results = (outs
+    AnyRankedTensor:$result
+  );
+  let assemblyFormat = [{
+    $input `on` $mesh (`mesh_axes` `=` $mesh_axes^)?
+    `destination` `=` custom<DynamicIndexList>($destination_dynamic, $destination)
+    attr-dict `:` type($input) `->` type($result)
+  }];
+}
+
+def Mesh_ShiftOp : Mesh_CollectiveCommunicationOpBase<"shift", [
+    SameOperandsAndResultElementType,
+    SameOperandsAndResultShape]> {
+  let summary = "Sift over a device mesh.";
+  let description = [{
+    Within each device group shift along mesh axis `shift_axis` by an offset
+    `offset`.
+    The result on devices that do not have a corresponding source is undefined.
+    `shift_axis` must be one of `mesh_axes`.
+    If the `rotate` attribute is present,
+    instead of a shift a rotation is done.
+
+    Example:
+    ```
+    mesh.cluster @mesh0(rank = 2, dim_sizes = [2, 4])
+    %1 = mesh.shift on @mesh0 mesh_axes = [1]
+      shift_axis = 1
+      offset = 2
+      rotate
+    ```
+
+    Input:
+    ```
+    mesh axis 1
+    ----------->
+
+    +----+----+----+----+
+    |  1 |  2 |  3 |  4 |
+    +----+----+----+----+
+    |  5 |  6 |  7 |  8 |
+    +----+----+----+----+
+    ```
+
+    Result:
+    ```
+    +----+----+----+----+
+    |  3 |  4 |  1 |  2 |
+    +----+----+----+----+
+    |  7 |  8 |  5 |  6 |
+    +----+----+----+----+
+    ```
+  }];
+  let arguments = !con(commonArgs, (ins
+    AnyNon0RankedTensor:$input,
+    IndexAttr:$shift_axis,
+    IndexAttr:$offset,
+    UnitAttr:$rotate
+  ));
+  let results = (outs
+    AnyRankedTensor:$result
+  );
+  let assemblyFormat = [{
+    $input `on` $mesh (`mesh_axes` `=` $mesh_axes^)?
+    `shift_axis` `=` $shift_axis
+    `offset` `=` $offset
+    (`rotate` $rotate^)?
+    attr-dict `:` type($input) `->` type($result)
+  }];
+}
+
 #endif // MLIR_DIALECT_MESH_IR_MESHOPS_TD
diff --git a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
index b45f7cd21ce9217..a56ee36d8016993 100644
--- a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
+++ b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
@@ -14,6 +14,7 @@
 #include "mlir/IR/DialectImplementation.h"
 #include "mlir/IR/Location.h"
 #include "mlir/IR/PatternMatch.h"
+#include "mlir/Interfaces/ViewLikeInterface.h"
 #include "mlir/Support/LLVM.h"
 #include "mlir/Support/LogicalResult.h"
 #include "llvm/ADT/ArrayRef.h"
@@ -507,6 +508,43 @@ void AllToAllOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
   patterns.add<EmptyMeshAxesCanonicalizationPattern<AllToAllOp>>(context);
 }
 
+//===----------------------------------------------------------------------===//
+// mesh.broadcast op
+//===----------------------------------------------------------------------===//
+
+LogicalResult
+BroadcastOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
+  // TODO
+  return failure();
+}
+
+//===----------------------------------------------------------------------===//
+// mesh.gather op
+//===----------------------------------------------------------------------===//
+
+LogicalResult GatherOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
+  // TODO
+  return failure();
+}
+
+//===----------------------------------------------------------------------===//
+// mesh.receive op
+//===----------------------------------------------------------------------===//
+
+LogicalResult ReceiveOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
+  // TODO
+  return failure();
+}
+
+//===----------------------------------------------------------------------===//
+// mesh.reduce op
+//===----------------------------------------------------------------------===//
+
+LogicalResult ReduceOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
+  // TODO
+  return failure();
+}
+
 //===----------------------------------------------------------------------===//
 // mesh.reduce_scatter op
 //===----------------------------------------------------------------------===//
@@ -528,6 +566,33 @@ void ReduceScatterOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
   patterns.add<EmptyMeshAxesCanonicalizationPattern<ReduceScatterOp>>(context);
 }
 
+//===----------------------------------------------------------------------===//
+// mesh.scatter op
+//===----------------------------------------------------------------------===//
+
+LogicalResult ScatterOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
+  // TODO
+  return failure();
+}
+
+//===----------------------------------------------------------------------===//
+// mesh.send op
+//===----------------------------------------------------------------------===//
+
+LogicalResult SendOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
+  // TODO
+  return failure();
+}
+
+//===----------------------------------------------------------------------===//
+// mesh.shift op
+//===----------------------------------------------------------------------===//
+
+LogicalResult ShiftOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
+  // TODO
+  return failure();
+}
+
 //===----------------------------------------------------------------------===//
 // TableGen'd op method definitions
 //===----------------------------------------------------------------------===//

>From 03a508f35d246405d8c987b075431ad0eb98f0bf Mon Sep 17 00:00:00 2001
From: Boian Petkantchin <boian.petkantchin at amd.com>
Date: Fri, 1 Dec 2023 09:46:57 -0800
Subject: [PATCH 2/4] Update collectives TableGen definitions

---
 mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td | 55 +++++++++++---------
 1 file changed, 30 insertions(+), 25 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
index 1f63b7824133fab..493021dabb27416 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
@@ -340,7 +340,8 @@ def Mesh_AllToAllOp : Mesh_CollectiveCommunicationOpBase<"all_to_all", [
 }
 
 def Mesh_BroadcastOp : Mesh_CollectiveCommunicationOpBase<"broadcast", [
-    SameOperandsAndResultShape
+    AllShapesMatch<["input", "result"]>,
+    AllElementTypesMatch<["input", "result"]>
   ]> {
   let summary = "Broadcast over a device mesh.";
   let description = [{
@@ -356,7 +357,7 @@ def Mesh_BroadcastOp : Mesh_CollectiveCommunicationOpBase<"broadcast", [
     %1 = mesh.broadcast %0 on @mesh0
       mesh_axes = [0]
       root = [0]
-      : tensor <2xi8> -> tensor <2xi8>
+      : (tensor<2xi8>) -> tensor<2xi8>
     ```
     
     Input:
@@ -388,13 +389,13 @@ def Mesh_BroadcastOp : Mesh_CollectiveCommunicationOpBase<"broadcast", [
   let assemblyFormat = [{
     $input `on` $mesh (`mesh_axes` `=` $mesh_axes^)?
     `root` `=` custom<DynamicIndexList>($root_dynamic, $root)
-    attr-dict `:` type($input) `->` type($result)
+    attr-dict `:` functional-type(operands, results)
   }];
 }
 
 def Mesh_GatherOp : Mesh_CollectiveCommunicationOpBase<"gather", [
-    SameOperandsAndResultElementType,
-    SameOperandsAndResultRank
+    AllRanksMatch<["input", "result"]>,
+    AllElementTypesMatch<["input", "result"]>
   ]> {
   let summary = "Gather over a device mesh.";
   let description = [{
@@ -410,7 +411,7 @@ def Mesh_GatherOp : Mesh_CollectiveCommunicationOpBase<"gather", [
     ...
     %1 = mesh.gather %0 on @mesh0 mesh_axes = [1]
       gather_axis = 1 root = [1]
-      : tensor<2x2xi8> -> tensor<2x4xi8>
+      : (tensor<2x2xi8>) -> tensor<2x4xi8>
     ```
     Input:
     ```
@@ -450,13 +451,14 @@ def Mesh_GatherOp : Mesh_CollectiveCommunicationOpBase<"gather", [
     $input `on` $mesh (`mesh_axes` `=` $mesh_axes^)?
     `gather_axis` `=` $gather_axis
     `root` `=` custom<DynamicIndexList>($root_dynamic, $root)
-    attr-dict `:` type($input) `->` type($result)
+    attr-dict `:` functional-type(operands, results)
   }];
 }
 
 def Mesh_ReceiveOp : Mesh_CollectiveCommunicationOpBase<"receive", [
-    SameOperandsAndResultElementType,
-    SameOperandsAndResultShape]> {
+    AllShapesMatch<["input", "result"]>,
+    AllElementTypesMatch<["input", "result"]>
+  ]> {
   let summary = "Send over a device mesh.";
   let description = [{
     Receive from a device within a device group.
@@ -472,12 +474,13 @@ def Mesh_ReceiveOp : Mesh_CollectiveCommunicationOpBase<"receive", [
   let assemblyFormat = [{
     $input `on` $mesh (`mesh_axes` `=` $mesh_axes^)?
     (`source` `=` custom<DynamicIndexList>($source_dynamic, $source)^)?
-    attr-dict `:` type($input) `->` type($result)
+    attr-dict `:` functional-type(operands, results)
   }];
 }
 
 def Mesh_ReduceOp : Mesh_CollectiveCommunicationOpBase<"reduce", [
-    SameOperandsAndResultShape]> {
+    AllShapesMatch<["input", "result"]>
+  ]> {
   let summary = "Reduce over a device mesh.";
   let description = [{
     Reduces on device `root` within each device group.
@@ -495,7 +498,7 @@ def Mesh_ReduceOp : Mesh_CollectiveCommunicationOpBase<"reduce", [
     ```
     %1 = mesh.reduce %0 on @mesh0 mesh_axes = [1, 0]
       reduction = <max> root = [2, 3]
-      : tensor<3x4xf32> -> tensor<3x4xf64>
+      : (tensor<3x4xf32>) -> tensor<3x4xf64>
     ```
   }];
   let arguments = !con(commonArgs, (ins
@@ -511,7 +514,7 @@ def Mesh_ReduceOp : Mesh_CollectiveCommunicationOpBase<"reduce", [
     $input `on` $mesh (`mesh_axes` `=` $mesh_axes^)?
     (`reduction` `=` $reduction^)?
     `root` `=` custom<DynamicIndexList>($root_dynamic, $root)
-    attr-dict `:` type($input) `->` type($result)
+    attr-dict `:` functional-type(operands, results)
   }];
 }
 
@@ -577,8 +580,9 @@ def Mesh_ReduceScatterOp : Mesh_CollectiveCommunicationOpBase<"reduce_scatter",
 }
 
 def Mesh_ScatterOp : Mesh_CollectiveCommunicationOpBase<"scatter", [
-    SameOperandsAndResultElementType,
-    SameOperandsAndResultRank]> {
+    AllRanksMatch<["input", "result"]>,
+    AllElementTypesMatch<["input", "result"]>
+  ]> {
   let summary = "Scatter over a device mesh.";
   let description = [{
     For each device group split the input tensor on the `root` device along
@@ -590,7 +594,7 @@ def Mesh_ScatterOp : Mesh_CollectiveCommunicationOpBase<"scatter", [
     %1 = mesh.scatter %0 on @mesh0 mesh_axes = [0]
       scatter_axis = 0
       root = [1]
-      : tensor<2x2xi8> -> tensor<1x2xi8>
+      : (tensor<2x2xi8>) -> tensor<1x2xi8>
     ```
 
     Input:
@@ -638,13 +642,14 @@ def Mesh_ScatterOp : Mesh_CollectiveCommunicationOpBase<"scatter", [
     $input `on` $mesh (`mesh_axes` `=` $mesh_axes^)?
     `scatter_axis` `=` $scatter_axis
     `root` `=` custom<DynamicIndexList>($root_dynamic, $root)
-    attr-dict `:` type($input) `->` type($result)
+    attr-dict `:` functional-type(operands, results)
   }];
 }
 
 def Mesh_SendOp : Mesh_CollectiveCommunicationOpBase<"send", [
-    SameOperandsAndResultElementType,
-    SameOperandsAndResultShape]> {
+    AllShapesMatch<["input", "result"]>,
+    AllElementTypesMatch<["input", "result"]>
+  ]> {
   let summary = "Send over a device mesh.";
   let description = [{
     Send from one device to another within a device group.
@@ -660,13 +665,14 @@ def Mesh_SendOp : Mesh_CollectiveCommunicationOpBase<"send", [
   let assemblyFormat = [{
     $input `on` $mesh (`mesh_axes` `=` $mesh_axes^)?
     `destination` `=` custom<DynamicIndexList>($destination_dynamic, $destination)
-    attr-dict `:` type($input) `->` type($result)
+    attr-dict `:` functional-type(operands, results)
   }];
 }
 
 def Mesh_ShiftOp : Mesh_CollectiveCommunicationOpBase<"shift", [
     SameOperandsAndResultElementType,
-    SameOperandsAndResultShape]> {
+    SameOperandsAndResultShape
+  ]> {
   let summary = "Sift over a device mesh.";
   let description = [{
     Within each device group shift along mesh axis `shift_axis` by an offset
@@ -680,9 +686,8 @@ def Mesh_ShiftOp : Mesh_CollectiveCommunicationOpBase<"shift", [
     ```
     mesh.cluster @mesh0(rank = 2, dim_sizes = [2, 4])
     %1 = mesh.shift on @mesh0 mesh_axes = [1]
-      shift_axis = 1
-      offset = 2
-      rotate
+      shift_axis = 1 offset = 2 rotate
+      : tensor<2xi8> -> tensor<2xi8>
     ```
 
     Input:
@@ -709,7 +714,7 @@ def Mesh_ShiftOp : Mesh_CollectiveCommunicationOpBase<"shift", [
   let arguments = !con(commonArgs, (ins
     AnyNon0RankedTensor:$input,
     IndexAttr:$shift_axis,
-    IndexAttr:$offset,
+    I64Attr:$offset,
     UnitAttr:$rotate
   ));
   let results = (outs

>From 9dca77f9a8895a04b61fc443c5a58d1033284438 Mon Sep 17 00:00:00 2001
From: Boian Petkantchin <boian.petkantchin at amd.com>
Date: Fri, 1 Dec 2023 16:44:53 -0800
Subject: [PATCH 3/4] Add missing include

---
 mlir/lib/Dialect/Mesh/IR/MeshOps.cpp | 1 +
 1 file changed, 1 insertion(+)

diff --git a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
index a56ee36d8016993..6f8de72d034ef48 100644
--- a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
+++ b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
@@ -14,6 +14,7 @@
 #include "mlir/IR/DialectImplementation.h"
 #include "mlir/IR/Location.h"
 #include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/TypeUtilities.h"
 #include "mlir/Interfaces/ViewLikeInterface.h"
 #include "mlir/Support/LLVM.h"
 #include "mlir/Support/LogicalResult.h"

>From fc79158f09cfc33cbe79adc9702ba385c9721c1b Mon Sep 17 00:00:00 2001
From: Boian Petkantchin <boian.petkantchin at amd.com>
Date: Fri, 1 Dec 2023 17:09:51 -0800
Subject: [PATCH 4/4] Rename receive -> recv

---
 mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td | 2 +-
 mlir/lib/Dialect/Mesh/IR/MeshOps.cpp         | 2 +-
 2 files changed, 2 insertions(+), 2 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
index 493021dabb27416..361e67fd1e19ac6 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
@@ -455,7 +455,7 @@ def Mesh_GatherOp : Mesh_CollectiveCommunicationOpBase<"gather", [
   }];
 }
 
-def Mesh_ReceiveOp : Mesh_CollectiveCommunicationOpBase<"receive", [
+def Mesh_RecvOp : Mesh_CollectiveCommunicationOpBase<"recv", [
     AllShapesMatch<["input", "result"]>,
     AllElementTypesMatch<["input", "result"]>
   ]> {
diff --git a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
index 6f8de72d034ef48..3b89860c14d9362 100644
--- a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
+++ b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
@@ -532,7 +532,7 @@ LogicalResult GatherOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
 // mesh.receive op
 //===----------------------------------------------------------------------===//
 
-LogicalResult ReceiveOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
+LogicalResult RecvOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
   // TODO
   return failure();
 }



More information about the Mlir-commits mailing list