[Mlir-commits] [mlir] 5cf5708 - [mlir][ElementsAttr] Change value_begin_impl to try_value_begin_impl

Jeff Niu llvmlistbot at llvm.org
Tue Aug 30 14:13:01 PDT 2022


Author: Jeff Niu
Date: 2022-08-30T14:12:46-07:00
New Revision: 5cf5708628a3a465d7845fb43e46bd7f97761ead

URL: https://github.com/llvm/llvm-project/commit/5cf5708628a3a465d7845fb43e46bd7f97761ead
DIFF: https://github.com/llvm/llvm-project/commit/5cf5708628a3a465d7845fb43e46bd7f97761ead.diff

LOG: [mlir][ElementsAttr] Change value_begin_impl to try_value_begin_impl

This patch changes `value_begin_impl` to a faillable
`try_value_begin_impl` so that specific cases can fail iteration if the
type doesn't match the internal storage.

Reviewed By: rriddle

Differential Revision: https://reviews.llvm.org/D132904

Added: 
    

Modified: 
    mlir/include/mlir/IR/BuiltinAttributeInterfaces.td
    mlir/include/mlir/IR/BuiltinAttributes.td
    mlir/include/mlir/Support/LogicalResult.h
    mlir/lib/IR/BuiltinAttributes.cpp
    mlir/test/IR/elements-attr-interface.mlir
    mlir/test/lib/Dialect/Test/TestAttrDefs.td
    mlir/test/lib/IR/TestBuiltinAttributeInterfaces.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/IR/BuiltinAttributeInterfaces.td b/mlir/include/mlir/IR/BuiltinAttributeInterfaces.td
index 58dad618d3cee..a9e6e0e12051a 100644
--- a/mlir/include/mlir/IR/BuiltinAttributeInterfaces.td
+++ b/mlir/include/mlir/IR/BuiltinAttributeInterfaces.td
@@ -54,33 +54,36 @@ def ElementsAttrInterface : AttrInterface<"ElementsAttr"> {
     using NonContiguousIterableTypesT = std::tuple<APInt, Attribute>;
     ```
 
-    * Provide a `iterator value_begin_impl(OverloadToken<T>) const` overload for
-      each iterable type
+    * Provide a `FailureOr<iterator> try_value_begin_impl(OverloadToken<T>) const`
+      overload for each iterable type
 
     These overloads should return an iterator to the start of the range for the
-    respective iterable type. Consider the example i64 elements attribute
-    described in the previous section. This attribute may define the
-    value_begin_impl overloads like so:
+    respective iterable type or fail if the type cannot be iterated. Consider
+    the example i64 elements attribute described in the previous section. This
+    attribute may define the value_begin_impl overloads like so:
 
     ```c++
     /// Provide begin iterators for the various iterable types.
     /// * uint64_t
-    auto value_begin_impl(OverloadToken<uint64_t>) const {
+    FailureOr<const uint64_t *>
+    value_begin_impl(OverloadToken<uint64_t>) const {
       return getElements().begin();
     }
     /// * APInt
     auto value_begin_impl(OverloadToken<llvm::APInt>) const {
-      return llvm::map_range(getElements(), [=](uint64_t value) {
+      auto it = llvm::map_range(getElements(), [=](uint64_t value) {
         return llvm::APInt(/*numBits=*/64, value);
       }).begin();
+      return FailureOr<decltype(it)>(std::move(it));
     }
     /// * Attribute
     auto value_begin_impl(OverloadToken<mlir::Attribute>) const {
       mlir::Type elementType = getType().getElementType();
-      return llvm::map_range(getElements(), [=](uint64_t value) {
+      auto it = llvm::map_range(getElements(), [=](uint64_t value) {
         return mlir::IntegerAttr::get(elementType,
                                       llvm::APInt(/*numBits=*/64, value));
       }).begin();
+      return FailureOr<decltype(it)>(std::move(it));
     }
     ```
 
@@ -244,18 +247,22 @@ def ElementsAttrInterface : AttrInterface<"ElementsAttr"> {
           /*isSplat=*/false, nullptr);
       }
 
-      auto valueIt = $_attr.value_begin_impl(OverloadToken<T>());
+      auto valueIt = $_attr.try_value_begin_impl(OverloadToken<T>());
+      if (::mlir::failed(valueIt))
+        return ::mlir::failure();
       return ::mlir::detail::ElementsAttrIndexer::contiguous(
-        $_attr.isSplat(), &*valueIt);
+        $_attr.isSplat(), &**valueIt);
     }
     /// Build an indexer for the given type `T`, which is represented via a
     /// non-contiguous range.
     template <typename T>
     ::mlir::FailureOr<::mlir::detail::ElementsAttrIndexer> buildValueResult(
         /*isContiguous*/std::false_type) const {
-      auto valueIt = $_attr.value_begin_impl(OverloadToken<T>());
+      auto valueIt = $_attr.try_value_begin_impl(OverloadToken<T>());
+      if (::mlir::failed(valueIt))
+        return ::mlir::failure();
       return ::mlir::detail::ElementsAttrIndexer::nonContiguous(
-        $_attr.isSplat(), valueIt);
+        $_attr.isSplat(), *valueIt);
     }
 
   public:
@@ -275,7 +282,7 @@ def ElementsAttrInterface : AttrInterface<"ElementsAttr"> {
     /// type `T`.
     template <typename T>
     auto value_begin() const {
-      return $_attr.value_begin_impl(OverloadToken<T>());
+      return *$_attr.try_value_begin_impl(OverloadToken<T>());
     }
 
     /// Return the elements of this attribute as a value of type 'T'.

diff  --git a/mlir/include/mlir/IR/BuiltinAttributes.td b/mlir/include/mlir/IR/BuiltinAttributes.td
index 2c4b92b6fa5fc..36a856afe90b1 100644
--- a/mlir/include/mlir/IR/BuiltinAttributes.td
+++ b/mlir/include/mlir/IR/BuiltinAttributes.td
@@ -215,13 +215,20 @@ def Builtin_DenseArray : Builtin_Attr<
     /// ElementsAttr implementation.
     using ContiguousIterableTypesT =
         std::tuple<bool, int8_t, int16_t, int32_t, int64_t, float, double>;
-    const bool *value_begin_impl(OverloadToken<bool>) const;
-    const int8_t *value_begin_impl(OverloadToken<int8_t>) const;
-    const int16_t *value_begin_impl(OverloadToken<int16_t>) const;
-    const int32_t *value_begin_impl(OverloadToken<int32_t>) const;
-    const int64_t *value_begin_impl(OverloadToken<int64_t>) const;
-    const float *value_begin_impl(OverloadToken<float>) const;
-    const double *value_begin_impl(OverloadToken<double>) const;
+    FailureOr<const bool *>
+    try_value_begin_impl(OverloadToken<bool>) const;
+    FailureOr<const int8_t *>
+    try_value_begin_impl(OverloadToken<int8_t>) const;
+    FailureOr<const int16_t *>
+    try_value_begin_impl(OverloadToken<int16_t>) const;
+    FailureOr<const int32_t *>
+    try_value_begin_impl(OverloadToken<int32_t>) const;
+    FailureOr<const int64_t *>
+    try_value_begin_impl(OverloadToken<int64_t>) const;
+    FailureOr<const float *>
+    try_value_begin_impl(OverloadToken<float>) const;
+    FailureOr<const double *>
+    try_value_begin_impl(OverloadToken<double>) const;
   }];
 
   let genVerifyDecl = 1;
@@ -292,10 +299,11 @@ def Builtin_DenseIntOrFPElementsAttr : Builtin_Attr<
       APFloat, std::complex<APFloat>
     >;
 
-    /// Provide a `value_begin_impl` to enable iteration within ElementsAttr.
+    /// Provide a `try_value_begin_impl` to enable iteration within
+    /// ElementsAttr.
     template <typename T>
-    auto value_begin_impl(OverloadToken<T>) const {
-      return value_begin<T>();
+    auto try_value_begin_impl(OverloadToken<T>) const {
+      return ::mlir::success(value_begin<T>());
     }
 
     /// Convert endianess of input ArrayRef for big-endian(BE) machines. All of
@@ -421,10 +429,11 @@ def Builtin_DenseStringElementsAttr : Builtin_Attr<
     using ContiguousIterableTypesT = std::tuple<StringRef>;
     using NonContiguousIterableTypesT = std::tuple<Attribute>;
 
-    /// Provide a `value_begin_impl` to enable iteration within ElementsAttr.
+    /// Provide a `try_value_begin_impl` to enable iteration within
+    /// ElementsAttr.
     template <typename T>
-    auto value_begin_impl(OverloadToken<T>) const {
-      return value_begin<T>();
+    auto try_value_begin_impl(OverloadToken<T>) const {
+      return ::mlir::success(value_begin<T>());
     }
 
   protected:
@@ -892,10 +901,11 @@ def Builtin_SparseElementsAttr : Builtin_Attr<
     >;
     using ElementsAttr::Trait<SparseElementsAttr>::getValues;
 
-    /// Provide a `value_begin_impl` to enable iteration within ElementsAttr.
+    /// Provide a `try_value_begin_impl` to enable iteration within
+    /// ElementsAttr.
     template <typename T>
-    auto value_begin_impl(OverloadToken<T>) const {
-      return value_begin<T>();
+    auto try_value_begin_impl(OverloadToken<T>) const {
+      return ::mlir::success(value_begin<T>());
     }
 
     template <typename T>

diff  --git a/mlir/include/mlir/Support/LogicalResult.h b/mlir/include/mlir/Support/LogicalResult.h
index 4967b921edc45..d603777f4394d 100644
--- a/mlir/include/mlir/Support/LogicalResult.h
+++ b/mlir/include/mlir/Support/LogicalResult.h
@@ -99,6 +99,13 @@ class [[nodiscard]] FailureOr : public Optional<T> {
   using Optional<T>::has_value;
 };
 
+/// Wrap a value on the success path in a FailureOr of the same value type.
+template <typename T,
+          typename = std::enable_if_t<!std::is_convertible_v<T, bool>>>
+inline auto success(T &&t) {
+  return FailureOr<std::decay_t<T>>(std::forward<T>(t));
+}
+
 /// This class represents success/failure for parsing-like operations that find
 /// it important to chain together failable operations with `||`.  This is an
 /// extended version of `LogicalResult` that allows for explicit conversion to

diff  --git a/mlir/lib/IR/BuiltinAttributes.cpp b/mlir/lib/IR/BuiltinAttributes.cpp
index 8d060ef233872..c911599eaab1d 100644
--- a/mlir/lib/IR/BuiltinAttributes.cpp
+++ b/mlir/lib/IR/BuiltinAttributes.cpp
@@ -763,26 +763,47 @@ DenseArrayAttr::verify(function_ref<InFlightDiagnostic()> emitError,
   return success();
 }
 
-const bool *DenseArrayAttr::value_begin_impl(OverloadToken<bool>) const {
-  return cast<DenseBoolArrayAttr>().asArrayRef().begin();
-}
-const int8_t *DenseArrayAttr::value_begin_impl(OverloadToken<int8_t>) const {
-  return cast<DenseI8ArrayAttr>().asArrayRef().begin();
-}
-const int16_t *DenseArrayAttr::value_begin_impl(OverloadToken<int16_t>) const {
-  return cast<DenseI16ArrayAttr>().asArrayRef().begin();
-}
-const int32_t *DenseArrayAttr::value_begin_impl(OverloadToken<int32_t>) const {
-  return cast<DenseI32ArrayAttr>().asArrayRef().begin();
-}
-const int64_t *DenseArrayAttr::value_begin_impl(OverloadToken<int64_t>) const {
-  return cast<DenseI64ArrayAttr>().asArrayRef().begin();
-}
-const float *DenseArrayAttr::value_begin_impl(OverloadToken<float>) const {
-  return cast<DenseF32ArrayAttr>().asArrayRef().begin();
-}
-const double *DenseArrayAttr::value_begin_impl(OverloadToken<double>) const {
-  return cast<DenseF64ArrayAttr>().asArrayRef().begin();
+FailureOr<const bool *>
+DenseArrayAttr::try_value_begin_impl(OverloadToken<bool>) const {
+  if (auto attr = dyn_cast<DenseBoolArrayAttr>())
+    return attr.asArrayRef().begin();
+  return failure();
+}
+FailureOr<const int8_t *>
+DenseArrayAttr::try_value_begin_impl(OverloadToken<int8_t>) const {
+  if (auto attr = dyn_cast<DenseI8ArrayAttr>())
+    return attr.asArrayRef().begin();
+  return failure();
+}
+FailureOr<const int16_t *>
+DenseArrayAttr::try_value_begin_impl(OverloadToken<int16_t>) const {
+  if (auto attr = dyn_cast<DenseI16ArrayAttr>())
+    return attr.asArrayRef().begin();
+  return failure();
+}
+FailureOr<const int32_t *>
+DenseArrayAttr::try_value_begin_impl(OverloadToken<int32_t>) const {
+  if (auto attr = dyn_cast<DenseI32ArrayAttr>())
+    return attr.asArrayRef().begin();
+  return failure();
+}
+FailureOr<const int64_t *>
+DenseArrayAttr::try_value_begin_impl(OverloadToken<int64_t>) const {
+  if (auto attr = dyn_cast<DenseI64ArrayAttr>())
+    return attr.asArrayRef().begin();
+  return failure();
+}
+FailureOr<const float *>
+DenseArrayAttr::try_value_begin_impl(OverloadToken<float>) const {
+  if (auto attr = dyn_cast<DenseF32ArrayAttr>())
+    return attr.asArrayRef().begin();
+  return failure();
+}
+FailureOr<const double *>
+DenseArrayAttr::try_value_begin_impl(OverloadToken<double>) const {
+  if (auto attr = dyn_cast<DenseF64ArrayAttr>())
+    return attr.asArrayRef().begin();
+  return failure();
 }
 
 namespace {

diff  --git a/mlir/test/IR/elements-attr-interface.mlir b/mlir/test/IR/elements-attr-interface.mlir
index 81fa495b45c7b..38b2b8aebb8a7 100644
--- a/mlir/test/IR/elements-attr-interface.mlir
+++ b/mlir/test/IR/elements-attr-interface.mlir
@@ -28,18 +28,24 @@ arith.constant dense<[10, 11, 12, 13, 14]> : tensor<5xi64>
 arith.constant dense<> : tensor<0xi64>
 
 // expected-error at below {{Test iterating `bool`: true, false, true, false, true, false}}
+// expected-error at below {{Test iterating `int64_t`: unable to iterate type}}
 arith.constant array<i1: true, false, true, false, true, false>
 // expected-error at below {{Test iterating `int8_t`: 10, 11, -12, 13, 14}}
+// expected-error at below {{Test iterating `int64_t`: unable to iterate type}}
 arith.constant array<i8: 10, 11, -12, 13, 14>
 // expected-error at below {{Test iterating `int16_t`: 10, 11, -12, 13, 14}}
+// expected-error at below {{Test iterating `int64_t`: unable to iterate type}}
 arith.constant array<i16: 10, 11, -12, 13, 14>
 // expected-error at below {{Test iterating `int32_t`: 10, 11, -12, 13, 14}}
+// expected-error at below {{Test iterating `int64_t`: unable to iterate type}}
 arith.constant array<i32: 10, 11, -12, 13, 14>
 // expected-error at below {{Test iterating `int64_t`: 10, 11, -12, 13, 14}}
 arith.constant array<i64: 10, 11, -12, 13, 14>
 // expected-error at below {{Test iterating `float`: 10.00, 11.00, -12.00, 13.00, 14.00}}
+// expected-error at below {{Test iterating `int64_t`: unable to iterate type}}
 arith.constant array<f32: 10., 11., -12., 13., 14.>
 // expected-error at below {{Test iterating `double`: 10.00, 11.00, -12.00, 13.00, 14.00}}
+// expected-error at below {{Test iterating `int64_t`: unable to iterate type}}
 arith.constant array<f64: 10., 11., -12., 13., 14.>
 
 // Check that we handle an external constant parsed from the config.

diff  --git a/mlir/test/lib/Dialect/Test/TestAttrDefs.td b/mlir/test/lib/Dialect/Test/TestAttrDefs.td
index 6bd1ad236d39e..1a7ccdc0694c3 100644
--- a/mlir/test/lib/Dialect/Test/TestAttrDefs.td
+++ b/mlir/test/lib/Dialect/Test/TestAttrDefs.td
@@ -94,22 +94,23 @@ def TestI64ElementsAttr : Test_Attr<"TestI64Elements", [
 
     /// Provide begin iterators for the various iterable types.
     // * uint64_t
-    auto value_begin_impl(OverloadToken<uint64_t>) const {
+    mlir::FailureOr<const uint64_t *>
+    try_value_begin_impl(OverloadToken<uint64_t>) const {
       return getElements().begin();
     }
     // * Attribute
-    auto value_begin_impl(OverloadToken<mlir::Attribute>) const {
+    auto try_value_begin_impl(OverloadToken<mlir::Attribute>) const {
       mlir::Type elementType = getType().getElementType();
-      return llvm::map_range(getElements(), [=](uint64_t value) {
+      return mlir::success(llvm::map_range(getElements(), [=](uint64_t value) {
         return mlir::IntegerAttr::get(elementType,
                                       llvm::APInt(/*numBits=*/64, value));
-      }).begin();
+      }).begin());
     }
     // * APInt
-    auto value_begin_impl(OverloadToken<llvm::APInt>) const {
-      return llvm::map_range(getElements(), [=](uint64_t value) {
+    auto try_value_begin_impl(OverloadToken<llvm::APInt>) const {
+      return mlir::success(llvm::map_range(getElements(), [=](uint64_t value) {
         return llvm::APInt(/*numBits=*/64, value);
-      }).begin();
+      }).begin());
     }
   }];
   let genVerifyDecl = 1;
@@ -257,7 +258,8 @@ def TestExtern1DI64ElementsAttr : Test_Attr<"TestExtern1DI64Elements", [
 
     /// Provide begin iterators for the various iterable types.
     // * uint64_t
-    auto value_begin_impl(OverloadToken<uint64_t>) const {
+    mlir::FailureOr<const uint64_t *>
+    try_value_begin_impl(OverloadToken<uint64_t>) const {
       return getElements().begin();
     }
   }];

diff  --git a/mlir/test/lib/IR/TestBuiltinAttributeInterfaces.cpp b/mlir/test/lib/IR/TestBuiltinAttributeInterfaces.cpp
index 23fde121682cc..cbef0bca2494d 100644
--- a/mlir/test/lib/IR/TestBuiltinAttributeInterfaces.cpp
+++ b/mlir/test/lib/IR/TestBuiltinAttributeInterfaces.cpp
@@ -64,6 +64,7 @@ struct TestElementsAttrInterface
               .Case([&](DenseF64ArrayAttr attr) {
                 testElementsAttrIteration<double>(op, attr, "double");
               });
+          testElementsAttrIteration<int64_t>(op, elementsAttr, "int64_t");
           continue;
         }
         testElementsAttrIteration<int64_t>(op, elementsAttr, "int64_t");


        


More information about the Mlir-commits mailing list