[Mlir-commits] [mlir] [mlir][mesh] Remove rank attribute and rename dim_sizes to shape in ClusterOp (PR #77838)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Jan 11 13:48:34 PST 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Boian Petkantchin (sogartar)

<details>
<summary>Changes</summary>

Remove the somewhat redundant rank attribute.
Before this change
```
mesh.cluster @<!-- -->mesh(rank = 3, shape = 2x3)
```
After
```
mesh.cluster @<!-- -->mesh(shape = 2x3x?)
```

The rank is instead determined by the provided shape. With this change no longer `getDimSizes()` can be wrongly assumed to have size equal to the cluster rank.
Now `getShape().size()` will always equal `getRank()`.

---

Patch is 38.97 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/77838.diff


13 Files Affected:

- (modified) mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td (+1-1) 
- (modified) mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td (+15-25) 
- (modified) mlir/lib/Dialect/Mesh/IR/MeshOps.cpp (+23-33) 
- (modified) mlir/lib/Dialect/Mesh/Transforms/Simplifications.cpp (+2-2) 
- (modified) mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp (+7-9) 
- (modified) mlir/test/Dialect/Mesh/canonicalization.mlir (+1-1) 
- (modified) mlir/test/Dialect/Mesh/folding.mlir (+2-2) 
- (modified) mlir/test/Dialect/Mesh/invalid.mlir (+65-70) 
- (modified) mlir/test/Dialect/Mesh/ops.mlir (+10-10) 
- (modified) mlir/test/Dialect/Mesh/process-multi-index-op-lowering.mlir (+1-1) 
- (modified) mlir/test/Dialect/Mesh/resharding-spmdization.mlir (+2-2) 
- (modified) mlir/test/Dialect/Mesh/sharding-propagation.mlir (+3-3) 
- (modified) mlir/test/Dialect/Mesh/simplifications.mlir (+2-2) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td b/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td
index bda6467e9c5d4b..07f954459ca49d 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td
@@ -117,7 +117,7 @@ def MeshSharding : AttrDef<Mesh_Dialect, "MeshSharding"> {
     Example:
 
     ```
-    mesh.cluster @mesh0(rank = 3, dim_sizes = [2, 2, 4])
+    mesh.cluster @mesh0(shape = 2x2x4)
 
     // The tensor is fully replicated on @mesh0.
     // Currently, there must be at least one sub-array present in axes, even
diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
index a9068562f5c903..c25996cf122c1c 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
@@ -41,7 +41,7 @@ def Mesh_ClusterOp : Mesh_Op<"cluster", [Symbol]> {
     determine the layout and the addressing space of the computation distributed
     across the mesh.
 
-    3. `dim_sizes`: This attribute represents the shape of the device cluster.
+    3. `shape`: This attribute represents the shape of the device cluster.
     It uses the same notation as a tensor shape. Also allowing for dynamic
     dimensions.
     This flexibility allows for dynamic device assignment or configurations
@@ -53,19 +53,19 @@ def Mesh_ClusterOp : Mesh_Op<"cluster", [Symbol]> {
     ```
     // A device mesh cluster with 3 axes, the total device number is 4 * 8 * 12
     // The dimension sizes are 4, 8, 12 
-    mesh.cluster @mesh0(rank = 3, dim_sizes = 4x8x12)
+    mesh.cluster @mesh0(shape = 4x8x12)
 
     // A device mesh cluster with 2 axes, the total device number is unknown
     // The first dimension size is 4 and the second is unknown
-    mesh.cluster @mesh1(rank = 2, dim_sizes = 4)
+    mesh.cluster @mesh1(shape = 4x?)
 
     // A device mesh cluster with 2 axes, the total device number is unknown
     // The first dimension size is unknown and the second is 4
-    mesh.cluster @mesh2(rank = 2, dim_sizes = ?x4)
+    mesh.cluster @mesh2(shape = ?x4)
 
     // A device mesh cluster with 2 axes, the number of devices along both axes
     // is unknown
-    mesh.cluster @mesh3(rank = 2)
+    mesh.cluster @mesh3(shape = ?x?)
 
     // Used in the mesh sharding attribute to extend the standard tensor to
     // distributed
@@ -74,24 +74,14 @@ def Mesh_ClusterOp : Mesh_Op<"cluster", [Symbol]> {
   }];
   let arguments = (ins
     SymbolNameAttr:$sym_name,
-    I64Attr:$rank,
-    DefaultValuedAttr<DenseI64ArrayAttr, "{}">:$dim_sizes
+    DenseI64ArrayAttr:$shape
   );
   let assemblyFormat = [{
-    $sym_name `(` `rank` `=` $rank (`,` `dim_sizes` `=` custom<DimensionList>($dim_sizes)^)? `)`
+    $sym_name `(` `shape` `=` custom<DimensionList>($shape) `)`
       attr-dict
   }];
   let extraClassDeclaration = [{
-    // The `dim_sizes` attribute may have size less than the rank of the mesh.
-    // Returns the shape of the mesh with missing trailing dimensions
-    // explicitly set as dynamic.
-    ::mlir::SmallVector<int64_t> canonicalDimSizes();
-
-    template <typename OutIt>
-    void canonicalDimSizes(OutIt outIt) {
-      std::copy(getDimSizes().begin(), getDimSizes().end(), outIt);
-      std::fill_n(outIt, getRank() - getDimSizes().size(), ::mlir::ShapedType::kDynamic);
-    }
+    int64_t getRank() { return getShape().size(); }
   }];
   let hasVerifier = 1;
 }
@@ -283,7 +273,7 @@ def Mesh_AllGatherOp : Mesh_CollectiveCommunicationOpBase<"all_gather", [
 
     Example:
     ```mlir
-    mesh.cluster @mesh0(rank = 2, dim_sizes = 2x2)
+    mesh.cluster @mesh0(shape = 2x2)
     ...
     %1 = mesh.all_gather %0 on @mesh0 mesh_axes = [1] gather_axis = 1
       : tensor<2x2xi8> -> tensor<2x4xi8>
@@ -368,7 +358,7 @@ def Mesh_AllToAllOp : Mesh_CollectiveCommunicationOpBase<"all_to_all", [
 
     Example:
     ```
-    mesh.cluster @mesh0(rank = 1, dim_sizes = 3)
+    mesh.cluster @mesh0(shape = 3)
     ...
     %1 = mesh.all_to_all %0 on @mesh0 mesh_axes = [0]
       split_axis = 0 concat_axis = 0
@@ -425,7 +415,7 @@ def Mesh_BroadcastOp : Mesh_CollectiveCommunicationOpBase<"broadcast", [
     
     Example:
     ```
-    mesh.cluster @mesh0(rank = 2, dim_sizes = [2, 2])
+    mesh.cluster @mesh0(shape = 2x2)
 
     %1 = mesh.broadcast %0 on @mesh0
       mesh_axes = [0]
@@ -481,7 +471,7 @@ def Mesh_GatherOp : Mesh_CollectiveCommunicationOpBase<"gather", [
 
     Example:
     ```mlir
-    mesh.cluster @mesh0(rank = 2, dim_sizes = [2, 2])
+    mesh.cluster @mesh0(shape = 2x2)
     ...
     %1 = mesh.gather %0 on @mesh0 mesh_axes = [1]
       gather_axis = 1 root = [1]
@@ -604,7 +594,7 @@ def Mesh_ReduceScatterOp : Mesh_CollectiveCommunicationOpBase<"reduce_scatter",
     across the device group.
     Example:
     ```
-    mesh.cluster @mesh0(rank = 1, dim_sizes = 2x2)
+    mesh.cluster @mesh0(shape = 2x2)
     ...
     %1 = mesh.reduce_scatter %0 on @mesh0 mesh_axes = [1]
       reduction = <max> scatter_axis = 0
@@ -667,7 +657,7 @@ def Mesh_ScatterOp : Mesh_CollectiveCommunicationOpBase<"scatter", [
 
     Example:
     ```
-    mesh.cluster @mesh0(rank = 2, dim_sizes = [2, 2])
+    mesh.cluster @mesh0(shape = 2x2)
     %1 = mesh.scatter %0 on @mesh0 mesh_axes = [0]
       scatter_axis = 0
       root = [1]
@@ -763,7 +753,7 @@ def Mesh_ShiftOp : Mesh_CollectiveCommunicationOpBase<"shift", [
 
     Example:
     ```
-    mesh.cluster @mesh0(rank = 2, dim_sizes = [2, 4])
+    mesh.cluster @mesh0(shape = 2x4)
     %1 = mesh.shift on @mesh0 mesh_axes = [1]
       shift_axis = 1 offset = 2 rotate
       : tensor<2xi8> -> tensor<2xi8>
diff --git a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
index 957b380efd516b..fa9da596a34587 100644
--- a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
+++ b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
@@ -196,17 +196,16 @@ Partial mesh::getPartialTypeFromReduction(IteratorType iType) {
 //===----------------------------------------------------------------------===//
 
 LogicalResult ClusterOp::verify() {
-  ArrayRef<int64_t> dimSizes = getDimSizes();
-  uint64_t rank = getRank();
+  int64_t rank = getRank();
 
-  if (rank == 0)
+  if (rank <= 0)
     return emitOpError("rank of cluster is expected to be a positive integer");
 
-  if (dimSizes.size() > rank)
+  if (getShape().size() > rank)
     return emitOpError(
-        "rank of dim_sizes is not expected to be larger than rank of cluster");
+        "rank of shape is not expected to be larger than rank of cluster");
 
-  for (int64_t dimSize : dimSizes) {
+  for (int64_t dimSize : getShape()) {
     if (dimSize < 0 && !ShapedType::isDynamic(dimSize))
       return emitOpError("dimension size of a mesh cluster is expected to be "
                          "non-negative or dynamic");
@@ -215,13 +214,6 @@ LogicalResult ClusterOp::verify() {
   return success();
 }
 
-SmallVector<int64_t> ClusterOp::canonicalDimSizes() {
-  SmallVector<int64_t> result;
-  canonicalDimSizes(std::back_inserter(result));
-  result.reserve(getRank());
-  return result;
-}
-
 //===----------------------------------------------------------------------===//
 // mesh.cluster_shape op
 //===----------------------------------------------------------------------===//
@@ -614,7 +606,7 @@ AllGatherOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
   auto gatherAxis = getGatherAxis().getSExtValue();
   return verifyGatherOperandAndResultShape(getOperand(), getResult(),
                                            gatherAxis, getMeshAxes(),
-                                           mesh.value().canonicalDimSizes());
+                                           mesh.value().getShape());
 }
 
 void AllGatherOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
@@ -648,8 +640,7 @@ LogicalResult AllToAllOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
 
   return verifyAllToAllOperandAndResultShape(
       getOperand(), getResult(), getSplitAxis().getSExtValue(),
-      getConcatAxis().getSExtValue(), getMeshAxes(),
-      mesh.value().canonicalDimSizes());
+      getConcatAxis().getSExtValue(), getMeshAxes(), mesh.value().getShape());
 }
 
 void AllToAllOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
@@ -667,9 +658,9 @@ BroadcastOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
   if (failed(mesh)) {
     return failure();
   }
-  auto meshShape = mesh.value().canonicalDimSizes();
   if (failed(verifyInGroupDevice(getLoc(), getRootAttrName(), getRoot(),
-                                 getRootDynamic(), getMeshAxes(), meshShape))) {
+                                 getRootDynamic(), getMeshAxes(),
+                                 mesh.value().getShape()))) {
     return failure();
   }
 
@@ -690,16 +681,16 @@ LogicalResult GatherOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
   if (failed(mesh)) {
     return failure();
   }
-  auto meshShape = mesh.value().canonicalDimSizes();
   if (failed(verifyInGroupDevice(getLoc(), getRootAttrName(), getRoot(),
-                                 getRootDynamic(), getMeshAxes(), meshShape))) {
+                                 getRootDynamic(), getMeshAxes(),
+                                 mesh.value().getShape()))) {
     return failure();
   }
 
   auto gatherAxis = getGatherAxis().getSExtValue();
   return verifyGatherOperandAndResultShape(getInput(), getResult(), gatherAxis,
                                            getMeshAxes(),
-                                           mesh.value().canonicalDimSizes());
+                                           mesh.value().getShape());
 }
 
 void GatherOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
@@ -716,10 +707,10 @@ LogicalResult RecvOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
   if (failed(mesh)) {
     return failure();
   }
-  auto meshShape = mesh.value().canonicalDimSizes();
-  if (getSource() && failed(verifyInGroupDevice(
-                         getLoc(), getSourceAttrName(), getSource().value(),
-                         getSourceDynamic(), getMeshAxes(), meshShape))) {
+  if (getSource() &&
+      failed(verifyInGroupDevice(getLoc(), getSourceAttrName(),
+                                 getSource().value(), getSourceDynamic(),
+                                 getMeshAxes(), mesh.value().getShape()))) {
     return failure();
   }
   return success();
@@ -739,9 +730,9 @@ LogicalResult ReduceOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
   if (failed(mesh)) {
     return failure();
   }
-  auto meshShape = mesh.value().canonicalDimSizes();
   if (failed(verifyInGroupDevice(getLoc(), getRootAttrName(), getRoot(),
-                                 getRootDynamic(), getMeshAxes(), meshShape))) {
+                                 getRootDynamic(), getMeshAxes(),
+                                 mesh.value().getShape()))) {
     return failure();
   }
 
@@ -766,7 +757,7 @@ ReduceScatterOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
 
   return verifyScatterOperandAndResultShape(
       getOperand(), getResult(), getScatterAxis().getSExtValue(), getMeshAxes(),
-      mesh.value().canonicalDimSizes());
+      mesh.value().getShape());
 }
 
 void ReduceScatterOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
@@ -783,16 +774,16 @@ LogicalResult ScatterOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
   if (failed(mesh)) {
     return failure();
   }
-  auto meshShape = mesh.value().canonicalDimSizes();
   if (failed(verifyInGroupDevice(getLoc(), getRootAttrName(), getRoot(),
-                                 getRootDynamic(), getMeshAxes(), meshShape))) {
+                                 getRootDynamic(), getMeshAxes(),
+                                 mesh.value().getShape()))) {
     return failure();
   }
 
   auto scatterAxis = getScatterAxis().getSExtValue();
   return verifyScatterOperandAndResultShape(getInput(), getResult(),
                                             scatterAxis, getMeshAxes(),
-                                            mesh.value().canonicalDimSizes());
+                                            mesh.value().getShape());
 }
 
 void ScatterOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
@@ -809,10 +800,9 @@ LogicalResult SendOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
   if (failed(mesh)) {
     return failure();
   }
-  auto meshShape = mesh.value().canonicalDimSizes();
   if (failed(verifyInGroupDevice(getLoc(), getDestinationAttrName(),
                                  getDestination(), getDestinationDynamic(),
-                                 getMeshAxes(), meshShape))) {
+                                 getMeshAxes(), mesh.value().getShape()))) {
     return failure();
   }
   return success();
diff --git a/mlir/lib/Dialect/Mesh/Transforms/Simplifications.cpp b/mlir/lib/Dialect/Mesh/Transforms/Simplifications.cpp
index c9275ad5ad4551..c478b6da4c27b3 100644
--- a/mlir/lib/Dialect/Mesh/Transforms/Simplifications.cpp
+++ b/mlir/lib/Dialect/Mesh/Transforms/Simplifications.cpp
@@ -80,7 +80,7 @@ struct ClusterShapeFolder : OpRewritePattern<ClusterShapeOp> {
       opMeshAxes = opAxesIota;
     }
     if (llvm::all_of(opMeshAxes, [&mesh](MeshAxis axis) {
-          return ShapedType::isDynamic(mesh.getDimSizes()[axis]);
+          return ShapedType::isDynamic(mesh.getShape()[axis]);
         })) {
       // All mesh dimensions are dynamic. Nothing to fold.
       return failure();
@@ -91,7 +91,7 @@ struct ClusterShapeFolder : OpRewritePattern<ClusterShapeOp> {
     SmallVector<size_t> newToOldResultsIndexMap;
 
     for (size_t i = 0; i < opMeshAxes.size(); ++i) {
-      auto meshAxisSize = mesh.getDimSizes()[opMeshAxes[i]];
+      auto meshAxisSize = mesh.getShape()[opMeshAxes[i]];
       if (ShapedType::isDynamic(meshAxisSize)) {
         newToOldResultsIndexMap.push_back(i);
         newShapeOpMeshAxes.push_back(opMeshAxes[i]);
diff --git a/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp b/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
index 0e83c024fc08f8..9478b2e4ee5cb2 100644
--- a/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
+++ b/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
@@ -88,8 +88,8 @@ ShapedType shardShapedType(ShapedType shape, ClusterOp mesh,
                            MeshShardingAttr sharding) {
   using Dim = std::decay_t<decltype(shape.getDimSize(0))>;
   SmallVector<Dim> resShapeArr(shape.getShape().size());
-  shardShape(shape.getShape(), mesh.canonicalDimSizes(),
-             sharding.getSplitAxes(), resShapeArr);
+  shardShape(shape.getShape(), mesh.getShape(), sharding.getSplitAxes(),
+             resShapeArr);
   return shape.clone(resShapeArr);
 }
 
@@ -212,9 +212,8 @@ splitLastAxisInResharding(ImplicitLocOpBuilder &builder,
 
   MeshShardingAttr targetSharding = targetShardingInSplitLastAxis(
       ctx, sourceSharding, splitTensorAxis, splitMeshAxis);
-  ShapedType targetShape =
-      targetShapeInSplitLastAxis(sourceShard.getType(), splitTensorAxis,
-                                 mesh.canonicalDimSizes()[splitMeshAxis]);
+  ShapedType targetShape = targetShapeInSplitLastAxis(
+      sourceShard.getType(), splitTensorAxis, mesh.getShape()[splitMeshAxis]);
 
   Value meshAxisSize =
       builder
@@ -391,8 +390,7 @@ unsplitLastAxisInResharding(ImplicitLocOpBuilder &builder,
   MeshShardingAttr targetSharding =
       targetShardingInUnsplitLastAxis(ctx, sourceSharding, splitMeshAxis);
   ShapedType allGatherResultShape = allGatherResultShapeInUnsplitLastAxis(
-      sourceShard.getType(), mesh.canonicalDimSizes()[splitMeshAxis],
-      splitTensorAxis);
+      sourceShard.getType(), mesh.getShape()[splitMeshAxis], splitTensorAxis);
   Value allGatherResult = builder.create<AllGatherOp>(
       RankedTensorType::get(allGatherResultShape.getShape(),
                             allGatherResultShape.getElementType()),
@@ -526,8 +524,8 @@ moveLastSplitAxisInResharding(ImplicitLocOpBuilder &builder, ClusterOp mesh,
   MeshShardingAttr targetSharding = targetShardingInMoveLastAxis(
       ctx, sourceSharding, sourceTensorAxis, targetTensorAxis);
   ShapedType allToAllResultShape = allToAllResultShapeInMoveLastAxis(
-      sourceShard.getType(), mesh.canonicalDimSizes()[meshAxis],
-      sourceTensorAxis, targetTensorAxis);
+      sourceShard.getType(), mesh.getShape()[meshAxis], sourceTensorAxis,
+      targetTensorAxis);
   Value allToAllResult = builder.create<AllToAllOp>(
       RankedTensorType::get(allToAllResultShape.getShape(),
                             allToAllResultShape.getElementType()),
diff --git a/mlir/test/Dialect/Mesh/canonicalization.mlir b/mlir/test/Dialect/Mesh/canonicalization.mlir
index 0a00ab41268d01..4cc009ef24eb3c 100644
--- a/mlir/test/Dialect/Mesh/canonicalization.mlir
+++ b/mlir/test/Dialect/Mesh/canonicalization.mlir
@@ -1,6 +1,6 @@
 // RUN: mlir-opt --canonicalize %s | FileCheck %s
 
-mesh.cluster @mesh0(rank = 2, dim_sizes = 2x4)
+mesh.cluster @mesh0(shape = 2x4)
 
 // CHECK-LABEL: func @all_reduce_empty_mesh_axes
 func.func @all_reduce_empty_mesh_axes(
diff --git a/mlir/test/Dialect/Mesh/folding.mlir b/mlir/test/Dialect/Mesh/folding.mlir
index dd64d746341b83..9162dc57ecfdf4 100644
--- a/mlir/test/Dialect/Mesh/folding.mlir
+++ b/mlir/test/Dialect/Mesh/folding.mlir
@@ -1,7 +1,7 @@
 // RUN: mlir-opt -test-mesh-simplifications %s | FileCheck %s
 
-mesh.cluster @mesh0(rank = 3, dim_sizes = 4x?x2)
-mesh.cluster @mesh1(rank = 2, dim_sizes = 2x3)
+mesh.cluster @mesh0(shape = 4x?x2)
+mesh.cluster @mesh1(shape = 2x3)
 
 // CHECK-LABEL: func.func @cluster_shape_op_folding
 func.func @cluster_shape_op_folding() -> (index, index) {
diff --git a/mlir/test/Dialect/Mesh/invalid.mlir b/mlir/test/Dialect/Mesh/invalid.mlir
index f3524a82a1b9d2..8a1fb80065573b 100644
--- a/mlir/test/Dialect/Mesh/invalid.mlir
+++ b/mlir/test/Dialect/Mesh/invalid.mlir
@@ -1,21 +1,16 @@
 // RUN: mlir-opt -split-input-file -verify-diagnostics %s
 
 // expected-error at +1 {{rank of cluster is expected to be a positive integer}}
-mesh.cluster @mesh0(rank = 0)
-
-// -----
-
-// expected-error at +1 {{rank of dim_sizes is not expected to be larger than rank of cluster}}
-mesh.cluster @mesh0(rank = 2, dim_sizes = 2x3x4)
+mesh.cluster @mesh0(shape = [])
 
 // -----
 
 // expected-error at +1 {{custom op 'mesh.cluster' Failed parsing dimension list. Did you mean an empty list? It must be denoted by "[]".}}
-mesh.cluster @mesh0(rank = 2, dim_sizes = -1)
+mesh.cluster @mesh0(shape = -1)
 
 // -----
 
-mesh.cluster @mesh0(rank = 2, dim_sizes = 2x4)
+mesh.cluster @mesh0(shape = 2x4)
 
 func.func @mesh_axis_duplicated_different_subarray(
     // expected-error at +1 {{mesh axis duplicated}}
@@ -26,7 +21,7 @@ func.func @mesh_axis_duplicated_different_subarray(
 
 // -----
 
-mesh.cluster @mesh0(rank = 2, dim_sizes = 2x4)
+mesh.cluster @mesh0(shape = 2x4)
 
 func.func @mesh_axis_duplicated_same_subarray(
     // expected-error at +1 {{mesh axis duplicated}}
@@ -37,7 +32,7 @@ func.func @mesh_axis_duplicated_same_subarray(
 
 // -----
 
-mesh.cluster @mesh0(rank = 2, dim_sizes = 2x4)
+mesh.cluster @mesh0(shape = 2x4)
 
 func.func @mesh_axis_duplicated_bewteen_split_and_partial(
     // expected-error at +1 {{mesh axis duplicated}}
@@ -48,7 +43,7 @@ func.func @mesh_axis_duplicated_bewteen_split_and_partial(
 
 // -----
 
-mesh.cluster @mesh0(rank = 2, dim_sizes = 2x4)
+mesh.cluster @mesh0(shape = 2x4)
 
 func.func @mesh_axis_negtive_in_split_part(
     // expected-error at +1 {{mesh axis is expected to be non-negative}}
@@ -59,7 +54,7 @@ func.func @mesh_axis_negtive_in_split_part(
 
 // -----
 
-mesh.cluster @mesh0(rank = 2, dim_sizes = 2x4)
+mesh.cluster @mesh0(shape = 2x4)
 
 func.func @mesh_axis_negtive_in_partial(
     // expected-error at +1 {{mesh axis is expected to be non-negative}}
@@ -78,7 +73,7 @@ func.func @sharding_attribute_invalid_nested_symbol(%arg0 : tensor<4x8xf32>) {
 
 // -----
 
-mesh.cluster @mesh0(rank = 2, dim_sizes = 2x4)
+mesh.cluster @mesh0(shape = 2x4)
 
 func.func @cluster_shape_mesh_axis_out_of_bounds() -> (index, index) {
   // expected-error at +1 {{0-based mesh axis index 2 is out of bounds. The referenced mesh "mesh0" is of rank 2.}}
@@ -88,7 +83,7 @@ func.func @cluster_shape_mesh_axis_out_of_bounds() -> (index, index) {
 
 // -----
 
-mesh.cluster @mesh0(rank = 3, dim_sizes = 1x2x3)
+mesh.cluster @mesh0(shape = 1x2x3)
 
 func.func @cluster_shape_duplicate_mesh_axis() -> (index, index, index) {
...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/77838


More information about the Mlir-commits mailing list