[Mlir-commits] [mlir] [mlir][mesh] Rename cluster to mesh (PR #79484)

Boian Petkantchin llvmlistbot at llvm.org
Thu Jan 25 10:54:09 PST 2024


https://github.com/sogartar created https://github.com/llvm/llvm-project/pull/79484

Rename
* Op mesh.cluster -> mesh.mesh
* Op mesh.cluster_shape -> mesh.mesh_shape
* variables and attributes.

The name `mesh` is more specific to what it really represents. It is a mesh of devices.
The name `cluster` implies a broader posibility of device configurations. When just the word `mesh` is used the meaning can often be inferred from the context whether it refers to the mesh dialect or a device mesh. The full name can be used when needed.

>From 03b07b074cad00a41e27463cd810c0c3900235b1 Mon Sep 17 00:00:00 2001
From: Boian Petkantchin <boian.petkantchin at amd.com>
Date: Thu, 25 Jan 2024 10:21:55 -0800
Subject: [PATCH] [mlir][mesh] Rename cluster to mesh

Rename
* Op mesh.cluster -> mesh.mesh
* Op mesh.cluster_shape -> mesh.mesh_shape
* variables and attributes.

The name `mesh` is more specific to what it really represents.
It is a mesh of devices.
The name `cluster` implies a broader posibility of device configurations.
When just the word `mesh` is used the meaning can often be inferred from the
context whether it refers to the mesh dialect or a device mesh.
The full name can be used when needed.
---
 mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td |  20 +--
 mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td  |  58 +++----
 .../Mesh/Interfaces/ShardingInterface.h       |   6 +-
 .../Dialect/Mesh/Transforms/Spmdization.h     |   6 +-
 mlir/lib/Dialect/Mesh/IR/MeshOps.cpp          |  43 +++--
 .../Mesh/Interfaces/ShardingInterface.cpp     |  22 ++-
 .../Mesh/Transforms/Simplifications.cpp       |  22 ++-
 .../Dialect/Mesh/Transforms/Spmdization.cpp   |  36 ++--
 .../Dialect/Mesh/Transforms/Transforms.cpp    |   9 +-
 mlir/test/Dialect/Mesh/canonicalization.mlir  |   2 +-
 mlir/test/Dialect/Mesh/folding.mlir           |  18 +-
 mlir/test/Dialect/Mesh/invalid.mlir           | 156 +++++++++---------
 mlir/test/Dialect/Mesh/ops.mlir               |  46 +++---
 .../Mesh/process-multi-index-op-lowering.mlir |   6 +-
 .../Dialect/Mesh/resharding-spmdization.mlir  |   8 +-
 .../Dialect/Mesh/sharding-propagation.mlir    |   6 +-
 mlir/test/Dialect/Mesh/simplifications.mlir   |   4 +-
 .../Mesh/TestReshardingSpmdization.cpp        |  12 +-
 18 files changed, 237 insertions(+), 243 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td b/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td
index 07f954459ca49d..c15c5f6bcf2f0f 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td
@@ -79,7 +79,7 @@ def MeshSharding : AttrDef<Mesh_Dialect, "MeshSharding"> {
   let mnemonic = "shard";
 
   let parameters = (ins
-    AttrParameter<"::mlir::FlatSymbolRefAttr", "cluster placed">:$cluster,
+    AttrParameter<"::mlir::FlatSymbolRefAttr", "mesh placed">:$mesh,
     ArrayRefParameter<"MeshAxesAttr">:$split_axes,
     OptionalArrayRefParameter<"MeshAxis">:$partial_axes,
     OptionalParameter<"::mlir::mesh::Partial">:$partial_type
@@ -91,9 +91,9 @@ def MeshSharding : AttrDef<Mesh_Dialect, "MeshSharding"> {
     The MeshSharding attribute could be used in the encoding of a
     `RankedTensorType` or the mesh.shard op. it contains three sub-attributes:
 
-    1. `cluster`: this attribute is a FlatSymbolRefAttr that refers to the mesh
-    cluster where the distributed tensor is placed. The symbol must resolve to a
-    `mesh.cluster` operation.
+    1. `mesh`: this attribute is a FlatSymbolRefAttr that refers to the device
+    mesh where the distributed tensor is placed. The symbol must resolve to a
+    `mesh.mesh` operation.
 
     2. `split_axes`: is an array composed of int64_t sub-arrays. The outer array's
     maximum size is the `rank` of the related tensor. For the i-th sub-array, if
@@ -117,7 +117,7 @@ def MeshSharding : AttrDef<Mesh_Dialect, "MeshSharding"> {
     Example:
 
     ```
-    mesh.cluster @mesh0(shape = 2x2x4)
+    mesh.mesh @mesh0(shape = 2x2x4)
 
     // The tensor is fully replicated on @mesh0.
     // Currently, there must be at least one sub-array present in axes, even
@@ -140,12 +140,12 @@ def MeshSharding : AttrDef<Mesh_Dialect, "MeshSharding"> {
     ```
   }];
   let assemblyFormat = [{
-    `<` $cluster `,` `[` $split_axes `]` (`,` `partial` `=` $partial_type `[`
+    `<` $mesh `,` `[` $split_axes `]` (`,` `partial` `=` $partial_type `[`
        $partial_axes^ `]`)? `>`
   }];
 
   let builders = [
-    AttrBuilder<(ins "FlatSymbolRefAttr":$cluster,
+    AttrBuilder<(ins "FlatSymbolRefAttr":$mesh,
                      "ArrayRef<SmallVector<MeshAxis>>":$split_axes,
                      "ArrayRef<MeshAxis>": $partial_axes,
                      "mesh::Partial": $partial_type), [{
@@ -153,12 +153,12 @@ def MeshSharding : AttrDef<Mesh_Dialect, "MeshSharding"> {
                   split_axes, [&](ArrayRef<MeshAxis> array) {
           return MeshAxesAttr::get($_ctxt, array);
       });
-      return $_get($_ctxt, cluster, splitAxesAttr, partial_axes,
+      return $_get($_ctxt, mesh, splitAxesAttr, partial_axes,
                    partial_type);
     }]>,
-    AttrBuilder<(ins "FlatSymbolRefAttr":$cluster,
+    AttrBuilder<(ins "FlatSymbolRefAttr":$mesh,
                      "ArrayRef<SmallVector<MeshAxis>>":$split_axes), [{
-      return MeshShardingAttr::get($_ctxt, cluster, split_axes, {}, Partial::Sum);
+      return MeshShardingAttr::get($_ctxt, mesh, split_axes, {}, Partial::Sum);
     }]>
   ];
 
diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
index 78ff8bd0cac621..c557ec1da1b6d9 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
@@ -26,17 +26,17 @@ class Mesh_Op<string mnemonic, list<Trait> traits = []> :
     Op<Mesh_Dialect, mnemonic, traits> {
 }
 
-def Mesh_ClusterOp : Mesh_Op<"cluster", [Symbol]> {
-  let summary = "representing a mesh cluster";
+def Mesh_MeshOp : Mesh_Op<"mesh", [Symbol]> {
+  let summary = "representing a device/process mesh";
   let description = [{
-    The mesh.cluster operation is a symbol operation that identifies a specific
-    mesh cluster. The operation has three attributes:
+    The mesh.mesh operation is a symbol operation that identifies a specific
+    mesh. The operation has three attributes:
 
-    1. `sym_name`: This attribute uniquely identifies the name of the mesh
-    cluster. This name serves as a symbolic reference to the cluster throughout
+    1. `sym_name`: This attribute uniquely identifies the name of the mesh.
+    This name serves as a symbolic reference to the mesh throughout
     the MLIR module, allowing for consistent referencing and easier debugging.
 
-    2. `shape`: This attribute represents the shape of the device cluster.
+    2. `shape`: This attribute represents the shape of the device mesh.
     It uses the same notation as a tensor shape. Also allowing for dynamic
     dimensions.
     This flexibility allows for dynamic device assignment or configurations
@@ -46,21 +46,21 @@ def Mesh_ClusterOp : Mesh_Op<"cluster", [Symbol]> {
 
     Example:
     ```
-    // A device mesh cluster with 3 axes, the total device number is 4 * 8 * 12
+    // A device mesh with 3 axes, the total device number is 4 * 8 * 12
     // The dimension sizes are 4, 8, 12 
-    mesh.cluster @mesh0(shape = 4x8x12)
+    mesh.mesh @mesh0(shape = 4x8x12)
 
-    // A device mesh cluster with 2 axes, the total device number is unknown
+    // A device mesh with 2 axes, the total device number is unknown
     // The first dimension size is 4 and the second is unknown
-    mesh.cluster @mesh1(shape = 4x?)
+    mesh.mesh @mesh1(shape = 4x?)
 
-    // A device mesh cluster with 2 axes, the total device number is unknown
+    // A device mesh with 2 axes, the total device number is unknown
     // The first dimension size is unknown and the second is 4
-    mesh.cluster @mesh2(shape = ?x4)
+    mesh.mesh @mesh2(shape = ?x4)
 
-    // A device mesh cluster with 2 axes, the number of devices along both axes
+    // A device mesh with 2 axes, the number of devices along both axes
     // is unknown
-    mesh.cluster @mesh3(shape = ?x?)
+    mesh.mesh @mesh3(shape = ?x?)
 
     // Used in the mesh sharding attribute to extend the standard tensor to
     // distributed
@@ -81,9 +81,9 @@ def Mesh_ClusterOp : Mesh_Op<"cluster", [Symbol]> {
   let hasVerifier = 1;
 }
 
-def Mesh_ClusterShapeOp : Mesh_Op<"cluster_shape", [
+def Mesh_MeshShapeOp : Mesh_Op<"mesh_shape", [
   Pure, DeclareOpInterfaceMethods<SymbolUserOpInterface>]> {
-  let summary = "Get the shape of the cluster.";
+  let summary = "Get the shape of the mesh.";
   let arguments = (ins
     FlatSymbolRefAttr:$mesh,
     DefaultValuedAttr<Mesh_MeshAxesAttr, "{}">:$axes
@@ -99,13 +99,13 @@ def Mesh_ClusterShapeOp : Mesh_Op<"cluster_shape", [
   }];
 
   let builders = [
-    OpBuilder<(ins "::mlir::mesh::ClusterOp":$mesh)>,
+    OpBuilder<(ins "::mlir::mesh::MeshOp":$mesh)>,
     OpBuilder<(ins "StringRef":$mesh, "ArrayRef<MeshAxis>":$axes)>
   ];
 }
 
 def Mesh_ShardOp : Mesh_Op<"shard", [Pure, SameOperandsAndResultType]> {
-  let summary = "Annotate on how a tensor is sharded across a mesh cluster.";
+  let summary = "Annotate on how a tensor is sharded across a mesh.";
   let description = [{
     The mesh.shard operation is designed to specify and guide the sharding
     behavior of a tensor value across a mesh topology. This operation has one
@@ -115,7 +115,7 @@ def Mesh_ShardOp : Mesh_Op<"shard", [Pure, SameOperandsAndResultType]> {
     annotated for sharding.
 
     2. `shard`: This attribute is type of `MeshSharding`, which is the core data
-    structure to represent distributed tensor in mesh cluster.
+    structure to represent distribution of a tensor on a mesh.
 
     3. `annotate_for_users`: A unit attribute addressing the scenario when a
     tensor's sharding annotation differs based on its context of use (either as
@@ -217,7 +217,7 @@ def Mesh_ProcessMultiIndexOp : Mesh_Op<"process_multi_index", [
     attr-dict `:` type($result)
   }];
   let builders = [
-    OpBuilder<(ins "::mlir::mesh::ClusterOp":$mesh)>,
+    OpBuilder<(ins "::mlir::mesh::MeshOp":$mesh)>,
     OpBuilder<(ins "StringRef":$mesh, "ArrayRef<MeshAxis>":$axes)>
   ];
 }
@@ -239,7 +239,7 @@ def Mesh_ProcessLinearIndexOp : Mesh_Op<"process_linear_index", [
   let results = (outs Index:$result);
   let assemblyFormat = "`on` $mesh attr-dict `:` type($result)";
   let builders = [
-    OpBuilder<(ins "::mlir::mesh::ClusterOp":$mesh)>
+    OpBuilder<(ins "::mlir::mesh::MeshOp":$mesh)>
   ];
 }
 
@@ -268,7 +268,7 @@ def Mesh_AllGatherOp : Mesh_CollectiveCommunicationOpBase<"all_gather", [
 
     Example:
     ```mlir
-    mesh.cluster @mesh0(shape = 2x2)
+    mesh.mesh @mesh0(shape = 2x2)
     ...
     %1 = mesh.all_gather %0 on @mesh0 mesh_axes = [1] gather_axis = 1
       : tensor<2x2xi8> -> tensor<2x4xi8>
@@ -353,7 +353,7 @@ def Mesh_AllToAllOp : Mesh_CollectiveCommunicationOpBase<"all_to_all", [
 
     Example:
     ```
-    mesh.cluster @mesh0(shape = 3)
+    mesh.mesh @mesh0(shape = 3)
     ...
     %1 = mesh.all_to_all %0 on @mesh0 mesh_axes = [0]
       split_axis = 0 concat_axis = 0
@@ -410,7 +410,7 @@ def Mesh_BroadcastOp : Mesh_CollectiveCommunicationOpBase<"broadcast", [
     
     Example:
     ```
-    mesh.cluster @mesh0(shape = 2x2)
+    mesh.mesh @mesh0(shape = 2x2)
 
     %1 = mesh.broadcast %0 on @mesh0
       mesh_axes = [0]
@@ -466,7 +466,7 @@ def Mesh_GatherOp : Mesh_CollectiveCommunicationOpBase<"gather", [
 
     Example:
     ```mlir
-    mesh.cluster @mesh0(shape = 2x2)
+    mesh.mesh @mesh0(shape = 2x2)
     ...
     %1 = mesh.gather %0 on @mesh0 mesh_axes = [1]
       gather_axis = 1 root = [1]
@@ -589,7 +589,7 @@ def Mesh_ReduceScatterOp : Mesh_CollectiveCommunicationOpBase<"reduce_scatter",
     across the device group.
     Example:
     ```
-    mesh.cluster @mesh0(shape = 2x2)
+    mesh.mesh @mesh0(shape = 2x2)
     ...
     %1 = mesh.reduce_scatter %0 on @mesh0 mesh_axes = [1]
       reduction = <max> scatter_axis = 0
@@ -652,7 +652,7 @@ def Mesh_ScatterOp : Mesh_CollectiveCommunicationOpBase<"scatter", [
 
     Example:
     ```
-    mesh.cluster @mesh0(shape = 2x2)
+    mesh.mesh @mesh0(shape = 2x2)
     %1 = mesh.scatter %0 on @mesh0 mesh_axes = [0]
       scatter_axis = 0
       root = [1]
@@ -748,7 +748,7 @@ def Mesh_ShiftOp : Mesh_CollectiveCommunicationOpBase<"shift", [
 
     Example:
     ```
-    mesh.cluster @mesh0(shape = 2x4)
+    mesh.mesh @mesh0(shape = 2x4)
     %1 = mesh.shift on @mesh0 mesh_axes = [1]
       shift_axis = 1 offset = 2 rotate
       : tensor<2xi8> -> tensor<2xi8>
diff --git a/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.h b/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.h
index a32274d857f15d..3bef7e6babdec9 100644
--- a/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.h
+++ b/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.h
@@ -25,14 +25,14 @@ struct ShardingOption {
   // An array of int array. The sub-array at the i-th position signifies the
   // mesh axes the i-th loop will be sharded on.
   ShardingArray shardingArray = {};
-  FlatSymbolRefAttr cluster = nullptr;
+  FlatSymbolRefAttr mesh = nullptr;
   // `empty` being true indicates that no sharding information can be inferred
   // at present. Note that it is different from the case where an operation is
   // not sharded.
   bool empty = false;
   ShardingOption() = default;
-  ShardingOption(ShardingArray shardingArray, FlatSymbolRefAttr cluster)
-      : shardingArray(std::move(shardingArray)), cluster(cluster) {}
+  ShardingOption(ShardingArray shardingArray, FlatSymbolRefAttr mesh)
+      : shardingArray(std::move(shardingArray)), mesh(mesh) {}
 };
 
 // This method retrieves the 'MeshShardingAttr' attribute from a given operation
diff --git a/mlir/include/mlir/Dialect/Mesh/Transforms/Spmdization.h b/mlir/include/mlir/Dialect/Mesh/Transforms/Spmdization.h
index f71bb9b262a380..7cb992aac019b3 100644
--- a/mlir/include/mlir/Dialect/Mesh/Transforms/Spmdization.h
+++ b/mlir/include/mlir/Dialect/Mesh/Transforms/Spmdization.h
@@ -17,14 +17,14 @@ namespace mlir {
 namespace mesh {
 
 // Return the sharded shape `shape` acording ot sharding `sharding`.
-ShapedType shardShapedType(ShapedType shape, ClusterOp mesh,
+ShapedType shardShapedType(ShapedType shape, MeshOp mesh,
                            MeshShardingAttr sharding);
 
 // Insert resharding spmdization of the value `sourceShardValue`
 // from sharding `source` to sharding `target`.
 // `sourceShardValue` is the already sharded value according to `source`.
-TypedValue<ShapedType> reshard(OpBuilder &builder, ClusterOp mesh,
-                               ShardOp source, ShardOp target,
+TypedValue<ShapedType> reshard(OpBuilder &builder, MeshOp mesh, ShardOp source,
+                               ShardOp target,
                                TypedValue<ShapedType> sourceShardValue);
 
 void reshardingRegisterDependentDialects(DialectRegistry &registry);
diff --git a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
index f6b6b7c248c432..994a017a1f46c2 100644
--- a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
+++ b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
@@ -114,10 +114,10 @@ Operation *MeshDialect::materializeConstant(OpBuilder &builder, Attribute value,
 // Mesh utilities
 //===----------------------------------------------------------------------===//
 
-static FailureOr<ClusterOp> getMesh(Operation *op, FlatSymbolRefAttr meshSymbol,
-                                    SymbolTableCollection &symbolTable) {
-  mesh::ClusterOp mesh =
-      symbolTable.lookupNearestSymbolFrom<mesh::ClusterOp>(op, meshSymbol);
+static FailureOr<MeshOp> getMesh(Operation *op, FlatSymbolRefAttr meshSymbol,
+                                 SymbolTableCollection &symbolTable) {
+  mesh::MeshOp mesh =
+      symbolTable.lookupNearestSymbolFrom<mesh::MeshOp>(op, meshSymbol);
   if (!mesh) {
     return op->emitError() << "Undefined required mesh symbol \""
                            << meshSymbol.getValue() << "\".";
@@ -144,7 +144,7 @@ bool isUnique(It begin, It end) {
 }
 
 static LogicalResult verifyMeshAxes(Location loc, ArrayRef<MeshAxis> axes,
-                                    ClusterOp mesh) {
+                                    MeshOp mesh) {
   SmallVector<MeshAxis> sorted = llvm::to_vector(axes);
   llvm::sort(sorted);
   if (!isUnique(sorted.begin(), sorted.end())) {
@@ -192,22 +192,22 @@ Partial mesh::getPartialTypeFromReduction(IteratorType iType) {
 }
 
 //===----------------------------------------------------------------------===//
-// mesh.cluster op
+// mesh.mesh op
 //===----------------------------------------------------------------------===//
 
-LogicalResult ClusterOp::verify() {
+LogicalResult MeshOp::verify() {
   int64_t rank = getRank();
 
   if (rank <= 0)
-    return emitOpError("rank of cluster is expected to be a positive integer");
+    return emitOpError("rank of mesh is expected to be a positive integer");
 
   if (getShape().size() > size_t(rank))
     return emitOpError(
-        "rank of shape is not expected to be larger than rank of cluster");
+        "rank of shape is not expected to be larger than rank of mesh");
 
   for (int64_t dimSize : getShape()) {
     if (dimSize < 0 && !ShapedType::isDynamic(dimSize))
-      return emitOpError("dimension size of a mesh cluster is expected to be "
+      return emitOpError("dimension size of a mesh is expected to be "
                          "non-negative or dynamic");
   }
 
@@ -215,11 +215,11 @@ LogicalResult ClusterOp::verify() {
 }
 
 //===----------------------------------------------------------------------===//
-// mesh.cluster_shape op
+// mesh.mesh_shape op
 //===----------------------------------------------------------------------===//
 
 LogicalResult
-ClusterShapeOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
+MeshShapeOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
   auto mesh = ::getMesh(getOperation(), getMeshAttr(), symbolTable);
   if (failed(mesh)) {
     return failure();
@@ -238,16 +238,16 @@ ClusterShapeOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
   return success();
 }
 
-void ClusterShapeOp::build(OpBuilder &odsBuilder, OperationState &odsState,
-                           ClusterOp mesh) {
+void MeshShapeOp::build(OpBuilder &odsBuilder, OperationState &odsState,
+                        MeshOp mesh) {
   build(odsBuilder, odsState,
         SmallVector<Type>(mesh.getRank(), odsBuilder.getIndexType()),
         mesh.getSymName(),
         MeshAxesAttr::get(odsBuilder.getContext(), SmallVector<MeshAxis>()));
 }
 
-void ClusterShapeOp::build(OpBuilder &odsBuilder, OperationState &odsState,
-                           StringRef mesh, ArrayRef<MeshAxis> axes) {
+void MeshShapeOp::build(OpBuilder &odsBuilder, OperationState &odsState,
+                        StringRef mesh, ArrayRef<MeshAxis> axes) {
   build(odsBuilder, odsState,
         SmallVector<Type>(axes.size(), odsBuilder.getIndexType()), mesh,
         MeshAxesAttr::get(odsBuilder.getContext(), axes));
@@ -261,7 +261,7 @@ LogicalResult
 MeshShardingAttr::verify(function_ref<InFlightDiagnostic()> emitError,
                          FlatSymbolRefAttr, ArrayRef<MeshAxesAttr> splitAxes,
                          ArrayRef<MeshAxis> partialAxes, Partial) {
-  // TODO: At present cluster symbol ref is not verified. This is due to the
+  // TODO: At present mesh symbol ref is not verified. This is due to the
   // difficulty in fetching the corresponding symbol op based on an attribute.
 
   llvm::SmallSet<MeshAxis, 4> visitedAxes;
@@ -292,8 +292,7 @@ bool MeshShardingAttr::operator==(Attribute rhs) const {
 }
 
 bool MeshShardingAttr::operator==(MeshShardingAttr rhs) const {
-  if (getCluster() != rhs.getCluster() ||
-      getPartialAxes() != rhs.getPartialAxes()) {
+  if (getMesh() != rhs.getMesh() || getPartialAxes() != rhs.getPartialAxes()) {
     return false;
   }
 
@@ -342,7 +341,7 @@ ProcessMultiIndexOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
 }
 
 void ProcessMultiIndexOp::build(OpBuilder &odsBuilder, OperationState &odsState,
-                                ClusterOp mesh) {
+                                MeshOp mesh) {
   build(odsBuilder, odsState,
         SmallVector<Type>(mesh.getRank(), odsBuilder.getIndexType()),
         mesh.getSymName(), ArrayRef<MeshAxis>());
@@ -369,7 +368,7 @@ ProcessLinearIndexOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
 }
 
 void ProcessLinearIndexOp::build(OpBuilder &odsBuilder,
-                                 OperationState &odsState, ClusterOp mesh) {
+                                 OperationState &odsState, MeshOp mesh) {
   build(odsBuilder, odsState, mesh.getSymName());
 }
 
@@ -427,7 +426,7 @@ static LogicalResult verifyInGroupDevice(Location loc, StringRef deviceName,
 }
 
 template <typename Op>
-static FailureOr<ClusterOp>
+static FailureOr<MeshOp>
 getMeshAndVerifyAxes(Op op, SymbolTableCollection &symbolTable) {
   auto mesh = ::getMesh(op.getOperation(), op.getMeshAttr(), symbolTable);
   if (failed(mesh)) {
diff --git a/mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp b/mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp
index dca7e86e6f07f5..5dc91ff1c02d20 100644
--- a/mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp
+++ b/mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp
@@ -215,11 +215,10 @@ namespace {
 // Update the given `shardingOption` according to `meshAxes` and `loopIdx`
 static LogicalResult fillShardingOption(Operation *op,
                                         ShardingOption &shardingOption,
-                                        FlatSymbolRefAttr cluster,
+                                        FlatSymbolRefAttr mesh,
                                         ArrayRef<MeshAxis> meshAxes,
                                         unsigned loopIdx) {
-  if ((shardingOption.cluster && cluster &&
-       shardingOption.cluster != cluster) ||
+  if ((shardingOption.mesh && mesh && shardingOption.mesh != mesh) ||
       (!shardingOption.shardingArray[loopIdx].empty() &&
        shardingOption.shardingArray[loopIdx] != meshAxes)) {
     LLVM_DEBUG(DBGS() << "sharding option conflicts on loop iterator "
@@ -238,8 +237,8 @@ static LogicalResult fillShardingOption(Operation *op,
       }
     }
   }
-  if (cluster)
-    shardingOption.cluster = cluster;
+  if (mesh)
+    shardingOption.mesh = mesh;
   if (shardingOption.shardingArray[loopIdx].empty())
     shardingOption.shardingArray[loopIdx].append(meshAxes.begin(),
                                                  meshAxes.end());
@@ -281,7 +280,7 @@ FailureOr<ShardingOption> mesh::detail::defaultGetShardingOption(
       auto dim = cast<AffineDimExpr>(expr);
       unsigned index = dim.getPosition();
       visitedLoopIndices.insert(index);
-      if (failed(fillShardingOption(op, shardingOption, shardAttr.getCluster(),
+      if (failed(fillShardingOption(op, shardingOption, shardAttr.getMesh(),
                                     axes, index)))
         return failure();
     }
@@ -333,8 +332,8 @@ FailureOr<ShardingOption> mesh::detail::defaultGetShardingOption(
       if (loopIndices->size() == 1) {
         unsigned loopIdx = *loopIndices->begin();
         visitedLoopIndices.insert(loopIdx);
-        if (failed(fillShardingOption(op, shardingOption,
-                                      shardAttr.getCluster(), axes, loopIdx)))
+        if (failed(fillShardingOption(op, shardingOption, shardAttr.getMesh(),
+                                      axes, loopIdx)))
           return failure();
       }
       // If multiple loop indices correspond to a dimension of an operand, it is
@@ -437,9 +436,8 @@ static LogicalResult addShardOp(OpBuilder &b, OpResult result,
   }
 
   removeTrailingEmptySubArray(splitAxes);
-  MeshShardingAttr shardAttr =
-      MeshShardingAttr::get(b.getContext(), shardingOption.cluster, splitAxes,
-                            partialAxes, partialType);
+  MeshShardingAttr shardAttr = MeshShardingAttr::get(
+      b.getContext(), shardingOption.mesh, splitAxes, partialAxes, partialType);
   OpBuilder::InsertionGuard guard(b);
   b.setInsertionPointAfterValue(result);
   auto shardOp = b.create<ShardOp>(result.getLoc(), resultType, result,
@@ -485,7 +483,7 @@ static LogicalResult addShardOp(OpBuilder &b, OpOperand &opOperand,
 
   removeTrailingEmptySubArray(splitAxes);
   MeshShardingAttr shardAttr =
-      MeshShardingAttr::get(b.getContext(), shardingOption.cluster, splitAxes);
+      MeshShardingAttr::get(b.getContext(), shardingOption.mesh, splitAxes);
   OpBuilder::InsertionGuard guard(b);
   b.setInsertionPoint(opOperand.getOwner());
   auto shardOp = b.create<ShardOp>(operand.getLoc(), operandType, operand,
diff --git a/mlir/lib/Dialect/Mesh/Transforms/Simplifications.cpp b/mlir/lib/Dialect/Mesh/Transforms/Simplifications.cpp
index 429e684c845fb1..c0273cdaef7144 100644
--- a/mlir/lib/Dialect/Mesh/Transforms/Simplifications.cpp
+++ b/mlir/lib/Dialect/Mesh/Transforms/Simplifications.cpp
@@ -55,20 +55,19 @@ namespace {
 // DialectFoldInterface, because it needs a SymbolTableCollection to cache the
 // symbol tables.
 // We can't use DialectFoldInterface since the cache may be invalidated by some
-// pass changing the referenced ClusterOp ops.
-struct ClusterShapeFolder : OpRewritePattern<ClusterShapeOp> {
+// pass changing the referenced MeshOp ops.
+struct MeshShapeFolder : OpRewritePattern<MeshShapeOp> {
   template <typename... OpRewritePatternArgs>
-  ClusterShapeFolder(SymbolTableCollection &symbolTableCollection,
-                     OpRewritePatternArgs &&...opRewritePatternArgs)
+  MeshShapeFolder(SymbolTableCollection &symbolTableCollection,
+                  OpRewritePatternArgs &&...opRewritePatternArgs)
       : OpRewritePattern(
             std::forward<OpRewritePatternArgs...>(opRewritePatternArgs)...),
         symbolTableCollection(symbolTableCollection) {}
-  LogicalResult matchAndRewrite(ClusterShapeOp op,
+  LogicalResult matchAndRewrite(MeshShapeOp op,
                                 PatternRewriter &rewriter) const override {
     ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
-    ClusterOp mesh =
-        symbolTableCollection.lookupNearestSymbolFrom<mesh::ClusterOp>(
-            op.getOperation(), op.getMeshAttr());
+    MeshOp mesh = symbolTableCollection.lookupNearestSymbolFrom<mesh::MeshOp>(
+        op.getOperation(), op.getMeshAttr());
     if (!mesh) {
       return failure();
     }
@@ -104,8 +103,8 @@ struct ClusterShapeFolder : OpRewritePattern<ClusterShapeOp> {
 
     // Leave only the dynamic mesh axes to be queried.
     if (!newShapeOpMeshAxes.empty()) {
-      ClusterShapeOp newShapeOp =
-          builder.create<ClusterShapeOp>(mesh.getSymName(), newShapeOpMeshAxes);
+      MeshShapeOp newShapeOp =
+          builder.create<MeshShapeOp>(mesh.getSymName(), newShapeOpMeshAxes);
       for (size_t i = 0; i < newShapeOp->getResults().size(); ++i) {
         newResults[newToOldResultsIndexMap[i]] = newShapeOp->getResults()[i];
       }
@@ -123,8 +122,7 @@ struct ClusterShapeFolder : OpRewritePattern<ClusterShapeOp> {
 
 void populateFoldingPatterns(RewritePatternSet &patterns,
                              SymbolTableCollection &symbolTableCollection) {
-  patterns.add<ClusterShapeFolder>(symbolTableCollection,
-                                   patterns.getContext());
+  patterns.add<MeshShapeFolder>(symbolTableCollection, patterns.getContext());
 }
 
 } // namespace mesh
diff --git a/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp b/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
index 9478b2e4ee5cb2..593158d5f6d293 100644
--- a/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
+++ b/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
@@ -84,7 +84,7 @@ static void shardShape(const InShape &inShape, const MeshShape &meshShape,
   }
 }
 
-ShapedType shardShapedType(ShapedType shape, ClusterOp mesh,
+ShapedType shardShapedType(ShapedType shape, MeshOp mesh,
                            MeshShardingAttr sharding) {
   using Dim = std::decay_t<decltype(shape.getDimSize(0))>;
   SmallVector<Dim> resShapeArr(shape.getShape().size());
@@ -141,7 +141,7 @@ handlePartialAxesDuringResharding(OpBuilder &builder,
   TypedValue<ShapedType> resultValue =
       builder
           .create<AllReduceOp>(sourceShard.getLoc(), sourceShard.getType(),
-                               sourceSharding.getCluster().getLeafReference(),
+                               sourceSharding.getMesh().getLeafReference(),
                                allReduceMeshAxes, sourceShard,
                                sourceSharding.getPartialType())
           .getResult()
@@ -154,7 +154,7 @@ handlePartialAxesDuringResharding(OpBuilder &builder,
                   return targetShardingPartialAxesSet.contains(a);
                 });
   MeshShardingAttr resultSharding =
-      MeshShardingAttr::get(builder.getContext(), sourceSharding.getCluster(),
+      MeshShardingAttr::get(builder.getContext(), sourceSharding.getMesh(),
                             sourceSharding.getSplitAxes(), remainingPartialAxes,
                             sourceSharding.getPartialType());
   return {resultValue, resultSharding};
@@ -175,7 +175,7 @@ targetShardingInSplitLastAxis(MLIRContext *ctx, MeshShardingAttr sourceSharding,
   targetShardingSplitAxes[splitTensorAxis] =
       MeshAxesAttr::get(ctx, targetSplitAxes);
   return MeshShardingAttr::get(
-      ctx, sourceSharding.getCluster(), targetShardingSplitAxes,
+      ctx, sourceSharding.getMesh(), targetShardingSplitAxes,
       sourceSharding.getPartialAxes(), sourceSharding.getPartialType());
 }
 
@@ -197,7 +197,7 @@ static ShapedType targetShapeInSplitLastAxis(ShapedType sourceShape,
 static std::tuple<TypedValue<ShapedType>, MeshShardingAttr>
 splitLastAxisInResharding(ImplicitLocOpBuilder &builder,
                           MeshShardingAttr sourceSharding,
-                          TypedValue<ShapedType> sourceShard, ClusterOp mesh,
+                          TypedValue<ShapedType> sourceShard, MeshOp mesh,
                           int64_t splitTensorAxis, MeshAxis splitMeshAxis) {
   MLIRContext *ctx = builder.getContext();
   builder.setInsertionPointAfterValue(sourceShard);
@@ -217,8 +217,8 @@ splitLastAxisInResharding(ImplicitLocOpBuilder &builder,
 
   Value meshAxisSize =
       builder
-          .create<ClusterShapeOp>(mesh.getSymName(),
-                                  SmallVector<MeshAxis>({splitMeshAxis}))
+          .create<MeshShapeOp>(mesh.getSymName(),
+                               SmallVector<MeshAxis>({splitMeshAxis}))
           .getResult()[0];
 
   Value sourceAxisSize =
@@ -305,7 +305,7 @@ detectSplitLastAxisInResharding(MeshShardingAttr sourceSharding,
 }
 
 static std::optional<std::tuple<TypedValue<ShapedType>, MeshShardingAttr>>
-trySplitLastAxisInResharding(ImplicitLocOpBuilder &builder, ClusterOp mesh,
+trySplitLastAxisInResharding(ImplicitLocOpBuilder &builder, MeshOp mesh,
                              MeshShardingAttr sourceSharding,
                              MeshShardingAttr targetSharding,
                              TypedValue<ShapedType> sourceShard) {
@@ -366,7 +366,7 @@ targetShardingInUnsplitLastAxis(MLIRContext *ctx,
   targetShardingSplitAxes[splitTensorAxis] =
       MeshAxesAttr::get(ctx, targetSplitAxes);
   return MeshShardingAttr::get(
-      ctx, sourceSharding.getCluster(), targetShardingSplitAxes,
+      ctx, sourceSharding.getMesh(), targetShardingSplitAxes,
       sourceSharding.getPartialAxes(), sourceSharding.getPartialType());
 }
 
@@ -382,7 +382,7 @@ static std::tuple<TypedValue<ShapedType>, MeshShardingAttr>
 unsplitLastAxisInResharding(ImplicitLocOpBuilder &builder,
                             MeshShardingAttr sourceSharding,
                             ShapedType sourceUnshardedShape,
-                            TypedValue<ShapedType> sourceShard, ClusterOp mesh,
+                            TypedValue<ShapedType> sourceShard, MeshOp mesh,
                             int64_t splitTensorAxis, MeshAxis splitMeshAxis) {
   MLIRContext *ctx = builder.getContext();
   builder.setInsertionPointAfterValue(sourceShard);
@@ -406,7 +406,7 @@ unsplitLastAxisInResharding(ImplicitLocOpBuilder &builder,
 }
 
 static std::optional<std::tuple<TypedValue<ShapedType>, MeshShardingAttr>>
-tryUnsplitLastAxisInResharding(ImplicitLocOpBuilder &builder, ClusterOp mesh,
+tryUnsplitLastAxisInResharding(ImplicitLocOpBuilder &builder, MeshOp mesh,
                                MeshShardingAttr sourceSharding,
                                MeshShardingAttr targetSharding,
                                ShapedType sourceUnshardedShape,
@@ -495,7 +495,7 @@ targetShardingInMoveLastAxis(MLIRContext *ctx, MeshShardingAttr sourceSharding,
       MeshAxesAttr::get(ctx, targetSplitAxes);
 
   return MeshShardingAttr::get(
-      ctx, sourceSharding.getCluster(), targetShardingSplitAxes,
+      ctx, sourceSharding.getMesh(), targetShardingSplitAxes,
       sourceSharding.getPartialAxes(), sourceSharding.getPartialType());
 }
 
@@ -512,7 +512,7 @@ static ShapedType allToAllResultShapeInMoveLastAxis(ShapedType sourceShape,
 }
 
 static std::tuple<TypedValue<ShapedType>, MeshShardingAttr>
-moveLastSplitAxisInResharding(ImplicitLocOpBuilder &builder, ClusterOp mesh,
+moveLastSplitAxisInResharding(ImplicitLocOpBuilder &builder, MeshOp mesh,
                               MeshShardingAttr sourceSharding,
                               ShapedType sourceUnshardedShape,
                               TypedValue<ShapedType> sourceShard,
@@ -541,7 +541,7 @@ moveLastSplitAxisInResharding(ImplicitLocOpBuilder &builder, ClusterOp mesh,
 }
 
 static std::optional<std::tuple<TypedValue<ShapedType>, MeshShardingAttr>>
-tryMoveLastSplitAxisInResharding(ImplicitLocOpBuilder &builder, ClusterOp mesh,
+tryMoveLastSplitAxisInResharding(ImplicitLocOpBuilder &builder, MeshOp mesh,
                                  MeshShardingAttr sourceSharding,
                                  MeshShardingAttr targetSharding,
                                  ShapedType sourceUnshardedShape,
@@ -561,7 +561,7 @@ tryMoveLastSplitAxisInResharding(ImplicitLocOpBuilder &builder, ClusterOp mesh,
 // Currently the sharded tensor axes must be exactly divisible by the single
 // mesh axis size.
 static TypedValue<ShapedType>
-reshardOn1DMesh(ImplicitLocOpBuilder &builder, ClusterOp mesh,
+reshardOn1DMesh(ImplicitLocOpBuilder &builder, MeshOp mesh,
                 MeshShardingAttr sourceSharding,
                 MeshShardingAttr targetSharding,
                 TypedValue<ShapedType> sourceUnshardedValue,
@@ -604,7 +604,7 @@ reshardOn1DMesh(ImplicitLocOpBuilder &builder, ClusterOp mesh,
   return targetShard;
 }
 
-TypedValue<ShapedType> reshard(ImplicitLocOpBuilder &builder, ClusterOp mesh,
+TypedValue<ShapedType> reshard(ImplicitLocOpBuilder &builder, MeshOp mesh,
                                MeshShardingAttr sourceSharding,
                                MeshShardingAttr targetSharding,
                                TypedValue<ShapedType> sourceUnshardedValue,
@@ -616,8 +616,8 @@ TypedValue<ShapedType> reshard(ImplicitLocOpBuilder &builder, ClusterOp mesh,
                          sourceUnshardedValue, sourceShard);
 }
 
-TypedValue<ShapedType> reshard(OpBuilder &builder, ClusterOp mesh,
-                               ShardOp source, ShardOp target,
+TypedValue<ShapedType> reshard(OpBuilder &builder, MeshOp mesh, ShardOp source,
+                               ShardOp target,
                                TypedValue<ShapedType> sourceShardValue) {
   assert(!source.getAnnotateForUsers());
   assert(target.getAnnotateForUsers());
diff --git a/mlir/lib/Dialect/Mesh/Transforms/Transforms.cpp b/mlir/lib/Dialect/Mesh/Transforms/Transforms.cpp
index c27e173d877d69..5c2344651bf60d 100644
--- a/mlir/lib/Dialect/Mesh/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Mesh/Transforms/Transforms.cpp
@@ -24,7 +24,7 @@ namespace mlir::mesh {
 namespace {
 
 /// Lower `mesh.process_multi_index` into expression using
-/// `mesh.process_linear_index` and `mesh.cluster_shape`.
+/// `mesh.process_linear_index` and `mesh.mesh_shape`.
 struct ProcessMultiIndexOpLowering : OpRewritePattern<ProcessMultiIndexOp> {
   template <typename... OpRewritePatternArgs>
   ProcessMultiIndexOpLowering(SymbolTableCollection &symbolTableCollection,
@@ -35,9 +35,8 @@ struct ProcessMultiIndexOpLowering : OpRewritePattern<ProcessMultiIndexOp> {
 
   LogicalResult matchAndRewrite(ProcessMultiIndexOp op,
                                 PatternRewriter &rewriter) const override {
-    ClusterOp mesh =
-        symbolTableCollection.lookupNearestSymbolFrom<mesh::ClusterOp>(
-            op.getOperation(), op.getMeshAttr());
+    MeshOp mesh = symbolTableCollection.lookupNearestSymbolFrom<mesh::MeshOp>(
+        op.getOperation(), op.getMeshAttr());
     if (!mesh) {
       return failure();
     }
@@ -45,7 +44,7 @@ struct ProcessMultiIndexOpLowering : OpRewritePattern<ProcessMultiIndexOp> {
     ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
     builder.setInsertionPointAfter(op.getOperation());
     Value linearIndex = builder.create<ProcessLinearIndexOp>(mesh);
-    ValueRange meshShape = builder.create<ClusterShapeOp>(mesh).getResults();
+    ValueRange meshShape = builder.create<MeshShapeOp>(mesh).getResults();
     SmallVector<Value> completeMultiIndex =
         builder.create<affine::AffineDelinearizeIndexOp>(linearIndex, meshShape)
             .getMultiIndex();
diff --git a/mlir/test/Dialect/Mesh/canonicalization.mlir b/mlir/test/Dialect/Mesh/canonicalization.mlir
index 4cc009ef24eb3c..23c5b253b4c073 100644
--- a/mlir/test/Dialect/Mesh/canonicalization.mlir
+++ b/mlir/test/Dialect/Mesh/canonicalization.mlir
@@ -1,6 +1,6 @@
 // RUN: mlir-opt --canonicalize %s | FileCheck %s
 
-mesh.cluster @mesh0(shape = 2x4)
+mesh.mesh @mesh0(shape = 2x4)
 
 // CHECK-LABEL: func @all_reduce_empty_mesh_axes
 func.func @all_reduce_empty_mesh_axes(
diff --git a/mlir/test/Dialect/Mesh/folding.mlir b/mlir/test/Dialect/Mesh/folding.mlir
index 9162dc57ecfdf4..369f316d0f797d 100644
--- a/mlir/test/Dialect/Mesh/folding.mlir
+++ b/mlir/test/Dialect/Mesh/folding.mlir
@@ -1,22 +1,22 @@
 // RUN: mlir-opt -test-mesh-simplifications %s | FileCheck %s
 
-mesh.cluster @mesh0(shape = 4x?x2)
-mesh.cluster @mesh1(shape = 2x3)
+mesh.mesh @mesh0(shape = 4x?x2)
+mesh.mesh @mesh1(shape = 2x3)
 
-// CHECK-LABEL: func.func @cluster_shape_op_folding
-func.func @cluster_shape_op_folding() -> (index, index) {
+// CHECK-LABEL: func.func @mesh_shape_op_folding
+func.func @mesh_shape_op_folding() -> (index, index) {
   // CHECK: %[[AXIS_2_SIZE:.*]] = arith.constant 2 : index
-  // CHECK: %[[AXIS_1_SIZE:.*]] = mesh.cluster_shape @mesh0 axes = [1] : index
-  %0:2 = mesh.cluster_shape @mesh0 axes = [2, 1] : index, index
+  // CHECK: %[[AXIS_1_SIZE:.*]] = mesh.mesh_shape @mesh0 axes = [1] : index
+  %0:2 = mesh.mesh_shape @mesh0 axes = [2, 1] : index, index
   // CHECK: return %[[AXIS_2_SIZE]], %[[AXIS_1_SIZE]]
   return %0#0, %0#1 : index, index
 }
 
-// CHECK-LABEL: func.func @cluster_shape_op_folding_all_axes_static_mesh
-func.func @cluster_shape_op_folding_all_axes_static_mesh() -> (index, index) {
+// CHECK-LABEL: func.func @mesh_shape_op_folding_all_axes_static_mesh
+func.func @mesh_shape_op_folding_all_axes_static_mesh() -> (index, index) {
   // CHECK: %[[AXIS_0_SIZE:.*]] = arith.constant 2 : index
   // CHECK: %[[AXIS_1_SIZE:.*]] = arith.constant 3 : index
-  %0:2 = mesh.cluster_shape @mesh1 : index, index
+  %0:2 = mesh.mesh_shape @mesh1 : index, index
   // CHECK: return %[[AXIS_0_SIZE]], %[[AXIS_1_SIZE]]
   return %0#0, %0#1 : index, index
 }
diff --git a/mlir/test/Dialect/Mesh/invalid.mlir b/mlir/test/Dialect/Mesh/invalid.mlir
index 8a1fb80065573b..259e4ebf76757c 100644
--- a/mlir/test/Dialect/Mesh/invalid.mlir
+++ b/mlir/test/Dialect/Mesh/invalid.mlir
@@ -1,16 +1,16 @@
 // RUN: mlir-opt -split-input-file -verify-diagnostics %s
 
-// expected-error at +1 {{rank of cluster is expected to be a positive integer}}
-mesh.cluster @mesh0(shape = [])
+// expected-error at +1 {{rank of mesh is expected to be a positive integer}}
+mesh.mesh @mesh0(shape = [])
 
 // -----
 
-// expected-error at +1 {{custom op 'mesh.cluster' Failed parsing dimension list. Did you mean an empty list? It must be denoted by "[]".}}
-mesh.cluster @mesh0(shape = -1)
+// expected-error at +1 {{custom op 'mesh.mesh' Failed parsing dimension list. Did you mean an empty list? It must be denoted by "[]".}}
+mesh.mesh @mesh0(shape = -1)
 
 // -----
 
-mesh.cluster @mesh0(shape = 2x4)
+mesh.mesh @mesh0(shape = 2x4)
 
 func.func @mesh_axis_duplicated_different_subarray(
     // expected-error at +1 {{mesh axis duplicated}}
@@ -21,7 +21,7 @@ func.func @mesh_axis_duplicated_different_subarray(
 
 // -----
 
-mesh.cluster @mesh0(shape = 2x4)
+mesh.mesh @mesh0(shape = 2x4)
 
 func.func @mesh_axis_duplicated_same_subarray(
     // expected-error at +1 {{mesh axis duplicated}}
@@ -32,7 +32,7 @@ func.func @mesh_axis_duplicated_same_subarray(
 
 // -----
 
-mesh.cluster @mesh0(shape = 2x4)
+mesh.mesh @mesh0(shape = 2x4)
 
 func.func @mesh_axis_duplicated_bewteen_split_and_partial(
     // expected-error at +1 {{mesh axis duplicated}}
@@ -43,7 +43,7 @@ func.func @mesh_axis_duplicated_bewteen_split_and_partial(
 
 // -----
 
-mesh.cluster @mesh0(shape = 2x4)
+mesh.mesh @mesh0(shape = 2x4)
 
 func.func @mesh_axis_negtive_in_split_part(
     // expected-error at +1 {{mesh axis is expected to be non-negative}}
@@ -54,7 +54,7 @@ func.func @mesh_axis_negtive_in_split_part(
 
 // -----
 
-mesh.cluster @mesh0(shape = 2x4)
+mesh.mesh @mesh0(shape = 2x4)
 
 func.func @mesh_axis_negtive_in_partial(
     // expected-error at +1 {{mesh axis is expected to be non-negative}}
@@ -67,61 +67,61 @@ func.func @mesh_axis_negtive_in_partial(
 
 func.func @sharding_attribute_invalid_nested_symbol(%arg0 : tensor<4x8xf32>) {
   // expected-error at +2 {{custom op 'mesh.shard' invalid kind of attribute specified}}
-  // expected-error at +1 {{custom op 'mesh.shard' failed to parse MeshSharding parameter 'cluster' which is to be a `::mlir::FlatSymbolRefAttr`}}
+  // expected-error at +1 {{custom op 'mesh.shard' failed to parse MeshSharding parameter 'mesh' which is to be a `::mlir::FlatSymbolRefAttr`}}
   %0 = mesh.shard %arg0 to <@a::@b, [[0]]> : tensor<4x8xf32>
 }
 
 // -----
 
-mesh.cluster @mesh0(shape = 2x4)
+mesh.mesh @mesh0(shape = 2x4)
 
-func.func @cluster_shape_mesh_axis_out_of_bounds() -> (index, index) {
+func.func @mesh_shape_mesh_axis_out_of_bounds() -> (index, index) {
   // expected-error at +1 {{0-based mesh axis index 2 is out of bounds. The referenced mesh "mesh0" is of rank 2.}}
-  %0:2 = mesh.cluster_shape @mesh0 axes = [0, 2] : index, index
+  %0:2 = mesh.mesh_shape @mesh0 axes = [0, 2] : index, index
   return %0#0, %0#1 : index, index
 }
 
 // -----
 
-mesh.cluster @mesh0(shape = 1x2x3)
+mesh.mesh @mesh0(shape = 1x2x3)
 
-func.func @cluster_shape_duplicate_mesh_axis() -> (index, index, index) {
+func.func @mesh_shape_duplicate_mesh_axis() -> (index, index, index) {
   // expected-error at +1 {{Mesh axes contains duplicate elements.}}
-  %0:3 = mesh.cluster_shape @mesh0 axes = [0, 2, 0] : index, index, index
+  %0:3 = mesh.mesh_shape @mesh0 axes = [0, 2, 0] : index, index, index
   return %0#0, %0#1, %0#2 : index, index, index
 }
 
 // -----
 
-mesh.cluster @mesh0(shape = 2x4)
+mesh.mesh @mesh0(shape = 2x4)
 
-func.func @cluster_shape_wrong_number_of_results() -> (index, index) {
+func.func @mesh_shape_wrong_number_of_results() -> (index, index) {
   // expected-error at +1 {{Unexpected number of results 2. Expected 1.}}
-  %0:2 = mesh.cluster_shape @mesh0 axes = [0] : index, index
+  %0:2 = mesh.mesh_shape @mesh0 axes = [0] : index, index
   return %0#0, %0#1 : index, index
 }
 
 // -----
 
-mesh.cluster @mesh0(shape = 1x2x3)
+mesh.mesh @mesh0(shape = 1x2x3)
 
-func.func @cluster_shape_wrong_number_of_results_empty_mesh_axes() -> (index, index) {
+func.func @mesh_shape_wrong_number_of_results_empty_mesh_axes() -> (index, index) {
   // expected-error at +1 {{Unexpected number of results 2. Expected 3.}}
-  %0:2 = mesh.cluster_shape @mesh0 : index, index
+  %0:2 = mesh.mesh_shape @mesh0 : index, index
   return %0#0, %0#1 : index, index
 }
 
 // -----
 
-func.func @cluster_shape_invalid_mesh_name() -> (index) {
+func.func @mesh_shape_invalid_mesh_name() -> (index) {
   // expected-error at +1 {{Undefined required mesh symbol "this_mesh_symbol_does_not_exist".}}
-  %0 = mesh.cluster_shape @this_mesh_symbol_does_not_exist : index
+  %0 = mesh.mesh_shape @this_mesh_symbol_does_not_exist : index
   return %0#0 : index
 }
 
 // -----
 
-mesh.cluster @mesh0(shape = 2x4)
+mesh.mesh @mesh0(shape = 2x4)
 
 func.func @process_multi_index_mesh_axis_out_of_bounds() -> (index, index) {
   // expected-error at +1 {{0-based mesh axis index 2 is out of bounds. The referenced mesh "mesh0" is of rank 2.}}
@@ -131,7 +131,7 @@ func.func @process_multi_index_mesh_axis_out_of_bounds() -> (index, index) {
 
 // -----
 
-mesh.cluster @mesh0(shape = 1x2x3)
+mesh.mesh @mesh0(shape = 1x2x3)
 
 func.func @process_multi_index_duplicate_mesh_axis() -> (index, index, index) {
   // expected-error at +1 {{Mesh axes contains duplicate elements.}}
@@ -141,7 +141,7 @@ func.func @process_multi_index_duplicate_mesh_axis() -> (index, index, index) {
 
 // -----
 
-mesh.cluster @mesh0(shape = 2x4)
+mesh.mesh @mesh0(shape = 2x4)
 
 func.func @process_multi_index_wrong_number_of_results() -> (index, index) {
   // expected-error at +1 {{Unexpected number of results 2. Expected 1.}}
@@ -151,7 +151,7 @@ func.func @process_multi_index_wrong_number_of_results() -> (index, index) {
 
 // -----
 
-mesh.cluster @mesh0(shape = 1x2x3)
+mesh.mesh @mesh0(shape = 1x2x3)
 
 func.func @process_multi_index_wrong_number_of_results_empty_mesh_axes() -> (index, index) {
   // expected-error at +1 {{Unexpected number of results 2. Expected 3.}}
@@ -187,7 +187,7 @@ func.func @all_reduce_invalid_mesh_symbol(
 
 // -----
 
-mesh.cluster @mesh0(shape = 2x4)
+mesh.mesh @mesh0(shape = 2x4)
 
 func.func @all_reduce_invalid_mesh_axis(
     %arg0 : tensor<4xf32>) -> tensor<4xf64> {
@@ -199,7 +199,7 @@ func.func @all_reduce_invalid_mesh_axis(
 
 // -----
 
-mesh.cluster @mesh0(shape = 2x4)
+mesh.mesh @mesh0(shape = 2x4)
 
 func.func @all_reduce_duplicate_mesh_axis(
     %arg0 : tensor<4xf32>) -> tensor<4xf64> {
@@ -211,7 +211,7 @@ func.func @all_reduce_duplicate_mesh_axis(
 
 // -----
 
-mesh.cluster @mesh0(shape = 2x4)
+mesh.mesh @mesh0(shape = 2x4)
 
 func.func @all_reduce_invalid_tensor_dimension_size(
     %arg0 : tensor<4xf32>) -> tensor<5xf64> {
@@ -232,7 +232,7 @@ func.func @all_gather_invalid_mesh_symbol(
 
 // -----
 
-mesh.cluster @mesh0(shape = 2x4)
+mesh.mesh @mesh0(shape = 2x4)
 
 func.func @all_gather_invalid_mesh_axis(
     %arg0 : tensor<4xf32>) -> tensor<4xf32> {
@@ -244,7 +244,7 @@ func.func @all_gather_invalid_mesh_axis(
 
 // -----
 
-mesh.cluster @mesh0(shape = 2x4)
+mesh.mesh @mesh0(shape = 2x4)
 
 func.func @all_reduce_duplicate_mesh_axis(
     %arg0 : tensor<4xf32>) -> tensor<4xf32> {
@@ -256,7 +256,7 @@ func.func @all_reduce_duplicate_mesh_axis(
 
 // -----
 
-mesh.cluster @mesh0(shape = 1)
+mesh.mesh @mesh0(shape = 1)
 
 func.func @all_gather_invalid_non_gather_axis_dimension_size(
     %arg0 : tensor<3x4xf32>) -> tensor<3x5xf32> {
@@ -268,7 +268,7 @@ func.func @all_gather_invalid_non_gather_axis_dimension_size(
 
 // -----
 
-mesh.cluster @mesh0(shape = 1x2)
+mesh.mesh @mesh0(shape = 1x2)
 
 func.func @all_gather_invalid_gather_axis_dimension_size(
     %arg0 : tensor<3x4xf32>) -> tensor<3x5xf32> {
@@ -280,7 +280,7 @@ func.func @all_gather_invalid_gather_axis_dimension_size(
 
 // -----
 
-mesh.cluster @mesh0(shape = 1)
+mesh.mesh @mesh0(shape = 1)
 
 func.func @all_gather_invalid_gather_axis_dynamic_dimension(
     %arg0 : tensor<?xf32>) -> tensor<3xf32> {
@@ -292,7 +292,7 @@ func.func @all_gather_invalid_gather_axis_dynamic_dimension(
 
 // -----
 
-mesh.cluster @mesh0(shape = 1)
+mesh.mesh @mesh0(shape = 1)
 
 func.func @all_gather_invalid_gather_axis(
     %arg0 : tensor<3xf32>) -> tensor<3xf32> {
@@ -304,7 +304,7 @@ func.func @all_gather_invalid_gather_axis(
 
 // -----
 
-mesh.cluster @mesh0(shape = 1)
+mesh.mesh @mesh0(shape = 1)
 
 func.func @all_gather_invalid_negative_gather_axis(
     %arg0 : tensor<3xf32>) -> tensor<3xf32> {
@@ -327,7 +327,7 @@ func.func @all_to_all_invalid_mesh_symbol(
 
 // -----
 
-mesh.cluster @mesh0(shape = 1)
+mesh.mesh @mesh0(shape = 1)
 
 func.func @all_to_all_duplicate_mesh_axis(
     %arg0 : tensor<3x6xi8>) -> tensor<3x6xi8> {
@@ -340,7 +340,7 @@ func.func @all_to_all_duplicate_mesh_axis(
 
 // -----
 
-mesh.cluster @mesh0(shape = ?x1)
+mesh.mesh @mesh0(shape = ?x1)
 
 func.func @all_to_all_invalid_non_dynamic_result_dimension_induced_by_dynamic_device_group(
     %arg0 : tensor<3x6xi8>) -> tensor<3x6xi8> {
@@ -353,7 +353,7 @@ func.func @all_to_all_invalid_non_dynamic_result_dimension_induced_by_dynamic_de
 
 // -----
 
-mesh.cluster @mesh0(shape = 1x1)
+mesh.mesh @mesh0(shape = 1x1)
 
 func.func @all_to_all_invalid_non_dynamic_result_split_dimension_induced_by_dynamic_operand_dimension(
     %arg0 : tensor<?x6xi8>) -> tensor<3x?xi8> {
@@ -366,7 +366,7 @@ func.func @all_to_all_invalid_non_dynamic_result_split_dimension_induced_by_dyna
 
 // -----
 
-mesh.cluster @mesh0(shape = 1x1)
+mesh.mesh @mesh0(shape = 1x1)
 
 func.func @all_to_all_invalid_non_dynamic_result_concat_dimension_induced_by_dynamic_operand_dimension(
     %arg0 : tensor<3x?xi8>) -> tensor<?x3xi8> {
@@ -379,7 +379,7 @@ func.func @all_to_all_invalid_non_dynamic_result_concat_dimension_induced_by_dyn
 
 // -----
 
-mesh.cluster @mesh0(shape = 3)
+mesh.mesh @mesh0(shape = 3)
 
 func.func @all_to_all_invalid_non_dynamic_result_concat_dimension_size(
     %arg0 : tensor<3x2xi8>) -> tensor<1x7xi8> {
@@ -392,7 +392,7 @@ func.func @all_to_all_invalid_non_dynamic_result_concat_dimension_size(
 
 // -----
 
-mesh.cluster @mesh0(shape = 3)
+mesh.mesh @mesh0(shape = 3)
 
 func.func @all_to_all_invalid_non_dynamic_result_split_dimension_size(
     %arg0 : tensor<3x2xi8>) -> tensor<2x6xi8> {
@@ -405,7 +405,7 @@ func.func @all_to_all_invalid_non_dynamic_result_split_dimension_size(
 
 // -----
 
-mesh.cluster @mesh0(shape = 3x?)
+mesh.mesh @mesh0(shape = 3x?)
 
 func.func @broadcast_root_dimension_out_of_bounds(
     %arg0 : tensor<2xi8>) -> tensor<2xi8> {
@@ -418,7 +418,7 @@ func.func @broadcast_root_dimension_out_of_bounds(
 
 // -----
 
-mesh.cluster @mesh0(shape = 3x?)
+mesh.mesh @mesh0(shape = 3x?)
 
 func.func @broadcast_root_wrong_number_dimensions(
     %arg0 : tensor<2xi8>) -> tensor<2xi8> {
@@ -431,7 +431,7 @@ func.func @broadcast_root_wrong_number_dimensions(
 
 // -----
 
-mesh.cluster @mesh0(shape = 3x?)
+mesh.mesh @mesh0(shape = 3x?)
 
 func.func @broadcast_different_input_and_result_type(
     %arg0 : tensor<2xi8>) -> tensor<2xi16> {
@@ -444,7 +444,7 @@ func.func @broadcast_different_input_and_result_type(
 
 // -----
 
-mesh.cluster @mesh0(shape = 1)
+mesh.mesh @mesh0(shape = 1)
 
 func.func @gather_wrong_return_element_type(
     %arg0 : tensor<1xf32>) -> tensor<1xi8> {
@@ -456,7 +456,7 @@ func.func @gather_wrong_return_element_type(
 
 // -----
 
-mesh.cluster @mesh0(shape = 1)
+mesh.mesh @mesh0(shape = 1)
 
 func.func @gather_invalid_non_gather_axis_dimension_size(
     %arg0 : tensor<3x4xf32>) -> tensor<3x5xf32> {
@@ -468,7 +468,7 @@ func.func @gather_invalid_non_gather_axis_dimension_size(
 
 // -----
 
-mesh.cluster @mesh0(shape = 1x2)
+mesh.mesh @mesh0(shape = 1x2)
 
 func.func @gather_invalid_gather_axis_dimension_size(
     %arg0 : tensor<3x4xf32>) -> tensor<3x5xf32> {
@@ -480,7 +480,7 @@ func.func @gather_invalid_gather_axis_dimension_size(
 
 // -----
 
-mesh.cluster @mesh0(shape = 1)
+mesh.mesh @mesh0(shape = 1)
 
 func.func @gather_invalid_gather_axis_dynamic_dimension(
     %arg0 : tensor<?xf32>) -> tensor<3xf32> {
@@ -492,7 +492,7 @@ func.func @gather_invalid_gather_axis_dynamic_dimension(
 
 // -----
 
-mesh.cluster @mesh0(shape = 1)
+mesh.mesh @mesh0(shape = 1)
 
 func.func @gather_invalid_gather_axis(
     %arg0 : tensor<3xf32>) -> tensor<3xf32> {
@@ -504,7 +504,7 @@ func.func @gather_invalid_gather_axis(
 
 // -----
 
-mesh.cluster @mesh0(shape = 1)
+mesh.mesh @mesh0(shape = 1)
 
 func.func @gather_invalid_negative_gather_axis(
     %arg0 : tensor<3xf32>) -> tensor<3xf32> {
@@ -516,7 +516,7 @@ func.func @gather_invalid_negative_gather_axis(
 
 // -----
 
-mesh.cluster @mesh0(shape = 3x?)
+mesh.mesh @mesh0(shape = 3x?)
 
 func.func @gather_root_dimension_out_of_bounds(
     %arg0 : tensor<2xi8>) -> tensor<6xi8> {
@@ -529,7 +529,7 @@ func.func @gather_root_dimension_out_of_bounds(
 
 // -----
 
-mesh.cluster @mesh0(shape = 3x?)
+mesh.mesh @mesh0(shape = 3x?)
 
 func.func @gather_root_wrong_number_dimensions(
     %arg0 : tensor<2xi8>) -> tensor<2xi8> {
@@ -542,7 +542,7 @@ func.func @gather_root_wrong_number_dimensions(
 
 // -----
 
-mesh.cluster @mesh0(shape = 3x?)
+mesh.mesh @mesh0(shape = 3x?)
 
 func.func @receive_source_dimension_out_of_bounds(
     %arg0 : tensor<2xi8>) -> tensor<2xi8> {
@@ -555,7 +555,7 @@ func.func @receive_source_dimension_out_of_bounds(
 
 // -----
 
-mesh.cluster @mesh0(shape = 3x?)
+mesh.mesh @mesh0(shape = 3x?)
 
 func.func @receive_source_wrong_number_dimensions(
     %arg0 : tensor<2xi8>) -> tensor<2xi8> {
@@ -568,7 +568,7 @@ func.func @receive_source_wrong_number_dimensions(
 
 // -----
 
-mesh.cluster @mesh0(shape = 3x?)
+mesh.mesh @mesh0(shape = 3x?)
 
 func.func @receive_different_input_and_result_type(
     %arg0 : tensor<2xi8>) -> tensor<2xi16> {
@@ -581,7 +581,7 @@ func.func @receive_different_input_and_result_type(
 
 // -----
 
-mesh.cluster @mesh0(shape = 3x?)
+mesh.mesh @mesh0(shape = 3x?)
 
 func.func @reduce_root_dimension_out_of_bounds(
     %arg0 : tensor<2xi8>) -> tensor<2xi8> {
@@ -594,7 +594,7 @@ func.func @reduce_root_dimension_out_of_bounds(
 
 // -----
 
-mesh.cluster @mesh0(shape = 3x?)
+mesh.mesh @mesh0(shape = 3x?)
 
 func.func @reduce_root_wrong_number_dimensions(
     %arg0 : tensor<2xi8>) -> tensor<2xi8> {
@@ -607,7 +607,7 @@ func.func @reduce_root_wrong_number_dimensions(
 
 // -----
 
-mesh.cluster @mesh0(shape = 3x?)
+mesh.mesh @mesh0(shape = 3x?)
 
 func.func @reduce_different_input_and_result_shape(
     %arg0 : tensor<2xi8>) -> tensor<3xi16> {
@@ -620,7 +620,7 @@ func.func @reduce_different_input_and_result_shape(
 
 // -----
 
-mesh.cluster @mesh0(shape = 3)
+mesh.mesh @mesh0(shape = 3)
 
 func.func @reduce_scatter_duplicate_mesh_axis(
     %arg0 : tensor<?xf32>) -> tensor<?xf64> {
@@ -632,7 +632,7 @@ func.func @reduce_scatter_duplicate_mesh_axis(
 
 // -----
 
-mesh.cluster @mesh0(shape = 3)
+mesh.mesh @mesh0(shape = 3)
 
 func.func @reduce_scatter_invalid_dynamic_dimension(
     %arg0 : tensor<?xf32>) -> tensor<2xf64> {
@@ -644,7 +644,7 @@ func.func @reduce_scatter_invalid_dynamic_dimension(
 
 // -----
 
-mesh.cluster @mesh0(shape = 3)
+mesh.mesh @mesh0(shape = 3)
 
 func.func @reduce_scatter_invalid_static_dimension_size(
     %arg0 : tensor<3xf32>) -> tensor<2xf64> {
@@ -656,7 +656,7 @@ func.func @reduce_scatter_invalid_static_dimension_size(
 
 // -----
 
-mesh.cluster @mesh0(shape = 3)
+mesh.mesh @mesh0(shape = 3)
 
 func.func @reduce_scatter_invalid_operand_static_dimension_size(
     %arg0 : tensor<4xf32>) -> tensor<?xf64> {
@@ -668,7 +668,7 @@ func.func @reduce_scatter_invalid_operand_static_dimension_size(
 
 // -----
 
-mesh.cluster @mesh0(shape = 3)
+mesh.mesh @mesh0(shape = 3)
 
 func.func @scatter_duplicate_mesh_axis(
     %arg0 : tensor<?xf32>) -> tensor<?xf32> {
@@ -681,7 +681,7 @@ func.func @scatter_duplicate_mesh_axis(
 
 // -----
 
-mesh.cluster @mesh0(shape = 3)
+mesh.mesh @mesh0(shape = 3)
 
 func.func @scatter_invalid_dynamic_dimension(
     %arg0 : tensor<?xf32>) -> tensor<2xf32> {
@@ -694,7 +694,7 @@ func.func @scatter_invalid_dynamic_dimension(
 
 // -----
 
-mesh.cluster @mesh0(shape = 3)
+mesh.mesh @mesh0(shape = 3)
 
 func.func @scatter_invalid_static_dimension_size(
     %arg0 : tensor<3xf32>) -> tensor<2xf32> {
@@ -707,7 +707,7 @@ func.func @scatter_invalid_static_dimension_size(
 
 // -----
 
-mesh.cluster @mesh0(shape = 3)
+mesh.mesh @mesh0(shape = 3)
 
 func.func @scatter_invalid_operand_static_dimension_size(
     %arg0 : tensor<4xf32>) -> tensor<?xf32> {
@@ -720,7 +720,7 @@ func.func @scatter_invalid_operand_static_dimension_size(
 
 // -----
 
-mesh.cluster @mesh0(shape = 3x?)
+mesh.mesh @mesh0(shape = 3x?)
 
 func.func @scatter_root_dimension_out_of_bounds(
     %arg0 : tensor<3xi8>) -> tensor<1xi8> {
@@ -733,7 +733,7 @@ func.func @scatter_root_dimension_out_of_bounds(
 
 // -----
 
-mesh.cluster @mesh0(shape = 3x?)
+mesh.mesh @mesh0(shape = 3x?)
 
 func.func @scatter_root_wrong_number_dimensions(
     %arg0 : tensor<3xi8>) -> tensor<1xi8> {
@@ -746,7 +746,7 @@ func.func @scatter_root_wrong_number_dimensions(
 
 // -----
 
-mesh.cluster @mesh0(shape = 3x?)
+mesh.mesh @mesh0(shape = 3x?)
 
 func.func @send_destination_dimension_out_of_bounds(
     %arg0 : tensor<2xi8>) -> tensor<2xi8> {
@@ -759,7 +759,7 @@ func.func @send_destination_dimension_out_of_bounds(
 
 // -----
 
-mesh.cluster @mesh0(shape = 3x?)
+mesh.mesh @mesh0(shape = 3x?)
 
 func.func @send_destination_wrong_number_dimensions(
     %arg0 : tensor<2xi8>) -> tensor<2xi8> {
@@ -772,7 +772,7 @@ func.func @send_destination_wrong_number_dimensions(
 
 // -----
 
-mesh.cluster @mesh0(shape = 3x?)
+mesh.mesh @mesh0(shape = 3x?)
 
 func.func @send_different_input_and_result_type(
     %arg0 : tensor<2xi8>) -> tensor<2xi16> {
@@ -796,7 +796,7 @@ func.func @shift_invalid_mesh_symbol(
 
 // -----
 
-mesh.cluster @mesh0(shape = 2x4)
+mesh.mesh @mesh0(shape = 2x4)
 
 func.func @shift_invalid_mesh_axis(
     %arg0 : tensor<4xi8>) -> tensor<4xi8> {
@@ -809,7 +809,7 @@ func.func @shift_invalid_mesh_axis(
 
 // -----
 
-mesh.cluster @mesh0(shape = 2x4)
+mesh.mesh @mesh0(shape = 2x4)
 
 func.func @shift_duplicate_mesh_axis(
     %arg0 : tensor<4xi8>) -> tensor<4xi8> {
@@ -822,7 +822,7 @@ func.func @shift_duplicate_mesh_axis(
 
 // -----
 
-mesh.cluster @mesh0(shape = 2x4)
+mesh.mesh @mesh0(shape = 2x4)
 
 func.func @shift_invalid_tensor_dimension_size(
     %arg0 : tensor<4xi8>) -> tensor<5xi8> {
@@ -835,7 +835,7 @@ func.func @shift_invalid_tensor_dimension_size(
 
 // -----
 
-mesh.cluster @mesh0(shape = 2x4)
+mesh.mesh @mesh0(shape = 2x4)
 
 func.func @shift_invalid_shift_axis(
     %arg0 : tensor<4xi8>) -> tensor<4xi8> {
diff --git a/mlir/test/Dialect/Mesh/ops.mlir b/mlir/test/Dialect/Mesh/ops.mlir
index 0aaa4bdee1db3a..dbaaff9c172fd9 100644
--- a/mlir/test/Dialect/Mesh/ops.mlir
+++ b/mlir/test/Dialect/Mesh/ops.mlir
@@ -1,21 +1,21 @@
 // RUN: mlir-opt %s | mlir-opt | FileCheck %s
 
-// CHECK: mesh.cluster @mesh0
-mesh.cluster @mesh0(shape = 2x2x4)
+// CHECK: mesh.mesh @mesh0
+mesh.mesh @mesh0(shape = 2x2x4)
 
-// CHECK: mesh.cluster @mesh1(shape = 4x?)
-mesh.cluster @mesh1(shape = 4x?)
+// CHECK: mesh.mesh @mesh1(shape = 4x?)
+mesh.mesh @mesh1(shape = 4x?)
 
-// CHECK: mesh.cluster @mesh2(shape = ?x4)
-mesh.cluster @mesh2(shape = ?x4)
+// CHECK: mesh.mesh @mesh2(shape = ?x4)
+mesh.mesh @mesh2(shape = ?x4)
 
-// CHECK: mesh.cluster @mesh3(shape = ?x?)
-mesh.cluster @mesh3(shape = ?x?)
+// CHECK: mesh.mesh @mesh3(shape = ?x?)
+mesh.mesh @mesh3(shape = ?x?)
 
-mesh.cluster @mesh4(shape = 3)
+mesh.mesh @mesh4(shape = 3)
 
-// CHECK: mesh.cluster @mesh5(shape = ?)
-mesh.cluster @mesh5(shape = ?)
+// CHECK: mesh.mesh @mesh5(shape = ?)
+mesh.mesh @mesh5(shape = ?)
 
 // CHECK-LABEL: func @mesh_shard_encoding_fully_replicated
 func.func @mesh_shard_encoding_fully_replicated(
@@ -132,26 +132,26 @@ func.func @mesh_shard_op_two_users(%arg0 : tensor<4x8xf32>) ->
   return %1, %2 : tensor<4x8xf32>, tensor<4x8xf32>
 }
 
-// CHECK-LABEL: func @cluster_shape
-func.func @cluster_shape() -> (index, index) {
-  // CHECK: %[[RES:.*]]:2 = mesh.cluster_shape @mesh0 axes = [0, 1] : index, index
-  %0:2 = mesh.cluster_shape @mesh0 axes = [0, 1] : index, index
+// CHECK-LABEL: func @mesh_shape
+func.func @mesh_shape() -> (index, index) {
+  // CHECK: %[[RES:.*]]:2 = mesh.mesh_shape @mesh0 axes = [0, 1] : index, index
+  %0:2 = mesh.mesh_shape @mesh0 axes = [0, 1] : index, index
   // CHECK: return %[[RES]]#0, %[[RES]]#1 : index, index
   return %0#0, %0#1 : index, index
 }
 
-// CHECK-LABEL: func @cluster_shape_default_axes
-func.func @cluster_shape_default_axes() -> (index, index, index) {
-  // CHECK: %[[RES:.*]]:3 = mesh.cluster_shape @mesh0 : index, index, index
-  %0:3 = mesh.cluster_shape @mesh0 : index, index, index
+// CHECK-LABEL: func @mesh_shape_default_axes
+func.func @mesh_shape_default_axes() -> (index, index, index) {
+  // CHECK: %[[RES:.*]]:3 = mesh.mesh_shape @mesh0 : index, index, index
+  %0:3 = mesh.mesh_shape @mesh0 : index, index, index
   // CHECK: return %[[RES]]#0, %[[RES]]#1, %[[RES]]#2 : index, index, index
   return %0#0, %0#1, %0#2 : index, index, index
 }
 
-// CHECK-LABEL: func @cluster_shape_empty_axes
-func.func @cluster_shape_empty_axes() -> (index, index, index) {
-  // CHECK: %[[RES:.*]]:3 = mesh.cluster_shape @mesh0 : index, index, index
-  %0:3 = mesh.cluster_shape @mesh0 axes = [] : index, index, index
+// CHECK-LABEL: func @mesh_shape_empty_axes
+func.func @mesh_shape_empty_axes() -> (index, index, index) {
+  // CHECK: %[[RES:.*]]:3 = mesh.mesh_shape @mesh0 : index, index, index
+  %0:3 = mesh.mesh_shape @mesh0 axes = [] : index, index, index
   // CHECK: return %[[RES]]#0, %[[RES]]#1, %[[RES]]#2 : index, index, index
   return %0#0, %0#1, %0#2 : index, index, index
 }
diff --git a/mlir/test/Dialect/Mesh/process-multi-index-op-lowering.mlir b/mlir/test/Dialect/Mesh/process-multi-index-op-lowering.mlir
index aeeba4ea3f1979..677a5982ea2540 100644
--- a/mlir/test/Dialect/Mesh/process-multi-index-op-lowering.mlir
+++ b/mlir/test/Dialect/Mesh/process-multi-index-op-lowering.mlir
@@ -1,11 +1,11 @@
 // RUN: mlir-opt -test-mesh-process-multi-index-op-lowering %s | FileCheck %s
 
-mesh.cluster @mesh2d(shape = ?x?)
+mesh.mesh @mesh2d(shape = ?x?)
 
 // CHECK-LABEL: func.func @multi_index_2d_mesh
 func.func @multi_index_2d_mesh() -> (index, index) {
   // CHECK: %[[LINEAR_IDX:.*]] = mesh.process_linear_index on @mesh2d : index
-  // CHECK: %[[MESH_SHAPE:.*]]:2 = mesh.cluster_shape @mesh2d : index, index
+  // CHECK: %[[MESH_SHAPE:.*]]:2 = mesh.mesh_shape @mesh2d : index, index
   // CHECK: %[[MULTI_IDX:.*]]:2 = affine.delinearize_index %0 into (%[[MESH_SHAPE]]#0, %[[MESH_SHAPE]]#1) : index, index
   %0:2 = mesh.process_multi_index on @mesh2d : index, index
   // CHECK: return %[[MULTI_IDX]]#0, %[[MULTI_IDX]]#1 : index, index
@@ -15,7 +15,7 @@ func.func @multi_index_2d_mesh() -> (index, index) {
 // CHECK-LABEL: func.func @multi_index_2d_mesh_single_inner_axis
 func.func @multi_index_2d_mesh_single_inner_axis() -> index {
   // CHECK: %[[LINEAR_IDX:.*]] = mesh.process_linear_index on @mesh2d : index
-  // CHECK: %[[MESH_SHAPE:.*]]:2 = mesh.cluster_shape @mesh2d : index, index
+  // CHECK: %[[MESH_SHAPE:.*]]:2 = mesh.mesh_shape @mesh2d : index, index
   // CHECK: %[[MULTI_IDX:.*]]:2 = affine.delinearize_index %0 into (%[[MESH_SHAPE]]#0, %[[MESH_SHAPE]]#1) : index, index
   %0 = mesh.process_multi_index on @mesh2d axes = [0] : index
   // CHECK: return %[[MULTI_IDX]]#0 : index
diff --git a/mlir/test/Dialect/Mesh/resharding-spmdization.mlir b/mlir/test/Dialect/Mesh/resharding-spmdization.mlir
index 3f5c7d80bf9c7e..cb98d31dd6692d 100644
--- a/mlir/test/Dialect/Mesh/resharding-spmdization.mlir
+++ b/mlir/test/Dialect/Mesh/resharding-spmdization.mlir
@@ -1,7 +1,7 @@
 // RUN: mlir-opt -test-mesh-resharding-spmdization %s | FileCheck %s
 
-mesh.cluster @mesh_1d(shape = 2)
-mesh.cluster @mesh_1d_dynamic(shape = ?)
+mesh.mesh @mesh_1d(shape = 2)
+mesh.mesh @mesh_1d_dynamic(shape = ?)
 
 // CHECK-LABEL: func @same_source_and_target_sharding
 func.func @same_source_and_target_sharding(
@@ -22,7 +22,7 @@ func.func @split_replicated_tensor_axis(
   // CHECK-DAG: %[[ZERO:.*]] = arith.constant 0 : index
   // CHECK-DAG: %[[TENSOR_SPLIT_AXIS_SIZE:.*]] = arith.constant 14 : index
   // CHECK: %[[PROCESS_INDEX:.*]] = mesh.process_multi_index on @mesh_1d axes = [0] : index
-  // CHECK: %[[MESH_AXIS_SIZE:.*]] = mesh.cluster_shape @mesh_1d axes = [0] : index
+  // CHECK: %[[MESH_AXIS_SIZE:.*]] = mesh.mesh_shape @mesh_1d axes = [0] : index
   // CHECK: %[[TENSOR_SPLIT_AXIS_SIZE_MOD_MESH_AXIS_SIZE:.*]] = arith.remui %[[TENSOR_SPLIT_AXIS_SIZE]], %[[MESH_AXIS_SIZE]] : index
   // CHECK: %[[RESULT_TENSOR_AXIS_SIZE_CHECK:.*]] = arith.cmpi eq, %[[TENSOR_SPLIT_AXIS_SIZE_MOD_MESH_AXIS_SIZE]], %[[ZERO]] : index
   // CHECK: cf.assert %[[RESULT_TENSOR_AXIS_SIZE_CHECK]]
@@ -44,7 +44,7 @@ func.func @split_replicated_tensor_axis_dynamic(
   // CHECK-DAG: %[[ZERO:.*]] = arith.constant 0 : index
   // CHECK-DAG: %[[TWO:.*]] = arith.constant 2 : index
   // CHECK: %[[PROCESS_INDEX:.*]] = mesh.process_multi_index on @mesh_1d_dynamic axes = [0] : index
-  // CHECK: %[[MESH_AXIS_SIZE:.*]] = mesh.cluster_shape @mesh_1d_dynamic axes = [0] : index
+  // CHECK: %[[MESH_AXIS_SIZE:.*]] = mesh.mesh_shape @mesh_1d_dynamic axes = [0] : index
   // CHECK: %[[TENSOR_SPLIT_AXIS_SIZE:.*]] = tensor.dim %[[ARG]], %[[ZERO]] : tensor<?x3x?xf32>
   // CHECK: %[[TENSOR_SPLIT_AXIS_SIZE_MOD_MESH_AXIS_SIZE:.*]] = arith.remui %[[TENSOR_SPLIT_AXIS_SIZE]], %[[MESH_AXIS_SIZE]] : index
   // CHECK: %[[RESULT_TENSOR_AXIS_SIZE_CHECK:.*]] = arith.cmpi eq, %[[TENSOR_SPLIT_AXIS_SIZE_MOD_MESH_AXIS_SIZE]], %[[ZERO]] : index
diff --git a/mlir/test/Dialect/Mesh/sharding-propagation.mlir b/mlir/test/Dialect/Mesh/sharding-propagation.mlir
index 065ae9ca8c6b41..94f8d94073c5ef 100644
--- a/mlir/test/Dialect/Mesh/sharding-propagation.mlir
+++ b/mlir/test/Dialect/Mesh/sharding-propagation.mlir
@@ -1,8 +1,8 @@
 // RUN: mlir-opt -sharding-propagation %s | FileCheck %s
 
-mesh.cluster @mesh_1d(shape = ?)
-mesh.cluster @mesh_2d(shape = 2x4)
-mesh.cluster @mesh_3d(shape = ?x?x?)
+mesh.mesh @mesh_1d(shape = ?)
+mesh.mesh @mesh_2d(shape = 2x4)
+mesh.mesh @mesh_3d(shape = ?x?x?)
 
 // CHECK-LABEL: func.func @element_wise_empty_sharding_info
 func.func @element_wise_empty_sharding_info(%arg0: tensor<8x16xf32>) -> tensor<8x16xf32> {
diff --git a/mlir/test/Dialect/Mesh/simplifications.mlir b/mlir/test/Dialect/Mesh/simplifications.mlir
index 63ae9d528df1b3..d748be82c5a46a 100644
--- a/mlir/test/Dialect/Mesh/simplifications.mlir
+++ b/mlir/test/Dialect/Mesh/simplifications.mlir
@@ -1,7 +1,7 @@
 // RUN: mlir-opt -test-mesh-simplifications %s | FileCheck %s
 
-mesh.cluster @mesh0(shape = 4x2)
-mesh.cluster @mesh1(shape = 4)
+mesh.mesh @mesh0(shape = 4x2)
+mesh.mesh @mesh1(shape = 4)
 
 // Checks that `all_reduce(x) + all_reduce(y)` gets transformed to
 // `all_reduce(x + y)`.
diff --git a/mlir/test/lib/Dialect/Mesh/TestReshardingSpmdization.cpp b/mlir/test/lib/Dialect/Mesh/TestReshardingSpmdization.cpp
index 6fecbd48f15387..9b3082a819224f 100644
--- a/mlir/test/lib/Dialect/Mesh/TestReshardingSpmdization.cpp
+++ b/mlir/test/lib/Dialect/Mesh/TestReshardingSpmdization.cpp
@@ -37,15 +37,15 @@ struct TestMeshReshardingRewritePattern : OpRewritePattern<ShardOp> {
     }
 
     SymbolTableCollection symbolTable;
-    mesh::ClusterOp mesh = symbolTable.lookupNearestSymbolFrom<mesh::ClusterOp>(
-        op, op.getShard().getCluster());
+    mesh::MeshOp mesh = symbolTable.lookupNearestSymbolFrom<mesh::MeshOp>(
+        op, op.getShard().getMesh());
 
     bool foundUser = false;
     for (auto user : op->getUsers()) {
       if (auto targetShardOp = llvm::dyn_cast<ShardOp>(user)) {
         if (targetShardOp.getAnnotateForUsers() &&
-            mesh == symbolTable.lookupNearestSymbolFrom<mesh::ClusterOp>(
-                        targetShardOp, targetShardOp.getShard().getCluster())) {
+            mesh == symbolTable.lookupNearestSymbolFrom<mesh::MeshOp>(
+                        targetShardOp, targetShardOp.getShard().getMesh())) {
           foundUser = true;
           break;
         }
@@ -59,8 +59,8 @@ struct TestMeshReshardingRewritePattern : OpRewritePattern<ShardOp> {
     for (auto user : op->getUsers()) {
       auto targetShardOp = llvm::dyn_cast<ShardOp>(user);
       if (!targetShardOp || !targetShardOp.getAnnotateForUsers() ||
-          symbolTable.lookupNearestSymbolFrom<mesh::ClusterOp>(
-              targetShardOp, targetShardOp.getShard().getCluster()) != mesh) {
+          symbolTable.lookupNearestSymbolFrom<mesh::MeshOp>(
+              targetShardOp, targetShardOp.getShard().getMesh()) != mesh) {
         continue;
       }
 



More information about the Mlir-commits mailing list