[Mlir-commits] [mlir] [mlir][mesh] Rename cluster to mesh (PR #79484)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Jan 25 10:54:38 PST 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Boian Petkantchin (sogartar)

<details>
<summary>Changes</summary>

Rename
* Op mesh.cluster -> mesh.mesh
* Op mesh.cluster_shape -> mesh.mesh_shape
* variables and attributes.

The name `mesh` is more specific to what it really represents. It is a mesh of devices.
The name `cluster` implies a broader posibility of device configurations. When just the word `mesh` is used the meaning can often be inferred from the context whether it refers to the mesh dialect or a device mesh. The full name can be used when needed.

---

Patch is 65.09 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/79484.diff


18 Files Affected:

- (modified) mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td (+10-10) 
- (modified) mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td (+29-29) 
- (modified) mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.h (+3-3) 
- (modified) mlir/include/mlir/Dialect/Mesh/Transforms/Spmdization.h (+3-3) 
- (modified) mlir/lib/Dialect/Mesh/IR/MeshOps.cpp (+21-22) 
- (modified) mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp (+10-12) 
- (modified) mlir/lib/Dialect/Mesh/Transforms/Simplifications.cpp (+10-12) 
- (modified) mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp (+18-18) 
- (modified) mlir/lib/Dialect/Mesh/Transforms/Transforms.cpp (+4-5) 
- (modified) mlir/test/Dialect/Mesh/canonicalization.mlir (+1-1) 
- (modified) mlir/test/Dialect/Mesh/folding.mlir (+9-9) 
- (modified) mlir/test/Dialect/Mesh/invalid.mlir (+78-78) 
- (modified) mlir/test/Dialect/Mesh/ops.mlir (+23-23) 
- (modified) mlir/test/Dialect/Mesh/process-multi-index-op-lowering.mlir (+3-3) 
- (modified) mlir/test/Dialect/Mesh/resharding-spmdization.mlir (+4-4) 
- (modified) mlir/test/Dialect/Mesh/sharding-propagation.mlir (+3-3) 
- (modified) mlir/test/Dialect/Mesh/simplifications.mlir (+2-2) 
- (modified) mlir/test/lib/Dialect/Mesh/TestReshardingSpmdization.cpp (+6-6) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td b/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td
index 07f954459ca49d..c15c5f6bcf2f0f 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td
@@ -79,7 +79,7 @@ def MeshSharding : AttrDef<Mesh_Dialect, "MeshSharding"> {
   let mnemonic = "shard";
 
   let parameters = (ins
-    AttrParameter<"::mlir::FlatSymbolRefAttr", "cluster placed">:$cluster,
+    AttrParameter<"::mlir::FlatSymbolRefAttr", "mesh placed">:$mesh,
     ArrayRefParameter<"MeshAxesAttr">:$split_axes,
     OptionalArrayRefParameter<"MeshAxis">:$partial_axes,
     OptionalParameter<"::mlir::mesh::Partial">:$partial_type
@@ -91,9 +91,9 @@ def MeshSharding : AttrDef<Mesh_Dialect, "MeshSharding"> {
     The MeshSharding attribute could be used in the encoding of a
     `RankedTensorType` or the mesh.shard op. it contains three sub-attributes:
 
-    1. `cluster`: this attribute is a FlatSymbolRefAttr that refers to the mesh
-    cluster where the distributed tensor is placed. The symbol must resolve to a
-    `mesh.cluster` operation.
+    1. `mesh`: this attribute is a FlatSymbolRefAttr that refers to the device
+    mesh where the distributed tensor is placed. The symbol must resolve to a
+    `mesh.mesh` operation.
 
     2. `split_axes`: is an array composed of int64_t sub-arrays. The outer array's
     maximum size is the `rank` of the related tensor. For the i-th sub-array, if
@@ -117,7 +117,7 @@ def MeshSharding : AttrDef<Mesh_Dialect, "MeshSharding"> {
     Example:
 
     ```
-    mesh.cluster @mesh0(shape = 2x2x4)
+    mesh.mesh @mesh0(shape = 2x2x4)
 
     // The tensor is fully replicated on @mesh0.
     // Currently, there must be at least one sub-array present in axes, even
@@ -140,12 +140,12 @@ def MeshSharding : AttrDef<Mesh_Dialect, "MeshSharding"> {
     ```
   }];
   let assemblyFormat = [{
-    `<` $cluster `,` `[` $split_axes `]` (`,` `partial` `=` $partial_type `[`
+    `<` $mesh `,` `[` $split_axes `]` (`,` `partial` `=` $partial_type `[`
        $partial_axes^ `]`)? `>`
   }];
 
   let builders = [
-    AttrBuilder<(ins "FlatSymbolRefAttr":$cluster,
+    AttrBuilder<(ins "FlatSymbolRefAttr":$mesh,
                      "ArrayRef<SmallVector<MeshAxis>>":$split_axes,
                      "ArrayRef<MeshAxis>": $partial_axes,
                      "mesh::Partial": $partial_type), [{
@@ -153,12 +153,12 @@ def MeshSharding : AttrDef<Mesh_Dialect, "MeshSharding"> {
                   split_axes, [&](ArrayRef<MeshAxis> array) {
           return MeshAxesAttr::get($_ctxt, array);
       });
-      return $_get($_ctxt, cluster, splitAxesAttr, partial_axes,
+      return $_get($_ctxt, mesh, splitAxesAttr, partial_axes,
                    partial_type);
     }]>,
-    AttrBuilder<(ins "FlatSymbolRefAttr":$cluster,
+    AttrBuilder<(ins "FlatSymbolRefAttr":$mesh,
                      "ArrayRef<SmallVector<MeshAxis>>":$split_axes), [{
-      return MeshShardingAttr::get($_ctxt, cluster, split_axes, {}, Partial::Sum);
+      return MeshShardingAttr::get($_ctxt, mesh, split_axes, {}, Partial::Sum);
     }]>
   ];
 
diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
index 78ff8bd0cac621..c557ec1da1b6d9 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
@@ -26,17 +26,17 @@ class Mesh_Op<string mnemonic, list<Trait> traits = []> :
     Op<Mesh_Dialect, mnemonic, traits> {
 }
 
-def Mesh_ClusterOp : Mesh_Op<"cluster", [Symbol]> {
-  let summary = "representing a mesh cluster";
+def Mesh_MeshOp : Mesh_Op<"mesh", [Symbol]> {
+  let summary = "representing a device/process mesh";
   let description = [{
-    The mesh.cluster operation is a symbol operation that identifies a specific
-    mesh cluster. The operation has three attributes:
+    The mesh.mesh operation is a symbol operation that identifies a specific
+    mesh. The operation has three attributes:
 
-    1. `sym_name`: This attribute uniquely identifies the name of the mesh
-    cluster. This name serves as a symbolic reference to the cluster throughout
+    1. `sym_name`: This attribute uniquely identifies the name of the mesh.
+    This name serves as a symbolic reference to the mesh throughout
     the MLIR module, allowing for consistent referencing and easier debugging.
 
-    2. `shape`: This attribute represents the shape of the device cluster.
+    2. `shape`: This attribute represents the shape of the device mesh.
     It uses the same notation as a tensor shape. Also allowing for dynamic
     dimensions.
     This flexibility allows for dynamic device assignment or configurations
@@ -46,21 +46,21 @@ def Mesh_ClusterOp : Mesh_Op<"cluster", [Symbol]> {
 
     Example:
     ```
-    // A device mesh cluster with 3 axes, the total device number is 4 * 8 * 12
+    // A device mesh with 3 axes, the total device number is 4 * 8 * 12
     // The dimension sizes are 4, 8, 12 
-    mesh.cluster @mesh0(shape = 4x8x12)
+    mesh.mesh @mesh0(shape = 4x8x12)
 
-    // A device mesh cluster with 2 axes, the total device number is unknown
+    // A device mesh with 2 axes, the total device number is unknown
     // The first dimension size is 4 and the second is unknown
-    mesh.cluster @mesh1(shape = 4x?)
+    mesh.mesh @mesh1(shape = 4x?)
 
-    // A device mesh cluster with 2 axes, the total device number is unknown
+    // A device mesh with 2 axes, the total device number is unknown
     // The first dimension size is unknown and the second is 4
-    mesh.cluster @mesh2(shape = ?x4)
+    mesh.mesh @mesh2(shape = ?x4)
 
-    // A device mesh cluster with 2 axes, the number of devices along both axes
+    // A device mesh with 2 axes, the number of devices along both axes
     // is unknown
-    mesh.cluster @mesh3(shape = ?x?)
+    mesh.mesh @mesh3(shape = ?x?)
 
     // Used in the mesh sharding attribute to extend the standard tensor to
     // distributed
@@ -81,9 +81,9 @@ def Mesh_ClusterOp : Mesh_Op<"cluster", [Symbol]> {
   let hasVerifier = 1;
 }
 
-def Mesh_ClusterShapeOp : Mesh_Op<"cluster_shape", [
+def Mesh_MeshShapeOp : Mesh_Op<"mesh_shape", [
   Pure, DeclareOpInterfaceMethods<SymbolUserOpInterface>]> {
-  let summary = "Get the shape of the cluster.";
+  let summary = "Get the shape of the mesh.";
   let arguments = (ins
     FlatSymbolRefAttr:$mesh,
     DefaultValuedAttr<Mesh_MeshAxesAttr, "{}">:$axes
@@ -99,13 +99,13 @@ def Mesh_ClusterShapeOp : Mesh_Op<"cluster_shape", [
   }];
 
   let builders = [
-    OpBuilder<(ins "::mlir::mesh::ClusterOp":$mesh)>,
+    OpBuilder<(ins "::mlir::mesh::MeshOp":$mesh)>,
     OpBuilder<(ins "StringRef":$mesh, "ArrayRef<MeshAxis>":$axes)>
   ];
 }
 
 def Mesh_ShardOp : Mesh_Op<"shard", [Pure, SameOperandsAndResultType]> {
-  let summary = "Annotate on how a tensor is sharded across a mesh cluster.";
+  let summary = "Annotate on how a tensor is sharded across a mesh.";
   let description = [{
     The mesh.shard operation is designed to specify and guide the sharding
     behavior of a tensor value across a mesh topology. This operation has one
@@ -115,7 +115,7 @@ def Mesh_ShardOp : Mesh_Op<"shard", [Pure, SameOperandsAndResultType]> {
     annotated for sharding.
 
     2. `shard`: This attribute is type of `MeshSharding`, which is the core data
-    structure to represent distributed tensor in mesh cluster.
+    structure to represent distribution of a tensor on a mesh.
 
     3. `annotate_for_users`: A unit attribute addressing the scenario when a
     tensor's sharding annotation differs based on its context of use (either as
@@ -217,7 +217,7 @@ def Mesh_ProcessMultiIndexOp : Mesh_Op<"process_multi_index", [
     attr-dict `:` type($result)
   }];
   let builders = [
-    OpBuilder<(ins "::mlir::mesh::ClusterOp":$mesh)>,
+    OpBuilder<(ins "::mlir::mesh::MeshOp":$mesh)>,
     OpBuilder<(ins "StringRef":$mesh, "ArrayRef<MeshAxis>":$axes)>
   ];
 }
@@ -239,7 +239,7 @@ def Mesh_ProcessLinearIndexOp : Mesh_Op<"process_linear_index", [
   let results = (outs Index:$result);
   let assemblyFormat = "`on` $mesh attr-dict `:` type($result)";
   let builders = [
-    OpBuilder<(ins "::mlir::mesh::ClusterOp":$mesh)>
+    OpBuilder<(ins "::mlir::mesh::MeshOp":$mesh)>
   ];
 }
 
@@ -268,7 +268,7 @@ def Mesh_AllGatherOp : Mesh_CollectiveCommunicationOpBase<"all_gather", [
 
     Example:
     ```mlir
-    mesh.cluster @mesh0(shape = 2x2)
+    mesh.mesh @mesh0(shape = 2x2)
     ...
     %1 = mesh.all_gather %0 on @mesh0 mesh_axes = [1] gather_axis = 1
       : tensor<2x2xi8> -> tensor<2x4xi8>
@@ -353,7 +353,7 @@ def Mesh_AllToAllOp : Mesh_CollectiveCommunicationOpBase<"all_to_all", [
 
     Example:
     ```
-    mesh.cluster @mesh0(shape = 3)
+    mesh.mesh @mesh0(shape = 3)
     ...
     %1 = mesh.all_to_all %0 on @mesh0 mesh_axes = [0]
       split_axis = 0 concat_axis = 0
@@ -410,7 +410,7 @@ def Mesh_BroadcastOp : Mesh_CollectiveCommunicationOpBase<"broadcast", [
     
     Example:
     ```
-    mesh.cluster @mesh0(shape = 2x2)
+    mesh.mesh @mesh0(shape = 2x2)
 
     %1 = mesh.broadcast %0 on @mesh0
       mesh_axes = [0]
@@ -466,7 +466,7 @@ def Mesh_GatherOp : Mesh_CollectiveCommunicationOpBase<"gather", [
 
     Example:
     ```mlir
-    mesh.cluster @mesh0(shape = 2x2)
+    mesh.mesh @mesh0(shape = 2x2)
     ...
     %1 = mesh.gather %0 on @mesh0 mesh_axes = [1]
       gather_axis = 1 root = [1]
@@ -589,7 +589,7 @@ def Mesh_ReduceScatterOp : Mesh_CollectiveCommunicationOpBase<"reduce_scatter",
     across the device group.
     Example:
     ```
-    mesh.cluster @mesh0(shape = 2x2)
+    mesh.mesh @mesh0(shape = 2x2)
     ...
     %1 = mesh.reduce_scatter %0 on @mesh0 mesh_axes = [1]
       reduction = <max> scatter_axis = 0
@@ -652,7 +652,7 @@ def Mesh_ScatterOp : Mesh_CollectiveCommunicationOpBase<"scatter", [
 
     Example:
     ```
-    mesh.cluster @mesh0(shape = 2x2)
+    mesh.mesh @mesh0(shape = 2x2)
     %1 = mesh.scatter %0 on @mesh0 mesh_axes = [0]
       scatter_axis = 0
       root = [1]
@@ -748,7 +748,7 @@ def Mesh_ShiftOp : Mesh_CollectiveCommunicationOpBase<"shift", [
 
     Example:
     ```
-    mesh.cluster @mesh0(shape = 2x4)
+    mesh.mesh @mesh0(shape = 2x4)
     %1 = mesh.shift on @mesh0 mesh_axes = [1]
       shift_axis = 1 offset = 2 rotate
       : tensor<2xi8> -> tensor<2xi8>
diff --git a/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.h b/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.h
index a32274d857f15d..3bef7e6babdec9 100644
--- a/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.h
+++ b/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.h
@@ -25,14 +25,14 @@ struct ShardingOption {
   // An array of int array. The sub-array at the i-th position signifies the
   // mesh axes the i-th loop will be sharded on.
   ShardingArray shardingArray = {};
-  FlatSymbolRefAttr cluster = nullptr;
+  FlatSymbolRefAttr mesh = nullptr;
   // `empty` being true indicates that no sharding information can be inferred
   // at present. Note that it is different from the case where an operation is
   // not sharded.
   bool empty = false;
   ShardingOption() = default;
-  ShardingOption(ShardingArray shardingArray, FlatSymbolRefAttr cluster)
-      : shardingArray(std::move(shardingArray)), cluster(cluster) {}
+  ShardingOption(ShardingArray shardingArray, FlatSymbolRefAttr mesh)
+      : shardingArray(std::move(shardingArray)), mesh(mesh) {}
 };
 
 // This method retrieves the 'MeshShardingAttr' attribute from a given operation
diff --git a/mlir/include/mlir/Dialect/Mesh/Transforms/Spmdization.h b/mlir/include/mlir/Dialect/Mesh/Transforms/Spmdization.h
index f71bb9b262a380..7cb992aac019b3 100644
--- a/mlir/include/mlir/Dialect/Mesh/Transforms/Spmdization.h
+++ b/mlir/include/mlir/Dialect/Mesh/Transforms/Spmdization.h
@@ -17,14 +17,14 @@ namespace mlir {
 namespace mesh {
 
 // Return the sharded shape `shape` acording ot sharding `sharding`.
-ShapedType shardShapedType(ShapedType shape, ClusterOp mesh,
+ShapedType shardShapedType(ShapedType shape, MeshOp mesh,
                            MeshShardingAttr sharding);
 
 // Insert resharding spmdization of the value `sourceShardValue`
 // from sharding `source` to sharding `target`.
 // `sourceShardValue` is the already sharded value according to `source`.
-TypedValue<ShapedType> reshard(OpBuilder &builder, ClusterOp mesh,
-                               ShardOp source, ShardOp target,
+TypedValue<ShapedType> reshard(OpBuilder &builder, MeshOp mesh, ShardOp source,
+                               ShardOp target,
                                TypedValue<ShapedType> sourceShardValue);
 
 void reshardingRegisterDependentDialects(DialectRegistry &registry);
diff --git a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
index f6b6b7c248c432..994a017a1f46c2 100644
--- a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
+++ b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
@@ -114,10 +114,10 @@ Operation *MeshDialect::materializeConstant(OpBuilder &builder, Attribute value,
 // Mesh utilities
 //===----------------------------------------------------------------------===//
 
-static FailureOr<ClusterOp> getMesh(Operation *op, FlatSymbolRefAttr meshSymbol,
-                                    SymbolTableCollection &symbolTable) {
-  mesh::ClusterOp mesh =
-      symbolTable.lookupNearestSymbolFrom<mesh::ClusterOp>(op, meshSymbol);
+static FailureOr<MeshOp> getMesh(Operation *op, FlatSymbolRefAttr meshSymbol,
+                                 SymbolTableCollection &symbolTable) {
+  mesh::MeshOp mesh =
+      symbolTable.lookupNearestSymbolFrom<mesh::MeshOp>(op, meshSymbol);
   if (!mesh) {
     return op->emitError() << "Undefined required mesh symbol \""
                            << meshSymbol.getValue() << "\".";
@@ -144,7 +144,7 @@ bool isUnique(It begin, It end) {
 }
 
 static LogicalResult verifyMeshAxes(Location loc, ArrayRef<MeshAxis> axes,
-                                    ClusterOp mesh) {
+                                    MeshOp mesh) {
   SmallVector<MeshAxis> sorted = llvm::to_vector(axes);
   llvm::sort(sorted);
   if (!isUnique(sorted.begin(), sorted.end())) {
@@ -192,22 +192,22 @@ Partial mesh::getPartialTypeFromReduction(IteratorType iType) {
 }
 
 //===----------------------------------------------------------------------===//
-// mesh.cluster op
+// mesh.mesh op
 //===----------------------------------------------------------------------===//
 
-LogicalResult ClusterOp::verify() {
+LogicalResult MeshOp::verify() {
   int64_t rank = getRank();
 
   if (rank <= 0)
-    return emitOpError("rank of cluster is expected to be a positive integer");
+    return emitOpError("rank of mesh is expected to be a positive integer");
 
   if (getShape().size() > size_t(rank))
     return emitOpError(
-        "rank of shape is not expected to be larger than rank of cluster");
+        "rank of shape is not expected to be larger than rank of mesh");
 
   for (int64_t dimSize : getShape()) {
     if (dimSize < 0 && !ShapedType::isDynamic(dimSize))
-      return emitOpError("dimension size of a mesh cluster is expected to be "
+      return emitOpError("dimension size of a mesh is expected to be "
                          "non-negative or dynamic");
   }
 
@@ -215,11 +215,11 @@ LogicalResult ClusterOp::verify() {
 }
 
 //===----------------------------------------------------------------------===//
-// mesh.cluster_shape op
+// mesh.mesh_shape op
 //===----------------------------------------------------------------------===//
 
 LogicalResult
-ClusterShapeOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
+MeshShapeOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
   auto mesh = ::getMesh(getOperation(), getMeshAttr(), symbolTable);
   if (failed(mesh)) {
     return failure();
@@ -238,16 +238,16 @@ ClusterShapeOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
   return success();
 }
 
-void ClusterShapeOp::build(OpBuilder &odsBuilder, OperationState &odsState,
-                           ClusterOp mesh) {
+void MeshShapeOp::build(OpBuilder &odsBuilder, OperationState &odsState,
+                        MeshOp mesh) {
   build(odsBuilder, odsState,
         SmallVector<Type>(mesh.getRank(), odsBuilder.getIndexType()),
         mesh.getSymName(),
         MeshAxesAttr::get(odsBuilder.getContext(), SmallVector<MeshAxis>()));
 }
 
-void ClusterShapeOp::build(OpBuilder &odsBuilder, OperationState &odsState,
-                           StringRef mesh, ArrayRef<MeshAxis> axes) {
+void MeshShapeOp::build(OpBuilder &odsBuilder, OperationState &odsState,
+                        StringRef mesh, ArrayRef<MeshAxis> axes) {
   build(odsBuilder, odsState,
         SmallVector<Type>(axes.size(), odsBuilder.getIndexType()), mesh,
         MeshAxesAttr::get(odsBuilder.getContext(), axes));
@@ -261,7 +261,7 @@ LogicalResult
 MeshShardingAttr::verify(function_ref<InFlightDiagnostic()> emitError,
                          FlatSymbolRefAttr, ArrayRef<MeshAxesAttr> splitAxes,
                          ArrayRef<MeshAxis> partialAxes, Partial) {
-  // TODO: At present cluster symbol ref is not verified. This is due to the
+  // TODO: At present mesh symbol ref is not verified. This is due to the
   // difficulty in fetching the corresponding symbol op based on an attribute.
 
   llvm::SmallSet<MeshAxis, 4> visitedAxes;
@@ -292,8 +292,7 @@ bool MeshShardingAttr::operator==(Attribute rhs) const {
 }
 
 bool MeshShardingAttr::operator==(MeshShardingAttr rhs) const {
-  if (getCluster() != rhs.getCluster() ||
-      getPartialAxes() != rhs.getPartialAxes()) {
+  if (getMesh() != rhs.getMesh() || getPartialAxes() != rhs.getPartialAxes()) {
     return false;
   }
 
@@ -342,7 +341,7 @@ ProcessMultiIndexOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
 }
 
 void ProcessMultiIndexOp::build(OpBuilder &odsBuilder, OperationState &odsState,
-                                ClusterOp mesh) {
+                                MeshOp mesh) {
   build(odsBuilder, odsState,
         SmallVector<Type>(mesh.getRank(), odsBuilder.getIndexType()),
         mesh.getSymName(), ArrayRef<MeshAxis>());
@@ -369,7 +368,7 @@ ProcessLinearIndexOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
 }
 
 void ProcessLinearIndexOp::build(OpBuilder &odsBuilder,
-                                 OperationState &odsState, ClusterOp mesh) {
+                                 OperationState &odsState, MeshOp mesh) {
   build(odsBuilder, odsState, mesh.getSymName());
 }
 
@@ -427,7 +426,7 @@ static LogicalResult verifyInGroupDevice(Location loc, StringRef deviceName,
 }
 
 template <typename Op>
-static FailureOr<ClusterOp>
+static FailureOr<MeshOp>
 getMeshAndVerifyAxes(Op op, SymbolTableCollection &symbolTable) {
   auto mesh = ::getMesh(op.getOperation(), op.getMeshAttr(), symbolTable);
   if (failed(mesh)) {
diff --git a/mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp b/mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp
index dca7e86e6f07f5..5dc91ff1c02d20 100644
--- a/mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp
+++ b/mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp
@@ -215,11 +215,10 @@ namespace {
 // Update the given `shardingOption` according to `meshAxes` and `loopIdx`
 static LogicalResult fillShardingOption(Operation *op,
                                         ShardingOption &shardingOption,
-                                        FlatSymbolRefAttr cluster,
+                                        FlatSymbolRefAttr mesh,
                                         ArrayRef<MeshAxis> meshAxes,
                                         unsigned loopIdx) {
-  if ((shardingOption.cluster && cluster &&
-       shardingOption.cluster != cluster) ||
+  if ((shardingOption.mesh && mesh && shardingOption.mesh != mesh) ||
       (!shardingOption.shardingArray[loopIdx].empty() &&
        shardingOption.shardingArray[loopIdx] != meshAxes)) {
     LLVM_DEBUG(DBGS() << "sharding option conflicts on loop iterator "
@@ -238,8 +237,8 @@ static LogicalResult fillShardingOption(Operation *op,
       }
     }
   }
-  if (cluster)
-    shardingOption.cluster = cluster;
+  if (mesh)
+    shardingOption.mesh = mesh;
   if (shardingOption.shardingArray[loopIdx].empty())
     shardingOption.shardingArray[loopIdx].append(meshAxes.begin(),
                                                 ...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/79484


More information about the Mlir-commits mailing list