[Mlir-commits] [mlir] 0cb5d7f - [mlir] Add value_begin/value_end methods to DenseElementsAttr
River Riddle
llvmlistbot at llvm.org
Mon Sep 20 18:58:12 PDT 2021
Author: River Riddle
Date: 2021-09-21T01:57:43Z
New Revision: 0cb5d7fc7fd3eeb40b6ecf9b34a497d46bcba6c6
URL: https://github.com/llvm/llvm-project/commit/0cb5d7fc7fd3eeb40b6ecf9b34a497d46bcba6c6
DIFF: https://github.com/llvm/llvm-project/commit/0cb5d7fc7fd3eeb40b6ecf9b34a497d46bcba6c6.diff
LOG: [mlir] Add value_begin/value_end methods to DenseElementsAttr
Currently DenseElementsAttr only exposes the ability to get the full range of values for a given type T, but there are many situations where we just want the beginning/end iterator. This revision adds proper value_begin/value_end methods for all of the supported T types, and also cleans up a bit of the interface.
Differential Revision: https://reviews.llvm.org/D104173
Added:
Modified:
mlir/examples/toy/Ch5/mlir/LowerToAffineLoops.cpp
mlir/examples/toy/Ch6/mlir/LowerToAffineLoops.cpp
mlir/examples/toy/Ch7/mlir/LowerToAffineLoops.cpp
mlir/include/mlir/Dialect/CommonFolders.h
mlir/include/mlir/IR/BuiltinAttributes.h
mlir/include/mlir/IR/BuiltinAttributes.td
mlir/lib/CAPI/IR/BuiltinAttributes.cpp
mlir/lib/Conversion/LinalgToSPIRV/LinalgToSPIRV.cpp
mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp
mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
mlir/lib/IR/AsmPrinter.cpp
mlir/lib/IR/BuiltinAttributes.cpp
mlir/lib/IR/Operation.cpp
mlir/lib/Interfaces/InferTypeOpInterface.cpp
mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp
mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
mlir/unittests/TableGen/StructsGenTest.cpp
Removed:
################################################################################
diff --git a/mlir/examples/toy/Ch5/mlir/LowerToAffineLoops.cpp b/mlir/examples/toy/Ch5/mlir/LowerToAffineLoops.cpp
index 8acbc37c77f7c..a647a51f0c9ab 100644
--- a/mlir/examples/toy/Ch5/mlir/LowerToAffineLoops.cpp
+++ b/mlir/examples/toy/Ch5/mlir/LowerToAffineLoops.cpp
@@ -165,7 +165,7 @@ struct ConstantOpLowering : public OpRewritePattern<toy::ConstantOp> {
// functor recursively walks the dimensions of the constant shape,
// generating a store when the recursion hits the base case.
SmallVector<Value, 2> indices;
- auto valueIt = constantValue.getValues<FloatAttr>().begin();
+ auto valueIt = constantValue.value_begin<FloatAttr>();
std::function<void(uint64_t)> storeElements = [&](uint64_t dimension) {
// The last dimension is the base case of the recursion, at this point
// we store the element at the given index.
diff --git a/mlir/examples/toy/Ch6/mlir/LowerToAffineLoops.cpp b/mlir/examples/toy/Ch6/mlir/LowerToAffineLoops.cpp
index c1ad4dc66e996..1684dc6cf32bc 100644
--- a/mlir/examples/toy/Ch6/mlir/LowerToAffineLoops.cpp
+++ b/mlir/examples/toy/Ch6/mlir/LowerToAffineLoops.cpp
@@ -164,7 +164,7 @@ struct ConstantOpLowering : public OpRewritePattern<toy::ConstantOp> {
// functor recursively walks the dimensions of the constant shape,
// generating a store when the recursion hits the base case.
SmallVector<Value, 2> indices;
- auto valueIt = constantValue.getValues<FloatAttr>().begin();
+ auto valueIt = constantValue.value_begin<FloatAttr>();
std::function<void(uint64_t)> storeElements = [&](uint64_t dimension) {
// The last dimension is the base case of the recursion, at this point
// we store the element at the given index.
diff --git a/mlir/examples/toy/Ch7/mlir/LowerToAffineLoops.cpp b/mlir/examples/toy/Ch7/mlir/LowerToAffineLoops.cpp
index 8acbc37c77f7c..a647a51f0c9ab 100644
--- a/mlir/examples/toy/Ch7/mlir/LowerToAffineLoops.cpp
+++ b/mlir/examples/toy/Ch7/mlir/LowerToAffineLoops.cpp
@@ -165,7 +165,7 @@ struct ConstantOpLowering : public OpRewritePattern<toy::ConstantOp> {
// functor recursively walks the dimensions of the constant shape,
// generating a store when the recursion hits the base case.
SmallVector<Value, 2> indices;
- auto valueIt = constantValue.getValues<FloatAttr>().begin();
+ auto valueIt = constantValue.value_begin<FloatAttr>();
std::function<void(uint64_t)> storeElements = [&](uint64_t dimension) {
// The last dimension is the base case of the recursion, at this point
// we store the element at the given index.
diff --git a/mlir/include/mlir/Dialect/CommonFolders.h b/mlir/include/mlir/Dialect/CommonFolders.h
index fb5991fa72af7..a52b89027c5b6 100644
--- a/mlir/include/mlir/Dialect/CommonFolders.h
+++ b/mlir/include/mlir/Dialect/CommonFolders.h
@@ -58,8 +58,8 @@ Attribute constFoldBinaryOp(ArrayRef<Attribute> operands,
auto lhs = operands[0].cast<ElementsAttr>();
auto rhs = operands[1].cast<ElementsAttr>();
- auto lhsIt = lhs.getValues<ElementValueT>().begin();
- auto rhsIt = rhs.getValues<ElementValueT>().begin();
+ auto lhsIt = lhs.value_begin<ElementValueT>();
+ auto rhsIt = rhs.value_begin<ElementValueT>();
SmallVector<ElementValueT, 4> elementResults;
elementResults.reserve(lhs.getNumElements());
for (size_t i = 0, e = lhs.getNumElements(); i < e; ++i, ++lhsIt, ++rhsIt)
diff --git a/mlir/include/mlir/IR/BuiltinAttributes.h b/mlir/include/mlir/IR/BuiltinAttributes.h
index d332718bb9b17..6edd56b09b55e 100644
--- a/mlir/include/mlir/IR/BuiltinAttributes.h
+++ b/mlir/include/mlir/IR/BuiltinAttributes.h
@@ -51,6 +51,9 @@ class ElementsAttr : public Attribute {
/// with static shape.
ShapedType getType() const;
+ /// Return the element type of this ElementsAttr.
+ Type getElementType() const;
+
/// Return the value at the given index. The index is expected to refer to a
/// valid element.
Attribute getValue(ArrayRef<uint64_t> index) const;
@@ -65,8 +68,9 @@ class ElementsAttr : public Attribute {
/// Return the elements of this attribute as a value of type 'T'. Note:
/// Aborts if the subclass is OpaqueElementsAttrs, these attrs do not support
/// iteration.
- template <typename T>
- iterator_range<T> getValues() const;
+ template <typename T> iterator_range<T> getValues() const;
+ template <typename T> iterator<T> value_begin() const;
+ template <typename T> iterator<T> value_end() const;
/// Return if the given 'index' refers to a valid element in this attribute.
bool isValidIndex(ArrayRef<uint64_t> index) const;
@@ -417,7 +421,7 @@ class DenseElementsAttr : public ElementsAttr {
T>::type
getSplatValue() const {
assert(isSplat() && "expected the attribute to be a splat");
- return *getValues<T>().begin();
+ return *value_begin<T>();
}
/// Return the splat value for derived attribute element types.
template <typename T>
@@ -436,15 +440,21 @@ class DenseElementsAttr : public ElementsAttr {
template <typename T>
T getValue(ArrayRef<uint64_t> index) const {
// Skip to the element corresponding to the flattened index.
- return *std::next(getValues<T>().begin(), getFlattenedIndex(index));
+ return getFlatValue<T>(getFlattenedIndex(index));
+ }
+ /// Return the value at the given flattened index.
+ template <typename T> T getFlatValue(uint64_t index) const {
+ return *std::next(value_begin<T>(), index);
}
/// Return the held element values as a range of integer or floating-point
/// values.
- template <typename T, typename = typename std::enable_if<
- (!std::is_same<T, bool>::value &&
- std::numeric_limits<T>::is_integer) ||
- is_valid_cpp_fp_type<T>::value>::type>
+ template <typename T>
+ using IntFloatValueTemplateCheckT =
+ typename std::enable_if<(!std::is_same<T, bool>::value &&
+ std::numeric_limits<T>::is_integer) ||
+ is_valid_cpp_fp_type<T>::value>::type;
+ template <typename T, typename = IntFloatValueTemplateCheckT<T>>
llvm::iterator_range<ElementIterator<T>> getValues() const {
assert(isValidIntOrFloat(sizeof(T), std::numeric_limits<T>::is_integer,
std::numeric_limits<T>::is_signed));
@@ -453,13 +463,27 @@ class DenseElementsAttr : public ElementsAttr {
return {ElementIterator<T>(rawData, splat, 0),
ElementIterator<T>(rawData, splat, getNumElements())};
}
+ template <typename T, typename = IntFloatValueTemplateCheckT<T>>
+ ElementIterator<T> value_begin() const {
+ assert(isValidIntOrFloat(sizeof(T), std::numeric_limits<T>::is_integer,
+ std::numeric_limits<T>::is_signed));
+ return ElementIterator<T>(getRawData().data(), isSplat(), 0);
+ }
+ template <typename T, typename = IntFloatValueTemplateCheckT<T>>
+ ElementIterator<T> value_end() const {
+ assert(isValidIntOrFloat(sizeof(T), std::numeric_limits<T>::is_integer,
+ std::numeric_limits<T>::is_signed));
+ return ElementIterator<T>(getRawData().data(), isSplat(), getNumElements());
+ }
/// Return the held element values as a range of std::complex.
+ template <typename T, typename ElementT>
+ using ComplexValueTemplateCheckT =
+ typename std::enable_if<detail::is_complex_t<T>::value &&
+ (std::numeric_limits<ElementT>::is_integer ||
+ is_valid_cpp_fp_type<ElementT>::value)>::type;
template <typename T, typename ElementT = typename T::value_type,
- typename = typename std::enable_if<
- detail::is_complex_t<T>::value &&
- (std::numeric_limits<ElementT>::is_integer ||
- is_valid_cpp_fp_type<ElementT>::value)>::type>
+ typename = ComplexValueTemplateCheckT<T, ElementT>>
llvm::iterator_range<ElementIterator<T>> getValues() const {
assert(isValidComplex(sizeof(T), std::numeric_limits<ElementT>::is_integer,
std::numeric_limits<ElementT>::is_signed));
@@ -468,10 +492,26 @@ class DenseElementsAttr : public ElementsAttr {
return {ElementIterator<T>(rawData, splat, 0),
ElementIterator<T>(rawData, splat, getNumElements())};
}
+ template <typename T, typename ElementT = typename T::value_type,
+ typename = ComplexValueTemplateCheckT<T, ElementT>>
+ ElementIterator<T> value_begin() const {
+ assert(isValidComplex(sizeof(T), std::numeric_limits<ElementT>::is_integer,
+ std::numeric_limits<ElementT>::is_signed));
+ return ElementIterator<T>(getRawData().data(), isSplat(), 0);
+ }
+ template <typename T, typename ElementT = typename T::value_type,
+ typename = ComplexValueTemplateCheckT<T, ElementT>>
+ ElementIterator<T> value_end() const {
+ assert(isValidComplex(sizeof(T), std::numeric_limits<ElementT>::is_integer,
+ std::numeric_limits<ElementT>::is_signed));
+ return ElementIterator<T>(getRawData().data(), isSplat(), getNumElements());
+ }
/// Return the held element values as a range of StringRef.
- template <typename T, typename = typename std::enable_if<
- std::is_same<T, StringRef>::value>::type>
+ template <typename T>
+ using StringRefValueTemplateCheckT =
+ typename std::enable_if<std::is_same<T, StringRef>::value>::type;
+ template <typename T, typename = StringRefValueTemplateCheckT<T>>
llvm::iterator_range<ElementIterator<StringRef>> getValues() const {
auto stringRefs = getRawStringData();
const char *ptr = reinterpret_cast<const char *>(stringRefs.data());
@@ -479,80 +519,156 @@ class DenseElementsAttr : public ElementsAttr {
return {ElementIterator<StringRef>(ptr, splat, 0),
ElementIterator<StringRef>(ptr, splat, getNumElements())};
}
+ template <typename T, typename = StringRefValueTemplateCheckT<T>>
+ ElementIterator<StringRef> value_begin() const {
+ const char *ptr = reinterpret_cast<const char *>(getRawStringData().data());
+ return ElementIterator<StringRef>(ptr, isSplat(), 0);
+ }
+ template <typename T, typename = StringRefValueTemplateCheckT<T>>
+ ElementIterator<StringRef> value_end() const {
+ const char *ptr = reinterpret_cast<const char *>(getRawStringData().data());
+ return ElementIterator<StringRef>(ptr, isSplat(), getNumElements());
+ }
/// Return the held element values as a range of Attributes.
- llvm::iterator_range<AttributeElementIterator> getAttributeValues() const;
- template <typename T, typename = typename std::enable_if<
- std::is_same<T, Attribute>::value>::type>
+ template <typename T>
+ using AttributeValueTemplateCheckT =
+ typename std::enable_if<std::is_same<T, Attribute>::value>::type;
+ template <typename T, typename = AttributeValueTemplateCheckT<T>>
llvm::iterator_range<AttributeElementIterator> getValues() const {
- return getAttributeValues();
+ return {value_begin<Attribute>(), value_end<Attribute>()};
+ }
+ template <typename T, typename = AttributeValueTemplateCheckT<T>>
+ AttributeElementIterator value_begin() const {
+ return AttributeElementIterator(*this, 0);
+ }
+ template <typename T, typename = AttributeValueTemplateCheckT<T>>
+ AttributeElementIterator value_end() const {
+ return AttributeElementIterator(*this, getNumElements());
}
- AttributeElementIterator attr_value_begin() const;
- AttributeElementIterator attr_value_end() const;
/// Return the held element values a range of T, where T is a derived
/// attribute type.
template <typename T>
+ using DerivedAttrValueTemplateCheckT =
+ typename std::enable_if<std::is_base_of<Attribute, T>::value &&
+ !std::is_same<Attribute, T>::value>::type;
+ template <typename T>
using DerivedAttributeElementIterator =
llvm::mapped_iterator<AttributeElementIterator, T (*)(Attribute)>;
- template <typename T, typename = typename std::enable_if<
- std::is_base_of<Attribute, T>::value &&
- !std::is_same<Attribute, T>::value>::type>
+ template <typename T, typename = DerivedAttrValueTemplateCheckT<T>>
llvm::iterator_range<DerivedAttributeElementIterator<T>> getValues() const {
auto castFn = [](Attribute attr) { return attr.template cast<T>(); };
- return llvm::map_range(getAttributeValues(),
+ return llvm::map_range(getValues<Attribute>(),
static_cast<T (*)(Attribute)>(castFn));
}
+ template <typename T, typename = DerivedAttrValueTemplateCheckT<T>>
+ DerivedAttributeElementIterator<T> value_begin() const {
+ auto castFn = [](Attribute attr) { return attr.template cast<T>(); };
+ return {value_begin<Attribute>(), static_cast<T (*)(Attribute)>(castFn)};
+ }
+ template <typename T, typename = DerivedAttrValueTemplateCheckT<T>>
+ DerivedAttributeElementIterator<T> value_end() const {
+ auto castFn = [](Attribute attr) { return attr.template cast<T>(); };
+ return {value_end<Attribute>(), static_cast<T (*)(Attribute)>(castFn)};
+ }
/// Return the held element values as a range of bool. The element type of
/// this attribute must be of integer type of bitwidth 1.
- llvm::iterator_range<BoolElementIterator> getBoolValues() const;
- template <typename T, typename = typename std::enable_if<
- std::is_same<T, bool>::value>::type>
+ template <typename T>
+ using BoolValueTemplateCheckT =
+ typename std::enable_if<std::is_same<T, bool>::value>::type;
+ template <typename T, typename = BoolValueTemplateCheckT<T>>
llvm::iterator_range<BoolElementIterator> getValues() const {
- return getBoolValues();
+ assert(isValidBool() && "bool is not the value of this elements attribute");
+ return {BoolElementIterator(*this, 0),
+ BoolElementIterator(*this, getNumElements())};
+ }
+ template <typename T, typename = BoolValueTemplateCheckT<T>>
+ BoolElementIterator value_begin() const {
+ assert(isValidBool() && "bool is not the value of this elements attribute");
+ return BoolElementIterator(*this, 0);
+ }
+ template <typename T, typename = BoolValueTemplateCheckT<T>>
+ BoolElementIterator value_end() const {
+ assert(isValidBool() && "bool is not the value of this elements attribute");
+ return BoolElementIterator(*this, getNumElements());
}
/// Return the held element values as a range of APInts. The element type of
/// this attribute must be of integer type.
- llvm::iterator_range<IntElementIterator> getIntValues() const;
- template <typename T, typename = typename std::enable_if<
- std::is_same<T, APInt>::value>::type>
+ template <typename T>
+ using APIntValueTemplateCheckT =
+ typename std::enable_if<std::is_same<T, APInt>::value>::type;
+ template <typename T, typename = APIntValueTemplateCheckT<T>>
llvm::iterator_range<IntElementIterator> getValues() const {
- return getIntValues();
+ assert(getElementType().isIntOrIndex() && "expected integral type");
+ return {raw_int_begin(), raw_int_end()};
+ }
+ template <typename T, typename = APIntValueTemplateCheckT<T>>
+ IntElementIterator value_begin() const {
+ assert(getElementType().isIntOrIndex() && "expected integral type");
+ return raw_int_begin();
+ }
+ template <typename T, typename = APIntValueTemplateCheckT<T>>
+ IntElementIterator value_end() const {
+ assert(getElementType().isIntOrIndex() && "expected integral type");
+ return raw_int_end();
}
- IntElementIterator int_value_begin() const;
- IntElementIterator int_value_end() const;
/// Return the held element values as a range of complex APInts. The element
/// type of this attribute must be a complex of integer type.
- llvm::iterator_range<ComplexIntElementIterator> getComplexIntValues() const;
- template <typename T, typename = typename std::enable_if<
- std::is_same<T, std::complex<APInt>>::value>::type>
+ template <typename T>
+ using ComplexAPIntValueTemplateCheckT = typename std::enable_if<
+ std::is_same<T, std::complex<APInt>>::value>::type;
+ template <typename T, typename = ComplexAPIntValueTemplateCheckT<T>>
llvm::iterator_range<ComplexIntElementIterator> getValues() const {
return getComplexIntValues();
}
+ template <typename T, typename = ComplexAPIntValueTemplateCheckT<T>>
+ ComplexIntElementIterator value_begin() const {
+ return complex_value_begin();
+ }
+ template <typename T, typename = ComplexAPIntValueTemplateCheckT<T>>
+ ComplexIntElementIterator value_end() const {
+ return complex_value_end();
+ }
/// Return the held element values as a range of APFloat. The element type of
/// this attribute must be of float type.
- llvm::iterator_range<FloatElementIterator> getFloatValues() const;
- template <typename T, typename = typename std::enable_if<
- std::is_same<T, APFloat>::value>::type>
+ template <typename T>
+ using APFloatValueTemplateCheckT =
+ typename std::enable_if<std::is_same<T, APFloat>::value>::type;
+ template <typename T, typename = APFloatValueTemplateCheckT<T>>
llvm::iterator_range<FloatElementIterator> getValues() const {
return getFloatValues();
}
- FloatElementIterator float_value_begin() const;
- FloatElementIterator float_value_end() const;
+ template <typename T, typename = APFloatValueTemplateCheckT<T>>
+ FloatElementIterator value_begin() const {
+ return float_value_begin();
+ }
+ template <typename T, typename = APFloatValueTemplateCheckT<T>>
+ FloatElementIterator value_end() const {
+ return float_value_end();
+ }
/// Return the held element values as a range of complex APFloat. The element
/// type of this attribute must be a complex of float type.
- llvm::iterator_range<ComplexFloatElementIterator>
- getComplexFloatValues() const;
- template <typename T, typename = typename std::enable_if<std::is_same<
- T, std::complex<APFloat>>::value>::type>
+ template <typename T>
+ using ComplexAPFloatValueTemplateCheckT = typename std::enable_if<
+ std::is_same<T, std::complex<APFloat>>::value>::type;
+ template <typename T, typename = ComplexAPFloatValueTemplateCheckT<T>>
llvm::iterator_range<ComplexFloatElementIterator> getValues() const {
return getComplexFloatValues();
}
+ template <typename T, typename = ComplexAPFloatValueTemplateCheckT<T>>
+ ComplexFloatElementIterator value_begin() const {
+ return complex_float_value_begin();
+ }
+ template <typename T, typename = ComplexAPFloatValueTemplateCheckT<T>>
+ ComplexFloatElementIterator value_end() const {
+ return complex_float_value_end();
+ }
/// Return the raw storage data held by this attribute. Users should generally
/// not use this directly, as the internal storage format is not always in the
@@ -590,13 +706,25 @@ class DenseElementsAttr : public ElementsAttr {
function_ref<APInt(const APFloat &)> mapping) const;
protected:
- /// Get iterators to the raw APInt values for each element in this attribute.
+ /// Iterators to various elements that require out-of-line definition. These
+ /// are hidden from the user to encourage consistent use of the
+ /// getValues/value_begin/value_end API.
IntElementIterator raw_int_begin() const {
return IntElementIterator(*this, 0);
}
IntElementIterator raw_int_end() const {
return IntElementIterator(*this, getNumElements());
}
+ llvm::iterator_range<ComplexIntElementIterator> getComplexIntValues() const;
+ ComplexIntElementIterator complex_value_begin() const;
+ ComplexIntElementIterator complex_value_end() const;
+ llvm::iterator_range<FloatElementIterator> getFloatValues() const;
+ FloatElementIterator float_value_begin() const;
+ FloatElementIterator float_value_end() const;
+ llvm::iterator_range<ComplexFloatElementIterator>
+ getComplexFloatValues() const;
+ ComplexFloatElementIterator complex_float_value_begin() const;
+ ComplexFloatElementIterator complex_float_value_end() const;
/// Overload of the raw 'get' method that asserts that the given type is of
/// complex type. This method is used to verify type invariants that the
@@ -616,11 +744,8 @@ class DenseElementsAttr : public ElementsAttr {
/// Check the information for a C++ data type, check if this type is valid for
/// the current attribute. This method is used to verify specific type
/// invariants that the templatized 'getValues' method cannot.
+ bool isValidBool() const { return getElementType().isInteger(1); }
bool isValidIntOrFloat(int64_t dataEltSize, bool isInt, bool isSigned) const;
-
- /// Check the information for a C++ data type, check if this type is valid for
- /// the current attribute. This method is used to verify specific type
- /// invariants that the templatized 'getValues' method cannot.
bool isValidComplex(int64_t dataEltSize, bool isInt, bool isSigned) const;
};
@@ -806,7 +931,7 @@ template <typename T>
auto SparseElementsAttr::getValues() const
-> llvm::iterator_range<iterator<T>> {
auto zeroValue = getZeroValue<T>();
- auto valueIt = getValues().getValues<T>().begin();
+ auto valueIt = getValues().value_begin<T>();
const std::vector<ptr
diff _t> flatSparseIndices(getFlattenedSparseIndices());
std::function<T(ptr
diff _t)> mapFn =
[flatSparseIndices{std::move(flatSparseIndices)},
@@ -821,6 +946,14 @@ auto SparseElementsAttr::getValues() const
};
return llvm::map_range(llvm::seq<ptr
diff _t>(0, getNumElements()), mapFn);
}
+template <typename T>
+auto SparseElementsAttr::value_begin() const -> iterator<T> {
+ return getValues<T>().begin();
+}
+template <typename T>
+auto SparseElementsAttr::value_end() const -> iterator<T> {
+ return getValues<T>().end();
+}
namespace detail {
/// This class represents a general iterator over the values of an ElementsAttr.
@@ -833,8 +966,7 @@ class ElementsAttrIterator
// NOTE: We use a dummy enable_if here because MSVC cannot use 'decltype'
// inside of a conversion operator.
using DenseIteratorT = typename std::enable_if<
- true,
- decltype(std::declval<DenseElementsAttr>().getValues<T>().begin())>::type;
+ true, decltype(std::declval<DenseElementsAttr>().value_begin<T>())>::type;
using SparseIteratorT = SparseElementsAttr::iterator<T>;
/// A union containing the specific iterators for each derived attribute kind.
@@ -960,6 +1092,21 @@ auto ElementsAttr::getValues() const -> iterator_range<T> {
llvm_unreachable("unexpected attribute kind");
}
+template <typename T> auto ElementsAttr::value_begin() const -> iterator<T> {
+ if (DenseElementsAttr denseAttr = dyn_cast<DenseElementsAttr>())
+ return iterator<T>(*this, denseAttr.value_begin<T>());
+ if (SparseElementsAttr sparseAttr = dyn_cast<SparseElementsAttr>())
+ return iterator<T>(*this, sparseAttr.value_begin<T>());
+ llvm_unreachable("unexpected attribute kind");
+}
+template <typename T> auto ElementsAttr::value_end() const -> iterator<T> {
+ if (DenseElementsAttr denseAttr = dyn_cast<DenseElementsAttr>())
+ return iterator<T>(*this, denseAttr.value_end<T>());
+ if (SparseElementsAttr sparseAttr = dyn_cast<SparseElementsAttr>())
+ return iterator<T>(*this, sparseAttr.value_end<T>());
+ llvm_unreachable("unexpected attribute kind");
+}
+
} // end namespace mlir.
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/IR/BuiltinAttributes.td b/mlir/include/mlir/IR/BuiltinAttributes.td
index 25e54cbfd68c9..e39d56f146ef1 100644
--- a/mlir/include/mlir/IR/BuiltinAttributes.td
+++ b/mlir/include/mlir/IR/BuiltinAttributes.td
@@ -721,6 +721,8 @@ def Builtin_SparseElementsAttr
/// 'T' may be any of Attribute, APInt, APFloat, c++ integer/float types,
/// etc.
template <typename T> llvm::iterator_range<iterator<T>> getValues() const;
+ template <typename T> iterator<T> value_begin() const;
+ template <typename T> iterator<T> value_end() const;
/// Return the value of the element at the given index. The 'index' is
/// expected to refer to a valid element.
diff --git a/mlir/lib/CAPI/IR/BuiltinAttributes.cpp b/mlir/lib/CAPI/IR/BuiltinAttributes.cpp
index 4ae54d4caad4e..a2ee06722f0d8 100644
--- a/mlir/lib/CAPI/IR/BuiltinAttributes.cpp
+++ b/mlir/lib/CAPI/IR/BuiltinAttributes.cpp
@@ -505,48 +505,36 @@ MlirStringRef mlirDenseElementsAttrGetStringSplatValue(MlirAttribute attr) {
// Indexed accessors.
bool mlirDenseElementsAttrGetBoolValue(MlirAttribute attr, intptr_t pos) {
- return *(unwrap(attr).cast<DenseElementsAttr>().getValues<bool>().begin() +
- pos);
+ return unwrap(attr).cast<DenseElementsAttr>().getFlatValue<bool>(pos);
}
int8_t mlirDenseElementsAttrGetInt8Value(MlirAttribute attr, intptr_t pos) {
- return *(unwrap(attr).cast<DenseElementsAttr>().getValues<int8_t>().begin() +
- pos);
+ return unwrap(attr).cast<DenseElementsAttr>().getFlatValue<int8_t>(pos);
}
uint8_t mlirDenseElementsAttrGetUInt8Value(MlirAttribute attr, intptr_t pos) {
- return *(unwrap(attr).cast<DenseElementsAttr>().getValues<uint8_t>().begin() +
- pos);
+ return unwrap(attr).cast<DenseElementsAttr>().getFlatValue<uint8_t>(pos);
}
int32_t mlirDenseElementsAttrGetInt32Value(MlirAttribute attr, intptr_t pos) {
- return *(unwrap(attr).cast<DenseElementsAttr>().getValues<int32_t>().begin() +
- pos);
+ return unwrap(attr).cast<DenseElementsAttr>().getFlatValue<int32_t>(pos);
}
uint32_t mlirDenseElementsAttrGetUInt32Value(MlirAttribute attr, intptr_t pos) {
- return *(
- unwrap(attr).cast<DenseElementsAttr>().getValues<uint32_t>().begin() +
- pos);
+ return unwrap(attr).cast<DenseElementsAttr>().getFlatValue<uint32_t>(pos);
}
int64_t mlirDenseElementsAttrGetInt64Value(MlirAttribute attr, intptr_t pos) {
- return *(unwrap(attr).cast<DenseElementsAttr>().getValues<int64_t>().begin() +
- pos);
+ return unwrap(attr).cast<DenseElementsAttr>().getFlatValue<int64_t>(pos);
}
uint64_t mlirDenseElementsAttrGetUInt64Value(MlirAttribute attr, intptr_t pos) {
- return *(
- unwrap(attr).cast<DenseElementsAttr>().getValues<uint64_t>().begin() +
- pos);
+ return unwrap(attr).cast<DenseElementsAttr>().getFlatValue<uint64_t>(pos);
}
float mlirDenseElementsAttrGetFloatValue(MlirAttribute attr, intptr_t pos) {
- return *(unwrap(attr).cast<DenseElementsAttr>().getValues<float>().begin() +
- pos);
+ return unwrap(attr).cast<DenseElementsAttr>().getFlatValue<float>(pos);
}
double mlirDenseElementsAttrGetDoubleValue(MlirAttribute attr, intptr_t pos) {
- return *(unwrap(attr).cast<DenseElementsAttr>().getValues<double>().begin() +
- pos);
+ return unwrap(attr).cast<DenseElementsAttr>().getFlatValue<double>(pos);
}
MlirStringRef mlirDenseElementsAttrGetStringValue(MlirAttribute attr,
intptr_t pos) {
return wrap(
- *(unwrap(attr).cast<DenseElementsAttr>().getValues<StringRef>().begin() +
- pos));
+ unwrap(attr).cast<DenseElementsAttr>().getFlatValue<StringRef>(pos));
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Conversion/LinalgToSPIRV/LinalgToSPIRV.cpp b/mlir/lib/Conversion/LinalgToSPIRV/LinalgToSPIRV.cpp
index c20f4ed6b567d..bd1e4ade4ec7a 100644
--- a/mlir/lib/Conversion/LinalgToSPIRV/LinalgToSPIRV.cpp
+++ b/mlir/lib/Conversion/LinalgToSPIRV/LinalgToSPIRV.cpp
@@ -127,7 +127,7 @@ LogicalResult SingleWorkgroupReduction::matchAndRewrite(
if ((*localSize.begin()).getSExtValue() != originalInputType.getDimSize(0))
return failure();
- if (llvm::any_of(llvm::drop_begin(localSize.getIntValues(), 1),
+ if (llvm::any_of(llvm::drop_begin(localSize.getValues<APInt>(), 1),
[](const APInt &size) { return !size.isOneValue(); }))
return failure();
diff --git a/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp b/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp
index c59f291adc1e9..0da2209e2df08 100644
--- a/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp
+++ b/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp
@@ -558,9 +558,9 @@ LogicalResult ConstantCompositeOpPattern::matchAndRewrite(
if (srcElemType != dstElemType) {
SmallVector<Attribute, 8> elements;
if (srcElemType.isa<FloatType>()) {
- for (Attribute srcAttr : dstElementsAttr.getAttributeValues()) {
- FloatAttr dstAttr = convertFloatAttr(
- srcAttr.cast<FloatAttr>(), dstElemType.cast<FloatType>(), rewriter);
+ for (FloatAttr srcAttr : dstElementsAttr.getValues<FloatAttr>()) {
+ FloatAttr dstAttr =
+ convertFloatAttr(srcAttr, dstElemType.cast<FloatType>(), rewriter);
if (!dstAttr)
return failure();
elements.push_back(dstAttr);
@@ -568,10 +568,9 @@ LogicalResult ConstantCompositeOpPattern::matchAndRewrite(
} else if (srcElemType.isInteger(1)) {
return failure();
} else {
- for (Attribute srcAttr : dstElementsAttr.getAttributeValues()) {
- IntegerAttr dstAttr =
- convertIntegerAttr(srcAttr.cast<IntegerAttr>(),
- dstElemType.cast<IntegerType>(), rewriter);
+ for (IntegerAttr srcAttr : dstElementsAttr.getValues<IntegerAttr>()) {
+ IntegerAttr dstAttr = convertIntegerAttr(
+ srcAttr, dstElemType.cast<IntegerType>(), rewriter);
if (!dstAttr)
return failure();
elements.push_back(dstAttr);
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index 558ba4d04acbb..4184b9cadf4dc 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -1610,7 +1610,7 @@ class TransposeConverter : public OpRewritePattern<tosa::TransposeOp> {
SmallVector<AffineExpr, 2> inputExprs;
inputExprs.resize(resultTy.getRank());
- for (auto permutation : llvm::enumerate(perms.getIntValues())) {
+ for (auto permutation : llvm::enumerate(perms.getValues<APInt>())) {
inputExprs[permutation.value().getZExtValue()] =
rewriter.getAffineDimExpr(permutation.index());
}
diff --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
index b3fe10d05ce04..26ab9551854a5 100644
--- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
+++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
@@ -337,11 +337,12 @@ void gpu::addAsyncDependency(Operation *op, Value token) {
auto attrName =
OpTrait::AttrSizedOperandSegments<void>::getOperandSegmentSizeAttr();
auto sizeAttr = op->template getAttrOfType<DenseIntElementsAttr>(attrName);
+
+ // Async dependencies is the only variadic operand.
if (!sizeAttr)
- return; // Async dependencies is the only variadic operand.
- SmallVector<int32_t, 8> sizes;
- for (auto size : sizeAttr.getIntValues())
- sizes.push_back(size.getSExtValue());
+ return;
+
+ SmallVector<int32_t, 8> sizes(sizeAttr.getValues<int32_t>());
++sizes.front();
op->setAttr(attrName, Builder(op->getContext()).getI32VectorAttr(sizes));
}
diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index b6d072f9f60a1..f3b69ae85182e 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -1825,8 +1825,9 @@ void ModulePrinter::printDenseIntOrFPElementsAttr(DenseIntOrFPElementsAttr attr,
// and hence was replaced.
if (complexElementType.isa<IntegerType>()) {
bool isSigned = !complexElementType.isUnsignedInteger();
+ auto valueIt = attr.value_begin<std::complex<APInt>>();
printDenseElementsAttrImpl(attr.isSplat(), type, os, [&](unsigned index) {
- auto complexValue = *(attr.getComplexIntValues().begin() + index);
+ auto complexValue = *(valueIt + index);
os << "(";
printDenseIntElement(complexValue.real(), os, isSigned);
os << ",";
@@ -1834,8 +1835,9 @@ void ModulePrinter::printDenseIntOrFPElementsAttr(DenseIntOrFPElementsAttr attr,
os << ")";
});
} else {
+ auto valueIt = attr.value_begin<std::complex<APFloat>>();
printDenseElementsAttrImpl(attr.isSplat(), type, os, [&](unsigned index) {
- auto complexValue = *(attr.getComplexFloatValues().begin() + index);
+ auto complexValue = *(valueIt + index);
os << "(";
printFloatValue(complexValue.real(), os);
os << ",";
@@ -1845,15 +1847,15 @@ void ModulePrinter::printDenseIntOrFPElementsAttr(DenseIntOrFPElementsAttr attr,
}
} else if (elementType.isIntOrIndex()) {
bool isSigned = !elementType.isUnsignedInteger();
- auto intValues = attr.getIntValues();
+ auto valueIt = attr.value_begin<APInt>();
printDenseElementsAttrImpl(attr.isSplat(), type, os, [&](unsigned index) {
- printDenseIntElement(*(intValues.begin() + index), os, isSigned);
+ printDenseIntElement(*(valueIt + index), os, isSigned);
});
} else {
assert(elementType.isa<FloatType>() && "unexpected element type");
- auto floatValues = attr.getFloatValues();
+ auto valueIt = attr.value_begin<APFloat>();
printDenseElementsAttrImpl(attr.isSplat(), type, os, [&](unsigned index) {
- printFloatValue(*(floatValues.begin() + index), os);
+ printFloatValue(*(valueIt + index), os);
});
}
}
diff --git a/mlir/lib/IR/BuiltinAttributes.cpp b/mlir/lib/IR/BuiltinAttributes.cpp
index d906adc3e8151..a3c7fb0af9293 100644
--- a/mlir/lib/IR/BuiltinAttributes.cpp
+++ b/mlir/lib/IR/BuiltinAttributes.cpp
@@ -390,6 +390,8 @@ ShapedType ElementsAttr::getType() const {
return Attribute::getType().cast<ShapedType>();
}
+Type ElementsAttr::getElementType() const { return getType().getElementType(); }
+
/// Returns the number of elements held by this attribute.
int64_t ElementsAttr::getNumElements() const {
return getType().getNumElements();
@@ -635,7 +637,7 @@ DenseElementsAttr::AttributeElementIterator::AttributeElementIterator(
Attribute DenseElementsAttr::AttributeElementIterator::operator*() const {
auto owner = getFromOpaquePointer(base).cast<DenseElementsAttr>();
- Type eltTy = owner.getType().getElementType();
+ Type eltTy = owner.getElementType();
if (auto intEltTy = eltTy.dyn_cast<IntegerType>())
return IntegerAttr::get(eltTy, *IntElementIterator(owner, index));
if (eltTy.isa<IndexType>())
@@ -690,7 +692,7 @@ DenseElementsAttr::IntElementIterator::IntElementIterator(
DenseElementsAttr attr, size_t dataIndex)
: DenseElementIndexedIteratorImpl<IntElementIterator, APInt, APInt, APInt>(
attr.getRawData().data(), attr.isSplat(), dataIndex),
- bitWidth(getDenseElementBitWidth(attr.getType().getElementType())) {}
+ bitWidth(getDenseElementBitWidth(attr.getElementType())) {}
APInt DenseElementsAttr::IntElementIterator::operator*() const {
return readBits(getData(),
@@ -707,7 +709,7 @@ DenseElementsAttr::ComplexIntElementIterator::ComplexIntElementIterator(
std::complex<APInt>, std::complex<APInt>,
std::complex<APInt>>(
attr.getRawData().data(), attr.isSplat(), dataIndex) {
- auto complexType = attr.getType().getElementType().cast<ComplexType>();
+ auto complexType = attr.getElementType().cast<ComplexType>();
bitWidth = getDenseElementBitWidth(complexType.getElementType());
}
@@ -930,21 +932,15 @@ DenseElementsAttr DenseElementsAttr::getRawIntOrFloat(ShapedType type,
isInt, isSigned);
}
-/// A method used to verify specific type invariants that the templatized 'get'
-/// method cannot.
bool DenseElementsAttr::isValidIntOrFloat(int64_t dataEltSize, bool isInt,
bool isSigned) const {
- return ::isValidIntOrFloat(getType().getElementType(), dataEltSize, isInt,
- isSigned);
+ return ::isValidIntOrFloat(getElementType(), dataEltSize, isInt, isSigned);
}
-
-/// Check the information for a C++ data type, check if this type is valid for
-/// the current attribute.
bool DenseElementsAttr::isValidComplex(int64_t dataEltSize, bool isInt,
bool isSigned) const {
return ::isValidIntOrFloat(
- getType().getElementType().cast<ComplexType>().getElementType(),
- dataEltSize / 2, isInt, isSigned);
+ getElementType().cast<ComplexType>().getElementType(), dataEltSize / 2,
+ isInt, isSigned);
}
/// Returns true if this attribute corresponds to a splat, i.e. if all element
@@ -953,76 +949,69 @@ bool DenseElementsAttr::isSplat() const {
return static_cast<DenseElementsAttributeStorage *>(impl)->isSplat;
}
-/// Return the held element values as a range of Attributes.
-auto DenseElementsAttr::getAttributeValues() const
- -> llvm::iterator_range<AttributeElementIterator> {
- return {attr_value_begin(), attr_value_end()};
-}
-auto DenseElementsAttr::attr_value_begin() const -> AttributeElementIterator {
- return AttributeElementIterator(*this, 0);
-}
-auto DenseElementsAttr::attr_value_end() const -> AttributeElementIterator {
- return AttributeElementIterator(*this, getNumElements());
+/// Return if the given complex type has an integer element type.
+static bool isComplexOfIntType(Type type) {
+ return type.cast<ComplexType>().getElementType().isa<IntegerType>();
}
-/// Return the held element values as a range of bool. The element type of
-/// this attribute must be of integer type of bitwidth 1.
-auto DenseElementsAttr::getBoolValues() const
- -> llvm::iterator_range<BoolElementIterator> {
- auto eltType = getType().getElementType().dyn_cast<IntegerType>();
- assert(eltType && eltType.getWidth() == 1 && "expected i1 integer type");
- (void)eltType;
- return {BoolElementIterator(*this, 0),
- BoolElementIterator(*this, getNumElements())};
-}
-
-/// Return the held element values as a range of APInts. The element type of
-/// this attribute must be of integer type.
-auto DenseElementsAttr::getIntValues() const
- -> llvm::iterator_range<IntElementIterator> {
- assert(getType().getElementType().isIntOrIndex() && "expected integral type");
- return {raw_int_begin(), raw_int_end()};
-}
-auto DenseElementsAttr::int_value_begin() const -> IntElementIterator {
- assert(getType().getElementType().isIntOrIndex() && "expected integral type");
- return raw_int_begin();
-}
-auto DenseElementsAttr::int_value_end() const -> IntElementIterator {
- assert(getType().getElementType().isIntOrIndex() && "expected integral type");
- return raw_int_end();
-}
auto DenseElementsAttr::getComplexIntValues() const
-> llvm::iterator_range<ComplexIntElementIterator> {
- Type eltTy = getType().getElementType().cast<ComplexType>().getElementType();
- (void)eltTy;
- assert(eltTy.isa<IntegerType>() && "expected complex integral type");
+ assert(isComplexOfIntType(getElementType()) &&
+ "expected complex integral type");
return {ComplexIntElementIterator(*this, 0),
ComplexIntElementIterator(*this, getNumElements())};
}
+auto DenseElementsAttr::complex_value_begin() const
+ -> ComplexIntElementIterator {
+ assert(isComplexOfIntType(getElementType()) &&
+ "expected complex integral type");
+ return ComplexIntElementIterator(*this, 0);
+}
+auto DenseElementsAttr::complex_value_end() const -> ComplexIntElementIterator {
+ assert(isComplexOfIntType(getElementType()) &&
+ "expected complex integral type");
+ return ComplexIntElementIterator(*this, getNumElements());
+}
/// Return the held element values as a range of APFloat. The element type of
/// this attribute must be of float type.
auto DenseElementsAttr::getFloatValues() const
-> llvm::iterator_range<FloatElementIterator> {
- auto elementType = getType().getElementType().cast<FloatType>();
+ auto elementType = getElementType().cast<FloatType>();
const auto &elementSemantics = elementType.getFloatSemantics();
return {FloatElementIterator(elementSemantics, raw_int_begin()),
FloatElementIterator(elementSemantics, raw_int_end())};
}
auto DenseElementsAttr::float_value_begin() const -> FloatElementIterator {
- return getFloatValues().begin();
+ auto elementType = getElementType().cast<FloatType>();
+ return FloatElementIterator(elementType.getFloatSemantics(), raw_int_begin());
}
auto DenseElementsAttr::float_value_end() const -> FloatElementIterator {
- return getFloatValues().end();
+ auto elementType = getElementType().cast<FloatType>();
+ return FloatElementIterator(elementType.getFloatSemantics(), raw_int_end());
}
+
auto DenseElementsAttr::getComplexFloatValues() const
-> llvm::iterator_range<ComplexFloatElementIterator> {
- Type eltTy = getType().getElementType().cast<ComplexType>().getElementType();
+ Type eltTy = getElementType().cast<ComplexType>().getElementType();
assert(eltTy.isa<FloatType>() && "expected complex float type");
const auto &semantics = eltTy.cast<FloatType>().getFloatSemantics();
return {{semantics, {*this, 0}},
{semantics, {*this, static_cast<size_t>(getNumElements())}}};
}
+auto DenseElementsAttr::complex_float_value_begin() const
+ -> ComplexFloatElementIterator {
+ Type eltTy = getElementType().cast<ComplexType>().getElementType();
+ assert(eltTy.isa<FloatType>() && "expected complex float type");
+ return {eltTy.cast<FloatType>().getFloatSemantics(), {*this, 0}};
+}
+auto DenseElementsAttr::complex_float_value_end() const
+ -> ComplexFloatElementIterator {
+ Type eltTy = getElementType().cast<ComplexType>().getElementType();
+ assert(eltTy.isa<FloatType>() && "expected complex float type");
+ return {eltTy.cast<FloatType>().getFloatSemantics(),
+ {*this, static_cast<size_t>(getNumElements())}};
+}
/// Return the raw storage data held by this attribute.
ArrayRef<char> DenseElementsAttr::getRawData() const {
@@ -1374,19 +1363,19 @@ Attribute SparseElementsAttr::getValue(ArrayRef<uint64_t> index) const {
/// Get a zero APFloat for the given sparse attribute.
APFloat SparseElementsAttr::getZeroAPFloat() const {
- auto eltType = getType().getElementType().cast<FloatType>();
+ auto eltType = getElementType().cast<FloatType>();
return APFloat(eltType.getFloatSemantics());
}
/// Get a zero APInt for the given sparse attribute.
APInt SparseElementsAttr::getZeroAPInt() const {
- auto eltType = getType().getElementType().cast<IntegerType>();
+ auto eltType = getElementType().cast<IntegerType>();
return APInt::getZero(eltType.getWidth());
}
/// Get a zero attribute for the given attribute type.
Attribute SparseElementsAttr::getZeroAttr() const {
- auto eltType = getType().getElementType();
+ auto eltType = getElementType();
// Handle floating point elements.
if (eltType.isa<FloatType>())
diff --git a/mlir/lib/IR/Operation.cpp b/mlir/lib/IR/Operation.cpp
index 457c993b92ad3..6cba5742a3e0f 100644
--- a/mlir/lib/IR/Operation.cpp
+++ b/mlir/lib/IR/Operation.cpp
@@ -1024,7 +1024,7 @@ LogicalResult OpTrait::impl::verifyValueSizeAttr(Operation *op,
return op->emitOpError("requires 1D i32 elements attribute '")
<< attrName << "'";
- if (llvm::any_of(sizeAttr.getIntValues(), [](const APInt &element) {
+ if (llvm::any_of(sizeAttr.getValues<APInt>(), [](const APInt &element) {
return !element.isNonNegative();
}))
return op->emitOpError("'")
diff --git a/mlir/lib/Interfaces/InferTypeOpInterface.cpp b/mlir/lib/Interfaces/InferTypeOpInterface.cpp
index 7a3f602ca3d43..9676553aeaa57 100644
--- a/mlir/lib/Interfaces/InferTypeOpInterface.cpp
+++ b/mlir/lib/Interfaces/InferTypeOpInterface.cpp
@@ -51,7 +51,7 @@ void ShapeAdaptor::getDims(SmallVectorImpl<int64_t> &res) const {
auto dattr = attr.cast<DenseIntElementsAttr>();
res.clear();
res.reserve(dattr.size());
- for (auto it : dattr.getIntValues())
+ for (auto it : dattr.getValues<APInt>())
res.push_back(it.getSExtValue());
} else {
auto vals = val.get<ShapedTypeComponents *>()->getDims();
@@ -71,7 +71,7 @@ int64_t ShapeAdaptor::getDimSize(int index) const {
return t.cast<ShapedType>().getDimSize(index);
if (auto attr = val.dyn_cast<Attribute>())
return attr.cast<DenseIntElementsAttr>()
- .getValue<APInt>({static_cast<uint64_t>(index)})
+ .getFlatValue<APInt>(index)
.getSExtValue();
auto *stc = val.get<ShapedTypeComponents *>();
return stc->getDims()[index];
@@ -94,7 +94,7 @@ bool ShapeAdaptor::hasStaticShape() const {
return t.cast<ShapedType>().hasStaticShape();
if (auto attr = val.dyn_cast<Attribute>()) {
auto dattr = attr.cast<DenseIntElementsAttr>();
- for (auto index : dattr.getIntValues())
+ for (auto index : dattr.getValues<APInt>())
if (ShapedType::isDynamic(index.getSExtValue()))
return false;
return true;
@@ -115,7 +115,7 @@ int64_t ShapeAdaptor::getNumElements() const {
if (auto attr = val.dyn_cast<Attribute>()) {
auto dattr = attr.cast<DenseIntElementsAttr>();
int64_t num = 1;
- for (auto index : dattr.getIntValues()) {
+ for (auto index : dattr.getValues<APInt>()) {
num *= index.getZExtValue();
assert(num >= 0 && "integer overflow in element count computation");
}
diff --git a/mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp b/mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp
index 3945cfa6ee0f7..199a8c49aca8c 100644
--- a/mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp
+++ b/mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp
@@ -294,7 +294,8 @@ Attribute Importer::getConstantAsAttr(llvm::Constant *value) {
if (!nested)
return nullptr;
- values.append(nested.attr_value_begin(), nested.attr_value_end());
+ values.append(nested.value_begin<Attribute>(),
+ nested.value_end<Attribute>());
}
return DenseElementsAttr::get(outerType, values);
diff --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
index e40fa9b8e6b7a..9562e9b1c1f92 100644
--- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
@@ -83,12 +83,14 @@ const char *opSegmentSizeAttrInitCode = R"(
auto sizeAttr = (*this)->getAttr({0}).cast<::mlir::DenseIntElementsAttr>();
)";
const char *attrSizedSegmentValueRangeCalcCode = R"(
- auto sizeAttrValues = sizeAttr.getValues<uint32_t>();
+ const uint32_t *sizeAttrValueIt = &*sizeAttr.value_begin<uint32_t>();
+ if (sizeAttr.isSplat())
+ return {*sizeAttrValueIt * index, *sizeAttrValueIt};
+
unsigned start = 0;
for (unsigned i = 0; i < index; ++i)
- start += *(sizeAttrValues.begin() + i);
- unsigned size = *(sizeAttrValues.begin() + index);
- return {start, size};
+ start += sizeAttrValueIt[i];
+ return {start, sizeAttrValueIt[index]};
)";
// The logic to calculate the actual value range for a declared operand
// of an op with variadic of variadic operands within the OpAdaptor.
diff --git a/mlir/unittests/TableGen/StructsGenTest.cpp b/mlir/unittests/TableGen/StructsGenTest.cpp
index ef0bdd81ee3a5..4fcdf1d673195 100644
--- a/mlir/unittests/TableGen/StructsGenTest.cpp
+++ b/mlir/unittests/TableGen/StructsGenTest.cpp
@@ -158,7 +158,7 @@ TEST(StructsGenTest, GetElements) {
auto denseAttr = returnedAttr.dyn_cast<mlir::DenseElementsAttr>();
ASSERT_TRUE(denseAttr);
- for (const auto &valIndexIt : llvm::enumerate(denseAttr.getIntValues())) {
+ for (const auto &valIndexIt : llvm::enumerate(denseAttr.getValues<APInt>())) {
EXPECT_EQ(valIndexIt.value(), valIndexIt.index() + 1);
}
}
More information about the Mlir-commits
mailing list