[Mlir-commits] [mlir] [MLIR] reverse int8 type's printing logic (PR #69361)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Oct 17 10:57:02 PDT 2023


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Chengji Yao (yaochengji)

<details>
<summary>Changes</summary>

Specializing for 8-bit integers to ensure values are printed as integers in a generic way will cause a bug., see https://github.com/llvm/llvm-project/issues/69310

---
Full diff: https://github.com/llvm/llvm-project/pull/69361.diff


4 Files Affected:

- (modified) mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td (+2-2) 
- (modified) mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td (+1-1) 
- (modified) mlir/include/mlir/IR/OpImplementation.h (+1-13) 
- (modified) mlir/lib/Dialect/Mesh/IR/MeshOps.cpp (+8-8) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td b/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td
index d761743a82bf86b..867c98078ae5171 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td
@@ -58,8 +58,8 @@ def MeshSharding : AttrDef<Mesh_Dialect, "MeshSharding"> {
 
   let parameters = (ins
     AttrParameter<"::mlir::SymbolRefAttr", "cluster placed">:$cluster,
-    ArrayRefParameter<"::mlir::DenseI8ArrayAttr">:$split_axes,
-    OptionalArrayRefParameter<"int8_t">:$partial_axes,
+    ArrayRefParameter<"::mlir::DenseI64ArrayAttr">:$split_axes,
+    OptionalArrayRefParameter<"int64_t">:$partial_axes,
     OptionalParameter<"::mlir::mesh::Partial">:$partial_type
   );
 
diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
index 8ca4b6653104221..a8aa0a694bee29f 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
@@ -70,7 +70,7 @@ def Mesh_ClusterOp : Mesh_Op<"cluster", [Symbol]> {
   }];
   let arguments = (ins
     SymbolNameAttr:$sym_name,
-    I8Attr:$rank,
+    I64Attr:$rank,
     DefaultValuedAttr<DenseI64ArrayAttr, "{}">:$dim_sizes
   );
   let assemblyFormat = [{
diff --git a/mlir/include/mlir/IR/OpImplementation.h b/mlir/include/mlir/IR/OpImplementation.h
index 379392ace46961a..f1fabf95a68b7ad 100644
--- a/mlir/include/mlir/IR/OpImplementation.h
+++ b/mlir/include/mlir/IR/OpImplementation.h
@@ -350,8 +350,7 @@ template <typename AsmPrinterT, typename T,
                                !std::is_convertible<T &, Attribute &>::value &&
                                !std::is_convertible<T &, ValueRange>::value &&
                                !std::is_convertible<T &, APFloat &>::value &&
-                               !llvm::is_one_of<T, bool, int8_t, uint8_t, float,
-                                                double>::value,
+                               !llvm::is_one_of<T, bool, float, double>::value,
                            T> * = nullptr>
 inline std::enable_if_t<std::is_base_of<AsmPrinter, AsmPrinterT>::value,
                         AsmPrinterT &>
@@ -367,17 +366,6 @@ operator<<(AsmPrinterT &p, bool value) {
   return p << (value ? StringRef("true") : "false");
 }
 
-/// Specialization for 8-bit integers to ensure values are printed as integers
-// and not characters.
-template <
-    typename AsmPrinterT, typename T,
-    std::enable_if_t<llvm::is_one_of<T, int8_t, uint8_t>::value, T> * = nullptr>
-inline std::enable_if_t<std::is_base_of<AsmPrinter, AsmPrinterT>::value,
-                        AsmPrinterT &>
-operator<<(AsmPrinterT &p, T value) {
-  return p << static_cast<int16_t>(value);
-}
-
 template <typename AsmPrinterT, typename ValueRangeT>
 inline std::enable_if_t<std::is_base_of<AsmPrinter, AsmPrinterT>::value,
                         AsmPrinterT &>
diff --git a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
index b2a47102528758c..e8dc14cf0fa9c04 100644
--- a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
+++ b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
@@ -47,7 +47,7 @@ Operation *MeshDialect::materializeConstant(OpBuilder &builder, Attribute value,
 
 LogicalResult ClusterOp::verify() {
   ArrayRef<int64_t> dimSizes = getDimSizes();
-  uint8_t rank = getRank();
+  uint64_t rank = getRank();
 
   if (rank == 0)
     return emitOpError("rank of cluster is expected to be a positive integer");
@@ -71,15 +71,15 @@ LogicalResult ClusterOp::verify() {
 
 LogicalResult
 MeshShardingAttr::verify(function_ref<InFlightDiagnostic()> emitError,
-                         SymbolRefAttr, ArrayRef<DenseI8ArrayAttr> splitAxes,
-                         ArrayRef<int8_t> partialAxes, Partial) {
+                         SymbolRefAttr, ArrayRef<DenseI64ArrayAttr> splitAxes,
+                         ArrayRef<int64_t> partialAxes, Partial) {
   // TODO: At present cluster symbol ref is not verified. This is due to the
   // difficulty in fetching the corresponding symbol op based on an attribute.
 
-  llvm::SmallSet<int8_t, 4> visitedAxes;
+  llvm::SmallSet<int64_t, 4> visitedAxes;
 
-  auto checkMeshAxis = [&](ArrayRef<int8_t> axesArray) -> LogicalResult {
-    for (int8_t axis : axesArray) {
+  auto checkMeshAxis = [&](ArrayRef<int64_t> axesArray) -> LogicalResult {
+    for (int64_t axis : axesArray) {
       if (axis < 0)
         return emitError() << "mesh axis is expected to be non-negative";
       if (!visitedAxes.insert(axis).second)
@@ -88,8 +88,8 @@ MeshShardingAttr::verify(function_ref<InFlightDiagnostic()> emitError,
     return success();
   };
 
-  for (DenseI8ArrayAttr subAxes : splitAxes) {
-    ArrayRef<int8_t> subAxesArray = subAxes.asArrayRef();
+  for (DenseI64ArrayAttr subAxes : splitAxes) {
+    ArrayRef<int64_t> subAxesArray = subAxes.asArrayRef();
     if (failed(checkMeshAxis(subAxesArray)))
       return failure();
   }

``````````

</details>


https://github.com/llvm/llvm-project/pull/69361


More information about the Mlir-commits mailing list