[Mlir-commits] [mlir] [mlir][mesh] Remove rank attribute and rename dim_sizes to shape in ClusterOp (PR #77838)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Jan 11 13:48:34 PST 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Boian Petkantchin (sogartar)
<details>
<summary>Changes</summary>
Remove the somewhat redundant rank attribute.
Before this change
```
mesh.cluster @<!-- -->mesh(rank = 3, shape = 2x3)
```
After
```
mesh.cluster @<!-- -->mesh(shape = 2x3x?)
```
The rank is instead determined by the provided shape. With this change no longer `getDimSizes()` can be wrongly assumed to have size equal to the cluster rank.
Now `getShape().size()` will always equal `getRank()`.
---
Patch is 38.97 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/77838.diff
13 Files Affected:
- (modified) mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td (+1-1)
- (modified) mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td (+15-25)
- (modified) mlir/lib/Dialect/Mesh/IR/MeshOps.cpp (+23-33)
- (modified) mlir/lib/Dialect/Mesh/Transforms/Simplifications.cpp (+2-2)
- (modified) mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp (+7-9)
- (modified) mlir/test/Dialect/Mesh/canonicalization.mlir (+1-1)
- (modified) mlir/test/Dialect/Mesh/folding.mlir (+2-2)
- (modified) mlir/test/Dialect/Mesh/invalid.mlir (+65-70)
- (modified) mlir/test/Dialect/Mesh/ops.mlir (+10-10)
- (modified) mlir/test/Dialect/Mesh/process-multi-index-op-lowering.mlir (+1-1)
- (modified) mlir/test/Dialect/Mesh/resharding-spmdization.mlir (+2-2)
- (modified) mlir/test/Dialect/Mesh/sharding-propagation.mlir (+3-3)
- (modified) mlir/test/Dialect/Mesh/simplifications.mlir (+2-2)
``````````diff
diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td b/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td
index bda6467e9c5d4b..07f954459ca49d 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td
@@ -117,7 +117,7 @@ def MeshSharding : AttrDef<Mesh_Dialect, "MeshSharding"> {
Example:
```
- mesh.cluster @mesh0(rank = 3, dim_sizes = [2, 2, 4])
+ mesh.cluster @mesh0(shape = 2x2x4)
// The tensor is fully replicated on @mesh0.
// Currently, there must be at least one sub-array present in axes, even
diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
index a9068562f5c903..c25996cf122c1c 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
@@ -41,7 +41,7 @@ def Mesh_ClusterOp : Mesh_Op<"cluster", [Symbol]> {
determine the layout and the addressing space of the computation distributed
across the mesh.
- 3. `dim_sizes`: This attribute represents the shape of the device cluster.
+ 3. `shape`: This attribute represents the shape of the device cluster.
It uses the same notation as a tensor shape. Also allowing for dynamic
dimensions.
This flexibility allows for dynamic device assignment or configurations
@@ -53,19 +53,19 @@ def Mesh_ClusterOp : Mesh_Op<"cluster", [Symbol]> {
```
// A device mesh cluster with 3 axes, the total device number is 4 * 8 * 12
// The dimension sizes are 4, 8, 12
- mesh.cluster @mesh0(rank = 3, dim_sizes = 4x8x12)
+ mesh.cluster @mesh0(shape = 4x8x12)
// A device mesh cluster with 2 axes, the total device number is unknown
// The first dimension size is 4 and the second is unknown
- mesh.cluster @mesh1(rank = 2, dim_sizes = 4)
+ mesh.cluster @mesh1(shape = 4x?)
// A device mesh cluster with 2 axes, the total device number is unknown
// The first dimension size is unknown and the second is 4
- mesh.cluster @mesh2(rank = 2, dim_sizes = ?x4)
+ mesh.cluster @mesh2(shape = ?x4)
// A device mesh cluster with 2 axes, the number of devices along both axes
// is unknown
- mesh.cluster @mesh3(rank = 2)
+ mesh.cluster @mesh3(shape = ?x?)
// Used in the mesh sharding attribute to extend the standard tensor to
// distributed
@@ -74,24 +74,14 @@ def Mesh_ClusterOp : Mesh_Op<"cluster", [Symbol]> {
}];
let arguments = (ins
SymbolNameAttr:$sym_name,
- I64Attr:$rank,
- DefaultValuedAttr<DenseI64ArrayAttr, "{}">:$dim_sizes
+ DenseI64ArrayAttr:$shape
);
let assemblyFormat = [{
- $sym_name `(` `rank` `=` $rank (`,` `dim_sizes` `=` custom<DimensionList>($dim_sizes)^)? `)`
+ $sym_name `(` `shape` `=` custom<DimensionList>($shape) `)`
attr-dict
}];
let extraClassDeclaration = [{
- // The `dim_sizes` attribute may have size less than the rank of the mesh.
- // Returns the shape of the mesh with missing trailing dimensions
- // explicitly set as dynamic.
- ::mlir::SmallVector<int64_t> canonicalDimSizes();
-
- template <typename OutIt>
- void canonicalDimSizes(OutIt outIt) {
- std::copy(getDimSizes().begin(), getDimSizes().end(), outIt);
- std::fill_n(outIt, getRank() - getDimSizes().size(), ::mlir::ShapedType::kDynamic);
- }
+ int64_t getRank() { return getShape().size(); }
}];
let hasVerifier = 1;
}
@@ -283,7 +273,7 @@ def Mesh_AllGatherOp : Mesh_CollectiveCommunicationOpBase<"all_gather", [
Example:
```mlir
- mesh.cluster @mesh0(rank = 2, dim_sizes = 2x2)
+ mesh.cluster @mesh0(shape = 2x2)
...
%1 = mesh.all_gather %0 on @mesh0 mesh_axes = [1] gather_axis = 1
: tensor<2x2xi8> -> tensor<2x4xi8>
@@ -368,7 +358,7 @@ def Mesh_AllToAllOp : Mesh_CollectiveCommunicationOpBase<"all_to_all", [
Example:
```
- mesh.cluster @mesh0(rank = 1, dim_sizes = 3)
+ mesh.cluster @mesh0(shape = 3)
...
%1 = mesh.all_to_all %0 on @mesh0 mesh_axes = [0]
split_axis = 0 concat_axis = 0
@@ -425,7 +415,7 @@ def Mesh_BroadcastOp : Mesh_CollectiveCommunicationOpBase<"broadcast", [
Example:
```
- mesh.cluster @mesh0(rank = 2, dim_sizes = [2, 2])
+ mesh.cluster @mesh0(shape = 2x2)
%1 = mesh.broadcast %0 on @mesh0
mesh_axes = [0]
@@ -481,7 +471,7 @@ def Mesh_GatherOp : Mesh_CollectiveCommunicationOpBase<"gather", [
Example:
```mlir
- mesh.cluster @mesh0(rank = 2, dim_sizes = [2, 2])
+ mesh.cluster @mesh0(shape = 2x2)
...
%1 = mesh.gather %0 on @mesh0 mesh_axes = [1]
gather_axis = 1 root = [1]
@@ -604,7 +594,7 @@ def Mesh_ReduceScatterOp : Mesh_CollectiveCommunicationOpBase<"reduce_scatter",
across the device group.
Example:
```
- mesh.cluster @mesh0(rank = 1, dim_sizes = 2x2)
+ mesh.cluster @mesh0(shape = 2x2)
...
%1 = mesh.reduce_scatter %0 on @mesh0 mesh_axes = [1]
reduction = <max> scatter_axis = 0
@@ -667,7 +657,7 @@ def Mesh_ScatterOp : Mesh_CollectiveCommunicationOpBase<"scatter", [
Example:
```
- mesh.cluster @mesh0(rank = 2, dim_sizes = [2, 2])
+ mesh.cluster @mesh0(shape = 2x2)
%1 = mesh.scatter %0 on @mesh0 mesh_axes = [0]
scatter_axis = 0
root = [1]
@@ -763,7 +753,7 @@ def Mesh_ShiftOp : Mesh_CollectiveCommunicationOpBase<"shift", [
Example:
```
- mesh.cluster @mesh0(rank = 2, dim_sizes = [2, 4])
+ mesh.cluster @mesh0(shape = 2x4)
%1 = mesh.shift on @mesh0 mesh_axes = [1]
shift_axis = 1 offset = 2 rotate
: tensor<2xi8> -> tensor<2xi8>
diff --git a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
index 957b380efd516b..fa9da596a34587 100644
--- a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
+++ b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
@@ -196,17 +196,16 @@ Partial mesh::getPartialTypeFromReduction(IteratorType iType) {
//===----------------------------------------------------------------------===//
LogicalResult ClusterOp::verify() {
- ArrayRef<int64_t> dimSizes = getDimSizes();
- uint64_t rank = getRank();
+ int64_t rank = getRank();
- if (rank == 0)
+ if (rank <= 0)
return emitOpError("rank of cluster is expected to be a positive integer");
- if (dimSizes.size() > rank)
+ if (getShape().size() > rank)
return emitOpError(
- "rank of dim_sizes is not expected to be larger than rank of cluster");
+ "rank of shape is not expected to be larger than rank of cluster");
- for (int64_t dimSize : dimSizes) {
+ for (int64_t dimSize : getShape()) {
if (dimSize < 0 && !ShapedType::isDynamic(dimSize))
return emitOpError("dimension size of a mesh cluster is expected to be "
"non-negative or dynamic");
@@ -215,13 +214,6 @@ LogicalResult ClusterOp::verify() {
return success();
}
-SmallVector<int64_t> ClusterOp::canonicalDimSizes() {
- SmallVector<int64_t> result;
- canonicalDimSizes(std::back_inserter(result));
- result.reserve(getRank());
- return result;
-}
-
//===----------------------------------------------------------------------===//
// mesh.cluster_shape op
//===----------------------------------------------------------------------===//
@@ -614,7 +606,7 @@ AllGatherOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
auto gatherAxis = getGatherAxis().getSExtValue();
return verifyGatherOperandAndResultShape(getOperand(), getResult(),
gatherAxis, getMeshAxes(),
- mesh.value().canonicalDimSizes());
+ mesh.value().getShape());
}
void AllGatherOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
@@ -648,8 +640,7 @@ LogicalResult AllToAllOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
return verifyAllToAllOperandAndResultShape(
getOperand(), getResult(), getSplitAxis().getSExtValue(),
- getConcatAxis().getSExtValue(), getMeshAxes(),
- mesh.value().canonicalDimSizes());
+ getConcatAxis().getSExtValue(), getMeshAxes(), mesh.value().getShape());
}
void AllToAllOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
@@ -667,9 +658,9 @@ BroadcastOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
if (failed(mesh)) {
return failure();
}
- auto meshShape = mesh.value().canonicalDimSizes();
if (failed(verifyInGroupDevice(getLoc(), getRootAttrName(), getRoot(),
- getRootDynamic(), getMeshAxes(), meshShape))) {
+ getRootDynamic(), getMeshAxes(),
+ mesh.value().getShape()))) {
return failure();
}
@@ -690,16 +681,16 @@ LogicalResult GatherOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
if (failed(mesh)) {
return failure();
}
- auto meshShape = mesh.value().canonicalDimSizes();
if (failed(verifyInGroupDevice(getLoc(), getRootAttrName(), getRoot(),
- getRootDynamic(), getMeshAxes(), meshShape))) {
+ getRootDynamic(), getMeshAxes(),
+ mesh.value().getShape()))) {
return failure();
}
auto gatherAxis = getGatherAxis().getSExtValue();
return verifyGatherOperandAndResultShape(getInput(), getResult(), gatherAxis,
getMeshAxes(),
- mesh.value().canonicalDimSizes());
+ mesh.value().getShape());
}
void GatherOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
@@ -716,10 +707,10 @@ LogicalResult RecvOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
if (failed(mesh)) {
return failure();
}
- auto meshShape = mesh.value().canonicalDimSizes();
- if (getSource() && failed(verifyInGroupDevice(
- getLoc(), getSourceAttrName(), getSource().value(),
- getSourceDynamic(), getMeshAxes(), meshShape))) {
+ if (getSource() &&
+ failed(verifyInGroupDevice(getLoc(), getSourceAttrName(),
+ getSource().value(), getSourceDynamic(),
+ getMeshAxes(), mesh.value().getShape()))) {
return failure();
}
return success();
@@ -739,9 +730,9 @@ LogicalResult ReduceOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
if (failed(mesh)) {
return failure();
}
- auto meshShape = mesh.value().canonicalDimSizes();
if (failed(verifyInGroupDevice(getLoc(), getRootAttrName(), getRoot(),
- getRootDynamic(), getMeshAxes(), meshShape))) {
+ getRootDynamic(), getMeshAxes(),
+ mesh.value().getShape()))) {
return failure();
}
@@ -766,7 +757,7 @@ ReduceScatterOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
return verifyScatterOperandAndResultShape(
getOperand(), getResult(), getScatterAxis().getSExtValue(), getMeshAxes(),
- mesh.value().canonicalDimSizes());
+ mesh.value().getShape());
}
void ReduceScatterOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
@@ -783,16 +774,16 @@ LogicalResult ScatterOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
if (failed(mesh)) {
return failure();
}
- auto meshShape = mesh.value().canonicalDimSizes();
if (failed(verifyInGroupDevice(getLoc(), getRootAttrName(), getRoot(),
- getRootDynamic(), getMeshAxes(), meshShape))) {
+ getRootDynamic(), getMeshAxes(),
+ mesh.value().getShape()))) {
return failure();
}
auto scatterAxis = getScatterAxis().getSExtValue();
return verifyScatterOperandAndResultShape(getInput(), getResult(),
scatterAxis, getMeshAxes(),
- mesh.value().canonicalDimSizes());
+ mesh.value().getShape());
}
void ScatterOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
@@ -809,10 +800,9 @@ LogicalResult SendOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
if (failed(mesh)) {
return failure();
}
- auto meshShape = mesh.value().canonicalDimSizes();
if (failed(verifyInGroupDevice(getLoc(), getDestinationAttrName(),
getDestination(), getDestinationDynamic(),
- getMeshAxes(), meshShape))) {
+ getMeshAxes(), mesh.value().getShape()))) {
return failure();
}
return success();
diff --git a/mlir/lib/Dialect/Mesh/Transforms/Simplifications.cpp b/mlir/lib/Dialect/Mesh/Transforms/Simplifications.cpp
index c9275ad5ad4551..c478b6da4c27b3 100644
--- a/mlir/lib/Dialect/Mesh/Transforms/Simplifications.cpp
+++ b/mlir/lib/Dialect/Mesh/Transforms/Simplifications.cpp
@@ -80,7 +80,7 @@ struct ClusterShapeFolder : OpRewritePattern<ClusterShapeOp> {
opMeshAxes = opAxesIota;
}
if (llvm::all_of(opMeshAxes, [&mesh](MeshAxis axis) {
- return ShapedType::isDynamic(mesh.getDimSizes()[axis]);
+ return ShapedType::isDynamic(mesh.getShape()[axis]);
})) {
// All mesh dimensions are dynamic. Nothing to fold.
return failure();
@@ -91,7 +91,7 @@ struct ClusterShapeFolder : OpRewritePattern<ClusterShapeOp> {
SmallVector<size_t> newToOldResultsIndexMap;
for (size_t i = 0; i < opMeshAxes.size(); ++i) {
- auto meshAxisSize = mesh.getDimSizes()[opMeshAxes[i]];
+ auto meshAxisSize = mesh.getShape()[opMeshAxes[i]];
if (ShapedType::isDynamic(meshAxisSize)) {
newToOldResultsIndexMap.push_back(i);
newShapeOpMeshAxes.push_back(opMeshAxes[i]);
diff --git a/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp b/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
index 0e83c024fc08f8..9478b2e4ee5cb2 100644
--- a/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
+++ b/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
@@ -88,8 +88,8 @@ ShapedType shardShapedType(ShapedType shape, ClusterOp mesh,
MeshShardingAttr sharding) {
using Dim = std::decay_t<decltype(shape.getDimSize(0))>;
SmallVector<Dim> resShapeArr(shape.getShape().size());
- shardShape(shape.getShape(), mesh.canonicalDimSizes(),
- sharding.getSplitAxes(), resShapeArr);
+ shardShape(shape.getShape(), mesh.getShape(), sharding.getSplitAxes(),
+ resShapeArr);
return shape.clone(resShapeArr);
}
@@ -212,9 +212,8 @@ splitLastAxisInResharding(ImplicitLocOpBuilder &builder,
MeshShardingAttr targetSharding = targetShardingInSplitLastAxis(
ctx, sourceSharding, splitTensorAxis, splitMeshAxis);
- ShapedType targetShape =
- targetShapeInSplitLastAxis(sourceShard.getType(), splitTensorAxis,
- mesh.canonicalDimSizes()[splitMeshAxis]);
+ ShapedType targetShape = targetShapeInSplitLastAxis(
+ sourceShard.getType(), splitTensorAxis, mesh.getShape()[splitMeshAxis]);
Value meshAxisSize =
builder
@@ -391,8 +390,7 @@ unsplitLastAxisInResharding(ImplicitLocOpBuilder &builder,
MeshShardingAttr targetSharding =
targetShardingInUnsplitLastAxis(ctx, sourceSharding, splitMeshAxis);
ShapedType allGatherResultShape = allGatherResultShapeInUnsplitLastAxis(
- sourceShard.getType(), mesh.canonicalDimSizes()[splitMeshAxis],
- splitTensorAxis);
+ sourceShard.getType(), mesh.getShape()[splitMeshAxis], splitTensorAxis);
Value allGatherResult = builder.create<AllGatherOp>(
RankedTensorType::get(allGatherResultShape.getShape(),
allGatherResultShape.getElementType()),
@@ -526,8 +524,8 @@ moveLastSplitAxisInResharding(ImplicitLocOpBuilder &builder, ClusterOp mesh,
MeshShardingAttr targetSharding = targetShardingInMoveLastAxis(
ctx, sourceSharding, sourceTensorAxis, targetTensorAxis);
ShapedType allToAllResultShape = allToAllResultShapeInMoveLastAxis(
- sourceShard.getType(), mesh.canonicalDimSizes()[meshAxis],
- sourceTensorAxis, targetTensorAxis);
+ sourceShard.getType(), mesh.getShape()[meshAxis], sourceTensorAxis,
+ targetTensorAxis);
Value allToAllResult = builder.create<AllToAllOp>(
RankedTensorType::get(allToAllResultShape.getShape(),
allToAllResultShape.getElementType()),
diff --git a/mlir/test/Dialect/Mesh/canonicalization.mlir b/mlir/test/Dialect/Mesh/canonicalization.mlir
index 0a00ab41268d01..4cc009ef24eb3c 100644
--- a/mlir/test/Dialect/Mesh/canonicalization.mlir
+++ b/mlir/test/Dialect/Mesh/canonicalization.mlir
@@ -1,6 +1,6 @@
// RUN: mlir-opt --canonicalize %s | FileCheck %s
-mesh.cluster @mesh0(rank = 2, dim_sizes = 2x4)
+mesh.cluster @mesh0(shape = 2x4)
// CHECK-LABEL: func @all_reduce_empty_mesh_axes
func.func @all_reduce_empty_mesh_axes(
diff --git a/mlir/test/Dialect/Mesh/folding.mlir b/mlir/test/Dialect/Mesh/folding.mlir
index dd64d746341b83..9162dc57ecfdf4 100644
--- a/mlir/test/Dialect/Mesh/folding.mlir
+++ b/mlir/test/Dialect/Mesh/folding.mlir
@@ -1,7 +1,7 @@
// RUN: mlir-opt -test-mesh-simplifications %s | FileCheck %s
-mesh.cluster @mesh0(rank = 3, dim_sizes = 4x?x2)
-mesh.cluster @mesh1(rank = 2, dim_sizes = 2x3)
+mesh.cluster @mesh0(shape = 4x?x2)
+mesh.cluster @mesh1(shape = 2x3)
// CHECK-LABEL: func.func @cluster_shape_op_folding
func.func @cluster_shape_op_folding() -> (index, index) {
diff --git a/mlir/test/Dialect/Mesh/invalid.mlir b/mlir/test/Dialect/Mesh/invalid.mlir
index f3524a82a1b9d2..8a1fb80065573b 100644
--- a/mlir/test/Dialect/Mesh/invalid.mlir
+++ b/mlir/test/Dialect/Mesh/invalid.mlir
@@ -1,21 +1,16 @@
// RUN: mlir-opt -split-input-file -verify-diagnostics %s
// expected-error at +1 {{rank of cluster is expected to be a positive integer}}
-mesh.cluster @mesh0(rank = 0)
-
-// -----
-
-// expected-error at +1 {{rank of dim_sizes is not expected to be larger than rank of cluster}}
-mesh.cluster @mesh0(rank = 2, dim_sizes = 2x3x4)
+mesh.cluster @mesh0(shape = [])
// -----
// expected-error at +1 {{custom op 'mesh.cluster' Failed parsing dimension list. Did you mean an empty list? It must be denoted by "[]".}}
-mesh.cluster @mesh0(rank = 2, dim_sizes = -1)
+mesh.cluster @mesh0(shape = -1)
// -----
-mesh.cluster @mesh0(rank = 2, dim_sizes = 2x4)
+mesh.cluster @mesh0(shape = 2x4)
func.func @mesh_axis_duplicated_different_subarray(
// expected-error at +1 {{mesh axis duplicated}}
@@ -26,7 +21,7 @@ func.func @mesh_axis_duplicated_different_subarray(
// -----
-mesh.cluster @mesh0(rank = 2, dim_sizes = 2x4)
+mesh.cluster @mesh0(shape = 2x4)
func.func @mesh_axis_duplicated_same_subarray(
// expected-error at +1 {{mesh axis duplicated}}
@@ -37,7 +32,7 @@ func.func @mesh_axis_duplicated_same_subarray(
// -----
-mesh.cluster @mesh0(rank = 2, dim_sizes = 2x4)
+mesh.cluster @mesh0(shape = 2x4)
func.func @mesh_axis_duplicated_bewteen_split_and_partial(
// expected-error at +1 {{mesh axis duplicated}}
@@ -48,7 +43,7 @@ func.func @mesh_axis_duplicated_bewteen_split_and_partial(
// -----
-mesh.cluster @mesh0(rank = 2, dim_sizes = 2x4)
+mesh.cluster @mesh0(shape = 2x4)
func.func @mesh_axis_negtive_in_split_part(
// expected-error at +1 {{mesh axis is expected to be non-negative}}
@@ -59,7 +54,7 @@ func.func @mesh_axis_negtive_in_split_part(
// -----
-mesh.cluster @mesh0(rank = 2, dim_sizes = 2x4)
+mesh.cluster @mesh0(shape = 2x4)
func.func @mesh_axis_negtive_in_partial(
// expected-error at +1 {{mesh axis is expected to be non-negative}}
@@ -78,7 +73,7 @@ func.func @sharding_attribute_invalid_nested_symbol(%arg0 : tensor<4x8xf32>) {
// -----
-mesh.cluster @mesh0(rank = 2, dim_sizes = 2x4)
+mesh.cluster @mesh0(shape = 2x4)
func.func @cluster_shape_mesh_axis_out_of_bounds() -> (index, index) {
// expected-error at +1 {{0-based mesh axis index 2 is out of bounds. The referenced mesh "mesh0" is of rank 2.}}
@@ -88,7 +83,7 @@ func.func @cluster_shape_mesh_axis_out_of_bounds() -> (index, index) {
// -----
-mesh.cluster @mesh0(rank = 3, dim_sizes = 1x2x3)
+mesh.cluster @mesh0(shape = 1x2x3)
func.func @cluster_shape_duplicate_mesh_axis() -> (index, index, index) {
...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/77838
More information about the Mlir-commits
mailing list