[Mlir-commits] [mlir] 24ad385 - [mlir][DenseElementsAttr] Add support for opaque APFloat/APInt complex values.

River Riddle llvmlistbot at llvm.org
Tue May 5 12:46:43 PDT 2020


Author: River Riddle
Date: 2020-05-05T12:42:37-07:00
New Revision: 24ad3858842552a8052c12c44c7707dd822898bf

URL: https://github.com/llvm/llvm-project/commit/24ad3858842552a8052c12c44c7707dd822898bf
DIFF: https://github.com/llvm/llvm-project/commit/24ad3858842552a8052c12c44c7707dd822898bf.diff

LOG: [mlir][DenseElementsAttr] Add support for opaque APFloat/APInt complex values.

This revision allows for creating DenseElementsAttrs and accessing elements using std::complex<APInt>/std::complex<APFloat>. This allows for opaquely accessing and transforming complex values. This is used by the printer/parser to provide pretty printing for complex values. The form for complex values matches that of std::complex, i.e.:

```
// `(` element `,` element `)`
dense<(10,10)> : tensor<complex<i64>>
```

Differential Revision: https://reviews.llvm.org/D79296

Added: 
    

Modified: 
    mlir/include/mlir/IR/Attributes.h
    mlir/lib/IR/AsmPrinter.cpp
    mlir/lib/IR/AttributeDetail.h
    mlir/lib/IR/Attributes.cpp
    mlir/lib/Parser/Parser.cpp
    mlir/test/IR/dense-elements-hex.mlir
    mlir/test/IR/invalid.mlir
    mlir/test/IR/parser.mlir
    mlir/unittests/IR/AttributeTest.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/IR/Attributes.h b/mlir/include/mlir/IR/Attributes.h
index 6b3edf14a39b..e0b4e5f43737 100644
--- a/mlir/include/mlir/IR/Attributes.h
+++ b/mlir/include/mlir/IR/Attributes.h
@@ -764,12 +764,26 @@ class DenseElementsAttr : public ElementsAttr {
   /// shape.
   static DenseElementsAttr get(ShapedType type, ArrayRef<APInt> values);
 
+  /// Constructs a dense complex elements attribute from an array of APInt
+  /// values. Each APInt value is expected to have the same bitwidth as the
+  /// element type of 'type'. 'type' must be a vector or tensor with static
+  /// shape.
+  static DenseElementsAttr get(ShapedType type,
+                               ArrayRef<std::complex<APInt>> values);
+
   /// Constructs a dense float elements attribute from an array of APFloat
   /// values. Each APFloat value is expected to have the same bitwidth as the
   /// element type of 'type'. 'type' must be a vector or tensor with static
   /// shape.
   static DenseElementsAttr get(ShapedType type, ArrayRef<APFloat> values);
 
+  /// Constructs a dense complex elements attribute from an array of APFloat
+  /// values. Each APFloat value is expected to have the same bitwidth as the
+  /// element type of 'type'. 'type' must be a vector or tensor with static
+  /// shape.
+  static DenseElementsAttr get(ShapedType type,
+                               ArrayRef<std::complex<APFloat>> values);
+
   /// Construct a dense elements attribute for an initializer_list of values.
   /// Each value is expected to be the same bitwidth of the element type of
   /// 'type'. 'type' must be a vector or tensor with static shape.
@@ -868,6 +882,26 @@ class DenseElementsAttr : public ElementsAttr {
     size_t bitWidth;
   };
 
+  /// A utility iterator that allows walking over the internal raw complex APInt
+  /// values.
+  class ComplexIntElementIterator
+      : public detail::DenseElementIndexedIteratorImpl<
+            ComplexIntElementIterator, std::complex<APInt>, std::complex<APInt>,
+            std::complex<APInt>> {
+  public:
+    /// Accesses the raw std::complex<APInt> value at this iterator position.
+    std::complex<APInt> operator*() const;
+
+  private:
+    friend DenseElementsAttr;
+
+    /// Constructs a new iterator.
+    ComplexIntElementIterator(DenseElementsAttr attr, size_t dataIndex);
+
+    /// The bitwidth of the element type.
+    size_t bitWidth;
+  };
+
   /// Iterator for walking over APFloat values.
   class FloatElementIterator final
       : public llvm::mapped_iterator<IntElementIterator,
@@ -881,6 +915,21 @@ class DenseElementsAttr : public ElementsAttr {
     using reference = APFloat;
   };
 
+  /// Iterator for walking over complex APFloat values.
+  class ComplexFloatElementIterator final
+      : public llvm::mapped_iterator<
+            ComplexIntElementIterator,
+            std::function<std::complex<APFloat>(const std::complex<APInt> &)>> {
+    friend DenseElementsAttr;
+
+    /// Initializes the float element iterator to the specified iterator.
+    ComplexFloatElementIterator(const llvm::fltSemantics &smt,
+                                ComplexIntElementIterator it);
+
+  public:
+    using reference = std::complex<APFloat>;
+  };
+
   //===--------------------------------------------------------------------===//
   // Value Querying
   //===--------------------------------------------------------------------===//
@@ -1004,6 +1053,15 @@ class DenseElementsAttr : public ElementsAttr {
   IntElementIterator int_value_begin() const;
   IntElementIterator int_value_end() const;
 
+  /// Return the held element values as a range of complex APInts. The element
+  /// type of this attribute must be a complex of integer type.
+  llvm::iterator_range<ComplexIntElementIterator> getComplexIntValues() const;
+  template <typename T, typename = typename std::enable_if<
+                            std::is_same<T, std::complex<APInt>>::value>::type>
+  llvm::iterator_range<ComplexIntElementIterator> getValues() const {
+    return getComplexIntValues();
+  }
+
   /// Return the held element values as a range of APFloat. The element type of
   /// this attribute must be of float type.
   llvm::iterator_range<FloatElementIterator> getFloatValues() const;
@@ -1015,6 +1073,16 @@ class DenseElementsAttr : public ElementsAttr {
   FloatElementIterator float_value_begin() const;
   FloatElementIterator float_value_end() const;
 
+  /// Return the held element values as a range of complex APFloat. The element
+  /// type of this attribute must be a complex of float type.
+  llvm::iterator_range<ComplexFloatElementIterator>
+  getComplexFloatValues() const;
+  template <typename T, typename = typename std::enable_if<std::is_same<
+                            T, std::complex<APFloat>>::value>::type>
+  llvm::iterator_range<ComplexFloatElementIterator> getValues() const {
+    return getComplexFloatValues();
+  }
+
   /// Return the raw storage data held by this attribute. Users should generally
   /// not use this directly, as the internal storage format is not always in the
   /// form the user might expect.
@@ -1120,10 +1188,17 @@ class DenseIntOrFPElementsAttr
 protected:
   friend DenseElementsAttr;
 
+  /// Constructs a dense elements attribute from an array of raw APFloat values.
+  /// Each APFloat value is expected to have the same bitwidth as the element
+  /// type of 'type'. 'type' must be a vector or tensor with static shape.
+  static DenseElementsAttr getRaw(ShapedType type, size_t storageWidth,
+                                  ArrayRef<APFloat> values, bool isSplat);
+
   /// Constructs a dense elements attribute from an array of raw APInt values.
   /// Each APInt value is expected to have the same bitwidth as the element type
   /// of 'type'. 'type' must be a vector or tensor with static shape.
-  static DenseElementsAttr getRaw(ShapedType type, ArrayRef<APInt> values);
+  static DenseElementsAttr getRaw(ShapedType type, size_t storageWidth,
+                                  ArrayRef<APInt> values, bool isSplat);
 
   /// Get or create a new dense elements attribute instance with the given raw
   /// data buffer. 'type' must be a vector or tensor with static shape.
@@ -1343,19 +1418,34 @@ class SparseElementsAttr
   getZeroValue() const {
     return getZeroAPInt();
   }
+  template <typename T>
+  typename std::enable_if<std::is_same<std::complex<APInt>, T>::value, T>::type
+  getZeroValue() const {
+    APInt intZero = getZeroAPInt();
+    return {intZero, intZero};
+  }
   /// Get a zero for an APFloat.
   template <typename T>
   typename std::enable_if<std::is_same<APFloat, T>::value, T>::type
   getZeroValue() const {
     return getZeroAPFloat();
   }
+  template <typename T>
+  typename std::enable_if<std::is_same<std::complex<APFloat>, T>::value,
+                          T>::type
+  getZeroValue() const {
+    APFloat floatZero = getZeroAPFloat();
+    return {floatZero, floatZero};
+  }
 
   /// Get a zero for an C++ integer, float, StringRef, or complex type.
   template <typename T>
   typename std::enable_if<
       std::numeric_limits<T>::is_integer ||
           llvm::is_one_of<T, float, double, StringRef>::value ||
-          detail::is_complex_t<T>::value,
+          (detail::is_complex_t<T>::value &&
+           !llvm::is_one_of<T, std::complex<APInt>,
+                            std::complex<APFloat>>::value),
       T>::type
   getZeroValue() const {
     return T();

diff  --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index 4305f8b9af19..f58656005cd6 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -1456,23 +1456,15 @@ void ModulePrinter::printAttribute(Attribute attr,
   }
 }
 
-/// Print the integer element of the given DenseElementsAttr at 'index'.
-static void printDenseIntElement(DenseElementsAttr attr, raw_ostream &os,
-                                 unsigned index, bool isSigned) {
-  APInt value = *std::next(attr.int_value_begin(), index);
+/// Print the integer element of a DenseElementsAttr.
+static void printDenseIntElement(const APInt &value, raw_ostream &os,
+                                 bool isSigned) {
   if (value.getBitWidth() == 1)
     os << (value.getBoolValue() ? "true" : "false");
   else
     value.print(os, isSigned);
 }
 
-/// Print the float element of the given DenseElementsAttr at 'index'.
-static void printDenseFloatElement(DenseElementsAttr attr, raw_ostream &os,
-                                   unsigned index) {
-  APFloat value = *std::next(attr.float_value_begin(), index);
-  printFloatValue(value, os);
-}
-
 static void
 printDenseElementsAttrImpl(bool isSplat, ShapedType type, raw_ostream &os,
                            function_ref<void(unsigned)> printEltFn) {
@@ -1543,26 +1535,45 @@ void ModulePrinter::printDenseIntOrFPElementsAttr(DenseIntOrFPElementsAttr attr,
   auto elementType = type.getElementType();
 
   // Check to see if we should format this attribute as a hex string.
-  // TODO: Add support for formatting complex elements nicely.
   auto numElements = type.getNumElements();
-  if (type.getElementType().isa<ComplexType>() ||
-      (!attr.isSplat() && allowHex &&
-       shouldPrintElementsAttrWithHex(numElements))) {
+  if (!attr.isSplat() && allowHex &&
+      shouldPrintElementsAttrWithHex(numElements)) {
     ArrayRef<char> rawData = attr.getRawData();
     os << '"' << "0x" << llvm::toHex(StringRef(rawData.data(), rawData.size()))
        << "\"";
     return;
   }
 
-  if (elementType.isIntOrIndex()) {
+  if (ComplexType complexTy = elementType.dyn_cast<ComplexType>()) {
+    auto printComplexValue = [&](auto complexValues, auto printFn,
+                                 raw_ostream &os, auto &&... params) {
+      printDenseElementsAttrImpl(attr.isSplat(), type, os, [&](unsigned index) {
+        auto complexValue = *(complexValues.begin() + index);
+        os << "(";
+        printFn(complexValue.real(), os, params...);
+        os << ",";
+        printFn(complexValue.imag(), os, params...);
+        os << ")";
+      });
+    };
+
+    Type complexElementType = complexTy.getElementType();
+    if (complexElementType.isa<IntegerType>())
+      printComplexValue(attr.getComplexIntValues(), printDenseIntElement, os,
+                        /*isSigned=*/!complexElementType.isUnsignedInteger());
+    else
+      printComplexValue(attr.getComplexFloatValues(), printFloatValue, os);
+  } else if (elementType.isIntOrIndex()) {
     bool isSigned = !elementType.isUnsignedInteger();
+    auto intValues = attr.getIntValues();
     printDenseElementsAttrImpl(attr.isSplat(), type, os, [&](unsigned index) {
-      printDenseIntElement(attr, os, index, isSigned);
+      printDenseIntElement(*(intValues.begin() + index), os, isSigned);
     });
   } else {
     assert(elementType.isa<FloatType>() && "unexpected element type");
+    auto floatValues = attr.getFloatValues();
     printDenseElementsAttrImpl(attr.isSplat(), type, os, [&](unsigned index) {
-      printDenseFloatElement(attr, os, index);
+      printFloatValue(*(floatValues.begin() + index), os);
     });
   }
 }

diff  --git a/mlir/lib/IR/AttributeDetail.h b/mlir/lib/IR/AttributeDetail.h
index 4084dde74919..ffd1b34504e3 100644
--- a/mlir/lib/IR/AttributeDetail.h
+++ b/mlir/lib/IR/AttributeDetail.h
@@ -374,8 +374,9 @@ struct TypeAttributeStorage : public AttributeStorage {
 
 /// Return the bit width which DenseElementsAttr should use for this type.
 inline size_t getDenseElementBitWidth(Type eltType) {
-  if (ComplexType complex = eltType.dyn_cast<ComplexType>())
-    return getDenseElementBitWidth(complex.getElementType()) * 2;
+  // Align the width for complex to 8 to make storage and interpretation easier.
+  if (ComplexType comp = eltType.dyn_cast<ComplexType>())
+    return llvm::alignTo<8>(getDenseElementBitWidth(comp.getElementType())) * 2;
   // FIXME(b/121118307): using 64 bits for BF16 because it is currently stored
   // with double semantics.
   if (eltType.isBF16())

diff  --git a/mlir/lib/IR/Attributes.cpp b/mlir/lib/IR/Attributes.cpp
index 72c638a68f73..a150fcd4d323 100644
--- a/mlir/lib/IR/Attributes.cpp
+++ b/mlir/lib/IR/Attributes.cpp
@@ -537,6 +537,9 @@ uint64_t ElementsAttr::getFlattenedIndex(ArrayRef<uint64_t> index) const {
 static size_t getDenseElementStorageWidth(size_t origWidth) {
   return origWidth == 1 ? origWidth : llvm::alignTo<8>(origWidth);
 }
+static size_t getDenseElementStorageWidth(Type elementType) {
+  return getDenseElementStorageWidth(getDenseElementBitWidth(elementType));
+}
 
 /// Set a bit to a specific value.
 static void setBit(char *rawData, size_t bitPos, bool value) {
@@ -613,14 +616,15 @@ static bool hasSameElementsOrSplat(ShapedType type, const Values &values) {
 // DenseElementAttr Iterators
 //===----------------------------------------------------------------------===//
 
-/// Constructs a new iterator.
+//===----------------------------------------------------------------------===//
+// AttributeElementIterator
+
 DenseElementsAttr::AttributeElementIterator::AttributeElementIterator(
     DenseElementsAttr attr, size_t index)
     : llvm::indexed_accessor_iterator<AttributeElementIterator, const void *,
                                       Attribute, Attribute, Attribute>(
           attr.getAsOpaquePointer(), index) {}
 
-/// Accesses the Attribute value at this iterator position.
 Attribute DenseElementsAttr::AttributeElementIterator::operator*() const {
   auto owner = getFromOpaquePointer(base).cast<DenseElementsAttr>();
   Type eltTy = owner.getType().getElementType();
@@ -640,37 +644,75 @@ Attribute DenseElementsAttr::AttributeElementIterator::operator*() const {
   llvm_unreachable("unexpected element type");
 }
 
-/// Constructs a new iterator.
+//===----------------------------------------------------------------------===//
+// BoolElementIterator
+
 DenseElementsAttr::BoolElementIterator::BoolElementIterator(
     DenseElementsAttr attr, size_t dataIndex)
     : DenseElementIndexedIteratorImpl<BoolElementIterator, bool, bool, bool>(
           attr.getRawData().data(), attr.isSplat(), dataIndex) {}
 
-/// Accesses the bool value at this iterator position.
 bool DenseElementsAttr::BoolElementIterator::operator*() const {
   return getBit(getData(), getDataIndex());
 }
 
-/// Constructs a new iterator.
+//===----------------------------------------------------------------------===//
+// IntElementIterator
+
 DenseElementsAttr::IntElementIterator::IntElementIterator(
     DenseElementsAttr attr, size_t dataIndex)
     : DenseElementIndexedIteratorImpl<IntElementIterator, APInt, APInt, APInt>(
           attr.getRawData().data(), attr.isSplat(), dataIndex),
       bitWidth(getDenseElementBitWidth(attr.getType().getElementType())) {}
 
-/// Accesses the raw APInt value at this iterator position.
 APInt DenseElementsAttr::IntElementIterator::operator*() const {
   return readBits(getData(),
                   getDataIndex() * getDenseElementStorageWidth(bitWidth),
                   bitWidth);
 }
 
+//===----------------------------------------------------------------------===//
+// ComplexIntElementIterator
+
+DenseElementsAttr::ComplexIntElementIterator::ComplexIntElementIterator(
+    DenseElementsAttr attr, size_t dataIndex)
+    : DenseElementIndexedIteratorImpl<ComplexIntElementIterator,
+                                      std::complex<APInt>, std::complex<APInt>,
+                                      std::complex<APInt>>(
+          attr.getRawData().data(), attr.isSplat(), dataIndex) {
+  auto complexType = attr.getType().getElementType().cast<ComplexType>();
+  bitWidth = getDenseElementBitWidth(complexType.getElementType());
+}
+
+std::complex<APInt>
+DenseElementsAttr::ComplexIntElementIterator::operator*() const {
+  size_t storageWidth = getDenseElementStorageWidth(bitWidth);
+  size_t offset = getDataIndex() * storageWidth * 2;
+  return {readBits(getData(), offset, bitWidth),
+          readBits(getData(), offset + storageWidth, bitWidth)};
+}
+
+//===----------------------------------------------------------------------===//
+// FloatElementIterator
+
 DenseElementsAttr::FloatElementIterator::FloatElementIterator(
     const llvm::fltSemantics &smt, IntElementIterator it)
     : llvm::mapped_iterator<IntElementIterator,
                             std::function<APFloat(const APInt &)>>(
           it, [&](const APInt &val) { return APFloat(smt, val); }) {}
 
+//===----------------------------------------------------------------------===//
+// ComplexFloatElementIterator
+
+DenseElementsAttr::ComplexFloatElementIterator::ComplexFloatElementIterator(
+    const llvm::fltSemantics &smt, ComplexIntElementIterator it)
+    : llvm::mapped_iterator<
+          ComplexIntElementIterator,
+          std::function<std::complex<APFloat>(const std::complex<APInt> &)>>(
+          it, [&](const std::complex<APInt> &val) -> std::complex<APFloat> {
+            return {APFloat(smt, val.real()), APFloat(smt, val.imag())};
+          }) {}
+
 //===----------------------------------------------------------------------===//
 // DenseElementsAttr
 //===----------------------------------------------------------------------===//
@@ -753,7 +795,21 @@ DenseElementsAttr DenseElementsAttr::get(ShapedType type,
 DenseElementsAttr DenseElementsAttr::get(ShapedType type,
                                          ArrayRef<APInt> values) {
   assert(type.getElementType().isIntOrIndex());
-  return DenseIntOrFPElementsAttr::getRaw(type, values);
+  assert(hasSameElementsOrSplat(type, values));
+  size_t storageBitWidth = getDenseElementStorageWidth(type.getElementType());
+  return DenseIntOrFPElementsAttr::getRaw(type, storageBitWidth, values,
+                                          /*isSplat=*/(values.size() == 1));
+}
+DenseElementsAttr DenseElementsAttr::get(ShapedType type,
+                                         ArrayRef<std::complex<APInt>> values) {
+  ComplexType complex = type.getElementType().cast<ComplexType>();
+  assert(complex.getElementType().isa<IntegerType>());
+  assert(hasSameElementsOrSplat(type, values));
+  size_t storageBitWidth = getDenseElementStorageWidth(complex) / 2;
+  ArrayRef<APInt> intVals(reinterpret_cast<const APInt *>(values.data()),
+                          values.size() * 2);
+  return DenseIntOrFPElementsAttr::getRaw(type, storageBitWidth, intVals,
+                                          /*isSplat=*/(values.size() == 1));
 }
 
 // Constructs a dense float elements attribute from an array of APFloat
@@ -762,12 +818,22 @@ DenseElementsAttr DenseElementsAttr::get(ShapedType type,
 DenseElementsAttr DenseElementsAttr::get(ShapedType type,
                                          ArrayRef<APFloat> values) {
   assert(type.getElementType().isa<FloatType>());
-
-  // Convert the APFloat values to APInt and create a dense elements attribute.
-  std::vector<APInt> intValues(values.size());
-  for (unsigned i = 0, e = values.size(); i != e; ++i)
-    intValues[i] = values[i].bitcastToAPInt();
-  return DenseIntOrFPElementsAttr::getRaw(type, intValues);
+  assert(hasSameElementsOrSplat(type, values));
+  size_t storageBitWidth = getDenseElementStorageWidth(type.getElementType());
+  return DenseIntOrFPElementsAttr::getRaw(type, storageBitWidth, values,
+                                          /*isSplat=*/(values.size() == 1));
+}
+DenseElementsAttr
+DenseElementsAttr::get(ShapedType type,
+                       ArrayRef<std::complex<APFloat>> values) {
+  ComplexType complex = type.getElementType().cast<ComplexType>();
+  assert(complex.getElementType().isa<FloatType>());
+  assert(hasSameElementsOrSplat(type, values));
+  ArrayRef<APFloat> apVals(reinterpret_cast<const APFloat *>(values.data()),
+                           values.size() * 2);
+  size_t storageBitWidth = getDenseElementStorageWidth(complex) / 2;
+  return DenseIntOrFPElementsAttr::getRaw(type, storageBitWidth, apVals,
+                                          /*isSplat=*/(values.size() == 1));
 }
 
 /// Construct a dense elements attribute from a raw buffer representing the
@@ -783,8 +849,7 @@ DenseElementsAttr DenseElementsAttr::getFromRawBuffer(ShapedType type,
 bool DenseElementsAttr::isValidRawBuffer(ShapedType type,
                                          ArrayRef<char> rawBuffer,
                                          bool &detectedSplat) {
-  size_t elementWidth = getDenseElementBitWidth(type.getElementType());
-  size_t storageWidth = getDenseElementStorageWidth(elementWidth);
+  size_t storageWidth = getDenseElementStorageWidth(type.getElementType());
   size_t rawBufferWidth = rawBuffer.size() * CHAR_BIT;
 
   // Storage width of 1 is special as it is packed by the bit.
@@ -904,13 +969,20 @@ auto DenseElementsAttr::int_value_end() const -> IntElementIterator {
   assert(getType().getElementType().isIntOrIndex() && "expected integral type");
   return raw_int_end();
 }
+auto DenseElementsAttr::getComplexIntValues() const
+    -> llvm::iterator_range<ComplexIntElementIterator> {
+  Type eltTy = getType().getElementType().cast<ComplexType>().getElementType();
+  (void)eltTy;
+  assert(eltTy.isa<IntegerType>() && "expected complex integral type");
+  return {ComplexIntElementIterator(*this, 0),
+          ComplexIntElementIterator(*this, getNumElements())};
+}
 
 /// Return the held element values as a range of APFloat. The element type of
 /// this attribute must be of float type.
 auto DenseElementsAttr::getFloatValues() const
     -> llvm::iterator_range<FloatElementIterator> {
   auto elementType = getType().getElementType().cast<FloatType>();
-  assert(elementType.isa<FloatType>() && "expected float type");
   const auto &elementSemantics = elementType.getFloatSemantics();
   return {FloatElementIterator(elementSemantics, raw_int_begin()),
           FloatElementIterator(elementSemantics, raw_int_end())};
@@ -921,6 +993,14 @@ auto DenseElementsAttr::float_value_begin() const -> FloatElementIterator {
 auto DenseElementsAttr::float_value_end() const -> FloatElementIterator {
   return getFloatValues().end();
 }
+auto DenseElementsAttr::getComplexFloatValues() const
+    -> llvm::iterator_range<ComplexFloatElementIterator> {
+  Type eltTy = getType().getElementType().cast<ComplexType>().getElementType();
+  assert(eltTy.isa<FloatType>() && "expected complex float type");
+  const auto &semantics = eltTy.cast<FloatType>().getFloatSemantics();
+  return {{semantics, {*this, 0}},
+          {semantics, {*this, static_cast<size_t>(getNumElements())}}};
+}
 
 /// Return the raw storage data held by this attribute.
 ArrayRef<char> DenseElementsAttr::getRawData() const {
@@ -972,23 +1052,42 @@ DenseStringElementsAttr::get(ShapedType type, ArrayRef<StringRef> values) {
 // DenseIntOrFPElementsAttr
 //===----------------------------------------------------------------------===//
 
+/// Utility method to write a range of APInt values to a buffer.
+template <typename APRangeT>
+static void writeAPIntsToBuffer(size_t storageWidth, std::vector<char> &data,
+                                APRangeT &&values) {
+  data.resize(llvm::divideCeil(storageWidth, CHAR_BIT) * llvm::size(values));
+  size_t offset = 0;
+  for (auto it = values.begin(), e = values.end(); it != e;
+       ++it, offset += storageWidth) {
+    assert((*it).getBitWidth() <= storageWidth);
+    writeBits(data.data(), offset, *it);
+  }
+}
+
+/// Constructs a dense elements attribute from an array of raw APFloat values.
+/// Each APFloat value is expected to have the same bitwidth as the element
+/// type of 'type'. 'type' must be a vector or tensor with static shape.
+DenseElementsAttr DenseIntOrFPElementsAttr::getRaw(ShapedType type,
+                                                   size_t storageWidth,
+                                                   ArrayRef<APFloat> values,
+                                                   bool isSplat) {
+  std::vector<char> data;
+  auto unwrapFloat = [](const APFloat &val) { return val.bitcastToAPInt(); };
+  writeAPIntsToBuffer(storageWidth, data, llvm::map_range(values, unwrapFloat));
+  return DenseIntOrFPElementsAttr::getRaw(type, data, isSplat);
+}
+
 /// Constructs a dense elements attribute from an array of raw APInt values.
 /// Each APInt value is expected to have the same bitwidth as the element type
 /// of 'type'.
 DenseElementsAttr DenseIntOrFPElementsAttr::getRaw(ShapedType type,
-                                                   ArrayRef<APInt> values) {
-  assert(hasSameElementsOrSplat(type, values));
-
-  size_t bitWidth = getDenseElementBitWidth(type.getElementType());
-  size_t storageBitWidth = getDenseElementStorageWidth(bitWidth);
-  std::vector<char> elementData(llvm::divideCeil(storageBitWidth, CHAR_BIT) *
-                                values.size());
-  for (unsigned i = 0, e = values.size(); i != e; ++i) {
-    assert(values[i].getBitWidth() == bitWidth);
-    writeBits(elementData.data(), i * storageBitWidth, values[i]);
-  }
-  return DenseIntOrFPElementsAttr::getRaw(type, elementData,
-                                          /*isSplat=*/(values.size() == 1));
+                                                   size_t storageWidth,
+                                                   ArrayRef<APInt> values,
+                                                   bool isSplat) {
+  std::vector<char> data;
+  writeAPIntsToBuffer(storageWidth, data, values);
+  return DenseIntOrFPElementsAttr::getRaw(type, data, isSplat);
 }
 
 DenseElementsAttr DenseIntOrFPElementsAttr::getRaw(ShapedType type,

diff  --git a/mlir/lib/Parser/Parser.cpp b/mlir/lib/Parser/Parser.cpp
index 3689983d45e2..384c69e44a05 100644
--- a/mlir/lib/Parser/Parser.cpp
+++ b/mlir/lib/Parser/Parser.cpp
@@ -1956,29 +1956,13 @@ class TensorLiteralParser {
   ArrayRef<int64_t> getShape() const { return shape; }
 
 private:
-  enum class ElementKind { Boolean, Integer, Float, String };
-
-  /// Return a string to represent the given element kind.
-  const char *getElementKindStr(ElementKind kind) {
-    switch (kind) {
-    case ElementKind::Boolean:
-      return "'boolean'";
-    case ElementKind::Integer:
-      return "'integer'";
-    case ElementKind::Float:
-      return "'float'";
-    case ElementKind::String:
-      return "'string'";
-    }
-    llvm_unreachable("unknown element kind");
-  }
-
-  /// Build a Dense Integer attribute for the given type.
-  DenseElementsAttr getIntAttr(llvm::SMLoc loc, ShapedType type, Type eltTy);
+  /// Get the parsed elements for an integer attribute.
+  ParseResult getIntAttrElements(llvm::SMLoc loc, Type eltTy,
+                                 std::vector<APInt> &intValues);
 
-  /// Build a Dense Float attribute for the given type.
-  DenseElementsAttr getFloatAttr(llvm::SMLoc loc, ShapedType type,
-                                 FloatType eltTy);
+  /// Get the parsed elements for a float attribute.
+  ParseResult getFloatAttrElements(llvm::SMLoc loc, FloatType eltTy,
+                                   std::vector<APFloat> &floatValues);
 
   /// Build a Dense String attribute for the given type.
   DenseElementsAttr getStringAttr(llvm::SMLoc loc, ShapedType type, Type eltTy);
@@ -2011,9 +1995,6 @@ class TensorLiteralParser {
   /// Storage used when parsing elements, this is a pair of <is_negated, token>.
   std::vector<std::pair<bool, Token>> storage;
 
-  /// A flag that indicates the type of elements that have been parsed.
-  Optional<ElementKind> knownEltKind;
-
   /// Storage used when parsing elements that were stored as hex values.
   Optional<Token> hexStorage;
 };
@@ -2053,22 +2034,40 @@ DenseElementsAttr TensorLiteralParser::getAttr(llvm::SMLoc loc,
     return nullptr;
   }
 
-  // If the type is an integer, build a set of APInt values from the storage
-  // with the correct bitwidth.
-  if (auto intTy = eltType.dyn_cast<IntegerType>())
-    return getIntAttr(loc, type, intTy);
-  if (auto indexTy = eltType.dyn_cast<IndexType>())
-    return getIntAttr(loc, type, indexTy);
-
-  // If parsing a floating point type.
-  if (auto floatTy = eltType.dyn_cast<FloatType>())
-    return getFloatAttr(loc, type, floatTy);
+  // Handle complex types in the specific element type cases below.
+  bool isComplex = false;
+  if (ComplexType complexTy = eltType.dyn_cast<ComplexType>()) {
+    eltType = complexTy.getElementType();
+    isComplex = true;
+  }
 
-  // If parsing a complex type.
-  // TODO: Support complex elements with pretty element printing.
-  if (eltType.isa<ComplexType>()) {
-    p.emitError(loc) << "complex elements only support hex formatting";
-    return nullptr;
+  // Handle integer and index types.
+  if (eltType.isIntOrIndex()) {
+    std::vector<APInt> intValues;
+    if (failed(getIntAttrElements(loc, eltType, intValues)))
+      return nullptr;
+    if (isComplex) {
+      // If this is a complex, treat the parsed values as complex values.
+      auto complexData = llvm::makeArrayRef(
+          reinterpret_cast<std::complex<APInt> *>(intValues.data()),
+          intValues.size() / 2);
+      return DenseElementsAttr::get(type, complexData);
+    }
+    return DenseElementsAttr::get(type, intValues);
+  }
+  // Handle floating point types.
+  if (FloatType floatTy = eltType.dyn_cast<FloatType>()) {
+    std::vector<APFloat> floatValues;
+    if (failed(getFloatAttrElements(loc, floatTy, floatValues)))
+      return nullptr;
+    if (isComplex) {
+      // If this is a complex, treat the parsed values as complex values.
+      auto complexData = llvm::makeArrayRef(
+          reinterpret_cast<std::complex<APFloat> *>(floatValues.data()),
+          floatValues.size() / 2);
+      return DenseElementsAttr::get(type, complexData);
+    }
+    return DenseElementsAttr::get(type, floatValues);
   }
 
   // Other types are assumed to be string representations.
@@ -2076,39 +2075,36 @@ DenseElementsAttr TensorLiteralParser::getAttr(llvm::SMLoc loc,
 }
 
 /// Build a Dense Integer attribute for the given type.
-DenseElementsAttr TensorLiteralParser::getIntAttr(llvm::SMLoc loc,
-                                                  ShapedType type, Type eltTy) {
-  std::vector<APInt> intElements;
-  intElements.reserve(storage.size());
-  auto isUintType = type.getElementType().isUnsignedInteger();
+ParseResult
+TensorLiteralParser::getIntAttrElements(llvm::SMLoc loc, Type eltTy,
+                                        std::vector<APInt> &intValues) {
+  intValues.reserve(storage.size());
+  bool isUintType = eltTy.isUnsignedInteger();
   for (const auto &signAndToken : storage) {
     bool isNegative = signAndToken.first;
     const Token &token = signAndToken.second;
     auto tokenLoc = token.getLoc();
 
     if (isNegative && isUintType) {
-      p.emitError(tokenLoc)
-          << "expected unsigned integer elements, but parsed negative value";
-      return nullptr;
+      return p.emitError(tokenLoc)
+             << "expected unsigned integer elements, but parsed negative value";
     }
 
     // Check to see if floating point values were parsed.
     if (token.is(Token::floatliteral)) {
-      p.emitError(tokenLoc)
-          << "expected integer elements, but parsed floating-point";
-      return nullptr;
+      return p.emitError(tokenLoc)
+             << "expected integer elements, but parsed floating-point";
     }
 
     assert(token.isAny(Token::integer, Token::kw_true, Token::kw_false) &&
            "unexpected token type");
     if (token.isAny(Token::kw_true, Token::kw_false)) {
       if (!eltTy.isInteger(1)) {
-        p.emitError(tokenLoc)
-            << "expected i1 type for 'true' or 'false' values";
-        return nullptr;
+        return p.emitError(tokenLoc)
+               << "expected i1 type for 'true' or 'false' values";
       }
       APInt apInt(1, token.is(Token::kw_true), /*isSigned=*/false);
-      intElements.push_back(apInt);
+      intValues.push_back(apInt);
       continue;
     }
 
@@ -2116,19 +2112,16 @@ DenseElementsAttr TensorLiteralParser::getIntAttr(llvm::SMLoc loc,
     Optional<APInt> apInt =
         buildAttributeAPInt(eltTy, isNegative, token.getSpelling());
     if (!apInt)
-      return (p.emitError(tokenLoc, "integer constant out of range for type"),
-              nullptr);
-    intElements.push_back(*apInt);
+      return p.emitError(tokenLoc, "integer constant out of range for type");
+    intValues.push_back(*apInt);
   }
-
-  return DenseElementsAttr::get(type, intElements);
+  return success();
 }
 
 /// Build a Dense Float attribute for the given type.
-DenseElementsAttr TensorLiteralParser::getFloatAttr(llvm::SMLoc loc,
-                                                    ShapedType type,
-                                                    FloatType eltTy) {
-  std::vector<APFloat> floatValues;
+ParseResult
+TensorLiteralParser::getFloatAttrElements(llvm::SMLoc loc, FloatType eltTy,
+                                          std::vector<APFloat> &floatValues) {
   floatValues.reserve(storage.size());
   for (const auto &signAndToken : storage) {
     bool isNegative = signAndToken.first;
@@ -2137,34 +2130,31 @@ DenseElementsAttr TensorLiteralParser::getFloatAttr(llvm::SMLoc loc,
     // Handle hexadecimal float literals.
     if (token.is(Token::integer) && token.getSpelling().startswith("0x")) {
       if (isNegative) {
-        p.emitError(token.getLoc())
-            << "hexadecimal float literal should not have a leading minus";
-        return nullptr;
+        return p.emitError(token.getLoc())
+               << "hexadecimal float literal should not have a leading minus";
       }
       auto val = token.getUInt64IntegerValue();
       if (!val.hasValue()) {
-        p.emitError("hexadecimal float constant out of range for attribute");
-        return nullptr;
+        return p.emitError(
+            "hexadecimal float constant out of range for attribute");
       }
       Optional<APFloat> apVal = buildHexadecimalFloatLiteral(&p, eltTy, *val);
       if (!apVal)
-        return nullptr;
+        return failure();
       floatValues.push_back(*apVal);
       continue;
     }
 
     // Check to see if any decimal integers or booleans were parsed.
-    if (!token.is(Token::floatliteral)) {
-      p.emitError() << "expected floating-point elements, but parsed integer";
-      return nullptr;
-    }
+    if (!token.is(Token::floatliteral))
+      return p.emitError()
+             << "expected floating-point elements, but parsed integer";
 
     // Build the float values from tokens.
     auto val = token.getFloatingPointValue();
-    if (!val.hasValue()) {
-      p.emitError("floating point value too large for attribute");
-      return nullptr;
-    }
+    if (!val.hasValue())
+      return p.emitError("floating point value too large for attribute");
+
     // Treat BF16 as double because it is not supported in LLVM's APFloat.
     APFloat apVal(isNegative ? -*val : *val);
     if (!eltTy.isBF16() && !eltTy.isF64()) {
@@ -2174,8 +2164,7 @@ DenseElementsAttr TensorLiteralParser::getFloatAttr(llvm::SMLoc loc,
     }
     floatValues.push_back(apVal);
   }
-
-  return DenseElementsAttr::get(type, floatValues);
+  return success();
 }
 
 /// Build a Dense String attribute for the given type.
@@ -2250,6 +2239,17 @@ ParseResult TensorLiteralParser::parseElement() {
     storage.emplace_back(/*isNegative=*/ false, p.getToken());
     p.consumeToken();
     break;
+
+  // Parse a complex element of the form '(' element ',' element ')'.
+  case Token::l_paren:
+    p.consumeToken(Token::l_paren);
+    if (parseElement() ||
+        p.parseToken(Token::comma, "expected ',' between complex elements") ||
+        parseElement() ||
+        p.parseToken(Token::r_paren, "expected ')' after complex elements"))
+      return failure();
+    break;
+
   default:
     return p.emitError("expected element literal of primitive type");
   }

diff  --git a/mlir/test/IR/dense-elements-hex.mlir b/mlir/test/IR/dense-elements-hex.mlir
index 87c0acf80341..e0e12418e1d5 100644
--- a/mlir/test/IR/dense-elements-hex.mlir
+++ b/mlir/test/IR/dense-elements-hex.mlir
@@ -7,11 +7,8 @@
 // CHECK: dense<[1.000000e+01, 5.000000e+00]> : tensor<2xf64>
 "foo.op"() {dense.attr = dense<"0x00000000000024400000000000001440"> : tensor<2xf64>} : () -> ()
 
-// CHECK: dense<"0x00000000000024400000000000001440"> : tensor<1xcomplex<f64>>
-"foo.op"() {dense.attr = dense<"0x00000000000024400000000000001440"> : tensor<1xcomplex<f64>>} : () -> ()
-
-// CHECK: dense<"0x00000000000024400000000000001440"> : tensor<10xcomplex<f64>>
-"foo.op"() {dense.attr = dense<"0x00000000000024400000000000001440"> : tensor<10xcomplex<f64>>} : () -> ()
+// CHECK: dense<(1.000000e+01,5.000000e+00)> : tensor<2xcomplex<f64>>
+"foo.op"() {dense.attr = dense<"0x0000000000002440000000000000144000000000000024400000000000001440"> : tensor<2xcomplex<f64>>} : () -> ()
 
 // CHECK: dense<[1.000000e+01, 5.000000e+00]> : tensor<2xbf16>
 "foo.op"() {dense.attr = dense<"0x00000000000024400000000000001440"> : tensor<2xbf16>} : () -> ()

diff  --git a/mlir/test/IR/invalid.mlir b/mlir/test/IR/invalid.mlir
index 96b234834481..926c2e3c2857 100644
--- a/mlir/test/IR/invalid.mlir
+++ b/mlir/test/IR/invalid.mlir
@@ -689,6 +689,22 @@ func @elementsattr_toolarge2() -> () {
 
 // -----
 
+"foo"(){bar = dense<[()]> : tensor<complex<i64>>} : () -> () // expected-error {{expected element literal of primitive type}}
+
+// -----
+
+"foo"(){bar = dense<[(10)]> : tensor<complex<i64>>} : () -> () // expected-error {{expected ',' between complex elements}}
+
+// -----
+
+"foo"(){bar = dense<[(10,)]> : tensor<complex<i64>>} : () -> () // expected-error {{expected element literal of primitive type}}
+
+// -----
+
+"foo"(){bar = dense<[(10,10]> : tensor<complex<i64>>} : () -> () // expected-error {{expected ')' after complex elements}}
+
+// -----
+
 func @elementsattr_malformed_opaque() -> () {
 ^bb0:
   "foo"(){bar = opaque<10, "0xQZz123"> : tensor<1xi8>} : () -> () // expected-error {{expected dialect namespace}}

diff  --git a/mlir/test/IR/parser.mlir b/mlir/test/IR/parser.mlir
index 2170927df927..34b1da4282f1 100644
--- a/mlir/test/IR/parser.mlir
+++ b/mlir/test/IR/parser.mlir
@@ -702,6 +702,15 @@ func @densetensorattr() -> () {
   "index"(){bar = dense<1> : tensor<index>} : () -> ()
 // CHECK: "index"() {bar = dense<[1, 2]> : tensor<2xindex>} : () -> ()
   "index"(){bar = dense<[1, 2]> : tensor<2xindex>} : () -> ()
+
+  // CHECK: dense<(1,1)> : tensor<complex<i64>>
+  "complex_attr"(){bar = dense<(1,1)> : tensor<complex<i64>>} : () -> ()
+  // CHECK: dense<[(1,1), (2,2)]> : tensor<2xcomplex<i64>>
+  "complex_attr"(){bar = dense<[(1,1), (2,2)]> : tensor<2xcomplex<i64>>} : () -> ()
+  // CHECK: dense<(1.000000e+00,0.000000e+00)> : tensor<complex<f32>>
+  "complex_attr"(){bar = dense<(1.000000e+00,0.000000e+00)> : tensor<complex<f32>>} : () -> ()
+  // CHECK: dense<[(1.000000e+00,0.000000e+00), (2.000000e+00,2.000000e+00)]> : tensor<2xcomplex<f32>>
+  "complex_attr"(){bar = dense<[(1.000000e+00,0.000000e+00), (2.000000e+00,2.000000e+00)]> : tensor<2xcomplex<f32>>} : () -> ()
   return
 }
 

diff  --git a/mlir/unittests/IR/AttributeTest.cpp b/mlir/unittests/IR/AttributeTest.cpp
index 7f91d57b01ae..8fda2a2e73b6 100644
--- a/mlir/unittests/IR/AttributeTest.cpp
+++ b/mlir/unittests/IR/AttributeTest.cpp
@@ -27,7 +27,7 @@ static void testSplat(Type eltType, const EltTy &splatElt) {
   EXPECT_EQ(detectedSplat, splat);
 
   for (auto newValue : detectedSplat.template getValues<EltTy>())
-    EXPECT_EQ(newValue, splatElt);
+    EXPECT_TRUE(newValue == splatElt);
 }
 
 namespace {
@@ -179,4 +179,18 @@ TEST(DenseComplexTest, ComplexIntSplat) {
   testSplat(complexType, value);
 }
 
+TEST(DenseComplexTest, ComplexAPFloatSplat) {
+  MLIRContext context;
+  ComplexType complexType = ComplexType::get(FloatType::getF32(&context));
+  std::complex<APFloat> value(APFloat(10.0f), APFloat(15.0f));
+  testSplat(complexType, value);
+}
+
+TEST(DenseComplexTest, ComplexAPIntSplat) {
+  MLIRContext context;
+  ComplexType complexType = ComplexType::get(IntegerType::get(64, &context));
+  std::complex<APInt> value(APInt(64, 10), APInt(64, 15));
+  testSplat(complexType, value);
+}
+
 } // end namespace


        


More information about the Mlir-commits mailing list