[Mlir-commits] [mlir] [mlir][mesh] Add collective communication operations (PR #71960)

Boian Petkantchin llvmlistbot at llvm.org
Thu Nov 16 16:46:33 PST 2023


https://github.com/sogartar updated https://github.com/llvm/llvm-project/pull/71960

>From 5b8c7f70202f7cac6b6b1a29e168fe6fa43fc720 Mon Sep 17 00:00:00 2001
From: Boian Petkantchin <boian at nod-labs.com>
Date: Mon, 6 Nov 2023 15:55:31 -0800
Subject: [PATCH 01/10] [mlir][mesh] Add collective communication operations

Add all-gather, all-reduce, all-to-all and reduce-scatter.
These operations have device mesh semantics.
---
 mlir/docs/Dialects/Mesh.md                    |  34 ++
 mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td |   8 +-
 mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h   |   2 +
 mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td  | 216 +++++++++
 mlir/lib/Dialect/Mesh/IR/MeshOps.cpp          | 417 ++++++++++++++++++
 mlir/test/Dialect/Mesh/canonicalization.mlir  |  72 +++
 mlir/test/Dialect/Mesh/invalid.mlir           | 240 ++++++++++
 mlir/test/Dialect/Mesh/ops.mlir               | 119 +++++
 8 files changed, 1105 insertions(+), 3 deletions(-)
 create mode 100644 mlir/docs/Dialects/Mesh.md
 create mode 100644 mlir/test/Dialect/Mesh/canonicalization.mlir

diff --git a/mlir/docs/Dialects/Mesh.md b/mlir/docs/Dialects/Mesh.md
new file mode 100644
index 000000000000000..6dd4f79022061ee
--- /dev/null
+++ b/mlir/docs/Dialects/Mesh.md
@@ -0,0 +1,34 @@
+# 'mesh' Dialect
+
+The `mesh` dialect contains a set of attributes, operations and interfaces that
+are useful for representing sharding and communication on a device mesh
+cluster.
+
+[TOC]
+
+## Collective Communication Operations
+There are a number of operations in the Mesh dialect to facilitate
+communication between devices in a mesh.
+It is assumed that the user is familiar with collective operations.
+[Wikipedia](https://en.wikipedia.org/wiki/Collective_operation) has a good
+explanation.
+The main addition is that the collectives in this dialect have mesh
+semantics.
+The operation attributes `mesh` and `mesh_axes` specifies a set of device mesh
+axes that partition the devices into disjoint groups.
+The collective operation is performed between devices in the same group.
+Devices that have the same coordinates outside of axes `mesh_axes` are in the
+same group.
+For example if we have a device mesh of size `2x3x4x5` and the partition mesh
+axes set is `{0, 1}` then devices are partitioned into the groups
+`{ { (i, j, k, m) | 0<=i<2, 0<=j<3 } | 0<=k<4, 0<=m<5 }`.
+Devices (1, 0, 2, 3) and (1, 1, 2, 3) will be in the same group.
+Device (1, 0, 2, 4) will be in another group.
+
+## Operations
+
+[include "Dialects/MeshOps.md"]
+
+## Attributes
+
+[include "Dialects/MeshAttributes.md"]
diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td b/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td
index a91ef569347bff1..9d39b1b3329fb4b 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td
@@ -23,9 +23,7 @@ def Mesh_Dialect : Dialect {
   let cppNamespace = "::mlir::mesh";
 
   let description = [{
-    The `mesh` dialect contains a set of attributes, operations, interfaces that
-    are useful for representing sharding and communication on device mesh
-    cluster.
+    See [Mesh dialect documentation](mlir/docs/Dialects/Mesh.md).
   }];
 
   let dependentDialects = [
@@ -49,6 +47,10 @@ def Mesh_Partial : I32EnumAttr<"Partial", "partial type of a distributed tensor"
   let cppNamespace = "::mlir::mesh";
 }
 
+def Mesh_PartialAttr : EnumAttr<Mesh_Dialect, Mesh_Partial, "partial"> {
+  let assemblyFormat = "`<` $value `>`";
+}
+
 // Mesh_IteratorType and Mesh_Partial are used to annotate different aspects of
 // distributed tensors. Mesh_IteratorType annotates loops in an operation, while
 // Mesh_Partial indicates whether a tensor is sharded on a specific dimension or
diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
index 05eba66a89949b6..7698d60813a8f10 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
@@ -10,9 +10,11 @@
 #define MLIR_DIALECT_MESH_IR_MESHOPS_H
 
 #include "mlir/Bytecode/BytecodeOpInterface.h"
+#include "mlir/IR/OpDefinition.h"
 #include "mlir/IR/SymbolTable.h"
 #include "mlir/Interfaces/InferTypeOpInterface.h"
 #include "mlir/Interfaces/SideEffectInterfaces.h"
+#include <algorithm>
 
 #include "mlir/Dialect/Mesh/IR/MeshOpsDialect.h.inc"
 
diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
index a8aa0a694bee29f..15354babe870599 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
@@ -13,6 +13,8 @@ include "mlir/Dialect/Mesh/IR/MeshBase.td"
 include "mlir/Interfaces/InferTypeOpInterface.td"
 include "mlir/Interfaces/SideEffectInterfaces.td"
 include "mlir/IR/BuiltinTypes.td"
+include "mlir/IR/CommonAttrConstraints.td"
+include "mlir/IR/CommonTypeConstraints.td"
 include "mlir/IR/SymbolInterfaces.td"
 
 //===----------------------------------------------------------------------===//
@@ -77,6 +79,15 @@ def Mesh_ClusterOp : Mesh_Op<"cluster", [Symbol]> {
     $sym_name `(` `rank` `=` $rank (`,` `dim_sizes` `=` $dim_sizes^)? `)`
       attr-dict
   }];
+  let extraClassDeclaration = [{
+    ::mlir::SmallVector<int64_t> canonicalDimSizes();
+
+    template <typename OutIt>
+    void canonicalDimSizes(OutIt outIt) {
+      std::copy(getDimSizes().begin(), getDimSizes().end(), outIt);
+      std::fill_n(outIt, getRank() - getDimSizes().size(), 0);
+    }
+  }];
   let hasVerifier = 1;
 }
 
@@ -171,4 +182,209 @@ def Mesh_ShardOp : Mesh_Op<"shard", [Pure, SameOperandsAndResultType]> {
   }];
 }
 
+//===----------------------------------------------------------------------===//
+// collective communication ops
+//===----------------------------------------------------------------------===//
+
+class Mesh_CollectiveCommunicationOpBase<
+    string mnemonic, list<Trait> traits = []> :
+    Mesh_Op<mnemonic,
+      !listconcat(traits,
+      [SymbolUserOpInterface])> {
+  let assemblyFormat = "$input attr-dict `:` type($input) `->` type($result)";
+  code extraClassDeclarationBase = [{
+    ::mlir::LogicalResult verifySymbolUses(
+          ::mlir::SymbolTableCollection &symbolTable);
+  }];
+}
+
+def Mesh_AllGatherOp : Mesh_CollectiveCommunicationOpBase<"all_gather", [
+    SameOperandsAndResultElementType,
+    SameOperandsAndResultRank
+  ]> {
+  let summary = "All-gather over a device mesh.";
+  let description = [{
+    Gathers along the `gather_axis` tensor axis.
+    The order of input tensors in the resulting tensor is the same as the
+    order of the corresponding devices' multi-index in the mesh.
+
+    Example:
+    ```mlir
+    mesh.cluster @mesh0(rank = 2, dim_sizes = [2, 2])
+    ...
+    %1 = mesh.all_gather %0 {
+        mesh = @mesh0, mesh_axes = array<i16: 1>, gather_axis = 1 : index
+      } : tensor<2x2xi8> -> tensor<2x4xi8>
+    ```
+    Input:
+    ```
+                     +-------+-------+
+    device (0, 0) -> |  1  2 |  5  6 | <- device (0, 1)
+                     |  3  4 |  7  8 |
+                     +-------+-------+
+    device (1, 0) -> |  9 10 | 13 14 | <- device (1, 1)
+                     | 11 12 | 15 16 |
+                     +-------+-------+
+    ```
+    Result:
+    ```
+    +-------------+
+    |  1  2  5  6 | <- devices (0, 0) and (0, 1)
+    |  3  4  7  8 |
+    +-------------+
+    |  9 10 13 14 | <- devices (1, 0) and (1, 1)
+    | 11 12 15 16 |
+    +-------------+
+    ```
+  }];
+  let arguments = (ins
+    AnyNon0RankedTensor:$input,
+    FlatSymbolRefAttr:$mesh,
+    DefaultValuedOptionalAttr<DenseI16ArrayAttr, "{}">:$mesh_axes,
+    APIntAttr:$gather_axis
+  );
+  let results = (outs
+    AnyNon0RankedTensor:$result
+  );
+  let hasCanonicalizer = 1;
+  let hasVerifier = 1;
+  let extraClassDeclaration = extraClassDeclarationBase;
+}
+
+def Mesh_AllReduceOp : Mesh_CollectiveCommunicationOpBase<"all_reduce", [
+    SameOperandsAndResultShape]> {
+  let summary = "All-reduce over a device mesh.";
+  let description = [{
+    The accumulation element type is specified by the result type and
+    it does not need to match the input element type.
+    The input element is converted to the result element type before
+    performing the reduction.
+
+    Attributes:
+    `reduction`: Indicates the reduction method.
+
+    Example:
+    ```
+    %1 = mesh.all_reduce %0 {
+        mesh = @mesh0, mesh_axes = array<i16: 1, 0>, reduction = #mesh.partial<max>
+      } : tensor<3x4xf32> -> tensor<3x4xf64>
+    ```
+  }];
+  let arguments = (ins
+    AnyRankedTensor:$input,
+    FlatSymbolRefAttr:$mesh,
+    DefaultValuedOptionalAttr<DenseI16ArrayAttr, "{}">:$mesh_axes,
+    DefaultValuedOptionalAttr<Mesh_PartialAttr, "::mlir::mesh::Partial::Sum">:$reduction
+  );
+  let results = (outs
+    AnyRankedTensor:$result
+  );
+  let hasCanonicalizer = 1;
+  let extraClassDeclaration = extraClassDeclarationBase;
+}
+
+def Mesh_AllToAllOp : Mesh_CollectiveCommunicationOpBase<"all_to_all", [
+    SameOperandsAndResultElementType,
+    SameOperandsAndResultRank]> {
+  let summary = "All-to-all over a device mesh.";
+  let description = [{
+    Performs an all-to-all on tensor pieces split along `split_axis`.
+    The resulting pieces are concatenated along `concat_axis` on ech device.
+    Example:
+    ```
+    mesh.cluster @mesh0(rank = 1, dim_sizes = [3])
+    ...
+    %1 = mesh.all_to_all %0 {
+        mesh = @mesh0, mesh_axes = array<i16: 0>, split_axis = 0, concat_axis = 0
+      } : tensor<3x6xi8> -> tensor<3x6xi8>
+    ```
+    Input:
+    ```
+     device  device  device
+     (0)     (1)     (2)
+    +-------+-------+-------+
+    | 11 12 | 21 22 | 31 32 |
+    | 13 14 | 23 24 | 33 34 |
+    | 15 16 | 25 26 | 35 36 |
+    +-------+-------+-------+
+    ```
+    Result:
+    ```
+     device  device  device
+     (0)     (1)     (2)
+    +-------+-------+-------+
+    | 11 12 | 13 14 | 15 16 |
+    | 21 22 | 23 24 | 25 26 |
+    | 31 32 | 33 34 | 35 36 |
+    +-------+-------+-------+
+    ```
+  }];
+  let arguments = (ins
+    AnyNon0RankedTensor:$input,
+    FlatSymbolRefAttr:$mesh,
+    DefaultValuedOptionalAttr<DenseI16ArrayAttr, "{}">:$mesh_axes,
+    APIntAttr:$split_axis,
+    APIntAttr:$concat_axis
+  );
+  let results = (outs
+    AnyNon0RankedTensor:$result
+  );
+  let hasCanonicalizer = 1;
+  let hasVerifier = 1;
+  let extraClassDeclaration = extraClassDeclarationBase;
+}
+
+def Mesh_ReduceScatterOp : Mesh_CollectiveCommunicationOpBase<"reduce_scatter", [
+    SameOperandsAndResultRank]> {
+  let summary = "Reduce-scatter over a device mesh.";
+  let description = [{
+    After the reduction scatters the result within each device group.
+    The tensor is split along `scatter_axis` and the pieces distributed
+    across the device group.
+    Example:
+    ```
+    mesh.cluster @mesh0(rank = 1, dim_sizes = [2, 2])
+    ...
+    %1 = mesh.reduce_scatter %0 {
+        mesh = @mesh0, mesh_axes = array<i16: 1>, reduction = #mesh.partial<max>, scatter_axis = 0
+      } : tensor<3x4xf32> -> tensor<1x4xf64>
+    ```
+    Input:
+    ```
+                     +-------+-------+
+    device (0, 0) -> |  1  2 |  5  6 | <- device (0, 1)
+                     |  3  4 |  7  8 |
+                     +-------+-------+
+    device (1, 0) -> |  9 10 | 13 14 | <- device (1, 1)
+                     | 11 12 | 15 16 |
+                     +-------+-------+
+    ```
+    Result:
+    ```
+    +-------+
+    |  6  8 | <- devices (0, 0)
+    +-------+
+    | 10 12 | <- devices (0, 1)
+    +-------+
+    | 22 24 | <- devices (1, 0)
+    +-------+
+    | 26 28 | <- devices (1, 1)
+    +-------+
+    ```
+  }];
+  let arguments = (ins
+    AnyNon0RankedTensor:$input,
+    FlatSymbolRefAttr:$mesh,
+    DefaultValuedOptionalAttr<DenseI16ArrayAttr, "{}">:$mesh_axes,
+    DefaultValuedOptionalAttr<Mesh_PartialAttr, "::mlir::mesh::Partial::Sum">:$reduction,
+    APIntAttr:$scatter_axis
+  );
+  let results = (outs
+    AnyRankedTensor:$result
+  );
+  let hasCanonicalizer = 1;
+  let hasVerifier = 1;
+  let extraClassDeclaration = extraClassDeclarationBase;
+}
+
 #endif // MLIR_DIALECT_MESH_IR_MESHOPS_TD
diff --git a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
index 588704f24574f90..6efc4c4ecc326ad 100644
--- a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
+++ b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
@@ -8,10 +8,26 @@
 
 #include "mlir/Dialect/Mesh/IR/MeshOps.h"
 #include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/BuiltinTypeInterfaces.h"
+#include "mlir/IR/Diagnostics.h"
 #include "mlir/IR/DialectImplementation.h"
+#include "mlir/IR/Location.h"
+#include "mlir/IR/PatternMatch.h"
 #include "mlir/Support/LLVM.h"
+#include "mlir/Support/LogicalResult.h"
+#include "llvm/ADT/DenseSet.h"
+#include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/SmallSet.h"
+#include "llvm/ADT/SmallVector.h"
 #include "llvm/ADT/TypeSwitch.h"
+#include <algorithm>
+#include <functional>
+#include <iterator>
+#include <numeric>
+#include <optional>
+#include <string>
+#include <utility>
 
 #define DEBUG_TYPE "mesh-ops"
 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ")
@@ -21,6 +37,60 @@ using namespace mlir::mesh;
 
 #include "mlir/Dialect/Mesh/IR/MeshOpsDialect.cpp.inc"
 
+namespace {
+
+template <typename It>
+It canonicalizeSetAsArray(It begin, It end) {
+  std::sort(begin, end);
+  return std::unique(begin, end);
+}
+
+template <typename R>
+auto canonicalizeSetAsArray(R &&range) {
+  return canonicalizeSetAsArray(adl_begin(range), adl_end(range));
+}
+
+template <typename T>
+SmallVector<T> &canonicalizeSetAsVector(SmallVector<T> &vec) {
+  auto newEnd = canonicalizeSetAsArray(vec);
+  vec.resize(newEnd - vec.begin());
+  return vec;
+}
+
+template <typename DimSize>
+bool isMeshDimensionDynamic(DimSize size) {
+  return size <= DimSize(0);
+}
+
+using MeshAxis = int16_t;
+
+struct DimensionSize {
+  static DimensionSize dynamic() { return DimensionSize(ShapedType::kDynamic); }
+  DimensionSize(int64_t val) : val(val) {}
+  int64_t value() const { return val; }
+  operator int64_t() const { return val; }
+  bool isDynamic() const { return ShapedType::isDynamic(val); }
+
+private:
+  int64_t val;
+};
+
+DimensionSize operator/(DimensionSize lhs, DimensionSize rhs) {
+  if (lhs.isDynamic() || rhs.isDynamic()) {
+    return DimensionSize::dynamic();
+  }
+  return lhs.value() / rhs.value();
+}
+
+DimensionSize operator*(DimensionSize lhs, DimensionSize rhs) {
+  if (lhs.isDynamic() || rhs.isDynamic()) {
+    return DimensionSize::dynamic();
+  }
+  return lhs.value() * rhs.value();
+}
+
+} // namespace
+
 //===----------------------------------------------------------------------===//
 // Mesh dialect
 //===----------------------------------------------------------------------===//
@@ -96,6 +166,12 @@ LogicalResult ClusterOp::verify() {
   return success();
 }
 
+SmallVector<int64_t> ClusterOp::canonicalDimSizes() {
+  SmallVector<int64_t> result;
+  canonicalDimSizes(std::back_inserter(result));
+  return result;
+}
+
 //===----------------------------------------------------------------------===//
 // mesh.shard op
 //===----------------------------------------------------------------------===//
@@ -129,6 +205,347 @@ MeshShardingAttr::verify(function_ref<InFlightDiagnostic()> emitError,
   return success();
 }
 
+//===----------------------------------------------------------------------===//
+// collective communication ops
+//===----------------------------------------------------------------------===//
+
+namespace {
+
+std::optional<DenseI16ArrayAttr>
+canonicalizeAxesSetAttribute(DenseI16ArrayAttr attr) {
+  if (!attr) {
+    return std::nullopt;
+  }
+  SmallVector<int16_t> axes = llvm::to_vector(attr.asArrayRef());
+  canonicalizeSetAsVector(axes);
+  if (axes.empty()) {
+    return std::nullopt;
+  }
+  return DenseI16ArrayAttr::get(attr.getContext(), axes);
+}
+
+template <typename Op>
+struct AxesSetCanonicalizationPattern : OpRewritePattern<Op> {
+  AxesSetCanonicalizationPattern(MLIRContext *context, StringRef axisSetAttr)
+      : OpRewritePattern<Op>(context), axisSetAttr(axisSetAttr) {}
+  LogicalResult matchAndRewrite(Op op,
+                                PatternRewriter &rewriter) const override {
+    auto canonicalMeshAxesAttr = canonicalizeAxesSetAttribute(
+        op->template getAttrOfType<DenseI16ArrayAttr>(axisSetAttr));
+    if (!canonicalMeshAxesAttr) {
+      op->removeAttr(axisSetAttr);
+    } else {
+      op->setAttr(axisSetAttr, canonicalMeshAxesAttr.value());
+    }
+    return success();
+  }
+
+  std::string axisSetAttr;
+};
+
+template <typename Op>
+void populateMeshAxesSetCanonicalizationPatterns(RewritePatternSet &patterns,
+                                                 MLIRContext *context) {
+  patterns.add<AxesSetCanonicalizationPattern<Op>>(context, "mesh_axes");
+}
+
+template <typename Op>
+LogicalResult verifyMeshSymbolUses(Op op, SymbolTableCollection &symbolTable) {
+  FlatSymbolRefAttr symbolAttr = op.getMeshAttr();
+  if (!symbolAttr) {
+    return op.emitError() << "Unspecified \"mesh\" symbol attribute.";
+  }
+  SymbolTableCollection symbolTableCollection;
+  mesh::ClusterOp mesh =
+      symbolTableCollection.lookupNearestSymbolFrom<mesh::ClusterOp>(
+          op.getOperation(), symbolAttr);
+  if (!mesh) {
+    return op.emitError() << "Undefined required mesh symbol \""
+                          << symbolAttr.getValue() << "\".";
+  }
+  DenseI16ArrayAttr meshAxes = op.getMeshAxesAttr();
+  if (!meshAxes) {
+    return success();
+  }
+  MeshAxis rank = mesh.getRank();
+  for (auto axis : meshAxes.asArrayRef()) {
+    if (axis >= rank || axis < 0) {
+      return op.emitError()
+             << "0-based mesh axis index " << axis
+             << " is out of bounds. The referenced mesh \""
+             << symbolAttr.getValue() << "\" is of rank " << rank << ".";
+    }
+  }
+
+  return success();
+}
+
+template <typename It>
+auto product(It begin, It end) {
+  using ElementType = std::decay_t<decltype(*begin)>;
+  return std::accumulate(begin, end, ElementType(1),
+                         std::multiplies<ElementType>());
+}
+
+template <typename R>
+auto product(R &&range) {
+  return product(adl_begin(range), adl_end(range));
+}
+
+int64_t collectiveDeviceGroupSize(ArrayRef<MeshAxis> meshAxes,
+                                  ArrayRef<int64_t> meshShape) {
+  int64_t res = 1;
+  for (MeshAxis axis = 0; axis < MeshAxis(meshShape.size()); ++axis) {
+    if (llvm::find(meshAxes, axis) == meshAxes.end()) {
+      continue;
+    }
+    if (isMeshDimensionDynamic(meshShape[axis])) {
+      return ShapedType::kDynamic;
+    }
+    res *= meshShape[axis];
+  }
+  return res;
+}
+
+LogicalResult verifyDimensionCompatibility(Location loc,
+                                           int64_t expectedDimSize,
+                                           int64_t resultDimSize,
+                                           int64_t resultAxis) {
+  if (!ShapedType::isDynamic(resultDimSize) &&
+      expectedDimSize != resultDimSize) {
+    return emitError(loc) << "Dimension size mismatch for result axis "
+                          << resultAxis << ". Expected "
+                          << (ShapedType::isDynamic(expectedDimSize)
+                                  ? Twine("dynamic")
+                                  : Twine(expectedDimSize))
+                          << ", but got " << resultDimSize << ".";
+  }
+
+  return success();
+}
+
+LogicalResult verifyGatherOperandAndResultShape(Value operand, Value result,
+                                                int64_t gatherAxis,
+                                                ArrayRef<MeshAxis> meshAxes,
+                                                ArrayRef<int64_t> meshShape) {
+  ShapedType operandType = operand.getType().cast<ShapedType>();
+  ShapedType resultType = result.getType().cast<ShapedType>();
+  auto deviceGroupSize =
+      DimensionSize(collectiveDeviceGroupSize(meshAxes, meshShape));
+  for (int64_t axis = 0; axis < operandType.getRank(); ++axis) {
+    auto operandDimSize = DimensionSize(operandType.getDimSize(axis));
+    auto resultDimSize = DimensionSize(resultType.getDimSize(axis));
+    auto expectedResultDimSize =
+        axis == gatherAxis ? deviceGroupSize * operandDimSize : operandDimSize;
+    if (failed(verifyDimensionCompatibility(
+            result.getLoc(), expectedResultDimSize, resultDimSize, axis))) {
+      return failure();
+    }
+  }
+  return success();
+}
+
+template <typename Op>
+FailureOr<ClusterOp> getMesh(Op op) {
+  SymbolTableCollection symbolTableCollection;
+  if (failed(verifyMeshSymbolUses(op, symbolTableCollection))) {
+    // We need to check the symbol here since this runs before
+    // SymbolUserOpInterface.
+    return failure();
+  }
+  return symbolTableCollection.lookupNearestSymbolFrom<mesh::ClusterOp>(
+      op.getOperation(), op.getMeshAttr());
+}
+
+template <typename Op>
+LogicalResult verifyGather(Op op) {
+  auto rank = op.getResult().getType().template cast<ShapedType>().getRank();
+  auto gatherAxis = op.getGatherAxis().getSExtValue();
+  if (gatherAxis < 0 || gatherAxis >= rank) {
+    return op.emitError() << "Gather axis " << gatherAxis
+                          << " is out of bounds [0, " << rank << ").";
+  }
+
+  auto mesh = getMesh(op);
+  if (failed(mesh)) {
+    return failure();
+  }
+  return verifyGatherOperandAndResultShape(op.getOperand(), op.getResult(),
+                                           gatherAxis, op.getMeshAxes(),
+                                           mesh.value().canonicalDimSizes());
+}
+
+LogicalResult verifyAllToAllOperandAndResultShape(Value operand, Value result,
+                                                  int64_t splitAxis,
+                                                  int64_t concatAxis,
+                                                  ArrayRef<MeshAxis> meshAxes,
+                                                  ArrayRef<int64_t> meshShape) {
+  ShapedType operandType = operand.getType().cast<ShapedType>();
+  ShapedType resultType = result.getType().cast<ShapedType>();
+  for (int64_t axis = 0; axis < operandType.getRank(); ++axis) {
+    if ((axis != splitAxis && axis != concatAxis) || splitAxis == concatAxis) {
+      if (failed(verifyDimensionCompatibility(
+              result.getLoc(), operandType.getDimSize(axis),
+              resultType.getDimSize(axis), axis))) {
+        return failure();
+      }
+    }
+  }
+
+  if (splitAxis == concatAxis) {
+    return success();
+  }
+
+  auto deviceGroupSize =
+      DimensionSize(collectiveDeviceGroupSize(meshAxes, meshShape));
+  auto operandConcatDimSize = DimensionSize(operandType.getDimSize(concatAxis));
+  auto operandSplitDimSize = DimensionSize(operandType.getDimSize(splitAxis));
+  if (!operandSplitDimSize.isDynamic() && !deviceGroupSize.isDynamic() &&
+      int64_t(operandSplitDimSize) % int64_t(deviceGroupSize) != 0) {
+    return emitError(result.getLoc())
+           << "Operand dimension size " << int64_t(operandSplitDimSize)
+           << " is not divisible by collective device group size "
+           << int64_t(deviceGroupSize) << " for split axis " << splitAxis
+           << ".";
+  }
+  DimensionSize expectedResultConcatDimSize =
+      operandConcatDimSize * deviceGroupSize;
+  DimensionSize expectedResultSplitDimSize =
+      operandSplitDimSize / deviceGroupSize;
+  if (failed(verifyDimensionCompatibility(
+          result.getLoc(), expectedResultConcatDimSize.value(),
+          resultType.getDimSize(concatAxis), concatAxis))) {
+    return failure();
+  }
+  if (failed(verifyDimensionCompatibility(
+          result.getLoc(), expectedResultSplitDimSize.value(),
+          resultType.getDimSize(splitAxis), splitAxis))) {
+    return failure();
+  }
+
+  return success();
+}
+
+LogicalResult verifyReduceScatterOperandAndResultShape(
+    Value operand, Value result, int64_t scatterAxis,
+    ArrayRef<MeshAxis> meshAxes, ArrayRef<int64_t> meshShape) {
+  ShapedType operandType = operand.getType().cast<ShapedType>();
+  ShapedType resultType = result.getType().cast<ShapedType>();
+  for (int64_t axis = 0; axis < operandType.getRank(); ++axis) {
+    if (axis != scatterAxis) {
+      if (failed(verifyDimensionCompatibility(
+              result.getLoc(), operandType.getDimSize(axis),
+              resultType.getDimSize(axis), axis))) {
+        return failure();
+      }
+    }
+  }
+
+  auto deviceGroupSize =
+      DimensionSize(collectiveDeviceGroupSize(meshAxes, meshShape));
+  auto operandScatterDimSize =
+      DimensionSize(operandType.getDimSize(scatterAxis));
+  if (!operandScatterDimSize.isDynamic() && !deviceGroupSize.isDynamic() &&
+      int64_t(operandScatterDimSize) % int64_t(deviceGroupSize) != 0) {
+    return emitError(result.getLoc())
+           << "Operand dimension size " << int64_t(operandScatterDimSize)
+           << " is not divisible by collective device group size "
+           << int64_t(deviceGroupSize) << " for scatter axis " << scatterAxis
+           << ".";
+  }
+  DimensionSize expectedResultScatterDimSize =
+      operandScatterDimSize / deviceGroupSize;
+  if (failed(verifyDimensionCompatibility(
+          result.getLoc(), expectedResultScatterDimSize.value(),
+          resultType.getDimSize(scatterAxis), scatterAxis))) {
+    return failure();
+  }
+
+  return success();
+}
+
+} // namespace
+
+//===----------------------------------------------------------------------===//
+// mesh.all_reduce op
+//===----------------------------------------------------------------------===//
+
+LogicalResult
+mlir::mesh::AllReduceOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
+  return verifyMeshSymbolUses(*this, symbolTable);
+}
+
+void mlir::mesh::AllReduceOp::getCanonicalizationPatterns(
+    RewritePatternSet &patterns, MLIRContext *context) {
+  populateMeshAxesSetCanonicalizationPatterns<AllReduceOp>(patterns, context);
+}
+
+//===----------------------------------------------------------------------===//
+// mesh.all_gather op
+//===----------------------------------------------------------------------===//
+
+LogicalResult
+mlir::mesh::AllGatherOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
+  return verifyMeshSymbolUses(*this, symbolTable);
+}
+
+void mlir::mesh::AllGatherOp::getCanonicalizationPatterns(
+    RewritePatternSet &patterns, MLIRContext *context) {
+  populateMeshAxesSetCanonicalizationPatterns<AllGatherOp>(patterns, context);
+}
+
+LogicalResult mlir::mesh::AllGatherOp::verify() { return verifyGather(*this); }
+
+//===----------------------------------------------------------------------===//
+// mesh.all_to_all op
+//===----------------------------------------------------------------------===//
+
+LogicalResult
+mlir::mesh::AllToAllOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
+  return verifyMeshSymbolUses(*this, symbolTable);
+}
+
+void mlir::mesh::AllToAllOp::getCanonicalizationPatterns(
+    RewritePatternSet &patterns, MLIRContext *context) {
+  populateMeshAxesSetCanonicalizationPatterns<AllToAllOp>(patterns, context);
+}
+
+LogicalResult mlir::mesh::AllToAllOp::verify() {
+  auto mesh = ::getMesh(*this);
+  if (failed(mesh)) {
+    return failure();
+  }
+  return verifyAllToAllOperandAndResultShape(
+      getOperand(), getResult(), getSplitAxis().getSExtValue(),
+      getConcatAxis().getSExtValue(), getMeshAxes(),
+      mesh.value().canonicalDimSizes());
+}
+
+//===----------------------------------------------------------------------===//
+// mesh.reduce_scatter op
+//===----------------------------------------------------------------------===//
+
+LogicalResult mlir::mesh::ReduceScatterOp::verifySymbolUses(
+    SymbolTableCollection &symbolTable) {
+  return verifyMeshSymbolUses(*this, symbolTable);
+}
+
+void mlir::mesh::ReduceScatterOp::getCanonicalizationPatterns(
+    RewritePatternSet &patterns, MLIRContext *context) {
+  populateMeshAxesSetCanonicalizationPatterns<ReduceScatterOp>(patterns,
+                                                               context);
+}
+
+LogicalResult mlir::mesh::ReduceScatterOp::verify() {
+  auto mesh = ::getMesh(*this);
+  if (failed(mesh)) {
+    return failure();
+  }
+  return verifyReduceScatterOperandAndResultShape(
+      getOperand(), getResult(), getScatterAxis().getSExtValue(), getMeshAxes(),
+      mesh.value().canonicalDimSizes());
+}
+
 //===----------------------------------------------------------------------===//
 // TableGen'd op method definitions
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Mesh/canonicalization.mlir b/mlir/test/Dialect/Mesh/canonicalization.mlir
new file mode 100644
index 000000000000000..21fa1887d14a53e
--- /dev/null
+++ b/mlir/test/Dialect/Mesh/canonicalization.mlir
@@ -0,0 +1,72 @@
+// RUN: mlir-opt --canonicalize %s | FileCheck %s
+
+mesh.cluster @mesh0(rank = 2, dim_sizes = [2, 4])
+
+// CHECK-LABEL: func @all_reduce_mesh_axes
+func.func @all_reduce_mesh_axes(
+    %arg0 : tensor<4xf32>) -> tensor<4xf64> {
+// CHECK: mesh_axes = array<i16: 0, 1>
+  %0 = mesh.all_reduce %arg0 {
+    mesh = @mesh0, mesh_axes = array<i16: 1, 0, 0>, reduction = #mesh.partial<sum>
+    } : tensor<4xf32> -> tensor<4xf64>
+  return %0 : tensor<4xf64>
+}
+
+// CHECK-LABEL: func @all_reduce_empty_mesh_axes_and_default_reduction
+func.func @all_reduce_empty_mesh_axes_and_default_reduction(
+    %arg0 : tensor<4xf32>) -> tensor<4xf64> {
+  %0 = mesh.all_reduce %arg0 {
+    mesh = @mesh0,
+// CHECK-NOT: mesh_axes
+    mesh_axes = array<i16>,
+// CHECK-NOT: reduction
+    reduction = #mesh.partial<sum>
+    } : tensor<4xf32> -> tensor<4xf64>
+  return %0 : tensor<4xf64>
+}
+
+// CHECK-LABEL: func @all_gather_empty_mesh_axes
+func.func @all_gather_empty_mesh_axes(
+    %arg0 : tensor<4xf32>) -> tensor<4xf32> {
+  %0 = mesh.all_gather %arg0 {
+    gather_axis = 0 : index,
+    mesh = @mesh0,
+// CHECK-NOT: mesh_axes
+    mesh_axes = array<i16>
+    } : tensor<4xf32> -> tensor<4xf32>
+  return %0 : tensor<4xf32>
+}
+
+// CHECK-LABEL: func @all_gather_mesh_axes
+func.func @all_gather_mesh_axes(
+    %arg0 : tensor<4xf32>) -> tensor<32xf32> {
+// CHECK: mesh_axes = array<i16: 0, 1>
+  %0 = mesh.all_gather %arg0 {
+    mesh = @mesh0, mesh_axes = array<i16: 1, 0, 0>, gather_axis = 0
+    } : tensor<4xf32> -> tensor<32xf32>
+  return %0 : tensor<32xf32>
+}
+
+// CHECK-LABEL: func @reduce_scatter_mesh_axes
+func.func @reduce_scatter_mesh_axes(
+    %arg0 : tensor<8xf32>) -> tensor<1xf64> {
+// CHECK: mesh_axes = array<i16: 0, 1>
+  %0 = mesh.reduce_scatter %arg0 {
+    mesh = @mesh0, mesh_axes = array<i16: 1, 0, 0>, scatter_axis = 0
+    } : tensor<8xf32> -> tensor<1xf64>
+  return %0 : tensor<1xf64>
+}
+
+// CHECK-LABEL: func @reduce_scatter_empty_mesh_axes_and_default_reduction
+func.func @reduce_scatter_empty_mesh_axes_and_default_reduction(
+    %arg0 : tensor<4xf32>) -> tensor<4xf64> {
+  %0 = mesh.reduce_scatter %arg0 {
+    mesh = @mesh0,
+// CHECK-NOT: mesh_axes
+    mesh_axes = array<i16>,
+// CHECK-NOT: reduction
+    reduction = #mesh.partial<sum>,
+    scatter_axis = 0
+    } : tensor<4xf32> -> tensor<4xf64>
+  return %0 : tensor<4xf64>
+}
\ No newline at end of file
diff --git a/mlir/test/Dialect/Mesh/invalid.mlir b/mlir/test/Dialect/Mesh/invalid.mlir
index 246439dd4be7122..413b01459f507a2 100644
--- a/mlir/test/Dialect/Mesh/invalid.mlir
+++ b/mlir/test/Dialect/Mesh/invalid.mlir
@@ -67,3 +67,243 @@ func.func @mesh_axis_negtive_in_partial(
             tensor<4x8xf32, #mesh.shard<@mesh0, [[0]], partial=max[-1]>> {
   return %arg0 : tensor<4x8xf32, #mesh.shard<@mesh0, [[0]], partial=max[-1]>>
 }
+
+// -----
+
+func.func @all_reduce_invalid_mesh_symbol(
+    %arg0 : tensor<4xf32>) -> tensor<4xf64> {
+  // expected-error at +1 {{Undefined required mesh symbol "this_mesh_symbol_does_not_exist".}}
+  %0 = mesh.all_reduce %arg0 {
+    mesh = @this_mesh_symbol_does_not_exist, reduction = #mesh.partial<sum>
+    } : tensor<4xf32> -> tensor<4xf64>
+  return %0 : tensor<4xf64>
+}
+
+// -----
+
+mesh.cluster @mesh0(rank = 2, dim_sizes = [2, 4])
+
+func.func @all_reduce_invalid_mesh_axis(
+    %arg0 : tensor<4xf32>) -> tensor<4xf64> {
+  // expected-error at +1 {{0-based mesh axis index 2 is out of bounds. The referenced mesh "mesh0" is of rank 2.}}
+  %0 = mesh.all_reduce %arg0 {
+    mesh = @mesh0, mesh_axes = array<i16: 2>, reduction = #mesh.partial<sum>
+    } : tensor<4xf32> -> tensor<4xf64>
+  return %0 : tensor<4xf64>
+}
+
+// -----
+
+mesh.cluster @mesh0(rank = 2, dim_sizes = [2, 4])
+
+func.func @all_reduce_invalid_tensor_dimension_size(
+    %arg0 : tensor<4xf32>) -> tensor<5xf64> {
+  // expected-error at +1 {{'mesh.all_reduce' op requires the same shape for all operands and results}}
+  %0 = mesh.all_reduce %arg0 { mesh = @mesh0 } : tensor<4xf32> -> tensor<5xf64>
+  return %0 : tensor<5xf64>
+}
+
+// -----
+
+func.func @all_gather_invalid_mesh_symbol(
+    %arg0 : tensor<4xf32>) -> tensor<4xf32> {
+  // expected-error at +1 {{Undefined required mesh symbol "this_mesh_symbol_does_not_exist".}}
+  %0 = mesh.all_gather %arg0 {
+    mesh = @this_mesh_symbol_does_not_exist, gather_axis = 0 : index
+    } : tensor<4xf32> -> tensor<4xf32>
+  return %0 : tensor<4xf32>
+}
+
+// -----
+
+mesh.cluster @mesh0(rank = 2, dim_sizes = [2, 4])
+
+func.func @all_gather_invalid_mesh_axis(
+    %arg0 : tensor<4xf32>) -> tensor<4xf32> {
+  // expected-error at +1 {{0-based mesh axis index 2 is out of bounds. The referenced mesh "mesh0" is of rank 2.}}
+  %0 = mesh.all_gather %arg0 {
+    mesh = @mesh0, mesh_axes = array<i16: 2>, gather_axis = 0 : index
+    } : tensor<4xf32> -> tensor<4xf32>
+  return %0 : tensor<4xf32>
+}
+
+// -----
+
+mesh.cluster @mesh0(rank = 1, dim_sizes = [1])
+
+func.func @all_gather_invalid_non_gather_axis_dimension_size(
+    %arg0 : tensor<3x4xf32>) -> tensor<3x5xf32> {
+  // expected-error at +1 {{Dimension size mismatch for result axis 1. Expected 4, but got 5.}}
+  %0 = mesh.all_gather %arg0 {
+    mesh = @mesh0, mesh_axes = array<i16: 0>, gather_axis = 0 : index
+    } : tensor<3x4xf32> -> tensor<3x5xf32>
+  return %0 : tensor<3x5xf32>
+}
+
+// -----
+
+mesh.cluster @mesh0(rank = 2, dim_sizes = [1, 2])
+
+func.func @all_gather_invalid_gather_axis_dimension_size(
+    %arg0 : tensor<3x4xf32>) -> tensor<3x5xf32> {
+  // expected-error at +1 {{Dimension size mismatch for result axis 1. Expected 8, but got 5.}}
+  %0 = mesh.all_gather %arg0 {
+    mesh = @mesh0, mesh_axes = array<i16: 1>, gather_axis = 1 : index
+    } : tensor<3x4xf32> -> tensor<3x5xf32>
+  return %0 : tensor<3x5xf32>
+}
+
+// -----
+
+mesh.cluster @mesh0(rank = 1, dim_sizes = [1])
+
+func.func @all_gather_invalid_gather_axis_dynamic_dimension(
+    %arg0 : tensor<?xf32>) -> tensor<3xf32> {
+  // expected-error at +1 {{Dimension size mismatch for result axis 0. Expected dynamic, but got 3.}}
+  %0 = mesh.all_gather %arg0 {
+    gather_axis = 0 : index,
+    mesh = @mesh0
+    } : tensor<?xf32> -> tensor<3xf32>
+  return %0 : tensor<3xf32>
+}
+
+// -----
+
+mesh.cluster @mesh0(rank = 1, dim_sizes = [1])
+
+func.func @all_gather_invalid_gather_axis(
+    %arg0 : tensor<3xf32>) -> tensor<3xf32> {
+  // expected-error at +1 {{Gather axis 1 is out of bounds [0, 1).}}
+  %0 = mesh.all_gather %arg0 {
+    mesh = @mesh0, mesh_axes = array<i16: 0>, gather_axis = 1 : index
+    } : tensor<3xf32> -> tensor<3xf32>
+  return %0 : tensor<3xf32>
+}
+
+// -----
+
+mesh.cluster @mesh0(rank = 1, dim_sizes = [1])
+
+func.func @all_gather_invalid_negative_gather_axis(
+    %arg0 : tensor<3xf32>) -> tensor<3xf32> {
+  // expected-error at +1 {{Gather axis -1 is out of bounds [0, 1).}}
+  %0 = mesh.all_gather %arg0 {
+    mesh = @mesh0, mesh_axes = array<i16: 0>, gather_axis = -1 : index
+    } : tensor<3xf32> -> tensor<3xf32>
+  return %0 : tensor<3xf32>
+}
+
+// -----
+
+func.func @all_to_all_gather_invalid_mesh_symbol(
+    %arg0 : tensor<3x6xi8>) -> tensor<3x6xi8> {
+  // expected-error at +1 {{Undefined required mesh symbol "this_mesh_symbol_does_not_exist".}}
+  %0 = mesh.all_to_all %arg0 {
+      concat_axis = 0, mesh = @this_mesh_symbol_does_not_exist, split_axis = 1
+    } : tensor<3x6xi8> -> tensor<3x6xi8>
+  return %0 : tensor<3x6xi8>
+}
+
+// -----
+
+mesh.cluster @mesh0(rank = 2, dim_sizes = [0, 1])
+
+func.func @all_to_all_invalid_non_dynamic_result_dimension_induced_by_dynamic_device_group(
+    %arg0 : tensor<3x6xi8>) -> tensor<3x6xi8> {
+  // expected-error at +1 {{Dimension size mismatch for result axis 1. Expected dynamic, but got 6.}}
+  %0 = mesh.all_to_all %arg0 {
+      concat_axis = 1, mesh = @mesh0, mesh_axes = array<i16: 0>, split_axis = 0
+    } : tensor<3x6xi8> -> tensor<3x6xi8>
+  return %0 : tensor<3x6xi8>
+}
+
+// -----
+
+mesh.cluster @mesh0(rank = 2, dim_sizes = [1, 1])
+
+func.func @all_to_all_invalid_non_dynamic_result_split_dimension_induced_by_dynamic_operand_dimension(
+    %arg0 : tensor<?x6xi8>) -> tensor<3x?xi8> {
+  // expected-error at +1 {{Dimension size mismatch for result axis 0. Expected dynamic, but got 3.}}
+  %0 = mesh.all_to_all %arg0 {
+      concat_axis = 1, mesh = @mesh0, mesh_axes = array<i16: 1>, split_axis = 0
+    } : tensor<?x6xi8> -> tensor<3x?xi8>
+  return %0 : tensor<3x?xi8>
+}
+
+// -----
+
+mesh.cluster @mesh0(rank = 2, dim_sizes = [1, 1])
+
+func.func @all_to_all_invalid_non_dynamic_result_concat_dimension_induced_by_dynamic_operand_dimension(
+    %arg0 : tensor<3x?xi8>) -> tensor<?x3xi8> {
+  // expected-error at +1 {{Dimension size mismatch for result axis 1. Expected dynamic, but got 3.}}
+  %0 = mesh.all_to_all %arg0 {
+      concat_axis = 1, mesh = @mesh0, mesh_axes = array<i16: 1>, split_axis = 0
+    } : tensor<3x?xi8> -> tensor<?x3xi8>
+  return %0 : tensor<?x3xi8>
+}
+
+// -----
+
+mesh.cluster @mesh0(rank = 1, dim_sizes = [3])
+
+func.func @all_to_all_invalid_non_dynamic_result_concat_dimension_size(
+    %arg0 : tensor<3x2xi8>) -> tensor<1x7xi8> {
+  // expected-error at +1 {{Dimension size mismatch for result axis 1. Expected 6, but got 7.}}
+  %0 = mesh.all_to_all %arg0 {
+      concat_axis = 1, mesh = @mesh0, mesh_axes = array<i16: 0>, split_axis = 0
+    } : tensor<3x2xi8> -> tensor<1x7xi8>
+  return %0 : tensor<1x7xi8>
+}
+
+// -----
+
+mesh.cluster @mesh0(rank = 1, dim_sizes = [3])
+
+func.func @all_to_all_invalid_non_dynamic_result_split_dimension_size(
+    %arg0 : tensor<3x2xi8>) -> tensor<2x6xi8> {
+  // expected-error at +1 {{Dimension size mismatch for result axis 0. Expected 1, but got 2.}}
+  %0 = mesh.all_to_all %arg0 {
+      concat_axis = 1, mesh = @mesh0, mesh_axes = array<i16: 0>, split_axis = 0
+    } : tensor<3x2xi8> -> tensor<2x6xi8>
+  return %0 : tensor<2x6xi8>
+}
+
+// -----
+
+mesh.cluster @mesh0(rank = 1, dim_sizes = [3])
+
+func.func @reduce_scatter_invalid_dynamic_dimension(
+    %arg0 : tensor<?xf32>) -> tensor<2xf64> {
+  // expected-error at +1 {{Dimension size mismatch for result axis 0. Expected dynamic, but got 2.}}
+  %0 = mesh.reduce_scatter %arg0 {
+      mesh = @mesh0, scatter_axis = 0
+    } : tensor<?xf32> -> tensor<2xf64>
+  return %0 : tensor<2xf64>
+}
+
+// -----
+
+mesh.cluster @mesh0(rank = 1, dim_sizes = [3])
+
+func.func @reduce_scatter_invalid_static_dimension_size(
+    %arg0 : tensor<3xf32>) -> tensor<2xf64> {
+  // expected-error at +1 {{Dimension size mismatch for result axis 0. Expected 1, but got 2.}}
+  %0 = mesh.reduce_scatter %arg0 {
+      mesh = @mesh0, mesh_axes = array<i16: 0>, scatter_axis = 0
+    } : tensor<3xf32> -> tensor<2xf64>
+  return %0 : tensor<2xf64>
+}
+
+// -----
+
+mesh.cluster @mesh0(rank = 1, dim_sizes = [3])
+
+func.func @reduce_scatter_invalid_operand_static_dimension_size(
+    %arg0 : tensor<4xf32>) -> tensor<?xf64> {
+  // expected-error at +1 {{Operand dimension size 4 is not divisible by collective device group size 3 for scatter axis 0.}}
+  %0 = mesh.reduce_scatter %arg0 {
+      mesh = @mesh0, mesh_axes = array<i16: 0>, scatter_axis = 0
+    } : tensor<4xf32> -> tensor<?xf64>
+  return %0 : tensor<?xf64>
+}
diff --git a/mlir/test/Dialect/Mesh/ops.mlir b/mlir/test/Dialect/Mesh/ops.mlir
index ee5f8f67792b928..5ec6c4a439327f0 100644
--- a/mlir/test/Dialect/Mesh/ops.mlir
+++ b/mlir/test/Dialect/Mesh/ops.mlir
@@ -12,6 +12,8 @@ mesh.cluster @mesh2(rank = 2, dim_sizes = [0, 4])
 // CHECK: mesh.cluster @mesh3
 mesh.cluster @mesh3(rank = 2)
 
+mesh.cluster @mesh4(rank = 1, dim_sizes = [3])
+
 // CHECK-LABEL: func @mesh_shard_encoding_fully_replicated
 func.func @mesh_shard_encoding_fully_replicated(
     // CHECK-SAME: %[[ARG:.*]]: tensor<4x8xf32, #mesh.shard<@mesh0, {{\[\[}}]]>>
@@ -126,3 +128,120 @@ func.func @mesh_shard_op_two_users(%arg0 : tensor<4x8xf32>) ->
   %2 = mesh.shard %0 to <@mesh0, [[2]]> annotate_for_users : tensor<4x8xf32>
   return %1, %2 : tensor<4x8xf32>, tensor<4x8xf32>
 }
+
+// CHECK-LABEL: func @all_reduce
+func.func @all_reduce(
+    // CHECK-SAME: %[[ARG:.*]]: tensor<3x4xf32>
+    %arg0 : tensor<3x4xf32>) -> tensor<3x4xf64> {
+  // CHECK-NEXT: mesh.all_reduce %[[ARG]]
+  // CHECK-SAME: {mesh = @mesh0, mesh_axes = array<i16: 1, 0>, reduction = #mesh.partial<max>}
+  // CHECK-SAME: : tensor<3x4xf32> -> tensor<3x4xf64>
+  %0 = mesh.all_reduce %arg0 {
+      mesh = @mesh0, mesh_axes = array<i16: 1, 0>, reduction = #mesh.partial<max>
+    } : tensor<3x4xf32> -> tensor<3x4xf64>
+  return %0 : tensor<3x4xf64>
+}
+
+// CHECK-LABEL: func @all_gather
+func.func @all_gather(
+    // CHECK-SAME: %[[ARG:.*]]: tensor<3x4xf32>
+    %arg0 : tensor<3x4xf32>) -> tensor<3x16xf32> {
+  // CHECK-NEXT: mesh.all_gather %[[ARG]]
+  // CHECK-SAME: {gather_axis = 1 : index, mesh = @mesh0, mesh_axes = array<i16: 2>}
+  // CHECK-SAME: : tensor<3x4xf32> -> tensor<3x16xf32>
+  %0 = mesh.all_gather %arg0 {
+      gather_axis = 1 : index, mesh = @mesh0, mesh_axes = array<i16: 2>
+    } : tensor<3x4xf32> -> tensor<3x16xf32>
+  return %0 : tensor<3x16xf32>
+}
+
+// CHECK-LABEL: func @all_gather_dynamic_dims_in_tensor
+func.func @all_gather_dynamic_dims_in_tensor(
+    // CHECK-SAME: %[[ARG:.*]]: tensor<?x?xf32>
+    %arg0 : tensor<?x?xf32>) -> tensor<?x?xf32> {
+  // CHECK-NEXT: mesh.all_gather %[[ARG]]
+  // CHECK-SAME: {gather_axis = 1 : index, mesh = @mesh0, mesh_axes = array<i16: 2>}
+  // CHECK-SAME: : tensor<?x?xf32> -> tensor<?x?xf32>
+  %0 = mesh.all_gather %arg0 {
+      gather_axis = 1 : index, mesh = @mesh0, mesh_axes = array<i16: 2>
+    } : tensor<?x?xf32> -> tensor<?x?xf32>
+  return %0 : tensor<?x?xf32>
+}
+
+// CHECK-LABEL: func @all_gather_dynamic_dims_in_mesh
+func.func @all_gather_dynamic_dims_in_mesh(
+    // CHECK-SAME: %[[ARG:.*]]: tensor<5x6xf32>
+    %arg0 : tensor<5x6xf32>) -> tensor<5x?xf32> {
+  // CHECK-NEXT: mesh.all_gather %[[ARG]]
+  // CHECK-SAME: {gather_axis = 1 : index, mesh = @mesh3, mesh_axes = array<i16: 1>}
+  // CHECK-SAME: : tensor<5x6xf32> -> tensor<5x?xf32>
+  %0 = mesh.all_gather %arg0 {
+      gather_axis = 1 : index, mesh = @mesh3, mesh_axes = array<i16: 1>
+    } : tensor<5x6xf32> -> tensor<5x?xf32>
+  return %0 : tensor<5x?xf32>
+}
+
+// CHECK-LABEL: func @all_to_all
+func.func @all_to_all(
+    // CHECK-SAME: %[[ARG:.*]]: tensor<3x6xi8>
+    %arg0 : tensor<3x6xi8>) -> tensor<3x6xi8> {
+  // CHECK-NEXT: mesh.all_to_all %[[ARG]]
+  // CHECK-SAME: {concat_axis = 0 : i64, mesh = @mesh4, split_axis = 1 : i64}
+  // CHECK-SAME: : tensor<3x6xi8> -> tensor<3x6xi8>
+  %0 = mesh.all_to_all %arg0 {
+      concat_axis = 0, mesh = @mesh4, split_axis = 1
+    } : tensor<3x6xi8> -> tensor<3x6xi8>
+  return %0 : tensor<3x6xi8>
+}
+
+// CHECK-LABEL: func @all_to_all_dynamic_dims_in_result
+func.func @all_to_all_dynamic_dims_in_result(
+    // CHECK-SAME: %[[ARG:.*]]: tensor<3x6xi8>
+    %arg0 : tensor<3x6xi8>) -> tensor<3x?xi8> {
+  // CHECK-NEXT: mesh.all_to_all %[[ARG]]
+  // CHECK-SAME: {concat_axis = 0 : i64, mesh = @mesh4, split_axis = 1 : i64}
+  // CHECK-SAME: : tensor<3x6xi8> -> tensor<3x?xi8>
+  %0 = mesh.all_to_all %arg0 {
+      concat_axis = 0, mesh = @mesh4, split_axis = 1
+    } : tensor<3x6xi8> -> tensor<3x?xi8>
+  return %0 : tensor<3x?xi8>
+}
+
+// CHECK-LABEL: func @all_to_all
+func.func @all_to_all_same_split_concat_dim_with_dynamic_device_group_size(
+    // CHECK-SAME: %[[ARG:.*]]: tensor<3xi8>
+    %arg0 : tensor<3xi8>) -> tensor<3xi8> {
+  // CHECK-NEXT: mesh.all_to_all %[[ARG]]
+  // CHECK-SAME: {concat_axis = 0 : i64, mesh = @mesh4, split_axis = 0 : i64}
+  // CHECK-SAME: : tensor<3xi8> -> tensor<3xi8>
+  %0 = mesh.all_to_all %arg0 {
+      concat_axis = 0, mesh = @mesh4, split_axis = 0
+    } : tensor<3xi8> -> tensor<3xi8>
+  return %0 : tensor<3xi8>
+}
+
+// CHECK-LABEL: func @reduce_scatter_static_dimensions
+func.func @reduce_scatter_static_dimensions(
+    // CHECK-SAME: %[[ARG:.*]]: tensor<3x4xf32>
+    %arg0 : tensor<3x4xf32>) -> tensor<3x1xf64> {
+  // CHECK-NEXT: mesh.reduce_scatter %[[ARG]]
+  // CHECK-SAME: {mesh = @mesh0, mesh_axes = array<i16: 2>, reduction = #mesh.partial<max>, scatter_axis = 1 : i64}
+  // CHECK-SAME: : tensor<3x4xf32> -> tensor<3x1xf64>
+  %0 = mesh.reduce_scatter %arg0 {
+      mesh = @mesh0, mesh_axes = array<i16: 2>, reduction = #mesh.partial<max>, scatter_axis = 1
+    } : tensor<3x4xf32> -> tensor<3x1xf64>
+  return %0 : tensor<3x1xf64>
+}
+
+// CHECK-LABEL: func @reduce_scatter_dynamic_dimensions
+func.func @reduce_scatter_dynamic_dimensions(
+    // CHECK-SAME: %[[ARG:.*]]: tensor<?xf32>
+    %arg0 : tensor<?xf32>) -> tensor<?xf64> {
+  // CHECK-NEXT: mesh.reduce_scatter %[[ARG]]
+  // CHECK-SAME: {mesh = @mesh3, mesh_axes = array<i16: 0, 1>, scatter_axis = 0 : i64}
+  // CHECK-SAME: : tensor<?xf32> -> tensor<?xf64>
+  %0 = mesh.reduce_scatter %arg0 {
+      mesh = @mesh3, mesh_axes = array<i16: 0, 1>, scatter_axis = 0
+    } : tensor<?xf32> -> tensor<?xf64>
+  return %0 : tensor<?xf64>
+}

>From cb3347fa2c23d7d40ca685b92254088c6707a933 Mon Sep 17 00:00:00 2001
From: Boian Petkantchin <boian at nod-labs.com>
Date: Tue, 14 Nov 2023 10:57:51 -0800
Subject: [PATCH 02/10] Adress comments in PR (to squash)

---
 mlir/docs/Dialects/Mesh.md                   |  13 +-
 mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td |  43 ++----
 mlir/lib/Dialect/Mesh/IR/MeshOps.cpp         | 138 ++++++++-----------
 mlir/test/Dialect/Mesh/canonicalization.mlir |  29 ----
 mlir/test/Dialect/Mesh/invalid.mlir          |  54 +++++++-
 5 files changed, 138 insertions(+), 139 deletions(-)

diff --git a/mlir/docs/Dialects/Mesh.md b/mlir/docs/Dialects/Mesh.md
index 6dd4f79022061ee..03877f1a6544817 100644
--- a/mlir/docs/Dialects/Mesh.md
+++ b/mlir/docs/Dialects/Mesh.md
@@ -14,16 +14,25 @@ It is assumed that the user is familiar with collective operations.
 explanation.
 The main addition is that the collectives in this dialect have mesh
 semantics.
-The operation attributes `mesh` and `mesh_axes` specifies a set of device mesh
+
+The operation attributes `mesh` and `mesh_axes` specifies a list of device mesh
 axes that partition the devices into disjoint groups.
 The collective operation is performed between devices in the same group.
 Devices that have the same coordinates outside of axes `mesh_axes` are in the
 same group.
 For example if we have a device mesh of size `2x3x4x5` and the partition mesh
-axes set is `{0, 1}` then devices are partitioned into the groups
+axes list is `[0, 1]` then devices are partitioned into the groups
 `{ { (i, j, k, m) | 0<=i<2, 0<=j<3 } | 0<=k<4, 0<=m<5 }`.
 Devices (1, 0, 2, 3) and (1, 1, 2, 3) will be in the same group.
 Device (1, 0, 2, 4) will be in another group.
+Some collective operations like all-to-all and all-gather care about the
+order of devices.
+The order of device in a device group is induced by the order of axes in
+`mesh_axes`.
+The axes are ordered from outer to inner.
+If we have an axis list `[3, 1]` then device `(i, 1, k, 0)` will precede
+both devices `(i, 0, k, 1)` and `(i, 2, k, 0)`.
+
 
 ## Operations
 
diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
index 15354babe870599..34110d400d7017e 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
@@ -190,12 +190,12 @@ class Mesh_CollectiveCommunicationOpBase<
     string mnemonic, list<Trait> traits = []> :
     Mesh_Op<mnemonic,
       !listconcat(traits,
-      [SymbolUserOpInterface])> {
+      [DeclareOpInterfaceMethods<SymbolUserOpInterface>])> {
   let assemblyFormat = "$input attr-dict `:` type($input) `->` type($result)";
-  code extraClassDeclarationBase = [{
-    ::mlir::LogicalResult verifySymbolUses(
-          ::mlir::SymbolTableCollection &symbolTable);
-  }];
+  dag commonArgs = (ins
+    FlatSymbolRefAttr:$mesh,
+    DefaultValuedOptionalAttr<DenseI16ArrayAttr, "{}">:$mesh_axes
+  );
 }
 
 def Mesh_AllGatherOp : Mesh_CollectiveCommunicationOpBase<"all_gather", [
@@ -237,18 +237,14 @@ def Mesh_AllGatherOp : Mesh_CollectiveCommunicationOpBase<"all_gather", [
     +-------------+
     ```
   }];
-  let arguments = (ins
+  let arguments = !con(commonArgs, (ins
     AnyNon0RankedTensor:$input,
-    FlatSymbolRefAttr:$mesh,
-    DefaultValuedOptionalAttr<DenseI16ArrayAttr, "{}">:$mesh_axes,
     APIntAttr:$gather_axis
-  );
+  ));
   let results = (outs
     AnyNon0RankedTensor:$result
   );
-  let hasCanonicalizer = 1;
   let hasVerifier = 1;
-  let extraClassDeclaration = extraClassDeclarationBase;
 }
 
 def Mesh_AllReduceOp : Mesh_CollectiveCommunicationOpBase<"all_reduce", [
@@ -270,17 +266,14 @@ def Mesh_AllReduceOp : Mesh_CollectiveCommunicationOpBase<"all_reduce", [
       } : tensor<3x4xf32> -> tensor<3x4xf64>
     ```
   }];
-  let arguments = (ins
+  let arguments = !con(commonArgs, (ins
     AnyRankedTensor:$input,
-    FlatSymbolRefAttr:$mesh,
-    DefaultValuedOptionalAttr<DenseI16ArrayAttr, "{}">:$mesh_axes,
     DefaultValuedOptionalAttr<Mesh_PartialAttr, "::mlir::mesh::Partial::Sum">:$reduction
-  );
+  ));
   let results = (outs
     AnyRankedTensor:$result
   );
-  let hasCanonicalizer = 1;
-  let extraClassDeclaration = extraClassDeclarationBase;
+  let hasVerifier = 1;
 }
 
 def Mesh_AllToAllOp : Mesh_CollectiveCommunicationOpBase<"all_to_all", [
@@ -319,19 +312,15 @@ def Mesh_AllToAllOp : Mesh_CollectiveCommunicationOpBase<"all_to_all", [
     +-------+-------+-------+
     ```
   }];
-  let arguments = (ins
+  let arguments = !con(commonArgs, (ins
     AnyNon0RankedTensor:$input,
-    FlatSymbolRefAttr:$mesh,
-    DefaultValuedOptionalAttr<DenseI16ArrayAttr, "{}">:$mesh_axes,
     APIntAttr:$split_axis,
     APIntAttr:$concat_axis
-  );
+  ));
   let results = (outs
     AnyNon0RankedTensor:$result
   );
-  let hasCanonicalizer = 1;
   let hasVerifier = 1;
-  let extraClassDeclaration = extraClassDeclarationBase;
 }
 
 def Mesh_ReduceScatterOp : Mesh_CollectiveCommunicationOpBase<"reduce_scatter", [
@@ -372,19 +361,15 @@ def Mesh_ReduceScatterOp : Mesh_CollectiveCommunicationOpBase<"reduce_scatter",
     +-------+
     ```
   }];
-  let arguments = (ins
+  let arguments = !con(commonArgs, (ins
     AnyNon0RankedTensor:$input,
-    FlatSymbolRefAttr:$mesh,
-    DefaultValuedOptionalAttr<DenseI16ArrayAttr, "{}">:$mesh_axes,
     DefaultValuedOptionalAttr<Mesh_PartialAttr, "::mlir::mesh::Partial::Sum">:$reduction,
     APIntAttr:$scatter_axis
-  );
+  ));
   let results = (outs
     AnyRankedTensor:$result
   );
-  let hasCanonicalizer = 1;
   let hasVerifier = 1;
-  let extraClassDeclaration = extraClassDeclarationBase;
 }
 
 #endif // MLIR_DIALECT_MESH_IR_MESHOPS_TD
diff --git a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
index 6efc4c4ecc326ad..bdbd3ad1514dfc6 100644
--- a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
+++ b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
@@ -16,6 +16,7 @@
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/Support/LLVM.h"
 #include "mlir/Support/LogicalResult.h"
+#include "llvm/ADT/ArrayRef.h"
 #include "llvm/ADT/DenseSet.h"
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/SmallSet.h"
@@ -169,6 +170,7 @@ LogicalResult ClusterOp::verify() {
 SmallVector<int64_t> ClusterOp::canonicalDimSizes() {
   SmallVector<int64_t> result;
   canonicalDimSizes(std::back_inserter(result));
+  result.reserve(getRank());
   return result;
 }
 
@@ -211,44 +213,6 @@ MeshShardingAttr::verify(function_ref<InFlightDiagnostic()> emitError,
 
 namespace {
 
-std::optional<DenseI16ArrayAttr>
-canonicalizeAxesSetAttribute(DenseI16ArrayAttr attr) {
-  if (!attr) {
-    return std::nullopt;
-  }
-  SmallVector<int16_t> axes = llvm::to_vector(attr.asArrayRef());
-  canonicalizeSetAsVector(axes);
-  if (axes.empty()) {
-    return std::nullopt;
-  }
-  return DenseI16ArrayAttr::get(attr.getContext(), axes);
-}
-
-template <typename Op>
-struct AxesSetCanonicalizationPattern : OpRewritePattern<Op> {
-  AxesSetCanonicalizationPattern(MLIRContext *context, StringRef axisSetAttr)
-      : OpRewritePattern<Op>(context), axisSetAttr(axisSetAttr) {}
-  LogicalResult matchAndRewrite(Op op,
-                                PatternRewriter &rewriter) const override {
-    auto canonicalMeshAxesAttr = canonicalizeAxesSetAttribute(
-        op->template getAttrOfType<DenseI16ArrayAttr>(axisSetAttr));
-    if (!canonicalMeshAxesAttr) {
-      op->removeAttr(axisSetAttr);
-    } else {
-      op->setAttr(axisSetAttr, canonicalMeshAxesAttr.value());
-    }
-    return success();
-  }
-
-  std::string axisSetAttr;
-};
-
-template <typename Op>
-void populateMeshAxesSetCanonicalizationPatterns(RewritePatternSet &patterns,
-                                                 MLIRContext *context) {
-  patterns.add<AxesSetCanonicalizationPattern<Op>>(context, "mesh_axes");
-}
-
 template <typename Op>
 LogicalResult verifyMeshSymbolUses(Op op, SymbolTableCollection &symbolTable) {
   FlatSymbolRefAttr symbolAttr = op.getMeshAttr();
@@ -280,10 +244,36 @@ LogicalResult verifyMeshSymbolUses(Op op, SymbolTableCollection &symbolTable) {
   return success();
 }
 
+template <typename It>
+bool isUnique(It begin, It end) {
+  if (begin == end) {
+    return true;
+  }
+  It next = std::next(begin);
+  if (next == end) {
+    return true;
+  }
+  for (; next != end; ++next, ++begin) {
+    if (*begin == *next) {
+      return false;
+    }
+  }
+  return true;
+}
+
+LogicalResult verifyMeshAxes(Location loc, ArrayRef<MeshAxis> axes) {
+  SmallVector<MeshAxis> sorted = llvm::to_vector(axes);
+  std::sort(sorted.begin(), sorted.end());
+  if (!isUnique(sorted.begin(), sorted.end())) {
+    return emitError(loc) << "Mesh axes contains duplicate elements.";
+  }
+  return success();
+}
+
 template <typename It>
 auto product(It begin, It end) {
   using ElementType = std::decay_t<decltype(*begin)>;
-  return std::accumulate(begin, end, ElementType(1),
+  return std::accumulate(begin, end, static_cast<ElementType>(1),
                          std::multiplies<ElementType>());
 }
 
@@ -295,15 +285,15 @@ auto product(R &&range) {
 int64_t collectiveDeviceGroupSize(ArrayRef<MeshAxis> meshAxes,
                                   ArrayRef<int64_t> meshShape) {
   int64_t res = 1;
-  for (MeshAxis axis = 0; axis < MeshAxis(meshShape.size()); ++axis) {
-    if (llvm::find(meshAxes, axis) == meshAxes.end()) {
-      continue;
-    }
+
+  for (MeshAxis axis : meshAxes) {
     if (isMeshDimensionDynamic(meshShape[axis])) {
       return ShapedType::kDynamic;
     }
+    assert(size_t(axis) < meshShape.size());
     res *= meshShape[axis];
   }
+
   return res;
 }
 
@@ -324,10 +314,9 @@ LogicalResult verifyDimensionCompatibility(Location loc,
   return success();
 }
 
-LogicalResult verifyGatherOperandAndResultShape(Value operand, Value result,
-                                                int64_t gatherAxis,
-                                                ArrayRef<MeshAxis> meshAxes,
-                                                ArrayRef<int64_t> meshShape) {
+LogicalResult verifyAllGatherOperandAndResultShape(
+    Value operand, Value result, int64_t gatherAxis,
+    ArrayRef<MeshAxis> meshAxes, ArrayRef<int64_t> meshShape) {
   ShapedType operandType = operand.getType().cast<ShapedType>();
   ShapedType resultType = result.getType().cast<ShapedType>();
   auto deviceGroupSize =
@@ -358,7 +347,7 @@ FailureOr<ClusterOp> getMesh(Op op) {
 }
 
 template <typename Op>
-LogicalResult verifyGather(Op op) {
+LogicalResult verifyAllGather(Op op) {
   auto rank = op.getResult().getType().template cast<ShapedType>().getRank();
   auto gatherAxis = op.getGatherAxis().getSExtValue();
   if (gatherAxis < 0 || gatherAxis >= rank) {
@@ -370,9 +359,9 @@ LogicalResult verifyGather(Op op) {
   if (failed(mesh)) {
     return failure();
   }
-  return verifyGatherOperandAndResultShape(op.getOperand(), op.getResult(),
-                                           gatherAxis, op.getMeshAxes(),
-                                           mesh.value().canonicalDimSizes());
+  return verifyAllGatherOperandAndResultShape(op.getOperand(), op.getResult(),
+                                              gatherAxis, op.getMeshAxes(),
+                                              mesh.value().canonicalDimSizes());
 }
 
 LogicalResult verifyAllToAllOperandAndResultShape(Value operand, Value result,
@@ -471,13 +460,12 @@ LogicalResult verifyReduceScatterOperandAndResultShape(
 //===----------------------------------------------------------------------===//
 
 LogicalResult
-mlir::mesh::AllReduceOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
+AllReduceOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
   return verifyMeshSymbolUses(*this, symbolTable);
 }
 
-void mlir::mesh::AllReduceOp::getCanonicalizationPatterns(
-    RewritePatternSet &patterns, MLIRContext *context) {
-  populateMeshAxesSetCanonicalizationPatterns<AllReduceOp>(patterns, context);
+LogicalResult mlir::mesh::AllReduceOp::verify() {
+  return verifyMeshAxes(getLoc(), getMeshAxes());
 }
 
 //===----------------------------------------------------------------------===//
@@ -485,32 +473,29 @@ void mlir::mesh::AllReduceOp::getCanonicalizationPatterns(
 //===----------------------------------------------------------------------===//
 
 LogicalResult
-mlir::mesh::AllGatherOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
+AllGatherOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
   return verifyMeshSymbolUses(*this, symbolTable);
 }
 
-void mlir::mesh::AllGatherOp::getCanonicalizationPatterns(
-    RewritePatternSet &patterns, MLIRContext *context) {
-  populateMeshAxesSetCanonicalizationPatterns<AllGatherOp>(patterns, context);
+LogicalResult mlir::mesh::AllGatherOp::verify() {
+  if (failed(verifyMeshAxes(getLoc(), getMeshAxes()))) {
+    return failure();
+  }
+  return verifyAllGather(*this);
 }
 
-LogicalResult mlir::mesh::AllGatherOp::verify() { return verifyGather(*this); }
-
 //===----------------------------------------------------------------------===//
 // mesh.all_to_all op
 //===----------------------------------------------------------------------===//
 
-LogicalResult
-mlir::mesh::AllToAllOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
+LogicalResult AllToAllOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
   return verifyMeshSymbolUses(*this, symbolTable);
 }
 
-void mlir::mesh::AllToAllOp::getCanonicalizationPatterns(
-    RewritePatternSet &patterns, MLIRContext *context) {
-  populateMeshAxesSetCanonicalizationPatterns<AllToAllOp>(patterns, context);
-}
-
-LogicalResult mlir::mesh::AllToAllOp::verify() {
+LogicalResult AllToAllOp::verify() {
+  if (failed(verifyMeshAxes(getLoc(), getMeshAxes()))) {
+    return failure();
+  }
   auto mesh = ::getMesh(*this);
   if (failed(mesh)) {
     return failure();
@@ -525,18 +510,15 @@ LogicalResult mlir::mesh::AllToAllOp::verify() {
 // mesh.reduce_scatter op
 //===----------------------------------------------------------------------===//
 
-LogicalResult mlir::mesh::ReduceScatterOp::verifySymbolUses(
-    SymbolTableCollection &symbolTable) {
+LogicalResult
+ReduceScatterOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
   return verifyMeshSymbolUses(*this, symbolTable);
 }
 
-void mlir::mesh::ReduceScatterOp::getCanonicalizationPatterns(
-    RewritePatternSet &patterns, MLIRContext *context) {
-  populateMeshAxesSetCanonicalizationPatterns<ReduceScatterOp>(patterns,
-                                                               context);
-}
-
-LogicalResult mlir::mesh::ReduceScatterOp::verify() {
+LogicalResult ReduceScatterOp::verify() {
+  if (failed(verifyMeshAxes(getLoc(), getMeshAxes()))) {
+    return failure();
+  }
   auto mesh = ::getMesh(*this);
   if (failed(mesh)) {
     return failure();
diff --git a/mlir/test/Dialect/Mesh/canonicalization.mlir b/mlir/test/Dialect/Mesh/canonicalization.mlir
index 21fa1887d14a53e..2bacfbe992525ee 100644
--- a/mlir/test/Dialect/Mesh/canonicalization.mlir
+++ b/mlir/test/Dialect/Mesh/canonicalization.mlir
@@ -2,15 +2,6 @@
 
 mesh.cluster @mesh0(rank = 2, dim_sizes = [2, 4])
 
-// CHECK-LABEL: func @all_reduce_mesh_axes
-func.func @all_reduce_mesh_axes(
-    %arg0 : tensor<4xf32>) -> tensor<4xf64> {
-// CHECK: mesh_axes = array<i16: 0, 1>
-  %0 = mesh.all_reduce %arg0 {
-    mesh = @mesh0, mesh_axes = array<i16: 1, 0, 0>, reduction = #mesh.partial<sum>
-    } : tensor<4xf32> -> tensor<4xf64>
-  return %0 : tensor<4xf64>
-}
 
 // CHECK-LABEL: func @all_reduce_empty_mesh_axes_and_default_reduction
 func.func @all_reduce_empty_mesh_axes_and_default_reduction(
@@ -37,26 +28,6 @@ func.func @all_gather_empty_mesh_axes(
   return %0 : tensor<4xf32>
 }
 
-// CHECK-LABEL: func @all_gather_mesh_axes
-func.func @all_gather_mesh_axes(
-    %arg0 : tensor<4xf32>) -> tensor<32xf32> {
-// CHECK: mesh_axes = array<i16: 0, 1>
-  %0 = mesh.all_gather %arg0 {
-    mesh = @mesh0, mesh_axes = array<i16: 1, 0, 0>, gather_axis = 0
-    } : tensor<4xf32> -> tensor<32xf32>
-  return %0 : tensor<32xf32>
-}
-
-// CHECK-LABEL: func @reduce_scatter_mesh_axes
-func.func @reduce_scatter_mesh_axes(
-    %arg0 : tensor<8xf32>) -> tensor<1xf64> {
-// CHECK: mesh_axes = array<i16: 0, 1>
-  %0 = mesh.reduce_scatter %arg0 {
-    mesh = @mesh0, mesh_axes = array<i16: 1, 0, 0>, scatter_axis = 0
-    } : tensor<8xf32> -> tensor<1xf64>
-  return %0 : tensor<1xf64>
-}
-
 // CHECK-LABEL: func @reduce_scatter_empty_mesh_axes_and_default_reduction
 func.func @reduce_scatter_empty_mesh_axes_and_default_reduction(
     %arg0 : tensor<4xf32>) -> tensor<4xf64> {
diff --git a/mlir/test/Dialect/Mesh/invalid.mlir b/mlir/test/Dialect/Mesh/invalid.mlir
index 413b01459f507a2..27d80e285d83778 100644
--- a/mlir/test/Dialect/Mesh/invalid.mlir
+++ b/mlir/test/Dialect/Mesh/invalid.mlir
@@ -96,6 +96,19 @@ func.func @all_reduce_invalid_mesh_axis(
 
 mesh.cluster @mesh0(rank = 2, dim_sizes = [2, 4])
 
+func.func @all_reduce_duplicate_mesh_axis(
+    %arg0 : tensor<4xf32>) -> tensor<4xf64> {
+  // expected-error at +1 {{Mesh axes contains duplicate elements.}}
+  %0 = mesh.all_reduce %arg0 {
+    mesh = @mesh0, mesh_axes = array<i16: 0, 1, 0>, reduction = #mesh.partial<sum>
+    } : tensor<4xf32> -> tensor<4xf64>
+  return %0 : tensor<4xf64>
+}
+
+// -----
+
+mesh.cluster @mesh0(rank = 2, dim_sizes = [2, 4])
+
 func.func @all_reduce_invalid_tensor_dimension_size(
     %arg0 : tensor<4xf32>) -> tensor<5xf64> {
   // expected-error at +1 {{'mesh.all_reduce' op requires the same shape for all operands and results}}
@@ -129,6 +142,19 @@ func.func @all_gather_invalid_mesh_axis(
 
 // -----
 
+mesh.cluster @mesh0(rank = 2, dim_sizes = [2, 4])
+
+func.func @all_reduce_duplicate_mesh_axis(
+    %arg0 : tensor<4xf32>) -> tensor<4xf32> {
+  // expected-error at +1 {{Mesh axes contains duplicate elements.}}
+  %0 = mesh.all_gather %arg0 {
+    mesh = @mesh0, mesh_axes = array<i16: 2, 2>, gather_axis = 0 : index
+    } : tensor<4xf32> -> tensor<4xf32>
+  return %0 : tensor<4xf32>
+}
+
+// -----
+
 mesh.cluster @mesh0(rank = 1, dim_sizes = [1])
 
 func.func @all_gather_invalid_non_gather_axis_dimension_size(
@@ -195,7 +221,7 @@ func.func @all_gather_invalid_negative_gather_axis(
 
 // -----
 
-func.func @all_to_all_gather_invalid_mesh_symbol(
+func.func @all_to_all_invalid_mesh_symbol(
     %arg0 : tensor<3x6xi8>) -> tensor<3x6xi8> {
   // expected-error at +1 {{Undefined required mesh symbol "this_mesh_symbol_does_not_exist".}}
   %0 = mesh.all_to_all %arg0 {
@@ -206,6 +232,19 @@ func.func @all_to_all_gather_invalid_mesh_symbol(
 
 // -----
 
+mesh.cluster @mesh0(rank = 1, dim_sizes = [1])
+
+func.func @all_to_all_duplicate_mesh_axis(
+    %arg0 : tensor<3x6xi8>) -> tensor<3x6xi8> {
+  // expected-error at +1 {{Mesh axes contains duplicate elements.}}
+  %0 = mesh.all_to_all %arg0 {
+      concat_axis = 0, mesh = @mesh0, mesh_axes = array<i16: 0, 0>, split_axis = 0
+    } : tensor<3x6xi8> -> tensor<3x6xi8>
+  return %0 : tensor<3x6xi8>
+}
+
+// -----
+
 mesh.cluster @mesh0(rank = 2, dim_sizes = [0, 1])
 
 func.func @all_to_all_invalid_non_dynamic_result_dimension_induced_by_dynamic_device_group(
@@ -273,6 +312,19 @@ func.func @all_to_all_invalid_non_dynamic_result_split_dimension_size(
 
 mesh.cluster @mesh0(rank = 1, dim_sizes = [3])
 
+func.func @reduce_scatter_duplicate_mesh_axis(
+    %arg0 : tensor<?xf32>) -> tensor<?xf64> {
+  // expected-error at +1 {{Mesh axes contains duplicate elements.}}
+  %0 = mesh.reduce_scatter %arg0 {
+      mesh = @mesh0, scatter_axis = 0, mesh_axes = array<i16: 0, 0>
+    } : tensor<?xf32> -> tensor<?xf64>
+  return %0 : tensor<?xf64>
+}
+
+// -----
+
+mesh.cluster @mesh0(rank = 1, dim_sizes = [3])
+
 func.func @reduce_scatter_invalid_dynamic_dimension(
     %arg0 : tensor<?xf32>) -> tensor<2xf64> {
   // expected-error at +1 {{Dimension size mismatch for result axis 0. Expected dynamic, but got 2.}}

>From 06ac579da85e01bf58c79220d480eac706bed2bc Mon Sep 17 00:00:00 2001
From: Boian Petkantchin <boian at nod-labs.com>
Date: Wed, 15 Nov 2023 10:29:52 -0800
Subject: [PATCH 03/10] Make attributes explicit in assembly format

Improve ops doc.
---
 mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td |  84 ++++++++-----
 mlir/test/Dialect/Mesh/canonicalization.mlir |  29 ++---
 mlir/test/Dialect/Mesh/invalid.mlir          | 120 ++++++++-----------
 mlir/test/Dialect/Mesh/ops.mlir              |  71 +++++------
 4 files changed, 150 insertions(+), 154 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
index 34110d400d7017e..eb890695c20fff0 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
@@ -191,10 +191,9 @@ class Mesh_CollectiveCommunicationOpBase<
     Mesh_Op<mnemonic,
       !listconcat(traits,
       [DeclareOpInterfaceMethods<SymbolUserOpInterface>])> {
-  let assemblyFormat = "$input attr-dict `:` type($input) `->` type($result)";
   dag commonArgs = (ins
     FlatSymbolRefAttr:$mesh,
-    DefaultValuedOptionalAttr<DenseI16ArrayAttr, "{}">:$mesh_axes
+    DefaultValuedAttr<DenseI16ArrayAttr, "{}">:$mesh_axes
   );
 }
 
@@ -205,16 +204,13 @@ def Mesh_AllGatherOp : Mesh_CollectiveCommunicationOpBase<"all_gather", [
   let summary = "All-gather over a device mesh.";
   let description = [{
     Gathers along the `gather_axis` tensor axis.
-    The order of input tensors in the resulting tensor is the same as the
-    order of the corresponding devices' multi-index in the mesh.
 
     Example:
     ```mlir
     mesh.cluster @mesh0(rank = 2, dim_sizes = [2, 2])
     ...
-    %1 = mesh.all_gather %0 {
-        mesh = @mesh0, mesh_axes = array<i16: 1>, gather_axis = 1 : index
-      } : tensor<2x2xi8> -> tensor<2x4xi8>
+    %1 = mesh.all_gather %0 on @mesh0 mesh_axes = [1] gather_axis = 1
+      : tensor<2x2xi8> -> tensor<2x4xi8>
     ```
     Input:
     ```
@@ -228,6 +224,9 @@ def Mesh_AllGatherOp : Mesh_CollectiveCommunicationOpBase<"all_gather", [
     ```
     Result:
     ```
+    gather tensor
+    axis 1
+    ------------>
     +-------------+
     |  1  2  5  6 | <- devices (0, 0) and (0, 1)
     |  3  4  7  8 |
@@ -239,11 +238,15 @@ def Mesh_AllGatherOp : Mesh_CollectiveCommunicationOpBase<"all_gather", [
   }];
   let arguments = !con(commonArgs, (ins
     AnyNon0RankedTensor:$input,
-    APIntAttr:$gather_axis
+    IndexAttr:$gather_axis
   ));
   let results = (outs
     AnyNon0RankedTensor:$result
   );
+  let assemblyFormat = [{
+    $input `on` $mesh (`mesh_axes` `=` $mesh_axes^)? `gather_axis` `=` $gather_axis
+    attr-dict `:` type($input) `->` type($result)
+  }];
   let hasVerifier = 1;
 }
 
@@ -261,18 +264,21 @@ def Mesh_AllReduceOp : Mesh_CollectiveCommunicationOpBase<"all_reduce", [
 
     Example:
     ```
-    %1 = mesh.all_reduce %0 {
-        mesh = @mesh0, mesh_axes = array<i16: 1, 0>, reduction = #mesh.partial<max>
-      } : tensor<3x4xf32> -> tensor<3x4xf64>
+    %1 = mesh.all_reduce %0 on @mesh0 mesh_axes = [1, 0] reduction = <max>
+      : tensor<3x4xf32> -> tensor<3x4xf64>
     ```
   }];
   let arguments = !con(commonArgs, (ins
     AnyRankedTensor:$input,
-    DefaultValuedOptionalAttr<Mesh_PartialAttr, "::mlir::mesh::Partial::Sum">:$reduction
+    DefaultValuedAttr<Mesh_PartialAttr, "::mlir::mesh::Partial::Sum">:$reduction
   ));
   let results = (outs
     AnyRankedTensor:$result
   );
+  let assemblyFormat = [{
+    $input `on` $mesh (`mesh_axes` `=` $mesh_axes^)? (`reduction` `=` $reduction^)?
+    attr-dict `:` type($input) `->` type($result)
+  }];
   let hasVerifier = 1;
 }
 
@@ -287,17 +293,17 @@ def Mesh_AllToAllOp : Mesh_CollectiveCommunicationOpBase<"all_to_all", [
     ```
     mesh.cluster @mesh0(rank = 1, dim_sizes = [3])
     ...
-    %1 = mesh.all_to_all %0 {
-        mesh = @mesh0, mesh_axes = array<i16: 0>, split_axis = 0, concat_axis = 0
-      } : tensor<3x6xi8> -> tensor<3x6xi8>
+    %1 = mesh.all_to_all %0 on @mesh0 mesh_axes = [0]
+      split_axis = 0 concat_axis = 0
+      : tensor<3x6xi8> -> tensor<3x6xi8>
     ```
     Input:
     ```
      device  device  device
      (0)     (1)     (2)
-    +-------+-------+-------+
-    | 11 12 | 21 22 | 31 32 |
-    | 13 14 | 23 24 | 33 34 |
+    +-------+-------+-------+  | split and concat
+    | 11 12 | 21 22 | 31 32 |  | tensor axis 0
+    | 13 14 | 23 24 | 33 34 |  ↓
     | 15 16 | 25 26 | 35 36 |
     +-------+-------+-------+
     ```
@@ -314,12 +320,18 @@ def Mesh_AllToAllOp : Mesh_CollectiveCommunicationOpBase<"all_to_all", [
   }];
   let arguments = !con(commonArgs, (ins
     AnyNon0RankedTensor:$input,
-    APIntAttr:$split_axis,
-    APIntAttr:$concat_axis
+    IndexAttr:$split_axis,
+    IndexAttr:$concat_axis
   ));
   let results = (outs
     AnyNon0RankedTensor:$result
   );
+  let assemblyFormat = [{
+    $input `on` $mesh (`mesh_axes` `=` $mesh_axes^)?
+    `split_axis` `=` $split_axis
+    `concat_axis` `=` $concat_axis
+    attr-dict `:` type($input) `->` type($result)
+  }];
   let hasVerifier = 1;
 }
 
@@ -327,26 +339,32 @@ def Mesh_ReduceScatterOp : Mesh_CollectiveCommunicationOpBase<"reduce_scatter",
     SameOperandsAndResultRank]> {
   let summary = "Reduce-scatter over a device mesh.";
   let description = [{
-    After the reduction scatters the result within each device group.
+    After the reduction, the result is scattered within each device group.
     The tensor is split along `scatter_axis` and the pieces distributed
     across the device group.
     Example:
     ```
     mesh.cluster @mesh0(rank = 1, dim_sizes = [2, 2])
     ...
-    %1 = mesh.reduce_scatter %0 {
-        mesh = @mesh0, mesh_axes = array<i16: 1>, reduction = #mesh.partial<max>, scatter_axis = 0
-      } : tensor<3x4xf32> -> tensor<1x4xf64>
+    %1 = mesh.reduce_scatter %0 on @mesh0 mesh_axes = [1]
+      reduction = <max> scatter_axis = 0
+      : tensor<3x4xf32> -> tensor<1x4xf64>
     ```
     Input:
     ```
+                              device
+                              (0, 1)
+                                 ↓
+                     +-------+-------+  | scatter tensor
+    device (0, 0) -> |  1  2 |  5  6 |  | axis 0
+                     |  3  4 |  7  8 |  ↓
                      +-------+-------+
-    device (0, 0) -> |  1  2 |  5  6 | <- device (0, 1)
-                     |  3  4 |  7  8 |
-                     +-------+-------+
-    device (1, 0) -> |  9 10 | 13 14 | <- device (1, 1)
+    device (1, 0) -> |  9 10 | 13 14 |
                      | 11 12 | 15 16 |
                      +-------+-------+
+                                ↑
+                              device
+                              (1, 1)
     ```
     Result:
     ```
@@ -363,12 +381,18 @@ def Mesh_ReduceScatterOp : Mesh_CollectiveCommunicationOpBase<"reduce_scatter",
   }];
   let arguments = !con(commonArgs, (ins
     AnyNon0RankedTensor:$input,
-    DefaultValuedOptionalAttr<Mesh_PartialAttr, "::mlir::mesh::Partial::Sum">:$reduction,
-    APIntAttr:$scatter_axis
+    DefaultValuedAttr<Mesh_PartialAttr, "::mlir::mesh::Partial::Sum">:$reduction,
+    IndexAttr:$scatter_axis
   ));
   let results = (outs
     AnyRankedTensor:$result
   );
+  let assemblyFormat = [{
+    $input `on` $mesh (`mesh_axes` `=` $mesh_axes^)?
+    (`reduction` `=` $reduction^)?
+    `scatter_axis` `=` $scatter_axis
+    attr-dict `:` type($input) `->` type($result)
+  }];
   let hasVerifier = 1;
 }
 
diff --git a/mlir/test/Dialect/Mesh/canonicalization.mlir b/mlir/test/Dialect/Mesh/canonicalization.mlir
index 2bacfbe992525ee..3383dc2bec2815c 100644
--- a/mlir/test/Dialect/Mesh/canonicalization.mlir
+++ b/mlir/test/Dialect/Mesh/canonicalization.mlir
@@ -6,38 +6,35 @@ mesh.cluster @mesh0(rank = 2, dim_sizes = [2, 4])
 // CHECK-LABEL: func @all_reduce_empty_mesh_axes_and_default_reduction
 func.func @all_reduce_empty_mesh_axes_and_default_reduction(
     %arg0 : tensor<4xf32>) -> tensor<4xf64> {
-  %0 = mesh.all_reduce %arg0 {
-    mesh = @mesh0,
+  %0 = mesh.all_reduce %arg0 on @mesh0
 // CHECK-NOT: mesh_axes
-    mesh_axes = array<i16>,
+    mesh_axes = []
 // CHECK-NOT: reduction
-    reduction = #mesh.partial<sum>
-    } : tensor<4xf32> -> tensor<4xf64>
+    reduction = <sum>
+    : tensor<4xf32> -> tensor<4xf64>
   return %0 : tensor<4xf64>
 }
 
 // CHECK-LABEL: func @all_gather_empty_mesh_axes
 func.func @all_gather_empty_mesh_axes(
     %arg0 : tensor<4xf32>) -> tensor<4xf32> {
-  %0 = mesh.all_gather %arg0 {
-    gather_axis = 0 : index,
-    mesh = @mesh0,
+  %0 = mesh.all_gather %arg0 on @mesh0
 // CHECK-NOT: mesh_axes
-    mesh_axes = array<i16>
-    } : tensor<4xf32> -> tensor<4xf32>
+    mesh_axes = []
+    gather_axis = 0
+    : tensor<4xf32> -> tensor<4xf32>
   return %0 : tensor<4xf32>
 }
 
 // CHECK-LABEL: func @reduce_scatter_empty_mesh_axes_and_default_reduction
 func.func @reduce_scatter_empty_mesh_axes_and_default_reduction(
     %arg0 : tensor<4xf32>) -> tensor<4xf64> {
-  %0 = mesh.reduce_scatter %arg0 {
-    mesh = @mesh0,
+  %0 = mesh.reduce_scatter %arg0 on @mesh0
 // CHECK-NOT: mesh_axes
-    mesh_axes = array<i16>,
+    mesh_axes = []
 // CHECK-NOT: reduction
-    reduction = #mesh.partial<sum>,
+    reduction = <sum>
     scatter_axis = 0
-    } : tensor<4xf32> -> tensor<4xf64>
+    : tensor<4xf32> -> tensor<4xf64>
   return %0 : tensor<4xf64>
-}
\ No newline at end of file
+}
diff --git a/mlir/test/Dialect/Mesh/invalid.mlir b/mlir/test/Dialect/Mesh/invalid.mlir
index 27d80e285d83778..2999668f770baa7 100644
--- a/mlir/test/Dialect/Mesh/invalid.mlir
+++ b/mlir/test/Dialect/Mesh/invalid.mlir
@@ -73,9 +73,8 @@ func.func @mesh_axis_negtive_in_partial(
 func.func @all_reduce_invalid_mesh_symbol(
     %arg0 : tensor<4xf32>) -> tensor<4xf64> {
   // expected-error at +1 {{Undefined required mesh symbol "this_mesh_symbol_does_not_exist".}}
-  %0 = mesh.all_reduce %arg0 {
-    mesh = @this_mesh_symbol_does_not_exist, reduction = #mesh.partial<sum>
-    } : tensor<4xf32> -> tensor<4xf64>
+  %0 = mesh.all_reduce %arg0 on @this_mesh_symbol_does_not_exist reduction = <sum>
+    : tensor<4xf32> -> tensor<4xf64>
   return %0 : tensor<4xf64>
 }
 
@@ -86,9 +85,8 @@ mesh.cluster @mesh0(rank = 2, dim_sizes = [2, 4])
 func.func @all_reduce_invalid_mesh_axis(
     %arg0 : tensor<4xf32>) -> tensor<4xf64> {
   // expected-error at +1 {{0-based mesh axis index 2 is out of bounds. The referenced mesh "mesh0" is of rank 2.}}
-  %0 = mesh.all_reduce %arg0 {
-    mesh = @mesh0, mesh_axes = array<i16: 2>, reduction = #mesh.partial<sum>
-    } : tensor<4xf32> -> tensor<4xf64>
+  %0 = mesh.all_reduce %arg0 on @mesh0 mesh_axes = [2] reduction = <sum>
+    : tensor<4xf32> -> tensor<4xf64>
   return %0 : tensor<4xf64>
 }
 
@@ -99,9 +97,8 @@ mesh.cluster @mesh0(rank = 2, dim_sizes = [2, 4])
 func.func @all_reduce_duplicate_mesh_axis(
     %arg0 : tensor<4xf32>) -> tensor<4xf64> {
   // expected-error at +1 {{Mesh axes contains duplicate elements.}}
-  %0 = mesh.all_reduce %arg0 {
-    mesh = @mesh0, mesh_axes = array<i16: 0, 1, 0>, reduction = #mesh.partial<sum>
-    } : tensor<4xf32> -> tensor<4xf64>
+  %0 = mesh.all_reduce %arg0 on @mesh0 mesh_axes = [0, 1, 0] reduction = <sum>
+    : tensor<4xf32> -> tensor<4xf64>
   return %0 : tensor<4xf64>
 }
 
@@ -112,7 +109,7 @@ mesh.cluster @mesh0(rank = 2, dim_sizes = [2, 4])
 func.func @all_reduce_invalid_tensor_dimension_size(
     %arg0 : tensor<4xf32>) -> tensor<5xf64> {
   // expected-error at +1 {{'mesh.all_reduce' op requires the same shape for all operands and results}}
-  %0 = mesh.all_reduce %arg0 { mesh = @mesh0 } : tensor<4xf32> -> tensor<5xf64>
+  %0 = mesh.all_reduce %arg0 on @mesh0 : tensor<4xf32> -> tensor<5xf64>
   return %0 : tensor<5xf64>
 }
 
@@ -121,9 +118,8 @@ func.func @all_reduce_invalid_tensor_dimension_size(
 func.func @all_gather_invalid_mesh_symbol(
     %arg0 : tensor<4xf32>) -> tensor<4xf32> {
   // expected-error at +1 {{Undefined required mesh symbol "this_mesh_symbol_does_not_exist".}}
-  %0 = mesh.all_gather %arg0 {
-    mesh = @this_mesh_symbol_does_not_exist, gather_axis = 0 : index
-    } : tensor<4xf32> -> tensor<4xf32>
+  %0 = mesh.all_gather %arg0 on @this_mesh_symbol_does_not_exist gather_axis = 0
+    : tensor<4xf32> -> tensor<4xf32>
   return %0 : tensor<4xf32>
 }
 
@@ -134,9 +130,8 @@ mesh.cluster @mesh0(rank = 2, dim_sizes = [2, 4])
 func.func @all_gather_invalid_mesh_axis(
     %arg0 : tensor<4xf32>) -> tensor<4xf32> {
   // expected-error at +1 {{0-based mesh axis index 2 is out of bounds. The referenced mesh "mesh0" is of rank 2.}}
-  %0 = mesh.all_gather %arg0 {
-    mesh = @mesh0, mesh_axes = array<i16: 2>, gather_axis = 0 : index
-    } : tensor<4xf32> -> tensor<4xf32>
+  %0 = mesh.all_gather %arg0 on @mesh0 mesh_axes = [2] gather_axis = 0
+    : tensor<4xf32> -> tensor<4xf32>
   return %0 : tensor<4xf32>
 }
 
@@ -147,9 +142,8 @@ mesh.cluster @mesh0(rank = 2, dim_sizes = [2, 4])
 func.func @all_reduce_duplicate_mesh_axis(
     %arg0 : tensor<4xf32>) -> tensor<4xf32> {
   // expected-error at +1 {{Mesh axes contains duplicate elements.}}
-  %0 = mesh.all_gather %arg0 {
-    mesh = @mesh0, mesh_axes = array<i16: 2, 2>, gather_axis = 0 : index
-    } : tensor<4xf32> -> tensor<4xf32>
+  %0 = mesh.all_gather %arg0 on @mesh0 mesh_axes = [2, 2] gather_axis = 0
+    : tensor<4xf32> -> tensor<4xf32>
   return %0 : tensor<4xf32>
 }
 
@@ -160,9 +154,8 @@ mesh.cluster @mesh0(rank = 1, dim_sizes = [1])
 func.func @all_gather_invalid_non_gather_axis_dimension_size(
     %arg0 : tensor<3x4xf32>) -> tensor<3x5xf32> {
   // expected-error at +1 {{Dimension size mismatch for result axis 1. Expected 4, but got 5.}}
-  %0 = mesh.all_gather %arg0 {
-    mesh = @mesh0, mesh_axes = array<i16: 0>, gather_axis = 0 : index
-    } : tensor<3x4xf32> -> tensor<3x5xf32>
+  %0 = mesh.all_gather %arg0 on @mesh0 mesh_axes = [0] gather_axis = 0
+    : tensor<3x4xf32> -> tensor<3x5xf32>
   return %0 : tensor<3x5xf32>
 }
 
@@ -173,9 +166,8 @@ mesh.cluster @mesh0(rank = 2, dim_sizes = [1, 2])
 func.func @all_gather_invalid_gather_axis_dimension_size(
     %arg0 : tensor<3x4xf32>) -> tensor<3x5xf32> {
   // expected-error at +1 {{Dimension size mismatch for result axis 1. Expected 8, but got 5.}}
-  %0 = mesh.all_gather %arg0 {
-    mesh = @mesh0, mesh_axes = array<i16: 1>, gather_axis = 1 : index
-    } : tensor<3x4xf32> -> tensor<3x5xf32>
+  %0 = mesh.all_gather %arg0 on @mesh0 mesh_axes = [1] gather_axis = 1
+    : tensor<3x4xf32> -> tensor<3x5xf32>
   return %0 : tensor<3x5xf32>
 }
 
@@ -186,10 +178,8 @@ mesh.cluster @mesh0(rank = 1, dim_sizes = [1])
 func.func @all_gather_invalid_gather_axis_dynamic_dimension(
     %arg0 : tensor<?xf32>) -> tensor<3xf32> {
   // expected-error at +1 {{Dimension size mismatch for result axis 0. Expected dynamic, but got 3.}}
-  %0 = mesh.all_gather %arg0 {
-    gather_axis = 0 : index,
-    mesh = @mesh0
-    } : tensor<?xf32> -> tensor<3xf32>
+  %0 = mesh.all_gather %arg0 on @mesh0 gather_axis = 0
+    : tensor<?xf32> -> tensor<3xf32>
   return %0 : tensor<3xf32>
 }
 
@@ -200,9 +190,8 @@ mesh.cluster @mesh0(rank = 1, dim_sizes = [1])
 func.func @all_gather_invalid_gather_axis(
     %arg0 : tensor<3xf32>) -> tensor<3xf32> {
   // expected-error at +1 {{Gather axis 1 is out of bounds [0, 1).}}
-  %0 = mesh.all_gather %arg0 {
-    mesh = @mesh0, mesh_axes = array<i16: 0>, gather_axis = 1 : index
-    } : tensor<3xf32> -> tensor<3xf32>
+  %0 = mesh.all_gather %arg0 on @mesh0 mesh_axes = [0] gather_axis = 1
+    : tensor<3xf32> -> tensor<3xf32>
   return %0 : tensor<3xf32>
 }
 
@@ -213,9 +202,8 @@ mesh.cluster @mesh0(rank = 1, dim_sizes = [1])
 func.func @all_gather_invalid_negative_gather_axis(
     %arg0 : tensor<3xf32>) -> tensor<3xf32> {
   // expected-error at +1 {{Gather axis -1 is out of bounds [0, 1).}}
-  %0 = mesh.all_gather %arg0 {
-    mesh = @mesh0, mesh_axes = array<i16: 0>, gather_axis = -1 : index
-    } : tensor<3xf32> -> tensor<3xf32>
+  %0 = mesh.all_gather %arg0 on @mesh0 mesh_axes = [0] gather_axis = -1
+    : tensor<3xf32> -> tensor<3xf32>
   return %0 : tensor<3xf32>
 }
 
@@ -224,9 +212,9 @@ func.func @all_gather_invalid_negative_gather_axis(
 func.func @all_to_all_invalid_mesh_symbol(
     %arg0 : tensor<3x6xi8>) -> tensor<3x6xi8> {
   // expected-error at +1 {{Undefined required mesh symbol "this_mesh_symbol_does_not_exist".}}
-  %0 = mesh.all_to_all %arg0 {
-      concat_axis = 0, mesh = @this_mesh_symbol_does_not_exist, split_axis = 1
-    } : tensor<3x6xi8> -> tensor<3x6xi8>
+  %0 = mesh.all_to_all %arg0 on @this_mesh_symbol_does_not_exist
+    split_axis = 1 concat_axis = 0
+    : tensor<3x6xi8> -> tensor<3x6xi8>
   return %0 : tensor<3x6xi8>
 }
 
@@ -237,9 +225,9 @@ mesh.cluster @mesh0(rank = 1, dim_sizes = [1])
 func.func @all_to_all_duplicate_mesh_axis(
     %arg0 : tensor<3x6xi8>) -> tensor<3x6xi8> {
   // expected-error at +1 {{Mesh axes contains duplicate elements.}}
-  %0 = mesh.all_to_all %arg0 {
-      concat_axis = 0, mesh = @mesh0, mesh_axes = array<i16: 0, 0>, split_axis = 0
-    } : tensor<3x6xi8> -> tensor<3x6xi8>
+  %0 = mesh.all_to_all %arg0 on @mesh0 mesh_axes = [0, 0]
+    split_axis = 0 concat_axis = 0
+    : tensor<3x6xi8> -> tensor<3x6xi8>
   return %0 : tensor<3x6xi8>
 }
 
@@ -250,9 +238,9 @@ mesh.cluster @mesh0(rank = 2, dim_sizes = [0, 1])
 func.func @all_to_all_invalid_non_dynamic_result_dimension_induced_by_dynamic_device_group(
     %arg0 : tensor<3x6xi8>) -> tensor<3x6xi8> {
   // expected-error at +1 {{Dimension size mismatch for result axis 1. Expected dynamic, but got 6.}}
-  %0 = mesh.all_to_all %arg0 {
-      concat_axis = 1, mesh = @mesh0, mesh_axes = array<i16: 0>, split_axis = 0
-    } : tensor<3x6xi8> -> tensor<3x6xi8>
+  %0 = mesh.all_to_all %arg0 on @mesh0 mesh_axes = [0]
+    split_axis = 0 concat_axis = 1
+    : tensor<3x6xi8> -> tensor<3x6xi8>
   return %0 : tensor<3x6xi8>
 }
 
@@ -263,9 +251,9 @@ mesh.cluster @mesh0(rank = 2, dim_sizes = [1, 1])
 func.func @all_to_all_invalid_non_dynamic_result_split_dimension_induced_by_dynamic_operand_dimension(
     %arg0 : tensor<?x6xi8>) -> tensor<3x?xi8> {
   // expected-error at +1 {{Dimension size mismatch for result axis 0. Expected dynamic, but got 3.}}
-  %0 = mesh.all_to_all %arg0 {
-      concat_axis = 1, mesh = @mesh0, mesh_axes = array<i16: 1>, split_axis = 0
-    } : tensor<?x6xi8> -> tensor<3x?xi8>
+  %0 = mesh.all_to_all %arg0 on @mesh0 mesh_axes = [1]
+    split_axis = 0 concat_axis = 1
+    : tensor<?x6xi8> -> tensor<3x?xi8>
   return %0 : tensor<3x?xi8>
 }
 
@@ -276,9 +264,9 @@ mesh.cluster @mesh0(rank = 2, dim_sizes = [1, 1])
 func.func @all_to_all_invalid_non_dynamic_result_concat_dimension_induced_by_dynamic_operand_dimension(
     %arg0 : tensor<3x?xi8>) -> tensor<?x3xi8> {
   // expected-error at +1 {{Dimension size mismatch for result axis 1. Expected dynamic, but got 3.}}
-  %0 = mesh.all_to_all %arg0 {
-      concat_axis = 1, mesh = @mesh0, mesh_axes = array<i16: 1>, split_axis = 0
-    } : tensor<3x?xi8> -> tensor<?x3xi8>
+  %0 = mesh.all_to_all %arg0 on @mesh0 mesh_axes = [1]
+    split_axis = 0 concat_axis = 1
+    : tensor<3x?xi8> -> tensor<?x3xi8>
   return %0 : tensor<?x3xi8>
 }
 
@@ -289,9 +277,9 @@ mesh.cluster @mesh0(rank = 1, dim_sizes = [3])
 func.func @all_to_all_invalid_non_dynamic_result_concat_dimension_size(
     %arg0 : tensor<3x2xi8>) -> tensor<1x7xi8> {
   // expected-error at +1 {{Dimension size mismatch for result axis 1. Expected 6, but got 7.}}
-  %0 = mesh.all_to_all %arg0 {
-      concat_axis = 1, mesh = @mesh0, mesh_axes = array<i16: 0>, split_axis = 0
-    } : tensor<3x2xi8> -> tensor<1x7xi8>
+  %0 = mesh.all_to_all %arg0  on @mesh0 mesh_axes = [0]
+    split_axis = 0 concat_axis = 1
+    : tensor<3x2xi8> -> tensor<1x7xi8>
   return %0 : tensor<1x7xi8>
 }
 
@@ -302,9 +290,9 @@ mesh.cluster @mesh0(rank = 1, dim_sizes = [3])
 func.func @all_to_all_invalid_non_dynamic_result_split_dimension_size(
     %arg0 : tensor<3x2xi8>) -> tensor<2x6xi8> {
   // expected-error at +1 {{Dimension size mismatch for result axis 0. Expected 1, but got 2.}}
-  %0 = mesh.all_to_all %arg0 {
-      concat_axis = 1, mesh = @mesh0, mesh_axes = array<i16: 0>, split_axis = 0
-    } : tensor<3x2xi8> -> tensor<2x6xi8>
+  %0 = mesh.all_to_all %arg0 on @mesh0 mesh_axes = [0]
+    split_axis = 0 concat_axis = 1
+    : tensor<3x2xi8> -> tensor<2x6xi8>
   return %0 : tensor<2x6xi8>
 }
 
@@ -315,9 +303,8 @@ mesh.cluster @mesh0(rank = 1, dim_sizes = [3])
 func.func @reduce_scatter_duplicate_mesh_axis(
     %arg0 : tensor<?xf32>) -> tensor<?xf64> {
   // expected-error at +1 {{Mesh axes contains duplicate elements.}}
-  %0 = mesh.reduce_scatter %arg0 {
-      mesh = @mesh0, scatter_axis = 0, mesh_axes = array<i16: 0, 0>
-    } : tensor<?xf32> -> tensor<?xf64>
+  %0 = mesh.reduce_scatter %arg0 on @mesh0 mesh_axes = [0, 0] scatter_axis = 0
+    : tensor<?xf32> -> tensor<?xf64>
   return %0 : tensor<?xf64>
 }
 
@@ -328,9 +315,8 @@ mesh.cluster @mesh0(rank = 1, dim_sizes = [3])
 func.func @reduce_scatter_invalid_dynamic_dimension(
     %arg0 : tensor<?xf32>) -> tensor<2xf64> {
   // expected-error at +1 {{Dimension size mismatch for result axis 0. Expected dynamic, but got 2.}}
-  %0 = mesh.reduce_scatter %arg0 {
-      mesh = @mesh0, scatter_axis = 0
-    } : tensor<?xf32> -> tensor<2xf64>
+  %0 = mesh.reduce_scatter %arg0 on @mesh0 scatter_axis = 0
+    : tensor<?xf32> -> tensor<2xf64>
   return %0 : tensor<2xf64>
 }
 
@@ -341,9 +327,8 @@ mesh.cluster @mesh0(rank = 1, dim_sizes = [3])
 func.func @reduce_scatter_invalid_static_dimension_size(
     %arg0 : tensor<3xf32>) -> tensor<2xf64> {
   // expected-error at +1 {{Dimension size mismatch for result axis 0. Expected 1, but got 2.}}
-  %0 = mesh.reduce_scatter %arg0 {
-      mesh = @mesh0, mesh_axes = array<i16: 0>, scatter_axis = 0
-    } : tensor<3xf32> -> tensor<2xf64>
+  %0 = mesh.reduce_scatter %arg0 on @mesh0 mesh_axes = [0] scatter_axis = 0
+    : tensor<3xf32> -> tensor<2xf64>
   return %0 : tensor<2xf64>
 }
 
@@ -354,8 +339,7 @@ mesh.cluster @mesh0(rank = 1, dim_sizes = [3])
 func.func @reduce_scatter_invalid_operand_static_dimension_size(
     %arg0 : tensor<4xf32>) -> tensor<?xf64> {
   // expected-error at +1 {{Operand dimension size 4 is not divisible by collective device group size 3 for scatter axis 0.}}
-  %0 = mesh.reduce_scatter %arg0 {
-      mesh = @mesh0, mesh_axes = array<i16: 0>, scatter_axis = 0
-    } : tensor<4xf32> -> tensor<?xf64>
+  %0 = mesh.reduce_scatter %arg0 on @mesh0 mesh_axes = [0] scatter_axis = 0
+    : tensor<4xf32> -> tensor<?xf64>
   return %0 : tensor<?xf64>
 }
diff --git a/mlir/test/Dialect/Mesh/ops.mlir b/mlir/test/Dialect/Mesh/ops.mlir
index 5ec6c4a439327f0..3ba9aabacdd7dcc 100644
--- a/mlir/test/Dialect/Mesh/ops.mlir
+++ b/mlir/test/Dialect/Mesh/ops.mlir
@@ -133,12 +133,10 @@ func.func @mesh_shard_op_two_users(%arg0 : tensor<4x8xf32>) ->
 func.func @all_reduce(
     // CHECK-SAME: %[[ARG:.*]]: tensor<3x4xf32>
     %arg0 : tensor<3x4xf32>) -> tensor<3x4xf64> {
-  // CHECK-NEXT: mesh.all_reduce %[[ARG]]
-  // CHECK-SAME: {mesh = @mesh0, mesh_axes = array<i16: 1, 0>, reduction = #mesh.partial<max>}
+  // CHECK-NEXT: mesh.all_reduce %[[ARG]] on @mesh0 mesh_axes = [1, 0] reduction = <max>
   // CHECK-SAME: : tensor<3x4xf32> -> tensor<3x4xf64>
-  %0 = mesh.all_reduce %arg0 {
-      mesh = @mesh0, mesh_axes = array<i16: 1, 0>, reduction = #mesh.partial<max>
-    } : tensor<3x4xf32> -> tensor<3x4xf64>
+  %0 = mesh.all_reduce %arg0 on @mesh0 mesh_axes = [1, 0] reduction = <max>
+    : tensor<3x4xf32> -> tensor<3x4xf64>
   return %0 : tensor<3x4xf64>
 }
 
@@ -146,12 +144,10 @@ func.func @all_reduce(
 func.func @all_gather(
     // CHECK-SAME: %[[ARG:.*]]: tensor<3x4xf32>
     %arg0 : tensor<3x4xf32>) -> tensor<3x16xf32> {
-  // CHECK-NEXT: mesh.all_gather %[[ARG]]
-  // CHECK-SAME: {gather_axis = 1 : index, mesh = @mesh0, mesh_axes = array<i16: 2>}
+  // CHECK-NEXT: mesh.all_gather %[[ARG]] on @mesh0 mesh_axes = [2] gather_axis = 1
   // CHECK-SAME: : tensor<3x4xf32> -> tensor<3x16xf32>
-  %0 = mesh.all_gather %arg0 {
-      gather_axis = 1 : index, mesh = @mesh0, mesh_axes = array<i16: 2>
-    } : tensor<3x4xf32> -> tensor<3x16xf32>
+  %0 = mesh.all_gather %arg0 on @mesh0 mesh_axes = [2] gather_axis = 1
+    : tensor<3x4xf32> -> tensor<3x16xf32>
   return %0 : tensor<3x16xf32>
 }
 
@@ -159,12 +155,10 @@ func.func @all_gather(
 func.func @all_gather_dynamic_dims_in_tensor(
     // CHECK-SAME: %[[ARG:.*]]: tensor<?x?xf32>
     %arg0 : tensor<?x?xf32>) -> tensor<?x?xf32> {
-  // CHECK-NEXT: mesh.all_gather %[[ARG]]
-  // CHECK-SAME: {gather_axis = 1 : index, mesh = @mesh0, mesh_axes = array<i16: 2>}
+  // CHECK-NEXT: mesh.all_gather %[[ARG]] on @mesh0 mesh_axes = [2] gather_axis = 1
   // CHECK-SAME: : tensor<?x?xf32> -> tensor<?x?xf32>
-  %0 = mesh.all_gather %arg0 {
-      gather_axis = 1 : index, mesh = @mesh0, mesh_axes = array<i16: 2>
-    } : tensor<?x?xf32> -> tensor<?x?xf32>
+  %0 = mesh.all_gather %arg0 on @mesh0 mesh_axes = [2] gather_axis = 1
+    : tensor<?x?xf32> -> tensor<?x?xf32>
   return %0 : tensor<?x?xf32>
 }
 
@@ -172,12 +166,10 @@ func.func @all_gather_dynamic_dims_in_tensor(
 func.func @all_gather_dynamic_dims_in_mesh(
     // CHECK-SAME: %[[ARG:.*]]: tensor<5x6xf32>
     %arg0 : tensor<5x6xf32>) -> tensor<5x?xf32> {
-  // CHECK-NEXT: mesh.all_gather %[[ARG]]
-  // CHECK-SAME: {gather_axis = 1 : index, mesh = @mesh3, mesh_axes = array<i16: 1>}
+  // CHECK-NEXT: mesh.all_gather %[[ARG]] on @mesh3 mesh_axes = [1] gather_axis = 1
   // CHECK-SAME: : tensor<5x6xf32> -> tensor<5x?xf32>
-  %0 = mesh.all_gather %arg0 {
-      gather_axis = 1 : index, mesh = @mesh3, mesh_axes = array<i16: 1>
-    } : tensor<5x6xf32> -> tensor<5x?xf32>
+  %0 = mesh.all_gather %arg0 on @mesh3 mesh_axes = [1] gather_axis = 1
+    : tensor<5x6xf32> -> tensor<5x?xf32>
   return %0 : tensor<5x?xf32>
 }
 
@@ -186,11 +178,11 @@ func.func @all_to_all(
     // CHECK-SAME: %[[ARG:.*]]: tensor<3x6xi8>
     %arg0 : tensor<3x6xi8>) -> tensor<3x6xi8> {
   // CHECK-NEXT: mesh.all_to_all %[[ARG]]
-  // CHECK-SAME: {concat_axis = 0 : i64, mesh = @mesh4, split_axis = 1 : i64}
+  // CHECK-SAME: on @mesh4 split_axis = 1 concat_axis = 0
   // CHECK-SAME: : tensor<3x6xi8> -> tensor<3x6xi8>
-  %0 = mesh.all_to_all %arg0 {
-      concat_axis = 0, mesh = @mesh4, split_axis = 1
-    } : tensor<3x6xi8> -> tensor<3x6xi8>
+  %0 = mesh.all_to_all %arg0 on @mesh4
+    split_axis = 1 concat_axis = 0
+    : tensor<3x6xi8> -> tensor<3x6xi8>
   return %0 : tensor<3x6xi8>
 }
 
@@ -199,11 +191,11 @@ func.func @all_to_all_dynamic_dims_in_result(
     // CHECK-SAME: %[[ARG:.*]]: tensor<3x6xi8>
     %arg0 : tensor<3x6xi8>) -> tensor<3x?xi8> {
   // CHECK-NEXT: mesh.all_to_all %[[ARG]]
-  // CHECK-SAME: {concat_axis = 0 : i64, mesh = @mesh4, split_axis = 1 : i64}
+  // CHECK-SAME: on @mesh4 split_axis = 1 concat_axis = 0
   // CHECK-SAME: : tensor<3x6xi8> -> tensor<3x?xi8>
-  %0 = mesh.all_to_all %arg0 {
-      concat_axis = 0, mesh = @mesh4, split_axis = 1
-    } : tensor<3x6xi8> -> tensor<3x?xi8>
+  %0 = mesh.all_to_all %arg0 on @mesh4
+    split_axis = 1 concat_axis = 0
+    : tensor<3x6xi8> -> tensor<3x?xi8>
   return %0 : tensor<3x?xi8>
 }
 
@@ -212,11 +204,11 @@ func.func @all_to_all_same_split_concat_dim_with_dynamic_device_group_size(
     // CHECK-SAME: %[[ARG:.*]]: tensor<3xi8>
     %arg0 : tensor<3xi8>) -> tensor<3xi8> {
   // CHECK-NEXT: mesh.all_to_all %[[ARG]]
-  // CHECK-SAME: {concat_axis = 0 : i64, mesh = @mesh4, split_axis = 0 : i64}
+  // CHECK-SAME: @mesh4 split_axis = 0 concat_axis = 0
   // CHECK-SAME: : tensor<3xi8> -> tensor<3xi8>
-  %0 = mesh.all_to_all %arg0 {
-      concat_axis = 0, mesh = @mesh4, split_axis = 0
-    } : tensor<3xi8> -> tensor<3xi8>
+  %0 = mesh.all_to_all %arg0 on @mesh4
+    split_axis = 0 concat_axis = 0
+    : tensor<3xi8> -> tensor<3xi8>
   return %0 : tensor<3xi8>
 }
 
@@ -225,11 +217,11 @@ func.func @reduce_scatter_static_dimensions(
     // CHECK-SAME: %[[ARG:.*]]: tensor<3x4xf32>
     %arg0 : tensor<3x4xf32>) -> tensor<3x1xf64> {
   // CHECK-NEXT: mesh.reduce_scatter %[[ARG]]
-  // CHECK-SAME: {mesh = @mesh0, mesh_axes = array<i16: 2>, reduction = #mesh.partial<max>, scatter_axis = 1 : i64}
+  // CHECK-SAME: on @mesh0 mesh_axes = [2] reduction = <max> scatter_axis = 1
   // CHECK-SAME: : tensor<3x4xf32> -> tensor<3x1xf64>
-  %0 = mesh.reduce_scatter %arg0 {
-      mesh = @mesh0, mesh_axes = array<i16: 2>, reduction = #mesh.partial<max>, scatter_axis = 1
-    } : tensor<3x4xf32> -> tensor<3x1xf64>
+  %0 = mesh.reduce_scatter %arg0 on @mesh0 mesh_axes = [2]
+    reduction = <max> scatter_axis = 1
+    : tensor<3x4xf32> -> tensor<3x1xf64>
   return %0 : tensor<3x1xf64>
 }
 
@@ -238,10 +230,9 @@ func.func @reduce_scatter_dynamic_dimensions(
     // CHECK-SAME: %[[ARG:.*]]: tensor<?xf32>
     %arg0 : tensor<?xf32>) -> tensor<?xf64> {
   // CHECK-NEXT: mesh.reduce_scatter %[[ARG]]
-  // CHECK-SAME: {mesh = @mesh3, mesh_axes = array<i16: 0, 1>, scatter_axis = 0 : i64}
+  // CHECK-SAME: on @mesh3 mesh_axes = [0, 1] scatter_axis = 0
   // CHECK-SAME: : tensor<?xf32> -> tensor<?xf64>
-  %0 = mesh.reduce_scatter %arg0 {
-      mesh = @mesh3, mesh_axes = array<i16: 0, 1>, scatter_axis = 0
-    } : tensor<?xf32> -> tensor<?xf64>
+  %0 = mesh.reduce_scatter %arg0 on @mesh3 mesh_axes = [0, 1] scatter_axis = 0
+    : tensor<?xf32> -> tensor<?xf64>
   return %0 : tensor<?xf64>
 }

>From fbcf70bbbc13b89cd2c03b2508b41dd7ce06240f Mon Sep 17 00:00:00 2001
From: Boian Petkantchin <boian at nod-labs.com>
Date: Wed, 15 Nov 2023 16:20:13 -0800
Subject: [PATCH 04/10] Add documentation for ClusterOp::canonicalDimSizes

---
 mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td | 3 +++
 1 file changed, 3 insertions(+)

diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
index eb890695c20fff0..73de9aa3133f6fa 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
@@ -80,6 +80,9 @@ def Mesh_ClusterOp : Mesh_Op<"cluster", [Symbol]> {
       attr-dict
   }];
   let extraClassDeclaration = [{
+    // The `dim_sizes` attribute may have size less than the rank of the mesh.
+    // Returns the shape of the mesh with missing trailing dimensions
+    // explicitly set as dynamic.
     ::mlir::SmallVector<int64_t> canonicalDimSizes();
 
     template <typename OutIt>

>From fde4cb2e48a8adf2a10345369d109e5cf0a6ea8f Mon Sep 17 00:00:00 2001
From: Boian Petkantchin <boian at nod-labs.com>
Date: Wed, 15 Nov 2023 16:21:13 -0800
Subject: [PATCH 05/10] Make getMesh and verifyMeshSymbolUses non-template

---
 mlir/lib/Dialect/Mesh/IR/MeshOps.cpp | 124 +++++++++++++++------------
 1 file changed, 68 insertions(+), 56 deletions(-)

diff --git a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
index bdbd3ad1514dfc6..7ef250d3db2b808 100644
--- a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
+++ b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
@@ -38,33 +38,33 @@ using namespace mlir::mesh;
 
 #include "mlir/Dialect/Mesh/IR/MeshOpsDialect.cpp.inc"
 
-namespace {
-
 template <typename It>
-It canonicalizeSetAsArray(It begin, It end) {
-  std::sort(begin, end);
+static It canonicalizeSetAsArray(It begin, It end) {
+  llvm::sort(begin, end);
   return std::unique(begin, end);
 }
 
 template <typename R>
-auto canonicalizeSetAsArray(R &&range) {
+static auto canonicalizeSetAsArray(R &&range) {
   return canonicalizeSetAsArray(adl_begin(range), adl_end(range));
 }
 
 template <typename T>
-SmallVector<T> &canonicalizeSetAsVector(SmallVector<T> &vec) {
+static SmallVector<T> &canonicalizeSetAsVector(SmallVector<T> &vec) {
   auto newEnd = canonicalizeSetAsArray(vec);
   vec.resize(newEnd - vec.begin());
   return vec;
 }
 
 template <typename DimSize>
-bool isMeshDimensionDynamic(DimSize size) {
+static bool isMeshDimensionDynamic(DimSize size) {
   return size <= DimSize(0);
 }
 
 using MeshAxis = int16_t;
 
+namespace {
+
 struct DimensionSize {
   static DimensionSize dynamic() { return DimensionSize(ShapedType::kDynamic); }
   DimensionSize(int64_t val) : val(val) {}
@@ -76,22 +76,22 @@ struct DimensionSize {
   int64_t val;
 };
 
-DimensionSize operator/(DimensionSize lhs, DimensionSize rhs) {
+} // namespace
+
+static DimensionSize operator/(DimensionSize lhs, DimensionSize rhs) {
   if (lhs.isDynamic() || rhs.isDynamic()) {
     return DimensionSize::dynamic();
   }
   return lhs.value() / rhs.value();
 }
 
-DimensionSize operator*(DimensionSize lhs, DimensionSize rhs) {
+static DimensionSize operator*(DimensionSize lhs, DimensionSize rhs) {
   if (lhs.isDynamic() || rhs.isDynamic()) {
     return DimensionSize::dynamic();
   }
   return lhs.value() * rhs.value();
 }
 
-} // namespace
-
 //===----------------------------------------------------------------------===//
 // Mesh dialect
 //===----------------------------------------------------------------------===//
@@ -211,33 +211,31 @@ MeshShardingAttr::verify(function_ref<InFlightDiagnostic()> emitError,
 // collective communication ops
 //===----------------------------------------------------------------------===//
 
-namespace {
-
-template <typename Op>
-LogicalResult verifyMeshSymbolUses(Op op, SymbolTableCollection &symbolTable) {
-  FlatSymbolRefAttr symbolAttr = op.getMeshAttr();
-  if (!symbolAttr) {
-    return op.emitError() << "Unspecified \"mesh\" symbol attribute.";
+static LogicalResult verifyMeshSymbolUses(Operation *op,
+                                          FlatSymbolRefAttr meshSymbol,
+                                          DenseI16ArrayAttr meshAxes,
+                                          SymbolTableCollection &symbolTable) {
+  if (!meshSymbol) {
+    return op->emitError() << "Unspecified \"mesh\" symbol attribute.";
   }
   SymbolTableCollection symbolTableCollection;
   mesh::ClusterOp mesh =
       symbolTableCollection.lookupNearestSymbolFrom<mesh::ClusterOp>(
-          op.getOperation(), symbolAttr);
+          op, meshSymbol);
   if (!mesh) {
-    return op.emitError() << "Undefined required mesh symbol \""
-                          << symbolAttr.getValue() << "\".";
+    return op->emitError() << "Undefined required mesh symbol \""
+                           << meshSymbol.getValue() << "\".";
   }
-  DenseI16ArrayAttr meshAxes = op.getMeshAxesAttr();
   if (!meshAxes) {
     return success();
   }
   MeshAxis rank = mesh.getRank();
   for (auto axis : meshAxes.asArrayRef()) {
     if (axis >= rank || axis < 0) {
-      return op.emitError()
+      return op->emitError()
              << "0-based mesh axis index " << axis
              << " is out of bounds. The referenced mesh \""
-             << symbolAttr.getValue() << "\" is of rank " << rank << ".";
+             << meshSymbol.getValue() << "\" is of rank " << rank << ".";
     }
   }
 
@@ -261,9 +259,9 @@ bool isUnique(It begin, It end) {
   return true;
 }
 
-LogicalResult verifyMeshAxes(Location loc, ArrayRef<MeshAxis> axes) {
+static LogicalResult verifyMeshAxes(Location loc, ArrayRef<MeshAxis> axes) {
   SmallVector<MeshAxis> sorted = llvm::to_vector(axes);
-  std::sort(sorted.begin(), sorted.end());
+  llvm::sort(sorted);
   if (!isUnique(sorted.begin(), sorted.end())) {
     return emitError(loc) << "Mesh axes contains duplicate elements.";
   }
@@ -271,19 +269,19 @@ LogicalResult verifyMeshAxes(Location loc, ArrayRef<MeshAxis> axes) {
 }
 
 template <typename It>
-auto product(It begin, It end) {
+static auto product(It begin, It end) {
   using ElementType = std::decay_t<decltype(*begin)>;
   return std::accumulate(begin, end, static_cast<ElementType>(1),
                          std::multiplies<ElementType>());
 }
 
 template <typename R>
-auto product(R &&range) {
+static auto product(R &&range) {
   return product(adl_begin(range), adl_end(range));
 }
 
-int64_t collectiveDeviceGroupSize(ArrayRef<MeshAxis> meshAxes,
-                                  ArrayRef<int64_t> meshShape) {
+static int64_t collectiveDeviceGroupSize(ArrayRef<MeshAxis> meshAxes,
+                                         ArrayRef<int64_t> meshShape) {
   int64_t res = 1;
 
   for (MeshAxis axis : meshAxes) {
@@ -297,10 +295,10 @@ int64_t collectiveDeviceGroupSize(ArrayRef<MeshAxis> meshAxes,
   return res;
 }
 
-LogicalResult verifyDimensionCompatibility(Location loc,
-                                           int64_t expectedDimSize,
-                                           int64_t resultDimSize,
-                                           int64_t resultAxis) {
+static LogicalResult verifyDimensionCompatibility(Location loc,
+                                                  int64_t expectedDimSize,
+                                                  int64_t resultDimSize,
+                                                  int64_t resultAxis) {
   if (!ShapedType::isDynamic(resultDimSize) &&
       expectedDimSize != resultDimSize) {
     return emitError(loc) << "Dimension size mismatch for result axis "
@@ -314,7 +312,7 @@ LogicalResult verifyDimensionCompatibility(Location loc,
   return success();
 }
 
-LogicalResult verifyAllGatherOperandAndResultShape(
+static LogicalResult verifyAllGatherOperandAndResultShape(
     Value operand, Value result, int64_t gatherAxis,
     ArrayRef<MeshAxis> meshAxes, ArrayRef<int64_t> meshShape) {
   ShapedType operandType = operand.getType().cast<ShapedType>();
@@ -334,20 +332,20 @@ LogicalResult verifyAllGatherOperandAndResultShape(
   return success();
 }
 
-template <typename Op>
-FailureOr<ClusterOp> getMesh(Op op) {
+static FailureOr<ClusterOp> getMesh(Operation *op, FlatSymbolRefAttr meshSymbol,
+                                    DenseI16ArrayAttr meshAxes) {
   SymbolTableCollection symbolTableCollection;
-  if (failed(verifyMeshSymbolUses(op, symbolTableCollection))) {
+  if (failed(verifyMeshSymbolUses(op, meshSymbol, meshAxes,
+                                  symbolTableCollection))) {
     // We need to check the symbol here since this runs before
     // SymbolUserOpInterface.
     return failure();
   }
   return symbolTableCollection.lookupNearestSymbolFrom<mesh::ClusterOp>(
-      op.getOperation(), op.getMeshAttr());
+      op, meshSymbol);
 }
 
-template <typename Op>
-LogicalResult verifyAllGather(Op op) {
+static LogicalResult verifyAllGather(AllGatherOp op) {
   auto rank = op.getResult().getType().template cast<ShapedType>().getRank();
   auto gatherAxis = op.getGatherAxis().getSExtValue();
   if (gatherAxis < 0 || gatherAxis >= rank) {
@@ -355,7 +353,9 @@ LogicalResult verifyAllGather(Op op) {
                           << " is out of bounds [0, " << rank << ").";
   }
 
-  auto mesh = getMesh(op);
+  FlatSymbolRefAttr meshSymbol = op.getMeshAttr();
+  DenseI16ArrayAttr meshAxes = op.getMeshAxesAttr();
+  auto mesh = getMesh(op.getOperation(), meshSymbol, meshAxes);
   if (failed(mesh)) {
     return failure();
   }
@@ -364,11 +364,9 @@ LogicalResult verifyAllGather(Op op) {
                                               mesh.value().canonicalDimSizes());
 }
 
-LogicalResult verifyAllToAllOperandAndResultShape(Value operand, Value result,
-                                                  int64_t splitAxis,
-                                                  int64_t concatAxis,
-                                                  ArrayRef<MeshAxis> meshAxes,
-                                                  ArrayRef<int64_t> meshShape) {
+static LogicalResult verifyAllToAllOperandAndResultShape(
+    Value operand, Value result, int64_t splitAxis, int64_t concatAxis,
+    ArrayRef<MeshAxis> meshAxes, ArrayRef<int64_t> meshShape) {
   ShapedType operandType = operand.getType().cast<ShapedType>();
   ShapedType resultType = result.getType().cast<ShapedType>();
   for (int64_t axis = 0; axis < operandType.getRank(); ++axis) {
@@ -415,7 +413,7 @@ LogicalResult verifyAllToAllOperandAndResultShape(Value operand, Value result,
   return success();
 }
 
-LogicalResult verifyReduceScatterOperandAndResultShape(
+static LogicalResult verifyReduceScatterOperandAndResultShape(
     Value operand, Value result, int64_t scatterAxis,
     ArrayRef<MeshAxis> meshAxes, ArrayRef<int64_t> meshShape) {
   ShapedType operandType = operand.getType().cast<ShapedType>();
@@ -453,15 +451,16 @@ LogicalResult verifyReduceScatterOperandAndResultShape(
   return success();
 }
 
-} // namespace
-
 //===----------------------------------------------------------------------===//
 // mesh.all_reduce op
 //===----------------------------------------------------------------------===//
 
 LogicalResult
 AllReduceOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
-  return verifyMeshSymbolUses(*this, symbolTable);
+  FlatSymbolRefAttr meshSymbol = getMeshAttr();
+  DenseI16ArrayAttr meshAxes = getMeshAxesAttr();
+  return verifyMeshSymbolUses(getOperation(), meshSymbol, meshAxes,
+                              symbolTable);
 }
 
 LogicalResult mlir::mesh::AllReduceOp::verify() {
@@ -474,7 +473,10 @@ LogicalResult mlir::mesh::AllReduceOp::verify() {
 
 LogicalResult
 AllGatherOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
-  return verifyMeshSymbolUses(*this, symbolTable);
+  FlatSymbolRefAttr meshSymbol = getMeshAttr();
+  DenseI16ArrayAttr meshAxes = getMeshAxesAttr();
+  return verifyMeshSymbolUses(getOperation(), meshSymbol, meshAxes,
+                              symbolTable);
 }
 
 LogicalResult mlir::mesh::AllGatherOp::verify() {
@@ -489,14 +491,19 @@ LogicalResult mlir::mesh::AllGatherOp::verify() {
 //===----------------------------------------------------------------------===//
 
 LogicalResult AllToAllOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
-  return verifyMeshSymbolUses(*this, symbolTable);
+  FlatSymbolRefAttr meshSymbol = getMeshAttr();
+  DenseI16ArrayAttr meshAxes = getMeshAxesAttr();
+  return verifyMeshSymbolUses(getOperation(), meshSymbol, meshAxes,
+                              symbolTable);
 }
 
 LogicalResult AllToAllOp::verify() {
   if (failed(verifyMeshAxes(getLoc(), getMeshAxes()))) {
     return failure();
   }
-  auto mesh = ::getMesh(*this);
+  FlatSymbolRefAttr meshSymbol = getMeshAttr();
+  DenseI16ArrayAttr meshAxes = getMeshAxesAttr();
+  auto mesh = ::getMesh(getOperation(), meshSymbol, meshAxes);
   if (failed(mesh)) {
     return failure();
   }
@@ -512,14 +519,19 @@ LogicalResult AllToAllOp::verify() {
 
 LogicalResult
 ReduceScatterOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
-  return verifyMeshSymbolUses(*this, symbolTable);
+  FlatSymbolRefAttr meshSymbol = getMeshAttr();
+  DenseI16ArrayAttr meshAxes = getMeshAxesAttr();
+  return verifyMeshSymbolUses(getOperation(), meshSymbol, meshAxes,
+                              symbolTable);
 }
 
 LogicalResult ReduceScatterOp::verify() {
   if (failed(verifyMeshAxes(getLoc(), getMeshAxes()))) {
     return failure();
   }
-  auto mesh = ::getMesh(*this);
+  FlatSymbolRefAttr meshSymbol = getMeshAttr();
+  DenseI16ArrayAttr meshAxes = getMeshAxesAttr();
+  auto mesh = ::getMesh(getOperation(), meshSymbol, meshAxes);
   if (failed(mesh)) {
     return failure();
   }

>From 530e30cfca2c17b1dda5d1fb1d36ca37bb93ef04 Mon Sep 17 00:00:00 2001
From: Boian Petkantchin <boian at nod-labs.com>
Date: Wed, 15 Nov 2023 16:35:45 -0800
Subject: [PATCH 06/10] Add canonicalization test for all-to-all

---
 mlir/test/Dialect/Mesh/canonicalization.mlir | 12 ++++++++++++
 1 file changed, 12 insertions(+)

diff --git a/mlir/test/Dialect/Mesh/canonicalization.mlir b/mlir/test/Dialect/Mesh/canonicalization.mlir
index 3383dc2bec2815c..a06105c9b706a98 100644
--- a/mlir/test/Dialect/Mesh/canonicalization.mlir
+++ b/mlir/test/Dialect/Mesh/canonicalization.mlir
@@ -15,6 +15,18 @@ func.func @all_reduce_empty_mesh_axes_and_default_reduction(
   return %0 : tensor<4xf64>
 }
 
+// CHECK-LABEL: func @all_to_all_empty_mesh_axes
+func.func @all_to_all_empty_mesh_axes(
+    %arg0 : tensor<8xf32>) -> tensor<8xf32> {
+  %0 = mesh.all_to_all %arg0 on @mesh0
+// CHECK-NOT: mesh_axes
+    mesh_axes = []
+    split_axis = 0
+    concat_axis = 0
+    : tensor<8xf32> -> tensor<8xf32>
+  return %0 : tensor<8xf32>
+}
+
 // CHECK-LABEL: func @all_gather_empty_mesh_axes
 func.func @all_gather_empty_mesh_axes(
     %arg0 : tensor<4xf32>) -> tensor<4xf32> {

>From b22669ee86eccfcaddb610889058a4ea8f15b4da Mon Sep 17 00:00:00 2001
From: Boian Petkantchin <boian at nod-labs.com>
Date: Wed, 15 Nov 2023 17:44:57 -0800
Subject: [PATCH 07/10] Add removal of collectives with empty mesh axes

---
 mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h  |  1 +
 mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td |  4 ++
 mlir/lib/Dialect/Mesh/IR/MeshOps.cpp         | 43 +++++++++++++
 mlir/test/Dialect/Mesh/canonicalization.mlir | 65 +++++++++++++++++---
 4 files changed, 105 insertions(+), 8 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
index 7698d60813a8f10..9077d2eb0189b72 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
@@ -11,6 +11,7 @@
 
 #include "mlir/Bytecode/BytecodeOpInterface.h"
 #include "mlir/IR/OpDefinition.h"
+#include "mlir/IR/PatternMatch.h"
 #include "mlir/IR/SymbolTable.h"
 #include "mlir/Interfaces/InferTypeOpInterface.h"
 #include "mlir/Interfaces/SideEffectInterfaces.h"
diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
index 73de9aa3133f6fa..a72c3970501dddb 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
@@ -251,6 +251,7 @@ def Mesh_AllGatherOp : Mesh_CollectiveCommunicationOpBase<"all_gather", [
     attr-dict `:` type($input) `->` type($result)
   }];
   let hasVerifier = 1;
+  let hasCanonicalizer = 1;
 }
 
 def Mesh_AllReduceOp : Mesh_CollectiveCommunicationOpBase<"all_reduce", [
@@ -283,6 +284,7 @@ def Mesh_AllReduceOp : Mesh_CollectiveCommunicationOpBase<"all_reduce", [
     attr-dict `:` type($input) `->` type($result)
   }];
   let hasVerifier = 1;
+  let hasCanonicalizer = 1;
 }
 
 def Mesh_AllToAllOp : Mesh_CollectiveCommunicationOpBase<"all_to_all", [
@@ -336,6 +338,7 @@ def Mesh_AllToAllOp : Mesh_CollectiveCommunicationOpBase<"all_to_all", [
     attr-dict `:` type($input) `->` type($result)
   }];
   let hasVerifier = 1;
+  let hasCanonicalizer = 1;
 }
 
 def Mesh_ReduceScatterOp : Mesh_CollectiveCommunicationOpBase<"reduce_scatter", [
@@ -397,6 +400,7 @@ def Mesh_ReduceScatterOp : Mesh_CollectiveCommunicationOpBase<"reduce_scatter",
     attr-dict `:` type($input) `->` type($result)
   }];
   let hasVerifier = 1;
+  let hasCanonicalizer = 1;
 }
 
 #endif // MLIR_DIALECT_MESH_IR_MESHOPS_TD
diff --git a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
index 7ef250d3db2b808..6d29894917b0ba8 100644
--- a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
+++ b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
@@ -211,6 +211,29 @@ MeshShardingAttr::verify(function_ref<InFlightDiagnostic()> emitError,
 // collective communication ops
 //===----------------------------------------------------------------------===//
 
+namespace {
+
+template <typename Op>
+struct EmptyMeshAxesCanonicalizationPattern : OpRewritePattern<Op> {
+  using OpRewritePattern<Op>::OpRewritePattern;
+  LogicalResult matchAndRewrite(Op op,
+                                PatternRewriter &rewriter) const override {
+    auto meshAxes = op.getMeshAxes();
+    if (!meshAxes.empty()) {
+      return failure();
+    }
+    if (op.getInput().getType() != op.getResult().getType()) {
+      return failure();
+    }
+
+    rewriter.replaceAllUsesWith(op.getResult(), op.getInput());
+    rewriter.eraseOp(op.getOperation());
+    return success();
+  }
+};
+
+} // namespace
+
 static LogicalResult verifyMeshSymbolUses(Operation *op,
                                           FlatSymbolRefAttr meshSymbol,
                                           DenseI16ArrayAttr meshAxes,
@@ -467,6 +490,11 @@ LogicalResult mlir::mesh::AllReduceOp::verify() {
   return verifyMeshAxes(getLoc(), getMeshAxes());
 }
 
+void AllReduceOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
+                                              MLIRContext *context) {
+  patterns.add<EmptyMeshAxesCanonicalizationPattern<AllReduceOp>>(context);
+}
+
 //===----------------------------------------------------------------------===//
 // mesh.all_gather op
 //===----------------------------------------------------------------------===//
@@ -486,6 +514,11 @@ LogicalResult mlir::mesh::AllGatherOp::verify() {
   return verifyAllGather(*this);
 }
 
+void AllGatherOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
+                                              MLIRContext *context) {
+  patterns.add<EmptyMeshAxesCanonicalizationPattern<AllGatherOp>>(context);
+}
+
 //===----------------------------------------------------------------------===//
 // mesh.all_to_all op
 //===----------------------------------------------------------------------===//
@@ -513,6 +546,11 @@ LogicalResult AllToAllOp::verify() {
       mesh.value().canonicalDimSizes());
 }
 
+void AllToAllOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
+                                             MLIRContext *context) {
+  patterns.add<EmptyMeshAxesCanonicalizationPattern<AllToAllOp>>(context);
+}
+
 //===----------------------------------------------------------------------===//
 // mesh.reduce_scatter op
 //===----------------------------------------------------------------------===//
@@ -540,6 +578,11 @@ LogicalResult ReduceScatterOp::verify() {
       mesh.value().canonicalDimSizes());
 }
 
+void ReduceScatterOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
+                                                  MLIRContext *context) {
+  patterns.add<EmptyMeshAxesCanonicalizationPattern<ReduceScatterOp>>(context);
+}
+
 //===----------------------------------------------------------------------===//
 // TableGen'd op method definitions
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Mesh/canonicalization.mlir b/mlir/test/Dialect/Mesh/canonicalization.mlir
index a06105c9b706a98..5802d198d368149 100644
--- a/mlir/test/Dialect/Mesh/canonicalization.mlir
+++ b/mlir/test/Dialect/Mesh/canonicalization.mlir
@@ -2,13 +2,34 @@
 
 mesh.cluster @mesh0(rank = 2, dim_sizes = [2, 4])
 
+// CHECK-LABEL: func @all_reduce_empty_mesh_axes
+func.func @all_reduce_empty_mesh_axes(
+// CHECK-SAME: %[[ARG:.*]]: tensor<4xf32>
+    %arg0 : tensor<4xf32>) -> tensor<4xf32> {
+// CHECK-NOT: mesh.all_reduce
+  %0 = mesh.all_reduce %arg0 on @mesh0
+    mesh_axes = []
+    : tensor<4xf32> -> tensor<4xf32>
+// CHECK: return %[[ARG]]
+  return %0 : tensor<4xf32>
+}
 
-// CHECK-LABEL: func @all_reduce_empty_mesh_axes_and_default_reduction
-func.func @all_reduce_empty_mesh_axes_and_default_reduction(
+// CHECK-LABEL: func @all_reduce_empty_mesh_axes_different_return_type
+func.func @all_reduce_empty_mesh_axes_different_return_type(
     %arg0 : tensor<4xf32>) -> tensor<4xf64> {
+// CHECK: mesh.all_reduce
   %0 = mesh.all_reduce %arg0 on @mesh0
 // CHECK-NOT: mesh_axes
     mesh_axes = []
+    : tensor<4xf32> -> tensor<4xf64>
+  return %0 : tensor<4xf64>
+}
+
+// CHECK-LABEL: func @all_reduce_default_reduction
+func.func @all_reduce_default_reduction(
+    %arg0 : tensor<4xf32>) -> tensor<4xf64> {
+  %0 = mesh.all_reduce %arg0 on @mesh0
+    mesh_axes = [0]
 // CHECK-NOT: reduction
     reduction = <sum>
     : tensor<4xf32> -> tensor<4xf64>
@@ -17,36 +38,64 @@ func.func @all_reduce_empty_mesh_axes_and_default_reduction(
 
 // CHECK-LABEL: func @all_to_all_empty_mesh_axes
 func.func @all_to_all_empty_mesh_axes(
+// CHECK-SAME: %[[ARG:.*]]: tensor<8xf32>
     %arg0 : tensor<8xf32>) -> tensor<8xf32> {
+// CHECK-NOT: mesh.all_to_all
   %0 = mesh.all_to_all %arg0 on @mesh0
-// CHECK-NOT: mesh_axes
     mesh_axes = []
     split_axis = 0
     concat_axis = 0
     : tensor<8xf32> -> tensor<8xf32>
+// CHECK: return %[[ARG]]
   return %0 : tensor<8xf32>
 }
 
 // CHECK-LABEL: func @all_gather_empty_mesh_axes
 func.func @all_gather_empty_mesh_axes(
+// CHECK-SAME: %[[ARG:.*]]: tensor<4xf32>
     %arg0 : tensor<4xf32>) -> tensor<4xf32> {
+// CHECK-NOT: mesh.all_gather
   %0 = mesh.all_gather %arg0 on @mesh0
-// CHECK-NOT: mesh_axes
     mesh_axes = []
     gather_axis = 0
     : tensor<4xf32> -> tensor<4xf32>
+// CHECK: return %[[ARG]]
+  return %0 : tensor<4xf32>
+}
+
+// CHECK-LABEL: func @reduce_scatter_empty_mesh_axes
+func.func @reduce_scatter_empty_mesh_axes(
+// CHECK-SAME: %[[ARG:.*]]: tensor<4xf32>
+    %arg0 : tensor<4xf32>) -> tensor<4xf32> {
+// CHECK-NOT: mesh.reduce_scatter
+  %0 = mesh.reduce_scatter %arg0 on @mesh0
+    mesh_axes = []
+    scatter_axis = 0
+    : tensor<4xf32> -> tensor<4xf32>
+// CHECK: return %[[ARG]]
   return %0 : tensor<4xf32>
 }
 
-// CHECK-LABEL: func @reduce_scatter_empty_mesh_axes_and_default_reduction
-func.func @reduce_scatter_empty_mesh_axes_and_default_reduction(
+// CHECK-LABEL: func @reduce_scatter_empty_mesh_axes_different_return_type
+func.func @reduce_scatter_empty_mesh_axes_different_return_type(
     %arg0 : tensor<4xf32>) -> tensor<4xf64> {
+// CHECK: mesh.reduce_scatter
   %0 = mesh.reduce_scatter %arg0 on @mesh0
 // CHECK-NOT: mesh_axes
     mesh_axes = []
-// CHECK-NOT: reduction
-    reduction = <sum>
     scatter_axis = 0
     : tensor<4xf32> -> tensor<4xf64>
   return %0 : tensor<4xf64>
 }
+
+// CHECK-LABEL: func @reduce_scatter_default_reduction
+func.func @reduce_scatter_default_reduction(
+    %arg0 : tensor<4xf32>) -> tensor<2xf64> {
+  %0 = mesh.reduce_scatter %arg0 on @mesh0
+    mesh_axes = [0]
+// CHECK-NOT: reduction
+    reduction = <sum>
+    scatter_axis = 0
+    : tensor<4xf32> -> tensor<2xf64>
+  return %0 : tensor<2xf64>
+}

>From abbe6a6161f6d28652950498e6834d743eb10aa2 Mon Sep 17 00:00:00 2001
From: Boian Petkantchin <boian at nod-labs.com>
Date: Wed, 15 Nov 2023 18:04:19 -0800
Subject: [PATCH 08/10] Remove useless check for missing mesh axes attribute

It is non-optional now.
---
 mlir/lib/Dialect/Mesh/IR/MeshOps.cpp | 3 ---
 1 file changed, 3 deletions(-)

diff --git a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
index 6d29894917b0ba8..3f1fc3ec1be9f74 100644
--- a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
+++ b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
@@ -249,9 +249,6 @@ static LogicalResult verifyMeshSymbolUses(Operation *op,
     return op->emitError() << "Undefined required mesh symbol \""
                            << meshSymbol.getValue() << "\".";
   }
-  if (!meshAxes) {
-    return success();
-  }
   MeshAxis rank = mesh.getRank();
   for (auto axis : meshAxes.asArrayRef()) {
     if (axis >= rank || axis < 0) {

>From 4090e345b4a9587f511c4f7715c07f043772168d Mon Sep 17 00:00:00 2001
From: Boian Petkantchin <boian at nod-labs.com>
Date: Thu, 16 Nov 2023 07:52:27 -0800
Subject: [PATCH 09/10] Move all verifications in verifySymbolUses

---
 mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td |   4 -
 mlir/lib/Dialect/Mesh/IR/MeshOps.cpp         | 164 +++++++------------
 2 files changed, 58 insertions(+), 110 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
index a72c3970501dddb..46b298f9e9f5ed9 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
@@ -250,7 +250,6 @@ def Mesh_AllGatherOp : Mesh_CollectiveCommunicationOpBase<"all_gather", [
     $input `on` $mesh (`mesh_axes` `=` $mesh_axes^)? `gather_axis` `=` $gather_axis
     attr-dict `:` type($input) `->` type($result)
   }];
-  let hasVerifier = 1;
   let hasCanonicalizer = 1;
 }
 
@@ -283,7 +282,6 @@ def Mesh_AllReduceOp : Mesh_CollectiveCommunicationOpBase<"all_reduce", [
     $input `on` $mesh (`mesh_axes` `=` $mesh_axes^)? (`reduction` `=` $reduction^)?
     attr-dict `:` type($input) `->` type($result)
   }];
-  let hasVerifier = 1;
   let hasCanonicalizer = 1;
 }
 
@@ -337,7 +335,6 @@ def Mesh_AllToAllOp : Mesh_CollectiveCommunicationOpBase<"all_to_all", [
     `concat_axis` `=` $concat_axis
     attr-dict `:` type($input) `->` type($result)
   }];
-  let hasVerifier = 1;
   let hasCanonicalizer = 1;
 }
 
@@ -399,7 +396,6 @@ def Mesh_ReduceScatterOp : Mesh_CollectiveCommunicationOpBase<"reduce_scatter",
     `scatter_axis` `=` $scatter_axis
     attr-dict `:` type($input) `->` type($result)
   }];
-  let hasVerifier = 1;
   let hasCanonicalizer = 1;
 }
 
diff --git a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
index 3f1fc3ec1be9f74..923c37df4d052f1 100644
--- a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
+++ b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
@@ -234,32 +234,16 @@ struct EmptyMeshAxesCanonicalizationPattern : OpRewritePattern<Op> {
 
 } // namespace
 
-static LogicalResult verifyMeshSymbolUses(Operation *op,
-                                          FlatSymbolRefAttr meshSymbol,
-                                          DenseI16ArrayAttr meshAxes,
-                                          SymbolTableCollection &symbolTable) {
-  if (!meshSymbol) {
-    return op->emitError() << "Unspecified \"mesh\" symbol attribute.";
-  }
-  SymbolTableCollection symbolTableCollection;
+static FailureOr<ClusterOp> getMesh(Operation *op, FlatSymbolRefAttr meshSymbol,
+                                    SymbolTableCollection &symbolTable) {
   mesh::ClusterOp mesh =
-      symbolTableCollection.lookupNearestSymbolFrom<mesh::ClusterOp>(
-          op, meshSymbol);
+      symbolTable.lookupNearestSymbolFrom<mesh::ClusterOp>(op, meshSymbol);
   if (!mesh) {
     return op->emitError() << "Undefined required mesh symbol \""
                            << meshSymbol.getValue() << "\".";
   }
-  MeshAxis rank = mesh.getRank();
-  for (auto axis : meshAxes.asArrayRef()) {
-    if (axis >= rank || axis < 0) {
-      return op->emitError()
-             << "0-based mesh axis index " << axis
-             << " is out of bounds. The referenced mesh \""
-             << meshSymbol.getValue() << "\" is of rank " << rank << ".";
-    }
-  }
 
-  return success();
+  return mesh;
 }
 
 template <typename It>
@@ -279,15 +263,40 @@ bool isUnique(It begin, It end) {
   return true;
 }
 
-static LogicalResult verifyMeshAxes(Location loc, ArrayRef<MeshAxis> axes) {
+static LogicalResult verifyMeshAxes(Location loc, ArrayRef<MeshAxis> axes,
+                                    ClusterOp mesh) {
   SmallVector<MeshAxis> sorted = llvm::to_vector(axes);
   llvm::sort(sorted);
   if (!isUnique(sorted.begin(), sorted.end())) {
     return emitError(loc) << "Mesh axes contains duplicate elements.";
   }
+
+  MeshAxis rank = mesh.getRank();
+  for (auto axis : axes) {
+    if (axis >= rank || axis < 0) {
+      return emitError(loc)
+             << "0-based mesh axis index " << axis
+             << " is out of bounds. The referenced mesh \"" << mesh.getSymName()
+             << "\" is of rank " << rank << ".";
+    }
+  }
+
   return success();
 }
 
+template <typename Op>
+static FailureOr<ClusterOp>
+getMeshAndVerifyAxes(Op op, SymbolTableCollection &symbolTable) {
+  auto mesh = ::getMesh(op.getOperation(), op.getMeshAttr(), symbolTable);
+  if (failed(mesh)) {
+    return failure();
+  }
+  if (failed(verifyMeshAxes(op.getLoc(), op.getMeshAxes(), mesh.value()))) {
+    return failure();
+  }
+  return mesh;
+}
+
 template <typename It>
 static auto product(It begin, It end) {
   using ElementType = std::decay_t<decltype(*begin)>;
@@ -335,6 +344,13 @@ static LogicalResult verifyDimensionCompatibility(Location loc,
 static LogicalResult verifyAllGatherOperandAndResultShape(
     Value operand, Value result, int64_t gatherAxis,
     ArrayRef<MeshAxis> meshAxes, ArrayRef<int64_t> meshShape) {
+  auto resultRank = result.getType().template cast<ShapedType>().getRank();
+  if (gatherAxis < 0 || gatherAxis >= resultRank) {
+    return emitError(result.getLoc())
+           << "Gather axis " << gatherAxis << " is out of bounds [0, "
+           << resultRank << ").";
+  }
+
   ShapedType operandType = operand.getType().cast<ShapedType>();
   ShapedType resultType = result.getType().cast<ShapedType>();
   auto deviceGroupSize =
@@ -352,38 +368,6 @@ static LogicalResult verifyAllGatherOperandAndResultShape(
   return success();
 }
 
-static FailureOr<ClusterOp> getMesh(Operation *op, FlatSymbolRefAttr meshSymbol,
-                                    DenseI16ArrayAttr meshAxes) {
-  SymbolTableCollection symbolTableCollection;
-  if (failed(verifyMeshSymbolUses(op, meshSymbol, meshAxes,
-                                  symbolTableCollection))) {
-    // We need to check the symbol here since this runs before
-    // SymbolUserOpInterface.
-    return failure();
-  }
-  return symbolTableCollection.lookupNearestSymbolFrom<mesh::ClusterOp>(
-      op, meshSymbol);
-}
-
-static LogicalResult verifyAllGather(AllGatherOp op) {
-  auto rank = op.getResult().getType().template cast<ShapedType>().getRank();
-  auto gatherAxis = op.getGatherAxis().getSExtValue();
-  if (gatherAxis < 0 || gatherAxis >= rank) {
-    return op.emitError() << "Gather axis " << gatherAxis
-                          << " is out of bounds [0, " << rank << ").";
-  }
-
-  FlatSymbolRefAttr meshSymbol = op.getMeshAttr();
-  DenseI16ArrayAttr meshAxes = op.getMeshAxesAttr();
-  auto mesh = getMesh(op.getOperation(), meshSymbol, meshAxes);
-  if (failed(mesh)) {
-    return failure();
-  }
-  return verifyAllGatherOperandAndResultShape(op.getOperand(), op.getResult(),
-                                              gatherAxis, op.getMeshAxes(),
-                                              mesh.value().canonicalDimSizes());
-}
-
 static LogicalResult verifyAllToAllOperandAndResultShape(
     Value operand, Value result, int64_t splitAxis, int64_t concatAxis,
     ArrayRef<MeshAxis> meshAxes, ArrayRef<int64_t> meshShape) {
@@ -472,48 +456,38 @@ static LogicalResult verifyReduceScatterOperandAndResultShape(
 }
 
 //===----------------------------------------------------------------------===//
-// mesh.all_reduce op
+// mesh.all_gather op
 //===----------------------------------------------------------------------===//
 
 LogicalResult
-AllReduceOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
-  FlatSymbolRefAttr meshSymbol = getMeshAttr();
-  DenseI16ArrayAttr meshAxes = getMeshAxesAttr();
-  return verifyMeshSymbolUses(getOperation(), meshSymbol, meshAxes,
-                              symbolTable);
-}
-
-LogicalResult mlir::mesh::AllReduceOp::verify() {
-  return verifyMeshAxes(getLoc(), getMeshAxes());
+AllGatherOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
+  auto mesh = getMeshAndVerifyAxes(*this, symbolTable);
+  if (failed(mesh)) {
+    return failure();
+  }
+  auto gatherAxis = getGatherAxis().getSExtValue();
+  return verifyAllGatherOperandAndResultShape(getOperand(), getResult(),
+                                              gatherAxis, getMeshAxes(),
+                                              mesh.value().canonicalDimSizes());
 }
 
-void AllReduceOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
+void AllGatherOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
                                               MLIRContext *context) {
-  patterns.add<EmptyMeshAxesCanonicalizationPattern<AllReduceOp>>(context);
+  patterns.add<EmptyMeshAxesCanonicalizationPattern<AllGatherOp>>(context);
 }
 
 //===----------------------------------------------------------------------===//
-// mesh.all_gather op
+// mesh.all_reduce op
 //===----------------------------------------------------------------------===//
 
 LogicalResult
-AllGatherOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
-  FlatSymbolRefAttr meshSymbol = getMeshAttr();
-  DenseI16ArrayAttr meshAxes = getMeshAxesAttr();
-  return verifyMeshSymbolUses(getOperation(), meshSymbol, meshAxes,
-                              symbolTable);
-}
-
-LogicalResult mlir::mesh::AllGatherOp::verify() {
-  if (failed(verifyMeshAxes(getLoc(), getMeshAxes()))) {
-    return failure();
-  }
-  return verifyAllGather(*this);
+AllReduceOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
+  return getMeshAndVerifyAxes(*this, symbolTable);
 }
 
-void AllGatherOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
+void AllReduceOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
                                               MLIRContext *context) {
-  patterns.add<EmptyMeshAxesCanonicalizationPattern<AllGatherOp>>(context);
+  patterns.add<EmptyMeshAxesCanonicalizationPattern<AllReduceOp>>(context);
 }
 
 //===----------------------------------------------------------------------===//
@@ -521,22 +495,11 @@ void AllGatherOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
 //===----------------------------------------------------------------------===//
 
 LogicalResult AllToAllOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
-  FlatSymbolRefAttr meshSymbol = getMeshAttr();
-  DenseI16ArrayAttr meshAxes = getMeshAxesAttr();
-  return verifyMeshSymbolUses(getOperation(), meshSymbol, meshAxes,
-                              symbolTable);
-}
-
-LogicalResult AllToAllOp::verify() {
-  if (failed(verifyMeshAxes(getLoc(), getMeshAxes()))) {
-    return failure();
-  }
-  FlatSymbolRefAttr meshSymbol = getMeshAttr();
-  DenseI16ArrayAttr meshAxes = getMeshAxesAttr();
-  auto mesh = ::getMesh(getOperation(), meshSymbol, meshAxes);
+  auto mesh = getMeshAndVerifyAxes(*this, symbolTable);
   if (failed(mesh)) {
     return failure();
   }
+
   return verifyAllToAllOperandAndResultShape(
       getOperand(), getResult(), getSplitAxis().getSExtValue(),
       getConcatAxis().getSExtValue(), getMeshAxes(),
@@ -554,22 +517,11 @@ void AllToAllOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
 
 LogicalResult
 ReduceScatterOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
-  FlatSymbolRefAttr meshSymbol = getMeshAttr();
-  DenseI16ArrayAttr meshAxes = getMeshAxesAttr();
-  return verifyMeshSymbolUses(getOperation(), meshSymbol, meshAxes,
-                              symbolTable);
-}
-
-LogicalResult ReduceScatterOp::verify() {
-  if (failed(verifyMeshAxes(getLoc(), getMeshAxes()))) {
-    return failure();
-  }
-  FlatSymbolRefAttr meshSymbol = getMeshAttr();
-  DenseI16ArrayAttr meshAxes = getMeshAxesAttr();
-  auto mesh = ::getMesh(getOperation(), meshSymbol, meshAxes);
+  auto mesh = getMeshAndVerifyAxes(*this, symbolTable);
   if (failed(mesh)) {
     return failure();
   }
+
   return verifyReduceScatterOperandAndResultShape(
       getOperand(), getResult(), getScatterAxis().getSExtValue(), getMeshAxes(),
       mesh.value().canonicalDimSizes());

>From 877149a23aaa1c3edb25e64931ee12b1cd313cd3 Mon Sep 17 00:00:00 2001
From: Boian Petkantchin <boian at nod-labs.com>
Date: Thu, 16 Nov 2023 16:46:07 -0800
Subject: [PATCH 10/10] In all-to-all allow for split axis non-divisible by
 device group size

---
 mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td |  5 +++--
 mlir/lib/Dialect/Mesh/IR/MeshOps.cpp         | 12 ++++--------
 mlir/test/Dialect/Mesh/ops.mlir              | 15 ++++++++++++++-
 3 files changed, 21 insertions(+), 11 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
index 46b298f9e9f5ed9..5cce15dd1015ecc 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
@@ -292,19 +292,20 @@ def Mesh_AllToAllOp : Mesh_CollectiveCommunicationOpBase<"all_to_all", [
   let description = [{
     Performs an all-to-all on tensor pieces split along `split_axis`.
     The resulting pieces are concatenated along `concat_axis` on ech device.
+
     Example:
     ```
     mesh.cluster @mesh0(rank = 1, dim_sizes = [3])
     ...
     %1 = mesh.all_to_all %0 on @mesh0 mesh_axes = [0]
       split_axis = 0 concat_axis = 0
-      : tensor<3x6xi8> -> tensor<3x6xi8>
+      : tensor<3x2xi8> -> tensor<3x2xi8>
     ```
     Input:
     ```
      device  device  device
      (0)     (1)     (2)
-    +-------+-------+-------+  | split and concat
+    +-------+-------+-------+  | split and concat along
     | 11 12 | 21 22 | 31 32 |  | tensor axis 0
     | 13 14 | 23 24 | 33 34 |  ↓
     | 15 16 | 25 26 | 35 36 |
diff --git a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
index 923c37df4d052f1..b45f7cd21ce9217 100644
--- a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
+++ b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
@@ -391,18 +391,14 @@ static LogicalResult verifyAllToAllOperandAndResultShape(
       DimensionSize(collectiveDeviceGroupSize(meshAxes, meshShape));
   auto operandConcatDimSize = DimensionSize(operandType.getDimSize(concatAxis));
   auto operandSplitDimSize = DimensionSize(operandType.getDimSize(splitAxis));
-  if (!operandSplitDimSize.isDynamic() && !deviceGroupSize.isDynamic() &&
-      int64_t(operandSplitDimSize) % int64_t(deviceGroupSize) != 0) {
-    return emitError(result.getLoc())
-           << "Operand dimension size " << int64_t(operandSplitDimSize)
-           << " is not divisible by collective device group size "
-           << int64_t(deviceGroupSize) << " for split axis " << splitAxis
-           << ".";
-  }
   DimensionSize expectedResultConcatDimSize =
       operandConcatDimSize * deviceGroupSize;
   DimensionSize expectedResultSplitDimSize =
       operandSplitDimSize / deviceGroupSize;
+  if (!expectedResultSplitDimSize.isDynamic() &&
+      int64_t(operandSplitDimSize) % int64_t(deviceGroupSize) != 0) {
+    expectedResultSplitDimSize = DimensionSize::dynamic();
+  }
   if (failed(verifyDimensionCompatibility(
           result.getLoc(), expectedResultConcatDimSize.value(),
           resultType.getDimSize(concatAxis), concatAxis))) {
diff --git a/mlir/test/Dialect/Mesh/ops.mlir b/mlir/test/Dialect/Mesh/ops.mlir
index 3ba9aabacdd7dcc..5b264bc88dfc2a7 100644
--- a/mlir/test/Dialect/Mesh/ops.mlir
+++ b/mlir/test/Dialect/Mesh/ops.mlir
@@ -199,7 +199,7 @@ func.func @all_to_all_dynamic_dims_in_result(
   return %0 : tensor<3x?xi8>
 }
 
-// CHECK-LABEL: func @all_to_all
+// CHECK-LABEL: func @all_to_all_same_split_concat_dim_with_dynamic_device_group_size
 func.func @all_to_all_same_split_concat_dim_with_dynamic_device_group_size(
     // CHECK-SAME: %[[ARG:.*]]: tensor<3xi8>
     %arg0 : tensor<3xi8>) -> tensor<3xi8> {
@@ -212,6 +212,19 @@ func.func @all_to_all_same_split_concat_dim_with_dynamic_device_group_size(
   return %0 : tensor<3xi8>
 }
 
+// CHECK-LABEL: func @all_to_all_non_divisible_split_axis_size
+func.func @all_to_all_non_divisible_split_axis_size(
+    // CHECK-SAME: %[[ARG:.*]]: tensor<2x3xi8>
+    %arg0 : tensor<2x3xi8>) -> tensor<?x12xi8> {
+  // CHECK-NEXT: mesh.all_to_all %[[ARG]]
+  // CHECK-SAME: @mesh0 mesh_axes = [0, 1] split_axis = 0 concat_axis = 1
+  // CHECK-SAME: : tensor<2x3xi8> -> tensor<?x12xi8>
+  %0 = mesh.all_to_all %arg0 on @mesh0 mesh_axes = [0, 1]
+    split_axis = 0 concat_axis = 1
+    : tensor<2x3xi8> -> tensor<?x12xi8>
+  return %0 : tensor<?x12xi8>
+}
+
 // CHECK-LABEL: func @reduce_scatter_static_dimensions
 func.func @reduce_scatter_static_dimensions(
     // CHECK-SAME: %[[ARG:.*]]: tensor<3x4xf32>



More information about the Mlir-commits mailing list