[Mlir-commits] [mlir] [mlir][mesh] Add collective communication operations (PR #71960)

Boian Petkantchin llvmlistbot at llvm.org
Tue Nov 14 10:58:24 PST 2023


https://github.com/sogartar updated https://github.com/llvm/llvm-project/pull/71960

>From 5b8c7f70202f7cac6b6b1a29e168fe6fa43fc720 Mon Sep 17 00:00:00 2001
From: Boian Petkantchin <boian at nod-labs.com>
Date: Mon, 6 Nov 2023 15:55:31 -0800
Subject: [PATCH 1/2] [mlir][mesh] Add collective communication operations

Add all-gather, all-reduce, all-to-all and reduce-scatter.
These operations have device mesh semantics.
---
 mlir/docs/Dialects/Mesh.md                    |  34 ++
 mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td |   8 +-
 mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h   |   2 +
 mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td  | 216 +++++++++
 mlir/lib/Dialect/Mesh/IR/MeshOps.cpp          | 417 ++++++++++++++++++
 mlir/test/Dialect/Mesh/canonicalization.mlir  |  72 +++
 mlir/test/Dialect/Mesh/invalid.mlir           | 240 ++++++++++
 mlir/test/Dialect/Mesh/ops.mlir               | 119 +++++
 8 files changed, 1105 insertions(+), 3 deletions(-)
 create mode 100644 mlir/docs/Dialects/Mesh.md
 create mode 100644 mlir/test/Dialect/Mesh/canonicalization.mlir

diff --git a/mlir/docs/Dialects/Mesh.md b/mlir/docs/Dialects/Mesh.md
new file mode 100644
index 000000000000000..6dd4f79022061ee
--- /dev/null
+++ b/mlir/docs/Dialects/Mesh.md
@@ -0,0 +1,34 @@
+# 'mesh' Dialect
+
+The `mesh` dialect contains a set of attributes, operations and interfaces that
+are useful for representing sharding and communication on a device mesh
+cluster.
+
+[TOC]
+
+## Collective Communication Operations
+There are a number of operations in the Mesh dialect to facilitate
+communication between devices in a mesh.
+It is assumed that the user is familiar with collective operations.
+[Wikipedia](https://en.wikipedia.org/wiki/Collective_operation) has a good
+explanation.
+The main addition is that the collectives in this dialect have mesh
+semantics.
+The operation attributes `mesh` and `mesh_axes` specifies a set of device mesh
+axes that partition the devices into disjoint groups.
+The collective operation is performed between devices in the same group.
+Devices that have the same coordinates outside of axes `mesh_axes` are in the
+same group.
+For example if we have a device mesh of size `2x3x4x5` and the partition mesh
+axes set is `{0, 1}` then devices are partitioned into the groups
+`{ { (i, j, k, m) | 0<=i<2, 0<=j<3 } | 0<=k<4, 0<=m<5 }`.
+Devices (1, 0, 2, 3) and (1, 1, 2, 3) will be in the same group.
+Device (1, 0, 2, 4) will be in another group.
+
+## Operations
+
+[include "Dialects/MeshOps.md"]
+
+## Attributes
+
+[include "Dialects/MeshAttributes.md"]
diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td b/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td
index a91ef569347bff1..9d39b1b3329fb4b 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td
@@ -23,9 +23,7 @@ def Mesh_Dialect : Dialect {
   let cppNamespace = "::mlir::mesh";
 
   let description = [{
-    The `mesh` dialect contains a set of attributes, operations, interfaces that
-    are useful for representing sharding and communication on device mesh
-    cluster.
+    See [Mesh dialect documentation](mlir/docs/Dialects/Mesh.md).
   }];
 
   let dependentDialects = [
@@ -49,6 +47,10 @@ def Mesh_Partial : I32EnumAttr<"Partial", "partial type of a distributed tensor"
   let cppNamespace = "::mlir::mesh";
 }
 
+def Mesh_PartialAttr : EnumAttr<Mesh_Dialect, Mesh_Partial, "partial"> {
+  let assemblyFormat = "`<` $value `>`";
+}
+
 // Mesh_IteratorType and Mesh_Partial are used to annotate different aspects of
 // distributed tensors. Mesh_IteratorType annotates loops in an operation, while
 // Mesh_Partial indicates whether a tensor is sharded on a specific dimension or
diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
index 05eba66a89949b6..7698d60813a8f10 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
@@ -10,9 +10,11 @@
 #define MLIR_DIALECT_MESH_IR_MESHOPS_H
 
 #include "mlir/Bytecode/BytecodeOpInterface.h"
+#include "mlir/IR/OpDefinition.h"
 #include "mlir/IR/SymbolTable.h"
 #include "mlir/Interfaces/InferTypeOpInterface.h"
 #include "mlir/Interfaces/SideEffectInterfaces.h"
+#include <algorithm>
 
 #include "mlir/Dialect/Mesh/IR/MeshOpsDialect.h.inc"
 
diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
index a8aa0a694bee29f..15354babe870599 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
@@ -13,6 +13,8 @@ include "mlir/Dialect/Mesh/IR/MeshBase.td"
 include "mlir/Interfaces/InferTypeOpInterface.td"
 include "mlir/Interfaces/SideEffectInterfaces.td"
 include "mlir/IR/BuiltinTypes.td"
+include "mlir/IR/CommonAttrConstraints.td"
+include "mlir/IR/CommonTypeConstraints.td"
 include "mlir/IR/SymbolInterfaces.td"
 
 //===----------------------------------------------------------------------===//
@@ -77,6 +79,15 @@ def Mesh_ClusterOp : Mesh_Op<"cluster", [Symbol]> {
     $sym_name `(` `rank` `=` $rank (`,` `dim_sizes` `=` $dim_sizes^)? `)`
       attr-dict
   }];
+  let extraClassDeclaration = [{
+    ::mlir::SmallVector<int64_t> canonicalDimSizes();
+
+    template <typename OutIt>
+    void canonicalDimSizes(OutIt outIt) {
+      std::copy(getDimSizes().begin(), getDimSizes().end(), outIt);
+      std::fill_n(outIt, getRank() - getDimSizes().size(), 0);
+    }
+  }];
   let hasVerifier = 1;
 }
 
@@ -171,4 +182,209 @@ def Mesh_ShardOp : Mesh_Op<"shard", [Pure, SameOperandsAndResultType]> {
   }];
 }
 
+//===----------------------------------------------------------------------===//
+// collective communication ops
+//===----------------------------------------------------------------------===//
+
+class Mesh_CollectiveCommunicationOpBase<
+    string mnemonic, list<Trait> traits = []> :
+    Mesh_Op<mnemonic,
+      !listconcat(traits,
+      [SymbolUserOpInterface])> {
+  let assemblyFormat = "$input attr-dict `:` type($input) `->` type($result)";
+  code extraClassDeclarationBase = [{
+    ::mlir::LogicalResult verifySymbolUses(
+          ::mlir::SymbolTableCollection &symbolTable);
+  }];
+}
+
+def Mesh_AllGatherOp : Mesh_CollectiveCommunicationOpBase<"all_gather", [
+    SameOperandsAndResultElementType,
+    SameOperandsAndResultRank
+  ]> {
+  let summary = "All-gather over a device mesh.";
+  let description = [{
+    Gathers along the `gather_axis` tensor axis.
+    The order of input tensors in the resulting tensor is the same as the
+    order of the corresponding devices' multi-index in the mesh.
+
+    Example:
+    ```mlir
+    mesh.cluster @mesh0(rank = 2, dim_sizes = [2, 2])
+    ...
+    %1 = mesh.all_gather %0 {
+        mesh = @mesh0, mesh_axes = array<i16: 1>, gather_axis = 1 : index
+      } : tensor<2x2xi8> -> tensor<2x4xi8>
+    ```
+    Input:
+    ```
+                     +-------+-------+
+    device (0, 0) -> |  1  2 |  5  6 | <- device (0, 1)
+                     |  3  4 |  7  8 |
+                     +-------+-------+
+    device (1, 0) -> |  9 10 | 13 14 | <- device (1, 1)
+                     | 11 12 | 15 16 |
+                     +-------+-------+
+    ```
+    Result:
+    ```
+    +-------------+
+    |  1  2  5  6 | <- devices (0, 0) and (0, 1)
+    |  3  4  7  8 |
+    +-------------+
+    |  9 10 13 14 | <- devices (1, 0) and (1, 1)
+    | 11 12 15 16 |
+    +-------------+
+    ```
+  }];
+  let arguments = (ins
+    AnyNon0RankedTensor:$input,
+    FlatSymbolRefAttr:$mesh,
+    DefaultValuedOptionalAttr<DenseI16ArrayAttr, "{}">:$mesh_axes,
+    APIntAttr:$gather_axis
+  );
+  let results = (outs
+    AnyNon0RankedTensor:$result
+  );
+  let hasCanonicalizer = 1;
+  let hasVerifier = 1;
+  let extraClassDeclaration = extraClassDeclarationBase;
+}
+
+def Mesh_AllReduceOp : Mesh_CollectiveCommunicationOpBase<"all_reduce", [
+    SameOperandsAndResultShape]> {
+  let summary = "All-reduce over a device mesh.";
+  let description = [{
+    The accumulation element type is specified by the result type and
+    it does not need to match the input element type.
+    The input element is converted to the result element type before
+    performing the reduction.
+
+    Attributes:
+    `reduction`: Indicates the reduction method.
+
+    Example:
+    ```
+    %1 = mesh.all_reduce %0 {
+        mesh = @mesh0, mesh_axes = array<i16: 1, 0>, reduction = #mesh.partial<max>
+      } : tensor<3x4xf32> -> tensor<3x4xf64>
+    ```
+  }];
+  let arguments = (ins
+    AnyRankedTensor:$input,
+    FlatSymbolRefAttr:$mesh,
+    DefaultValuedOptionalAttr<DenseI16ArrayAttr, "{}">:$mesh_axes,
+    DefaultValuedOptionalAttr<Mesh_PartialAttr, "::mlir::mesh::Partial::Sum">:$reduction
+  );
+  let results = (outs
+    AnyRankedTensor:$result
+  );
+  let hasCanonicalizer = 1;
+  let extraClassDeclaration = extraClassDeclarationBase;
+}
+
+def Mesh_AllToAllOp : Mesh_CollectiveCommunicationOpBase<"all_to_all", [
+    SameOperandsAndResultElementType,
+    SameOperandsAndResultRank]> {
+  let summary = "All-to-all over a device mesh.";
+  let description = [{
+    Performs an all-to-all on tensor pieces split along `split_axis`.
+    The resulting pieces are concatenated along `concat_axis` on ech device.
+    Example:
+    ```
+    mesh.cluster @mesh0(rank = 1, dim_sizes = [3])
+    ...
+    %1 = mesh.all_to_all %0 {
+        mesh = @mesh0, mesh_axes = array<i16: 0>, split_axis = 0, concat_axis = 0
+      } : tensor<3x6xi8> -> tensor<3x6xi8>
+    ```
+    Input:
+    ```
+     device  device  device
+     (0)     (1)     (2)
+    +-------+-------+-------+
+    | 11 12 | 21 22 | 31 32 |
+    | 13 14 | 23 24 | 33 34 |
+    | 15 16 | 25 26 | 35 36 |
+    +-------+-------+-------+
+    ```
+    Result:
+    ```
+     device  device  device
+     (0)     (1)     (2)
+    +-------+-------+-------+
+    | 11 12 | 13 14 | 15 16 |
+    | 21 22 | 23 24 | 25 26 |
+    | 31 32 | 33 34 | 35 36 |
+    +-------+-------+-------+
+    ```
+  }];
+  let arguments = (ins
+    AnyNon0RankedTensor:$input,
+    FlatSymbolRefAttr:$mesh,
+    DefaultValuedOptionalAttr<DenseI16ArrayAttr, "{}">:$mesh_axes,
+    APIntAttr:$split_axis,
+    APIntAttr:$concat_axis
+  );
+  let results = (outs
+    AnyNon0RankedTensor:$result
+  );
+  let hasCanonicalizer = 1;
+  let hasVerifier = 1;
+  let extraClassDeclaration = extraClassDeclarationBase;
+}
+
+def Mesh_ReduceScatterOp : Mesh_CollectiveCommunicationOpBase<"reduce_scatter", [
+    SameOperandsAndResultRank]> {
+  let summary = "Reduce-scatter over a device mesh.";
+  let description = [{
+    After the reduction scatters the result within each device group.
+    The tensor is split along `scatter_axis` and the pieces distributed
+    across the device group.
+    Example:
+    ```
+    mesh.cluster @mesh0(rank = 1, dim_sizes = [2, 2])
+    ...
+    %1 = mesh.reduce_scatter %0 {
+        mesh = @mesh0, mesh_axes = array<i16: 1>, reduction = #mesh.partial<max>, scatter_axis = 0
+      } : tensor<3x4xf32> -> tensor<1x4xf64>
+    ```
+    Input:
+    ```
+                     +-------+-------+
+    device (0, 0) -> |  1  2 |  5  6 | <- device (0, 1)
+                     |  3  4 |  7  8 |
+                     +-------+-------+
+    device (1, 0) -> |  9 10 | 13 14 | <- device (1, 1)
+                     | 11 12 | 15 16 |
+                     +-------+-------+
+    ```
+    Result:
+    ```
+    +-------+
+    |  6  8 | <- devices (0, 0)
+    +-------+
+    | 10 12 | <- devices (0, 1)
+    +-------+
+    | 22 24 | <- devices (1, 0)
+    +-------+
+    | 26 28 | <- devices (1, 1)
+    +-------+
+    ```
+  }];
+  let arguments = (ins
+    AnyNon0RankedTensor:$input,
+    FlatSymbolRefAttr:$mesh,
+    DefaultValuedOptionalAttr<DenseI16ArrayAttr, "{}">:$mesh_axes,
+    DefaultValuedOptionalAttr<Mesh_PartialAttr, "::mlir::mesh::Partial::Sum">:$reduction,
+    APIntAttr:$scatter_axis
+  );
+  let results = (outs
+    AnyRankedTensor:$result
+  );
+  let hasCanonicalizer = 1;
+  let hasVerifier = 1;
+  let extraClassDeclaration = extraClassDeclarationBase;
+}
+
 #endif // MLIR_DIALECT_MESH_IR_MESHOPS_TD
diff --git a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
index 588704f24574f90..6efc4c4ecc326ad 100644
--- a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
+++ b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
@@ -8,10 +8,26 @@
 
 #include "mlir/Dialect/Mesh/IR/MeshOps.h"
 #include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/BuiltinTypeInterfaces.h"
+#include "mlir/IR/Diagnostics.h"
 #include "mlir/IR/DialectImplementation.h"
+#include "mlir/IR/Location.h"
+#include "mlir/IR/PatternMatch.h"
 #include "mlir/Support/LLVM.h"
+#include "mlir/Support/LogicalResult.h"
+#include "llvm/ADT/DenseSet.h"
+#include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/SmallSet.h"
+#include "llvm/ADT/SmallVector.h"
 #include "llvm/ADT/TypeSwitch.h"
+#include <algorithm>
+#include <functional>
+#include <iterator>
+#include <numeric>
+#include <optional>
+#include <string>
+#include <utility>
 
 #define DEBUG_TYPE "mesh-ops"
 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ")
@@ -21,6 +37,60 @@ using namespace mlir::mesh;
 
 #include "mlir/Dialect/Mesh/IR/MeshOpsDialect.cpp.inc"
 
+namespace {
+
+template <typename It>
+It canonicalizeSetAsArray(It begin, It end) {
+  std::sort(begin, end);
+  return std::unique(begin, end);
+}
+
+template <typename R>
+auto canonicalizeSetAsArray(R &&range) {
+  return canonicalizeSetAsArray(adl_begin(range), adl_end(range));
+}
+
+template <typename T>
+SmallVector<T> &canonicalizeSetAsVector(SmallVector<T> &vec) {
+  auto newEnd = canonicalizeSetAsArray(vec);
+  vec.resize(newEnd - vec.begin());
+  return vec;
+}
+
+template <typename DimSize>
+bool isMeshDimensionDynamic(DimSize size) {
+  return size <= DimSize(0);
+}
+
+using MeshAxis = int16_t;
+
+struct DimensionSize {
+  static DimensionSize dynamic() { return DimensionSize(ShapedType::kDynamic); }
+  DimensionSize(int64_t val) : val(val) {}
+  int64_t value() const { return val; }
+  operator int64_t() const { return val; }
+  bool isDynamic() const { return ShapedType::isDynamic(val); }
+
+private:
+  int64_t val;
+};
+
+DimensionSize operator/(DimensionSize lhs, DimensionSize rhs) {
+  if (lhs.isDynamic() || rhs.isDynamic()) {
+    return DimensionSize::dynamic();
+  }
+  return lhs.value() / rhs.value();
+}
+
+DimensionSize operator*(DimensionSize lhs, DimensionSize rhs) {
+  if (lhs.isDynamic() || rhs.isDynamic()) {
+    return DimensionSize::dynamic();
+  }
+  return lhs.value() * rhs.value();
+}
+
+} // namespace
+
 //===----------------------------------------------------------------------===//
 // Mesh dialect
 //===----------------------------------------------------------------------===//
@@ -96,6 +166,12 @@ LogicalResult ClusterOp::verify() {
   return success();
 }
 
+SmallVector<int64_t> ClusterOp::canonicalDimSizes() {
+  SmallVector<int64_t> result;
+  canonicalDimSizes(std::back_inserter(result));
+  return result;
+}
+
 //===----------------------------------------------------------------------===//
 // mesh.shard op
 //===----------------------------------------------------------------------===//
@@ -129,6 +205,347 @@ MeshShardingAttr::verify(function_ref<InFlightDiagnostic()> emitError,
   return success();
 }
 
+//===----------------------------------------------------------------------===//
+// collective communication ops
+//===----------------------------------------------------------------------===//
+
+namespace {
+
+std::optional<DenseI16ArrayAttr>
+canonicalizeAxesSetAttribute(DenseI16ArrayAttr attr) {
+  if (!attr) {
+    return std::nullopt;
+  }
+  SmallVector<int16_t> axes = llvm::to_vector(attr.asArrayRef());
+  canonicalizeSetAsVector(axes);
+  if (axes.empty()) {
+    return std::nullopt;
+  }
+  return DenseI16ArrayAttr::get(attr.getContext(), axes);
+}
+
+template <typename Op>
+struct AxesSetCanonicalizationPattern : OpRewritePattern<Op> {
+  AxesSetCanonicalizationPattern(MLIRContext *context, StringRef axisSetAttr)
+      : OpRewritePattern<Op>(context), axisSetAttr(axisSetAttr) {}
+  LogicalResult matchAndRewrite(Op op,
+                                PatternRewriter &rewriter) const override {
+    auto canonicalMeshAxesAttr = canonicalizeAxesSetAttribute(
+        op->template getAttrOfType<DenseI16ArrayAttr>(axisSetAttr));
+    if (!canonicalMeshAxesAttr) {
+      op->removeAttr(axisSetAttr);
+    } else {
+      op->setAttr(axisSetAttr, canonicalMeshAxesAttr.value());
+    }
+    return success();
+  }
+
+  std::string axisSetAttr;
+};
+
+template <typename Op>
+void populateMeshAxesSetCanonicalizationPatterns(RewritePatternSet &patterns,
+                                                 MLIRContext *context) {
+  patterns.add<AxesSetCanonicalizationPattern<Op>>(context, "mesh_axes");
+}
+
+template <typename Op>
+LogicalResult verifyMeshSymbolUses(Op op, SymbolTableCollection &symbolTable) {
+  FlatSymbolRefAttr symbolAttr = op.getMeshAttr();
+  if (!symbolAttr) {
+    return op.emitError() << "Unspecified \"mesh\" symbol attribute.";
+  }
+  SymbolTableCollection symbolTableCollection;
+  mesh::ClusterOp mesh =
+      symbolTableCollection.lookupNearestSymbolFrom<mesh::ClusterOp>(
+          op.getOperation(), symbolAttr);
+  if (!mesh) {
+    return op.emitError() << "Undefined required mesh symbol \""
+                          << symbolAttr.getValue() << "\".";
+  }
+  DenseI16ArrayAttr meshAxes = op.getMeshAxesAttr();
+  if (!meshAxes) {
+    return success();
+  }
+  MeshAxis rank = mesh.getRank();
+  for (auto axis : meshAxes.asArrayRef()) {
+    if (axis >= rank || axis < 0) {
+      return op.emitError()
+             << "0-based mesh axis index " << axis
+             << " is out of bounds. The referenced mesh \""
+             << symbolAttr.getValue() << "\" is of rank " << rank << ".";
+    }
+  }
+
+  return success();
+}
+
+template <typename It>
+auto product(It begin, It end) {
+  using ElementType = std::decay_t<decltype(*begin)>;
+  return std::accumulate(begin, end, ElementType(1),
+                         std::multiplies<ElementType>());
+}
+
+template <typename R>
+auto product(R &&range) {
+  return product(adl_begin(range), adl_end(range));
+}
+
+int64_t collectiveDeviceGroupSize(ArrayRef<MeshAxis> meshAxes,
+                                  ArrayRef<int64_t> meshShape) {
+  int64_t res = 1;
+  for (MeshAxis axis = 0; axis < MeshAxis(meshShape.size()); ++axis) {
+    if (llvm::find(meshAxes, axis) == meshAxes.end()) {
+      continue;
+    }
+    if (isMeshDimensionDynamic(meshShape[axis])) {
+      return ShapedType::kDynamic;
+    }
+    res *= meshShape[axis];
+  }
+  return res;
+}
+
+LogicalResult verifyDimensionCompatibility(Location loc,
+                                           int64_t expectedDimSize,
+                                           int64_t resultDimSize,
+                                           int64_t resultAxis) {
+  if (!ShapedType::isDynamic(resultDimSize) &&
+      expectedDimSize != resultDimSize) {
+    return emitError(loc) << "Dimension size mismatch for result axis "
+                          << resultAxis << ". Expected "
+                          << (ShapedType::isDynamic(expectedDimSize)
+                                  ? Twine("dynamic")
+                                  : Twine(expectedDimSize))
+                          << ", but got " << resultDimSize << ".";
+  }
+
+  return success();
+}
+
+LogicalResult verifyGatherOperandAndResultShape(Value operand, Value result,
+                                                int64_t gatherAxis,
+                                                ArrayRef<MeshAxis> meshAxes,
+                                                ArrayRef<int64_t> meshShape) {
+  ShapedType operandType = operand.getType().cast<ShapedType>();
+  ShapedType resultType = result.getType().cast<ShapedType>();
+  auto deviceGroupSize =
+      DimensionSize(collectiveDeviceGroupSize(meshAxes, meshShape));
+  for (int64_t axis = 0; axis < operandType.getRank(); ++axis) {
+    auto operandDimSize = DimensionSize(operandType.getDimSize(axis));
+    auto resultDimSize = DimensionSize(resultType.getDimSize(axis));
+    auto expectedResultDimSize =
+        axis == gatherAxis ? deviceGroupSize * operandDimSize : operandDimSize;
+    if (failed(verifyDimensionCompatibility(
+            result.getLoc(), expectedResultDimSize, resultDimSize, axis))) {
+      return failure();
+    }
+  }
+  return success();
+}
+
+template <typename Op>
+FailureOr<ClusterOp> getMesh(Op op) {
+  SymbolTableCollection symbolTableCollection;
+  if (failed(verifyMeshSymbolUses(op, symbolTableCollection))) {
+    // We need to check the symbol here since this runs before
+    // SymbolUserOpInterface.
+    return failure();
+  }
+  return symbolTableCollection.lookupNearestSymbolFrom<mesh::ClusterOp>(
+      op.getOperation(), op.getMeshAttr());
+}
+
+template <typename Op>
+LogicalResult verifyGather(Op op) {
+  auto rank = op.getResult().getType().template cast<ShapedType>().getRank();
+  auto gatherAxis = op.getGatherAxis().getSExtValue();
+  if (gatherAxis < 0 || gatherAxis >= rank) {
+    return op.emitError() << "Gather axis " << gatherAxis
+                          << " is out of bounds [0, " << rank << ").";
+  }
+
+  auto mesh = getMesh(op);
+  if (failed(mesh)) {
+    return failure();
+  }
+  return verifyGatherOperandAndResultShape(op.getOperand(), op.getResult(),
+                                           gatherAxis, op.getMeshAxes(),
+                                           mesh.value().canonicalDimSizes());
+}
+
+LogicalResult verifyAllToAllOperandAndResultShape(Value operand, Value result,
+                                                  int64_t splitAxis,
+                                                  int64_t concatAxis,
+                                                  ArrayRef<MeshAxis> meshAxes,
+                                                  ArrayRef<int64_t> meshShape) {
+  ShapedType operandType = operand.getType().cast<ShapedType>();
+  ShapedType resultType = result.getType().cast<ShapedType>();
+  for (int64_t axis = 0; axis < operandType.getRank(); ++axis) {
+    if ((axis != splitAxis && axis != concatAxis) || splitAxis == concatAxis) {
+      if (failed(verifyDimensionCompatibility(
+              result.getLoc(), operandType.getDimSize(axis),
+              resultType.getDimSize(axis), axis))) {
+        return failure();
+      }
+    }
+  }
+
+  if (splitAxis == concatAxis) {
+    return success();
+  }
+
+  auto deviceGroupSize =
+      DimensionSize(collectiveDeviceGroupSize(meshAxes, meshShape));
+  auto operandConcatDimSize = DimensionSize(operandType.getDimSize(concatAxis));
+  auto operandSplitDimSize = DimensionSize(operandType.getDimSize(splitAxis));
+  if (!operandSplitDimSize.isDynamic() && !deviceGroupSize.isDynamic() &&
+      int64_t(operandSplitDimSize) % int64_t(deviceGroupSize) != 0) {
+    return emitError(result.getLoc())
+           << "Operand dimension size " << int64_t(operandSplitDimSize)
+           << " is not divisible by collective device group size "
+           << int64_t(deviceGroupSize) << " for split axis " << splitAxis
+           << ".";
+  }
+  DimensionSize expectedResultConcatDimSize =
+      operandConcatDimSize * deviceGroupSize;
+  DimensionSize expectedResultSplitDimSize =
+      operandSplitDimSize / deviceGroupSize;
+  if (failed(verifyDimensionCompatibility(
+          result.getLoc(), expectedResultConcatDimSize.value(),
+          resultType.getDimSize(concatAxis), concatAxis))) {
+    return failure();
+  }
+  if (failed(verifyDimensionCompatibility(
+          result.getLoc(), expectedResultSplitDimSize.value(),
+          resultType.getDimSize(splitAxis), splitAxis))) {
+    return failure();
+  }
+
+  return success();
+}
+
+LogicalResult verifyReduceScatterOperandAndResultShape(
+    Value operand, Value result, int64_t scatterAxis,
+    ArrayRef<MeshAxis> meshAxes, ArrayRef<int64_t> meshShape) {
+  ShapedType operandType = operand.getType().cast<ShapedType>();
+  ShapedType resultType = result.getType().cast<ShapedType>();
+  for (int64_t axis = 0; axis < operandType.getRank(); ++axis) {
+    if (axis != scatterAxis) {
+      if (failed(verifyDimensionCompatibility(
+              result.getLoc(), operandType.getDimSize(axis),
+              resultType.getDimSize(axis), axis))) {
+        return failure();
+      }
+    }
+  }
+
+  auto deviceGroupSize =
+      DimensionSize(collectiveDeviceGroupSize(meshAxes, meshShape));
+  auto operandScatterDimSize =
+      DimensionSize(operandType.getDimSize(scatterAxis));
+  if (!operandScatterDimSize.isDynamic() && !deviceGroupSize.isDynamic() &&
+      int64_t(operandScatterDimSize) % int64_t(deviceGroupSize) != 0) {
+    return emitError(result.getLoc())
+           << "Operand dimension size " << int64_t(operandScatterDimSize)
+           << " is not divisible by collective device group size "
+           << int64_t(deviceGroupSize) << " for scatter axis " << scatterAxis
+           << ".";
+  }
+  DimensionSize expectedResultScatterDimSize =
+      operandScatterDimSize / deviceGroupSize;
+  if (failed(verifyDimensionCompatibility(
+          result.getLoc(), expectedResultScatterDimSize.value(),
+          resultType.getDimSize(scatterAxis), scatterAxis))) {
+    return failure();
+  }
+
+  return success();
+}
+
+} // namespace
+
+//===----------------------------------------------------------------------===//
+// mesh.all_reduce op
+//===----------------------------------------------------------------------===//
+
+LogicalResult
+mlir::mesh::AllReduceOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
+  return verifyMeshSymbolUses(*this, symbolTable);
+}
+
+void mlir::mesh::AllReduceOp::getCanonicalizationPatterns(
+    RewritePatternSet &patterns, MLIRContext *context) {
+  populateMeshAxesSetCanonicalizationPatterns<AllReduceOp>(patterns, context);
+}
+
+//===----------------------------------------------------------------------===//
+// mesh.all_gather op
+//===----------------------------------------------------------------------===//
+
+LogicalResult
+mlir::mesh::AllGatherOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
+  return verifyMeshSymbolUses(*this, symbolTable);
+}
+
+void mlir::mesh::AllGatherOp::getCanonicalizationPatterns(
+    RewritePatternSet &patterns, MLIRContext *context) {
+  populateMeshAxesSetCanonicalizationPatterns<AllGatherOp>(patterns, context);
+}
+
+LogicalResult mlir::mesh::AllGatherOp::verify() { return verifyGather(*this); }
+
+//===----------------------------------------------------------------------===//
+// mesh.all_to_all op
+//===----------------------------------------------------------------------===//
+
+LogicalResult
+mlir::mesh::AllToAllOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
+  return verifyMeshSymbolUses(*this, symbolTable);
+}
+
+void mlir::mesh::AllToAllOp::getCanonicalizationPatterns(
+    RewritePatternSet &patterns, MLIRContext *context) {
+  populateMeshAxesSetCanonicalizationPatterns<AllToAllOp>(patterns, context);
+}
+
+LogicalResult mlir::mesh::AllToAllOp::verify() {
+  auto mesh = ::getMesh(*this);
+  if (failed(mesh)) {
+    return failure();
+  }
+  return verifyAllToAllOperandAndResultShape(
+      getOperand(), getResult(), getSplitAxis().getSExtValue(),
+      getConcatAxis().getSExtValue(), getMeshAxes(),
+      mesh.value().canonicalDimSizes());
+}
+
+//===----------------------------------------------------------------------===//
+// mesh.reduce_scatter op
+//===----------------------------------------------------------------------===//
+
+LogicalResult mlir::mesh::ReduceScatterOp::verifySymbolUses(
+    SymbolTableCollection &symbolTable) {
+  return verifyMeshSymbolUses(*this, symbolTable);
+}
+
+void mlir::mesh::ReduceScatterOp::getCanonicalizationPatterns(
+    RewritePatternSet &patterns, MLIRContext *context) {
+  populateMeshAxesSetCanonicalizationPatterns<ReduceScatterOp>(patterns,
+                                                               context);
+}
+
+LogicalResult mlir::mesh::ReduceScatterOp::verify() {
+  auto mesh = ::getMesh(*this);
+  if (failed(mesh)) {
+    return failure();
+  }
+  return verifyReduceScatterOperandAndResultShape(
+      getOperand(), getResult(), getScatterAxis().getSExtValue(), getMeshAxes(),
+      mesh.value().canonicalDimSizes());
+}
+
 //===----------------------------------------------------------------------===//
 // TableGen'd op method definitions
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Mesh/canonicalization.mlir b/mlir/test/Dialect/Mesh/canonicalization.mlir
new file mode 100644
index 000000000000000..21fa1887d14a53e
--- /dev/null
+++ b/mlir/test/Dialect/Mesh/canonicalization.mlir
@@ -0,0 +1,72 @@
+// RUN: mlir-opt --canonicalize %s | FileCheck %s
+
+mesh.cluster @mesh0(rank = 2, dim_sizes = [2, 4])
+
+// CHECK-LABEL: func @all_reduce_mesh_axes
+func.func @all_reduce_mesh_axes(
+    %arg0 : tensor<4xf32>) -> tensor<4xf64> {
+// CHECK: mesh_axes = array<i16: 0, 1>
+  %0 = mesh.all_reduce %arg0 {
+    mesh = @mesh0, mesh_axes = array<i16: 1, 0, 0>, reduction = #mesh.partial<sum>
+    } : tensor<4xf32> -> tensor<4xf64>
+  return %0 : tensor<4xf64>
+}
+
+// CHECK-LABEL: func @all_reduce_empty_mesh_axes_and_default_reduction
+func.func @all_reduce_empty_mesh_axes_and_default_reduction(
+    %arg0 : tensor<4xf32>) -> tensor<4xf64> {
+  %0 = mesh.all_reduce %arg0 {
+    mesh = @mesh0,
+// CHECK-NOT: mesh_axes
+    mesh_axes = array<i16>,
+// CHECK-NOT: reduction
+    reduction = #mesh.partial<sum>
+    } : tensor<4xf32> -> tensor<4xf64>
+  return %0 : tensor<4xf64>
+}
+
+// CHECK-LABEL: func @all_gather_empty_mesh_axes
+func.func @all_gather_empty_mesh_axes(
+    %arg0 : tensor<4xf32>) -> tensor<4xf32> {
+  %0 = mesh.all_gather %arg0 {
+    gather_axis = 0 : index,
+    mesh = @mesh0,
+// CHECK-NOT: mesh_axes
+    mesh_axes = array<i16>
+    } : tensor<4xf32> -> tensor<4xf32>
+  return %0 : tensor<4xf32>
+}
+
+// CHECK-LABEL: func @all_gather_mesh_axes
+func.func @all_gather_mesh_axes(
+    %arg0 : tensor<4xf32>) -> tensor<32xf32> {
+// CHECK: mesh_axes = array<i16: 0, 1>
+  %0 = mesh.all_gather %arg0 {
+    mesh = @mesh0, mesh_axes = array<i16: 1, 0, 0>, gather_axis = 0
+    } : tensor<4xf32> -> tensor<32xf32>
+  return %0 : tensor<32xf32>
+}
+
+// CHECK-LABEL: func @reduce_scatter_mesh_axes
+func.func @reduce_scatter_mesh_axes(
+    %arg0 : tensor<8xf32>) -> tensor<1xf64> {
+// CHECK: mesh_axes = array<i16: 0, 1>
+  %0 = mesh.reduce_scatter %arg0 {
+    mesh = @mesh0, mesh_axes = array<i16: 1, 0, 0>, scatter_axis = 0
+    } : tensor<8xf32> -> tensor<1xf64>
+  return %0 : tensor<1xf64>
+}
+
+// CHECK-LABEL: func @reduce_scatter_empty_mesh_axes_and_default_reduction
+func.func @reduce_scatter_empty_mesh_axes_and_default_reduction(
+    %arg0 : tensor<4xf32>) -> tensor<4xf64> {
+  %0 = mesh.reduce_scatter %arg0 {
+    mesh = @mesh0,
+// CHECK-NOT: mesh_axes
+    mesh_axes = array<i16>,
+// CHECK-NOT: reduction
+    reduction = #mesh.partial<sum>,
+    scatter_axis = 0
+    } : tensor<4xf32> -> tensor<4xf64>
+  return %0 : tensor<4xf64>
+}
\ No newline at end of file
diff --git a/mlir/test/Dialect/Mesh/invalid.mlir b/mlir/test/Dialect/Mesh/invalid.mlir
index 246439dd4be7122..413b01459f507a2 100644
--- a/mlir/test/Dialect/Mesh/invalid.mlir
+++ b/mlir/test/Dialect/Mesh/invalid.mlir
@@ -67,3 +67,243 @@ func.func @mesh_axis_negtive_in_partial(
             tensor<4x8xf32, #mesh.shard<@mesh0, [[0]], partial=max[-1]>> {
   return %arg0 : tensor<4x8xf32, #mesh.shard<@mesh0, [[0]], partial=max[-1]>>
 }
+
+// -----
+
+func.func @all_reduce_invalid_mesh_symbol(
+    %arg0 : tensor<4xf32>) -> tensor<4xf64> {
+  // expected-error at +1 {{Undefined required mesh symbol "this_mesh_symbol_does_not_exist".}}
+  %0 = mesh.all_reduce %arg0 {
+    mesh = @this_mesh_symbol_does_not_exist, reduction = #mesh.partial<sum>
+    } : tensor<4xf32> -> tensor<4xf64>
+  return %0 : tensor<4xf64>
+}
+
+// -----
+
+mesh.cluster @mesh0(rank = 2, dim_sizes = [2, 4])
+
+func.func @all_reduce_invalid_mesh_axis(
+    %arg0 : tensor<4xf32>) -> tensor<4xf64> {
+  // expected-error at +1 {{0-based mesh axis index 2 is out of bounds. The referenced mesh "mesh0" is of rank 2.}}
+  %0 = mesh.all_reduce %arg0 {
+    mesh = @mesh0, mesh_axes = array<i16: 2>, reduction = #mesh.partial<sum>
+    } : tensor<4xf32> -> tensor<4xf64>
+  return %0 : tensor<4xf64>
+}
+
+// -----
+
+mesh.cluster @mesh0(rank = 2, dim_sizes = [2, 4])
+
+func.func @all_reduce_invalid_tensor_dimension_size(
+    %arg0 : tensor<4xf32>) -> tensor<5xf64> {
+  // expected-error at +1 {{'mesh.all_reduce' op requires the same shape for all operands and results}}
+  %0 = mesh.all_reduce %arg0 { mesh = @mesh0 } : tensor<4xf32> -> tensor<5xf64>
+  return %0 : tensor<5xf64>
+}
+
+// -----
+
+func.func @all_gather_invalid_mesh_symbol(
+    %arg0 : tensor<4xf32>) -> tensor<4xf32> {
+  // expected-error at +1 {{Undefined required mesh symbol "this_mesh_symbol_does_not_exist".}}
+  %0 = mesh.all_gather %arg0 {
+    mesh = @this_mesh_symbol_does_not_exist, gather_axis = 0 : index
+    } : tensor<4xf32> -> tensor<4xf32>
+  return %0 : tensor<4xf32>
+}
+
+// -----
+
+mesh.cluster @mesh0(rank = 2, dim_sizes = [2, 4])
+
+func.func @all_gather_invalid_mesh_axis(
+    %arg0 : tensor<4xf32>) -> tensor<4xf32> {
+  // expected-error at +1 {{0-based mesh axis index 2 is out of bounds. The referenced mesh "mesh0" is of rank 2.}}
+  %0 = mesh.all_gather %arg0 {
+    mesh = @mesh0, mesh_axes = array<i16: 2>, gather_axis = 0 : index
+    } : tensor<4xf32> -> tensor<4xf32>
+  return %0 : tensor<4xf32>
+}
+
+// -----
+
+mesh.cluster @mesh0(rank = 1, dim_sizes = [1])
+
+func.func @all_gather_invalid_non_gather_axis_dimension_size(
+    %arg0 : tensor<3x4xf32>) -> tensor<3x5xf32> {
+  // expected-error at +1 {{Dimension size mismatch for result axis 1. Expected 4, but got 5.}}
+  %0 = mesh.all_gather %arg0 {
+    mesh = @mesh0, mesh_axes = array<i16: 0>, gather_axis = 0 : index
+    } : tensor<3x4xf32> -> tensor<3x5xf32>
+  return %0 : tensor<3x5xf32>
+}
+
+// -----
+
+mesh.cluster @mesh0(rank = 2, dim_sizes = [1, 2])
+
+func.func @all_gather_invalid_gather_axis_dimension_size(
+    %arg0 : tensor<3x4xf32>) -> tensor<3x5xf32> {
+  // expected-error at +1 {{Dimension size mismatch for result axis 1. Expected 8, but got 5.}}
+  %0 = mesh.all_gather %arg0 {
+    mesh = @mesh0, mesh_axes = array<i16: 1>, gather_axis = 1 : index
+    } : tensor<3x4xf32> -> tensor<3x5xf32>
+  return %0 : tensor<3x5xf32>
+}
+
+// -----
+
+mesh.cluster @mesh0(rank = 1, dim_sizes = [1])
+
+func.func @all_gather_invalid_gather_axis_dynamic_dimension(
+    %arg0 : tensor<?xf32>) -> tensor<3xf32> {
+  // expected-error at +1 {{Dimension size mismatch for result axis 0. Expected dynamic, but got 3.}}
+  %0 = mesh.all_gather %arg0 {
+    gather_axis = 0 : index,
+    mesh = @mesh0
+    } : tensor<?xf32> -> tensor<3xf32>
+  return %0 : tensor<3xf32>
+}
+
+// -----
+
+mesh.cluster @mesh0(rank = 1, dim_sizes = [1])
+
+func.func @all_gather_invalid_gather_axis(
+    %arg0 : tensor<3xf32>) -> tensor<3xf32> {
+  // expected-error at +1 {{Gather axis 1 is out of bounds [0, 1).}}
+  %0 = mesh.all_gather %arg0 {
+    mesh = @mesh0, mesh_axes = array<i16: 0>, gather_axis = 1 : index
+    } : tensor<3xf32> -> tensor<3xf32>
+  return %0 : tensor<3xf32>
+}
+
+// -----
+
+mesh.cluster @mesh0(rank = 1, dim_sizes = [1])
+
+func.func @all_gather_invalid_negative_gather_axis(
+    %arg0 : tensor<3xf32>) -> tensor<3xf32> {
+  // expected-error at +1 {{Gather axis -1 is out of bounds [0, 1).}}
+  %0 = mesh.all_gather %arg0 {
+    mesh = @mesh0, mesh_axes = array<i16: 0>, gather_axis = -1 : index
+    } : tensor<3xf32> -> tensor<3xf32>
+  return %0 : tensor<3xf32>
+}
+
+// -----
+
+func.func @all_to_all_gather_invalid_mesh_symbol(
+    %arg0 : tensor<3x6xi8>) -> tensor<3x6xi8> {
+  // expected-error at +1 {{Undefined required mesh symbol "this_mesh_symbol_does_not_exist".}}
+  %0 = mesh.all_to_all %arg0 {
+      concat_axis = 0, mesh = @this_mesh_symbol_does_not_exist, split_axis = 1
+    } : tensor<3x6xi8> -> tensor<3x6xi8>
+  return %0 : tensor<3x6xi8>
+}
+
+// -----
+
+mesh.cluster @mesh0(rank = 2, dim_sizes = [0, 1])
+
+func.func @all_to_all_invalid_non_dynamic_result_dimension_induced_by_dynamic_device_group(
+    %arg0 : tensor<3x6xi8>) -> tensor<3x6xi8> {
+  // expected-error at +1 {{Dimension size mismatch for result axis 1. Expected dynamic, but got 6.}}
+  %0 = mesh.all_to_all %arg0 {
+      concat_axis = 1, mesh = @mesh0, mesh_axes = array<i16: 0>, split_axis = 0
+    } : tensor<3x6xi8> -> tensor<3x6xi8>
+  return %0 : tensor<3x6xi8>
+}
+
+// -----
+
+mesh.cluster @mesh0(rank = 2, dim_sizes = [1, 1])
+
+func.func @all_to_all_invalid_non_dynamic_result_split_dimension_induced_by_dynamic_operand_dimension(
+    %arg0 : tensor<?x6xi8>) -> tensor<3x?xi8> {
+  // expected-error at +1 {{Dimension size mismatch for result axis 0. Expected dynamic, but got 3.}}
+  %0 = mesh.all_to_all %arg0 {
+      concat_axis = 1, mesh = @mesh0, mesh_axes = array<i16: 1>, split_axis = 0
+    } : tensor<?x6xi8> -> tensor<3x?xi8>
+  return %0 : tensor<3x?xi8>
+}
+
+// -----
+
+mesh.cluster @mesh0(rank = 2, dim_sizes = [1, 1])
+
+func.func @all_to_all_invalid_non_dynamic_result_concat_dimension_induced_by_dynamic_operand_dimension(
+    %arg0 : tensor<3x?xi8>) -> tensor<?x3xi8> {
+  // expected-error at +1 {{Dimension size mismatch for result axis 1. Expected dynamic, but got 3.}}
+  %0 = mesh.all_to_all %arg0 {
+      concat_axis = 1, mesh = @mesh0, mesh_axes = array<i16: 1>, split_axis = 0
+    } : tensor<3x?xi8> -> tensor<?x3xi8>
+  return %0 : tensor<?x3xi8>
+}
+
+// -----
+
+mesh.cluster @mesh0(rank = 1, dim_sizes = [3])
+
+func.func @all_to_all_invalid_non_dynamic_result_concat_dimension_size(
+    %arg0 : tensor<3x2xi8>) -> tensor<1x7xi8> {
+  // expected-error at +1 {{Dimension size mismatch for result axis 1. Expected 6, but got 7.}}
+  %0 = mesh.all_to_all %arg0 {
+      concat_axis = 1, mesh = @mesh0, mesh_axes = array<i16: 0>, split_axis = 0
+    } : tensor<3x2xi8> -> tensor<1x7xi8>
+  return %0 : tensor<1x7xi8>
+}
+
+// -----
+
+mesh.cluster @mesh0(rank = 1, dim_sizes = [3])
+
+func.func @all_to_all_invalid_non_dynamic_result_split_dimension_size(
+    %arg0 : tensor<3x2xi8>) -> tensor<2x6xi8> {
+  // expected-error at +1 {{Dimension size mismatch for result axis 0. Expected 1, but got 2.}}
+  %0 = mesh.all_to_all %arg0 {
+      concat_axis = 1, mesh = @mesh0, mesh_axes = array<i16: 0>, split_axis = 0
+    } : tensor<3x2xi8> -> tensor<2x6xi8>
+  return %0 : tensor<2x6xi8>
+}
+
+// -----
+
+mesh.cluster @mesh0(rank = 1, dim_sizes = [3])
+
+func.func @reduce_scatter_invalid_dynamic_dimension(
+    %arg0 : tensor<?xf32>) -> tensor<2xf64> {
+  // expected-error at +1 {{Dimension size mismatch for result axis 0. Expected dynamic, but got 2.}}
+  %0 = mesh.reduce_scatter %arg0 {
+      mesh = @mesh0, scatter_axis = 0
+    } : tensor<?xf32> -> tensor<2xf64>
+  return %0 : tensor<2xf64>
+}
+
+// -----
+
+mesh.cluster @mesh0(rank = 1, dim_sizes = [3])
+
+func.func @reduce_scatter_invalid_static_dimension_size(
+    %arg0 : tensor<3xf32>) -> tensor<2xf64> {
+  // expected-error at +1 {{Dimension size mismatch for result axis 0. Expected 1, but got 2.}}
+  %0 = mesh.reduce_scatter %arg0 {
+      mesh = @mesh0, mesh_axes = array<i16: 0>, scatter_axis = 0
+    } : tensor<3xf32> -> tensor<2xf64>
+  return %0 : tensor<2xf64>
+}
+
+// -----
+
+mesh.cluster @mesh0(rank = 1, dim_sizes = [3])
+
+func.func @reduce_scatter_invalid_operand_static_dimension_size(
+    %arg0 : tensor<4xf32>) -> tensor<?xf64> {
+  // expected-error at +1 {{Operand dimension size 4 is not divisible by collective device group size 3 for scatter axis 0.}}
+  %0 = mesh.reduce_scatter %arg0 {
+      mesh = @mesh0, mesh_axes = array<i16: 0>, scatter_axis = 0
+    } : tensor<4xf32> -> tensor<?xf64>
+  return %0 : tensor<?xf64>
+}
diff --git a/mlir/test/Dialect/Mesh/ops.mlir b/mlir/test/Dialect/Mesh/ops.mlir
index ee5f8f67792b928..5ec6c4a439327f0 100644
--- a/mlir/test/Dialect/Mesh/ops.mlir
+++ b/mlir/test/Dialect/Mesh/ops.mlir
@@ -12,6 +12,8 @@ mesh.cluster @mesh2(rank = 2, dim_sizes = [0, 4])
 // CHECK: mesh.cluster @mesh3
 mesh.cluster @mesh3(rank = 2)
 
+mesh.cluster @mesh4(rank = 1, dim_sizes = [3])
+
 // CHECK-LABEL: func @mesh_shard_encoding_fully_replicated
 func.func @mesh_shard_encoding_fully_replicated(
     // CHECK-SAME: %[[ARG:.*]]: tensor<4x8xf32, #mesh.shard<@mesh0, {{\[\[}}]]>>
@@ -126,3 +128,120 @@ func.func @mesh_shard_op_two_users(%arg0 : tensor<4x8xf32>) ->
   %2 = mesh.shard %0 to <@mesh0, [[2]]> annotate_for_users : tensor<4x8xf32>
   return %1, %2 : tensor<4x8xf32>, tensor<4x8xf32>
 }
+
+// CHECK-LABEL: func @all_reduce
+func.func @all_reduce(
+    // CHECK-SAME: %[[ARG:.*]]: tensor<3x4xf32>
+    %arg0 : tensor<3x4xf32>) -> tensor<3x4xf64> {
+  // CHECK-NEXT: mesh.all_reduce %[[ARG]]
+  // CHECK-SAME: {mesh = @mesh0, mesh_axes = array<i16: 1, 0>, reduction = #mesh.partial<max>}
+  // CHECK-SAME: : tensor<3x4xf32> -> tensor<3x4xf64>
+  %0 = mesh.all_reduce %arg0 {
+      mesh = @mesh0, mesh_axes = array<i16: 1, 0>, reduction = #mesh.partial<max>
+    } : tensor<3x4xf32> -> tensor<3x4xf64>
+  return %0 : tensor<3x4xf64>
+}
+
+// CHECK-LABEL: func @all_gather
+func.func @all_gather(
+    // CHECK-SAME: %[[ARG:.*]]: tensor<3x4xf32>
+    %arg0 : tensor<3x4xf32>) -> tensor<3x16xf32> {
+  // CHECK-NEXT: mesh.all_gather %[[ARG]]
+  // CHECK-SAME: {gather_axis = 1 : index, mesh = @mesh0, mesh_axes = array<i16: 2>}
+  // CHECK-SAME: : tensor<3x4xf32> -> tensor<3x16xf32>
+  %0 = mesh.all_gather %arg0 {
+      gather_axis = 1 : index, mesh = @mesh0, mesh_axes = array<i16: 2>
+    } : tensor<3x4xf32> -> tensor<3x16xf32>
+  return %0 : tensor<3x16xf32>
+}
+
+// CHECK-LABEL: func @all_gather_dynamic_dims_in_tensor
+func.func @all_gather_dynamic_dims_in_tensor(
+    // CHECK-SAME: %[[ARG:.*]]: tensor<?x?xf32>
+    %arg0 : tensor<?x?xf32>) -> tensor<?x?xf32> {
+  // CHECK-NEXT: mesh.all_gather %[[ARG]]
+  // CHECK-SAME: {gather_axis = 1 : index, mesh = @mesh0, mesh_axes = array<i16: 2>}
+  // CHECK-SAME: : tensor<?x?xf32> -> tensor<?x?xf32>
+  %0 = mesh.all_gather %arg0 {
+      gather_axis = 1 : index, mesh = @mesh0, mesh_axes = array<i16: 2>
+    } : tensor<?x?xf32> -> tensor<?x?xf32>
+  return %0 : tensor<?x?xf32>
+}
+
+// CHECK-LABEL: func @all_gather_dynamic_dims_in_mesh
+func.func @all_gather_dynamic_dims_in_mesh(
+    // CHECK-SAME: %[[ARG:.*]]: tensor<5x6xf32>
+    %arg0 : tensor<5x6xf32>) -> tensor<5x?xf32> {
+  // CHECK-NEXT: mesh.all_gather %[[ARG]]
+  // CHECK-SAME: {gather_axis = 1 : index, mesh = @mesh3, mesh_axes = array<i16: 1>}
+  // CHECK-SAME: : tensor<5x6xf32> -> tensor<5x?xf32>
+  %0 = mesh.all_gather %arg0 {
+      gather_axis = 1 : index, mesh = @mesh3, mesh_axes = array<i16: 1>
+    } : tensor<5x6xf32> -> tensor<5x?xf32>
+  return %0 : tensor<5x?xf32>
+}
+
+// CHECK-LABEL: func @all_to_all
+func.func @all_to_all(
+    // CHECK-SAME: %[[ARG:.*]]: tensor<3x6xi8>
+    %arg0 : tensor<3x6xi8>) -> tensor<3x6xi8> {
+  // CHECK-NEXT: mesh.all_to_all %[[ARG]]
+  // CHECK-SAME: {concat_axis = 0 : i64, mesh = @mesh4, split_axis = 1 : i64}
+  // CHECK-SAME: : tensor<3x6xi8> -> tensor<3x6xi8>
+  %0 = mesh.all_to_all %arg0 {
+      concat_axis = 0, mesh = @mesh4, split_axis = 1
+    } : tensor<3x6xi8> -> tensor<3x6xi8>
+  return %0 : tensor<3x6xi8>
+}
+
+// CHECK-LABEL: func @all_to_all_dynamic_dims_in_result
+func.func @all_to_all_dynamic_dims_in_result(
+    // CHECK-SAME: %[[ARG:.*]]: tensor<3x6xi8>
+    %arg0 : tensor<3x6xi8>) -> tensor<3x?xi8> {
+  // CHECK-NEXT: mesh.all_to_all %[[ARG]]
+  // CHECK-SAME: {concat_axis = 0 : i64, mesh = @mesh4, split_axis = 1 : i64}
+  // CHECK-SAME: : tensor<3x6xi8> -> tensor<3x?xi8>
+  %0 = mesh.all_to_all %arg0 {
+      concat_axis = 0, mesh = @mesh4, split_axis = 1
+    } : tensor<3x6xi8> -> tensor<3x?xi8>
+  return %0 : tensor<3x?xi8>
+}
+
+// CHECK-LABEL: func @all_to_all
+func.func @all_to_all_same_split_concat_dim_with_dynamic_device_group_size(
+    // CHECK-SAME: %[[ARG:.*]]: tensor<3xi8>
+    %arg0 : tensor<3xi8>) -> tensor<3xi8> {
+  // CHECK-NEXT: mesh.all_to_all %[[ARG]]
+  // CHECK-SAME: {concat_axis = 0 : i64, mesh = @mesh4, split_axis = 0 : i64}
+  // CHECK-SAME: : tensor<3xi8> -> tensor<3xi8>
+  %0 = mesh.all_to_all %arg0 {
+      concat_axis = 0, mesh = @mesh4, split_axis = 0
+    } : tensor<3xi8> -> tensor<3xi8>
+  return %0 : tensor<3xi8>
+}
+
+// CHECK-LABEL: func @reduce_scatter_static_dimensions
+func.func @reduce_scatter_static_dimensions(
+    // CHECK-SAME: %[[ARG:.*]]: tensor<3x4xf32>
+    %arg0 : tensor<3x4xf32>) -> tensor<3x1xf64> {
+  // CHECK-NEXT: mesh.reduce_scatter %[[ARG]]
+  // CHECK-SAME: {mesh = @mesh0, mesh_axes = array<i16: 2>, reduction = #mesh.partial<max>, scatter_axis = 1 : i64}
+  // CHECK-SAME: : tensor<3x4xf32> -> tensor<3x1xf64>
+  %0 = mesh.reduce_scatter %arg0 {
+      mesh = @mesh0, mesh_axes = array<i16: 2>, reduction = #mesh.partial<max>, scatter_axis = 1
+    } : tensor<3x4xf32> -> tensor<3x1xf64>
+  return %0 : tensor<3x1xf64>
+}
+
+// CHECK-LABEL: func @reduce_scatter_dynamic_dimensions
+func.func @reduce_scatter_dynamic_dimensions(
+    // CHECK-SAME: %[[ARG:.*]]: tensor<?xf32>
+    %arg0 : tensor<?xf32>) -> tensor<?xf64> {
+  // CHECK-NEXT: mesh.reduce_scatter %[[ARG]]
+  // CHECK-SAME: {mesh = @mesh3, mesh_axes = array<i16: 0, 1>, scatter_axis = 0 : i64}
+  // CHECK-SAME: : tensor<?xf32> -> tensor<?xf64>
+  %0 = mesh.reduce_scatter %arg0 {
+      mesh = @mesh3, mesh_axes = array<i16: 0, 1>, scatter_axis = 0
+    } : tensor<?xf32> -> tensor<?xf64>
+  return %0 : tensor<?xf64>
+}

>From cb3347fa2c23d7d40ca685b92254088c6707a933 Mon Sep 17 00:00:00 2001
From: Boian Petkantchin <boian at nod-labs.com>
Date: Tue, 14 Nov 2023 10:57:51 -0800
Subject: [PATCH 2/2] Adress comments in PR (to squash)

---
 mlir/docs/Dialects/Mesh.md                   |  13 +-
 mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td |  43 ++----
 mlir/lib/Dialect/Mesh/IR/MeshOps.cpp         | 138 ++++++++-----------
 mlir/test/Dialect/Mesh/canonicalization.mlir |  29 ----
 mlir/test/Dialect/Mesh/invalid.mlir          |  54 +++++++-
 5 files changed, 138 insertions(+), 139 deletions(-)

diff --git a/mlir/docs/Dialects/Mesh.md b/mlir/docs/Dialects/Mesh.md
index 6dd4f79022061ee..03877f1a6544817 100644
--- a/mlir/docs/Dialects/Mesh.md
+++ b/mlir/docs/Dialects/Mesh.md
@@ -14,16 +14,25 @@ It is assumed that the user is familiar with collective operations.
 explanation.
 The main addition is that the collectives in this dialect have mesh
 semantics.
-The operation attributes `mesh` and `mesh_axes` specifies a set of device mesh
+
+The operation attributes `mesh` and `mesh_axes` specifies a list of device mesh
 axes that partition the devices into disjoint groups.
 The collective operation is performed between devices in the same group.
 Devices that have the same coordinates outside of axes `mesh_axes` are in the
 same group.
 For example if we have a device mesh of size `2x3x4x5` and the partition mesh
-axes set is `{0, 1}` then devices are partitioned into the groups
+axes list is `[0, 1]` then devices are partitioned into the groups
 `{ { (i, j, k, m) | 0<=i<2, 0<=j<3 } | 0<=k<4, 0<=m<5 }`.
 Devices (1, 0, 2, 3) and (1, 1, 2, 3) will be in the same group.
 Device (1, 0, 2, 4) will be in another group.
+Some collective operations like all-to-all and all-gather care about the
+order of devices.
+The order of device in a device group is induced by the order of axes in
+`mesh_axes`.
+The axes are ordered from outer to inner.
+If we have an axis list `[3, 1]` then device `(i, 1, k, 0)` will precede
+both devices `(i, 0, k, 1)` and `(i, 2, k, 0)`.
+
 
 ## Operations
 
diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
index 15354babe870599..34110d400d7017e 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
@@ -190,12 +190,12 @@ class Mesh_CollectiveCommunicationOpBase<
     string mnemonic, list<Trait> traits = []> :
     Mesh_Op<mnemonic,
       !listconcat(traits,
-      [SymbolUserOpInterface])> {
+      [DeclareOpInterfaceMethods<SymbolUserOpInterface>])> {
   let assemblyFormat = "$input attr-dict `:` type($input) `->` type($result)";
-  code extraClassDeclarationBase = [{
-    ::mlir::LogicalResult verifySymbolUses(
-          ::mlir::SymbolTableCollection &symbolTable);
-  }];
+  dag commonArgs = (ins
+    FlatSymbolRefAttr:$mesh,
+    DefaultValuedOptionalAttr<DenseI16ArrayAttr, "{}">:$mesh_axes
+  );
 }
 
 def Mesh_AllGatherOp : Mesh_CollectiveCommunicationOpBase<"all_gather", [
@@ -237,18 +237,14 @@ def Mesh_AllGatherOp : Mesh_CollectiveCommunicationOpBase<"all_gather", [
     +-------------+
     ```
   }];
-  let arguments = (ins
+  let arguments = !con(commonArgs, (ins
     AnyNon0RankedTensor:$input,
-    FlatSymbolRefAttr:$mesh,
-    DefaultValuedOptionalAttr<DenseI16ArrayAttr, "{}">:$mesh_axes,
     APIntAttr:$gather_axis
-  );
+  ));
   let results = (outs
     AnyNon0RankedTensor:$result
   );
-  let hasCanonicalizer = 1;
   let hasVerifier = 1;
-  let extraClassDeclaration = extraClassDeclarationBase;
 }
 
 def Mesh_AllReduceOp : Mesh_CollectiveCommunicationOpBase<"all_reduce", [
@@ -270,17 +266,14 @@ def Mesh_AllReduceOp : Mesh_CollectiveCommunicationOpBase<"all_reduce", [
       } : tensor<3x4xf32> -> tensor<3x4xf64>
     ```
   }];
-  let arguments = (ins
+  let arguments = !con(commonArgs, (ins
     AnyRankedTensor:$input,
-    FlatSymbolRefAttr:$mesh,
-    DefaultValuedOptionalAttr<DenseI16ArrayAttr, "{}">:$mesh_axes,
     DefaultValuedOptionalAttr<Mesh_PartialAttr, "::mlir::mesh::Partial::Sum">:$reduction
-  );
+  ));
   let results = (outs
     AnyRankedTensor:$result
   );
-  let hasCanonicalizer = 1;
-  let extraClassDeclaration = extraClassDeclarationBase;
+  let hasVerifier = 1;
 }
 
 def Mesh_AllToAllOp : Mesh_CollectiveCommunicationOpBase<"all_to_all", [
@@ -319,19 +312,15 @@ def Mesh_AllToAllOp : Mesh_CollectiveCommunicationOpBase<"all_to_all", [
     +-------+-------+-------+
     ```
   }];
-  let arguments = (ins
+  let arguments = !con(commonArgs, (ins
     AnyNon0RankedTensor:$input,
-    FlatSymbolRefAttr:$mesh,
-    DefaultValuedOptionalAttr<DenseI16ArrayAttr, "{}">:$mesh_axes,
     APIntAttr:$split_axis,
     APIntAttr:$concat_axis
-  );
+  ));
   let results = (outs
     AnyNon0RankedTensor:$result
   );
-  let hasCanonicalizer = 1;
   let hasVerifier = 1;
-  let extraClassDeclaration = extraClassDeclarationBase;
 }
 
 def Mesh_ReduceScatterOp : Mesh_CollectiveCommunicationOpBase<"reduce_scatter", [
@@ -372,19 +361,15 @@ def Mesh_ReduceScatterOp : Mesh_CollectiveCommunicationOpBase<"reduce_scatter",
     +-------+
     ```
   }];
-  let arguments = (ins
+  let arguments = !con(commonArgs, (ins
     AnyNon0RankedTensor:$input,
-    FlatSymbolRefAttr:$mesh,
-    DefaultValuedOptionalAttr<DenseI16ArrayAttr, "{}">:$mesh_axes,
     DefaultValuedOptionalAttr<Mesh_PartialAttr, "::mlir::mesh::Partial::Sum">:$reduction,
     APIntAttr:$scatter_axis
-  );
+  ));
   let results = (outs
     AnyRankedTensor:$result
   );
-  let hasCanonicalizer = 1;
   let hasVerifier = 1;
-  let extraClassDeclaration = extraClassDeclarationBase;
 }
 
 #endif // MLIR_DIALECT_MESH_IR_MESHOPS_TD
diff --git a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
index 6efc4c4ecc326ad..bdbd3ad1514dfc6 100644
--- a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
+++ b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
@@ -16,6 +16,7 @@
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/Support/LLVM.h"
 #include "mlir/Support/LogicalResult.h"
+#include "llvm/ADT/ArrayRef.h"
 #include "llvm/ADT/DenseSet.h"
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/SmallSet.h"
@@ -169,6 +170,7 @@ LogicalResult ClusterOp::verify() {
 SmallVector<int64_t> ClusterOp::canonicalDimSizes() {
   SmallVector<int64_t> result;
   canonicalDimSizes(std::back_inserter(result));
+  result.reserve(getRank());
   return result;
 }
 
@@ -211,44 +213,6 @@ MeshShardingAttr::verify(function_ref<InFlightDiagnostic()> emitError,
 
 namespace {
 
-std::optional<DenseI16ArrayAttr>
-canonicalizeAxesSetAttribute(DenseI16ArrayAttr attr) {
-  if (!attr) {
-    return std::nullopt;
-  }
-  SmallVector<int16_t> axes = llvm::to_vector(attr.asArrayRef());
-  canonicalizeSetAsVector(axes);
-  if (axes.empty()) {
-    return std::nullopt;
-  }
-  return DenseI16ArrayAttr::get(attr.getContext(), axes);
-}
-
-template <typename Op>
-struct AxesSetCanonicalizationPattern : OpRewritePattern<Op> {
-  AxesSetCanonicalizationPattern(MLIRContext *context, StringRef axisSetAttr)
-      : OpRewritePattern<Op>(context), axisSetAttr(axisSetAttr) {}
-  LogicalResult matchAndRewrite(Op op,
-                                PatternRewriter &rewriter) const override {
-    auto canonicalMeshAxesAttr = canonicalizeAxesSetAttribute(
-        op->template getAttrOfType<DenseI16ArrayAttr>(axisSetAttr));
-    if (!canonicalMeshAxesAttr) {
-      op->removeAttr(axisSetAttr);
-    } else {
-      op->setAttr(axisSetAttr, canonicalMeshAxesAttr.value());
-    }
-    return success();
-  }
-
-  std::string axisSetAttr;
-};
-
-template <typename Op>
-void populateMeshAxesSetCanonicalizationPatterns(RewritePatternSet &patterns,
-                                                 MLIRContext *context) {
-  patterns.add<AxesSetCanonicalizationPattern<Op>>(context, "mesh_axes");
-}
-
 template <typename Op>
 LogicalResult verifyMeshSymbolUses(Op op, SymbolTableCollection &symbolTable) {
   FlatSymbolRefAttr symbolAttr = op.getMeshAttr();
@@ -280,10 +244,36 @@ LogicalResult verifyMeshSymbolUses(Op op, SymbolTableCollection &symbolTable) {
   return success();
 }
 
+template <typename It>
+bool isUnique(It begin, It end) {
+  if (begin == end) {
+    return true;
+  }
+  It next = std::next(begin);
+  if (next == end) {
+    return true;
+  }
+  for (; next != end; ++next, ++begin) {
+    if (*begin == *next) {
+      return false;
+    }
+  }
+  return true;
+}
+
+LogicalResult verifyMeshAxes(Location loc, ArrayRef<MeshAxis> axes) {
+  SmallVector<MeshAxis> sorted = llvm::to_vector(axes);
+  std::sort(sorted.begin(), sorted.end());
+  if (!isUnique(sorted.begin(), sorted.end())) {
+    return emitError(loc) << "Mesh axes contains duplicate elements.";
+  }
+  return success();
+}
+
 template <typename It>
 auto product(It begin, It end) {
   using ElementType = std::decay_t<decltype(*begin)>;
-  return std::accumulate(begin, end, ElementType(1),
+  return std::accumulate(begin, end, static_cast<ElementType>(1),
                          std::multiplies<ElementType>());
 }
 
@@ -295,15 +285,15 @@ auto product(R &&range) {
 int64_t collectiveDeviceGroupSize(ArrayRef<MeshAxis> meshAxes,
                                   ArrayRef<int64_t> meshShape) {
   int64_t res = 1;
-  for (MeshAxis axis = 0; axis < MeshAxis(meshShape.size()); ++axis) {
-    if (llvm::find(meshAxes, axis) == meshAxes.end()) {
-      continue;
-    }
+
+  for (MeshAxis axis : meshAxes) {
     if (isMeshDimensionDynamic(meshShape[axis])) {
       return ShapedType::kDynamic;
     }
+    assert(size_t(axis) < meshShape.size());
     res *= meshShape[axis];
   }
+
   return res;
 }
 
@@ -324,10 +314,9 @@ LogicalResult verifyDimensionCompatibility(Location loc,
   return success();
 }
 
-LogicalResult verifyGatherOperandAndResultShape(Value operand, Value result,
-                                                int64_t gatherAxis,
-                                                ArrayRef<MeshAxis> meshAxes,
-                                                ArrayRef<int64_t> meshShape) {
+LogicalResult verifyAllGatherOperandAndResultShape(
+    Value operand, Value result, int64_t gatherAxis,
+    ArrayRef<MeshAxis> meshAxes, ArrayRef<int64_t> meshShape) {
   ShapedType operandType = operand.getType().cast<ShapedType>();
   ShapedType resultType = result.getType().cast<ShapedType>();
   auto deviceGroupSize =
@@ -358,7 +347,7 @@ FailureOr<ClusterOp> getMesh(Op op) {
 }
 
 template <typename Op>
-LogicalResult verifyGather(Op op) {
+LogicalResult verifyAllGather(Op op) {
   auto rank = op.getResult().getType().template cast<ShapedType>().getRank();
   auto gatherAxis = op.getGatherAxis().getSExtValue();
   if (gatherAxis < 0 || gatherAxis >= rank) {
@@ -370,9 +359,9 @@ LogicalResult verifyGather(Op op) {
   if (failed(mesh)) {
     return failure();
   }
-  return verifyGatherOperandAndResultShape(op.getOperand(), op.getResult(),
-                                           gatherAxis, op.getMeshAxes(),
-                                           mesh.value().canonicalDimSizes());
+  return verifyAllGatherOperandAndResultShape(op.getOperand(), op.getResult(),
+                                              gatherAxis, op.getMeshAxes(),
+                                              mesh.value().canonicalDimSizes());
 }
 
 LogicalResult verifyAllToAllOperandAndResultShape(Value operand, Value result,
@@ -471,13 +460,12 @@ LogicalResult verifyReduceScatterOperandAndResultShape(
 //===----------------------------------------------------------------------===//
 
 LogicalResult
-mlir::mesh::AllReduceOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
+AllReduceOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
   return verifyMeshSymbolUses(*this, symbolTable);
 }
 
-void mlir::mesh::AllReduceOp::getCanonicalizationPatterns(
-    RewritePatternSet &patterns, MLIRContext *context) {
-  populateMeshAxesSetCanonicalizationPatterns<AllReduceOp>(patterns, context);
+LogicalResult mlir::mesh::AllReduceOp::verify() {
+  return verifyMeshAxes(getLoc(), getMeshAxes());
 }
 
 //===----------------------------------------------------------------------===//
@@ -485,32 +473,29 @@ void mlir::mesh::AllReduceOp::getCanonicalizationPatterns(
 //===----------------------------------------------------------------------===//
 
 LogicalResult
-mlir::mesh::AllGatherOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
+AllGatherOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
   return verifyMeshSymbolUses(*this, symbolTable);
 }
 
-void mlir::mesh::AllGatherOp::getCanonicalizationPatterns(
-    RewritePatternSet &patterns, MLIRContext *context) {
-  populateMeshAxesSetCanonicalizationPatterns<AllGatherOp>(patterns, context);
+LogicalResult mlir::mesh::AllGatherOp::verify() {
+  if (failed(verifyMeshAxes(getLoc(), getMeshAxes()))) {
+    return failure();
+  }
+  return verifyAllGather(*this);
 }
 
-LogicalResult mlir::mesh::AllGatherOp::verify() { return verifyGather(*this); }
-
 //===----------------------------------------------------------------------===//
 // mesh.all_to_all op
 //===----------------------------------------------------------------------===//
 
-LogicalResult
-mlir::mesh::AllToAllOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
+LogicalResult AllToAllOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
   return verifyMeshSymbolUses(*this, symbolTable);
 }
 
-void mlir::mesh::AllToAllOp::getCanonicalizationPatterns(
-    RewritePatternSet &patterns, MLIRContext *context) {
-  populateMeshAxesSetCanonicalizationPatterns<AllToAllOp>(patterns, context);
-}
-
-LogicalResult mlir::mesh::AllToAllOp::verify() {
+LogicalResult AllToAllOp::verify() {
+  if (failed(verifyMeshAxes(getLoc(), getMeshAxes()))) {
+    return failure();
+  }
   auto mesh = ::getMesh(*this);
   if (failed(mesh)) {
     return failure();
@@ -525,18 +510,15 @@ LogicalResult mlir::mesh::AllToAllOp::verify() {
 // mesh.reduce_scatter op
 //===----------------------------------------------------------------------===//
 
-LogicalResult mlir::mesh::ReduceScatterOp::verifySymbolUses(
-    SymbolTableCollection &symbolTable) {
+LogicalResult
+ReduceScatterOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
   return verifyMeshSymbolUses(*this, symbolTable);
 }
 
-void mlir::mesh::ReduceScatterOp::getCanonicalizationPatterns(
-    RewritePatternSet &patterns, MLIRContext *context) {
-  populateMeshAxesSetCanonicalizationPatterns<ReduceScatterOp>(patterns,
-                                                               context);
-}
-
-LogicalResult mlir::mesh::ReduceScatterOp::verify() {
+LogicalResult ReduceScatterOp::verify() {
+  if (failed(verifyMeshAxes(getLoc(), getMeshAxes()))) {
+    return failure();
+  }
   auto mesh = ::getMesh(*this);
   if (failed(mesh)) {
     return failure();
diff --git a/mlir/test/Dialect/Mesh/canonicalization.mlir b/mlir/test/Dialect/Mesh/canonicalization.mlir
index 21fa1887d14a53e..2bacfbe992525ee 100644
--- a/mlir/test/Dialect/Mesh/canonicalization.mlir
+++ b/mlir/test/Dialect/Mesh/canonicalization.mlir
@@ -2,15 +2,6 @@
 
 mesh.cluster @mesh0(rank = 2, dim_sizes = [2, 4])
 
-// CHECK-LABEL: func @all_reduce_mesh_axes
-func.func @all_reduce_mesh_axes(
-    %arg0 : tensor<4xf32>) -> tensor<4xf64> {
-// CHECK: mesh_axes = array<i16: 0, 1>
-  %0 = mesh.all_reduce %arg0 {
-    mesh = @mesh0, mesh_axes = array<i16: 1, 0, 0>, reduction = #mesh.partial<sum>
-    } : tensor<4xf32> -> tensor<4xf64>
-  return %0 : tensor<4xf64>
-}
 
 // CHECK-LABEL: func @all_reduce_empty_mesh_axes_and_default_reduction
 func.func @all_reduce_empty_mesh_axes_and_default_reduction(
@@ -37,26 +28,6 @@ func.func @all_gather_empty_mesh_axes(
   return %0 : tensor<4xf32>
 }
 
-// CHECK-LABEL: func @all_gather_mesh_axes
-func.func @all_gather_mesh_axes(
-    %arg0 : tensor<4xf32>) -> tensor<32xf32> {
-// CHECK: mesh_axes = array<i16: 0, 1>
-  %0 = mesh.all_gather %arg0 {
-    mesh = @mesh0, mesh_axes = array<i16: 1, 0, 0>, gather_axis = 0
-    } : tensor<4xf32> -> tensor<32xf32>
-  return %0 : tensor<32xf32>
-}
-
-// CHECK-LABEL: func @reduce_scatter_mesh_axes
-func.func @reduce_scatter_mesh_axes(
-    %arg0 : tensor<8xf32>) -> tensor<1xf64> {
-// CHECK: mesh_axes = array<i16: 0, 1>
-  %0 = mesh.reduce_scatter %arg0 {
-    mesh = @mesh0, mesh_axes = array<i16: 1, 0, 0>, scatter_axis = 0
-    } : tensor<8xf32> -> tensor<1xf64>
-  return %0 : tensor<1xf64>
-}
-
 // CHECK-LABEL: func @reduce_scatter_empty_mesh_axes_and_default_reduction
 func.func @reduce_scatter_empty_mesh_axes_and_default_reduction(
     %arg0 : tensor<4xf32>) -> tensor<4xf64> {
diff --git a/mlir/test/Dialect/Mesh/invalid.mlir b/mlir/test/Dialect/Mesh/invalid.mlir
index 413b01459f507a2..27d80e285d83778 100644
--- a/mlir/test/Dialect/Mesh/invalid.mlir
+++ b/mlir/test/Dialect/Mesh/invalid.mlir
@@ -96,6 +96,19 @@ func.func @all_reduce_invalid_mesh_axis(
 
 mesh.cluster @mesh0(rank = 2, dim_sizes = [2, 4])
 
+func.func @all_reduce_duplicate_mesh_axis(
+    %arg0 : tensor<4xf32>) -> tensor<4xf64> {
+  // expected-error at +1 {{Mesh axes contains duplicate elements.}}
+  %0 = mesh.all_reduce %arg0 {
+    mesh = @mesh0, mesh_axes = array<i16: 0, 1, 0>, reduction = #mesh.partial<sum>
+    } : tensor<4xf32> -> tensor<4xf64>
+  return %0 : tensor<4xf64>
+}
+
+// -----
+
+mesh.cluster @mesh0(rank = 2, dim_sizes = [2, 4])
+
 func.func @all_reduce_invalid_tensor_dimension_size(
     %arg0 : tensor<4xf32>) -> tensor<5xf64> {
   // expected-error at +1 {{'mesh.all_reduce' op requires the same shape for all operands and results}}
@@ -129,6 +142,19 @@ func.func @all_gather_invalid_mesh_axis(
 
 // -----
 
+mesh.cluster @mesh0(rank = 2, dim_sizes = [2, 4])
+
+func.func @all_reduce_duplicate_mesh_axis(
+    %arg0 : tensor<4xf32>) -> tensor<4xf32> {
+  // expected-error at +1 {{Mesh axes contains duplicate elements.}}
+  %0 = mesh.all_gather %arg0 {
+    mesh = @mesh0, mesh_axes = array<i16: 2, 2>, gather_axis = 0 : index
+    } : tensor<4xf32> -> tensor<4xf32>
+  return %0 : tensor<4xf32>
+}
+
+// -----
+
 mesh.cluster @mesh0(rank = 1, dim_sizes = [1])
 
 func.func @all_gather_invalid_non_gather_axis_dimension_size(
@@ -195,7 +221,7 @@ func.func @all_gather_invalid_negative_gather_axis(
 
 // -----
 
-func.func @all_to_all_gather_invalid_mesh_symbol(
+func.func @all_to_all_invalid_mesh_symbol(
     %arg0 : tensor<3x6xi8>) -> tensor<3x6xi8> {
   // expected-error at +1 {{Undefined required mesh symbol "this_mesh_symbol_does_not_exist".}}
   %0 = mesh.all_to_all %arg0 {
@@ -206,6 +232,19 @@ func.func @all_to_all_gather_invalid_mesh_symbol(
 
 // -----
 
+mesh.cluster @mesh0(rank = 1, dim_sizes = [1])
+
+func.func @all_to_all_duplicate_mesh_axis(
+    %arg0 : tensor<3x6xi8>) -> tensor<3x6xi8> {
+  // expected-error at +1 {{Mesh axes contains duplicate elements.}}
+  %0 = mesh.all_to_all %arg0 {
+      concat_axis = 0, mesh = @mesh0, mesh_axes = array<i16: 0, 0>, split_axis = 0
+    } : tensor<3x6xi8> -> tensor<3x6xi8>
+  return %0 : tensor<3x6xi8>
+}
+
+// -----
+
 mesh.cluster @mesh0(rank = 2, dim_sizes = [0, 1])
 
 func.func @all_to_all_invalid_non_dynamic_result_dimension_induced_by_dynamic_device_group(
@@ -273,6 +312,19 @@ func.func @all_to_all_invalid_non_dynamic_result_split_dimension_size(
 
 mesh.cluster @mesh0(rank = 1, dim_sizes = [3])
 
+func.func @reduce_scatter_duplicate_mesh_axis(
+    %arg0 : tensor<?xf32>) -> tensor<?xf64> {
+  // expected-error at +1 {{Mesh axes contains duplicate elements.}}
+  %0 = mesh.reduce_scatter %arg0 {
+      mesh = @mesh0, scatter_axis = 0, mesh_axes = array<i16: 0, 0>
+    } : tensor<?xf32> -> tensor<?xf64>
+  return %0 : tensor<?xf64>
+}
+
+// -----
+
+mesh.cluster @mesh0(rank = 1, dim_sizes = [3])
+
 func.func @reduce_scatter_invalid_dynamic_dimension(
     %arg0 : tensor<?xf32>) -> tensor<2xf64> {
   // expected-error at +1 {{Dimension size mismatch for result axis 0. Expected dynamic, but got 2.}}



More information about the Mlir-commits mailing list