[Mlir-commits] [mlir] 1c66bac - [mlir] Fix try_value_begin_impl for DenseElementsAttr

Jeff Niu llvmlistbot at llvm.org
Tue Aug 30 14:13:03 PDT 2022


Author: Jeff Niu
Date: 2022-08-30T14:12:48-07:00
New Revision: 1c66bacd6cde1f37d6ac96c45b389666a1334ec0

URL: https://github.com/llvm/llvm-project/commit/1c66bacd6cde1f37d6ac96c45b389666a1334ec0
DIFF: https://github.com/llvm/llvm-project/commit/1c66bacd6cde1f37d6ac96c45b389666a1334ec0.diff

LOG: [mlir] Fix try_value_begin_impl for DenseElementsAttr

The previous implementation would still crash if the element type was
not iterable. This patch changes SparseElementsAttr to properly
implement `try_value_begin_impl` according to ElementsAttr and changes
DenseElementsAttr to implement `tryGetValues` as the basis for querying
element values.

Depends on D132904

Reviewed By: rriddle

Differential Revision: https://reviews.llvm.org/D132958

Added: 
    

Modified: 
    mlir/include/mlir/IR/BuiltinAttributes.h
    mlir/include/mlir/IR/BuiltinAttributes.td
    mlir/lib/IR/BuiltinAttributes.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/IR/BuiltinAttributes.h b/mlir/include/mlir/IR/BuiltinAttributes.h
index a102311beabc..5e9370061597 100644
--- a/mlir/include/mlir/IR/BuiltinAttributes.h
+++ b/mlir/include/mlir/IR/BuiltinAttributes.h
@@ -392,7 +392,45 @@ class DenseElementsAttr : public Attribute {
     return getSplatValue<Attribute>().template cast<T>();
   }
 
-  /// Return the held element values as a range of integer or floating-point
+  /// Try to get an iterator of the given type to the start of the held element
+  /// values. Return failure if the type cannot be iterated.
+  template <typename T>
+  auto try_value_begin() const {
+    auto range = tryGetValues<T>();
+    using iterator = decltype(range->begin());
+    return failed(range) ? FailureOr<iterator>(failure()) : range->begin();
+  }
+
+  /// Try to get an iterator of the given type to the end of the held element
+  /// values. Return failure if the type cannot be iterated.
+  template <typename T>
+  auto try_value_end() const {
+    auto range = tryGetValues<T>();
+    using iterator = decltype(range->begin());
+    return failed(range) ? FailureOr<iterator>(failure()) : range->end();
+  }
+
+  /// Return the held element values as a range of the given type.
+  template <typename T>
+  auto getValues() const {
+    auto range = tryGetValues<T>();
+    assert(succeeded(range) && "element type cannot be iterated");
+    return std::move(*range);
+  }
+
+  /// Get an iterator of the given type to the start of the held element values.
+  template <typename T>
+  auto value_begin() const {
+    return getValues<T>().begin();
+  }
+
+  /// Get an iterator of the given type to the end of the held element values.
+  template <typename T>
+  auto value_end() const {
+    return getValues<T>().end();
+  }
+
+  /// Try to get the held element values as a range of integer or floating-point
   /// values.
   template <typename T>
   using IntFloatValueTemplateCheckT =
@@ -400,28 +438,18 @@ class DenseElementsAttr : public Attribute {
                                std::numeric_limits<T>::is_integer) ||
                               is_valid_cpp_fp_type<T>::value>::type;
   template <typename T, typename = IntFloatValueTemplateCheckT<T>>
-  iterator_range_impl<ElementIterator<T>> getValues() const {
-    assert(isValidIntOrFloat(sizeof(T), std::numeric_limits<T>::is_integer,
-                             std::numeric_limits<T>::is_signed));
+  FailureOr<iterator_range_impl<ElementIterator<T>>> tryGetValues() const {
+    if (!isValidIntOrFloat(sizeof(T), std::numeric_limits<T>::is_integer,
+                           std::numeric_limits<T>::is_signed))
+      return failure();
     const char *rawData = getRawData().data();
     bool splat = isSplat();
-    return {getType(), ElementIterator<T>(rawData, splat, 0),
-            ElementIterator<T>(rawData, splat, getNumElements())};
-  }
-  template <typename T, typename = IntFloatValueTemplateCheckT<T>>
-  ElementIterator<T> value_begin() const {
-    assert(isValidIntOrFloat(sizeof(T), std::numeric_limits<T>::is_integer,
-                             std::numeric_limits<T>::is_signed));
-    return ElementIterator<T>(getRawData().data(), isSplat(), 0);
-  }
-  template <typename T, typename = IntFloatValueTemplateCheckT<T>>
-  ElementIterator<T> value_end() const {
-    assert(isValidIntOrFloat(sizeof(T), std::numeric_limits<T>::is_integer,
-                             std::numeric_limits<T>::is_signed));
-    return ElementIterator<T>(getRawData().data(), isSplat(), getNumElements());
+    return iterator_range_impl<ElementIterator<T>>(
+        getType(), ElementIterator<T>(rawData, splat, 0),
+        ElementIterator<T>(rawData, splat, getNumElements()));
   }
 
-  /// Return the held element values as a range of std::complex.
+  /// Try to get the held element values as a range of std::complex.
   template <typename T, typename ElementT>
   using ComplexValueTemplateCheckT =
       typename std::enable_if<detail::is_complex_t<T>::value &&
@@ -429,70 +457,45 @@ class DenseElementsAttr : public Attribute {
                                is_valid_cpp_fp_type<ElementT>::value)>::type;
   template <typename T, typename ElementT = typename T::value_type,
             typename = ComplexValueTemplateCheckT<T, ElementT>>
-  iterator_range_impl<ElementIterator<T>> getValues() const {
-    assert(isValidComplex(sizeof(T), std::numeric_limits<ElementT>::is_integer,
-                          std::numeric_limits<ElementT>::is_signed));
+  FailureOr<iterator_range_impl<ElementIterator<T>>> tryGetValues() const {
+    if (!isValidComplex(sizeof(T), std::numeric_limits<ElementT>::is_integer,
+                        std::numeric_limits<ElementT>::is_signed))
+      return failure();
     const char *rawData = getRawData().data();
     bool splat = isSplat();
-    return {getType(), ElementIterator<T>(rawData, splat, 0),
-            ElementIterator<T>(rawData, splat, getNumElements())};
-  }
-  template <typename T, typename ElementT = typename T::value_type,
-            typename = ComplexValueTemplateCheckT<T, ElementT>>
-  ElementIterator<T> value_begin() const {
-    assert(isValidComplex(sizeof(T), std::numeric_limits<ElementT>::is_integer,
-                          std::numeric_limits<ElementT>::is_signed));
-    return ElementIterator<T>(getRawData().data(), isSplat(), 0);
-  }
-  template <typename T, typename ElementT = typename T::value_type,
-            typename = ComplexValueTemplateCheckT<T, ElementT>>
-  ElementIterator<T> value_end() const {
-    assert(isValidComplex(sizeof(T), std::numeric_limits<ElementT>::is_integer,
-                          std::numeric_limits<ElementT>::is_signed));
-    return ElementIterator<T>(getRawData().data(), isSplat(), getNumElements());
+    return iterator_range_impl<ElementIterator<T>>(
+        getType(), ElementIterator<T>(rawData, splat, 0),
+        ElementIterator<T>(rawData, splat, getNumElements()));
   }
 
-  /// Return the held element values as a range of StringRef.
+  /// Try to get the held element values as a range of StringRef.
   template <typename T>
   using StringRefValueTemplateCheckT =
       typename std::enable_if<std::is_same<T, StringRef>::value>::type;
   template <typename T, typename = StringRefValueTemplateCheckT<T>>
-  iterator_range_impl<ElementIterator<StringRef>> getValues() const {
+  FailureOr<iterator_range_impl<ElementIterator<StringRef>>>
+  tryGetValues() const {
     auto stringRefs = getRawStringData();
     const char *ptr = reinterpret_cast<const char *>(stringRefs.data());
     bool splat = isSplat();
-    return {getType(), ElementIterator<StringRef>(ptr, splat, 0),
-            ElementIterator<StringRef>(ptr, splat, getNumElements())};
-  }
-  template <typename T, typename = StringRefValueTemplateCheckT<T>>
-  ElementIterator<StringRef> value_begin() const {
-    const char *ptr = reinterpret_cast<const char *>(getRawStringData().data());
-    return ElementIterator<StringRef>(ptr, isSplat(), 0);
-  }
-  template <typename T, typename = StringRefValueTemplateCheckT<T>>
-  ElementIterator<StringRef> value_end() const {
-    const char *ptr = reinterpret_cast<const char *>(getRawStringData().data());
-    return ElementIterator<StringRef>(ptr, isSplat(), getNumElements());
+    return iterator_range_impl<ElementIterator<StringRef>>(
+        getType(), ElementIterator<StringRef>(ptr, splat, 0),
+        ElementIterator<StringRef>(ptr, splat, getNumElements()));
   }
 
-  /// Return the held element values as a range of Attributes.
+  /// Try to get the held element values as a range of Attributes.
   template <typename T>
   using AttributeValueTemplateCheckT =
       typename std::enable_if<std::is_same<T, Attribute>::value>::type;
   template <typename T, typename = AttributeValueTemplateCheckT<T>>
-  iterator_range_impl<AttributeElementIterator> getValues() const {
-    return {getType(), value_begin<Attribute>(), value_end<Attribute>()};
-  }
-  template <typename T, typename = AttributeValueTemplateCheckT<T>>
-  AttributeElementIterator value_begin() const {
-    return AttributeElementIterator(*this, 0);
-  }
-  template <typename T, typename = AttributeValueTemplateCheckT<T>>
-  AttributeElementIterator value_end() const {
-    return AttributeElementIterator(*this, getNumElements());
+  FailureOr<iterator_range_impl<AttributeElementIterator>>
+  tryGetValues() const {
+    return iterator_range_impl<AttributeElementIterator>(
+        getType(), AttributeElementIterator(*this, 0),
+        AttributeElementIterator(*this, getNumElements()));
   }
 
-  /// Return the held element values a range of T, where T is a derived
+  /// Try to get the held element values a range of T, where T is a derived
   /// attribute type.
   template <typename T>
   using DerivedAttrValueTemplateCheckT =
@@ -510,115 +513,71 @@ class DenseElementsAttr : public Attribute {
     T mapElement(Attribute attr) const { return attr.cast<T>(); }
   };
   template <typename T, typename = DerivedAttrValueTemplateCheckT<T>>
-  iterator_range_impl<DerivedAttributeElementIterator<T>> getValues() const {
+  FailureOr<iterator_range_impl<DerivedAttributeElementIterator<T>>>
+  tryGetValues() const {
     using DerivedIterT = DerivedAttributeElementIterator<T>;
-    return {getType(), DerivedIterT(value_begin<Attribute>()),
-            DerivedIterT(value_end<Attribute>())};
-  }
-  template <typename T, typename = DerivedAttrValueTemplateCheckT<T>>
-  DerivedAttributeElementIterator<T> value_begin() const {
-    return {value_begin<Attribute>()};
-  }
-  template <typename T, typename = DerivedAttrValueTemplateCheckT<T>>
-  DerivedAttributeElementIterator<T> value_end() const {
-    return {value_end<Attribute>()};
+    return iterator_range_impl<DerivedIterT>(
+        getType(), DerivedIterT(value_begin<Attribute>()),
+        DerivedIterT(value_end<Attribute>()));
   }
 
-  /// Return the held element values as a range of bool. The element type of
+  /// Try to get the held element values as a range of bool. The element type of
   /// this attribute must be of integer type of bitwidth 1.
   template <typename T>
   using BoolValueTemplateCheckT =
       typename std::enable_if<std::is_same<T, bool>::value>::type;
   template <typename T, typename = BoolValueTemplateCheckT<T>>
-  iterator_range_impl<BoolElementIterator> getValues() const {
-    assert(isValidBool() && "bool is not the value of this elements attribute");
-    return {getType(), BoolElementIterator(*this, 0),
-            BoolElementIterator(*this, getNumElements())};
-  }
-  template <typename T, typename = BoolValueTemplateCheckT<T>>
-  BoolElementIterator value_begin() const {
-    assert(isValidBool() && "bool is not the value of this elements attribute");
-    return BoolElementIterator(*this, 0);
-  }
-  template <typename T, typename = BoolValueTemplateCheckT<T>>
-  BoolElementIterator value_end() const {
-    assert(isValidBool() && "bool is not the value of this elements attribute");
-    return BoolElementIterator(*this, getNumElements());
+  FailureOr<iterator_range_impl<BoolElementIterator>> tryGetValues() const {
+    if (!isValidBool())
+      return failure();
+    return iterator_range_impl<BoolElementIterator>(
+        getType(), BoolElementIterator(*this, 0),
+        BoolElementIterator(*this, getNumElements()));
   }
 
-  /// Return the held element values as a range of APInts. The element type of
-  /// this attribute must be of integer type.
+  /// Try to get the held element values as a range of APInts. The element type
+  /// of this attribute must be of integer type.
   template <typename T>
   using APIntValueTemplateCheckT =
       typename std::enable_if<std::is_same<T, APInt>::value>::type;
   template <typename T, typename = APIntValueTemplateCheckT<T>>
-  iterator_range_impl<IntElementIterator> getValues() const {
-    assert(getElementType().isIntOrIndex() && "expected integral type");
-    return {getType(), raw_int_begin(), raw_int_end()};
-  }
-  template <typename T, typename = APIntValueTemplateCheckT<T>>
-  IntElementIterator value_begin() const {
-    assert(getElementType().isIntOrIndex() && "expected integral type");
-    return raw_int_begin();
-  }
-  template <typename T, typename = APIntValueTemplateCheckT<T>>
-  IntElementIterator value_end() const {
-    assert(getElementType().isIntOrIndex() && "expected integral type");
-    return raw_int_end();
+  FailureOr<iterator_range_impl<IntElementIterator>> tryGetValues() const {
+    if (!getElementType().isIntOrIndex())
+      return failure();
+    return iterator_range_impl<IntElementIterator>(getType(), raw_int_begin(),
+                                                   raw_int_end());
   }
 
-  /// Return the held element values as a range of complex APInts. The element
-  /// type of this attribute must be a complex of integer type.
+  /// Try to get the held element values as a range of complex APInts. The
+  /// element type of this attribute must be a complex of integer type.
   template <typename T>
   using ComplexAPIntValueTemplateCheckT = typename std::enable_if<
       std::is_same<T, std::complex<APInt>>::value>::type;
   template <typename T, typename = ComplexAPIntValueTemplateCheckT<T>>
-  iterator_range_impl<ComplexIntElementIterator> getValues() const {
-    return getComplexIntValues();
-  }
-  template <typename T, typename = ComplexAPIntValueTemplateCheckT<T>>
-  ComplexIntElementIterator value_begin() const {
-    return complex_value_begin();
-  }
-  template <typename T, typename = ComplexAPIntValueTemplateCheckT<T>>
-  ComplexIntElementIterator value_end() const {
-    return complex_value_end();
+  FailureOr<iterator_range_impl<ComplexIntElementIterator>>
+  tryGetValues() const {
+    return tryGetComplexIntValues();
   }
 
-  /// Return the held element values as a range of APFloat. The element type of
-  /// this attribute must be of float type.
+  /// Try to get the held element values as a range of APFloat. The element type
+  /// of this attribute must be of float type.
   template <typename T>
   using APFloatValueTemplateCheckT =
       typename std::enable_if<std::is_same<T, APFloat>::value>::type;
   template <typename T, typename = APFloatValueTemplateCheckT<T>>
-  iterator_range_impl<FloatElementIterator> getValues() const {
-    return getFloatValues();
-  }
-  template <typename T, typename = APFloatValueTemplateCheckT<T>>
-  FloatElementIterator value_begin() const {
-    return float_value_begin();
-  }
-  template <typename T, typename = APFloatValueTemplateCheckT<T>>
-  FloatElementIterator value_end() const {
-    return float_value_end();
+  FailureOr<iterator_range_impl<FloatElementIterator>> tryGetValues() const {
+    return tryGetFloatValues();
   }
 
-  /// Return the held element values as a range of complex APFloat. The element
-  /// type of this attribute must be a complex of float type.
+  /// Try to get the held element values as a range of complex APFloat. The
+  /// element type of this attribute must be a complex of float type.
   template <typename T>
   using ComplexAPFloatValueTemplateCheckT = typename std::enable_if<
       std::is_same<T, std::complex<APFloat>>::value>::type;
   template <typename T, typename = ComplexAPFloatValueTemplateCheckT<T>>
-  iterator_range_impl<ComplexFloatElementIterator> getValues() const {
-    return getComplexFloatValues();
-  }
-  template <typename T, typename = ComplexAPFloatValueTemplateCheckT<T>>
-  ComplexFloatElementIterator value_begin() const {
-    return complex_float_value_begin();
-  }
-  template <typename T, typename = ComplexAPFloatValueTemplateCheckT<T>>
-  ComplexFloatElementIterator value_end() const {
-    return complex_float_value_end();
+  FailureOr<iterator_range_impl<ComplexFloatElementIterator>>
+  tryGetValues() const {
+    return tryGetComplexFloatValues();
   }
 
   /// Return the raw storage data held by this attribute. Users should generally
@@ -687,16 +646,12 @@ class DenseElementsAttr : public Attribute {
   IntElementIterator raw_int_end() const {
     return IntElementIterator(*this, getNumElements());
   }
-  iterator_range_impl<ComplexIntElementIterator> getComplexIntValues() const;
-  ComplexIntElementIterator complex_value_begin() const;
-  ComplexIntElementIterator complex_value_end() const;
-  iterator_range_impl<FloatElementIterator> getFloatValues() const;
-  FloatElementIterator float_value_begin() const;
-  FloatElementIterator float_value_end() const;
-  iterator_range_impl<ComplexFloatElementIterator>
-  getComplexFloatValues() const;
-  ComplexFloatElementIterator complex_float_value_begin() const;
-  ComplexFloatElementIterator complex_float_value_end() const;
+  FailureOr<iterator_range_impl<ComplexIntElementIterator>>
+  tryGetComplexIntValues() const;
+  FailureOr<iterator_range_impl<FloatElementIterator>>
+  tryGetFloatValues() const;
+  FailureOr<iterator_range_impl<ComplexFloatElementIterator>>
+  tryGetComplexFloatValues() const;
 
   /// Overload of the raw 'get' method that asserts that the given type is of
   /// complex type. This method is used to verify type invariants that the
@@ -973,8 +928,8 @@ class DenseFPElementsAttr : public DenseIntOrFPElementsAttr {
             function_ref<APInt(const APFloat &)> mapping) const;
 
   /// Iterator access to the float element values.
-  iterator begin() const { return float_value_begin(); }
-  iterator end() const { return float_value_end(); }
+  iterator begin() const { return tryGetFloatValues()->begin(); }
+  iterator end() const { return tryGetFloatValues()->end(); }
 
   /// Method for supporting type inquiry through isa, cast and dyn_cast.
   static bool classof(Attribute attr);
@@ -1026,12 +981,15 @@ class DenseIntElementsAttr : public DenseIntOrFPElementsAttr {
 //===----------------------------------------------------------------------===//
 
 template <typename T>
-auto SparseElementsAttr::value_begin() const -> iterator<T> {
+auto SparseElementsAttr::try_value_begin_impl(OverloadToken<T>) const
+    -> FailureOr<iterator<T>> {
   auto zeroValue = getZeroValue<T>();
-  auto valueIt = getValues().value_begin<T>();
+  auto valueIt = getValues().try_value_begin<T>();
+  if (failed(valueIt))
+    return failure();
   const std::vector<ptr
diff _t> flatSparseIndices(getFlattenedSparseIndices());
   std::function<T(ptr
diff _t)> mapFn =
-      [flatSparseIndices{flatSparseIndices}, valueIt{std::move(valueIt)},
+      [flatSparseIndices{flatSparseIndices}, valueIt{std::move(*valueIt)},
        zeroValue{std::move(zeroValue)}](ptr
diff _t index) {
         // Try to map the current index to one of the sparse indices.
         for (unsigned i = 0, e = flatSparseIndices.size(); i != e; ++i)

diff  --git a/mlir/include/mlir/IR/BuiltinAttributes.td b/mlir/include/mlir/IR/BuiltinAttributes.td
index 36a856afe90b..da24332b1c56 100644
--- a/mlir/include/mlir/IR/BuiltinAttributes.td
+++ b/mlir/include/mlir/IR/BuiltinAttributes.td
@@ -303,7 +303,7 @@ def Builtin_DenseIntOrFPElementsAttr : Builtin_Attr<
     /// ElementsAttr.
     template <typename T>
     auto try_value_begin_impl(OverloadToken<T>) const {
-      return ::mlir::success(value_begin<T>());
+      return try_value_begin<T>();
     }
 
     /// Convert endianess of input ArrayRef for big-endian(BE) machines. All of
@@ -433,7 +433,7 @@ def Builtin_DenseStringElementsAttr : Builtin_Attr<
     /// ElementsAttr.
     template <typename T>
     auto try_value_begin_impl(OverloadToken<T>) const {
-      return ::mlir::success(value_begin<T>());
+      return try_value_begin<T>();
     }
 
   protected:
@@ -900,23 +900,17 @@ def Builtin_SparseElementsAttr : Builtin_Attr<
       StringRef
     >;
     using ElementsAttr::Trait<SparseElementsAttr>::getValues;
-
-    /// Provide a `try_value_begin_impl` to enable iteration within
-    /// ElementsAttr.
-    template <typename T>
-    auto try_value_begin_impl(OverloadToken<T>) const {
-      return ::mlir::success(value_begin<T>());
-    }
+    using ElementsAttr::Trait<SparseElementsAttr>::value_begin;
 
     template <typename T>
     using iterator =
         llvm::mapped_iterator<typename decltype(llvm::seq<ptr
diff _t>(0, 0))::iterator,
                               std::function<T(ptr
diff _t)>>;
 
-    /// Return the values of this attribute in the form of the given type 'T'.
-    /// 'T'  may be any of Attribute, APInt, APFloat, c++ integer/float types,
-    /// etc.
-    template <typename T> iterator<T> value_begin() const;
+    /// Provide a `try_value_begin_impl` to enable iteration within
+    /// ElementsAttr.
+    template <typename T>
+    FailureOr<iterator<T>> try_value_begin_impl(OverloadToken<T>) const;
 
   private:
     /// Get a zero APFloat for the given sparse attribute.

diff  --git a/mlir/lib/IR/BuiltinAttributes.cpp b/mlir/lib/IR/BuiltinAttributes.cpp
index c911599eaab1..da7e694e7459 100644
--- a/mlir/lib/IR/BuiltinAttributes.cpp
+++ b/mlir/lib/IR/BuiltinAttributes.cpp
@@ -1219,68 +1219,42 @@ bool DenseElementsAttr::isSplat() const {
 }
 
 /// Return if the given complex type has an integer element type.
-LLVM_ATTRIBUTE_UNUSED static bool isComplexOfIntType(Type type) {
+static bool isComplexOfIntType(Type type) {
   return type.cast<ComplexType>().getElementType().isa<IntegerType>();
 }
 
-auto DenseElementsAttr::getComplexIntValues() const
-    -> iterator_range_impl<ComplexIntElementIterator> {
-  assert(isComplexOfIntType(getElementType()) &&
-         "expected complex integral type");
-  return {getType(), ComplexIntElementIterator(*this, 0),
-          ComplexIntElementIterator(*this, getNumElements())};
-}
-auto DenseElementsAttr::complex_value_begin() const
-    -> ComplexIntElementIterator {
-  assert(isComplexOfIntType(getElementType()) &&
-         "expected complex integral type");
-  return ComplexIntElementIterator(*this, 0);
-}
-auto DenseElementsAttr::complex_value_end() const -> ComplexIntElementIterator {
-  assert(isComplexOfIntType(getElementType()) &&
-         "expected complex integral type");
-  return ComplexIntElementIterator(*this, getNumElements());
-}
-
-/// Return the held element values as a range of APFloat. The element type of
-/// this attribute must be of float type.
-auto DenseElementsAttr::getFloatValues() const
-    -> iterator_range_impl<FloatElementIterator> {
-  auto elementType = getElementType().cast<FloatType>();
-  const auto &elementSemantics = elementType.getFloatSemantics();
-  return {getType(), FloatElementIterator(elementSemantics, raw_int_begin()),
-          FloatElementIterator(elementSemantics, raw_int_end())};
-}
-auto DenseElementsAttr::float_value_begin() const -> FloatElementIterator {
-  auto elementType = getElementType().cast<FloatType>();
-  return FloatElementIterator(elementType.getFloatSemantics(), raw_int_begin());
-}
-auto DenseElementsAttr::float_value_end() const -> FloatElementIterator {
-  auto elementType = getElementType().cast<FloatType>();
-  return FloatElementIterator(elementType.getFloatSemantics(), raw_int_end());
-}
-
-auto DenseElementsAttr::getComplexFloatValues() const
-    -> iterator_range_impl<ComplexFloatElementIterator> {
-  Type eltTy = getElementType().cast<ComplexType>().getElementType();
-  assert(eltTy.isa<FloatType>() && "expected complex float type");
-  const auto &semantics = eltTy.cast<FloatType>().getFloatSemantics();
-  return {getType(),
-          {semantics, {*this, 0}},
-          {semantics, {*this, static_cast<size_t>(getNumElements())}}};
-}
-auto DenseElementsAttr::complex_float_value_begin() const
-    -> ComplexFloatElementIterator {
-  Type eltTy = getElementType().cast<ComplexType>().getElementType();
-  assert(eltTy.isa<FloatType>() && "expected complex float type");
-  return {eltTy.cast<FloatType>().getFloatSemantics(), {*this, 0}};
-}
-auto DenseElementsAttr::complex_float_value_end() const
-    -> ComplexFloatElementIterator {
-  Type eltTy = getElementType().cast<ComplexType>().getElementType();
-  assert(eltTy.isa<FloatType>() && "expected complex float type");
-  return {eltTy.cast<FloatType>().getFloatSemantics(),
-          {*this, static_cast<size_t>(getNumElements())}};
+auto DenseElementsAttr::tryGetComplexIntValues() const
+    -> FailureOr<iterator_range_impl<ComplexIntElementIterator>> {
+  if (!isComplexOfIntType(getElementType()))
+    return failure();
+  return iterator_range_impl<ComplexIntElementIterator>(
+      getType(), ComplexIntElementIterator(*this, 0),
+      ComplexIntElementIterator(*this, getNumElements()));
+}
+
+auto DenseElementsAttr::tryGetFloatValues() const
+    -> FailureOr<iterator_range_impl<FloatElementIterator>> {
+  auto eltTy = getElementType().dyn_cast<FloatType>();
+  if (!eltTy)
+    return failure();
+  const auto &elementSemantics = eltTy.getFloatSemantics();
+  return iterator_range_impl<FloatElementIterator>(
+      getType(), FloatElementIterator(elementSemantics, raw_int_begin()),
+      FloatElementIterator(elementSemantics, raw_int_end()));
+}
+
+auto DenseElementsAttr::tryGetComplexFloatValues() const
+    -> FailureOr<iterator_range_impl<ComplexFloatElementIterator>> {
+  auto complexTy = getElementType().dyn_cast<ComplexType>();
+  if (!complexTy)
+    return failure();
+  auto eltTy = complexTy.getElementType().dyn_cast<FloatType>();
+  if (!eltTy)
+    return failure();
+  const auto &semantics = eltTy.getFloatSemantics();
+  return iterator_range_impl<ComplexFloatElementIterator>(
+      getType(), {semantics, {*this, 0}},
+      {semantics, {*this, static_cast<size_t>(getNumElements())}});
 }
 
 /// Return the raw storage data held by this attribute.


        


More information about the Mlir-commits mailing list