[Mlir-commits] [mlir] 2e71dad - [mlir][DenseElementsAttr] Allow for custom floating point types in `getValues`

River Riddle llvmlistbot at llvm.org
Thu Nov 12 22:51:52 PST 2020


Author: River Riddle
Date: 2020-11-12T22:47:30-08:00
New Revision: 2e71dad3328e03337eca53352e7d45b6efc7e0a2

URL: https://github.com/llvm/llvm-project/commit/2e71dad3328e03337eca53352e7d45b6efc7e0a2
DIFF: https://github.com/llvm/llvm-project/commit/2e71dad3328e03337eca53352e7d45b6efc7e0a2.diff

LOG: [mlir][DenseElementsAttr] Allow for custom floating point types in `getValues`

Some users have native c++ data types that correspond to floating point values stored within a DenseElementsAttr that do not have a corresponding native C++ data type(e.g. bfloat16/half/etc). This revision allows for such users to use those native types directly, and removes the need to go through APFloat when the much faster native value path is available.

Differential Revision: https://reviews.llvm.org/D91402

Added: 
    

Modified: 
    mlir/include/mlir/IR/Attributes.h

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/IR/Attributes.h b/mlir/include/mlir/IR/Attributes.h
index f02470891fcf..0432d0584f9c 100644
--- a/mlir/include/mlir/IR/Attributes.h
+++ b/mlir/include/mlir/IR/Attributes.h
@@ -675,6 +675,21 @@ class DenseElementsAttr : public ElementsAttr {
 public:
   using ElementsAttr::ElementsAttr;
 
+  /// Type trait used to check if the given type T is a potentially valid C++
+  /// floating point type that can be used to access the underlying element
+  /// types of a DenseElementsAttr.
+  // TODO: Use std::disjunction when C++17 is supported.
+  template <typename T> struct is_valid_cpp_fp_type {
+    /// The type is a valid floating point type if it is a builtin floating
+    /// point type, or is a potentially user defined floating point type. The
+    /// latter allows for supporting users that have custom types defined for
+    /// bfloat16/half/etc.
+    static inline constexpr bool value =
+        llvm::is_one_of<T, float, double>::value ||
+        (std::numeric_limits<T>::is_specialized &&
+         !std::numeric_limits<T>::is_integer);
+  };
+
   /// Method for support type inquiry through isa, cast and dyn_cast.
   static bool classof(Attribute attr);
 
@@ -690,7 +705,7 @@ class DenseElementsAttr : public ElementsAttr {
   /// static shape.
   template <typename T, typename = typename std::enable_if<
                             std::numeric_limits<T>::is_integer ||
-                            llvm::is_one_of<T, float, double>::value>::type>
+                            is_valid_cpp_fp_type<T>::value>::type>
   static DenseElementsAttr get(const ShapedType &type, ArrayRef<T> values) {
     const char *data = reinterpret_cast<const char *>(values.data());
     return getRawIntOrFloat(
@@ -701,7 +716,7 @@ class DenseElementsAttr : public ElementsAttr {
   /// Constructs a dense integer elements attribute from a single element.
   template <typename T, typename = typename std::enable_if<
                             std::numeric_limits<T>::is_integer ||
-                            llvm::is_one_of<T, float, double>::value ||
+                            is_valid_cpp_fp_type<T>::value ||
                             detail::is_complex_t<T>::value>::type>
   static DenseElementsAttr get(const ShapedType &type, T value) {
     return get(type, llvm::makeArrayRef(value));
@@ -714,7 +729,7 @@ class DenseElementsAttr : public ElementsAttr {
             typename = typename std::enable_if<
                 detail::is_complex_t<T>::value &&
                 (std::numeric_limits<ElementT>::is_integer ||
-                 llvm::is_one_of<ElementT, float, double>::value)>::type>
+                 is_valid_cpp_fp_type<ElementT>::value)>::type>
   static DenseElementsAttr get(const ShapedType &type, ArrayRef<T> values) {
     const char *data = reinterpret_cast<const char *>(values.data());
     return getRawComplex(type, ArrayRef<char>(data, values.size() * sizeof(T)),
@@ -944,7 +959,7 @@ class DenseElementsAttr : public ElementsAttr {
   template <typename T, typename = typename std::enable_if<
                             (!std::is_same<T, bool>::value &&
                              std::numeric_limits<T>::is_integer) ||
-                            llvm::is_one_of<T, float, double>::value>::type>
+                            is_valid_cpp_fp_type<T>::value>::type>
   llvm::iterator_range<ElementIterator<T>> getValues() const {
     assert(isValidIntOrFloat(sizeof(T), std::numeric_limits<T>::is_integer,
                              std::numeric_limits<T>::is_signed));
@@ -959,7 +974,7 @@ class DenseElementsAttr : public ElementsAttr {
             typename = typename std::enable_if<
                 detail::is_complex_t<T>::value &&
                 (std::numeric_limits<ElementT>::is_integer ||
-                 llvm::is_one_of<ElementT, float, double>::value)>::type>
+                 is_valid_cpp_fp_type<ElementT>::value)>::type>
   llvm::iterator_range<ElementIterator<T>> getValues() const {
     assert(isValidComplex(sizeof(T), std::numeric_limits<ElementT>::is_integer,
                           std::numeric_limits<ElementT>::is_signed));
@@ -1411,7 +1426,8 @@ class SparseElementsAttr
   template <typename T>
   typename std::enable_if<
       std::numeric_limits<T>::is_integer ||
-          llvm::is_one_of<T, float, double, StringRef>::value ||
+          DenseElementsAttr::is_valid_cpp_fp_type<T>::value ||
+          std::is_same<T, StringRef>::value ||
           (detail::is_complex_t<T>::value &&
            !llvm::is_one_of<T, std::complex<APInt>,
                             std::complex<APFloat>>::value),


        


More information about the Mlir-commits mailing list