[Mlir-commits] [mlir] 1a21196 - [MLIR] reverse int8 type's printing logic (#69361)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Oct 18 10:30:18 PDT 2023
Author: Chengji Yao
Date: 2023-10-18T10:30:13-07:00
New Revision: 1a21196b9a53d43cabb3db4ac9473e35f3cbb21b
URL: https://github.com/llvm/llvm-project/commit/1a21196b9a53d43cabb3db4ac9473e35f3cbb21b
DIFF: https://github.com/llvm/llvm-project/commit/1a21196b9a53d43cabb3db4ac9473e35f3cbb21b.diff
LOG: [MLIR] reverse int8 type's printing logic (#69361)
Specializing for 8-bit integers to ensure values are printed as integers
Fixes #69310
Added:
Modified:
mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td
mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
mlir/include/mlir/IR/OpImplementation.h
mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td b/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td
index d761743a82bf86b..39d24595ec1c446 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td
@@ -58,8 +58,8 @@ def MeshSharding : AttrDef<Mesh_Dialect, "MeshSharding"> {
let parameters = (ins
AttrParameter<"::mlir::SymbolRefAttr", "cluster placed">:$cluster,
- ArrayRefParameter<"::mlir::DenseI8ArrayAttr">:$split_axes,
- OptionalArrayRefParameter<"int8_t">:$partial_axes,
+ ArrayRefParameter<"::mlir::DenseI32ArrayAttr">:$split_axes,
+ OptionalArrayRefParameter<"int32_t">:$partial_axes,
OptionalParameter<"::mlir::mesh::Partial">:$partial_type
);
diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
index 8ca4b6653104221..a8aa0a694bee29f 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
@@ -70,7 +70,7 @@ def Mesh_ClusterOp : Mesh_Op<"cluster", [Symbol]> {
}];
let arguments = (ins
SymbolNameAttr:$sym_name,
- I8Attr:$rank,
+ I64Attr:$rank,
DefaultValuedAttr<DenseI64ArrayAttr, "{}">:$dim_sizes
);
let assemblyFormat = [{
diff --git a/mlir/include/mlir/IR/OpImplementation.h b/mlir/include/mlir/IR/OpImplementation.h
index 379392ace46961a..f1fabf95a68b7ad 100644
--- a/mlir/include/mlir/IR/OpImplementation.h
+++ b/mlir/include/mlir/IR/OpImplementation.h
@@ -350,8 +350,7 @@ template <typename AsmPrinterT, typename T,
!std::is_convertible<T &, Attribute &>::value &&
!std::is_convertible<T &, ValueRange>::value &&
!std::is_convertible<T &, APFloat &>::value &&
- !llvm::is_one_of<T, bool, int8_t, uint8_t, float,
- double>::value,
+ !llvm::is_one_of<T, bool, float, double>::value,
T> * = nullptr>
inline std::enable_if_t<std::is_base_of<AsmPrinter, AsmPrinterT>::value,
AsmPrinterT &>
@@ -367,17 +366,6 @@ operator<<(AsmPrinterT &p, bool value) {
return p << (value ? StringRef("true") : "false");
}
-/// Specialization for 8-bit integers to ensure values are printed as integers
-// and not characters.
-template <
- typename AsmPrinterT, typename T,
- std::enable_if_t<llvm::is_one_of<T, int8_t, uint8_t>::value, T> * = nullptr>
-inline std::enable_if_t<std::is_base_of<AsmPrinter, AsmPrinterT>::value,
- AsmPrinterT &>
-operator<<(AsmPrinterT &p, T value) {
- return p << static_cast<int16_t>(value);
-}
-
template <typename AsmPrinterT, typename ValueRangeT>
inline std::enable_if_t<std::is_base_of<AsmPrinter, AsmPrinterT>::value,
AsmPrinterT &>
diff --git a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
index b2a47102528758c..fc91fd994f12dc2 100644
--- a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
+++ b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
@@ -47,7 +47,7 @@ Operation *MeshDialect::materializeConstant(OpBuilder &builder, Attribute value,
LogicalResult ClusterOp::verify() {
ArrayRef<int64_t> dimSizes = getDimSizes();
- uint8_t rank = getRank();
+ uint64_t rank = getRank();
if (rank == 0)
return emitOpError("rank of cluster is expected to be a positive integer");
@@ -71,15 +71,15 @@ LogicalResult ClusterOp::verify() {
LogicalResult
MeshShardingAttr::verify(function_ref<InFlightDiagnostic()> emitError,
- SymbolRefAttr, ArrayRef<DenseI8ArrayAttr> splitAxes,
- ArrayRef<int8_t> partialAxes, Partial) {
+ SymbolRefAttr, ArrayRef<DenseI32ArrayAttr> splitAxes,
+ ArrayRef<int32_t> partialAxes, Partial) {
// TODO: At present cluster symbol ref is not verified. This is due to the
//
diff iculty in fetching the corresponding symbol op based on an attribute.
- llvm::SmallSet<int8_t, 4> visitedAxes;
+ llvm::SmallSet<int32_t, 4> visitedAxes;
- auto checkMeshAxis = [&](ArrayRef<int8_t> axesArray) -> LogicalResult {
- for (int8_t axis : axesArray) {
+ auto checkMeshAxis = [&](ArrayRef<int32_t> axesArray) -> LogicalResult {
+ for (int32_t axis : axesArray) {
if (axis < 0)
return emitError() << "mesh axis is expected to be non-negative";
if (!visitedAxes.insert(axis).second)
@@ -88,8 +88,8 @@ MeshShardingAttr::verify(function_ref<InFlightDiagnostic()> emitError,
return success();
};
- for (DenseI8ArrayAttr subAxes : splitAxes) {
- ArrayRef<int8_t> subAxesArray = subAxes.asArrayRef();
+ for (DenseI32ArrayAttr subAxes : splitAxes) {
+ ArrayRef<int32_t> subAxesArray = subAxes.asArrayRef();
if (failed(checkMeshAxis(subAxesArray)))
return failure();
}
More information about the Mlir-commits
mailing list