[Mlir-commits] [mlir] 24ad385 - [mlir][DenseElementsAttr] Add support for opaque APFloat/APInt complex values.
River Riddle
llvmlistbot at llvm.org
Tue May 5 12:46:43 PDT 2020
Author: River Riddle
Date: 2020-05-05T12:42:37-07:00
New Revision: 24ad3858842552a8052c12c44c7707dd822898bf
URL: https://github.com/llvm/llvm-project/commit/24ad3858842552a8052c12c44c7707dd822898bf
DIFF: https://github.com/llvm/llvm-project/commit/24ad3858842552a8052c12c44c7707dd822898bf.diff
LOG: [mlir][DenseElementsAttr] Add support for opaque APFloat/APInt complex values.
This revision allows for creating DenseElementsAttrs and accessing elements using std::complex<APInt>/std::complex<APFloat>. This allows for opaquely accessing and transforming complex values. This is used by the printer/parser to provide pretty printing for complex values. The form for complex values matches that of std::complex, i.e.:
```
// `(` element `,` element `)`
dense<(10,10)> : tensor<complex<i64>>
```
Differential Revision: https://reviews.llvm.org/D79296
Added:
Modified:
mlir/include/mlir/IR/Attributes.h
mlir/lib/IR/AsmPrinter.cpp
mlir/lib/IR/AttributeDetail.h
mlir/lib/IR/Attributes.cpp
mlir/lib/Parser/Parser.cpp
mlir/test/IR/dense-elements-hex.mlir
mlir/test/IR/invalid.mlir
mlir/test/IR/parser.mlir
mlir/unittests/IR/AttributeTest.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/IR/Attributes.h b/mlir/include/mlir/IR/Attributes.h
index 6b3edf14a39b..e0b4e5f43737 100644
--- a/mlir/include/mlir/IR/Attributes.h
+++ b/mlir/include/mlir/IR/Attributes.h
@@ -764,12 +764,26 @@ class DenseElementsAttr : public ElementsAttr {
/// shape.
static DenseElementsAttr get(ShapedType type, ArrayRef<APInt> values);
+ /// Constructs a dense complex elements attribute from an array of APInt
+ /// values. Each APInt value is expected to have the same bitwidth as the
+ /// element type of 'type'. 'type' must be a vector or tensor with static
+ /// shape.
+ static DenseElementsAttr get(ShapedType type,
+ ArrayRef<std::complex<APInt>> values);
+
/// Constructs a dense float elements attribute from an array of APFloat
/// values. Each APFloat value is expected to have the same bitwidth as the
/// element type of 'type'. 'type' must be a vector or tensor with static
/// shape.
static DenseElementsAttr get(ShapedType type, ArrayRef<APFloat> values);
+ /// Constructs a dense complex elements attribute from an array of APFloat
+ /// values. Each APFloat value is expected to have the same bitwidth as the
+ /// element type of 'type'. 'type' must be a vector or tensor with static
+ /// shape.
+ static DenseElementsAttr get(ShapedType type,
+ ArrayRef<std::complex<APFloat>> values);
+
/// Construct a dense elements attribute for an initializer_list of values.
/// Each value is expected to be the same bitwidth of the element type of
/// 'type'. 'type' must be a vector or tensor with static shape.
@@ -868,6 +882,26 @@ class DenseElementsAttr : public ElementsAttr {
size_t bitWidth;
};
+ /// A utility iterator that allows walking over the internal raw complex APInt
+ /// values.
+ class ComplexIntElementIterator
+ : public detail::DenseElementIndexedIteratorImpl<
+ ComplexIntElementIterator, std::complex<APInt>, std::complex<APInt>,
+ std::complex<APInt>> {
+ public:
+ /// Accesses the raw std::complex<APInt> value at this iterator position.
+ std::complex<APInt> operator*() const;
+
+ private:
+ friend DenseElementsAttr;
+
+ /// Constructs a new iterator.
+ ComplexIntElementIterator(DenseElementsAttr attr, size_t dataIndex);
+
+ /// The bitwidth of the element type.
+ size_t bitWidth;
+ };
+
/// Iterator for walking over APFloat values.
class FloatElementIterator final
: public llvm::mapped_iterator<IntElementIterator,
@@ -881,6 +915,21 @@ class DenseElementsAttr : public ElementsAttr {
using reference = APFloat;
};
+ /// Iterator for walking over complex APFloat values.
+ class ComplexFloatElementIterator final
+ : public llvm::mapped_iterator<
+ ComplexIntElementIterator,
+ std::function<std::complex<APFloat>(const std::complex<APInt> &)>> {
+ friend DenseElementsAttr;
+
+ /// Initializes the float element iterator to the specified iterator.
+ ComplexFloatElementIterator(const llvm::fltSemantics &smt,
+ ComplexIntElementIterator it);
+
+ public:
+ using reference = std::complex<APFloat>;
+ };
+
//===--------------------------------------------------------------------===//
// Value Querying
//===--------------------------------------------------------------------===//
@@ -1004,6 +1053,15 @@ class DenseElementsAttr : public ElementsAttr {
IntElementIterator int_value_begin() const;
IntElementIterator int_value_end() const;
+ /// Return the held element values as a range of complex APInts. The element
+ /// type of this attribute must be a complex of integer type.
+ llvm::iterator_range<ComplexIntElementIterator> getComplexIntValues() const;
+ template <typename T, typename = typename std::enable_if<
+ std::is_same<T, std::complex<APInt>>::value>::type>
+ llvm::iterator_range<ComplexIntElementIterator> getValues() const {
+ return getComplexIntValues();
+ }
+
/// Return the held element values as a range of APFloat. The element type of
/// this attribute must be of float type.
llvm::iterator_range<FloatElementIterator> getFloatValues() const;
@@ -1015,6 +1073,16 @@ class DenseElementsAttr : public ElementsAttr {
FloatElementIterator float_value_begin() const;
FloatElementIterator float_value_end() const;
+ /// Return the held element values as a range of complex APFloat. The element
+ /// type of this attribute must be a complex of float type.
+ llvm::iterator_range<ComplexFloatElementIterator>
+ getComplexFloatValues() const;
+ template <typename T, typename = typename std::enable_if<std::is_same<
+ T, std::complex<APFloat>>::value>::type>
+ llvm::iterator_range<ComplexFloatElementIterator> getValues() const {
+ return getComplexFloatValues();
+ }
+
/// Return the raw storage data held by this attribute. Users should generally
/// not use this directly, as the internal storage format is not always in the
/// form the user might expect.
@@ -1120,10 +1188,17 @@ class DenseIntOrFPElementsAttr
protected:
friend DenseElementsAttr;
+ /// Constructs a dense elements attribute from an array of raw APFloat values.
+ /// Each APFloat value is expected to have the same bitwidth as the element
+ /// type of 'type'. 'type' must be a vector or tensor with static shape.
+ static DenseElementsAttr getRaw(ShapedType type, size_t storageWidth,
+ ArrayRef<APFloat> values, bool isSplat);
+
/// Constructs a dense elements attribute from an array of raw APInt values.
/// Each APInt value is expected to have the same bitwidth as the element type
/// of 'type'. 'type' must be a vector or tensor with static shape.
- static DenseElementsAttr getRaw(ShapedType type, ArrayRef<APInt> values);
+ static DenseElementsAttr getRaw(ShapedType type, size_t storageWidth,
+ ArrayRef<APInt> values, bool isSplat);
/// Get or create a new dense elements attribute instance with the given raw
/// data buffer. 'type' must be a vector or tensor with static shape.
@@ -1343,19 +1418,34 @@ class SparseElementsAttr
getZeroValue() const {
return getZeroAPInt();
}
+ template <typename T>
+ typename std::enable_if<std::is_same<std::complex<APInt>, T>::value, T>::type
+ getZeroValue() const {
+ APInt intZero = getZeroAPInt();
+ return {intZero, intZero};
+ }
/// Get a zero for an APFloat.
template <typename T>
typename std::enable_if<std::is_same<APFloat, T>::value, T>::type
getZeroValue() const {
return getZeroAPFloat();
}
+ template <typename T>
+ typename std::enable_if<std::is_same<std::complex<APFloat>, T>::value,
+ T>::type
+ getZeroValue() const {
+ APFloat floatZero = getZeroAPFloat();
+ return {floatZero, floatZero};
+ }
/// Get a zero for an C++ integer, float, StringRef, or complex type.
template <typename T>
typename std::enable_if<
std::numeric_limits<T>::is_integer ||
llvm::is_one_of<T, float, double, StringRef>::value ||
- detail::is_complex_t<T>::value,
+ (detail::is_complex_t<T>::value &&
+ !llvm::is_one_of<T, std::complex<APInt>,
+ std::complex<APFloat>>::value),
T>::type
getZeroValue() const {
return T();
diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index 4305f8b9af19..f58656005cd6 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -1456,23 +1456,15 @@ void ModulePrinter::printAttribute(Attribute attr,
}
}
-/// Print the integer element of the given DenseElementsAttr at 'index'.
-static void printDenseIntElement(DenseElementsAttr attr, raw_ostream &os,
- unsigned index, bool isSigned) {
- APInt value = *std::next(attr.int_value_begin(), index);
+/// Print the integer element of a DenseElementsAttr.
+static void printDenseIntElement(const APInt &value, raw_ostream &os,
+ bool isSigned) {
if (value.getBitWidth() == 1)
os << (value.getBoolValue() ? "true" : "false");
else
value.print(os, isSigned);
}
-/// Print the float element of the given DenseElementsAttr at 'index'.
-static void printDenseFloatElement(DenseElementsAttr attr, raw_ostream &os,
- unsigned index) {
- APFloat value = *std::next(attr.float_value_begin(), index);
- printFloatValue(value, os);
-}
-
static void
printDenseElementsAttrImpl(bool isSplat, ShapedType type, raw_ostream &os,
function_ref<void(unsigned)> printEltFn) {
@@ -1543,26 +1535,45 @@ void ModulePrinter::printDenseIntOrFPElementsAttr(DenseIntOrFPElementsAttr attr,
auto elementType = type.getElementType();
// Check to see if we should format this attribute as a hex string.
- // TODO: Add support for formatting complex elements nicely.
auto numElements = type.getNumElements();
- if (type.getElementType().isa<ComplexType>() ||
- (!attr.isSplat() && allowHex &&
- shouldPrintElementsAttrWithHex(numElements))) {
+ if (!attr.isSplat() && allowHex &&
+ shouldPrintElementsAttrWithHex(numElements)) {
ArrayRef<char> rawData = attr.getRawData();
os << '"' << "0x" << llvm::toHex(StringRef(rawData.data(), rawData.size()))
<< "\"";
return;
}
- if (elementType.isIntOrIndex()) {
+ if (ComplexType complexTy = elementType.dyn_cast<ComplexType>()) {
+ auto printComplexValue = [&](auto complexValues, auto printFn,
+ raw_ostream &os, auto &&... params) {
+ printDenseElementsAttrImpl(attr.isSplat(), type, os, [&](unsigned index) {
+ auto complexValue = *(complexValues.begin() + index);
+ os << "(";
+ printFn(complexValue.real(), os, params...);
+ os << ",";
+ printFn(complexValue.imag(), os, params...);
+ os << ")";
+ });
+ };
+
+ Type complexElementType = complexTy.getElementType();
+ if (complexElementType.isa<IntegerType>())
+ printComplexValue(attr.getComplexIntValues(), printDenseIntElement, os,
+ /*isSigned=*/!complexElementType.isUnsignedInteger());
+ else
+ printComplexValue(attr.getComplexFloatValues(), printFloatValue, os);
+ } else if (elementType.isIntOrIndex()) {
bool isSigned = !elementType.isUnsignedInteger();
+ auto intValues = attr.getIntValues();
printDenseElementsAttrImpl(attr.isSplat(), type, os, [&](unsigned index) {
- printDenseIntElement(attr, os, index, isSigned);
+ printDenseIntElement(*(intValues.begin() + index), os, isSigned);
});
} else {
assert(elementType.isa<FloatType>() && "unexpected element type");
+ auto floatValues = attr.getFloatValues();
printDenseElementsAttrImpl(attr.isSplat(), type, os, [&](unsigned index) {
- printDenseFloatElement(attr, os, index);
+ printFloatValue(*(floatValues.begin() + index), os);
});
}
}
diff --git a/mlir/lib/IR/AttributeDetail.h b/mlir/lib/IR/AttributeDetail.h
index 4084dde74919..ffd1b34504e3 100644
--- a/mlir/lib/IR/AttributeDetail.h
+++ b/mlir/lib/IR/AttributeDetail.h
@@ -374,8 +374,9 @@ struct TypeAttributeStorage : public AttributeStorage {
/// Return the bit width which DenseElementsAttr should use for this type.
inline size_t getDenseElementBitWidth(Type eltType) {
- if (ComplexType complex = eltType.dyn_cast<ComplexType>())
- return getDenseElementBitWidth(complex.getElementType()) * 2;
+ // Align the width for complex to 8 to make storage and interpretation easier.
+ if (ComplexType comp = eltType.dyn_cast<ComplexType>())
+ return llvm::alignTo<8>(getDenseElementBitWidth(comp.getElementType())) * 2;
// FIXME(b/121118307): using 64 bits for BF16 because it is currently stored
// with double semantics.
if (eltType.isBF16())
diff --git a/mlir/lib/IR/Attributes.cpp b/mlir/lib/IR/Attributes.cpp
index 72c638a68f73..a150fcd4d323 100644
--- a/mlir/lib/IR/Attributes.cpp
+++ b/mlir/lib/IR/Attributes.cpp
@@ -537,6 +537,9 @@ uint64_t ElementsAttr::getFlattenedIndex(ArrayRef<uint64_t> index) const {
static size_t getDenseElementStorageWidth(size_t origWidth) {
return origWidth == 1 ? origWidth : llvm::alignTo<8>(origWidth);
}
+static size_t getDenseElementStorageWidth(Type elementType) {
+ return getDenseElementStorageWidth(getDenseElementBitWidth(elementType));
+}
/// Set a bit to a specific value.
static void setBit(char *rawData, size_t bitPos, bool value) {
@@ -613,14 +616,15 @@ static bool hasSameElementsOrSplat(ShapedType type, const Values &values) {
// DenseElementAttr Iterators
//===----------------------------------------------------------------------===//
-/// Constructs a new iterator.
+//===----------------------------------------------------------------------===//
+// AttributeElementIterator
+
DenseElementsAttr::AttributeElementIterator::AttributeElementIterator(
DenseElementsAttr attr, size_t index)
: llvm::indexed_accessor_iterator<AttributeElementIterator, const void *,
Attribute, Attribute, Attribute>(
attr.getAsOpaquePointer(), index) {}
-/// Accesses the Attribute value at this iterator position.
Attribute DenseElementsAttr::AttributeElementIterator::operator*() const {
auto owner = getFromOpaquePointer(base).cast<DenseElementsAttr>();
Type eltTy = owner.getType().getElementType();
@@ -640,37 +644,75 @@ Attribute DenseElementsAttr::AttributeElementIterator::operator*() const {
llvm_unreachable("unexpected element type");
}
-/// Constructs a new iterator.
+//===----------------------------------------------------------------------===//
+// BoolElementIterator
+
DenseElementsAttr::BoolElementIterator::BoolElementIterator(
DenseElementsAttr attr, size_t dataIndex)
: DenseElementIndexedIteratorImpl<BoolElementIterator, bool, bool, bool>(
attr.getRawData().data(), attr.isSplat(), dataIndex) {}
-/// Accesses the bool value at this iterator position.
bool DenseElementsAttr::BoolElementIterator::operator*() const {
return getBit(getData(), getDataIndex());
}
-/// Constructs a new iterator.
+//===----------------------------------------------------------------------===//
+// IntElementIterator
+
DenseElementsAttr::IntElementIterator::IntElementIterator(
DenseElementsAttr attr, size_t dataIndex)
: DenseElementIndexedIteratorImpl<IntElementIterator, APInt, APInt, APInt>(
attr.getRawData().data(), attr.isSplat(), dataIndex),
bitWidth(getDenseElementBitWidth(attr.getType().getElementType())) {}
-/// Accesses the raw APInt value at this iterator position.
APInt DenseElementsAttr::IntElementIterator::operator*() const {
return readBits(getData(),
getDataIndex() * getDenseElementStorageWidth(bitWidth),
bitWidth);
}
+//===----------------------------------------------------------------------===//
+// ComplexIntElementIterator
+
+DenseElementsAttr::ComplexIntElementIterator::ComplexIntElementIterator(
+ DenseElementsAttr attr, size_t dataIndex)
+ : DenseElementIndexedIteratorImpl<ComplexIntElementIterator,
+ std::complex<APInt>, std::complex<APInt>,
+ std::complex<APInt>>(
+ attr.getRawData().data(), attr.isSplat(), dataIndex) {
+ auto complexType = attr.getType().getElementType().cast<ComplexType>();
+ bitWidth = getDenseElementBitWidth(complexType.getElementType());
+}
+
+std::complex<APInt>
+DenseElementsAttr::ComplexIntElementIterator::operator*() const {
+ size_t storageWidth = getDenseElementStorageWidth(bitWidth);
+ size_t offset = getDataIndex() * storageWidth * 2;
+ return {readBits(getData(), offset, bitWidth),
+ readBits(getData(), offset + storageWidth, bitWidth)};
+}
+
+//===----------------------------------------------------------------------===//
+// FloatElementIterator
+
DenseElementsAttr::FloatElementIterator::FloatElementIterator(
const llvm::fltSemantics &smt, IntElementIterator it)
: llvm::mapped_iterator<IntElementIterator,
std::function<APFloat(const APInt &)>>(
it, [&](const APInt &val) { return APFloat(smt, val); }) {}
+//===----------------------------------------------------------------------===//
+// ComplexFloatElementIterator
+
+DenseElementsAttr::ComplexFloatElementIterator::ComplexFloatElementIterator(
+ const llvm::fltSemantics &smt, ComplexIntElementIterator it)
+ : llvm::mapped_iterator<
+ ComplexIntElementIterator,
+ std::function<std::complex<APFloat>(const std::complex<APInt> &)>>(
+ it, [&](const std::complex<APInt> &val) -> std::complex<APFloat> {
+ return {APFloat(smt, val.real()), APFloat(smt, val.imag())};
+ }) {}
+
//===----------------------------------------------------------------------===//
// DenseElementsAttr
//===----------------------------------------------------------------------===//
@@ -753,7 +795,21 @@ DenseElementsAttr DenseElementsAttr::get(ShapedType type,
DenseElementsAttr DenseElementsAttr::get(ShapedType type,
ArrayRef<APInt> values) {
assert(type.getElementType().isIntOrIndex());
- return DenseIntOrFPElementsAttr::getRaw(type, values);
+ assert(hasSameElementsOrSplat(type, values));
+ size_t storageBitWidth = getDenseElementStorageWidth(type.getElementType());
+ return DenseIntOrFPElementsAttr::getRaw(type, storageBitWidth, values,
+ /*isSplat=*/(values.size() == 1));
+}
+DenseElementsAttr DenseElementsAttr::get(ShapedType type,
+ ArrayRef<std::complex<APInt>> values) {
+ ComplexType complex = type.getElementType().cast<ComplexType>();
+ assert(complex.getElementType().isa<IntegerType>());
+ assert(hasSameElementsOrSplat(type, values));
+ size_t storageBitWidth = getDenseElementStorageWidth(complex) / 2;
+ ArrayRef<APInt> intVals(reinterpret_cast<const APInt *>(values.data()),
+ values.size() * 2);
+ return DenseIntOrFPElementsAttr::getRaw(type, storageBitWidth, intVals,
+ /*isSplat=*/(values.size() == 1));
}
// Constructs a dense float elements attribute from an array of APFloat
@@ -762,12 +818,22 @@ DenseElementsAttr DenseElementsAttr::get(ShapedType type,
DenseElementsAttr DenseElementsAttr::get(ShapedType type,
ArrayRef<APFloat> values) {
assert(type.getElementType().isa<FloatType>());
-
- // Convert the APFloat values to APInt and create a dense elements attribute.
- std::vector<APInt> intValues(values.size());
- for (unsigned i = 0, e = values.size(); i != e; ++i)
- intValues[i] = values[i].bitcastToAPInt();
- return DenseIntOrFPElementsAttr::getRaw(type, intValues);
+ assert(hasSameElementsOrSplat(type, values));
+ size_t storageBitWidth = getDenseElementStorageWidth(type.getElementType());
+ return DenseIntOrFPElementsAttr::getRaw(type, storageBitWidth, values,
+ /*isSplat=*/(values.size() == 1));
+}
+DenseElementsAttr
+DenseElementsAttr::get(ShapedType type,
+ ArrayRef<std::complex<APFloat>> values) {
+ ComplexType complex = type.getElementType().cast<ComplexType>();
+ assert(complex.getElementType().isa<FloatType>());
+ assert(hasSameElementsOrSplat(type, values));
+ ArrayRef<APFloat> apVals(reinterpret_cast<const APFloat *>(values.data()),
+ values.size() * 2);
+ size_t storageBitWidth = getDenseElementStorageWidth(complex) / 2;
+ return DenseIntOrFPElementsAttr::getRaw(type, storageBitWidth, apVals,
+ /*isSplat=*/(values.size() == 1));
}
/// Construct a dense elements attribute from a raw buffer representing the
@@ -783,8 +849,7 @@ DenseElementsAttr DenseElementsAttr::getFromRawBuffer(ShapedType type,
bool DenseElementsAttr::isValidRawBuffer(ShapedType type,
ArrayRef<char> rawBuffer,
bool &detectedSplat) {
- size_t elementWidth = getDenseElementBitWidth(type.getElementType());
- size_t storageWidth = getDenseElementStorageWidth(elementWidth);
+ size_t storageWidth = getDenseElementStorageWidth(type.getElementType());
size_t rawBufferWidth = rawBuffer.size() * CHAR_BIT;
// Storage width of 1 is special as it is packed by the bit.
@@ -904,13 +969,20 @@ auto DenseElementsAttr::int_value_end() const -> IntElementIterator {
assert(getType().getElementType().isIntOrIndex() && "expected integral type");
return raw_int_end();
}
+auto DenseElementsAttr::getComplexIntValues() const
+ -> llvm::iterator_range<ComplexIntElementIterator> {
+ Type eltTy = getType().getElementType().cast<ComplexType>().getElementType();
+ (void)eltTy;
+ assert(eltTy.isa<IntegerType>() && "expected complex integral type");
+ return {ComplexIntElementIterator(*this, 0),
+ ComplexIntElementIterator(*this, getNumElements())};
+}
/// Return the held element values as a range of APFloat. The element type of
/// this attribute must be of float type.
auto DenseElementsAttr::getFloatValues() const
-> llvm::iterator_range<FloatElementIterator> {
auto elementType = getType().getElementType().cast<FloatType>();
- assert(elementType.isa<FloatType>() && "expected float type");
const auto &elementSemantics = elementType.getFloatSemantics();
return {FloatElementIterator(elementSemantics, raw_int_begin()),
FloatElementIterator(elementSemantics, raw_int_end())};
@@ -921,6 +993,14 @@ auto DenseElementsAttr::float_value_begin() const -> FloatElementIterator {
auto DenseElementsAttr::float_value_end() const -> FloatElementIterator {
return getFloatValues().end();
}
+auto DenseElementsAttr::getComplexFloatValues() const
+ -> llvm::iterator_range<ComplexFloatElementIterator> {
+ Type eltTy = getType().getElementType().cast<ComplexType>().getElementType();
+ assert(eltTy.isa<FloatType>() && "expected complex float type");
+ const auto &semantics = eltTy.cast<FloatType>().getFloatSemantics();
+ return {{semantics, {*this, 0}},
+ {semantics, {*this, static_cast<size_t>(getNumElements())}}};
+}
/// Return the raw storage data held by this attribute.
ArrayRef<char> DenseElementsAttr::getRawData() const {
@@ -972,23 +1052,42 @@ DenseStringElementsAttr::get(ShapedType type, ArrayRef<StringRef> values) {
// DenseIntOrFPElementsAttr
//===----------------------------------------------------------------------===//
+/// Utility method to write a range of APInt values to a buffer.
+template <typename APRangeT>
+static void writeAPIntsToBuffer(size_t storageWidth, std::vector<char> &data,
+ APRangeT &&values) {
+ data.resize(llvm::divideCeil(storageWidth, CHAR_BIT) * llvm::size(values));
+ size_t offset = 0;
+ for (auto it = values.begin(), e = values.end(); it != e;
+ ++it, offset += storageWidth) {
+ assert((*it).getBitWidth() <= storageWidth);
+ writeBits(data.data(), offset, *it);
+ }
+}
+
+/// Constructs a dense elements attribute from an array of raw APFloat values.
+/// Each APFloat value is expected to have the same bitwidth as the element
+/// type of 'type'. 'type' must be a vector or tensor with static shape.
+DenseElementsAttr DenseIntOrFPElementsAttr::getRaw(ShapedType type,
+ size_t storageWidth,
+ ArrayRef<APFloat> values,
+ bool isSplat) {
+ std::vector<char> data;
+ auto unwrapFloat = [](const APFloat &val) { return val.bitcastToAPInt(); };
+ writeAPIntsToBuffer(storageWidth, data, llvm::map_range(values, unwrapFloat));
+ return DenseIntOrFPElementsAttr::getRaw(type, data, isSplat);
+}
+
/// Constructs a dense elements attribute from an array of raw APInt values.
/// Each APInt value is expected to have the same bitwidth as the element type
/// of 'type'.
DenseElementsAttr DenseIntOrFPElementsAttr::getRaw(ShapedType type,
- ArrayRef<APInt> values) {
- assert(hasSameElementsOrSplat(type, values));
-
- size_t bitWidth = getDenseElementBitWidth(type.getElementType());
- size_t storageBitWidth = getDenseElementStorageWidth(bitWidth);
- std::vector<char> elementData(llvm::divideCeil(storageBitWidth, CHAR_BIT) *
- values.size());
- for (unsigned i = 0, e = values.size(); i != e; ++i) {
- assert(values[i].getBitWidth() == bitWidth);
- writeBits(elementData.data(), i * storageBitWidth, values[i]);
- }
- return DenseIntOrFPElementsAttr::getRaw(type, elementData,
- /*isSplat=*/(values.size() == 1));
+ size_t storageWidth,
+ ArrayRef<APInt> values,
+ bool isSplat) {
+ std::vector<char> data;
+ writeAPIntsToBuffer(storageWidth, data, values);
+ return DenseIntOrFPElementsAttr::getRaw(type, data, isSplat);
}
DenseElementsAttr DenseIntOrFPElementsAttr::getRaw(ShapedType type,
diff --git a/mlir/lib/Parser/Parser.cpp b/mlir/lib/Parser/Parser.cpp
index 3689983d45e2..384c69e44a05 100644
--- a/mlir/lib/Parser/Parser.cpp
+++ b/mlir/lib/Parser/Parser.cpp
@@ -1956,29 +1956,13 @@ class TensorLiteralParser {
ArrayRef<int64_t> getShape() const { return shape; }
private:
- enum class ElementKind { Boolean, Integer, Float, String };
-
- /// Return a string to represent the given element kind.
- const char *getElementKindStr(ElementKind kind) {
- switch (kind) {
- case ElementKind::Boolean:
- return "'boolean'";
- case ElementKind::Integer:
- return "'integer'";
- case ElementKind::Float:
- return "'float'";
- case ElementKind::String:
- return "'string'";
- }
- llvm_unreachable("unknown element kind");
- }
-
- /// Build a Dense Integer attribute for the given type.
- DenseElementsAttr getIntAttr(llvm::SMLoc loc, ShapedType type, Type eltTy);
+ /// Get the parsed elements for an integer attribute.
+ ParseResult getIntAttrElements(llvm::SMLoc loc, Type eltTy,
+ std::vector<APInt> &intValues);
- /// Build a Dense Float attribute for the given type.
- DenseElementsAttr getFloatAttr(llvm::SMLoc loc, ShapedType type,
- FloatType eltTy);
+ /// Get the parsed elements for a float attribute.
+ ParseResult getFloatAttrElements(llvm::SMLoc loc, FloatType eltTy,
+ std::vector<APFloat> &floatValues);
/// Build a Dense String attribute for the given type.
DenseElementsAttr getStringAttr(llvm::SMLoc loc, ShapedType type, Type eltTy);
@@ -2011,9 +1995,6 @@ class TensorLiteralParser {
/// Storage used when parsing elements, this is a pair of <is_negated, token>.
std::vector<std::pair<bool, Token>> storage;
- /// A flag that indicates the type of elements that have been parsed.
- Optional<ElementKind> knownEltKind;
-
/// Storage used when parsing elements that were stored as hex values.
Optional<Token> hexStorage;
};
@@ -2053,22 +2034,40 @@ DenseElementsAttr TensorLiteralParser::getAttr(llvm::SMLoc loc,
return nullptr;
}
- // If the type is an integer, build a set of APInt values from the storage
- // with the correct bitwidth.
- if (auto intTy = eltType.dyn_cast<IntegerType>())
- return getIntAttr(loc, type, intTy);
- if (auto indexTy = eltType.dyn_cast<IndexType>())
- return getIntAttr(loc, type, indexTy);
-
- // If parsing a floating point type.
- if (auto floatTy = eltType.dyn_cast<FloatType>())
- return getFloatAttr(loc, type, floatTy);
+ // Handle complex types in the specific element type cases below.
+ bool isComplex = false;
+ if (ComplexType complexTy = eltType.dyn_cast<ComplexType>()) {
+ eltType = complexTy.getElementType();
+ isComplex = true;
+ }
- // If parsing a complex type.
- // TODO: Support complex elements with pretty element printing.
- if (eltType.isa<ComplexType>()) {
- p.emitError(loc) << "complex elements only support hex formatting";
- return nullptr;
+ // Handle integer and index types.
+ if (eltType.isIntOrIndex()) {
+ std::vector<APInt> intValues;
+ if (failed(getIntAttrElements(loc, eltType, intValues)))
+ return nullptr;
+ if (isComplex) {
+ // If this is a complex, treat the parsed values as complex values.
+ auto complexData = llvm::makeArrayRef(
+ reinterpret_cast<std::complex<APInt> *>(intValues.data()),
+ intValues.size() / 2);
+ return DenseElementsAttr::get(type, complexData);
+ }
+ return DenseElementsAttr::get(type, intValues);
+ }
+ // Handle floating point types.
+ if (FloatType floatTy = eltType.dyn_cast<FloatType>()) {
+ std::vector<APFloat> floatValues;
+ if (failed(getFloatAttrElements(loc, floatTy, floatValues)))
+ return nullptr;
+ if (isComplex) {
+ // If this is a complex, treat the parsed values as complex values.
+ auto complexData = llvm::makeArrayRef(
+ reinterpret_cast<std::complex<APFloat> *>(floatValues.data()),
+ floatValues.size() / 2);
+ return DenseElementsAttr::get(type, complexData);
+ }
+ return DenseElementsAttr::get(type, floatValues);
}
// Other types are assumed to be string representations.
@@ -2076,39 +2075,36 @@ DenseElementsAttr TensorLiteralParser::getAttr(llvm::SMLoc loc,
}
/// Build a Dense Integer attribute for the given type.
-DenseElementsAttr TensorLiteralParser::getIntAttr(llvm::SMLoc loc,
- ShapedType type, Type eltTy) {
- std::vector<APInt> intElements;
- intElements.reserve(storage.size());
- auto isUintType = type.getElementType().isUnsignedInteger();
+ParseResult
+TensorLiteralParser::getIntAttrElements(llvm::SMLoc loc, Type eltTy,
+ std::vector<APInt> &intValues) {
+ intValues.reserve(storage.size());
+ bool isUintType = eltTy.isUnsignedInteger();
for (const auto &signAndToken : storage) {
bool isNegative = signAndToken.first;
const Token &token = signAndToken.second;
auto tokenLoc = token.getLoc();
if (isNegative && isUintType) {
- p.emitError(tokenLoc)
- << "expected unsigned integer elements, but parsed negative value";
- return nullptr;
+ return p.emitError(tokenLoc)
+ << "expected unsigned integer elements, but parsed negative value";
}
// Check to see if floating point values were parsed.
if (token.is(Token::floatliteral)) {
- p.emitError(tokenLoc)
- << "expected integer elements, but parsed floating-point";
- return nullptr;
+ return p.emitError(tokenLoc)
+ << "expected integer elements, but parsed floating-point";
}
assert(token.isAny(Token::integer, Token::kw_true, Token::kw_false) &&
"unexpected token type");
if (token.isAny(Token::kw_true, Token::kw_false)) {
if (!eltTy.isInteger(1)) {
- p.emitError(tokenLoc)
- << "expected i1 type for 'true' or 'false' values";
- return nullptr;
+ return p.emitError(tokenLoc)
+ << "expected i1 type for 'true' or 'false' values";
}
APInt apInt(1, token.is(Token::kw_true), /*isSigned=*/false);
- intElements.push_back(apInt);
+ intValues.push_back(apInt);
continue;
}
@@ -2116,19 +2112,16 @@ DenseElementsAttr TensorLiteralParser::getIntAttr(llvm::SMLoc loc,
Optional<APInt> apInt =
buildAttributeAPInt(eltTy, isNegative, token.getSpelling());
if (!apInt)
- return (p.emitError(tokenLoc, "integer constant out of range for type"),
- nullptr);
- intElements.push_back(*apInt);
+ return p.emitError(tokenLoc, "integer constant out of range for type");
+ intValues.push_back(*apInt);
}
-
- return DenseElementsAttr::get(type, intElements);
+ return success();
}
/// Build a Dense Float attribute for the given type.
-DenseElementsAttr TensorLiteralParser::getFloatAttr(llvm::SMLoc loc,
- ShapedType type,
- FloatType eltTy) {
- std::vector<APFloat> floatValues;
+ParseResult
+TensorLiteralParser::getFloatAttrElements(llvm::SMLoc loc, FloatType eltTy,
+ std::vector<APFloat> &floatValues) {
floatValues.reserve(storage.size());
for (const auto &signAndToken : storage) {
bool isNegative = signAndToken.first;
@@ -2137,34 +2130,31 @@ DenseElementsAttr TensorLiteralParser::getFloatAttr(llvm::SMLoc loc,
// Handle hexadecimal float literals.
if (token.is(Token::integer) && token.getSpelling().startswith("0x")) {
if (isNegative) {
- p.emitError(token.getLoc())
- << "hexadecimal float literal should not have a leading minus";
- return nullptr;
+ return p.emitError(token.getLoc())
+ << "hexadecimal float literal should not have a leading minus";
}
auto val = token.getUInt64IntegerValue();
if (!val.hasValue()) {
- p.emitError("hexadecimal float constant out of range for attribute");
- return nullptr;
+ return p.emitError(
+ "hexadecimal float constant out of range for attribute");
}
Optional<APFloat> apVal = buildHexadecimalFloatLiteral(&p, eltTy, *val);
if (!apVal)
- return nullptr;
+ return failure();
floatValues.push_back(*apVal);
continue;
}
// Check to see if any decimal integers or booleans were parsed.
- if (!token.is(Token::floatliteral)) {
- p.emitError() << "expected floating-point elements, but parsed integer";
- return nullptr;
- }
+ if (!token.is(Token::floatliteral))
+ return p.emitError()
+ << "expected floating-point elements, but parsed integer";
// Build the float values from tokens.
auto val = token.getFloatingPointValue();
- if (!val.hasValue()) {
- p.emitError("floating point value too large for attribute");
- return nullptr;
- }
+ if (!val.hasValue())
+ return p.emitError("floating point value too large for attribute");
+
// Treat BF16 as double because it is not supported in LLVM's APFloat.
APFloat apVal(isNegative ? -*val : *val);
if (!eltTy.isBF16() && !eltTy.isF64()) {
@@ -2174,8 +2164,7 @@ DenseElementsAttr TensorLiteralParser::getFloatAttr(llvm::SMLoc loc,
}
floatValues.push_back(apVal);
}
-
- return DenseElementsAttr::get(type, floatValues);
+ return success();
}
/// Build a Dense String attribute for the given type.
@@ -2250,6 +2239,17 @@ ParseResult TensorLiteralParser::parseElement() {
storage.emplace_back(/*isNegative=*/ false, p.getToken());
p.consumeToken();
break;
+
+ // Parse a complex element of the form '(' element ',' element ')'.
+ case Token::l_paren:
+ p.consumeToken(Token::l_paren);
+ if (parseElement() ||
+ p.parseToken(Token::comma, "expected ',' between complex elements") ||
+ parseElement() ||
+ p.parseToken(Token::r_paren, "expected ')' after complex elements"))
+ return failure();
+ break;
+
default:
return p.emitError("expected element literal of primitive type");
}
diff --git a/mlir/test/IR/dense-elements-hex.mlir b/mlir/test/IR/dense-elements-hex.mlir
index 87c0acf80341..e0e12418e1d5 100644
--- a/mlir/test/IR/dense-elements-hex.mlir
+++ b/mlir/test/IR/dense-elements-hex.mlir
@@ -7,11 +7,8 @@
// CHECK: dense<[1.000000e+01, 5.000000e+00]> : tensor<2xf64>
"foo.op"() {dense.attr = dense<"0x00000000000024400000000000001440"> : tensor<2xf64>} : () -> ()
-// CHECK: dense<"0x00000000000024400000000000001440"> : tensor<1xcomplex<f64>>
-"foo.op"() {dense.attr = dense<"0x00000000000024400000000000001440"> : tensor<1xcomplex<f64>>} : () -> ()
-
-// CHECK: dense<"0x00000000000024400000000000001440"> : tensor<10xcomplex<f64>>
-"foo.op"() {dense.attr = dense<"0x00000000000024400000000000001440"> : tensor<10xcomplex<f64>>} : () -> ()
+// CHECK: dense<(1.000000e+01,5.000000e+00)> : tensor<2xcomplex<f64>>
+"foo.op"() {dense.attr = dense<"0x0000000000002440000000000000144000000000000024400000000000001440"> : tensor<2xcomplex<f64>>} : () -> ()
// CHECK: dense<[1.000000e+01, 5.000000e+00]> : tensor<2xbf16>
"foo.op"() {dense.attr = dense<"0x00000000000024400000000000001440"> : tensor<2xbf16>} : () -> ()
diff --git a/mlir/test/IR/invalid.mlir b/mlir/test/IR/invalid.mlir
index 96b234834481..926c2e3c2857 100644
--- a/mlir/test/IR/invalid.mlir
+++ b/mlir/test/IR/invalid.mlir
@@ -689,6 +689,22 @@ func @elementsattr_toolarge2() -> () {
// -----
+"foo"(){bar = dense<[()]> : tensor<complex<i64>>} : () -> () // expected-error {{expected element literal of primitive type}}
+
+// -----
+
+"foo"(){bar = dense<[(10)]> : tensor<complex<i64>>} : () -> () // expected-error {{expected ',' between complex elements}}
+
+// -----
+
+"foo"(){bar = dense<[(10,)]> : tensor<complex<i64>>} : () -> () // expected-error {{expected element literal of primitive type}}
+
+// -----
+
+"foo"(){bar = dense<[(10,10]> : tensor<complex<i64>>} : () -> () // expected-error {{expected ')' after complex elements}}
+
+// -----
+
func @elementsattr_malformed_opaque() -> () {
^bb0:
"foo"(){bar = opaque<10, "0xQZz123"> : tensor<1xi8>} : () -> () // expected-error {{expected dialect namespace}}
diff --git a/mlir/test/IR/parser.mlir b/mlir/test/IR/parser.mlir
index 2170927df927..34b1da4282f1 100644
--- a/mlir/test/IR/parser.mlir
+++ b/mlir/test/IR/parser.mlir
@@ -702,6 +702,15 @@ func @densetensorattr() -> () {
"index"(){bar = dense<1> : tensor<index>} : () -> ()
// CHECK: "index"() {bar = dense<[1, 2]> : tensor<2xindex>} : () -> ()
"index"(){bar = dense<[1, 2]> : tensor<2xindex>} : () -> ()
+
+ // CHECK: dense<(1,1)> : tensor<complex<i64>>
+ "complex_attr"(){bar = dense<(1,1)> : tensor<complex<i64>>} : () -> ()
+ // CHECK: dense<[(1,1), (2,2)]> : tensor<2xcomplex<i64>>
+ "complex_attr"(){bar = dense<[(1,1), (2,2)]> : tensor<2xcomplex<i64>>} : () -> ()
+ // CHECK: dense<(1.000000e+00,0.000000e+00)> : tensor<complex<f32>>
+ "complex_attr"(){bar = dense<(1.000000e+00,0.000000e+00)> : tensor<complex<f32>>} : () -> ()
+ // CHECK: dense<[(1.000000e+00,0.000000e+00), (2.000000e+00,2.000000e+00)]> : tensor<2xcomplex<f32>>
+ "complex_attr"(){bar = dense<[(1.000000e+00,0.000000e+00), (2.000000e+00,2.000000e+00)]> : tensor<2xcomplex<f32>>} : () -> ()
return
}
diff --git a/mlir/unittests/IR/AttributeTest.cpp b/mlir/unittests/IR/AttributeTest.cpp
index 7f91d57b01ae..8fda2a2e73b6 100644
--- a/mlir/unittests/IR/AttributeTest.cpp
+++ b/mlir/unittests/IR/AttributeTest.cpp
@@ -27,7 +27,7 @@ static void testSplat(Type eltType, const EltTy &splatElt) {
EXPECT_EQ(detectedSplat, splat);
for (auto newValue : detectedSplat.template getValues<EltTy>())
- EXPECT_EQ(newValue, splatElt);
+ EXPECT_TRUE(newValue == splatElt);
}
namespace {
@@ -179,4 +179,18 @@ TEST(DenseComplexTest, ComplexIntSplat) {
testSplat(complexType, value);
}
+TEST(DenseComplexTest, ComplexAPFloatSplat) {
+ MLIRContext context;
+ ComplexType complexType = ComplexType::get(FloatType::getF32(&context));
+ std::complex<APFloat> value(APFloat(10.0f), APFloat(15.0f));
+ testSplat(complexType, value);
+}
+
+TEST(DenseComplexTest, ComplexAPIntSplat) {
+ MLIRContext context;
+ ComplexType complexType = ComplexType::get(IntegerType::get(64, &context));
+ std::complex<APInt> value(APInt(64, 10), APInt(64, 15));
+ testSplat(complexType, value);
+}
+
} // end namespace
More information about the Mlir-commits
mailing list