[Mlir-commits] [mlir] 9e0900c - [mlir] Fix DenseElementsAttr treatment of bool splat of "true"

River Riddle llvmlistbot at llvm.org
Tue Sep 13 11:39:45 PDT 2022


Author: River Riddle
Date: 2022-09-13T11:39:20-07:00
New Revision: 9e0900cbf1cbc2b8366df66a562bf031c8a2b8db

URL: https://github.com/llvm/llvm-project/commit/9e0900cbf1cbc2b8366df66a562bf031c8a2b8db
DIFF: https://github.com/llvm/llvm-project/commit/9e0900cbf1cbc2b8366df66a562bf031c8a2b8db.diff

LOG: [mlir] Fix DenseElementsAttr treatment of bool splat of "true"

Boolean splats currently can't roundtrip via the "raw" DenseElementsAttr
API. This is because internally we treat true splats in some cases as "1"(one bit set)
and in other cases as "0xFF"(all bits set). This commit cleans up this handling to
consistently use 0xFF (all bits set) as the value for a splat of true.

Differential Revision: https://reviews.llvm.org/D133743

Added: 
    

Modified: 
    mlir/lib/IR/AttributeDetail.h
    mlir/unittests/IR/AttributeTest.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/IR/AttributeDetail.h b/mlir/lib/IR/AttributeDetail.h
index ced9dcf6c7b7c..bc336dcc121e9 100644
--- a/mlir/lib/IR/AttributeDetail.h
+++ b/mlir/lib/IR/AttributeDetail.h
@@ -76,22 +76,7 @@ struct DenseIntOrFPElementsAttrStorage : public DenseElementsAttributeStorage {
 
   /// Compare this storage instance with the provided key.
   bool operator==(const KeyTy &key) const {
-    if (key.type != type)
-      return false;
-
-    // For boolean splats we need to explicitly check that the first bit is the
-    // same. Boolean values are packed at the bit level, and even though a splat
-    // is detected the rest of the bits in the first byte may 
diff er from the
-    // splat value.
-    if (key.type.getElementType().isInteger(1)) {
-      if (key.isSplat != isSplat)
-        return false;
-      if (isSplat)
-        return (key.data.front() & 1) == data.front();
-    }
-
-    // Otherwise, we can default to just checking the data.
-    return key.data == data;
+    return key.type == type && key.data == data;
   }
 
   /// Construct a key from a shaped type, raw data buffer, and a flag that
@@ -105,8 +90,12 @@ struct DenseIntOrFPElementsAttrStorage : public DenseElementsAttributeStorage {
 
     // If the data is already known to be a splat, the key hash value is
     // directly the data buffer.
-    if (isKnownSplat)
+    bool isBoolData = ty.getElementType().isInteger(1);
+    if (isKnownSplat) {
+      if (isBoolData)
+        return getKeyForSplatBoolData(ty, data[0] != 0);
       return KeyTy(ty, data, llvm::hash_value(data), isKnownSplat);
+    }
 
     // Otherwise, we need to check if the data corresponds to a splat or not.
 
@@ -115,7 +104,7 @@ struct DenseIntOrFPElementsAttrStorage : public DenseElementsAttributeStorage {
     assert(numElements != 1 && "splat of 1 element should already be detected");
 
     // Handle boolean values directly as they are packed to 1-bit.
-    if (ty.getElementType().isInteger(1) == 1)
+    if (isBoolData)
       return getKeyForBoolData(ty, data, numElements);
 
     size_t elementWidth = getDenseElementBitWidth(ty.getElementType());
@@ -144,12 +133,9 @@ struct DenseIntOrFPElementsAttrStorage : public DenseElementsAttributeStorage {
     ArrayRef<char> splatData = data;
     bool splatValue = splatData.front() & 1;
 
-    // Helper functor to generate a KeyTy for a boolean splat value.
-    auto generateSplatKey = [=] {
-      return KeyTy(ty, data.take_front(1),
-                   llvm::hash_value(ArrayRef<char>(splatValue ? 1 : 0)),
-                   /*isSplat=*/true);
-    };
+    // Check the simple case where the data matches the known splat value.
+    if (splatData == ArrayRef<char>(splatValue ? kSplatTrue : kSplatFalse))
+      return getKeyForSplatBoolData(ty, splatValue);
 
     // Handle the case where the potential splat value is 1 and the number of
     // elements is non 8-bit aligned.
@@ -162,17 +148,24 @@ struct DenseIntOrFPElementsAttrStorage : public DenseElementsAttributeStorage {
 
       // If this is the only element, the data is known to be a splat.
       if (splatData.size() == 1)
-        return generateSplatKey();
+        return getKeyForSplatBoolData(ty, splatValue);
       splatData = splatData.drop_back();
     }
 
     // Check that the data buffer corresponds to a splat of the proper mask.
     char mask = splatValue ? ~0 : 0;
     return llvm::all_of(splatData, [mask](char c) { return c == mask; })
-               ? generateSplatKey()
+               ? getKeyForSplatBoolData(ty, splatValue)
                : KeyTy(ty, data, llvm::hash_value(data));
   }
 
+  /// Return a key to use for a boolean splat of the given value.
+  static KeyTy getKeyForSplatBoolData(ShapedType type, bool splatValue) {
+    const char &splatData = splatValue ? kSplatTrue : kSplatFalse;
+    return KeyTy(type, splatData, llvm::hash_value(splatData),
+                 /*isSplat=*/true);
+  }
+
   /// Hash the key for the storage.
   static llvm::hash_code hashKey(const KeyTy &key) {
     return llvm::hash_combine(key.type, key.hashCode);
@@ -188,10 +181,6 @@ struct DenseIntOrFPElementsAttrStorage : public DenseElementsAttributeStorage {
       char *rawData = reinterpret_cast<char *>(
           allocator.allocate(data.size(), alignof(uint64_t)));
       std::memcpy(rawData, data.data(), data.size());
-
-      // If this is a boolean splat, make sure only the first bit is used.
-      if (key.isSplat && key.type.getElementType().isInteger(1))
-        rawData[0] &= 1;
       copy = ArrayRef<char>(rawData, data.size());
     }
 
@@ -200,6 +189,10 @@ struct DenseIntOrFPElementsAttrStorage : public DenseElementsAttributeStorage {
   }
 
   ArrayRef<char> data;
+
+  /// The values used to denote a boolean splat value.
+  static constexpr char kSplatTrue = ~0;
+  static constexpr char kSplatFalse = 0;
 };
 
 /// An attribute representing a reference to a dense vector or tensor object

diff  --git a/mlir/unittests/IR/AttributeTest.cpp b/mlir/unittests/IR/AttributeTest.cpp
index cffac41bd9532..e3c8cf45f2f2d 100644
--- a/mlir/unittests/IR/AttributeTest.cpp
+++ b/mlir/unittests/IR/AttributeTest.cpp
@@ -58,6 +58,20 @@ TEST(DenseSplatTest, BoolSplat) {
   detectedSplat = DenseElementsAttr::get(shape, {false, false, false, false});
   EXPECT_EQ(detectedSplat, falseSplat);
 }
+TEST(DenseSplatTest, BoolSplatRawRoundtrip) {
+  MLIRContext context;
+  IntegerType boolTy = IntegerType::get(&context, 1);
+  RankedTensorType shape = RankedTensorType::get({2, 2}, boolTy);
+
+  // Check that splat booleans properly round trip via the raw API.
+  DenseElementsAttr trueSplat = DenseElementsAttr::get(shape, true);
+  EXPECT_TRUE(trueSplat.isSplat());
+  DenseElementsAttr trueSplatFromRaw =
+      DenseElementsAttr::getFromRawBuffer(shape, trueSplat.getRawData());
+  EXPECT_TRUE(trueSplatFromRaw.isSplat());
+
+  EXPECT_EQ(trueSplat, trueSplatFromRaw);
+}
 
 TEST(DenseSplatTest, LargeBoolSplat) {
   constexpr int64_t boolCount = 56;


        


More information about the Mlir-commits mailing list