[Mlir-commits] [mlir] [mlir][mesh] Rename cluster to mesh (PR #79484)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Jan 25 10:54:38 PST 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Boian Petkantchin (sogartar)
<details>
<summary>Changes</summary>
Rename
* Op mesh.cluster -> mesh.mesh
* Op mesh.cluster_shape -> mesh.mesh_shape
* variables and attributes.
The name `mesh` is more specific to what it really represents. It is a mesh of devices.
The name `cluster` implies a broader posibility of device configurations. When just the word `mesh` is used the meaning can often be inferred from the context whether it refers to the mesh dialect or a device mesh. The full name can be used when needed.
---
Patch is 65.09 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/79484.diff
18 Files Affected:
- (modified) mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td (+10-10)
- (modified) mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td (+29-29)
- (modified) mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.h (+3-3)
- (modified) mlir/include/mlir/Dialect/Mesh/Transforms/Spmdization.h (+3-3)
- (modified) mlir/lib/Dialect/Mesh/IR/MeshOps.cpp (+21-22)
- (modified) mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp (+10-12)
- (modified) mlir/lib/Dialect/Mesh/Transforms/Simplifications.cpp (+10-12)
- (modified) mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp (+18-18)
- (modified) mlir/lib/Dialect/Mesh/Transforms/Transforms.cpp (+4-5)
- (modified) mlir/test/Dialect/Mesh/canonicalization.mlir (+1-1)
- (modified) mlir/test/Dialect/Mesh/folding.mlir (+9-9)
- (modified) mlir/test/Dialect/Mesh/invalid.mlir (+78-78)
- (modified) mlir/test/Dialect/Mesh/ops.mlir (+23-23)
- (modified) mlir/test/Dialect/Mesh/process-multi-index-op-lowering.mlir (+3-3)
- (modified) mlir/test/Dialect/Mesh/resharding-spmdization.mlir (+4-4)
- (modified) mlir/test/Dialect/Mesh/sharding-propagation.mlir (+3-3)
- (modified) mlir/test/Dialect/Mesh/simplifications.mlir (+2-2)
- (modified) mlir/test/lib/Dialect/Mesh/TestReshardingSpmdization.cpp (+6-6)
``````````diff
diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td b/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td
index 07f954459ca49d..c15c5f6bcf2f0f 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td
@@ -79,7 +79,7 @@ def MeshSharding : AttrDef<Mesh_Dialect, "MeshSharding"> {
let mnemonic = "shard";
let parameters = (ins
- AttrParameter<"::mlir::FlatSymbolRefAttr", "cluster placed">:$cluster,
+ AttrParameter<"::mlir::FlatSymbolRefAttr", "mesh placed">:$mesh,
ArrayRefParameter<"MeshAxesAttr">:$split_axes,
OptionalArrayRefParameter<"MeshAxis">:$partial_axes,
OptionalParameter<"::mlir::mesh::Partial">:$partial_type
@@ -91,9 +91,9 @@ def MeshSharding : AttrDef<Mesh_Dialect, "MeshSharding"> {
The MeshSharding attribute could be used in the encoding of a
`RankedTensorType` or the mesh.shard op. it contains three sub-attributes:
- 1. `cluster`: this attribute is a FlatSymbolRefAttr that refers to the mesh
- cluster where the distributed tensor is placed. The symbol must resolve to a
- `mesh.cluster` operation.
+ 1. `mesh`: this attribute is a FlatSymbolRefAttr that refers to the device
+ mesh where the distributed tensor is placed. The symbol must resolve to a
+ `mesh.mesh` operation.
2. `split_axes`: is an array composed of int64_t sub-arrays. The outer array's
maximum size is the `rank` of the related tensor. For the i-th sub-array, if
@@ -117,7 +117,7 @@ def MeshSharding : AttrDef<Mesh_Dialect, "MeshSharding"> {
Example:
```
- mesh.cluster @mesh0(shape = 2x2x4)
+ mesh.mesh @mesh0(shape = 2x2x4)
// The tensor is fully replicated on @mesh0.
// Currently, there must be at least one sub-array present in axes, even
@@ -140,12 +140,12 @@ def MeshSharding : AttrDef<Mesh_Dialect, "MeshSharding"> {
```
}];
let assemblyFormat = [{
- `<` $cluster `,` `[` $split_axes `]` (`,` `partial` `=` $partial_type `[`
+ `<` $mesh `,` `[` $split_axes `]` (`,` `partial` `=` $partial_type `[`
$partial_axes^ `]`)? `>`
}];
let builders = [
- AttrBuilder<(ins "FlatSymbolRefAttr":$cluster,
+ AttrBuilder<(ins "FlatSymbolRefAttr":$mesh,
"ArrayRef<SmallVector<MeshAxis>>":$split_axes,
"ArrayRef<MeshAxis>": $partial_axes,
"mesh::Partial": $partial_type), [{
@@ -153,12 +153,12 @@ def MeshSharding : AttrDef<Mesh_Dialect, "MeshSharding"> {
split_axes, [&](ArrayRef<MeshAxis> array) {
return MeshAxesAttr::get($_ctxt, array);
});
- return $_get($_ctxt, cluster, splitAxesAttr, partial_axes,
+ return $_get($_ctxt, mesh, splitAxesAttr, partial_axes,
partial_type);
}]>,
- AttrBuilder<(ins "FlatSymbolRefAttr":$cluster,
+ AttrBuilder<(ins "FlatSymbolRefAttr":$mesh,
"ArrayRef<SmallVector<MeshAxis>>":$split_axes), [{
- return MeshShardingAttr::get($_ctxt, cluster, split_axes, {}, Partial::Sum);
+ return MeshShardingAttr::get($_ctxt, mesh, split_axes, {}, Partial::Sum);
}]>
];
diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
index 78ff8bd0cac621..c557ec1da1b6d9 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
@@ -26,17 +26,17 @@ class Mesh_Op<string mnemonic, list<Trait> traits = []> :
Op<Mesh_Dialect, mnemonic, traits> {
}
-def Mesh_ClusterOp : Mesh_Op<"cluster", [Symbol]> {
- let summary = "representing a mesh cluster";
+def Mesh_MeshOp : Mesh_Op<"mesh", [Symbol]> {
+ let summary = "representing a device/process mesh";
let description = [{
- The mesh.cluster operation is a symbol operation that identifies a specific
- mesh cluster. The operation has three attributes:
+ The mesh.mesh operation is a symbol operation that identifies a specific
+ mesh. The operation has three attributes:
- 1. `sym_name`: This attribute uniquely identifies the name of the mesh
- cluster. This name serves as a symbolic reference to the cluster throughout
+ 1. `sym_name`: This attribute uniquely identifies the name of the mesh.
+ This name serves as a symbolic reference to the mesh throughout
the MLIR module, allowing for consistent referencing and easier debugging.
- 2. `shape`: This attribute represents the shape of the device cluster.
+ 2. `shape`: This attribute represents the shape of the device mesh.
It uses the same notation as a tensor shape. Also allowing for dynamic
dimensions.
This flexibility allows for dynamic device assignment or configurations
@@ -46,21 +46,21 @@ def Mesh_ClusterOp : Mesh_Op<"cluster", [Symbol]> {
Example:
```
- // A device mesh cluster with 3 axes, the total device number is 4 * 8 * 12
+ // A device mesh with 3 axes, the total device number is 4 * 8 * 12
// The dimension sizes are 4, 8, 12
- mesh.cluster @mesh0(shape = 4x8x12)
+ mesh.mesh @mesh0(shape = 4x8x12)
- // A device mesh cluster with 2 axes, the total device number is unknown
+ // A device mesh with 2 axes, the total device number is unknown
// The first dimension size is 4 and the second is unknown
- mesh.cluster @mesh1(shape = 4x?)
+ mesh.mesh @mesh1(shape = 4x?)
- // A device mesh cluster with 2 axes, the total device number is unknown
+ // A device mesh with 2 axes, the total device number is unknown
// The first dimension size is unknown and the second is 4
- mesh.cluster @mesh2(shape = ?x4)
+ mesh.mesh @mesh2(shape = ?x4)
- // A device mesh cluster with 2 axes, the number of devices along both axes
+ // A device mesh with 2 axes, the number of devices along both axes
// is unknown
- mesh.cluster @mesh3(shape = ?x?)
+ mesh.mesh @mesh3(shape = ?x?)
// Used in the mesh sharding attribute to extend the standard tensor to
// distributed
@@ -81,9 +81,9 @@ def Mesh_ClusterOp : Mesh_Op<"cluster", [Symbol]> {
let hasVerifier = 1;
}
-def Mesh_ClusterShapeOp : Mesh_Op<"cluster_shape", [
+def Mesh_MeshShapeOp : Mesh_Op<"mesh_shape", [
Pure, DeclareOpInterfaceMethods<SymbolUserOpInterface>]> {
- let summary = "Get the shape of the cluster.";
+ let summary = "Get the shape of the mesh.";
let arguments = (ins
FlatSymbolRefAttr:$mesh,
DefaultValuedAttr<Mesh_MeshAxesAttr, "{}">:$axes
@@ -99,13 +99,13 @@ def Mesh_ClusterShapeOp : Mesh_Op<"cluster_shape", [
}];
let builders = [
- OpBuilder<(ins "::mlir::mesh::ClusterOp":$mesh)>,
+ OpBuilder<(ins "::mlir::mesh::MeshOp":$mesh)>,
OpBuilder<(ins "StringRef":$mesh, "ArrayRef<MeshAxis>":$axes)>
];
}
def Mesh_ShardOp : Mesh_Op<"shard", [Pure, SameOperandsAndResultType]> {
- let summary = "Annotate on how a tensor is sharded across a mesh cluster.";
+ let summary = "Annotate on how a tensor is sharded across a mesh.";
let description = [{
The mesh.shard operation is designed to specify and guide the sharding
behavior of a tensor value across a mesh topology. This operation has one
@@ -115,7 +115,7 @@ def Mesh_ShardOp : Mesh_Op<"shard", [Pure, SameOperandsAndResultType]> {
annotated for sharding.
2. `shard`: This attribute is type of `MeshSharding`, which is the core data
- structure to represent distributed tensor in mesh cluster.
+ structure to represent distribution of a tensor on a mesh.
3. `annotate_for_users`: A unit attribute addressing the scenario when a
tensor's sharding annotation differs based on its context of use (either as
@@ -217,7 +217,7 @@ def Mesh_ProcessMultiIndexOp : Mesh_Op<"process_multi_index", [
attr-dict `:` type($result)
}];
let builders = [
- OpBuilder<(ins "::mlir::mesh::ClusterOp":$mesh)>,
+ OpBuilder<(ins "::mlir::mesh::MeshOp":$mesh)>,
OpBuilder<(ins "StringRef":$mesh, "ArrayRef<MeshAxis>":$axes)>
];
}
@@ -239,7 +239,7 @@ def Mesh_ProcessLinearIndexOp : Mesh_Op<"process_linear_index", [
let results = (outs Index:$result);
let assemblyFormat = "`on` $mesh attr-dict `:` type($result)";
let builders = [
- OpBuilder<(ins "::mlir::mesh::ClusterOp":$mesh)>
+ OpBuilder<(ins "::mlir::mesh::MeshOp":$mesh)>
];
}
@@ -268,7 +268,7 @@ def Mesh_AllGatherOp : Mesh_CollectiveCommunicationOpBase<"all_gather", [
Example:
```mlir
- mesh.cluster @mesh0(shape = 2x2)
+ mesh.mesh @mesh0(shape = 2x2)
...
%1 = mesh.all_gather %0 on @mesh0 mesh_axes = [1] gather_axis = 1
: tensor<2x2xi8> -> tensor<2x4xi8>
@@ -353,7 +353,7 @@ def Mesh_AllToAllOp : Mesh_CollectiveCommunicationOpBase<"all_to_all", [
Example:
```
- mesh.cluster @mesh0(shape = 3)
+ mesh.mesh @mesh0(shape = 3)
...
%1 = mesh.all_to_all %0 on @mesh0 mesh_axes = [0]
split_axis = 0 concat_axis = 0
@@ -410,7 +410,7 @@ def Mesh_BroadcastOp : Mesh_CollectiveCommunicationOpBase<"broadcast", [
Example:
```
- mesh.cluster @mesh0(shape = 2x2)
+ mesh.mesh @mesh0(shape = 2x2)
%1 = mesh.broadcast %0 on @mesh0
mesh_axes = [0]
@@ -466,7 +466,7 @@ def Mesh_GatherOp : Mesh_CollectiveCommunicationOpBase<"gather", [
Example:
```mlir
- mesh.cluster @mesh0(shape = 2x2)
+ mesh.mesh @mesh0(shape = 2x2)
...
%1 = mesh.gather %0 on @mesh0 mesh_axes = [1]
gather_axis = 1 root = [1]
@@ -589,7 +589,7 @@ def Mesh_ReduceScatterOp : Mesh_CollectiveCommunicationOpBase<"reduce_scatter",
across the device group.
Example:
```
- mesh.cluster @mesh0(shape = 2x2)
+ mesh.mesh @mesh0(shape = 2x2)
...
%1 = mesh.reduce_scatter %0 on @mesh0 mesh_axes = [1]
reduction = <max> scatter_axis = 0
@@ -652,7 +652,7 @@ def Mesh_ScatterOp : Mesh_CollectiveCommunicationOpBase<"scatter", [
Example:
```
- mesh.cluster @mesh0(shape = 2x2)
+ mesh.mesh @mesh0(shape = 2x2)
%1 = mesh.scatter %0 on @mesh0 mesh_axes = [0]
scatter_axis = 0
root = [1]
@@ -748,7 +748,7 @@ def Mesh_ShiftOp : Mesh_CollectiveCommunicationOpBase<"shift", [
Example:
```
- mesh.cluster @mesh0(shape = 2x4)
+ mesh.mesh @mesh0(shape = 2x4)
%1 = mesh.shift on @mesh0 mesh_axes = [1]
shift_axis = 1 offset = 2 rotate
: tensor<2xi8> -> tensor<2xi8>
diff --git a/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.h b/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.h
index a32274d857f15d..3bef7e6babdec9 100644
--- a/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.h
+++ b/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.h
@@ -25,14 +25,14 @@ struct ShardingOption {
// An array of int array. The sub-array at the i-th position signifies the
// mesh axes the i-th loop will be sharded on.
ShardingArray shardingArray = {};
- FlatSymbolRefAttr cluster = nullptr;
+ FlatSymbolRefAttr mesh = nullptr;
// `empty` being true indicates that no sharding information can be inferred
// at present. Note that it is different from the case where an operation is
// not sharded.
bool empty = false;
ShardingOption() = default;
- ShardingOption(ShardingArray shardingArray, FlatSymbolRefAttr cluster)
- : shardingArray(std::move(shardingArray)), cluster(cluster) {}
+ ShardingOption(ShardingArray shardingArray, FlatSymbolRefAttr mesh)
+ : shardingArray(std::move(shardingArray)), mesh(mesh) {}
};
// This method retrieves the 'MeshShardingAttr' attribute from a given operation
diff --git a/mlir/include/mlir/Dialect/Mesh/Transforms/Spmdization.h b/mlir/include/mlir/Dialect/Mesh/Transforms/Spmdization.h
index f71bb9b262a380..7cb992aac019b3 100644
--- a/mlir/include/mlir/Dialect/Mesh/Transforms/Spmdization.h
+++ b/mlir/include/mlir/Dialect/Mesh/Transforms/Spmdization.h
@@ -17,14 +17,14 @@ namespace mlir {
namespace mesh {
// Return the sharded shape `shape` acording ot sharding `sharding`.
-ShapedType shardShapedType(ShapedType shape, ClusterOp mesh,
+ShapedType shardShapedType(ShapedType shape, MeshOp mesh,
MeshShardingAttr sharding);
// Insert resharding spmdization of the value `sourceShardValue`
// from sharding `source` to sharding `target`.
// `sourceShardValue` is the already sharded value according to `source`.
-TypedValue<ShapedType> reshard(OpBuilder &builder, ClusterOp mesh,
- ShardOp source, ShardOp target,
+TypedValue<ShapedType> reshard(OpBuilder &builder, MeshOp mesh, ShardOp source,
+ ShardOp target,
TypedValue<ShapedType> sourceShardValue);
void reshardingRegisterDependentDialects(DialectRegistry ®istry);
diff --git a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
index f6b6b7c248c432..994a017a1f46c2 100644
--- a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
+++ b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
@@ -114,10 +114,10 @@ Operation *MeshDialect::materializeConstant(OpBuilder &builder, Attribute value,
// Mesh utilities
//===----------------------------------------------------------------------===//
-static FailureOr<ClusterOp> getMesh(Operation *op, FlatSymbolRefAttr meshSymbol,
- SymbolTableCollection &symbolTable) {
- mesh::ClusterOp mesh =
- symbolTable.lookupNearestSymbolFrom<mesh::ClusterOp>(op, meshSymbol);
+static FailureOr<MeshOp> getMesh(Operation *op, FlatSymbolRefAttr meshSymbol,
+ SymbolTableCollection &symbolTable) {
+ mesh::MeshOp mesh =
+ symbolTable.lookupNearestSymbolFrom<mesh::MeshOp>(op, meshSymbol);
if (!mesh) {
return op->emitError() << "Undefined required mesh symbol \""
<< meshSymbol.getValue() << "\".";
@@ -144,7 +144,7 @@ bool isUnique(It begin, It end) {
}
static LogicalResult verifyMeshAxes(Location loc, ArrayRef<MeshAxis> axes,
- ClusterOp mesh) {
+ MeshOp mesh) {
SmallVector<MeshAxis> sorted = llvm::to_vector(axes);
llvm::sort(sorted);
if (!isUnique(sorted.begin(), sorted.end())) {
@@ -192,22 +192,22 @@ Partial mesh::getPartialTypeFromReduction(IteratorType iType) {
}
//===----------------------------------------------------------------------===//
-// mesh.cluster op
+// mesh.mesh op
//===----------------------------------------------------------------------===//
-LogicalResult ClusterOp::verify() {
+LogicalResult MeshOp::verify() {
int64_t rank = getRank();
if (rank <= 0)
- return emitOpError("rank of cluster is expected to be a positive integer");
+ return emitOpError("rank of mesh is expected to be a positive integer");
if (getShape().size() > size_t(rank))
return emitOpError(
- "rank of shape is not expected to be larger than rank of cluster");
+ "rank of shape is not expected to be larger than rank of mesh");
for (int64_t dimSize : getShape()) {
if (dimSize < 0 && !ShapedType::isDynamic(dimSize))
- return emitOpError("dimension size of a mesh cluster is expected to be "
+ return emitOpError("dimension size of a mesh is expected to be "
"non-negative or dynamic");
}
@@ -215,11 +215,11 @@ LogicalResult ClusterOp::verify() {
}
//===----------------------------------------------------------------------===//
-// mesh.cluster_shape op
+// mesh.mesh_shape op
//===----------------------------------------------------------------------===//
LogicalResult
-ClusterShapeOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
+MeshShapeOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
auto mesh = ::getMesh(getOperation(), getMeshAttr(), symbolTable);
if (failed(mesh)) {
return failure();
@@ -238,16 +238,16 @@ ClusterShapeOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
return success();
}
-void ClusterShapeOp::build(OpBuilder &odsBuilder, OperationState &odsState,
- ClusterOp mesh) {
+void MeshShapeOp::build(OpBuilder &odsBuilder, OperationState &odsState,
+ MeshOp mesh) {
build(odsBuilder, odsState,
SmallVector<Type>(mesh.getRank(), odsBuilder.getIndexType()),
mesh.getSymName(),
MeshAxesAttr::get(odsBuilder.getContext(), SmallVector<MeshAxis>()));
}
-void ClusterShapeOp::build(OpBuilder &odsBuilder, OperationState &odsState,
- StringRef mesh, ArrayRef<MeshAxis> axes) {
+void MeshShapeOp::build(OpBuilder &odsBuilder, OperationState &odsState,
+ StringRef mesh, ArrayRef<MeshAxis> axes) {
build(odsBuilder, odsState,
SmallVector<Type>(axes.size(), odsBuilder.getIndexType()), mesh,
MeshAxesAttr::get(odsBuilder.getContext(), axes));
@@ -261,7 +261,7 @@ LogicalResult
MeshShardingAttr::verify(function_ref<InFlightDiagnostic()> emitError,
FlatSymbolRefAttr, ArrayRef<MeshAxesAttr> splitAxes,
ArrayRef<MeshAxis> partialAxes, Partial) {
- // TODO: At present cluster symbol ref is not verified. This is due to the
+ // TODO: At present mesh symbol ref is not verified. This is due to the
// difficulty in fetching the corresponding symbol op based on an attribute.
llvm::SmallSet<MeshAxis, 4> visitedAxes;
@@ -292,8 +292,7 @@ bool MeshShardingAttr::operator==(Attribute rhs) const {
}
bool MeshShardingAttr::operator==(MeshShardingAttr rhs) const {
- if (getCluster() != rhs.getCluster() ||
- getPartialAxes() != rhs.getPartialAxes()) {
+ if (getMesh() != rhs.getMesh() || getPartialAxes() != rhs.getPartialAxes()) {
return false;
}
@@ -342,7 +341,7 @@ ProcessMultiIndexOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
}
void ProcessMultiIndexOp::build(OpBuilder &odsBuilder, OperationState &odsState,
- ClusterOp mesh) {
+ MeshOp mesh) {
build(odsBuilder, odsState,
SmallVector<Type>(mesh.getRank(), odsBuilder.getIndexType()),
mesh.getSymName(), ArrayRef<MeshAxis>());
@@ -369,7 +368,7 @@ ProcessLinearIndexOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
}
void ProcessLinearIndexOp::build(OpBuilder &odsBuilder,
- OperationState &odsState, ClusterOp mesh) {
+ OperationState &odsState, MeshOp mesh) {
build(odsBuilder, odsState, mesh.getSymName());
}
@@ -427,7 +426,7 @@ static LogicalResult verifyInGroupDevice(Location loc, StringRef deviceName,
}
template <typename Op>
-static FailureOr<ClusterOp>
+static FailureOr<MeshOp>
getMeshAndVerifyAxes(Op op, SymbolTableCollection &symbolTable) {
auto mesh = ::getMesh(op.getOperation(), op.getMeshAttr(), symbolTable);
if (failed(mesh)) {
diff --git a/mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp b/mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp
index dca7e86e6f07f5..5dc91ff1c02d20 100644
--- a/mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp
+++ b/mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp
@@ -215,11 +215,10 @@ namespace {
// Update the given `shardingOption` according to `meshAxes` and `loopIdx`
static LogicalResult fillShardingOption(Operation *op,
ShardingOption &shardingOption,
- FlatSymbolRefAttr cluster,
+ FlatSymbolRefAttr mesh,
ArrayRef<MeshAxis> meshAxes,
unsigned loopIdx) {
- if ((shardingOption.cluster && cluster &&
- shardingOption.cluster != cluster) ||
+ if ((shardingOption.mesh && mesh && shardingOption.mesh != mesh) ||
(!shardingOption.shardingArray[loopIdx].empty() &&
shardingOption.shardingArray[loopIdx] != meshAxes)) {
LLVM_DEBUG(DBGS() << "sharding option conflicts on loop iterator "
@@ -238,8 +237,8 @@ static LogicalResult fillShardingOption(Operation *op,
}
}
}
- if (cluster)
- shardingOption.cluster = cluster;
+ if (mesh)
+ shardingOption.mesh = mesh;
if (shardingOption.shardingArray[loopIdx].empty())
shardingOption.shardingArray[loopIdx].append(meshAxes.begin(),
...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/79484
More information about the Mlir-commits
mailing list