[Mlir-commits] [mlir] da2a6f4 - [mlir][DenseElementsAttr] Add support for ComplexType elements
River Riddle
llvmlistbot at llvm.org
Tue May 5 12:46:41 PDT 2020
Author: River Riddle
Date: 2020-05-05T12:42:37-07:00
New Revision: da2a6f4e3b5235d871c2e81ae1b0577002733653
URL: https://github.com/llvm/llvm-project/commit/da2a6f4e3b5235d871c2e81ae1b0577002733653
DIFF: https://github.com/llvm/llvm-project/commit/da2a6f4e3b5235d871c2e81ae1b0577002733653.diff
LOG: [mlir][DenseElementsAttr] Add support for ComplexType elements
This revision adds support for storing ComplexType elements inside of a DenseElementsAttr. We store complex objects as an array of two elements, matching the definition of std::complex. There is no current attribute storage for ComplexType, but DenseElementsAttr provides API for access/creation using std::complex<>. Given that the internal implementation of DenseElementsAttr is already fairly opaque, the only real complexity here is in the printing/parsing. This revision keeps it simple for now and always uses hex when printing complex elements. A followup will add prettier syntax for this.
Differential Revision: https://reviews.llvm.org/D79281
Added:
Modified:
mlir/include/mlir/IR/Attributes.h
mlir/lib/IR/AsmPrinter.cpp
mlir/lib/IR/AttributeDetail.h
mlir/lib/IR/Attributes.cpp
mlir/lib/Parser/Parser.cpp
mlir/test/IR/dense-elements-hex.mlir
mlir/unittests/IR/AttributeTest.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/IR/Attributes.h b/mlir/include/mlir/IR/Attributes.h
index 6d24cd087648..6b3edf14a39b 100644
--- a/mlir/include/mlir/IR/Attributes.h
+++ b/mlir/include/mlir/IR/Attributes.h
@@ -13,6 +13,7 @@
#include "llvm/ADT/APFloat.h"
#include "llvm/ADT/Sequence.h"
#include "llvm/Support/PointerLikeTypeTraits.h"
+#include <complex>
namespace mlir {
class AffineMap;
@@ -687,6 +688,11 @@ class DenseElementIndexedIteratorImpl
/// Return the data base pointer.
const char *getData() const { return this->base.getPointer(); }
};
+
+/// Type trait detector that checks if a given type T is a complex type.
+template <typename T> struct is_complex_t : public std::false_type {};
+template <typename T>
+struct is_complex_t<std::complex<T>> : public std::true_type {};
} // namespace detail
/// An attribute that represents a reference to a dense vector or tensor object.
@@ -724,11 +730,27 @@ class DenseElementsAttr : public ElementsAttr {
/// Constructs a dense integer elements attribute from a single element.
template <typename T, typename = typename std::enable_if<
std::numeric_limits<T>::is_integer ||
- llvm::is_one_of<T, float, double>::value>::type>
+ llvm::is_one_of<T, float, double>::value ||
+ detail::is_complex_t<T>::value>::type>
static DenseElementsAttr get(const ShapedType &type, T value) {
return get(type, llvm::makeArrayRef(value));
}
+ /// Constructs a dense complex elements attribute from an array of complex
+ /// values. Each value is expected to be the same bitwidth of the element type
+ /// of 'type'. 'type' must be a vector or tensor with static shape.
+ template <typename T, typename ElementT = typename T::value_type,
+ typename = typename std::enable_if<
+ detail::is_complex_t<T>::value &&
+ (std::numeric_limits<ElementT>::is_integer ||
+ llvm::is_one_of<ElementT, float, double>::value)>::type>
+ static DenseElementsAttr get(const ShapedType &type, ArrayRef<T> values) {
+ const char *data = reinterpret_cast<const char *>(values.data());
+ return getRawComplex(type, ArrayRef<char>(data, values.size() * sizeof(T)),
+ sizeof(T), std::numeric_limits<ElementT>::is_integer,
+ std::numeric_limits<ElementT>::is_signed);
+ }
+
/// Overload of the above 'get' method that is specialized for boolean values.
static DenseElementsAttr get(ShapedType type, ArrayRef<bool> values);
@@ -764,6 +786,12 @@ class DenseElementsAttr : public ElementsAttr {
ArrayRef<char> rawBuffer,
bool isSplatBuffer);
+ /// Returns true if the given buffer is a valid raw buffer for the given type.
+ /// `detectedSplat` is set if the buffer is valid and represents a splat
+ /// buffer.
+ static bool isValidRawBuffer(ShapedType type, ArrayRef<char> rawBuffer,
+ bool &detectedSplat);
+
//===--------------------------------------------------------------------===//
// Iterators
//===--------------------------------------------------------------------===//
@@ -900,12 +928,28 @@ class DenseElementsAttr : public ElementsAttr {
llvm::iterator_range<ElementIterator<T>> getValues() const {
assert(isValidIntOrFloat(sizeof(T), std::numeric_limits<T>::is_integer,
std::numeric_limits<T>::is_signed));
- auto rawData = getRawData().data();
+ const char *rawData = getRawData().data();
+ bool splat = isSplat();
+ return {ElementIterator<T>(rawData, splat, 0),
+ ElementIterator<T>(rawData, splat, getNumElements())};
+ }
+
+ /// Return the held element values as a range of std::complex.
+ template <typename T, typename ElementT = typename T::value_type,
+ typename = typename std::enable_if<
+ detail::is_complex_t<T>::value &&
+ (std::numeric_limits<ElementT>::is_integer ||
+ llvm::is_one_of<ElementT, float, double>::value)>::type>
+ llvm::iterator_range<ElementIterator<T>> getValues() const {
+ assert(isValidComplex(sizeof(T), std::numeric_limits<ElementT>::is_integer,
+ std::numeric_limits<ElementT>::is_signed));
+ const char *rawData = getRawData().data();
bool splat = isSplat();
return {ElementIterator<T>(rawData, splat, 0),
ElementIterator<T>(rawData, splat, getNumElements())};
}
+ /// Return the held element values as a range of StringRef.
template <typename T, typename = typename std::enable_if<
std::is_same<T, StringRef>::value>::type>
llvm::iterator_range<ElementIterator<StringRef>> getValues() const {
@@ -1010,6 +1054,13 @@ class DenseElementsAttr : public ElementsAttr {
return IntElementIterator(*this, getNumElements());
}
+ /// Overload of the raw 'get' method that asserts that the given type is of
+ /// complex type. This method is used to verify type invariants that the
+ /// templatized 'get' method cannot.
+ static DenseElementsAttr getRawComplex(ShapedType type, ArrayRef<char> data,
+ int64_t dataEltSize, bool isInt,
+ bool isSigned);
+
/// Overload of the raw 'get' method that asserts that the given type is of
/// integer or floating-point type. This method is used to verify type
/// invariants that the templatized 'get' method cannot.
@@ -1022,6 +1073,11 @@ class DenseElementsAttr : public ElementsAttr {
/// the current attribute. This method is used to verify specific type
/// invariants that the templatized 'getValues' method cannot.
bool isValidIntOrFloat(int64_t dataEltSize, bool isInt, bool isSigned) const;
+
+ /// Check the information for a C++ data type, check if this type is valid for
+ /// the current attribute. This method is used to verify specific type
+ /// invariants that the templatized 'getValues' method cannot.
+ bool isValidComplex(int64_t dataEltSize, bool isInt, bool isSigned) const;
};
/// An attribute class for representing dense arrays of strings. The structure
@@ -1074,6 +1130,13 @@ class DenseIntOrFPElementsAttr
static DenseElementsAttr getRaw(ShapedType type, ArrayRef<char> data,
bool isSplat);
+ /// Overload of the raw 'get' method that asserts that the given type is of
+ /// complex type. This method is used to verify type invariants that the
+ /// templatized 'get' method cannot.
+ static DenseElementsAttr getRawComplex(ShapedType type, ArrayRef<char> data,
+ int64_t dataEltSize, bool isInt,
+ bool isSigned);
+
/// Overload of the raw 'get' method that asserts that the given type is of
/// integer or floating-point type. This method is used to verify type
/// invariants that the templatized 'get' method cannot.
@@ -1287,20 +1350,15 @@ class SparseElementsAttr
return getZeroAPFloat();
}
- /// Get a zero for a StringRef.
+ /// Get a zero for an C++ integer, float, StringRef, or complex type.
template <typename T>
- typename std::enable_if<std::is_same<StringRef, T>::value, T>::type
- getZeroValue() const {
- return StringRef();
- }
-
- /// Get a zero for an C++ integer or float type.
- template <typename T>
- typename std::enable_if<std::numeric_limits<T>::is_integer ||
- llvm::is_one_of<T, float, double>::value,
- T>::type
+ typename std::enable_if<
+ std::numeric_limits<T>::is_integer ||
+ llvm::is_one_of<T, float, double, StringRef>::value ||
+ detail::is_complex_t<T>::value,
+ T>::type
getZeroValue() const {
- return T(0);
+ return T();
}
/// Flatten, and return, all of the sparse indices in this attribute in
diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index 005571afb59a..4305f8b9af19 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -1468,50 +1468,21 @@ static void printDenseIntElement(DenseElementsAttr attr, raw_ostream &os,
/// Print the float element of the given DenseElementsAttr at 'index'.
static void printDenseFloatElement(DenseElementsAttr attr, raw_ostream &os,
- unsigned index, bool isSigned) {
- assert(isSigned && "floating point values are always signed");
+ unsigned index) {
APFloat value = *std::next(attr.float_value_begin(), index);
printFloatValue(value, os);
}
-static void printDenseStringElement(DenseStringElementsAttr attr,
- raw_ostream &os, unsigned index) {
- os << "\"";
- printEscapedString(attr.getRawStringData()[index], os);
- os << "\"";
-}
-
-void ModulePrinter::printDenseElementsAttr(DenseElementsAttr attr,
- bool allowHex) {
- if (auto stringAttr = attr.dyn_cast<DenseStringElementsAttr>()) {
- printDenseStringElementsAttr(stringAttr);
- return;
- }
-
- printDenseIntOrFPElementsAttr(attr.cast<DenseIntOrFPElementsAttr>(),
- allowHex);
-}
-
-void ModulePrinter::printDenseIntOrFPElementsAttr(DenseIntOrFPElementsAttr attr,
- bool allowHex) {
- auto type = attr.getType();
- auto shape = type.getShape();
- auto rank = type.getRank();
- bool isSigned = !type.getElementType().isUnsignedInteger();
-
- // The function used to print elements of this attribute.
- auto printEltFn = type.getElementType().isIntOrIndex()
- ? printDenseIntElement
- : printDenseFloatElement;
-
+static void
+printDenseElementsAttrImpl(bool isSplat, ShapedType type, raw_ostream &os,
+ function_ref<void(unsigned)> printEltFn) {
// Special case for 0-d and splat tensors.
- if (attr.isSplat()) {
- printEltFn(attr, os, 0, isSigned);
- return;
- }
+ if (isSplat)
+ return printEltFn(0);
// Special case for degenerate tensors.
auto numElements = type.getNumElements();
+ int64_t rank = type.getRank();
if (numElements == 0) {
for (int i = 0; i < rank; ++i)
os << '[';
@@ -1520,14 +1491,6 @@ void ModulePrinter::printDenseIntOrFPElementsAttr(DenseIntOrFPElementsAttr attr,
return;
}
- // Check to see if we should format this attribute as a hex string.
- if (allowHex && shouldPrintElementsAttrWithHex(numElements)) {
- ArrayRef<char> rawData = attr.getRawData();
- os << '"' << "0x" << llvm::toHex(StringRef(rawData.data(), rawData.size()))
- << "\"";
- return;
- }
-
// We use a mixed-radix counter to iterate through the shape. When we bump a
// non-least-significant digit, we emit a close bracket. When we next emit an
// element we re-open all closed brackets.
@@ -1537,7 +1500,8 @@ void ModulePrinter::printDenseIntOrFPElementsAttr(DenseIntOrFPElementsAttr attr,
// The number of brackets that have been opened and not closed.
unsigned openBrackets = 0;
- auto bumpCounter = [&]() {
+ auto shape = type.getShape();
+ auto bumpCounter = [&] {
// Bump the least significant digit.
++counter[rank - 1];
// Iterate backwards bubbling back the increment.
@@ -1557,68 +1521,60 @@ void ModulePrinter::printDenseIntOrFPElementsAttr(DenseIntOrFPElementsAttr attr,
while (openBrackets++ < rank)
os << '[';
openBrackets = rank;
- printEltFn(attr, os, idx, isSigned);
+ printEltFn(idx);
bumpCounter();
}
while (openBrackets-- > 0)
os << ']';
}
-void ModulePrinter::printDenseStringElementsAttr(DenseStringElementsAttr attr) {
- auto type = attr.getType();
- auto shape = type.getShape();
- auto rank = type.getRank();
+void ModulePrinter::printDenseElementsAttr(DenseElementsAttr attr,
+ bool allowHex) {
+ if (auto stringAttr = attr.dyn_cast<DenseStringElementsAttr>())
+ return printDenseStringElementsAttr(stringAttr);
- // Special case for 0-d and splat tensors.
- if (attr.isSplat()) {
- printDenseStringElement(attr, os, 0);
- return;
- }
+ printDenseIntOrFPElementsAttr(attr.cast<DenseIntOrFPElementsAttr>(),
+ allowHex);
+}
- // Special case for degenerate tensors.
+void ModulePrinter::printDenseIntOrFPElementsAttr(DenseIntOrFPElementsAttr attr,
+ bool allowHex) {
+ auto type = attr.getType();
+ auto elementType = type.getElementType();
+
+ // Check to see if we should format this attribute as a hex string.
+ // TODO: Add support for formatting complex elements nicely.
auto numElements = type.getNumElements();
- if (numElements == 0) {
- for (int i = 0; i < rank; ++i)
- os << '[';
- for (int i = 0; i < rank; ++i)
- os << ']';
+ if (type.getElementType().isa<ComplexType>() ||
+ (!attr.isSplat() && allowHex &&
+ shouldPrintElementsAttrWithHex(numElements))) {
+ ArrayRef<char> rawData = attr.getRawData();
+ os << '"' << "0x" << llvm::toHex(StringRef(rawData.data(), rawData.size()))
+ << "\"";
return;
}
- // We use a mixed-radix counter to iterate through the shape. When we bump a
- // non-least-significant digit, we emit a close bracket. When we next emit an
- // element we re-open all closed brackets.
-
- // The mixed-radix counter, with radices in 'shape'.
- SmallVector<unsigned, 4> counter(rank, 0);
- // The number of brackets that have been opened and not closed.
- unsigned openBrackets = 0;
+ if (elementType.isIntOrIndex()) {
+ bool isSigned = !elementType.isUnsignedInteger();
+ printDenseElementsAttrImpl(attr.isSplat(), type, os, [&](unsigned index) {
+ printDenseIntElement(attr, os, index, isSigned);
+ });
+ } else {
+ assert(elementType.isa<FloatType>() && "unexpected element type");
+ printDenseElementsAttrImpl(attr.isSplat(), type, os, [&](unsigned index) {
+ printDenseFloatElement(attr, os, index);
+ });
+ }
+}
- auto bumpCounter = [&]() {
- // Bump the least significant digit.
- ++counter[rank - 1];
- // Iterate backwards bubbling back the increment.
- for (unsigned i = rank - 1; i > 0; --i)
- if (counter[i] >= shape[i]) {
- // Index 'i' is rolled over. Bump (i-1) and close a bracket.
- counter[i] = 0;
- ++counter[i - 1];
- --openBrackets;
- os << ']';
- }
+void ModulePrinter::printDenseStringElementsAttr(DenseStringElementsAttr attr) {
+ ArrayRef<StringRef> data = attr.getRawStringData();
+ auto printFn = [&](unsigned index) {
+ os << "\"";
+ printEscapedString(data[index], os);
+ os << "\"";
};
-
- for (unsigned idx = 0, e = numElements; idx != e; ++idx) {
- if (idx != 0)
- os << ", ";
- while (openBrackets++ < rank)
- os << '[';
- openBrackets = rank;
- printDenseStringElement(attr, os, idx);
- bumpCounter();
- }
- while (openBrackets-- > 0)
- os << ']';
+ printDenseElementsAttrImpl(attr.isSplat(), attr.getType(), os, printFn);
}
void ModulePrinter::printType(Type type) {
diff --git a/mlir/lib/IR/AttributeDetail.h b/mlir/lib/IR/AttributeDetail.h
index 201f058571af..4084dde74919 100644
--- a/mlir/lib/IR/AttributeDetail.h
+++ b/mlir/lib/IR/AttributeDetail.h
@@ -374,6 +374,8 @@ struct TypeAttributeStorage : public AttributeStorage {
/// Return the bit width which DenseElementsAttr should use for this type.
inline size_t getDenseElementBitWidth(Type eltType) {
+ if (ComplexType complex = eltType.dyn_cast<ComplexType>())
+ return getDenseElementBitWidth(complex.getElementType()) * 2;
// FIXME(b/121118307): using 64 bits for BF16 because it is currently stored
// with double semantics.
if (eltType.isBF16())
diff --git a/mlir/lib/IR/Attributes.cpp b/mlir/lib/IR/Attributes.cpp
index 5cb2d5a2ea84..72c638a68f73 100644
--- a/mlir/lib/IR/Attributes.cpp
+++ b/mlir/lib/IR/Attributes.cpp
@@ -779,24 +779,44 @@ DenseElementsAttr DenseElementsAttr::getFromRawBuffer(ShapedType type,
return DenseIntOrFPElementsAttr::getRaw(type, rawBuffer, isSplatBuffer);
}
+/// Returns true if the given buffer is a valid raw buffer for the given type.
+bool DenseElementsAttr::isValidRawBuffer(ShapedType type,
+ ArrayRef<char> rawBuffer,
+ bool &detectedSplat) {
+ size_t elementWidth = getDenseElementBitWidth(type.getElementType());
+ size_t storageWidth = getDenseElementStorageWidth(elementWidth);
+ size_t rawBufferWidth = rawBuffer.size() * CHAR_BIT;
+
+ // Storage width of 1 is special as it is packed by the bit.
+ if (storageWidth == 1) {
+ // Check for a splat, or a buffer equal to the number of elements.
+ if ((detectedSplat = rawBuffer.size() == 1))
+ return true;
+ return rawBufferWidth == llvm::alignTo<8>(type.getNumElements());
+ }
+ // All other types are 8-bit aligned.
+ if ((detectedSplat = rawBufferWidth == storageWidth))
+ return true;
+ return rawBufferWidth == (storageWidth * type.getNumElements());
+}
+
/// Check the information for a C++ data type, check if this type is valid for
/// the current attribute. This method is used to verify specific type
/// invariants that the templatized 'getValues' method cannot.
-static bool isValidIntOrFloat(ShapedType type, int64_t dataEltSize, bool isInt,
+static bool isValidIntOrFloat(Type type, int64_t dataEltSize, bool isInt,
bool isSigned) {
// Make sure that the data element size is the same as the type element width.
- if (getDenseElementBitWidth(type.getElementType()) !=
+ if (getDenseElementBitWidth(type) !=
static_cast<size_t>(dataEltSize * CHAR_BIT))
return false;
// Check that the element type is either float or integer or index.
if (!isInt)
- return type.getElementType().isa<FloatType>();
-
- if (type.getElementType().isIndex())
+ return type.isa<FloatType>();
+ if (type.isIndex())
return true;
- auto intType = type.getElementType().dyn_cast<IntegerType>();
+ auto intType = type.dyn_cast<IntegerType>();
if (!intType)
return false;
@@ -807,6 +827,13 @@ static bool isValidIntOrFloat(ShapedType type, int64_t dataEltSize, bool isInt,
}
/// Defaults down the subclass implementation.
+DenseElementsAttr DenseElementsAttr::getRawComplex(ShapedType type,
+ ArrayRef<char> data,
+ int64_t dataEltSize,
+ bool isInt, bool isSigned) {
+ return DenseIntOrFPElementsAttr::getRawComplex(type, data, dataEltSize, isInt,
+ isSigned);
+}
DenseElementsAttr DenseElementsAttr::getRawIntOrFloat(ShapedType type,
ArrayRef<char> data,
int64_t dataEltSize,
@@ -820,7 +847,17 @@ DenseElementsAttr DenseElementsAttr::getRawIntOrFloat(ShapedType type,
/// method cannot.
bool DenseElementsAttr::isValidIntOrFloat(int64_t dataEltSize, bool isInt,
bool isSigned) const {
- return ::isValidIntOrFloat(getType(), dataEltSize, isInt, isSigned);
+ return ::isValidIntOrFloat(getType().getElementType(), dataEltSize, isInt,
+ isSigned);
+}
+
+/// Check the information for a C++ data type, check if this type is valid for
+/// the current attribute.
+bool DenseElementsAttr::isValidComplex(int64_t dataEltSize, bool isInt,
+ bool isSigned) const {
+ return ::isValidIntOrFloat(
+ getType().getElementType().cast<ComplexType>().getElementType(),
+ dataEltSize / 2, isInt, isSigned);
}
/// Returns if this attribute corresponds to a splat, i.e. if all element
@@ -964,6 +1001,23 @@ DenseElementsAttr DenseIntOrFPElementsAttr::getRaw(ShapedType type,
type, data, isSplat);
}
+/// Overload of the raw 'get' method that asserts that the given type is of
+/// complex type. This method is used to verify type invariants that the
+/// templatized 'get' method cannot.
+DenseElementsAttr DenseIntOrFPElementsAttr::getRawComplex(ShapedType type,
+ ArrayRef<char> data,
+ int64_t dataEltSize,
+ bool isInt,
+ bool isSigned) {
+ assert(::isValidIntOrFloat(
+ type.getElementType().cast<ComplexType>().getElementType(),
+ dataEltSize / 2, isInt, isSigned));
+
+ int64_t numElements = data.size() / dataEltSize;
+ assert(numElements == 1 || numElements == type.getNumElements());
+ return getRaw(type, data, /*isSplat=*/numElements == 1);
+}
+
/// Overload of the 'getRaw' method that asserts that the given type is of
/// integer type. This method is used to verify type invariants that the
/// templatized 'get' method cannot.
@@ -971,7 +1025,8 @@ DenseElementsAttr
DenseIntOrFPElementsAttr::getRawIntOrFloat(ShapedType type, ArrayRef<char> data,
int64_t dataEltSize, bool isInt,
bool isSigned) {
- assert(::isValidIntOrFloat(type, dataEltSize, isInt, isSigned));
+ assert(
+ ::isValidIntOrFloat(type.getElementType(), dataEltSize, isInt, isSigned));
int64_t numElements = data.size() / dataEltSize;
assert(numElements == 1 || numElements == type.getNumElements());
diff --git a/mlir/lib/Parser/Parser.cpp b/mlir/lib/Parser/Parser.cpp
index d129b867fb0c..3689983d45e2 100644
--- a/mlir/lib/Parser/Parser.cpp
+++ b/mlir/lib/Parser/Parser.cpp
@@ -2041,7 +2041,8 @@ DenseElementsAttr TensorLiteralParser::getAttr(llvm::SMLoc loc,
Type eltType = type.getElementType();
// Check to see if we parse the literal from a hex string.
- if (hexStorage.hasValue() && eltType.isIntOrFloat())
+ if (hexStorage.hasValue() &&
+ (eltType.isIntOrFloat() || eltType.isa<ComplexType>()))
return getHexAttr(loc, type);
// Check that the parsed storage size has the same number of elements to the
@@ -2063,6 +2064,13 @@ DenseElementsAttr TensorLiteralParser::getAttr(llvm::SMLoc loc,
if (auto floatTy = eltType.dyn_cast<FloatType>())
return getFloatAttr(loc, type, floatTy);
+ // If parsing a complex type.
+ // TODO: Support complex elements with pretty element printing.
+ if (eltType.isa<ComplexType>()) {
+ p.emitError(loc) << "complex elements only support hex formatting";
+ return nullptr;
+ }
+
// Other types are assumed to be string representations.
return getStringAttr(loc, type, type.getElementType());
}
@@ -2196,9 +2204,10 @@ DenseElementsAttr TensorLiteralParser::getStringAttr(llvm::SMLoc loc,
DenseElementsAttr TensorLiteralParser::getHexAttr(llvm::SMLoc loc,
ShapedType type) {
Type elementType = type.getElementType();
- if (!elementType.isa<FloatType>() && !elementType.isa<IntegerType>()) {
- p.emitError(loc) << "expected floating-point or integer element type, got "
- << elementType;
+ if (!elementType.isIntOrIndexOrFloat() && !elementType.isa<ComplexType>()) {
+ p.emitError(loc)
+ << "expected floating-point, integer, or complex element type, got "
+ << elementType;
return nullptr;
}
@@ -2206,21 +2215,15 @@ DenseElementsAttr TensorLiteralParser::getHexAttr(llvm::SMLoc loc,
if (parseElementAttrHexValues(p, hexStorage.getValue(), data))
return nullptr;
- // Check that the size of the hex data corresponds to the size of the type, or
- // a splat of the type.
- // TODO: bf16 is currently stored as a double, this should be removed when
- // APFloat properly supports it.
- int64_t elementWidth =
- elementType.isBF16() ? 64 : elementType.getIntOrFloatBitWidth();
- if (static_cast<int64_t>(data.size() * CHAR_BIT) !=
- (type.getNumElements() * elementWidth)) {
+ ArrayRef<char> rawData(data.data(), data.size());
+ bool detectedSplat = false;
+ if (!DenseElementsAttr::isValidRawBuffer(type, rawData, detectedSplat)) {
p.emitError(loc) << "elements hex data size is invalid for provided type: "
<< type;
return nullptr;
}
- return DenseElementsAttr::getFromRawBuffer(
- type, ArrayRef<char>(data.data(), data.size()), /*isSplatBuffer=*/false);
+ return DenseElementsAttr::getFromRawBuffer(type, rawData, detectedSplat);
}
ParseResult TensorLiteralParser::parseElement() {
diff --git a/mlir/test/IR/dense-elements-hex.mlir b/mlir/test/IR/dense-elements-hex.mlir
index 0375004e1d84..87c0acf80341 100644
--- a/mlir/test/IR/dense-elements-hex.mlir
+++ b/mlir/test/IR/dense-elements-hex.mlir
@@ -7,6 +7,12 @@
// CHECK: dense<[1.000000e+01, 5.000000e+00]> : tensor<2xf64>
"foo.op"() {dense.attr = dense<"0x00000000000024400000000000001440"> : tensor<2xf64>} : () -> ()
+// CHECK: dense<"0x00000000000024400000000000001440"> : tensor<1xcomplex<f64>>
+"foo.op"() {dense.attr = dense<"0x00000000000024400000000000001440"> : tensor<1xcomplex<f64>>} : () -> ()
+
+// CHECK: dense<"0x00000000000024400000000000001440"> : tensor<10xcomplex<f64>>
+"foo.op"() {dense.attr = dense<"0x00000000000024400000000000001440"> : tensor<10xcomplex<f64>>} : () -> ()
+
// CHECK: dense<[1.000000e+01, 5.000000e+00]> : tensor<2xbf16>
"foo.op"() {dense.attr = dense<"0x00000000000024400000000000001440"> : tensor<2xbf16>} : () -> ()
diff --git a/mlir/unittests/IR/AttributeTest.cpp b/mlir/unittests/IR/AttributeTest.cpp
index 7d22b5e5a07f..7f91d57b01ae 100644
--- a/mlir/unittests/IR/AttributeTest.cpp
+++ b/mlir/unittests/IR/AttributeTest.cpp
@@ -25,6 +25,9 @@ static void testSplat(Type eltType, const EltTy &splatElt) {
auto detectedSplat =
DenseElementsAttr::get(shape, llvm::makeArrayRef({splatElt, splatElt}));
EXPECT_EQ(detectedSplat, splat);
+
+ for (auto newValue : detectedSplat.template getValues<EltTy>())
+ EXPECT_EQ(newValue, splatElt);
}
namespace {
@@ -162,4 +165,18 @@ TEST(DenseSplatTest, StringAttrSplat) {
testSplat(stringType, stringAttr);
}
+TEST(DenseComplexTest, ComplexFloatSplat) {
+ MLIRContext context;
+ ComplexType complexType = ComplexType::get(FloatType::getF32(&context));
+ std::complex<float> value(10.0, 15.0);
+ testSplat(complexType, value);
+}
+
+TEST(DenseComplexTest, ComplexIntSplat) {
+ MLIRContext context;
+ ComplexType complexType = ComplexType::get(IntegerType::get(64, &context));
+ std::complex<int64_t> value(10, 15);
+ testSplat(complexType, value);
+}
+
} // end namespace
More information about the Mlir-commits
mailing list