[Mlir-commits] [mlir] [MLIR] reverse int8 type's printing logic (PR #69361)
Chengji Yao
llvmlistbot at llvm.org
Tue Oct 17 15:22:12 PDT 2023
https://github.com/yaochengji updated https://github.com/llvm/llvm-project/pull/69361
>From 8a632f815645e12b553f341fcaca517a99bf00c1 Mon Sep 17 00:00:00 2001
From: Chengji Yao <yaochengji at hotmial.com>
Date: Tue, 17 Oct 2023 17:48:28 +0000
Subject: [PATCH 1/2] [MLIR] reverse int8 type's printing logic
---
mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td | 4 ++--
mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td | 2 +-
mlir/include/mlir/IR/OpImplementation.h | 14 +-------------
mlir/lib/Dialect/Mesh/IR/MeshOps.cpp | 16 ++++++++--------
4 files changed, 12 insertions(+), 24 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td b/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td
index d761743a82bf86b..867c98078ae5171 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td
@@ -58,8 +58,8 @@ def MeshSharding : AttrDef<Mesh_Dialect, "MeshSharding"> {
let parameters = (ins
AttrParameter<"::mlir::SymbolRefAttr", "cluster placed">:$cluster,
- ArrayRefParameter<"::mlir::DenseI8ArrayAttr">:$split_axes,
- OptionalArrayRefParameter<"int8_t">:$partial_axes,
+ ArrayRefParameter<"::mlir::DenseI64ArrayAttr">:$split_axes,
+ OptionalArrayRefParameter<"int64_t">:$partial_axes,
OptionalParameter<"::mlir::mesh::Partial">:$partial_type
);
diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
index 8ca4b6653104221..a8aa0a694bee29f 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
@@ -70,7 +70,7 @@ def Mesh_ClusterOp : Mesh_Op<"cluster", [Symbol]> {
}];
let arguments = (ins
SymbolNameAttr:$sym_name,
- I8Attr:$rank,
+ I64Attr:$rank,
DefaultValuedAttr<DenseI64ArrayAttr, "{}">:$dim_sizes
);
let assemblyFormat = [{
diff --git a/mlir/include/mlir/IR/OpImplementation.h b/mlir/include/mlir/IR/OpImplementation.h
index 379392ace46961a..f1fabf95a68b7ad 100644
--- a/mlir/include/mlir/IR/OpImplementation.h
+++ b/mlir/include/mlir/IR/OpImplementation.h
@@ -350,8 +350,7 @@ template <typename AsmPrinterT, typename T,
!std::is_convertible<T &, Attribute &>::value &&
!std::is_convertible<T &, ValueRange>::value &&
!std::is_convertible<T &, APFloat &>::value &&
- !llvm::is_one_of<T, bool, int8_t, uint8_t, float,
- double>::value,
+ !llvm::is_one_of<T, bool, float, double>::value,
T> * = nullptr>
inline std::enable_if_t<std::is_base_of<AsmPrinter, AsmPrinterT>::value,
AsmPrinterT &>
@@ -367,17 +366,6 @@ operator<<(AsmPrinterT &p, bool value) {
return p << (value ? StringRef("true") : "false");
}
-/// Specialization for 8-bit integers to ensure values are printed as integers
-// and not characters.
-template <
- typename AsmPrinterT, typename T,
- std::enable_if_t<llvm::is_one_of<T, int8_t, uint8_t>::value, T> * = nullptr>
-inline std::enable_if_t<std::is_base_of<AsmPrinter, AsmPrinterT>::value,
- AsmPrinterT &>
-operator<<(AsmPrinterT &p, T value) {
- return p << static_cast<int16_t>(value);
-}
-
template <typename AsmPrinterT, typename ValueRangeT>
inline std::enable_if_t<std::is_base_of<AsmPrinter, AsmPrinterT>::value,
AsmPrinterT &>
diff --git a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
index b2a47102528758c..e8dc14cf0fa9c04 100644
--- a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
+++ b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
@@ -47,7 +47,7 @@ Operation *MeshDialect::materializeConstant(OpBuilder &builder, Attribute value,
LogicalResult ClusterOp::verify() {
ArrayRef<int64_t> dimSizes = getDimSizes();
- uint8_t rank = getRank();
+ uint64_t rank = getRank();
if (rank == 0)
return emitOpError("rank of cluster is expected to be a positive integer");
@@ -71,15 +71,15 @@ LogicalResult ClusterOp::verify() {
LogicalResult
MeshShardingAttr::verify(function_ref<InFlightDiagnostic()> emitError,
- SymbolRefAttr, ArrayRef<DenseI8ArrayAttr> splitAxes,
- ArrayRef<int8_t> partialAxes, Partial) {
+ SymbolRefAttr, ArrayRef<DenseI64ArrayAttr> splitAxes,
+ ArrayRef<int64_t> partialAxes, Partial) {
// TODO: At present cluster symbol ref is not verified. This is due to the
// difficulty in fetching the corresponding symbol op based on an attribute.
- llvm::SmallSet<int8_t, 4> visitedAxes;
+ llvm::SmallSet<int64_t, 4> visitedAxes;
- auto checkMeshAxis = [&](ArrayRef<int8_t> axesArray) -> LogicalResult {
- for (int8_t axis : axesArray) {
+ auto checkMeshAxis = [&](ArrayRef<int64_t> axesArray) -> LogicalResult {
+ for (int64_t axis : axesArray) {
if (axis < 0)
return emitError() << "mesh axis is expected to be non-negative";
if (!visitedAxes.insert(axis).second)
@@ -88,8 +88,8 @@ MeshShardingAttr::verify(function_ref<InFlightDiagnostic()> emitError,
return success();
};
- for (DenseI8ArrayAttr subAxes : splitAxes) {
- ArrayRef<int8_t> subAxesArray = subAxes.asArrayRef();
+ for (DenseI64ArrayAttr subAxes : splitAxes) {
+ ArrayRef<int64_t> subAxesArray = subAxes.asArrayRef();
if (failed(checkMeshAxis(subAxesArray)))
return failure();
}
>From 8a45b1ded95596ea6ad9aa72a844f237577da86b Mon Sep 17 00:00:00 2001
From: Chengji Yao <yaochengji at hotmial.com>
Date: Tue, 17 Oct 2023 22:19:26 +0000
Subject: [PATCH 2/2] fix comment
---
mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td | 4 ++--
mlir/lib/Dialect/Mesh/IR/MeshOps.cpp | 14 +++++++-------
2 files changed, 9 insertions(+), 9 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td b/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td
index 867c98078ae5171..39d24595ec1c446 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td
@@ -58,8 +58,8 @@ def MeshSharding : AttrDef<Mesh_Dialect, "MeshSharding"> {
let parameters = (ins
AttrParameter<"::mlir::SymbolRefAttr", "cluster placed">:$cluster,
- ArrayRefParameter<"::mlir::DenseI64ArrayAttr">:$split_axes,
- OptionalArrayRefParameter<"int64_t">:$partial_axes,
+ ArrayRefParameter<"::mlir::DenseI32ArrayAttr">:$split_axes,
+ OptionalArrayRefParameter<"int32_t">:$partial_axes,
OptionalParameter<"::mlir::mesh::Partial">:$partial_type
);
diff --git a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
index e8dc14cf0fa9c04..fc91fd994f12dc2 100644
--- a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
+++ b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
@@ -71,15 +71,15 @@ LogicalResult ClusterOp::verify() {
LogicalResult
MeshShardingAttr::verify(function_ref<InFlightDiagnostic()> emitError,
- SymbolRefAttr, ArrayRef<DenseI64ArrayAttr> splitAxes,
- ArrayRef<int64_t> partialAxes, Partial) {
+ SymbolRefAttr, ArrayRef<DenseI32ArrayAttr> splitAxes,
+ ArrayRef<int32_t> partialAxes, Partial) {
// TODO: At present cluster symbol ref is not verified. This is due to the
// difficulty in fetching the corresponding symbol op based on an attribute.
- llvm::SmallSet<int64_t, 4> visitedAxes;
+ llvm::SmallSet<int32_t, 4> visitedAxes;
- auto checkMeshAxis = [&](ArrayRef<int64_t> axesArray) -> LogicalResult {
- for (int64_t axis : axesArray) {
+ auto checkMeshAxis = [&](ArrayRef<int32_t> axesArray) -> LogicalResult {
+ for (int32_t axis : axesArray) {
if (axis < 0)
return emitError() << "mesh axis is expected to be non-negative";
if (!visitedAxes.insert(axis).second)
@@ -88,8 +88,8 @@ MeshShardingAttr::verify(function_ref<InFlightDiagnostic()> emitError,
return success();
};
- for (DenseI64ArrayAttr subAxes : splitAxes) {
- ArrayRef<int64_t> subAxesArray = subAxes.asArrayRef();
+ for (DenseI32ArrayAttr subAxes : splitAxes) {
+ ArrayRef<int32_t> subAxesArray = subAxes.asArrayRef();
if (failed(checkMeshAxis(subAxesArray)))
return failure();
}
More information about the Mlir-commits
mailing list