[Mlir-commits] [mlir] 5f7c8c1 - [mlir][mesh] Add collective communication operations (#71960)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Nov 21 06:50:29 PST 2023


Author: Boian Petkantchin
Date: 2023-11-21T06:50:24-08:00
New Revision: 5f7c8c1068d23d1f23643ddd141ca6f4f23f3578

URL: https://github.com/llvm/llvm-project/commit/5f7c8c1068d23d1f23643ddd141ca6f4f23f3578
DIFF: https://github.com/llvm/llvm-project/commit/5f7c8c1068d23d1f23643ddd141ca6f4f23f3578.diff

LOG: [mlir][mesh] Add collective communication operations (#71960)

Add all-gather, all-reduce, all-to-all and reduce-scatter. These
operations have device mesh semantics.

Added: 
    mlir/docs/Dialects/Mesh.md
    mlir/test/Dialect/Mesh/canonicalization.mlir

Modified: 
    mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td
    mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
    mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
    mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
    mlir/test/Dialect/Mesh/invalid.mlir
    mlir/test/Dialect/Mesh/ops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/docs/Dialects/Mesh.md b/mlir/docs/Dialects/Mesh.md
new file mode 100644
index 000000000000000..03877f1a6544817
--- /dev/null
+++ b/mlir/docs/Dialects/Mesh.md
@@ -0,0 +1,43 @@
+# 'mesh' Dialect
+
+The `mesh` dialect contains a set of attributes, operations and interfaces that
+are useful for representing sharding and communication on a device mesh
+cluster.
+
+[TOC]
+
+## Collective Communication Operations
+There are a number of operations in the Mesh dialect to facilitate
+communication between devices in a mesh.
+It is assumed that the user is familiar with collective operations.
+[Wikipedia](https://en.wikipedia.org/wiki/Collective_operation) has a good
+explanation.
+The main addition is that the collectives in this dialect have mesh
+semantics.
+
+The operation attributes `mesh` and `mesh_axes` specifies a list of device mesh
+axes that partition the devices into disjoint groups.
+The collective operation is performed between devices in the same group.
+Devices that have the same coordinates outside of axes `mesh_axes` are in the
+same group.
+For example if we have a device mesh of size `2x3x4x5` and the partition mesh
+axes list is `[0, 1]` then devices are partitioned into the groups
+`{ { (i, j, k, m) | 0<=i<2, 0<=j<3 } | 0<=k<4, 0<=m<5 }`.
+Devices (1, 0, 2, 3) and (1, 1, 2, 3) will be in the same group.
+Device (1, 0, 2, 4) will be in another group.
+Some collective operations like all-to-all and all-gather care about the
+order of devices.
+The order of device in a device group is induced by the order of axes in
+`mesh_axes`.
+The axes are ordered from outer to inner.
+If we have an axis list `[3, 1]` then device `(i, 1, k, 0)` will precede
+both devices `(i, 0, k, 1)` and `(i, 2, k, 0)`.
+
+
+## Operations
+
+[include "Dialects/MeshOps.md"]
+
+## Attributes
+
+[include "Dialects/MeshAttributes.md"]

diff  --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td b/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td
index a91ef569347bff1..9d39b1b3329fb4b 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td
@@ -23,9 +23,7 @@ def Mesh_Dialect : Dialect {
   let cppNamespace = "::mlir::mesh";
 
   let description = [{
-    The `mesh` dialect contains a set of attributes, operations, interfaces that
-    are useful for representing sharding and communication on device mesh
-    cluster.
+    See [Mesh dialect documentation](mlir/docs/Dialects/Mesh.md).
   }];
 
   let dependentDialects = [
@@ -49,6 +47,10 @@ def Mesh_Partial : I32EnumAttr<"Partial", "partial type of a distributed tensor"
   let cppNamespace = "::mlir::mesh";
 }
 
+def Mesh_PartialAttr : EnumAttr<Mesh_Dialect, Mesh_Partial, "partial"> {
+  let assemblyFormat = "`<` $value `>`";
+}
+
 // Mesh_IteratorType and Mesh_Partial are used to annotate 
diff erent aspects of
 // distributed tensors. Mesh_IteratorType annotates loops in an operation, while
 // Mesh_Partial indicates whether a tensor is sharded on a specific dimension or

diff  --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
index 05eba66a89949b6..9077d2eb0189b72 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
@@ -10,9 +10,12 @@
 #define MLIR_DIALECT_MESH_IR_MESHOPS_H
 
 #include "mlir/Bytecode/BytecodeOpInterface.h"
+#include "mlir/IR/OpDefinition.h"
+#include "mlir/IR/PatternMatch.h"
 #include "mlir/IR/SymbolTable.h"
 #include "mlir/Interfaces/InferTypeOpInterface.h"
 #include "mlir/Interfaces/SideEffectInterfaces.h"
+#include <algorithm>
 
 #include "mlir/Dialect/Mesh/IR/MeshOpsDialect.h.inc"
 

diff  --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
index a8aa0a694bee29f..5cce15dd1015ecc 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
@@ -13,6 +13,8 @@ include "mlir/Dialect/Mesh/IR/MeshBase.td"
 include "mlir/Interfaces/InferTypeOpInterface.td"
 include "mlir/Interfaces/SideEffectInterfaces.td"
 include "mlir/IR/BuiltinTypes.td"
+include "mlir/IR/CommonAttrConstraints.td"
+include "mlir/IR/CommonTypeConstraints.td"
 include "mlir/IR/SymbolInterfaces.td"
 
 //===----------------------------------------------------------------------===//
@@ -77,6 +79,18 @@ def Mesh_ClusterOp : Mesh_Op<"cluster", [Symbol]> {
     $sym_name `(` `rank` `=` $rank (`,` `dim_sizes` `=` $dim_sizes^)? `)`
       attr-dict
   }];
+  let extraClassDeclaration = [{
+    // The `dim_sizes` attribute may have size less than the rank of the mesh.
+    // Returns the shape of the mesh with missing trailing dimensions
+    // explicitly set as dynamic.
+    ::mlir::SmallVector<int64_t> canonicalDimSizes();
+
+    template <typename OutIt>
+    void canonicalDimSizes(OutIt outIt) {
+      std::copy(getDimSizes().begin(), getDimSizes().end(), outIt);
+      std::fill_n(outIt, getRank() - getDimSizes().size(), 0);
+    }
+  }];
   let hasVerifier = 1;
 }
 
@@ -171,4 +185,219 @@ def Mesh_ShardOp : Mesh_Op<"shard", [Pure, SameOperandsAndResultType]> {
   }];
 }
 
+//===----------------------------------------------------------------------===//
+// collective communication ops
+//===----------------------------------------------------------------------===//
+
+class Mesh_CollectiveCommunicationOpBase<
+    string mnemonic, list<Trait> traits = []> :
+    Mesh_Op<mnemonic,
+      !listconcat(traits,
+      [DeclareOpInterfaceMethods<SymbolUserOpInterface>])> {
+  dag commonArgs = (ins
+    FlatSymbolRefAttr:$mesh,
+    DefaultValuedAttr<DenseI16ArrayAttr, "{}">:$mesh_axes
+  );
+}
+
+def Mesh_AllGatherOp : Mesh_CollectiveCommunicationOpBase<"all_gather", [
+    SameOperandsAndResultElementType,
+    SameOperandsAndResultRank
+  ]> {
+  let summary = "All-gather over a device mesh.";
+  let description = [{
+    Gathers along the `gather_axis` tensor axis.
+
+    Example:
+    ```mlir
+    mesh.cluster @mesh0(rank = 2, dim_sizes = [2, 2])
+    ...
+    %1 = mesh.all_gather %0 on @mesh0 mesh_axes = [1] gather_axis = 1
+      : tensor<2x2xi8> -> tensor<2x4xi8>
+    ```
+    Input:
+    ```
+                     +-------+-------+
+    device (0, 0) -> |  1  2 |  5  6 | <- device (0, 1)
+                     |  3  4 |  7  8 |
+                     +-------+-------+
+    device (1, 0) -> |  9 10 | 13 14 | <- device (1, 1)
+                     | 11 12 | 15 16 |
+                     +-------+-------+
+    ```
+    Result:
+    ```
+    gather tensor
+    axis 1
+    ------------>
+    +-------------+
+    |  1  2  5  6 | <- devices (0, 0) and (0, 1)
+    |  3  4  7  8 |
+    +-------------+
+    |  9 10 13 14 | <- devices (1, 0) and (1, 1)
+    | 11 12 15 16 |
+    +-------------+
+    ```
+  }];
+  let arguments = !con(commonArgs, (ins
+    AnyNon0RankedTensor:$input,
+    IndexAttr:$gather_axis
+  ));
+  let results = (outs
+    AnyNon0RankedTensor:$result
+  );
+  let assemblyFormat = [{
+    $input `on` $mesh (`mesh_axes` `=` $mesh_axes^)? `gather_axis` `=` $gather_axis
+    attr-dict `:` type($input) `->` type($result)
+  }];
+  let hasCanonicalizer = 1;
+}
+
+def Mesh_AllReduceOp : Mesh_CollectiveCommunicationOpBase<"all_reduce", [
+    SameOperandsAndResultShape]> {
+  let summary = "All-reduce over a device mesh.";
+  let description = [{
+    The accumulation element type is specified by the result type and
+    it does not need to match the input element type.
+    The input element is converted to the result element type before
+    performing the reduction.
+
+    Attributes:
+    `reduction`: Indicates the reduction method.
+
+    Example:
+    ```
+    %1 = mesh.all_reduce %0 on @mesh0 mesh_axes = [1, 0] reduction = <max>
+      : tensor<3x4xf32> -> tensor<3x4xf64>
+    ```
+  }];
+  let arguments = !con(commonArgs, (ins
+    AnyRankedTensor:$input,
+    DefaultValuedAttr<Mesh_PartialAttr, "::mlir::mesh::Partial::Sum">:$reduction
+  ));
+  let results = (outs
+    AnyRankedTensor:$result
+  );
+  let assemblyFormat = [{
+    $input `on` $mesh (`mesh_axes` `=` $mesh_axes^)? (`reduction` `=` $reduction^)?
+    attr-dict `:` type($input) `->` type($result)
+  }];
+  let hasCanonicalizer = 1;
+}
+
+def Mesh_AllToAllOp : Mesh_CollectiveCommunicationOpBase<"all_to_all", [
+    SameOperandsAndResultElementType,
+    SameOperandsAndResultRank]> {
+  let summary = "All-to-all over a device mesh.";
+  let description = [{
+    Performs an all-to-all on tensor pieces split along `split_axis`.
+    The resulting pieces are concatenated along `concat_axis` on ech device.
+
+    Example:
+    ```
+    mesh.cluster @mesh0(rank = 1, dim_sizes = [3])
+    ...
+    %1 = mesh.all_to_all %0 on @mesh0 mesh_axes = [0]
+      split_axis = 0 concat_axis = 0
+      : tensor<3x2xi8> -> tensor<3x2xi8>
+    ```
+    Input:
+    ```
+     device  device  device
+     (0)     (1)     (2)
+    +-------+-------+-------+  | split and concat along
+    | 11 12 | 21 22 | 31 32 |  | tensor axis 0
+    | 13 14 | 23 24 | 33 34 |  ↓
+    | 15 16 | 25 26 | 35 36 |
+    +-------+-------+-------+
+    ```
+    Result:
+    ```
+     device  device  device
+     (0)     (1)     (2)
+    +-------+-------+-------+
+    | 11 12 | 13 14 | 15 16 |
+    | 21 22 | 23 24 | 25 26 |
+    | 31 32 | 33 34 | 35 36 |
+    +-------+-------+-------+
+    ```
+  }];
+  let arguments = !con(commonArgs, (ins
+    AnyNon0RankedTensor:$input,
+    IndexAttr:$split_axis,
+    IndexAttr:$concat_axis
+  ));
+  let results = (outs
+    AnyNon0RankedTensor:$result
+  );
+  let assemblyFormat = [{
+    $input `on` $mesh (`mesh_axes` `=` $mesh_axes^)?
+    `split_axis` `=` $split_axis
+    `concat_axis` `=` $concat_axis
+    attr-dict `:` type($input) `->` type($result)
+  }];
+  let hasCanonicalizer = 1;
+}
+
+def Mesh_ReduceScatterOp : Mesh_CollectiveCommunicationOpBase<"reduce_scatter", [
+    SameOperandsAndResultRank]> {
+  let summary = "Reduce-scatter over a device mesh.";
+  let description = [{
+    After the reduction, the result is scattered within each device group.
+    The tensor is split along `scatter_axis` and the pieces distributed
+    across the device group.
+    Example:
+    ```
+    mesh.cluster @mesh0(rank = 1, dim_sizes = [2, 2])
+    ...
+    %1 = mesh.reduce_scatter %0 on @mesh0 mesh_axes = [1]
+      reduction = <max> scatter_axis = 0
+      : tensor<3x4xf32> -> tensor<1x4xf64>
+    ```
+    Input:
+    ```
+                              device
+                              (0, 1)
+                                 ↓
+                     +-------+-------+  | scatter tensor
+    device (0, 0) -> |  1  2 |  5  6 |  | axis 0
+                     |  3  4 |  7  8 |  ↓
+                     +-------+-------+
+    device (1, 0) -> |  9 10 | 13 14 |
+                     | 11 12 | 15 16 |
+                     +-------+-------+
+                                ↑
+                              device
+                              (1, 1)
+    ```
+    Result:
+    ```
+    +-------+
+    |  6  8 | <- devices (0, 0)
+    +-------+
+    | 10 12 | <- devices (0, 1)
+    +-------+
+    | 22 24 | <- devices (1, 0)
+    +-------+
+    | 26 28 | <- devices (1, 1)
+    +-------+
+    ```
+  }];
+  let arguments = !con(commonArgs, (ins
+    AnyNon0RankedTensor:$input,
+    DefaultValuedAttr<Mesh_PartialAttr, "::mlir::mesh::Partial::Sum">:$reduction,
+    IndexAttr:$scatter_axis
+  ));
+  let results = (outs
+    AnyRankedTensor:$result
+  );
+  let assemblyFormat = [{
+    $input `on` $mesh (`mesh_axes` `=` $mesh_axes^)?
+    (`reduction` `=` $reduction^)?
+    `scatter_axis` `=` $scatter_axis
+    attr-dict `:` type($input) `->` type($result)
+  }];
+  let hasCanonicalizer = 1;
+}
+
 #endif // MLIR_DIALECT_MESH_IR_MESHOPS_TD

diff  --git a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
index 588704f24574f90..b45f7cd21ce9217 100644
--- a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
+++ b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
@@ -8,10 +8,27 @@
 
 #include "mlir/Dialect/Mesh/IR/MeshOps.h"
 #include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/BuiltinTypeInterfaces.h"
+#include "mlir/IR/Diagnostics.h"
 #include "mlir/IR/DialectImplementation.h"
+#include "mlir/IR/Location.h"
+#include "mlir/IR/PatternMatch.h"
 #include "mlir/Support/LLVM.h"
+#include "mlir/Support/LogicalResult.h"
+#include "llvm/ADT/ArrayRef.h"
+#include "llvm/ADT/DenseSet.h"
+#include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/SmallSet.h"
+#include "llvm/ADT/SmallVector.h"
 #include "llvm/ADT/TypeSwitch.h"
+#include <algorithm>
+#include <functional>
+#include <iterator>
+#include <numeric>
+#include <optional>
+#include <string>
+#include <utility>
 
 #define DEBUG_TYPE "mesh-ops"
 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ")
@@ -21,6 +38,60 @@ using namespace mlir::mesh;
 
 #include "mlir/Dialect/Mesh/IR/MeshOpsDialect.cpp.inc"
 
+template <typename It>
+static It canonicalizeSetAsArray(It begin, It end) {
+  llvm::sort(begin, end);
+  return std::unique(begin, end);
+}
+
+template <typename R>
+static auto canonicalizeSetAsArray(R &&range) {
+  return canonicalizeSetAsArray(adl_begin(range), adl_end(range));
+}
+
+template <typename T>
+static SmallVector<T> &canonicalizeSetAsVector(SmallVector<T> &vec) {
+  auto newEnd = canonicalizeSetAsArray(vec);
+  vec.resize(newEnd - vec.begin());
+  return vec;
+}
+
+template <typename DimSize>
+static bool isMeshDimensionDynamic(DimSize size) {
+  return size <= DimSize(0);
+}
+
+using MeshAxis = int16_t;
+
+namespace {
+
+struct DimensionSize {
+  static DimensionSize dynamic() { return DimensionSize(ShapedType::kDynamic); }
+  DimensionSize(int64_t val) : val(val) {}
+  int64_t value() const { return val; }
+  operator int64_t() const { return val; }
+  bool isDynamic() const { return ShapedType::isDynamic(val); }
+
+private:
+  int64_t val;
+};
+
+} // namespace
+
+static DimensionSize operator/(DimensionSize lhs, DimensionSize rhs) {
+  if (lhs.isDynamic() || rhs.isDynamic()) {
+    return DimensionSize::dynamic();
+  }
+  return lhs.value() / rhs.value();
+}
+
+static DimensionSize operator*(DimensionSize lhs, DimensionSize rhs) {
+  if (lhs.isDynamic() || rhs.isDynamic()) {
+    return DimensionSize::dynamic();
+  }
+  return lhs.value() * rhs.value();
+}
+
 //===----------------------------------------------------------------------===//
 // Mesh dialect
 //===----------------------------------------------------------------------===//
@@ -96,6 +167,13 @@ LogicalResult ClusterOp::verify() {
   return success();
 }
 
+SmallVector<int64_t> ClusterOp::canonicalDimSizes() {
+  SmallVector<int64_t> result;
+  canonicalDimSizes(std::back_inserter(result));
+  result.reserve(getRank());
+  return result;
+}
+
 //===----------------------------------------------------------------------===//
 // mesh.shard op
 //===----------------------------------------------------------------------===//
@@ -129,6 +207,327 @@ MeshShardingAttr::verify(function_ref<InFlightDiagnostic()> emitError,
   return success();
 }
 
+//===----------------------------------------------------------------------===//
+// collective communication ops
+//===----------------------------------------------------------------------===//
+
+namespace {
+
+template <typename Op>
+struct EmptyMeshAxesCanonicalizationPattern : OpRewritePattern<Op> {
+  using OpRewritePattern<Op>::OpRewritePattern;
+  LogicalResult matchAndRewrite(Op op,
+                                PatternRewriter &rewriter) const override {
+    auto meshAxes = op.getMeshAxes();
+    if (!meshAxes.empty()) {
+      return failure();
+    }
+    if (op.getInput().getType() != op.getResult().getType()) {
+      return failure();
+    }
+
+    rewriter.replaceAllUsesWith(op.getResult(), op.getInput());
+    rewriter.eraseOp(op.getOperation());
+    return success();
+  }
+};
+
+} // namespace
+
+static FailureOr<ClusterOp> getMesh(Operation *op, FlatSymbolRefAttr meshSymbol,
+                                    SymbolTableCollection &symbolTable) {
+  mesh::ClusterOp mesh =
+      symbolTable.lookupNearestSymbolFrom<mesh::ClusterOp>(op, meshSymbol);
+  if (!mesh) {
+    return op->emitError() << "Undefined required mesh symbol \""
+                           << meshSymbol.getValue() << "\".";
+  }
+
+  return mesh;
+}
+
+template <typename It>
+bool isUnique(It begin, It end) {
+  if (begin == end) {
+    return true;
+  }
+  It next = std::next(begin);
+  if (next == end) {
+    return true;
+  }
+  for (; next != end; ++next, ++begin) {
+    if (*begin == *next) {
+      return false;
+    }
+  }
+  return true;
+}
+
+static LogicalResult verifyMeshAxes(Location loc, ArrayRef<MeshAxis> axes,
+                                    ClusterOp mesh) {
+  SmallVector<MeshAxis> sorted = llvm::to_vector(axes);
+  llvm::sort(sorted);
+  if (!isUnique(sorted.begin(), sorted.end())) {
+    return emitError(loc) << "Mesh axes contains duplicate elements.";
+  }
+
+  MeshAxis rank = mesh.getRank();
+  for (auto axis : axes) {
+    if (axis >= rank || axis < 0) {
+      return emitError(loc)
+             << "0-based mesh axis index " << axis
+             << " is out of bounds. The referenced mesh \"" << mesh.getSymName()
+             << "\" is of rank " << rank << ".";
+    }
+  }
+
+  return success();
+}
+
+template <typename Op>
+static FailureOr<ClusterOp>
+getMeshAndVerifyAxes(Op op, SymbolTableCollection &symbolTable) {
+  auto mesh = ::getMesh(op.getOperation(), op.getMeshAttr(), symbolTable);
+  if (failed(mesh)) {
+    return failure();
+  }
+  if (failed(verifyMeshAxes(op.getLoc(), op.getMeshAxes(), mesh.value()))) {
+    return failure();
+  }
+  return mesh;
+}
+
+template <typename It>
+static auto product(It begin, It end) {
+  using ElementType = std::decay_t<decltype(*begin)>;
+  return std::accumulate(begin, end, static_cast<ElementType>(1),
+                         std::multiplies<ElementType>());
+}
+
+template <typename R>
+static auto product(R &&range) {
+  return product(adl_begin(range), adl_end(range));
+}
+
+static int64_t collectiveDeviceGroupSize(ArrayRef<MeshAxis> meshAxes,
+                                         ArrayRef<int64_t> meshShape) {
+  int64_t res = 1;
+
+  for (MeshAxis axis : meshAxes) {
+    if (isMeshDimensionDynamic(meshShape[axis])) {
+      return ShapedType::kDynamic;
+    }
+    assert(size_t(axis) < meshShape.size());
+    res *= meshShape[axis];
+  }
+
+  return res;
+}
+
+static LogicalResult verifyDimensionCompatibility(Location loc,
+                                                  int64_t expectedDimSize,
+                                                  int64_t resultDimSize,
+                                                  int64_t resultAxis) {
+  if (!ShapedType::isDynamic(resultDimSize) &&
+      expectedDimSize != resultDimSize) {
+    return emitError(loc) << "Dimension size mismatch for result axis "
+                          << resultAxis << ". Expected "
+                          << (ShapedType::isDynamic(expectedDimSize)
+                                  ? Twine("dynamic")
+                                  : Twine(expectedDimSize))
+                          << ", but got " << resultDimSize << ".";
+  }
+
+  return success();
+}
+
+static LogicalResult verifyAllGatherOperandAndResultShape(
+    Value operand, Value result, int64_t gatherAxis,
+    ArrayRef<MeshAxis> meshAxes, ArrayRef<int64_t> meshShape) {
+  auto resultRank = result.getType().template cast<ShapedType>().getRank();
+  if (gatherAxis < 0 || gatherAxis >= resultRank) {
+    return emitError(result.getLoc())
+           << "Gather axis " << gatherAxis << " is out of bounds [0, "
+           << resultRank << ").";
+  }
+
+  ShapedType operandType = operand.getType().cast<ShapedType>();
+  ShapedType resultType = result.getType().cast<ShapedType>();
+  auto deviceGroupSize =
+      DimensionSize(collectiveDeviceGroupSize(meshAxes, meshShape));
+  for (int64_t axis = 0; axis < operandType.getRank(); ++axis) {
+    auto operandDimSize = DimensionSize(operandType.getDimSize(axis));
+    auto resultDimSize = DimensionSize(resultType.getDimSize(axis));
+    auto expectedResultDimSize =
+        axis == gatherAxis ? deviceGroupSize * operandDimSize : operandDimSize;
+    if (failed(verifyDimensionCompatibility(
+            result.getLoc(), expectedResultDimSize, resultDimSize, axis))) {
+      return failure();
+    }
+  }
+  return success();
+}
+
+static LogicalResult verifyAllToAllOperandAndResultShape(
+    Value operand, Value result, int64_t splitAxis, int64_t concatAxis,
+    ArrayRef<MeshAxis> meshAxes, ArrayRef<int64_t> meshShape) {
+  ShapedType operandType = operand.getType().cast<ShapedType>();
+  ShapedType resultType = result.getType().cast<ShapedType>();
+  for (int64_t axis = 0; axis < operandType.getRank(); ++axis) {
+    if ((axis != splitAxis && axis != concatAxis) || splitAxis == concatAxis) {
+      if (failed(verifyDimensionCompatibility(
+              result.getLoc(), operandType.getDimSize(axis),
+              resultType.getDimSize(axis), axis))) {
+        return failure();
+      }
+    }
+  }
+
+  if (splitAxis == concatAxis) {
+    return success();
+  }
+
+  auto deviceGroupSize =
+      DimensionSize(collectiveDeviceGroupSize(meshAxes, meshShape));
+  auto operandConcatDimSize = DimensionSize(operandType.getDimSize(concatAxis));
+  auto operandSplitDimSize = DimensionSize(operandType.getDimSize(splitAxis));
+  DimensionSize expectedResultConcatDimSize =
+      operandConcatDimSize * deviceGroupSize;
+  DimensionSize expectedResultSplitDimSize =
+      operandSplitDimSize / deviceGroupSize;
+  if (!expectedResultSplitDimSize.isDynamic() &&
+      int64_t(operandSplitDimSize) % int64_t(deviceGroupSize) != 0) {
+    expectedResultSplitDimSize = DimensionSize::dynamic();
+  }
+  if (failed(verifyDimensionCompatibility(
+          result.getLoc(), expectedResultConcatDimSize.value(),
+          resultType.getDimSize(concatAxis), concatAxis))) {
+    return failure();
+  }
+  if (failed(verifyDimensionCompatibility(
+          result.getLoc(), expectedResultSplitDimSize.value(),
+          resultType.getDimSize(splitAxis), splitAxis))) {
+    return failure();
+  }
+
+  return success();
+}
+
+static LogicalResult verifyReduceScatterOperandAndResultShape(
+    Value operand, Value result, int64_t scatterAxis,
+    ArrayRef<MeshAxis> meshAxes, ArrayRef<int64_t> meshShape) {
+  ShapedType operandType = operand.getType().cast<ShapedType>();
+  ShapedType resultType = result.getType().cast<ShapedType>();
+  for (int64_t axis = 0; axis < operandType.getRank(); ++axis) {
+    if (axis != scatterAxis) {
+      if (failed(verifyDimensionCompatibility(
+              result.getLoc(), operandType.getDimSize(axis),
+              resultType.getDimSize(axis), axis))) {
+        return failure();
+      }
+    }
+  }
+
+  auto deviceGroupSize =
+      DimensionSize(collectiveDeviceGroupSize(meshAxes, meshShape));
+  auto operandScatterDimSize =
+      DimensionSize(operandType.getDimSize(scatterAxis));
+  if (!operandScatterDimSize.isDynamic() && !deviceGroupSize.isDynamic() &&
+      int64_t(operandScatterDimSize) % int64_t(deviceGroupSize) != 0) {
+    return emitError(result.getLoc())
+           << "Operand dimension size " << int64_t(operandScatterDimSize)
+           << " is not divisible by collective device group size "
+           << int64_t(deviceGroupSize) << " for scatter axis " << scatterAxis
+           << ".";
+  }
+  DimensionSize expectedResultScatterDimSize =
+      operandScatterDimSize / deviceGroupSize;
+  if (failed(verifyDimensionCompatibility(
+          result.getLoc(), expectedResultScatterDimSize.value(),
+          resultType.getDimSize(scatterAxis), scatterAxis))) {
+    return failure();
+  }
+
+  return success();
+}
+
+//===----------------------------------------------------------------------===//
+// mesh.all_gather op
+//===----------------------------------------------------------------------===//
+
+LogicalResult
+AllGatherOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
+  auto mesh = getMeshAndVerifyAxes(*this, symbolTable);
+  if (failed(mesh)) {
+    return failure();
+  }
+  auto gatherAxis = getGatherAxis().getSExtValue();
+  return verifyAllGatherOperandAndResultShape(getOperand(), getResult(),
+                                              gatherAxis, getMeshAxes(),
+                                              mesh.value().canonicalDimSizes());
+}
+
+void AllGatherOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
+                                              MLIRContext *context) {
+  patterns.add<EmptyMeshAxesCanonicalizationPattern<AllGatherOp>>(context);
+}
+
+//===----------------------------------------------------------------------===//
+// mesh.all_reduce op
+//===----------------------------------------------------------------------===//
+
+LogicalResult
+AllReduceOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
+  return getMeshAndVerifyAxes(*this, symbolTable);
+}
+
+void AllReduceOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
+                                              MLIRContext *context) {
+  patterns.add<EmptyMeshAxesCanonicalizationPattern<AllReduceOp>>(context);
+}
+
+//===----------------------------------------------------------------------===//
+// mesh.all_to_all op
+//===----------------------------------------------------------------------===//
+
+LogicalResult AllToAllOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
+  auto mesh = getMeshAndVerifyAxes(*this, symbolTable);
+  if (failed(mesh)) {
+    return failure();
+  }
+
+  return verifyAllToAllOperandAndResultShape(
+      getOperand(), getResult(), getSplitAxis().getSExtValue(),
+      getConcatAxis().getSExtValue(), getMeshAxes(),
+      mesh.value().canonicalDimSizes());
+}
+
+void AllToAllOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
+                                             MLIRContext *context) {
+  patterns.add<EmptyMeshAxesCanonicalizationPattern<AllToAllOp>>(context);
+}
+
+//===----------------------------------------------------------------------===//
+// mesh.reduce_scatter op
+//===----------------------------------------------------------------------===//
+
+LogicalResult
+ReduceScatterOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
+  auto mesh = getMeshAndVerifyAxes(*this, symbolTable);
+  if (failed(mesh)) {
+    return failure();
+  }
+
+  return verifyReduceScatterOperandAndResultShape(
+      getOperand(), getResult(), getScatterAxis().getSExtValue(), getMeshAxes(),
+      mesh.value().canonicalDimSizes());
+}
+
+void ReduceScatterOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
+                                                  MLIRContext *context) {
+  patterns.add<EmptyMeshAxesCanonicalizationPattern<ReduceScatterOp>>(context);
+}
+
 //===----------------------------------------------------------------------===//
 // TableGen'd op method definitions
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Dialect/Mesh/canonicalization.mlir b/mlir/test/Dialect/Mesh/canonicalization.mlir
new file mode 100644
index 000000000000000..5802d198d368149
--- /dev/null
+++ b/mlir/test/Dialect/Mesh/canonicalization.mlir
@@ -0,0 +1,101 @@
+// RUN: mlir-opt --canonicalize %s | FileCheck %s
+
+mesh.cluster @mesh0(rank = 2, dim_sizes = [2, 4])
+
+// CHECK-LABEL: func @all_reduce_empty_mesh_axes
+func.func @all_reduce_empty_mesh_axes(
+// CHECK-SAME: %[[ARG:.*]]: tensor<4xf32>
+    %arg0 : tensor<4xf32>) -> tensor<4xf32> {
+// CHECK-NOT: mesh.all_reduce
+  %0 = mesh.all_reduce %arg0 on @mesh0
+    mesh_axes = []
+    : tensor<4xf32> -> tensor<4xf32>
+// CHECK: return %[[ARG]]
+  return %0 : tensor<4xf32>
+}
+
+// CHECK-LABEL: func @all_reduce_empty_mesh_axes_
diff erent_return_type
+func.func @all_reduce_empty_mesh_axes_
diff erent_return_type(
+    %arg0 : tensor<4xf32>) -> tensor<4xf64> {
+// CHECK: mesh.all_reduce
+  %0 = mesh.all_reduce %arg0 on @mesh0
+// CHECK-NOT: mesh_axes
+    mesh_axes = []
+    : tensor<4xf32> -> tensor<4xf64>
+  return %0 : tensor<4xf64>
+}
+
+// CHECK-LABEL: func @all_reduce_default_reduction
+func.func @all_reduce_default_reduction(
+    %arg0 : tensor<4xf32>) -> tensor<4xf64> {
+  %0 = mesh.all_reduce %arg0 on @mesh0
+    mesh_axes = [0]
+// CHECK-NOT: reduction
+    reduction = <sum>
+    : tensor<4xf32> -> tensor<4xf64>
+  return %0 : tensor<4xf64>
+}
+
+// CHECK-LABEL: func @all_to_all_empty_mesh_axes
+func.func @all_to_all_empty_mesh_axes(
+// CHECK-SAME: %[[ARG:.*]]: tensor<8xf32>
+    %arg0 : tensor<8xf32>) -> tensor<8xf32> {
+// CHECK-NOT: mesh.all_to_all
+  %0 = mesh.all_to_all %arg0 on @mesh0
+    mesh_axes = []
+    split_axis = 0
+    concat_axis = 0
+    : tensor<8xf32> -> tensor<8xf32>
+// CHECK: return %[[ARG]]
+  return %0 : tensor<8xf32>
+}
+
+// CHECK-LABEL: func @all_gather_empty_mesh_axes
+func.func @all_gather_empty_mesh_axes(
+// CHECK-SAME: %[[ARG:.*]]: tensor<4xf32>
+    %arg0 : tensor<4xf32>) -> tensor<4xf32> {
+// CHECK-NOT: mesh.all_gather
+  %0 = mesh.all_gather %arg0 on @mesh0
+    mesh_axes = []
+    gather_axis = 0
+    : tensor<4xf32> -> tensor<4xf32>
+// CHECK: return %[[ARG]]
+  return %0 : tensor<4xf32>
+}
+
+// CHECK-LABEL: func @reduce_scatter_empty_mesh_axes
+func.func @reduce_scatter_empty_mesh_axes(
+// CHECK-SAME: %[[ARG:.*]]: tensor<4xf32>
+    %arg0 : tensor<4xf32>) -> tensor<4xf32> {
+// CHECK-NOT: mesh.reduce_scatter
+  %0 = mesh.reduce_scatter %arg0 on @mesh0
+    mesh_axes = []
+    scatter_axis = 0
+    : tensor<4xf32> -> tensor<4xf32>
+// CHECK: return %[[ARG]]
+  return %0 : tensor<4xf32>
+}
+
+// CHECK-LABEL: func @reduce_scatter_empty_mesh_axes_
diff erent_return_type
+func.func @reduce_scatter_empty_mesh_axes_
diff erent_return_type(
+    %arg0 : tensor<4xf32>) -> tensor<4xf64> {
+// CHECK: mesh.reduce_scatter
+  %0 = mesh.reduce_scatter %arg0 on @mesh0
+// CHECK-NOT: mesh_axes
+    mesh_axes = []
+    scatter_axis = 0
+    : tensor<4xf32> -> tensor<4xf64>
+  return %0 : tensor<4xf64>
+}
+
+// CHECK-LABEL: func @reduce_scatter_default_reduction
+func.func @reduce_scatter_default_reduction(
+    %arg0 : tensor<4xf32>) -> tensor<2xf64> {
+  %0 = mesh.reduce_scatter %arg0 on @mesh0
+    mesh_axes = [0]
+// CHECK-NOT: reduction
+    reduction = <sum>
+    scatter_axis = 0
+    : tensor<4xf32> -> tensor<2xf64>
+  return %0 : tensor<2xf64>
+}

diff  --git a/mlir/test/Dialect/Mesh/invalid.mlir b/mlir/test/Dialect/Mesh/invalid.mlir
index 246439dd4be7122..2999668f770baa7 100644
--- a/mlir/test/Dialect/Mesh/invalid.mlir
+++ b/mlir/test/Dialect/Mesh/invalid.mlir
@@ -67,3 +67,279 @@ func.func @mesh_axis_negtive_in_partial(
             tensor<4x8xf32, #mesh.shard<@mesh0, [[0]], partial=max[-1]>> {
   return %arg0 : tensor<4x8xf32, #mesh.shard<@mesh0, [[0]], partial=max[-1]>>
 }
+
+// -----
+
+func.func @all_reduce_invalid_mesh_symbol(
+    %arg0 : tensor<4xf32>) -> tensor<4xf64> {
+  // expected-error at +1 {{Undefined required mesh symbol "this_mesh_symbol_does_not_exist".}}
+  %0 = mesh.all_reduce %arg0 on @this_mesh_symbol_does_not_exist reduction = <sum>
+    : tensor<4xf32> -> tensor<4xf64>
+  return %0 : tensor<4xf64>
+}
+
+// -----
+
+mesh.cluster @mesh0(rank = 2, dim_sizes = [2, 4])
+
+func.func @all_reduce_invalid_mesh_axis(
+    %arg0 : tensor<4xf32>) -> tensor<4xf64> {
+  // expected-error at +1 {{0-based mesh axis index 2 is out of bounds. The referenced mesh "mesh0" is of rank 2.}}
+  %0 = mesh.all_reduce %arg0 on @mesh0 mesh_axes = [2] reduction = <sum>
+    : tensor<4xf32> -> tensor<4xf64>
+  return %0 : tensor<4xf64>
+}
+
+// -----
+
+mesh.cluster @mesh0(rank = 2, dim_sizes = [2, 4])
+
+func.func @all_reduce_duplicate_mesh_axis(
+    %arg0 : tensor<4xf32>) -> tensor<4xf64> {
+  // expected-error at +1 {{Mesh axes contains duplicate elements.}}
+  %0 = mesh.all_reduce %arg0 on @mesh0 mesh_axes = [0, 1, 0] reduction = <sum>
+    : tensor<4xf32> -> tensor<4xf64>
+  return %0 : tensor<4xf64>
+}
+
+// -----
+
+mesh.cluster @mesh0(rank = 2, dim_sizes = [2, 4])
+
+func.func @all_reduce_invalid_tensor_dimension_size(
+    %arg0 : tensor<4xf32>) -> tensor<5xf64> {
+  // expected-error at +1 {{'mesh.all_reduce' op requires the same shape for all operands and results}}
+  %0 = mesh.all_reduce %arg0 on @mesh0 : tensor<4xf32> -> tensor<5xf64>
+  return %0 : tensor<5xf64>
+}
+
+// -----
+
+func.func @all_gather_invalid_mesh_symbol(
+    %arg0 : tensor<4xf32>) -> tensor<4xf32> {
+  // expected-error at +1 {{Undefined required mesh symbol "this_mesh_symbol_does_not_exist".}}
+  %0 = mesh.all_gather %arg0 on @this_mesh_symbol_does_not_exist gather_axis = 0
+    : tensor<4xf32> -> tensor<4xf32>
+  return %0 : tensor<4xf32>
+}
+
+// -----
+
+mesh.cluster @mesh0(rank = 2, dim_sizes = [2, 4])
+
+func.func @all_gather_invalid_mesh_axis(
+    %arg0 : tensor<4xf32>) -> tensor<4xf32> {
+  // expected-error at +1 {{0-based mesh axis index 2 is out of bounds. The referenced mesh "mesh0" is of rank 2.}}
+  %0 = mesh.all_gather %arg0 on @mesh0 mesh_axes = [2] gather_axis = 0
+    : tensor<4xf32> -> tensor<4xf32>
+  return %0 : tensor<4xf32>
+}
+
+// -----
+
+mesh.cluster @mesh0(rank = 2, dim_sizes = [2, 4])
+
+func.func @all_reduce_duplicate_mesh_axis(
+    %arg0 : tensor<4xf32>) -> tensor<4xf32> {
+  // expected-error at +1 {{Mesh axes contains duplicate elements.}}
+  %0 = mesh.all_gather %arg0 on @mesh0 mesh_axes = [2, 2] gather_axis = 0
+    : tensor<4xf32> -> tensor<4xf32>
+  return %0 : tensor<4xf32>
+}
+
+// -----
+
+mesh.cluster @mesh0(rank = 1, dim_sizes = [1])
+
+func.func @all_gather_invalid_non_gather_axis_dimension_size(
+    %arg0 : tensor<3x4xf32>) -> tensor<3x5xf32> {
+  // expected-error at +1 {{Dimension size mismatch for result axis 1. Expected 4, but got 5.}}
+  %0 = mesh.all_gather %arg0 on @mesh0 mesh_axes = [0] gather_axis = 0
+    : tensor<3x4xf32> -> tensor<3x5xf32>
+  return %0 : tensor<3x5xf32>
+}
+
+// -----
+
+mesh.cluster @mesh0(rank = 2, dim_sizes = [1, 2])
+
+func.func @all_gather_invalid_gather_axis_dimension_size(
+    %arg0 : tensor<3x4xf32>) -> tensor<3x5xf32> {
+  // expected-error at +1 {{Dimension size mismatch for result axis 1. Expected 8, but got 5.}}
+  %0 = mesh.all_gather %arg0 on @mesh0 mesh_axes = [1] gather_axis = 1
+    : tensor<3x4xf32> -> tensor<3x5xf32>
+  return %0 : tensor<3x5xf32>
+}
+
+// -----
+
+mesh.cluster @mesh0(rank = 1, dim_sizes = [1])
+
+func.func @all_gather_invalid_gather_axis_dynamic_dimension(
+    %arg0 : tensor<?xf32>) -> tensor<3xf32> {
+  // expected-error at +1 {{Dimension size mismatch for result axis 0. Expected dynamic, but got 3.}}
+  %0 = mesh.all_gather %arg0 on @mesh0 gather_axis = 0
+    : tensor<?xf32> -> tensor<3xf32>
+  return %0 : tensor<3xf32>
+}
+
+// -----
+
+mesh.cluster @mesh0(rank = 1, dim_sizes = [1])
+
+func.func @all_gather_invalid_gather_axis(
+    %arg0 : tensor<3xf32>) -> tensor<3xf32> {
+  // expected-error at +1 {{Gather axis 1 is out of bounds [0, 1).}}
+  %0 = mesh.all_gather %arg0 on @mesh0 mesh_axes = [0] gather_axis = 1
+    : tensor<3xf32> -> tensor<3xf32>
+  return %0 : tensor<3xf32>
+}
+
+// -----
+
+mesh.cluster @mesh0(rank = 1, dim_sizes = [1])
+
+func.func @all_gather_invalid_negative_gather_axis(
+    %arg0 : tensor<3xf32>) -> tensor<3xf32> {
+  // expected-error at +1 {{Gather axis -1 is out of bounds [0, 1).}}
+  %0 = mesh.all_gather %arg0 on @mesh0 mesh_axes = [0] gather_axis = -1
+    : tensor<3xf32> -> tensor<3xf32>
+  return %0 : tensor<3xf32>
+}
+
+// -----
+
+func.func @all_to_all_invalid_mesh_symbol(
+    %arg0 : tensor<3x6xi8>) -> tensor<3x6xi8> {
+  // expected-error at +1 {{Undefined required mesh symbol "this_mesh_symbol_does_not_exist".}}
+  %0 = mesh.all_to_all %arg0 on @this_mesh_symbol_does_not_exist
+    split_axis = 1 concat_axis = 0
+    : tensor<3x6xi8> -> tensor<3x6xi8>
+  return %0 : tensor<3x6xi8>
+}
+
+// -----
+
+mesh.cluster @mesh0(rank = 1, dim_sizes = [1])
+
+func.func @all_to_all_duplicate_mesh_axis(
+    %arg0 : tensor<3x6xi8>) -> tensor<3x6xi8> {
+  // expected-error at +1 {{Mesh axes contains duplicate elements.}}
+  %0 = mesh.all_to_all %arg0 on @mesh0 mesh_axes = [0, 0]
+    split_axis = 0 concat_axis = 0
+    : tensor<3x6xi8> -> tensor<3x6xi8>
+  return %0 : tensor<3x6xi8>
+}
+
+// -----
+
+mesh.cluster @mesh0(rank = 2, dim_sizes = [0, 1])
+
+func.func @all_to_all_invalid_non_dynamic_result_dimension_induced_by_dynamic_device_group(
+    %arg0 : tensor<3x6xi8>) -> tensor<3x6xi8> {
+  // expected-error at +1 {{Dimension size mismatch for result axis 1. Expected dynamic, but got 6.}}
+  %0 = mesh.all_to_all %arg0 on @mesh0 mesh_axes = [0]
+    split_axis = 0 concat_axis = 1
+    : tensor<3x6xi8> -> tensor<3x6xi8>
+  return %0 : tensor<3x6xi8>
+}
+
+// -----
+
+mesh.cluster @mesh0(rank = 2, dim_sizes = [1, 1])
+
+func.func @all_to_all_invalid_non_dynamic_result_split_dimension_induced_by_dynamic_operand_dimension(
+    %arg0 : tensor<?x6xi8>) -> tensor<3x?xi8> {
+  // expected-error at +1 {{Dimension size mismatch for result axis 0. Expected dynamic, but got 3.}}
+  %0 = mesh.all_to_all %arg0 on @mesh0 mesh_axes = [1]
+    split_axis = 0 concat_axis = 1
+    : tensor<?x6xi8> -> tensor<3x?xi8>
+  return %0 : tensor<3x?xi8>
+}
+
+// -----
+
+mesh.cluster @mesh0(rank = 2, dim_sizes = [1, 1])
+
+func.func @all_to_all_invalid_non_dynamic_result_concat_dimension_induced_by_dynamic_operand_dimension(
+    %arg0 : tensor<3x?xi8>) -> tensor<?x3xi8> {
+  // expected-error at +1 {{Dimension size mismatch for result axis 1. Expected dynamic, but got 3.}}
+  %0 = mesh.all_to_all %arg0 on @mesh0 mesh_axes = [1]
+    split_axis = 0 concat_axis = 1
+    : tensor<3x?xi8> -> tensor<?x3xi8>
+  return %0 : tensor<?x3xi8>
+}
+
+// -----
+
+mesh.cluster @mesh0(rank = 1, dim_sizes = [3])
+
+func.func @all_to_all_invalid_non_dynamic_result_concat_dimension_size(
+    %arg0 : tensor<3x2xi8>) -> tensor<1x7xi8> {
+  // expected-error at +1 {{Dimension size mismatch for result axis 1. Expected 6, but got 7.}}
+  %0 = mesh.all_to_all %arg0  on @mesh0 mesh_axes = [0]
+    split_axis = 0 concat_axis = 1
+    : tensor<3x2xi8> -> tensor<1x7xi8>
+  return %0 : tensor<1x7xi8>
+}
+
+// -----
+
+mesh.cluster @mesh0(rank = 1, dim_sizes = [3])
+
+func.func @all_to_all_invalid_non_dynamic_result_split_dimension_size(
+    %arg0 : tensor<3x2xi8>) -> tensor<2x6xi8> {
+  // expected-error at +1 {{Dimension size mismatch for result axis 0. Expected 1, but got 2.}}
+  %0 = mesh.all_to_all %arg0 on @mesh0 mesh_axes = [0]
+    split_axis = 0 concat_axis = 1
+    : tensor<3x2xi8> -> tensor<2x6xi8>
+  return %0 : tensor<2x6xi8>
+}
+
+// -----
+
+mesh.cluster @mesh0(rank = 1, dim_sizes = [3])
+
+func.func @reduce_scatter_duplicate_mesh_axis(
+    %arg0 : tensor<?xf32>) -> tensor<?xf64> {
+  // expected-error at +1 {{Mesh axes contains duplicate elements.}}
+  %0 = mesh.reduce_scatter %arg0 on @mesh0 mesh_axes = [0, 0] scatter_axis = 0
+    : tensor<?xf32> -> tensor<?xf64>
+  return %0 : tensor<?xf64>
+}
+
+// -----
+
+mesh.cluster @mesh0(rank = 1, dim_sizes = [3])
+
+func.func @reduce_scatter_invalid_dynamic_dimension(
+    %arg0 : tensor<?xf32>) -> tensor<2xf64> {
+  // expected-error at +1 {{Dimension size mismatch for result axis 0. Expected dynamic, but got 2.}}
+  %0 = mesh.reduce_scatter %arg0 on @mesh0 scatter_axis = 0
+    : tensor<?xf32> -> tensor<2xf64>
+  return %0 : tensor<2xf64>
+}
+
+// -----
+
+mesh.cluster @mesh0(rank = 1, dim_sizes = [3])
+
+func.func @reduce_scatter_invalid_static_dimension_size(
+    %arg0 : tensor<3xf32>) -> tensor<2xf64> {
+  // expected-error at +1 {{Dimension size mismatch for result axis 0. Expected 1, but got 2.}}
+  %0 = mesh.reduce_scatter %arg0 on @mesh0 mesh_axes = [0] scatter_axis = 0
+    : tensor<3xf32> -> tensor<2xf64>
+  return %0 : tensor<2xf64>
+}
+
+// -----
+
+mesh.cluster @mesh0(rank = 1, dim_sizes = [3])
+
+func.func @reduce_scatter_invalid_operand_static_dimension_size(
+    %arg0 : tensor<4xf32>) -> tensor<?xf64> {
+  // expected-error at +1 {{Operand dimension size 4 is not divisible by collective device group size 3 for scatter axis 0.}}
+  %0 = mesh.reduce_scatter %arg0 on @mesh0 mesh_axes = [0] scatter_axis = 0
+    : tensor<4xf32> -> tensor<?xf64>
+  return %0 : tensor<?xf64>
+}

diff  --git a/mlir/test/Dialect/Mesh/ops.mlir b/mlir/test/Dialect/Mesh/ops.mlir
index ee5f8f67792b928..5b264bc88dfc2a7 100644
--- a/mlir/test/Dialect/Mesh/ops.mlir
+++ b/mlir/test/Dialect/Mesh/ops.mlir
@@ -12,6 +12,8 @@ mesh.cluster @mesh2(rank = 2, dim_sizes = [0, 4])
 // CHECK: mesh.cluster @mesh3
 mesh.cluster @mesh3(rank = 2)
 
+mesh.cluster @mesh4(rank = 1, dim_sizes = [3])
+
 // CHECK-LABEL: func @mesh_shard_encoding_fully_replicated
 func.func @mesh_shard_encoding_fully_replicated(
     // CHECK-SAME: %[[ARG:.*]]: tensor<4x8xf32, #mesh.shard<@mesh0, {{\[\[}}]]>>
@@ -126,3 +128,124 @@ func.func @mesh_shard_op_two_users(%arg0 : tensor<4x8xf32>) ->
   %2 = mesh.shard %0 to <@mesh0, [[2]]> annotate_for_users : tensor<4x8xf32>
   return %1, %2 : tensor<4x8xf32>, tensor<4x8xf32>
 }
+
+// CHECK-LABEL: func @all_reduce
+func.func @all_reduce(
+    // CHECK-SAME: %[[ARG:.*]]: tensor<3x4xf32>
+    %arg0 : tensor<3x4xf32>) -> tensor<3x4xf64> {
+  // CHECK-NEXT: mesh.all_reduce %[[ARG]] on @mesh0 mesh_axes = [1, 0] reduction = <max>
+  // CHECK-SAME: : tensor<3x4xf32> -> tensor<3x4xf64>
+  %0 = mesh.all_reduce %arg0 on @mesh0 mesh_axes = [1, 0] reduction = <max>
+    : tensor<3x4xf32> -> tensor<3x4xf64>
+  return %0 : tensor<3x4xf64>
+}
+
+// CHECK-LABEL: func @all_gather
+func.func @all_gather(
+    // CHECK-SAME: %[[ARG:.*]]: tensor<3x4xf32>
+    %arg0 : tensor<3x4xf32>) -> tensor<3x16xf32> {
+  // CHECK-NEXT: mesh.all_gather %[[ARG]] on @mesh0 mesh_axes = [2] gather_axis = 1
+  // CHECK-SAME: : tensor<3x4xf32> -> tensor<3x16xf32>
+  %0 = mesh.all_gather %arg0 on @mesh0 mesh_axes = [2] gather_axis = 1
+    : tensor<3x4xf32> -> tensor<3x16xf32>
+  return %0 : tensor<3x16xf32>
+}
+
+// CHECK-LABEL: func @all_gather_dynamic_dims_in_tensor
+func.func @all_gather_dynamic_dims_in_tensor(
+    // CHECK-SAME: %[[ARG:.*]]: tensor<?x?xf32>
+    %arg0 : tensor<?x?xf32>) -> tensor<?x?xf32> {
+  // CHECK-NEXT: mesh.all_gather %[[ARG]] on @mesh0 mesh_axes = [2] gather_axis = 1
+  // CHECK-SAME: : tensor<?x?xf32> -> tensor<?x?xf32>
+  %0 = mesh.all_gather %arg0 on @mesh0 mesh_axes = [2] gather_axis = 1
+    : tensor<?x?xf32> -> tensor<?x?xf32>
+  return %0 : tensor<?x?xf32>
+}
+
+// CHECK-LABEL: func @all_gather_dynamic_dims_in_mesh
+func.func @all_gather_dynamic_dims_in_mesh(
+    // CHECK-SAME: %[[ARG:.*]]: tensor<5x6xf32>
+    %arg0 : tensor<5x6xf32>) -> tensor<5x?xf32> {
+  // CHECK-NEXT: mesh.all_gather %[[ARG]] on @mesh3 mesh_axes = [1] gather_axis = 1
+  // CHECK-SAME: : tensor<5x6xf32> -> tensor<5x?xf32>
+  %0 = mesh.all_gather %arg0 on @mesh3 mesh_axes = [1] gather_axis = 1
+    : tensor<5x6xf32> -> tensor<5x?xf32>
+  return %0 : tensor<5x?xf32>
+}
+
+// CHECK-LABEL: func @all_to_all
+func.func @all_to_all(
+    // CHECK-SAME: %[[ARG:.*]]: tensor<3x6xi8>
+    %arg0 : tensor<3x6xi8>) -> tensor<3x6xi8> {
+  // CHECK-NEXT: mesh.all_to_all %[[ARG]]
+  // CHECK-SAME: on @mesh4 split_axis = 1 concat_axis = 0
+  // CHECK-SAME: : tensor<3x6xi8> -> tensor<3x6xi8>
+  %0 = mesh.all_to_all %arg0 on @mesh4
+    split_axis = 1 concat_axis = 0
+    : tensor<3x6xi8> -> tensor<3x6xi8>
+  return %0 : tensor<3x6xi8>
+}
+
+// CHECK-LABEL: func @all_to_all_dynamic_dims_in_result
+func.func @all_to_all_dynamic_dims_in_result(
+    // CHECK-SAME: %[[ARG:.*]]: tensor<3x6xi8>
+    %arg0 : tensor<3x6xi8>) -> tensor<3x?xi8> {
+  // CHECK-NEXT: mesh.all_to_all %[[ARG]]
+  // CHECK-SAME: on @mesh4 split_axis = 1 concat_axis = 0
+  // CHECK-SAME: : tensor<3x6xi8> -> tensor<3x?xi8>
+  %0 = mesh.all_to_all %arg0 on @mesh4
+    split_axis = 1 concat_axis = 0
+    : tensor<3x6xi8> -> tensor<3x?xi8>
+  return %0 : tensor<3x?xi8>
+}
+
+// CHECK-LABEL: func @all_to_all_same_split_concat_dim_with_dynamic_device_group_size
+func.func @all_to_all_same_split_concat_dim_with_dynamic_device_group_size(
+    // CHECK-SAME: %[[ARG:.*]]: tensor<3xi8>
+    %arg0 : tensor<3xi8>) -> tensor<3xi8> {
+  // CHECK-NEXT: mesh.all_to_all %[[ARG]]
+  // CHECK-SAME: @mesh4 split_axis = 0 concat_axis = 0
+  // CHECK-SAME: : tensor<3xi8> -> tensor<3xi8>
+  %0 = mesh.all_to_all %arg0 on @mesh4
+    split_axis = 0 concat_axis = 0
+    : tensor<3xi8> -> tensor<3xi8>
+  return %0 : tensor<3xi8>
+}
+
+// CHECK-LABEL: func @all_to_all_non_divisible_split_axis_size
+func.func @all_to_all_non_divisible_split_axis_size(
+    // CHECK-SAME: %[[ARG:.*]]: tensor<2x3xi8>
+    %arg0 : tensor<2x3xi8>) -> tensor<?x12xi8> {
+  // CHECK-NEXT: mesh.all_to_all %[[ARG]]
+  // CHECK-SAME: @mesh0 mesh_axes = [0, 1] split_axis = 0 concat_axis = 1
+  // CHECK-SAME: : tensor<2x3xi8> -> tensor<?x12xi8>
+  %0 = mesh.all_to_all %arg0 on @mesh0 mesh_axes = [0, 1]
+    split_axis = 0 concat_axis = 1
+    : tensor<2x3xi8> -> tensor<?x12xi8>
+  return %0 : tensor<?x12xi8>
+}
+
+// CHECK-LABEL: func @reduce_scatter_static_dimensions
+func.func @reduce_scatter_static_dimensions(
+    // CHECK-SAME: %[[ARG:.*]]: tensor<3x4xf32>
+    %arg0 : tensor<3x4xf32>) -> tensor<3x1xf64> {
+  // CHECK-NEXT: mesh.reduce_scatter %[[ARG]]
+  // CHECK-SAME: on @mesh0 mesh_axes = [2] reduction = <max> scatter_axis = 1
+  // CHECK-SAME: : tensor<3x4xf32> -> tensor<3x1xf64>
+  %0 = mesh.reduce_scatter %arg0 on @mesh0 mesh_axes = [2]
+    reduction = <max> scatter_axis = 1
+    : tensor<3x4xf32> -> tensor<3x1xf64>
+  return %0 : tensor<3x1xf64>
+}
+
+// CHECK-LABEL: func @reduce_scatter_dynamic_dimensions
+func.func @reduce_scatter_dynamic_dimensions(
+    // CHECK-SAME: %[[ARG:.*]]: tensor<?xf32>
+    %arg0 : tensor<?xf32>) -> tensor<?xf64> {
+  // CHECK-NEXT: mesh.reduce_scatter %[[ARG]]
+  // CHECK-SAME: on @mesh3 mesh_axes = [0, 1] scatter_axis = 0
+  // CHECK-SAME: : tensor<?xf32> -> tensor<?xf64>
+  %0 = mesh.reduce_scatter %arg0 on @mesh3 mesh_axes = [0, 1] scatter_axis = 0
+    : tensor<?xf32> -> tensor<?xf64>
+  return %0 : tensor<?xf64>
+}


        


More information about the Mlir-commits mailing list