[Mlir-commits] [mlir] 5b89c1d - [mlir] DenseStringElementsAttr added to default attribute types

River Riddle llvmlistbot at llvm.org
Thu Apr 23 19:05:44 PDT 2020


Author: Rob Suderman
Date: 2020-04-23T19:02:15-07:00
New Revision: 5b89c1dd68966b7fc8d19a0197da4f95eaab066a

URL: https://github.com/llvm/llvm-project/commit/5b89c1dd68966b7fc8d19a0197da4f95eaab066a
DIFF: https://github.com/llvm/llvm-project/commit/5b89c1dd68966b7fc8d19a0197da4f95eaab066a.diff

LOG: [mlir] DenseStringElementsAttr added to default attribute types

Summary:
Implemented a DenseStringsElements attr for handling arrays / tensors of strings. This includes the
necessary logic for parsing and printing the attribute from MLIR's text format.

To store the attribute we perform a single allocation that includes all wrapped string data tightly packed.
This means no padding characters and no null terminators (as they could be present in the string). This
buffer includes a first chunk of data that represents an array of StringRefs, that contain address pointers
into the string data, with the length of each string wrapped. At this point there is no Sparse representation
however strings are not typically represented sparsely.

Differential Revision: https://reviews.llvm.org/D78600

Added: 
    

Modified: 
    mlir/include/mlir/IR/Attributes.h
    mlir/include/mlir/IR/OpBase.td
    mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
    mlir/lib/IR/AsmPrinter.cpp
    mlir/lib/IR/AttributeDetail.h
    mlir/lib/IR/Attributes.cpp
    mlir/lib/IR/MLIRContext.cpp
    mlir/lib/Parser/Parser.cpp
    mlir/test/IR/attribute.mlir
    mlir/test/IR/dense-elements-hex.mlir
    mlir/test/lib/Dialect/Test/TestOps.td

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/IR/Attributes.h b/mlir/include/mlir/IR/Attributes.h
index ab8cc6ee0a00..e65f4a0b0624 100644
--- a/mlir/include/mlir/IR/Attributes.h
+++ b/mlir/include/mlir/IR/Attributes.h
@@ -40,7 +40,8 @@ struct SymbolRefAttributeStorage;
 struct TypeAttributeStorage;
 
 /// Elements Attributes.
-struct DenseElementsAttributeStorage;
+struct DenseIntOrFPElementsAttributeStorage;
+struct DenseStringElementsAttributeStorage;
 struct OpaqueElementsAttributeStorage;
 struct SparseElementsAttributeStorage;
 } // namespace detail
@@ -141,10 +142,11 @@ enum Kind {
   Unit,
 
   /// Elements Attributes.
-  DenseElements,
+  DenseIntOrFPElements,
+  DenseStringElements,
   OpaqueElements,
   SparseElements,
-  FIRST_ELEMENTS_ATTR = DenseElements,
+  FIRST_ELEMENTS_ATTR = DenseIntOrFPElements,
   LAST_ELEMENTS_ATTR = SparseElements,
 
   /// Locations.
@@ -671,15 +673,14 @@ class DenseElementIndexedIteratorImpl
 
 /// An attribute that represents a reference to a dense vector or tensor object.
 ///
-class DenseElementsAttr
-    : public Attribute::AttrBase<DenseElementsAttr, ElementsAttr,
-                                 detail::DenseElementsAttributeStorage> {
+class DenseElementsAttr : public ElementsAttr {
 public:
-  using Base::Base;
+  using ElementsAttr::ElementsAttr;
 
   /// Method for support type inquiry through isa, cast and dyn_cast.
   static bool classof(Attribute attr) {
-    return attr.getKind() == StandardAttributes::DenseElements;
+    return attr.getKind() == StandardAttributes::DenseIntOrFPElements ||
+           attr.getKind() == StandardAttributes::DenseStringElements;
   }
 
   /// Constructs a dense elements attribute from an array of element values.
@@ -712,6 +713,10 @@ class DenseElementsAttr
   /// Overload of the above 'get' method that is specialized for boolean values.
   static DenseElementsAttr get(ShapedType type, ArrayRef<bool> values);
 
+  /// Overload of the above 'get' method that is specialized for StringRef
+  /// values.
+  static DenseElementsAttr get(ShapedType type, ArrayRef<StringRef> values);
+
   /// Constructs a dense integer elements attribute from an array of APInt
   /// values. Each APInt value is expected to have the same bitwidth as the
   /// element type of 'type'. 'type' must be a vector or tensor with static
@@ -882,6 +887,14 @@ class DenseElementsAttr
             ElementIterator<T>(rawData, splat, getNumElements())};
   }
 
+  llvm::iterator_range<ElementIterator<StringRef>> getValues() const {
+    auto stringRefs = getRawStringData();
+    const char *ptr = reinterpret_cast<const char *>(stringRefs.data());
+    bool splat = isSplat();
+    return {ElementIterator<StringRef>(ptr, splat, 0),
+            ElementIterator<StringRef>(ptr, splat, getNumElements())};
+  }
+
   /// Return the held element values as a range of Attributes.
   llvm::iterator_range<AttributeElementIterator> getAttributeValues() const;
   template <typename T, typename = typename std::enable_if<
@@ -942,6 +955,9 @@ class DenseElementsAttr
   /// form the user might expect.
   ArrayRef<char> getRawData() const;
 
+  /// Return the raw StringRef data held by this attribute.
+  ArrayRef<StringRef> getRawStringData() const;
+
   //===--------------------------------------------------------------------===//
   // Mutation Utilities
   //===--------------------------------------------------------------------===//
@@ -973,6 +989,60 @@ class DenseElementsAttr
     return IntElementIterator(*this, getNumElements());
   }
 
+  /// Overload of the raw 'get' method that asserts that the given type is of
+  /// integer or floating-point type. This method is used to verify type
+  /// invariants that the templatized 'get' method cannot.
+  static DenseElementsAttr getRawIntOrFloat(ShapedType type,
+                                            ArrayRef<char> data,
+                                            int64_t dataEltSize, bool isInt,
+                                            bool isSigned);
+
+  /// Check the information for a C++ data type, check if this type is valid for
+  /// the current attribute. This method is used to verify specific type
+  /// invariants that the templatized 'getValues' method cannot.
+  bool isValidIntOrFloat(int64_t dataEltSize, bool isInt, bool isSigned) const;
+};
+
+/// An attribute class for representing dense arrays of strings. The structure
+/// storing and querying a list of densely packed strings.
+class DenseStringElementsAttr
+    : public Attribute::AttrBase<DenseStringElementsAttr, DenseElementsAttr,
+                                 detail::DenseStringElementsAttributeStorage> {
+
+public:
+  using Base::Base;
+
+  /// Method for support type inquiry through isa, cast and dyn_cast.
+  static bool kindof(unsigned kind) {
+    return kind == StandardAttributes::DenseStringElements;
+  }
+
+  /// Overload of the raw 'get' method that asserts that the given type is of
+  /// integer or floating-point type. This method is used to verify type
+  /// invariants that the templatized 'get' method cannot.
+  static DenseStringElementsAttr get(ShapedType type, ArrayRef<StringRef> data);
+
+protected:
+  friend DenseElementsAttr;
+};
+
+/// An attribute class for specializing behavior of Int and Floating-point
+/// densely packed string arrays.
+class DenseIntOrFPElementsAttr
+    : public Attribute::AttrBase<DenseIntOrFPElementsAttr, DenseElementsAttr,
+                                 detail::DenseIntOrFPElementsAttributeStorage> {
+
+public:
+  using Base::Base;
+
+  /// Method for support type inquiry through isa, cast and dyn_cast.
+  static bool kindof(unsigned kind) {
+    return kind == StandardAttributes::DenseIntOrFPElements;
+  }
+
+protected:
+  friend DenseElementsAttr;
+
   /// Constructs a dense elements attribute from an array of raw APInt values.
   /// Each APInt value is expected to have the same bitwidth as the element type
   /// of 'type'. 'type' must be a vector or tensor with static shape.
@@ -990,20 +1060,15 @@ class DenseElementsAttr
                                             ArrayRef<char> data,
                                             int64_t dataEltSize, bool isInt,
                                             bool isSigned);
-
-  /// Check the information for a C++ data type, check if this type is valid for
-  /// the current attribute. This method is used to verify specific type
-  /// invariants that the templatized 'getValues' method cannot.
-  bool isValidIntOrFloat(int64_t dataEltSize, bool isInt, bool isSigned) const;
 };
 
 /// An attribute that represents a reference to a dense float vector or tensor
 /// object. Each element is stored as a double.
-class DenseFPElementsAttr : public DenseElementsAttr {
+class DenseFPElementsAttr : public DenseIntOrFPElementsAttr {
 public:
   using iterator = DenseElementsAttr::FloatElementIterator;
 
-  using DenseElementsAttr::DenseElementsAttr;
+  using DenseIntOrFPElementsAttr::DenseIntOrFPElementsAttr;
 
   /// Get an instance of a DenseFPElementsAttr with the given arguments. This
   /// simply wraps the DenseElementsAttr::get calls.
@@ -1035,13 +1100,13 @@ class DenseFPElementsAttr : public DenseElementsAttr {
 
 /// An attribute that represents a reference to a dense integer vector or tensor
 /// object.
-class DenseIntElementsAttr : public DenseElementsAttr {
+class DenseIntElementsAttr : public DenseIntOrFPElementsAttr {
 public:
   /// DenseIntElementsAttr iterates on APInt, so we can use the raw element
   /// iterator directly.
   using iterator = DenseElementsAttr::IntElementIterator;
 
-  using DenseElementsAttr::DenseElementsAttr;
+  using DenseIntOrFPElementsAttr::DenseIntOrFPElementsAttr;
 
   /// Get an instance of a DenseIntElementsAttr with the given arguments. This
   /// simply wraps the DenseElementsAttr::get calls.
@@ -1266,7 +1331,7 @@ class ElementsAttrIterator
             typename... Args>
   RetT process(Args &... args) const {
     switch (attrKind) {
-    case StandardAttributes::DenseElements:
+    case StandardAttributes::DenseIntOrFPElements:
       return ProcessFn<DenseIteratorT>()(args...);
     case StandardAttributes::SparseElements:
       return ProcessFn<SparseIteratorT>()(args...);

diff  --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td
index 8c6561b73f63..46eaa3b895d4 100644
--- a/mlir/include/mlir/IR/OpBase.td
+++ b/mlir/include/mlir/IR/OpBase.td
@@ -1307,6 +1307,16 @@ class RankedFloatElementsAttr<int width, list<int> dims> : ElementsAttrBase<
 class RankedF32ElementsAttr<list<int> dims> : RankedFloatElementsAttr<32, dims>;
 class RankedF64ElementsAttr<list<int> dims> : RankedFloatElementsAttr<64, dims>;
 
+def StringElementsAttr : ElementsAttrBase<
+  CPred<"$_self.isa<DenseStringElementsAttr>()" >,
+  "string elements attribute"> {
+
+  let storageType = [{ DenseElementsAttr }];
+  let returnType = [{ DenseElementsAttr }];
+
+  let convertFromStorage = "$_self";
+}
+
 // Base class for array attributes.
 class ArrayAttrBase<Pred condition, string description> :
     Attr<condition, description> {

diff  --git a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
index a9c19cd7262b..6deae62c6987 100644
--- a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
@@ -1316,7 +1316,7 @@ static LogicalResult verify(spirv::ConstantOp constOp) {
              << opType << ") does not match value type (" << valueType << ")";
     return success();
   } break;
-  case StandardAttributes::DenseElements:
+  case StandardAttributes::DenseIntOrFPElements:
   case StandardAttributes::SparseElements: {
     if (valueType == opType)
       break;

diff  --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index ac92b707fac5..bdaf15c6e5c5 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -973,6 +973,9 @@ class ModulePrinter {
   /// used instead of individual elements when the elements attr is large.
   void printDenseElementsAttr(DenseElementsAttr attr, bool allowHex);
 
+  /// Print a dense string elements attribute.
+  void printDenseStringElementsAttr(DenseStringElementsAttr attr);
+
   void printDialectAttribute(Attribute attr);
   void printDialectType(Type type);
 
@@ -1392,7 +1395,7 @@ void ModulePrinter::printAttribute(Attribute attr,
     os << '"' << "0x" << llvm::toHex(eltsAttr.getValue()) << "\">";
     break;
   }
-  case StandardAttributes::DenseElements: {
+  case StandardAttributes::DenseIntOrFPElements: {
     auto eltsAttr = attr.cast<DenseElementsAttr>();
     if (printerFlags.shouldElideElementsAttr(eltsAttr)) {
       printElidedElementsAttr(os);
@@ -1403,6 +1406,17 @@ void ModulePrinter::printAttribute(Attribute attr,
     os << '>';
     break;
   }
+  case StandardAttributes::DenseStringElements: {
+    auto eltsAttr = attr.cast<DenseStringElementsAttr>();
+    if (printerFlags.shouldElideElementsAttr(eltsAttr)) {
+      printElidedElementsAttr(os);
+      break;
+    }
+    os << "dense<";
+    printDenseStringElementsAttr(eltsAttr);
+    os << '>';
+    break;
+  }
   case StandardAttributes::SparseElements: {
     auto elementsAttr = attr.cast<SparseElementsAttr>();
     if (printerFlags.shouldElideElementsAttr(elementsAttr.getIndices()) ||
@@ -1454,6 +1468,13 @@ static void printDenseFloatElement(DenseElementsAttr attr, raw_ostream &os,
   printFloatValue(value, os);
 }
 
+static void printDenseStringElement(DenseStringElementsAttr attr,
+                                    raw_ostream &os, unsigned index) {
+  os << "\"";
+  printEscapedString(attr.getRawStringData()[index], os);
+  os << "\"";
+}
+
 void ModulePrinter::printDenseElementsAttr(DenseElementsAttr attr,
                                            bool allowHex) {
   auto type = attr.getType();
@@ -1526,6 +1547,63 @@ void ModulePrinter::printDenseElementsAttr(DenseElementsAttr attr,
     os << ']';
 }
 
+void ModulePrinter::printDenseStringElementsAttr(DenseStringElementsAttr attr) {
+  auto type = attr.getType();
+  auto shape = type.getShape();
+  auto rank = type.getRank();
+
+  // Special case for 0-d and splat tensors.
+  if (attr.isSplat()) {
+    printDenseStringElement(attr, os, 0);
+    return;
+  }
+
+  // Special case for degenerate tensors.
+  auto numElements = type.getNumElements();
+  if (numElements == 0) {
+    for (int i = 0; i < rank; ++i)
+      os << '[';
+    for (int i = 0; i < rank; ++i)
+      os << ']';
+    return;
+  }
+
+  // We use a mixed-radix counter to iterate through the shape. When we bump a
+  // non-least-significant digit, we emit a close bracket. When we next emit an
+  // element we re-open all closed brackets.
+
+  // The mixed-radix counter, with radices in 'shape'.
+  SmallVector<unsigned, 4> counter(rank, 0);
+  // The number of brackets that have been opened and not closed.
+  unsigned openBrackets = 0;
+
+  auto bumpCounter = [&]() {
+    // Bump the least significant digit.
+    ++counter[rank - 1];
+    // Iterate backwards bubbling back the increment.
+    for (unsigned i = rank - 1; i > 0; --i)
+      if (counter[i] >= shape[i]) {
+        // Index 'i' is rolled over. Bump (i-1) and close a bracket.
+        counter[i] = 0;
+        ++counter[i - 1];
+        --openBrackets;
+        os << ']';
+      }
+  };
+
+  for (unsigned idx = 0, e = numElements; idx != e; ++idx) {
+    if (idx != 0)
+      os << ", ";
+    while (openBrackets++ < rank)
+      os << '[';
+    openBrackets = rank;
+    printDenseStringElement(attr, os, idx);
+    bumpCounter();
+  }
+  while (openBrackets-- > 0)
+    os << ']';
+}
+
 void ModulePrinter::printType(Type type) {
   if (!type) {
     os << "<<NULL TYPE>>";

diff  --git a/mlir/lib/IR/AttributeDetail.h b/mlir/lib/IR/AttributeDetail.h
index 0119830e1ab9..49995da012e8 100644
--- a/mlir/lib/IR/AttributeDetail.h
+++ b/mlir/lib/IR/AttributeDetail.h
@@ -385,6 +385,20 @@ inline size_t getDenseElementBitWidth(Type eltType) {
 
 /// An attribute representing a reference to a dense vector or tensor object.
 struct DenseElementsAttributeStorage : public AttributeStorage {
+public:
+  DenseElementsAttributeStorage(ShapedType ty, bool isSplat)
+      : AttributeStorage(ty), isSplat(isSplat) {}
+
+  bool isSplat;
+};
+
+/// An attribute representing a reference to a dense vector or tensor object.
+struct DenseIntOrFPElementsAttributeStorage
+    : public DenseElementsAttributeStorage {
+  DenseIntOrFPElementsAttributeStorage(ShapedType ty, ArrayRef<char> data,
+                                       bool isSplat = false)
+      : DenseElementsAttributeStorage(ty, isSplat), data(data) {}
+
   struct KeyTy {
     KeyTy(ShapedType type, ArrayRef<char> data, llvm::hash_code hashCode,
           bool isSplat = false)
@@ -403,10 +417,6 @@ struct DenseElementsAttributeStorage : public AttributeStorage {
     bool isSplat;
   };
 
-  DenseElementsAttributeStorage(ShapedType ty, ArrayRef<char> data,
-                                bool isSplat = false)
-      : AttributeStorage(ty), data(data), isSplat(isSplat) {}
-
   /// Compare this storage instance with the provided key.
   bool operator==(const KeyTy &key) const {
     if (key.type != getType())
@@ -512,7 +522,7 @@ struct DenseElementsAttributeStorage : public AttributeStorage {
   }
 
   /// Construct a new storage instance.
-  static DenseElementsAttributeStorage *
+  static DenseIntOrFPElementsAttributeStorage *
   construct(AttributeStorageAllocator &allocator, KeyTy key) {
     // If the data buffer is non-empty, we copy it into the allocator with a
     // 64-bit alignment.
@@ -528,12 +538,129 @@ struct DenseElementsAttributeStorage : public AttributeStorage {
       copy = ArrayRef<char>(rawData, data.size());
     }
 
-    return new (allocator.allocate<DenseElementsAttributeStorage>())
-        DenseElementsAttributeStorage(key.type, copy, key.isSplat);
+    return new (allocator.allocate<DenseIntOrFPElementsAttributeStorage>())
+        DenseIntOrFPElementsAttributeStorage(key.type, copy, key.isSplat);
   }
 
   ArrayRef<char> data;
-  bool isSplat;
+};
+
+/// An attribute representing a reference to a dense vector or tensor object
+/// containing strings.
+struct DenseStringElementsAttributeStorage
+    : public DenseElementsAttributeStorage {
+  DenseStringElementsAttributeStorage(ShapedType ty, ArrayRef<StringRef> data,
+                                      bool isSplat = false)
+      : DenseElementsAttributeStorage(ty, isSplat), data(data) {}
+
+  struct KeyTy {
+    KeyTy(ShapedType type, ArrayRef<StringRef> data, llvm::hash_code hashCode,
+          bool isSplat = false)
+        : type(type), data(data), hashCode(hashCode), isSplat(isSplat) {}
+
+    /// The type of the dense elements.
+    ShapedType type;
+
+    /// The raw buffer for the data storage.
+    ArrayRef<StringRef> data;
+
+    /// The computed hash code for the storage data.
+    llvm::hash_code hashCode;
+
+    /// A boolean that indicates if this data is a splat or not.
+    bool isSplat;
+  };
+
+  /// Compare this storage instance with the provided key.
+  bool operator==(const KeyTy &key) const {
+    if (key.type != getType())
+      return false;
+
+    // Otherwise, we can default to just checking the data. StringRefs compare
+    // by contents.
+    return key.data == data;
+  }
+
+  /// Construct a key from a shaped type, StringRef data buffer, and a flag that
+  /// signals if the data is already known to be a splat. Callers to this
+  /// function are expected to tag preknown splat values when possible, e.g. one
+  /// element shapes.
+  static KeyTy getKey(ShapedType ty, ArrayRef<StringRef> data,
+                      bool isKnownSplat) {
+    // Handle an empty storage instance.
+    if (data.empty())
+      return KeyTy(ty, data, 0);
+
+    // If the data is already known to be a splat, the key hash value is
+    // directly the data buffer.
+    if (isKnownSplat)
+      return KeyTy(ty, data, llvm::hash_value(data), isKnownSplat);
+
+    // Handle the simple case of only one element.
+    size_t numElements = ty.getNumElements();
+    assert(numElements != 1 && "splat of 1 element should already be detected");
+
+    // Create the initial hash value with just the first element.
+    const auto &firstElt = data.front();
+    auto hashVal = llvm::hash_value(firstElt);
+
+    // Check to see if this storage represents a splat. If it doesn't then
+    // combine the hash for the data starting with the first non splat element.
+    for (size_t i = 1, e = data.size(); i != e; i++)
+      if (!firstElt.equals(data[i]))
+        return KeyTy(ty, data, llvm::hash_combine(hashVal, data.drop_front(i)));
+
+    // Otherwise, this is a splat so just return the hash of the first element.
+    return KeyTy(ty, {firstElt}, hashVal, /*isSplat=*/true);
+  }
+
+  /// Hash the key for the storage.
+  static llvm::hash_code hashKey(const KeyTy &key) {
+    return llvm::hash_combine(key.type, key.hashCode);
+  }
+
+  /// Construct a new storage instance.
+  static DenseStringElementsAttributeStorage *
+  construct(AttributeStorageAllocator &allocator, KeyTy key) {
+    // If the data buffer is non-empty, we copy it into the allocator with a
+    // 64-bit alignment.
+    ArrayRef<StringRef> copy, data = key.data;
+    if (data.empty()) {
+      return new (allocator.allocate<DenseStringElementsAttributeStorage>())
+          DenseStringElementsAttributeStorage(key.type, copy, key.isSplat);
+    }
+
+    int numEntries = key.isSplat ? 1 : data.size();
+
+    // Compute the amount data needed to store the ArrayRef and StringRef
+    // contents.
+    size_t dataSize = sizeof(StringRef) * numEntries;
+    for (int i = 0; i < numEntries; i++)
+      dataSize += data[i].size();
+
+    char *rawData = reinterpret_cast<char *>(
+        allocator.allocate(dataSize, alignof(uint64_t)));
+
+    // Setup a mutable array ref of our string refs so that we can update their
+    // contents.
+    auto mutableCopy = MutableArrayRef<StringRef>(
+        reinterpret_cast<StringRef *>(rawData), numEntries);
+    auto stringData = rawData + numEntries * sizeof(StringRef);
+
+    for (int i = 0; i < numEntries; i++) {
+      memcpy(stringData, data[i].data(), data[i].size());
+      mutableCopy[i] = StringRef(stringData, data[i].size());
+      stringData += data[i].size();
+    }
+
+    copy =
+        ArrayRef<StringRef>(reinterpret_cast<StringRef *>(rawData), numEntries);
+
+    return new (allocator.allocate<DenseStringElementsAttributeStorage>())
+        DenseStringElementsAttributeStorage(key.type, copy, key.isSplat);
+  }
+
+  ArrayRef<StringRef> data;
 };
 
 /// An attribute representing a reference to a tensor constant with opaque

diff  --git a/mlir/lib/IR/Attributes.cpp b/mlir/lib/IR/Attributes.cpp
index be31d95657ff..26cb0513c099 100644
--- a/mlir/lib/IR/Attributes.cpp
+++ b/mlir/lib/IR/Attributes.cpp
@@ -411,7 +411,7 @@ int64_t ElementsAttr::getNumElements() const {
 /// element, then a null attribute is returned.
 Attribute ElementsAttr::getValue(ArrayRef<uint64_t> index) const {
   switch (getKind()) {
-  case StandardAttributes::DenseElements:
+  case StandardAttributes::DenseIntOrFPElements:
     return cast<DenseElementsAttr>().getValue(index);
   case StandardAttributes::OpaqueElements:
     return cast<OpaqueElementsAttr>().getValue(index);
@@ -442,7 +442,7 @@ ElementsAttr
 ElementsAttr::mapValues(Type newElementType,
                         function_ref<APInt(const APInt &)> mapping) const {
   switch (getKind()) {
-  case StandardAttributes::DenseElements:
+  case StandardAttributes::DenseIntOrFPElements:
     return cast<DenseElementsAttr>().mapValues(newElementType, mapping);
   default:
     llvm_unreachable("unsupported ElementsAttr subtype");
@@ -453,7 +453,7 @@ ElementsAttr
 ElementsAttr::mapValues(Type newElementType,
                         function_ref<APInt(const APFloat &)> mapping) const {
   switch (getKind()) {
-  case StandardAttributes::DenseElements:
+  case StandardAttributes::DenseIntOrFPElements:
     return cast<DenseElementsAttr>().mapValues(newElementType, mapping);
   default:
     llvm_unreachable("unsupported ElementsAttr subtype");
@@ -643,7 +643,8 @@ DenseElementsAttr DenseElementsAttr::get(ShapedType type,
            "expected value to have same bitwidth as element type");
     writeBits(data.data(), i * storageBitWidth, intVal);
   }
-  return getRaw(type, data, /*isSplat=*/(values.size() == 1));
+  return DenseIntOrFPElementsAttr::getRaw(type, data,
+                                          /*isSplat=*/(values.size() == 1));
 }
 
 DenseElementsAttr DenseElementsAttr::get(ShapedType type,
@@ -654,7 +655,14 @@ DenseElementsAttr DenseElementsAttr::get(ShapedType type,
   std::vector<char> buff(llvm::divideCeil(values.size(), CHAR_BIT));
   for (int i = 0, e = values.size(); i != e; ++i)
     setBit(buff.data(), i, values[i]);
-  return getRaw(type, buff, /*isSplat=*/(values.size() == 1));
+  return DenseIntOrFPElementsAttr::getRaw(type, buff,
+                                          /*isSplat=*/(values.size() == 1));
+}
+
+DenseElementsAttr DenseElementsAttr::get(ShapedType type,
+                                         ArrayRef<StringRef> values) {
+  assert(!type.getElementType().isIntOrFloat());
+  return DenseStringElementsAttr::get(type, values);
 }
 
 /// Constructs a dense integer elements attribute from an array of APInt
@@ -663,7 +671,7 @@ DenseElementsAttr DenseElementsAttr::get(ShapedType type,
 DenseElementsAttr DenseElementsAttr::get(ShapedType type,
                                          ArrayRef<APInt> values) {
   assert(type.getElementType().isIntOrIndex());
-  return getRaw(type, values);
+  return DenseIntOrFPElementsAttr::getRaw(type, values);
 }
 
 // Constructs a dense float elements attribute from an array of APFloat
@@ -677,7 +685,7 @@ DenseElementsAttr DenseElementsAttr::get(ShapedType type,
   std::vector<APInt> intValues(values.size());
   for (unsigned i = 0, e = values.size(); i != e; ++i)
     intValues[i] = values[i].bitcastToAPInt();
-  return getRaw(type, intValues);
+  return DenseIntOrFPElementsAttr::getRaw(type, intValues);
 }
 
 /// Construct a dense elements attribute from a raw buffer representing the
@@ -686,34 +694,7 @@ DenseElementsAttr DenseElementsAttr::get(ShapedType type,
 DenseElementsAttr DenseElementsAttr::getFromRawBuffer(ShapedType type,
                                                       ArrayRef<char> rawBuffer,
                                                       bool isSplatBuffer) {
-  return getRaw(type, rawBuffer, isSplatBuffer);
-}
-
-/// Constructs a dense elements attribute from an array of raw APInt values.
-/// Each APInt value is expected to have the same bitwidth as the element type
-/// of 'type'.
-DenseElementsAttr DenseElementsAttr::getRaw(ShapedType type,
-                                            ArrayRef<APInt> values) {
-  assert(hasSameElementsOrSplat(type, values));
-
-  size_t bitWidth = getDenseElementBitWidth(type.getElementType());
-  size_t storageBitWidth = getDenseElementStorageWidth(bitWidth);
-  std::vector<char> elementData(llvm::divideCeil(storageBitWidth, CHAR_BIT) *
-                                values.size());
-  for (unsigned i = 0, e = values.size(); i != e; ++i) {
-    assert(values[i].getBitWidth() == bitWidth);
-    writeBits(elementData.data(), i * storageBitWidth, values[i]);
-  }
-  return getRaw(type, elementData, /*isSplat=*/(values.size() == 1));
-}
-
-DenseElementsAttr DenseElementsAttr::getRaw(ShapedType type,
-                                            ArrayRef<char> data, bool isSplat) {
-  assert((type.isa<RankedTensorType>() || type.isa<VectorType>()) &&
-         "type must be ranked tensor or vector");
-  assert(type.hasStaticShape() && "type must have static shape");
-  return Base::get(type.getContext(), StandardAttributes::DenseElements, type,
-                   data, isSplat);
+  return DenseIntOrFPElementsAttr::getRaw(type, rawBuffer, isSplatBuffer);
 }
 
 /// Check the information for a C++ data type, check if this type is valid for
@@ -743,19 +724,14 @@ static bool isValidIntOrFloat(ShapedType type, int64_t dataEltSize, bool isInt,
   return intType.isSigned() ? isSigned : !isSigned;
 }
 
-/// Overload of the 'getRaw' method that asserts that the given type is of
-/// integer type. This method is used to verify type invariants that the
-/// templatized 'get' method cannot.
+/// Defaults down the subclass implementation.
 DenseElementsAttr DenseElementsAttr::getRawIntOrFloat(ShapedType type,
                                                       ArrayRef<char> data,
                                                       int64_t dataEltSize,
                                                       bool isInt,
                                                       bool isSigned) {
-  assert(::isValidIntOrFloat(type, dataEltSize, isInt, isSigned));
-
-  int64_t numElements = data.size() / dataEltSize;
-  assert(numElements == 1 || numElements == type.getNumElements());
-  return getRaw(type, data, /*isSplat=*/numElements == 1);
+  return DenseIntOrFPElementsAttr::getRawIntOrFloat(type, data, dataEltSize,
+                                                    isInt, isSigned);
 }
 
 /// A method used to verify specific type invariants that the templatized 'get'
@@ -767,7 +743,9 @@ bool DenseElementsAttr::isValidIntOrFloat(int64_t dataEltSize, bool isInt,
 
 /// Returns if this attribute corresponds to a splat, i.e. if all element
 /// values are the same.
-bool DenseElementsAttr::isSplat() const { return getImpl()->isSplat; }
+bool DenseElementsAttr::isSplat() const {
+  return static_cast<DenseElementsAttributeStorage *>(impl)->isSplat;
+}
 
 /// Return the held element values as a range of Attributes.
 auto DenseElementsAttr::getAttributeValues() const
@@ -827,7 +805,11 @@ auto DenseElementsAttr::float_value_end() const -> FloatElementIterator {
 
 /// Return the raw storage data held by this attribute.
 ArrayRef<char> DenseElementsAttr::getRawData() const {
-  return static_cast<ImplType *>(impl)->data;
+  return static_cast<DenseIntOrFPElementsAttributeStorage *>(impl)->data;
+}
+
+ArrayRef<StringRef> DenseElementsAttr::getRawStringData() const {
+  return static_cast<DenseStringElementsAttributeStorage *>(impl)->data;
 }
 
 /// Return a new DenseElementsAttr that has the same data as the current
@@ -843,7 +825,7 @@ DenseElementsAttr DenseElementsAttr::reshape(ShapedType newType) {
          "expected the same element type");
   assert(newType.getNumElements() == curType.getNumElements() &&
          "expected the same number of elements");
-  return getRaw(newType, getRawData(), isSplat());
+  return DenseIntOrFPElementsAttr::getRaw(newType, getRawData(), isSplat());
 }
 
 DenseElementsAttr
@@ -857,6 +839,63 @@ DenseElementsAttr DenseElementsAttr::mapValues(
   return cast<DenseFPElementsAttr>().mapValues(newElementType, mapping);
 }
 
+//===----------------------------------------------------------------------===//
+// DenseStringElementsAttr
+//===----------------------------------------------------------------------===//
+
+DenseStringElementsAttr
+DenseStringElementsAttr::get(ShapedType type, ArrayRef<StringRef> values) {
+  return Base::get(type.getContext(), StandardAttributes::DenseStringElements,
+                   type, values, (values.size() == 1));
+}
+
+//===----------------------------------------------------------------------===//
+// DenseIntOrFPElementsAttr
+//===----------------------------------------------------------------------===//
+
+/// Constructs a dense elements attribute from an array of raw APInt values.
+/// Each APInt value is expected to have the same bitwidth as the element type
+/// of 'type'.
+DenseElementsAttr DenseIntOrFPElementsAttr::getRaw(ShapedType type,
+                                                   ArrayRef<APInt> values) {
+  assert(hasSameElementsOrSplat(type, values));
+
+  size_t bitWidth = getDenseElementBitWidth(type.getElementType());
+  size_t storageBitWidth = getDenseElementStorageWidth(bitWidth);
+  std::vector<char> elementData(llvm::divideCeil(storageBitWidth, CHAR_BIT) *
+                                values.size());
+  for (unsigned i = 0, e = values.size(); i != e; ++i) {
+    assert(values[i].getBitWidth() == bitWidth);
+    writeBits(elementData.data(), i * storageBitWidth, values[i]);
+  }
+  return DenseIntOrFPElementsAttr::getRaw(type, elementData,
+                                          /*isSplat=*/(values.size() == 1));
+}
+
+DenseElementsAttr DenseIntOrFPElementsAttr::getRaw(ShapedType type,
+                                                   ArrayRef<char> data,
+                                                   bool isSplat) {
+  assert((type.isa<RankedTensorType>() || type.isa<VectorType>()) &&
+         "type must be ranked tensor or vector");
+  assert(type.hasStaticShape() && "type must have static shape");
+  return Base::get(type.getContext(), StandardAttributes::DenseIntOrFPElements,
+                   type, data, isSplat);
+}
+
+/// Overload of the 'getRaw' method that asserts that the given type is of
+/// integer type. This method is used to verify type invariants that the
+/// templatized 'get' method cannot.
+DenseElementsAttr
+DenseIntOrFPElementsAttr::getRawIntOrFloat(ShapedType type, ArrayRef<char> data,
+                                           int64_t dataEltSize, bool isInt,
+                                           bool isSigned) {
+  assert(::isValidIntOrFloat(type, dataEltSize, isInt, isSigned));
+
+  int64_t numElements = data.size() / dataEltSize;
+  assert(numElements == 1 || numElements == type.getNumElements());
+  return getRaw(type, data, /*isSplat=*/numElements == 1);
+}
+
 //===----------------------------------------------------------------------===//
 // DenseFPElementsAttr
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/IR/MLIRContext.cpp b/mlir/lib/IR/MLIRContext.cpp
index f0f3cc72d03a..b25c5111d8dc 100644
--- a/mlir/lib/IR/MLIRContext.cpp
+++ b/mlir/lib/IR/MLIRContext.cpp
@@ -82,10 +82,11 @@ namespace {
 /// the IR.
 struct BuiltinDialect : public Dialect {
   BuiltinDialect(MLIRContext *context) : Dialect(/*name=*/"", context) {
-    addAttributes<AffineMapAttr, ArrayAttr, BoolAttr, DenseElementsAttr,
-                  DictionaryAttr, FloatAttr, SymbolRefAttr, IntegerAttr,
-                  IntegerSetAttr, OpaqueAttr, OpaqueElementsAttr,
-                  SparseElementsAttr, StringAttr, TypeAttr, UnitAttr>();
+    addAttributes<AffineMapAttr, ArrayAttr, BoolAttr, DenseIntOrFPElementsAttr,
+                  DenseStringElementsAttr, DictionaryAttr, FloatAttr,
+                  SymbolRefAttr, IntegerAttr, IntegerSetAttr, OpaqueAttr,
+                  OpaqueElementsAttr, SparseElementsAttr, StringAttr, TypeAttr,
+                  UnitAttr>();
     addAttributes<CallSiteLoc, FileLineColLoc, FusedLoc, NameLoc, OpaqueLoc,
                   UnknownLoc>();
 

diff  --git a/mlir/lib/Parser/Parser.cpp b/mlir/lib/Parser/Parser.cpp
index 8eabe3c46f4c..c5b68ddff1b0 100644
--- a/mlir/lib/Parser/Parser.cpp
+++ b/mlir/lib/Parser/Parser.cpp
@@ -1953,7 +1953,7 @@ class TensorLiteralParser {
   ArrayRef<int64_t> getShape() const { return shape; }
 
 private:
-  enum class ElementKind { Boolean, Integer, Float };
+  enum class ElementKind { Boolean, Integer, Float, String };
 
   /// Return a string to represent the given element kind.
   const char *getElementKindStr(ElementKind kind) {
@@ -1964,6 +1964,8 @@ class TensorLiteralParser {
       return "'integer'";
     case ElementKind::Float:
       return "'float'";
+    case ElementKind::String:
+      return "'string'";
     }
     llvm_unreachable("unknown element kind");
   }
@@ -1975,6 +1977,9 @@ class TensorLiteralParser {
   DenseElementsAttr getFloatAttr(llvm::SMLoc loc, ShapedType type,
                                  FloatType eltTy);
 
+  /// Build a Dense String attribute for the given type.
+  DenseElementsAttr getStringAttr(llvm::SMLoc loc, ShapedType type, Type eltTy);
+
   /// Build a Dense attribute with hex data for the given type.
   DenseElementsAttr getHexAttr(llvm::SMLoc loc, ShapedType type);
 
@@ -2030,8 +2035,10 @@ ParseResult TensorLiteralParser::parse(bool allowHex) {
 /// shaped type.
 DenseElementsAttr TensorLiteralParser::getAttr(llvm::SMLoc loc,
                                                ShapedType type) {
-  // Check to see if we parsed the literal from a hex string.
-  if (hexStorage.hasValue())
+  Type eltType = type.getElementType();
+
+  // Check to see if we parse the literal from a hex string.
+  if (hexStorage.hasValue() && eltType.isIntOrFloat())
     return getHexAttr(loc, type);
 
   // Check that the parsed storage size has the same number of elements to the
@@ -2044,20 +2051,17 @@ DenseElementsAttr TensorLiteralParser::getAttr(llvm::SMLoc loc,
 
   // If the type is an integer, build a set of APInt values from the storage
   // with the correct bitwidth.
-  Type eltType = type.getElementType();
   if (auto intTy = eltType.dyn_cast<IntegerType>())
     return getIntAttr(loc, type, intTy);
   if (auto indexTy = eltType.dyn_cast<IndexType>())
     return getIntAttr(loc, type, indexTy);
 
-  // Otherwise, this must be a floating point type.
-  auto floatTy = eltType.dyn_cast<FloatType>();
-  if (!floatTy) {
-    p.emitError(loc) << "expected floating-point or integer element type, got "
-                     << eltType;
-    return nullptr;
-  }
-  return getFloatAttr(loc, type, floatTy);
+  // If parsing a floating point type.
+  if (auto floatTy = eltType.dyn_cast<FloatType>())
+    return getFloatAttr(loc, type, floatTy);
+
+  // Other types are assumed to be string representations.
+  return getStringAttr(loc, type, type.getElementType());
 }
 
 /// Build a Dense Integer attribute for the given type.
@@ -2163,6 +2167,28 @@ DenseElementsAttr TensorLiteralParser::getFloatAttr(llvm::SMLoc loc,
   return DenseElementsAttr::get(type, floatValues);
 }
 
+/// Build a Dense String attribute for the given type.
+DenseElementsAttr TensorLiteralParser::getStringAttr(llvm::SMLoc loc,
+                                                     ShapedType type,
+                                                     Type eltTy) {
+  if (hexStorage.hasValue()) {
+    auto stringValue = hexStorage.getValue().getStringValue();
+    return DenseStringElementsAttr::get(type, {stringValue});
+  }
+
+  std::vector<std::string> stringValues;
+  std::vector<StringRef> stringRefValues;
+  stringValues.reserve(storage.size());
+  stringRefValues.reserve(storage.size());
+
+  for (auto val : storage) {
+    stringValues.push_back(val.second.getStringValue());
+    stringRefValues.push_back(stringValues.back());
+  }
+
+  return DenseStringElementsAttr::get(type, stringRefValues);
+}
+
 /// Build a Dense attribute with hex data for the given type.
 DenseElementsAttr TensorLiteralParser::getHexAttr(llvm::SMLoc loc,
                                                   ShapedType type) {
@@ -2214,6 +2240,10 @@ ParseResult TensorLiteralParser::parseElement() {
     p.consumeToken();
     break;
 
+  case Token::string:
+    storage.emplace_back(/*isNegative=*/ false, p.getToken());
+    p.consumeToken();
+    break;
   default:
     return p.emitError("expected element literal of primitive type");
   }

diff  --git a/mlir/test/IR/attribute.mlir b/mlir/test/IR/attribute.mlir
index 81edebd796b4..2a43f8aef127 100644
--- a/mlir/test/IR/attribute.mlir
+++ b/mlir/test/IR/attribute.mlir
@@ -390,6 +390,40 @@ func @correct_type_pass() {
 
 // -----
 
+//===----------------------------------------------------------------------===//
+// Test StringElementsAttr
+//===----------------------------------------------------------------------===//
+
+func @simple_scalar_example() {
+  "test.string_elements_attr"() {
+    // CHECK: dense<"example">
+    scalar_string_attr = dense<"example"> : tensor<2x!unknown<"">>
+  } : () -> ()
+  return
+}
+
+// -----
+
+func @escape_string_example() {
+  "test.string_elements_attr"() {
+    // CHECK: dense<"new\0Aline">
+    scalar_string_attr = dense<"new\nline"> : tensor<2x!unknown<"">>
+  } : () -> ()
+  return
+}
+
+// -----
+
+func @simple_scalar_example() {
+  "test.string_elements_attr"() {
+    // CHECK: dense<["example1", "example2"]>
+    scalar_string_attr = dense<["example1", "example2"]> : tensor<2x!unknown<"">>
+  } : () -> ()
+  return
+}
+
+// -----
+
 //===----------------------------------------------------------------------===//
 // Test SymbolRefAttr
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/IR/dense-elements-hex.mlir b/mlir/test/IR/dense-elements-hex.mlir
index bfd1f1b27e7f..0375004e1d84 100644
--- a/mlir/test/IR/dense-elements-hex.mlir
+++ b/mlir/test/IR/dense-elements-hex.mlir
@@ -22,10 +22,5 @@
 
 // -----
 
-// expected-error at +1 {{expected floating-point or integer element type, got '!unknown<"">'}}
-"foo.op"() {dense.attr = dense<"0x00000000000024400000000000001440"> : tensor<2x!unknown<"">>} : () -> ()
-
-// -----
-
 // expected-error at +1 {{elements hex data size is invalid for provided type}}
 "foo.op"() {dense.attr = dense<"0x00000000000024400000000000001440"> : tensor<4xf64>} : () -> ()

diff  --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index c633bde2d769..d5259639f4b5 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -245,6 +245,12 @@ def DerivedTypeAttrOp : TEST_Op<"derived_type_attr", []> {
     "$_builder.getI32IntegerAttr($_self)">;
 }
 
+def StringElementsAttrOp : TEST_Op<"string_elements_attr"> {
+  let arguments = (ins
+      StringElementsAttr:$scalar_string_attr
+  );
+}
+
 //===----------------------------------------------------------------------===//
 // Test Attribute Constraints
 //===----------------------------------------------------------------------===//


        


More information about the Mlir-commits mailing list