[Mlir-commits] [mlir] 1b2c787 - Add support for IndexType inside DenseIntElementsAttr.

Sean Silva llvmlistbot at llvm.org
Thu Apr 23 17:42:51 PDT 2020


Author: Sean Silva
Date: 2020-04-23T17:42:33-07:00
New Revision: 1b2c7877a4dd3da2ba7462d65a48341f59c6cc61

URL: https://github.com/llvm/llvm-project/commit/1b2c7877a4dd3da2ba7462d65a48341f59c6cc61
DIFF: https://github.com/llvm/llvm-project/commit/1b2c7877a4dd3da2ba7462d65a48341f59c6cc61.diff

LOG: Add support for IndexType inside DenseIntElementsAttr.

This also fixes issues discovered in the parsing/printing path.

Added: 
    

Modified: 
    mlir/include/mlir/IR/StandardTypes.h
    mlir/include/mlir/IR/Types.h
    mlir/lib/IR/AsmPrinter.cpp
    mlir/lib/IR/AttributeDetail.h
    mlir/lib/IR/Attributes.cpp
    mlir/lib/IR/StandardTypes.cpp
    mlir/lib/Parser/Parser.cpp
    mlir/test/IR/parser.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/IR/StandardTypes.h b/mlir/include/mlir/IR/StandardTypes.h
index 4c5bbba0aa6a..3241a70f2675 100644
--- a/mlir/include/mlir/IR/StandardTypes.h
+++ b/mlir/include/mlir/IR/StandardTypes.h
@@ -76,6 +76,9 @@ class IndexType : public Type::TypeBase<IndexType, Type> {
 
   /// Support method to enable LLVM-style type casting.
   static bool kindof(unsigned kind) { return kind == StandardTypes::Index; }
+
+  /// Storage bit width used for IndexType by internal compiler data structures.
+  static constexpr unsigned kInternalStorageBitWidth = 64;
 };
 
 /// Integer types can have arbitrary bitwidth up to a large fixed limit.

diff  --git a/mlir/include/mlir/IR/Types.h b/mlir/include/mlir/IR/Types.h
index e45fa9037470..32ae13f86dc9 100644
--- a/mlir/include/mlir/IR/Types.h
+++ b/mlir/include/mlir/IR/Types.h
@@ -169,6 +169,8 @@ class Type {
   /// Return true of this is a signless integer or a float type.
   bool isSignlessIntOrFloat();
 
+  /// Return true if this is an integer (of any signedness) or an index type.
+  bool isIntOrIndex();
   /// Return true if this is an integer (of any signedness) or a float type.
   bool isIntOrFloat();
   /// Return true if this is an integer (of any signedness), index, or float

diff  --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index 75ea0612bb3a..ac92b707fac5 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -1462,7 +1462,7 @@ void ModulePrinter::printDenseElementsAttr(DenseElementsAttr attr,
   bool isSigned = !type.getElementType().isUnsignedInteger();
 
   // The function used to print elements of this attribute.
-  auto printEltFn = type.getElementType().isa<IntegerType>()
+  auto printEltFn = type.getElementType().isIntOrIndex()
                         ? printDenseIntElement
                         : printDenseFloatElement;
 

diff  --git a/mlir/lib/IR/AttributeDetail.h b/mlir/lib/IR/AttributeDetail.h
index 8908416b3efa..0119830e1ab9 100644
--- a/mlir/lib/IR/AttributeDetail.h
+++ b/mlir/lib/IR/AttributeDetail.h
@@ -372,6 +372,17 @@ struct TypeAttributeStorage : public AttributeStorage {
 // Elements Attributes
 //===----------------------------------------------------------------------===//
 
+/// Return the bit width which DenseElementsAttr should use for this type.
+inline size_t getDenseElementBitWidth(Type eltType) {
+  // FIXME(b/121118307): using 64 bits for BF16 because it is currently stored
+  // with double semantics.
+  if (eltType.isBF16())
+    return 64;
+  if (eltType.isIndex())
+    return IndexType::kInternalStorageBitWidth;
+  return eltType.getIntOrFloatBitWidth();
+}
+
 /// An attribute representing a reference to a dense vector or tensor object.
 struct DenseElementsAttributeStorage : public AttributeStorage {
   struct KeyTy {
@@ -405,7 +416,7 @@ struct DenseElementsAttributeStorage : public AttributeStorage {
     // same. Boolean values are packed at the bit level, and even though a splat
     // is detected the rest of the bits in the first byte may 
diff er from the
     // splat value.
-    if (key.type.getElementTypeBitWidth() == 1) {
+    if (key.type.getElementType().isInteger(1)) {
       if (key.isSplat != isSplat)
         return false;
       if (isSplat)
@@ -437,15 +448,10 @@ struct DenseElementsAttributeStorage : public AttributeStorage {
     assert(numElements != 1 && "splat of 1 element should already be detected");
 
     // Handle boolean values directly as they are packed to 1-bit.
-    size_t elementWidth = ty.getElementTypeBitWidth();
-    if (elementWidth == 1)
+    if (ty.getElementType().isInteger(1) == 1)
       return getKeyForBoolData(ty, data, numElements);
 
-    // FIXME(b/121118307): using 64 bits for BF16 because it is currently stored
-    // with double semantics.
-    if (ty.getElementType().isBF16())
-      elementWidth = 64;
-
+    size_t elementWidth = getDenseElementBitWidth(ty.getElementType());
     // Non 1-bit dense elements are padded to 8-bits.
     size_t storageSize = llvm::divideCeil(elementWidth, CHAR_BIT);
     assert(((data.size() / storageSize) == numElements) &&
@@ -517,7 +523,7 @@ struct DenseElementsAttributeStorage : public AttributeStorage {
       std::memcpy(rawData, data.data(), data.size());
 
       // If this is a boolean splat, make sure only the first bit is used.
-      if (key.isSplat && key.type.getElementTypeBitWidth() == 1)
+      if (key.isSplat && key.type.getElementType().isInteger(1))
         rawData[0] &= 1;
       copy = ArrayRef<char>(rawData, data.size());
     }

diff  --git a/mlir/lib/IR/Attributes.cpp b/mlir/lib/IR/Attributes.cpp
index 657df5752ade..be31d95657ff 100644
--- a/mlir/lib/IR/Attributes.cpp
+++ b/mlir/lib/IR/Attributes.cpp
@@ -275,7 +275,7 @@ IntegerAttr IntegerAttr::get(Type type, const APInt &value) {
 IntegerAttr IntegerAttr::get(Type type, int64_t value) {
   // This uses 64 bit APInts by default for index type.
   if (type.isIndex())
-    return get(type, APInt(64, value));
+    return get(type, APInt(IndexType::kInternalStorageBitWidth, value));
 
   auto intType = type.cast<IntegerType>();
   return get(type, APInt(intType.getWidth(), value, intType.isSignedInteger()));
@@ -483,12 +483,6 @@ uint64_t ElementsAttr::getFlattenedIndex(ArrayRef<uint64_t> index) const {
 // DenseElementAttr Utilities
 //===----------------------------------------------------------------------===//
 
-static size_t getDenseElementBitwidth(Type eltType) {
-  // FIXME(b/121118307): using 64 bits for BF16 because it is currently stored
-  // with double semantics.
-  return eltType.isBF16() ? 64 : eltType.getIntOrFloatBitWidth();
-}
-
 /// Get the bitwidth of a dense element type within the buffer.
 /// DenseElementsAttr requires bitwidths greater than 1 to be aligned by 8.
 static size_t getDenseElementStorageWidth(size_t origWidth) {
@@ -592,7 +586,7 @@ DenseElementsAttr::IntElementIterator::IntElementIterator(
     DenseElementsAttr attr, size_t dataIndex)
     : DenseElementIndexedIteratorImpl<IntElementIterator, APInt, APInt, APInt>(
           attr.getRawData().data(), attr.isSplat(), dataIndex),
-      bitWidth(getDenseElementBitwidth(attr.getType().getElementType())) {}
+      bitWidth(getDenseElementBitWidth(attr.getType().getElementType())) {}
 
 /// Accesses the raw APInt value at this iterator position.
 APInt DenseElementsAttr::IntElementIterator::operator*() const {
@@ -613,12 +607,12 @@ DenseElementsAttr::FloatElementIterator::FloatElementIterator(
 
 DenseElementsAttr DenseElementsAttr::get(ShapedType type,
                                          ArrayRef<Attribute> values) {
-  assert(type.getElementType().isIntOrFloat() &&
-         "expected int or float element type");
+  assert(type.getElementType().isIntOrIndexOrFloat() &&
+         "expected int or index or float element type");
   assert(hasSameElementsOrSplat(type, values));
 
   auto eltType = type.getElementType();
-  size_t bitWidth = getDenseElementBitwidth(eltType);
+  size_t bitWidth = getDenseElementBitWidth(eltType);
   size_t storageBitWidth = getDenseElementStorageWidth(bitWidth);
 
   // Compress the attribute values into a character buffer.
@@ -637,6 +631,7 @@ DenseElementsAttr DenseElementsAttr::get(ShapedType type,
       intVal = values[i].cast<FloatAttr>().getValue().bitcastToAPInt();
       break;
     case StandardTypes::Integer:
+    case StandardTypes::Index:
       intVal = values[i].isa<BoolAttr>()
                    ? APInt(1, values[i].cast<BoolAttr>().getValue() ? 1 : 0)
                    : values[i].cast<IntegerAttr>().getValue();
@@ -667,7 +662,7 @@ DenseElementsAttr DenseElementsAttr::get(ShapedType type,
 /// element type of 'type'.
 DenseElementsAttr DenseElementsAttr::get(ShapedType type,
                                          ArrayRef<APInt> values) {
-  assert(type.getElementType().isa<IntegerType>());
+  assert(type.getElementType().isIntOrIndex());
   return getRaw(type, values);
 }
 
@@ -701,7 +696,7 @@ DenseElementsAttr DenseElementsAttr::getRaw(ShapedType type,
                                             ArrayRef<APInt> values) {
   assert(hasSameElementsOrSplat(type, values));
 
-  size_t bitWidth = getDenseElementBitwidth(type.getElementType());
+  size_t bitWidth = getDenseElementBitWidth(type.getElementType());
   size_t storageBitWidth = getDenseElementStorageWidth(bitWidth);
   std::vector<char> elementData(llvm::divideCeil(storageBitWidth, CHAR_BIT) *
                                 values.size());
@@ -727,14 +722,17 @@ DenseElementsAttr DenseElementsAttr::getRaw(ShapedType type,
 static bool isValidIntOrFloat(ShapedType type, int64_t dataEltSize, bool isInt,
                               bool isSigned) {
   // Make sure that the data element size is the same as the type element width.
-  if (getDenseElementBitwidth(type.getElementType()) !=
+  if (getDenseElementBitWidth(type.getElementType()) !=
       static_cast<size_t>(dataEltSize * CHAR_BIT))
     return false;
 
-  // Check that the element type is either float or integer.
+  // Check that the element type is either float or integer or index.
   if (!isInt)
     return type.getElementType().isa<FloatType>();
 
+  if (type.getElementType().isIndex())
+    return true;
+
   auto intType = type.getElementType().dyn_cast<IntegerType>();
   if (!intType)
     return false;
@@ -798,18 +796,15 @@ auto DenseElementsAttr::getBoolValues() const
 /// this attribute must be of integer type.
 auto DenseElementsAttr::getIntValues() const
     -> llvm::iterator_range<IntElementIterator> {
-  assert(getType().getElementType().isa<IntegerType>() &&
-         "expected integer type");
+  assert(getType().getElementType().isIntOrIndex() && "expected integral type");
   return {raw_int_begin(), raw_int_end()};
 }
 auto DenseElementsAttr::int_value_begin() const -> IntElementIterator {
-  assert(getType().getElementType().isa<IntegerType>() &&
-         "expected integer type");
+  assert(getType().getElementType().isIntOrIndex() && "expected integral type");
   return raw_int_begin();
 }
 auto DenseElementsAttr::int_value_end() const -> IntElementIterator {
-  assert(getType().getElementType().isa<IntegerType>() &&
-         "expected integer type");
+  assert(getType().getElementType().isIntOrIndex() && "expected integral type");
   return raw_int_end();
 }
 
@@ -870,7 +865,7 @@ template <typename Fn, typename Attr>
 static ShapedType mappingHelper(Fn mapping, Attr &attr, ShapedType inType,
                                 Type newElementType,
                                 llvm::SmallVectorImpl<char> &data) {
-  size_t bitWidth = getDenseElementBitwidth(newElementType);
+  size_t bitWidth = getDenseElementBitWidth(newElementType);
   size_t storageBitWidth = getDenseElementStorageWidth(bitWidth);
 
   ShapedType newArrayType;
@@ -937,7 +932,7 @@ DenseElementsAttr DenseIntElementsAttr::mapValues(
 /// Method for supporting type inquiry through isa, cast and dyn_cast.
 bool DenseIntElementsAttr::classof(Attribute attr) {
   return attr.isa<DenseElementsAttr>() &&
-         attr.getType().cast<ShapedType>().getElementType().isa<IntegerType>();
+         attr.getType().cast<ShapedType>().getElementType().isIntOrIndex();
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/IR/StandardTypes.cpp b/mlir/lib/IR/StandardTypes.cpp
index 94156c358eb0..c61e1d9e2ab9 100644
--- a/mlir/lib/IR/StandardTypes.cpp
+++ b/mlir/lib/IR/StandardTypes.cpp
@@ -83,6 +83,8 @@ bool Type::isSignlessIntOrFloat() {
   return isSignlessInteger() || isa<FloatType>();
 }
 
+bool Type::isIntOrIndex() { return isa<IntegerType>() || isIndex(); }
+
 bool Type::isIntOrFloat() { return isa<IntegerType>() || isa<FloatType>(); }
 
 bool Type::isIntOrIndexOrFloat() { return isIntOrFloat() || isIndex(); }

diff  --git a/mlir/lib/Parser/Parser.cpp b/mlir/lib/Parser/Parser.cpp
index e8825ff0fb27..8eabe3c46f4c 100644
--- a/mlir/lib/Parser/Parser.cpp
+++ b/mlir/lib/Parser/Parser.cpp
@@ -1797,7 +1797,8 @@ static Optional<APInt> buildAttributeAPInt(Type type, bool isNegative,
     return llvm::None;
 
   // Extend or truncate the bitwidth to the right size.
-  unsigned width = type.isIndex() ? 64 : type.getIntOrFloatBitWidth();
+  unsigned width = type.isIndex() ? IndexType::kInternalStorageBitWidth
+                                  : type.getIntOrFloatBitWidth();
   if (width > result.getBitWidth()) {
     result = result.zext(width);
   } else if (width < result.getBitWidth()) {
@@ -1968,8 +1969,7 @@ class TensorLiteralParser {
   }
 
   /// Build a Dense Integer attribute for the given type.
-  DenseElementsAttr getIntAttr(llvm::SMLoc loc, ShapedType type,
-                               IntegerType eltTy);
+  DenseElementsAttr getIntAttr(llvm::SMLoc loc, ShapedType type, Type eltTy);
 
   /// Build a Dense Float attribute for the given type.
   DenseElementsAttr getFloatAttr(llvm::SMLoc loc, ShapedType type,
@@ -2044,14 +2044,17 @@ DenseElementsAttr TensorLiteralParser::getAttr(llvm::SMLoc loc,
 
   // If the type is an integer, build a set of APInt values from the storage
   // with the correct bitwidth.
-  if (auto intTy = type.getElementType().dyn_cast<IntegerType>())
+  Type eltType = type.getElementType();
+  if (auto intTy = eltType.dyn_cast<IntegerType>())
     return getIntAttr(loc, type, intTy);
+  if (auto indexTy = eltType.dyn_cast<IndexType>())
+    return getIntAttr(loc, type, indexTy);
 
   // Otherwise, this must be a floating point type.
-  auto floatTy = type.getElementType().dyn_cast<FloatType>();
+  auto floatTy = eltType.dyn_cast<FloatType>();
   if (!floatTy) {
     p.emitError(loc) << "expected floating-point or integer element type, got "
-                     << type.getElementType();
+                     << eltType;
     return nullptr;
   }
   return getFloatAttr(loc, type, floatTy);
@@ -2059,8 +2062,7 @@ DenseElementsAttr TensorLiteralParser::getAttr(llvm::SMLoc loc,
 
 /// Build a Dense Integer attribute for the given type.
 DenseElementsAttr TensorLiteralParser::getIntAttr(llvm::SMLoc loc,
-                                                  ShapedType type,
-                                                  IntegerType eltTy) {
+                                                  ShapedType type, Type eltTy) {
   std::vector<APInt> intElements;
   intElements.reserve(storage.size());
   auto isUintType = type.getElementType().isUnsignedInteger();
@@ -2085,11 +2087,12 @@ DenseElementsAttr TensorLiteralParser::getIntAttr(llvm::SMLoc loc,
     assert(token.isAny(Token::integer, Token::kw_true, Token::kw_false) &&
            "unexpected token type");
     if (token.isAny(Token::kw_true, Token::kw_false)) {
-      if (!eltTy.isInteger(1))
+      if (!eltTy.isInteger(1)) {
         p.emitError(tokenLoc)
             << "expected i1 type for 'true' or 'false' values";
-      APInt apInt(eltTy.getWidth(), token.is(Token::kw_true),
-                  /*isSigned=*/false);
+        return nullptr;
+      }
+      APInt apInt(1, token.is(Token::kw_true), /*isSigned=*/false);
       intElements.push_back(apInt);
       continue;
     }

diff  --git a/mlir/test/IR/parser.mlir b/mlir/test/IR/parser.mlir
index 585f7de40aaf..7a02d776c172 100644
--- a/mlir/test/IR/parser.mlir
+++ b/mlir/test/IR/parser.mlir
@@ -697,6 +697,11 @@ func @densetensorattr() -> () {
   "intscalar"(){bar = dense<1> : tensor<i32>} : () -> ()
 // CHECK: "floatscalar"() {bar = dense<5.000000e+00> : tensor<f32>} : () -> ()
   "floatscalar"(){bar = dense<5.0> : tensor<f32>} : () -> ()
+
+// CHECK: "index"() {bar = dense<1> : tensor<index>} : () -> ()
+  "index"(){bar = dense<1> : tensor<index>} : () -> ()
+// CHECK: "index"() {bar = dense<[1, 2]> : tensor<2xindex>} : () -> ()
+  "index"(){bar = dense<[1, 2]> : tensor<2xindex>} : () -> ()
   return
 }
 


        


More information about the Mlir-commits mailing list