[Mlir-commits] [mlir] da2a6f4 - [mlir][DenseElementsAttr] Add support for ComplexType elements

River Riddle llvmlistbot at llvm.org
Tue May 5 12:46:41 PDT 2020


Author: River Riddle
Date: 2020-05-05T12:42:37-07:00
New Revision: da2a6f4e3b5235d871c2e81ae1b0577002733653

URL: https://github.com/llvm/llvm-project/commit/da2a6f4e3b5235d871c2e81ae1b0577002733653
DIFF: https://github.com/llvm/llvm-project/commit/da2a6f4e3b5235d871c2e81ae1b0577002733653.diff

LOG: [mlir][DenseElementsAttr] Add support for ComplexType elements

This revision adds support for storing ComplexType elements inside of a DenseElementsAttr. We store complex objects as an array of two elements, matching the  definition of std::complex. There is no current attribute storage for ComplexType, but DenseElementsAttr provides API for access/creation using std::complex<>. Given that the internal implementation of DenseElementsAttr is already fairly opaque, the only real complexity here is in the printing/parsing. This revision keeps it simple for now and always uses hex when printing complex elements. A followup will add prettier syntax for this.

Differential Revision: https://reviews.llvm.org/D79281

Added: 
    

Modified: 
    mlir/include/mlir/IR/Attributes.h
    mlir/lib/IR/AsmPrinter.cpp
    mlir/lib/IR/AttributeDetail.h
    mlir/lib/IR/Attributes.cpp
    mlir/lib/Parser/Parser.cpp
    mlir/test/IR/dense-elements-hex.mlir
    mlir/unittests/IR/AttributeTest.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/IR/Attributes.h b/mlir/include/mlir/IR/Attributes.h
index 6d24cd087648..6b3edf14a39b 100644
--- a/mlir/include/mlir/IR/Attributes.h
+++ b/mlir/include/mlir/IR/Attributes.h
@@ -13,6 +13,7 @@
 #include "llvm/ADT/APFloat.h"
 #include "llvm/ADT/Sequence.h"
 #include "llvm/Support/PointerLikeTypeTraits.h"
+#include <complex>
 
 namespace mlir {
 class AffineMap;
@@ -687,6 +688,11 @@ class DenseElementIndexedIteratorImpl
   /// Return the data base pointer.
   const char *getData() const { return this->base.getPointer(); }
 };
+
+/// Type trait detector that checks if a given type T is a complex type.
+template <typename T> struct is_complex_t : public std::false_type {};
+template <typename T>
+struct is_complex_t<std::complex<T>> : public std::true_type {};
 } // namespace detail
 
 /// An attribute that represents a reference to a dense vector or tensor object.
@@ -724,11 +730,27 @@ class DenseElementsAttr : public ElementsAttr {
   /// Constructs a dense integer elements attribute from a single element.
   template <typename T, typename = typename std::enable_if<
                             std::numeric_limits<T>::is_integer ||
-                            llvm::is_one_of<T, float, double>::value>::type>
+                            llvm::is_one_of<T, float, double>::value ||
+                            detail::is_complex_t<T>::value>::type>
   static DenseElementsAttr get(const ShapedType &type, T value) {
     return get(type, llvm::makeArrayRef(value));
   }
 
+  /// Constructs a dense complex elements attribute from an array of complex
+  /// values. Each value is expected to be the same bitwidth of the element type
+  /// of 'type'. 'type' must be a vector or tensor with static shape.
+  template <typename T, typename ElementT = typename T::value_type,
+            typename = typename std::enable_if<
+                detail::is_complex_t<T>::value &&
+                (std::numeric_limits<ElementT>::is_integer ||
+                 llvm::is_one_of<ElementT, float, double>::value)>::type>
+  static DenseElementsAttr get(const ShapedType &type, ArrayRef<T> values) {
+    const char *data = reinterpret_cast<const char *>(values.data());
+    return getRawComplex(type, ArrayRef<char>(data, values.size() * sizeof(T)),
+                         sizeof(T), std::numeric_limits<ElementT>::is_integer,
+                         std::numeric_limits<ElementT>::is_signed);
+  }
+
   /// Overload of the above 'get' method that is specialized for boolean values.
   static DenseElementsAttr get(ShapedType type, ArrayRef<bool> values);
 
@@ -764,6 +786,12 @@ class DenseElementsAttr : public ElementsAttr {
                                             ArrayRef<char> rawBuffer,
                                             bool isSplatBuffer);
 
+  /// Returns true if the given buffer is a valid raw buffer for the given type.
+  /// `detectedSplat` is set if the buffer is valid and represents a splat
+  /// buffer.
+  static bool isValidRawBuffer(ShapedType type, ArrayRef<char> rawBuffer,
+                               bool &detectedSplat);
+
   //===--------------------------------------------------------------------===//
   // Iterators
   //===--------------------------------------------------------------------===//
@@ -900,12 +928,28 @@ class DenseElementsAttr : public ElementsAttr {
   llvm::iterator_range<ElementIterator<T>> getValues() const {
     assert(isValidIntOrFloat(sizeof(T), std::numeric_limits<T>::is_integer,
                              std::numeric_limits<T>::is_signed));
-    auto rawData = getRawData().data();
+    const char *rawData = getRawData().data();
+    bool splat = isSplat();
+    return {ElementIterator<T>(rawData, splat, 0),
+            ElementIterator<T>(rawData, splat, getNumElements())};
+  }
+
+  /// Return the held element values as a range of std::complex.
+  template <typename T, typename ElementT = typename T::value_type,
+            typename = typename std::enable_if<
+                detail::is_complex_t<T>::value &&
+                (std::numeric_limits<ElementT>::is_integer ||
+                 llvm::is_one_of<ElementT, float, double>::value)>::type>
+  llvm::iterator_range<ElementIterator<T>> getValues() const {
+    assert(isValidComplex(sizeof(T), std::numeric_limits<ElementT>::is_integer,
+                          std::numeric_limits<ElementT>::is_signed));
+    const char *rawData = getRawData().data();
     bool splat = isSplat();
     return {ElementIterator<T>(rawData, splat, 0),
             ElementIterator<T>(rawData, splat, getNumElements())};
   }
 
+  /// Return the held element values as a range of StringRef.
   template <typename T, typename = typename std::enable_if<
                             std::is_same<T, StringRef>::value>::type>
   llvm::iterator_range<ElementIterator<StringRef>> getValues() const {
@@ -1010,6 +1054,13 @@ class DenseElementsAttr : public ElementsAttr {
     return IntElementIterator(*this, getNumElements());
   }
 
+  /// Overload of the raw 'get' method that asserts that the given type is of
+  /// complex type. This method is used to verify type invariants that the
+  /// templatized 'get' method cannot.
+  static DenseElementsAttr getRawComplex(ShapedType type, ArrayRef<char> data,
+                                         int64_t dataEltSize, bool isInt,
+                                         bool isSigned);
+
   /// Overload of the raw 'get' method that asserts that the given type is of
   /// integer or floating-point type. This method is used to verify type
   /// invariants that the templatized 'get' method cannot.
@@ -1022,6 +1073,11 @@ class DenseElementsAttr : public ElementsAttr {
   /// the current attribute. This method is used to verify specific type
   /// invariants that the templatized 'getValues' method cannot.
   bool isValidIntOrFloat(int64_t dataEltSize, bool isInt, bool isSigned) const;
+
+  /// Check the information for a C++ data type, check if this type is valid for
+  /// the current attribute. This method is used to verify specific type
+  /// invariants that the templatized 'getValues' method cannot.
+  bool isValidComplex(int64_t dataEltSize, bool isInt, bool isSigned) const;
 };
 
 /// An attribute class for representing dense arrays of strings. The structure
@@ -1074,6 +1130,13 @@ class DenseIntOrFPElementsAttr
   static DenseElementsAttr getRaw(ShapedType type, ArrayRef<char> data,
                                   bool isSplat);
 
+  /// Overload of the raw 'get' method that asserts that the given type is of
+  /// complex type. This method is used to verify type invariants that the
+  /// templatized 'get' method cannot.
+  static DenseElementsAttr getRawComplex(ShapedType type, ArrayRef<char> data,
+                                         int64_t dataEltSize, bool isInt,
+                                         bool isSigned);
+
   /// Overload of the raw 'get' method that asserts that the given type is of
   /// integer or floating-point type. This method is used to verify type
   /// invariants that the templatized 'get' method cannot.
@@ -1287,20 +1350,15 @@ class SparseElementsAttr
     return getZeroAPFloat();
   }
 
-  /// Get a zero for a StringRef.
+  /// Get a zero for an C++ integer, float, StringRef, or complex type.
   template <typename T>
-  typename std::enable_if<std::is_same<StringRef, T>::value, T>::type
-  getZeroValue() const {
-    return StringRef();
-  }
-
-  /// Get a zero for an C++ integer or float type.
-  template <typename T>
-  typename std::enable_if<std::numeric_limits<T>::is_integer ||
-                              llvm::is_one_of<T, float, double>::value,
-                          T>::type
+  typename std::enable_if<
+      std::numeric_limits<T>::is_integer ||
+          llvm::is_one_of<T, float, double, StringRef>::value ||
+          detail::is_complex_t<T>::value,
+      T>::type
   getZeroValue() const {
-    return T(0);
+    return T();
   }
 
   /// Flatten, and return, all of the sparse indices in this attribute in

diff  --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index 005571afb59a..4305f8b9af19 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -1468,50 +1468,21 @@ static void printDenseIntElement(DenseElementsAttr attr, raw_ostream &os,
 
 /// Print the float element of the given DenseElementsAttr at 'index'.
 static void printDenseFloatElement(DenseElementsAttr attr, raw_ostream &os,
-                                   unsigned index, bool isSigned) {
-  assert(isSigned && "floating point values are always signed");
+                                   unsigned index) {
   APFloat value = *std::next(attr.float_value_begin(), index);
   printFloatValue(value, os);
 }
 
-static void printDenseStringElement(DenseStringElementsAttr attr,
-                                    raw_ostream &os, unsigned index) {
-  os << "\"";
-  printEscapedString(attr.getRawStringData()[index], os);
-  os << "\"";
-}
-
-void ModulePrinter::printDenseElementsAttr(DenseElementsAttr attr,
-                                           bool allowHex) {
-  if (auto stringAttr = attr.dyn_cast<DenseStringElementsAttr>()) {
-    printDenseStringElementsAttr(stringAttr);
-    return;
-  }
-
-  printDenseIntOrFPElementsAttr(attr.cast<DenseIntOrFPElementsAttr>(),
-                                allowHex);
-}
-
-void ModulePrinter::printDenseIntOrFPElementsAttr(DenseIntOrFPElementsAttr attr,
-                                                  bool allowHex) {
-  auto type = attr.getType();
-  auto shape = type.getShape();
-  auto rank = type.getRank();
-  bool isSigned = !type.getElementType().isUnsignedInteger();
-
-  // The function used to print elements of this attribute.
-  auto printEltFn = type.getElementType().isIntOrIndex()
-                        ? printDenseIntElement
-                        : printDenseFloatElement;
-
+static void
+printDenseElementsAttrImpl(bool isSplat, ShapedType type, raw_ostream &os,
+                           function_ref<void(unsigned)> printEltFn) {
   // Special case for 0-d and splat tensors.
-  if (attr.isSplat()) {
-    printEltFn(attr, os, 0, isSigned);
-    return;
-  }
+  if (isSplat)
+    return printEltFn(0);
 
   // Special case for degenerate tensors.
   auto numElements = type.getNumElements();
+  int64_t rank = type.getRank();
   if (numElements == 0) {
     for (int i = 0; i < rank; ++i)
       os << '[';
@@ -1520,14 +1491,6 @@ void ModulePrinter::printDenseIntOrFPElementsAttr(DenseIntOrFPElementsAttr attr,
     return;
   }
 
-  // Check to see if we should format this attribute as a hex string.
-  if (allowHex && shouldPrintElementsAttrWithHex(numElements)) {
-    ArrayRef<char> rawData = attr.getRawData();
-    os << '"' << "0x" << llvm::toHex(StringRef(rawData.data(), rawData.size()))
-       << "\"";
-    return;
-  }
-
   // We use a mixed-radix counter to iterate through the shape. When we bump a
   // non-least-significant digit, we emit a close bracket. When we next emit an
   // element we re-open all closed brackets.
@@ -1537,7 +1500,8 @@ void ModulePrinter::printDenseIntOrFPElementsAttr(DenseIntOrFPElementsAttr attr,
   // The number of brackets that have been opened and not closed.
   unsigned openBrackets = 0;
 
-  auto bumpCounter = [&]() {
+  auto shape = type.getShape();
+  auto bumpCounter = [&] {
     // Bump the least significant digit.
     ++counter[rank - 1];
     // Iterate backwards bubbling back the increment.
@@ -1557,68 +1521,60 @@ void ModulePrinter::printDenseIntOrFPElementsAttr(DenseIntOrFPElementsAttr attr,
     while (openBrackets++ < rank)
       os << '[';
     openBrackets = rank;
-    printEltFn(attr, os, idx, isSigned);
+    printEltFn(idx);
     bumpCounter();
   }
   while (openBrackets-- > 0)
     os << ']';
 }
 
-void ModulePrinter::printDenseStringElementsAttr(DenseStringElementsAttr attr) {
-  auto type = attr.getType();
-  auto shape = type.getShape();
-  auto rank = type.getRank();
+void ModulePrinter::printDenseElementsAttr(DenseElementsAttr attr,
+                                           bool allowHex) {
+  if (auto stringAttr = attr.dyn_cast<DenseStringElementsAttr>())
+    return printDenseStringElementsAttr(stringAttr);
 
-  // Special case for 0-d and splat tensors.
-  if (attr.isSplat()) {
-    printDenseStringElement(attr, os, 0);
-    return;
-  }
+  printDenseIntOrFPElementsAttr(attr.cast<DenseIntOrFPElementsAttr>(),
+                                allowHex);
+}
 
-  // Special case for degenerate tensors.
+void ModulePrinter::printDenseIntOrFPElementsAttr(DenseIntOrFPElementsAttr attr,
+                                                  bool allowHex) {
+  auto type = attr.getType();
+  auto elementType = type.getElementType();
+
+  // Check to see if we should format this attribute as a hex string.
+  // TODO: Add support for formatting complex elements nicely.
   auto numElements = type.getNumElements();
-  if (numElements == 0) {
-    for (int i = 0; i < rank; ++i)
-      os << '[';
-    for (int i = 0; i < rank; ++i)
-      os << ']';
+  if (type.getElementType().isa<ComplexType>() ||
+      (!attr.isSplat() && allowHex &&
+       shouldPrintElementsAttrWithHex(numElements))) {
+    ArrayRef<char> rawData = attr.getRawData();
+    os << '"' << "0x" << llvm::toHex(StringRef(rawData.data(), rawData.size()))
+       << "\"";
     return;
   }
 
-  // We use a mixed-radix counter to iterate through the shape. When we bump a
-  // non-least-significant digit, we emit a close bracket. When we next emit an
-  // element we re-open all closed brackets.
-
-  // The mixed-radix counter, with radices in 'shape'.
-  SmallVector<unsigned, 4> counter(rank, 0);
-  // The number of brackets that have been opened and not closed.
-  unsigned openBrackets = 0;
+  if (elementType.isIntOrIndex()) {
+    bool isSigned = !elementType.isUnsignedInteger();
+    printDenseElementsAttrImpl(attr.isSplat(), type, os, [&](unsigned index) {
+      printDenseIntElement(attr, os, index, isSigned);
+    });
+  } else {
+    assert(elementType.isa<FloatType>() && "unexpected element type");
+    printDenseElementsAttrImpl(attr.isSplat(), type, os, [&](unsigned index) {
+      printDenseFloatElement(attr, os, index);
+    });
+  }
+}
 
-  auto bumpCounter = [&]() {
-    // Bump the least significant digit.
-    ++counter[rank - 1];
-    // Iterate backwards bubbling back the increment.
-    for (unsigned i = rank - 1; i > 0; --i)
-      if (counter[i] >= shape[i]) {
-        // Index 'i' is rolled over. Bump (i-1) and close a bracket.
-        counter[i] = 0;
-        ++counter[i - 1];
-        --openBrackets;
-        os << ']';
-      }
+void ModulePrinter::printDenseStringElementsAttr(DenseStringElementsAttr attr) {
+  ArrayRef<StringRef> data = attr.getRawStringData();
+  auto printFn = [&](unsigned index) {
+    os << "\"";
+    printEscapedString(data[index], os);
+    os << "\"";
   };
-
-  for (unsigned idx = 0, e = numElements; idx != e; ++idx) {
-    if (idx != 0)
-      os << ", ";
-    while (openBrackets++ < rank)
-      os << '[';
-    openBrackets = rank;
-    printDenseStringElement(attr, os, idx);
-    bumpCounter();
-  }
-  while (openBrackets-- > 0)
-    os << ']';
+  printDenseElementsAttrImpl(attr.isSplat(), attr.getType(), os, printFn);
 }
 
 void ModulePrinter::printType(Type type) {

diff  --git a/mlir/lib/IR/AttributeDetail.h b/mlir/lib/IR/AttributeDetail.h
index 201f058571af..4084dde74919 100644
--- a/mlir/lib/IR/AttributeDetail.h
+++ b/mlir/lib/IR/AttributeDetail.h
@@ -374,6 +374,8 @@ struct TypeAttributeStorage : public AttributeStorage {
 
 /// Return the bit width which DenseElementsAttr should use for this type.
 inline size_t getDenseElementBitWidth(Type eltType) {
+  if (ComplexType complex = eltType.dyn_cast<ComplexType>())
+    return getDenseElementBitWidth(complex.getElementType()) * 2;
   // FIXME(b/121118307): using 64 bits for BF16 because it is currently stored
   // with double semantics.
   if (eltType.isBF16())

diff  --git a/mlir/lib/IR/Attributes.cpp b/mlir/lib/IR/Attributes.cpp
index 5cb2d5a2ea84..72c638a68f73 100644
--- a/mlir/lib/IR/Attributes.cpp
+++ b/mlir/lib/IR/Attributes.cpp
@@ -779,24 +779,44 @@ DenseElementsAttr DenseElementsAttr::getFromRawBuffer(ShapedType type,
   return DenseIntOrFPElementsAttr::getRaw(type, rawBuffer, isSplatBuffer);
 }
 
+/// Returns true if the given buffer is a valid raw buffer for the given type.
+bool DenseElementsAttr::isValidRawBuffer(ShapedType type,
+                                         ArrayRef<char> rawBuffer,
+                                         bool &detectedSplat) {
+  size_t elementWidth = getDenseElementBitWidth(type.getElementType());
+  size_t storageWidth = getDenseElementStorageWidth(elementWidth);
+  size_t rawBufferWidth = rawBuffer.size() * CHAR_BIT;
+
+  // Storage width of 1 is special as it is packed by the bit.
+  if (storageWidth == 1) {
+    // Check for a splat, or a buffer equal to the number of elements.
+    if ((detectedSplat = rawBuffer.size() == 1))
+      return true;
+    return rawBufferWidth == llvm::alignTo<8>(type.getNumElements());
+  }
+  // All other types are 8-bit aligned.
+  if ((detectedSplat = rawBufferWidth == storageWidth))
+    return true;
+  return rawBufferWidth == (storageWidth * type.getNumElements());
+}
+
 /// Check the information for a C++ data type, check if this type is valid for
 /// the current attribute. This method is used to verify specific type
 /// invariants that the templatized 'getValues' method cannot.
-static bool isValidIntOrFloat(ShapedType type, int64_t dataEltSize, bool isInt,
+static bool isValidIntOrFloat(Type type, int64_t dataEltSize, bool isInt,
                               bool isSigned) {
   // Make sure that the data element size is the same as the type element width.
-  if (getDenseElementBitWidth(type.getElementType()) !=
+  if (getDenseElementBitWidth(type) !=
       static_cast<size_t>(dataEltSize * CHAR_BIT))
     return false;
 
   // Check that the element type is either float or integer or index.
   if (!isInt)
-    return type.getElementType().isa<FloatType>();
-
-  if (type.getElementType().isIndex())
+    return type.isa<FloatType>();
+  if (type.isIndex())
     return true;
 
-  auto intType = type.getElementType().dyn_cast<IntegerType>();
+  auto intType = type.dyn_cast<IntegerType>();
   if (!intType)
     return false;
 
@@ -807,6 +827,13 @@ static bool isValidIntOrFloat(ShapedType type, int64_t dataEltSize, bool isInt,
 }
 
 /// Defaults down the subclass implementation.
+DenseElementsAttr DenseElementsAttr::getRawComplex(ShapedType type,
+                                                   ArrayRef<char> data,
+                                                   int64_t dataEltSize,
+                                                   bool isInt, bool isSigned) {
+  return DenseIntOrFPElementsAttr::getRawComplex(type, data, dataEltSize, isInt,
+                                                 isSigned);
+}
 DenseElementsAttr DenseElementsAttr::getRawIntOrFloat(ShapedType type,
                                                       ArrayRef<char> data,
                                                       int64_t dataEltSize,
@@ -820,7 +847,17 @@ DenseElementsAttr DenseElementsAttr::getRawIntOrFloat(ShapedType type,
 /// method cannot.
 bool DenseElementsAttr::isValidIntOrFloat(int64_t dataEltSize, bool isInt,
                                           bool isSigned) const {
-  return ::isValidIntOrFloat(getType(), dataEltSize, isInt, isSigned);
+  return ::isValidIntOrFloat(getType().getElementType(), dataEltSize, isInt,
+                             isSigned);
+}
+
+/// Check the information for a C++ data type, check if this type is valid for
+/// the current attribute.
+bool DenseElementsAttr::isValidComplex(int64_t dataEltSize, bool isInt,
+                                       bool isSigned) const {
+  return ::isValidIntOrFloat(
+      getType().getElementType().cast<ComplexType>().getElementType(),
+      dataEltSize / 2, isInt, isSigned);
 }
 
 /// Returns if this attribute corresponds to a splat, i.e. if all element
@@ -964,6 +1001,23 @@ DenseElementsAttr DenseIntOrFPElementsAttr::getRaw(ShapedType type,
                    type, data, isSplat);
 }
 
+/// Overload of the raw 'get' method that asserts that the given type is of
+/// complex type. This method is used to verify type invariants that the
+/// templatized 'get' method cannot.
+DenseElementsAttr DenseIntOrFPElementsAttr::getRawComplex(ShapedType type,
+                                                          ArrayRef<char> data,
+                                                          int64_t dataEltSize,
+                                                          bool isInt,
+                                                          bool isSigned) {
+  assert(::isValidIntOrFloat(
+      type.getElementType().cast<ComplexType>().getElementType(),
+      dataEltSize / 2, isInt, isSigned));
+
+  int64_t numElements = data.size() / dataEltSize;
+  assert(numElements == 1 || numElements == type.getNumElements());
+  return getRaw(type, data, /*isSplat=*/numElements == 1);
+}
+
 /// Overload of the 'getRaw' method that asserts that the given type is of
 /// integer type. This method is used to verify type invariants that the
 /// templatized 'get' method cannot.
@@ -971,7 +1025,8 @@ DenseElementsAttr
 DenseIntOrFPElementsAttr::getRawIntOrFloat(ShapedType type, ArrayRef<char> data,
                                            int64_t dataEltSize, bool isInt,
                                            bool isSigned) {
-  assert(::isValidIntOrFloat(type, dataEltSize, isInt, isSigned));
+  assert(
+      ::isValidIntOrFloat(type.getElementType(), dataEltSize, isInt, isSigned));
 
   int64_t numElements = data.size() / dataEltSize;
   assert(numElements == 1 || numElements == type.getNumElements());

diff  --git a/mlir/lib/Parser/Parser.cpp b/mlir/lib/Parser/Parser.cpp
index d129b867fb0c..3689983d45e2 100644
--- a/mlir/lib/Parser/Parser.cpp
+++ b/mlir/lib/Parser/Parser.cpp
@@ -2041,7 +2041,8 @@ DenseElementsAttr TensorLiteralParser::getAttr(llvm::SMLoc loc,
   Type eltType = type.getElementType();
 
   // Check to see if we parse the literal from a hex string.
-  if (hexStorage.hasValue() && eltType.isIntOrFloat())
+  if (hexStorage.hasValue() &&
+      (eltType.isIntOrFloat() || eltType.isa<ComplexType>()))
     return getHexAttr(loc, type);
 
   // Check that the parsed storage size has the same number of elements to the
@@ -2063,6 +2064,13 @@ DenseElementsAttr TensorLiteralParser::getAttr(llvm::SMLoc loc,
   if (auto floatTy = eltType.dyn_cast<FloatType>())
     return getFloatAttr(loc, type, floatTy);
 
+  // If parsing a complex type.
+  // TODO: Support complex elements with pretty element printing.
+  if (eltType.isa<ComplexType>()) {
+    p.emitError(loc) << "complex elements only support hex formatting";
+    return nullptr;
+  }
+
   // Other types are assumed to be string representations.
   return getStringAttr(loc, type, type.getElementType());
 }
@@ -2196,9 +2204,10 @@ DenseElementsAttr TensorLiteralParser::getStringAttr(llvm::SMLoc loc,
 DenseElementsAttr TensorLiteralParser::getHexAttr(llvm::SMLoc loc,
                                                   ShapedType type) {
   Type elementType = type.getElementType();
-  if (!elementType.isa<FloatType>() && !elementType.isa<IntegerType>()) {
-    p.emitError(loc) << "expected floating-point or integer element type, got "
-                     << elementType;
+  if (!elementType.isIntOrIndexOrFloat() && !elementType.isa<ComplexType>()) {
+    p.emitError(loc)
+        << "expected floating-point, integer, or complex element type, got "
+        << elementType;
     return nullptr;
   }
 
@@ -2206,21 +2215,15 @@ DenseElementsAttr TensorLiteralParser::getHexAttr(llvm::SMLoc loc,
   if (parseElementAttrHexValues(p, hexStorage.getValue(), data))
     return nullptr;
 
-  // Check that the size of the hex data corresponds to the size of the type, or
-  // a splat of the type.
-  // TODO: bf16 is currently stored as a double, this should be removed when
-  // APFloat properly supports it.
-  int64_t elementWidth =
-      elementType.isBF16() ? 64 : elementType.getIntOrFloatBitWidth();
-  if (static_cast<int64_t>(data.size() * CHAR_BIT) !=
-      (type.getNumElements() * elementWidth)) {
+  ArrayRef<char> rawData(data.data(), data.size());
+  bool detectedSplat = false;
+  if (!DenseElementsAttr::isValidRawBuffer(type, rawData, detectedSplat)) {
     p.emitError(loc) << "elements hex data size is invalid for provided type: "
                      << type;
     return nullptr;
   }
 
-  return DenseElementsAttr::getFromRawBuffer(
-      type, ArrayRef<char>(data.data(), data.size()), /*isSplatBuffer=*/false);
+  return DenseElementsAttr::getFromRawBuffer(type, rawData, detectedSplat);
 }
 
 ParseResult TensorLiteralParser::parseElement() {

diff  --git a/mlir/test/IR/dense-elements-hex.mlir b/mlir/test/IR/dense-elements-hex.mlir
index 0375004e1d84..87c0acf80341 100644
--- a/mlir/test/IR/dense-elements-hex.mlir
+++ b/mlir/test/IR/dense-elements-hex.mlir
@@ -7,6 +7,12 @@
 // CHECK: dense<[1.000000e+01, 5.000000e+00]> : tensor<2xf64>
 "foo.op"() {dense.attr = dense<"0x00000000000024400000000000001440"> : tensor<2xf64>} : () -> ()
 
+// CHECK: dense<"0x00000000000024400000000000001440"> : tensor<1xcomplex<f64>>
+"foo.op"() {dense.attr = dense<"0x00000000000024400000000000001440"> : tensor<1xcomplex<f64>>} : () -> ()
+
+// CHECK: dense<"0x00000000000024400000000000001440"> : tensor<10xcomplex<f64>>
+"foo.op"() {dense.attr = dense<"0x00000000000024400000000000001440"> : tensor<10xcomplex<f64>>} : () -> ()
+
 // CHECK: dense<[1.000000e+01, 5.000000e+00]> : tensor<2xbf16>
 "foo.op"() {dense.attr = dense<"0x00000000000024400000000000001440"> : tensor<2xbf16>} : () -> ()
 

diff  --git a/mlir/unittests/IR/AttributeTest.cpp b/mlir/unittests/IR/AttributeTest.cpp
index 7d22b5e5a07f..7f91d57b01ae 100644
--- a/mlir/unittests/IR/AttributeTest.cpp
+++ b/mlir/unittests/IR/AttributeTest.cpp
@@ -25,6 +25,9 @@ static void testSplat(Type eltType, const EltTy &splatElt) {
   auto detectedSplat =
       DenseElementsAttr::get(shape, llvm::makeArrayRef({splatElt, splatElt}));
   EXPECT_EQ(detectedSplat, splat);
+
+  for (auto newValue : detectedSplat.template getValues<EltTy>())
+    EXPECT_EQ(newValue, splatElt);
 }
 
 namespace {
@@ -162,4 +165,18 @@ TEST(DenseSplatTest, StringAttrSplat) {
   testSplat(stringType, stringAttr);
 }
 
+TEST(DenseComplexTest, ComplexFloatSplat) {
+  MLIRContext context;
+  ComplexType complexType = ComplexType::get(FloatType::getF32(&context));
+  std::complex<float> value(10.0, 15.0);
+  testSplat(complexType, value);
+}
+
+TEST(DenseComplexTest, ComplexIntSplat) {
+  MLIRContext context;
+  ComplexType complexType = ComplexType::get(IntegerType::get(64, &context));
+  std::complex<int64_t> value(10, 15);
+  testSplat(complexType, value);
+}
+
 } // end namespace


        


More information about the Mlir-commits mailing list