[Mlir-commits] [mlir] de5a81b - [mlir] Update several usages of IntegerType to properly handled unsignedness.

River Riddle llvmlistbot at llvm.org
Mon Mar 2 09:22:59 PST 2020


Author: River Riddle
Date: 2020-03-02T09:19:26-08:00
New Revision: de5a81b1023e95a06f0e40b8ef9cdfc2e38b6223

URL: https://github.com/llvm/llvm-project/commit/de5a81b1023e95a06f0e40b8ef9cdfc2e38b6223
DIFF: https://github.com/llvm/llvm-project/commit/de5a81b1023e95a06f0e40b8ef9cdfc2e38b6223.diff

LOG: [mlir] Update several usages of IntegerType to properly handled unsignedness.

Summary: For example, DenseElementsAttr currently does not properly round-trip unsigned integer values.

Differential Revision: https://reviews.llvm.org/D75374

Added: 
    

Modified: 
    mlir/include/mlir/IR/Matchers.h
    mlir/include/mlir/IR/OpBase.td
    mlir/include/mlir/IR/StandardTypes.h
    mlir/include/mlir/IR/Types.h
    mlir/lib/Analysis/Utils.cpp
    mlir/lib/IR/AsmPrinter.cpp
    mlir/lib/IR/Attributes.cpp
    mlir/lib/IR/StandardTypes.cpp
    mlir/lib/Parser/Parser.cpp
    mlir/lib/Transforms/LoopFusion.cpp
    mlir/test/IR/parser.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/IR/Matchers.h b/mlir/include/mlir/IR/Matchers.h
index 6321e88c9c10..d9979b8467ee 100644
--- a/mlir/include/mlir/IR/Matchers.h
+++ b/mlir/include/mlir/IR/Matchers.h
@@ -93,9 +93,8 @@ struct constant_int_op_binder {
       return false;
     auto type = op->getResult(0).getType();
 
-    if (type.isSignlessIntOrIndex()) {
+    if (type.isa<IntegerType>() || type.isa<IndexType>())
       return attr_value_binder<IntegerAttr>(bind_value).match(attr);
-    }
     if (type.isa<VectorType>() || type.isa<RankedTensorType>()) {
       if (auto splatAttr = attr.dyn_cast<SplatElementsAttr>()) {
         return attr_value_binder<IntegerAttr>(bind_value)

diff  --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td
index 25c0238946a9..d431d4ebabf4 100644
--- a/mlir/include/mlir/IR/OpBase.td
+++ b/mlir/include/mlir/IR/OpBase.td
@@ -339,6 +339,30 @@ def I16 : I<16>;
 def I32 : I<32>;
 def I64 : I<64>;
 
+// Unsigned integer types.
+// Any unsigned integer type irrespective of its width.
+def AnyUnsignedInteger : Type<
+  CPred<"$_self.isUnsignedInteger()">, "unsigned integer">;
+
+// Unsigned integer type of a specific width.
+class UI<int width>
+    : Type<CPred<"$_self.isUnsignedInteger(" # width # ")">,
+                  width # "-bit unsigned integer">,
+      BuildableType<"$_builder.getIntegerType(" # width #
+                    ", /*isSigned=*/false)"> {
+  int bitwidth = width;
+}
+
+class UnsignedIntOfWidths<list<int> widths> :
+    AnyTypeOf<!foreach(w, widths, UI<w>),
+              StrJoinInt<widths, "/">.result # "-bit unsigned integer">;
+
+def UI1  : UI<1>;
+def UI8  : UI<8>;
+def UI16 : UI<16>;
+def UI32 : UI<32>;
+def UI64 : UI<64>;
+
 // Floating point types.
 
 // Any float type irrespective of its width.

diff  --git a/mlir/include/mlir/IR/StandardTypes.h b/mlir/include/mlir/IR/StandardTypes.h
index 9bb9a8c06234..cd5ba07b689d 100644
--- a/mlir/include/mlir/IR/StandardTypes.h
+++ b/mlir/include/mlir/IR/StandardTypes.h
@@ -328,8 +328,9 @@ class TensorType : public ShapedType {
     // Note: Non standard/builtin types are allowed to exist within tensor
     // types. Dialects are expected to verify that tensor types have a valid
     // element type within that dialect.
-    return type.isSignlessIntOrFloat() || type.isa<ComplexType>() ||
-           type.isa<VectorType>() || type.isa<OpaqueType>() ||
+    return type.isa<ComplexType>() || type.isa<FloatType>() ||
+           type.isa<IntegerType>() || type.isa<OpaqueType>() ||
+           type.isa<VectorType>() ||
            (type.getKind() > Type::Kind::LAST_STANDARD_TYPE);
   }
 

diff  --git a/mlir/include/mlir/IR/Types.h b/mlir/include/mlir/IR/Types.h
index 40f1d4818769..eccc90cdae0c 100644
--- a/mlir/include/mlir/IR/Types.h
+++ b/mlir/include/mlir/IR/Types.h
@@ -169,6 +169,9 @@ class Type {
   /// Return true of this is a signless integer or a float type.
   bool isSignlessIntOrFloat();
 
+  /// Return true of this is an integer(of any signedness) or a float type.
+  bool isIntOrFloat();
+
   /// Print the current type.
   void print(raw_ostream &os);
   void dump();

diff  --git a/mlir/lib/Analysis/Utils.cpp b/mlir/lib/Analysis/Utils.cpp
index b76c0c0770a3..14635a144735 100644
--- a/mlir/lib/Analysis/Utils.cpp
+++ b/mlir/lib/Analysis/Utils.cpp
@@ -314,7 +314,7 @@ static unsigned getMemRefEltSizeInBytes(MemRefType memRefType) {
   auto elementType = memRefType.getElementType();
 
   unsigned sizeInBits;
-  if (elementType.isSignlessIntOrFloat()) {
+  if (elementType.isIntOrFloat()) {
     sizeInBits = elementType.getIntOrFloatBitWidth();
   } else {
     auto vectorType = elementType.cast<VectorType>();
@@ -358,7 +358,7 @@ Optional<uint64_t> mlir::getMemRefSizeInBytes(MemRefType memRefType) {
   if (!memRefType.hasStaticShape())
     return None;
   auto elementType = memRefType.getElementType();
-  if (!elementType.isSignlessIntOrFloat() && !elementType.isa<VectorType>())
+  if (!elementType.isIntOrFloat() && !elementType.isa<VectorType>())
     return None;
 
   uint64_t sizeInBytes = getMemRefEltSizeInBytes(memRefType);

diff  --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index 140f533e0b15..ac2648846b24 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -1372,17 +1372,18 @@ void ModulePrinter::printAttribute(Attribute attr,
 
 /// Print the integer element of the given DenseElementsAttr at 'index'.
 static void printDenseIntElement(DenseElementsAttr attr, raw_ostream &os,
-                                 unsigned index) {
+                                 unsigned index, bool isSigned) {
   APInt value = *std::next(attr.int_value_begin(), index);
   if (value.getBitWidth() == 1)
     os << (value.getBoolValue() ? "true" : "false");
   else
-    value.print(os, /*isSigned=*/true);
+    value.print(os, isSigned);
 }
 
 /// Print the float element of the given DenseElementsAttr at 'index'.
 static void printDenseFloatElement(DenseElementsAttr attr, raw_ostream &os,
-                                   unsigned index) {
+                                   unsigned index, bool isSigned) {
+  assert(isSigned && "floating point values are always signed");
   APFloat value = *std::next(attr.float_value_begin(), index);
   printFloatValue(value, os);
 }
@@ -1392,6 +1393,7 @@ void ModulePrinter::printDenseElementsAttr(DenseElementsAttr attr,
   auto type = attr.getType();
   auto shape = type.getShape();
   auto rank = type.getRank();
+  bool isSigned = !type.getElementType().isUnsignedInteger();
 
   // The function used to print elements of this attribute.
   auto printEltFn = type.getElementType().isa<IntegerType>()
@@ -1400,7 +1402,7 @@ void ModulePrinter::printDenseElementsAttr(DenseElementsAttr attr,
 
   // Special case for 0-d and splat tensors.
   if (attr.isSplat()) {
-    printEltFn(attr, os, 0);
+    printEltFn(attr, os, 0, isSigned);
     return;
   }
 
@@ -1452,7 +1454,7 @@ void ModulePrinter::printDenseElementsAttr(DenseElementsAttr attr,
     while (openBrackets++ < rank)
       os << '[';
     openBrackets = rank;
-    printEltFn(attr, os, idx);
+    printEltFn(attr, os, idx, isSigned);
     bumpCounter();
   }
   while (openBrackets-- > 0)

diff  --git a/mlir/lib/IR/Attributes.cpp b/mlir/lib/IR/Attributes.cpp
index 5beb12a59940..4526d7dc10be 100644
--- a/mlir/lib/IR/Attributes.cpp
+++ b/mlir/lib/IR/Attributes.cpp
@@ -608,7 +608,7 @@ DenseElementsAttr::FloatElementIterator::FloatElementIterator(
 
 DenseElementsAttr DenseElementsAttr::get(ShapedType type,
                                          ArrayRef<Attribute> values) {
-  assert(type.getElementType().isSignlessIntOrFloat() &&
+  assert(type.getElementType().isIntOrFloat() &&
          "expected int or float element type");
   assert(hasSameElementsOrSplat(type, values));
 

diff  --git a/mlir/lib/IR/StandardTypes.cpp b/mlir/lib/IR/StandardTypes.cpp
index 30d5bbcc7b3c..774f80a46de3 100644
--- a/mlir/lib/IR/StandardTypes.cpp
+++ b/mlir/lib/IR/StandardTypes.cpp
@@ -84,6 +84,8 @@ bool Type::isSignlessIntOrFloat() {
   return isSignlessInteger() || isa<FloatType>();
 }
 
+bool Type::isIntOrFloat() { return isa<IntegerType>() || isa<FloatType>(); }
+
 //===----------------------------------------------------------------------===//
 // Integer Type
 //===----------------------------------------------------------------------===//
@@ -147,13 +149,10 @@ const llvm::fltSemantics &FloatType::getFloatSemantics() {
 }
 
 unsigned Type::getIntOrFloatBitWidth() {
-  assert(isSignlessIntOrFloat() && "only ints and floats have a bitwidth");
-  if (auto intType = dyn_cast<IntegerType>()) {
+  assert(isIntOrFloat() && "only integers and floats have a bitwidth");
+  if (auto intType = dyn_cast<IntegerType>())
     return intType.getWidth();
-  }
-
-  auto floatType = cast<FloatType>();
-  return floatType.getWidth();
+  return cast<FloatType>().getWidth();
 }
 
 //===----------------------------------------------------------------------===//
@@ -202,7 +201,7 @@ int64_t ShapedType::getSizeInBits() const {
          "cannot get the bit size of an aggregate with a dynamic shape");
 
   auto elementType = getElementType();
-  if (elementType.isSignlessIntOrFloat())
+  if (elementType.isIntOrFloat())
     return elementType.getIntOrFloatBitWidth() * getNumElements();
 
   // Tensors can have vectors and other tensors as elements, other shaped types
@@ -373,7 +372,7 @@ MemRefType MemRefType::getImpl(ArrayRef<int64_t> shape, Type elementType,
   auto *context = elementType.getContext();
 
   // Check that memref is formed from allowed types.
-  if (!elementType.isSignlessIntOrFloat() && !elementType.isa<VectorType>() &&
+  if (!elementType.isIntOrFloat() && !elementType.isa<VectorType>() &&
       !elementType.isa<ComplexType>())
     return emitOptionalError(location, "invalid memref element type"),
            MemRefType();
@@ -451,7 +450,7 @@ LogicalResult
 UnrankedMemRefType::verifyConstructionInvariants(Location loc, Type elementType,
                                                  unsigned memorySpace) {
   // Check that memref is formed from allowed types.
-  if (!elementType.isSignlessIntOrFloat() && !elementType.isa<VectorType>() &&
+  if (!elementType.isIntOrFloat() && !elementType.isa<VectorType>() &&
       !elementType.isa<ComplexType>())
     return emitError(loc, "invalid memref element type");
   return success();

diff  --git a/mlir/lib/Parser/Parser.cpp b/mlir/lib/Parser/Parser.cpp
index 668fb694d8fd..661bddf8107a 100644
--- a/mlir/lib/Parser/Parser.cpp
+++ b/mlir/lib/Parser/Parser.cpp
@@ -1102,7 +1102,7 @@ Type Parser::parseMemRefType() {
     return nullptr;
 
   // Check that memref is formed from allowed types.
-  if (!elementType.isSignlessIntOrFloat() && !elementType.isa<VectorType>() &&
+  if (!elementType.isIntOrFloat() && !elementType.isa<VectorType>() &&
       !elementType.isa<ComplexType>())
     return emitError(typeLoc, "invalid memref element type"), nullptr;
 

diff  --git a/mlir/lib/Transforms/LoopFusion.cpp b/mlir/lib/Transforms/LoopFusion.cpp
index ef1af5d71aa8..bcb0c16ba77b 100644
--- a/mlir/lib/Transforms/LoopFusion.cpp
+++ b/mlir/lib/Transforms/LoopFusion.cpp
@@ -869,7 +869,7 @@ static unsigned getMemRefEltSizeInBytes(MemRefType memRefType) {
   auto elementType = memRefType.getElementType();
 
   unsigned sizeInBits;
-  if (elementType.isSignlessIntOrFloat()) {
+  if (elementType.isIntOrFloat()) {
     sizeInBits = elementType.getIntOrFloatBitWidth();
   } else {
     auto vectorType = elementType.cast<VectorType>();

diff  --git a/mlir/test/IR/parser.mlir b/mlir/test/IR/parser.mlir
index bec1fbd4aca6..3baf0642e8b0 100644
--- a/mlir/test/IR/parser.mlir
+++ b/mlir/test/IR/parser.mlir
@@ -616,6 +616,9 @@ func @splattensorattr() -> () {
   // CHECK: "splatBoolTensor"() {bar = dense<false> : tensor<i1>} : () -> ()
   "splatBoolTensor"(){bar = dense<false> : tensor<i1>} : () -> ()
 
+  // CHECK: "splatUIntTensor"() {bar = dense<222> : tensor<2x1x4xui8>} : () -> ()
+  "splatUIntTensor"(){bar = dense<222> : tensor<2x1x4xui8>} : () -> ()
+
   // CHECK: "splatIntTensor"() {bar = dense<5> : tensor<2x1x4xi32>} : () -> ()
   "splatIntTensor"(){bar = dense<5> : tensor<2x1x4xi32>} : () -> ()
 


        


More information about the Mlir-commits mailing list