[llvm-branch-commits] [mlir] [mlir][IR] Separate `DenseStringElementsAttr` from `DenseElementsAttr` (PR #181385)
Matthias Springer via llvm-branch-commits
llvm-branch-commits at lists.llvm.org
Sun Feb 15 06:42:26 PST 2026
https://github.com/matthias-springer updated https://github.com/llvm/llvm-project/pull/181385
>From de9ef4c7b90cdcecde51251f8ed5052c4185dd63 Mon Sep 17 00:00:00 2001
From: Matthias Springer <me at m-sp.org>
Date: Fri, 13 Feb 2026 16:58:41 +0000
Subject: [PATCH] [mlir][IR] Separate `DenseStringElementsAttr` from
`DenseElementsAttr`
---
mlir/include/mlir/IR/BuiltinAttributes.h | 37 ++----
mlir/include/mlir/IR/BuiltinAttributes.td | 110 +++++++++++++++---
mlir/include/mlir/IR/CommonAttrConstraints.td | 4 +-
mlir/lib/AsmParser/AttributeParser.cpp | 36 ++++--
mlir/lib/CAPI/IR/BuiltinAttributes.cpp | 22 ++--
mlir/lib/IR/BuiltinAttributes.cpp | 70 ++++-------
mlir/test/IR/parser.mlir | 4 -
mlir/test/mlir-tblgen/openmp-clause-ops.td | 2 +-
mlir/unittests/IR/AttributeTest.cpp | 37 +++---
9 files changed, 192 insertions(+), 130 deletions(-)
diff --git a/mlir/include/mlir/IR/BuiltinAttributes.h b/mlir/include/mlir/IR/BuiltinAttributes.h
index ee6a8f4e4d948..3ba943c7ccd41 100644
--- a/mlir/include/mlir/IR/BuiltinAttributes.h
+++ b/mlir/include/mlir/IR/BuiltinAttributes.h
@@ -152,10 +152,6 @@ class DenseElementsAttr : public Attribute {
/// Overload of the above 'get' method that is specialized for boolean values.
static DenseElementsAttr get(ShapedType type, ArrayRef<bool> values);
- /// Overload of the above 'get' method that is specialized for StringRef
- /// values.
- static DenseElementsAttr get(ShapedType type, ArrayRef<StringRef> values);
-
/// Constructs a dense integer elements attribute from an array of APInt
/// values. Each APInt value is expected to have the same bitwidth as the
/// element type of 'type'. 'type' must be a vector or tensor with static
@@ -223,7 +219,8 @@ class DenseElementsAttr : public Attribute {
decltype(std::declval<AttrT>().template getValues<T>());
/// A utility iterator that allows walking over the internal Attribute values
- /// of a DenseElementsAttr.
+ /// of a dense elements attribute (DenseElementsAttr or
+ /// DenseStringElementsAttr).
class AttributeElementIterator
: public llvm::indexed_accessor_iterator<AttributeElementIterator,
const void *, Attribute,
@@ -232,11 +229,9 @@ class DenseElementsAttr : public Attribute {
/// Accesses the Attribute value at this iterator position.
Attribute operator*() const;
- private:
- friend DenseElementsAttr;
-
- /// Constructs a new iterator.
- AttributeElementIterator(DenseElementsAttr attr, size_t index);
+ /// Constructs a new iterator. Accepts any attribute implementing
+ /// ElementsAttr (e.g. DenseElementsAttr, DenseStringElementsAttr).
+ AttributeElementIterator(Attribute attr, size_t index);
};
/// Iterator for walking raw element values of the specified type 'T', which
@@ -461,21 +456,6 @@ class DenseElementsAttr : public Attribute {
ElementIterator<T>(rawData, splat, getNumElements()));
}
- /// Try to get the held element values as a range of StringRef.
- template <typename T>
- using StringRefValueTemplateCheckT =
- std::enable_if_t<std::is_same<T, StringRef>::value>;
- template <typename T, typename = StringRefValueTemplateCheckT<T>>
- FailureOr<iterator_range_impl<ElementIterator<StringRef>>>
- tryGetValues() const {
- auto stringRefs = getRawStringData();
- const char *ptr = reinterpret_cast<const char *>(stringRefs.data());
- bool splat = isSplat();
- return iterator_range_impl<ElementIterator<StringRef>>(
- getType(), ElementIterator<StringRef>(ptr, splat, 0),
- ElementIterator<StringRef>(ptr, splat, getNumElements()));
- }
-
/// Try to get the held element values as a range of Attributes.
template <typename T>
using AttributeValueTemplateCheckT =
@@ -484,8 +464,8 @@ class DenseElementsAttr : public Attribute {
FailureOr<iterator_range_impl<AttributeElementIterator>>
tryGetValues() const {
return iterator_range_impl<AttributeElementIterator>(
- getType(), AttributeElementIterator(*this, 0),
- AttributeElementIterator(*this, getNumElements()));
+ getType(), AttributeElementIterator(Attribute(*this), 0),
+ AttributeElementIterator(Attribute(*this), getNumElements()));
}
/// Try to get the held element values a range of T, where T is a derived
@@ -578,9 +558,6 @@ class DenseElementsAttr : public Attribute {
/// form the user might expect.
ArrayRef<char> getRawData() const;
- /// Return the raw StringRef data held by this attribute.
- ArrayRef<StringRef> getRawStringData() const;
-
/// Return the type of this ElementsAttr, guaranteed to be a vector or tensor
/// with static shape.
ShapedType getType() const;
diff --git a/mlir/include/mlir/IR/BuiltinAttributes.td b/mlir/include/mlir/IR/BuiltinAttributes.td
index dced379d1f979..064783ae5f87a 100644
--- a/mlir/include/mlir/IR/BuiltinAttributes.td
+++ b/mlir/include/mlir/IR/BuiltinAttributes.td
@@ -395,8 +395,7 @@ def Builtin_DenseIntOrFPElementsAttr : Builtin_Attr<
//===----------------------------------------------------------------------===//
def Builtin_DenseStringElementsAttr : Builtin_Attr<
- "DenseStringElements", "dense_string_elements", [ElementsAttrInterface],
- "DenseElementsAttr"
+ "DenseStringElements", "dense_string_elements", [ElementsAttrInterface]
> {
let summary = "An Attribute containing a dense multi-dimensional array of "
"strings";
@@ -431,13 +430,97 @@ def Builtin_DenseStringElementsAttr : Builtin_Attr<
}]>,
];
let extraClassDeclaration = [{
- using DenseElementsAttr::empty;
- using DenseElementsAttr::getNumElements;
- using DenseElementsAttr::getElementType;
- using DenseElementsAttr::getValues;
- using DenseElementsAttr::isSplat;
- using DenseElementsAttr::size;
- using DenseElementsAttr::value_begin;
+ /// Iterator for walking StringRef element values.
+ class StringRefElementIterator
+ : public detail::DenseElementIndexedIteratorImpl<StringRefElementIterator,
+ const StringRef> {
+ public:
+ const StringRef &operator*() const {
+ return reinterpret_cast<const StringRef *>(this->getData())[this->getDataIndex()];
+ }
+ StringRefElementIterator(const char *data, bool isSplat, size_t dataIndex)
+ : detail::DenseElementIndexedIteratorImpl<StringRefElementIterator,
+ const StringRef>(
+ data, isSplat, dataIndex) {}
+ };
+
+ /// Iterator for walking element values as Attribute (StringAttr).
+ class StringAttributeElementIterator
+ : public llvm::indexed_accessor_iterator<StringAttributeElementIterator,
+ const void *, Attribute,
+ Attribute, Attribute> {
+ public:
+ Attribute operator*() const;
+ StringAttributeElementIterator(const DenseStringElementsAttr *attr,
+ size_t index)
+ : llvm::indexed_accessor_iterator<StringAttributeElementIterator,
+ const void *, Attribute,
+ Attribute, Attribute>(
+ attr->getAsOpaquePointer(), index) {}
+ };
+
+ /// Return the type of this attribute (vector or tensor with static shape).
+ ShapedType getType() const;
+
+ /// Helper methods for ElementsAttr interface.
+ bool empty() const { return getNumElements() == 0; }
+ int64_t getNumElements() const { return getType().getNumElements(); }
+ Type getElementType() const { return getType().getElementType(); }
+ bool isSplat() const { return getRawStringData().size() == 1; }
+ int64_t size() const { return getNumElements(); }
+
+ /// Return the raw StringRef data held by this attribute.
+ ArrayRef<StringRef> getRawStringData() const;
+
+ /// Try to get the held element values as a range of StringRef.
+ template <typename T>
+ using StringRefValueTemplateCheckT =
+ std::enable_if_t<std::is_same<T, StringRef>::value>;
+ template <typename T, typename = StringRefValueTemplateCheckT<T>>
+ FailureOr<detail::ElementsAttrRange<StringRefElementIterator>>
+ tryGetValues() const {
+ auto stringRefs = getRawStringData();
+ const char *ptr = reinterpret_cast<const char *>(stringRefs.data());
+ bool splat = isSplat();
+ return detail::ElementsAttrRange<StringRefElementIterator>(
+ getType(), StringRefElementIterator(ptr, splat, 0),
+ StringRefElementIterator(ptr, splat, getNumElements()));
+ }
+
+ /// Try to get the held element values as a range of Attributes.
+ template <typename T>
+ using AttributeValueTemplateCheckT =
+ std::enable_if_t<std::is_same<T, Attribute>::value>;
+ template <typename T, typename = AttributeValueTemplateCheckT<T>>
+ FailureOr<detail::ElementsAttrRange<StringAttributeElementIterator>>
+ tryGetValues() const {
+ return detail::ElementsAttrRange<StringAttributeElementIterator>(
+ getType(), StringAttributeElementIterator(this, 0),
+ StringAttributeElementIterator(this, getNumElements()));
+ }
+
+ template <typename T>
+ auto getValues() const {
+ auto range = tryGetValues<T>();
+ assert(succeeded(range) && "element type cannot be iterated");
+ return std::move(*range);
+ }
+ template <typename T>
+ auto value_begin() const { return getValues<T>().begin(); }
+ template <typename T>
+ auto value_end() const { return getValues<T>().end(); }
+ /// Return the splat value. Asserts that the attribute is a splat.
+ template <typename T>
+ auto getSplatValue() const {
+ assert(isSplat() && "expected the attribute to be a splat");
+ return *value_begin<T>();
+ }
+ template <typename T>
+ auto try_value_begin() const {
+ auto range = tryGetValues<T>();
+ using iterator = decltype(range->begin());
+ return failed(range) ? FailureOr<iterator>(failure()) : range->begin();
+ }
/// The set of data types that can be iterated by this attribute.
using ContiguousIterableTypesT = std::tuple<StringRef>;
@@ -449,11 +532,6 @@ def Builtin_DenseStringElementsAttr : Builtin_Attr<
auto try_value_begin_impl(OverloadToken<T>) const {
return try_value_begin<T>();
}
-
- protected:
- friend DenseElementsAttr;
-
- public:
}];
let genAccessors = 0;
let genStorageClass = 0;
@@ -931,9 +1009,7 @@ def Builtin_SparseElementsAttr : Builtin_Attr<
std::complex<int16_t>, std::complex<int32_t>, std::complex<int64_t>,
// Float types.
APFloat, float, double,
- std::complex<APFloat>, std::complex<float>, std::complex<double>,
- // String types.
- StringRef
+ std::complex<APFloat>, std::complex<float>, std::complex<double>
>;
using ElementsAttr::Trait<SparseElementsAttr>::getValues;
using ElementsAttr::Trait<SparseElementsAttr>::value_begin;
diff --git a/mlir/include/mlir/IR/CommonAttrConstraints.td b/mlir/include/mlir/IR/CommonAttrConstraints.td
index ba6cf55a8fb9e..634881f5813f3 100644
--- a/mlir/include/mlir/IR/CommonAttrConstraints.td
+++ b/mlir/include/mlir/IR/CommonAttrConstraints.td
@@ -565,8 +565,8 @@ def StringElementsAttr : ElementsAttrBase<
CPred<"::llvm::isa<::mlir::DenseStringElementsAttr>($_self)" >,
"string elements attribute"> {
- let storageType = [{ ::mlir::DenseElementsAttr }];
- let returnType = [{ ::mlir::DenseElementsAttr }];
+ let storageType = [{ ::mlir::DenseStringElementsAttr }];
+ let returnType = [{ ::mlir::DenseStringElementsAttr }];
let convertFromStorage = "$_self";
}
diff --git a/mlir/lib/AsmParser/AttributeParser.cpp b/mlir/lib/AsmParser/AttributeParser.cpp
index dc9744a42b730..15c2e0225f98b 100644
--- a/mlir/lib/AsmParser/AttributeParser.cpp
+++ b/mlir/lib/AsmParser/AttributeParser.cpp
@@ -472,8 +472,8 @@ class TensorLiteralParser {
ParseResult parse(bool allowHex);
/// Build a dense attribute instance with the parsed elements and the given
- /// shaped type.
- DenseElementsAttr getAttr(SMLoc loc, ShapedType type);
+ /// shaped type. Returns DenseElementsAttr or DenseStringElementsAttr.
+ Attribute getAttr(SMLoc loc, ShapedType type);
ArrayRef<int64_t> getShape() const { return shape; }
@@ -487,7 +487,7 @@ class TensorLiteralParser {
std::vector<APFloat> &floatValues);
/// Build a Dense String attribute for the given type.
- DenseElementsAttr getStringAttr(SMLoc loc, ShapedType type, Type eltTy);
+ DenseStringElementsAttr getStringAttr(SMLoc loc, ShapedType type, Type eltTy);
/// Build a Dense attribute with hex data for the given type.
DenseElementsAttr getHexAttr(SMLoc loc, ShapedType type);
@@ -539,7 +539,7 @@ ParseResult TensorLiteralParser::parse(bool allowHex) {
/// Build a dense attribute instance with the parsed elements and the given
/// shaped type.
-DenseElementsAttr TensorLiteralParser::getAttr(SMLoc loc, ShapedType type) {
+Attribute TensorLiteralParser::getAttr(SMLoc loc, ShapedType type) {
Type eltType = type.getElementType();
// Check to see if we parse the literal from a hex string.
@@ -679,8 +679,8 @@ TensorLiteralParser::getFloatAttrElements(SMLoc loc, FloatType eltTy,
}
/// Build a Dense String attribute for the given type.
-DenseElementsAttr TensorLiteralParser::getStringAttr(SMLoc loc, ShapedType type,
- Type eltTy) {
+DenseStringElementsAttr
+TensorLiteralParser::getStringAttr(SMLoc loc, ShapedType type, Type eltTy) {
if (hexStorage.has_value()) {
auto stringValue = hexStorage->getStringValue();
return DenseStringElementsAttr::get(type, {stringValue});
@@ -1174,6 +1174,13 @@ Attribute Parser::parseSparseElementsAttr(Type attrType) {
if (!type)
return nullptr;
+ // SparseElementsAttr only supports int/float element types.
+ if (!type.getElementType().isIntOrIndexOrFloat()) {
+ emitError(loc) << "sparse elements attribute does not support string "
+ "element type";
+ return nullptr;
+ }
+
// Construct the sparse elements attr using zero element indice/value
// attributes.
ShapedType indicesType =
@@ -1219,9 +1226,10 @@ Attribute Parser::parseSparseElementsAttr(Type attrType) {
// Otherwise, set the shape to the one parsed by the literal parser.
indicesType = RankedTensorType::get(indiceParser.getShape(), indiceEltType);
}
- auto indices = indiceParser.getAttr(indicesLoc, indicesType);
- if (!indices)
+ auto indicesAttr = indiceParser.getAttr(indicesLoc, indicesType);
+ if (!indicesAttr)
return nullptr;
+ auto indices = llvm::cast<DenseIntElementsAttr>(indicesAttr);
// If the values are a splat, set the shape explicitly based on the number of
// indices. The number of indices is encoded in the first dimension of the
@@ -1231,10 +1239,18 @@ Attribute Parser::parseSparseElementsAttr(Type attrType) {
valuesParser.getShape().empty()
? RankedTensorType::get({indicesType.getDimSize(0)}, valuesEltType)
: RankedTensorType::get(valuesParser.getShape(), valuesEltType);
- auto values = valuesParser.getAttr(valuesLoc, valuesType);
- if (!values)
+ auto valuesAttr = valuesParser.getAttr(valuesLoc, valuesType);
+ if (!valuesAttr)
return nullptr;
+ // SparseElementsAttr only supports DenseElementsAttr for values (not string).
+ auto values = llvm::dyn_cast<DenseElementsAttr>(valuesAttr);
+ if (!values) {
+ emitError(valuesLoc)
+ << "dense string elements not supported in sparse elements attribute";
+ return nullptr;
+ }
+
// Build the sparse elements attribute by the indices and values.
return getChecked<SparseElementsAttr>(loc, type, indices, values);
}
diff --git a/mlir/lib/CAPI/IR/BuiltinAttributes.cpp b/mlir/lib/CAPI/IR/BuiltinAttributes.cpp
index 44a3deaf57db5..7325179c047c5 100644
--- a/mlir/lib/CAPI/IR/BuiltinAttributes.cpp
+++ b/mlir/lib/CAPI/IR/BuiltinAttributes.cpp
@@ -728,8 +728,8 @@ MlirAttribute mlirDenseElementsAttrStringGet(MlirType shapedType,
for (intptr_t i = 0; i < numElements; ++i)
values.push_back(unwrap(strs[i]));
- return wrap(DenseElementsAttr::get(llvm::cast<ShapedType>(unwrap(shapedType)),
- values));
+ return wrap(DenseStringElementsAttr::get(
+ llvm::cast<ShapedType>(unwrap(shapedType)), values));
}
MlirAttribute mlirDenseElementsAttrReshapeGet(MlirAttribute attr,
@@ -743,12 +743,18 @@ MlirAttribute mlirDenseElementsAttrReshapeGet(MlirAttribute attr,
//===----------------------------------------------------------------------===//
bool mlirDenseElementsAttrIsSplat(MlirAttribute attr) {
- return llvm::cast<DenseElementsAttr>(unwrap(attr)).isSplat();
+ Attribute a = unwrap(attr);
+ if (auto strAttr = llvm::dyn_cast<DenseStringElementsAttr>(a))
+ return strAttr.isSplat();
+ return llvm::cast<DenseElementsAttr>(a).isSplat();
}
MlirAttribute mlirDenseElementsAttrGetSplatValue(MlirAttribute attr) {
+ mlir::Attribute a = unwrap(attr);
+ if (auto strAttr = llvm::dyn_cast<DenseStringElementsAttr>(a))
+ return wrap(strAttr.getSplatValue<mlir::Attribute>());
return wrap(
- llvm::cast<DenseElementsAttr>(unwrap(attr)).getSplatValue<Attribute>());
+ llvm::cast<DenseElementsAttr>(a).getSplatValue<mlir::Attribute>());
}
int mlirDenseElementsAttrGetBoolSplatValue(MlirAttribute attr) {
return llvm::cast<DenseElementsAttr>(unwrap(attr)).getSplatValue<bool>();
@@ -778,8 +784,8 @@ double mlirDenseElementsAttrGetDoubleSplatValue(MlirAttribute attr) {
return llvm::cast<DenseElementsAttr>(unwrap(attr)).getSplatValue<double>();
}
MlirStringRef mlirDenseElementsAttrGetStringSplatValue(MlirAttribute attr) {
- return wrap(
- llvm::cast<DenseElementsAttr>(unwrap(attr)).getSplatValue<StringRef>());
+ return wrap(llvm::cast<DenseStringElementsAttr>(unwrap(attr))
+ .getSplatValue<llvm::StringRef>());
}
//===----------------------------------------------------------------------===//
@@ -824,8 +830,8 @@ double mlirDenseElementsAttrGetDoubleValue(MlirAttribute attr, intptr_t pos) {
}
MlirStringRef mlirDenseElementsAttrGetStringValue(MlirAttribute attr,
intptr_t pos) {
- return wrap(
- llvm::cast<DenseElementsAttr>(unwrap(attr)).getValues<StringRef>()[pos]);
+ return wrap(llvm::cast<DenseStringElementsAttr>(unwrap(attr))
+ .getValues<StringRef>()[pos]);
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/IR/BuiltinAttributes.cpp b/mlir/lib/IR/BuiltinAttributes.cpp
index bbbc9198a68ab..e288be3271fab 100644
--- a/mlir/lib/IR/BuiltinAttributes.cpp
+++ b/mlir/lib/IR/BuiltinAttributes.cpp
@@ -589,23 +589,14 @@ static bool hasSameNumElementsOrSplat(ShapedType type, const Values &values) {
//===----------------------------------------------------------------------===//
DenseElementsAttr::AttributeElementIterator::AttributeElementIterator(
- DenseElementsAttr attr, size_t index)
+ Attribute attr, size_t index)
: llvm::indexed_accessor_iterator<AttributeElementIterator, const void *,
Attribute, Attribute, Attribute>(
attr.getAsOpaquePointer(), index) {}
Attribute DenseElementsAttr::AttributeElementIterator::operator*() const {
auto owner = llvm::cast<DenseElementsAttr>(getFromOpaquePointer(base));
- Type eltTy = owner.getElementType();
-
- // Handle strings specially.
- if (llvm::isa<DenseStringElementsAttr>(owner)) {
- ArrayRef<StringRef> vals = owner.getRawStringData();
- return StringAttr::get(owner.isSplat() ? vals.front() : vals[index], eltTy);
- }
-
- // All other types should implement DenseElementTypeInterface.
- auto denseEltTy = llvm::cast<DenseElementType>(eltTy);
+ auto denseEltTy = llvm::cast<DenseElementType>(owner.getElementType());
ArrayRef<char> rawData = owner.getRawData();
// Storage is byte-aligned: align bit size up to next byte boundary.
size_t bitSize = denseEltTy.getDenseElementBitSize();
@@ -864,28 +855,13 @@ template class DenseArrayAttrImpl<double>;
/// Method for support type inquiry through isa, cast and dyn_cast.
bool DenseElementsAttr::classof(Attribute attr) {
- return llvm::isa<DenseIntOrFPElementsAttr, DenseStringElementsAttr>(attr);
+ return llvm::isa<DenseIntOrFPElementsAttr>(attr);
}
DenseElementsAttr DenseElementsAttr::get(ShapedType type,
ArrayRef<Attribute> values) {
assert(hasSameNumElementsOrSplat(type, values));
- Type eltType = type.getElementType();
-
- // Handle strings specially.
- if (!llvm::isa<DenseElementType>(eltType)) {
- SmallVector<StringRef, 8> stringValues;
- stringValues.reserve(values.size());
- for (Attribute attr : values) {
- assert(llvm::isa<StringAttr>(attr) &&
- "expected string value for non-DenseElementType element");
- stringValues.push_back(llvm::cast<StringAttr>(attr).getValue());
- }
- return get(type, stringValues);
- }
-
- // All other types go through DenseElementTypeInterface.
- auto denseEltType = llvm::dyn_cast<DenseElementType>(eltType);
+ auto denseEltType = llvm::dyn_cast<DenseElementType>(type.getElementType());
assert(denseEltType &&
"attempted to get DenseElementsAttr with unsupported element type");
SmallVector<char> data;
@@ -906,12 +882,6 @@ DenseElementsAttr DenseElementsAttr::get(ShapedType type,
values.size()));
}
-DenseElementsAttr DenseElementsAttr::get(ShapedType type,
- ArrayRef<StringRef> values) {
- assert(!type.getElementType().isIntOrFloat());
- return DenseStringElementsAttr::get(type, values);
-}
-
/// Constructs a dense integer elements attribute from an array of APInt
/// values. Each APInt value is expected to have the same bitwidth as the
/// element type of 'type'.
@@ -1048,9 +1018,6 @@ bool DenseElementsAttr::isValidComplex(int64_t dataEltSize, bool isInt,
/// values are the same.
bool DenseElementsAttr::isSplat() const {
// Splat iff the data array has exactly one element.
- if (isa<DenseStringElementsAttr>(*this))
- return getRawStringData().size() == 1;
- // FP/Int case.
size_t storageSize = llvm::divideCeil(
getDenseElementBitWidth(getType().getElementType()), CHAR_BIT);
return getRawData().size() == storageSize;
@@ -1100,10 +1067,6 @@ ArrayRef<char> DenseElementsAttr::getRawData() const {
return static_cast<DenseIntOrFPElementsAttrStorage *>(impl)->data;
}
-ArrayRef<StringRef> DenseElementsAttr::getRawStringData() const {
- return static_cast<DenseStringElementsAttrStorage *>(impl)->data;
-}
-
/// Return a new DenseElementsAttr that has the same data as the current
/// attribute, but has been reshaped to 'newType'. The new type must have the
/// same total number of elements as well as element type.
@@ -1390,6 +1353,27 @@ bool DenseIntElementsAttr::classof(Attribute attr) {
return false;
}
+//===----------------------------------------------------------------------===//
+// DenseStringElementsAttr
+//===----------------------------------------------------------------------===//
+
+ShapedType DenseStringElementsAttr::getType() const {
+ return static_cast<const DenseStringElementsAttrStorage *>(impl)->type;
+}
+
+ArrayRef<StringRef> DenseStringElementsAttr::getRawStringData() const {
+ return static_cast<const DenseStringElementsAttrStorage *>(impl)->data;
+}
+
+Attribute
+DenseStringElementsAttr::StringAttributeElementIterator::operator*() const {
+ auto attr = llvm::cast<DenseStringElementsAttr>(
+ Attribute::getFromOpaquePointer(this->base));
+ auto data = attr.getRawStringData();
+ return StringAttr::get(attr.isSplat() ? data.front() : data[this->index],
+ attr.getElementType());
+}
+
//===----------------------------------------------------------------------===//
// DenseResourceElementsAttr
//===----------------------------------------------------------------------===//
@@ -1557,10 +1541,6 @@ Attribute SparseElementsAttr::getZeroAttr() const {
ArrayRef<Attribute>{zero, zero});
}
- // Handle string type.
- if (llvm::isa<DenseStringElementsAttr>(getValues()))
- return StringAttr::get("", eltType);
-
// Otherwise, this is an integer.
return IntegerAttr::get(eltType, 0);
}
diff --git a/mlir/test/IR/parser.mlir b/mlir/test/IR/parser.mlir
index 3bb6e38b4d613..c4a415e626760 100644
--- a/mlir/test/IR/parser.mlir
+++ b/mlir/test/IR/parser.mlir
@@ -797,10 +797,6 @@ func.func @sparsetensorattr() -> () {
// CHECK: "foof321"() {bar = sparse<> : tensor<f32>} : () -> ()
"foof321"(){bar = sparse<> : tensor<f32>} : () -> ()
-// CHECK: "foostr"() {bar = sparse<0, "foo"> : tensor<1x1x1x!unknown<>>} : () -> ()
- "foostr"(){bar = sparse<0, "foo"> : tensor<1x1x1x!unknown<>>} : () -> ()
-// CHECK: "foostr"() {bar = sparse<{{\[\[}}1, 1, 0], {{\[}}0, 1, 0], {{\[}}0, 0, 1]], {{\[}}"a", "b", "c"]> : tensor<2x2x2x!unknown<>>} : () -> ()
- "foostr"(){bar = sparse<[[1, 1, 0], [0, 1, 0], [0, 0, 1]], ["a", "b", "c"]> : tensor<2x2x2x!unknown<>>} : () -> ()
return
}
diff --git a/mlir/test/mlir-tblgen/openmp-clause-ops.td b/mlir/test/mlir-tblgen/openmp-clause-ops.td
index 3e5896a00182b..c502b21c3baf8 100644
--- a/mlir/test/mlir-tblgen/openmp-clause-ops.td
+++ b/mlir/test/mlir-tblgen/openmp-clause-ops.td
@@ -59,7 +59,7 @@ def OpenMP_MyFirstClause : OpenMP_Clause<
// CHECK-NEXT: ::mlir::IntegerAttr complexOptIntAttr;
// CHECK-NEXT: ::mlir::ElementsAttr elementsAttr;
-// CHECK-NEXT: ::mlir::DenseElementsAttr stringElementsAttr;
+// CHECK-NEXT: ::mlir::DenseStringElementsAttr stringElementsAttr;
// CHECK-NEXT: }
def OpenMP_MySecondClause : OpenMP_Clause<
diff --git a/mlir/unittests/IR/AttributeTest.cpp b/mlir/unittests/IR/AttributeTest.cpp
index 404aa8c0dcf3d..f72aebabce280 100644
--- a/mlir/unittests/IR/AttributeTest.cpp
+++ b/mlir/unittests/IR/AttributeTest.cpp
@@ -38,6 +38,21 @@ static void testSplat(Type eltType, const EltTy &splatElt) {
EXPECT_TRUE(newValue == splatElt);
}
+template <>
+void testSplat<StringRef>(Type eltType, const StringRef &splatElt) {
+ RankedTensorType shape = RankedTensorType::get({2, 1}, eltType);
+
+ DenseStringElementsAttr splat = DenseStringElementsAttr::get(shape, splatElt);
+ EXPECT_TRUE(splat.isSplat());
+
+ auto detectedSplat =
+ DenseStringElementsAttr::get(shape, llvm::ArrayRef({splatElt, splatElt}));
+ EXPECT_EQ(detectedSplat, splat);
+
+ for (auto newValue : detectedSplat.getValues<StringRef>())
+ EXPECT_TRUE(newValue == splatElt);
+}
+
namespace {
TEST(DenseSplatTest, BoolSplat) {
MLIRContext context;
@@ -184,8 +199,16 @@ TEST(DenseSplatTest, StringAttrSplat) {
context.allowUnregisteredDialects();
Type stringType =
OpaqueType::get(StringAttr::get(&context, "test"), "string");
+ RankedTensorType shape = RankedTensorType::get({2, 1}, stringType);
Attribute stringAttr = StringAttr::get("test-string", stringType);
- testSplat(stringType, stringAttr);
+ StringRef value = llvm::cast<StringAttr>(stringAttr).getValue();
+ DenseStringElementsAttr splat = DenseStringElementsAttr::get(shape, value);
+ EXPECT_TRUE(splat.isSplat());
+ auto detectedSplat =
+ DenseStringElementsAttr::get(shape, llvm::ArrayRef({value, value}));
+ EXPECT_EQ(detectedSplat, splat);
+ for (auto newValue : detectedSplat.getValues<StringRef>())
+ EXPECT_TRUE(newValue == value);
}
TEST(DenseComplexTest, ComplexFloatSplat) {
@@ -396,11 +419,9 @@ TEST(SparseElementsAttrTest, GetZero) {
IntegerType intTy = IntegerType::get(&context, 32);
FloatType floatTy = Float32Type::get(&context);
- Type stringTy = OpaqueType::get(StringAttr::get(&context, "test"), "string");
ShapedType tensorI32 = RankedTensorType::get({2, 2}, intTy);
ShapedType tensorF32 = RankedTensorType::get({2, 2}, floatTy);
- ShapedType tensorString = RankedTensorType::get({2, 2}, stringTy);
auto indicesType =
RankedTensorType::get({1, 2}, IntegerType::get(&context, 64));
@@ -413,13 +434,8 @@ TEST(SparseElementsAttrTest, GetZero) {
RankedTensorType floatValueTy = RankedTensorType::get({1}, floatTy);
auto floatValue = DenseFPElementsAttr::get(floatValueTy, {1.0f});
- RankedTensorType stringValueTy = RankedTensorType::get({1}, stringTy);
- auto stringValue = DenseElementsAttr::get(stringValueTy, {StringRef("foo")});
-
auto sparseInt = SparseElementsAttr::get(tensorI32, indices, intValue);
auto sparseFloat = SparseElementsAttr::get(tensorF32, indices, floatValue);
- auto sparseString =
- SparseElementsAttr::get(tensorString, indices, stringValue);
// Only index (0, 0) contains an element, others are supposed to return
// the zero/empty value.
@@ -432,11 +448,6 @@ TEST(SparseElementsAttrTest, GetZero) {
cast<FloatAttr>(sparseFloat.getValues<Attribute>()[{1, 1}]);
EXPECT_EQ(zeroFloatValue.getValueAsDouble(), 0.0f);
EXPECT_TRUE(zeroFloatValue.getType() == floatTy);
-
- auto zeroStringValue =
- cast<StringAttr>(sparseString.getValues<Attribute>()[{1, 1}]);
- EXPECT_TRUE(zeroStringValue.empty());
- EXPECT_TRUE(zeroStringValue.getType() == stringTy);
}
//===----------------------------------------------------------------------===//
More information about the llvm-branch-commits
mailing list