[Mlir-commits] [mlir] 9a8437f - [mlir][mesh] Rename cluster to mesh (#79484)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Jan 26 07:03:33 PST 2024
Author: Boian Petkantchin
Date: 2024-01-26T07:03:29-08:00
New Revision: 9a8437f50470e2658ca0b26bbc9f3da654c20dba
URL: https://github.com/llvm/llvm-project/commit/9a8437f50470e2658ca0b26bbc9f3da654c20dba
DIFF: https://github.com/llvm/llvm-project/commit/9a8437f50470e2658ca0b26bbc9f3da654c20dba.diff
LOG: [mlir][mesh] Rename cluster to mesh (#79484)
Rename
* Op mesh.cluster -> mesh.mesh
* Op mesh.cluster_shape -> mesh.mesh_shape
* variables and attributes.
The name `mesh` is more specific to what it really represents. It is a
mesh of devices.
The name `cluster` implies a broader posibility of device
configurations. When just the word `mesh` is used the meaning can often
be inferred from the context whether it refers to the mesh dialect or a
device mesh. The full name can be used when needed.
Added:
Modified:
mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td
mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.h
mlir/include/mlir/Dialect/Mesh/Transforms/Spmdization.h
mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp
mlir/lib/Dialect/Mesh/Transforms/Simplifications.cpp
mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
mlir/lib/Dialect/Mesh/Transforms/Transforms.cpp
mlir/test/Dialect/Mesh/canonicalization.mlir
mlir/test/Dialect/Mesh/folding.mlir
mlir/test/Dialect/Mesh/invalid.mlir
mlir/test/Dialect/Mesh/ops.mlir
mlir/test/Dialect/Mesh/process-multi-index-op-lowering.mlir
mlir/test/Dialect/Mesh/resharding-spmdization.mlir
mlir/test/Dialect/Mesh/sharding-propagation.mlir
mlir/test/Dialect/Mesh/simplifications.mlir
mlir/test/lib/Dialect/Mesh/TestReshardingSpmdization.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td b/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td
index 07f954459ca49d..e8353613cd0e12 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td
@@ -79,7 +79,8 @@ def MeshSharding : AttrDef<Mesh_Dialect, "MeshSharding"> {
let mnemonic = "shard";
let parameters = (ins
- AttrParameter<"::mlir::FlatSymbolRefAttr", "cluster placed">:$cluster,
+ AttrParameter<"::mlir::FlatSymbolRefAttr",
+ "The mesh on which tensors are sharded.">:$mesh,
ArrayRefParameter<"MeshAxesAttr">:$split_axes,
OptionalArrayRefParameter<"MeshAxis">:$partial_axes,
OptionalParameter<"::mlir::mesh::Partial">:$partial_type
@@ -91,9 +92,9 @@ def MeshSharding : AttrDef<Mesh_Dialect, "MeshSharding"> {
The MeshSharding attribute could be used in the encoding of a
`RankedTensorType` or the mesh.shard op. it contains three sub-attributes:
- 1. `cluster`: this attribute is a FlatSymbolRefAttr that refers to the mesh
- cluster where the distributed tensor is placed. The symbol must resolve to a
- `mesh.cluster` operation.
+ 1. `mesh`: this attribute is a FlatSymbolRefAttr that refers to the device
+ mesh where the distributed tensor is placed. The symbol must resolve to a
+ `mesh.mesh` operation.
2. `split_axes`: is an array composed of int64_t sub-arrays. The outer array's
maximum size is the `rank` of the related tensor. For the i-th sub-array, if
@@ -117,7 +118,7 @@ def MeshSharding : AttrDef<Mesh_Dialect, "MeshSharding"> {
Example:
```
- mesh.cluster @mesh0(shape = 2x2x4)
+ mesh.mesh @mesh0(shape = 2x2x4)
// The tensor is fully replicated on @mesh0.
// Currently, there must be at least one sub-array present in axes, even
@@ -140,12 +141,12 @@ def MeshSharding : AttrDef<Mesh_Dialect, "MeshSharding"> {
```
}];
let assemblyFormat = [{
- `<` $cluster `,` `[` $split_axes `]` (`,` `partial` `=` $partial_type `[`
+ `<` $mesh `,` `[` $split_axes `]` (`,` `partial` `=` $partial_type `[`
$partial_axes^ `]`)? `>`
}];
let builders = [
- AttrBuilder<(ins "FlatSymbolRefAttr":$cluster,
+ AttrBuilder<(ins "FlatSymbolRefAttr":$mesh,
"ArrayRef<SmallVector<MeshAxis>>":$split_axes,
"ArrayRef<MeshAxis>": $partial_axes,
"mesh::Partial": $partial_type), [{
@@ -153,12 +154,12 @@ def MeshSharding : AttrDef<Mesh_Dialect, "MeshSharding"> {
split_axes, [&](ArrayRef<MeshAxis> array) {
return MeshAxesAttr::get($_ctxt, array);
});
- return $_get($_ctxt, cluster, splitAxesAttr, partial_axes,
+ return $_get($_ctxt, mesh, splitAxesAttr, partial_axes,
partial_type);
}]>,
- AttrBuilder<(ins "FlatSymbolRefAttr":$cluster,
+ AttrBuilder<(ins "FlatSymbolRefAttr":$mesh,
"ArrayRef<SmallVector<MeshAxis>>":$split_axes), [{
- return MeshShardingAttr::get($_ctxt, cluster, split_axes, {}, Partial::Sum);
+ return MeshShardingAttr::get($_ctxt, mesh, split_axes, {}, Partial::Sum);
}]>
];
diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
index 78ff8bd0cac621..7b301025e687ae 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
@@ -26,17 +26,17 @@ class Mesh_Op<string mnemonic, list<Trait> traits = []> :
Op<Mesh_Dialect, mnemonic, traits> {
}
-def Mesh_ClusterOp : Mesh_Op<"cluster", [Symbol]> {
- let summary = "representing a mesh cluster";
+def Mesh_MeshOp : Mesh_Op<"mesh", [Symbol]> {
+ let summary = "Description of a device/process mesh.";
let description = [{
- The mesh.cluster operation is a symbol operation that identifies a specific
- mesh cluster. The operation has three attributes:
+ The mesh.mesh operation is a symbol operation that identifies a specific
+ mesh. The operation has three attributes:
- 1. `sym_name`: This attribute uniquely identifies the name of the mesh
- cluster. This name serves as a symbolic reference to the cluster throughout
+ 1. `sym_name`: This attribute uniquely identifies the name of the mesh.
+ This name serves as a symbolic reference to the mesh throughout
the MLIR module, allowing for consistent referencing and easier debugging.
- 2. `shape`: This attribute represents the shape of the device cluster.
+ 2. `shape`: This attribute represents the shape of the device mesh.
It uses the same notation as a tensor shape. Also allowing for dynamic
dimensions.
This flexibility allows for dynamic device assignment or configurations
@@ -46,21 +46,21 @@ def Mesh_ClusterOp : Mesh_Op<"cluster", [Symbol]> {
Example:
```
- // A device mesh cluster with 3 axes, the total device number is 4 * 8 * 12
+ // A device mesh with 3 axes, the total device number is 4 * 8 * 12
// The dimension sizes are 4, 8, 12
- mesh.cluster @mesh0(shape = 4x8x12)
+ mesh.mesh @mesh0(shape = 4x8x12)
- // A device mesh cluster with 2 axes, the total device number is unknown
+ // A device mesh with 2 axes, the total device number is unknown
// The first dimension size is 4 and the second is unknown
- mesh.cluster @mesh1(shape = 4x?)
+ mesh.mesh @mesh1(shape = 4x?)
- // A device mesh cluster with 2 axes, the total device number is unknown
+ // A device mesh with 2 axes, the total device number is unknown
// The first dimension size is unknown and the second is 4
- mesh.cluster @mesh2(shape = ?x4)
+ mesh.mesh @mesh2(shape = ?x4)
- // A device mesh cluster with 2 axes, the number of devices along both axes
+ // A device mesh with 2 axes, the number of devices along both axes
// is unknown
- mesh.cluster @mesh3(shape = ?x?)
+ mesh.mesh @mesh3(shape = ?x?)
// Used in the mesh sharding attribute to extend the standard tensor to
// distributed
@@ -81,9 +81,9 @@ def Mesh_ClusterOp : Mesh_Op<"cluster", [Symbol]> {
let hasVerifier = 1;
}
-def Mesh_ClusterShapeOp : Mesh_Op<"cluster_shape", [
+def Mesh_MeshShapeOp : Mesh_Op<"mesh_shape", [
Pure, DeclareOpInterfaceMethods<SymbolUserOpInterface>]> {
- let summary = "Get the shape of the cluster.";
+ let summary = "Get the shape of the mesh.";
let arguments = (ins
FlatSymbolRefAttr:$mesh,
DefaultValuedAttr<Mesh_MeshAxesAttr, "{}">:$axes
@@ -99,13 +99,13 @@ def Mesh_ClusterShapeOp : Mesh_Op<"cluster_shape", [
}];
let builders = [
- OpBuilder<(ins "::mlir::mesh::ClusterOp":$mesh)>,
+ OpBuilder<(ins "::mlir::mesh::MeshOp":$mesh)>,
OpBuilder<(ins "StringRef":$mesh, "ArrayRef<MeshAxis>":$axes)>
];
}
def Mesh_ShardOp : Mesh_Op<"shard", [Pure, SameOperandsAndResultType]> {
- let summary = "Annotate on how a tensor is sharded across a mesh cluster.";
+ let summary = "Annotate on how a tensor is sharded across a mesh.";
let description = [{
The mesh.shard operation is designed to specify and guide the sharding
behavior of a tensor value across a mesh topology. This operation has one
@@ -115,7 +115,7 @@ def Mesh_ShardOp : Mesh_Op<"shard", [Pure, SameOperandsAndResultType]> {
annotated for sharding.
2. `shard`: This attribute is type of `MeshSharding`, which is the core data
- structure to represent distributed tensor in mesh cluster.
+ structure to represent distribution of a tensor on a mesh.
3. `annotate_for_users`: A unit attribute addressing the scenario when a
tensor's sharding annotation
diff ers based on its context of use (either as
@@ -217,7 +217,7 @@ def Mesh_ProcessMultiIndexOp : Mesh_Op<"process_multi_index", [
attr-dict `:` type($result)
}];
let builders = [
- OpBuilder<(ins "::mlir::mesh::ClusterOp":$mesh)>,
+ OpBuilder<(ins "::mlir::mesh::MeshOp":$mesh)>,
OpBuilder<(ins "StringRef":$mesh, "ArrayRef<MeshAxis>":$axes)>
];
}
@@ -239,7 +239,7 @@ def Mesh_ProcessLinearIndexOp : Mesh_Op<"process_linear_index", [
let results = (outs Index:$result);
let assemblyFormat = "`on` $mesh attr-dict `:` type($result)";
let builders = [
- OpBuilder<(ins "::mlir::mesh::ClusterOp":$mesh)>
+ OpBuilder<(ins "::mlir::mesh::MeshOp":$mesh)>
];
}
@@ -268,7 +268,7 @@ def Mesh_AllGatherOp : Mesh_CollectiveCommunicationOpBase<"all_gather", [
Example:
```mlir
- mesh.cluster @mesh0(shape = 2x2)
+ mesh.mesh @mesh0(shape = 2x2)
...
%1 = mesh.all_gather %0 on @mesh0 mesh_axes = [1] gather_axis = 1
: tensor<2x2xi8> -> tensor<2x4xi8>
@@ -353,7 +353,7 @@ def Mesh_AllToAllOp : Mesh_CollectiveCommunicationOpBase<"all_to_all", [
Example:
```
- mesh.cluster @mesh0(shape = 3)
+ mesh.mesh @mesh0(shape = 3)
...
%1 = mesh.all_to_all %0 on @mesh0 mesh_axes = [0]
split_axis = 0 concat_axis = 0
@@ -410,7 +410,7 @@ def Mesh_BroadcastOp : Mesh_CollectiveCommunicationOpBase<"broadcast", [
Example:
```
- mesh.cluster @mesh0(shape = 2x2)
+ mesh.mesh @mesh0(shape = 2x2)
%1 = mesh.broadcast %0 on @mesh0
mesh_axes = [0]
@@ -466,7 +466,7 @@ def Mesh_GatherOp : Mesh_CollectiveCommunicationOpBase<"gather", [
Example:
```mlir
- mesh.cluster @mesh0(shape = 2x2)
+ mesh.mesh @mesh0(shape = 2x2)
...
%1 = mesh.gather %0 on @mesh0 mesh_axes = [1]
gather_axis = 1 root = [1]
@@ -589,7 +589,7 @@ def Mesh_ReduceScatterOp : Mesh_CollectiveCommunicationOpBase<"reduce_scatter",
across the device group.
Example:
```
- mesh.cluster @mesh0(shape = 2x2)
+ mesh.mesh @mesh0(shape = 2x2)
...
%1 = mesh.reduce_scatter %0 on @mesh0 mesh_axes = [1]
reduction = <max> scatter_axis = 0
@@ -652,7 +652,7 @@ def Mesh_ScatterOp : Mesh_CollectiveCommunicationOpBase<"scatter", [
Example:
```
- mesh.cluster @mesh0(shape = 2x2)
+ mesh.mesh @mesh0(shape = 2x2)
%1 = mesh.scatter %0 on @mesh0 mesh_axes = [0]
scatter_axis = 0
root = [1]
@@ -748,7 +748,7 @@ def Mesh_ShiftOp : Mesh_CollectiveCommunicationOpBase<"shift", [
Example:
```
- mesh.cluster @mesh0(shape = 2x4)
+ mesh.mesh @mesh0(shape = 2x4)
%1 = mesh.shift on @mesh0 mesh_axes = [1]
shift_axis = 1 offset = 2 rotate
: tensor<2xi8> -> tensor<2xi8>
diff --git a/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.h b/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.h
index a32274d857f15d..3bef7e6babdec9 100644
--- a/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.h
+++ b/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.h
@@ -25,14 +25,14 @@ struct ShardingOption {
// An array of int array. The sub-array at the i-th position signifies the
// mesh axes the i-th loop will be sharded on.
ShardingArray shardingArray = {};
- FlatSymbolRefAttr cluster = nullptr;
+ FlatSymbolRefAttr mesh = nullptr;
// `empty` being true indicates that no sharding information can be inferred
// at present. Note that it is
diff erent from the case where an operation is
// not sharded.
bool empty = false;
ShardingOption() = default;
- ShardingOption(ShardingArray shardingArray, FlatSymbolRefAttr cluster)
- : shardingArray(std::move(shardingArray)), cluster(cluster) {}
+ ShardingOption(ShardingArray shardingArray, FlatSymbolRefAttr mesh)
+ : shardingArray(std::move(shardingArray)), mesh(mesh) {}
};
// This method retrieves the 'MeshShardingAttr' attribute from a given operation
diff --git a/mlir/include/mlir/Dialect/Mesh/Transforms/Spmdization.h b/mlir/include/mlir/Dialect/Mesh/Transforms/Spmdization.h
index f71bb9b262a380..7cb992aac019b3 100644
--- a/mlir/include/mlir/Dialect/Mesh/Transforms/Spmdization.h
+++ b/mlir/include/mlir/Dialect/Mesh/Transforms/Spmdization.h
@@ -17,14 +17,14 @@ namespace mlir {
namespace mesh {
// Return the sharded shape `shape` acording ot sharding `sharding`.
-ShapedType shardShapedType(ShapedType shape, ClusterOp mesh,
+ShapedType shardShapedType(ShapedType shape, MeshOp mesh,
MeshShardingAttr sharding);
// Insert resharding spmdization of the value `sourceShardValue`
// from sharding `source` to sharding `target`.
// `sourceShardValue` is the already sharded value according to `source`.
-TypedValue<ShapedType> reshard(OpBuilder &builder, ClusterOp mesh,
- ShardOp source, ShardOp target,
+TypedValue<ShapedType> reshard(OpBuilder &builder, MeshOp mesh, ShardOp source,
+ ShardOp target,
TypedValue<ShapedType> sourceShardValue);
void reshardingRegisterDependentDialects(DialectRegistry ®istry);
diff --git a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
index f6b6b7c248c432..994a017a1f46c2 100644
--- a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
+++ b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
@@ -114,10 +114,10 @@ Operation *MeshDialect::materializeConstant(OpBuilder &builder, Attribute value,
// Mesh utilities
//===----------------------------------------------------------------------===//
-static FailureOr<ClusterOp> getMesh(Operation *op, FlatSymbolRefAttr meshSymbol,
- SymbolTableCollection &symbolTable) {
- mesh::ClusterOp mesh =
- symbolTable.lookupNearestSymbolFrom<mesh::ClusterOp>(op, meshSymbol);
+static FailureOr<MeshOp> getMesh(Operation *op, FlatSymbolRefAttr meshSymbol,
+ SymbolTableCollection &symbolTable) {
+ mesh::MeshOp mesh =
+ symbolTable.lookupNearestSymbolFrom<mesh::MeshOp>(op, meshSymbol);
if (!mesh) {
return op->emitError() << "Undefined required mesh symbol \""
<< meshSymbol.getValue() << "\".";
@@ -144,7 +144,7 @@ bool isUnique(It begin, It end) {
}
static LogicalResult verifyMeshAxes(Location loc, ArrayRef<MeshAxis> axes,
- ClusterOp mesh) {
+ MeshOp mesh) {
SmallVector<MeshAxis> sorted = llvm::to_vector(axes);
llvm::sort(sorted);
if (!isUnique(sorted.begin(), sorted.end())) {
@@ -192,22 +192,22 @@ Partial mesh::getPartialTypeFromReduction(IteratorType iType) {
}
//===----------------------------------------------------------------------===//
-// mesh.cluster op
+// mesh.mesh op
//===----------------------------------------------------------------------===//
-LogicalResult ClusterOp::verify() {
+LogicalResult MeshOp::verify() {
int64_t rank = getRank();
if (rank <= 0)
- return emitOpError("rank of cluster is expected to be a positive integer");
+ return emitOpError("rank of mesh is expected to be a positive integer");
if (getShape().size() > size_t(rank))
return emitOpError(
- "rank of shape is not expected to be larger than rank of cluster");
+ "rank of shape is not expected to be larger than rank of mesh");
for (int64_t dimSize : getShape()) {
if (dimSize < 0 && !ShapedType::isDynamic(dimSize))
- return emitOpError("dimension size of a mesh cluster is expected to be "
+ return emitOpError("dimension size of a mesh is expected to be "
"non-negative or dynamic");
}
@@ -215,11 +215,11 @@ LogicalResult ClusterOp::verify() {
}
//===----------------------------------------------------------------------===//
-// mesh.cluster_shape op
+// mesh.mesh_shape op
//===----------------------------------------------------------------------===//
LogicalResult
-ClusterShapeOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
+MeshShapeOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
auto mesh = ::getMesh(getOperation(), getMeshAttr(), symbolTable);
if (failed(mesh)) {
return failure();
@@ -238,16 +238,16 @@ ClusterShapeOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
return success();
}
-void ClusterShapeOp::build(OpBuilder &odsBuilder, OperationState &odsState,
- ClusterOp mesh) {
+void MeshShapeOp::build(OpBuilder &odsBuilder, OperationState &odsState,
+ MeshOp mesh) {
build(odsBuilder, odsState,
SmallVector<Type>(mesh.getRank(), odsBuilder.getIndexType()),
mesh.getSymName(),
MeshAxesAttr::get(odsBuilder.getContext(), SmallVector<MeshAxis>()));
}
-void ClusterShapeOp::build(OpBuilder &odsBuilder, OperationState &odsState,
- StringRef mesh, ArrayRef<MeshAxis> axes) {
+void MeshShapeOp::build(OpBuilder &odsBuilder, OperationState &odsState,
+ StringRef mesh, ArrayRef<MeshAxis> axes) {
build(odsBuilder, odsState,
SmallVector<Type>(axes.size(), odsBuilder.getIndexType()), mesh,
MeshAxesAttr::get(odsBuilder.getContext(), axes));
@@ -261,7 +261,7 @@ LogicalResult
MeshShardingAttr::verify(function_ref<InFlightDiagnostic()> emitError,
FlatSymbolRefAttr, ArrayRef<MeshAxesAttr> splitAxes,
ArrayRef<MeshAxis> partialAxes, Partial) {
- // TODO: At present cluster symbol ref is not verified. This is due to the
+ // TODO: At present mesh symbol ref is not verified. This is due to the
//
diff iculty in fetching the corresponding symbol op based on an attribute.
llvm::SmallSet<MeshAxis, 4> visitedAxes;
@@ -292,8 +292,7 @@ bool MeshShardingAttr::operator==(Attribute rhs) const {
}
bool MeshShardingAttr::operator==(MeshShardingAttr rhs) const {
- if (getCluster() != rhs.getCluster() ||
- getPartialAxes() != rhs.getPartialAxes()) {
+ if (getMesh() != rhs.getMesh() || getPartialAxes() != rhs.getPartialAxes()) {
return false;
}
@@ -342,7 +341,7 @@ ProcessMultiIndexOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
}
void ProcessMultiIndexOp::build(OpBuilder &odsBuilder, OperationState &odsState,
- ClusterOp mesh) {
+ MeshOp mesh) {
build(odsBuilder, odsState,
SmallVector<Type>(mesh.getRank(), odsBuilder.getIndexType()),
mesh.getSymName(), ArrayRef<MeshAxis>());
@@ -369,7 +368,7 @@ ProcessLinearIndexOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
}
void ProcessLinearIndexOp::build(OpBuilder &odsBuilder,
- OperationState &odsState, ClusterOp mesh) {
+ OperationState &odsState, MeshOp mesh) {
build(odsBuilder, odsState, mesh.getSymName());
}
@@ -427,7 +426,7 @@ static LogicalResult verifyInGroupDevice(Location loc, StringRef deviceName,
}
template <typename Op>
-static FailureOr<ClusterOp>
+static FailureOr<MeshOp>
getMeshAndVerifyAxes(Op op, SymbolTableCollection &symbolTable) {
auto mesh = ::getMesh(op.getOperation(), op.getMeshAttr(), symbolTable);
if (failed(mesh)) {
diff --git a/mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp b/mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp
index dca7e86e6f07f5..5dc91ff1c02d20 100644
--- a/mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp
+++ b/mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp
@@ -215,11 +215,10 @@ namespace {
// Update the given `shardingOption` according to `meshAxes` and `loopIdx`
static LogicalResult fillShardingOption(Operation *op,
ShardingOption &shardingOption,
- FlatSymbolRefAttr cluster,
+ FlatSymbolRefAttr mesh,
ArrayRef<MeshAxis> meshAxes,
unsigned loopIdx) {
- if ((shardingOption.cluster && cluster &&
- shardingOption.cluster != cluster) ||
+ if ((shardingOption.mesh && mesh && shardingOption.mesh != mesh) ||
(!shardingOption.shardingArray[loopIdx].empty() &&
shardingOption.shardingArray[loopIdx] != meshAxes)) {
LLVM_DEBUG(DBGS() << "sharding option conflicts on loop iterator "
@@ -238,8 +237,8 @@ static LogicalResult fillShardingOption(Operation *op,
}
}
}
- if (cluster)
- shardingOption.cluster = cluster;
+ if (mesh)
+ shardingOption.mesh = mesh;
if (shardingOption.shardingArray[loopIdx].empty())
shardingOption.shardingArray[loopIdx].append(meshAxes.begin(),
meshAxes.end());
@@ -281,7 +280,7 @@ FailureOr<ShardingOption> mesh::detail::defaultGetShardingOption(
auto dim = cast<AffineDimExpr>(expr);
unsigned index = dim.getPosition();
visitedLoopIndices.insert(index);
- if (failed(fillShardingOption(op, shardingOption, shardAttr.getCluster(),
+ if (failed(fillShardingOption(op, shardingOption, shardAttr.getMesh(),
axes, index)))
return failure();
}
@@ -333,8 +332,8 @@ FailureOr<ShardingOption> mesh::detail::defaultGetShardingOption(
if (loopIndices->size() == 1) {
unsigned loopIdx = *loopIndices->begin();
visitedLoopIndices.insert(loopIdx);
- if (failed(fillShardingOption(op, shardingOption,
- shardAttr.getCluster(), axes, loopIdx)))
+ if (failed(fillShardingOption(op, shardingOption, shardAttr.getMesh(),
+ axes, loopIdx)))
return failure();
}
// If multiple loop indices correspond to a dimension of an operand, it is
@@ -437,9 +436,8 @@ static LogicalResult addShardOp(OpBuilder &b, OpResult result,
}
removeTrailingEmptySubArray(splitAxes);
- MeshShardingAttr shardAttr =
- MeshShardingAttr::get(b.getContext(), shardingOption.cluster, splitAxes,
- partialAxes, partialType);
+ MeshShardingAttr shardAttr = MeshShardingAttr::get(
+ b.getContext(), shardingOption.mesh, splitAxes, partialAxes, partialType);
OpBuilder::InsertionGuard guard(b);
b.setInsertionPointAfterValue(result);
auto shardOp = b.create<ShardOp>(result.getLoc(), resultType, result,
@@ -485,7 +483,7 @@ static LogicalResult addShardOp(OpBuilder &b, OpOperand &opOperand,
removeTrailingEmptySubArray(splitAxes);
MeshShardingAttr shardAttr =
- MeshShardingAttr::get(b.getContext(), shardingOption.cluster, splitAxes);
+ MeshShardingAttr::get(b.getContext(), shardingOption.mesh, splitAxes);
OpBuilder::InsertionGuard guard(b);
b.setInsertionPoint(opOperand.getOwner());
auto shardOp = b.create<ShardOp>(operand.getLoc(), operandType, operand,
diff --git a/mlir/lib/Dialect/Mesh/Transforms/Simplifications.cpp b/mlir/lib/Dialect/Mesh/Transforms/Simplifications.cpp
index 429e684c845fb1..c0273cdaef7144 100644
--- a/mlir/lib/Dialect/Mesh/Transforms/Simplifications.cpp
+++ b/mlir/lib/Dialect/Mesh/Transforms/Simplifications.cpp
@@ -55,20 +55,19 @@ namespace {
// DialectFoldInterface, because it needs a SymbolTableCollection to cache the
// symbol tables.
// We can't use DialectFoldInterface since the cache may be invalidated by some
-// pass changing the referenced ClusterOp ops.
-struct ClusterShapeFolder : OpRewritePattern<ClusterShapeOp> {
+// pass changing the referenced MeshOp ops.
+struct MeshShapeFolder : OpRewritePattern<MeshShapeOp> {
template <typename... OpRewritePatternArgs>
- ClusterShapeFolder(SymbolTableCollection &symbolTableCollection,
- OpRewritePatternArgs &&...opRewritePatternArgs)
+ MeshShapeFolder(SymbolTableCollection &symbolTableCollection,
+ OpRewritePatternArgs &&...opRewritePatternArgs)
: OpRewritePattern(
std::forward<OpRewritePatternArgs...>(opRewritePatternArgs)...),
symbolTableCollection(symbolTableCollection) {}
- LogicalResult matchAndRewrite(ClusterShapeOp op,
+ LogicalResult matchAndRewrite(MeshShapeOp op,
PatternRewriter &rewriter) const override {
ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
- ClusterOp mesh =
- symbolTableCollection.lookupNearestSymbolFrom<mesh::ClusterOp>(
- op.getOperation(), op.getMeshAttr());
+ MeshOp mesh = symbolTableCollection.lookupNearestSymbolFrom<mesh::MeshOp>(
+ op.getOperation(), op.getMeshAttr());
if (!mesh) {
return failure();
}
@@ -104,8 +103,8 @@ struct ClusterShapeFolder : OpRewritePattern<ClusterShapeOp> {
// Leave only the dynamic mesh axes to be queried.
if (!newShapeOpMeshAxes.empty()) {
- ClusterShapeOp newShapeOp =
- builder.create<ClusterShapeOp>(mesh.getSymName(), newShapeOpMeshAxes);
+ MeshShapeOp newShapeOp =
+ builder.create<MeshShapeOp>(mesh.getSymName(), newShapeOpMeshAxes);
for (size_t i = 0; i < newShapeOp->getResults().size(); ++i) {
newResults[newToOldResultsIndexMap[i]] = newShapeOp->getResults()[i];
}
@@ -123,8 +122,7 @@ struct ClusterShapeFolder : OpRewritePattern<ClusterShapeOp> {
void populateFoldingPatterns(RewritePatternSet &patterns,
SymbolTableCollection &symbolTableCollection) {
- patterns.add<ClusterShapeFolder>(symbolTableCollection,
- patterns.getContext());
+ patterns.add<MeshShapeFolder>(symbolTableCollection, patterns.getContext());
}
} // namespace mesh
diff --git a/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp b/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
index 9478b2e4ee5cb2..593158d5f6d293 100644
--- a/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
+++ b/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
@@ -84,7 +84,7 @@ static void shardShape(const InShape &inShape, const MeshShape &meshShape,
}
}
-ShapedType shardShapedType(ShapedType shape, ClusterOp mesh,
+ShapedType shardShapedType(ShapedType shape, MeshOp mesh,
MeshShardingAttr sharding) {
using Dim = std::decay_t<decltype(shape.getDimSize(0))>;
SmallVector<Dim> resShapeArr(shape.getShape().size());
@@ -141,7 +141,7 @@ handlePartialAxesDuringResharding(OpBuilder &builder,
TypedValue<ShapedType> resultValue =
builder
.create<AllReduceOp>(sourceShard.getLoc(), sourceShard.getType(),
- sourceSharding.getCluster().getLeafReference(),
+ sourceSharding.getMesh().getLeafReference(),
allReduceMeshAxes, sourceShard,
sourceSharding.getPartialType())
.getResult()
@@ -154,7 +154,7 @@ handlePartialAxesDuringResharding(OpBuilder &builder,
return targetShardingPartialAxesSet.contains(a);
});
MeshShardingAttr resultSharding =
- MeshShardingAttr::get(builder.getContext(), sourceSharding.getCluster(),
+ MeshShardingAttr::get(builder.getContext(), sourceSharding.getMesh(),
sourceSharding.getSplitAxes(), remainingPartialAxes,
sourceSharding.getPartialType());
return {resultValue, resultSharding};
@@ -175,7 +175,7 @@ targetShardingInSplitLastAxis(MLIRContext *ctx, MeshShardingAttr sourceSharding,
targetShardingSplitAxes[splitTensorAxis] =
MeshAxesAttr::get(ctx, targetSplitAxes);
return MeshShardingAttr::get(
- ctx, sourceSharding.getCluster(), targetShardingSplitAxes,
+ ctx, sourceSharding.getMesh(), targetShardingSplitAxes,
sourceSharding.getPartialAxes(), sourceSharding.getPartialType());
}
@@ -197,7 +197,7 @@ static ShapedType targetShapeInSplitLastAxis(ShapedType sourceShape,
static std::tuple<TypedValue<ShapedType>, MeshShardingAttr>
splitLastAxisInResharding(ImplicitLocOpBuilder &builder,
MeshShardingAttr sourceSharding,
- TypedValue<ShapedType> sourceShard, ClusterOp mesh,
+ TypedValue<ShapedType> sourceShard, MeshOp mesh,
int64_t splitTensorAxis, MeshAxis splitMeshAxis) {
MLIRContext *ctx = builder.getContext();
builder.setInsertionPointAfterValue(sourceShard);
@@ -217,8 +217,8 @@ splitLastAxisInResharding(ImplicitLocOpBuilder &builder,
Value meshAxisSize =
builder
- .create<ClusterShapeOp>(mesh.getSymName(),
- SmallVector<MeshAxis>({splitMeshAxis}))
+ .create<MeshShapeOp>(mesh.getSymName(),
+ SmallVector<MeshAxis>({splitMeshAxis}))
.getResult()[0];
Value sourceAxisSize =
@@ -305,7 +305,7 @@ detectSplitLastAxisInResharding(MeshShardingAttr sourceSharding,
}
static std::optional<std::tuple<TypedValue<ShapedType>, MeshShardingAttr>>
-trySplitLastAxisInResharding(ImplicitLocOpBuilder &builder, ClusterOp mesh,
+trySplitLastAxisInResharding(ImplicitLocOpBuilder &builder, MeshOp mesh,
MeshShardingAttr sourceSharding,
MeshShardingAttr targetSharding,
TypedValue<ShapedType> sourceShard) {
@@ -366,7 +366,7 @@ targetShardingInUnsplitLastAxis(MLIRContext *ctx,
targetShardingSplitAxes[splitTensorAxis] =
MeshAxesAttr::get(ctx, targetSplitAxes);
return MeshShardingAttr::get(
- ctx, sourceSharding.getCluster(), targetShardingSplitAxes,
+ ctx, sourceSharding.getMesh(), targetShardingSplitAxes,
sourceSharding.getPartialAxes(), sourceSharding.getPartialType());
}
@@ -382,7 +382,7 @@ static std::tuple<TypedValue<ShapedType>, MeshShardingAttr>
unsplitLastAxisInResharding(ImplicitLocOpBuilder &builder,
MeshShardingAttr sourceSharding,
ShapedType sourceUnshardedShape,
- TypedValue<ShapedType> sourceShard, ClusterOp mesh,
+ TypedValue<ShapedType> sourceShard, MeshOp mesh,
int64_t splitTensorAxis, MeshAxis splitMeshAxis) {
MLIRContext *ctx = builder.getContext();
builder.setInsertionPointAfterValue(sourceShard);
@@ -406,7 +406,7 @@ unsplitLastAxisInResharding(ImplicitLocOpBuilder &builder,
}
static std::optional<std::tuple<TypedValue<ShapedType>, MeshShardingAttr>>
-tryUnsplitLastAxisInResharding(ImplicitLocOpBuilder &builder, ClusterOp mesh,
+tryUnsplitLastAxisInResharding(ImplicitLocOpBuilder &builder, MeshOp mesh,
MeshShardingAttr sourceSharding,
MeshShardingAttr targetSharding,
ShapedType sourceUnshardedShape,
@@ -495,7 +495,7 @@ targetShardingInMoveLastAxis(MLIRContext *ctx, MeshShardingAttr sourceSharding,
MeshAxesAttr::get(ctx, targetSplitAxes);
return MeshShardingAttr::get(
- ctx, sourceSharding.getCluster(), targetShardingSplitAxes,
+ ctx, sourceSharding.getMesh(), targetShardingSplitAxes,
sourceSharding.getPartialAxes(), sourceSharding.getPartialType());
}
@@ -512,7 +512,7 @@ static ShapedType allToAllResultShapeInMoveLastAxis(ShapedType sourceShape,
}
static std::tuple<TypedValue<ShapedType>, MeshShardingAttr>
-moveLastSplitAxisInResharding(ImplicitLocOpBuilder &builder, ClusterOp mesh,
+moveLastSplitAxisInResharding(ImplicitLocOpBuilder &builder, MeshOp mesh,
MeshShardingAttr sourceSharding,
ShapedType sourceUnshardedShape,
TypedValue<ShapedType> sourceShard,
@@ -541,7 +541,7 @@ moveLastSplitAxisInResharding(ImplicitLocOpBuilder &builder, ClusterOp mesh,
}
static std::optional<std::tuple<TypedValue<ShapedType>, MeshShardingAttr>>
-tryMoveLastSplitAxisInResharding(ImplicitLocOpBuilder &builder, ClusterOp mesh,
+tryMoveLastSplitAxisInResharding(ImplicitLocOpBuilder &builder, MeshOp mesh,
MeshShardingAttr sourceSharding,
MeshShardingAttr targetSharding,
ShapedType sourceUnshardedShape,
@@ -561,7 +561,7 @@ tryMoveLastSplitAxisInResharding(ImplicitLocOpBuilder &builder, ClusterOp mesh,
// Currently the sharded tensor axes must be exactly divisible by the single
// mesh axis size.
static TypedValue<ShapedType>
-reshardOn1DMesh(ImplicitLocOpBuilder &builder, ClusterOp mesh,
+reshardOn1DMesh(ImplicitLocOpBuilder &builder, MeshOp mesh,
MeshShardingAttr sourceSharding,
MeshShardingAttr targetSharding,
TypedValue<ShapedType> sourceUnshardedValue,
@@ -604,7 +604,7 @@ reshardOn1DMesh(ImplicitLocOpBuilder &builder, ClusterOp mesh,
return targetShard;
}
-TypedValue<ShapedType> reshard(ImplicitLocOpBuilder &builder, ClusterOp mesh,
+TypedValue<ShapedType> reshard(ImplicitLocOpBuilder &builder, MeshOp mesh,
MeshShardingAttr sourceSharding,
MeshShardingAttr targetSharding,
TypedValue<ShapedType> sourceUnshardedValue,
@@ -616,8 +616,8 @@ TypedValue<ShapedType> reshard(ImplicitLocOpBuilder &builder, ClusterOp mesh,
sourceUnshardedValue, sourceShard);
}
-TypedValue<ShapedType> reshard(OpBuilder &builder, ClusterOp mesh,
- ShardOp source, ShardOp target,
+TypedValue<ShapedType> reshard(OpBuilder &builder, MeshOp mesh, ShardOp source,
+ ShardOp target,
TypedValue<ShapedType> sourceShardValue) {
assert(!source.getAnnotateForUsers());
assert(target.getAnnotateForUsers());
diff --git a/mlir/lib/Dialect/Mesh/Transforms/Transforms.cpp b/mlir/lib/Dialect/Mesh/Transforms/Transforms.cpp
index c27e173d877d69..5c2344651bf60d 100644
--- a/mlir/lib/Dialect/Mesh/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Mesh/Transforms/Transforms.cpp
@@ -24,7 +24,7 @@ namespace mlir::mesh {
namespace {
/// Lower `mesh.process_multi_index` into expression using
-/// `mesh.process_linear_index` and `mesh.cluster_shape`.
+/// `mesh.process_linear_index` and `mesh.mesh_shape`.
struct ProcessMultiIndexOpLowering : OpRewritePattern<ProcessMultiIndexOp> {
template <typename... OpRewritePatternArgs>
ProcessMultiIndexOpLowering(SymbolTableCollection &symbolTableCollection,
@@ -35,9 +35,8 @@ struct ProcessMultiIndexOpLowering : OpRewritePattern<ProcessMultiIndexOp> {
LogicalResult matchAndRewrite(ProcessMultiIndexOp op,
PatternRewriter &rewriter) const override {
- ClusterOp mesh =
- symbolTableCollection.lookupNearestSymbolFrom<mesh::ClusterOp>(
- op.getOperation(), op.getMeshAttr());
+ MeshOp mesh = symbolTableCollection.lookupNearestSymbolFrom<mesh::MeshOp>(
+ op.getOperation(), op.getMeshAttr());
if (!mesh) {
return failure();
}
@@ -45,7 +44,7 @@ struct ProcessMultiIndexOpLowering : OpRewritePattern<ProcessMultiIndexOp> {
ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
builder.setInsertionPointAfter(op.getOperation());
Value linearIndex = builder.create<ProcessLinearIndexOp>(mesh);
- ValueRange meshShape = builder.create<ClusterShapeOp>(mesh).getResults();
+ ValueRange meshShape = builder.create<MeshShapeOp>(mesh).getResults();
SmallVector<Value> completeMultiIndex =
builder.create<affine::AffineDelinearizeIndexOp>(linearIndex, meshShape)
.getMultiIndex();
diff --git a/mlir/test/Dialect/Mesh/canonicalization.mlir b/mlir/test/Dialect/Mesh/canonicalization.mlir
index 4cc009ef24eb3c..23c5b253b4c073 100644
--- a/mlir/test/Dialect/Mesh/canonicalization.mlir
+++ b/mlir/test/Dialect/Mesh/canonicalization.mlir
@@ -1,6 +1,6 @@
// RUN: mlir-opt --canonicalize %s | FileCheck %s
-mesh.cluster @mesh0(shape = 2x4)
+mesh.mesh @mesh0(shape = 2x4)
// CHECK-LABEL: func @all_reduce_empty_mesh_axes
func.func @all_reduce_empty_mesh_axes(
diff --git a/mlir/test/Dialect/Mesh/folding.mlir b/mlir/test/Dialect/Mesh/folding.mlir
index 9162dc57ecfdf4..369f316d0f797d 100644
--- a/mlir/test/Dialect/Mesh/folding.mlir
+++ b/mlir/test/Dialect/Mesh/folding.mlir
@@ -1,22 +1,22 @@
// RUN: mlir-opt -test-mesh-simplifications %s | FileCheck %s
-mesh.cluster @mesh0(shape = 4x?x2)
-mesh.cluster @mesh1(shape = 2x3)
+mesh.mesh @mesh0(shape = 4x?x2)
+mesh.mesh @mesh1(shape = 2x3)
-// CHECK-LABEL: func.func @cluster_shape_op_folding
-func.func @cluster_shape_op_folding() -> (index, index) {
+// CHECK-LABEL: func.func @mesh_shape_op_folding
+func.func @mesh_shape_op_folding() -> (index, index) {
// CHECK: %[[AXIS_2_SIZE:.*]] = arith.constant 2 : index
- // CHECK: %[[AXIS_1_SIZE:.*]] = mesh.cluster_shape @mesh0 axes = [1] : index
- %0:2 = mesh.cluster_shape @mesh0 axes = [2, 1] : index, index
+ // CHECK: %[[AXIS_1_SIZE:.*]] = mesh.mesh_shape @mesh0 axes = [1] : index
+ %0:2 = mesh.mesh_shape @mesh0 axes = [2, 1] : index, index
// CHECK: return %[[AXIS_2_SIZE]], %[[AXIS_1_SIZE]]
return %0#0, %0#1 : index, index
}
-// CHECK-LABEL: func.func @cluster_shape_op_folding_all_axes_static_mesh
-func.func @cluster_shape_op_folding_all_axes_static_mesh() -> (index, index) {
+// CHECK-LABEL: func.func @mesh_shape_op_folding_all_axes_static_mesh
+func.func @mesh_shape_op_folding_all_axes_static_mesh() -> (index, index) {
// CHECK: %[[AXIS_0_SIZE:.*]] = arith.constant 2 : index
// CHECK: %[[AXIS_1_SIZE:.*]] = arith.constant 3 : index
- %0:2 = mesh.cluster_shape @mesh1 : index, index
+ %0:2 = mesh.mesh_shape @mesh1 : index, index
// CHECK: return %[[AXIS_0_SIZE]], %[[AXIS_1_SIZE]]
return %0#0, %0#1 : index, index
}
diff --git a/mlir/test/Dialect/Mesh/invalid.mlir b/mlir/test/Dialect/Mesh/invalid.mlir
index 8a1fb80065573b..259e4ebf76757c 100644
--- a/mlir/test/Dialect/Mesh/invalid.mlir
+++ b/mlir/test/Dialect/Mesh/invalid.mlir
@@ -1,16 +1,16 @@
// RUN: mlir-opt -split-input-file -verify-diagnostics %s
-// expected-error at +1 {{rank of cluster is expected to be a positive integer}}
-mesh.cluster @mesh0(shape = [])
+// expected-error at +1 {{rank of mesh is expected to be a positive integer}}
+mesh.mesh @mesh0(shape = [])
// -----
-// expected-error at +1 {{custom op 'mesh.cluster' Failed parsing dimension list. Did you mean an empty list? It must be denoted by "[]".}}
-mesh.cluster @mesh0(shape = -1)
+// expected-error at +1 {{custom op 'mesh.mesh' Failed parsing dimension list. Did you mean an empty list? It must be denoted by "[]".}}
+mesh.mesh @mesh0(shape = -1)
// -----
-mesh.cluster @mesh0(shape = 2x4)
+mesh.mesh @mesh0(shape = 2x4)
func.func @mesh_axis_duplicated_
diff erent_subarray(
// expected-error at +1 {{mesh axis duplicated}}
@@ -21,7 +21,7 @@ func.func @mesh_axis_duplicated_
diff erent_subarray(
// -----
-mesh.cluster @mesh0(shape = 2x4)
+mesh.mesh @mesh0(shape = 2x4)
func.func @mesh_axis_duplicated_same_subarray(
// expected-error at +1 {{mesh axis duplicated}}
@@ -32,7 +32,7 @@ func.func @mesh_axis_duplicated_same_subarray(
// -----
-mesh.cluster @mesh0(shape = 2x4)
+mesh.mesh @mesh0(shape = 2x4)
func.func @mesh_axis_duplicated_bewteen_split_and_partial(
// expected-error at +1 {{mesh axis duplicated}}
@@ -43,7 +43,7 @@ func.func @mesh_axis_duplicated_bewteen_split_and_partial(
// -----
-mesh.cluster @mesh0(shape = 2x4)
+mesh.mesh @mesh0(shape = 2x4)
func.func @mesh_axis_negtive_in_split_part(
// expected-error at +1 {{mesh axis is expected to be non-negative}}
@@ -54,7 +54,7 @@ func.func @mesh_axis_negtive_in_split_part(
// -----
-mesh.cluster @mesh0(shape = 2x4)
+mesh.mesh @mesh0(shape = 2x4)
func.func @mesh_axis_negtive_in_partial(
// expected-error at +1 {{mesh axis is expected to be non-negative}}
@@ -67,61 +67,61 @@ func.func @mesh_axis_negtive_in_partial(
func.func @sharding_attribute_invalid_nested_symbol(%arg0 : tensor<4x8xf32>) {
// expected-error at +2 {{custom op 'mesh.shard' invalid kind of attribute specified}}
- // expected-error at +1 {{custom op 'mesh.shard' failed to parse MeshSharding parameter 'cluster' which is to be a `::mlir::FlatSymbolRefAttr`}}
+ // expected-error at +1 {{custom op 'mesh.shard' failed to parse MeshSharding parameter 'mesh' which is to be a `::mlir::FlatSymbolRefAttr`}}
%0 = mesh.shard %arg0 to <@a::@b, [[0]]> : tensor<4x8xf32>
}
// -----
-mesh.cluster @mesh0(shape = 2x4)
+mesh.mesh @mesh0(shape = 2x4)
-func.func @cluster_shape_mesh_axis_out_of_bounds() -> (index, index) {
+func.func @mesh_shape_mesh_axis_out_of_bounds() -> (index, index) {
// expected-error at +1 {{0-based mesh axis index 2 is out of bounds. The referenced mesh "mesh0" is of rank 2.}}
- %0:2 = mesh.cluster_shape @mesh0 axes = [0, 2] : index, index
+ %0:2 = mesh.mesh_shape @mesh0 axes = [0, 2] : index, index
return %0#0, %0#1 : index, index
}
// -----
-mesh.cluster @mesh0(shape = 1x2x3)
+mesh.mesh @mesh0(shape = 1x2x3)
-func.func @cluster_shape_duplicate_mesh_axis() -> (index, index, index) {
+func.func @mesh_shape_duplicate_mesh_axis() -> (index, index, index) {
// expected-error at +1 {{Mesh axes contains duplicate elements.}}
- %0:3 = mesh.cluster_shape @mesh0 axes = [0, 2, 0] : index, index, index
+ %0:3 = mesh.mesh_shape @mesh0 axes = [0, 2, 0] : index, index, index
return %0#0, %0#1, %0#2 : index, index, index
}
// -----
-mesh.cluster @mesh0(shape = 2x4)
+mesh.mesh @mesh0(shape = 2x4)
-func.func @cluster_shape_wrong_number_of_results() -> (index, index) {
+func.func @mesh_shape_wrong_number_of_results() -> (index, index) {
// expected-error at +1 {{Unexpected number of results 2. Expected 1.}}
- %0:2 = mesh.cluster_shape @mesh0 axes = [0] : index, index
+ %0:2 = mesh.mesh_shape @mesh0 axes = [0] : index, index
return %0#0, %0#1 : index, index
}
// -----
-mesh.cluster @mesh0(shape = 1x2x3)
+mesh.mesh @mesh0(shape = 1x2x3)
-func.func @cluster_shape_wrong_number_of_results_empty_mesh_axes() -> (index, index) {
+func.func @mesh_shape_wrong_number_of_results_empty_mesh_axes() -> (index, index) {
// expected-error at +1 {{Unexpected number of results 2. Expected 3.}}
- %0:2 = mesh.cluster_shape @mesh0 : index, index
+ %0:2 = mesh.mesh_shape @mesh0 : index, index
return %0#0, %0#1 : index, index
}
// -----
-func.func @cluster_shape_invalid_mesh_name() -> (index) {
+func.func @mesh_shape_invalid_mesh_name() -> (index) {
// expected-error at +1 {{Undefined required mesh symbol "this_mesh_symbol_does_not_exist".}}
- %0 = mesh.cluster_shape @this_mesh_symbol_does_not_exist : index
+ %0 = mesh.mesh_shape @this_mesh_symbol_does_not_exist : index
return %0#0 : index
}
// -----
-mesh.cluster @mesh0(shape = 2x4)
+mesh.mesh @mesh0(shape = 2x4)
func.func @process_multi_index_mesh_axis_out_of_bounds() -> (index, index) {
// expected-error at +1 {{0-based mesh axis index 2 is out of bounds. The referenced mesh "mesh0" is of rank 2.}}
@@ -131,7 +131,7 @@ func.func @process_multi_index_mesh_axis_out_of_bounds() -> (index, index) {
// -----
-mesh.cluster @mesh0(shape = 1x2x3)
+mesh.mesh @mesh0(shape = 1x2x3)
func.func @process_multi_index_duplicate_mesh_axis() -> (index, index, index) {
// expected-error at +1 {{Mesh axes contains duplicate elements.}}
@@ -141,7 +141,7 @@ func.func @process_multi_index_duplicate_mesh_axis() -> (index, index, index) {
// -----
-mesh.cluster @mesh0(shape = 2x4)
+mesh.mesh @mesh0(shape = 2x4)
func.func @process_multi_index_wrong_number_of_results() -> (index, index) {
// expected-error at +1 {{Unexpected number of results 2. Expected 1.}}
@@ -151,7 +151,7 @@ func.func @process_multi_index_wrong_number_of_results() -> (index, index) {
// -----
-mesh.cluster @mesh0(shape = 1x2x3)
+mesh.mesh @mesh0(shape = 1x2x3)
func.func @process_multi_index_wrong_number_of_results_empty_mesh_axes() -> (index, index) {
// expected-error at +1 {{Unexpected number of results 2. Expected 3.}}
@@ -187,7 +187,7 @@ func.func @all_reduce_invalid_mesh_symbol(
// -----
-mesh.cluster @mesh0(shape = 2x4)
+mesh.mesh @mesh0(shape = 2x4)
func.func @all_reduce_invalid_mesh_axis(
%arg0 : tensor<4xf32>) -> tensor<4xf64> {
@@ -199,7 +199,7 @@ func.func @all_reduce_invalid_mesh_axis(
// -----
-mesh.cluster @mesh0(shape = 2x4)
+mesh.mesh @mesh0(shape = 2x4)
func.func @all_reduce_duplicate_mesh_axis(
%arg0 : tensor<4xf32>) -> tensor<4xf64> {
@@ -211,7 +211,7 @@ func.func @all_reduce_duplicate_mesh_axis(
// -----
-mesh.cluster @mesh0(shape = 2x4)
+mesh.mesh @mesh0(shape = 2x4)
func.func @all_reduce_invalid_tensor_dimension_size(
%arg0 : tensor<4xf32>) -> tensor<5xf64> {
@@ -232,7 +232,7 @@ func.func @all_gather_invalid_mesh_symbol(
// -----
-mesh.cluster @mesh0(shape = 2x4)
+mesh.mesh @mesh0(shape = 2x4)
func.func @all_gather_invalid_mesh_axis(
%arg0 : tensor<4xf32>) -> tensor<4xf32> {
@@ -244,7 +244,7 @@ func.func @all_gather_invalid_mesh_axis(
// -----
-mesh.cluster @mesh0(shape = 2x4)
+mesh.mesh @mesh0(shape = 2x4)
func.func @all_reduce_duplicate_mesh_axis(
%arg0 : tensor<4xf32>) -> tensor<4xf32> {
@@ -256,7 +256,7 @@ func.func @all_reduce_duplicate_mesh_axis(
// -----
-mesh.cluster @mesh0(shape = 1)
+mesh.mesh @mesh0(shape = 1)
func.func @all_gather_invalid_non_gather_axis_dimension_size(
%arg0 : tensor<3x4xf32>) -> tensor<3x5xf32> {
@@ -268,7 +268,7 @@ func.func @all_gather_invalid_non_gather_axis_dimension_size(
// -----
-mesh.cluster @mesh0(shape = 1x2)
+mesh.mesh @mesh0(shape = 1x2)
func.func @all_gather_invalid_gather_axis_dimension_size(
%arg0 : tensor<3x4xf32>) -> tensor<3x5xf32> {
@@ -280,7 +280,7 @@ func.func @all_gather_invalid_gather_axis_dimension_size(
// -----
-mesh.cluster @mesh0(shape = 1)
+mesh.mesh @mesh0(shape = 1)
func.func @all_gather_invalid_gather_axis_dynamic_dimension(
%arg0 : tensor<?xf32>) -> tensor<3xf32> {
@@ -292,7 +292,7 @@ func.func @all_gather_invalid_gather_axis_dynamic_dimension(
// -----
-mesh.cluster @mesh0(shape = 1)
+mesh.mesh @mesh0(shape = 1)
func.func @all_gather_invalid_gather_axis(
%arg0 : tensor<3xf32>) -> tensor<3xf32> {
@@ -304,7 +304,7 @@ func.func @all_gather_invalid_gather_axis(
// -----
-mesh.cluster @mesh0(shape = 1)
+mesh.mesh @mesh0(shape = 1)
func.func @all_gather_invalid_negative_gather_axis(
%arg0 : tensor<3xf32>) -> tensor<3xf32> {
@@ -327,7 +327,7 @@ func.func @all_to_all_invalid_mesh_symbol(
// -----
-mesh.cluster @mesh0(shape = 1)
+mesh.mesh @mesh0(shape = 1)
func.func @all_to_all_duplicate_mesh_axis(
%arg0 : tensor<3x6xi8>) -> tensor<3x6xi8> {
@@ -340,7 +340,7 @@ func.func @all_to_all_duplicate_mesh_axis(
// -----
-mesh.cluster @mesh0(shape = ?x1)
+mesh.mesh @mesh0(shape = ?x1)
func.func @all_to_all_invalid_non_dynamic_result_dimension_induced_by_dynamic_device_group(
%arg0 : tensor<3x6xi8>) -> tensor<3x6xi8> {
@@ -353,7 +353,7 @@ func.func @all_to_all_invalid_non_dynamic_result_dimension_induced_by_dynamic_de
// -----
-mesh.cluster @mesh0(shape = 1x1)
+mesh.mesh @mesh0(shape = 1x1)
func.func @all_to_all_invalid_non_dynamic_result_split_dimension_induced_by_dynamic_operand_dimension(
%arg0 : tensor<?x6xi8>) -> tensor<3x?xi8> {
@@ -366,7 +366,7 @@ func.func @all_to_all_invalid_non_dynamic_result_split_dimension_induced_by_dyna
// -----
-mesh.cluster @mesh0(shape = 1x1)
+mesh.mesh @mesh0(shape = 1x1)
func.func @all_to_all_invalid_non_dynamic_result_concat_dimension_induced_by_dynamic_operand_dimension(
%arg0 : tensor<3x?xi8>) -> tensor<?x3xi8> {
@@ -379,7 +379,7 @@ func.func @all_to_all_invalid_non_dynamic_result_concat_dimension_induced_by_dyn
// -----
-mesh.cluster @mesh0(shape = 3)
+mesh.mesh @mesh0(shape = 3)
func.func @all_to_all_invalid_non_dynamic_result_concat_dimension_size(
%arg0 : tensor<3x2xi8>) -> tensor<1x7xi8> {
@@ -392,7 +392,7 @@ func.func @all_to_all_invalid_non_dynamic_result_concat_dimension_size(
// -----
-mesh.cluster @mesh0(shape = 3)
+mesh.mesh @mesh0(shape = 3)
func.func @all_to_all_invalid_non_dynamic_result_split_dimension_size(
%arg0 : tensor<3x2xi8>) -> tensor<2x6xi8> {
@@ -405,7 +405,7 @@ func.func @all_to_all_invalid_non_dynamic_result_split_dimension_size(
// -----
-mesh.cluster @mesh0(shape = 3x?)
+mesh.mesh @mesh0(shape = 3x?)
func.func @broadcast_root_dimension_out_of_bounds(
%arg0 : tensor<2xi8>) -> tensor<2xi8> {
@@ -418,7 +418,7 @@ func.func @broadcast_root_dimension_out_of_bounds(
// -----
-mesh.cluster @mesh0(shape = 3x?)
+mesh.mesh @mesh0(shape = 3x?)
func.func @broadcast_root_wrong_number_dimensions(
%arg0 : tensor<2xi8>) -> tensor<2xi8> {
@@ -431,7 +431,7 @@ func.func @broadcast_root_wrong_number_dimensions(
// -----
-mesh.cluster @mesh0(shape = 3x?)
+mesh.mesh @mesh0(shape = 3x?)
func.func @broadcast_
diff erent_input_and_result_type(
%arg0 : tensor<2xi8>) -> tensor<2xi16> {
@@ -444,7 +444,7 @@ func.func @broadcast_
diff erent_input_and_result_type(
// -----
-mesh.cluster @mesh0(shape = 1)
+mesh.mesh @mesh0(shape = 1)
func.func @gather_wrong_return_element_type(
%arg0 : tensor<1xf32>) -> tensor<1xi8> {
@@ -456,7 +456,7 @@ func.func @gather_wrong_return_element_type(
// -----
-mesh.cluster @mesh0(shape = 1)
+mesh.mesh @mesh0(shape = 1)
func.func @gather_invalid_non_gather_axis_dimension_size(
%arg0 : tensor<3x4xf32>) -> tensor<3x5xf32> {
@@ -468,7 +468,7 @@ func.func @gather_invalid_non_gather_axis_dimension_size(
// -----
-mesh.cluster @mesh0(shape = 1x2)
+mesh.mesh @mesh0(shape = 1x2)
func.func @gather_invalid_gather_axis_dimension_size(
%arg0 : tensor<3x4xf32>) -> tensor<3x5xf32> {
@@ -480,7 +480,7 @@ func.func @gather_invalid_gather_axis_dimension_size(
// -----
-mesh.cluster @mesh0(shape = 1)
+mesh.mesh @mesh0(shape = 1)
func.func @gather_invalid_gather_axis_dynamic_dimension(
%arg0 : tensor<?xf32>) -> tensor<3xf32> {
@@ -492,7 +492,7 @@ func.func @gather_invalid_gather_axis_dynamic_dimension(
// -----
-mesh.cluster @mesh0(shape = 1)
+mesh.mesh @mesh0(shape = 1)
func.func @gather_invalid_gather_axis(
%arg0 : tensor<3xf32>) -> tensor<3xf32> {
@@ -504,7 +504,7 @@ func.func @gather_invalid_gather_axis(
// -----
-mesh.cluster @mesh0(shape = 1)
+mesh.mesh @mesh0(shape = 1)
func.func @gather_invalid_negative_gather_axis(
%arg0 : tensor<3xf32>) -> tensor<3xf32> {
@@ -516,7 +516,7 @@ func.func @gather_invalid_negative_gather_axis(
// -----
-mesh.cluster @mesh0(shape = 3x?)
+mesh.mesh @mesh0(shape = 3x?)
func.func @gather_root_dimension_out_of_bounds(
%arg0 : tensor<2xi8>) -> tensor<6xi8> {
@@ -529,7 +529,7 @@ func.func @gather_root_dimension_out_of_bounds(
// -----
-mesh.cluster @mesh0(shape = 3x?)
+mesh.mesh @mesh0(shape = 3x?)
func.func @gather_root_wrong_number_dimensions(
%arg0 : tensor<2xi8>) -> tensor<2xi8> {
@@ -542,7 +542,7 @@ func.func @gather_root_wrong_number_dimensions(
// -----
-mesh.cluster @mesh0(shape = 3x?)
+mesh.mesh @mesh0(shape = 3x?)
func.func @receive_source_dimension_out_of_bounds(
%arg0 : tensor<2xi8>) -> tensor<2xi8> {
@@ -555,7 +555,7 @@ func.func @receive_source_dimension_out_of_bounds(
// -----
-mesh.cluster @mesh0(shape = 3x?)
+mesh.mesh @mesh0(shape = 3x?)
func.func @receive_source_wrong_number_dimensions(
%arg0 : tensor<2xi8>) -> tensor<2xi8> {
@@ -568,7 +568,7 @@ func.func @receive_source_wrong_number_dimensions(
// -----
-mesh.cluster @mesh0(shape = 3x?)
+mesh.mesh @mesh0(shape = 3x?)
func.func @receive_
diff erent_input_and_result_type(
%arg0 : tensor<2xi8>) -> tensor<2xi16> {
@@ -581,7 +581,7 @@ func.func @receive_
diff erent_input_and_result_type(
// -----
-mesh.cluster @mesh0(shape = 3x?)
+mesh.mesh @mesh0(shape = 3x?)
func.func @reduce_root_dimension_out_of_bounds(
%arg0 : tensor<2xi8>) -> tensor<2xi8> {
@@ -594,7 +594,7 @@ func.func @reduce_root_dimension_out_of_bounds(
// -----
-mesh.cluster @mesh0(shape = 3x?)
+mesh.mesh @mesh0(shape = 3x?)
func.func @reduce_root_wrong_number_dimensions(
%arg0 : tensor<2xi8>) -> tensor<2xi8> {
@@ -607,7 +607,7 @@ func.func @reduce_root_wrong_number_dimensions(
// -----
-mesh.cluster @mesh0(shape = 3x?)
+mesh.mesh @mesh0(shape = 3x?)
func.func @reduce_
diff erent_input_and_result_shape(
%arg0 : tensor<2xi8>) -> tensor<3xi16> {
@@ -620,7 +620,7 @@ func.func @reduce_
diff erent_input_and_result_shape(
// -----
-mesh.cluster @mesh0(shape = 3)
+mesh.mesh @mesh0(shape = 3)
func.func @reduce_scatter_duplicate_mesh_axis(
%arg0 : tensor<?xf32>) -> tensor<?xf64> {
@@ -632,7 +632,7 @@ func.func @reduce_scatter_duplicate_mesh_axis(
// -----
-mesh.cluster @mesh0(shape = 3)
+mesh.mesh @mesh0(shape = 3)
func.func @reduce_scatter_invalid_dynamic_dimension(
%arg0 : tensor<?xf32>) -> tensor<2xf64> {
@@ -644,7 +644,7 @@ func.func @reduce_scatter_invalid_dynamic_dimension(
// -----
-mesh.cluster @mesh0(shape = 3)
+mesh.mesh @mesh0(shape = 3)
func.func @reduce_scatter_invalid_static_dimension_size(
%arg0 : tensor<3xf32>) -> tensor<2xf64> {
@@ -656,7 +656,7 @@ func.func @reduce_scatter_invalid_static_dimension_size(
// -----
-mesh.cluster @mesh0(shape = 3)
+mesh.mesh @mesh0(shape = 3)
func.func @reduce_scatter_invalid_operand_static_dimension_size(
%arg0 : tensor<4xf32>) -> tensor<?xf64> {
@@ -668,7 +668,7 @@ func.func @reduce_scatter_invalid_operand_static_dimension_size(
// -----
-mesh.cluster @mesh0(shape = 3)
+mesh.mesh @mesh0(shape = 3)
func.func @scatter_duplicate_mesh_axis(
%arg0 : tensor<?xf32>) -> tensor<?xf32> {
@@ -681,7 +681,7 @@ func.func @scatter_duplicate_mesh_axis(
// -----
-mesh.cluster @mesh0(shape = 3)
+mesh.mesh @mesh0(shape = 3)
func.func @scatter_invalid_dynamic_dimension(
%arg0 : tensor<?xf32>) -> tensor<2xf32> {
@@ -694,7 +694,7 @@ func.func @scatter_invalid_dynamic_dimension(
// -----
-mesh.cluster @mesh0(shape = 3)
+mesh.mesh @mesh0(shape = 3)
func.func @scatter_invalid_static_dimension_size(
%arg0 : tensor<3xf32>) -> tensor<2xf32> {
@@ -707,7 +707,7 @@ func.func @scatter_invalid_static_dimension_size(
// -----
-mesh.cluster @mesh0(shape = 3)
+mesh.mesh @mesh0(shape = 3)
func.func @scatter_invalid_operand_static_dimension_size(
%arg0 : tensor<4xf32>) -> tensor<?xf32> {
@@ -720,7 +720,7 @@ func.func @scatter_invalid_operand_static_dimension_size(
// -----
-mesh.cluster @mesh0(shape = 3x?)
+mesh.mesh @mesh0(shape = 3x?)
func.func @scatter_root_dimension_out_of_bounds(
%arg0 : tensor<3xi8>) -> tensor<1xi8> {
@@ -733,7 +733,7 @@ func.func @scatter_root_dimension_out_of_bounds(
// -----
-mesh.cluster @mesh0(shape = 3x?)
+mesh.mesh @mesh0(shape = 3x?)
func.func @scatter_root_wrong_number_dimensions(
%arg0 : tensor<3xi8>) -> tensor<1xi8> {
@@ -746,7 +746,7 @@ func.func @scatter_root_wrong_number_dimensions(
// -----
-mesh.cluster @mesh0(shape = 3x?)
+mesh.mesh @mesh0(shape = 3x?)
func.func @send_destination_dimension_out_of_bounds(
%arg0 : tensor<2xi8>) -> tensor<2xi8> {
@@ -759,7 +759,7 @@ func.func @send_destination_dimension_out_of_bounds(
// -----
-mesh.cluster @mesh0(shape = 3x?)
+mesh.mesh @mesh0(shape = 3x?)
func.func @send_destination_wrong_number_dimensions(
%arg0 : tensor<2xi8>) -> tensor<2xi8> {
@@ -772,7 +772,7 @@ func.func @send_destination_wrong_number_dimensions(
// -----
-mesh.cluster @mesh0(shape = 3x?)
+mesh.mesh @mesh0(shape = 3x?)
func.func @send_
diff erent_input_and_result_type(
%arg0 : tensor<2xi8>) -> tensor<2xi16> {
@@ -796,7 +796,7 @@ func.func @shift_invalid_mesh_symbol(
// -----
-mesh.cluster @mesh0(shape = 2x4)
+mesh.mesh @mesh0(shape = 2x4)
func.func @shift_invalid_mesh_axis(
%arg0 : tensor<4xi8>) -> tensor<4xi8> {
@@ -809,7 +809,7 @@ func.func @shift_invalid_mesh_axis(
// -----
-mesh.cluster @mesh0(shape = 2x4)
+mesh.mesh @mesh0(shape = 2x4)
func.func @shift_duplicate_mesh_axis(
%arg0 : tensor<4xi8>) -> tensor<4xi8> {
@@ -822,7 +822,7 @@ func.func @shift_duplicate_mesh_axis(
// -----
-mesh.cluster @mesh0(shape = 2x4)
+mesh.mesh @mesh0(shape = 2x4)
func.func @shift_invalid_tensor_dimension_size(
%arg0 : tensor<4xi8>) -> tensor<5xi8> {
@@ -835,7 +835,7 @@ func.func @shift_invalid_tensor_dimension_size(
// -----
-mesh.cluster @mesh0(shape = 2x4)
+mesh.mesh @mesh0(shape = 2x4)
func.func @shift_invalid_shift_axis(
%arg0 : tensor<4xi8>) -> tensor<4xi8> {
diff --git a/mlir/test/Dialect/Mesh/ops.mlir b/mlir/test/Dialect/Mesh/ops.mlir
index 0aaa4bdee1db3a..dbaaff9c172fd9 100644
--- a/mlir/test/Dialect/Mesh/ops.mlir
+++ b/mlir/test/Dialect/Mesh/ops.mlir
@@ -1,21 +1,21 @@
// RUN: mlir-opt %s | mlir-opt | FileCheck %s
-// CHECK: mesh.cluster @mesh0
-mesh.cluster @mesh0(shape = 2x2x4)
+// CHECK: mesh.mesh @mesh0
+mesh.mesh @mesh0(shape = 2x2x4)
-// CHECK: mesh.cluster @mesh1(shape = 4x?)
-mesh.cluster @mesh1(shape = 4x?)
+// CHECK: mesh.mesh @mesh1(shape = 4x?)
+mesh.mesh @mesh1(shape = 4x?)
-// CHECK: mesh.cluster @mesh2(shape = ?x4)
-mesh.cluster @mesh2(shape = ?x4)
+// CHECK: mesh.mesh @mesh2(shape = ?x4)
+mesh.mesh @mesh2(shape = ?x4)
-// CHECK: mesh.cluster @mesh3(shape = ?x?)
-mesh.cluster @mesh3(shape = ?x?)
+// CHECK: mesh.mesh @mesh3(shape = ?x?)
+mesh.mesh @mesh3(shape = ?x?)
-mesh.cluster @mesh4(shape = 3)
+mesh.mesh @mesh4(shape = 3)
-// CHECK: mesh.cluster @mesh5(shape = ?)
-mesh.cluster @mesh5(shape = ?)
+// CHECK: mesh.mesh @mesh5(shape = ?)
+mesh.mesh @mesh5(shape = ?)
// CHECK-LABEL: func @mesh_shard_encoding_fully_replicated
func.func @mesh_shard_encoding_fully_replicated(
@@ -132,26 +132,26 @@ func.func @mesh_shard_op_two_users(%arg0 : tensor<4x8xf32>) ->
return %1, %2 : tensor<4x8xf32>, tensor<4x8xf32>
}
-// CHECK-LABEL: func @cluster_shape
-func.func @cluster_shape() -> (index, index) {
- // CHECK: %[[RES:.*]]:2 = mesh.cluster_shape @mesh0 axes = [0, 1] : index, index
- %0:2 = mesh.cluster_shape @mesh0 axes = [0, 1] : index, index
+// CHECK-LABEL: func @mesh_shape
+func.func @mesh_shape() -> (index, index) {
+ // CHECK: %[[RES:.*]]:2 = mesh.mesh_shape @mesh0 axes = [0, 1] : index, index
+ %0:2 = mesh.mesh_shape @mesh0 axes = [0, 1] : index, index
// CHECK: return %[[RES]]#0, %[[RES]]#1 : index, index
return %0#0, %0#1 : index, index
}
-// CHECK-LABEL: func @cluster_shape_default_axes
-func.func @cluster_shape_default_axes() -> (index, index, index) {
- // CHECK: %[[RES:.*]]:3 = mesh.cluster_shape @mesh0 : index, index, index
- %0:3 = mesh.cluster_shape @mesh0 : index, index, index
+// CHECK-LABEL: func @mesh_shape_default_axes
+func.func @mesh_shape_default_axes() -> (index, index, index) {
+ // CHECK: %[[RES:.*]]:3 = mesh.mesh_shape @mesh0 : index, index, index
+ %0:3 = mesh.mesh_shape @mesh0 : index, index, index
// CHECK: return %[[RES]]#0, %[[RES]]#1, %[[RES]]#2 : index, index, index
return %0#0, %0#1, %0#2 : index, index, index
}
-// CHECK-LABEL: func @cluster_shape_empty_axes
-func.func @cluster_shape_empty_axes() -> (index, index, index) {
- // CHECK: %[[RES:.*]]:3 = mesh.cluster_shape @mesh0 : index, index, index
- %0:3 = mesh.cluster_shape @mesh0 axes = [] : index, index, index
+// CHECK-LABEL: func @mesh_shape_empty_axes
+func.func @mesh_shape_empty_axes() -> (index, index, index) {
+ // CHECK: %[[RES:.*]]:3 = mesh.mesh_shape @mesh0 : index, index, index
+ %0:3 = mesh.mesh_shape @mesh0 axes = [] : index, index, index
// CHECK: return %[[RES]]#0, %[[RES]]#1, %[[RES]]#2 : index, index, index
return %0#0, %0#1, %0#2 : index, index, index
}
diff --git a/mlir/test/Dialect/Mesh/process-multi-index-op-lowering.mlir b/mlir/test/Dialect/Mesh/process-multi-index-op-lowering.mlir
index aeeba4ea3f1979..677a5982ea2540 100644
--- a/mlir/test/Dialect/Mesh/process-multi-index-op-lowering.mlir
+++ b/mlir/test/Dialect/Mesh/process-multi-index-op-lowering.mlir
@@ -1,11 +1,11 @@
// RUN: mlir-opt -test-mesh-process-multi-index-op-lowering %s | FileCheck %s
-mesh.cluster @mesh2d(shape = ?x?)
+mesh.mesh @mesh2d(shape = ?x?)
// CHECK-LABEL: func.func @multi_index_2d_mesh
func.func @multi_index_2d_mesh() -> (index, index) {
// CHECK: %[[LINEAR_IDX:.*]] = mesh.process_linear_index on @mesh2d : index
- // CHECK: %[[MESH_SHAPE:.*]]:2 = mesh.cluster_shape @mesh2d : index, index
+ // CHECK: %[[MESH_SHAPE:.*]]:2 = mesh.mesh_shape @mesh2d : index, index
// CHECK: %[[MULTI_IDX:.*]]:2 = affine.delinearize_index %0 into (%[[MESH_SHAPE]]#0, %[[MESH_SHAPE]]#1) : index, index
%0:2 = mesh.process_multi_index on @mesh2d : index, index
// CHECK: return %[[MULTI_IDX]]#0, %[[MULTI_IDX]]#1 : index, index
@@ -15,7 +15,7 @@ func.func @multi_index_2d_mesh() -> (index, index) {
// CHECK-LABEL: func.func @multi_index_2d_mesh_single_inner_axis
func.func @multi_index_2d_mesh_single_inner_axis() -> index {
// CHECK: %[[LINEAR_IDX:.*]] = mesh.process_linear_index on @mesh2d : index
- // CHECK: %[[MESH_SHAPE:.*]]:2 = mesh.cluster_shape @mesh2d : index, index
+ // CHECK: %[[MESH_SHAPE:.*]]:2 = mesh.mesh_shape @mesh2d : index, index
// CHECK: %[[MULTI_IDX:.*]]:2 = affine.delinearize_index %0 into (%[[MESH_SHAPE]]#0, %[[MESH_SHAPE]]#1) : index, index
%0 = mesh.process_multi_index on @mesh2d axes = [0] : index
// CHECK: return %[[MULTI_IDX]]#0 : index
diff --git a/mlir/test/Dialect/Mesh/resharding-spmdization.mlir b/mlir/test/Dialect/Mesh/resharding-spmdization.mlir
index 3f5c7d80bf9c7e..cb98d31dd6692d 100644
--- a/mlir/test/Dialect/Mesh/resharding-spmdization.mlir
+++ b/mlir/test/Dialect/Mesh/resharding-spmdization.mlir
@@ -1,7 +1,7 @@
// RUN: mlir-opt -test-mesh-resharding-spmdization %s | FileCheck %s
-mesh.cluster @mesh_1d(shape = 2)
-mesh.cluster @mesh_1d_dynamic(shape = ?)
+mesh.mesh @mesh_1d(shape = 2)
+mesh.mesh @mesh_1d_dynamic(shape = ?)
// CHECK-LABEL: func @same_source_and_target_sharding
func.func @same_source_and_target_sharding(
@@ -22,7 +22,7 @@ func.func @split_replicated_tensor_axis(
// CHECK-DAG: %[[ZERO:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[TENSOR_SPLIT_AXIS_SIZE:.*]] = arith.constant 14 : index
// CHECK: %[[PROCESS_INDEX:.*]] = mesh.process_multi_index on @mesh_1d axes = [0] : index
- // CHECK: %[[MESH_AXIS_SIZE:.*]] = mesh.cluster_shape @mesh_1d axes = [0] : index
+ // CHECK: %[[MESH_AXIS_SIZE:.*]] = mesh.mesh_shape @mesh_1d axes = [0] : index
// CHECK: %[[TENSOR_SPLIT_AXIS_SIZE_MOD_MESH_AXIS_SIZE:.*]] = arith.remui %[[TENSOR_SPLIT_AXIS_SIZE]], %[[MESH_AXIS_SIZE]] : index
// CHECK: %[[RESULT_TENSOR_AXIS_SIZE_CHECK:.*]] = arith.cmpi eq, %[[TENSOR_SPLIT_AXIS_SIZE_MOD_MESH_AXIS_SIZE]], %[[ZERO]] : index
// CHECK: cf.assert %[[RESULT_TENSOR_AXIS_SIZE_CHECK]]
@@ -44,7 +44,7 @@ func.func @split_replicated_tensor_axis_dynamic(
// CHECK-DAG: %[[ZERO:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[TWO:.*]] = arith.constant 2 : index
// CHECK: %[[PROCESS_INDEX:.*]] = mesh.process_multi_index on @mesh_1d_dynamic axes = [0] : index
- // CHECK: %[[MESH_AXIS_SIZE:.*]] = mesh.cluster_shape @mesh_1d_dynamic axes = [0] : index
+ // CHECK: %[[MESH_AXIS_SIZE:.*]] = mesh.mesh_shape @mesh_1d_dynamic axes = [0] : index
// CHECK: %[[TENSOR_SPLIT_AXIS_SIZE:.*]] = tensor.dim %[[ARG]], %[[ZERO]] : tensor<?x3x?xf32>
// CHECK: %[[TENSOR_SPLIT_AXIS_SIZE_MOD_MESH_AXIS_SIZE:.*]] = arith.remui %[[TENSOR_SPLIT_AXIS_SIZE]], %[[MESH_AXIS_SIZE]] : index
// CHECK: %[[RESULT_TENSOR_AXIS_SIZE_CHECK:.*]] = arith.cmpi eq, %[[TENSOR_SPLIT_AXIS_SIZE_MOD_MESH_AXIS_SIZE]], %[[ZERO]] : index
diff --git a/mlir/test/Dialect/Mesh/sharding-propagation.mlir b/mlir/test/Dialect/Mesh/sharding-propagation.mlir
index 065ae9ca8c6b41..94f8d94073c5ef 100644
--- a/mlir/test/Dialect/Mesh/sharding-propagation.mlir
+++ b/mlir/test/Dialect/Mesh/sharding-propagation.mlir
@@ -1,8 +1,8 @@
// RUN: mlir-opt -sharding-propagation %s | FileCheck %s
-mesh.cluster @mesh_1d(shape = ?)
-mesh.cluster @mesh_2d(shape = 2x4)
-mesh.cluster @mesh_3d(shape = ?x?x?)
+mesh.mesh @mesh_1d(shape = ?)
+mesh.mesh @mesh_2d(shape = 2x4)
+mesh.mesh @mesh_3d(shape = ?x?x?)
// CHECK-LABEL: func.func @element_wise_empty_sharding_info
func.func @element_wise_empty_sharding_info(%arg0: tensor<8x16xf32>) -> tensor<8x16xf32> {
diff --git a/mlir/test/Dialect/Mesh/simplifications.mlir b/mlir/test/Dialect/Mesh/simplifications.mlir
index 63ae9d528df1b3..d748be82c5a46a 100644
--- a/mlir/test/Dialect/Mesh/simplifications.mlir
+++ b/mlir/test/Dialect/Mesh/simplifications.mlir
@@ -1,7 +1,7 @@
// RUN: mlir-opt -test-mesh-simplifications %s | FileCheck %s
-mesh.cluster @mesh0(shape = 4x2)
-mesh.cluster @mesh1(shape = 4)
+mesh.mesh @mesh0(shape = 4x2)
+mesh.mesh @mesh1(shape = 4)
// Checks that `all_reduce(x) + all_reduce(y)` gets transformed to
// `all_reduce(x + y)`.
diff --git a/mlir/test/lib/Dialect/Mesh/TestReshardingSpmdization.cpp b/mlir/test/lib/Dialect/Mesh/TestReshardingSpmdization.cpp
index 6fecbd48f15387..9b3082a819224f 100644
--- a/mlir/test/lib/Dialect/Mesh/TestReshardingSpmdization.cpp
+++ b/mlir/test/lib/Dialect/Mesh/TestReshardingSpmdization.cpp
@@ -37,15 +37,15 @@ struct TestMeshReshardingRewritePattern : OpRewritePattern<ShardOp> {
}
SymbolTableCollection symbolTable;
- mesh::ClusterOp mesh = symbolTable.lookupNearestSymbolFrom<mesh::ClusterOp>(
- op, op.getShard().getCluster());
+ mesh::MeshOp mesh = symbolTable.lookupNearestSymbolFrom<mesh::MeshOp>(
+ op, op.getShard().getMesh());
bool foundUser = false;
for (auto user : op->getUsers()) {
if (auto targetShardOp = llvm::dyn_cast<ShardOp>(user)) {
if (targetShardOp.getAnnotateForUsers() &&
- mesh == symbolTable.lookupNearestSymbolFrom<mesh::ClusterOp>(
- targetShardOp, targetShardOp.getShard().getCluster())) {
+ mesh == symbolTable.lookupNearestSymbolFrom<mesh::MeshOp>(
+ targetShardOp, targetShardOp.getShard().getMesh())) {
foundUser = true;
break;
}
@@ -59,8 +59,8 @@ struct TestMeshReshardingRewritePattern : OpRewritePattern<ShardOp> {
for (auto user : op->getUsers()) {
auto targetShardOp = llvm::dyn_cast<ShardOp>(user);
if (!targetShardOp || !targetShardOp.getAnnotateForUsers() ||
- symbolTable.lookupNearestSymbolFrom<mesh::ClusterOp>(
- targetShardOp, targetShardOp.getShard().getCluster()) != mesh) {
+ symbolTable.lookupNearestSymbolFrom<mesh::MeshOp>(
+ targetShardOp, targetShardOp.getShard().getMesh()) != mesh) {
continue;
}
More information about the Mlir-commits
mailing list