[Mlir-commits] [mlir] [mlir][mesh] Use tensor shape notation for the shape of a cluster (PR #73826)
Boian Petkantchin
llvmlistbot at llvm.org
Mon Dec 4 15:44:05 PST 2023
https://github.com/sogartar updated https://github.com/llvm/llvm-project/pull/73826
>From 965102ddbf67f7592c57270560894df202986b60 Mon Sep 17 00:00:00 2001
From: Boian Petkantchin <boian.petkantchin at amd.com>
Date: Wed, 29 Nov 2023 08:55:16 -0800
Subject: [PATCH 1/7] [mlir][mesh] Use tensor shape notation for the shape of a
cluster
Examle:
substitute
mesh.cluster @mesh0(rank = 2, dim_sizes = [0, 4])
with
mesh.cluster @mesh0(rank = 2, dim_sizes = ?x4)
The only difference is for 0-rank shapes.
With tensors you would have something like `tensor<f32>`.
Here to avoid matching an empty string a 0-rank shape is denoted by `[]`.
---
mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td | 27 ++++-----
mlir/include/mlir/IR/BuiltinAttributes.h | 5 ++
mlir/lib/Dialect/Mesh/IR/MeshOps.cpp | 13 ++---
mlir/lib/IR/BuiltinAttributes.cpp | 41 ++++++++++++++
mlir/test/Dialect/Mesh/canonicalization.mlir | 2 +-
mlir/test/Dialect/Mesh/invalid.mlir | 56 +++++++++----------
mlir/test/Dialect/Mesh/ops.mlir | 8 +--
.../Dialect/Mesh/sharding-propagation.mlir | 2 +-
8 files changed, 98 insertions(+), 56 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
index 5cce15dd1015e..83139035eda52 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
@@ -40,26 +40,27 @@ def Mesh_ClusterOp : Mesh_Op<"cluster", [Symbol]> {
determine the layout and the addressing space of the computation distributed
across the mesh.
- 3. `dim_sizes`: This attribute represents the device assignment along the
- axes of the cluster. Each integer in the array corresponds to the number of
- devices along a specific axis. If an integer value is 0, it implies that the
- number of devices along that axis is unknown. This flexibility allows for
- dynamic device assignment or configurations where the exact number of
- devices might not be determined during compile time.
+ 3. `dim_sizes`: This attribute represents the shape of the device cluster.
+ It uses the same notation as a tensor shape. Also allowing for dynamic
+ dimensions.
+ This flexibility allows for dynamic device assignment or configurations
+ where the exact number of devices might not be determined during compile
+ time.
+ For example `2x?x4`.
Example:
```
// A device mesh cluster with 3 axes, the total device number is 4 * 8 * 12
// The dimension sizes are 4, 8, 12
- mesh.cluster @mesh0(rank = 3, dim_sizes = [4, 8, 12])
+ mesh.cluster @mesh0(rank = 3, dim_sizes = 4x8x12)
// A device mesh cluster with 2 axes, the total device number is unknown
// The first dimension size is 4 and the second is unknown
- mesh.cluster @mesh1(rank = 2, dim_sizes = [4])
+ mesh.cluster @mesh1(rank = 2, dim_sizes = 4)
// A device mesh cluster with 2 axes, the total device number is unknown
// The first dimension size is unknown and the second is 4
- mesh.cluster @mesh2(rank = 2, dim_sizes = [0, 4])
+ mesh.cluster @mesh2(rank = 2, dim_sizes = ?x4)
// A device mesh cluster with 2 axes, the number of devices along both axes
// is unknown
@@ -76,7 +77,7 @@ def Mesh_ClusterOp : Mesh_Op<"cluster", [Symbol]> {
DefaultValuedAttr<DenseI64ArrayAttr, "{}">:$dim_sizes
);
let assemblyFormat = [{
- $sym_name `(` `rank` `=` $rank (`,` `dim_sizes` `=` $dim_sizes^)? `)`
+ $sym_name `(` `rank` `=` $rank (`,` `dim_sizes` `=` custom<Shape>($dim_sizes)^)? `)`
attr-dict
}];
let extraClassDeclaration = [{
@@ -210,7 +211,7 @@ def Mesh_AllGatherOp : Mesh_CollectiveCommunicationOpBase<"all_gather", [
Example:
```mlir
- mesh.cluster @mesh0(rank = 2, dim_sizes = [2, 2])
+ mesh.cluster @mesh0(rank = 2, dim_sizes = 2x2)
...
%1 = mesh.all_gather %0 on @mesh0 mesh_axes = [1] gather_axis = 1
: tensor<2x2xi8> -> tensor<2x4xi8>
@@ -295,7 +296,7 @@ def Mesh_AllToAllOp : Mesh_CollectiveCommunicationOpBase<"all_to_all", [
Example:
```
- mesh.cluster @mesh0(rank = 1, dim_sizes = [3])
+ mesh.cluster @mesh0(rank = 1, dim_sizes = 3)
...
%1 = mesh.all_to_all %0 on @mesh0 mesh_axes = [0]
split_axis = 0 concat_axis = 0
@@ -348,7 +349,7 @@ def Mesh_ReduceScatterOp : Mesh_CollectiveCommunicationOpBase<"reduce_scatter",
across the device group.
Example:
```
- mesh.cluster @mesh0(rank = 1, dim_sizes = [2, 2])
+ mesh.cluster @mesh0(rank = 1, dim_sizes = 2x2)
...
%1 = mesh.reduce_scatter %0 on @mesh0 mesh_axes = [1]
reduction = <max> scatter_axis = 0
diff --git a/mlir/include/mlir/IR/BuiltinAttributes.h b/mlir/include/mlir/IR/BuiltinAttributes.h
index c8161604aad35..9261c82eae5b7 100644
--- a/mlir/include/mlir/IR/BuiltinAttributes.h
+++ b/mlir/include/mlir/IR/BuiltinAttributes.h
@@ -28,6 +28,8 @@ class FunctionType;
class IntegerSet;
class IntegerType;
class Location;
+class OpAsmParser;
+class OpAsmPrinter;
class Operation;
class RankedTensorType;
@@ -1099,6 +1101,9 @@ namespace mlir {
AffineMap makeStridedLinearLayoutMap(ArrayRef<int64_t> strides, int64_t offset,
MLIRContext *context);
+void printShape(OpAsmPrinter &printer, Operation *op, ArrayRef<int64_t> shape);
+ParseResult parseShape(OpAsmParser &parser, DenseI64ArrayAttr &shape);
+
} // namespace mlir
namespace llvm {
diff --git a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
index b45f7cd21ce92..4b08ecc69628b 100644
--- a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
+++ b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
@@ -56,11 +56,6 @@ static SmallVector<T> &canonicalizeSetAsVector(SmallVector<T> &vec) {
return vec;
}
-template <typename DimSize>
-static bool isMeshDimensionDynamic(DimSize size) {
- return size <= DimSize(0);
-}
-
using MeshAxis = int16_t;
namespace {
@@ -159,9 +154,9 @@ LogicalResult ClusterOp::verify() {
"rank of dim_sizes is not expected to be larger than rank of cluster");
for (int64_t dimSize : dimSizes) {
- if (dimSize < 0)
- return emitOpError(
- "dimension size of a mesh cluster is expected to be non-negative");
+ if (dimSize < 0 && !ShapedType::isDynamic(dimSize))
+ return emitOpError("dimension size of a mesh cluster is expected to be "
+ "non-negative or dynamic");
}
return success();
@@ -314,7 +309,7 @@ static int64_t collectiveDeviceGroupSize(ArrayRef<MeshAxis> meshAxes,
int64_t res = 1;
for (MeshAxis axis : meshAxes) {
- if (isMeshDimensionDynamic(meshShape[axis])) {
+ if (ShapedType::isDynamic(meshShape[axis])) {
return ShapedType::kDynamic;
}
assert(size_t(axis) < meshShape.size());
diff --git a/mlir/lib/IR/BuiltinAttributes.cpp b/mlir/lib/IR/BuiltinAttributes.cpp
index 89b1ed67f5d06..1396cd6a90ffc 100644
--- a/mlir/lib/IR/BuiltinAttributes.cpp
+++ b/mlir/lib/IR/BuiltinAttributes.cpp
@@ -17,6 +17,7 @@
#include "mlir/IR/Operation.h"
#include "mlir/IR/SymbolTable.h"
#include "mlir/IR/Types.h"
+#include "mlir/Support/LogicalResult.h"
#include "llvm/ADT/APSInt.h"
#include "llvm/ADT/Sequence.h"
#include "llvm/ADT/TypeSwitch.h"
@@ -1822,3 +1823,43 @@ AffineMap mlir::makeStridedLinearLayoutMap(ArrayRef<int64_t> strides,
return AffineMap::get(strides.size(), nSymbols, expr);
}
+
+void mlir::printShape(OpAsmPrinter &printer, Operation *op,
+ ArrayRef<int64_t> shape) {
+ if (!shape.empty())
+ printer << "[";
+
+ for (size_t i = 0; i < shape.size(); ++i) {
+ if (ShapedType::isDynamic(shape[i]))
+ printer << '?';
+ else
+ printer << shape[i];
+ if (i != shape.size() - 1) {
+ printer << 'x';
+ }
+ }
+
+ if (!shape.empty())
+ printer << "]";
+}
+
+ParseResult mlir::parseShape(OpAsmParser &parser, DenseI64ArrayAttr &shape) {
+ bool hasOpeningSquare = succeeded(parser.parseOptionalLSquare());
+ SmallVector<int64_t> shapeArr;
+ if (failed(parser.parseDimensionList(shapeArr, true, false))) {
+ return parser.emitError(parser.getCurrentLocation())
+ << "Failed parsing shape.";
+ }
+ if (shapeArr.empty() && !hasOpeningSquare) {
+ return parser.emitError(parser.getCurrentLocation())
+ << "Failed parsing shape. Did you mean a 0-rank shape? It must be "
+ "denoted by \"[]\".";
+ }
+ if (hasOpeningSquare && failed(parser.parseRSquare())) {
+ return parser.emitError(parser.getCurrentLocation())
+ << "Failed parsing shape.";
+ }
+
+ shape = DenseI64ArrayAttr::get(parser.getContext(), shapeArr);
+ return success();
+}
diff --git a/mlir/test/Dialect/Mesh/canonicalization.mlir b/mlir/test/Dialect/Mesh/canonicalization.mlir
index 5802d198d3681..baee9faa645c9 100644
--- a/mlir/test/Dialect/Mesh/canonicalization.mlir
+++ b/mlir/test/Dialect/Mesh/canonicalization.mlir
@@ -1,6 +1,6 @@
// RUN: mlir-opt --canonicalize %s | FileCheck %s
-mesh.cluster @mesh0(rank = 2, dim_sizes = [2, 4])
+mesh.cluster @mesh0(rank = 2, dim_sizes = 2x4)
// CHECK-LABEL: func @all_reduce_empty_mesh_axes
func.func @all_reduce_empty_mesh_axes(
diff --git a/mlir/test/Dialect/Mesh/invalid.mlir b/mlir/test/Dialect/Mesh/invalid.mlir
index 2999668f770ba..dc91ff54fd839 100644
--- a/mlir/test/Dialect/Mesh/invalid.mlir
+++ b/mlir/test/Dialect/Mesh/invalid.mlir
@@ -6,16 +6,16 @@ mesh.cluster @mesh0(rank = 0)
// -----
// expected-error at +1 {{rank of dim_sizes is not expected to be larger than rank of cluster}}
-mesh.cluster @mesh0(rank = 2, dim_sizes = [2, 3, 4])
+mesh.cluster @mesh0(rank = 2, dim_sizes = 2x3x4)
// -----
-// expected-error at +1 {{dimension size of a mesh cluster is expected to be non-negative}}
-mesh.cluster @mesh0(rank = 2, dim_sizes = [-1])
+// expected-error at +1 {{unexpected error: custom op 'mesh.cluster' Failed parsing shape}}
+mesh.cluster @mesh0(rank = 2, dim_sizes = -1)
// -----
-mesh.cluster @mesh0(rank = 2, dim_sizes = [2, 4])
+mesh.cluster @mesh0(rank = 2, dim_sizes = 2x4)
func.func @mesh_axis_duplicated_different_subarray(
// expected-error at +1 {{mesh axis duplicated}}
@@ -26,7 +26,7 @@ func.func @mesh_axis_duplicated_different_subarray(
// -----
-mesh.cluster @mesh0(rank = 2, dim_sizes = [2, 4])
+mesh.cluster @mesh0(rank = 2, dim_sizes = 2x4)
func.func @mesh_axis_duplicated_same_subarray(
// expected-error at +1 {{mesh axis duplicated}}
@@ -37,7 +37,7 @@ func.func @mesh_axis_duplicated_same_subarray(
// -----
-mesh.cluster @mesh0(rank = 2, dim_sizes = [2, 4])
+mesh.cluster @mesh0(rank = 2, dim_sizes = 2x4)
func.func @mesh_axis_duplicated_bewteen_split_and_partial(
// expected-error at +1 {{mesh axis duplicated}}
@@ -48,7 +48,7 @@ func.func @mesh_axis_duplicated_bewteen_split_and_partial(
// -----
-mesh.cluster @mesh0(rank = 2, dim_sizes = [2, 4])
+mesh.cluster @mesh0(rank = 2, dim_sizes = 2x4)
func.func @mesh_axis_negtive_in_split_part(
// expected-error at +1 {{mesh axis is expected to be non-negative}}
@@ -59,7 +59,7 @@ func.func @mesh_axis_negtive_in_split_part(
// -----
-mesh.cluster @mesh0(rank = 2, dim_sizes = [2, 4])
+mesh.cluster @mesh0(rank = 2, dim_sizes = 2x4)
func.func @mesh_axis_negtive_in_partial(
// expected-error at +1 {{mesh axis is expected to be non-negative}}
@@ -80,7 +80,7 @@ func.func @all_reduce_invalid_mesh_symbol(
// -----
-mesh.cluster @mesh0(rank = 2, dim_sizes = [2, 4])
+mesh.cluster @mesh0(rank = 2, dim_sizes = 2x4)
func.func @all_reduce_invalid_mesh_axis(
%arg0 : tensor<4xf32>) -> tensor<4xf64> {
@@ -92,7 +92,7 @@ func.func @all_reduce_invalid_mesh_axis(
// -----
-mesh.cluster @mesh0(rank = 2, dim_sizes = [2, 4])
+mesh.cluster @mesh0(rank = 2, dim_sizes = 2x4)
func.func @all_reduce_duplicate_mesh_axis(
%arg0 : tensor<4xf32>) -> tensor<4xf64> {
@@ -104,7 +104,7 @@ func.func @all_reduce_duplicate_mesh_axis(
// -----
-mesh.cluster @mesh0(rank = 2, dim_sizes = [2, 4])
+mesh.cluster @mesh0(rank = 2, dim_sizes = 2x4)
func.func @all_reduce_invalid_tensor_dimension_size(
%arg0 : tensor<4xf32>) -> tensor<5xf64> {
@@ -125,7 +125,7 @@ func.func @all_gather_invalid_mesh_symbol(
// -----
-mesh.cluster @mesh0(rank = 2, dim_sizes = [2, 4])
+mesh.cluster @mesh0(rank = 2, dim_sizes = 2x4)
func.func @all_gather_invalid_mesh_axis(
%arg0 : tensor<4xf32>) -> tensor<4xf32> {
@@ -137,7 +137,7 @@ func.func @all_gather_invalid_mesh_axis(
// -----
-mesh.cluster @mesh0(rank = 2, dim_sizes = [2, 4])
+mesh.cluster @mesh0(rank = 2, dim_sizes = 2x4)
func.func @all_reduce_duplicate_mesh_axis(
%arg0 : tensor<4xf32>) -> tensor<4xf32> {
@@ -149,7 +149,7 @@ func.func @all_reduce_duplicate_mesh_axis(
// -----
-mesh.cluster @mesh0(rank = 1, dim_sizes = [1])
+mesh.cluster @mesh0(rank = 1, dim_sizes = 1)
func.func @all_gather_invalid_non_gather_axis_dimension_size(
%arg0 : tensor<3x4xf32>) -> tensor<3x5xf32> {
@@ -161,7 +161,7 @@ func.func @all_gather_invalid_non_gather_axis_dimension_size(
// -----
-mesh.cluster @mesh0(rank = 2, dim_sizes = [1, 2])
+mesh.cluster @mesh0(rank = 2, dim_sizes = 1x2)
func.func @all_gather_invalid_gather_axis_dimension_size(
%arg0 : tensor<3x4xf32>) -> tensor<3x5xf32> {
@@ -173,7 +173,7 @@ func.func @all_gather_invalid_gather_axis_dimension_size(
// -----
-mesh.cluster @mesh0(rank = 1, dim_sizes = [1])
+mesh.cluster @mesh0(rank = 1, dim_sizes = 1)
func.func @all_gather_invalid_gather_axis_dynamic_dimension(
%arg0 : tensor<?xf32>) -> tensor<3xf32> {
@@ -185,7 +185,7 @@ func.func @all_gather_invalid_gather_axis_dynamic_dimension(
// -----
-mesh.cluster @mesh0(rank = 1, dim_sizes = [1])
+mesh.cluster @mesh0(rank = 1, dim_sizes = 1)
func.func @all_gather_invalid_gather_axis(
%arg0 : tensor<3xf32>) -> tensor<3xf32> {
@@ -197,7 +197,7 @@ func.func @all_gather_invalid_gather_axis(
// -----
-mesh.cluster @mesh0(rank = 1, dim_sizes = [1])
+mesh.cluster @mesh0(rank = 1, dim_sizes = 1)
func.func @all_gather_invalid_negative_gather_axis(
%arg0 : tensor<3xf32>) -> tensor<3xf32> {
@@ -220,7 +220,7 @@ func.func @all_to_all_invalid_mesh_symbol(
// -----
-mesh.cluster @mesh0(rank = 1, dim_sizes = [1])
+mesh.cluster @mesh0(rank = 1, dim_sizes = 1)
func.func @all_to_all_duplicate_mesh_axis(
%arg0 : tensor<3x6xi8>) -> tensor<3x6xi8> {
@@ -233,7 +233,7 @@ func.func @all_to_all_duplicate_mesh_axis(
// -----
-mesh.cluster @mesh0(rank = 2, dim_sizes = [0, 1])
+mesh.cluster @mesh0(rank = 2, dim_sizes = ?x1)
func.func @all_to_all_invalid_non_dynamic_result_dimension_induced_by_dynamic_device_group(
%arg0 : tensor<3x6xi8>) -> tensor<3x6xi8> {
@@ -246,7 +246,7 @@ func.func @all_to_all_invalid_non_dynamic_result_dimension_induced_by_dynamic_de
// -----
-mesh.cluster @mesh0(rank = 2, dim_sizes = [1, 1])
+mesh.cluster @mesh0(rank = 2, dim_sizes = 1x1)
func.func @all_to_all_invalid_non_dynamic_result_split_dimension_induced_by_dynamic_operand_dimension(
%arg0 : tensor<?x6xi8>) -> tensor<3x?xi8> {
@@ -259,7 +259,7 @@ func.func @all_to_all_invalid_non_dynamic_result_split_dimension_induced_by_dyna
// -----
-mesh.cluster @mesh0(rank = 2, dim_sizes = [1, 1])
+mesh.cluster @mesh0(rank = 2, dim_sizes = 1x1)
func.func @all_to_all_invalid_non_dynamic_result_concat_dimension_induced_by_dynamic_operand_dimension(
%arg0 : tensor<3x?xi8>) -> tensor<?x3xi8> {
@@ -272,7 +272,7 @@ func.func @all_to_all_invalid_non_dynamic_result_concat_dimension_induced_by_dyn
// -----
-mesh.cluster @mesh0(rank = 1, dim_sizes = [3])
+mesh.cluster @mesh0(rank = 1, dim_sizes = 3)
func.func @all_to_all_invalid_non_dynamic_result_concat_dimension_size(
%arg0 : tensor<3x2xi8>) -> tensor<1x7xi8> {
@@ -285,7 +285,7 @@ func.func @all_to_all_invalid_non_dynamic_result_concat_dimension_size(
// -----
-mesh.cluster @mesh0(rank = 1, dim_sizes = [3])
+mesh.cluster @mesh0(rank = 1, dim_sizes = 3)
func.func @all_to_all_invalid_non_dynamic_result_split_dimension_size(
%arg0 : tensor<3x2xi8>) -> tensor<2x6xi8> {
@@ -298,7 +298,7 @@ func.func @all_to_all_invalid_non_dynamic_result_split_dimension_size(
// -----
-mesh.cluster @mesh0(rank = 1, dim_sizes = [3])
+mesh.cluster @mesh0(rank = 1, dim_sizes = 3)
func.func @reduce_scatter_duplicate_mesh_axis(
%arg0 : tensor<?xf32>) -> tensor<?xf64> {
@@ -310,7 +310,7 @@ func.func @reduce_scatter_duplicate_mesh_axis(
// -----
-mesh.cluster @mesh0(rank = 1, dim_sizes = [3])
+mesh.cluster @mesh0(rank = 1, dim_sizes = 3)
func.func @reduce_scatter_invalid_dynamic_dimension(
%arg0 : tensor<?xf32>) -> tensor<2xf64> {
@@ -322,7 +322,7 @@ func.func @reduce_scatter_invalid_dynamic_dimension(
// -----
-mesh.cluster @mesh0(rank = 1, dim_sizes = [3])
+mesh.cluster @mesh0(rank = 1, dim_sizes = 3)
func.func @reduce_scatter_invalid_static_dimension_size(
%arg0 : tensor<3xf32>) -> tensor<2xf64> {
@@ -334,7 +334,7 @@ func.func @reduce_scatter_invalid_static_dimension_size(
// -----
-mesh.cluster @mesh0(rank = 1, dim_sizes = [3])
+mesh.cluster @mesh0(rank = 1, dim_sizes = 3)
func.func @reduce_scatter_invalid_operand_static_dimension_size(
%arg0 : tensor<4xf32>) -> tensor<?xf64> {
diff --git a/mlir/test/Dialect/Mesh/ops.mlir b/mlir/test/Dialect/Mesh/ops.mlir
index 5b264bc88dfc2..2e10c359f9ad8 100644
--- a/mlir/test/Dialect/Mesh/ops.mlir
+++ b/mlir/test/Dialect/Mesh/ops.mlir
@@ -1,18 +1,18 @@
// RUN: mlir-opt %s | mlir-opt | FileCheck %s
// CHECK: mesh.cluster @mesh0
-mesh.cluster @mesh0(rank = 3, dim_sizes = [2, 2, 4])
+mesh.cluster @mesh0(rank = 3, dim_sizes = 2x2x4)
// CHECK: mesh.cluster @mesh1
-mesh.cluster @mesh1(rank = 2, dim_sizes = [4])
+mesh.cluster @mesh1(rank = 2, dim_sizes = 4)
// CHECK: mesh.cluster @mesh2
-mesh.cluster @mesh2(rank = 2, dim_sizes = [0, 4])
+mesh.cluster @mesh2(rank = 2, dim_sizes = ?x4)
// CHECK: mesh.cluster @mesh3
mesh.cluster @mesh3(rank = 2)
-mesh.cluster @mesh4(rank = 1, dim_sizes = [3])
+mesh.cluster @mesh4(rank = 1, dim_sizes = 3)
// CHECK-LABEL: func @mesh_shard_encoding_fully_replicated
func.func @mesh_shard_encoding_fully_replicated(
diff --git a/mlir/test/Dialect/Mesh/sharding-propagation.mlir b/mlir/test/Dialect/Mesh/sharding-propagation.mlir
index bda407b52bfd4..30bbd5c6619e8 100644
--- a/mlir/test/Dialect/Mesh/sharding-propagation.mlir
+++ b/mlir/test/Dialect/Mesh/sharding-propagation.mlir
@@ -1,7 +1,7 @@
// RUN: mlir-opt -sharding-propagation %s | FileCheck %s
mesh.cluster @mesh_1d(rank = 1)
-mesh.cluster @mesh_2d(rank = 2, dim_sizes = [2, 4])
+mesh.cluster @mesh_2d(rank = 2, dim_sizes = 2x4)
mesh.cluster @mesh_3d(rank = 3)
// CHECK-LABEL: func.func @element_wise_empty_sharding_info
>From c923af9714fcdc2bb0270e0f25ccf20473633a29 Mon Sep 17 00:00:00 2001
From: Boian Petkantchin <boian.petkantchin at amd.com>
Date: Thu, 30 Nov 2023 08:43:33 -0800
Subject: [PATCH 2/7] Fix invalid cluster op test
---
mlir/test/Dialect/Mesh/invalid.mlir | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/mlir/test/Dialect/Mesh/invalid.mlir b/mlir/test/Dialect/Mesh/invalid.mlir
index dc91ff54fd839..875fd1839d3d1 100644
--- a/mlir/test/Dialect/Mesh/invalid.mlir
+++ b/mlir/test/Dialect/Mesh/invalid.mlir
@@ -10,7 +10,7 @@ mesh.cluster @mesh0(rank = 2, dim_sizes = 2x3x4)
// -----
-// expected-error at +1 {{unexpected error: custom op 'mesh.cluster' Failed parsing shape}}
+// expected-error at +1 {{custom op 'mesh.cluster' Failed parsing shape}}
mesh.cluster @mesh0(rank = 2, dim_sizes = -1)
// -----
>From 0971cd7a5ef6ffb5a572d772a58aada23f4bb2f4 Mon Sep 17 00:00:00 2001
From: Boian Petkantchin <boian.petkantchin at amd.com>
Date: Thu, 30 Nov 2023 08:50:17 -0800
Subject: [PATCH 3/7] Refactor out a common printShape function
---
mlir/include/mlir/IR/OpImplementation.h | 17 +++++++++++++++++
mlir/lib/IR/AsmPrinter.cpp | 16 ++++------------
mlir/lib/IR/BuiltinAttributes.cpp | 12 +-----------
3 files changed, 22 insertions(+), 23 deletions(-)
diff --git a/mlir/include/mlir/IR/OpImplementation.h b/mlir/include/mlir/IR/OpImplementation.h
index f1fabf95a68b7..ca74ac2592a1d 100644
--- a/mlir/include/mlir/IR/OpImplementation.h
+++ b/mlir/include/mlir/IR/OpImplementation.h
@@ -13,12 +13,16 @@
#ifndef MLIR_IR_OPIMPLEMENTATION_H
#define MLIR_IR_OPIMPLEMENTATION_H
+#include "mlir/IR/BuiltinTypeInterfaces.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/DialectInterface.h"
#include "mlir/IR/OpDefinition.h"
+#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/Twine.h"
#include "llvm/Support/SMLoc.h"
+#include "llvm/Support/raw_ostream.h"
#include <optional>
+#include <type_traits>
namespace mlir {
class AsmParsedResourceEntry;
@@ -1762,6 +1766,19 @@ class OpAsmDialectInterface
const SetVector<AsmDialectResourceHandle> &referencedResources,
AsmResourceBuilder &builder) const {}
};
+
+template <typename Range>
+void printShape(raw_ostream& stream, Range&& shape) {
+ for (auto [idx, dimSize] : llvm::enumerate(shape)) {
+ if (ShapedType::isDynamic(dimSize))
+ stream << "?";
+ else
+ stream << dimSize;
+ if (static_cast<std::decay_t<decltype(range_size(shape))>>(idx) != range_size(shape) - 1)
+ stream << "x";
+ }
+}
+
} // namespace mlir
//===--------------------------------------------------------------------===//
diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index 4b76dcf7f8a9f..a69ee19690d0b 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -2576,13 +2576,9 @@ void AsmPrinter::Impl::printTypeImpl(Type type) {
})
.Case<RankedTensorType>([&](RankedTensorType tensorTy) {
os << "tensor<";
- for (int64_t dim : tensorTy.getShape()) {
- if (ShapedType::isDynamic(dim))
- os << '?';
- else
- os << dim;
+ printShape(os, tensorTy.getShape());
+ if (!tensorTy.getShape().empty())
os << 'x';
- }
printType(tensorTy.getElementType());
// Only print the encoding attribute value if set.
if (tensorTy.getEncoding()) {
@@ -2598,13 +2594,9 @@ void AsmPrinter::Impl::printTypeImpl(Type type) {
})
.Case<MemRefType>([&](MemRefType memrefTy) {
os << "memref<";
- for (int64_t dim : memrefTy.getShape()) {
- if (ShapedType::isDynamic(dim))
- os << '?';
- else
- os << dim;
+ printShape(os, memrefTy.getShape());
+ if (!memrefTy.getShape().empty())
os << 'x';
- }
printType(memrefTy.getElementType());
MemRefLayoutAttrInterface layout = memrefTy.getLayout();
if (!llvm::isa<AffineMapAttr>(layout) || !layout.isIdentity()) {
diff --git a/mlir/lib/IR/BuiltinAttributes.cpp b/mlir/lib/IR/BuiltinAttributes.cpp
index 1396cd6a90ffc..bd2693b26a7dc 100644
--- a/mlir/lib/IR/BuiltinAttributes.cpp
+++ b/mlir/lib/IR/BuiltinAttributes.cpp
@@ -1828,17 +1828,7 @@ void mlir::printShape(OpAsmPrinter &printer, Operation *op,
ArrayRef<int64_t> shape) {
if (!shape.empty())
printer << "[";
-
- for (size_t i = 0; i < shape.size(); ++i) {
- if (ShapedType::isDynamic(shape[i]))
- printer << '?';
- else
- printer << shape[i];
- if (i != shape.size() - 1) {
- printer << 'x';
- }
- }
-
+ printShape(printer.getStream(), shape);
if (!shape.empty())
printer << "]";
}
>From 2beda8ef118a3e82d4caac60050bf9132a99b0a8 Mon Sep 17 00:00:00 2001
From: Boian Petkantchin <boian.petkantchin at amd.com>
Date: Thu, 30 Nov 2023 10:06:49 -0800
Subject: [PATCH 4/7] Fix formatting
---
mlir/include/mlir/IR/OpImplementation.h | 5 +++--
1 file changed, 3 insertions(+), 2 deletions(-)
diff --git a/mlir/include/mlir/IR/OpImplementation.h b/mlir/include/mlir/IR/OpImplementation.h
index ca74ac2592a1d..3be75d4517516 100644
--- a/mlir/include/mlir/IR/OpImplementation.h
+++ b/mlir/include/mlir/IR/OpImplementation.h
@@ -1768,13 +1768,14 @@ class OpAsmDialectInterface
};
template <typename Range>
-void printShape(raw_ostream& stream, Range&& shape) {
+void printShape(raw_ostream &stream, Range &&shape) {
for (auto [idx, dimSize] : llvm::enumerate(shape)) {
if (ShapedType::isDynamic(dimSize))
stream << "?";
else
stream << dimSize;
- if (static_cast<std::decay_t<decltype(range_size(shape))>>(idx) != range_size(shape) - 1)
+ if (static_cast<std::decay_t<decltype(range_size(shape))>>(idx) !=
+ range_size(shape) - 1)
stream << "x";
}
}
>From a0ca14c06f45c1ab75a62df5a15431d1b2b77a80 Mon Sep 17 00:00:00 2001
From: Boian Petkantchin <boian.petkantchin at amd.com>
Date: Fri, 1 Dec 2023 09:44:26 -0800
Subject: [PATCH 5/7] Fix canonicalDimSizes in mesh.cluster
---
mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
index 83139035eda52..85fe9a8ef9949 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
@@ -89,7 +89,7 @@ def Mesh_ClusterOp : Mesh_Op<"cluster", [Symbol]> {
template <typename OutIt>
void canonicalDimSizes(OutIt outIt) {
std::copy(getDimSizes().begin(), getDimSizes().end(), outIt);
- std::fill_n(outIt, getRank() - getDimSizes().size(), 0);
+ std::fill_n(outIt, getRank() - getDimSizes().size(), ::mlir::ShapedType::kDynamic);
}
}];
let hasVerifier = 1;
>From adec6dadaef0561a7e7f9af8b2250bb028feb0f6 Mon Sep 17 00:00:00 2001
From: Boian Petkantchin <boian.petkantchin at amd.com>
Date: Fri, 1 Dec 2023 10:11:28 -0800
Subject: [PATCH 6/7] Use llvm::interleave in printShape
---
mlir/include/mlir/IR/OpImplementation.h | 18 +++++++++---------
1 file changed, 9 insertions(+), 9 deletions(-)
diff --git a/mlir/include/mlir/IR/OpImplementation.h b/mlir/include/mlir/IR/OpImplementation.h
index 3be75d4517516..cdc780081668e 100644
--- a/mlir/include/mlir/IR/OpImplementation.h
+++ b/mlir/include/mlir/IR/OpImplementation.h
@@ -1769,15 +1769,15 @@ class OpAsmDialectInterface
template <typename Range>
void printShape(raw_ostream &stream, Range &&shape) {
- for (auto [idx, dimSize] : llvm::enumerate(shape)) {
- if (ShapedType::isDynamic(dimSize))
- stream << "?";
- else
- stream << dimSize;
- if (static_cast<std::decay_t<decltype(range_size(shape))>>(idx) !=
- range_size(shape) - 1)
- stream << "x";
- }
+ llvm::interleave(
+ shape, stream,
+ [&stream](const auto &dimSize) {
+ if (ShapedType::isDynamic(dimSize))
+ stream << "?";
+ else
+ stream << dimSize;
+ },
+ "x");
}
} // namespace mlir
>From 0336f933319e61d258bdd7c1f8102437ace200fe Mon Sep 17 00:00:00 2001
From: Boian Petkantchin <boian.petkantchin at amd.com>
Date: Mon, 4 Dec 2023 10:37:04 -0800
Subject: [PATCH 7/7] Move printShape and parseShape to OpImplementation.h
---
mlir/include/mlir/IR/BuiltinAttributes.h | 5 --
mlir/include/mlir/IR/OpImplementation.h | 28 +++++-----
mlir/lib/IR/AsmPrinter.cpp | 67 +++++++++++++++++++++++-
mlir/lib/IR/BuiltinAttributes.cpp | 31 -----------
4 files changed, 77 insertions(+), 54 deletions(-)
diff --git a/mlir/include/mlir/IR/BuiltinAttributes.h b/mlir/include/mlir/IR/BuiltinAttributes.h
index 9261c82eae5b7..c8161604aad35 100644
--- a/mlir/include/mlir/IR/BuiltinAttributes.h
+++ b/mlir/include/mlir/IR/BuiltinAttributes.h
@@ -28,8 +28,6 @@ class FunctionType;
class IntegerSet;
class IntegerType;
class Location;
-class OpAsmParser;
-class OpAsmPrinter;
class Operation;
class RankedTensorType;
@@ -1101,9 +1099,6 @@ namespace mlir {
AffineMap makeStridedLinearLayoutMap(ArrayRef<int64_t> strides, int64_t offset,
MLIRContext *context);
-void printShape(OpAsmPrinter &printer, Operation *op, ArrayRef<int64_t> shape);
-ParseResult parseShape(OpAsmParser &parser, DenseI64ArrayAttr &shape);
-
} // namespace mlir
namespace llvm {
diff --git a/mlir/include/mlir/IR/OpImplementation.h b/mlir/include/mlir/IR/OpImplementation.h
index cdc780081668e..057c9dbca7655 100644
--- a/mlir/include/mlir/IR/OpImplementation.h
+++ b/mlir/include/mlir/IR/OpImplementation.h
@@ -13,16 +13,12 @@
#ifndef MLIR_IR_OPIMPLEMENTATION_H
#define MLIR_IR_OPIMPLEMENTATION_H
-#include "mlir/IR/BuiltinTypeInterfaces.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/DialectInterface.h"
#include "mlir/IR/OpDefinition.h"
-#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/Twine.h"
#include "llvm/Support/SMLoc.h"
-#include "llvm/Support/raw_ostream.h"
#include <optional>
-#include <type_traits>
namespace mlir {
class AsmParsedResourceEntry;
@@ -230,6 +226,8 @@ class AsmPrinter {
printArrowTypeList(results);
}
+ void printDimensionList(ArrayRef<int64_t> shape);
+
/// Class used to automatically end a cyclic region on destruction.
class CyclicPrintReset {
public:
@@ -776,6 +774,9 @@ class AsmParser {
return parseCommaSeparatedList(Delimiter::None, parseElementFn);
}
+ template <typename OutIt>
+ ParseResult parseShape(SmallVectorImpl<int64_t> &dims);
+
//===--------------------------------------------------------------------===//
// Keyword Parsing
//===--------------------------------------------------------------------===//
@@ -1767,18 +1768,13 @@ class OpAsmDialectInterface
AsmResourceBuilder &builder) const {}
};
-template <typename Range>
-void printShape(raw_ostream &stream, Range &&shape) {
- llvm::interleave(
- shape, stream,
- [&stream](const auto &dimSize) {
- if (ShapedType::isDynamic(dimSize))
- stream << "?";
- else
- stream << dimSize;
- },
- "x");
-}
+//===--------------------------------------------------------------------===//
+// Custom attribute printers and parsers.
+//===--------------------------------------------------------------------===//
+
+// Handles custom<Shape>(...) in TableGen.
+void printShape(OpAsmPrinter &printer, Operation *op, ArrayRef<int64_t> shape);
+ParseResult parseShape(OpAsmParser &parser, DenseI64ArrayAttr &shape);
} // namespace mlir
diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index a69ee19690d0b..8a54d594fe1e7 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -16,7 +16,9 @@
#include "mlir/IR/AsmState.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinDialect.h"
+#include "mlir/IR/BuiltinTypeInterfaces.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/DialectImplementation.h"
@@ -44,6 +46,7 @@
#include "llvm/Support/SaveAndRestore.h"
#include "llvm/Support/Threading.h"
#include "llvm/Support/raw_ostream.h"
+#include <type_traits>
#include <optional>
#include <tuple>
@@ -425,6 +428,8 @@ class AsmPrinter::Impl {
void popCyclicPrinting();
+ void printDimensionList(ArrayRef<int64_t> shape);
+
protected:
void printOptionalAttrDict(ArrayRef<NamedAttribute> attrs,
ArrayRef<StringRef> elidedAttrs = {},
@@ -1860,6 +1865,20 @@ class AsmStateImpl {
// Allow direct access to the impl fields.
friend AsmState;
};
+
+template <typename Range>
+void printDimensionList(raw_ostream &stream, Range &&shape) {
+ llvm::interleave(
+ shape, stream,
+ [&stream](const auto &dimSize) {
+ if (ShapedType::isDynamic(dimSize))
+ stream << "?";
+ else
+ stream << dimSize;
+ },
+ "x");
+}
+
} // namespace detail
} // namespace mlir
@@ -2576,7 +2595,7 @@ void AsmPrinter::Impl::printTypeImpl(Type type) {
})
.Case<RankedTensorType>([&](RankedTensorType tensorTy) {
os << "tensor<";
- printShape(os, tensorTy.getShape());
+ printDimensionList(tensorTy.getShape());
if (!tensorTy.getShape().empty())
os << 'x';
printType(tensorTy.getElementType());
@@ -2594,7 +2613,7 @@ void AsmPrinter::Impl::printTypeImpl(Type type) {
})
.Case<MemRefType>([&](MemRefType memrefTy) {
os << "memref<";
- printShape(os, memrefTy.getShape());
+ printDimensionList(memrefTy.getShape());
if (!memrefTy.getShape().empty())
os << 'x';
printType(memrefTy.getElementType());
@@ -2727,6 +2746,10 @@ LogicalResult AsmPrinter::Impl::pushCyclicPrinting(const void *opaquePointer) {
void AsmPrinter::Impl::popCyclicPrinting() { state.popCyclicPrinting(); }
+void AsmPrinter::Impl::printDimensionList(ArrayRef<int64_t> shape) {
+ detail::printDimensionList(os, shape);
+}
+
//===--------------------------------------------------------------------===//
// AsmPrinter
//===--------------------------------------------------------------------===//
@@ -2792,6 +2815,10 @@ void AsmPrinter::printResourceHandle(const AsmDialectResourceHandle &resource) {
impl->printResourceHandle(resource);
}
+void AsmPrinter::printDimensionList(ArrayRef<int64_t> shape) {
+ detail::printDimensionList(getStream(), shape);
+}
+
LogicalResult AsmPrinter::pushCyclicPrinting(const void *opaquePointer) {
return impl->pushCyclicPrinting(opaquePointer);
}
@@ -3903,3 +3930,39 @@ void Block::printAsOperand(raw_ostream &os, AsmState &state) {
OperationPrinter printer(os, state.getImpl());
printer.printBlockName(this);
}
+
+//===--------------------------------------------------------------------===//
+// Custom attribute printers and parsers.
+//===--------------------------------------------------------------------===//
+namespace mlir {
+
+void printShape(OpAsmPrinter &printer, Operation *op, ArrayRef<int64_t> shape) {
+ if (!shape.empty())
+ printer << "[";
+ printer.printDimensionList(shape);
+ if (!shape.empty())
+ printer << "]";
+}
+
+ParseResult parseShape(OpAsmParser &parser, DenseI64ArrayAttr &shape) {
+ bool hasOpeningSquare = succeeded(parser.parseOptionalLSquare());
+ SmallVector<int64_t> shapeArr;
+ if (failed(parser.parseDimensionList(shapeArr, true, false))) {
+ return parser.emitError(parser.getCurrentLocation())
+ << "Failed parsing shape.";
+ }
+ if (shapeArr.empty() && !hasOpeningSquare) {
+ return parser.emitError(parser.getCurrentLocation())
+ << "Failed parsing shape. Did you mean a 0-rank shape? It must be "
+ "denoted by \"[]\".";
+ }
+ if (hasOpeningSquare && failed(parser.parseRSquare())) {
+ return parser.emitError(parser.getCurrentLocation())
+ << "Failed parsing shape.";
+ }
+
+ shape = DenseI64ArrayAttr::get(parser.getContext(), shapeArr);
+ return success();
+}
+
+} // namespace mlir
\ No newline at end of file
diff --git a/mlir/lib/IR/BuiltinAttributes.cpp b/mlir/lib/IR/BuiltinAttributes.cpp
index bd2693b26a7dc..89b1ed67f5d06 100644
--- a/mlir/lib/IR/BuiltinAttributes.cpp
+++ b/mlir/lib/IR/BuiltinAttributes.cpp
@@ -17,7 +17,6 @@
#include "mlir/IR/Operation.h"
#include "mlir/IR/SymbolTable.h"
#include "mlir/IR/Types.h"
-#include "mlir/Support/LogicalResult.h"
#include "llvm/ADT/APSInt.h"
#include "llvm/ADT/Sequence.h"
#include "llvm/ADT/TypeSwitch.h"
@@ -1823,33 +1822,3 @@ AffineMap mlir::makeStridedLinearLayoutMap(ArrayRef<int64_t> strides,
return AffineMap::get(strides.size(), nSymbols, expr);
}
-
-void mlir::printShape(OpAsmPrinter &printer, Operation *op,
- ArrayRef<int64_t> shape) {
- if (!shape.empty())
- printer << "[";
- printShape(printer.getStream(), shape);
- if (!shape.empty())
- printer << "]";
-}
-
-ParseResult mlir::parseShape(OpAsmParser &parser, DenseI64ArrayAttr &shape) {
- bool hasOpeningSquare = succeeded(parser.parseOptionalLSquare());
- SmallVector<int64_t> shapeArr;
- if (failed(parser.parseDimensionList(shapeArr, true, false))) {
- return parser.emitError(parser.getCurrentLocation())
- << "Failed parsing shape.";
- }
- if (shapeArr.empty() && !hasOpeningSquare) {
- return parser.emitError(parser.getCurrentLocation())
- << "Failed parsing shape. Did you mean a 0-rank shape? It must be "
- "denoted by \"[]\".";
- }
- if (hasOpeningSquare && failed(parser.parseRSquare())) {
- return parser.emitError(parser.getCurrentLocation())
- << "Failed parsing shape.";
- }
-
- shape = DenseI64ArrayAttr::get(parser.getContext(), shapeArr);
- return success();
-}
More information about the Mlir-commits
mailing list