[llvm-branch-commits] [mlir] [mlir][IR] `DenseElementsAttr`: Remove `i1` dense packing special case (PR #180397)
Matthias Springer via llvm-branch-commits
llvm-branch-commits at lists.llvm.org
Sun Feb 8 01:10:13 PST 2026
https://github.com/matthias-springer created https://github.com/llvm/llvm-project/pull/180397
Discussion: https://discourse.llvm.org/t/denseelementsattr-i1-element-type/62525
Depends on #179122.
>From 907f66d32fcd7caf572f05bf86fae1f5b138bb8e Mon Sep 17 00:00:00 2001
From: Matthias Springer <me at m-sp.org>
Date: Sun, 8 Feb 2026 09:07:31 +0000
Subject: [PATCH] [mlir][IR] `DenseElementsAttr`: Remove `i1` dense packing
special case
---
mlir/lib/IR/AttributeDetail.h | 52 +----------------------
mlir/lib/IR/BuiltinAttributes.cpp | 60 +++------------------------
mlir/test/IR/attribute-roundtrip.mlir | 10 -----
mlir/test/IR/parse-literal.mlir | 8 ++--
mlir/unittests/IR/AttributeTest.cpp | 14 -------
5 files changed, 11 insertions(+), 133 deletions(-)
delete mode 100644 mlir/test/IR/attribute-roundtrip.mlir
diff --git a/mlir/lib/IR/AttributeDetail.h b/mlir/lib/IR/AttributeDetail.h
index 7af5c8cd9191d..c60886bc061ce 100644
--- a/mlir/lib/IR/AttributeDetail.h
+++ b/mlir/lib/IR/AttributeDetail.h
@@ -33,10 +33,6 @@ namespace detail {
/// Return the bit width which DenseElementsAttr should use for this type.
inline size_t getDenseElementBitWidth(Type eltType) {
- // i1 is stored as a single bit (bit-packed storage).
- if (eltType.isInteger(1))
- return 1;
- // Check for DenseElementTypeInterface.
if (auto denseEltType = llvm::dyn_cast<DenseElementType>(eltType))
return denseEltType.getDenseElementBitSize();
llvm_unreachable("unsupported element type");
@@ -92,10 +88,7 @@ struct DenseIntOrFPElementsAttrStorage : public DenseElementsAttributeStorage {
// If the data is already known to be a splat, the key hash value is
// directly the data buffer.
- bool isBoolData = ty.getElementType().isInteger(1);
if (isKnownSplat) {
- if (isBoolData)
- return getKeyForSplatBoolData(ty, data[0] != 0);
return KeyTy(ty, data, llvm::hash_value(data), isKnownSplat);
}
@@ -105,12 +98,8 @@ struct DenseIntOrFPElementsAttrStorage : public DenseElementsAttributeStorage {
size_t numElements = ty.getNumElements();
assert(numElements != 1 && "splat of 1 element should already be detected");
- // Handle boolean values directly as they are packed to 1-bit.
- if (isBoolData)
- return getKeyForBoolData(ty, data, numElements);
-
size_t elementWidth = getDenseElementBitWidth(ty.getElementType());
- // Non 1-bit dense elements are padded to 8-bits.
+ // Dense elements are padded to 8-bits.
size_t storageSize = llvm::divideCeil(elementWidth, CHAR_BIT);
assert(((data.size() / storageSize) == numElements) &&
"data does not hold expected number of elements");
@@ -129,45 +118,6 @@ struct DenseIntOrFPElementsAttrStorage : public DenseElementsAttributeStorage {
return KeyTy(ty, firstElt, hashVal, /*isSplat=*/true);
}
- /// Construct a key with a set of boolean data.
- static KeyTy getKeyForBoolData(ShapedType ty, ArrayRef<char> data,
- size_t numElements) {
- ArrayRef<char> splatData = data;
- bool splatValue = splatData.front() & 1;
-
- // Check the simple case where the data matches the known splat value.
- if (splatData == ArrayRef<char>(splatValue ? kSplatTrue : kSplatFalse))
- return getKeyForSplatBoolData(ty, splatValue);
-
- // Handle the case where the potential splat value is 1 and the number of
- // elements is non 8-bit aligned.
- size_t numOddElements = numElements % CHAR_BIT;
- if (splatValue && numOddElements != 0) {
- // Check that all bits are set in the last value.
- char lastElt = splatData.back();
- if (lastElt != llvm::maskTrailingOnes<unsigned char>(numOddElements))
- return KeyTy(ty, data, llvm::hash_value(data));
-
- // If this is the only element, the data is known to be a splat.
- if (splatData.size() == 1)
- return getKeyForSplatBoolData(ty, splatValue);
- splatData = splatData.drop_back();
- }
-
- // Check that the data buffer corresponds to a splat of the proper mask.
- char mask = splatValue ? ~0 : 0;
- return llvm::all_of(splatData, [mask](char c) { return c == mask; })
- ? getKeyForSplatBoolData(ty, splatValue)
- : KeyTy(ty, data, llvm::hash_value(data));
- }
-
- /// Return a key to use for a boolean splat of the given value.
- static KeyTy getKeyForSplatBoolData(ShapedType type, bool splatValue) {
- const char &splatData = splatValue ? kSplatTrue : kSplatFalse;
- return KeyTy(type, splatData, llvm::hash_value(splatData),
- /*isSplat=*/true);
- }
-
/// Hash the key for the storage.
static llvm::hash_code hashKey(const KeyTy &key) {
return llvm::hash_combine(key.type, key.hashCode);
diff --git a/mlir/lib/IR/BuiltinAttributes.cpp b/mlir/lib/IR/BuiltinAttributes.cpp
index d9c5fd9acb811..b2f5853269f0a 100644
--- a/mlir/lib/IR/BuiltinAttributes.cpp
+++ b/mlir/lib/IR/BuiltinAttributes.cpp
@@ -460,7 +460,7 @@ const char DenseIntOrFPElementsAttrStorage::kSplatFalse = 0;
/// Get the bitwidth of a dense element type within the buffer.
/// DenseElementsAttr requires bitwidths greater than 1 to be aligned by 8.
static size_t getDenseElementStorageWidth(size_t origWidth) {
- return origWidth == 1 ? origWidth : llvm::alignTo<8>(origWidth);
+ return llvm::alignTo<8>(origWidth);
}
static size_t getDenseElementStorageWidth(Type elementType) {
return getDenseElementStorageWidth(getDenseElementBitWidth(elementType));
@@ -622,12 +622,6 @@ Attribute DenseElementsAttr::AttributeElementIterator::operator*() const {
auto owner = llvm::cast<DenseElementsAttr>(getFromOpaquePointer(base));
Type eltTy = owner.getElementType();
- // Handle i1 (boolean) specially - it's bit-packed and doesn't use interface.
- if (eltTy.isInteger(1)) {
- bool value = *BoolElementIterator(owner, index);
- return IntegerAttr::get(eltTy, APInt(1, value));
- }
-
// Handle strings specially.
if (llvm::isa<DenseStringElementsAttr>(owner)) {
ArrayRef<StringRef> vals = owner.getRawStringData();
@@ -654,7 +648,7 @@ DenseElementsAttr::BoolElementIterator::BoolElementIterator(
attr.getRawData().data(), attr.isSplat(), dataIndex) {}
bool DenseElementsAttr::BoolElementIterator::operator*() const {
- return getBit(getData(), getDataIndex());
+ return static_cast<bool>(getData()[getDataIndex()]);
}
//===----------------------------------------------------------------------===//
@@ -900,18 +894,8 @@ bool DenseElementsAttr::classof(Attribute attr) {
DenseElementsAttr DenseElementsAttr::get(ShapedType type,
ArrayRef<Attribute> values) {
assert(hasSameNumElementsOrSplat(type, values));
-
Type eltType = type.getElementType();
- // Handle i1 (boolean) specially - it's bit-packed.
- if (eltType.isInteger(1)) {
- SmallVector<bool> boolValues;
- boolValues.reserve(values.size());
- for (Attribute attr : values)
- boolValues.push_back(llvm::cast<IntegerAttr>(attr).getValue().isOne());
- return get(type, boolValues);
- }
-
// Handle strings specially.
if (!llvm::isa<DenseElementType>(eltType)) {
SmallVector<StringRef, 8> stringValues;
@@ -941,25 +925,9 @@ DenseElementsAttr DenseElementsAttr::get(ShapedType type,
ArrayRef<bool> values) {
assert(hasSameNumElementsOrSplat(type, values));
assert(type.getElementType().isInteger(1));
-
- SmallVector<char> buff(llvm::divideCeil(values.size(), CHAR_BIT));
-
- if (!values.empty()) {
- bool isSplat = true;
- bool firstValue = values[0];
- for (int i = 0, e = values.size(); i != e; ++i) {
- isSplat &= values[i] == firstValue;
- setBit(buff.data(), i, values[i]);
- }
-
- // Splat of bool is encoded as a byte with all-ones in it.
- if (isSplat) {
- buff.resize(1);
- buff[0] = values[0] ? -1 : 0;
- }
- }
-
- return DenseIntOrFPElementsAttr::getRaw(type, buff);
+ return DenseIntOrFPElementsAttr::getRaw(
+ type, ArrayRef<char>(reinterpret_cast<const char *>(values.data()),
+ values.size()));
}
DenseElementsAttr DenseElementsAttr::get(ShapedType type,
@@ -1030,23 +998,7 @@ bool DenseElementsAttr::isValidRawBuffer(ShapedType type,
// The initializer is always a splat if the result type has a single element.
detectedSplat = numElements == 1;
- // Storage width of 1 is special as it is packed by the bit.
- if (storageWidth == 1) {
- // Check for a splat, or a buffer equal to the number of elements which
- // consists of either all 0's or all 1's.
- if (rawBuffer.size() == 1) {
- auto rawByte = static_cast<uint8_t>(rawBuffer[0]);
- if (rawByte == 0 || rawByte == 0xff) {
- detectedSplat = true;
- return true;
- }
- }
-
- // This is a valid non-splat buffer if it has the right size.
- return rawBufferWidth == llvm::alignTo<8>(numElements);
- }
-
- // All other types are 8-bit aligned, so we can just check the buffer width
+ // All types are 8-bit aligned, so we can just check the buffer width
// to know if only a single initializer element was passed in.
if (rawBufferWidth == storageWidth) {
detectedSplat = true;
diff --git a/mlir/test/IR/attribute-roundtrip.mlir b/mlir/test/IR/attribute-roundtrip.mlir
deleted file mode 100644
index 974dbcae6cf0a..0000000000000
--- a/mlir/test/IR/attribute-roundtrip.mlir
+++ /dev/null
@@ -1,10 +0,0 @@
-// RUN: mlir-opt -canonicalize %s | mlir-opt | FileCheck %s
-
-// CHECK-LABEL: @large_i1_tensor_roundtrip
-func.func @large_i1_tensor_roundtrip() -> tensor<160xi1> {
- %cst_0 = arith.constant dense<"0xFFF00000FF000000FF000000FF000000FF000000"> : tensor<160xi1>
- %cst_1 = arith.constant dense<"0xFF000000FF000000FF000000FF000000FF0000F0"> : tensor<160xi1>
- // CHECK: dense<"0xFF000000FF000000FF000000FF000000FF000000">
- %0 = arith.andi %cst_0, %cst_1 : tensor<160xi1>
- return %0 : tensor<160xi1>
-}
diff --git a/mlir/test/IR/parse-literal.mlir b/mlir/test/IR/parse-literal.mlir
index 71b25e1d86480..36867c56075d0 100644
--- a/mlir/test/IR/parse-literal.mlir
+++ b/mlir/test/IR/parse-literal.mlir
@@ -36,8 +36,8 @@ func.func @parse_i4_tensor() -> tensor<32xi4> {
}
// CHECK-LABEL: @parse_i1_tensor
-func.func @parse_i1_tensor() -> tensor<256xi1> {
- // CHECK: dense<"0x0F0F0F0F0F0F0F0F0F0F0F0F0F0F0F0F0F0F0F0F0F0F0F0F0F0F0F0F0F0F0F0F"> : tensor<256xi1>
- %0 = arith.constant dense<"0x0F0F0F0F0F0F0F0F0F0F0F0F0F0F0F0F0F0F0F0F0F0F0F0F0F0F0F0F0F0F0F0F"> : tensor<256xi1>
- return %0 : tensor<256xi1>
+func.func @parse_i1_tensor() -> tensor<32xi1> {
+ // CHECK: dense<[true, false, true, false, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, false, false, false, false, false, false, true]> : tensor<32xi1>
+ %0 = arith.constant dense<"0x0100010001010101010101010101010101010101010101010100000000000001"> : tensor<32xi1>
+ return %0 : tensor<32xi1>
}
diff --git a/mlir/unittests/IR/AttributeTest.cpp b/mlir/unittests/IR/AttributeTest.cpp
index fd40404bf3008..404aa8c0dcf3d 100644
--- a/mlir/unittests/IR/AttributeTest.cpp
+++ b/mlir/unittests/IR/AttributeTest.cpp
@@ -76,20 +76,6 @@ TEST(DenseSplatTest, BoolSplatRawRoundtrip) {
EXPECT_EQ(trueSplat, trueSplatFromRaw);
}
-TEST(DenseSplatTest, BoolSplatSmall) {
- MLIRContext context;
- Builder builder(&context);
-
- // Check that splats that don't fill entire byte are handled properly.
- auto tensorType = RankedTensorType::get({4}, builder.getI1Type());
- std::vector<char> data{0b00001111};
- auto trueSplatFromRaw =
- DenseIntOrFPElementsAttr::getFromRawBuffer(tensorType, data);
- EXPECT_TRUE(trueSplatFromRaw.isSplat());
- DenseElementsAttr trueSplat = DenseElementsAttr::get(tensorType, true);
- EXPECT_EQ(trueSplat, trueSplatFromRaw);
-}
-
TEST(DenseSplatTest, LargeBoolSplat) {
constexpr int64_t boolCount = 56;
More information about the llvm-branch-commits
mailing list