[Mlir-commits] [mlir] 910fff1 - [mlir][DenseStringAttr] Fix support for splat detection and iteration
River Riddle
llvmlistbot at llvm.org
Sun Apr 26 13:58:12 PDT 2020
Author: River Riddle
Date: 2020-04-26T13:53:57-07:00
New Revision: 910fff1c1dd182243fbbf867a402bc1a48281618
URL: https://github.com/llvm/llvm-project/commit/910fff1c1dd182243fbbf867a402bc1a48281618
DIFF: https://github.com/llvm/llvm-project/commit/910fff1c1dd182243fbbf867a402bc1a48281618.diff
LOG: [mlir][DenseStringAttr] Fix support for splat detection and iteration
This revision also adds proper tests for splat detection.
Added:
Modified:
mlir/include/mlir/IR/Attributes.h
mlir/lib/IR/AttributeDetail.h
mlir/unittests/IR/AttributeTest.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/IR/Attributes.h b/mlir/include/mlir/IR/Attributes.h
index b4f1d025be85..477494661972 100644
--- a/mlir/include/mlir/IR/Attributes.h
+++ b/mlir/include/mlir/IR/Attributes.h
@@ -896,6 +896,8 @@ class DenseElementsAttr : public ElementsAttr {
ElementIterator<T>(rawData, splat, getNumElements())};
}
+ template <typename T, typename = typename std::enable_if<
+ std::is_same<T, StringRef>::value>::type>
llvm::iterator_range<ElementIterator<StringRef>> getValues() const {
auto stringRefs = getRawStringData();
const char *ptr = reinterpret_cast<const char *>(stringRefs.data());
diff --git a/mlir/lib/IR/AttributeDetail.h b/mlir/lib/IR/AttributeDetail.h
index 011f4d7650f4..201f058571af 100644
--- a/mlir/lib/IR/AttributeDetail.h
+++ b/mlir/lib/IR/AttributeDetail.h
@@ -594,7 +594,7 @@ struct DenseStringElementsAttributeStorage
// If the data is already known to be a splat, the key hash value is
// directly the data buffer.
if (isKnownSplat)
- return KeyTy(ty, data, llvm::hash_value(data), isKnownSplat);
+ return KeyTy(ty, data, llvm::hash_value(data.front()), isKnownSplat);
// Handle the simple case of only one element.
assert(ty.getNumElements() != 1 &&
@@ -610,8 +610,8 @@ struct DenseStringElementsAttributeStorage
if (!firstElt.equals(data[i]))
return KeyTy(ty, data, llvm::hash_combine(hashVal, data.drop_front(i)));
- // Otherwise, this is a splat.
- return KeyTy(ty, data, hashVal, /*isSplat=*/true);
+ // Otherwise, this is a splat so just return the hash of the first element.
+ return KeyTy(ty, data.take_front(), hashVal, /*isSplat=*/true);
}
/// Hash the key for the storage.
diff --git a/mlir/unittests/IR/AttributeTest.cpp b/mlir/unittests/IR/AttributeTest.cpp
index 122c1787e301..ad4b422eae91 100644
--- a/mlir/unittests/IR/AttributeTest.cpp
+++ b/mlir/unittests/IR/AttributeTest.cpp
@@ -7,6 +7,7 @@
//===----------------------------------------------------------------------===//
#include "mlir/IR/Attributes.h"
+#include "mlir/IR/Identifier.h"
#include "mlir/IR/StandardTypes.h"
#include "gtest/gtest.h"
@@ -15,7 +16,7 @@ using namespace mlir::detail;
template <typename EltTy>
static void testSplat(Type eltType, const EltTy &splatElt) {
- VectorType shape = VectorType::get({2, 1}, eltType);
+ RankedTensorType shape = RankedTensorType::get({2, 1}, eltType);
// Check that the generated splat is the same for 1 element and N elements.
DenseElementsAttr splat = DenseElementsAttr::get(shape, splatElt);
@@ -30,7 +31,7 @@ namespace {
TEST(DenseSplatTest, BoolSplat) {
MLIRContext context;
IntegerType boolTy = IntegerType::get(1, &context);
- VectorType shape = VectorType::get({2, 2}, boolTy);
+ RankedTensorType shape = RankedTensorType::get({2, 2}, boolTy);
// Check that splat is automatically detected for boolean values.
/// True.
@@ -55,7 +56,7 @@ TEST(DenseSplatTest, LargeBoolSplat) {
MLIRContext context;
IntegerType boolTy = IntegerType::get(1, &context);
- VectorType shape = VectorType::get({boolCount}, boolTy);
+ RankedTensorType shape = RankedTensorType::get({boolCount}, boolTy);
// Check that splat is automatically detected for boolean values.
/// True.
@@ -78,7 +79,7 @@ TEST(DenseSplatTest, LargeBoolSplat) {
TEST(DenseSplatTest, BoolNonSplat) {
MLIRContext context;
IntegerType boolTy = IntegerType::get(1, &context);
- VectorType shape = VectorType::get({6}, boolTy);
+ RankedTensorType shape = RankedTensorType::get({6}, boolTy);
// Check that we properly handle non-splat values.
DenseElementsAttr nonSplat =
@@ -145,4 +146,12 @@ TEST(DenseSplatTest, BF16Splat) {
testSplat(floatTy, value);
}
+TEST(DenseSplatTest, StringSplat) {
+ MLIRContext context;
+ Type stringType =
+ OpaqueType::get(Identifier::get("test", &context), "string", &context);
+ StringRef value = "test-string";
+ testSplat(stringType, value);
+}
+
} // end namespace
More information about the Mlir-commits
mailing list