[Mlir-commits] [mlir] 910fff1 - [mlir][DenseStringAttr] Fix support for splat detection and iteration

River Riddle llvmlistbot at llvm.org
Sun Apr 26 13:58:12 PDT 2020


Author: River Riddle
Date: 2020-04-26T13:53:57-07:00
New Revision: 910fff1c1dd182243fbbf867a402bc1a48281618

URL: https://github.com/llvm/llvm-project/commit/910fff1c1dd182243fbbf867a402bc1a48281618
DIFF: https://github.com/llvm/llvm-project/commit/910fff1c1dd182243fbbf867a402bc1a48281618.diff

LOG: [mlir][DenseStringAttr] Fix support for splat detection and iteration

This revision also adds proper tests for splat detection.

Added: 
    

Modified: 
    mlir/include/mlir/IR/Attributes.h
    mlir/lib/IR/AttributeDetail.h
    mlir/unittests/IR/AttributeTest.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/IR/Attributes.h b/mlir/include/mlir/IR/Attributes.h
index b4f1d025be85..477494661972 100644
--- a/mlir/include/mlir/IR/Attributes.h
+++ b/mlir/include/mlir/IR/Attributes.h
@@ -896,6 +896,8 @@ class DenseElementsAttr : public ElementsAttr {
             ElementIterator<T>(rawData, splat, getNumElements())};
   }
 
+  template <typename T, typename = typename std::enable_if<
+                            std::is_same<T, StringRef>::value>::type>
   llvm::iterator_range<ElementIterator<StringRef>> getValues() const {
     auto stringRefs = getRawStringData();
     const char *ptr = reinterpret_cast<const char *>(stringRefs.data());

diff  --git a/mlir/lib/IR/AttributeDetail.h b/mlir/lib/IR/AttributeDetail.h
index 011f4d7650f4..201f058571af 100644
--- a/mlir/lib/IR/AttributeDetail.h
+++ b/mlir/lib/IR/AttributeDetail.h
@@ -594,7 +594,7 @@ struct DenseStringElementsAttributeStorage
     // If the data is already known to be a splat, the key hash value is
     // directly the data buffer.
     if (isKnownSplat)
-      return KeyTy(ty, data, llvm::hash_value(data), isKnownSplat);
+      return KeyTy(ty, data, llvm::hash_value(data.front()), isKnownSplat);
 
     // Handle the simple case of only one element.
     assert(ty.getNumElements() != 1 &&
@@ -610,8 +610,8 @@ struct DenseStringElementsAttributeStorage
       if (!firstElt.equals(data[i]))
         return KeyTy(ty, data, llvm::hash_combine(hashVal, data.drop_front(i)));
 
-    // Otherwise, this is a splat.
-    return KeyTy(ty, data, hashVal, /*isSplat=*/true);
+    // Otherwise, this is a splat so just return the hash of the first element.
+    return KeyTy(ty, data.take_front(), hashVal, /*isSplat=*/true);
   }
 
   /// Hash the key for the storage.

diff  --git a/mlir/unittests/IR/AttributeTest.cpp b/mlir/unittests/IR/AttributeTest.cpp
index 122c1787e301..ad4b422eae91 100644
--- a/mlir/unittests/IR/AttributeTest.cpp
+++ b/mlir/unittests/IR/AttributeTest.cpp
@@ -7,6 +7,7 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/IR/Attributes.h"
+#include "mlir/IR/Identifier.h"
 #include "mlir/IR/StandardTypes.h"
 #include "gtest/gtest.h"
 
@@ -15,7 +16,7 @@ using namespace mlir::detail;
 
 template <typename EltTy>
 static void testSplat(Type eltType, const EltTy &splatElt) {
-  VectorType shape = VectorType::get({2, 1}, eltType);
+  RankedTensorType shape = RankedTensorType::get({2, 1}, eltType);
 
   // Check that the generated splat is the same for 1 element and N elements.
   DenseElementsAttr splat = DenseElementsAttr::get(shape, splatElt);
@@ -30,7 +31,7 @@ namespace {
 TEST(DenseSplatTest, BoolSplat) {
   MLIRContext context;
   IntegerType boolTy = IntegerType::get(1, &context);
-  VectorType shape = VectorType::get({2, 2}, boolTy);
+  RankedTensorType shape = RankedTensorType::get({2, 2}, boolTy);
 
   // Check that splat is automatically detected for boolean values.
   /// True.
@@ -55,7 +56,7 @@ TEST(DenseSplatTest, LargeBoolSplat) {
 
   MLIRContext context;
   IntegerType boolTy = IntegerType::get(1, &context);
-  VectorType shape = VectorType::get({boolCount}, boolTy);
+  RankedTensorType shape = RankedTensorType::get({boolCount}, boolTy);
 
   // Check that splat is automatically detected for boolean values.
   /// True.
@@ -78,7 +79,7 @@ TEST(DenseSplatTest, LargeBoolSplat) {
 TEST(DenseSplatTest, BoolNonSplat) {
   MLIRContext context;
   IntegerType boolTy = IntegerType::get(1, &context);
-  VectorType shape = VectorType::get({6}, boolTy);
+  RankedTensorType shape = RankedTensorType::get({6}, boolTy);
 
   // Check that we properly handle non-splat values.
   DenseElementsAttr nonSplat =
@@ -145,4 +146,12 @@ TEST(DenseSplatTest, BF16Splat) {
   testSplat(floatTy, value);
 }
 
+TEST(DenseSplatTest, StringSplat) {
+  MLIRContext context;
+  Type stringType =
+      OpaqueType::get(Identifier::get("test", &context), "string", &context);
+  StringRef value = "test-string";
+  testSplat(stringType, value);
+}
+
 } // end namespace


        


More information about the Mlir-commits mailing list