[Mlir-commits] [mlir] [mlir][mesh] Add collective communication operations (PR #71960)
Boian Petkantchin
llvmlistbot at llvm.org
Fri Nov 10 09:18:19 PST 2023
https://github.com/sogartar created https://github.com/llvm/llvm-project/pull/71960
Add all-gather, all-reduce, all-to-all and reduce-scatter. These operations have device mesh semantics.
I have not included ops like reduce, gather, send and recv to see first if reviewers notice any systemic issues. Also this PR is already big enough.
>From 5b8c7f70202f7cac6b6b1a29e168fe6fa43fc720 Mon Sep 17 00:00:00 2001
From: Boian Petkantchin <boian at nod-labs.com>
Date: Mon, 6 Nov 2023 15:55:31 -0800
Subject: [PATCH] [mlir][mesh] Add collective communication operations
Add all-gather, all-reduce, all-to-all and reduce-scatter.
These operations have device mesh semantics.
---
mlir/docs/Dialects/Mesh.md | 34 ++
mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td | 8 +-
mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h | 2 +
mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td | 216 +++++++++
mlir/lib/Dialect/Mesh/IR/MeshOps.cpp | 417 ++++++++++++++++++
mlir/test/Dialect/Mesh/canonicalization.mlir | 72 +++
mlir/test/Dialect/Mesh/invalid.mlir | 240 ++++++++++
mlir/test/Dialect/Mesh/ops.mlir | 119 +++++
8 files changed, 1105 insertions(+), 3 deletions(-)
create mode 100644 mlir/docs/Dialects/Mesh.md
create mode 100644 mlir/test/Dialect/Mesh/canonicalization.mlir
diff --git a/mlir/docs/Dialects/Mesh.md b/mlir/docs/Dialects/Mesh.md
new file mode 100644
index 000000000000000..6dd4f79022061ee
--- /dev/null
+++ b/mlir/docs/Dialects/Mesh.md
@@ -0,0 +1,34 @@
+# 'mesh' Dialect
+
+The `mesh` dialect contains a set of attributes, operations and interfaces that
+are useful for representing sharding and communication on a device mesh
+cluster.
+
+[TOC]
+
+## Collective Communication Operations
+There are a number of operations in the Mesh dialect to facilitate
+communication between devices in a mesh.
+It is assumed that the user is familiar with collective operations.
+[Wikipedia](https://en.wikipedia.org/wiki/Collective_operation) has a good
+explanation.
+The main addition is that the collectives in this dialect have mesh
+semantics.
+The operation attributes `mesh` and `mesh_axes` specifies a set of device mesh
+axes that partition the devices into disjoint groups.
+The collective operation is performed between devices in the same group.
+Devices that have the same coordinates outside of axes `mesh_axes` are in the
+same group.
+For example if we have a device mesh of size `2x3x4x5` and the partition mesh
+axes set is `{0, 1}` then devices are partitioned into the groups
+`{ { (i, j, k, m) | 0<=i<2, 0<=j<3 } | 0<=k<4, 0<=m<5 }`.
+Devices (1, 0, 2, 3) and (1, 1, 2, 3) will be in the same group.
+Device (1, 0, 2, 4) will be in another group.
+
+## Operations
+
+[include "Dialects/MeshOps.md"]
+
+## Attributes
+
+[include "Dialects/MeshAttributes.md"]
diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td b/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td
index a91ef569347bff1..9d39b1b3329fb4b 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td
@@ -23,9 +23,7 @@ def Mesh_Dialect : Dialect {
let cppNamespace = "::mlir::mesh";
let description = [{
- The `mesh` dialect contains a set of attributes, operations, interfaces that
- are useful for representing sharding and communication on device mesh
- cluster.
+ See [Mesh dialect documentation](mlir/docs/Dialects/Mesh.md).
}];
let dependentDialects = [
@@ -49,6 +47,10 @@ def Mesh_Partial : I32EnumAttr<"Partial", "partial type of a distributed tensor"
let cppNamespace = "::mlir::mesh";
}
+def Mesh_PartialAttr : EnumAttr<Mesh_Dialect, Mesh_Partial, "partial"> {
+ let assemblyFormat = "`<` $value `>`";
+}
+
// Mesh_IteratorType and Mesh_Partial are used to annotate different aspects of
// distributed tensors. Mesh_IteratorType annotates loops in an operation, while
// Mesh_Partial indicates whether a tensor is sharded on a specific dimension or
diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
index 05eba66a89949b6..7698d60813a8f10 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
@@ -10,9 +10,11 @@
#define MLIR_DIALECT_MESH_IR_MESHOPS_H
#include "mlir/Bytecode/BytecodeOpInterface.h"
+#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/SymbolTable.h"
#include "mlir/Interfaces/InferTypeOpInterface.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
+#include <algorithm>
#include "mlir/Dialect/Mesh/IR/MeshOpsDialect.h.inc"
diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
index a8aa0a694bee29f..15354babe870599 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
@@ -13,6 +13,8 @@ include "mlir/Dialect/Mesh/IR/MeshBase.td"
include "mlir/Interfaces/InferTypeOpInterface.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
include "mlir/IR/BuiltinTypes.td"
+include "mlir/IR/CommonAttrConstraints.td"
+include "mlir/IR/CommonTypeConstraints.td"
include "mlir/IR/SymbolInterfaces.td"
//===----------------------------------------------------------------------===//
@@ -77,6 +79,15 @@ def Mesh_ClusterOp : Mesh_Op<"cluster", [Symbol]> {
$sym_name `(` `rank` `=` $rank (`,` `dim_sizes` `=` $dim_sizes^)? `)`
attr-dict
}];
+ let extraClassDeclaration = [{
+ ::mlir::SmallVector<int64_t> canonicalDimSizes();
+
+ template <typename OutIt>
+ void canonicalDimSizes(OutIt outIt) {
+ std::copy(getDimSizes().begin(), getDimSizes().end(), outIt);
+ std::fill_n(outIt, getRank() - getDimSizes().size(), 0);
+ }
+ }];
let hasVerifier = 1;
}
@@ -171,4 +182,209 @@ def Mesh_ShardOp : Mesh_Op<"shard", [Pure, SameOperandsAndResultType]> {
}];
}
+//===----------------------------------------------------------------------===//
+// collective communication ops
+//===----------------------------------------------------------------------===//
+
+class Mesh_CollectiveCommunicationOpBase<
+ string mnemonic, list<Trait> traits = []> :
+ Mesh_Op<mnemonic,
+ !listconcat(traits,
+ [SymbolUserOpInterface])> {
+ let assemblyFormat = "$input attr-dict `:` type($input) `->` type($result)";
+ code extraClassDeclarationBase = [{
+ ::mlir::LogicalResult verifySymbolUses(
+ ::mlir::SymbolTableCollection &symbolTable);
+ }];
+}
+
+def Mesh_AllGatherOp : Mesh_CollectiveCommunicationOpBase<"all_gather", [
+ SameOperandsAndResultElementType,
+ SameOperandsAndResultRank
+ ]> {
+ let summary = "All-gather over a device mesh.";
+ let description = [{
+ Gathers along the `gather_axis` tensor axis.
+ The order of input tensors in the resulting tensor is the same as the
+ order of the corresponding devices' multi-index in the mesh.
+
+ Example:
+ ```mlir
+ mesh.cluster @mesh0(rank = 2, dim_sizes = [2, 2])
+ ...
+ %1 = mesh.all_gather %0 {
+ mesh = @mesh0, mesh_axes = array<i16: 1>, gather_axis = 1 : index
+ } : tensor<2x2xi8> -> tensor<2x4xi8>
+ ```
+ Input:
+ ```
+ +-------+-------+
+ device (0, 0) -> | 1 2 | 5 6 | <- device (0, 1)
+ | 3 4 | 7 8 |
+ +-------+-------+
+ device (1, 0) -> | 9 10 | 13 14 | <- device (1, 1)
+ | 11 12 | 15 16 |
+ +-------+-------+
+ ```
+ Result:
+ ```
+ +-------------+
+ | 1 2 5 6 | <- devices (0, 0) and (0, 1)
+ | 3 4 7 8 |
+ +-------------+
+ | 9 10 13 14 | <- devices (1, 0) and (1, 1)
+ | 11 12 15 16 |
+ +-------------+
+ ```
+ }];
+ let arguments = (ins
+ AnyNon0RankedTensor:$input,
+ FlatSymbolRefAttr:$mesh,
+ DefaultValuedOptionalAttr<DenseI16ArrayAttr, "{}">:$mesh_axes,
+ APIntAttr:$gather_axis
+ );
+ let results = (outs
+ AnyNon0RankedTensor:$result
+ );
+ let hasCanonicalizer = 1;
+ let hasVerifier = 1;
+ let extraClassDeclaration = extraClassDeclarationBase;
+}
+
+def Mesh_AllReduceOp : Mesh_CollectiveCommunicationOpBase<"all_reduce", [
+ SameOperandsAndResultShape]> {
+ let summary = "All-reduce over a device mesh.";
+ let description = [{
+ The accumulation element type is specified by the result type and
+ it does not need to match the input element type.
+ The input element is converted to the result element type before
+ performing the reduction.
+
+ Attributes:
+ `reduction`: Indicates the reduction method.
+
+ Example:
+ ```
+ %1 = mesh.all_reduce %0 {
+ mesh = @mesh0, mesh_axes = array<i16: 1, 0>, reduction = #mesh.partial<max>
+ } : tensor<3x4xf32> -> tensor<3x4xf64>
+ ```
+ }];
+ let arguments = (ins
+ AnyRankedTensor:$input,
+ FlatSymbolRefAttr:$mesh,
+ DefaultValuedOptionalAttr<DenseI16ArrayAttr, "{}">:$mesh_axes,
+ DefaultValuedOptionalAttr<Mesh_PartialAttr, "::mlir::mesh::Partial::Sum">:$reduction
+ );
+ let results = (outs
+ AnyRankedTensor:$result
+ );
+ let hasCanonicalizer = 1;
+ let extraClassDeclaration = extraClassDeclarationBase;
+}
+
+def Mesh_AllToAllOp : Mesh_CollectiveCommunicationOpBase<"all_to_all", [
+ SameOperandsAndResultElementType,
+ SameOperandsAndResultRank]> {
+ let summary = "All-to-all over a device mesh.";
+ let description = [{
+ Performs an all-to-all on tensor pieces split along `split_axis`.
+ The resulting pieces are concatenated along `concat_axis` on ech device.
+ Example:
+ ```
+ mesh.cluster @mesh0(rank = 1, dim_sizes = [3])
+ ...
+ %1 = mesh.all_to_all %0 {
+ mesh = @mesh0, mesh_axes = array<i16: 0>, split_axis = 0, concat_axis = 0
+ } : tensor<3x6xi8> -> tensor<3x6xi8>
+ ```
+ Input:
+ ```
+ device device device
+ (0) (1) (2)
+ +-------+-------+-------+
+ | 11 12 | 21 22 | 31 32 |
+ | 13 14 | 23 24 | 33 34 |
+ | 15 16 | 25 26 | 35 36 |
+ +-------+-------+-------+
+ ```
+ Result:
+ ```
+ device device device
+ (0) (1) (2)
+ +-------+-------+-------+
+ | 11 12 | 13 14 | 15 16 |
+ | 21 22 | 23 24 | 25 26 |
+ | 31 32 | 33 34 | 35 36 |
+ +-------+-------+-------+
+ ```
+ }];
+ let arguments = (ins
+ AnyNon0RankedTensor:$input,
+ FlatSymbolRefAttr:$mesh,
+ DefaultValuedOptionalAttr<DenseI16ArrayAttr, "{}">:$mesh_axes,
+ APIntAttr:$split_axis,
+ APIntAttr:$concat_axis
+ );
+ let results = (outs
+ AnyNon0RankedTensor:$result
+ );
+ let hasCanonicalizer = 1;
+ let hasVerifier = 1;
+ let extraClassDeclaration = extraClassDeclarationBase;
+}
+
+def Mesh_ReduceScatterOp : Mesh_CollectiveCommunicationOpBase<"reduce_scatter", [
+ SameOperandsAndResultRank]> {
+ let summary = "Reduce-scatter over a device mesh.";
+ let description = [{
+ After the reduction scatters the result within each device group.
+ The tensor is split along `scatter_axis` and the pieces distributed
+ across the device group.
+ Example:
+ ```
+ mesh.cluster @mesh0(rank = 1, dim_sizes = [2, 2])
+ ...
+ %1 = mesh.reduce_scatter %0 {
+ mesh = @mesh0, mesh_axes = array<i16: 1>, reduction = #mesh.partial<max>, scatter_axis = 0
+ } : tensor<3x4xf32> -> tensor<1x4xf64>
+ ```
+ Input:
+ ```
+ +-------+-------+
+ device (0, 0) -> | 1 2 | 5 6 | <- device (0, 1)
+ | 3 4 | 7 8 |
+ +-------+-------+
+ device (1, 0) -> | 9 10 | 13 14 | <- device (1, 1)
+ | 11 12 | 15 16 |
+ +-------+-------+
+ ```
+ Result:
+ ```
+ +-------+
+ | 6 8 | <- devices (0, 0)
+ +-------+
+ | 10 12 | <- devices (0, 1)
+ +-------+
+ | 22 24 | <- devices (1, 0)
+ +-------+
+ | 26 28 | <- devices (1, 1)
+ +-------+
+ ```
+ }];
+ let arguments = (ins
+ AnyNon0RankedTensor:$input,
+ FlatSymbolRefAttr:$mesh,
+ DefaultValuedOptionalAttr<DenseI16ArrayAttr, "{}">:$mesh_axes,
+ DefaultValuedOptionalAttr<Mesh_PartialAttr, "::mlir::mesh::Partial::Sum">:$reduction,
+ APIntAttr:$scatter_axis
+ );
+ let results = (outs
+ AnyRankedTensor:$result
+ );
+ let hasCanonicalizer = 1;
+ let hasVerifier = 1;
+ let extraClassDeclaration = extraClassDeclarationBase;
+}
+
#endif // MLIR_DIALECT_MESH_IR_MESHOPS_TD
diff --git a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
index 588704f24574f90..6efc4c4ecc326ad 100644
--- a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
+++ b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
@@ -8,10 +8,26 @@
#include "mlir/Dialect/Mesh/IR/MeshOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/BuiltinTypeInterfaces.h"
+#include "mlir/IR/Diagnostics.h"
#include "mlir/IR/DialectImplementation.h"
+#include "mlir/IR/Location.h"
+#include "mlir/IR/PatternMatch.h"
#include "mlir/Support/LLVM.h"
+#include "mlir/Support/LogicalResult.h"
+#include "llvm/ADT/DenseSet.h"
+#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallSet.h"
+#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/TypeSwitch.h"
+#include <algorithm>
+#include <functional>
+#include <iterator>
+#include <numeric>
+#include <optional>
+#include <string>
+#include <utility>
#define DEBUG_TYPE "mesh-ops"
#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ")
@@ -21,6 +37,60 @@ using namespace mlir::mesh;
#include "mlir/Dialect/Mesh/IR/MeshOpsDialect.cpp.inc"
+namespace {
+
+template <typename It>
+It canonicalizeSetAsArray(It begin, It end) {
+ std::sort(begin, end);
+ return std::unique(begin, end);
+}
+
+template <typename R>
+auto canonicalizeSetAsArray(R &&range) {
+ return canonicalizeSetAsArray(adl_begin(range), adl_end(range));
+}
+
+template <typename T>
+SmallVector<T> &canonicalizeSetAsVector(SmallVector<T> &vec) {
+ auto newEnd = canonicalizeSetAsArray(vec);
+ vec.resize(newEnd - vec.begin());
+ return vec;
+}
+
+template <typename DimSize>
+bool isMeshDimensionDynamic(DimSize size) {
+ return size <= DimSize(0);
+}
+
+using MeshAxis = int16_t;
+
+struct DimensionSize {
+ static DimensionSize dynamic() { return DimensionSize(ShapedType::kDynamic); }
+ DimensionSize(int64_t val) : val(val) {}
+ int64_t value() const { return val; }
+ operator int64_t() const { return val; }
+ bool isDynamic() const { return ShapedType::isDynamic(val); }
+
+private:
+ int64_t val;
+};
+
+DimensionSize operator/(DimensionSize lhs, DimensionSize rhs) {
+ if (lhs.isDynamic() || rhs.isDynamic()) {
+ return DimensionSize::dynamic();
+ }
+ return lhs.value() / rhs.value();
+}
+
+DimensionSize operator*(DimensionSize lhs, DimensionSize rhs) {
+ if (lhs.isDynamic() || rhs.isDynamic()) {
+ return DimensionSize::dynamic();
+ }
+ return lhs.value() * rhs.value();
+}
+
+} // namespace
+
//===----------------------------------------------------------------------===//
// Mesh dialect
//===----------------------------------------------------------------------===//
@@ -96,6 +166,12 @@ LogicalResult ClusterOp::verify() {
return success();
}
+SmallVector<int64_t> ClusterOp::canonicalDimSizes() {
+ SmallVector<int64_t> result;
+ canonicalDimSizes(std::back_inserter(result));
+ return result;
+}
+
//===----------------------------------------------------------------------===//
// mesh.shard op
//===----------------------------------------------------------------------===//
@@ -129,6 +205,347 @@ MeshShardingAttr::verify(function_ref<InFlightDiagnostic()> emitError,
return success();
}
+//===----------------------------------------------------------------------===//
+// collective communication ops
+//===----------------------------------------------------------------------===//
+
+namespace {
+
+std::optional<DenseI16ArrayAttr>
+canonicalizeAxesSetAttribute(DenseI16ArrayAttr attr) {
+ if (!attr) {
+ return std::nullopt;
+ }
+ SmallVector<int16_t> axes = llvm::to_vector(attr.asArrayRef());
+ canonicalizeSetAsVector(axes);
+ if (axes.empty()) {
+ return std::nullopt;
+ }
+ return DenseI16ArrayAttr::get(attr.getContext(), axes);
+}
+
+template <typename Op>
+struct AxesSetCanonicalizationPattern : OpRewritePattern<Op> {
+ AxesSetCanonicalizationPattern(MLIRContext *context, StringRef axisSetAttr)
+ : OpRewritePattern<Op>(context), axisSetAttr(axisSetAttr) {}
+ LogicalResult matchAndRewrite(Op op,
+ PatternRewriter &rewriter) const override {
+ auto canonicalMeshAxesAttr = canonicalizeAxesSetAttribute(
+ op->template getAttrOfType<DenseI16ArrayAttr>(axisSetAttr));
+ if (!canonicalMeshAxesAttr) {
+ op->removeAttr(axisSetAttr);
+ } else {
+ op->setAttr(axisSetAttr, canonicalMeshAxesAttr.value());
+ }
+ return success();
+ }
+
+ std::string axisSetAttr;
+};
+
+template <typename Op>
+void populateMeshAxesSetCanonicalizationPatterns(RewritePatternSet &patterns,
+ MLIRContext *context) {
+ patterns.add<AxesSetCanonicalizationPattern<Op>>(context, "mesh_axes");
+}
+
+template <typename Op>
+LogicalResult verifyMeshSymbolUses(Op op, SymbolTableCollection &symbolTable) {
+ FlatSymbolRefAttr symbolAttr = op.getMeshAttr();
+ if (!symbolAttr) {
+ return op.emitError() << "Unspecified \"mesh\" symbol attribute.";
+ }
+ SymbolTableCollection symbolTableCollection;
+ mesh::ClusterOp mesh =
+ symbolTableCollection.lookupNearestSymbolFrom<mesh::ClusterOp>(
+ op.getOperation(), symbolAttr);
+ if (!mesh) {
+ return op.emitError() << "Undefined required mesh symbol \""
+ << symbolAttr.getValue() << "\".";
+ }
+ DenseI16ArrayAttr meshAxes = op.getMeshAxesAttr();
+ if (!meshAxes) {
+ return success();
+ }
+ MeshAxis rank = mesh.getRank();
+ for (auto axis : meshAxes.asArrayRef()) {
+ if (axis >= rank || axis < 0) {
+ return op.emitError()
+ << "0-based mesh axis index " << axis
+ << " is out of bounds. The referenced mesh \""
+ << symbolAttr.getValue() << "\" is of rank " << rank << ".";
+ }
+ }
+
+ return success();
+}
+
+template <typename It>
+auto product(It begin, It end) {
+ using ElementType = std::decay_t<decltype(*begin)>;
+ return std::accumulate(begin, end, ElementType(1),
+ std::multiplies<ElementType>());
+}
+
+template <typename R>
+auto product(R &&range) {
+ return product(adl_begin(range), adl_end(range));
+}
+
+int64_t collectiveDeviceGroupSize(ArrayRef<MeshAxis> meshAxes,
+ ArrayRef<int64_t> meshShape) {
+ int64_t res = 1;
+ for (MeshAxis axis = 0; axis < MeshAxis(meshShape.size()); ++axis) {
+ if (llvm::find(meshAxes, axis) == meshAxes.end()) {
+ continue;
+ }
+ if (isMeshDimensionDynamic(meshShape[axis])) {
+ return ShapedType::kDynamic;
+ }
+ res *= meshShape[axis];
+ }
+ return res;
+}
+
+LogicalResult verifyDimensionCompatibility(Location loc,
+ int64_t expectedDimSize,
+ int64_t resultDimSize,
+ int64_t resultAxis) {
+ if (!ShapedType::isDynamic(resultDimSize) &&
+ expectedDimSize != resultDimSize) {
+ return emitError(loc) << "Dimension size mismatch for result axis "
+ << resultAxis << ". Expected "
+ << (ShapedType::isDynamic(expectedDimSize)
+ ? Twine("dynamic")
+ : Twine(expectedDimSize))
+ << ", but got " << resultDimSize << ".";
+ }
+
+ return success();
+}
+
+LogicalResult verifyGatherOperandAndResultShape(Value operand, Value result,
+ int64_t gatherAxis,
+ ArrayRef<MeshAxis> meshAxes,
+ ArrayRef<int64_t> meshShape) {
+ ShapedType operandType = operand.getType().cast<ShapedType>();
+ ShapedType resultType = result.getType().cast<ShapedType>();
+ auto deviceGroupSize =
+ DimensionSize(collectiveDeviceGroupSize(meshAxes, meshShape));
+ for (int64_t axis = 0; axis < operandType.getRank(); ++axis) {
+ auto operandDimSize = DimensionSize(operandType.getDimSize(axis));
+ auto resultDimSize = DimensionSize(resultType.getDimSize(axis));
+ auto expectedResultDimSize =
+ axis == gatherAxis ? deviceGroupSize * operandDimSize : operandDimSize;
+ if (failed(verifyDimensionCompatibility(
+ result.getLoc(), expectedResultDimSize, resultDimSize, axis))) {
+ return failure();
+ }
+ }
+ return success();
+}
+
+template <typename Op>
+FailureOr<ClusterOp> getMesh(Op op) {
+ SymbolTableCollection symbolTableCollection;
+ if (failed(verifyMeshSymbolUses(op, symbolTableCollection))) {
+ // We need to check the symbol here since this runs before
+ // SymbolUserOpInterface.
+ return failure();
+ }
+ return symbolTableCollection.lookupNearestSymbolFrom<mesh::ClusterOp>(
+ op.getOperation(), op.getMeshAttr());
+}
+
+template <typename Op>
+LogicalResult verifyGather(Op op) {
+ auto rank = op.getResult().getType().template cast<ShapedType>().getRank();
+ auto gatherAxis = op.getGatherAxis().getSExtValue();
+ if (gatherAxis < 0 || gatherAxis >= rank) {
+ return op.emitError() << "Gather axis " << gatherAxis
+ << " is out of bounds [0, " << rank << ").";
+ }
+
+ auto mesh = getMesh(op);
+ if (failed(mesh)) {
+ return failure();
+ }
+ return verifyGatherOperandAndResultShape(op.getOperand(), op.getResult(),
+ gatherAxis, op.getMeshAxes(),
+ mesh.value().canonicalDimSizes());
+}
+
+LogicalResult verifyAllToAllOperandAndResultShape(Value operand, Value result,
+ int64_t splitAxis,
+ int64_t concatAxis,
+ ArrayRef<MeshAxis> meshAxes,
+ ArrayRef<int64_t> meshShape) {
+ ShapedType operandType = operand.getType().cast<ShapedType>();
+ ShapedType resultType = result.getType().cast<ShapedType>();
+ for (int64_t axis = 0; axis < operandType.getRank(); ++axis) {
+ if ((axis != splitAxis && axis != concatAxis) || splitAxis == concatAxis) {
+ if (failed(verifyDimensionCompatibility(
+ result.getLoc(), operandType.getDimSize(axis),
+ resultType.getDimSize(axis), axis))) {
+ return failure();
+ }
+ }
+ }
+
+ if (splitAxis == concatAxis) {
+ return success();
+ }
+
+ auto deviceGroupSize =
+ DimensionSize(collectiveDeviceGroupSize(meshAxes, meshShape));
+ auto operandConcatDimSize = DimensionSize(operandType.getDimSize(concatAxis));
+ auto operandSplitDimSize = DimensionSize(operandType.getDimSize(splitAxis));
+ if (!operandSplitDimSize.isDynamic() && !deviceGroupSize.isDynamic() &&
+ int64_t(operandSplitDimSize) % int64_t(deviceGroupSize) != 0) {
+ return emitError(result.getLoc())
+ << "Operand dimension size " << int64_t(operandSplitDimSize)
+ << " is not divisible by collective device group size "
+ << int64_t(deviceGroupSize) << " for split axis " << splitAxis
+ << ".";
+ }
+ DimensionSize expectedResultConcatDimSize =
+ operandConcatDimSize * deviceGroupSize;
+ DimensionSize expectedResultSplitDimSize =
+ operandSplitDimSize / deviceGroupSize;
+ if (failed(verifyDimensionCompatibility(
+ result.getLoc(), expectedResultConcatDimSize.value(),
+ resultType.getDimSize(concatAxis), concatAxis))) {
+ return failure();
+ }
+ if (failed(verifyDimensionCompatibility(
+ result.getLoc(), expectedResultSplitDimSize.value(),
+ resultType.getDimSize(splitAxis), splitAxis))) {
+ return failure();
+ }
+
+ return success();
+}
+
+LogicalResult verifyReduceScatterOperandAndResultShape(
+ Value operand, Value result, int64_t scatterAxis,
+ ArrayRef<MeshAxis> meshAxes, ArrayRef<int64_t> meshShape) {
+ ShapedType operandType = operand.getType().cast<ShapedType>();
+ ShapedType resultType = result.getType().cast<ShapedType>();
+ for (int64_t axis = 0; axis < operandType.getRank(); ++axis) {
+ if (axis != scatterAxis) {
+ if (failed(verifyDimensionCompatibility(
+ result.getLoc(), operandType.getDimSize(axis),
+ resultType.getDimSize(axis), axis))) {
+ return failure();
+ }
+ }
+ }
+
+ auto deviceGroupSize =
+ DimensionSize(collectiveDeviceGroupSize(meshAxes, meshShape));
+ auto operandScatterDimSize =
+ DimensionSize(operandType.getDimSize(scatterAxis));
+ if (!operandScatterDimSize.isDynamic() && !deviceGroupSize.isDynamic() &&
+ int64_t(operandScatterDimSize) % int64_t(deviceGroupSize) != 0) {
+ return emitError(result.getLoc())
+ << "Operand dimension size " << int64_t(operandScatterDimSize)
+ << " is not divisible by collective device group size "
+ << int64_t(deviceGroupSize) << " for scatter axis " << scatterAxis
+ << ".";
+ }
+ DimensionSize expectedResultScatterDimSize =
+ operandScatterDimSize / deviceGroupSize;
+ if (failed(verifyDimensionCompatibility(
+ result.getLoc(), expectedResultScatterDimSize.value(),
+ resultType.getDimSize(scatterAxis), scatterAxis))) {
+ return failure();
+ }
+
+ return success();
+}
+
+} // namespace
+
+//===----------------------------------------------------------------------===//
+// mesh.all_reduce op
+//===----------------------------------------------------------------------===//
+
+LogicalResult
+mlir::mesh::AllReduceOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
+ return verifyMeshSymbolUses(*this, symbolTable);
+}
+
+void mlir::mesh::AllReduceOp::getCanonicalizationPatterns(
+ RewritePatternSet &patterns, MLIRContext *context) {
+ populateMeshAxesSetCanonicalizationPatterns<AllReduceOp>(patterns, context);
+}
+
+//===----------------------------------------------------------------------===//
+// mesh.all_gather op
+//===----------------------------------------------------------------------===//
+
+LogicalResult
+mlir::mesh::AllGatherOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
+ return verifyMeshSymbolUses(*this, symbolTable);
+}
+
+void mlir::mesh::AllGatherOp::getCanonicalizationPatterns(
+ RewritePatternSet &patterns, MLIRContext *context) {
+ populateMeshAxesSetCanonicalizationPatterns<AllGatherOp>(patterns, context);
+}
+
+LogicalResult mlir::mesh::AllGatherOp::verify() { return verifyGather(*this); }
+
+//===----------------------------------------------------------------------===//
+// mesh.all_to_all op
+//===----------------------------------------------------------------------===//
+
+LogicalResult
+mlir::mesh::AllToAllOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
+ return verifyMeshSymbolUses(*this, symbolTable);
+}
+
+void mlir::mesh::AllToAllOp::getCanonicalizationPatterns(
+ RewritePatternSet &patterns, MLIRContext *context) {
+ populateMeshAxesSetCanonicalizationPatterns<AllToAllOp>(patterns, context);
+}
+
+LogicalResult mlir::mesh::AllToAllOp::verify() {
+ auto mesh = ::getMesh(*this);
+ if (failed(mesh)) {
+ return failure();
+ }
+ return verifyAllToAllOperandAndResultShape(
+ getOperand(), getResult(), getSplitAxis().getSExtValue(),
+ getConcatAxis().getSExtValue(), getMeshAxes(),
+ mesh.value().canonicalDimSizes());
+}
+
+//===----------------------------------------------------------------------===//
+// mesh.reduce_scatter op
+//===----------------------------------------------------------------------===//
+
+LogicalResult mlir::mesh::ReduceScatterOp::verifySymbolUses(
+ SymbolTableCollection &symbolTable) {
+ return verifyMeshSymbolUses(*this, symbolTable);
+}
+
+void mlir::mesh::ReduceScatterOp::getCanonicalizationPatterns(
+ RewritePatternSet &patterns, MLIRContext *context) {
+ populateMeshAxesSetCanonicalizationPatterns<ReduceScatterOp>(patterns,
+ context);
+}
+
+LogicalResult mlir::mesh::ReduceScatterOp::verify() {
+ auto mesh = ::getMesh(*this);
+ if (failed(mesh)) {
+ return failure();
+ }
+ return verifyReduceScatterOperandAndResultShape(
+ getOperand(), getResult(), getScatterAxis().getSExtValue(), getMeshAxes(),
+ mesh.value().canonicalDimSizes());
+}
+
//===----------------------------------------------------------------------===//
// TableGen'd op method definitions
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Mesh/canonicalization.mlir b/mlir/test/Dialect/Mesh/canonicalization.mlir
new file mode 100644
index 000000000000000..21fa1887d14a53e
--- /dev/null
+++ b/mlir/test/Dialect/Mesh/canonicalization.mlir
@@ -0,0 +1,72 @@
+// RUN: mlir-opt --canonicalize %s | FileCheck %s
+
+mesh.cluster @mesh0(rank = 2, dim_sizes = [2, 4])
+
+// CHECK-LABEL: func @all_reduce_mesh_axes
+func.func @all_reduce_mesh_axes(
+ %arg0 : tensor<4xf32>) -> tensor<4xf64> {
+// CHECK: mesh_axes = array<i16: 0, 1>
+ %0 = mesh.all_reduce %arg0 {
+ mesh = @mesh0, mesh_axes = array<i16: 1, 0, 0>, reduction = #mesh.partial<sum>
+ } : tensor<4xf32> -> tensor<4xf64>
+ return %0 : tensor<4xf64>
+}
+
+// CHECK-LABEL: func @all_reduce_empty_mesh_axes_and_default_reduction
+func.func @all_reduce_empty_mesh_axes_and_default_reduction(
+ %arg0 : tensor<4xf32>) -> tensor<4xf64> {
+ %0 = mesh.all_reduce %arg0 {
+ mesh = @mesh0,
+// CHECK-NOT: mesh_axes
+ mesh_axes = array<i16>,
+// CHECK-NOT: reduction
+ reduction = #mesh.partial<sum>
+ } : tensor<4xf32> -> tensor<4xf64>
+ return %0 : tensor<4xf64>
+}
+
+// CHECK-LABEL: func @all_gather_empty_mesh_axes
+func.func @all_gather_empty_mesh_axes(
+ %arg0 : tensor<4xf32>) -> tensor<4xf32> {
+ %0 = mesh.all_gather %arg0 {
+ gather_axis = 0 : index,
+ mesh = @mesh0,
+// CHECK-NOT: mesh_axes
+ mesh_axes = array<i16>
+ } : tensor<4xf32> -> tensor<4xf32>
+ return %0 : tensor<4xf32>
+}
+
+// CHECK-LABEL: func @all_gather_mesh_axes
+func.func @all_gather_mesh_axes(
+ %arg0 : tensor<4xf32>) -> tensor<32xf32> {
+// CHECK: mesh_axes = array<i16: 0, 1>
+ %0 = mesh.all_gather %arg0 {
+ mesh = @mesh0, mesh_axes = array<i16: 1, 0, 0>, gather_axis = 0
+ } : tensor<4xf32> -> tensor<32xf32>
+ return %0 : tensor<32xf32>
+}
+
+// CHECK-LABEL: func @reduce_scatter_mesh_axes
+func.func @reduce_scatter_mesh_axes(
+ %arg0 : tensor<8xf32>) -> tensor<1xf64> {
+// CHECK: mesh_axes = array<i16: 0, 1>
+ %0 = mesh.reduce_scatter %arg0 {
+ mesh = @mesh0, mesh_axes = array<i16: 1, 0, 0>, scatter_axis = 0
+ } : tensor<8xf32> -> tensor<1xf64>
+ return %0 : tensor<1xf64>
+}
+
+// CHECK-LABEL: func @reduce_scatter_empty_mesh_axes_and_default_reduction
+func.func @reduce_scatter_empty_mesh_axes_and_default_reduction(
+ %arg0 : tensor<4xf32>) -> tensor<4xf64> {
+ %0 = mesh.reduce_scatter %arg0 {
+ mesh = @mesh0,
+// CHECK-NOT: mesh_axes
+ mesh_axes = array<i16>,
+// CHECK-NOT: reduction
+ reduction = #mesh.partial<sum>,
+ scatter_axis = 0
+ } : tensor<4xf32> -> tensor<4xf64>
+ return %0 : tensor<4xf64>
+}
\ No newline at end of file
diff --git a/mlir/test/Dialect/Mesh/invalid.mlir b/mlir/test/Dialect/Mesh/invalid.mlir
index 246439dd4be7122..413b01459f507a2 100644
--- a/mlir/test/Dialect/Mesh/invalid.mlir
+++ b/mlir/test/Dialect/Mesh/invalid.mlir
@@ -67,3 +67,243 @@ func.func @mesh_axis_negtive_in_partial(
tensor<4x8xf32, #mesh.shard<@mesh0, [[0]], partial=max[-1]>> {
return %arg0 : tensor<4x8xf32, #mesh.shard<@mesh0, [[0]], partial=max[-1]>>
}
+
+// -----
+
+func.func @all_reduce_invalid_mesh_symbol(
+ %arg0 : tensor<4xf32>) -> tensor<4xf64> {
+ // expected-error at +1 {{Undefined required mesh symbol "this_mesh_symbol_does_not_exist".}}
+ %0 = mesh.all_reduce %arg0 {
+ mesh = @this_mesh_symbol_does_not_exist, reduction = #mesh.partial<sum>
+ } : tensor<4xf32> -> tensor<4xf64>
+ return %0 : tensor<4xf64>
+}
+
+// -----
+
+mesh.cluster @mesh0(rank = 2, dim_sizes = [2, 4])
+
+func.func @all_reduce_invalid_mesh_axis(
+ %arg0 : tensor<4xf32>) -> tensor<4xf64> {
+ // expected-error at +1 {{0-based mesh axis index 2 is out of bounds. The referenced mesh "mesh0" is of rank 2.}}
+ %0 = mesh.all_reduce %arg0 {
+ mesh = @mesh0, mesh_axes = array<i16: 2>, reduction = #mesh.partial<sum>
+ } : tensor<4xf32> -> tensor<4xf64>
+ return %0 : tensor<4xf64>
+}
+
+// -----
+
+mesh.cluster @mesh0(rank = 2, dim_sizes = [2, 4])
+
+func.func @all_reduce_invalid_tensor_dimension_size(
+ %arg0 : tensor<4xf32>) -> tensor<5xf64> {
+ // expected-error at +1 {{'mesh.all_reduce' op requires the same shape for all operands and results}}
+ %0 = mesh.all_reduce %arg0 { mesh = @mesh0 } : tensor<4xf32> -> tensor<5xf64>
+ return %0 : tensor<5xf64>
+}
+
+// -----
+
+func.func @all_gather_invalid_mesh_symbol(
+ %arg0 : tensor<4xf32>) -> tensor<4xf32> {
+ // expected-error at +1 {{Undefined required mesh symbol "this_mesh_symbol_does_not_exist".}}
+ %0 = mesh.all_gather %arg0 {
+ mesh = @this_mesh_symbol_does_not_exist, gather_axis = 0 : index
+ } : tensor<4xf32> -> tensor<4xf32>
+ return %0 : tensor<4xf32>
+}
+
+// -----
+
+mesh.cluster @mesh0(rank = 2, dim_sizes = [2, 4])
+
+func.func @all_gather_invalid_mesh_axis(
+ %arg0 : tensor<4xf32>) -> tensor<4xf32> {
+ // expected-error at +1 {{0-based mesh axis index 2 is out of bounds. The referenced mesh "mesh0" is of rank 2.}}
+ %0 = mesh.all_gather %arg0 {
+ mesh = @mesh0, mesh_axes = array<i16: 2>, gather_axis = 0 : index
+ } : tensor<4xf32> -> tensor<4xf32>
+ return %0 : tensor<4xf32>
+}
+
+// -----
+
+mesh.cluster @mesh0(rank = 1, dim_sizes = [1])
+
+func.func @all_gather_invalid_non_gather_axis_dimension_size(
+ %arg0 : tensor<3x4xf32>) -> tensor<3x5xf32> {
+ // expected-error at +1 {{Dimension size mismatch for result axis 1. Expected 4, but got 5.}}
+ %0 = mesh.all_gather %arg0 {
+ mesh = @mesh0, mesh_axes = array<i16: 0>, gather_axis = 0 : index
+ } : tensor<3x4xf32> -> tensor<3x5xf32>
+ return %0 : tensor<3x5xf32>
+}
+
+// -----
+
+mesh.cluster @mesh0(rank = 2, dim_sizes = [1, 2])
+
+func.func @all_gather_invalid_gather_axis_dimension_size(
+ %arg0 : tensor<3x4xf32>) -> tensor<3x5xf32> {
+ // expected-error at +1 {{Dimension size mismatch for result axis 1. Expected 8, but got 5.}}
+ %0 = mesh.all_gather %arg0 {
+ mesh = @mesh0, mesh_axes = array<i16: 1>, gather_axis = 1 : index
+ } : tensor<3x4xf32> -> tensor<3x5xf32>
+ return %0 : tensor<3x5xf32>
+}
+
+// -----
+
+mesh.cluster @mesh0(rank = 1, dim_sizes = [1])
+
+func.func @all_gather_invalid_gather_axis_dynamic_dimension(
+ %arg0 : tensor<?xf32>) -> tensor<3xf32> {
+ // expected-error at +1 {{Dimension size mismatch for result axis 0. Expected dynamic, but got 3.}}
+ %0 = mesh.all_gather %arg0 {
+ gather_axis = 0 : index,
+ mesh = @mesh0
+ } : tensor<?xf32> -> tensor<3xf32>
+ return %0 : tensor<3xf32>
+}
+
+// -----
+
+mesh.cluster @mesh0(rank = 1, dim_sizes = [1])
+
+func.func @all_gather_invalid_gather_axis(
+ %arg0 : tensor<3xf32>) -> tensor<3xf32> {
+ // expected-error at +1 {{Gather axis 1 is out of bounds [0, 1).}}
+ %0 = mesh.all_gather %arg0 {
+ mesh = @mesh0, mesh_axes = array<i16: 0>, gather_axis = 1 : index
+ } : tensor<3xf32> -> tensor<3xf32>
+ return %0 : tensor<3xf32>
+}
+
+// -----
+
+mesh.cluster @mesh0(rank = 1, dim_sizes = [1])
+
+func.func @all_gather_invalid_negative_gather_axis(
+ %arg0 : tensor<3xf32>) -> tensor<3xf32> {
+ // expected-error at +1 {{Gather axis -1 is out of bounds [0, 1).}}
+ %0 = mesh.all_gather %arg0 {
+ mesh = @mesh0, mesh_axes = array<i16: 0>, gather_axis = -1 : index
+ } : tensor<3xf32> -> tensor<3xf32>
+ return %0 : tensor<3xf32>
+}
+
+// -----
+
+func.func @all_to_all_gather_invalid_mesh_symbol(
+ %arg0 : tensor<3x6xi8>) -> tensor<3x6xi8> {
+ // expected-error at +1 {{Undefined required mesh symbol "this_mesh_symbol_does_not_exist".}}
+ %0 = mesh.all_to_all %arg0 {
+ concat_axis = 0, mesh = @this_mesh_symbol_does_not_exist, split_axis = 1
+ } : tensor<3x6xi8> -> tensor<3x6xi8>
+ return %0 : tensor<3x6xi8>
+}
+
+// -----
+
+mesh.cluster @mesh0(rank = 2, dim_sizes = [0, 1])
+
+func.func @all_to_all_invalid_non_dynamic_result_dimension_induced_by_dynamic_device_group(
+ %arg0 : tensor<3x6xi8>) -> tensor<3x6xi8> {
+ // expected-error at +1 {{Dimension size mismatch for result axis 1. Expected dynamic, but got 6.}}
+ %0 = mesh.all_to_all %arg0 {
+ concat_axis = 1, mesh = @mesh0, mesh_axes = array<i16: 0>, split_axis = 0
+ } : tensor<3x6xi8> -> tensor<3x6xi8>
+ return %0 : tensor<3x6xi8>
+}
+
+// -----
+
+mesh.cluster @mesh0(rank = 2, dim_sizes = [1, 1])
+
+func.func @all_to_all_invalid_non_dynamic_result_split_dimension_induced_by_dynamic_operand_dimension(
+ %arg0 : tensor<?x6xi8>) -> tensor<3x?xi8> {
+ // expected-error at +1 {{Dimension size mismatch for result axis 0. Expected dynamic, but got 3.}}
+ %0 = mesh.all_to_all %arg0 {
+ concat_axis = 1, mesh = @mesh0, mesh_axes = array<i16: 1>, split_axis = 0
+ } : tensor<?x6xi8> -> tensor<3x?xi8>
+ return %0 : tensor<3x?xi8>
+}
+
+// -----
+
+mesh.cluster @mesh0(rank = 2, dim_sizes = [1, 1])
+
+func.func @all_to_all_invalid_non_dynamic_result_concat_dimension_induced_by_dynamic_operand_dimension(
+ %arg0 : tensor<3x?xi8>) -> tensor<?x3xi8> {
+ // expected-error at +1 {{Dimension size mismatch for result axis 1. Expected dynamic, but got 3.}}
+ %0 = mesh.all_to_all %arg0 {
+ concat_axis = 1, mesh = @mesh0, mesh_axes = array<i16: 1>, split_axis = 0
+ } : tensor<3x?xi8> -> tensor<?x3xi8>
+ return %0 : tensor<?x3xi8>
+}
+
+// -----
+
+mesh.cluster @mesh0(rank = 1, dim_sizes = [3])
+
+func.func @all_to_all_invalid_non_dynamic_result_concat_dimension_size(
+ %arg0 : tensor<3x2xi8>) -> tensor<1x7xi8> {
+ // expected-error at +1 {{Dimension size mismatch for result axis 1. Expected 6, but got 7.}}
+ %0 = mesh.all_to_all %arg0 {
+ concat_axis = 1, mesh = @mesh0, mesh_axes = array<i16: 0>, split_axis = 0
+ } : tensor<3x2xi8> -> tensor<1x7xi8>
+ return %0 : tensor<1x7xi8>
+}
+
+// -----
+
+mesh.cluster @mesh0(rank = 1, dim_sizes = [3])
+
+func.func @all_to_all_invalid_non_dynamic_result_split_dimension_size(
+ %arg0 : tensor<3x2xi8>) -> tensor<2x6xi8> {
+ // expected-error at +1 {{Dimension size mismatch for result axis 0. Expected 1, but got 2.}}
+ %0 = mesh.all_to_all %arg0 {
+ concat_axis = 1, mesh = @mesh0, mesh_axes = array<i16: 0>, split_axis = 0
+ } : tensor<3x2xi8> -> tensor<2x6xi8>
+ return %0 : tensor<2x6xi8>
+}
+
+// -----
+
+mesh.cluster @mesh0(rank = 1, dim_sizes = [3])
+
+func.func @reduce_scatter_invalid_dynamic_dimension(
+ %arg0 : tensor<?xf32>) -> tensor<2xf64> {
+ // expected-error at +1 {{Dimension size mismatch for result axis 0. Expected dynamic, but got 2.}}
+ %0 = mesh.reduce_scatter %arg0 {
+ mesh = @mesh0, scatter_axis = 0
+ } : tensor<?xf32> -> tensor<2xf64>
+ return %0 : tensor<2xf64>
+}
+
+// -----
+
+mesh.cluster @mesh0(rank = 1, dim_sizes = [3])
+
+func.func @reduce_scatter_invalid_static_dimension_size(
+ %arg0 : tensor<3xf32>) -> tensor<2xf64> {
+ // expected-error at +1 {{Dimension size mismatch for result axis 0. Expected 1, but got 2.}}
+ %0 = mesh.reduce_scatter %arg0 {
+ mesh = @mesh0, mesh_axes = array<i16: 0>, scatter_axis = 0
+ } : tensor<3xf32> -> tensor<2xf64>
+ return %0 : tensor<2xf64>
+}
+
+// -----
+
+mesh.cluster @mesh0(rank = 1, dim_sizes = [3])
+
+func.func @reduce_scatter_invalid_operand_static_dimension_size(
+ %arg0 : tensor<4xf32>) -> tensor<?xf64> {
+ // expected-error at +1 {{Operand dimension size 4 is not divisible by collective device group size 3 for scatter axis 0.}}
+ %0 = mesh.reduce_scatter %arg0 {
+ mesh = @mesh0, mesh_axes = array<i16: 0>, scatter_axis = 0
+ } : tensor<4xf32> -> tensor<?xf64>
+ return %0 : tensor<?xf64>
+}
diff --git a/mlir/test/Dialect/Mesh/ops.mlir b/mlir/test/Dialect/Mesh/ops.mlir
index ee5f8f67792b928..5ec6c4a439327f0 100644
--- a/mlir/test/Dialect/Mesh/ops.mlir
+++ b/mlir/test/Dialect/Mesh/ops.mlir
@@ -12,6 +12,8 @@ mesh.cluster @mesh2(rank = 2, dim_sizes = [0, 4])
// CHECK: mesh.cluster @mesh3
mesh.cluster @mesh3(rank = 2)
+mesh.cluster @mesh4(rank = 1, dim_sizes = [3])
+
// CHECK-LABEL: func @mesh_shard_encoding_fully_replicated
func.func @mesh_shard_encoding_fully_replicated(
// CHECK-SAME: %[[ARG:.*]]: tensor<4x8xf32, #mesh.shard<@mesh0, {{\[\[}}]]>>
@@ -126,3 +128,120 @@ func.func @mesh_shard_op_two_users(%arg0 : tensor<4x8xf32>) ->
%2 = mesh.shard %0 to <@mesh0, [[2]]> annotate_for_users : tensor<4x8xf32>
return %1, %2 : tensor<4x8xf32>, tensor<4x8xf32>
}
+
+// CHECK-LABEL: func @all_reduce
+func.func @all_reduce(
+ // CHECK-SAME: %[[ARG:.*]]: tensor<3x4xf32>
+ %arg0 : tensor<3x4xf32>) -> tensor<3x4xf64> {
+ // CHECK-NEXT: mesh.all_reduce %[[ARG]]
+ // CHECK-SAME: {mesh = @mesh0, mesh_axes = array<i16: 1, 0>, reduction = #mesh.partial<max>}
+ // CHECK-SAME: : tensor<3x4xf32> -> tensor<3x4xf64>
+ %0 = mesh.all_reduce %arg0 {
+ mesh = @mesh0, mesh_axes = array<i16: 1, 0>, reduction = #mesh.partial<max>
+ } : tensor<3x4xf32> -> tensor<3x4xf64>
+ return %0 : tensor<3x4xf64>
+}
+
+// CHECK-LABEL: func @all_gather
+func.func @all_gather(
+ // CHECK-SAME: %[[ARG:.*]]: tensor<3x4xf32>
+ %arg0 : tensor<3x4xf32>) -> tensor<3x16xf32> {
+ // CHECK-NEXT: mesh.all_gather %[[ARG]]
+ // CHECK-SAME: {gather_axis = 1 : index, mesh = @mesh0, mesh_axes = array<i16: 2>}
+ // CHECK-SAME: : tensor<3x4xf32> -> tensor<3x16xf32>
+ %0 = mesh.all_gather %arg0 {
+ gather_axis = 1 : index, mesh = @mesh0, mesh_axes = array<i16: 2>
+ } : tensor<3x4xf32> -> tensor<3x16xf32>
+ return %0 : tensor<3x16xf32>
+}
+
+// CHECK-LABEL: func @all_gather_dynamic_dims_in_tensor
+func.func @all_gather_dynamic_dims_in_tensor(
+ // CHECK-SAME: %[[ARG:.*]]: tensor<?x?xf32>
+ %arg0 : tensor<?x?xf32>) -> tensor<?x?xf32> {
+ // CHECK-NEXT: mesh.all_gather %[[ARG]]
+ // CHECK-SAME: {gather_axis = 1 : index, mesh = @mesh0, mesh_axes = array<i16: 2>}
+ // CHECK-SAME: : tensor<?x?xf32> -> tensor<?x?xf32>
+ %0 = mesh.all_gather %arg0 {
+ gather_axis = 1 : index, mesh = @mesh0, mesh_axes = array<i16: 2>
+ } : tensor<?x?xf32> -> tensor<?x?xf32>
+ return %0 : tensor<?x?xf32>
+}
+
+// CHECK-LABEL: func @all_gather_dynamic_dims_in_mesh
+func.func @all_gather_dynamic_dims_in_mesh(
+ // CHECK-SAME: %[[ARG:.*]]: tensor<5x6xf32>
+ %arg0 : tensor<5x6xf32>) -> tensor<5x?xf32> {
+ // CHECK-NEXT: mesh.all_gather %[[ARG]]
+ // CHECK-SAME: {gather_axis = 1 : index, mesh = @mesh3, mesh_axes = array<i16: 1>}
+ // CHECK-SAME: : tensor<5x6xf32> -> tensor<5x?xf32>
+ %0 = mesh.all_gather %arg0 {
+ gather_axis = 1 : index, mesh = @mesh3, mesh_axes = array<i16: 1>
+ } : tensor<5x6xf32> -> tensor<5x?xf32>
+ return %0 : tensor<5x?xf32>
+}
+
+// CHECK-LABEL: func @all_to_all
+func.func @all_to_all(
+ // CHECK-SAME: %[[ARG:.*]]: tensor<3x6xi8>
+ %arg0 : tensor<3x6xi8>) -> tensor<3x6xi8> {
+ // CHECK-NEXT: mesh.all_to_all %[[ARG]]
+ // CHECK-SAME: {concat_axis = 0 : i64, mesh = @mesh4, split_axis = 1 : i64}
+ // CHECK-SAME: : tensor<3x6xi8> -> tensor<3x6xi8>
+ %0 = mesh.all_to_all %arg0 {
+ concat_axis = 0, mesh = @mesh4, split_axis = 1
+ } : tensor<3x6xi8> -> tensor<3x6xi8>
+ return %0 : tensor<3x6xi8>
+}
+
+// CHECK-LABEL: func @all_to_all_dynamic_dims_in_result
+func.func @all_to_all_dynamic_dims_in_result(
+ // CHECK-SAME: %[[ARG:.*]]: tensor<3x6xi8>
+ %arg0 : tensor<3x6xi8>) -> tensor<3x?xi8> {
+ // CHECK-NEXT: mesh.all_to_all %[[ARG]]
+ // CHECK-SAME: {concat_axis = 0 : i64, mesh = @mesh4, split_axis = 1 : i64}
+ // CHECK-SAME: : tensor<3x6xi8> -> tensor<3x?xi8>
+ %0 = mesh.all_to_all %arg0 {
+ concat_axis = 0, mesh = @mesh4, split_axis = 1
+ } : tensor<3x6xi8> -> tensor<3x?xi8>
+ return %0 : tensor<3x?xi8>
+}
+
+// CHECK-LABEL: func @all_to_all
+func.func @all_to_all_same_split_concat_dim_with_dynamic_device_group_size(
+ // CHECK-SAME: %[[ARG:.*]]: tensor<3xi8>
+ %arg0 : tensor<3xi8>) -> tensor<3xi8> {
+ // CHECK-NEXT: mesh.all_to_all %[[ARG]]
+ // CHECK-SAME: {concat_axis = 0 : i64, mesh = @mesh4, split_axis = 0 : i64}
+ // CHECK-SAME: : tensor<3xi8> -> tensor<3xi8>
+ %0 = mesh.all_to_all %arg0 {
+ concat_axis = 0, mesh = @mesh4, split_axis = 0
+ } : tensor<3xi8> -> tensor<3xi8>
+ return %0 : tensor<3xi8>
+}
+
+// CHECK-LABEL: func @reduce_scatter_static_dimensions
+func.func @reduce_scatter_static_dimensions(
+ // CHECK-SAME: %[[ARG:.*]]: tensor<3x4xf32>
+ %arg0 : tensor<3x4xf32>) -> tensor<3x1xf64> {
+ // CHECK-NEXT: mesh.reduce_scatter %[[ARG]]
+ // CHECK-SAME: {mesh = @mesh0, mesh_axes = array<i16: 2>, reduction = #mesh.partial<max>, scatter_axis = 1 : i64}
+ // CHECK-SAME: : tensor<3x4xf32> -> tensor<3x1xf64>
+ %0 = mesh.reduce_scatter %arg0 {
+ mesh = @mesh0, mesh_axes = array<i16: 2>, reduction = #mesh.partial<max>, scatter_axis = 1
+ } : tensor<3x4xf32> -> tensor<3x1xf64>
+ return %0 : tensor<3x1xf64>
+}
+
+// CHECK-LABEL: func @reduce_scatter_dynamic_dimensions
+func.func @reduce_scatter_dynamic_dimensions(
+ // CHECK-SAME: %[[ARG:.*]]: tensor<?xf32>
+ %arg0 : tensor<?xf32>) -> tensor<?xf64> {
+ // CHECK-NEXT: mesh.reduce_scatter %[[ARG]]
+ // CHECK-SAME: {mesh = @mesh3, mesh_axes = array<i16: 0, 1>, scatter_axis = 0 : i64}
+ // CHECK-SAME: : tensor<?xf32> -> tensor<?xf64>
+ %0 = mesh.reduce_scatter %arg0 {
+ mesh = @mesh3, mesh_axes = array<i16: 0, 1>, scatter_axis = 0
+ } : tensor<?xf32> -> tensor<?xf64>
+ return %0 : tensor<?xf64>
+}
More information about the Mlir-commits
mailing list