[Mlir-commits] [mlir] 5cf5708 - [mlir][ElementsAttr] Change value_begin_impl to try_value_begin_impl
Jeff Niu
llvmlistbot at llvm.org
Tue Aug 30 14:13:01 PDT 2022
Author: Jeff Niu
Date: 2022-08-30T14:12:46-07:00
New Revision: 5cf5708628a3a465d7845fb43e46bd7f97761ead
URL: https://github.com/llvm/llvm-project/commit/5cf5708628a3a465d7845fb43e46bd7f97761ead
DIFF: https://github.com/llvm/llvm-project/commit/5cf5708628a3a465d7845fb43e46bd7f97761ead.diff
LOG: [mlir][ElementsAttr] Change value_begin_impl to try_value_begin_impl
This patch changes `value_begin_impl` to a faillable
`try_value_begin_impl` so that specific cases can fail iteration if the
type doesn't match the internal storage.
Reviewed By: rriddle
Differential Revision: https://reviews.llvm.org/D132904
Added:
Modified:
mlir/include/mlir/IR/BuiltinAttributeInterfaces.td
mlir/include/mlir/IR/BuiltinAttributes.td
mlir/include/mlir/Support/LogicalResult.h
mlir/lib/IR/BuiltinAttributes.cpp
mlir/test/IR/elements-attr-interface.mlir
mlir/test/lib/Dialect/Test/TestAttrDefs.td
mlir/test/lib/IR/TestBuiltinAttributeInterfaces.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/IR/BuiltinAttributeInterfaces.td b/mlir/include/mlir/IR/BuiltinAttributeInterfaces.td
index 58dad618d3cee..a9e6e0e12051a 100644
--- a/mlir/include/mlir/IR/BuiltinAttributeInterfaces.td
+++ b/mlir/include/mlir/IR/BuiltinAttributeInterfaces.td
@@ -54,33 +54,36 @@ def ElementsAttrInterface : AttrInterface<"ElementsAttr"> {
using NonContiguousIterableTypesT = std::tuple<APInt, Attribute>;
```
- * Provide a `iterator value_begin_impl(OverloadToken<T>) const` overload for
- each iterable type
+ * Provide a `FailureOr<iterator> try_value_begin_impl(OverloadToken<T>) const`
+ overload for each iterable type
These overloads should return an iterator to the start of the range for the
- respective iterable type. Consider the example i64 elements attribute
- described in the previous section. This attribute may define the
- value_begin_impl overloads like so:
+ respective iterable type or fail if the type cannot be iterated. Consider
+ the example i64 elements attribute described in the previous section. This
+ attribute may define the value_begin_impl overloads like so:
```c++
/// Provide begin iterators for the various iterable types.
/// * uint64_t
- auto value_begin_impl(OverloadToken<uint64_t>) const {
+ FailureOr<const uint64_t *>
+ value_begin_impl(OverloadToken<uint64_t>) const {
return getElements().begin();
}
/// * APInt
auto value_begin_impl(OverloadToken<llvm::APInt>) const {
- return llvm::map_range(getElements(), [=](uint64_t value) {
+ auto it = llvm::map_range(getElements(), [=](uint64_t value) {
return llvm::APInt(/*numBits=*/64, value);
}).begin();
+ return FailureOr<decltype(it)>(std::move(it));
}
/// * Attribute
auto value_begin_impl(OverloadToken<mlir::Attribute>) const {
mlir::Type elementType = getType().getElementType();
- return llvm::map_range(getElements(), [=](uint64_t value) {
+ auto it = llvm::map_range(getElements(), [=](uint64_t value) {
return mlir::IntegerAttr::get(elementType,
llvm::APInt(/*numBits=*/64, value));
}).begin();
+ return FailureOr<decltype(it)>(std::move(it));
}
```
@@ -244,18 +247,22 @@ def ElementsAttrInterface : AttrInterface<"ElementsAttr"> {
/*isSplat=*/false, nullptr);
}
- auto valueIt = $_attr.value_begin_impl(OverloadToken<T>());
+ auto valueIt = $_attr.try_value_begin_impl(OverloadToken<T>());
+ if (::mlir::failed(valueIt))
+ return ::mlir::failure();
return ::mlir::detail::ElementsAttrIndexer::contiguous(
- $_attr.isSplat(), &*valueIt);
+ $_attr.isSplat(), &**valueIt);
}
/// Build an indexer for the given type `T`, which is represented via a
/// non-contiguous range.
template <typename T>
::mlir::FailureOr<::mlir::detail::ElementsAttrIndexer> buildValueResult(
/*isContiguous*/std::false_type) const {
- auto valueIt = $_attr.value_begin_impl(OverloadToken<T>());
+ auto valueIt = $_attr.try_value_begin_impl(OverloadToken<T>());
+ if (::mlir::failed(valueIt))
+ return ::mlir::failure();
return ::mlir::detail::ElementsAttrIndexer::nonContiguous(
- $_attr.isSplat(), valueIt);
+ $_attr.isSplat(), *valueIt);
}
public:
@@ -275,7 +282,7 @@ def ElementsAttrInterface : AttrInterface<"ElementsAttr"> {
/// type `T`.
template <typename T>
auto value_begin() const {
- return $_attr.value_begin_impl(OverloadToken<T>());
+ return *$_attr.try_value_begin_impl(OverloadToken<T>());
}
/// Return the elements of this attribute as a value of type 'T'.
diff --git a/mlir/include/mlir/IR/BuiltinAttributes.td b/mlir/include/mlir/IR/BuiltinAttributes.td
index 2c4b92b6fa5fc..36a856afe90b1 100644
--- a/mlir/include/mlir/IR/BuiltinAttributes.td
+++ b/mlir/include/mlir/IR/BuiltinAttributes.td
@@ -215,13 +215,20 @@ def Builtin_DenseArray : Builtin_Attr<
/// ElementsAttr implementation.
using ContiguousIterableTypesT =
std::tuple<bool, int8_t, int16_t, int32_t, int64_t, float, double>;
- const bool *value_begin_impl(OverloadToken<bool>) const;
- const int8_t *value_begin_impl(OverloadToken<int8_t>) const;
- const int16_t *value_begin_impl(OverloadToken<int16_t>) const;
- const int32_t *value_begin_impl(OverloadToken<int32_t>) const;
- const int64_t *value_begin_impl(OverloadToken<int64_t>) const;
- const float *value_begin_impl(OverloadToken<float>) const;
- const double *value_begin_impl(OverloadToken<double>) const;
+ FailureOr<const bool *>
+ try_value_begin_impl(OverloadToken<bool>) const;
+ FailureOr<const int8_t *>
+ try_value_begin_impl(OverloadToken<int8_t>) const;
+ FailureOr<const int16_t *>
+ try_value_begin_impl(OverloadToken<int16_t>) const;
+ FailureOr<const int32_t *>
+ try_value_begin_impl(OverloadToken<int32_t>) const;
+ FailureOr<const int64_t *>
+ try_value_begin_impl(OverloadToken<int64_t>) const;
+ FailureOr<const float *>
+ try_value_begin_impl(OverloadToken<float>) const;
+ FailureOr<const double *>
+ try_value_begin_impl(OverloadToken<double>) const;
}];
let genVerifyDecl = 1;
@@ -292,10 +299,11 @@ def Builtin_DenseIntOrFPElementsAttr : Builtin_Attr<
APFloat, std::complex<APFloat>
>;
- /// Provide a `value_begin_impl` to enable iteration within ElementsAttr.
+ /// Provide a `try_value_begin_impl` to enable iteration within
+ /// ElementsAttr.
template <typename T>
- auto value_begin_impl(OverloadToken<T>) const {
- return value_begin<T>();
+ auto try_value_begin_impl(OverloadToken<T>) const {
+ return ::mlir::success(value_begin<T>());
}
/// Convert endianess of input ArrayRef for big-endian(BE) machines. All of
@@ -421,10 +429,11 @@ def Builtin_DenseStringElementsAttr : Builtin_Attr<
using ContiguousIterableTypesT = std::tuple<StringRef>;
using NonContiguousIterableTypesT = std::tuple<Attribute>;
- /// Provide a `value_begin_impl` to enable iteration within ElementsAttr.
+ /// Provide a `try_value_begin_impl` to enable iteration within
+ /// ElementsAttr.
template <typename T>
- auto value_begin_impl(OverloadToken<T>) const {
- return value_begin<T>();
+ auto try_value_begin_impl(OverloadToken<T>) const {
+ return ::mlir::success(value_begin<T>());
}
protected:
@@ -892,10 +901,11 @@ def Builtin_SparseElementsAttr : Builtin_Attr<
>;
using ElementsAttr::Trait<SparseElementsAttr>::getValues;
- /// Provide a `value_begin_impl` to enable iteration within ElementsAttr.
+ /// Provide a `try_value_begin_impl` to enable iteration within
+ /// ElementsAttr.
template <typename T>
- auto value_begin_impl(OverloadToken<T>) const {
- return value_begin<T>();
+ auto try_value_begin_impl(OverloadToken<T>) const {
+ return ::mlir::success(value_begin<T>());
}
template <typename T>
diff --git a/mlir/include/mlir/Support/LogicalResult.h b/mlir/include/mlir/Support/LogicalResult.h
index 4967b921edc45..d603777f4394d 100644
--- a/mlir/include/mlir/Support/LogicalResult.h
+++ b/mlir/include/mlir/Support/LogicalResult.h
@@ -99,6 +99,13 @@ class [[nodiscard]] FailureOr : public Optional<T> {
using Optional<T>::has_value;
};
+/// Wrap a value on the success path in a FailureOr of the same value type.
+template <typename T,
+ typename = std::enable_if_t<!std::is_convertible_v<T, bool>>>
+inline auto success(T &&t) {
+ return FailureOr<std::decay_t<T>>(std::forward<T>(t));
+}
+
/// This class represents success/failure for parsing-like operations that find
/// it important to chain together failable operations with `||`. This is an
/// extended version of `LogicalResult` that allows for explicit conversion to
diff --git a/mlir/lib/IR/BuiltinAttributes.cpp b/mlir/lib/IR/BuiltinAttributes.cpp
index 8d060ef233872..c911599eaab1d 100644
--- a/mlir/lib/IR/BuiltinAttributes.cpp
+++ b/mlir/lib/IR/BuiltinAttributes.cpp
@@ -763,26 +763,47 @@ DenseArrayAttr::verify(function_ref<InFlightDiagnostic()> emitError,
return success();
}
-const bool *DenseArrayAttr::value_begin_impl(OverloadToken<bool>) const {
- return cast<DenseBoolArrayAttr>().asArrayRef().begin();
-}
-const int8_t *DenseArrayAttr::value_begin_impl(OverloadToken<int8_t>) const {
- return cast<DenseI8ArrayAttr>().asArrayRef().begin();
-}
-const int16_t *DenseArrayAttr::value_begin_impl(OverloadToken<int16_t>) const {
- return cast<DenseI16ArrayAttr>().asArrayRef().begin();
-}
-const int32_t *DenseArrayAttr::value_begin_impl(OverloadToken<int32_t>) const {
- return cast<DenseI32ArrayAttr>().asArrayRef().begin();
-}
-const int64_t *DenseArrayAttr::value_begin_impl(OverloadToken<int64_t>) const {
- return cast<DenseI64ArrayAttr>().asArrayRef().begin();
-}
-const float *DenseArrayAttr::value_begin_impl(OverloadToken<float>) const {
- return cast<DenseF32ArrayAttr>().asArrayRef().begin();
-}
-const double *DenseArrayAttr::value_begin_impl(OverloadToken<double>) const {
- return cast<DenseF64ArrayAttr>().asArrayRef().begin();
+FailureOr<const bool *>
+DenseArrayAttr::try_value_begin_impl(OverloadToken<bool>) const {
+ if (auto attr = dyn_cast<DenseBoolArrayAttr>())
+ return attr.asArrayRef().begin();
+ return failure();
+}
+FailureOr<const int8_t *>
+DenseArrayAttr::try_value_begin_impl(OverloadToken<int8_t>) const {
+ if (auto attr = dyn_cast<DenseI8ArrayAttr>())
+ return attr.asArrayRef().begin();
+ return failure();
+}
+FailureOr<const int16_t *>
+DenseArrayAttr::try_value_begin_impl(OverloadToken<int16_t>) const {
+ if (auto attr = dyn_cast<DenseI16ArrayAttr>())
+ return attr.asArrayRef().begin();
+ return failure();
+}
+FailureOr<const int32_t *>
+DenseArrayAttr::try_value_begin_impl(OverloadToken<int32_t>) const {
+ if (auto attr = dyn_cast<DenseI32ArrayAttr>())
+ return attr.asArrayRef().begin();
+ return failure();
+}
+FailureOr<const int64_t *>
+DenseArrayAttr::try_value_begin_impl(OverloadToken<int64_t>) const {
+ if (auto attr = dyn_cast<DenseI64ArrayAttr>())
+ return attr.asArrayRef().begin();
+ return failure();
+}
+FailureOr<const float *>
+DenseArrayAttr::try_value_begin_impl(OverloadToken<float>) const {
+ if (auto attr = dyn_cast<DenseF32ArrayAttr>())
+ return attr.asArrayRef().begin();
+ return failure();
+}
+FailureOr<const double *>
+DenseArrayAttr::try_value_begin_impl(OverloadToken<double>) const {
+ if (auto attr = dyn_cast<DenseF64ArrayAttr>())
+ return attr.asArrayRef().begin();
+ return failure();
}
namespace {
diff --git a/mlir/test/IR/elements-attr-interface.mlir b/mlir/test/IR/elements-attr-interface.mlir
index 81fa495b45c7b..38b2b8aebb8a7 100644
--- a/mlir/test/IR/elements-attr-interface.mlir
+++ b/mlir/test/IR/elements-attr-interface.mlir
@@ -28,18 +28,24 @@ arith.constant dense<[10, 11, 12, 13, 14]> : tensor<5xi64>
arith.constant dense<> : tensor<0xi64>
// expected-error at below {{Test iterating `bool`: true, false, true, false, true, false}}
+// expected-error at below {{Test iterating `int64_t`: unable to iterate type}}
arith.constant array<i1: true, false, true, false, true, false>
// expected-error at below {{Test iterating `int8_t`: 10, 11, -12, 13, 14}}
+// expected-error at below {{Test iterating `int64_t`: unable to iterate type}}
arith.constant array<i8: 10, 11, -12, 13, 14>
// expected-error at below {{Test iterating `int16_t`: 10, 11, -12, 13, 14}}
+// expected-error at below {{Test iterating `int64_t`: unable to iterate type}}
arith.constant array<i16: 10, 11, -12, 13, 14>
// expected-error at below {{Test iterating `int32_t`: 10, 11, -12, 13, 14}}
+// expected-error at below {{Test iterating `int64_t`: unable to iterate type}}
arith.constant array<i32: 10, 11, -12, 13, 14>
// expected-error at below {{Test iterating `int64_t`: 10, 11, -12, 13, 14}}
arith.constant array<i64: 10, 11, -12, 13, 14>
// expected-error at below {{Test iterating `float`: 10.00, 11.00, -12.00, 13.00, 14.00}}
+// expected-error at below {{Test iterating `int64_t`: unable to iterate type}}
arith.constant array<f32: 10., 11., -12., 13., 14.>
// expected-error at below {{Test iterating `double`: 10.00, 11.00, -12.00, 13.00, 14.00}}
+// expected-error at below {{Test iterating `int64_t`: unable to iterate type}}
arith.constant array<f64: 10., 11., -12., 13., 14.>
// Check that we handle an external constant parsed from the config.
diff --git a/mlir/test/lib/Dialect/Test/TestAttrDefs.td b/mlir/test/lib/Dialect/Test/TestAttrDefs.td
index 6bd1ad236d39e..1a7ccdc0694c3 100644
--- a/mlir/test/lib/Dialect/Test/TestAttrDefs.td
+++ b/mlir/test/lib/Dialect/Test/TestAttrDefs.td
@@ -94,22 +94,23 @@ def TestI64ElementsAttr : Test_Attr<"TestI64Elements", [
/// Provide begin iterators for the various iterable types.
// * uint64_t
- auto value_begin_impl(OverloadToken<uint64_t>) const {
+ mlir::FailureOr<const uint64_t *>
+ try_value_begin_impl(OverloadToken<uint64_t>) const {
return getElements().begin();
}
// * Attribute
- auto value_begin_impl(OverloadToken<mlir::Attribute>) const {
+ auto try_value_begin_impl(OverloadToken<mlir::Attribute>) const {
mlir::Type elementType = getType().getElementType();
- return llvm::map_range(getElements(), [=](uint64_t value) {
+ return mlir::success(llvm::map_range(getElements(), [=](uint64_t value) {
return mlir::IntegerAttr::get(elementType,
llvm::APInt(/*numBits=*/64, value));
- }).begin();
+ }).begin());
}
// * APInt
- auto value_begin_impl(OverloadToken<llvm::APInt>) const {
- return llvm::map_range(getElements(), [=](uint64_t value) {
+ auto try_value_begin_impl(OverloadToken<llvm::APInt>) const {
+ return mlir::success(llvm::map_range(getElements(), [=](uint64_t value) {
return llvm::APInt(/*numBits=*/64, value);
- }).begin();
+ }).begin());
}
}];
let genVerifyDecl = 1;
@@ -257,7 +258,8 @@ def TestExtern1DI64ElementsAttr : Test_Attr<"TestExtern1DI64Elements", [
/// Provide begin iterators for the various iterable types.
// * uint64_t
- auto value_begin_impl(OverloadToken<uint64_t>) const {
+ mlir::FailureOr<const uint64_t *>
+ try_value_begin_impl(OverloadToken<uint64_t>) const {
return getElements().begin();
}
}];
diff --git a/mlir/test/lib/IR/TestBuiltinAttributeInterfaces.cpp b/mlir/test/lib/IR/TestBuiltinAttributeInterfaces.cpp
index 23fde121682cc..cbef0bca2494d 100644
--- a/mlir/test/lib/IR/TestBuiltinAttributeInterfaces.cpp
+++ b/mlir/test/lib/IR/TestBuiltinAttributeInterfaces.cpp
@@ -64,6 +64,7 @@ struct TestElementsAttrInterface
.Case([&](DenseF64ArrayAttr attr) {
testElementsAttrIteration<double>(op, attr, "double");
});
+ testElementsAttrIteration<int64_t>(op, elementsAttr, "int64_t");
continue;
}
testElementsAttrIteration<int64_t>(op, elementsAttr, "int64_t");
More information about the Mlir-commits
mailing list