[Mlir-commits] [mlir] 5df2c00 - [mlir][mesh] Remove rank attribute and rename dim_sizes to shape in ClusterOp (#77838)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Jan 15 07:39:13 PST 2024


Author: Boian Petkantchin
Date: 2024-01-15T07:39:09-08:00
New Revision: 5df2c00af381326340dd2f75615c4b4222ae6d96

URL: https://github.com/llvm/llvm-project/commit/5df2c00af381326340dd2f75615c4b4222ae6d96
DIFF: https://github.com/llvm/llvm-project/commit/5df2c00af381326340dd2f75615c4b4222ae6d96.diff

LOG: [mlir][mesh] Remove rank attribute and rename dim_sizes to shape in ClusterOp (#77838)

Remove the somewhat redundant rank attribute.
Before this change
```
mesh.cluster @mesh(rank = 3, dim_sizes = 2x3)
```
After
```
mesh.cluster @mesh(shape = 2x3x?)
```

The rank is instead determined by the provided shape. With this change
no longer `getDimSizes()` can be wrongly assumed to have size equal to
the cluster rank.
Now `getShape().size()` will always equal `getRank()`.

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td
    mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
    mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
    mlir/lib/Dialect/Mesh/Transforms/Simplifications.cpp
    mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
    mlir/test/Dialect/Mesh/canonicalization.mlir
    mlir/test/Dialect/Mesh/folding.mlir
    mlir/test/Dialect/Mesh/invalid.mlir
    mlir/test/Dialect/Mesh/ops.mlir
    mlir/test/Dialect/Mesh/process-multi-index-op-lowering.mlir
    mlir/test/Dialect/Mesh/resharding-spmdization.mlir
    mlir/test/Dialect/Mesh/sharding-propagation.mlir
    mlir/test/Dialect/Mesh/simplifications.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td b/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td
index bda6467e9c5d4b..07f954459ca49d 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td
@@ -117,7 +117,7 @@ def MeshSharding : AttrDef<Mesh_Dialect, "MeshSharding"> {
     Example:
 
     ```
-    mesh.cluster @mesh0(rank = 3, dim_sizes = [2, 2, 4])
+    mesh.cluster @mesh0(shape = 2x2x4)
 
     // The tensor is fully replicated on @mesh0.
     // Currently, there must be at least one sub-array present in axes, even

diff  --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
index a9068562f5c903..78ff8bd0cac621 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
@@ -36,12 +36,7 @@ def Mesh_ClusterOp : Mesh_Op<"cluster", [Symbol]> {
     cluster. This name serves as a symbolic reference to the cluster throughout
     the MLIR module, allowing for consistent referencing and easier debugging.
 
-    2. `rank`: This attribute specifies the number of axes of the cluster. The
-    rank indicates the dimensionality of the mesh cluster and can be used to
-    determine the layout and the addressing space of the computation distributed
-    across the mesh.
-
-    3. `dim_sizes`: This attribute represents the shape of the device cluster.
+    2. `shape`: This attribute represents the shape of the device cluster.
     It uses the same notation as a tensor shape. Also allowing for dynamic
     dimensions.
     This flexibility allows for dynamic device assignment or configurations
@@ -53,19 +48,19 @@ def Mesh_ClusterOp : Mesh_Op<"cluster", [Symbol]> {
     ```
     // A device mesh cluster with 3 axes, the total device number is 4 * 8 * 12
     // The dimension sizes are 4, 8, 12 
-    mesh.cluster @mesh0(rank = 3, dim_sizes = 4x8x12)
+    mesh.cluster @mesh0(shape = 4x8x12)
 
     // A device mesh cluster with 2 axes, the total device number is unknown
     // The first dimension size is 4 and the second is unknown
-    mesh.cluster @mesh1(rank = 2, dim_sizes = 4)
+    mesh.cluster @mesh1(shape = 4x?)
 
     // A device mesh cluster with 2 axes, the total device number is unknown
     // The first dimension size is unknown and the second is 4
-    mesh.cluster @mesh2(rank = 2, dim_sizes = ?x4)
+    mesh.cluster @mesh2(shape = ?x4)
 
     // A device mesh cluster with 2 axes, the number of devices along both axes
     // is unknown
-    mesh.cluster @mesh3(rank = 2)
+    mesh.cluster @mesh3(shape = ?x?)
 
     // Used in the mesh sharding attribute to extend the standard tensor to
     // distributed
@@ -74,24 +69,14 @@ def Mesh_ClusterOp : Mesh_Op<"cluster", [Symbol]> {
   }];
   let arguments = (ins
     SymbolNameAttr:$sym_name,
-    I64Attr:$rank,
-    DefaultValuedAttr<DenseI64ArrayAttr, "{}">:$dim_sizes
+    DenseI64ArrayAttr:$shape
   );
   let assemblyFormat = [{
-    $sym_name `(` `rank` `=` $rank (`,` `dim_sizes` `=` custom<DimensionList>($dim_sizes)^)? `)`
+    $sym_name `(` `shape` `=` custom<DimensionList>($shape) `)`
       attr-dict
   }];
   let extraClassDeclaration = [{
-    // The `dim_sizes` attribute may have size less than the rank of the mesh.
-    // Returns the shape of the mesh with missing trailing dimensions
-    // explicitly set as dynamic.
-    ::mlir::SmallVector<int64_t> canonicalDimSizes();
-
-    template <typename OutIt>
-    void canonicalDimSizes(OutIt outIt) {
-      std::copy(getDimSizes().begin(), getDimSizes().end(), outIt);
-      std::fill_n(outIt, getRank() - getDimSizes().size(), ::mlir::ShapedType::kDynamic);
-    }
+    int64_t getRank() { return getShape().size(); }
   }];
   let hasVerifier = 1;
 }
@@ -283,7 +268,7 @@ def Mesh_AllGatherOp : Mesh_CollectiveCommunicationOpBase<"all_gather", [
 
     Example:
     ```mlir
-    mesh.cluster @mesh0(rank = 2, dim_sizes = 2x2)
+    mesh.cluster @mesh0(shape = 2x2)
     ...
     %1 = mesh.all_gather %0 on @mesh0 mesh_axes = [1] gather_axis = 1
       : tensor<2x2xi8> -> tensor<2x4xi8>
@@ -368,7 +353,7 @@ def Mesh_AllToAllOp : Mesh_CollectiveCommunicationOpBase<"all_to_all", [
 
     Example:
     ```
-    mesh.cluster @mesh0(rank = 1, dim_sizes = 3)
+    mesh.cluster @mesh0(shape = 3)
     ...
     %1 = mesh.all_to_all %0 on @mesh0 mesh_axes = [0]
       split_axis = 0 concat_axis = 0
@@ -425,7 +410,7 @@ def Mesh_BroadcastOp : Mesh_CollectiveCommunicationOpBase<"broadcast", [
     
     Example:
     ```
-    mesh.cluster @mesh0(rank = 2, dim_sizes = [2, 2])
+    mesh.cluster @mesh0(shape = 2x2)
 
     %1 = mesh.broadcast %0 on @mesh0
       mesh_axes = [0]
@@ -481,7 +466,7 @@ def Mesh_GatherOp : Mesh_CollectiveCommunicationOpBase<"gather", [
 
     Example:
     ```mlir
-    mesh.cluster @mesh0(rank = 2, dim_sizes = [2, 2])
+    mesh.cluster @mesh0(shape = 2x2)
     ...
     %1 = mesh.gather %0 on @mesh0 mesh_axes = [1]
       gather_axis = 1 root = [1]
@@ -604,7 +589,7 @@ def Mesh_ReduceScatterOp : Mesh_CollectiveCommunicationOpBase<"reduce_scatter",
     across the device group.
     Example:
     ```
-    mesh.cluster @mesh0(rank = 1, dim_sizes = 2x2)
+    mesh.cluster @mesh0(shape = 2x2)
     ...
     %1 = mesh.reduce_scatter %0 on @mesh0 mesh_axes = [1]
       reduction = <max> scatter_axis = 0
@@ -667,7 +652,7 @@ def Mesh_ScatterOp : Mesh_CollectiveCommunicationOpBase<"scatter", [
 
     Example:
     ```
-    mesh.cluster @mesh0(rank = 2, dim_sizes = [2, 2])
+    mesh.cluster @mesh0(shape = 2x2)
     %1 = mesh.scatter %0 on @mesh0 mesh_axes = [0]
       scatter_axis = 0
       root = [1]
@@ -763,7 +748,7 @@ def Mesh_ShiftOp : Mesh_CollectiveCommunicationOpBase<"shift", [
 
     Example:
     ```
-    mesh.cluster @mesh0(rank = 2, dim_sizes = [2, 4])
+    mesh.cluster @mesh0(shape = 2x4)
     %1 = mesh.shift on @mesh0 mesh_axes = [1]
       shift_axis = 1 offset = 2 rotate
       : tensor<2xi8> -> tensor<2xi8>

diff  --git a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
index 957b380efd516b..fa9da596a34587 100644
--- a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
+++ b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
@@ -196,17 +196,16 @@ Partial mesh::getPartialTypeFromReduction(IteratorType iType) {
 //===----------------------------------------------------------------------===//
 
 LogicalResult ClusterOp::verify() {
-  ArrayRef<int64_t> dimSizes = getDimSizes();
-  uint64_t rank = getRank();
+  int64_t rank = getRank();
 
-  if (rank == 0)
+  if (rank <= 0)
     return emitOpError("rank of cluster is expected to be a positive integer");
 
-  if (dimSizes.size() > rank)
+  if (getShape().size() > rank)
     return emitOpError(
-        "rank of dim_sizes is not expected to be larger than rank of cluster");
+        "rank of shape is not expected to be larger than rank of cluster");
 
-  for (int64_t dimSize : dimSizes) {
+  for (int64_t dimSize : getShape()) {
     if (dimSize < 0 && !ShapedType::isDynamic(dimSize))
       return emitOpError("dimension size of a mesh cluster is expected to be "
                          "non-negative or dynamic");
@@ -215,13 +214,6 @@ LogicalResult ClusterOp::verify() {
   return success();
 }
 
-SmallVector<int64_t> ClusterOp::canonicalDimSizes() {
-  SmallVector<int64_t> result;
-  canonicalDimSizes(std::back_inserter(result));
-  result.reserve(getRank());
-  return result;
-}
-
 //===----------------------------------------------------------------------===//
 // mesh.cluster_shape op
 //===----------------------------------------------------------------------===//
@@ -614,7 +606,7 @@ AllGatherOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
   auto gatherAxis = getGatherAxis().getSExtValue();
   return verifyGatherOperandAndResultShape(getOperand(), getResult(),
                                            gatherAxis, getMeshAxes(),
-                                           mesh.value().canonicalDimSizes());
+                                           mesh.value().getShape());
 }
 
 void AllGatherOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
@@ -648,8 +640,7 @@ LogicalResult AllToAllOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
 
   return verifyAllToAllOperandAndResultShape(
       getOperand(), getResult(), getSplitAxis().getSExtValue(),
-      getConcatAxis().getSExtValue(), getMeshAxes(),
-      mesh.value().canonicalDimSizes());
+      getConcatAxis().getSExtValue(), getMeshAxes(), mesh.value().getShape());
 }
 
 void AllToAllOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
@@ -667,9 +658,9 @@ BroadcastOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
   if (failed(mesh)) {
     return failure();
   }
-  auto meshShape = mesh.value().canonicalDimSizes();
   if (failed(verifyInGroupDevice(getLoc(), getRootAttrName(), getRoot(),
-                                 getRootDynamic(), getMeshAxes(), meshShape))) {
+                                 getRootDynamic(), getMeshAxes(),
+                                 mesh.value().getShape()))) {
     return failure();
   }
 
@@ -690,16 +681,16 @@ LogicalResult GatherOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
   if (failed(mesh)) {
     return failure();
   }
-  auto meshShape = mesh.value().canonicalDimSizes();
   if (failed(verifyInGroupDevice(getLoc(), getRootAttrName(), getRoot(),
-                                 getRootDynamic(), getMeshAxes(), meshShape))) {
+                                 getRootDynamic(), getMeshAxes(),
+                                 mesh.value().getShape()))) {
     return failure();
   }
 
   auto gatherAxis = getGatherAxis().getSExtValue();
   return verifyGatherOperandAndResultShape(getInput(), getResult(), gatherAxis,
                                            getMeshAxes(),
-                                           mesh.value().canonicalDimSizes());
+                                           mesh.value().getShape());
 }
 
 void GatherOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
@@ -716,10 +707,10 @@ LogicalResult RecvOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
   if (failed(mesh)) {
     return failure();
   }
-  auto meshShape = mesh.value().canonicalDimSizes();
-  if (getSource() && failed(verifyInGroupDevice(
-                         getLoc(), getSourceAttrName(), getSource().value(),
-                         getSourceDynamic(), getMeshAxes(), meshShape))) {
+  if (getSource() &&
+      failed(verifyInGroupDevice(getLoc(), getSourceAttrName(),
+                                 getSource().value(), getSourceDynamic(),
+                                 getMeshAxes(), mesh.value().getShape()))) {
     return failure();
   }
   return success();
@@ -739,9 +730,9 @@ LogicalResult ReduceOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
   if (failed(mesh)) {
     return failure();
   }
-  auto meshShape = mesh.value().canonicalDimSizes();
   if (failed(verifyInGroupDevice(getLoc(), getRootAttrName(), getRoot(),
-                                 getRootDynamic(), getMeshAxes(), meshShape))) {
+                                 getRootDynamic(), getMeshAxes(),
+                                 mesh.value().getShape()))) {
     return failure();
   }
 
@@ -766,7 +757,7 @@ ReduceScatterOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
 
   return verifyScatterOperandAndResultShape(
       getOperand(), getResult(), getScatterAxis().getSExtValue(), getMeshAxes(),
-      mesh.value().canonicalDimSizes());
+      mesh.value().getShape());
 }
 
 void ReduceScatterOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
@@ -783,16 +774,16 @@ LogicalResult ScatterOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
   if (failed(mesh)) {
     return failure();
   }
-  auto meshShape = mesh.value().canonicalDimSizes();
   if (failed(verifyInGroupDevice(getLoc(), getRootAttrName(), getRoot(),
-                                 getRootDynamic(), getMeshAxes(), meshShape))) {
+                                 getRootDynamic(), getMeshAxes(),
+                                 mesh.value().getShape()))) {
     return failure();
   }
 
   auto scatterAxis = getScatterAxis().getSExtValue();
   return verifyScatterOperandAndResultShape(getInput(), getResult(),
                                             scatterAxis, getMeshAxes(),
-                                            mesh.value().canonicalDimSizes());
+                                            mesh.value().getShape());
 }
 
 void ScatterOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
@@ -809,10 +800,9 @@ LogicalResult SendOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
   if (failed(mesh)) {
     return failure();
   }
-  auto meshShape = mesh.value().canonicalDimSizes();
   if (failed(verifyInGroupDevice(getLoc(), getDestinationAttrName(),
                                  getDestination(), getDestinationDynamic(),
-                                 getMeshAxes(), meshShape))) {
+                                 getMeshAxes(), mesh.value().getShape()))) {
     return failure();
   }
   return success();

diff  --git a/mlir/lib/Dialect/Mesh/Transforms/Simplifications.cpp b/mlir/lib/Dialect/Mesh/Transforms/Simplifications.cpp
index 67e1bf6320dbf3..429e684c845fb1 100644
--- a/mlir/lib/Dialect/Mesh/Transforms/Simplifications.cpp
+++ b/mlir/lib/Dialect/Mesh/Transforms/Simplifications.cpp
@@ -80,7 +80,7 @@ struct ClusterShapeFolder : OpRewritePattern<ClusterShapeOp> {
       opMeshAxes = opAxesIota;
     }
     if (llvm::all_of(opMeshAxes, [&mesh](MeshAxis axis) {
-          return ShapedType::isDynamic(mesh.getDimSizes()[axis]);
+          return ShapedType::isDynamic(mesh.getShape()[axis]);
         })) {
       // All mesh dimensions are dynamic. Nothing to fold.
       return failure();
@@ -91,7 +91,7 @@ struct ClusterShapeFolder : OpRewritePattern<ClusterShapeOp> {
     SmallVector<size_t> newToOldResultsIndexMap;
 
     for (size_t i = 0; i < opMeshAxes.size(); ++i) {
-      auto meshAxisSize = mesh.getDimSizes()[opMeshAxes[i]];
+      auto meshAxisSize = mesh.getShape()[opMeshAxes[i]];
       if (ShapedType::isDynamic(meshAxisSize)) {
         newToOldResultsIndexMap.push_back(i);
         newShapeOpMeshAxes.push_back(opMeshAxes[i]);

diff  --git a/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp b/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
index 0e83c024fc08f8..9478b2e4ee5cb2 100644
--- a/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
+++ b/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
@@ -88,8 +88,8 @@ ShapedType shardShapedType(ShapedType shape, ClusterOp mesh,
                            MeshShardingAttr sharding) {
   using Dim = std::decay_t<decltype(shape.getDimSize(0))>;
   SmallVector<Dim> resShapeArr(shape.getShape().size());
-  shardShape(shape.getShape(), mesh.canonicalDimSizes(),
-             sharding.getSplitAxes(), resShapeArr);
+  shardShape(shape.getShape(), mesh.getShape(), sharding.getSplitAxes(),
+             resShapeArr);
   return shape.clone(resShapeArr);
 }
 
@@ -212,9 +212,8 @@ splitLastAxisInResharding(ImplicitLocOpBuilder &builder,
 
   MeshShardingAttr targetSharding = targetShardingInSplitLastAxis(
       ctx, sourceSharding, splitTensorAxis, splitMeshAxis);
-  ShapedType targetShape =
-      targetShapeInSplitLastAxis(sourceShard.getType(), splitTensorAxis,
-                                 mesh.canonicalDimSizes()[splitMeshAxis]);
+  ShapedType targetShape = targetShapeInSplitLastAxis(
+      sourceShard.getType(), splitTensorAxis, mesh.getShape()[splitMeshAxis]);
 
   Value meshAxisSize =
       builder
@@ -391,8 +390,7 @@ unsplitLastAxisInResharding(ImplicitLocOpBuilder &builder,
   MeshShardingAttr targetSharding =
       targetShardingInUnsplitLastAxis(ctx, sourceSharding, splitMeshAxis);
   ShapedType allGatherResultShape = allGatherResultShapeInUnsplitLastAxis(
-      sourceShard.getType(), mesh.canonicalDimSizes()[splitMeshAxis],
-      splitTensorAxis);
+      sourceShard.getType(), mesh.getShape()[splitMeshAxis], splitTensorAxis);
   Value allGatherResult = builder.create<AllGatherOp>(
       RankedTensorType::get(allGatherResultShape.getShape(),
                             allGatherResultShape.getElementType()),
@@ -526,8 +524,8 @@ moveLastSplitAxisInResharding(ImplicitLocOpBuilder &builder, ClusterOp mesh,
   MeshShardingAttr targetSharding = targetShardingInMoveLastAxis(
       ctx, sourceSharding, sourceTensorAxis, targetTensorAxis);
   ShapedType allToAllResultShape = allToAllResultShapeInMoveLastAxis(
-      sourceShard.getType(), mesh.canonicalDimSizes()[meshAxis],
-      sourceTensorAxis, targetTensorAxis);
+      sourceShard.getType(), mesh.getShape()[meshAxis], sourceTensorAxis,
+      targetTensorAxis);
   Value allToAllResult = builder.create<AllToAllOp>(
       RankedTensorType::get(allToAllResultShape.getShape(),
                             allToAllResultShape.getElementType()),

diff  --git a/mlir/test/Dialect/Mesh/canonicalization.mlir b/mlir/test/Dialect/Mesh/canonicalization.mlir
index 0a00ab41268d01..4cc009ef24eb3c 100644
--- a/mlir/test/Dialect/Mesh/canonicalization.mlir
+++ b/mlir/test/Dialect/Mesh/canonicalization.mlir
@@ -1,6 +1,6 @@
 // RUN: mlir-opt --canonicalize %s | FileCheck %s
 
-mesh.cluster @mesh0(rank = 2, dim_sizes = 2x4)
+mesh.cluster @mesh0(shape = 2x4)
 
 // CHECK-LABEL: func @all_reduce_empty_mesh_axes
 func.func @all_reduce_empty_mesh_axes(

diff  --git a/mlir/test/Dialect/Mesh/folding.mlir b/mlir/test/Dialect/Mesh/folding.mlir
index dd64d746341b83..9162dc57ecfdf4 100644
--- a/mlir/test/Dialect/Mesh/folding.mlir
+++ b/mlir/test/Dialect/Mesh/folding.mlir
@@ -1,7 +1,7 @@
 // RUN: mlir-opt -test-mesh-simplifications %s | FileCheck %s
 
-mesh.cluster @mesh0(rank = 3, dim_sizes = 4x?x2)
-mesh.cluster @mesh1(rank = 2, dim_sizes = 2x3)
+mesh.cluster @mesh0(shape = 4x?x2)
+mesh.cluster @mesh1(shape = 2x3)
 
 // CHECK-LABEL: func.func @cluster_shape_op_folding
 func.func @cluster_shape_op_folding() -> (index, index) {

diff  --git a/mlir/test/Dialect/Mesh/invalid.mlir b/mlir/test/Dialect/Mesh/invalid.mlir
index f3524a82a1b9d2..8a1fb80065573b 100644
--- a/mlir/test/Dialect/Mesh/invalid.mlir
+++ b/mlir/test/Dialect/Mesh/invalid.mlir
@@ -1,21 +1,16 @@
 // RUN: mlir-opt -split-input-file -verify-diagnostics %s
 
 // expected-error at +1 {{rank of cluster is expected to be a positive integer}}
-mesh.cluster @mesh0(rank = 0)
-
-// -----
-
-// expected-error at +1 {{rank of dim_sizes is not expected to be larger than rank of cluster}}
-mesh.cluster @mesh0(rank = 2, dim_sizes = 2x3x4)
+mesh.cluster @mesh0(shape = [])
 
 // -----
 
 // expected-error at +1 {{custom op 'mesh.cluster' Failed parsing dimension list. Did you mean an empty list? It must be denoted by "[]".}}
-mesh.cluster @mesh0(rank = 2, dim_sizes = -1)
+mesh.cluster @mesh0(shape = -1)
 
 // -----
 
-mesh.cluster @mesh0(rank = 2, dim_sizes = 2x4)
+mesh.cluster @mesh0(shape = 2x4)
 
 func.func @mesh_axis_duplicated_
diff erent_subarray(
     // expected-error at +1 {{mesh axis duplicated}}
@@ -26,7 +21,7 @@ func.func @mesh_axis_duplicated_
diff erent_subarray(
 
 // -----
 
-mesh.cluster @mesh0(rank = 2, dim_sizes = 2x4)
+mesh.cluster @mesh0(shape = 2x4)
 
 func.func @mesh_axis_duplicated_same_subarray(
     // expected-error at +1 {{mesh axis duplicated}}
@@ -37,7 +32,7 @@ func.func @mesh_axis_duplicated_same_subarray(
 
 // -----
 
-mesh.cluster @mesh0(rank = 2, dim_sizes = 2x4)
+mesh.cluster @mesh0(shape = 2x4)
 
 func.func @mesh_axis_duplicated_bewteen_split_and_partial(
     // expected-error at +1 {{mesh axis duplicated}}
@@ -48,7 +43,7 @@ func.func @mesh_axis_duplicated_bewteen_split_and_partial(
 
 // -----
 
-mesh.cluster @mesh0(rank = 2, dim_sizes = 2x4)
+mesh.cluster @mesh0(shape = 2x4)
 
 func.func @mesh_axis_negtive_in_split_part(
     // expected-error at +1 {{mesh axis is expected to be non-negative}}
@@ -59,7 +54,7 @@ func.func @mesh_axis_negtive_in_split_part(
 
 // -----
 
-mesh.cluster @mesh0(rank = 2, dim_sizes = 2x4)
+mesh.cluster @mesh0(shape = 2x4)
 
 func.func @mesh_axis_negtive_in_partial(
     // expected-error at +1 {{mesh axis is expected to be non-negative}}
@@ -78,7 +73,7 @@ func.func @sharding_attribute_invalid_nested_symbol(%arg0 : tensor<4x8xf32>) {
 
 // -----
 
-mesh.cluster @mesh0(rank = 2, dim_sizes = 2x4)
+mesh.cluster @mesh0(shape = 2x4)
 
 func.func @cluster_shape_mesh_axis_out_of_bounds() -> (index, index) {
   // expected-error at +1 {{0-based mesh axis index 2 is out of bounds. The referenced mesh "mesh0" is of rank 2.}}
@@ -88,7 +83,7 @@ func.func @cluster_shape_mesh_axis_out_of_bounds() -> (index, index) {
 
 // -----
 
-mesh.cluster @mesh0(rank = 3, dim_sizes = 1x2x3)
+mesh.cluster @mesh0(shape = 1x2x3)
 
 func.func @cluster_shape_duplicate_mesh_axis() -> (index, index, index) {
   // expected-error at +1 {{Mesh axes contains duplicate elements.}}
@@ -98,7 +93,7 @@ func.func @cluster_shape_duplicate_mesh_axis() -> (index, index, index) {
 
 // -----
 
-mesh.cluster @mesh0(rank = 2, dim_sizes = 2x4)
+mesh.cluster @mesh0(shape = 2x4)
 
 func.func @cluster_shape_wrong_number_of_results() -> (index, index) {
   // expected-error at +1 {{Unexpected number of results 2. Expected 1.}}
@@ -108,7 +103,7 @@ func.func @cluster_shape_wrong_number_of_results() -> (index, index) {
 
 // -----
 
-mesh.cluster @mesh0(rank = 3, dim_sizes = 1x2x3)
+mesh.cluster @mesh0(shape = 1x2x3)
 
 func.func @cluster_shape_wrong_number_of_results_empty_mesh_axes() -> (index, index) {
   // expected-error at +1 {{Unexpected number of results 2. Expected 3.}}
@@ -126,7 +121,7 @@ func.func @cluster_shape_invalid_mesh_name() -> (index) {
 
 // -----
 
-mesh.cluster @mesh0(rank = 2, dim_sizes = 2x4)
+mesh.cluster @mesh0(shape = 2x4)
 
 func.func @process_multi_index_mesh_axis_out_of_bounds() -> (index, index) {
   // expected-error at +1 {{0-based mesh axis index 2 is out of bounds. The referenced mesh "mesh0" is of rank 2.}}
@@ -136,7 +131,7 @@ func.func @process_multi_index_mesh_axis_out_of_bounds() -> (index, index) {
 
 // -----
 
-mesh.cluster @mesh0(rank = 3, dim_sizes = 1x2x3)
+mesh.cluster @mesh0(shape = 1x2x3)
 
 func.func @process_multi_index_duplicate_mesh_axis() -> (index, index, index) {
   // expected-error at +1 {{Mesh axes contains duplicate elements.}}
@@ -146,7 +141,7 @@ func.func @process_multi_index_duplicate_mesh_axis() -> (index, index, index) {
 
 // -----
 
-mesh.cluster @mesh0(rank = 2, dim_sizes = 2x4)
+mesh.cluster @mesh0(shape = 2x4)
 
 func.func @process_multi_index_wrong_number_of_results() -> (index, index) {
   // expected-error at +1 {{Unexpected number of results 2. Expected 1.}}
@@ -156,7 +151,7 @@ func.func @process_multi_index_wrong_number_of_results() -> (index, index) {
 
 // -----
 
-mesh.cluster @mesh0(rank = 3, dim_sizes = 1x2x3)
+mesh.cluster @mesh0(shape = 1x2x3)
 
 func.func @process_multi_index_wrong_number_of_results_empty_mesh_axes() -> (index, index) {
   // expected-error at +1 {{Unexpected number of results 2. Expected 3.}}
@@ -192,7 +187,7 @@ func.func @all_reduce_invalid_mesh_symbol(
 
 // -----
 
-mesh.cluster @mesh0(rank = 2, dim_sizes = 2x4)
+mesh.cluster @mesh0(shape = 2x4)
 
 func.func @all_reduce_invalid_mesh_axis(
     %arg0 : tensor<4xf32>) -> tensor<4xf64> {
@@ -204,7 +199,7 @@ func.func @all_reduce_invalid_mesh_axis(
 
 // -----
 
-mesh.cluster @mesh0(rank = 2, dim_sizes = 2x4)
+mesh.cluster @mesh0(shape = 2x4)
 
 func.func @all_reduce_duplicate_mesh_axis(
     %arg0 : tensor<4xf32>) -> tensor<4xf64> {
@@ -216,7 +211,7 @@ func.func @all_reduce_duplicate_mesh_axis(
 
 // -----
 
-mesh.cluster @mesh0(rank = 2, dim_sizes = 2x4)
+mesh.cluster @mesh0(shape = 2x4)
 
 func.func @all_reduce_invalid_tensor_dimension_size(
     %arg0 : tensor<4xf32>) -> tensor<5xf64> {
@@ -237,7 +232,7 @@ func.func @all_gather_invalid_mesh_symbol(
 
 // -----
 
-mesh.cluster @mesh0(rank = 2, dim_sizes = 2x4)
+mesh.cluster @mesh0(shape = 2x4)
 
 func.func @all_gather_invalid_mesh_axis(
     %arg0 : tensor<4xf32>) -> tensor<4xf32> {
@@ -249,7 +244,7 @@ func.func @all_gather_invalid_mesh_axis(
 
 // -----
 
-mesh.cluster @mesh0(rank = 2, dim_sizes = 2x4)
+mesh.cluster @mesh0(shape = 2x4)
 
 func.func @all_reduce_duplicate_mesh_axis(
     %arg0 : tensor<4xf32>) -> tensor<4xf32> {
@@ -261,7 +256,7 @@ func.func @all_reduce_duplicate_mesh_axis(
 
 // -----
 
-mesh.cluster @mesh0(rank = 1, dim_sizes = 1)
+mesh.cluster @mesh0(shape = 1)
 
 func.func @all_gather_invalid_non_gather_axis_dimension_size(
     %arg0 : tensor<3x4xf32>) -> tensor<3x5xf32> {
@@ -273,7 +268,7 @@ func.func @all_gather_invalid_non_gather_axis_dimension_size(
 
 // -----
 
-mesh.cluster @mesh0(rank = 2, dim_sizes = 1x2)
+mesh.cluster @mesh0(shape = 1x2)
 
 func.func @all_gather_invalid_gather_axis_dimension_size(
     %arg0 : tensor<3x4xf32>) -> tensor<3x5xf32> {
@@ -285,7 +280,7 @@ func.func @all_gather_invalid_gather_axis_dimension_size(
 
 // -----
 
-mesh.cluster @mesh0(rank = 1, dim_sizes = 1)
+mesh.cluster @mesh0(shape = 1)
 
 func.func @all_gather_invalid_gather_axis_dynamic_dimension(
     %arg0 : tensor<?xf32>) -> tensor<3xf32> {
@@ -297,7 +292,7 @@ func.func @all_gather_invalid_gather_axis_dynamic_dimension(
 
 // -----
 
-mesh.cluster @mesh0(rank = 1, dim_sizes = 1)
+mesh.cluster @mesh0(shape = 1)
 
 func.func @all_gather_invalid_gather_axis(
     %arg0 : tensor<3xf32>) -> tensor<3xf32> {
@@ -309,7 +304,7 @@ func.func @all_gather_invalid_gather_axis(
 
 // -----
 
-mesh.cluster @mesh0(rank = 1, dim_sizes = 1)
+mesh.cluster @mesh0(shape = 1)
 
 func.func @all_gather_invalid_negative_gather_axis(
     %arg0 : tensor<3xf32>) -> tensor<3xf32> {
@@ -332,7 +327,7 @@ func.func @all_to_all_invalid_mesh_symbol(
 
 // -----
 
-mesh.cluster @mesh0(rank = 1, dim_sizes = 1)
+mesh.cluster @mesh0(shape = 1)
 
 func.func @all_to_all_duplicate_mesh_axis(
     %arg0 : tensor<3x6xi8>) -> tensor<3x6xi8> {
@@ -345,7 +340,7 @@ func.func @all_to_all_duplicate_mesh_axis(
 
 // -----
 
-mesh.cluster @mesh0(rank = 2, dim_sizes = ?x1)
+mesh.cluster @mesh0(shape = ?x1)
 
 func.func @all_to_all_invalid_non_dynamic_result_dimension_induced_by_dynamic_device_group(
     %arg0 : tensor<3x6xi8>) -> tensor<3x6xi8> {
@@ -358,7 +353,7 @@ func.func @all_to_all_invalid_non_dynamic_result_dimension_induced_by_dynamic_de
 
 // -----
 
-mesh.cluster @mesh0(rank = 2, dim_sizes = 1x1)
+mesh.cluster @mesh0(shape = 1x1)
 
 func.func @all_to_all_invalid_non_dynamic_result_split_dimension_induced_by_dynamic_operand_dimension(
     %arg0 : tensor<?x6xi8>) -> tensor<3x?xi8> {
@@ -371,7 +366,7 @@ func.func @all_to_all_invalid_non_dynamic_result_split_dimension_induced_by_dyna
 
 // -----
 
-mesh.cluster @mesh0(rank = 2, dim_sizes = 1x1)
+mesh.cluster @mesh0(shape = 1x1)
 
 func.func @all_to_all_invalid_non_dynamic_result_concat_dimension_induced_by_dynamic_operand_dimension(
     %arg0 : tensor<3x?xi8>) -> tensor<?x3xi8> {
@@ -384,7 +379,7 @@ func.func @all_to_all_invalid_non_dynamic_result_concat_dimension_induced_by_dyn
 
 // -----
 
-mesh.cluster @mesh0(rank = 1, dim_sizes = 3)
+mesh.cluster @mesh0(shape = 3)
 
 func.func @all_to_all_invalid_non_dynamic_result_concat_dimension_size(
     %arg0 : tensor<3x2xi8>) -> tensor<1x7xi8> {
@@ -397,7 +392,7 @@ func.func @all_to_all_invalid_non_dynamic_result_concat_dimension_size(
 
 // -----
 
-mesh.cluster @mesh0(rank = 1, dim_sizes = 3)
+mesh.cluster @mesh0(shape = 3)
 
 func.func @all_to_all_invalid_non_dynamic_result_split_dimension_size(
     %arg0 : tensor<3x2xi8>) -> tensor<2x6xi8> {
@@ -410,7 +405,7 @@ func.func @all_to_all_invalid_non_dynamic_result_split_dimension_size(
 
 // -----
 
-mesh.cluster @mesh0(rank = 2, dim_sizes = 3x?)
+mesh.cluster @mesh0(shape = 3x?)
 
 func.func @broadcast_root_dimension_out_of_bounds(
     %arg0 : tensor<2xi8>) -> tensor<2xi8> {
@@ -423,7 +418,7 @@ func.func @broadcast_root_dimension_out_of_bounds(
 
 // -----
 
-mesh.cluster @mesh0(rank = 2, dim_sizes = 3x?)
+mesh.cluster @mesh0(shape = 3x?)
 
 func.func @broadcast_root_wrong_number_dimensions(
     %arg0 : tensor<2xi8>) -> tensor<2xi8> {
@@ -436,7 +431,7 @@ func.func @broadcast_root_wrong_number_dimensions(
 
 // -----
 
-mesh.cluster @mesh0(rank = 2, dim_sizes = 3x?)
+mesh.cluster @mesh0(shape = 3x?)
 
 func.func @broadcast_
diff erent_input_and_result_type(
     %arg0 : tensor<2xi8>) -> tensor<2xi16> {
@@ -449,7 +444,7 @@ func.func @broadcast_
diff erent_input_and_result_type(
 
 // -----
 
-mesh.cluster @mesh0(rank = 1, dim_sizes = 1)
+mesh.cluster @mesh0(shape = 1)
 
 func.func @gather_wrong_return_element_type(
     %arg0 : tensor<1xf32>) -> tensor<1xi8> {
@@ -461,7 +456,7 @@ func.func @gather_wrong_return_element_type(
 
 // -----
 
-mesh.cluster @mesh0(rank = 1, dim_sizes = 1)
+mesh.cluster @mesh0(shape = 1)
 
 func.func @gather_invalid_non_gather_axis_dimension_size(
     %arg0 : tensor<3x4xf32>) -> tensor<3x5xf32> {
@@ -473,7 +468,7 @@ func.func @gather_invalid_non_gather_axis_dimension_size(
 
 // -----
 
-mesh.cluster @mesh0(rank = 2, dim_sizes = 1x2)
+mesh.cluster @mesh0(shape = 1x2)
 
 func.func @gather_invalid_gather_axis_dimension_size(
     %arg0 : tensor<3x4xf32>) -> tensor<3x5xf32> {
@@ -485,7 +480,7 @@ func.func @gather_invalid_gather_axis_dimension_size(
 
 // -----
 
-mesh.cluster @mesh0(rank = 1, dim_sizes = 1)
+mesh.cluster @mesh0(shape = 1)
 
 func.func @gather_invalid_gather_axis_dynamic_dimension(
     %arg0 : tensor<?xf32>) -> tensor<3xf32> {
@@ -497,7 +492,7 @@ func.func @gather_invalid_gather_axis_dynamic_dimension(
 
 // -----
 
-mesh.cluster @mesh0(rank = 1, dim_sizes = 1)
+mesh.cluster @mesh0(shape = 1)
 
 func.func @gather_invalid_gather_axis(
     %arg0 : tensor<3xf32>) -> tensor<3xf32> {
@@ -509,7 +504,7 @@ func.func @gather_invalid_gather_axis(
 
 // -----
 
-mesh.cluster @mesh0(rank = 1, dim_sizes = 1)
+mesh.cluster @mesh0(shape = 1)
 
 func.func @gather_invalid_negative_gather_axis(
     %arg0 : tensor<3xf32>) -> tensor<3xf32> {
@@ -521,7 +516,7 @@ func.func @gather_invalid_negative_gather_axis(
 
 // -----
 
-mesh.cluster @mesh0(rank = 2, dim_sizes = 3x?)
+mesh.cluster @mesh0(shape = 3x?)
 
 func.func @gather_root_dimension_out_of_bounds(
     %arg0 : tensor<2xi8>) -> tensor<6xi8> {
@@ -534,7 +529,7 @@ func.func @gather_root_dimension_out_of_bounds(
 
 // -----
 
-mesh.cluster @mesh0(rank = 2, dim_sizes = 3x?)
+mesh.cluster @mesh0(shape = 3x?)
 
 func.func @gather_root_wrong_number_dimensions(
     %arg0 : tensor<2xi8>) -> tensor<2xi8> {
@@ -547,7 +542,7 @@ func.func @gather_root_wrong_number_dimensions(
 
 // -----
 
-mesh.cluster @mesh0(rank = 2, dim_sizes = 3x?)
+mesh.cluster @mesh0(shape = 3x?)
 
 func.func @receive_source_dimension_out_of_bounds(
     %arg0 : tensor<2xi8>) -> tensor<2xi8> {
@@ -560,7 +555,7 @@ func.func @receive_source_dimension_out_of_bounds(
 
 // -----
 
-mesh.cluster @mesh0(rank = 2, dim_sizes = 3x?)
+mesh.cluster @mesh0(shape = 3x?)
 
 func.func @receive_source_wrong_number_dimensions(
     %arg0 : tensor<2xi8>) -> tensor<2xi8> {
@@ -573,7 +568,7 @@ func.func @receive_source_wrong_number_dimensions(
 
 // -----
 
-mesh.cluster @mesh0(rank = 2, dim_sizes = 3x?)
+mesh.cluster @mesh0(shape = 3x?)
 
 func.func @receive_
diff erent_input_and_result_type(
     %arg0 : tensor<2xi8>) -> tensor<2xi16> {
@@ -586,7 +581,7 @@ func.func @receive_
diff erent_input_and_result_type(
 
 // -----
 
-mesh.cluster @mesh0(rank = 2, dim_sizes = 3x?)
+mesh.cluster @mesh0(shape = 3x?)
 
 func.func @reduce_root_dimension_out_of_bounds(
     %arg0 : tensor<2xi8>) -> tensor<2xi8> {
@@ -599,7 +594,7 @@ func.func @reduce_root_dimension_out_of_bounds(
 
 // -----
 
-mesh.cluster @mesh0(rank = 2, dim_sizes = 3x?)
+mesh.cluster @mesh0(shape = 3x?)
 
 func.func @reduce_root_wrong_number_dimensions(
     %arg0 : tensor<2xi8>) -> tensor<2xi8> {
@@ -612,7 +607,7 @@ func.func @reduce_root_wrong_number_dimensions(
 
 // -----
 
-mesh.cluster @mesh0(rank = 2, dim_sizes = 3x?)
+mesh.cluster @mesh0(shape = 3x?)
 
 func.func @reduce_
diff erent_input_and_result_shape(
     %arg0 : tensor<2xi8>) -> tensor<3xi16> {
@@ -625,7 +620,7 @@ func.func @reduce_
diff erent_input_and_result_shape(
 
 // -----
 
-mesh.cluster @mesh0(rank = 1, dim_sizes = 3)
+mesh.cluster @mesh0(shape = 3)
 
 func.func @reduce_scatter_duplicate_mesh_axis(
     %arg0 : tensor<?xf32>) -> tensor<?xf64> {
@@ -637,7 +632,7 @@ func.func @reduce_scatter_duplicate_mesh_axis(
 
 // -----
 
-mesh.cluster @mesh0(rank = 1, dim_sizes = 3)
+mesh.cluster @mesh0(shape = 3)
 
 func.func @reduce_scatter_invalid_dynamic_dimension(
     %arg0 : tensor<?xf32>) -> tensor<2xf64> {
@@ -649,7 +644,7 @@ func.func @reduce_scatter_invalid_dynamic_dimension(
 
 // -----
 
-mesh.cluster @mesh0(rank = 1, dim_sizes = 3)
+mesh.cluster @mesh0(shape = 3)
 
 func.func @reduce_scatter_invalid_static_dimension_size(
     %arg0 : tensor<3xf32>) -> tensor<2xf64> {
@@ -661,7 +656,7 @@ func.func @reduce_scatter_invalid_static_dimension_size(
 
 // -----
 
-mesh.cluster @mesh0(rank = 1, dim_sizes = 3)
+mesh.cluster @mesh0(shape = 3)
 
 func.func @reduce_scatter_invalid_operand_static_dimension_size(
     %arg0 : tensor<4xf32>) -> tensor<?xf64> {
@@ -673,7 +668,7 @@ func.func @reduce_scatter_invalid_operand_static_dimension_size(
 
 // -----
 
-mesh.cluster @mesh0(rank = 1, dim_sizes = 3)
+mesh.cluster @mesh0(shape = 3)
 
 func.func @scatter_duplicate_mesh_axis(
     %arg0 : tensor<?xf32>) -> tensor<?xf32> {
@@ -686,7 +681,7 @@ func.func @scatter_duplicate_mesh_axis(
 
 // -----
 
-mesh.cluster @mesh0(rank = 1, dim_sizes = 3)
+mesh.cluster @mesh0(shape = 3)
 
 func.func @scatter_invalid_dynamic_dimension(
     %arg0 : tensor<?xf32>) -> tensor<2xf32> {
@@ -699,7 +694,7 @@ func.func @scatter_invalid_dynamic_dimension(
 
 // -----
 
-mesh.cluster @mesh0(rank = 1, dim_sizes = 3)
+mesh.cluster @mesh0(shape = 3)
 
 func.func @scatter_invalid_static_dimension_size(
     %arg0 : tensor<3xf32>) -> tensor<2xf32> {
@@ -712,7 +707,7 @@ func.func @scatter_invalid_static_dimension_size(
 
 // -----
 
-mesh.cluster @mesh0(rank = 1, dim_sizes = 3)
+mesh.cluster @mesh0(shape = 3)
 
 func.func @scatter_invalid_operand_static_dimension_size(
     %arg0 : tensor<4xf32>) -> tensor<?xf32> {
@@ -725,7 +720,7 @@ func.func @scatter_invalid_operand_static_dimension_size(
 
 // -----
 
-mesh.cluster @mesh0(rank = 2, dim_sizes = 3x?)
+mesh.cluster @mesh0(shape = 3x?)
 
 func.func @scatter_root_dimension_out_of_bounds(
     %arg0 : tensor<3xi8>) -> tensor<1xi8> {
@@ -738,7 +733,7 @@ func.func @scatter_root_dimension_out_of_bounds(
 
 // -----
 
-mesh.cluster @mesh0(rank = 2, dim_sizes = 3x?)
+mesh.cluster @mesh0(shape = 3x?)
 
 func.func @scatter_root_wrong_number_dimensions(
     %arg0 : tensor<3xi8>) -> tensor<1xi8> {
@@ -751,7 +746,7 @@ func.func @scatter_root_wrong_number_dimensions(
 
 // -----
 
-mesh.cluster @mesh0(rank = 2, dim_sizes = 3x?)
+mesh.cluster @mesh0(shape = 3x?)
 
 func.func @send_destination_dimension_out_of_bounds(
     %arg0 : tensor<2xi8>) -> tensor<2xi8> {
@@ -764,7 +759,7 @@ func.func @send_destination_dimension_out_of_bounds(
 
 // -----
 
-mesh.cluster @mesh0(rank = 2, dim_sizes = 3x?)
+mesh.cluster @mesh0(shape = 3x?)
 
 func.func @send_destination_wrong_number_dimensions(
     %arg0 : tensor<2xi8>) -> tensor<2xi8> {
@@ -777,7 +772,7 @@ func.func @send_destination_wrong_number_dimensions(
 
 // -----
 
-mesh.cluster @mesh0(rank = 2, dim_sizes = 3x?)
+mesh.cluster @mesh0(shape = 3x?)
 
 func.func @send_
diff erent_input_and_result_type(
     %arg0 : tensor<2xi8>) -> tensor<2xi16> {
@@ -801,7 +796,7 @@ func.func @shift_invalid_mesh_symbol(
 
 // -----
 
-mesh.cluster @mesh0(rank = 2, dim_sizes = 2x4)
+mesh.cluster @mesh0(shape = 2x4)
 
 func.func @shift_invalid_mesh_axis(
     %arg0 : tensor<4xi8>) -> tensor<4xi8> {
@@ -814,7 +809,7 @@ func.func @shift_invalid_mesh_axis(
 
 // -----
 
-mesh.cluster @mesh0(rank = 2, dim_sizes = 2x4)
+mesh.cluster @mesh0(shape = 2x4)
 
 func.func @shift_duplicate_mesh_axis(
     %arg0 : tensor<4xi8>) -> tensor<4xi8> {
@@ -827,7 +822,7 @@ func.func @shift_duplicate_mesh_axis(
 
 // -----
 
-mesh.cluster @mesh0(rank = 2, dim_sizes = 2x4)
+mesh.cluster @mesh0(shape = 2x4)
 
 func.func @shift_invalid_tensor_dimension_size(
     %arg0 : tensor<4xi8>) -> tensor<5xi8> {
@@ -840,7 +835,7 @@ func.func @shift_invalid_tensor_dimension_size(
 
 // -----
 
-mesh.cluster @mesh0(rank = 2, dim_sizes = 2x4)
+mesh.cluster @mesh0(shape = 2x4)
 
 func.func @shift_invalid_shift_axis(
     %arg0 : tensor<4xi8>) -> tensor<4xi8> {

diff  --git a/mlir/test/Dialect/Mesh/ops.mlir b/mlir/test/Dialect/Mesh/ops.mlir
index 0abe31c8b3f7db..0aaa4bdee1db3a 100644
--- a/mlir/test/Dialect/Mesh/ops.mlir
+++ b/mlir/test/Dialect/Mesh/ops.mlir
@@ -1,21 +1,21 @@
 // RUN: mlir-opt %s | mlir-opt | FileCheck %s
 
 // CHECK: mesh.cluster @mesh0
-mesh.cluster @mesh0(rank = 3, dim_sizes = 2x2x4)
+mesh.cluster @mesh0(shape = 2x2x4)
 
-// CHECK: mesh.cluster @mesh1(rank = 2, dim_sizes = 4)
-mesh.cluster @mesh1(rank = 2, dim_sizes = 4)
+// CHECK: mesh.cluster @mesh1(shape = 4x?)
+mesh.cluster @mesh1(shape = 4x?)
 
-// CHECK: mesh.cluster @mesh2(rank = 2, dim_sizes = ?x4)
-mesh.cluster @mesh2(rank = 2, dim_sizes = ?x4)
+// CHECK: mesh.cluster @mesh2(shape = ?x4)
+mesh.cluster @mesh2(shape = ?x4)
 
-// CHECK: mesh.cluster @mesh3
-mesh.cluster @mesh3(rank = 2)
+// CHECK: mesh.cluster @mesh3(shape = ?x?)
+mesh.cluster @mesh3(shape = ?x?)
 
-mesh.cluster @mesh4(rank = 1, dim_sizes = 3)
+mesh.cluster @mesh4(shape = 3)
 
-// CHECK: mesh.cluster @mesh5(rank = 1)
-mesh.cluster @mesh5(rank = 1, dim_sizes = [])
+// CHECK: mesh.cluster @mesh5(shape = ?)
+mesh.cluster @mesh5(shape = ?)
 
 // CHECK-LABEL: func @mesh_shard_encoding_fully_replicated
 func.func @mesh_shard_encoding_fully_replicated(

diff  --git a/mlir/test/Dialect/Mesh/process-multi-index-op-lowering.mlir b/mlir/test/Dialect/Mesh/process-multi-index-op-lowering.mlir
index 9602fb729c2681..aeeba4ea3f1979 100644
--- a/mlir/test/Dialect/Mesh/process-multi-index-op-lowering.mlir
+++ b/mlir/test/Dialect/Mesh/process-multi-index-op-lowering.mlir
@@ -1,6 +1,6 @@
 // RUN: mlir-opt -test-mesh-process-multi-index-op-lowering %s | FileCheck %s
 
-mesh.cluster @mesh2d(rank = 2)
+mesh.cluster @mesh2d(shape = ?x?)
 
 // CHECK-LABEL: func.func @multi_index_2d_mesh
 func.func @multi_index_2d_mesh() -> (index, index) {

diff  --git a/mlir/test/Dialect/Mesh/resharding-spmdization.mlir b/mlir/test/Dialect/Mesh/resharding-spmdization.mlir
index 786ea386df815a..3f5c7d80bf9c7e 100644
--- a/mlir/test/Dialect/Mesh/resharding-spmdization.mlir
+++ b/mlir/test/Dialect/Mesh/resharding-spmdization.mlir
@@ -1,7 +1,7 @@
 // RUN: mlir-opt -test-mesh-resharding-spmdization %s | FileCheck %s
 
-mesh.cluster @mesh_1d(rank = 1, dim_sizes = 2)
-mesh.cluster @mesh_1d_dynamic(rank = 1, dim_sizes = ?)
+mesh.cluster @mesh_1d(shape = 2)
+mesh.cluster @mesh_1d_dynamic(shape = ?)
 
 // CHECK-LABEL: func @same_source_and_target_sharding
 func.func @same_source_and_target_sharding(

diff  --git a/mlir/test/Dialect/Mesh/sharding-propagation.mlir b/mlir/test/Dialect/Mesh/sharding-propagation.mlir
index 30bbd5c6619e8a..065ae9ca8c6b41 100644
--- a/mlir/test/Dialect/Mesh/sharding-propagation.mlir
+++ b/mlir/test/Dialect/Mesh/sharding-propagation.mlir
@@ -1,8 +1,8 @@
 // RUN: mlir-opt -sharding-propagation %s | FileCheck %s
 
-mesh.cluster @mesh_1d(rank = 1)
-mesh.cluster @mesh_2d(rank = 2, dim_sizes = 2x4)
-mesh.cluster @mesh_3d(rank = 3)
+mesh.cluster @mesh_1d(shape = ?)
+mesh.cluster @mesh_2d(shape = 2x4)
+mesh.cluster @mesh_3d(shape = ?x?x?)
 
 // CHECK-LABEL: func.func @element_wise_empty_sharding_info
 func.func @element_wise_empty_sharding_info(%arg0: tensor<8x16xf32>) -> tensor<8x16xf32> {

diff  --git a/mlir/test/Dialect/Mesh/simplifications.mlir b/mlir/test/Dialect/Mesh/simplifications.mlir
index e716940f2301e9..63ae9d528df1b3 100644
--- a/mlir/test/Dialect/Mesh/simplifications.mlir
+++ b/mlir/test/Dialect/Mesh/simplifications.mlir
@@ -1,7 +1,7 @@
 // RUN: mlir-opt -test-mesh-simplifications %s | FileCheck %s
 
-mesh.cluster @mesh0(rank = 2, dim_sizes = 4x2)
-mesh.cluster @mesh1(rank = 1, dim_sizes = 4)
+mesh.cluster @mesh0(shape = 4x2)
+mesh.cluster @mesh1(shape = 4)
 
 // Checks that `all_reduce(x) + all_reduce(y)` gets transformed to
 // `all_reduce(x + y)`.


        


More information about the Mlir-commits mailing list