[Mlir-commits] [mlir] [mlir][mesh] Add collective communication operations (PR #71960)

Boian Petkantchin llvmlistbot at llvm.org
Fri Nov 10 09:18:19 PST 2023


https://github.com/sogartar created https://github.com/llvm/llvm-project/pull/71960

Add all-gather, all-reduce, all-to-all and reduce-scatter. These operations have device mesh semantics.

I have not included ops like reduce, gather, send and recv to see first if reviewers notice any systemic issues. Also this PR is already big enough.


>From 5b8c7f70202f7cac6b6b1a29e168fe6fa43fc720 Mon Sep 17 00:00:00 2001
From: Boian Petkantchin <boian at nod-labs.com>
Date: Mon, 6 Nov 2023 15:55:31 -0800
Subject: [PATCH] [mlir][mesh] Add collective communication operations

Add all-gather, all-reduce, all-to-all and reduce-scatter.
These operations have device mesh semantics.
---
 mlir/docs/Dialects/Mesh.md                    |  34 ++
 mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td |   8 +-
 mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h   |   2 +
 mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td  | 216 +++++++++
 mlir/lib/Dialect/Mesh/IR/MeshOps.cpp          | 417 ++++++++++++++++++
 mlir/test/Dialect/Mesh/canonicalization.mlir  |  72 +++
 mlir/test/Dialect/Mesh/invalid.mlir           | 240 ++++++++++
 mlir/test/Dialect/Mesh/ops.mlir               | 119 +++++
 8 files changed, 1105 insertions(+), 3 deletions(-)
 create mode 100644 mlir/docs/Dialects/Mesh.md
 create mode 100644 mlir/test/Dialect/Mesh/canonicalization.mlir

diff --git a/mlir/docs/Dialects/Mesh.md b/mlir/docs/Dialects/Mesh.md
new file mode 100644
index 000000000000000..6dd4f79022061ee
--- /dev/null
+++ b/mlir/docs/Dialects/Mesh.md
@@ -0,0 +1,34 @@
+# 'mesh' Dialect
+
+The `mesh` dialect contains a set of attributes, operations and interfaces that
+are useful for representing sharding and communication on a device mesh
+cluster.
+
+[TOC]
+
+## Collective Communication Operations
+There are a number of operations in the Mesh dialect to facilitate
+communication between devices in a mesh.
+It is assumed that the user is familiar with collective operations.
+[Wikipedia](https://en.wikipedia.org/wiki/Collective_operation) has a good
+explanation.
+The main addition is that the collectives in this dialect have mesh
+semantics.
+The operation attributes `mesh` and `mesh_axes` specifies a set of device mesh
+axes that partition the devices into disjoint groups.
+The collective operation is performed between devices in the same group.
+Devices that have the same coordinates outside of axes `mesh_axes` are in the
+same group.
+For example if we have a device mesh of size `2x3x4x5` and the partition mesh
+axes set is `{0, 1}` then devices are partitioned into the groups
+`{ { (i, j, k, m) | 0<=i<2, 0<=j<3 } | 0<=k<4, 0<=m<5 }`.
+Devices (1, 0, 2, 3) and (1, 1, 2, 3) will be in the same group.
+Device (1, 0, 2, 4) will be in another group.
+
+## Operations
+
+[include "Dialects/MeshOps.md"]
+
+## Attributes
+
+[include "Dialects/MeshAttributes.md"]
diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td b/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td
index a91ef569347bff1..9d39b1b3329fb4b 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td
@@ -23,9 +23,7 @@ def Mesh_Dialect : Dialect {
   let cppNamespace = "::mlir::mesh";
 
   let description = [{
-    The `mesh` dialect contains a set of attributes, operations, interfaces that
-    are useful for representing sharding and communication on device mesh
-    cluster.
+    See [Mesh dialect documentation](mlir/docs/Dialects/Mesh.md).
   }];
 
   let dependentDialects = [
@@ -49,6 +47,10 @@ def Mesh_Partial : I32EnumAttr<"Partial", "partial type of a distributed tensor"
   let cppNamespace = "::mlir::mesh";
 }
 
+def Mesh_PartialAttr : EnumAttr<Mesh_Dialect, Mesh_Partial, "partial"> {
+  let assemblyFormat = "`<` $value `>`";
+}
+
 // Mesh_IteratorType and Mesh_Partial are used to annotate different aspects of
 // distributed tensors. Mesh_IteratorType annotates loops in an operation, while
 // Mesh_Partial indicates whether a tensor is sharded on a specific dimension or
diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
index 05eba66a89949b6..7698d60813a8f10 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
@@ -10,9 +10,11 @@
 #define MLIR_DIALECT_MESH_IR_MESHOPS_H
 
 #include "mlir/Bytecode/BytecodeOpInterface.h"
+#include "mlir/IR/OpDefinition.h"
 #include "mlir/IR/SymbolTable.h"
 #include "mlir/Interfaces/InferTypeOpInterface.h"
 #include "mlir/Interfaces/SideEffectInterfaces.h"
+#include <algorithm>
 
 #include "mlir/Dialect/Mesh/IR/MeshOpsDialect.h.inc"
 
diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
index a8aa0a694bee29f..15354babe870599 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
@@ -13,6 +13,8 @@ include "mlir/Dialect/Mesh/IR/MeshBase.td"
 include "mlir/Interfaces/InferTypeOpInterface.td"
 include "mlir/Interfaces/SideEffectInterfaces.td"
 include "mlir/IR/BuiltinTypes.td"
+include "mlir/IR/CommonAttrConstraints.td"
+include "mlir/IR/CommonTypeConstraints.td"
 include "mlir/IR/SymbolInterfaces.td"
 
 //===----------------------------------------------------------------------===//
@@ -77,6 +79,15 @@ def Mesh_ClusterOp : Mesh_Op<"cluster", [Symbol]> {
     $sym_name `(` `rank` `=` $rank (`,` `dim_sizes` `=` $dim_sizes^)? `)`
       attr-dict
   }];
+  let extraClassDeclaration = [{
+    ::mlir::SmallVector<int64_t> canonicalDimSizes();
+
+    template <typename OutIt>
+    void canonicalDimSizes(OutIt outIt) {
+      std::copy(getDimSizes().begin(), getDimSizes().end(), outIt);
+      std::fill_n(outIt, getRank() - getDimSizes().size(), 0);
+    }
+  }];
   let hasVerifier = 1;
 }
 
@@ -171,4 +182,209 @@ def Mesh_ShardOp : Mesh_Op<"shard", [Pure, SameOperandsAndResultType]> {
   }];
 }
 
+//===----------------------------------------------------------------------===//
+// collective communication ops
+//===----------------------------------------------------------------------===//
+
+class Mesh_CollectiveCommunicationOpBase<
+    string mnemonic, list<Trait> traits = []> :
+    Mesh_Op<mnemonic,
+      !listconcat(traits,
+      [SymbolUserOpInterface])> {
+  let assemblyFormat = "$input attr-dict `:` type($input) `->` type($result)";
+  code extraClassDeclarationBase = [{
+    ::mlir::LogicalResult verifySymbolUses(
+          ::mlir::SymbolTableCollection &symbolTable);
+  }];
+}
+
+def Mesh_AllGatherOp : Mesh_CollectiveCommunicationOpBase<"all_gather", [
+    SameOperandsAndResultElementType,
+    SameOperandsAndResultRank
+  ]> {
+  let summary = "All-gather over a device mesh.";
+  let description = [{
+    Gathers along the `gather_axis` tensor axis.
+    The order of input tensors in the resulting tensor is the same as the
+    order of the corresponding devices' multi-index in the mesh.
+
+    Example:
+    ```mlir
+    mesh.cluster @mesh0(rank = 2, dim_sizes = [2, 2])
+    ...
+    %1 = mesh.all_gather %0 {
+        mesh = @mesh0, mesh_axes = array<i16: 1>, gather_axis = 1 : index
+      } : tensor<2x2xi8> -> tensor<2x4xi8>
+    ```
+    Input:
+    ```
+                     +-------+-------+
+    device (0, 0) -> |  1  2 |  5  6 | <- device (0, 1)
+                     |  3  4 |  7  8 |
+                     +-------+-------+
+    device (1, 0) -> |  9 10 | 13 14 | <- device (1, 1)
+                     | 11 12 | 15 16 |
+                     +-------+-------+
+    ```
+    Result:
+    ```
+    +-------------+
+    |  1  2  5  6 | <- devices (0, 0) and (0, 1)
+    |  3  4  7  8 |
+    +-------------+
+    |  9 10 13 14 | <- devices (1, 0) and (1, 1)
+    | 11 12 15 16 |
+    +-------------+
+    ```
+  }];
+  let arguments = (ins
+    AnyNon0RankedTensor:$input,
+    FlatSymbolRefAttr:$mesh,
+    DefaultValuedOptionalAttr<DenseI16ArrayAttr, "{}">:$mesh_axes,
+    APIntAttr:$gather_axis
+  );
+  let results = (outs
+    AnyNon0RankedTensor:$result
+  );
+  let hasCanonicalizer = 1;
+  let hasVerifier = 1;
+  let extraClassDeclaration = extraClassDeclarationBase;
+}
+
+def Mesh_AllReduceOp : Mesh_CollectiveCommunicationOpBase<"all_reduce", [
+    SameOperandsAndResultShape]> {
+  let summary = "All-reduce over a device mesh.";
+  let description = [{
+    The accumulation element type is specified by the result type and
+    it does not need to match the input element type.
+    The input element is converted to the result element type before
+    performing the reduction.
+
+    Attributes:
+    `reduction`: Indicates the reduction method.
+
+    Example:
+    ```
+    %1 = mesh.all_reduce %0 {
+        mesh = @mesh0, mesh_axes = array<i16: 1, 0>, reduction = #mesh.partial<max>
+      } : tensor<3x4xf32> -> tensor<3x4xf64>
+    ```
+  }];
+  let arguments = (ins
+    AnyRankedTensor:$input,
+    FlatSymbolRefAttr:$mesh,
+    DefaultValuedOptionalAttr<DenseI16ArrayAttr, "{}">:$mesh_axes,
+    DefaultValuedOptionalAttr<Mesh_PartialAttr, "::mlir::mesh::Partial::Sum">:$reduction
+  );
+  let results = (outs
+    AnyRankedTensor:$result
+  );
+  let hasCanonicalizer = 1;
+  let extraClassDeclaration = extraClassDeclarationBase;
+}
+
+def Mesh_AllToAllOp : Mesh_CollectiveCommunicationOpBase<"all_to_all", [
+    SameOperandsAndResultElementType,
+    SameOperandsAndResultRank]> {
+  let summary = "All-to-all over a device mesh.";
+  let description = [{
+    Performs an all-to-all on tensor pieces split along `split_axis`.
+    The resulting pieces are concatenated along `concat_axis` on ech device.
+    Example:
+    ```
+    mesh.cluster @mesh0(rank = 1, dim_sizes = [3])
+    ...
+    %1 = mesh.all_to_all %0 {
+        mesh = @mesh0, mesh_axes = array<i16: 0>, split_axis = 0, concat_axis = 0
+      } : tensor<3x6xi8> -> tensor<3x6xi8>
+    ```
+    Input:
+    ```
+     device  device  device
+     (0)     (1)     (2)
+    +-------+-------+-------+
+    | 11 12 | 21 22 | 31 32 |
+    | 13 14 | 23 24 | 33 34 |
+    | 15 16 | 25 26 | 35 36 |
+    +-------+-------+-------+
+    ```
+    Result:
+    ```
+     device  device  device
+     (0)     (1)     (2)
+    +-------+-------+-------+
+    | 11 12 | 13 14 | 15 16 |
+    | 21 22 | 23 24 | 25 26 |
+    | 31 32 | 33 34 | 35 36 |
+    +-------+-------+-------+
+    ```
+  }];
+  let arguments = (ins
+    AnyNon0RankedTensor:$input,
+    FlatSymbolRefAttr:$mesh,
+    DefaultValuedOptionalAttr<DenseI16ArrayAttr, "{}">:$mesh_axes,
+    APIntAttr:$split_axis,
+    APIntAttr:$concat_axis
+  );
+  let results = (outs
+    AnyNon0RankedTensor:$result
+  );
+  let hasCanonicalizer = 1;
+  let hasVerifier = 1;
+  let extraClassDeclaration = extraClassDeclarationBase;
+}
+
+def Mesh_ReduceScatterOp : Mesh_CollectiveCommunicationOpBase<"reduce_scatter", [
+    SameOperandsAndResultRank]> {
+  let summary = "Reduce-scatter over a device mesh.";
+  let description = [{
+    After the reduction scatters the result within each device group.
+    The tensor is split along `scatter_axis` and the pieces distributed
+    across the device group.
+    Example:
+    ```
+    mesh.cluster @mesh0(rank = 1, dim_sizes = [2, 2])
+    ...
+    %1 = mesh.reduce_scatter %0 {
+        mesh = @mesh0, mesh_axes = array<i16: 1>, reduction = #mesh.partial<max>, scatter_axis = 0
+      } : tensor<3x4xf32> -> tensor<1x4xf64>
+    ```
+    Input:
+    ```
+                     +-------+-------+
+    device (0, 0) -> |  1  2 |  5  6 | <- device (0, 1)
+                     |  3  4 |  7  8 |
+                     +-------+-------+
+    device (1, 0) -> |  9 10 | 13 14 | <- device (1, 1)
+                     | 11 12 | 15 16 |
+                     +-------+-------+
+    ```
+    Result:
+    ```
+    +-------+
+    |  6  8 | <- devices (0, 0)
+    +-------+
+    | 10 12 | <- devices (0, 1)
+    +-------+
+    | 22 24 | <- devices (1, 0)
+    +-------+
+    | 26 28 | <- devices (1, 1)
+    +-------+
+    ```
+  }];
+  let arguments = (ins
+    AnyNon0RankedTensor:$input,
+    FlatSymbolRefAttr:$mesh,
+    DefaultValuedOptionalAttr<DenseI16ArrayAttr, "{}">:$mesh_axes,
+    DefaultValuedOptionalAttr<Mesh_PartialAttr, "::mlir::mesh::Partial::Sum">:$reduction,
+    APIntAttr:$scatter_axis
+  );
+  let results = (outs
+    AnyRankedTensor:$result
+  );
+  let hasCanonicalizer = 1;
+  let hasVerifier = 1;
+  let extraClassDeclaration = extraClassDeclarationBase;
+}
+
 #endif // MLIR_DIALECT_MESH_IR_MESHOPS_TD
diff --git a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
index 588704f24574f90..6efc4c4ecc326ad 100644
--- a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
+++ b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
@@ -8,10 +8,26 @@
 
 #include "mlir/Dialect/Mesh/IR/MeshOps.h"
 #include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/BuiltinTypeInterfaces.h"
+#include "mlir/IR/Diagnostics.h"
 #include "mlir/IR/DialectImplementation.h"
+#include "mlir/IR/Location.h"
+#include "mlir/IR/PatternMatch.h"
 #include "mlir/Support/LLVM.h"
+#include "mlir/Support/LogicalResult.h"
+#include "llvm/ADT/DenseSet.h"
+#include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/SmallSet.h"
+#include "llvm/ADT/SmallVector.h"
 #include "llvm/ADT/TypeSwitch.h"
+#include <algorithm>
+#include <functional>
+#include <iterator>
+#include <numeric>
+#include <optional>
+#include <string>
+#include <utility>
 
 #define DEBUG_TYPE "mesh-ops"
 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ")
@@ -21,6 +37,60 @@ using namespace mlir::mesh;
 
 #include "mlir/Dialect/Mesh/IR/MeshOpsDialect.cpp.inc"
 
+namespace {
+
+template <typename It>
+It canonicalizeSetAsArray(It begin, It end) {
+  std::sort(begin, end);
+  return std::unique(begin, end);
+}
+
+template <typename R>
+auto canonicalizeSetAsArray(R &&range) {
+  return canonicalizeSetAsArray(adl_begin(range), adl_end(range));
+}
+
+template <typename T>
+SmallVector<T> &canonicalizeSetAsVector(SmallVector<T> &vec) {
+  auto newEnd = canonicalizeSetAsArray(vec);
+  vec.resize(newEnd - vec.begin());
+  return vec;
+}
+
+template <typename DimSize>
+bool isMeshDimensionDynamic(DimSize size) {
+  return size <= DimSize(0);
+}
+
+using MeshAxis = int16_t;
+
+struct DimensionSize {
+  static DimensionSize dynamic() { return DimensionSize(ShapedType::kDynamic); }
+  DimensionSize(int64_t val) : val(val) {}
+  int64_t value() const { return val; }
+  operator int64_t() const { return val; }
+  bool isDynamic() const { return ShapedType::isDynamic(val); }
+
+private:
+  int64_t val;
+};
+
+DimensionSize operator/(DimensionSize lhs, DimensionSize rhs) {
+  if (lhs.isDynamic() || rhs.isDynamic()) {
+    return DimensionSize::dynamic();
+  }
+  return lhs.value() / rhs.value();
+}
+
+DimensionSize operator*(DimensionSize lhs, DimensionSize rhs) {
+  if (lhs.isDynamic() || rhs.isDynamic()) {
+    return DimensionSize::dynamic();
+  }
+  return lhs.value() * rhs.value();
+}
+
+} // namespace
+
 //===----------------------------------------------------------------------===//
 // Mesh dialect
 //===----------------------------------------------------------------------===//
@@ -96,6 +166,12 @@ LogicalResult ClusterOp::verify() {
   return success();
 }
 
+SmallVector<int64_t> ClusterOp::canonicalDimSizes() {
+  SmallVector<int64_t> result;
+  canonicalDimSizes(std::back_inserter(result));
+  return result;
+}
+
 //===----------------------------------------------------------------------===//
 // mesh.shard op
 //===----------------------------------------------------------------------===//
@@ -129,6 +205,347 @@ MeshShardingAttr::verify(function_ref<InFlightDiagnostic()> emitError,
   return success();
 }
 
+//===----------------------------------------------------------------------===//
+// collective communication ops
+//===----------------------------------------------------------------------===//
+
+namespace {
+
+std::optional<DenseI16ArrayAttr>
+canonicalizeAxesSetAttribute(DenseI16ArrayAttr attr) {
+  if (!attr) {
+    return std::nullopt;
+  }
+  SmallVector<int16_t> axes = llvm::to_vector(attr.asArrayRef());
+  canonicalizeSetAsVector(axes);
+  if (axes.empty()) {
+    return std::nullopt;
+  }
+  return DenseI16ArrayAttr::get(attr.getContext(), axes);
+}
+
+template <typename Op>
+struct AxesSetCanonicalizationPattern : OpRewritePattern<Op> {
+  AxesSetCanonicalizationPattern(MLIRContext *context, StringRef axisSetAttr)
+      : OpRewritePattern<Op>(context), axisSetAttr(axisSetAttr) {}
+  LogicalResult matchAndRewrite(Op op,
+                                PatternRewriter &rewriter) const override {
+    auto canonicalMeshAxesAttr = canonicalizeAxesSetAttribute(
+        op->template getAttrOfType<DenseI16ArrayAttr>(axisSetAttr));
+    if (!canonicalMeshAxesAttr) {
+      op->removeAttr(axisSetAttr);
+    } else {
+      op->setAttr(axisSetAttr, canonicalMeshAxesAttr.value());
+    }
+    return success();
+  }
+
+  std::string axisSetAttr;
+};
+
+template <typename Op>
+void populateMeshAxesSetCanonicalizationPatterns(RewritePatternSet &patterns,
+                                                 MLIRContext *context) {
+  patterns.add<AxesSetCanonicalizationPattern<Op>>(context, "mesh_axes");
+}
+
+template <typename Op>
+LogicalResult verifyMeshSymbolUses(Op op, SymbolTableCollection &symbolTable) {
+  FlatSymbolRefAttr symbolAttr = op.getMeshAttr();
+  if (!symbolAttr) {
+    return op.emitError() << "Unspecified \"mesh\" symbol attribute.";
+  }
+  SymbolTableCollection symbolTableCollection;
+  mesh::ClusterOp mesh =
+      symbolTableCollection.lookupNearestSymbolFrom<mesh::ClusterOp>(
+          op.getOperation(), symbolAttr);
+  if (!mesh) {
+    return op.emitError() << "Undefined required mesh symbol \""
+                          << symbolAttr.getValue() << "\".";
+  }
+  DenseI16ArrayAttr meshAxes = op.getMeshAxesAttr();
+  if (!meshAxes) {
+    return success();
+  }
+  MeshAxis rank = mesh.getRank();
+  for (auto axis : meshAxes.asArrayRef()) {
+    if (axis >= rank || axis < 0) {
+      return op.emitError()
+             << "0-based mesh axis index " << axis
+             << " is out of bounds. The referenced mesh \""
+             << symbolAttr.getValue() << "\" is of rank " << rank << ".";
+    }
+  }
+
+  return success();
+}
+
+template <typename It>
+auto product(It begin, It end) {
+  using ElementType = std::decay_t<decltype(*begin)>;
+  return std::accumulate(begin, end, ElementType(1),
+                         std::multiplies<ElementType>());
+}
+
+template <typename R>
+auto product(R &&range) {
+  return product(adl_begin(range), adl_end(range));
+}
+
+int64_t collectiveDeviceGroupSize(ArrayRef<MeshAxis> meshAxes,
+                                  ArrayRef<int64_t> meshShape) {
+  int64_t res = 1;
+  for (MeshAxis axis = 0; axis < MeshAxis(meshShape.size()); ++axis) {
+    if (llvm::find(meshAxes, axis) == meshAxes.end()) {
+      continue;
+    }
+    if (isMeshDimensionDynamic(meshShape[axis])) {
+      return ShapedType::kDynamic;
+    }
+    res *= meshShape[axis];
+  }
+  return res;
+}
+
+LogicalResult verifyDimensionCompatibility(Location loc,
+                                           int64_t expectedDimSize,
+                                           int64_t resultDimSize,
+                                           int64_t resultAxis) {
+  if (!ShapedType::isDynamic(resultDimSize) &&
+      expectedDimSize != resultDimSize) {
+    return emitError(loc) << "Dimension size mismatch for result axis "
+                          << resultAxis << ". Expected "
+                          << (ShapedType::isDynamic(expectedDimSize)
+                                  ? Twine("dynamic")
+                                  : Twine(expectedDimSize))
+                          << ", but got " << resultDimSize << ".";
+  }
+
+  return success();
+}
+
+LogicalResult verifyGatherOperandAndResultShape(Value operand, Value result,
+                                                int64_t gatherAxis,
+                                                ArrayRef<MeshAxis> meshAxes,
+                                                ArrayRef<int64_t> meshShape) {
+  ShapedType operandType = operand.getType().cast<ShapedType>();
+  ShapedType resultType = result.getType().cast<ShapedType>();
+  auto deviceGroupSize =
+      DimensionSize(collectiveDeviceGroupSize(meshAxes, meshShape));
+  for (int64_t axis = 0; axis < operandType.getRank(); ++axis) {
+    auto operandDimSize = DimensionSize(operandType.getDimSize(axis));
+    auto resultDimSize = DimensionSize(resultType.getDimSize(axis));
+    auto expectedResultDimSize =
+        axis == gatherAxis ? deviceGroupSize * operandDimSize : operandDimSize;
+    if (failed(verifyDimensionCompatibility(
+            result.getLoc(), expectedResultDimSize, resultDimSize, axis))) {
+      return failure();
+    }
+  }
+  return success();
+}
+
+template <typename Op>
+FailureOr<ClusterOp> getMesh(Op op) {
+  SymbolTableCollection symbolTableCollection;
+  if (failed(verifyMeshSymbolUses(op, symbolTableCollection))) {
+    // We need to check the symbol here since this runs before
+    // SymbolUserOpInterface.
+    return failure();
+  }
+  return symbolTableCollection.lookupNearestSymbolFrom<mesh::ClusterOp>(
+      op.getOperation(), op.getMeshAttr());
+}
+
+template <typename Op>
+LogicalResult verifyGather(Op op) {
+  auto rank = op.getResult().getType().template cast<ShapedType>().getRank();
+  auto gatherAxis = op.getGatherAxis().getSExtValue();
+  if (gatherAxis < 0 || gatherAxis >= rank) {
+    return op.emitError() << "Gather axis " << gatherAxis
+                          << " is out of bounds [0, " << rank << ").";
+  }
+
+  auto mesh = getMesh(op);
+  if (failed(mesh)) {
+    return failure();
+  }
+  return verifyGatherOperandAndResultShape(op.getOperand(), op.getResult(),
+                                           gatherAxis, op.getMeshAxes(),
+                                           mesh.value().canonicalDimSizes());
+}
+
+LogicalResult verifyAllToAllOperandAndResultShape(Value operand, Value result,
+                                                  int64_t splitAxis,
+                                                  int64_t concatAxis,
+                                                  ArrayRef<MeshAxis> meshAxes,
+                                                  ArrayRef<int64_t> meshShape) {
+  ShapedType operandType = operand.getType().cast<ShapedType>();
+  ShapedType resultType = result.getType().cast<ShapedType>();
+  for (int64_t axis = 0; axis < operandType.getRank(); ++axis) {
+    if ((axis != splitAxis && axis != concatAxis) || splitAxis == concatAxis) {
+      if (failed(verifyDimensionCompatibility(
+              result.getLoc(), operandType.getDimSize(axis),
+              resultType.getDimSize(axis), axis))) {
+        return failure();
+      }
+    }
+  }
+
+  if (splitAxis == concatAxis) {
+    return success();
+  }
+
+  auto deviceGroupSize =
+      DimensionSize(collectiveDeviceGroupSize(meshAxes, meshShape));
+  auto operandConcatDimSize = DimensionSize(operandType.getDimSize(concatAxis));
+  auto operandSplitDimSize = DimensionSize(operandType.getDimSize(splitAxis));
+  if (!operandSplitDimSize.isDynamic() && !deviceGroupSize.isDynamic() &&
+      int64_t(operandSplitDimSize) % int64_t(deviceGroupSize) != 0) {
+    return emitError(result.getLoc())
+           << "Operand dimension size " << int64_t(operandSplitDimSize)
+           << " is not divisible by collective device group size "
+           << int64_t(deviceGroupSize) << " for split axis " << splitAxis
+           << ".";
+  }
+  DimensionSize expectedResultConcatDimSize =
+      operandConcatDimSize * deviceGroupSize;
+  DimensionSize expectedResultSplitDimSize =
+      operandSplitDimSize / deviceGroupSize;
+  if (failed(verifyDimensionCompatibility(
+          result.getLoc(), expectedResultConcatDimSize.value(),
+          resultType.getDimSize(concatAxis), concatAxis))) {
+    return failure();
+  }
+  if (failed(verifyDimensionCompatibility(
+          result.getLoc(), expectedResultSplitDimSize.value(),
+          resultType.getDimSize(splitAxis), splitAxis))) {
+    return failure();
+  }
+
+  return success();
+}
+
+LogicalResult verifyReduceScatterOperandAndResultShape(
+    Value operand, Value result, int64_t scatterAxis,
+    ArrayRef<MeshAxis> meshAxes, ArrayRef<int64_t> meshShape) {
+  ShapedType operandType = operand.getType().cast<ShapedType>();
+  ShapedType resultType = result.getType().cast<ShapedType>();
+  for (int64_t axis = 0; axis < operandType.getRank(); ++axis) {
+    if (axis != scatterAxis) {
+      if (failed(verifyDimensionCompatibility(
+              result.getLoc(), operandType.getDimSize(axis),
+              resultType.getDimSize(axis), axis))) {
+        return failure();
+      }
+    }
+  }
+
+  auto deviceGroupSize =
+      DimensionSize(collectiveDeviceGroupSize(meshAxes, meshShape));
+  auto operandScatterDimSize =
+      DimensionSize(operandType.getDimSize(scatterAxis));
+  if (!operandScatterDimSize.isDynamic() && !deviceGroupSize.isDynamic() &&
+      int64_t(operandScatterDimSize) % int64_t(deviceGroupSize) != 0) {
+    return emitError(result.getLoc())
+           << "Operand dimension size " << int64_t(operandScatterDimSize)
+           << " is not divisible by collective device group size "
+           << int64_t(deviceGroupSize) << " for scatter axis " << scatterAxis
+           << ".";
+  }
+  DimensionSize expectedResultScatterDimSize =
+      operandScatterDimSize / deviceGroupSize;
+  if (failed(verifyDimensionCompatibility(
+          result.getLoc(), expectedResultScatterDimSize.value(),
+          resultType.getDimSize(scatterAxis), scatterAxis))) {
+    return failure();
+  }
+
+  return success();
+}
+
+} // namespace
+
+//===----------------------------------------------------------------------===//
+// mesh.all_reduce op
+//===----------------------------------------------------------------------===//
+
+LogicalResult
+mlir::mesh::AllReduceOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
+  return verifyMeshSymbolUses(*this, symbolTable);
+}
+
+void mlir::mesh::AllReduceOp::getCanonicalizationPatterns(
+    RewritePatternSet &patterns, MLIRContext *context) {
+  populateMeshAxesSetCanonicalizationPatterns<AllReduceOp>(patterns, context);
+}
+
+//===----------------------------------------------------------------------===//
+// mesh.all_gather op
+//===----------------------------------------------------------------------===//
+
+LogicalResult
+mlir::mesh::AllGatherOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
+  return verifyMeshSymbolUses(*this, symbolTable);
+}
+
+void mlir::mesh::AllGatherOp::getCanonicalizationPatterns(
+    RewritePatternSet &patterns, MLIRContext *context) {
+  populateMeshAxesSetCanonicalizationPatterns<AllGatherOp>(patterns, context);
+}
+
+LogicalResult mlir::mesh::AllGatherOp::verify() { return verifyGather(*this); }
+
+//===----------------------------------------------------------------------===//
+// mesh.all_to_all op
+//===----------------------------------------------------------------------===//
+
+LogicalResult
+mlir::mesh::AllToAllOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
+  return verifyMeshSymbolUses(*this, symbolTable);
+}
+
+void mlir::mesh::AllToAllOp::getCanonicalizationPatterns(
+    RewritePatternSet &patterns, MLIRContext *context) {
+  populateMeshAxesSetCanonicalizationPatterns<AllToAllOp>(patterns, context);
+}
+
+LogicalResult mlir::mesh::AllToAllOp::verify() {
+  auto mesh = ::getMesh(*this);
+  if (failed(mesh)) {
+    return failure();
+  }
+  return verifyAllToAllOperandAndResultShape(
+      getOperand(), getResult(), getSplitAxis().getSExtValue(),
+      getConcatAxis().getSExtValue(), getMeshAxes(),
+      mesh.value().canonicalDimSizes());
+}
+
+//===----------------------------------------------------------------------===//
+// mesh.reduce_scatter op
+//===----------------------------------------------------------------------===//
+
+LogicalResult mlir::mesh::ReduceScatterOp::verifySymbolUses(
+    SymbolTableCollection &symbolTable) {
+  return verifyMeshSymbolUses(*this, symbolTable);
+}
+
+void mlir::mesh::ReduceScatterOp::getCanonicalizationPatterns(
+    RewritePatternSet &patterns, MLIRContext *context) {
+  populateMeshAxesSetCanonicalizationPatterns<ReduceScatterOp>(patterns,
+                                                               context);
+}
+
+LogicalResult mlir::mesh::ReduceScatterOp::verify() {
+  auto mesh = ::getMesh(*this);
+  if (failed(mesh)) {
+    return failure();
+  }
+  return verifyReduceScatterOperandAndResultShape(
+      getOperand(), getResult(), getScatterAxis().getSExtValue(), getMeshAxes(),
+      mesh.value().canonicalDimSizes());
+}
+
 //===----------------------------------------------------------------------===//
 // TableGen'd op method definitions
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Mesh/canonicalization.mlir b/mlir/test/Dialect/Mesh/canonicalization.mlir
new file mode 100644
index 000000000000000..21fa1887d14a53e
--- /dev/null
+++ b/mlir/test/Dialect/Mesh/canonicalization.mlir
@@ -0,0 +1,72 @@
+// RUN: mlir-opt --canonicalize %s | FileCheck %s
+
+mesh.cluster @mesh0(rank = 2, dim_sizes = [2, 4])
+
+// CHECK-LABEL: func @all_reduce_mesh_axes
+func.func @all_reduce_mesh_axes(
+    %arg0 : tensor<4xf32>) -> tensor<4xf64> {
+// CHECK: mesh_axes = array<i16: 0, 1>
+  %0 = mesh.all_reduce %arg0 {
+    mesh = @mesh0, mesh_axes = array<i16: 1, 0, 0>, reduction = #mesh.partial<sum>
+    } : tensor<4xf32> -> tensor<4xf64>
+  return %0 : tensor<4xf64>
+}
+
+// CHECK-LABEL: func @all_reduce_empty_mesh_axes_and_default_reduction
+func.func @all_reduce_empty_mesh_axes_and_default_reduction(
+    %arg0 : tensor<4xf32>) -> tensor<4xf64> {
+  %0 = mesh.all_reduce %arg0 {
+    mesh = @mesh0,
+// CHECK-NOT: mesh_axes
+    mesh_axes = array<i16>,
+// CHECK-NOT: reduction
+    reduction = #mesh.partial<sum>
+    } : tensor<4xf32> -> tensor<4xf64>
+  return %0 : tensor<4xf64>
+}
+
+// CHECK-LABEL: func @all_gather_empty_mesh_axes
+func.func @all_gather_empty_mesh_axes(
+    %arg0 : tensor<4xf32>) -> tensor<4xf32> {
+  %0 = mesh.all_gather %arg0 {
+    gather_axis = 0 : index,
+    mesh = @mesh0,
+// CHECK-NOT: mesh_axes
+    mesh_axes = array<i16>
+    } : tensor<4xf32> -> tensor<4xf32>
+  return %0 : tensor<4xf32>
+}
+
+// CHECK-LABEL: func @all_gather_mesh_axes
+func.func @all_gather_mesh_axes(
+    %arg0 : tensor<4xf32>) -> tensor<32xf32> {
+// CHECK: mesh_axes = array<i16: 0, 1>
+  %0 = mesh.all_gather %arg0 {
+    mesh = @mesh0, mesh_axes = array<i16: 1, 0, 0>, gather_axis = 0
+    } : tensor<4xf32> -> tensor<32xf32>
+  return %0 : tensor<32xf32>
+}
+
+// CHECK-LABEL: func @reduce_scatter_mesh_axes
+func.func @reduce_scatter_mesh_axes(
+    %arg0 : tensor<8xf32>) -> tensor<1xf64> {
+// CHECK: mesh_axes = array<i16: 0, 1>
+  %0 = mesh.reduce_scatter %arg0 {
+    mesh = @mesh0, mesh_axes = array<i16: 1, 0, 0>, scatter_axis = 0
+    } : tensor<8xf32> -> tensor<1xf64>
+  return %0 : tensor<1xf64>
+}
+
+// CHECK-LABEL: func @reduce_scatter_empty_mesh_axes_and_default_reduction
+func.func @reduce_scatter_empty_mesh_axes_and_default_reduction(
+    %arg0 : tensor<4xf32>) -> tensor<4xf64> {
+  %0 = mesh.reduce_scatter %arg0 {
+    mesh = @mesh0,
+// CHECK-NOT: mesh_axes
+    mesh_axes = array<i16>,
+// CHECK-NOT: reduction
+    reduction = #mesh.partial<sum>,
+    scatter_axis = 0
+    } : tensor<4xf32> -> tensor<4xf64>
+  return %0 : tensor<4xf64>
+}
\ No newline at end of file
diff --git a/mlir/test/Dialect/Mesh/invalid.mlir b/mlir/test/Dialect/Mesh/invalid.mlir
index 246439dd4be7122..413b01459f507a2 100644
--- a/mlir/test/Dialect/Mesh/invalid.mlir
+++ b/mlir/test/Dialect/Mesh/invalid.mlir
@@ -67,3 +67,243 @@ func.func @mesh_axis_negtive_in_partial(
             tensor<4x8xf32, #mesh.shard<@mesh0, [[0]], partial=max[-1]>> {
   return %arg0 : tensor<4x8xf32, #mesh.shard<@mesh0, [[0]], partial=max[-1]>>
 }
+
+// -----
+
+func.func @all_reduce_invalid_mesh_symbol(
+    %arg0 : tensor<4xf32>) -> tensor<4xf64> {
+  // expected-error at +1 {{Undefined required mesh symbol "this_mesh_symbol_does_not_exist".}}
+  %0 = mesh.all_reduce %arg0 {
+    mesh = @this_mesh_symbol_does_not_exist, reduction = #mesh.partial<sum>
+    } : tensor<4xf32> -> tensor<4xf64>
+  return %0 : tensor<4xf64>
+}
+
+// -----
+
+mesh.cluster @mesh0(rank = 2, dim_sizes = [2, 4])
+
+func.func @all_reduce_invalid_mesh_axis(
+    %arg0 : tensor<4xf32>) -> tensor<4xf64> {
+  // expected-error at +1 {{0-based mesh axis index 2 is out of bounds. The referenced mesh "mesh0" is of rank 2.}}
+  %0 = mesh.all_reduce %arg0 {
+    mesh = @mesh0, mesh_axes = array<i16: 2>, reduction = #mesh.partial<sum>
+    } : tensor<4xf32> -> tensor<4xf64>
+  return %0 : tensor<4xf64>
+}
+
+// -----
+
+mesh.cluster @mesh0(rank = 2, dim_sizes = [2, 4])
+
+func.func @all_reduce_invalid_tensor_dimension_size(
+    %arg0 : tensor<4xf32>) -> tensor<5xf64> {
+  // expected-error at +1 {{'mesh.all_reduce' op requires the same shape for all operands and results}}
+  %0 = mesh.all_reduce %arg0 { mesh = @mesh0 } : tensor<4xf32> -> tensor<5xf64>
+  return %0 : tensor<5xf64>
+}
+
+// -----
+
+func.func @all_gather_invalid_mesh_symbol(
+    %arg0 : tensor<4xf32>) -> tensor<4xf32> {
+  // expected-error at +1 {{Undefined required mesh symbol "this_mesh_symbol_does_not_exist".}}
+  %0 = mesh.all_gather %arg0 {
+    mesh = @this_mesh_symbol_does_not_exist, gather_axis = 0 : index
+    } : tensor<4xf32> -> tensor<4xf32>
+  return %0 : tensor<4xf32>
+}
+
+// -----
+
+mesh.cluster @mesh0(rank = 2, dim_sizes = [2, 4])
+
+func.func @all_gather_invalid_mesh_axis(
+    %arg0 : tensor<4xf32>) -> tensor<4xf32> {
+  // expected-error at +1 {{0-based mesh axis index 2 is out of bounds. The referenced mesh "mesh0" is of rank 2.}}
+  %0 = mesh.all_gather %arg0 {
+    mesh = @mesh0, mesh_axes = array<i16: 2>, gather_axis = 0 : index
+    } : tensor<4xf32> -> tensor<4xf32>
+  return %0 : tensor<4xf32>
+}
+
+// -----
+
+mesh.cluster @mesh0(rank = 1, dim_sizes = [1])
+
+func.func @all_gather_invalid_non_gather_axis_dimension_size(
+    %arg0 : tensor<3x4xf32>) -> tensor<3x5xf32> {
+  // expected-error at +1 {{Dimension size mismatch for result axis 1. Expected 4, but got 5.}}
+  %0 = mesh.all_gather %arg0 {
+    mesh = @mesh0, mesh_axes = array<i16: 0>, gather_axis = 0 : index
+    } : tensor<3x4xf32> -> tensor<3x5xf32>
+  return %0 : tensor<3x5xf32>
+}
+
+// -----
+
+mesh.cluster @mesh0(rank = 2, dim_sizes = [1, 2])
+
+func.func @all_gather_invalid_gather_axis_dimension_size(
+    %arg0 : tensor<3x4xf32>) -> tensor<3x5xf32> {
+  // expected-error at +1 {{Dimension size mismatch for result axis 1. Expected 8, but got 5.}}
+  %0 = mesh.all_gather %arg0 {
+    mesh = @mesh0, mesh_axes = array<i16: 1>, gather_axis = 1 : index
+    } : tensor<3x4xf32> -> tensor<3x5xf32>
+  return %0 : tensor<3x5xf32>
+}
+
+// -----
+
+mesh.cluster @mesh0(rank = 1, dim_sizes = [1])
+
+func.func @all_gather_invalid_gather_axis_dynamic_dimension(
+    %arg0 : tensor<?xf32>) -> tensor<3xf32> {
+  // expected-error at +1 {{Dimension size mismatch for result axis 0. Expected dynamic, but got 3.}}
+  %0 = mesh.all_gather %arg0 {
+    gather_axis = 0 : index,
+    mesh = @mesh0
+    } : tensor<?xf32> -> tensor<3xf32>
+  return %0 : tensor<3xf32>
+}
+
+// -----
+
+mesh.cluster @mesh0(rank = 1, dim_sizes = [1])
+
+func.func @all_gather_invalid_gather_axis(
+    %arg0 : tensor<3xf32>) -> tensor<3xf32> {
+  // expected-error at +1 {{Gather axis 1 is out of bounds [0, 1).}}
+  %0 = mesh.all_gather %arg0 {
+    mesh = @mesh0, mesh_axes = array<i16: 0>, gather_axis = 1 : index
+    } : tensor<3xf32> -> tensor<3xf32>
+  return %0 : tensor<3xf32>
+}
+
+// -----
+
+mesh.cluster @mesh0(rank = 1, dim_sizes = [1])
+
+func.func @all_gather_invalid_negative_gather_axis(
+    %arg0 : tensor<3xf32>) -> tensor<3xf32> {
+  // expected-error at +1 {{Gather axis -1 is out of bounds [0, 1).}}
+  %0 = mesh.all_gather %arg0 {
+    mesh = @mesh0, mesh_axes = array<i16: 0>, gather_axis = -1 : index
+    } : tensor<3xf32> -> tensor<3xf32>
+  return %0 : tensor<3xf32>
+}
+
+// -----
+
+func.func @all_to_all_gather_invalid_mesh_symbol(
+    %arg0 : tensor<3x6xi8>) -> tensor<3x6xi8> {
+  // expected-error at +1 {{Undefined required mesh symbol "this_mesh_symbol_does_not_exist".}}
+  %0 = mesh.all_to_all %arg0 {
+      concat_axis = 0, mesh = @this_mesh_symbol_does_not_exist, split_axis = 1
+    } : tensor<3x6xi8> -> tensor<3x6xi8>
+  return %0 : tensor<3x6xi8>
+}
+
+// -----
+
+mesh.cluster @mesh0(rank = 2, dim_sizes = [0, 1])
+
+func.func @all_to_all_invalid_non_dynamic_result_dimension_induced_by_dynamic_device_group(
+    %arg0 : tensor<3x6xi8>) -> tensor<3x6xi8> {
+  // expected-error at +1 {{Dimension size mismatch for result axis 1. Expected dynamic, but got 6.}}
+  %0 = mesh.all_to_all %arg0 {
+      concat_axis = 1, mesh = @mesh0, mesh_axes = array<i16: 0>, split_axis = 0
+    } : tensor<3x6xi8> -> tensor<3x6xi8>
+  return %0 : tensor<3x6xi8>
+}
+
+// -----
+
+mesh.cluster @mesh0(rank = 2, dim_sizes = [1, 1])
+
+func.func @all_to_all_invalid_non_dynamic_result_split_dimension_induced_by_dynamic_operand_dimension(
+    %arg0 : tensor<?x6xi8>) -> tensor<3x?xi8> {
+  // expected-error at +1 {{Dimension size mismatch for result axis 0. Expected dynamic, but got 3.}}
+  %0 = mesh.all_to_all %arg0 {
+      concat_axis = 1, mesh = @mesh0, mesh_axes = array<i16: 1>, split_axis = 0
+    } : tensor<?x6xi8> -> tensor<3x?xi8>
+  return %0 : tensor<3x?xi8>
+}
+
+// -----
+
+mesh.cluster @mesh0(rank = 2, dim_sizes = [1, 1])
+
+func.func @all_to_all_invalid_non_dynamic_result_concat_dimension_induced_by_dynamic_operand_dimension(
+    %arg0 : tensor<3x?xi8>) -> tensor<?x3xi8> {
+  // expected-error at +1 {{Dimension size mismatch for result axis 1. Expected dynamic, but got 3.}}
+  %0 = mesh.all_to_all %arg0 {
+      concat_axis = 1, mesh = @mesh0, mesh_axes = array<i16: 1>, split_axis = 0
+    } : tensor<3x?xi8> -> tensor<?x3xi8>
+  return %0 : tensor<?x3xi8>
+}
+
+// -----
+
+mesh.cluster @mesh0(rank = 1, dim_sizes = [3])
+
+func.func @all_to_all_invalid_non_dynamic_result_concat_dimension_size(
+    %arg0 : tensor<3x2xi8>) -> tensor<1x7xi8> {
+  // expected-error at +1 {{Dimension size mismatch for result axis 1. Expected 6, but got 7.}}
+  %0 = mesh.all_to_all %arg0 {
+      concat_axis = 1, mesh = @mesh0, mesh_axes = array<i16: 0>, split_axis = 0
+    } : tensor<3x2xi8> -> tensor<1x7xi8>
+  return %0 : tensor<1x7xi8>
+}
+
+// -----
+
+mesh.cluster @mesh0(rank = 1, dim_sizes = [3])
+
+func.func @all_to_all_invalid_non_dynamic_result_split_dimension_size(
+    %arg0 : tensor<3x2xi8>) -> tensor<2x6xi8> {
+  // expected-error at +1 {{Dimension size mismatch for result axis 0. Expected 1, but got 2.}}
+  %0 = mesh.all_to_all %arg0 {
+      concat_axis = 1, mesh = @mesh0, mesh_axes = array<i16: 0>, split_axis = 0
+    } : tensor<3x2xi8> -> tensor<2x6xi8>
+  return %0 : tensor<2x6xi8>
+}
+
+// -----
+
+mesh.cluster @mesh0(rank = 1, dim_sizes = [3])
+
+func.func @reduce_scatter_invalid_dynamic_dimension(
+    %arg0 : tensor<?xf32>) -> tensor<2xf64> {
+  // expected-error at +1 {{Dimension size mismatch for result axis 0. Expected dynamic, but got 2.}}
+  %0 = mesh.reduce_scatter %arg0 {
+      mesh = @mesh0, scatter_axis = 0
+    } : tensor<?xf32> -> tensor<2xf64>
+  return %0 : tensor<2xf64>
+}
+
+// -----
+
+mesh.cluster @mesh0(rank = 1, dim_sizes = [3])
+
+func.func @reduce_scatter_invalid_static_dimension_size(
+    %arg0 : tensor<3xf32>) -> tensor<2xf64> {
+  // expected-error at +1 {{Dimension size mismatch for result axis 0. Expected 1, but got 2.}}
+  %0 = mesh.reduce_scatter %arg0 {
+      mesh = @mesh0, mesh_axes = array<i16: 0>, scatter_axis = 0
+    } : tensor<3xf32> -> tensor<2xf64>
+  return %0 : tensor<2xf64>
+}
+
+// -----
+
+mesh.cluster @mesh0(rank = 1, dim_sizes = [3])
+
+func.func @reduce_scatter_invalid_operand_static_dimension_size(
+    %arg0 : tensor<4xf32>) -> tensor<?xf64> {
+  // expected-error at +1 {{Operand dimension size 4 is not divisible by collective device group size 3 for scatter axis 0.}}
+  %0 = mesh.reduce_scatter %arg0 {
+      mesh = @mesh0, mesh_axes = array<i16: 0>, scatter_axis = 0
+    } : tensor<4xf32> -> tensor<?xf64>
+  return %0 : tensor<?xf64>
+}
diff --git a/mlir/test/Dialect/Mesh/ops.mlir b/mlir/test/Dialect/Mesh/ops.mlir
index ee5f8f67792b928..5ec6c4a439327f0 100644
--- a/mlir/test/Dialect/Mesh/ops.mlir
+++ b/mlir/test/Dialect/Mesh/ops.mlir
@@ -12,6 +12,8 @@ mesh.cluster @mesh2(rank = 2, dim_sizes = [0, 4])
 // CHECK: mesh.cluster @mesh3
 mesh.cluster @mesh3(rank = 2)
 
+mesh.cluster @mesh4(rank = 1, dim_sizes = [3])
+
 // CHECK-LABEL: func @mesh_shard_encoding_fully_replicated
 func.func @mesh_shard_encoding_fully_replicated(
     // CHECK-SAME: %[[ARG:.*]]: tensor<4x8xf32, #mesh.shard<@mesh0, {{\[\[}}]]>>
@@ -126,3 +128,120 @@ func.func @mesh_shard_op_two_users(%arg0 : tensor<4x8xf32>) ->
   %2 = mesh.shard %0 to <@mesh0, [[2]]> annotate_for_users : tensor<4x8xf32>
   return %1, %2 : tensor<4x8xf32>, tensor<4x8xf32>
 }
+
+// CHECK-LABEL: func @all_reduce
+func.func @all_reduce(
+    // CHECK-SAME: %[[ARG:.*]]: tensor<3x4xf32>
+    %arg0 : tensor<3x4xf32>) -> tensor<3x4xf64> {
+  // CHECK-NEXT: mesh.all_reduce %[[ARG]]
+  // CHECK-SAME: {mesh = @mesh0, mesh_axes = array<i16: 1, 0>, reduction = #mesh.partial<max>}
+  // CHECK-SAME: : tensor<3x4xf32> -> tensor<3x4xf64>
+  %0 = mesh.all_reduce %arg0 {
+      mesh = @mesh0, mesh_axes = array<i16: 1, 0>, reduction = #mesh.partial<max>
+    } : tensor<3x4xf32> -> tensor<3x4xf64>
+  return %0 : tensor<3x4xf64>
+}
+
+// CHECK-LABEL: func @all_gather
+func.func @all_gather(
+    // CHECK-SAME: %[[ARG:.*]]: tensor<3x4xf32>
+    %arg0 : tensor<3x4xf32>) -> tensor<3x16xf32> {
+  // CHECK-NEXT: mesh.all_gather %[[ARG]]
+  // CHECK-SAME: {gather_axis = 1 : index, mesh = @mesh0, mesh_axes = array<i16: 2>}
+  // CHECK-SAME: : tensor<3x4xf32> -> tensor<3x16xf32>
+  %0 = mesh.all_gather %arg0 {
+      gather_axis = 1 : index, mesh = @mesh0, mesh_axes = array<i16: 2>
+    } : tensor<3x4xf32> -> tensor<3x16xf32>
+  return %0 : tensor<3x16xf32>
+}
+
+// CHECK-LABEL: func @all_gather_dynamic_dims_in_tensor
+func.func @all_gather_dynamic_dims_in_tensor(
+    // CHECK-SAME: %[[ARG:.*]]: tensor<?x?xf32>
+    %arg0 : tensor<?x?xf32>) -> tensor<?x?xf32> {
+  // CHECK-NEXT: mesh.all_gather %[[ARG]]
+  // CHECK-SAME: {gather_axis = 1 : index, mesh = @mesh0, mesh_axes = array<i16: 2>}
+  // CHECK-SAME: : tensor<?x?xf32> -> tensor<?x?xf32>
+  %0 = mesh.all_gather %arg0 {
+      gather_axis = 1 : index, mesh = @mesh0, mesh_axes = array<i16: 2>
+    } : tensor<?x?xf32> -> tensor<?x?xf32>
+  return %0 : tensor<?x?xf32>
+}
+
+// CHECK-LABEL: func @all_gather_dynamic_dims_in_mesh
+func.func @all_gather_dynamic_dims_in_mesh(
+    // CHECK-SAME: %[[ARG:.*]]: tensor<5x6xf32>
+    %arg0 : tensor<5x6xf32>) -> tensor<5x?xf32> {
+  // CHECK-NEXT: mesh.all_gather %[[ARG]]
+  // CHECK-SAME: {gather_axis = 1 : index, mesh = @mesh3, mesh_axes = array<i16: 1>}
+  // CHECK-SAME: : tensor<5x6xf32> -> tensor<5x?xf32>
+  %0 = mesh.all_gather %arg0 {
+      gather_axis = 1 : index, mesh = @mesh3, mesh_axes = array<i16: 1>
+    } : tensor<5x6xf32> -> tensor<5x?xf32>
+  return %0 : tensor<5x?xf32>
+}
+
+// CHECK-LABEL: func @all_to_all
+func.func @all_to_all(
+    // CHECK-SAME: %[[ARG:.*]]: tensor<3x6xi8>
+    %arg0 : tensor<3x6xi8>) -> tensor<3x6xi8> {
+  // CHECK-NEXT: mesh.all_to_all %[[ARG]]
+  // CHECK-SAME: {concat_axis = 0 : i64, mesh = @mesh4, split_axis = 1 : i64}
+  // CHECK-SAME: : tensor<3x6xi8> -> tensor<3x6xi8>
+  %0 = mesh.all_to_all %arg0 {
+      concat_axis = 0, mesh = @mesh4, split_axis = 1
+    } : tensor<3x6xi8> -> tensor<3x6xi8>
+  return %0 : tensor<3x6xi8>
+}
+
+// CHECK-LABEL: func @all_to_all_dynamic_dims_in_result
+func.func @all_to_all_dynamic_dims_in_result(
+    // CHECK-SAME: %[[ARG:.*]]: tensor<3x6xi8>
+    %arg0 : tensor<3x6xi8>) -> tensor<3x?xi8> {
+  // CHECK-NEXT: mesh.all_to_all %[[ARG]]
+  // CHECK-SAME: {concat_axis = 0 : i64, mesh = @mesh4, split_axis = 1 : i64}
+  // CHECK-SAME: : tensor<3x6xi8> -> tensor<3x?xi8>
+  %0 = mesh.all_to_all %arg0 {
+      concat_axis = 0, mesh = @mesh4, split_axis = 1
+    } : tensor<3x6xi8> -> tensor<3x?xi8>
+  return %0 : tensor<3x?xi8>
+}
+
+// CHECK-LABEL: func @all_to_all
+func.func @all_to_all_same_split_concat_dim_with_dynamic_device_group_size(
+    // CHECK-SAME: %[[ARG:.*]]: tensor<3xi8>
+    %arg0 : tensor<3xi8>) -> tensor<3xi8> {
+  // CHECK-NEXT: mesh.all_to_all %[[ARG]]
+  // CHECK-SAME: {concat_axis = 0 : i64, mesh = @mesh4, split_axis = 0 : i64}
+  // CHECK-SAME: : tensor<3xi8> -> tensor<3xi8>
+  %0 = mesh.all_to_all %arg0 {
+      concat_axis = 0, mesh = @mesh4, split_axis = 0
+    } : tensor<3xi8> -> tensor<3xi8>
+  return %0 : tensor<3xi8>
+}
+
+// CHECK-LABEL: func @reduce_scatter_static_dimensions
+func.func @reduce_scatter_static_dimensions(
+    // CHECK-SAME: %[[ARG:.*]]: tensor<3x4xf32>
+    %arg0 : tensor<3x4xf32>) -> tensor<3x1xf64> {
+  // CHECK-NEXT: mesh.reduce_scatter %[[ARG]]
+  // CHECK-SAME: {mesh = @mesh0, mesh_axes = array<i16: 2>, reduction = #mesh.partial<max>, scatter_axis = 1 : i64}
+  // CHECK-SAME: : tensor<3x4xf32> -> tensor<3x1xf64>
+  %0 = mesh.reduce_scatter %arg0 {
+      mesh = @mesh0, mesh_axes = array<i16: 2>, reduction = #mesh.partial<max>, scatter_axis = 1
+    } : tensor<3x4xf32> -> tensor<3x1xf64>
+  return %0 : tensor<3x1xf64>
+}
+
+// CHECK-LABEL: func @reduce_scatter_dynamic_dimensions
+func.func @reduce_scatter_dynamic_dimensions(
+    // CHECK-SAME: %[[ARG:.*]]: tensor<?xf32>
+    %arg0 : tensor<?xf32>) -> tensor<?xf64> {
+  // CHECK-NEXT: mesh.reduce_scatter %[[ARG]]
+  // CHECK-SAME: {mesh = @mesh3, mesh_axes = array<i16: 0, 1>, scatter_axis = 0 : i64}
+  // CHECK-SAME: : tensor<?xf32> -> tensor<?xf64>
+  %0 = mesh.reduce_scatter %arg0 {
+      mesh = @mesh3, mesh_axes = array<i16: 0, 1>, scatter_axis = 0
+    } : tensor<?xf32> -> tensor<?xf64>
+  return %0 : tensor<?xf64>
+}



More information about the Mlir-commits mailing list