[Mlir-commits] [mlir] 5f7c8c1 - [mlir][mesh] Add collective communication operations (#71960)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Nov 21 06:50:29 PST 2023
Author: Boian Petkantchin
Date: 2023-11-21T06:50:24-08:00
New Revision: 5f7c8c1068d23d1f23643ddd141ca6f4f23f3578
URL: https://github.com/llvm/llvm-project/commit/5f7c8c1068d23d1f23643ddd141ca6f4f23f3578
DIFF: https://github.com/llvm/llvm-project/commit/5f7c8c1068d23d1f23643ddd141ca6f4f23f3578.diff
LOG: [mlir][mesh] Add collective communication operations (#71960)
Add all-gather, all-reduce, all-to-all and reduce-scatter. These
operations have device mesh semantics.
Added:
mlir/docs/Dialects/Mesh.md
mlir/test/Dialect/Mesh/canonicalization.mlir
Modified:
mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td
mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
mlir/test/Dialect/Mesh/invalid.mlir
mlir/test/Dialect/Mesh/ops.mlir
Removed:
################################################################################
diff --git a/mlir/docs/Dialects/Mesh.md b/mlir/docs/Dialects/Mesh.md
new file mode 100644
index 000000000000000..03877f1a6544817
--- /dev/null
+++ b/mlir/docs/Dialects/Mesh.md
@@ -0,0 +1,43 @@
+# 'mesh' Dialect
+
+The `mesh` dialect contains a set of attributes, operations and interfaces that
+are useful for representing sharding and communication on a device mesh
+cluster.
+
+[TOC]
+
+## Collective Communication Operations
+There are a number of operations in the Mesh dialect to facilitate
+communication between devices in a mesh.
+It is assumed that the user is familiar with collective operations.
+[Wikipedia](https://en.wikipedia.org/wiki/Collective_operation) has a good
+explanation.
+The main addition is that the collectives in this dialect have mesh
+semantics.
+
+The operation attributes `mesh` and `mesh_axes` specifies a list of device mesh
+axes that partition the devices into disjoint groups.
+The collective operation is performed between devices in the same group.
+Devices that have the same coordinates outside of axes `mesh_axes` are in the
+same group.
+For example if we have a device mesh of size `2x3x4x5` and the partition mesh
+axes list is `[0, 1]` then devices are partitioned into the groups
+`{ { (i, j, k, m) | 0<=i<2, 0<=j<3 } | 0<=k<4, 0<=m<5 }`.
+Devices (1, 0, 2, 3) and (1, 1, 2, 3) will be in the same group.
+Device (1, 0, 2, 4) will be in another group.
+Some collective operations like all-to-all and all-gather care about the
+order of devices.
+The order of device in a device group is induced by the order of axes in
+`mesh_axes`.
+The axes are ordered from outer to inner.
+If we have an axis list `[3, 1]` then device `(i, 1, k, 0)` will precede
+both devices `(i, 0, k, 1)` and `(i, 2, k, 0)`.
+
+
+## Operations
+
+[include "Dialects/MeshOps.md"]
+
+## Attributes
+
+[include "Dialects/MeshAttributes.md"]
diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td b/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td
index a91ef569347bff1..9d39b1b3329fb4b 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td
@@ -23,9 +23,7 @@ def Mesh_Dialect : Dialect {
let cppNamespace = "::mlir::mesh";
let description = [{
- The `mesh` dialect contains a set of attributes, operations, interfaces that
- are useful for representing sharding and communication on device mesh
- cluster.
+ See [Mesh dialect documentation](mlir/docs/Dialects/Mesh.md).
}];
let dependentDialects = [
@@ -49,6 +47,10 @@ def Mesh_Partial : I32EnumAttr<"Partial", "partial type of a distributed tensor"
let cppNamespace = "::mlir::mesh";
}
+def Mesh_PartialAttr : EnumAttr<Mesh_Dialect, Mesh_Partial, "partial"> {
+ let assemblyFormat = "`<` $value `>`";
+}
+
// Mesh_IteratorType and Mesh_Partial are used to annotate
diff erent aspects of
// distributed tensors. Mesh_IteratorType annotates loops in an operation, while
// Mesh_Partial indicates whether a tensor is sharded on a specific dimension or
diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
index 05eba66a89949b6..9077d2eb0189b72 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
@@ -10,9 +10,12 @@
#define MLIR_DIALECT_MESH_IR_MESHOPS_H
#include "mlir/Bytecode/BytecodeOpInterface.h"
+#include "mlir/IR/OpDefinition.h"
+#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/SymbolTable.h"
#include "mlir/Interfaces/InferTypeOpInterface.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
+#include <algorithm>
#include "mlir/Dialect/Mesh/IR/MeshOpsDialect.h.inc"
diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
index a8aa0a694bee29f..5cce15dd1015ecc 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
@@ -13,6 +13,8 @@ include "mlir/Dialect/Mesh/IR/MeshBase.td"
include "mlir/Interfaces/InferTypeOpInterface.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
include "mlir/IR/BuiltinTypes.td"
+include "mlir/IR/CommonAttrConstraints.td"
+include "mlir/IR/CommonTypeConstraints.td"
include "mlir/IR/SymbolInterfaces.td"
//===----------------------------------------------------------------------===//
@@ -77,6 +79,18 @@ def Mesh_ClusterOp : Mesh_Op<"cluster", [Symbol]> {
$sym_name `(` `rank` `=` $rank (`,` `dim_sizes` `=` $dim_sizes^)? `)`
attr-dict
}];
+ let extraClassDeclaration = [{
+ // The `dim_sizes` attribute may have size less than the rank of the mesh.
+ // Returns the shape of the mesh with missing trailing dimensions
+ // explicitly set as dynamic.
+ ::mlir::SmallVector<int64_t> canonicalDimSizes();
+
+ template <typename OutIt>
+ void canonicalDimSizes(OutIt outIt) {
+ std::copy(getDimSizes().begin(), getDimSizes().end(), outIt);
+ std::fill_n(outIt, getRank() - getDimSizes().size(), 0);
+ }
+ }];
let hasVerifier = 1;
}
@@ -171,4 +185,219 @@ def Mesh_ShardOp : Mesh_Op<"shard", [Pure, SameOperandsAndResultType]> {
}];
}
+//===----------------------------------------------------------------------===//
+// collective communication ops
+//===----------------------------------------------------------------------===//
+
+class Mesh_CollectiveCommunicationOpBase<
+ string mnemonic, list<Trait> traits = []> :
+ Mesh_Op<mnemonic,
+ !listconcat(traits,
+ [DeclareOpInterfaceMethods<SymbolUserOpInterface>])> {
+ dag commonArgs = (ins
+ FlatSymbolRefAttr:$mesh,
+ DefaultValuedAttr<DenseI16ArrayAttr, "{}">:$mesh_axes
+ );
+}
+
+def Mesh_AllGatherOp : Mesh_CollectiveCommunicationOpBase<"all_gather", [
+ SameOperandsAndResultElementType,
+ SameOperandsAndResultRank
+ ]> {
+ let summary = "All-gather over a device mesh.";
+ let description = [{
+ Gathers along the `gather_axis` tensor axis.
+
+ Example:
+ ```mlir
+ mesh.cluster @mesh0(rank = 2, dim_sizes = [2, 2])
+ ...
+ %1 = mesh.all_gather %0 on @mesh0 mesh_axes = [1] gather_axis = 1
+ : tensor<2x2xi8> -> tensor<2x4xi8>
+ ```
+ Input:
+ ```
+ +-------+-------+
+ device (0, 0) -> | 1 2 | 5 6 | <- device (0, 1)
+ | 3 4 | 7 8 |
+ +-------+-------+
+ device (1, 0) -> | 9 10 | 13 14 | <- device (1, 1)
+ | 11 12 | 15 16 |
+ +-------+-------+
+ ```
+ Result:
+ ```
+ gather tensor
+ axis 1
+ ------------>
+ +-------------+
+ | 1 2 5 6 | <- devices (0, 0) and (0, 1)
+ | 3 4 7 8 |
+ +-------------+
+ | 9 10 13 14 | <- devices (1, 0) and (1, 1)
+ | 11 12 15 16 |
+ +-------------+
+ ```
+ }];
+ let arguments = !con(commonArgs, (ins
+ AnyNon0RankedTensor:$input,
+ IndexAttr:$gather_axis
+ ));
+ let results = (outs
+ AnyNon0RankedTensor:$result
+ );
+ let assemblyFormat = [{
+ $input `on` $mesh (`mesh_axes` `=` $mesh_axes^)? `gather_axis` `=` $gather_axis
+ attr-dict `:` type($input) `->` type($result)
+ }];
+ let hasCanonicalizer = 1;
+}
+
+def Mesh_AllReduceOp : Mesh_CollectiveCommunicationOpBase<"all_reduce", [
+ SameOperandsAndResultShape]> {
+ let summary = "All-reduce over a device mesh.";
+ let description = [{
+ The accumulation element type is specified by the result type and
+ it does not need to match the input element type.
+ The input element is converted to the result element type before
+ performing the reduction.
+
+ Attributes:
+ `reduction`: Indicates the reduction method.
+
+ Example:
+ ```
+ %1 = mesh.all_reduce %0 on @mesh0 mesh_axes = [1, 0] reduction = <max>
+ : tensor<3x4xf32> -> tensor<3x4xf64>
+ ```
+ }];
+ let arguments = !con(commonArgs, (ins
+ AnyRankedTensor:$input,
+ DefaultValuedAttr<Mesh_PartialAttr, "::mlir::mesh::Partial::Sum">:$reduction
+ ));
+ let results = (outs
+ AnyRankedTensor:$result
+ );
+ let assemblyFormat = [{
+ $input `on` $mesh (`mesh_axes` `=` $mesh_axes^)? (`reduction` `=` $reduction^)?
+ attr-dict `:` type($input) `->` type($result)
+ }];
+ let hasCanonicalizer = 1;
+}
+
+def Mesh_AllToAllOp : Mesh_CollectiveCommunicationOpBase<"all_to_all", [
+ SameOperandsAndResultElementType,
+ SameOperandsAndResultRank]> {
+ let summary = "All-to-all over a device mesh.";
+ let description = [{
+ Performs an all-to-all on tensor pieces split along `split_axis`.
+ The resulting pieces are concatenated along `concat_axis` on ech device.
+
+ Example:
+ ```
+ mesh.cluster @mesh0(rank = 1, dim_sizes = [3])
+ ...
+ %1 = mesh.all_to_all %0 on @mesh0 mesh_axes = [0]
+ split_axis = 0 concat_axis = 0
+ : tensor<3x2xi8> -> tensor<3x2xi8>
+ ```
+ Input:
+ ```
+ device device device
+ (0) (1) (2)
+ +-------+-------+-------+ | split and concat along
+ | 11 12 | 21 22 | 31 32 | | tensor axis 0
+ | 13 14 | 23 24 | 33 34 | ↓
+ | 15 16 | 25 26 | 35 36 |
+ +-------+-------+-------+
+ ```
+ Result:
+ ```
+ device device device
+ (0) (1) (2)
+ +-------+-------+-------+
+ | 11 12 | 13 14 | 15 16 |
+ | 21 22 | 23 24 | 25 26 |
+ | 31 32 | 33 34 | 35 36 |
+ +-------+-------+-------+
+ ```
+ }];
+ let arguments = !con(commonArgs, (ins
+ AnyNon0RankedTensor:$input,
+ IndexAttr:$split_axis,
+ IndexAttr:$concat_axis
+ ));
+ let results = (outs
+ AnyNon0RankedTensor:$result
+ );
+ let assemblyFormat = [{
+ $input `on` $mesh (`mesh_axes` `=` $mesh_axes^)?
+ `split_axis` `=` $split_axis
+ `concat_axis` `=` $concat_axis
+ attr-dict `:` type($input) `->` type($result)
+ }];
+ let hasCanonicalizer = 1;
+}
+
+def Mesh_ReduceScatterOp : Mesh_CollectiveCommunicationOpBase<"reduce_scatter", [
+ SameOperandsAndResultRank]> {
+ let summary = "Reduce-scatter over a device mesh.";
+ let description = [{
+ After the reduction, the result is scattered within each device group.
+ The tensor is split along `scatter_axis` and the pieces distributed
+ across the device group.
+ Example:
+ ```
+ mesh.cluster @mesh0(rank = 1, dim_sizes = [2, 2])
+ ...
+ %1 = mesh.reduce_scatter %0 on @mesh0 mesh_axes = [1]
+ reduction = <max> scatter_axis = 0
+ : tensor<3x4xf32> -> tensor<1x4xf64>
+ ```
+ Input:
+ ```
+ device
+ (0, 1)
+ ↓
+ +-------+-------+ | scatter tensor
+ device (0, 0) -> | 1 2 | 5 6 | | axis 0
+ | 3 4 | 7 8 | ↓
+ +-------+-------+
+ device (1, 0) -> | 9 10 | 13 14 |
+ | 11 12 | 15 16 |
+ +-------+-------+
+ ↑
+ device
+ (1, 1)
+ ```
+ Result:
+ ```
+ +-------+
+ | 6 8 | <- devices (0, 0)
+ +-------+
+ | 10 12 | <- devices (0, 1)
+ +-------+
+ | 22 24 | <- devices (1, 0)
+ +-------+
+ | 26 28 | <- devices (1, 1)
+ +-------+
+ ```
+ }];
+ let arguments = !con(commonArgs, (ins
+ AnyNon0RankedTensor:$input,
+ DefaultValuedAttr<Mesh_PartialAttr, "::mlir::mesh::Partial::Sum">:$reduction,
+ IndexAttr:$scatter_axis
+ ));
+ let results = (outs
+ AnyRankedTensor:$result
+ );
+ let assemblyFormat = [{
+ $input `on` $mesh (`mesh_axes` `=` $mesh_axes^)?
+ (`reduction` `=` $reduction^)?
+ `scatter_axis` `=` $scatter_axis
+ attr-dict `:` type($input) `->` type($result)
+ }];
+ let hasCanonicalizer = 1;
+}
+
#endif // MLIR_DIALECT_MESH_IR_MESHOPS_TD
diff --git a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
index 588704f24574f90..b45f7cd21ce9217 100644
--- a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
+++ b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
@@ -8,10 +8,27 @@
#include "mlir/Dialect/Mesh/IR/MeshOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/BuiltinTypeInterfaces.h"
+#include "mlir/IR/Diagnostics.h"
#include "mlir/IR/DialectImplementation.h"
+#include "mlir/IR/Location.h"
+#include "mlir/IR/PatternMatch.h"
#include "mlir/Support/LLVM.h"
+#include "mlir/Support/LogicalResult.h"
+#include "llvm/ADT/ArrayRef.h"
+#include "llvm/ADT/DenseSet.h"
+#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallSet.h"
+#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/TypeSwitch.h"
+#include <algorithm>
+#include <functional>
+#include <iterator>
+#include <numeric>
+#include <optional>
+#include <string>
+#include <utility>
#define DEBUG_TYPE "mesh-ops"
#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ")
@@ -21,6 +38,60 @@ using namespace mlir::mesh;
#include "mlir/Dialect/Mesh/IR/MeshOpsDialect.cpp.inc"
+template <typename It>
+static It canonicalizeSetAsArray(It begin, It end) {
+ llvm::sort(begin, end);
+ return std::unique(begin, end);
+}
+
+template <typename R>
+static auto canonicalizeSetAsArray(R &&range) {
+ return canonicalizeSetAsArray(adl_begin(range), adl_end(range));
+}
+
+template <typename T>
+static SmallVector<T> &canonicalizeSetAsVector(SmallVector<T> &vec) {
+ auto newEnd = canonicalizeSetAsArray(vec);
+ vec.resize(newEnd - vec.begin());
+ return vec;
+}
+
+template <typename DimSize>
+static bool isMeshDimensionDynamic(DimSize size) {
+ return size <= DimSize(0);
+}
+
+using MeshAxis = int16_t;
+
+namespace {
+
+struct DimensionSize {
+ static DimensionSize dynamic() { return DimensionSize(ShapedType::kDynamic); }
+ DimensionSize(int64_t val) : val(val) {}
+ int64_t value() const { return val; }
+ operator int64_t() const { return val; }
+ bool isDynamic() const { return ShapedType::isDynamic(val); }
+
+private:
+ int64_t val;
+};
+
+} // namespace
+
+static DimensionSize operator/(DimensionSize lhs, DimensionSize rhs) {
+ if (lhs.isDynamic() || rhs.isDynamic()) {
+ return DimensionSize::dynamic();
+ }
+ return lhs.value() / rhs.value();
+}
+
+static DimensionSize operator*(DimensionSize lhs, DimensionSize rhs) {
+ if (lhs.isDynamic() || rhs.isDynamic()) {
+ return DimensionSize::dynamic();
+ }
+ return lhs.value() * rhs.value();
+}
+
//===----------------------------------------------------------------------===//
// Mesh dialect
//===----------------------------------------------------------------------===//
@@ -96,6 +167,13 @@ LogicalResult ClusterOp::verify() {
return success();
}
+SmallVector<int64_t> ClusterOp::canonicalDimSizes() {
+ SmallVector<int64_t> result;
+ canonicalDimSizes(std::back_inserter(result));
+ result.reserve(getRank());
+ return result;
+}
+
//===----------------------------------------------------------------------===//
// mesh.shard op
//===----------------------------------------------------------------------===//
@@ -129,6 +207,327 @@ MeshShardingAttr::verify(function_ref<InFlightDiagnostic()> emitError,
return success();
}
+//===----------------------------------------------------------------------===//
+// collective communication ops
+//===----------------------------------------------------------------------===//
+
+namespace {
+
+template <typename Op>
+struct EmptyMeshAxesCanonicalizationPattern : OpRewritePattern<Op> {
+ using OpRewritePattern<Op>::OpRewritePattern;
+ LogicalResult matchAndRewrite(Op op,
+ PatternRewriter &rewriter) const override {
+ auto meshAxes = op.getMeshAxes();
+ if (!meshAxes.empty()) {
+ return failure();
+ }
+ if (op.getInput().getType() != op.getResult().getType()) {
+ return failure();
+ }
+
+ rewriter.replaceAllUsesWith(op.getResult(), op.getInput());
+ rewriter.eraseOp(op.getOperation());
+ return success();
+ }
+};
+
+} // namespace
+
+static FailureOr<ClusterOp> getMesh(Operation *op, FlatSymbolRefAttr meshSymbol,
+ SymbolTableCollection &symbolTable) {
+ mesh::ClusterOp mesh =
+ symbolTable.lookupNearestSymbolFrom<mesh::ClusterOp>(op, meshSymbol);
+ if (!mesh) {
+ return op->emitError() << "Undefined required mesh symbol \""
+ << meshSymbol.getValue() << "\".";
+ }
+
+ return mesh;
+}
+
+template <typename It>
+bool isUnique(It begin, It end) {
+ if (begin == end) {
+ return true;
+ }
+ It next = std::next(begin);
+ if (next == end) {
+ return true;
+ }
+ for (; next != end; ++next, ++begin) {
+ if (*begin == *next) {
+ return false;
+ }
+ }
+ return true;
+}
+
+static LogicalResult verifyMeshAxes(Location loc, ArrayRef<MeshAxis> axes,
+ ClusterOp mesh) {
+ SmallVector<MeshAxis> sorted = llvm::to_vector(axes);
+ llvm::sort(sorted);
+ if (!isUnique(sorted.begin(), sorted.end())) {
+ return emitError(loc) << "Mesh axes contains duplicate elements.";
+ }
+
+ MeshAxis rank = mesh.getRank();
+ for (auto axis : axes) {
+ if (axis >= rank || axis < 0) {
+ return emitError(loc)
+ << "0-based mesh axis index " << axis
+ << " is out of bounds. The referenced mesh \"" << mesh.getSymName()
+ << "\" is of rank " << rank << ".";
+ }
+ }
+
+ return success();
+}
+
+template <typename Op>
+static FailureOr<ClusterOp>
+getMeshAndVerifyAxes(Op op, SymbolTableCollection &symbolTable) {
+ auto mesh = ::getMesh(op.getOperation(), op.getMeshAttr(), symbolTable);
+ if (failed(mesh)) {
+ return failure();
+ }
+ if (failed(verifyMeshAxes(op.getLoc(), op.getMeshAxes(), mesh.value()))) {
+ return failure();
+ }
+ return mesh;
+}
+
+template <typename It>
+static auto product(It begin, It end) {
+ using ElementType = std::decay_t<decltype(*begin)>;
+ return std::accumulate(begin, end, static_cast<ElementType>(1),
+ std::multiplies<ElementType>());
+}
+
+template <typename R>
+static auto product(R &&range) {
+ return product(adl_begin(range), adl_end(range));
+}
+
+static int64_t collectiveDeviceGroupSize(ArrayRef<MeshAxis> meshAxes,
+ ArrayRef<int64_t> meshShape) {
+ int64_t res = 1;
+
+ for (MeshAxis axis : meshAxes) {
+ if (isMeshDimensionDynamic(meshShape[axis])) {
+ return ShapedType::kDynamic;
+ }
+ assert(size_t(axis) < meshShape.size());
+ res *= meshShape[axis];
+ }
+
+ return res;
+}
+
+static LogicalResult verifyDimensionCompatibility(Location loc,
+ int64_t expectedDimSize,
+ int64_t resultDimSize,
+ int64_t resultAxis) {
+ if (!ShapedType::isDynamic(resultDimSize) &&
+ expectedDimSize != resultDimSize) {
+ return emitError(loc) << "Dimension size mismatch for result axis "
+ << resultAxis << ". Expected "
+ << (ShapedType::isDynamic(expectedDimSize)
+ ? Twine("dynamic")
+ : Twine(expectedDimSize))
+ << ", but got " << resultDimSize << ".";
+ }
+
+ return success();
+}
+
+static LogicalResult verifyAllGatherOperandAndResultShape(
+ Value operand, Value result, int64_t gatherAxis,
+ ArrayRef<MeshAxis> meshAxes, ArrayRef<int64_t> meshShape) {
+ auto resultRank = result.getType().template cast<ShapedType>().getRank();
+ if (gatherAxis < 0 || gatherAxis >= resultRank) {
+ return emitError(result.getLoc())
+ << "Gather axis " << gatherAxis << " is out of bounds [0, "
+ << resultRank << ").";
+ }
+
+ ShapedType operandType = operand.getType().cast<ShapedType>();
+ ShapedType resultType = result.getType().cast<ShapedType>();
+ auto deviceGroupSize =
+ DimensionSize(collectiveDeviceGroupSize(meshAxes, meshShape));
+ for (int64_t axis = 0; axis < operandType.getRank(); ++axis) {
+ auto operandDimSize = DimensionSize(operandType.getDimSize(axis));
+ auto resultDimSize = DimensionSize(resultType.getDimSize(axis));
+ auto expectedResultDimSize =
+ axis == gatherAxis ? deviceGroupSize * operandDimSize : operandDimSize;
+ if (failed(verifyDimensionCompatibility(
+ result.getLoc(), expectedResultDimSize, resultDimSize, axis))) {
+ return failure();
+ }
+ }
+ return success();
+}
+
+static LogicalResult verifyAllToAllOperandAndResultShape(
+ Value operand, Value result, int64_t splitAxis, int64_t concatAxis,
+ ArrayRef<MeshAxis> meshAxes, ArrayRef<int64_t> meshShape) {
+ ShapedType operandType = operand.getType().cast<ShapedType>();
+ ShapedType resultType = result.getType().cast<ShapedType>();
+ for (int64_t axis = 0; axis < operandType.getRank(); ++axis) {
+ if ((axis != splitAxis && axis != concatAxis) || splitAxis == concatAxis) {
+ if (failed(verifyDimensionCompatibility(
+ result.getLoc(), operandType.getDimSize(axis),
+ resultType.getDimSize(axis), axis))) {
+ return failure();
+ }
+ }
+ }
+
+ if (splitAxis == concatAxis) {
+ return success();
+ }
+
+ auto deviceGroupSize =
+ DimensionSize(collectiveDeviceGroupSize(meshAxes, meshShape));
+ auto operandConcatDimSize = DimensionSize(operandType.getDimSize(concatAxis));
+ auto operandSplitDimSize = DimensionSize(operandType.getDimSize(splitAxis));
+ DimensionSize expectedResultConcatDimSize =
+ operandConcatDimSize * deviceGroupSize;
+ DimensionSize expectedResultSplitDimSize =
+ operandSplitDimSize / deviceGroupSize;
+ if (!expectedResultSplitDimSize.isDynamic() &&
+ int64_t(operandSplitDimSize) % int64_t(deviceGroupSize) != 0) {
+ expectedResultSplitDimSize = DimensionSize::dynamic();
+ }
+ if (failed(verifyDimensionCompatibility(
+ result.getLoc(), expectedResultConcatDimSize.value(),
+ resultType.getDimSize(concatAxis), concatAxis))) {
+ return failure();
+ }
+ if (failed(verifyDimensionCompatibility(
+ result.getLoc(), expectedResultSplitDimSize.value(),
+ resultType.getDimSize(splitAxis), splitAxis))) {
+ return failure();
+ }
+
+ return success();
+}
+
+static LogicalResult verifyReduceScatterOperandAndResultShape(
+ Value operand, Value result, int64_t scatterAxis,
+ ArrayRef<MeshAxis> meshAxes, ArrayRef<int64_t> meshShape) {
+ ShapedType operandType = operand.getType().cast<ShapedType>();
+ ShapedType resultType = result.getType().cast<ShapedType>();
+ for (int64_t axis = 0; axis < operandType.getRank(); ++axis) {
+ if (axis != scatterAxis) {
+ if (failed(verifyDimensionCompatibility(
+ result.getLoc(), operandType.getDimSize(axis),
+ resultType.getDimSize(axis), axis))) {
+ return failure();
+ }
+ }
+ }
+
+ auto deviceGroupSize =
+ DimensionSize(collectiveDeviceGroupSize(meshAxes, meshShape));
+ auto operandScatterDimSize =
+ DimensionSize(operandType.getDimSize(scatterAxis));
+ if (!operandScatterDimSize.isDynamic() && !deviceGroupSize.isDynamic() &&
+ int64_t(operandScatterDimSize) % int64_t(deviceGroupSize) != 0) {
+ return emitError(result.getLoc())
+ << "Operand dimension size " << int64_t(operandScatterDimSize)
+ << " is not divisible by collective device group size "
+ << int64_t(deviceGroupSize) << " for scatter axis " << scatterAxis
+ << ".";
+ }
+ DimensionSize expectedResultScatterDimSize =
+ operandScatterDimSize / deviceGroupSize;
+ if (failed(verifyDimensionCompatibility(
+ result.getLoc(), expectedResultScatterDimSize.value(),
+ resultType.getDimSize(scatterAxis), scatterAxis))) {
+ return failure();
+ }
+
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// mesh.all_gather op
+//===----------------------------------------------------------------------===//
+
+LogicalResult
+AllGatherOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
+ auto mesh = getMeshAndVerifyAxes(*this, symbolTable);
+ if (failed(mesh)) {
+ return failure();
+ }
+ auto gatherAxis = getGatherAxis().getSExtValue();
+ return verifyAllGatherOperandAndResultShape(getOperand(), getResult(),
+ gatherAxis, getMeshAxes(),
+ mesh.value().canonicalDimSizes());
+}
+
+void AllGatherOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
+ MLIRContext *context) {
+ patterns.add<EmptyMeshAxesCanonicalizationPattern<AllGatherOp>>(context);
+}
+
+//===----------------------------------------------------------------------===//
+// mesh.all_reduce op
+//===----------------------------------------------------------------------===//
+
+LogicalResult
+AllReduceOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
+ return getMeshAndVerifyAxes(*this, symbolTable);
+}
+
+void AllReduceOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
+ MLIRContext *context) {
+ patterns.add<EmptyMeshAxesCanonicalizationPattern<AllReduceOp>>(context);
+}
+
+//===----------------------------------------------------------------------===//
+// mesh.all_to_all op
+//===----------------------------------------------------------------------===//
+
+LogicalResult AllToAllOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
+ auto mesh = getMeshAndVerifyAxes(*this, symbolTable);
+ if (failed(mesh)) {
+ return failure();
+ }
+
+ return verifyAllToAllOperandAndResultShape(
+ getOperand(), getResult(), getSplitAxis().getSExtValue(),
+ getConcatAxis().getSExtValue(), getMeshAxes(),
+ mesh.value().canonicalDimSizes());
+}
+
+void AllToAllOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
+ MLIRContext *context) {
+ patterns.add<EmptyMeshAxesCanonicalizationPattern<AllToAllOp>>(context);
+}
+
+//===----------------------------------------------------------------------===//
+// mesh.reduce_scatter op
+//===----------------------------------------------------------------------===//
+
+LogicalResult
+ReduceScatterOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
+ auto mesh = getMeshAndVerifyAxes(*this, symbolTable);
+ if (failed(mesh)) {
+ return failure();
+ }
+
+ return verifyReduceScatterOperandAndResultShape(
+ getOperand(), getResult(), getScatterAxis().getSExtValue(), getMeshAxes(),
+ mesh.value().canonicalDimSizes());
+}
+
+void ReduceScatterOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
+ MLIRContext *context) {
+ patterns.add<EmptyMeshAxesCanonicalizationPattern<ReduceScatterOp>>(context);
+}
+
//===----------------------------------------------------------------------===//
// TableGen'd op method definitions
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Mesh/canonicalization.mlir b/mlir/test/Dialect/Mesh/canonicalization.mlir
new file mode 100644
index 000000000000000..5802d198d368149
--- /dev/null
+++ b/mlir/test/Dialect/Mesh/canonicalization.mlir
@@ -0,0 +1,101 @@
+// RUN: mlir-opt --canonicalize %s | FileCheck %s
+
+mesh.cluster @mesh0(rank = 2, dim_sizes = [2, 4])
+
+// CHECK-LABEL: func @all_reduce_empty_mesh_axes
+func.func @all_reduce_empty_mesh_axes(
+// CHECK-SAME: %[[ARG:.*]]: tensor<4xf32>
+ %arg0 : tensor<4xf32>) -> tensor<4xf32> {
+// CHECK-NOT: mesh.all_reduce
+ %0 = mesh.all_reduce %arg0 on @mesh0
+ mesh_axes = []
+ : tensor<4xf32> -> tensor<4xf32>
+// CHECK: return %[[ARG]]
+ return %0 : tensor<4xf32>
+}
+
+// CHECK-LABEL: func @all_reduce_empty_mesh_axes_
diff erent_return_type
+func.func @all_reduce_empty_mesh_axes_
diff erent_return_type(
+ %arg0 : tensor<4xf32>) -> tensor<4xf64> {
+// CHECK: mesh.all_reduce
+ %0 = mesh.all_reduce %arg0 on @mesh0
+// CHECK-NOT: mesh_axes
+ mesh_axes = []
+ : tensor<4xf32> -> tensor<4xf64>
+ return %0 : tensor<4xf64>
+}
+
+// CHECK-LABEL: func @all_reduce_default_reduction
+func.func @all_reduce_default_reduction(
+ %arg0 : tensor<4xf32>) -> tensor<4xf64> {
+ %0 = mesh.all_reduce %arg0 on @mesh0
+ mesh_axes = [0]
+// CHECK-NOT: reduction
+ reduction = <sum>
+ : tensor<4xf32> -> tensor<4xf64>
+ return %0 : tensor<4xf64>
+}
+
+// CHECK-LABEL: func @all_to_all_empty_mesh_axes
+func.func @all_to_all_empty_mesh_axes(
+// CHECK-SAME: %[[ARG:.*]]: tensor<8xf32>
+ %arg0 : tensor<8xf32>) -> tensor<8xf32> {
+// CHECK-NOT: mesh.all_to_all
+ %0 = mesh.all_to_all %arg0 on @mesh0
+ mesh_axes = []
+ split_axis = 0
+ concat_axis = 0
+ : tensor<8xf32> -> tensor<8xf32>
+// CHECK: return %[[ARG]]
+ return %0 : tensor<8xf32>
+}
+
+// CHECK-LABEL: func @all_gather_empty_mesh_axes
+func.func @all_gather_empty_mesh_axes(
+// CHECK-SAME: %[[ARG:.*]]: tensor<4xf32>
+ %arg0 : tensor<4xf32>) -> tensor<4xf32> {
+// CHECK-NOT: mesh.all_gather
+ %0 = mesh.all_gather %arg0 on @mesh0
+ mesh_axes = []
+ gather_axis = 0
+ : tensor<4xf32> -> tensor<4xf32>
+// CHECK: return %[[ARG]]
+ return %0 : tensor<4xf32>
+}
+
+// CHECK-LABEL: func @reduce_scatter_empty_mesh_axes
+func.func @reduce_scatter_empty_mesh_axes(
+// CHECK-SAME: %[[ARG:.*]]: tensor<4xf32>
+ %arg0 : tensor<4xf32>) -> tensor<4xf32> {
+// CHECK-NOT: mesh.reduce_scatter
+ %0 = mesh.reduce_scatter %arg0 on @mesh0
+ mesh_axes = []
+ scatter_axis = 0
+ : tensor<4xf32> -> tensor<4xf32>
+// CHECK: return %[[ARG]]
+ return %0 : tensor<4xf32>
+}
+
+// CHECK-LABEL: func @reduce_scatter_empty_mesh_axes_
diff erent_return_type
+func.func @reduce_scatter_empty_mesh_axes_
diff erent_return_type(
+ %arg0 : tensor<4xf32>) -> tensor<4xf64> {
+// CHECK: mesh.reduce_scatter
+ %0 = mesh.reduce_scatter %arg0 on @mesh0
+// CHECK-NOT: mesh_axes
+ mesh_axes = []
+ scatter_axis = 0
+ : tensor<4xf32> -> tensor<4xf64>
+ return %0 : tensor<4xf64>
+}
+
+// CHECK-LABEL: func @reduce_scatter_default_reduction
+func.func @reduce_scatter_default_reduction(
+ %arg0 : tensor<4xf32>) -> tensor<2xf64> {
+ %0 = mesh.reduce_scatter %arg0 on @mesh0
+ mesh_axes = [0]
+// CHECK-NOT: reduction
+ reduction = <sum>
+ scatter_axis = 0
+ : tensor<4xf32> -> tensor<2xf64>
+ return %0 : tensor<2xf64>
+}
diff --git a/mlir/test/Dialect/Mesh/invalid.mlir b/mlir/test/Dialect/Mesh/invalid.mlir
index 246439dd4be7122..2999668f770baa7 100644
--- a/mlir/test/Dialect/Mesh/invalid.mlir
+++ b/mlir/test/Dialect/Mesh/invalid.mlir
@@ -67,3 +67,279 @@ func.func @mesh_axis_negtive_in_partial(
tensor<4x8xf32, #mesh.shard<@mesh0, [[0]], partial=max[-1]>> {
return %arg0 : tensor<4x8xf32, #mesh.shard<@mesh0, [[0]], partial=max[-1]>>
}
+
+// -----
+
+func.func @all_reduce_invalid_mesh_symbol(
+ %arg0 : tensor<4xf32>) -> tensor<4xf64> {
+ // expected-error at +1 {{Undefined required mesh symbol "this_mesh_symbol_does_not_exist".}}
+ %0 = mesh.all_reduce %arg0 on @this_mesh_symbol_does_not_exist reduction = <sum>
+ : tensor<4xf32> -> tensor<4xf64>
+ return %0 : tensor<4xf64>
+}
+
+// -----
+
+mesh.cluster @mesh0(rank = 2, dim_sizes = [2, 4])
+
+func.func @all_reduce_invalid_mesh_axis(
+ %arg0 : tensor<4xf32>) -> tensor<4xf64> {
+ // expected-error at +1 {{0-based mesh axis index 2 is out of bounds. The referenced mesh "mesh0" is of rank 2.}}
+ %0 = mesh.all_reduce %arg0 on @mesh0 mesh_axes = [2] reduction = <sum>
+ : tensor<4xf32> -> tensor<4xf64>
+ return %0 : tensor<4xf64>
+}
+
+// -----
+
+mesh.cluster @mesh0(rank = 2, dim_sizes = [2, 4])
+
+func.func @all_reduce_duplicate_mesh_axis(
+ %arg0 : tensor<4xf32>) -> tensor<4xf64> {
+ // expected-error at +1 {{Mesh axes contains duplicate elements.}}
+ %0 = mesh.all_reduce %arg0 on @mesh0 mesh_axes = [0, 1, 0] reduction = <sum>
+ : tensor<4xf32> -> tensor<4xf64>
+ return %0 : tensor<4xf64>
+}
+
+// -----
+
+mesh.cluster @mesh0(rank = 2, dim_sizes = [2, 4])
+
+func.func @all_reduce_invalid_tensor_dimension_size(
+ %arg0 : tensor<4xf32>) -> tensor<5xf64> {
+ // expected-error at +1 {{'mesh.all_reduce' op requires the same shape for all operands and results}}
+ %0 = mesh.all_reduce %arg0 on @mesh0 : tensor<4xf32> -> tensor<5xf64>
+ return %0 : tensor<5xf64>
+}
+
+// -----
+
+func.func @all_gather_invalid_mesh_symbol(
+ %arg0 : tensor<4xf32>) -> tensor<4xf32> {
+ // expected-error at +1 {{Undefined required mesh symbol "this_mesh_symbol_does_not_exist".}}
+ %0 = mesh.all_gather %arg0 on @this_mesh_symbol_does_not_exist gather_axis = 0
+ : tensor<4xf32> -> tensor<4xf32>
+ return %0 : tensor<4xf32>
+}
+
+// -----
+
+mesh.cluster @mesh0(rank = 2, dim_sizes = [2, 4])
+
+func.func @all_gather_invalid_mesh_axis(
+ %arg0 : tensor<4xf32>) -> tensor<4xf32> {
+ // expected-error at +1 {{0-based mesh axis index 2 is out of bounds. The referenced mesh "mesh0" is of rank 2.}}
+ %0 = mesh.all_gather %arg0 on @mesh0 mesh_axes = [2] gather_axis = 0
+ : tensor<4xf32> -> tensor<4xf32>
+ return %0 : tensor<4xf32>
+}
+
+// -----
+
+mesh.cluster @mesh0(rank = 2, dim_sizes = [2, 4])
+
+func.func @all_reduce_duplicate_mesh_axis(
+ %arg0 : tensor<4xf32>) -> tensor<4xf32> {
+ // expected-error at +1 {{Mesh axes contains duplicate elements.}}
+ %0 = mesh.all_gather %arg0 on @mesh0 mesh_axes = [2, 2] gather_axis = 0
+ : tensor<4xf32> -> tensor<4xf32>
+ return %0 : tensor<4xf32>
+}
+
+// -----
+
+mesh.cluster @mesh0(rank = 1, dim_sizes = [1])
+
+func.func @all_gather_invalid_non_gather_axis_dimension_size(
+ %arg0 : tensor<3x4xf32>) -> tensor<3x5xf32> {
+ // expected-error at +1 {{Dimension size mismatch for result axis 1. Expected 4, but got 5.}}
+ %0 = mesh.all_gather %arg0 on @mesh0 mesh_axes = [0] gather_axis = 0
+ : tensor<3x4xf32> -> tensor<3x5xf32>
+ return %0 : tensor<3x5xf32>
+}
+
+// -----
+
+mesh.cluster @mesh0(rank = 2, dim_sizes = [1, 2])
+
+func.func @all_gather_invalid_gather_axis_dimension_size(
+ %arg0 : tensor<3x4xf32>) -> tensor<3x5xf32> {
+ // expected-error at +1 {{Dimension size mismatch for result axis 1. Expected 8, but got 5.}}
+ %0 = mesh.all_gather %arg0 on @mesh0 mesh_axes = [1] gather_axis = 1
+ : tensor<3x4xf32> -> tensor<3x5xf32>
+ return %0 : tensor<3x5xf32>
+}
+
+// -----
+
+mesh.cluster @mesh0(rank = 1, dim_sizes = [1])
+
+func.func @all_gather_invalid_gather_axis_dynamic_dimension(
+ %arg0 : tensor<?xf32>) -> tensor<3xf32> {
+ // expected-error at +1 {{Dimension size mismatch for result axis 0. Expected dynamic, but got 3.}}
+ %0 = mesh.all_gather %arg0 on @mesh0 gather_axis = 0
+ : tensor<?xf32> -> tensor<3xf32>
+ return %0 : tensor<3xf32>
+}
+
+// -----
+
+mesh.cluster @mesh0(rank = 1, dim_sizes = [1])
+
+func.func @all_gather_invalid_gather_axis(
+ %arg0 : tensor<3xf32>) -> tensor<3xf32> {
+ // expected-error at +1 {{Gather axis 1 is out of bounds [0, 1).}}
+ %0 = mesh.all_gather %arg0 on @mesh0 mesh_axes = [0] gather_axis = 1
+ : tensor<3xf32> -> tensor<3xf32>
+ return %0 : tensor<3xf32>
+}
+
+// -----
+
+mesh.cluster @mesh0(rank = 1, dim_sizes = [1])
+
+func.func @all_gather_invalid_negative_gather_axis(
+ %arg0 : tensor<3xf32>) -> tensor<3xf32> {
+ // expected-error at +1 {{Gather axis -1 is out of bounds [0, 1).}}
+ %0 = mesh.all_gather %arg0 on @mesh0 mesh_axes = [0] gather_axis = -1
+ : tensor<3xf32> -> tensor<3xf32>
+ return %0 : tensor<3xf32>
+}
+
+// -----
+
+func.func @all_to_all_invalid_mesh_symbol(
+ %arg0 : tensor<3x6xi8>) -> tensor<3x6xi8> {
+ // expected-error at +1 {{Undefined required mesh symbol "this_mesh_symbol_does_not_exist".}}
+ %0 = mesh.all_to_all %arg0 on @this_mesh_symbol_does_not_exist
+ split_axis = 1 concat_axis = 0
+ : tensor<3x6xi8> -> tensor<3x6xi8>
+ return %0 : tensor<3x6xi8>
+}
+
+// -----
+
+mesh.cluster @mesh0(rank = 1, dim_sizes = [1])
+
+func.func @all_to_all_duplicate_mesh_axis(
+ %arg0 : tensor<3x6xi8>) -> tensor<3x6xi8> {
+ // expected-error at +1 {{Mesh axes contains duplicate elements.}}
+ %0 = mesh.all_to_all %arg0 on @mesh0 mesh_axes = [0, 0]
+ split_axis = 0 concat_axis = 0
+ : tensor<3x6xi8> -> tensor<3x6xi8>
+ return %0 : tensor<3x6xi8>
+}
+
+// -----
+
+mesh.cluster @mesh0(rank = 2, dim_sizes = [0, 1])
+
+func.func @all_to_all_invalid_non_dynamic_result_dimension_induced_by_dynamic_device_group(
+ %arg0 : tensor<3x6xi8>) -> tensor<3x6xi8> {
+ // expected-error at +1 {{Dimension size mismatch for result axis 1. Expected dynamic, but got 6.}}
+ %0 = mesh.all_to_all %arg0 on @mesh0 mesh_axes = [0]
+ split_axis = 0 concat_axis = 1
+ : tensor<3x6xi8> -> tensor<3x6xi8>
+ return %0 : tensor<3x6xi8>
+}
+
+// -----
+
+mesh.cluster @mesh0(rank = 2, dim_sizes = [1, 1])
+
+func.func @all_to_all_invalid_non_dynamic_result_split_dimension_induced_by_dynamic_operand_dimension(
+ %arg0 : tensor<?x6xi8>) -> tensor<3x?xi8> {
+ // expected-error at +1 {{Dimension size mismatch for result axis 0. Expected dynamic, but got 3.}}
+ %0 = mesh.all_to_all %arg0 on @mesh0 mesh_axes = [1]
+ split_axis = 0 concat_axis = 1
+ : tensor<?x6xi8> -> tensor<3x?xi8>
+ return %0 : tensor<3x?xi8>
+}
+
+// -----
+
+mesh.cluster @mesh0(rank = 2, dim_sizes = [1, 1])
+
+func.func @all_to_all_invalid_non_dynamic_result_concat_dimension_induced_by_dynamic_operand_dimension(
+ %arg0 : tensor<3x?xi8>) -> tensor<?x3xi8> {
+ // expected-error at +1 {{Dimension size mismatch for result axis 1. Expected dynamic, but got 3.}}
+ %0 = mesh.all_to_all %arg0 on @mesh0 mesh_axes = [1]
+ split_axis = 0 concat_axis = 1
+ : tensor<3x?xi8> -> tensor<?x3xi8>
+ return %0 : tensor<?x3xi8>
+}
+
+// -----
+
+mesh.cluster @mesh0(rank = 1, dim_sizes = [3])
+
+func.func @all_to_all_invalid_non_dynamic_result_concat_dimension_size(
+ %arg0 : tensor<3x2xi8>) -> tensor<1x7xi8> {
+ // expected-error at +1 {{Dimension size mismatch for result axis 1. Expected 6, but got 7.}}
+ %0 = mesh.all_to_all %arg0 on @mesh0 mesh_axes = [0]
+ split_axis = 0 concat_axis = 1
+ : tensor<3x2xi8> -> tensor<1x7xi8>
+ return %0 : tensor<1x7xi8>
+}
+
+// -----
+
+mesh.cluster @mesh0(rank = 1, dim_sizes = [3])
+
+func.func @all_to_all_invalid_non_dynamic_result_split_dimension_size(
+ %arg0 : tensor<3x2xi8>) -> tensor<2x6xi8> {
+ // expected-error at +1 {{Dimension size mismatch for result axis 0. Expected 1, but got 2.}}
+ %0 = mesh.all_to_all %arg0 on @mesh0 mesh_axes = [0]
+ split_axis = 0 concat_axis = 1
+ : tensor<3x2xi8> -> tensor<2x6xi8>
+ return %0 : tensor<2x6xi8>
+}
+
+// -----
+
+mesh.cluster @mesh0(rank = 1, dim_sizes = [3])
+
+func.func @reduce_scatter_duplicate_mesh_axis(
+ %arg0 : tensor<?xf32>) -> tensor<?xf64> {
+ // expected-error at +1 {{Mesh axes contains duplicate elements.}}
+ %0 = mesh.reduce_scatter %arg0 on @mesh0 mesh_axes = [0, 0] scatter_axis = 0
+ : tensor<?xf32> -> tensor<?xf64>
+ return %0 : tensor<?xf64>
+}
+
+// -----
+
+mesh.cluster @mesh0(rank = 1, dim_sizes = [3])
+
+func.func @reduce_scatter_invalid_dynamic_dimension(
+ %arg0 : tensor<?xf32>) -> tensor<2xf64> {
+ // expected-error at +1 {{Dimension size mismatch for result axis 0. Expected dynamic, but got 2.}}
+ %0 = mesh.reduce_scatter %arg0 on @mesh0 scatter_axis = 0
+ : tensor<?xf32> -> tensor<2xf64>
+ return %0 : tensor<2xf64>
+}
+
+// -----
+
+mesh.cluster @mesh0(rank = 1, dim_sizes = [3])
+
+func.func @reduce_scatter_invalid_static_dimension_size(
+ %arg0 : tensor<3xf32>) -> tensor<2xf64> {
+ // expected-error at +1 {{Dimension size mismatch for result axis 0. Expected 1, but got 2.}}
+ %0 = mesh.reduce_scatter %arg0 on @mesh0 mesh_axes = [0] scatter_axis = 0
+ : tensor<3xf32> -> tensor<2xf64>
+ return %0 : tensor<2xf64>
+}
+
+// -----
+
+mesh.cluster @mesh0(rank = 1, dim_sizes = [3])
+
+func.func @reduce_scatter_invalid_operand_static_dimension_size(
+ %arg0 : tensor<4xf32>) -> tensor<?xf64> {
+ // expected-error at +1 {{Operand dimension size 4 is not divisible by collective device group size 3 for scatter axis 0.}}
+ %0 = mesh.reduce_scatter %arg0 on @mesh0 mesh_axes = [0] scatter_axis = 0
+ : tensor<4xf32> -> tensor<?xf64>
+ return %0 : tensor<?xf64>
+}
diff --git a/mlir/test/Dialect/Mesh/ops.mlir b/mlir/test/Dialect/Mesh/ops.mlir
index ee5f8f67792b928..5b264bc88dfc2a7 100644
--- a/mlir/test/Dialect/Mesh/ops.mlir
+++ b/mlir/test/Dialect/Mesh/ops.mlir
@@ -12,6 +12,8 @@ mesh.cluster @mesh2(rank = 2, dim_sizes = [0, 4])
// CHECK: mesh.cluster @mesh3
mesh.cluster @mesh3(rank = 2)
+mesh.cluster @mesh4(rank = 1, dim_sizes = [3])
+
// CHECK-LABEL: func @mesh_shard_encoding_fully_replicated
func.func @mesh_shard_encoding_fully_replicated(
// CHECK-SAME: %[[ARG:.*]]: tensor<4x8xf32, #mesh.shard<@mesh0, {{\[\[}}]]>>
@@ -126,3 +128,124 @@ func.func @mesh_shard_op_two_users(%arg0 : tensor<4x8xf32>) ->
%2 = mesh.shard %0 to <@mesh0, [[2]]> annotate_for_users : tensor<4x8xf32>
return %1, %2 : tensor<4x8xf32>, tensor<4x8xf32>
}
+
+// CHECK-LABEL: func @all_reduce
+func.func @all_reduce(
+ // CHECK-SAME: %[[ARG:.*]]: tensor<3x4xf32>
+ %arg0 : tensor<3x4xf32>) -> tensor<3x4xf64> {
+ // CHECK-NEXT: mesh.all_reduce %[[ARG]] on @mesh0 mesh_axes = [1, 0] reduction = <max>
+ // CHECK-SAME: : tensor<3x4xf32> -> tensor<3x4xf64>
+ %0 = mesh.all_reduce %arg0 on @mesh0 mesh_axes = [1, 0] reduction = <max>
+ : tensor<3x4xf32> -> tensor<3x4xf64>
+ return %0 : tensor<3x4xf64>
+}
+
+// CHECK-LABEL: func @all_gather
+func.func @all_gather(
+ // CHECK-SAME: %[[ARG:.*]]: tensor<3x4xf32>
+ %arg0 : tensor<3x4xf32>) -> tensor<3x16xf32> {
+ // CHECK-NEXT: mesh.all_gather %[[ARG]] on @mesh0 mesh_axes = [2] gather_axis = 1
+ // CHECK-SAME: : tensor<3x4xf32> -> tensor<3x16xf32>
+ %0 = mesh.all_gather %arg0 on @mesh0 mesh_axes = [2] gather_axis = 1
+ : tensor<3x4xf32> -> tensor<3x16xf32>
+ return %0 : tensor<3x16xf32>
+}
+
+// CHECK-LABEL: func @all_gather_dynamic_dims_in_tensor
+func.func @all_gather_dynamic_dims_in_tensor(
+ // CHECK-SAME: %[[ARG:.*]]: tensor<?x?xf32>
+ %arg0 : tensor<?x?xf32>) -> tensor<?x?xf32> {
+ // CHECK-NEXT: mesh.all_gather %[[ARG]] on @mesh0 mesh_axes = [2] gather_axis = 1
+ // CHECK-SAME: : tensor<?x?xf32> -> tensor<?x?xf32>
+ %0 = mesh.all_gather %arg0 on @mesh0 mesh_axes = [2] gather_axis = 1
+ : tensor<?x?xf32> -> tensor<?x?xf32>
+ return %0 : tensor<?x?xf32>
+}
+
+// CHECK-LABEL: func @all_gather_dynamic_dims_in_mesh
+func.func @all_gather_dynamic_dims_in_mesh(
+ // CHECK-SAME: %[[ARG:.*]]: tensor<5x6xf32>
+ %arg0 : tensor<5x6xf32>) -> tensor<5x?xf32> {
+ // CHECK-NEXT: mesh.all_gather %[[ARG]] on @mesh3 mesh_axes = [1] gather_axis = 1
+ // CHECK-SAME: : tensor<5x6xf32> -> tensor<5x?xf32>
+ %0 = mesh.all_gather %arg0 on @mesh3 mesh_axes = [1] gather_axis = 1
+ : tensor<5x6xf32> -> tensor<5x?xf32>
+ return %0 : tensor<5x?xf32>
+}
+
+// CHECK-LABEL: func @all_to_all
+func.func @all_to_all(
+ // CHECK-SAME: %[[ARG:.*]]: tensor<3x6xi8>
+ %arg0 : tensor<3x6xi8>) -> tensor<3x6xi8> {
+ // CHECK-NEXT: mesh.all_to_all %[[ARG]]
+ // CHECK-SAME: on @mesh4 split_axis = 1 concat_axis = 0
+ // CHECK-SAME: : tensor<3x6xi8> -> tensor<3x6xi8>
+ %0 = mesh.all_to_all %arg0 on @mesh4
+ split_axis = 1 concat_axis = 0
+ : tensor<3x6xi8> -> tensor<3x6xi8>
+ return %0 : tensor<3x6xi8>
+}
+
+// CHECK-LABEL: func @all_to_all_dynamic_dims_in_result
+func.func @all_to_all_dynamic_dims_in_result(
+ // CHECK-SAME: %[[ARG:.*]]: tensor<3x6xi8>
+ %arg0 : tensor<3x6xi8>) -> tensor<3x?xi8> {
+ // CHECK-NEXT: mesh.all_to_all %[[ARG]]
+ // CHECK-SAME: on @mesh4 split_axis = 1 concat_axis = 0
+ // CHECK-SAME: : tensor<3x6xi8> -> tensor<3x?xi8>
+ %0 = mesh.all_to_all %arg0 on @mesh4
+ split_axis = 1 concat_axis = 0
+ : tensor<3x6xi8> -> tensor<3x?xi8>
+ return %0 : tensor<3x?xi8>
+}
+
+// CHECK-LABEL: func @all_to_all_same_split_concat_dim_with_dynamic_device_group_size
+func.func @all_to_all_same_split_concat_dim_with_dynamic_device_group_size(
+ // CHECK-SAME: %[[ARG:.*]]: tensor<3xi8>
+ %arg0 : tensor<3xi8>) -> tensor<3xi8> {
+ // CHECK-NEXT: mesh.all_to_all %[[ARG]]
+ // CHECK-SAME: @mesh4 split_axis = 0 concat_axis = 0
+ // CHECK-SAME: : tensor<3xi8> -> tensor<3xi8>
+ %0 = mesh.all_to_all %arg0 on @mesh4
+ split_axis = 0 concat_axis = 0
+ : tensor<3xi8> -> tensor<3xi8>
+ return %0 : tensor<3xi8>
+}
+
+// CHECK-LABEL: func @all_to_all_non_divisible_split_axis_size
+func.func @all_to_all_non_divisible_split_axis_size(
+ // CHECK-SAME: %[[ARG:.*]]: tensor<2x3xi8>
+ %arg0 : tensor<2x3xi8>) -> tensor<?x12xi8> {
+ // CHECK-NEXT: mesh.all_to_all %[[ARG]]
+ // CHECK-SAME: @mesh0 mesh_axes = [0, 1] split_axis = 0 concat_axis = 1
+ // CHECK-SAME: : tensor<2x3xi8> -> tensor<?x12xi8>
+ %0 = mesh.all_to_all %arg0 on @mesh0 mesh_axes = [0, 1]
+ split_axis = 0 concat_axis = 1
+ : tensor<2x3xi8> -> tensor<?x12xi8>
+ return %0 : tensor<?x12xi8>
+}
+
+// CHECK-LABEL: func @reduce_scatter_static_dimensions
+func.func @reduce_scatter_static_dimensions(
+ // CHECK-SAME: %[[ARG:.*]]: tensor<3x4xf32>
+ %arg0 : tensor<3x4xf32>) -> tensor<3x1xf64> {
+ // CHECK-NEXT: mesh.reduce_scatter %[[ARG]]
+ // CHECK-SAME: on @mesh0 mesh_axes = [2] reduction = <max> scatter_axis = 1
+ // CHECK-SAME: : tensor<3x4xf32> -> tensor<3x1xf64>
+ %0 = mesh.reduce_scatter %arg0 on @mesh0 mesh_axes = [2]
+ reduction = <max> scatter_axis = 1
+ : tensor<3x4xf32> -> tensor<3x1xf64>
+ return %0 : tensor<3x1xf64>
+}
+
+// CHECK-LABEL: func @reduce_scatter_dynamic_dimensions
+func.func @reduce_scatter_dynamic_dimensions(
+ // CHECK-SAME: %[[ARG:.*]]: tensor<?xf32>
+ %arg0 : tensor<?xf32>) -> tensor<?xf64> {
+ // CHECK-NEXT: mesh.reduce_scatter %[[ARG]]
+ // CHECK-SAME: on @mesh3 mesh_axes = [0, 1] scatter_axis = 0
+ // CHECK-SAME: : tensor<?xf32> -> tensor<?xf64>
+ %0 = mesh.reduce_scatter %arg0 on @mesh3 mesh_axes = [0, 1] scatter_axis = 0
+ : tensor<?xf32> -> tensor<?xf64>
+ return %0 : tensor<?xf64>
+}
More information about the Mlir-commits
mailing list