[Mlir-commits] [mlir] [mlir][mesh] Use tensor shape notation for the shape of a cluster (PR #73826)

Boian Petkantchin llvmlistbot at llvm.org
Fri Dec 8 08:56:24 PST 2023


https://github.com/sogartar updated https://github.com/llvm/llvm-project/pull/73826

>From d573fae1212b1a5ed6276b8726e290df9c74b8df Mon Sep 17 00:00:00 2001
From: Boian Petkantchin <boian.petkantchin at amd.com>
Date: Wed, 29 Nov 2023 08:55:16 -0800
Subject: [PATCH 01/11] [mlir][mesh] Use tensor shape notation for the shape of
 a cluster

Examle:

substitute
mesh.cluster @mesh0(rank = 2, dim_sizes = [0, 4])

with
mesh.cluster @mesh0(rank = 2, dim_sizes = ?x4)

The only difference is for 0-rank shapes.
With tensors you would have something like `tensor<f32>`.
Here to avoid matching an empty string a 0-rank shape is denoted by `[]`.
---
 mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td  | 27 ++++-----
 mlir/include/mlir/IR/BuiltinAttributes.h      |  5 ++
 mlir/lib/Dialect/Mesh/IR/MeshOps.cpp          | 13 ++---
 mlir/lib/IR/BuiltinAttributes.cpp             | 41 ++++++++++++++
 mlir/test/Dialect/Mesh/canonicalization.mlir  |  2 +-
 mlir/test/Dialect/Mesh/invalid.mlir           | 56 +++++++++----------
 mlir/test/Dialect/Mesh/ops.mlir               |  8 +--
 .../Dialect/Mesh/sharding-propagation.mlir    |  2 +-
 8 files changed, 98 insertions(+), 56 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
index 361e67fd1e19ac..2ef9c65255f74e 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
@@ -40,26 +40,27 @@ def Mesh_ClusterOp : Mesh_Op<"cluster", [Symbol]> {
     determine the layout and the addressing space of the computation distributed
     across the mesh.
 
-    3. `dim_sizes`: This attribute represents the device assignment along the
-    axes of the cluster. Each integer in the array corresponds to the number of
-    devices along a specific axis. If an integer value is 0, it implies that the
-    number of devices along that axis is unknown. This flexibility allows for
-    dynamic device assignment or configurations where the exact number of
-    devices might not be determined during compile time.
+    3. `dim_sizes`: This attribute represents the shape of the device cluster.
+    It uses the same notation as a tensor shape. Also allowing for dynamic
+    dimensions.
+    This flexibility allows for dynamic device assignment or configurations
+    where the exact number of devices might not be determined during compile
+    time.
+    For example `2x?x4`.
 
     Example:
     ```
     // A device mesh cluster with 3 axes, the total device number is 4 * 8 * 12
     // The dimension sizes are 4, 8, 12 
-    mesh.cluster @mesh0(rank = 3, dim_sizes = [4, 8, 12])
+    mesh.cluster @mesh0(rank = 3, dim_sizes = 4x8x12)
 
     // A device mesh cluster with 2 axes, the total device number is unknown
     // The first dimension size is 4 and the second is unknown
-    mesh.cluster @mesh1(rank = 2, dim_sizes = [4])
+    mesh.cluster @mesh1(rank = 2, dim_sizes = 4)
 
     // A device mesh cluster with 2 axes, the total device number is unknown
     // The first dimension size is unknown and the second is 4
-    mesh.cluster @mesh2(rank = 2, dim_sizes = [0, 4])
+    mesh.cluster @mesh2(rank = 2, dim_sizes = ?x4)
 
     // A device mesh cluster with 2 axes, the number of devices along both axes
     // is unknown
@@ -76,7 +77,7 @@ def Mesh_ClusterOp : Mesh_Op<"cluster", [Symbol]> {
     DefaultValuedAttr<DenseI64ArrayAttr, "{}">:$dim_sizes
   );
   let assemblyFormat = [{
-    $sym_name `(` `rank` `=` $rank (`,` `dim_sizes` `=` $dim_sizes^)? `)`
+    $sym_name `(` `rank` `=` $rank (`,` `dim_sizes` `=` custom<Shape>($dim_sizes)^)? `)`
       attr-dict
   }];
   let extraClassDeclaration = [{
@@ -210,7 +211,7 @@ def Mesh_AllGatherOp : Mesh_CollectiveCommunicationOpBase<"all_gather", [
 
     Example:
     ```mlir
-    mesh.cluster @mesh0(rank = 2, dim_sizes = [2, 2])
+    mesh.cluster @mesh0(rank = 2, dim_sizes = 2x2)
     ...
     %1 = mesh.all_gather %0 on @mesh0 mesh_axes = [1] gather_axis = 1
       : tensor<2x2xi8> -> tensor<2x4xi8>
@@ -295,7 +296,7 @@ def Mesh_AllToAllOp : Mesh_CollectiveCommunicationOpBase<"all_to_all", [
 
     Example:
     ```
-    mesh.cluster @mesh0(rank = 1, dim_sizes = [3])
+    mesh.cluster @mesh0(rank = 1, dim_sizes = 3)
     ...
     %1 = mesh.all_to_all %0 on @mesh0 mesh_axes = [0]
       split_axis = 0 concat_axis = 0
@@ -527,7 +528,7 @@ def Mesh_ReduceScatterOp : Mesh_CollectiveCommunicationOpBase<"reduce_scatter",
     across the device group.
     Example:
     ```
-    mesh.cluster @mesh0(rank = 1, dim_sizes = [2, 2])
+    mesh.cluster @mesh0(rank = 1, dim_sizes = 2x2)
     ...
     %1 = mesh.reduce_scatter %0 on @mesh0 mesh_axes = [1]
       reduction = <max> scatter_axis = 0
diff --git a/mlir/include/mlir/IR/BuiltinAttributes.h b/mlir/include/mlir/IR/BuiltinAttributes.h
index 901df3a25a46f1..6d0871e9279065 100644
--- a/mlir/include/mlir/IR/BuiltinAttributes.h
+++ b/mlir/include/mlir/IR/BuiltinAttributes.h
@@ -28,6 +28,8 @@ class FunctionType;
 class IntegerSet;
 class IntegerType;
 class Location;
+class OpAsmParser;
+class OpAsmPrinter;
 class Operation;
 class RankedTensorType;
 
@@ -1101,6 +1103,9 @@ namespace mlir {
 AffineMap makeStridedLinearLayoutMap(ArrayRef<int64_t> strides, int64_t offset,
                                      MLIRContext *context);
 
+void printShape(OpAsmPrinter &printer, Operation *op, ArrayRef<int64_t> shape);
+ParseResult parseShape(OpAsmParser &parser, DenseI64ArrayAttr &shape);
+
 } // namespace mlir
 
 namespace llvm {
diff --git a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
index 3b89860c14d936..1ba95f21ec7f3d 100644
--- a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
+++ b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
@@ -58,11 +58,6 @@ static SmallVector<T> &canonicalizeSetAsVector(SmallVector<T> &vec) {
   return vec;
 }
 
-template <typename DimSize>
-static bool isMeshDimensionDynamic(DimSize size) {
-  return size <= DimSize(0);
-}
-
 using MeshAxis = int16_t;
 
 namespace {
@@ -161,9 +156,9 @@ LogicalResult ClusterOp::verify() {
         "rank of dim_sizes is not expected to be larger than rank of cluster");
 
   for (int64_t dimSize : dimSizes) {
-    if (dimSize < 0)
-      return emitOpError(
-          "dimension size of a mesh cluster is expected to be non-negative");
+    if (dimSize < 0 && !ShapedType::isDynamic(dimSize))
+      return emitOpError("dimension size of a mesh cluster is expected to be "
+                         "non-negative or dynamic");
   }
 
   return success();
@@ -316,7 +311,7 @@ static int64_t collectiveDeviceGroupSize(ArrayRef<MeshAxis> meshAxes,
   int64_t res = 1;
 
   for (MeshAxis axis : meshAxes) {
-    if (isMeshDimensionDynamic(meshShape[axis])) {
+    if (ShapedType::isDynamic(meshShape[axis])) {
       return ShapedType::kDynamic;
     }
     assert(size_t(axis) < meshShape.size());
diff --git a/mlir/lib/IR/BuiltinAttributes.cpp b/mlir/lib/IR/BuiltinAttributes.cpp
index 89b1ed67f5d067..1396cd6a90ffc6 100644
--- a/mlir/lib/IR/BuiltinAttributes.cpp
+++ b/mlir/lib/IR/BuiltinAttributes.cpp
@@ -17,6 +17,7 @@
 #include "mlir/IR/Operation.h"
 #include "mlir/IR/SymbolTable.h"
 #include "mlir/IR/Types.h"
+#include "mlir/Support/LogicalResult.h"
 #include "llvm/ADT/APSInt.h"
 #include "llvm/ADT/Sequence.h"
 #include "llvm/ADT/TypeSwitch.h"
@@ -1822,3 +1823,43 @@ AffineMap mlir::makeStridedLinearLayoutMap(ArrayRef<int64_t> strides,
 
   return AffineMap::get(strides.size(), nSymbols, expr);
 }
+
+void mlir::printShape(OpAsmPrinter &printer, Operation *op,
+                      ArrayRef<int64_t> shape) {
+  if (!shape.empty())
+    printer << "[";
+
+  for (size_t i = 0; i < shape.size(); ++i) {
+    if (ShapedType::isDynamic(shape[i]))
+      printer << '?';
+    else
+      printer << shape[i];
+    if (i != shape.size() - 1) {
+      printer << 'x';
+    }
+  }
+
+  if (!shape.empty())
+    printer << "]";
+}
+
+ParseResult mlir::parseShape(OpAsmParser &parser, DenseI64ArrayAttr &shape) {
+  bool hasOpeningSquare = succeeded(parser.parseOptionalLSquare());
+  SmallVector<int64_t> shapeArr;
+  if (failed(parser.parseDimensionList(shapeArr, true, false))) {
+    return parser.emitError(parser.getCurrentLocation())
+           << "Failed parsing shape.";
+  }
+  if (shapeArr.empty() && !hasOpeningSquare) {
+    return parser.emitError(parser.getCurrentLocation())
+           << "Failed parsing shape. Did you mean a 0-rank shape? It must be "
+              "denoted by \"[]\".";
+  }
+  if (hasOpeningSquare && failed(parser.parseRSquare())) {
+    return parser.emitError(parser.getCurrentLocation())
+           << "Failed parsing shape.";
+  }
+
+  shape = DenseI64ArrayAttr::get(parser.getContext(), shapeArr);
+  return success();
+}
diff --git a/mlir/test/Dialect/Mesh/canonicalization.mlir b/mlir/test/Dialect/Mesh/canonicalization.mlir
index 5802d198d36814..baee9faa645c93 100644
--- a/mlir/test/Dialect/Mesh/canonicalization.mlir
+++ b/mlir/test/Dialect/Mesh/canonicalization.mlir
@@ -1,6 +1,6 @@
 // RUN: mlir-opt --canonicalize %s | FileCheck %s
 
-mesh.cluster @mesh0(rank = 2, dim_sizes = [2, 4])
+mesh.cluster @mesh0(rank = 2, dim_sizes = 2x4)
 
 // CHECK-LABEL: func @all_reduce_empty_mesh_axes
 func.func @all_reduce_empty_mesh_axes(
diff --git a/mlir/test/Dialect/Mesh/invalid.mlir b/mlir/test/Dialect/Mesh/invalid.mlir
index 2999668f770baa..dc91ff54fd8391 100644
--- a/mlir/test/Dialect/Mesh/invalid.mlir
+++ b/mlir/test/Dialect/Mesh/invalid.mlir
@@ -6,16 +6,16 @@ mesh.cluster @mesh0(rank = 0)
 // -----
 
 // expected-error at +1 {{rank of dim_sizes is not expected to be larger than rank of cluster}}
-mesh.cluster @mesh0(rank = 2, dim_sizes = [2, 3, 4])
+mesh.cluster @mesh0(rank = 2, dim_sizes = 2x3x4)
 
 // -----
 
-// expected-error at +1 {{dimension size of a mesh cluster is expected to be non-negative}}
-mesh.cluster @mesh0(rank = 2, dim_sizes = [-1])
+// expected-error at +1 {{unexpected error: custom op 'mesh.cluster' Failed parsing shape}}
+mesh.cluster @mesh0(rank = 2, dim_sizes = -1)
 
 // -----
 
-mesh.cluster @mesh0(rank = 2, dim_sizes = [2, 4])
+mesh.cluster @mesh0(rank = 2, dim_sizes = 2x4)
 
 func.func @mesh_axis_duplicated_different_subarray(
     // expected-error at +1 {{mesh axis duplicated}}
@@ -26,7 +26,7 @@ func.func @mesh_axis_duplicated_different_subarray(
 
 // -----
 
-mesh.cluster @mesh0(rank = 2, dim_sizes = [2, 4])
+mesh.cluster @mesh0(rank = 2, dim_sizes = 2x4)
 
 func.func @mesh_axis_duplicated_same_subarray(
     // expected-error at +1 {{mesh axis duplicated}}
@@ -37,7 +37,7 @@ func.func @mesh_axis_duplicated_same_subarray(
 
 // -----
 
-mesh.cluster @mesh0(rank = 2, dim_sizes = [2, 4])
+mesh.cluster @mesh0(rank = 2, dim_sizes = 2x4)
 
 func.func @mesh_axis_duplicated_bewteen_split_and_partial(
     // expected-error at +1 {{mesh axis duplicated}}
@@ -48,7 +48,7 @@ func.func @mesh_axis_duplicated_bewteen_split_and_partial(
 
 // -----
 
-mesh.cluster @mesh0(rank = 2, dim_sizes = [2, 4])
+mesh.cluster @mesh0(rank = 2, dim_sizes = 2x4)
 
 func.func @mesh_axis_negtive_in_split_part(
     // expected-error at +1 {{mesh axis is expected to be non-negative}}
@@ -59,7 +59,7 @@ func.func @mesh_axis_negtive_in_split_part(
 
 // -----
 
-mesh.cluster @mesh0(rank = 2, dim_sizes = [2, 4])
+mesh.cluster @mesh0(rank = 2, dim_sizes = 2x4)
 
 func.func @mesh_axis_negtive_in_partial(
     // expected-error at +1 {{mesh axis is expected to be non-negative}}
@@ -80,7 +80,7 @@ func.func @all_reduce_invalid_mesh_symbol(
 
 // -----
 
-mesh.cluster @mesh0(rank = 2, dim_sizes = [2, 4])
+mesh.cluster @mesh0(rank = 2, dim_sizes = 2x4)
 
 func.func @all_reduce_invalid_mesh_axis(
     %arg0 : tensor<4xf32>) -> tensor<4xf64> {
@@ -92,7 +92,7 @@ func.func @all_reduce_invalid_mesh_axis(
 
 // -----
 
-mesh.cluster @mesh0(rank = 2, dim_sizes = [2, 4])
+mesh.cluster @mesh0(rank = 2, dim_sizes = 2x4)
 
 func.func @all_reduce_duplicate_mesh_axis(
     %arg0 : tensor<4xf32>) -> tensor<4xf64> {
@@ -104,7 +104,7 @@ func.func @all_reduce_duplicate_mesh_axis(
 
 // -----
 
-mesh.cluster @mesh0(rank = 2, dim_sizes = [2, 4])
+mesh.cluster @mesh0(rank = 2, dim_sizes = 2x4)
 
 func.func @all_reduce_invalid_tensor_dimension_size(
     %arg0 : tensor<4xf32>) -> tensor<5xf64> {
@@ -125,7 +125,7 @@ func.func @all_gather_invalid_mesh_symbol(
 
 // -----
 
-mesh.cluster @mesh0(rank = 2, dim_sizes = [2, 4])
+mesh.cluster @mesh0(rank = 2, dim_sizes = 2x4)
 
 func.func @all_gather_invalid_mesh_axis(
     %arg0 : tensor<4xf32>) -> tensor<4xf32> {
@@ -137,7 +137,7 @@ func.func @all_gather_invalid_mesh_axis(
 
 // -----
 
-mesh.cluster @mesh0(rank = 2, dim_sizes = [2, 4])
+mesh.cluster @mesh0(rank = 2, dim_sizes = 2x4)
 
 func.func @all_reduce_duplicate_mesh_axis(
     %arg0 : tensor<4xf32>) -> tensor<4xf32> {
@@ -149,7 +149,7 @@ func.func @all_reduce_duplicate_mesh_axis(
 
 // -----
 
-mesh.cluster @mesh0(rank = 1, dim_sizes = [1])
+mesh.cluster @mesh0(rank = 1, dim_sizes = 1)
 
 func.func @all_gather_invalid_non_gather_axis_dimension_size(
     %arg0 : tensor<3x4xf32>) -> tensor<3x5xf32> {
@@ -161,7 +161,7 @@ func.func @all_gather_invalid_non_gather_axis_dimension_size(
 
 // -----
 
-mesh.cluster @mesh0(rank = 2, dim_sizes = [1, 2])
+mesh.cluster @mesh0(rank = 2, dim_sizes = 1x2)
 
 func.func @all_gather_invalid_gather_axis_dimension_size(
     %arg0 : tensor<3x4xf32>) -> tensor<3x5xf32> {
@@ -173,7 +173,7 @@ func.func @all_gather_invalid_gather_axis_dimension_size(
 
 // -----
 
-mesh.cluster @mesh0(rank = 1, dim_sizes = [1])
+mesh.cluster @mesh0(rank = 1, dim_sizes = 1)
 
 func.func @all_gather_invalid_gather_axis_dynamic_dimension(
     %arg0 : tensor<?xf32>) -> tensor<3xf32> {
@@ -185,7 +185,7 @@ func.func @all_gather_invalid_gather_axis_dynamic_dimension(
 
 // -----
 
-mesh.cluster @mesh0(rank = 1, dim_sizes = [1])
+mesh.cluster @mesh0(rank = 1, dim_sizes = 1)
 
 func.func @all_gather_invalid_gather_axis(
     %arg0 : tensor<3xf32>) -> tensor<3xf32> {
@@ -197,7 +197,7 @@ func.func @all_gather_invalid_gather_axis(
 
 // -----
 
-mesh.cluster @mesh0(rank = 1, dim_sizes = [1])
+mesh.cluster @mesh0(rank = 1, dim_sizes = 1)
 
 func.func @all_gather_invalid_negative_gather_axis(
     %arg0 : tensor<3xf32>) -> tensor<3xf32> {
@@ -220,7 +220,7 @@ func.func @all_to_all_invalid_mesh_symbol(
 
 // -----
 
-mesh.cluster @mesh0(rank = 1, dim_sizes = [1])
+mesh.cluster @mesh0(rank = 1, dim_sizes = 1)
 
 func.func @all_to_all_duplicate_mesh_axis(
     %arg0 : tensor<3x6xi8>) -> tensor<3x6xi8> {
@@ -233,7 +233,7 @@ func.func @all_to_all_duplicate_mesh_axis(
 
 // -----
 
-mesh.cluster @mesh0(rank = 2, dim_sizes = [0, 1])
+mesh.cluster @mesh0(rank = 2, dim_sizes = ?x1)
 
 func.func @all_to_all_invalid_non_dynamic_result_dimension_induced_by_dynamic_device_group(
     %arg0 : tensor<3x6xi8>) -> tensor<3x6xi8> {
@@ -246,7 +246,7 @@ func.func @all_to_all_invalid_non_dynamic_result_dimension_induced_by_dynamic_de
 
 // -----
 
-mesh.cluster @mesh0(rank = 2, dim_sizes = [1, 1])
+mesh.cluster @mesh0(rank = 2, dim_sizes = 1x1)
 
 func.func @all_to_all_invalid_non_dynamic_result_split_dimension_induced_by_dynamic_operand_dimension(
     %arg0 : tensor<?x6xi8>) -> tensor<3x?xi8> {
@@ -259,7 +259,7 @@ func.func @all_to_all_invalid_non_dynamic_result_split_dimension_induced_by_dyna
 
 // -----
 
-mesh.cluster @mesh0(rank = 2, dim_sizes = [1, 1])
+mesh.cluster @mesh0(rank = 2, dim_sizes = 1x1)
 
 func.func @all_to_all_invalid_non_dynamic_result_concat_dimension_induced_by_dynamic_operand_dimension(
     %arg0 : tensor<3x?xi8>) -> tensor<?x3xi8> {
@@ -272,7 +272,7 @@ func.func @all_to_all_invalid_non_dynamic_result_concat_dimension_induced_by_dyn
 
 // -----
 
-mesh.cluster @mesh0(rank = 1, dim_sizes = [3])
+mesh.cluster @mesh0(rank = 1, dim_sizes = 3)
 
 func.func @all_to_all_invalid_non_dynamic_result_concat_dimension_size(
     %arg0 : tensor<3x2xi8>) -> tensor<1x7xi8> {
@@ -285,7 +285,7 @@ func.func @all_to_all_invalid_non_dynamic_result_concat_dimension_size(
 
 // -----
 
-mesh.cluster @mesh0(rank = 1, dim_sizes = [3])
+mesh.cluster @mesh0(rank = 1, dim_sizes = 3)
 
 func.func @all_to_all_invalid_non_dynamic_result_split_dimension_size(
     %arg0 : tensor<3x2xi8>) -> tensor<2x6xi8> {
@@ -298,7 +298,7 @@ func.func @all_to_all_invalid_non_dynamic_result_split_dimension_size(
 
 // -----
 
-mesh.cluster @mesh0(rank = 1, dim_sizes = [3])
+mesh.cluster @mesh0(rank = 1, dim_sizes = 3)
 
 func.func @reduce_scatter_duplicate_mesh_axis(
     %arg0 : tensor<?xf32>) -> tensor<?xf64> {
@@ -310,7 +310,7 @@ func.func @reduce_scatter_duplicate_mesh_axis(
 
 // -----
 
-mesh.cluster @mesh0(rank = 1, dim_sizes = [3])
+mesh.cluster @mesh0(rank = 1, dim_sizes = 3)
 
 func.func @reduce_scatter_invalid_dynamic_dimension(
     %arg0 : tensor<?xf32>) -> tensor<2xf64> {
@@ -322,7 +322,7 @@ func.func @reduce_scatter_invalid_dynamic_dimension(
 
 // -----
 
-mesh.cluster @mesh0(rank = 1, dim_sizes = [3])
+mesh.cluster @mesh0(rank = 1, dim_sizes = 3)
 
 func.func @reduce_scatter_invalid_static_dimension_size(
     %arg0 : tensor<3xf32>) -> tensor<2xf64> {
@@ -334,7 +334,7 @@ func.func @reduce_scatter_invalid_static_dimension_size(
 
 // -----
 
-mesh.cluster @mesh0(rank = 1, dim_sizes = [3])
+mesh.cluster @mesh0(rank = 1, dim_sizes = 3)
 
 func.func @reduce_scatter_invalid_operand_static_dimension_size(
     %arg0 : tensor<4xf32>) -> tensor<?xf64> {
diff --git a/mlir/test/Dialect/Mesh/ops.mlir b/mlir/test/Dialect/Mesh/ops.mlir
index 5b264bc88dfc2a..2e10c359f9ad89 100644
--- a/mlir/test/Dialect/Mesh/ops.mlir
+++ b/mlir/test/Dialect/Mesh/ops.mlir
@@ -1,18 +1,18 @@
 // RUN: mlir-opt %s | mlir-opt | FileCheck %s
 
 // CHECK: mesh.cluster @mesh0
-mesh.cluster @mesh0(rank = 3, dim_sizes = [2, 2, 4])
+mesh.cluster @mesh0(rank = 3, dim_sizes = 2x2x4)
 
 // CHECK: mesh.cluster @mesh1
-mesh.cluster @mesh1(rank = 2, dim_sizes = [4])
+mesh.cluster @mesh1(rank = 2, dim_sizes = 4)
 
 // CHECK: mesh.cluster @mesh2
-mesh.cluster @mesh2(rank = 2, dim_sizes = [0, 4])
+mesh.cluster @mesh2(rank = 2, dim_sizes = ?x4)
 
 // CHECK: mesh.cluster @mesh3
 mesh.cluster @mesh3(rank = 2)
 
-mesh.cluster @mesh4(rank = 1, dim_sizes = [3])
+mesh.cluster @mesh4(rank = 1, dim_sizes = 3)
 
 // CHECK-LABEL: func @mesh_shard_encoding_fully_replicated
 func.func @mesh_shard_encoding_fully_replicated(
diff --git a/mlir/test/Dialect/Mesh/sharding-propagation.mlir b/mlir/test/Dialect/Mesh/sharding-propagation.mlir
index bda407b52bfd4f..30bbd5c6619e8a 100644
--- a/mlir/test/Dialect/Mesh/sharding-propagation.mlir
+++ b/mlir/test/Dialect/Mesh/sharding-propagation.mlir
@@ -1,7 +1,7 @@
 // RUN: mlir-opt -sharding-propagation %s | FileCheck %s
 
 mesh.cluster @mesh_1d(rank = 1)
-mesh.cluster @mesh_2d(rank = 2, dim_sizes = [2, 4])
+mesh.cluster @mesh_2d(rank = 2, dim_sizes = 2x4)
 mesh.cluster @mesh_3d(rank = 3)
 
 // CHECK-LABEL: func.func @element_wise_empty_sharding_info

>From b78d1640e9043c28dafc975291d699734df48478 Mon Sep 17 00:00:00 2001
From: Boian Petkantchin <boian.petkantchin at amd.com>
Date: Thu, 30 Nov 2023 08:43:33 -0800
Subject: [PATCH 02/11] Fix invalid cluster op test

---
 mlir/test/Dialect/Mesh/invalid.mlir | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/mlir/test/Dialect/Mesh/invalid.mlir b/mlir/test/Dialect/Mesh/invalid.mlir
index dc91ff54fd8391..875fd1839d3d19 100644
--- a/mlir/test/Dialect/Mesh/invalid.mlir
+++ b/mlir/test/Dialect/Mesh/invalid.mlir
@@ -10,7 +10,7 @@ mesh.cluster @mesh0(rank = 2, dim_sizes = 2x3x4)
 
 // -----
 
-// expected-error at +1 {{unexpected error: custom op 'mesh.cluster' Failed parsing shape}}
+// expected-error at +1 {{custom op 'mesh.cluster' Failed parsing shape}}
 mesh.cluster @mesh0(rank = 2, dim_sizes = -1)
 
 // -----

>From 1bbe4303c8ca0653b526960e2517d14bf4b828bc Mon Sep 17 00:00:00 2001
From: Boian Petkantchin <boian.petkantchin at amd.com>
Date: Thu, 30 Nov 2023 08:50:17 -0800
Subject: [PATCH 03/11] Refactor out a common printShape function

---
 mlir/include/mlir/IR/OpImplementation.h | 17 +++++++++++++++++
 mlir/lib/IR/AsmPrinter.cpp              | 16 ++++------------
 mlir/lib/IR/BuiltinAttributes.cpp       | 12 +-----------
 3 files changed, 22 insertions(+), 23 deletions(-)

diff --git a/mlir/include/mlir/IR/OpImplementation.h b/mlir/include/mlir/IR/OpImplementation.h
index f1fabf95a68b7a..ca74ac2592a1db 100644
--- a/mlir/include/mlir/IR/OpImplementation.h
+++ b/mlir/include/mlir/IR/OpImplementation.h
@@ -13,12 +13,16 @@
 #ifndef MLIR_IR_OPIMPLEMENTATION_H
 #define MLIR_IR_OPIMPLEMENTATION_H
 
+#include "mlir/IR/BuiltinTypeInterfaces.h"
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/DialectInterface.h"
 #include "mlir/IR/OpDefinition.h"
+#include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/Twine.h"
 #include "llvm/Support/SMLoc.h"
+#include "llvm/Support/raw_ostream.h"
 #include <optional>
+#include <type_traits>
 
 namespace mlir {
 class AsmParsedResourceEntry;
@@ -1762,6 +1766,19 @@ class OpAsmDialectInterface
                  const SetVector<AsmDialectResourceHandle> &referencedResources,
                  AsmResourceBuilder &builder) const {}
 };
+
+template <typename Range>
+void printShape(raw_ostream& stream, Range&& shape) {
+  for (auto [idx, dimSize] : llvm::enumerate(shape)) {
+    if (ShapedType::isDynamic(dimSize))
+      stream << "?";
+    else
+      stream << dimSize;
+    if (static_cast<std::decay_t<decltype(range_size(shape))>>(idx) != range_size(shape) - 1)
+      stream << "x";
+  }
+}
+
 } // namespace mlir
 
 //===--------------------------------------------------------------------===//
diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index 4b76dcf7f8a9f7..a69ee19690d0be 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -2576,13 +2576,9 @@ void AsmPrinter::Impl::printTypeImpl(Type type) {
       })
       .Case<RankedTensorType>([&](RankedTensorType tensorTy) {
         os << "tensor<";
-        for (int64_t dim : tensorTy.getShape()) {
-          if (ShapedType::isDynamic(dim))
-            os << '?';
-          else
-            os << dim;
+        printShape(os, tensorTy.getShape());
+        if (!tensorTy.getShape().empty())
           os << 'x';
-        }
         printType(tensorTy.getElementType());
         // Only print the encoding attribute value if set.
         if (tensorTy.getEncoding()) {
@@ -2598,13 +2594,9 @@ void AsmPrinter::Impl::printTypeImpl(Type type) {
       })
       .Case<MemRefType>([&](MemRefType memrefTy) {
         os << "memref<";
-        for (int64_t dim : memrefTy.getShape()) {
-          if (ShapedType::isDynamic(dim))
-            os << '?';
-          else
-            os << dim;
+        printShape(os, memrefTy.getShape());
+        if (!memrefTy.getShape().empty())
           os << 'x';
-        }
         printType(memrefTy.getElementType());
         MemRefLayoutAttrInterface layout = memrefTy.getLayout();
         if (!llvm::isa<AffineMapAttr>(layout) || !layout.isIdentity()) {
diff --git a/mlir/lib/IR/BuiltinAttributes.cpp b/mlir/lib/IR/BuiltinAttributes.cpp
index 1396cd6a90ffc6..bd2693b26a7dce 100644
--- a/mlir/lib/IR/BuiltinAttributes.cpp
+++ b/mlir/lib/IR/BuiltinAttributes.cpp
@@ -1828,17 +1828,7 @@ void mlir::printShape(OpAsmPrinter &printer, Operation *op,
                       ArrayRef<int64_t> shape) {
   if (!shape.empty())
     printer << "[";
-
-  for (size_t i = 0; i < shape.size(); ++i) {
-    if (ShapedType::isDynamic(shape[i]))
-      printer << '?';
-    else
-      printer << shape[i];
-    if (i != shape.size() - 1) {
-      printer << 'x';
-    }
-  }
-
+  printShape(printer.getStream(), shape);
   if (!shape.empty())
     printer << "]";
 }

>From 3b02b7fcaa6d9471c5d6d35b0bd5b463ca6c2753 Mon Sep 17 00:00:00 2001
From: Boian Petkantchin <boian.petkantchin at amd.com>
Date: Thu, 30 Nov 2023 10:06:49 -0800
Subject: [PATCH 04/11] Fix formatting

---
 mlir/include/mlir/IR/OpImplementation.h | 5 +++--
 1 file changed, 3 insertions(+), 2 deletions(-)

diff --git a/mlir/include/mlir/IR/OpImplementation.h b/mlir/include/mlir/IR/OpImplementation.h
index ca74ac2592a1db..3be75d45175165 100644
--- a/mlir/include/mlir/IR/OpImplementation.h
+++ b/mlir/include/mlir/IR/OpImplementation.h
@@ -1768,13 +1768,14 @@ class OpAsmDialectInterface
 };
 
 template <typename Range>
-void printShape(raw_ostream& stream, Range&& shape) {
+void printShape(raw_ostream &stream, Range &&shape) {
   for (auto [idx, dimSize] : llvm::enumerate(shape)) {
     if (ShapedType::isDynamic(dimSize))
       stream << "?";
     else
       stream << dimSize;
-    if (static_cast<std::decay_t<decltype(range_size(shape))>>(idx) != range_size(shape) - 1)
+    if (static_cast<std::decay_t<decltype(range_size(shape))>>(idx) !=
+        range_size(shape) - 1)
       stream << "x";
   }
 }

>From e2db4f90ba2b233a3e6281ccaef85640c590ad20 Mon Sep 17 00:00:00 2001
From: Boian Petkantchin <boian.petkantchin at amd.com>
Date: Fri, 1 Dec 2023 09:44:26 -0800
Subject: [PATCH 05/11] Fix canonicalDimSizes in mesh.cluster

---
 mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
index 2ef9c65255f74e..607dd0c5de7beb 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
@@ -89,7 +89,7 @@ def Mesh_ClusterOp : Mesh_Op<"cluster", [Symbol]> {
     template <typename OutIt>
     void canonicalDimSizes(OutIt outIt) {
       std::copy(getDimSizes().begin(), getDimSizes().end(), outIt);
-      std::fill_n(outIt, getRank() - getDimSizes().size(), 0);
+      std::fill_n(outIt, getRank() - getDimSizes().size(), ::mlir::ShapedType::kDynamic);
     }
   }];
   let hasVerifier = 1;

>From cade3a366e761ec391b90e40bbbb2d2eec95256f Mon Sep 17 00:00:00 2001
From: Boian Petkantchin <boian.petkantchin at amd.com>
Date: Fri, 1 Dec 2023 10:11:28 -0800
Subject: [PATCH 06/11] Use llvm::interleave in printShape

---
 mlir/include/mlir/IR/OpImplementation.h | 18 +++++++++---------
 1 file changed, 9 insertions(+), 9 deletions(-)

diff --git a/mlir/include/mlir/IR/OpImplementation.h b/mlir/include/mlir/IR/OpImplementation.h
index 3be75d45175165..cdc780081668e8 100644
--- a/mlir/include/mlir/IR/OpImplementation.h
+++ b/mlir/include/mlir/IR/OpImplementation.h
@@ -1769,15 +1769,15 @@ class OpAsmDialectInterface
 
 template <typename Range>
 void printShape(raw_ostream &stream, Range &&shape) {
-  for (auto [idx, dimSize] : llvm::enumerate(shape)) {
-    if (ShapedType::isDynamic(dimSize))
-      stream << "?";
-    else
-      stream << dimSize;
-    if (static_cast<std::decay_t<decltype(range_size(shape))>>(idx) !=
-        range_size(shape) - 1)
-      stream << "x";
-  }
+  llvm::interleave(
+      shape, stream,
+      [&stream](const auto &dimSize) {
+        if (ShapedType::isDynamic(dimSize))
+          stream << "?";
+        else
+          stream << dimSize;
+      },
+      "x");
 }
 
 } // namespace mlir

>From e4e1bdc334a766d00667ec42e233d67f639184f1 Mon Sep 17 00:00:00 2001
From: Boian Petkantchin <boian.petkantchin at amd.com>
Date: Mon, 4 Dec 2023 10:37:04 -0800
Subject: [PATCH 07/11] Move printShape and parseShape to OpImplementation.h

---
 mlir/include/mlir/IR/BuiltinAttributes.h |  5 --
 mlir/include/mlir/IR/OpImplementation.h  | 28 +++++-----
 mlir/lib/IR/AsmPrinter.cpp               | 67 +++++++++++++++++++++++-
 mlir/lib/IR/BuiltinAttributes.cpp        | 31 -----------
 4 files changed, 77 insertions(+), 54 deletions(-)

diff --git a/mlir/include/mlir/IR/BuiltinAttributes.h b/mlir/include/mlir/IR/BuiltinAttributes.h
index 6d0871e9279065..901df3a25a46f1 100644
--- a/mlir/include/mlir/IR/BuiltinAttributes.h
+++ b/mlir/include/mlir/IR/BuiltinAttributes.h
@@ -28,8 +28,6 @@ class FunctionType;
 class IntegerSet;
 class IntegerType;
 class Location;
-class OpAsmParser;
-class OpAsmPrinter;
 class Operation;
 class RankedTensorType;
 
@@ -1103,9 +1101,6 @@ namespace mlir {
 AffineMap makeStridedLinearLayoutMap(ArrayRef<int64_t> strides, int64_t offset,
                                      MLIRContext *context);
 
-void printShape(OpAsmPrinter &printer, Operation *op, ArrayRef<int64_t> shape);
-ParseResult parseShape(OpAsmParser &parser, DenseI64ArrayAttr &shape);
-
 } // namespace mlir
 
 namespace llvm {
diff --git a/mlir/include/mlir/IR/OpImplementation.h b/mlir/include/mlir/IR/OpImplementation.h
index cdc780081668e8..057c9dbca76557 100644
--- a/mlir/include/mlir/IR/OpImplementation.h
+++ b/mlir/include/mlir/IR/OpImplementation.h
@@ -13,16 +13,12 @@
 #ifndef MLIR_IR_OPIMPLEMENTATION_H
 #define MLIR_IR_OPIMPLEMENTATION_H
 
-#include "mlir/IR/BuiltinTypeInterfaces.h"
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/DialectInterface.h"
 #include "mlir/IR/OpDefinition.h"
-#include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/Twine.h"
 #include "llvm/Support/SMLoc.h"
-#include "llvm/Support/raw_ostream.h"
 #include <optional>
-#include <type_traits>
 
 namespace mlir {
 class AsmParsedResourceEntry;
@@ -230,6 +226,8 @@ class AsmPrinter {
     printArrowTypeList(results);
   }
 
+  void printDimensionList(ArrayRef<int64_t> shape);
+
   /// Class used to automatically end a cyclic region on destruction.
   class CyclicPrintReset {
   public:
@@ -776,6 +774,9 @@ class AsmParser {
     return parseCommaSeparatedList(Delimiter::None, parseElementFn);
   }
 
+  template <typename OutIt>
+  ParseResult parseShape(SmallVectorImpl<int64_t> &dims);
+
   //===--------------------------------------------------------------------===//
   // Keyword Parsing
   //===--------------------------------------------------------------------===//
@@ -1767,18 +1768,13 @@ class OpAsmDialectInterface
                  AsmResourceBuilder &builder) const {}
 };
 
-template <typename Range>
-void printShape(raw_ostream &stream, Range &&shape) {
-  llvm::interleave(
-      shape, stream,
-      [&stream](const auto &dimSize) {
-        if (ShapedType::isDynamic(dimSize))
-          stream << "?";
-        else
-          stream << dimSize;
-      },
-      "x");
-}
+//===--------------------------------------------------------------------===//
+// Custom attribute printers and parsers.
+//===--------------------------------------------------------------------===//
+
+// Handles custom<Shape>(...) in TableGen.
+void printShape(OpAsmPrinter &printer, Operation *op, ArrayRef<int64_t> shape);
+ParseResult parseShape(OpAsmParser &parser, DenseI64ArrayAttr &shape);
 
 } // namespace mlir
 
diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index a69ee19690d0be..8a54d594fe1e7b 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -16,7 +16,9 @@
 #include "mlir/IR/AsmState.h"
 #include "mlir/IR/Attributes.h"
 #include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinAttributes.h"
 #include "mlir/IR/BuiltinDialect.h"
+#include "mlir/IR/BuiltinTypeInterfaces.h"
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/Dialect.h"
 #include "mlir/IR/DialectImplementation.h"
@@ -44,6 +46,7 @@
 #include "llvm/Support/SaveAndRestore.h"
 #include "llvm/Support/Threading.h"
 #include "llvm/Support/raw_ostream.h"
+#include <type_traits>
 
 #include <optional>
 #include <tuple>
@@ -425,6 +428,8 @@ class AsmPrinter::Impl {
 
   void popCyclicPrinting();
 
+  void printDimensionList(ArrayRef<int64_t> shape);
+
 protected:
   void printOptionalAttrDict(ArrayRef<NamedAttribute> attrs,
                              ArrayRef<StringRef> elidedAttrs = {},
@@ -1860,6 +1865,20 @@ class AsmStateImpl {
   // Allow direct access to the impl fields.
   friend AsmState;
 };
+
+template <typename Range>
+void printDimensionList(raw_ostream &stream, Range &&shape) {
+  llvm::interleave(
+      shape, stream,
+      [&stream](const auto &dimSize) {
+        if (ShapedType::isDynamic(dimSize))
+          stream << "?";
+        else
+          stream << dimSize;
+      },
+      "x");
+}
+
 } // namespace detail
 } // namespace mlir
 
@@ -2576,7 +2595,7 @@ void AsmPrinter::Impl::printTypeImpl(Type type) {
       })
       .Case<RankedTensorType>([&](RankedTensorType tensorTy) {
         os << "tensor<";
-        printShape(os, tensorTy.getShape());
+        printDimensionList(tensorTy.getShape());
         if (!tensorTy.getShape().empty())
           os << 'x';
         printType(tensorTy.getElementType());
@@ -2594,7 +2613,7 @@ void AsmPrinter::Impl::printTypeImpl(Type type) {
       })
       .Case<MemRefType>([&](MemRefType memrefTy) {
         os << "memref<";
-        printShape(os, memrefTy.getShape());
+        printDimensionList(memrefTy.getShape());
         if (!memrefTy.getShape().empty())
           os << 'x';
         printType(memrefTy.getElementType());
@@ -2727,6 +2746,10 @@ LogicalResult AsmPrinter::Impl::pushCyclicPrinting(const void *opaquePointer) {
 
 void AsmPrinter::Impl::popCyclicPrinting() { state.popCyclicPrinting(); }
 
+void AsmPrinter::Impl::printDimensionList(ArrayRef<int64_t> shape) {
+  detail::printDimensionList(os, shape);
+}
+
 //===--------------------------------------------------------------------===//
 // AsmPrinter
 //===--------------------------------------------------------------------===//
@@ -2792,6 +2815,10 @@ void AsmPrinter::printResourceHandle(const AsmDialectResourceHandle &resource) {
   impl->printResourceHandle(resource);
 }
 
+void AsmPrinter::printDimensionList(ArrayRef<int64_t> shape) {
+  detail::printDimensionList(getStream(), shape);
+}
+
 LogicalResult AsmPrinter::pushCyclicPrinting(const void *opaquePointer) {
   return impl->pushCyclicPrinting(opaquePointer);
 }
@@ -3903,3 +3930,39 @@ void Block::printAsOperand(raw_ostream &os, AsmState &state) {
   OperationPrinter printer(os, state.getImpl());
   printer.printBlockName(this);
 }
+
+//===--------------------------------------------------------------------===//
+// Custom attribute printers and parsers.
+//===--------------------------------------------------------------------===//
+namespace mlir {
+
+void printShape(OpAsmPrinter &printer, Operation *op, ArrayRef<int64_t> shape) {
+  if (!shape.empty())
+    printer << "[";
+  printer.printDimensionList(shape);
+  if (!shape.empty())
+    printer << "]";
+}
+
+ParseResult parseShape(OpAsmParser &parser, DenseI64ArrayAttr &shape) {
+  bool hasOpeningSquare = succeeded(parser.parseOptionalLSquare());
+  SmallVector<int64_t> shapeArr;
+  if (failed(parser.parseDimensionList(shapeArr, true, false))) {
+    return parser.emitError(parser.getCurrentLocation())
+           << "Failed parsing shape.";
+  }
+  if (shapeArr.empty() && !hasOpeningSquare) {
+    return parser.emitError(parser.getCurrentLocation())
+           << "Failed parsing shape. Did you mean a 0-rank shape? It must be "
+              "denoted by \"[]\".";
+  }
+  if (hasOpeningSquare && failed(parser.parseRSquare())) {
+    return parser.emitError(parser.getCurrentLocation())
+           << "Failed parsing shape.";
+  }
+
+  shape = DenseI64ArrayAttr::get(parser.getContext(), shapeArr);
+  return success();
+}
+
+} // namespace mlir
\ No newline at end of file
diff --git a/mlir/lib/IR/BuiltinAttributes.cpp b/mlir/lib/IR/BuiltinAttributes.cpp
index bd2693b26a7dce..89b1ed67f5d067 100644
--- a/mlir/lib/IR/BuiltinAttributes.cpp
+++ b/mlir/lib/IR/BuiltinAttributes.cpp
@@ -17,7 +17,6 @@
 #include "mlir/IR/Operation.h"
 #include "mlir/IR/SymbolTable.h"
 #include "mlir/IR/Types.h"
-#include "mlir/Support/LogicalResult.h"
 #include "llvm/ADT/APSInt.h"
 #include "llvm/ADT/Sequence.h"
 #include "llvm/ADT/TypeSwitch.h"
@@ -1823,33 +1822,3 @@ AffineMap mlir::makeStridedLinearLayoutMap(ArrayRef<int64_t> strides,
 
   return AffineMap::get(strides.size(), nSymbols, expr);
 }
-
-void mlir::printShape(OpAsmPrinter &printer, Operation *op,
-                      ArrayRef<int64_t> shape) {
-  if (!shape.empty())
-    printer << "[";
-  printShape(printer.getStream(), shape);
-  if (!shape.empty())
-    printer << "]";
-}
-
-ParseResult mlir::parseShape(OpAsmParser &parser, DenseI64ArrayAttr &shape) {
-  bool hasOpeningSquare = succeeded(parser.parseOptionalLSquare());
-  SmallVector<int64_t> shapeArr;
-  if (failed(parser.parseDimensionList(shapeArr, true, false))) {
-    return parser.emitError(parser.getCurrentLocation())
-           << "Failed parsing shape.";
-  }
-  if (shapeArr.empty() && !hasOpeningSquare) {
-    return parser.emitError(parser.getCurrentLocation())
-           << "Failed parsing shape. Did you mean a 0-rank shape? It must be "
-              "denoted by \"[]\".";
-  }
-  if (hasOpeningSquare && failed(parser.parseRSquare())) {
-    return parser.emitError(parser.getCurrentLocation())
-           << "Failed parsing shape.";
-  }
-
-  shape = DenseI64ArrayAttr::get(parser.getContext(), shapeArr);
-  return success();
-}

>From c18ac63bc022894767c583cfeae52a82bab5ae24 Mon Sep 17 00:00:00 2001
From: Boian Petkantchin <boian.petkantchin at amd.com>
Date: Mon, 4 Dec 2023 16:43:59 -0800
Subject: [PATCH 08/11] Remove unsused function signature and fix comment

---
 mlir/include/mlir/IR/OpImplementation.h | 5 +----
 mlir/lib/IR/AsmPrinter.cpp              | 2 +-
 2 files changed, 2 insertions(+), 5 deletions(-)

diff --git a/mlir/include/mlir/IR/OpImplementation.h b/mlir/include/mlir/IR/OpImplementation.h
index 057c9dbca76557..bb0c18c50439af 100644
--- a/mlir/include/mlir/IR/OpImplementation.h
+++ b/mlir/include/mlir/IR/OpImplementation.h
@@ -774,9 +774,6 @@ class AsmParser {
     return parseCommaSeparatedList(Delimiter::None, parseElementFn);
   }
 
-  template <typename OutIt>
-  ParseResult parseShape(SmallVectorImpl<int64_t> &dims);
-
   //===--------------------------------------------------------------------===//
   // Keyword Parsing
   //===--------------------------------------------------------------------===//
@@ -1769,7 +1766,7 @@ class OpAsmDialectInterface
 };
 
 //===--------------------------------------------------------------------===//
-// Custom attribute printers and parsers.
+// Custom printers and parsers.
 //===--------------------------------------------------------------------===//
 
 // Handles custom<Shape>(...) in TableGen.
diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index 8a54d594fe1e7b..416d0795e9a638 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -3932,7 +3932,7 @@ void Block::printAsOperand(raw_ostream &os, AsmState &state) {
 }
 
 //===--------------------------------------------------------------------===//
-// Custom attribute printers and parsers.
+// Custom printers and parsers.
 //===--------------------------------------------------------------------===//
 namespace mlir {
 

>From bf34237508a7033c8142a2d607d8a7705d6452b9 Mon Sep 17 00:00:00 2001
From: Boian Petkantchin <boian.petkantchin at amd.com>
Date: Tue, 5 Dec 2023 09:54:54 -0800
Subject: [PATCH 09/11] Rename custom<Shape> -> custom<DimensionList> and add
 tests for it

---
 mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td |  2 +-
 mlir/include/mlir/IR/OpImplementation.h      |  8 ++++---
 mlir/lib/IR/AsmPrinter.cpp                   | 25 +++++++++++---------
 mlir/test/Dialect/Mesh/invalid.mlir          |  2 +-
 mlir/test/Dialect/Mesh/ops.mlir              |  3 +++
 mlir/test/IR/custom-print-parse.mlir         | 18 ++++++++++++++
 mlir/test/IR/invalid-custom-print-parse.mlir | 10 ++++++++
 mlir/test/lib/Dialect/Test/TestOps.td        | 14 +++++++++++
 mlir/test/lib/Dialect/Test/TestOpsSyntax.td  |  2 ++
 9 files changed, 68 insertions(+), 16 deletions(-)
 create mode 100644 mlir/test/IR/custom-print-parse.mlir
 create mode 100644 mlir/test/IR/invalid-custom-print-parse.mlir

diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
index 607dd0c5de7beb..e6cdba949b1721 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
@@ -77,7 +77,7 @@ def Mesh_ClusterOp : Mesh_Op<"cluster", [Symbol]> {
     DefaultValuedAttr<DenseI64ArrayAttr, "{}">:$dim_sizes
   );
   let assemblyFormat = [{
-    $sym_name `(` `rank` `=` $rank (`,` `dim_sizes` `=` custom<Shape>($dim_sizes)^)? `)`
+    $sym_name `(` `rank` `=` $rank (`,` `dim_sizes` `=` custom<DimensionList>($dim_sizes)^)? `)`
       attr-dict
   }];
   let extraClassDeclaration = [{
diff --git a/mlir/include/mlir/IR/OpImplementation.h b/mlir/include/mlir/IR/OpImplementation.h
index bb0c18c50439af..86ed14e7ca8439 100644
--- a/mlir/include/mlir/IR/OpImplementation.h
+++ b/mlir/include/mlir/IR/OpImplementation.h
@@ -1769,9 +1769,11 @@ class OpAsmDialectInterface
 // Custom printers and parsers.
 //===--------------------------------------------------------------------===//
 
-// Handles custom<Shape>(...) in TableGen.
-void printShape(OpAsmPrinter &printer, Operation *op, ArrayRef<int64_t> shape);
-ParseResult parseShape(OpAsmParser &parser, DenseI64ArrayAttr &shape);
+// Handles custom<DimensionList>(...) in TableGen.
+void printDimensionList(OpAsmPrinter &printer, Operation *op,
+                        ArrayRef<int64_t> dimensions);
+ParseResult parseDimensionList(OpAsmParser &parser,
+                               DenseI64ArrayAttr &dimensions);
 
 } // namespace mlir
 
diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index 416d0795e9a638..62628a9391a63c 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -3932,37 +3932,40 @@ void Block::printAsOperand(raw_ostream &os, AsmState &state) {
 }
 
 //===--------------------------------------------------------------------===//
-// Custom printers and parsers.
+// Custom printers
 //===--------------------------------------------------------------------===//
 namespace mlir {
 
-void printShape(OpAsmPrinter &printer, Operation *op, ArrayRef<int64_t> shape) {
-  if (!shape.empty())
+void printDimensionList(OpAsmPrinter &printer, Operation *op,
+                        ArrayRef<int64_t> dimensions) {
+  if (dimensions.empty())
     printer << "[";
-  printer.printDimensionList(shape);
-  if (!shape.empty())
+  printer.printDimensionList(dimensions);
+  if (dimensions.empty())
     printer << "]";
 }
 
-ParseResult parseShape(OpAsmParser &parser, DenseI64ArrayAttr &shape) {
+ParseResult parseDimensionList(OpAsmParser &parser,
+                               DenseI64ArrayAttr &dimensions) {
   bool hasOpeningSquare = succeeded(parser.parseOptionalLSquare());
   SmallVector<int64_t> shapeArr;
   if (failed(parser.parseDimensionList(shapeArr, true, false))) {
     return parser.emitError(parser.getCurrentLocation())
-           << "Failed parsing shape.";
+           << "Failed parsing dimension list.";
   }
   if (shapeArr.empty() && !hasOpeningSquare) {
     return parser.emitError(parser.getCurrentLocation())
-           << "Failed parsing shape. Did you mean a 0-rank shape? It must be "
+           << "Failed parsing dimension list. Did you mean an empty list? It "
+              "must be "
               "denoted by \"[]\".";
   }
   if (hasOpeningSquare && failed(parser.parseRSquare())) {
     return parser.emitError(parser.getCurrentLocation())
-           << "Failed parsing shape.";
+           << "Failed parsing dimension list.";
   }
 
-  shape = DenseI64ArrayAttr::get(parser.getContext(), shapeArr);
+  dimensions = DenseI64ArrayAttr::get(parser.getContext(), shapeArr);
   return success();
 }
 
-} // namespace mlir
\ No newline at end of file
+} // namespace mlir
diff --git a/mlir/test/Dialect/Mesh/invalid.mlir b/mlir/test/Dialect/Mesh/invalid.mlir
index 875fd1839d3d19..a26e3950186e95 100644
--- a/mlir/test/Dialect/Mesh/invalid.mlir
+++ b/mlir/test/Dialect/Mesh/invalid.mlir
@@ -10,7 +10,7 @@ mesh.cluster @mesh0(rank = 2, dim_sizes = 2x3x4)
 
 // -----
 
-// expected-error at +1 {{custom op 'mesh.cluster' Failed parsing shape}}
+// expected-error at +1 {{custom op 'mesh.cluster' Failed parsing dimension list. Did you mean an empty list? It must be denoted by "[]".}}
 mesh.cluster @mesh0(rank = 2, dim_sizes = -1)
 
 // -----
diff --git a/mlir/test/Dialect/Mesh/ops.mlir b/mlir/test/Dialect/Mesh/ops.mlir
index 2e10c359f9ad89..2261a4d2ad5563 100644
--- a/mlir/test/Dialect/Mesh/ops.mlir
+++ b/mlir/test/Dialect/Mesh/ops.mlir
@@ -14,6 +14,9 @@ mesh.cluster @mesh3(rank = 2)
 
 mesh.cluster @mesh4(rank = 1, dim_sizes = 3)
 
+// CHECK: mesh.cluster @mesh5(rank = 1)
+mesh.cluster @mesh5(rank = 1, dim_sizes = [])
+
 // CHECK-LABEL: func @mesh_shard_encoding_fully_replicated
 func.func @mesh_shard_encoding_fully_replicated(
     // CHECK-SAME: %[[ARG:.*]]: tensor<4x8xf32, #mesh.shard<@mesh0, {{\[\[}}]]>>
diff --git a/mlir/test/IR/custom-print-parse.mlir b/mlir/test/IR/custom-print-parse.mlir
new file mode 100644
index 00000000000000..ec98d53525703a
--- /dev/null
+++ b/mlir/test/IR/custom-print-parse.mlir
@@ -0,0 +1,18 @@
+// RUN: mlir-opt %s | FileCheck %s
+
+// CHECK-LABEL: module @dimension_list
+module @dimension_list {
+  // CHECK: test.custom_dimension_list_attr dimension_list = []
+  test.custom_dimension_list_attr dimension_list = []
+  // CHECK: test.custom_dimension_list_attr dimension_list = 3
+  test.custom_dimension_list_attr dimension_list = 3
+  // CHECK: test.custom_dimension_list_attr dimension_list = 0
+  test.custom_dimension_list_attr dimension_list = 0
+  // CHECK: test.custom_dimension_list_attr dimension_list = 1x2
+  test.custom_dimension_list_attr dimension_list = 1x2
+  // CHECK: test.custom_dimension_list_attr dimension_list = ?
+  test.custom_dimension_list_attr dimension_list = ?
+  // CHECK: test.custom_dimension_list_attr dimension_list = ?x?
+  test.custom_dimension_list_attr dimension_list = ?x?
+
+}
diff --git a/mlir/test/IR/invalid-custom-print-parse.mlir b/mlir/test/IR/invalid-custom-print-parse.mlir
new file mode 100644
index 00000000000000..aa27d4817e44c9
--- /dev/null
+++ b/mlir/test/IR/invalid-custom-print-parse.mlir
@@ -0,0 +1,10 @@
+// RUN: mlir-opt %s -split-input-file -verify-diagnostics
+
+// expected-error at +2 {{invalid dimension}}
+// expected-error at +1 {{custom op 'test.custom_dimension_list_attr' Failed parsing dimension list.}}
+test.custom_dimension_list_attr dimension_list = 1x-1
+
+// -----
+
+// expected-error at +1 {{custom op 'test.custom_dimension_list_attr' Failed parsing dimension list. Did you mean an empty list? It must be denoted by "[]".}}
+test.custom_dimension_list_attr dimension_list = -1
diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index 1add9bd3c32943..96f66c2ca06ecf 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -2020,6 +2020,20 @@ def AffineScopeOp : TEST_Op<"affine_scope", [AffineScope]> {
   let hasCustomAssemblyFormat = 1;
 }
 
+//===----------------------------------------------------------------------===//
+// Custom printer/parser
+
+def CustomDimensionListAttrOp : TEST_Op<"custom_dimension_list_attr"> {
+  let description = [{
+    Test printing/parsing of dimension list attribute.
+  }];
+  let arguments = (ins DenseI64ArrayAttr:$dimension_list);
+  let assemblyFormat = [{
+    `dimension_list` `=` custom<DimensionList>($dimension_list)
+    attr-dict
+  }];
+}
+
 //===----------------------------------------------------------------------===//
 // Test OpAsmInterface.
 
diff --git a/mlir/test/lib/Dialect/Test/TestOpsSyntax.td b/mlir/test/lib/Dialect/Test/TestOpsSyntax.td
index 9522a775e247da..135927084ec6c4 100644
--- a/mlir/test/lib/Dialect/Test/TestOpsSyntax.td
+++ b/mlir/test/lib/Dialect/Test/TestOpsSyntax.td
@@ -568,6 +568,8 @@ def FormatLiteralFollowingOptionalGroup
   let assemblyFormat = "(`(` $value^ `)`)? `:` $type attr-dict";
 }
 
+
+
 //===----------------------------------------------------------------------===//
 // AllTypesMatch type inference
 

>From 2985352251fa92b0f64899720aaf8c3890a2a417 Mon Sep 17 00:00:00 2001
From: Boian Petkantchin <boian.petkantchin at amd.com>
Date: Wed, 6 Dec 2023 08:43:05 -0800
Subject: [PATCH 10/11] Disallow non-empty dim lists of the form [1x2]

---
 mlir/lib/IR/AsmPrinter.cpp                   | 8 ++++----
 mlir/test/Dialect/Mesh/ops.mlir              | 4 ++--
 mlir/test/IR/invalid-custom-print-parse.mlir | 6 ++++++
 3 files changed, 12 insertions(+), 6 deletions(-)

diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index 62628a9391a63c..f448fc068844fb 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -3948,6 +3948,10 @@ void printDimensionList(OpAsmPrinter &printer, Operation *op,
 ParseResult parseDimensionList(OpAsmParser &parser,
                                DenseI64ArrayAttr &dimensions) {
   bool hasOpeningSquare = succeeded(parser.parseOptionalLSquare());
+  if (hasOpeningSquare && failed(parser.parseRSquare())) {
+    return parser.emitError(parser.getCurrentLocation())
+           << "Failed parsing dimension list.";
+  }
   SmallVector<int64_t> shapeArr;
   if (failed(parser.parseDimensionList(shapeArr, true, false))) {
     return parser.emitError(parser.getCurrentLocation())
@@ -3959,10 +3963,6 @@ ParseResult parseDimensionList(OpAsmParser &parser,
               "must be "
               "denoted by \"[]\".";
   }
-  if (hasOpeningSquare && failed(parser.parseRSquare())) {
-    return parser.emitError(parser.getCurrentLocation())
-           << "Failed parsing dimension list.";
-  }
 
   dimensions = DenseI64ArrayAttr::get(parser.getContext(), shapeArr);
   return success();
diff --git a/mlir/test/Dialect/Mesh/ops.mlir b/mlir/test/Dialect/Mesh/ops.mlir
index 2261a4d2ad5563..78ce276a7b33a3 100644
--- a/mlir/test/Dialect/Mesh/ops.mlir
+++ b/mlir/test/Dialect/Mesh/ops.mlir
@@ -3,10 +3,10 @@
 // CHECK: mesh.cluster @mesh0
 mesh.cluster @mesh0(rank = 3, dim_sizes = 2x2x4)
 
-// CHECK: mesh.cluster @mesh1
+// CHECK: mesh.cluster @mesh1(rank = 2, dim_sizes = 4)
 mesh.cluster @mesh1(rank = 2, dim_sizes = 4)
 
-// CHECK: mesh.cluster @mesh2
+// CHECK: mesh.cluster @mesh2(rank = 2, dim_sizes = ?x4)
 mesh.cluster @mesh2(rank = 2, dim_sizes = ?x4)
 
 // CHECK: mesh.cluster @mesh3
diff --git a/mlir/test/IR/invalid-custom-print-parse.mlir b/mlir/test/IR/invalid-custom-print-parse.mlir
index aa27d4817e44c9..456b16c91bc04b 100644
--- a/mlir/test/IR/invalid-custom-print-parse.mlir
+++ b/mlir/test/IR/invalid-custom-print-parse.mlir
@@ -8,3 +8,9 @@ test.custom_dimension_list_attr dimension_list = 1x-1
 
 // expected-error at +1 {{custom op 'test.custom_dimension_list_attr' Failed parsing dimension list. Did you mean an empty list? It must be denoted by "[]".}}
 test.custom_dimension_list_attr dimension_list = -1
+
+// -----
+
+// expected-error at +2 {{expected ']'}}
+// expected-error at +1 {{custom op 'test.custom_dimension_list_attr' Failed parsing dimension list.}}
+test.custom_dimension_list_attr dimension_list = [2x3]

>From 79c9eb48c73214f343a14544cfb5367a849b4afd Mon Sep 17 00:00:00 2001
From: Boian Petkantchin <boian.petkantchin at amd.com>
Date: Thu, 7 Dec 2023 09:12:11 -0800
Subject: [PATCH 11/11] Parser logic fix for the empty list case

---
 mlir/lib/IR/AsmPrinter.cpp                  | 22 +++++++++++++--------
 mlir/test/IR/custom-print-parse.mlir        |  1 -
 mlir/test/lib/Dialect/Test/TestOpsSyntax.td |  2 --
 3 files changed, 14 insertions(+), 11 deletions(-)

diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index f448fc068844fb..1f7cbf349255d5 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -29,6 +29,7 @@
 #include "mlir/IR/Operation.h"
 #include "mlir/IR/Verifier.h"
 #include "llvm/ADT/APFloat.h"
+#include "llvm/ADT/ArrayRef.h"
 #include "llvm/ADT/DenseMap.h"
 #include "llvm/ADT/MapVector.h"
 #include "llvm/ADT/STLExtras.h"
@@ -3947,23 +3948,28 @@ void printDimensionList(OpAsmPrinter &printer, Operation *op,
 
 ParseResult parseDimensionList(OpAsmParser &parser,
                                DenseI64ArrayAttr &dimensions) {
-  bool hasOpeningSquare = succeeded(parser.parseOptionalLSquare());
-  if (hasOpeningSquare && failed(parser.parseRSquare())) {
-    return parser.emitError(parser.getCurrentLocation())
-           << "Failed parsing dimension list.";
+  // Empty list case denoted by "[]".
+  if (succeeded(parser.parseOptionalLSquare())) {
+    if (failed(parser.parseRSquare())) {
+      return parser.emitError(parser.getCurrentLocation())
+             << "Failed parsing dimension list.";
+    }
+    dimensions =
+        DenseI64ArrayAttr::get(parser.getContext(), ArrayRef<int64_t>());
+    return success();
   }
+
+  // Non-empty list case.
   SmallVector<int64_t> shapeArr;
   if (failed(parser.parseDimensionList(shapeArr, true, false))) {
     return parser.emitError(parser.getCurrentLocation())
            << "Failed parsing dimension list.";
   }
-  if (shapeArr.empty() && !hasOpeningSquare) {
+  if (shapeArr.empty()) {
     return parser.emitError(parser.getCurrentLocation())
            << "Failed parsing dimension list. Did you mean an empty list? It "
-              "must be "
-              "denoted by \"[]\".";
+              "must be denoted by \"[]\".";
   }
-
   dimensions = DenseI64ArrayAttr::get(parser.getContext(), shapeArr);
   return success();
 }
diff --git a/mlir/test/IR/custom-print-parse.mlir b/mlir/test/IR/custom-print-parse.mlir
index ec98d53525703a..b157fd1b1ea338 100644
--- a/mlir/test/IR/custom-print-parse.mlir
+++ b/mlir/test/IR/custom-print-parse.mlir
@@ -14,5 +14,4 @@ module @dimension_list {
   test.custom_dimension_list_attr dimension_list = ?
   // CHECK: test.custom_dimension_list_attr dimension_list = ?x?
   test.custom_dimension_list_attr dimension_list = ?x?
-
 }
diff --git a/mlir/test/lib/Dialect/Test/TestOpsSyntax.td b/mlir/test/lib/Dialect/Test/TestOpsSyntax.td
index 135927084ec6c4..9522a775e247da 100644
--- a/mlir/test/lib/Dialect/Test/TestOpsSyntax.td
+++ b/mlir/test/lib/Dialect/Test/TestOpsSyntax.td
@@ -568,8 +568,6 @@ def FormatLiteralFollowingOptionalGroup
   let assemblyFormat = "(`(` $value^ `)`)? `:` $type attr-dict";
 }
 
-
-
 //===----------------------------------------------------------------------===//
 // AllTypesMatch type inference
 



More information about the Mlir-commits mailing list