[Mlir-commits] [mlir] [mlir][mesh] Use tensor shape notation for the shape of a cluster (PR #73826)

Boian Petkantchin llvmlistbot at llvm.org
Wed Nov 29 09:18:40 PST 2023


https://github.com/sogartar created https://github.com/llvm/llvm-project/pull/73826

Examle:

substitute
```
mesh.cluster @mesh0(rank = 2, dim_sizes = [0, 4])
```

with
```
mesh.cluster @mesh0(rank = 2, dim_sizes = ?x4)
```

The only difference is for 0-rank shapes.
With tensors you would have something like `tensor<f32>`. Here to avoid matching an empty string a 0-rank shape is denoted by `[]`.

-----

I am not quite happy with the 0-rank special case. An alternative is to use the same notation as arrays but handle `?` for dynamic dimensions. For example `[1, 2, ?]`. But then it would be different from the notation of the `tensor` type.

I wanted to rename the attribute `dim_sizes` to `shape`, but did not do that in this PR to avoid clutter.

>From 4c8a87bea7e9a288377b7d19cdad0a915de3ae85 Mon Sep 17 00:00:00 2001
From: Boian Petkantchin <boian.petkantchin at amd.com>
Date: Wed, 29 Nov 2023 08:55:16 -0800
Subject: [PATCH] [mlir][mesh] Use tensor shape notation for the shape of a
 cluster

Examle:

substitute
mesh.cluster @mesh0(rank = 2, dim_sizes = [0, 4])

with
mesh.cluster @mesh0(rank = 2, dim_sizes = ?x4)

The only difference is for 0-rank shapes.
With tensors you would have something like `tensor<f32>`.
Here to avoid matching an empty string a 0-rank shape is denoted by `[]`.
---
 mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td  | 14 ++---
 mlir/include/mlir/IR/BuiltinAttributes.h      |  5 ++
 mlir/lib/Dialect/Mesh/IR/MeshOps.cpp          | 13 ++---
 mlir/lib/IR/BuiltinAttributes.cpp             | 41 ++++++++++++++
 mlir/test/Dialect/Mesh/canonicalization.mlir  |  2 +-
 mlir/test/Dialect/Mesh/invalid.mlir           | 56 +++++++++----------
 mlir/test/Dialect/Mesh/ops.mlir               |  8 +--
 .../Dialect/Mesh/sharding-propagation.mlir    |  2 +-
 8 files changed, 91 insertions(+), 50 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
index 5cce15dd1015ecc..4f258be39b92b65 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
@@ -51,15 +51,15 @@ def Mesh_ClusterOp : Mesh_Op<"cluster", [Symbol]> {
     ```
     // A device mesh cluster with 3 axes, the total device number is 4 * 8 * 12
     // The dimension sizes are 4, 8, 12 
-    mesh.cluster @mesh0(rank = 3, dim_sizes = [4, 8, 12])
+    mesh.cluster @mesh0(rank = 3, dim_sizes = 4x8x12)
 
     // A device mesh cluster with 2 axes, the total device number is unknown
     // The first dimension size is 4 and the second is unknown
-    mesh.cluster @mesh1(rank = 2, dim_sizes = [4])
+    mesh.cluster @mesh1(rank = 2, dim_sizes = 4)
 
     // A device mesh cluster with 2 axes, the total device number is unknown
     // The first dimension size is unknown and the second is 4
-    mesh.cluster @mesh2(rank = 2, dim_sizes = [0, 4])
+    mesh.cluster @mesh2(rank = 2, dim_sizes = ?x4)
 
     // A device mesh cluster with 2 axes, the number of devices along both axes
     // is unknown
@@ -76,7 +76,7 @@ def Mesh_ClusterOp : Mesh_Op<"cluster", [Symbol]> {
     DefaultValuedAttr<DenseI64ArrayAttr, "{}">:$dim_sizes
   );
   let assemblyFormat = [{
-    $sym_name `(` `rank` `=` $rank (`,` `dim_sizes` `=` $dim_sizes^)? `)`
+    $sym_name `(` `rank` `=` $rank (`,` `dim_sizes` `=` custom<Shape>($dim_sizes)^)? `)`
       attr-dict
   }];
   let extraClassDeclaration = [{
@@ -210,7 +210,7 @@ def Mesh_AllGatherOp : Mesh_CollectiveCommunicationOpBase<"all_gather", [
 
     Example:
     ```mlir
-    mesh.cluster @mesh0(rank = 2, dim_sizes = [2, 2])
+    mesh.cluster @mesh0(rank = 2, dim_sizes = 2x2)
     ...
     %1 = mesh.all_gather %0 on @mesh0 mesh_axes = [1] gather_axis = 1
       : tensor<2x2xi8> -> tensor<2x4xi8>
@@ -295,7 +295,7 @@ def Mesh_AllToAllOp : Mesh_CollectiveCommunicationOpBase<"all_to_all", [
 
     Example:
     ```
-    mesh.cluster @mesh0(rank = 1, dim_sizes = [3])
+    mesh.cluster @mesh0(rank = 1, dim_sizes = 3)
     ...
     %1 = mesh.all_to_all %0 on @mesh0 mesh_axes = [0]
       split_axis = 0 concat_axis = 0
@@ -348,7 +348,7 @@ def Mesh_ReduceScatterOp : Mesh_CollectiveCommunicationOpBase<"reduce_scatter",
     across the device group.
     Example:
     ```
-    mesh.cluster @mesh0(rank = 1, dim_sizes = [2, 2])
+    mesh.cluster @mesh0(rank = 1, dim_sizes = 2x2)
     ...
     %1 = mesh.reduce_scatter %0 on @mesh0 mesh_axes = [1]
       reduction = <max> scatter_axis = 0
diff --git a/mlir/include/mlir/IR/BuiltinAttributes.h b/mlir/include/mlir/IR/BuiltinAttributes.h
index c8161604aad3503..9261c82eae5b776 100644
--- a/mlir/include/mlir/IR/BuiltinAttributes.h
+++ b/mlir/include/mlir/IR/BuiltinAttributes.h
@@ -28,6 +28,8 @@ class FunctionType;
 class IntegerSet;
 class IntegerType;
 class Location;
+class OpAsmParser;
+class OpAsmPrinter;
 class Operation;
 class RankedTensorType;
 
@@ -1099,6 +1101,9 @@ namespace mlir {
 AffineMap makeStridedLinearLayoutMap(ArrayRef<int64_t> strides, int64_t offset,
                                      MLIRContext *context);
 
+void printShape(OpAsmPrinter &printer, Operation *op, ArrayRef<int64_t> shape);
+ParseResult parseShape(OpAsmParser &parser, DenseI64ArrayAttr &shape);
+
 } // namespace mlir
 
 namespace llvm {
diff --git a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
index b45f7cd21ce9217..4b08ecc69628b27 100644
--- a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
+++ b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
@@ -56,11 +56,6 @@ static SmallVector<T> &canonicalizeSetAsVector(SmallVector<T> &vec) {
   return vec;
 }
 
-template <typename DimSize>
-static bool isMeshDimensionDynamic(DimSize size) {
-  return size <= DimSize(0);
-}
-
 using MeshAxis = int16_t;
 
 namespace {
@@ -159,9 +154,9 @@ LogicalResult ClusterOp::verify() {
         "rank of dim_sizes is not expected to be larger than rank of cluster");
 
   for (int64_t dimSize : dimSizes) {
-    if (dimSize < 0)
-      return emitOpError(
-          "dimension size of a mesh cluster is expected to be non-negative");
+    if (dimSize < 0 && !ShapedType::isDynamic(dimSize))
+      return emitOpError("dimension size of a mesh cluster is expected to be "
+                         "non-negative or dynamic");
   }
 
   return success();
@@ -314,7 +309,7 @@ static int64_t collectiveDeviceGroupSize(ArrayRef<MeshAxis> meshAxes,
   int64_t res = 1;
 
   for (MeshAxis axis : meshAxes) {
-    if (isMeshDimensionDynamic(meshShape[axis])) {
+    if (ShapedType::isDynamic(meshShape[axis])) {
       return ShapedType::kDynamic;
     }
     assert(size_t(axis) < meshShape.size());
diff --git a/mlir/lib/IR/BuiltinAttributes.cpp b/mlir/lib/IR/BuiltinAttributes.cpp
index 89b1ed67f5d067a..1396cd6a90ffc6d 100644
--- a/mlir/lib/IR/BuiltinAttributes.cpp
+++ b/mlir/lib/IR/BuiltinAttributes.cpp
@@ -17,6 +17,7 @@
 #include "mlir/IR/Operation.h"
 #include "mlir/IR/SymbolTable.h"
 #include "mlir/IR/Types.h"
+#include "mlir/Support/LogicalResult.h"
 #include "llvm/ADT/APSInt.h"
 #include "llvm/ADT/Sequence.h"
 #include "llvm/ADT/TypeSwitch.h"
@@ -1822,3 +1823,43 @@ AffineMap mlir::makeStridedLinearLayoutMap(ArrayRef<int64_t> strides,
 
   return AffineMap::get(strides.size(), nSymbols, expr);
 }
+
+void mlir::printShape(OpAsmPrinter &printer, Operation *op,
+                      ArrayRef<int64_t> shape) {
+  if (!shape.empty())
+    printer << "[";
+
+  for (size_t i = 0; i < shape.size(); ++i) {
+    if (ShapedType::isDynamic(shape[i]))
+      printer << '?';
+    else
+      printer << shape[i];
+    if (i != shape.size() - 1) {
+      printer << 'x';
+    }
+  }
+
+  if (!shape.empty())
+    printer << "]";
+}
+
+ParseResult mlir::parseShape(OpAsmParser &parser, DenseI64ArrayAttr &shape) {
+  bool hasOpeningSquare = succeeded(parser.parseOptionalLSquare());
+  SmallVector<int64_t> shapeArr;
+  if (failed(parser.parseDimensionList(shapeArr, true, false))) {
+    return parser.emitError(parser.getCurrentLocation())
+           << "Failed parsing shape.";
+  }
+  if (shapeArr.empty() && !hasOpeningSquare) {
+    return parser.emitError(parser.getCurrentLocation())
+           << "Failed parsing shape. Did you mean a 0-rank shape? It must be "
+              "denoted by \"[]\".";
+  }
+  if (hasOpeningSquare && failed(parser.parseRSquare())) {
+    return parser.emitError(parser.getCurrentLocation())
+           << "Failed parsing shape.";
+  }
+
+  shape = DenseI64ArrayAttr::get(parser.getContext(), shapeArr);
+  return success();
+}
diff --git a/mlir/test/Dialect/Mesh/canonicalization.mlir b/mlir/test/Dialect/Mesh/canonicalization.mlir
index 5802d198d368149..baee9faa645c93a 100644
--- a/mlir/test/Dialect/Mesh/canonicalization.mlir
+++ b/mlir/test/Dialect/Mesh/canonicalization.mlir
@@ -1,6 +1,6 @@
 // RUN: mlir-opt --canonicalize %s | FileCheck %s
 
-mesh.cluster @mesh0(rank = 2, dim_sizes = [2, 4])
+mesh.cluster @mesh0(rank = 2, dim_sizes = 2x4)
 
 // CHECK-LABEL: func @all_reduce_empty_mesh_axes
 func.func @all_reduce_empty_mesh_axes(
diff --git a/mlir/test/Dialect/Mesh/invalid.mlir b/mlir/test/Dialect/Mesh/invalid.mlir
index 2999668f770baa7..dc91ff54fd83915 100644
--- a/mlir/test/Dialect/Mesh/invalid.mlir
+++ b/mlir/test/Dialect/Mesh/invalid.mlir
@@ -6,16 +6,16 @@ mesh.cluster @mesh0(rank = 0)
 // -----
 
 // expected-error at +1 {{rank of dim_sizes is not expected to be larger than rank of cluster}}
-mesh.cluster @mesh0(rank = 2, dim_sizes = [2, 3, 4])
+mesh.cluster @mesh0(rank = 2, dim_sizes = 2x3x4)
 
 // -----
 
-// expected-error at +1 {{dimension size of a mesh cluster is expected to be non-negative}}
-mesh.cluster @mesh0(rank = 2, dim_sizes = [-1])
+// expected-error at +1 {{unexpected error: custom op 'mesh.cluster' Failed parsing shape}}
+mesh.cluster @mesh0(rank = 2, dim_sizes = -1)
 
 // -----
 
-mesh.cluster @mesh0(rank = 2, dim_sizes = [2, 4])
+mesh.cluster @mesh0(rank = 2, dim_sizes = 2x4)
 
 func.func @mesh_axis_duplicated_different_subarray(
     // expected-error at +1 {{mesh axis duplicated}}
@@ -26,7 +26,7 @@ func.func @mesh_axis_duplicated_different_subarray(
 
 // -----
 
-mesh.cluster @mesh0(rank = 2, dim_sizes = [2, 4])
+mesh.cluster @mesh0(rank = 2, dim_sizes = 2x4)
 
 func.func @mesh_axis_duplicated_same_subarray(
     // expected-error at +1 {{mesh axis duplicated}}
@@ -37,7 +37,7 @@ func.func @mesh_axis_duplicated_same_subarray(
 
 // -----
 
-mesh.cluster @mesh0(rank = 2, dim_sizes = [2, 4])
+mesh.cluster @mesh0(rank = 2, dim_sizes = 2x4)
 
 func.func @mesh_axis_duplicated_bewteen_split_and_partial(
     // expected-error at +1 {{mesh axis duplicated}}
@@ -48,7 +48,7 @@ func.func @mesh_axis_duplicated_bewteen_split_and_partial(
 
 // -----
 
-mesh.cluster @mesh0(rank = 2, dim_sizes = [2, 4])
+mesh.cluster @mesh0(rank = 2, dim_sizes = 2x4)
 
 func.func @mesh_axis_negtive_in_split_part(
     // expected-error at +1 {{mesh axis is expected to be non-negative}}
@@ -59,7 +59,7 @@ func.func @mesh_axis_negtive_in_split_part(
 
 // -----
 
-mesh.cluster @mesh0(rank = 2, dim_sizes = [2, 4])
+mesh.cluster @mesh0(rank = 2, dim_sizes = 2x4)
 
 func.func @mesh_axis_negtive_in_partial(
     // expected-error at +1 {{mesh axis is expected to be non-negative}}
@@ -80,7 +80,7 @@ func.func @all_reduce_invalid_mesh_symbol(
 
 // -----
 
-mesh.cluster @mesh0(rank = 2, dim_sizes = [2, 4])
+mesh.cluster @mesh0(rank = 2, dim_sizes = 2x4)
 
 func.func @all_reduce_invalid_mesh_axis(
     %arg0 : tensor<4xf32>) -> tensor<4xf64> {
@@ -92,7 +92,7 @@ func.func @all_reduce_invalid_mesh_axis(
 
 // -----
 
-mesh.cluster @mesh0(rank = 2, dim_sizes = [2, 4])
+mesh.cluster @mesh0(rank = 2, dim_sizes = 2x4)
 
 func.func @all_reduce_duplicate_mesh_axis(
     %arg0 : tensor<4xf32>) -> tensor<4xf64> {
@@ -104,7 +104,7 @@ func.func @all_reduce_duplicate_mesh_axis(
 
 // -----
 
-mesh.cluster @mesh0(rank = 2, dim_sizes = [2, 4])
+mesh.cluster @mesh0(rank = 2, dim_sizes = 2x4)
 
 func.func @all_reduce_invalid_tensor_dimension_size(
     %arg0 : tensor<4xf32>) -> tensor<5xf64> {
@@ -125,7 +125,7 @@ func.func @all_gather_invalid_mesh_symbol(
 
 // -----
 
-mesh.cluster @mesh0(rank = 2, dim_sizes = [2, 4])
+mesh.cluster @mesh0(rank = 2, dim_sizes = 2x4)
 
 func.func @all_gather_invalid_mesh_axis(
     %arg0 : tensor<4xf32>) -> tensor<4xf32> {
@@ -137,7 +137,7 @@ func.func @all_gather_invalid_mesh_axis(
 
 // -----
 
-mesh.cluster @mesh0(rank = 2, dim_sizes = [2, 4])
+mesh.cluster @mesh0(rank = 2, dim_sizes = 2x4)
 
 func.func @all_reduce_duplicate_mesh_axis(
     %arg0 : tensor<4xf32>) -> tensor<4xf32> {
@@ -149,7 +149,7 @@ func.func @all_reduce_duplicate_mesh_axis(
 
 // -----
 
-mesh.cluster @mesh0(rank = 1, dim_sizes = [1])
+mesh.cluster @mesh0(rank = 1, dim_sizes = 1)
 
 func.func @all_gather_invalid_non_gather_axis_dimension_size(
     %arg0 : tensor<3x4xf32>) -> tensor<3x5xf32> {
@@ -161,7 +161,7 @@ func.func @all_gather_invalid_non_gather_axis_dimension_size(
 
 // -----
 
-mesh.cluster @mesh0(rank = 2, dim_sizes = [1, 2])
+mesh.cluster @mesh0(rank = 2, dim_sizes = 1x2)
 
 func.func @all_gather_invalid_gather_axis_dimension_size(
     %arg0 : tensor<3x4xf32>) -> tensor<3x5xf32> {
@@ -173,7 +173,7 @@ func.func @all_gather_invalid_gather_axis_dimension_size(
 
 // -----
 
-mesh.cluster @mesh0(rank = 1, dim_sizes = [1])
+mesh.cluster @mesh0(rank = 1, dim_sizes = 1)
 
 func.func @all_gather_invalid_gather_axis_dynamic_dimension(
     %arg0 : tensor<?xf32>) -> tensor<3xf32> {
@@ -185,7 +185,7 @@ func.func @all_gather_invalid_gather_axis_dynamic_dimension(
 
 // -----
 
-mesh.cluster @mesh0(rank = 1, dim_sizes = [1])
+mesh.cluster @mesh0(rank = 1, dim_sizes = 1)
 
 func.func @all_gather_invalid_gather_axis(
     %arg0 : tensor<3xf32>) -> tensor<3xf32> {
@@ -197,7 +197,7 @@ func.func @all_gather_invalid_gather_axis(
 
 // -----
 
-mesh.cluster @mesh0(rank = 1, dim_sizes = [1])
+mesh.cluster @mesh0(rank = 1, dim_sizes = 1)
 
 func.func @all_gather_invalid_negative_gather_axis(
     %arg0 : tensor<3xf32>) -> tensor<3xf32> {
@@ -220,7 +220,7 @@ func.func @all_to_all_invalid_mesh_symbol(
 
 // -----
 
-mesh.cluster @mesh0(rank = 1, dim_sizes = [1])
+mesh.cluster @mesh0(rank = 1, dim_sizes = 1)
 
 func.func @all_to_all_duplicate_mesh_axis(
     %arg0 : tensor<3x6xi8>) -> tensor<3x6xi8> {
@@ -233,7 +233,7 @@ func.func @all_to_all_duplicate_mesh_axis(
 
 // -----
 
-mesh.cluster @mesh0(rank = 2, dim_sizes = [0, 1])
+mesh.cluster @mesh0(rank = 2, dim_sizes = ?x1)
 
 func.func @all_to_all_invalid_non_dynamic_result_dimension_induced_by_dynamic_device_group(
     %arg0 : tensor<3x6xi8>) -> tensor<3x6xi8> {
@@ -246,7 +246,7 @@ func.func @all_to_all_invalid_non_dynamic_result_dimension_induced_by_dynamic_de
 
 // -----
 
-mesh.cluster @mesh0(rank = 2, dim_sizes = [1, 1])
+mesh.cluster @mesh0(rank = 2, dim_sizes = 1x1)
 
 func.func @all_to_all_invalid_non_dynamic_result_split_dimension_induced_by_dynamic_operand_dimension(
     %arg0 : tensor<?x6xi8>) -> tensor<3x?xi8> {
@@ -259,7 +259,7 @@ func.func @all_to_all_invalid_non_dynamic_result_split_dimension_induced_by_dyna
 
 // -----
 
-mesh.cluster @mesh0(rank = 2, dim_sizes = [1, 1])
+mesh.cluster @mesh0(rank = 2, dim_sizes = 1x1)
 
 func.func @all_to_all_invalid_non_dynamic_result_concat_dimension_induced_by_dynamic_operand_dimension(
     %arg0 : tensor<3x?xi8>) -> tensor<?x3xi8> {
@@ -272,7 +272,7 @@ func.func @all_to_all_invalid_non_dynamic_result_concat_dimension_induced_by_dyn
 
 // -----
 
-mesh.cluster @mesh0(rank = 1, dim_sizes = [3])
+mesh.cluster @mesh0(rank = 1, dim_sizes = 3)
 
 func.func @all_to_all_invalid_non_dynamic_result_concat_dimension_size(
     %arg0 : tensor<3x2xi8>) -> tensor<1x7xi8> {
@@ -285,7 +285,7 @@ func.func @all_to_all_invalid_non_dynamic_result_concat_dimension_size(
 
 // -----
 
-mesh.cluster @mesh0(rank = 1, dim_sizes = [3])
+mesh.cluster @mesh0(rank = 1, dim_sizes = 3)
 
 func.func @all_to_all_invalid_non_dynamic_result_split_dimension_size(
     %arg0 : tensor<3x2xi8>) -> tensor<2x6xi8> {
@@ -298,7 +298,7 @@ func.func @all_to_all_invalid_non_dynamic_result_split_dimension_size(
 
 // -----
 
-mesh.cluster @mesh0(rank = 1, dim_sizes = [3])
+mesh.cluster @mesh0(rank = 1, dim_sizes = 3)
 
 func.func @reduce_scatter_duplicate_mesh_axis(
     %arg0 : tensor<?xf32>) -> tensor<?xf64> {
@@ -310,7 +310,7 @@ func.func @reduce_scatter_duplicate_mesh_axis(
 
 // -----
 
-mesh.cluster @mesh0(rank = 1, dim_sizes = [3])
+mesh.cluster @mesh0(rank = 1, dim_sizes = 3)
 
 func.func @reduce_scatter_invalid_dynamic_dimension(
     %arg0 : tensor<?xf32>) -> tensor<2xf64> {
@@ -322,7 +322,7 @@ func.func @reduce_scatter_invalid_dynamic_dimension(
 
 // -----
 
-mesh.cluster @mesh0(rank = 1, dim_sizes = [3])
+mesh.cluster @mesh0(rank = 1, dim_sizes = 3)
 
 func.func @reduce_scatter_invalid_static_dimension_size(
     %arg0 : tensor<3xf32>) -> tensor<2xf64> {
@@ -334,7 +334,7 @@ func.func @reduce_scatter_invalid_static_dimension_size(
 
 // -----
 
-mesh.cluster @mesh0(rank = 1, dim_sizes = [3])
+mesh.cluster @mesh0(rank = 1, dim_sizes = 3)
 
 func.func @reduce_scatter_invalid_operand_static_dimension_size(
     %arg0 : tensor<4xf32>) -> tensor<?xf64> {
diff --git a/mlir/test/Dialect/Mesh/ops.mlir b/mlir/test/Dialect/Mesh/ops.mlir
index 5b264bc88dfc2a7..2e10c359f9ad899 100644
--- a/mlir/test/Dialect/Mesh/ops.mlir
+++ b/mlir/test/Dialect/Mesh/ops.mlir
@@ -1,18 +1,18 @@
 // RUN: mlir-opt %s | mlir-opt | FileCheck %s
 
 // CHECK: mesh.cluster @mesh0
-mesh.cluster @mesh0(rank = 3, dim_sizes = [2, 2, 4])
+mesh.cluster @mesh0(rank = 3, dim_sizes = 2x2x4)
 
 // CHECK: mesh.cluster @mesh1
-mesh.cluster @mesh1(rank = 2, dim_sizes = [4])
+mesh.cluster @mesh1(rank = 2, dim_sizes = 4)
 
 // CHECK: mesh.cluster @mesh2
-mesh.cluster @mesh2(rank = 2, dim_sizes = [0, 4])
+mesh.cluster @mesh2(rank = 2, dim_sizes = ?x4)
 
 // CHECK: mesh.cluster @mesh3
 mesh.cluster @mesh3(rank = 2)
 
-mesh.cluster @mesh4(rank = 1, dim_sizes = [3])
+mesh.cluster @mesh4(rank = 1, dim_sizes = 3)
 
 // CHECK-LABEL: func @mesh_shard_encoding_fully_replicated
 func.func @mesh_shard_encoding_fully_replicated(
diff --git a/mlir/test/Dialect/Mesh/sharding-propagation.mlir b/mlir/test/Dialect/Mesh/sharding-propagation.mlir
index bda407b52bfd4f2..30bbd5c6619e8af 100644
--- a/mlir/test/Dialect/Mesh/sharding-propagation.mlir
+++ b/mlir/test/Dialect/Mesh/sharding-propagation.mlir
@@ -1,7 +1,7 @@
 // RUN: mlir-opt -sharding-propagation %s | FileCheck %s
 
 mesh.cluster @mesh_1d(rank = 1)
-mesh.cluster @mesh_2d(rank = 2, dim_sizes = [2, 4])
+mesh.cluster @mesh_2d(rank = 2, dim_sizes = 2x4)
 mesh.cluster @mesh_3d(rank = 3)
 
 // CHECK-LABEL: func.func @element_wise_empty_sharding_info



More information about the Mlir-commits mailing list