[Mlir-commits] [mlir] [mlir][mesh] Remove rank attribute and rename dim_sizes to shape in ClusterOp (PR #77838)
Boian Petkantchin
llvmlistbot at llvm.org
Thu Jan 11 13:48:05 PST 2024
https://github.com/sogartar created https://github.com/llvm/llvm-project/pull/77838
Remove the somewhat redundant rank attribute.
Before this change
```
mesh.cluster @mesh(rank = 3, shape = 2x3)
```
After
```
mesh.cluster @mesh(shape = 2x3x?)
```
The rank is instead determined by the provided shape. With this change no longer `getDimSizes()` can be wrongly assumed to have size equal to the cluster rank.
Now `getShape().size()` will always equal `getRank()`.
>From 394ff443201d6ff70e2d2a3900c0f7da89198a26 Mon Sep 17 00:00:00 2001
From: Boian Petkantchin <boian.petkantchin at amd.com>
Date: Thu, 11 Jan 2024 12:09:31 -0800
Subject: [PATCH] [mlir][mesh] Remove rank attribute and rename dim_sizes to
shape in ClusterOp
Remove the somewhat redundant rank attribute.
Before this change
```
mesh.cluster @mesh(rank = 3, shape = 2x3)
```
After
```
mesh.cluster @mesh(shape = 2x3x?)
```
The rank is instead determined by the provided shape.
With this change no longer `getDimSizes()` can be wrongly assumed to have size
equal to the cluster rank.
Now `getShape().size()` will always equal `getRank()`.
---
mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td | 2 +-
mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td | 40 ++----
mlir/lib/Dialect/Mesh/IR/MeshOps.cpp | 56 +++-----
.../Mesh/Transforms/Simplifications.cpp | 4 +-
.../Dialect/Mesh/Transforms/Spmdization.cpp | 16 +--
mlir/test/Dialect/Mesh/canonicalization.mlir | 2 +-
mlir/test/Dialect/Mesh/folding.mlir | 4 +-
mlir/test/Dialect/Mesh/invalid.mlir | 135 +++++++++---------
mlir/test/Dialect/Mesh/ops.mlir | 20 +--
.../Mesh/process-multi-index-op-lowering.mlir | 2 +-
.../Dialect/Mesh/resharding-spmdization.mlir | 4 +-
.../Dialect/Mesh/sharding-propagation.mlir | 6 +-
mlir/test/Dialect/Mesh/simplifications.mlir | 4 +-
13 files changed, 134 insertions(+), 161 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td b/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td
index bda6467e9c5d4b..07f954459ca49d 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td
@@ -117,7 +117,7 @@ def MeshSharding : AttrDef<Mesh_Dialect, "MeshSharding"> {
Example:
```
- mesh.cluster @mesh0(rank = 3, dim_sizes = [2, 2, 4])
+ mesh.cluster @mesh0(shape = 2x2x4)
// The tensor is fully replicated on @mesh0.
// Currently, there must be at least one sub-array present in axes, even
diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
index a9068562f5c903..c25996cf122c1c 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
@@ -41,7 +41,7 @@ def Mesh_ClusterOp : Mesh_Op<"cluster", [Symbol]> {
determine the layout and the addressing space of the computation distributed
across the mesh.
- 3. `dim_sizes`: This attribute represents the shape of the device cluster.
+ 3. `shape`: This attribute represents the shape of the device cluster.
It uses the same notation as a tensor shape. Also allowing for dynamic
dimensions.
This flexibility allows for dynamic device assignment or configurations
@@ -53,19 +53,19 @@ def Mesh_ClusterOp : Mesh_Op<"cluster", [Symbol]> {
```
// A device mesh cluster with 3 axes, the total device number is 4 * 8 * 12
// The dimension sizes are 4, 8, 12
- mesh.cluster @mesh0(rank = 3, dim_sizes = 4x8x12)
+ mesh.cluster @mesh0(shape = 4x8x12)
// A device mesh cluster with 2 axes, the total device number is unknown
// The first dimension size is 4 and the second is unknown
- mesh.cluster @mesh1(rank = 2, dim_sizes = 4)
+ mesh.cluster @mesh1(shape = 4x?)
// A device mesh cluster with 2 axes, the total device number is unknown
// The first dimension size is unknown and the second is 4
- mesh.cluster @mesh2(rank = 2, dim_sizes = ?x4)
+ mesh.cluster @mesh2(shape = ?x4)
// A device mesh cluster with 2 axes, the number of devices along both axes
// is unknown
- mesh.cluster @mesh3(rank = 2)
+ mesh.cluster @mesh3(shape = ?x?)
// Used in the mesh sharding attribute to extend the standard tensor to
// distributed
@@ -74,24 +74,14 @@ def Mesh_ClusterOp : Mesh_Op<"cluster", [Symbol]> {
}];
let arguments = (ins
SymbolNameAttr:$sym_name,
- I64Attr:$rank,
- DefaultValuedAttr<DenseI64ArrayAttr, "{}">:$dim_sizes
+ DenseI64ArrayAttr:$shape
);
let assemblyFormat = [{
- $sym_name `(` `rank` `=` $rank (`,` `dim_sizes` `=` custom<DimensionList>($dim_sizes)^)? `)`
+ $sym_name `(` `shape` `=` custom<DimensionList>($shape) `)`
attr-dict
}];
let extraClassDeclaration = [{
- // The `dim_sizes` attribute may have size less than the rank of the mesh.
- // Returns the shape of the mesh with missing trailing dimensions
- // explicitly set as dynamic.
- ::mlir::SmallVector<int64_t> canonicalDimSizes();
-
- template <typename OutIt>
- void canonicalDimSizes(OutIt outIt) {
- std::copy(getDimSizes().begin(), getDimSizes().end(), outIt);
- std::fill_n(outIt, getRank() - getDimSizes().size(), ::mlir::ShapedType::kDynamic);
- }
+ int64_t getRank() { return getShape().size(); }
}];
let hasVerifier = 1;
}
@@ -283,7 +273,7 @@ def Mesh_AllGatherOp : Mesh_CollectiveCommunicationOpBase<"all_gather", [
Example:
```mlir
- mesh.cluster @mesh0(rank = 2, dim_sizes = 2x2)
+ mesh.cluster @mesh0(shape = 2x2)
...
%1 = mesh.all_gather %0 on @mesh0 mesh_axes = [1] gather_axis = 1
: tensor<2x2xi8> -> tensor<2x4xi8>
@@ -368,7 +358,7 @@ def Mesh_AllToAllOp : Mesh_CollectiveCommunicationOpBase<"all_to_all", [
Example:
```
- mesh.cluster @mesh0(rank = 1, dim_sizes = 3)
+ mesh.cluster @mesh0(shape = 3)
...
%1 = mesh.all_to_all %0 on @mesh0 mesh_axes = [0]
split_axis = 0 concat_axis = 0
@@ -425,7 +415,7 @@ def Mesh_BroadcastOp : Mesh_CollectiveCommunicationOpBase<"broadcast", [
Example:
```
- mesh.cluster @mesh0(rank = 2, dim_sizes = [2, 2])
+ mesh.cluster @mesh0(shape = 2x2)
%1 = mesh.broadcast %0 on @mesh0
mesh_axes = [0]
@@ -481,7 +471,7 @@ def Mesh_GatherOp : Mesh_CollectiveCommunicationOpBase<"gather", [
Example:
```mlir
- mesh.cluster @mesh0(rank = 2, dim_sizes = [2, 2])
+ mesh.cluster @mesh0(shape = 2x2)
...
%1 = mesh.gather %0 on @mesh0 mesh_axes = [1]
gather_axis = 1 root = [1]
@@ -604,7 +594,7 @@ def Mesh_ReduceScatterOp : Mesh_CollectiveCommunicationOpBase<"reduce_scatter",
across the device group.
Example:
```
- mesh.cluster @mesh0(rank = 1, dim_sizes = 2x2)
+ mesh.cluster @mesh0(shape = 2x2)
...
%1 = mesh.reduce_scatter %0 on @mesh0 mesh_axes = [1]
reduction = <max> scatter_axis = 0
@@ -667,7 +657,7 @@ def Mesh_ScatterOp : Mesh_CollectiveCommunicationOpBase<"scatter", [
Example:
```
- mesh.cluster @mesh0(rank = 2, dim_sizes = [2, 2])
+ mesh.cluster @mesh0(shape = 2x2)
%1 = mesh.scatter %0 on @mesh0 mesh_axes = [0]
scatter_axis = 0
root = [1]
@@ -763,7 +753,7 @@ def Mesh_ShiftOp : Mesh_CollectiveCommunicationOpBase<"shift", [
Example:
```
- mesh.cluster @mesh0(rank = 2, dim_sizes = [2, 4])
+ mesh.cluster @mesh0(shape = 2x4)
%1 = mesh.shift on @mesh0 mesh_axes = [1]
shift_axis = 1 offset = 2 rotate
: tensor<2xi8> -> tensor<2xi8>
diff --git a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
index 957b380efd516b..fa9da596a34587 100644
--- a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
+++ b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
@@ -196,17 +196,16 @@ Partial mesh::getPartialTypeFromReduction(IteratorType iType) {
//===----------------------------------------------------------------------===//
LogicalResult ClusterOp::verify() {
- ArrayRef<int64_t> dimSizes = getDimSizes();
- uint64_t rank = getRank();
+ int64_t rank = getRank();
- if (rank == 0)
+ if (rank <= 0)
return emitOpError("rank of cluster is expected to be a positive integer");
- if (dimSizes.size() > rank)
+ if (getShape().size() > rank)
return emitOpError(
- "rank of dim_sizes is not expected to be larger than rank of cluster");
+ "rank of shape is not expected to be larger than rank of cluster");
- for (int64_t dimSize : dimSizes) {
+ for (int64_t dimSize : getShape()) {
if (dimSize < 0 && !ShapedType::isDynamic(dimSize))
return emitOpError("dimension size of a mesh cluster is expected to be "
"non-negative or dynamic");
@@ -215,13 +214,6 @@ LogicalResult ClusterOp::verify() {
return success();
}
-SmallVector<int64_t> ClusterOp::canonicalDimSizes() {
- SmallVector<int64_t> result;
- canonicalDimSizes(std::back_inserter(result));
- result.reserve(getRank());
- return result;
-}
-
//===----------------------------------------------------------------------===//
// mesh.cluster_shape op
//===----------------------------------------------------------------------===//
@@ -614,7 +606,7 @@ AllGatherOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
auto gatherAxis = getGatherAxis().getSExtValue();
return verifyGatherOperandAndResultShape(getOperand(), getResult(),
gatherAxis, getMeshAxes(),
- mesh.value().canonicalDimSizes());
+ mesh.value().getShape());
}
void AllGatherOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
@@ -648,8 +640,7 @@ LogicalResult AllToAllOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
return verifyAllToAllOperandAndResultShape(
getOperand(), getResult(), getSplitAxis().getSExtValue(),
- getConcatAxis().getSExtValue(), getMeshAxes(),
- mesh.value().canonicalDimSizes());
+ getConcatAxis().getSExtValue(), getMeshAxes(), mesh.value().getShape());
}
void AllToAllOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
@@ -667,9 +658,9 @@ BroadcastOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
if (failed(mesh)) {
return failure();
}
- auto meshShape = mesh.value().canonicalDimSizes();
if (failed(verifyInGroupDevice(getLoc(), getRootAttrName(), getRoot(),
- getRootDynamic(), getMeshAxes(), meshShape))) {
+ getRootDynamic(), getMeshAxes(),
+ mesh.value().getShape()))) {
return failure();
}
@@ -690,16 +681,16 @@ LogicalResult GatherOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
if (failed(mesh)) {
return failure();
}
- auto meshShape = mesh.value().canonicalDimSizes();
if (failed(verifyInGroupDevice(getLoc(), getRootAttrName(), getRoot(),
- getRootDynamic(), getMeshAxes(), meshShape))) {
+ getRootDynamic(), getMeshAxes(),
+ mesh.value().getShape()))) {
return failure();
}
auto gatherAxis = getGatherAxis().getSExtValue();
return verifyGatherOperandAndResultShape(getInput(), getResult(), gatherAxis,
getMeshAxes(),
- mesh.value().canonicalDimSizes());
+ mesh.value().getShape());
}
void GatherOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
@@ -716,10 +707,10 @@ LogicalResult RecvOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
if (failed(mesh)) {
return failure();
}
- auto meshShape = mesh.value().canonicalDimSizes();
- if (getSource() && failed(verifyInGroupDevice(
- getLoc(), getSourceAttrName(), getSource().value(),
- getSourceDynamic(), getMeshAxes(), meshShape))) {
+ if (getSource() &&
+ failed(verifyInGroupDevice(getLoc(), getSourceAttrName(),
+ getSource().value(), getSourceDynamic(),
+ getMeshAxes(), mesh.value().getShape()))) {
return failure();
}
return success();
@@ -739,9 +730,9 @@ LogicalResult ReduceOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
if (failed(mesh)) {
return failure();
}
- auto meshShape = mesh.value().canonicalDimSizes();
if (failed(verifyInGroupDevice(getLoc(), getRootAttrName(), getRoot(),
- getRootDynamic(), getMeshAxes(), meshShape))) {
+ getRootDynamic(), getMeshAxes(),
+ mesh.value().getShape()))) {
return failure();
}
@@ -766,7 +757,7 @@ ReduceScatterOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
return verifyScatterOperandAndResultShape(
getOperand(), getResult(), getScatterAxis().getSExtValue(), getMeshAxes(),
- mesh.value().canonicalDimSizes());
+ mesh.value().getShape());
}
void ReduceScatterOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
@@ -783,16 +774,16 @@ LogicalResult ScatterOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
if (failed(mesh)) {
return failure();
}
- auto meshShape = mesh.value().canonicalDimSizes();
if (failed(verifyInGroupDevice(getLoc(), getRootAttrName(), getRoot(),
- getRootDynamic(), getMeshAxes(), meshShape))) {
+ getRootDynamic(), getMeshAxes(),
+ mesh.value().getShape()))) {
return failure();
}
auto scatterAxis = getScatterAxis().getSExtValue();
return verifyScatterOperandAndResultShape(getInput(), getResult(),
scatterAxis, getMeshAxes(),
- mesh.value().canonicalDimSizes());
+ mesh.value().getShape());
}
void ScatterOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
@@ -809,10 +800,9 @@ LogicalResult SendOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
if (failed(mesh)) {
return failure();
}
- auto meshShape = mesh.value().canonicalDimSizes();
if (failed(verifyInGroupDevice(getLoc(), getDestinationAttrName(),
getDestination(), getDestinationDynamic(),
- getMeshAxes(), meshShape))) {
+ getMeshAxes(), mesh.value().getShape()))) {
return failure();
}
return success();
diff --git a/mlir/lib/Dialect/Mesh/Transforms/Simplifications.cpp b/mlir/lib/Dialect/Mesh/Transforms/Simplifications.cpp
index c9275ad5ad4551..c478b6da4c27b3 100644
--- a/mlir/lib/Dialect/Mesh/Transforms/Simplifications.cpp
+++ b/mlir/lib/Dialect/Mesh/Transforms/Simplifications.cpp
@@ -80,7 +80,7 @@ struct ClusterShapeFolder : OpRewritePattern<ClusterShapeOp> {
opMeshAxes = opAxesIota;
}
if (llvm::all_of(opMeshAxes, [&mesh](MeshAxis axis) {
- return ShapedType::isDynamic(mesh.getDimSizes()[axis]);
+ return ShapedType::isDynamic(mesh.getShape()[axis]);
})) {
// All mesh dimensions are dynamic. Nothing to fold.
return failure();
@@ -91,7 +91,7 @@ struct ClusterShapeFolder : OpRewritePattern<ClusterShapeOp> {
SmallVector<size_t> newToOldResultsIndexMap;
for (size_t i = 0; i < opMeshAxes.size(); ++i) {
- auto meshAxisSize = mesh.getDimSizes()[opMeshAxes[i]];
+ auto meshAxisSize = mesh.getShape()[opMeshAxes[i]];
if (ShapedType::isDynamic(meshAxisSize)) {
newToOldResultsIndexMap.push_back(i);
newShapeOpMeshAxes.push_back(opMeshAxes[i]);
diff --git a/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp b/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
index 0e83c024fc08f8..9478b2e4ee5cb2 100644
--- a/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
+++ b/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
@@ -88,8 +88,8 @@ ShapedType shardShapedType(ShapedType shape, ClusterOp mesh,
MeshShardingAttr sharding) {
using Dim = std::decay_t<decltype(shape.getDimSize(0))>;
SmallVector<Dim> resShapeArr(shape.getShape().size());
- shardShape(shape.getShape(), mesh.canonicalDimSizes(),
- sharding.getSplitAxes(), resShapeArr);
+ shardShape(shape.getShape(), mesh.getShape(), sharding.getSplitAxes(),
+ resShapeArr);
return shape.clone(resShapeArr);
}
@@ -212,9 +212,8 @@ splitLastAxisInResharding(ImplicitLocOpBuilder &builder,
MeshShardingAttr targetSharding = targetShardingInSplitLastAxis(
ctx, sourceSharding, splitTensorAxis, splitMeshAxis);
- ShapedType targetShape =
- targetShapeInSplitLastAxis(sourceShard.getType(), splitTensorAxis,
- mesh.canonicalDimSizes()[splitMeshAxis]);
+ ShapedType targetShape = targetShapeInSplitLastAxis(
+ sourceShard.getType(), splitTensorAxis, mesh.getShape()[splitMeshAxis]);
Value meshAxisSize =
builder
@@ -391,8 +390,7 @@ unsplitLastAxisInResharding(ImplicitLocOpBuilder &builder,
MeshShardingAttr targetSharding =
targetShardingInUnsplitLastAxis(ctx, sourceSharding, splitMeshAxis);
ShapedType allGatherResultShape = allGatherResultShapeInUnsplitLastAxis(
- sourceShard.getType(), mesh.canonicalDimSizes()[splitMeshAxis],
- splitTensorAxis);
+ sourceShard.getType(), mesh.getShape()[splitMeshAxis], splitTensorAxis);
Value allGatherResult = builder.create<AllGatherOp>(
RankedTensorType::get(allGatherResultShape.getShape(),
allGatherResultShape.getElementType()),
@@ -526,8 +524,8 @@ moveLastSplitAxisInResharding(ImplicitLocOpBuilder &builder, ClusterOp mesh,
MeshShardingAttr targetSharding = targetShardingInMoveLastAxis(
ctx, sourceSharding, sourceTensorAxis, targetTensorAxis);
ShapedType allToAllResultShape = allToAllResultShapeInMoveLastAxis(
- sourceShard.getType(), mesh.canonicalDimSizes()[meshAxis],
- sourceTensorAxis, targetTensorAxis);
+ sourceShard.getType(), mesh.getShape()[meshAxis], sourceTensorAxis,
+ targetTensorAxis);
Value allToAllResult = builder.create<AllToAllOp>(
RankedTensorType::get(allToAllResultShape.getShape(),
allToAllResultShape.getElementType()),
diff --git a/mlir/test/Dialect/Mesh/canonicalization.mlir b/mlir/test/Dialect/Mesh/canonicalization.mlir
index 0a00ab41268d01..4cc009ef24eb3c 100644
--- a/mlir/test/Dialect/Mesh/canonicalization.mlir
+++ b/mlir/test/Dialect/Mesh/canonicalization.mlir
@@ -1,6 +1,6 @@
// RUN: mlir-opt --canonicalize %s | FileCheck %s
-mesh.cluster @mesh0(rank = 2, dim_sizes = 2x4)
+mesh.cluster @mesh0(shape = 2x4)
// CHECK-LABEL: func @all_reduce_empty_mesh_axes
func.func @all_reduce_empty_mesh_axes(
diff --git a/mlir/test/Dialect/Mesh/folding.mlir b/mlir/test/Dialect/Mesh/folding.mlir
index dd64d746341b83..9162dc57ecfdf4 100644
--- a/mlir/test/Dialect/Mesh/folding.mlir
+++ b/mlir/test/Dialect/Mesh/folding.mlir
@@ -1,7 +1,7 @@
// RUN: mlir-opt -test-mesh-simplifications %s | FileCheck %s
-mesh.cluster @mesh0(rank = 3, dim_sizes = 4x?x2)
-mesh.cluster @mesh1(rank = 2, dim_sizes = 2x3)
+mesh.cluster @mesh0(shape = 4x?x2)
+mesh.cluster @mesh1(shape = 2x3)
// CHECK-LABEL: func.func @cluster_shape_op_folding
func.func @cluster_shape_op_folding() -> (index, index) {
diff --git a/mlir/test/Dialect/Mesh/invalid.mlir b/mlir/test/Dialect/Mesh/invalid.mlir
index f3524a82a1b9d2..8a1fb80065573b 100644
--- a/mlir/test/Dialect/Mesh/invalid.mlir
+++ b/mlir/test/Dialect/Mesh/invalid.mlir
@@ -1,21 +1,16 @@
// RUN: mlir-opt -split-input-file -verify-diagnostics %s
// expected-error at +1 {{rank of cluster is expected to be a positive integer}}
-mesh.cluster @mesh0(rank = 0)
-
-// -----
-
-// expected-error at +1 {{rank of dim_sizes is not expected to be larger than rank of cluster}}
-mesh.cluster @mesh0(rank = 2, dim_sizes = 2x3x4)
+mesh.cluster @mesh0(shape = [])
// -----
// expected-error at +1 {{custom op 'mesh.cluster' Failed parsing dimension list. Did you mean an empty list? It must be denoted by "[]".}}
-mesh.cluster @mesh0(rank = 2, dim_sizes = -1)
+mesh.cluster @mesh0(shape = -1)
// -----
-mesh.cluster @mesh0(rank = 2, dim_sizes = 2x4)
+mesh.cluster @mesh0(shape = 2x4)
func.func @mesh_axis_duplicated_different_subarray(
// expected-error at +1 {{mesh axis duplicated}}
@@ -26,7 +21,7 @@ func.func @mesh_axis_duplicated_different_subarray(
// -----
-mesh.cluster @mesh0(rank = 2, dim_sizes = 2x4)
+mesh.cluster @mesh0(shape = 2x4)
func.func @mesh_axis_duplicated_same_subarray(
// expected-error at +1 {{mesh axis duplicated}}
@@ -37,7 +32,7 @@ func.func @mesh_axis_duplicated_same_subarray(
// -----
-mesh.cluster @mesh0(rank = 2, dim_sizes = 2x4)
+mesh.cluster @mesh0(shape = 2x4)
func.func @mesh_axis_duplicated_bewteen_split_and_partial(
// expected-error at +1 {{mesh axis duplicated}}
@@ -48,7 +43,7 @@ func.func @mesh_axis_duplicated_bewteen_split_and_partial(
// -----
-mesh.cluster @mesh0(rank = 2, dim_sizes = 2x4)
+mesh.cluster @mesh0(shape = 2x4)
func.func @mesh_axis_negtive_in_split_part(
// expected-error at +1 {{mesh axis is expected to be non-negative}}
@@ -59,7 +54,7 @@ func.func @mesh_axis_negtive_in_split_part(
// -----
-mesh.cluster @mesh0(rank = 2, dim_sizes = 2x4)
+mesh.cluster @mesh0(shape = 2x4)
func.func @mesh_axis_negtive_in_partial(
// expected-error at +1 {{mesh axis is expected to be non-negative}}
@@ -78,7 +73,7 @@ func.func @sharding_attribute_invalid_nested_symbol(%arg0 : tensor<4x8xf32>) {
// -----
-mesh.cluster @mesh0(rank = 2, dim_sizes = 2x4)
+mesh.cluster @mesh0(shape = 2x4)
func.func @cluster_shape_mesh_axis_out_of_bounds() -> (index, index) {
// expected-error at +1 {{0-based mesh axis index 2 is out of bounds. The referenced mesh "mesh0" is of rank 2.}}
@@ -88,7 +83,7 @@ func.func @cluster_shape_mesh_axis_out_of_bounds() -> (index, index) {
// -----
-mesh.cluster @mesh0(rank = 3, dim_sizes = 1x2x3)
+mesh.cluster @mesh0(shape = 1x2x3)
func.func @cluster_shape_duplicate_mesh_axis() -> (index, index, index) {
// expected-error at +1 {{Mesh axes contains duplicate elements.}}
@@ -98,7 +93,7 @@ func.func @cluster_shape_duplicate_mesh_axis() -> (index, index, index) {
// -----
-mesh.cluster @mesh0(rank = 2, dim_sizes = 2x4)
+mesh.cluster @mesh0(shape = 2x4)
func.func @cluster_shape_wrong_number_of_results() -> (index, index) {
// expected-error at +1 {{Unexpected number of results 2. Expected 1.}}
@@ -108,7 +103,7 @@ func.func @cluster_shape_wrong_number_of_results() -> (index, index) {
// -----
-mesh.cluster @mesh0(rank = 3, dim_sizes = 1x2x3)
+mesh.cluster @mesh0(shape = 1x2x3)
func.func @cluster_shape_wrong_number_of_results_empty_mesh_axes() -> (index, index) {
// expected-error at +1 {{Unexpected number of results 2. Expected 3.}}
@@ -126,7 +121,7 @@ func.func @cluster_shape_invalid_mesh_name() -> (index) {
// -----
-mesh.cluster @mesh0(rank = 2, dim_sizes = 2x4)
+mesh.cluster @mesh0(shape = 2x4)
func.func @process_multi_index_mesh_axis_out_of_bounds() -> (index, index) {
// expected-error at +1 {{0-based mesh axis index 2 is out of bounds. The referenced mesh "mesh0" is of rank 2.}}
@@ -136,7 +131,7 @@ func.func @process_multi_index_mesh_axis_out_of_bounds() -> (index, index) {
// -----
-mesh.cluster @mesh0(rank = 3, dim_sizes = 1x2x3)
+mesh.cluster @mesh0(shape = 1x2x3)
func.func @process_multi_index_duplicate_mesh_axis() -> (index, index, index) {
// expected-error at +1 {{Mesh axes contains duplicate elements.}}
@@ -146,7 +141,7 @@ func.func @process_multi_index_duplicate_mesh_axis() -> (index, index, index) {
// -----
-mesh.cluster @mesh0(rank = 2, dim_sizes = 2x4)
+mesh.cluster @mesh0(shape = 2x4)
func.func @process_multi_index_wrong_number_of_results() -> (index, index) {
// expected-error at +1 {{Unexpected number of results 2. Expected 1.}}
@@ -156,7 +151,7 @@ func.func @process_multi_index_wrong_number_of_results() -> (index, index) {
// -----
-mesh.cluster @mesh0(rank = 3, dim_sizes = 1x2x3)
+mesh.cluster @mesh0(shape = 1x2x3)
func.func @process_multi_index_wrong_number_of_results_empty_mesh_axes() -> (index, index) {
// expected-error at +1 {{Unexpected number of results 2. Expected 3.}}
@@ -192,7 +187,7 @@ func.func @all_reduce_invalid_mesh_symbol(
// -----
-mesh.cluster @mesh0(rank = 2, dim_sizes = 2x4)
+mesh.cluster @mesh0(shape = 2x4)
func.func @all_reduce_invalid_mesh_axis(
%arg0 : tensor<4xf32>) -> tensor<4xf64> {
@@ -204,7 +199,7 @@ func.func @all_reduce_invalid_mesh_axis(
// -----
-mesh.cluster @mesh0(rank = 2, dim_sizes = 2x4)
+mesh.cluster @mesh0(shape = 2x4)
func.func @all_reduce_duplicate_mesh_axis(
%arg0 : tensor<4xf32>) -> tensor<4xf64> {
@@ -216,7 +211,7 @@ func.func @all_reduce_duplicate_mesh_axis(
// -----
-mesh.cluster @mesh0(rank = 2, dim_sizes = 2x4)
+mesh.cluster @mesh0(shape = 2x4)
func.func @all_reduce_invalid_tensor_dimension_size(
%arg0 : tensor<4xf32>) -> tensor<5xf64> {
@@ -237,7 +232,7 @@ func.func @all_gather_invalid_mesh_symbol(
// -----
-mesh.cluster @mesh0(rank = 2, dim_sizes = 2x4)
+mesh.cluster @mesh0(shape = 2x4)
func.func @all_gather_invalid_mesh_axis(
%arg0 : tensor<4xf32>) -> tensor<4xf32> {
@@ -249,7 +244,7 @@ func.func @all_gather_invalid_mesh_axis(
// -----
-mesh.cluster @mesh0(rank = 2, dim_sizes = 2x4)
+mesh.cluster @mesh0(shape = 2x4)
func.func @all_reduce_duplicate_mesh_axis(
%arg0 : tensor<4xf32>) -> tensor<4xf32> {
@@ -261,7 +256,7 @@ func.func @all_reduce_duplicate_mesh_axis(
// -----
-mesh.cluster @mesh0(rank = 1, dim_sizes = 1)
+mesh.cluster @mesh0(shape = 1)
func.func @all_gather_invalid_non_gather_axis_dimension_size(
%arg0 : tensor<3x4xf32>) -> tensor<3x5xf32> {
@@ -273,7 +268,7 @@ func.func @all_gather_invalid_non_gather_axis_dimension_size(
// -----
-mesh.cluster @mesh0(rank = 2, dim_sizes = 1x2)
+mesh.cluster @mesh0(shape = 1x2)
func.func @all_gather_invalid_gather_axis_dimension_size(
%arg0 : tensor<3x4xf32>) -> tensor<3x5xf32> {
@@ -285,7 +280,7 @@ func.func @all_gather_invalid_gather_axis_dimension_size(
// -----
-mesh.cluster @mesh0(rank = 1, dim_sizes = 1)
+mesh.cluster @mesh0(shape = 1)
func.func @all_gather_invalid_gather_axis_dynamic_dimension(
%arg0 : tensor<?xf32>) -> tensor<3xf32> {
@@ -297,7 +292,7 @@ func.func @all_gather_invalid_gather_axis_dynamic_dimension(
// -----
-mesh.cluster @mesh0(rank = 1, dim_sizes = 1)
+mesh.cluster @mesh0(shape = 1)
func.func @all_gather_invalid_gather_axis(
%arg0 : tensor<3xf32>) -> tensor<3xf32> {
@@ -309,7 +304,7 @@ func.func @all_gather_invalid_gather_axis(
// -----
-mesh.cluster @mesh0(rank = 1, dim_sizes = 1)
+mesh.cluster @mesh0(shape = 1)
func.func @all_gather_invalid_negative_gather_axis(
%arg0 : tensor<3xf32>) -> tensor<3xf32> {
@@ -332,7 +327,7 @@ func.func @all_to_all_invalid_mesh_symbol(
// -----
-mesh.cluster @mesh0(rank = 1, dim_sizes = 1)
+mesh.cluster @mesh0(shape = 1)
func.func @all_to_all_duplicate_mesh_axis(
%arg0 : tensor<3x6xi8>) -> tensor<3x6xi8> {
@@ -345,7 +340,7 @@ func.func @all_to_all_duplicate_mesh_axis(
// -----
-mesh.cluster @mesh0(rank = 2, dim_sizes = ?x1)
+mesh.cluster @mesh0(shape = ?x1)
func.func @all_to_all_invalid_non_dynamic_result_dimension_induced_by_dynamic_device_group(
%arg0 : tensor<3x6xi8>) -> tensor<3x6xi8> {
@@ -358,7 +353,7 @@ func.func @all_to_all_invalid_non_dynamic_result_dimension_induced_by_dynamic_de
// -----
-mesh.cluster @mesh0(rank = 2, dim_sizes = 1x1)
+mesh.cluster @mesh0(shape = 1x1)
func.func @all_to_all_invalid_non_dynamic_result_split_dimension_induced_by_dynamic_operand_dimension(
%arg0 : tensor<?x6xi8>) -> tensor<3x?xi8> {
@@ -371,7 +366,7 @@ func.func @all_to_all_invalid_non_dynamic_result_split_dimension_induced_by_dyna
// -----
-mesh.cluster @mesh0(rank = 2, dim_sizes = 1x1)
+mesh.cluster @mesh0(shape = 1x1)
func.func @all_to_all_invalid_non_dynamic_result_concat_dimension_induced_by_dynamic_operand_dimension(
%arg0 : tensor<3x?xi8>) -> tensor<?x3xi8> {
@@ -384,7 +379,7 @@ func.func @all_to_all_invalid_non_dynamic_result_concat_dimension_induced_by_dyn
// -----
-mesh.cluster @mesh0(rank = 1, dim_sizes = 3)
+mesh.cluster @mesh0(shape = 3)
func.func @all_to_all_invalid_non_dynamic_result_concat_dimension_size(
%arg0 : tensor<3x2xi8>) -> tensor<1x7xi8> {
@@ -397,7 +392,7 @@ func.func @all_to_all_invalid_non_dynamic_result_concat_dimension_size(
// -----
-mesh.cluster @mesh0(rank = 1, dim_sizes = 3)
+mesh.cluster @mesh0(shape = 3)
func.func @all_to_all_invalid_non_dynamic_result_split_dimension_size(
%arg0 : tensor<3x2xi8>) -> tensor<2x6xi8> {
@@ -410,7 +405,7 @@ func.func @all_to_all_invalid_non_dynamic_result_split_dimension_size(
// -----
-mesh.cluster @mesh0(rank = 2, dim_sizes = 3x?)
+mesh.cluster @mesh0(shape = 3x?)
func.func @broadcast_root_dimension_out_of_bounds(
%arg0 : tensor<2xi8>) -> tensor<2xi8> {
@@ -423,7 +418,7 @@ func.func @broadcast_root_dimension_out_of_bounds(
// -----
-mesh.cluster @mesh0(rank = 2, dim_sizes = 3x?)
+mesh.cluster @mesh0(shape = 3x?)
func.func @broadcast_root_wrong_number_dimensions(
%arg0 : tensor<2xi8>) -> tensor<2xi8> {
@@ -436,7 +431,7 @@ func.func @broadcast_root_wrong_number_dimensions(
// -----
-mesh.cluster @mesh0(rank = 2, dim_sizes = 3x?)
+mesh.cluster @mesh0(shape = 3x?)
func.func @broadcast_different_input_and_result_type(
%arg0 : tensor<2xi8>) -> tensor<2xi16> {
@@ -449,7 +444,7 @@ func.func @broadcast_different_input_and_result_type(
// -----
-mesh.cluster @mesh0(rank = 1, dim_sizes = 1)
+mesh.cluster @mesh0(shape = 1)
func.func @gather_wrong_return_element_type(
%arg0 : tensor<1xf32>) -> tensor<1xi8> {
@@ -461,7 +456,7 @@ func.func @gather_wrong_return_element_type(
// -----
-mesh.cluster @mesh0(rank = 1, dim_sizes = 1)
+mesh.cluster @mesh0(shape = 1)
func.func @gather_invalid_non_gather_axis_dimension_size(
%arg0 : tensor<3x4xf32>) -> tensor<3x5xf32> {
@@ -473,7 +468,7 @@ func.func @gather_invalid_non_gather_axis_dimension_size(
// -----
-mesh.cluster @mesh0(rank = 2, dim_sizes = 1x2)
+mesh.cluster @mesh0(shape = 1x2)
func.func @gather_invalid_gather_axis_dimension_size(
%arg0 : tensor<3x4xf32>) -> tensor<3x5xf32> {
@@ -485,7 +480,7 @@ func.func @gather_invalid_gather_axis_dimension_size(
// -----
-mesh.cluster @mesh0(rank = 1, dim_sizes = 1)
+mesh.cluster @mesh0(shape = 1)
func.func @gather_invalid_gather_axis_dynamic_dimension(
%arg0 : tensor<?xf32>) -> tensor<3xf32> {
@@ -497,7 +492,7 @@ func.func @gather_invalid_gather_axis_dynamic_dimension(
// -----
-mesh.cluster @mesh0(rank = 1, dim_sizes = 1)
+mesh.cluster @mesh0(shape = 1)
func.func @gather_invalid_gather_axis(
%arg0 : tensor<3xf32>) -> tensor<3xf32> {
@@ -509,7 +504,7 @@ func.func @gather_invalid_gather_axis(
// -----
-mesh.cluster @mesh0(rank = 1, dim_sizes = 1)
+mesh.cluster @mesh0(shape = 1)
func.func @gather_invalid_negative_gather_axis(
%arg0 : tensor<3xf32>) -> tensor<3xf32> {
@@ -521,7 +516,7 @@ func.func @gather_invalid_negative_gather_axis(
// -----
-mesh.cluster @mesh0(rank = 2, dim_sizes = 3x?)
+mesh.cluster @mesh0(shape = 3x?)
func.func @gather_root_dimension_out_of_bounds(
%arg0 : tensor<2xi8>) -> tensor<6xi8> {
@@ -534,7 +529,7 @@ func.func @gather_root_dimension_out_of_bounds(
// -----
-mesh.cluster @mesh0(rank = 2, dim_sizes = 3x?)
+mesh.cluster @mesh0(shape = 3x?)
func.func @gather_root_wrong_number_dimensions(
%arg0 : tensor<2xi8>) -> tensor<2xi8> {
@@ -547,7 +542,7 @@ func.func @gather_root_wrong_number_dimensions(
// -----
-mesh.cluster @mesh0(rank = 2, dim_sizes = 3x?)
+mesh.cluster @mesh0(shape = 3x?)
func.func @receive_source_dimension_out_of_bounds(
%arg0 : tensor<2xi8>) -> tensor<2xi8> {
@@ -560,7 +555,7 @@ func.func @receive_source_dimension_out_of_bounds(
// -----
-mesh.cluster @mesh0(rank = 2, dim_sizes = 3x?)
+mesh.cluster @mesh0(shape = 3x?)
func.func @receive_source_wrong_number_dimensions(
%arg0 : tensor<2xi8>) -> tensor<2xi8> {
@@ -573,7 +568,7 @@ func.func @receive_source_wrong_number_dimensions(
// -----
-mesh.cluster @mesh0(rank = 2, dim_sizes = 3x?)
+mesh.cluster @mesh0(shape = 3x?)
func.func @receive_different_input_and_result_type(
%arg0 : tensor<2xi8>) -> tensor<2xi16> {
@@ -586,7 +581,7 @@ func.func @receive_different_input_and_result_type(
// -----
-mesh.cluster @mesh0(rank = 2, dim_sizes = 3x?)
+mesh.cluster @mesh0(shape = 3x?)
func.func @reduce_root_dimension_out_of_bounds(
%arg0 : tensor<2xi8>) -> tensor<2xi8> {
@@ -599,7 +594,7 @@ func.func @reduce_root_dimension_out_of_bounds(
// -----
-mesh.cluster @mesh0(rank = 2, dim_sizes = 3x?)
+mesh.cluster @mesh0(shape = 3x?)
func.func @reduce_root_wrong_number_dimensions(
%arg0 : tensor<2xi8>) -> tensor<2xi8> {
@@ -612,7 +607,7 @@ func.func @reduce_root_wrong_number_dimensions(
// -----
-mesh.cluster @mesh0(rank = 2, dim_sizes = 3x?)
+mesh.cluster @mesh0(shape = 3x?)
func.func @reduce_different_input_and_result_shape(
%arg0 : tensor<2xi8>) -> tensor<3xi16> {
@@ -625,7 +620,7 @@ func.func @reduce_different_input_and_result_shape(
// -----
-mesh.cluster @mesh0(rank = 1, dim_sizes = 3)
+mesh.cluster @mesh0(shape = 3)
func.func @reduce_scatter_duplicate_mesh_axis(
%arg0 : tensor<?xf32>) -> tensor<?xf64> {
@@ -637,7 +632,7 @@ func.func @reduce_scatter_duplicate_mesh_axis(
// -----
-mesh.cluster @mesh0(rank = 1, dim_sizes = 3)
+mesh.cluster @mesh0(shape = 3)
func.func @reduce_scatter_invalid_dynamic_dimension(
%arg0 : tensor<?xf32>) -> tensor<2xf64> {
@@ -649,7 +644,7 @@ func.func @reduce_scatter_invalid_dynamic_dimension(
// -----
-mesh.cluster @mesh0(rank = 1, dim_sizes = 3)
+mesh.cluster @mesh0(shape = 3)
func.func @reduce_scatter_invalid_static_dimension_size(
%arg0 : tensor<3xf32>) -> tensor<2xf64> {
@@ -661,7 +656,7 @@ func.func @reduce_scatter_invalid_static_dimension_size(
// -----
-mesh.cluster @mesh0(rank = 1, dim_sizes = 3)
+mesh.cluster @mesh0(shape = 3)
func.func @reduce_scatter_invalid_operand_static_dimension_size(
%arg0 : tensor<4xf32>) -> tensor<?xf64> {
@@ -673,7 +668,7 @@ func.func @reduce_scatter_invalid_operand_static_dimension_size(
// -----
-mesh.cluster @mesh0(rank = 1, dim_sizes = 3)
+mesh.cluster @mesh0(shape = 3)
func.func @scatter_duplicate_mesh_axis(
%arg0 : tensor<?xf32>) -> tensor<?xf32> {
@@ -686,7 +681,7 @@ func.func @scatter_duplicate_mesh_axis(
// -----
-mesh.cluster @mesh0(rank = 1, dim_sizes = 3)
+mesh.cluster @mesh0(shape = 3)
func.func @scatter_invalid_dynamic_dimension(
%arg0 : tensor<?xf32>) -> tensor<2xf32> {
@@ -699,7 +694,7 @@ func.func @scatter_invalid_dynamic_dimension(
// -----
-mesh.cluster @mesh0(rank = 1, dim_sizes = 3)
+mesh.cluster @mesh0(shape = 3)
func.func @scatter_invalid_static_dimension_size(
%arg0 : tensor<3xf32>) -> tensor<2xf32> {
@@ -712,7 +707,7 @@ func.func @scatter_invalid_static_dimension_size(
// -----
-mesh.cluster @mesh0(rank = 1, dim_sizes = 3)
+mesh.cluster @mesh0(shape = 3)
func.func @scatter_invalid_operand_static_dimension_size(
%arg0 : tensor<4xf32>) -> tensor<?xf32> {
@@ -725,7 +720,7 @@ func.func @scatter_invalid_operand_static_dimension_size(
// -----
-mesh.cluster @mesh0(rank = 2, dim_sizes = 3x?)
+mesh.cluster @mesh0(shape = 3x?)
func.func @scatter_root_dimension_out_of_bounds(
%arg0 : tensor<3xi8>) -> tensor<1xi8> {
@@ -738,7 +733,7 @@ func.func @scatter_root_dimension_out_of_bounds(
// -----
-mesh.cluster @mesh0(rank = 2, dim_sizes = 3x?)
+mesh.cluster @mesh0(shape = 3x?)
func.func @scatter_root_wrong_number_dimensions(
%arg0 : tensor<3xi8>) -> tensor<1xi8> {
@@ -751,7 +746,7 @@ func.func @scatter_root_wrong_number_dimensions(
// -----
-mesh.cluster @mesh0(rank = 2, dim_sizes = 3x?)
+mesh.cluster @mesh0(shape = 3x?)
func.func @send_destination_dimension_out_of_bounds(
%arg0 : tensor<2xi8>) -> tensor<2xi8> {
@@ -764,7 +759,7 @@ func.func @send_destination_dimension_out_of_bounds(
// -----
-mesh.cluster @mesh0(rank = 2, dim_sizes = 3x?)
+mesh.cluster @mesh0(shape = 3x?)
func.func @send_destination_wrong_number_dimensions(
%arg0 : tensor<2xi8>) -> tensor<2xi8> {
@@ -777,7 +772,7 @@ func.func @send_destination_wrong_number_dimensions(
// -----
-mesh.cluster @mesh0(rank = 2, dim_sizes = 3x?)
+mesh.cluster @mesh0(shape = 3x?)
func.func @send_different_input_and_result_type(
%arg0 : tensor<2xi8>) -> tensor<2xi16> {
@@ -801,7 +796,7 @@ func.func @shift_invalid_mesh_symbol(
// -----
-mesh.cluster @mesh0(rank = 2, dim_sizes = 2x4)
+mesh.cluster @mesh0(shape = 2x4)
func.func @shift_invalid_mesh_axis(
%arg0 : tensor<4xi8>) -> tensor<4xi8> {
@@ -814,7 +809,7 @@ func.func @shift_invalid_mesh_axis(
// -----
-mesh.cluster @mesh0(rank = 2, dim_sizes = 2x4)
+mesh.cluster @mesh0(shape = 2x4)
func.func @shift_duplicate_mesh_axis(
%arg0 : tensor<4xi8>) -> tensor<4xi8> {
@@ -827,7 +822,7 @@ func.func @shift_duplicate_mesh_axis(
// -----
-mesh.cluster @mesh0(rank = 2, dim_sizes = 2x4)
+mesh.cluster @mesh0(shape = 2x4)
func.func @shift_invalid_tensor_dimension_size(
%arg0 : tensor<4xi8>) -> tensor<5xi8> {
@@ -840,7 +835,7 @@ func.func @shift_invalid_tensor_dimension_size(
// -----
-mesh.cluster @mesh0(rank = 2, dim_sizes = 2x4)
+mesh.cluster @mesh0(shape = 2x4)
func.func @shift_invalid_shift_axis(
%arg0 : tensor<4xi8>) -> tensor<4xi8> {
diff --git a/mlir/test/Dialect/Mesh/ops.mlir b/mlir/test/Dialect/Mesh/ops.mlir
index 0abe31c8b3f7db..0aaa4bdee1db3a 100644
--- a/mlir/test/Dialect/Mesh/ops.mlir
+++ b/mlir/test/Dialect/Mesh/ops.mlir
@@ -1,21 +1,21 @@
// RUN: mlir-opt %s | mlir-opt | FileCheck %s
// CHECK: mesh.cluster @mesh0
-mesh.cluster @mesh0(rank = 3, dim_sizes = 2x2x4)
+mesh.cluster @mesh0(shape = 2x2x4)
-// CHECK: mesh.cluster @mesh1(rank = 2, dim_sizes = 4)
-mesh.cluster @mesh1(rank = 2, dim_sizes = 4)
+// CHECK: mesh.cluster @mesh1(shape = 4x?)
+mesh.cluster @mesh1(shape = 4x?)
-// CHECK: mesh.cluster @mesh2(rank = 2, dim_sizes = ?x4)
-mesh.cluster @mesh2(rank = 2, dim_sizes = ?x4)
+// CHECK: mesh.cluster @mesh2(shape = ?x4)
+mesh.cluster @mesh2(shape = ?x4)
-// CHECK: mesh.cluster @mesh3
-mesh.cluster @mesh3(rank = 2)
+// CHECK: mesh.cluster @mesh3(shape = ?x?)
+mesh.cluster @mesh3(shape = ?x?)
-mesh.cluster @mesh4(rank = 1, dim_sizes = 3)
+mesh.cluster @mesh4(shape = 3)
-// CHECK: mesh.cluster @mesh5(rank = 1)
-mesh.cluster @mesh5(rank = 1, dim_sizes = [])
+// CHECK: mesh.cluster @mesh5(shape = ?)
+mesh.cluster @mesh5(shape = ?)
// CHECK-LABEL: func @mesh_shard_encoding_fully_replicated
func.func @mesh_shard_encoding_fully_replicated(
diff --git a/mlir/test/Dialect/Mesh/process-multi-index-op-lowering.mlir b/mlir/test/Dialect/Mesh/process-multi-index-op-lowering.mlir
index 9602fb729c2681..aeeba4ea3f1979 100644
--- a/mlir/test/Dialect/Mesh/process-multi-index-op-lowering.mlir
+++ b/mlir/test/Dialect/Mesh/process-multi-index-op-lowering.mlir
@@ -1,6 +1,6 @@
// RUN: mlir-opt -test-mesh-process-multi-index-op-lowering %s | FileCheck %s
-mesh.cluster @mesh2d(rank = 2)
+mesh.cluster @mesh2d(shape = ?x?)
// CHECK-LABEL: func.func @multi_index_2d_mesh
func.func @multi_index_2d_mesh() -> (index, index) {
diff --git a/mlir/test/Dialect/Mesh/resharding-spmdization.mlir b/mlir/test/Dialect/Mesh/resharding-spmdization.mlir
index 786ea386df815a..3f5c7d80bf9c7e 100644
--- a/mlir/test/Dialect/Mesh/resharding-spmdization.mlir
+++ b/mlir/test/Dialect/Mesh/resharding-spmdization.mlir
@@ -1,7 +1,7 @@
// RUN: mlir-opt -test-mesh-resharding-spmdization %s | FileCheck %s
-mesh.cluster @mesh_1d(rank = 1, dim_sizes = 2)
-mesh.cluster @mesh_1d_dynamic(rank = 1, dim_sizes = ?)
+mesh.cluster @mesh_1d(shape = 2)
+mesh.cluster @mesh_1d_dynamic(shape = ?)
// CHECK-LABEL: func @same_source_and_target_sharding
func.func @same_source_and_target_sharding(
diff --git a/mlir/test/Dialect/Mesh/sharding-propagation.mlir b/mlir/test/Dialect/Mesh/sharding-propagation.mlir
index 30bbd5c6619e8a..065ae9ca8c6b41 100644
--- a/mlir/test/Dialect/Mesh/sharding-propagation.mlir
+++ b/mlir/test/Dialect/Mesh/sharding-propagation.mlir
@@ -1,8 +1,8 @@
// RUN: mlir-opt -sharding-propagation %s | FileCheck %s
-mesh.cluster @mesh_1d(rank = 1)
-mesh.cluster @mesh_2d(rank = 2, dim_sizes = 2x4)
-mesh.cluster @mesh_3d(rank = 3)
+mesh.cluster @mesh_1d(shape = ?)
+mesh.cluster @mesh_2d(shape = 2x4)
+mesh.cluster @mesh_3d(shape = ?x?x?)
// CHECK-LABEL: func.func @element_wise_empty_sharding_info
func.func @element_wise_empty_sharding_info(%arg0: tensor<8x16xf32>) -> tensor<8x16xf32> {
diff --git a/mlir/test/Dialect/Mesh/simplifications.mlir b/mlir/test/Dialect/Mesh/simplifications.mlir
index e716940f2301e9..63ae9d528df1b3 100644
--- a/mlir/test/Dialect/Mesh/simplifications.mlir
+++ b/mlir/test/Dialect/Mesh/simplifications.mlir
@@ -1,7 +1,7 @@
// RUN: mlir-opt -test-mesh-simplifications %s | FileCheck %s
-mesh.cluster @mesh0(rank = 2, dim_sizes = 4x2)
-mesh.cluster @mesh1(rank = 1, dim_sizes = 4)
+mesh.cluster @mesh0(shape = 4x2)
+mesh.cluster @mesh1(shape = 4)
// Checks that `all_reduce(x) + all_reduce(y)` gets transformed to
// `all_reduce(x + y)`.
More information about the Mlir-commits
mailing list