[llvm-branch-commits] [mlir] [mlir][IR] Separate `DenseStringElementsAttr` from `DenseElementsAttr` (PR #181385)

Matthias Springer via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Sun Feb 15 06:42:26 PST 2026


https://github.com/matthias-springer updated https://github.com/llvm/llvm-project/pull/181385

>From de9ef4c7b90cdcecde51251f8ed5052c4185dd63 Mon Sep 17 00:00:00 2001
From: Matthias Springer <me at m-sp.org>
Date: Fri, 13 Feb 2026 16:58:41 +0000
Subject: [PATCH] [mlir][IR] Separate `DenseStringElementsAttr` from
 `DenseElementsAttr`

---
 mlir/include/mlir/IR/BuiltinAttributes.h      |  37 ++----
 mlir/include/mlir/IR/BuiltinAttributes.td     | 110 +++++++++++++++---
 mlir/include/mlir/IR/CommonAttrConstraints.td |   4 +-
 mlir/lib/AsmParser/AttributeParser.cpp        |  36 ++++--
 mlir/lib/CAPI/IR/BuiltinAttributes.cpp        |  22 ++--
 mlir/lib/IR/BuiltinAttributes.cpp             |  70 ++++-------
 mlir/test/IR/parser.mlir                      |   4 -
 mlir/test/mlir-tblgen/openmp-clause-ops.td    |   2 +-
 mlir/unittests/IR/AttributeTest.cpp           |  37 +++---
 9 files changed, 192 insertions(+), 130 deletions(-)

diff --git a/mlir/include/mlir/IR/BuiltinAttributes.h b/mlir/include/mlir/IR/BuiltinAttributes.h
index ee6a8f4e4d948..3ba943c7ccd41 100644
--- a/mlir/include/mlir/IR/BuiltinAttributes.h
+++ b/mlir/include/mlir/IR/BuiltinAttributes.h
@@ -152,10 +152,6 @@ class DenseElementsAttr : public Attribute {
   /// Overload of the above 'get' method that is specialized for boolean values.
   static DenseElementsAttr get(ShapedType type, ArrayRef<bool> values);
 
-  /// Overload of the above 'get' method that is specialized for StringRef
-  /// values.
-  static DenseElementsAttr get(ShapedType type, ArrayRef<StringRef> values);
-
   /// Constructs a dense integer elements attribute from an array of APInt
   /// values. Each APInt value is expected to have the same bitwidth as the
   /// element type of 'type'. 'type' must be a vector or tensor with static
@@ -223,7 +219,8 @@ class DenseElementsAttr : public Attribute {
       decltype(std::declval<AttrT>().template getValues<T>());
 
   /// A utility iterator that allows walking over the internal Attribute values
-  /// of a DenseElementsAttr.
+  /// of a dense elements attribute (DenseElementsAttr or
+  /// DenseStringElementsAttr).
   class AttributeElementIterator
       : public llvm::indexed_accessor_iterator<AttributeElementIterator,
                                                const void *, Attribute,
@@ -232,11 +229,9 @@ class DenseElementsAttr : public Attribute {
     /// Accesses the Attribute value at this iterator position.
     Attribute operator*() const;
 
-  private:
-    friend DenseElementsAttr;
-
-    /// Constructs a new iterator.
-    AttributeElementIterator(DenseElementsAttr attr, size_t index);
+    /// Constructs a new iterator. Accepts any attribute implementing
+    /// ElementsAttr (e.g. DenseElementsAttr, DenseStringElementsAttr).
+    AttributeElementIterator(Attribute attr, size_t index);
   };
 
   /// Iterator for walking raw element values of the specified type 'T', which
@@ -461,21 +456,6 @@ class DenseElementsAttr : public Attribute {
         ElementIterator<T>(rawData, splat, getNumElements()));
   }
 
-  /// Try to get the held element values as a range of StringRef.
-  template <typename T>
-  using StringRefValueTemplateCheckT =
-      std::enable_if_t<std::is_same<T, StringRef>::value>;
-  template <typename T, typename = StringRefValueTemplateCheckT<T>>
-  FailureOr<iterator_range_impl<ElementIterator<StringRef>>>
-  tryGetValues() const {
-    auto stringRefs = getRawStringData();
-    const char *ptr = reinterpret_cast<const char *>(stringRefs.data());
-    bool splat = isSplat();
-    return iterator_range_impl<ElementIterator<StringRef>>(
-        getType(), ElementIterator<StringRef>(ptr, splat, 0),
-        ElementIterator<StringRef>(ptr, splat, getNumElements()));
-  }
-
   /// Try to get the held element values as a range of Attributes.
   template <typename T>
   using AttributeValueTemplateCheckT =
@@ -484,8 +464,8 @@ class DenseElementsAttr : public Attribute {
   FailureOr<iterator_range_impl<AttributeElementIterator>>
   tryGetValues() const {
     return iterator_range_impl<AttributeElementIterator>(
-        getType(), AttributeElementIterator(*this, 0),
-        AttributeElementIterator(*this, getNumElements()));
+        getType(), AttributeElementIterator(Attribute(*this), 0),
+        AttributeElementIterator(Attribute(*this), getNumElements()));
   }
 
   /// Try to get the held element values a range of T, where T is a derived
@@ -578,9 +558,6 @@ class DenseElementsAttr : public Attribute {
   /// form the user might expect.
   ArrayRef<char> getRawData() const;
 
-  /// Return the raw StringRef data held by this attribute.
-  ArrayRef<StringRef> getRawStringData() const;
-
   /// Return the type of this ElementsAttr, guaranteed to be a vector or tensor
   /// with static shape.
   ShapedType getType() const;
diff --git a/mlir/include/mlir/IR/BuiltinAttributes.td b/mlir/include/mlir/IR/BuiltinAttributes.td
index dced379d1f979..064783ae5f87a 100644
--- a/mlir/include/mlir/IR/BuiltinAttributes.td
+++ b/mlir/include/mlir/IR/BuiltinAttributes.td
@@ -395,8 +395,7 @@ def Builtin_DenseIntOrFPElementsAttr : Builtin_Attr<
 //===----------------------------------------------------------------------===//
 
 def Builtin_DenseStringElementsAttr : Builtin_Attr<
-    "DenseStringElements", "dense_string_elements", [ElementsAttrInterface],
-    "DenseElementsAttr"
+    "DenseStringElements", "dense_string_elements", [ElementsAttrInterface]
   > {
   let summary = "An Attribute containing a dense multi-dimensional array of "
                 "strings";
@@ -431,13 +430,97 @@ def Builtin_DenseStringElementsAttr : Builtin_Attr<
     }]>,
   ];
   let extraClassDeclaration = [{
-    using DenseElementsAttr::empty;
-    using DenseElementsAttr::getNumElements;
-    using DenseElementsAttr::getElementType;
-    using DenseElementsAttr::getValues;
-    using DenseElementsAttr::isSplat;
-    using DenseElementsAttr::size;
-    using DenseElementsAttr::value_begin;
+    /// Iterator for walking StringRef element values.
+    class StringRefElementIterator
+        : public detail::DenseElementIndexedIteratorImpl<StringRefElementIterator,
+                                                        const StringRef> {
+    public:
+      const StringRef &operator*() const {
+        return reinterpret_cast<const StringRef *>(this->getData())[this->getDataIndex()];
+      }
+      StringRefElementIterator(const char *data, bool isSplat, size_t dataIndex)
+          : detail::DenseElementIndexedIteratorImpl<StringRefElementIterator,
+                                                    const StringRef>(
+                data, isSplat, dataIndex) {}
+    };
+
+    /// Iterator for walking element values as Attribute (StringAttr).
+    class StringAttributeElementIterator
+        : public llvm::indexed_accessor_iterator<StringAttributeElementIterator,
+                                                 const void *, Attribute,
+                                                 Attribute, Attribute> {
+    public:
+      Attribute operator*() const;
+      StringAttributeElementIterator(const DenseStringElementsAttr *attr,
+                                     size_t index)
+          : llvm::indexed_accessor_iterator<StringAttributeElementIterator,
+                                            const void *, Attribute,
+                                            Attribute, Attribute>(
+                attr->getAsOpaquePointer(), index) {}
+    };
+
+    /// Return the type of this attribute (vector or tensor with static shape).
+    ShapedType getType() const;
+
+    /// Helper methods for ElementsAttr interface.
+    bool empty() const { return getNumElements() == 0; }
+    int64_t getNumElements() const { return getType().getNumElements(); }
+    Type getElementType() const { return getType().getElementType(); }
+    bool isSplat() const { return getRawStringData().size() == 1; }
+    int64_t size() const { return getNumElements(); }
+
+    /// Return the raw StringRef data held by this attribute.
+    ArrayRef<StringRef> getRawStringData() const;
+
+    /// Try to get the held element values as a range of StringRef.
+    template <typename T>
+    using StringRefValueTemplateCheckT =
+        std::enable_if_t<std::is_same<T, StringRef>::value>;
+    template <typename T, typename = StringRefValueTemplateCheckT<T>>
+    FailureOr<detail::ElementsAttrRange<StringRefElementIterator>>
+    tryGetValues() const {
+      auto stringRefs = getRawStringData();
+      const char *ptr = reinterpret_cast<const char *>(stringRefs.data());
+      bool splat = isSplat();
+      return detail::ElementsAttrRange<StringRefElementIterator>(
+          getType(), StringRefElementIterator(ptr, splat, 0),
+          StringRefElementIterator(ptr, splat, getNumElements()));
+    }
+
+    /// Try to get the held element values as a range of Attributes.
+    template <typename T>
+    using AttributeValueTemplateCheckT =
+        std::enable_if_t<std::is_same<T, Attribute>::value>;
+    template <typename T, typename = AttributeValueTemplateCheckT<T>>
+    FailureOr<detail::ElementsAttrRange<StringAttributeElementIterator>>
+    tryGetValues() const {
+      return detail::ElementsAttrRange<StringAttributeElementIterator>(
+          getType(), StringAttributeElementIterator(this, 0),
+          StringAttributeElementIterator(this, getNumElements()));
+    }
+
+    template <typename T>
+    auto getValues() const {
+      auto range = tryGetValues<T>();
+      assert(succeeded(range) && "element type cannot be iterated");
+      return std::move(*range);
+    }
+    template <typename T>
+    auto value_begin() const { return getValues<T>().begin(); }
+    template <typename T>
+    auto value_end() const { return getValues<T>().end(); }
+    /// Return the splat value. Asserts that the attribute is a splat.
+    template <typename T>
+    auto getSplatValue() const {
+      assert(isSplat() && "expected the attribute to be a splat");
+      return *value_begin<T>();
+    }
+    template <typename T>
+    auto try_value_begin() const {
+      auto range = tryGetValues<T>();
+      using iterator = decltype(range->begin());
+      return failed(range) ? FailureOr<iterator>(failure()) : range->begin();
+    }
 
     /// The set of data types that can be iterated by this attribute.
     using ContiguousIterableTypesT = std::tuple<StringRef>;
@@ -449,11 +532,6 @@ def Builtin_DenseStringElementsAttr : Builtin_Attr<
     auto try_value_begin_impl(OverloadToken<T>) const {
       return try_value_begin<T>();
     }
-
-  protected:
-    friend DenseElementsAttr;
-
-  public:
   }];
   let genAccessors = 0;
   let genStorageClass = 0;
@@ -931,9 +1009,7 @@ def Builtin_SparseElementsAttr : Builtin_Attr<
       std::complex<int16_t>, std::complex<int32_t>, std::complex<int64_t>,
       // Float types.
       APFloat, float, double,
-      std::complex<APFloat>, std::complex<float>, std::complex<double>,
-      // String types.
-      StringRef
+      std::complex<APFloat>, std::complex<float>, std::complex<double>
     >;
     using ElementsAttr::Trait<SparseElementsAttr>::getValues;
     using ElementsAttr::Trait<SparseElementsAttr>::value_begin;
diff --git a/mlir/include/mlir/IR/CommonAttrConstraints.td b/mlir/include/mlir/IR/CommonAttrConstraints.td
index ba6cf55a8fb9e..634881f5813f3 100644
--- a/mlir/include/mlir/IR/CommonAttrConstraints.td
+++ b/mlir/include/mlir/IR/CommonAttrConstraints.td
@@ -565,8 +565,8 @@ def StringElementsAttr : ElementsAttrBase<
   CPred<"::llvm::isa<::mlir::DenseStringElementsAttr>($_self)" >,
   "string elements attribute"> {
 
-  let storageType = [{ ::mlir::DenseElementsAttr }];
-  let returnType = [{ ::mlir::DenseElementsAttr }];
+  let storageType = [{ ::mlir::DenseStringElementsAttr }];
+  let returnType = [{ ::mlir::DenseStringElementsAttr }];
 
   let convertFromStorage = "$_self";
 }
diff --git a/mlir/lib/AsmParser/AttributeParser.cpp b/mlir/lib/AsmParser/AttributeParser.cpp
index dc9744a42b730..15c2e0225f98b 100644
--- a/mlir/lib/AsmParser/AttributeParser.cpp
+++ b/mlir/lib/AsmParser/AttributeParser.cpp
@@ -472,8 +472,8 @@ class TensorLiteralParser {
   ParseResult parse(bool allowHex);
 
   /// Build a dense attribute instance with the parsed elements and the given
-  /// shaped type.
-  DenseElementsAttr getAttr(SMLoc loc, ShapedType type);
+  /// shaped type. Returns DenseElementsAttr or DenseStringElementsAttr.
+  Attribute getAttr(SMLoc loc, ShapedType type);
 
   ArrayRef<int64_t> getShape() const { return shape; }
 
@@ -487,7 +487,7 @@ class TensorLiteralParser {
                                    std::vector<APFloat> &floatValues);
 
   /// Build a Dense String attribute for the given type.
-  DenseElementsAttr getStringAttr(SMLoc loc, ShapedType type, Type eltTy);
+  DenseStringElementsAttr getStringAttr(SMLoc loc, ShapedType type, Type eltTy);
 
   /// Build a Dense attribute with hex data for the given type.
   DenseElementsAttr getHexAttr(SMLoc loc, ShapedType type);
@@ -539,7 +539,7 @@ ParseResult TensorLiteralParser::parse(bool allowHex) {
 
 /// Build a dense attribute instance with the parsed elements and the given
 /// shaped type.
-DenseElementsAttr TensorLiteralParser::getAttr(SMLoc loc, ShapedType type) {
+Attribute TensorLiteralParser::getAttr(SMLoc loc, ShapedType type) {
   Type eltType = type.getElementType();
 
   // Check to see if we parse the literal from a hex string.
@@ -679,8 +679,8 @@ TensorLiteralParser::getFloatAttrElements(SMLoc loc, FloatType eltTy,
 }
 
 /// Build a Dense String attribute for the given type.
-DenseElementsAttr TensorLiteralParser::getStringAttr(SMLoc loc, ShapedType type,
-                                                     Type eltTy) {
+DenseStringElementsAttr
+TensorLiteralParser::getStringAttr(SMLoc loc, ShapedType type, Type eltTy) {
   if (hexStorage.has_value()) {
     auto stringValue = hexStorage->getStringValue();
     return DenseStringElementsAttr::get(type, {stringValue});
@@ -1174,6 +1174,13 @@ Attribute Parser::parseSparseElementsAttr(Type attrType) {
     if (!type)
       return nullptr;
 
+    // SparseElementsAttr only supports int/float element types.
+    if (!type.getElementType().isIntOrIndexOrFloat()) {
+      emitError(loc) << "sparse elements attribute does not support string "
+                        "element type";
+      return nullptr;
+    }
+
     // Construct the sparse elements attr using zero element indice/value
     // attributes.
     ShapedType indicesType =
@@ -1219,9 +1226,10 @@ Attribute Parser::parseSparseElementsAttr(Type attrType) {
     // Otherwise, set the shape to the one parsed by the literal parser.
     indicesType = RankedTensorType::get(indiceParser.getShape(), indiceEltType);
   }
-  auto indices = indiceParser.getAttr(indicesLoc, indicesType);
-  if (!indices)
+  auto indicesAttr = indiceParser.getAttr(indicesLoc, indicesType);
+  if (!indicesAttr)
     return nullptr;
+  auto indices = llvm::cast<DenseIntElementsAttr>(indicesAttr);
 
   // If the values are a splat, set the shape explicitly based on the number of
   // indices. The number of indices is encoded in the first dimension of the
@@ -1231,10 +1239,18 @@ Attribute Parser::parseSparseElementsAttr(Type attrType) {
       valuesParser.getShape().empty()
           ? RankedTensorType::get({indicesType.getDimSize(0)}, valuesEltType)
           : RankedTensorType::get(valuesParser.getShape(), valuesEltType);
-  auto values = valuesParser.getAttr(valuesLoc, valuesType);
-  if (!values)
+  auto valuesAttr = valuesParser.getAttr(valuesLoc, valuesType);
+  if (!valuesAttr)
     return nullptr;
 
+  // SparseElementsAttr only supports DenseElementsAttr for values (not string).
+  auto values = llvm::dyn_cast<DenseElementsAttr>(valuesAttr);
+  if (!values) {
+    emitError(valuesLoc)
+        << "dense string elements not supported in sparse elements attribute";
+    return nullptr;
+  }
+
   // Build the sparse elements attribute by the indices and values.
   return getChecked<SparseElementsAttr>(loc, type, indices, values);
 }
diff --git a/mlir/lib/CAPI/IR/BuiltinAttributes.cpp b/mlir/lib/CAPI/IR/BuiltinAttributes.cpp
index 44a3deaf57db5..7325179c047c5 100644
--- a/mlir/lib/CAPI/IR/BuiltinAttributes.cpp
+++ b/mlir/lib/CAPI/IR/BuiltinAttributes.cpp
@@ -728,8 +728,8 @@ MlirAttribute mlirDenseElementsAttrStringGet(MlirType shapedType,
   for (intptr_t i = 0; i < numElements; ++i)
     values.push_back(unwrap(strs[i]));
 
-  return wrap(DenseElementsAttr::get(llvm::cast<ShapedType>(unwrap(shapedType)),
-                                     values));
+  return wrap(DenseStringElementsAttr::get(
+      llvm::cast<ShapedType>(unwrap(shapedType)), values));
 }
 
 MlirAttribute mlirDenseElementsAttrReshapeGet(MlirAttribute attr,
@@ -743,12 +743,18 @@ MlirAttribute mlirDenseElementsAttrReshapeGet(MlirAttribute attr,
 //===----------------------------------------------------------------------===//
 
 bool mlirDenseElementsAttrIsSplat(MlirAttribute attr) {
-  return llvm::cast<DenseElementsAttr>(unwrap(attr)).isSplat();
+  Attribute a = unwrap(attr);
+  if (auto strAttr = llvm::dyn_cast<DenseStringElementsAttr>(a))
+    return strAttr.isSplat();
+  return llvm::cast<DenseElementsAttr>(a).isSplat();
 }
 
 MlirAttribute mlirDenseElementsAttrGetSplatValue(MlirAttribute attr) {
+  mlir::Attribute a = unwrap(attr);
+  if (auto strAttr = llvm::dyn_cast<DenseStringElementsAttr>(a))
+    return wrap(strAttr.getSplatValue<mlir::Attribute>());
   return wrap(
-      llvm::cast<DenseElementsAttr>(unwrap(attr)).getSplatValue<Attribute>());
+      llvm::cast<DenseElementsAttr>(a).getSplatValue<mlir::Attribute>());
 }
 int mlirDenseElementsAttrGetBoolSplatValue(MlirAttribute attr) {
   return llvm::cast<DenseElementsAttr>(unwrap(attr)).getSplatValue<bool>();
@@ -778,8 +784,8 @@ double mlirDenseElementsAttrGetDoubleSplatValue(MlirAttribute attr) {
   return llvm::cast<DenseElementsAttr>(unwrap(attr)).getSplatValue<double>();
 }
 MlirStringRef mlirDenseElementsAttrGetStringSplatValue(MlirAttribute attr) {
-  return wrap(
-      llvm::cast<DenseElementsAttr>(unwrap(attr)).getSplatValue<StringRef>());
+  return wrap(llvm::cast<DenseStringElementsAttr>(unwrap(attr))
+                  .getSplatValue<llvm::StringRef>());
 }
 
 //===----------------------------------------------------------------------===//
@@ -824,8 +830,8 @@ double mlirDenseElementsAttrGetDoubleValue(MlirAttribute attr, intptr_t pos) {
 }
 MlirStringRef mlirDenseElementsAttrGetStringValue(MlirAttribute attr,
                                                   intptr_t pos) {
-  return wrap(
-      llvm::cast<DenseElementsAttr>(unwrap(attr)).getValues<StringRef>()[pos]);
+  return wrap(llvm::cast<DenseStringElementsAttr>(unwrap(attr))
+                  .getValues<StringRef>()[pos]);
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/IR/BuiltinAttributes.cpp b/mlir/lib/IR/BuiltinAttributes.cpp
index bbbc9198a68ab..e288be3271fab 100644
--- a/mlir/lib/IR/BuiltinAttributes.cpp
+++ b/mlir/lib/IR/BuiltinAttributes.cpp
@@ -589,23 +589,14 @@ static bool hasSameNumElementsOrSplat(ShapedType type, const Values &values) {
 //===----------------------------------------------------------------------===//
 
 DenseElementsAttr::AttributeElementIterator::AttributeElementIterator(
-    DenseElementsAttr attr, size_t index)
+    Attribute attr, size_t index)
     : llvm::indexed_accessor_iterator<AttributeElementIterator, const void *,
                                       Attribute, Attribute, Attribute>(
           attr.getAsOpaquePointer(), index) {}
 
 Attribute DenseElementsAttr::AttributeElementIterator::operator*() const {
   auto owner = llvm::cast<DenseElementsAttr>(getFromOpaquePointer(base));
-  Type eltTy = owner.getElementType();
-
-  // Handle strings specially.
-  if (llvm::isa<DenseStringElementsAttr>(owner)) {
-    ArrayRef<StringRef> vals = owner.getRawStringData();
-    return StringAttr::get(owner.isSplat() ? vals.front() : vals[index], eltTy);
-  }
-
-  // All other types should implement DenseElementTypeInterface.
-  auto denseEltTy = llvm::cast<DenseElementType>(eltTy);
+  auto denseEltTy = llvm::cast<DenseElementType>(owner.getElementType());
   ArrayRef<char> rawData = owner.getRawData();
   // Storage is byte-aligned: align bit size up to next byte boundary.
   size_t bitSize = denseEltTy.getDenseElementBitSize();
@@ -864,28 +855,13 @@ template class DenseArrayAttrImpl<double>;
 
 /// Method for support type inquiry through isa, cast and dyn_cast.
 bool DenseElementsAttr::classof(Attribute attr) {
-  return llvm::isa<DenseIntOrFPElementsAttr, DenseStringElementsAttr>(attr);
+  return llvm::isa<DenseIntOrFPElementsAttr>(attr);
 }
 
 DenseElementsAttr DenseElementsAttr::get(ShapedType type,
                                          ArrayRef<Attribute> values) {
   assert(hasSameNumElementsOrSplat(type, values));
-  Type eltType = type.getElementType();
-
-  // Handle strings specially.
-  if (!llvm::isa<DenseElementType>(eltType)) {
-    SmallVector<StringRef, 8> stringValues;
-    stringValues.reserve(values.size());
-    for (Attribute attr : values) {
-      assert(llvm::isa<StringAttr>(attr) &&
-             "expected string value for non-DenseElementType element");
-      stringValues.push_back(llvm::cast<StringAttr>(attr).getValue());
-    }
-    return get(type, stringValues);
-  }
-
-  // All other types go through DenseElementTypeInterface.
-  auto denseEltType = llvm::dyn_cast<DenseElementType>(eltType);
+  auto denseEltType = llvm::dyn_cast<DenseElementType>(type.getElementType());
   assert(denseEltType &&
          "attempted to get DenseElementsAttr with unsupported element type");
   SmallVector<char> data;
@@ -906,12 +882,6 @@ DenseElementsAttr DenseElementsAttr::get(ShapedType type,
                            values.size()));
 }
 
-DenseElementsAttr DenseElementsAttr::get(ShapedType type,
-                                         ArrayRef<StringRef> values) {
-  assert(!type.getElementType().isIntOrFloat());
-  return DenseStringElementsAttr::get(type, values);
-}
-
 /// Constructs a dense integer elements attribute from an array of APInt
 /// values. Each APInt value is expected to have the same bitwidth as the
 /// element type of 'type'.
@@ -1048,9 +1018,6 @@ bool DenseElementsAttr::isValidComplex(int64_t dataEltSize, bool isInt,
 /// values are the same.
 bool DenseElementsAttr::isSplat() const {
   // Splat iff the data array has exactly one element.
-  if (isa<DenseStringElementsAttr>(*this))
-    return getRawStringData().size() == 1;
-  // FP/Int case.
   size_t storageSize = llvm::divideCeil(
       getDenseElementBitWidth(getType().getElementType()), CHAR_BIT);
   return getRawData().size() == storageSize;
@@ -1100,10 +1067,6 @@ ArrayRef<char> DenseElementsAttr::getRawData() const {
   return static_cast<DenseIntOrFPElementsAttrStorage *>(impl)->data;
 }
 
-ArrayRef<StringRef> DenseElementsAttr::getRawStringData() const {
-  return static_cast<DenseStringElementsAttrStorage *>(impl)->data;
-}
-
 /// Return a new DenseElementsAttr that has the same data as the current
 /// attribute, but has been reshaped to 'newType'. The new type must have the
 /// same total number of elements as well as element type.
@@ -1390,6 +1353,27 @@ bool DenseIntElementsAttr::classof(Attribute attr) {
   return false;
 }
 
+//===----------------------------------------------------------------------===//
+// DenseStringElementsAttr
+//===----------------------------------------------------------------------===//
+
+ShapedType DenseStringElementsAttr::getType() const {
+  return static_cast<const DenseStringElementsAttrStorage *>(impl)->type;
+}
+
+ArrayRef<StringRef> DenseStringElementsAttr::getRawStringData() const {
+  return static_cast<const DenseStringElementsAttrStorage *>(impl)->data;
+}
+
+Attribute
+DenseStringElementsAttr::StringAttributeElementIterator::operator*() const {
+  auto attr = llvm::cast<DenseStringElementsAttr>(
+      Attribute::getFromOpaquePointer(this->base));
+  auto data = attr.getRawStringData();
+  return StringAttr::get(attr.isSplat() ? data.front() : data[this->index],
+                         attr.getElementType());
+}
+
 //===----------------------------------------------------------------------===//
 // DenseResourceElementsAttr
 //===----------------------------------------------------------------------===//
@@ -1557,10 +1541,6 @@ Attribute SparseElementsAttr::getZeroAttr() const {
                           ArrayRef<Attribute>{zero, zero});
   }
 
-  // Handle string type.
-  if (llvm::isa<DenseStringElementsAttr>(getValues()))
-    return StringAttr::get("", eltType);
-
   // Otherwise, this is an integer.
   return IntegerAttr::get(eltType, 0);
 }
diff --git a/mlir/test/IR/parser.mlir b/mlir/test/IR/parser.mlir
index 3bb6e38b4d613..c4a415e626760 100644
--- a/mlir/test/IR/parser.mlir
+++ b/mlir/test/IR/parser.mlir
@@ -797,10 +797,6 @@ func.func @sparsetensorattr() -> () {
 // CHECK: "foof321"() {bar = sparse<> : tensor<f32>} : () -> ()
   "foof321"(){bar = sparse<> : tensor<f32>} : () -> ()
 
-// CHECK: "foostr"() {bar = sparse<0, "foo"> : tensor<1x1x1x!unknown<>>} : () -> ()
-  "foostr"(){bar = sparse<0, "foo"> : tensor<1x1x1x!unknown<>>} : () -> ()
-// CHECK: "foostr"() {bar = sparse<{{\[\[}}1, 1, 0], {{\[}}0, 1, 0], {{\[}}0, 0, 1]], {{\[}}"a", "b", "c"]> : tensor<2x2x2x!unknown<>>} : () -> ()
-  "foostr"(){bar = sparse<[[1, 1, 0], [0, 1, 0], [0, 0, 1]], ["a", "b", "c"]> : tensor<2x2x2x!unknown<>>} : () -> ()
   return
 }
 
diff --git a/mlir/test/mlir-tblgen/openmp-clause-ops.td b/mlir/test/mlir-tblgen/openmp-clause-ops.td
index 3e5896a00182b..c502b21c3baf8 100644
--- a/mlir/test/mlir-tblgen/openmp-clause-ops.td
+++ b/mlir/test/mlir-tblgen/openmp-clause-ops.td
@@ -59,7 +59,7 @@ def OpenMP_MyFirstClause : OpenMP_Clause<
 // CHECK-NEXT:   ::mlir::IntegerAttr complexOptIntAttr;
 
 // CHECK-NEXT:   ::mlir::ElementsAttr elementsAttr;
-// CHECK-NEXT:   ::mlir::DenseElementsAttr stringElementsAttr;
+// CHECK-NEXT:   ::mlir::DenseStringElementsAttr stringElementsAttr;
 // CHECK-NEXT: }
 
 def OpenMP_MySecondClause : OpenMP_Clause<
diff --git a/mlir/unittests/IR/AttributeTest.cpp b/mlir/unittests/IR/AttributeTest.cpp
index 404aa8c0dcf3d..f72aebabce280 100644
--- a/mlir/unittests/IR/AttributeTest.cpp
+++ b/mlir/unittests/IR/AttributeTest.cpp
@@ -38,6 +38,21 @@ static void testSplat(Type eltType, const EltTy &splatElt) {
     EXPECT_TRUE(newValue == splatElt);
 }
 
+template <>
+void testSplat<StringRef>(Type eltType, const StringRef &splatElt) {
+  RankedTensorType shape = RankedTensorType::get({2, 1}, eltType);
+
+  DenseStringElementsAttr splat = DenseStringElementsAttr::get(shape, splatElt);
+  EXPECT_TRUE(splat.isSplat());
+
+  auto detectedSplat =
+      DenseStringElementsAttr::get(shape, llvm::ArrayRef({splatElt, splatElt}));
+  EXPECT_EQ(detectedSplat, splat);
+
+  for (auto newValue : detectedSplat.getValues<StringRef>())
+    EXPECT_TRUE(newValue == splatElt);
+}
+
 namespace {
 TEST(DenseSplatTest, BoolSplat) {
   MLIRContext context;
@@ -184,8 +199,16 @@ TEST(DenseSplatTest, StringAttrSplat) {
   context.allowUnregisteredDialects();
   Type stringType =
       OpaqueType::get(StringAttr::get(&context, "test"), "string");
+  RankedTensorType shape = RankedTensorType::get({2, 1}, stringType);
   Attribute stringAttr = StringAttr::get("test-string", stringType);
-  testSplat(stringType, stringAttr);
+  StringRef value = llvm::cast<StringAttr>(stringAttr).getValue();
+  DenseStringElementsAttr splat = DenseStringElementsAttr::get(shape, value);
+  EXPECT_TRUE(splat.isSplat());
+  auto detectedSplat =
+      DenseStringElementsAttr::get(shape, llvm::ArrayRef({value, value}));
+  EXPECT_EQ(detectedSplat, splat);
+  for (auto newValue : detectedSplat.getValues<StringRef>())
+    EXPECT_TRUE(newValue == value);
 }
 
 TEST(DenseComplexTest, ComplexFloatSplat) {
@@ -396,11 +419,9 @@ TEST(SparseElementsAttrTest, GetZero) {
 
   IntegerType intTy = IntegerType::get(&context, 32);
   FloatType floatTy = Float32Type::get(&context);
-  Type stringTy = OpaqueType::get(StringAttr::get(&context, "test"), "string");
 
   ShapedType tensorI32 = RankedTensorType::get({2, 2}, intTy);
   ShapedType tensorF32 = RankedTensorType::get({2, 2}, floatTy);
-  ShapedType tensorString = RankedTensorType::get({2, 2}, stringTy);
 
   auto indicesType =
       RankedTensorType::get({1, 2}, IntegerType::get(&context, 64));
@@ -413,13 +434,8 @@ TEST(SparseElementsAttrTest, GetZero) {
   RankedTensorType floatValueTy = RankedTensorType::get({1}, floatTy);
   auto floatValue = DenseFPElementsAttr::get(floatValueTy, {1.0f});
 
-  RankedTensorType stringValueTy = RankedTensorType::get({1}, stringTy);
-  auto stringValue = DenseElementsAttr::get(stringValueTy, {StringRef("foo")});
-
   auto sparseInt = SparseElementsAttr::get(tensorI32, indices, intValue);
   auto sparseFloat = SparseElementsAttr::get(tensorF32, indices, floatValue);
-  auto sparseString =
-      SparseElementsAttr::get(tensorString, indices, stringValue);
 
   // Only index (0, 0) contains an element, others are supposed to return
   // the zero/empty value.
@@ -432,11 +448,6 @@ TEST(SparseElementsAttrTest, GetZero) {
       cast<FloatAttr>(sparseFloat.getValues<Attribute>()[{1, 1}]);
   EXPECT_EQ(zeroFloatValue.getValueAsDouble(), 0.0f);
   EXPECT_TRUE(zeroFloatValue.getType() == floatTy);
-
-  auto zeroStringValue =
-      cast<StringAttr>(sparseString.getValues<Attribute>()[{1, 1}]);
-  EXPECT_TRUE(zeroStringValue.empty());
-  EXPECT_TRUE(zeroStringValue.getType() == stringTy);
 }
 
 //===----------------------------------------------------------------------===//



More information about the llvm-branch-commits mailing list