[Mlir-commits] [mlir] 1a21196 - [MLIR] reverse int8 type's printing logic (#69361)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Oct 18 10:30:18 PDT 2023


Author: Chengji Yao
Date: 2023-10-18T10:30:13-07:00
New Revision: 1a21196b9a53d43cabb3db4ac9473e35f3cbb21b

URL: https://github.com/llvm/llvm-project/commit/1a21196b9a53d43cabb3db4ac9473e35f3cbb21b
DIFF: https://github.com/llvm/llvm-project/commit/1a21196b9a53d43cabb3db4ac9473e35f3cbb21b.diff

LOG: [MLIR] reverse int8 type's printing logic (#69361)

Specializing for 8-bit integers to ensure values are printed as integers

Fixes #69310

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td
    mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
    mlir/include/mlir/IR/OpImplementation.h
    mlir/lib/Dialect/Mesh/IR/MeshOps.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td b/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td
index d761743a82bf86b..39d24595ec1c446 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td
@@ -58,8 +58,8 @@ def MeshSharding : AttrDef<Mesh_Dialect, "MeshSharding"> {
 
   let parameters = (ins
     AttrParameter<"::mlir::SymbolRefAttr", "cluster placed">:$cluster,
-    ArrayRefParameter<"::mlir::DenseI8ArrayAttr">:$split_axes,
-    OptionalArrayRefParameter<"int8_t">:$partial_axes,
+    ArrayRefParameter<"::mlir::DenseI32ArrayAttr">:$split_axes,
+    OptionalArrayRefParameter<"int32_t">:$partial_axes,
     OptionalParameter<"::mlir::mesh::Partial">:$partial_type
   );
 

diff  --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
index 8ca4b6653104221..a8aa0a694bee29f 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
@@ -70,7 +70,7 @@ def Mesh_ClusterOp : Mesh_Op<"cluster", [Symbol]> {
   }];
   let arguments = (ins
     SymbolNameAttr:$sym_name,
-    I8Attr:$rank,
+    I64Attr:$rank,
     DefaultValuedAttr<DenseI64ArrayAttr, "{}">:$dim_sizes
   );
   let assemblyFormat = [{

diff  --git a/mlir/include/mlir/IR/OpImplementation.h b/mlir/include/mlir/IR/OpImplementation.h
index 379392ace46961a..f1fabf95a68b7ad 100644
--- a/mlir/include/mlir/IR/OpImplementation.h
+++ b/mlir/include/mlir/IR/OpImplementation.h
@@ -350,8 +350,7 @@ template <typename AsmPrinterT, typename T,
                                !std::is_convertible<T &, Attribute &>::value &&
                                !std::is_convertible<T &, ValueRange>::value &&
                                !std::is_convertible<T &, APFloat &>::value &&
-                               !llvm::is_one_of<T, bool, int8_t, uint8_t, float,
-                                                double>::value,
+                               !llvm::is_one_of<T, bool, float, double>::value,
                            T> * = nullptr>
 inline std::enable_if_t<std::is_base_of<AsmPrinter, AsmPrinterT>::value,
                         AsmPrinterT &>
@@ -367,17 +366,6 @@ operator<<(AsmPrinterT &p, bool value) {
   return p << (value ? StringRef("true") : "false");
 }
 
-/// Specialization for 8-bit integers to ensure values are printed as integers
-// and not characters.
-template <
-    typename AsmPrinterT, typename T,
-    std::enable_if_t<llvm::is_one_of<T, int8_t, uint8_t>::value, T> * = nullptr>
-inline std::enable_if_t<std::is_base_of<AsmPrinter, AsmPrinterT>::value,
-                        AsmPrinterT &>
-operator<<(AsmPrinterT &p, T value) {
-  return p << static_cast<int16_t>(value);
-}
-
 template <typename AsmPrinterT, typename ValueRangeT>
 inline std::enable_if_t<std::is_base_of<AsmPrinter, AsmPrinterT>::value,
                         AsmPrinterT &>

diff  --git a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
index b2a47102528758c..fc91fd994f12dc2 100644
--- a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
+++ b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
@@ -47,7 +47,7 @@ Operation *MeshDialect::materializeConstant(OpBuilder &builder, Attribute value,
 
 LogicalResult ClusterOp::verify() {
   ArrayRef<int64_t> dimSizes = getDimSizes();
-  uint8_t rank = getRank();
+  uint64_t rank = getRank();
 
   if (rank == 0)
     return emitOpError("rank of cluster is expected to be a positive integer");
@@ -71,15 +71,15 @@ LogicalResult ClusterOp::verify() {
 
 LogicalResult
 MeshShardingAttr::verify(function_ref<InFlightDiagnostic()> emitError,
-                         SymbolRefAttr, ArrayRef<DenseI8ArrayAttr> splitAxes,
-                         ArrayRef<int8_t> partialAxes, Partial) {
+                         SymbolRefAttr, ArrayRef<DenseI32ArrayAttr> splitAxes,
+                         ArrayRef<int32_t> partialAxes, Partial) {
   // TODO: At present cluster symbol ref is not verified. This is due to the
   // 
diff iculty in fetching the corresponding symbol op based on an attribute.
 
-  llvm::SmallSet<int8_t, 4> visitedAxes;
+  llvm::SmallSet<int32_t, 4> visitedAxes;
 
-  auto checkMeshAxis = [&](ArrayRef<int8_t> axesArray) -> LogicalResult {
-    for (int8_t axis : axesArray) {
+  auto checkMeshAxis = [&](ArrayRef<int32_t> axesArray) -> LogicalResult {
+    for (int32_t axis : axesArray) {
       if (axis < 0)
         return emitError() << "mesh axis is expected to be non-negative";
       if (!visitedAxes.insert(axis).second)
@@ -88,8 +88,8 @@ MeshShardingAttr::verify(function_ref<InFlightDiagnostic()> emitError,
     return success();
   };
 
-  for (DenseI8ArrayAttr subAxes : splitAxes) {
-    ArrayRef<int8_t> subAxesArray = subAxes.asArrayRef();
+  for (DenseI32ArrayAttr subAxes : splitAxes) {
+    ArrayRef<int32_t> subAxesArray = subAxes.asArrayRef();
     if (failed(checkMeshAxis(subAxesArray)))
       return failure();
   }


        


More information about the Mlir-commits mailing list