[Mlir-commits] [llvm] [mlir] [mlir] Add Repeated<T> constructors for TypeRange and ValueRange (PR #186923)
Jakub Kuderski
llvmlistbot at llvm.org
Tue Mar 17 11:50:42 PDT 2026
https://github.com/kuhar updated https://github.com/llvm/llvm-project/pull/186923
>From cb054c9522cc9c58d95c3b5eba26992e13af8999 Mon Sep 17 00:00:00 2001
From: Jakub Kuderski <jakub at nod-labs.com>
Date: Mon, 16 Mar 2026 17:25:59 -0400
Subject: [PATCH 1/2] [mlir] Add Repeated<T> constructors for TypeRange and
ValueRange
Many MLIR APIs end up using a range of the same Type / Value repeated N
times, due to the dimensionality of the problem. Allocating a vector
of N identical element is wasteful.
Add `Repeated<T>::Storage` as PointerUnion variants in TypeRange
and ValueRange, enabling O(1) storage for repeated elements.
Size remains 2 pointers (16 bytes on 64-bit) for both range types.
Also update several MLIR dialects and conversions to exercise the new
code.
Co-Authored-By: Claude Opus 4.6 <noreply at anthropic.com>
---
mlir/include/mlir/IR/TypeRange.h | 16 ++++--
mlir/include/mlir/IR/ValueRange.h | 13 +++--
mlir/include/mlir/Support/LLVM.h | 3 ++
.../Conversion/GPUCommon/GPUOpsLowering.cpp | 2 +-
.../Conversion/MathToSPIRV/MathToSPIRV.cpp | 4 +-
mlir/lib/Conversion/TosaToSCF/TosaToSCF.cpp | 4 +-
.../Conversion/VectorToAMX/VectorToAMX.cpp | 2 +-
.../OwnershipBasedBufferDeallocation.cpp | 2 +-
.../Transforms/VectorTransferOpTransforms.cpp | 4 +-
mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp | 2 +-
mlir/lib/IR/OperationSupport.cpp | 6 +++
mlir/lib/IR/TypeRange.cpp | 14 +++++
mlir/unittests/IR/OperationSupportTest.cpp | 51 +++++++++++++++++++
13 files changed, 107 insertions(+), 16 deletions(-)
diff --git a/mlir/include/mlir/IR/TypeRange.h b/mlir/include/mlir/IR/TypeRange.h
index c6cbf3461bcd7..3debed6212778 100644
--- a/mlir/include/mlir/IR/TypeRange.h
+++ b/mlir/include/mlir/IR/TypeRange.h
@@ -33,7 +33,9 @@ namespace mlir {
class TypeRange : public llvm::detail::indexed_accessor_range_base<
TypeRange,
llvm::PointerUnion<const Value *, const Type *,
- OpOperand *, detail::OpResultImpl *>,
+ OpOperand *, detail::OpResultImpl *,
+ const Repeated<Type>::Storage *,
+ const Repeated<Value>::Storage *>,
Type, Type, Type> {
public:
using RangeBaseT::RangeBaseT;
@@ -51,6 +53,10 @@ class TypeRange : public llvm::detail::indexed_accessor_range_base<
: TypeRange(ArrayRef<Type>(std::forward<Arg>(arg))) {}
TypeRange(std::initializer_list<Type> types LLVM_LIFETIME_BOUND)
: TypeRange(ArrayRef<Type>(types)) {}
+ /// Constructs a range from a repeated type. The Repeated object must outlive
+ /// this range.
+ TypeRange(const Repeated<Type> &repeatedValue LLVM_LIFETIME_BOUND)
+ : RangeBaseT(&repeatedValue.storage, repeatedValue.count) {}
private:
/// The owner of the range is either:
@@ -58,8 +64,12 @@ class TypeRange : public llvm::detail::indexed_accessor_range_base<
/// * A pointer to the first element of an array of types.
/// * A pointer to the first element of an array of operands.
/// * A pointer to the first element of an array of results.
- using OwnerT = llvm::PointerUnion<const Value *, const Type *, OpOperand *,
- detail::OpResultImpl *>;
+ /// * A pointer to a Repeated<Type>::Storage (single type repeated N times).
+ /// * A pointer to a Repeated<Value>::Storage (single value repeated N times,
+ /// dereferenced via getType()).
+ using OwnerT = llvm::PointerUnion<
+ const Value *, const Type *, OpOperand *, detail::OpResultImpl *,
+ const Repeated<Type>::Storage *, const Repeated<Value>::Storage *>;
/// See `llvm::detail::indexed_accessor_range_base` for details.
static OwnerT offset_base(OwnerT object, ptrdiff_t index);
diff --git a/mlir/include/mlir/IR/ValueRange.h b/mlir/include/mlir/IR/ValueRange.h
index f04ed0544c0f6..d40de878d5d10 100644
--- a/mlir/include/mlir/IR/ValueRange.h
+++ b/mlir/include/mlir/IR/ValueRange.h
@@ -17,6 +17,7 @@
#include "mlir/IR/Types.h"
#include "mlir/IR/Value.h"
#include "llvm/ADT/PointerUnion.h"
+#include "llvm/ADT/Repeated.h"
#include "llvm/ADT/Sequence.h"
#include <optional>
@@ -383,13 +384,15 @@ class ResultRange::UseIterator final
class ValueRange final
: public llvm::detail::indexed_accessor_range_base<
ValueRange,
- PointerUnion<const Value *, OpOperand *, detail::OpResultImpl *>,
+ PointerUnion<const Value *, OpOperand *, detail::OpResultImpl *,
+ const Repeated<Value>::Storage *>,
Value, Value, Value> {
public:
/// The type representing the owner of a ValueRange. This is either a list of
- /// values, operands, or results.
+ /// values, operands, results, or a repeated single value.
using OwnerT =
- PointerUnion<const Value *, OpOperand *, detail::OpResultImpl *>;
+ PointerUnion<const Value *, OpOperand *, detail::OpResultImpl *,
+ const Repeated<Value>::Storage *>;
using RangeBaseT::RangeBaseT;
@@ -412,6 +415,10 @@ class ValueRange final
ValueRange(ArrayRef<Value> values = {});
ValueRange(OperandRange values);
ValueRange(ResultRange values);
+ /// Constructs a range from a repeated value. The Repeated object must outlive
+ /// this range.
+ ValueRange(const Repeated<Value> &repeatedValue LLVM_LIFETIME_BOUND)
+ : RangeBaseT(&repeatedValue.storage, repeatedValue.count) {}
/// Returns the types of the values within this range.
using type_iterator = ValueTypeIterator<iterator>;
diff --git a/mlir/include/mlir/Support/LLVM.h b/mlir/include/mlir/Support/LLVM.h
index 81bbd717c4f80..8b27ee8c4fed6 100644
--- a/mlir/include/mlir/Support/LLVM.h
+++ b/mlir/include/mlir/Support/LLVM.h
@@ -54,6 +54,8 @@ template <typename T>
class MutableArrayRef;
template <typename... PT>
class PointerUnion;
+template <typename T>
+struct Repeated;
template <typename T, typename Vector, typename Set, unsigned N>
class SetVector;
template <typename T, unsigned N>
@@ -125,6 +127,7 @@ template <typename AllocatorTy = llvm::MallocAllocator>
using StringSet = llvm::StringSet<AllocatorTy>;
using llvm::MutableArrayRef;
using llvm::PointerUnion;
+using llvm::Repeated;
using llvm::SmallPtrSet;
using llvm::SmallPtrSetImpl;
using llvm::SmallVector;
diff --git a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
index 498bea0fd17b4..6a705ebab7aa4 100644
--- a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
+++ b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
@@ -95,7 +95,7 @@ GPUFuncOpLowering::matchAndRewrite(gpu::GPUFuncOp gpuFuncOp, OpAdaptor adaptor,
// New arguments will simply be `llvm.ptr` with the correct address space
Type workgroupPtrType =
rewriter.getType<LLVM::LLVMPointerType>(workgroupAddrSpace);
- SmallVector<Type> argTypes(numAttributions, workgroupPtrType);
+ Repeated<Type> argTypes(numAttributions, workgroupPtrType);
// Attributes: noalias, llvm.mlir.workgroup_attribution(<size>, <type>)
std::array attrs{
diff --git a/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp b/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp
index 78f0fe1392962..e4b5da7a5ea92 100644
--- a/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp
+++ b/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp
@@ -155,11 +155,11 @@ struct CopySignPattern final : public OpConversionPattern<math::CopySignOp> {
int count = vectorType.getNumElements();
intType = VectorType::get(count, intType);
- SmallVector<Value> signSplat(count, signMask);
+ Repeated<Value> signSplat(count, signMask);
signMask = spirv::CompositeConstructOp::create(rewriter, loc, intType,
signSplat);
- SmallVector<Value> valueSplat(count, valueMask);
+ Repeated<Value> valueSplat(count, valueMask);
valueMask = spirv::CompositeConstructOp::create(rewriter, loc, intType,
valueSplat);
}
diff --git a/mlir/lib/Conversion/TosaToSCF/TosaToSCF.cpp b/mlir/lib/Conversion/TosaToSCF/TosaToSCF.cpp
index b46026b855b90..7e9c9090c51df 100644
--- a/mlir/lib/Conversion/TosaToSCF/TosaToSCF.cpp
+++ b/mlir/lib/Conversion/TosaToSCF/TosaToSCF.cpp
@@ -117,8 +117,8 @@ class ScatterOpConverter : public OpRewritePattern<tosa::ScatterOp> {
auto one = createIndexConst(rewriter, loc, 1);
// Loop bounds
- auto lbs = llvm::SmallVector<Value>(2, zero);
- auto steps = llvm::SmallVector<Value>(2, one);
+ auto lbs = Repeated<Value>(2, zero);
+ auto steps = Repeated<Value>(2, one);
auto ubs = llvm::SmallVector<Value>{{dimN, dimW}};
auto buildBody = [&](OpBuilder &builder, Location loc, ValueRange ivs,
diff --git a/mlir/lib/Conversion/VectorToAMX/VectorToAMX.cpp b/mlir/lib/Conversion/VectorToAMX/VectorToAMX.cpp
index bce67b3e4748b..c6182379026df 100644
--- a/mlir/lib/Conversion/VectorToAMX/VectorToAMX.cpp
+++ b/mlir/lib/Conversion/VectorToAMX/VectorToAMX.cpp
@@ -354,7 +354,7 @@ static TypedValue<VectorType> storeTile(PatternRewriter &rewriter,
rewriter, loc,
MemRefType::get(tileTy.getShape(), tileTy.getElementType()));
Value zeroIndex = rewriter.createOrFold<arith::ConstantIndexOp>(loc, 0);
- SmallVector<Value> indices(2, zeroIndex);
+ Repeated<Value> indices(2, zeroIndex);
x86::amx::TileStoreOp::create(rewriter, loc, buf, indices, tile);
auto vecTy = VectorType::get(tileTy.getShape(), tileTy.getElementType());
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation.cpp
index e625f172a3bf3..f73c8476bf20e 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation.cpp
@@ -723,7 +723,7 @@ BufferDeallocation::handleInterface(RegionBranchOpInterface op) {
// the outside.
Value falseVal = buildBoolValue(builder, op.getLoc(), false);
op->insertOperands(op->getNumOperands(),
- SmallVector<Value>(numMemrefOperands, falseVal));
+ Repeated<Value>(numMemrefOperands, falseVal));
int counter = op->getNumResults();
unsigned numMemrefResults = llvm::count_if(op->getResults(), isMemref);
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
index 19db8b3b48a25..babd321e484bd 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
@@ -544,7 +544,7 @@ class TransferReadDropUnitDimsPattern
Value reducedShapeSource =
rankReducingSubviewDroppingUnitDims(rewriter, loc, source);
Value c0 = arith::ConstantIndexOp::create(rewriter, loc, 0);
- SmallVector<Value> zeros(reducedRank, c0);
+ Repeated<Value> zeros(reducedRank, c0);
auto identityMap = rewriter.getMultiDimIdentityMap(reducedRank);
SmallVector<bool> inBounds(reducedVectorType.getRank(), true);
Operation *newTransferReadOp = vector::TransferReadOp::create(
@@ -658,7 +658,7 @@ class TransferWriteDropUnitDimsPattern
Value reducedShapeSource =
rankReducingSubviewDroppingUnitDims(rewriter, loc, source);
Value c0 = arith::ConstantIndexOp::create(rewriter, loc, 0);
- SmallVector<Value> zeros(reducedRank, c0);
+ Repeated<Value> zeros(reducedRank, c0);
auto identityMap = rewriter.getMultiDimIdentityMap(reducedRank);
SmallVector<bool> inBounds(reducedVectorType.getRank(), true);
auto shapeCastSrc = rewriter.createOrFold<vector::ShapeCastOp>(
diff --git a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
index df8e6cf167348..9585f5a1d774a 100644
--- a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
+++ b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
@@ -357,7 +357,7 @@ Value vector::createReadOrMaskedRead(OpBuilder &builder, Location loc,
builder, loc,
/*vectorType=*/vecToReadTy,
/*source=*/source,
- /*indices=*/SmallVector<Value>(vecToReadRank, zero),
+ /*indices=*/Repeated<Value>(vecToReadRank, zero),
/*padding=*/padValue,
/*inBounds=*/inBoundsVal);
diff --git a/mlir/lib/IR/OperationSupport.cpp b/mlir/lib/IR/OperationSupport.cpp
index 3ff61daaac60b..f1ee879136756 100644
--- a/mlir/lib/IR/OperationSupport.cpp
+++ b/mlir/lib/IR/OperationSupport.cpp
@@ -654,6 +654,9 @@ ValueRange::OwnerT ValueRange::offset_base(const OwnerT &owner,
return {value + index};
if (auto *operand = llvm::dyn_cast_if_present<OpOperand *>(owner))
return {operand + index};
+ // All elements are identical; the owner pointer never advances.
+ if (llvm::isa<const Repeated<Value>::Storage *>(owner))
+ return owner;
return cast<detail::OpResultImpl *>(owner)->getNextResultAtOffset(index);
}
/// See `llvm::detail::indexed_accessor_range_base` for details.
@@ -662,6 +665,9 @@ Value ValueRange::dereference_iterator(const OwnerT &owner, ptrdiff_t index) {
return value[index];
if (auto *operand = llvm::dyn_cast_if_present<OpOperand *>(owner))
return operand[index].get();
+ if (auto *repeated =
+ llvm::dyn_cast_if_present<const Repeated<Value>::Storage *>(owner))
+ return repeated->value;
return cast<detail::OpResultImpl *>(owner)->getNextResultAtOffset(index);
}
diff --git a/mlir/lib/IR/TypeRange.cpp b/mlir/lib/IR/TypeRange.cpp
index 88e788aa1b2b8..d14ad76c83e75 100644
--- a/mlir/lib/IR/TypeRange.cpp
+++ b/mlir/lib/IR/TypeRange.cpp
@@ -31,6 +31,10 @@ TypeRange::TypeRange(ValueRange values) : TypeRange(OwnerT(), values.size()) {
this->base = result;
else if (auto *operand = llvm::dyn_cast_if_present<OpOperand *>(owner))
this->base = operand;
+ else if (auto *repeated =
+ llvm::dyn_cast_if_present<const Repeated<Value>::Storage *>(
+ owner))
+ this->base = repeated;
else
this->base = cast<const Value *>(owner);
}
@@ -43,6 +47,10 @@ TypeRange::OwnerT TypeRange::offset_base(OwnerT object, ptrdiff_t index) {
return {operand + index};
if (auto *result = llvm::dyn_cast_if_present<detail::OpResultImpl *>(object))
return {result->getNextResultAtOffset(index)};
+ // All elements are identical; the owner pointer never advances.
+ if (llvm::isa<const Repeated<Type>::Storage *,
+ const Repeated<Value>::Storage *>(object))
+ return object;
return {llvm::dyn_cast_if_present<const Type *>(object) + index};
}
@@ -54,5 +62,11 @@ Type TypeRange::dereference_iterator(OwnerT object, ptrdiff_t index) {
return (operand + index)->get().getType();
if (auto *result = llvm::dyn_cast_if_present<detail::OpResultImpl *>(object))
return result->getNextResultAtOffset(index)->getType();
+ if (auto *repeated =
+ llvm::dyn_cast_if_present<const Repeated<Type>::Storage *>(object))
+ return repeated->value;
+ if (auto *repeated =
+ llvm::dyn_cast_if_present<const Repeated<Value>::Storage *>(object))
+ return repeated->value.getType();
return llvm::dyn_cast_if_present<const Type *>(object)[index];
}
diff --git a/mlir/unittests/IR/OperationSupportTest.cpp b/mlir/unittests/IR/OperationSupportTest.cpp
index f00d5c1f7f927..6319fcbb0f216 100644
--- a/mlir/unittests/IR/OperationSupportTest.cpp
+++ b/mlir/unittests/IR/OperationSupportTest.cpp
@@ -11,7 +11,9 @@
#include "../../test/lib/Dialect/Test/TestOps.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/TypeRange.h"
#include "llvm/ADT/BitVector.h"
+#include "llvm/ADT/Repeated.h"
#include "llvm/Support/FormatVariadic.h"
#include "gtest/gtest.h"
@@ -375,4 +377,53 @@ TEST(OperationCloneTest, CloneWithDifferentResults) {
cloneOp->destroy();
}
+TEST(RepeatedRangeTest, TypeRangeFromRepeatedType) {
+ MLIRContext context;
+ Builder builder(&context);
+ Type i32 = builder.getI32Type();
+
+ llvm::Repeated<Type> rep(3, i32);
+ TypeRange range(rep);
+
+ EXPECT_EQ(range.size(), 3u);
+ EXPECT_FALSE(range.empty());
+ for (Type t : range)
+ EXPECT_EQ(t, i32);
+
+ llvm::Repeated<Type> emptyRep(0, Type{});
+ TypeRange emptyTypeRange(emptyRep);
+
+ EXPECT_EQ(emptyTypeRange.size(), 0u);
+ EXPECT_TRUE(emptyTypeRange.empty());
+}
+
+TEST(RepeatedRangeTest, ValueRangeFromRepeatedValue) {
+ Value nullVal;
+ llvm::Repeated<Value> rep(4, nullVal);
+ ValueRange range(rep);
+
+ EXPECT_EQ(range.size(), 4u);
+ EXPECT_FALSE(range.empty());
+ for (Value v : range)
+ EXPECT_EQ(v, nullVal);
+}
+
+TEST(RepeatedRangeTest, TypeRangeFromRepeatedValueViaValueRange) {
+ MLIRContext context;
+ Builder builder(&context);
+ Type i32 = builder.getI32Type();
+
+ Operation *op = createOp(&context, /*operands=*/{}, i32);
+ Value val = op->getResult(0);
+
+ llvm::Repeated<Value> rep(3, val);
+ TypeRange tr = ValueRange(rep);
+
+ EXPECT_EQ(tr.size(), 3u);
+ for (Type t : tr)
+ EXPECT_EQ(t, i32);
+
+ op->destroy();
+}
+
} // namespace
>From 8af9d42e0f4c7e32cc012e947178ca2649e08f6c Mon Sep 17 00:00:00 2001
From: Jakub Kuderski <jakub at nod-labs.com>
Date: Tue, 17 Mar 2026 13:47:14 -0400
Subject: [PATCH 2/2] [llvm][ADT] Add VariablePointerUnion with variable-length
tag encoding
Add VariablePointerUnionImpl, a pointer union that encodes type tags
using a variable number of low pointer bits. Types are grouped into
tiers by their NumLowBitsAvailable; each non-final tier reserves one
code as an escape prefix to the next tier. This allows more variants
than a fixed-width tag when pointer types have heterogeneous alignment
(e.g., 2-bit + 3-bit + 4-bit types on 32-bit platforms).
The public alias VariablePointerUnion<PTs...> is a zero-cost alias to
plain PointerUnion when all types have sufficient alignment for a
fixed-width tag (the common case on 64-bit), and only uses the
variable-length encoding when necessary.
Also adds alignas(max(16, alignof(T))) to Repeated<T> so that
const Repeated<T>* has more low bits available than a plain pointer,
enabling its use as a PointerUnion variant in TypeRange on 32-bit.
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply at anthropic.com>
---
llvm/include/llvm/ADT/PointerUnion.h | 325 ++++++++++++++++++++++
llvm/include/llvm/ADT/Repeated.h | 7 +-
llvm/unittests/ADT/PointerUnionTest.cpp | 340 ++++++++++++++++++++++++
llvm/unittests/ADT/RepeatedTest.cpp | 11 +
4 files changed, 682 insertions(+), 1 deletion(-)
diff --git a/llvm/include/llvm/ADT/PointerUnion.h b/llvm/include/llvm/ADT/PointerUnion.h
index d9087dd1c516e..e2f6dfa99728b 100644
--- a/llvm/include/llvm/ADT/PointerUnion.h
+++ b/llvm/include/llvm/ADT/PointerUnion.h
@@ -21,6 +21,7 @@
#include "llvm/Support/Casting.h"
#include "llvm/Support/PointerLikeTypeTraits.h"
#include <algorithm>
+#include <array>
#include <cassert>
#include <cstddef>
#include <cstdint>
@@ -277,6 +278,330 @@ template <typename ...PTs> struct DenseMapInfo<PointerUnion<PTs...>> {
}
};
+//===----------------------------------------------------------------------===//
+// VariablePointerUnion - variable-length tag encoding for heterogeneous
+// pointer alignments.
+//===----------------------------------------------------------------------===//
+
+namespace pointer_union_detail {
+
+/// Tag descriptor for one type in a variable-length tag encoding.
+struct TagEntry {
+ intptr_t value; ///< Bit pattern stored in the low bits.
+ intptr_t mask; ///< Mask covering all tag bits for this entry.
+};
+
+/// True if types are in non-decreasing NumLowBitsAvailable order.
+template <typename... PTs>
+constexpr bool typesInAscendingBitOrder() {
+ if constexpr (sizeof...(PTs) <= 1)
+ return true;
+ else {
+ int bits[] = {PointerLikeTypeTraits<PTs>::NumLowBitsAvailable...};
+ for (size_t i = 1; i < sizeof...(PTs); ++i)
+ if (bits[i] < bits[i - 1])
+ return false;
+ return true;
+ }
+}
+
+/// True if the variable-length encoding has enough capacity for all types.
+template <typename... PTs>
+constexpr bool extendedTagsFit() {
+ if constexpr (sizeof...(PTs) == 0)
+ return true;
+ else {
+ constexpr size_t N = sizeof...(PTs);
+ int bits[] = {PointerLikeTypeTraits<PTs>::NumLowBitsAvailable...};
+ int prevBits = 0;
+ size_t i = 0;
+ while (i < N) {
+ int tierBits = bits[i];
+ int newBits = tierBits - prevBits;
+ size_t tierEnd = i;
+ while (tierEnd < N && bits[tierEnd] == tierBits)
+ ++tierEnd;
+ bool isLastTier = (tierEnd == N);
+ size_t typesInTier = tierEnd - i;
+ size_t capacity =
+ isLastTier ? size_t(1) << newBits : (size_t(1) << newBits) - 1;
+ if (typesInTier > capacity)
+ return false;
+ prevBits = tierBits;
+ i = tierEnd;
+ }
+ return true;
+ }
+}
+
+/// Compute variable-length tag table. Types must be in ascending
+/// NumLowBitsAvailable order. Groups types into tiers by bit count;
+/// each non-final tier reserves one code as an escape prefix.
+template <typename... PTs>
+constexpr std::array<TagEntry, sizeof...(PTs)> computeExtendedTags() {
+ constexpr size_t N = sizeof...(PTs);
+ std::array<TagEntry, N> result = {};
+ if constexpr (N == 0)
+ return result;
+ else {
+ int bits[] = {PointerLikeTypeTraits<PTs>::NumLowBitsAvailable...};
+ intptr_t escapePrefix = 0;
+ int prevBits = 0;
+ size_t i = 0;
+ while (i < N) {
+ int tierBits = bits[i];
+ int newBits = tierBits - prevBits;
+ size_t tierEnd = i;
+ while (tierEnd < N && bits[tierEnd] == tierBits)
+ ++tierEnd;
+ for (size_t j = 0; j < tierEnd - i; ++j) {
+ result[i + j].value = escapePrefix | (intptr_t(j) << prevBits);
+ result[i + j].mask = (intptr_t(1) << tierBits) - 1;
+ }
+ intptr_t escapeCode = (intptr_t(1) << newBits) - 1;
+ escapePrefix |= escapeCode << prevBits;
+ prevBits = tierBits;
+ i = tierEnd;
+ }
+ return result;
+ }
+}
+
+/// A pointer union that encodes the type tag in a variable number of low
+/// pointer bits, allowing more variants than a fixed-width tag when pointer
+/// types have heterogeneous alignment.
+///
+/// Do not use this class directly; use the VariablePointerUnion alias, which
+/// falls back to the faster PointerUnion when all types have sufficient
+/// alignment for a fixed-width tag.
+template <typename... PTs>
+class VariablePointerUnionImpl {
+ static_assert(TypesAreDistinct<PTs...>::value,
+ "VariablePointerUnion alternative types cannot be repeated");
+ static_assert(typesInAscendingBitOrder<PTs...>(),
+ "Types must be listed in ascending NumLowBitsAvailable order");
+ static_assert(extendedTagsFit<PTs...>(),
+ "Too many types for the available low bits");
+
+ static constexpr std::array<TagEntry, sizeof...(PTs)> TagTable =
+ computeExtendedTags<PTs...>();
+ // First type always gets tag 0, which is required by getAddrOfPtr1().
+ static_assert(TagTable[0].value == 0,
+ "First type must have tag value 0 for getAddrOfPtr1");
+
+ detail::PunnedPointer<void *> Val;
+
+ using First = TypeAtIndex<0, PTs...>;
+
+ // Qualified: CastInfo lives in ::llvm, not in pointer_union_detail.
+ template <typename To, typename From, typename Enable>
+ friend struct ::llvm::CastInfo;
+
+ // Null is stored as tag-only (pointer bits all zero). Each type has a
+ // distinct tag value, so we compare against all of them. For typical
+ // union sizes (5-6 types) the compiler optimizes this into a short chain.
+ template <size_t... Is>
+ static constexpr bool isNullCheck(intptr_t v, std::index_sequence<Is...>) {
+ return ((v == TagTable[Is].value) || ...);
+ }
+
+ template <typename T>
+ static intptr_t encode(T V) {
+ constexpr size_t Idx = FirstIndexOfType<T, PTs...>::value;
+ void *vp =
+ const_cast<void *>(PointerLikeTypeTraits<T>::getAsVoidPointer(V));
+ intptr_t ptrInt = reinterpret_cast<intptr_t>(vp);
+ assert((ptrInt & TagTable[Idx].mask) == 0 &&
+ "Pointer low bits collide with tag");
+ return ptrInt | TagTable[Idx].value;
+ }
+
+public:
+ VariablePointerUnionImpl() : Val(intptr_t(0)) {}
+ VariablePointerUnionImpl(std::nullptr_t) : VariablePointerUnionImpl() {}
+
+ template <typename T,
+ typename = std::enable_if_t<is_one_of<T, PTs...>::value>>
+ VariablePointerUnionImpl(T V) : Val(encode(V)) {}
+
+ template <typename T,
+ typename = std::enable_if_t<is_one_of<T, PTs...>::value>>
+ VariablePointerUnionImpl &operator=(T V) {
+ Val = encode(V);
+ return *this;
+ }
+
+ const VariablePointerUnionImpl &operator=(std::nullptr_t) {
+ Val = intptr_t(0);
+ return *this;
+ }
+
+ /// Test if the pointer held in the union is null, regardless of
+ /// which type it is.
+ bool isNull() const {
+ return isNullCheck(Val.asInt(), std::index_sequence_for<PTs...>{});
+ }
+
+ explicit operator bool() const { return !isNull(); }
+
+ template <typename T> [[deprecated("Use isa instead")]] bool is() const {
+ return isa<T>(*this);
+ }
+
+ template <typename T> [[deprecated("Use cast instead")]] T get() const {
+ assert(isa<T>(*this) && "Invalid accessor called");
+ return cast<T>(*this);
+ }
+
+ template <typename T> inline T dyn_cast() const {
+ return llvm::dyn_cast_if_present<T>(*this);
+ }
+
+ /// If the union is set to the first pointer type get an address pointing to
+ /// it.
+ First const *getAddrOfPtr1() const {
+ return const_cast<VariablePointerUnionImpl *>(this)->getAddrOfPtr1();
+ }
+
+ /// If the union is set to the first pointer type get an address pointing to
+ /// it.
+ First *getAddrOfPtr1() {
+ assert(isa<First>(*this) && "Val is not the first pointer");
+ assert(
+ PointerLikeTypeTraits<First>::getAsVoidPointer(cast<First>(*this)) ==
+ reinterpret_cast<void *>(Val.asInt()) &&
+ "Can't get the address because PointerLikeTypeTraits changes the ptr");
+ return const_cast<First *>(
+ reinterpret_cast<const First *>(Val.getPointerAddress()));
+ }
+
+ void *getOpaqueValue() const {
+ return reinterpret_cast<void *>(Val.asInt());
+ }
+
+ static inline VariablePointerUnionImpl getFromOpaqueValue(void *VP) {
+ VariablePointerUnionImpl V;
+ V.Val = reinterpret_cast<intptr_t>(VP);
+ return V;
+ }
+
+ friend bool operator==(VariablePointerUnionImpl lhs,
+ VariablePointerUnionImpl rhs) {
+ return lhs.getOpaqueValue() == rhs.getOpaqueValue();
+ }
+
+ friend bool operator!=(VariablePointerUnionImpl lhs,
+ VariablePointerUnionImpl rhs) {
+ return lhs.getOpaqueValue() != rhs.getOpaqueValue();
+ }
+
+ friend bool operator<(VariablePointerUnionImpl lhs,
+ VariablePointerUnionImpl rhs) {
+ return lhs.getOpaqueValue() < rhs.getOpaqueValue();
+ }
+};
+
+/// Helper to enforce ascending bit order even when the alias resolves to
+/// plain PointerUnion, ensuring cross-platform portability.
+template <typename... PTs>
+struct VariablePointerUnionChecker {
+ static_assert(
+ typesInAscendingBitOrder<PTs...>(),
+ "VariablePointerUnion types must be in ascending "
+ "NumLowBitsAvailable order for cross-platform portability");
+ using type = std::conditional_t<
+ (lowBitsAvailable<PTs...>() >= bitsRequired(sizeof...(PTs))),
+ PointerUnion<PTs...>, VariablePointerUnionImpl<PTs...>>;
+};
+
+} // end namespace pointer_union_detail
+
+// Specialization of CastInfo for VariablePointerUnionImpl
+template <typename To, typename... PTs>
+struct CastInfo<To,
+ pointer_union_detail::VariablePointerUnionImpl<PTs...>>
+ : public DefaultDoCastIfPossible<
+ To, pointer_union_detail::VariablePointerUnionImpl<PTs...>,
+ CastInfo<To,
+ pointer_union_detail::VariablePointerUnionImpl<PTs...>>> {
+ using From = pointer_union_detail::VariablePointerUnionImpl<PTs...>;
+
+ static inline bool isPossible(From &f) {
+ constexpr size_t Idx = FirstIndexOfType<To, PTs...>::value;
+ intptr_t v = reinterpret_cast<intptr_t>(f.getOpaqueValue());
+ return (v & From::TagTable[Idx].mask) == From::TagTable[Idx].value;
+ }
+
+ static To doCast(From &f) {
+ assert(isPossible(f) && "cast to an incompatible type!");
+ constexpr intptr_t ptrMask =
+ ~((intptr_t(1) << PointerLikeTypeTraits<To>::NumLowBitsAvailable) - 1);
+ void *ptr = reinterpret_cast<void *>(
+ reinterpret_cast<intptr_t>(f.getOpaqueValue()) & ptrMask);
+ return PointerLikeTypeTraits<To>::getFromVoidPointer(ptr);
+ }
+
+ static inline To castFailed() { return To(); }
+};
+
+template <typename To, typename... PTs>
+struct CastInfo<
+ To, const pointer_union_detail::VariablePointerUnionImpl<PTs...>>
+ : public ConstStrippingForwardingCast<
+ To,
+ const pointer_union_detail::VariablePointerUnionImpl<PTs...>,
+ CastInfo<To, pointer_union_detail::VariablePointerUnionImpl<
+ PTs...>>> {};
+
+template <typename... PTs>
+struct PointerLikeTypeTraits<
+ pointer_union_detail::VariablePointerUnionImpl<PTs...>> {
+ using Impl = pointer_union_detail::VariablePointerUnionImpl<PTs...>;
+
+ static inline void *getAsVoidPointer(const Impl &P) {
+ return P.getOpaqueValue();
+ }
+
+ static inline Impl getFromVoidPointer(void *P) {
+ return Impl::getFromOpaqueValue(P);
+ }
+
+ // All low bits are consumed by the variable-length tag.
+ static constexpr int NumLowBitsAvailable = 0;
+};
+
+template <typename... PTs>
+struct DenseMapInfo<
+ pointer_union_detail::VariablePointerUnionImpl<PTs...>> {
+ using Union = pointer_union_detail::VariablePointerUnionImpl<PTs...>;
+ using FirstInfo = DenseMapInfo<TypeAtIndex<0, PTs...>>;
+
+ static inline Union getEmptyKey() { return Union(FirstInfo::getEmptyKey()); }
+
+ static inline Union getTombstoneKey() {
+ return Union(FirstInfo::getTombstoneKey());
+ }
+
+ static unsigned getHashValue(const Union &UnionVal) {
+ intptr_t key = (intptr_t)UnionVal.getOpaqueValue();
+ return DenseMapInfo<intptr_t>::getHashValue(key);
+ }
+
+ static bool isEqual(const Union &LHS, const Union &RHS) {
+ return LHS == RHS;
+ }
+};
+
+/// A pointer union that uses variable-length tag encoding to support more
+/// types than a fixed-width tag when pointer types have heterogeneous
+/// alignment. Types must be listed in ascending NumLowBitsAvailable order.
+///
+/// When all types have enough low bits for a fixed-width tag (the common case
+/// on 64-bit platforms), this is a zero-cost alias to PointerUnion.
+template <typename... PTs>
+using VariablePointerUnion =
+ typename pointer_union_detail::VariablePointerUnionChecker<PTs...>::type;
+
} // end namespace llvm
#endif // LLVM_ADT_POINTERUNION_H
diff --git a/llvm/include/llvm/ADT/Repeated.h b/llvm/include/llvm/ADT/Repeated.h
index f821f3f1f73ca..699651012d91b 100644
--- a/llvm/include/llvm/ADT/Repeated.h
+++ b/llvm/include/llvm/ADT/Repeated.h
@@ -16,6 +16,7 @@
#include "llvm/ADT/iterator.h"
+#include <algorithm>
#include <cassert>
#include <cstddef>
#include <utility>
@@ -72,7 +73,11 @@ class RepeatedIterator
///
/// `Repeated<T>` is also a proper random-access range: `begin()`/`end()`
/// return iterators that always dereference to the same stored value.
-template <typename T> struct [[nodiscard]] Repeated {
+// At least 16-byte aligned so that Repeated<T>* has more low bits available
+// than a plain pointer. The primary use case is pointer-like types (e.g. MLIR
+// Type, Value) where Repeated<T>* appears in a PointerUnion alongside them.
+template <typename T>
+struct [[nodiscard]] alignas(std::max(size_t{16}, alignof(T))) Repeated {
/// Wrapper for the stored value used as a PointerUnion target in range
/// types (e.g., TypeRange, ValueRange).
struct Storage {
diff --git a/llvm/unittests/ADT/PointerUnionTest.cpp b/llvm/unittests/ADT/PointerUnionTest.cpp
index d8ac3aed76da2..a9972bca15174 100644
--- a/llvm/unittests/ADT/PointerUnionTest.cpp
+++ b/llvm/unittests/ADT/PointerUnionTest.cpp
@@ -7,6 +7,7 @@
//===----------------------------------------------------------------------===//
#include "llvm/ADT/PointerUnion.h"
+#include "llvm/ADT/DenseMap.h"
#include "gtest/gtest.h"
using namespace llvm;
@@ -289,4 +290,343 @@ TEST_F(PointerUnionTest, NewCastInfra) {
"type mismatch for cast with PointerUnion");
}
+//===----------------------------------------------------------------------===//
+// VariablePointerUnion tests
+//===----------------------------------------------------------------------===//
+
+template <int I> struct alignas(4) Align4 {};
+template <int I> struct alignas(8) Align8 {};
+template <int I> struct alignas(16) Align16 {};
+
+// 2-tier VPU: 3 x 2-bit + 2 x 3-bit types.
+using VPU2 = VariablePointerUnion<Align4<0> *, Align4<1> *, Align4<2> *,
+ Align8<0> *, Align8<1> *>;
+
+// 3-tier VPU: 3 x 2-bit + 1 x 3-bit + 2 x 4-bit types.
+using VPU3 = VariablePointerUnion<Align4<0> *, Align4<1> *, Align4<2> *,
+ Align8<0> *, Align16<0> *, Align16<1> *>;
+
+// When all types have sufficient bits, alias resolves to PointerUnion.
+static_assert(std::is_same_v<VariablePointerUnion<Align8<0> *, Align8<1> *,
+ Align8<2> *, Align8<3> *>,
+ PointerUnion<Align8<0> *, Align8<1> *,
+ Align8<2> *, Align8<3> *>>,
+ "Same-alignment VPU should resolve to PointerUnion");
+
+// Boundary: 8 types with 3 bits each, needs exactly 3 bits -> PointerUnion.
+static_assert(
+ std::is_same_v<VariablePointerUnion<Align8<0> *, Align8<1> *, Align8<2> *,
+ Align8<3> *, Align8<4> *, Align8<5> *,
+ Align8<6> *, Align8<7> *>,
+ PointerUnion<Align8<0> *, Align8<1> *, Align8<2> *,
+ Align8<3> *, Align8<4> *, Align8<5> *,
+ Align8<6> *, Align8<7> *>>,
+ "8-type same-alignment VPU should still be PointerUnion");
+
+// Mixed alignment should use the Impl.
+static_assert(!std::is_same_v<VPU2, PointerUnion<Align4<0> *, Align4<1> *,
+ Align4<2> *, Align8<0> *,
+ Align8<1> *>>,
+ "Mixed-alignment VPU should use VariablePointerUnionImpl");
+
+// NumLowBitsAvailable is 0 for VariablePointerUnionImpl.
+static_assert(PointerLikeTypeTraits<VPU2>::NumLowBitsAvailable == 0);
+
+struct VariablePointerUnion2TierTest : public testing::Test {
+ Align4<0> a0;
+ Align4<1> a1;
+ Align4<2> a2;
+ Align8<0> b0;
+ Align8<1> b1;
+
+ VPU2 pa0, pa1, pa2, pb0, pb1, null;
+ VPU2 na0, na1, na2, nb0, nb1;
+
+ VariablePointerUnion2TierTest()
+ : pa0(&a0), pa1(&a1), pa2(&a2), pb0(&b0), pb1(&b1), null(),
+ na0((Align4<0> *)nullptr), na1((Align4<1> *)nullptr),
+ na2((Align4<2> *)nullptr), nb0((Align8<0> *)nullptr),
+ nb1((Align8<1> *)nullptr) {}
+};
+
+TEST_F(VariablePointerUnion2TierTest, Isa) {
+ // Tier 0 types
+ EXPECT_TRUE(isa<Align4<0> *>(pa0));
+ EXPECT_FALSE(isa<Align4<1> *>(pa0));
+ EXPECT_FALSE(isa<Align4<2> *>(pa0));
+ EXPECT_FALSE(isa<Align8<0> *>(pa0));
+ EXPECT_FALSE(isa<Align8<1> *>(pa0));
+
+ EXPECT_TRUE(isa<Align4<1> *>(pa1));
+ EXPECT_TRUE(isa<Align4<2> *>(pa2));
+
+ // Tier 1 types
+ EXPECT_TRUE(isa<Align8<0> *>(pb0));
+ EXPECT_FALSE(isa<Align4<0> *>(pb0));
+ EXPECT_FALSE(isa<Align8<1> *>(pb0));
+
+ EXPECT_TRUE(isa<Align8<1> *>(pb1));
+ EXPECT_FALSE(isa<Align8<0> *>(pb1));
+
+ // Null pointers preserve type identity
+ EXPECT_TRUE(isa<Align4<0> *>(na0));
+ EXPECT_TRUE(isa<Align8<1> *>(nb1));
+ EXPECT_FALSE(isa<Align8<0> *>(na0));
+}
+
+TEST_F(VariablePointerUnion2TierTest, Cast) {
+ EXPECT_EQ(cast<Align4<0> *>(pa0), &a0);
+ EXPECT_EQ(cast<Align4<1> *>(pa1), &a1);
+ EXPECT_EQ(cast<Align4<2> *>(pa2), &a2);
+ EXPECT_EQ(cast<Align8<0> *>(pb0), &b0);
+ EXPECT_EQ(cast<Align8<1> *>(pb1), &b1);
+}
+
+TEST_F(VariablePointerUnion2TierTest, DynCast) {
+ EXPECT_EQ(dyn_cast<Align4<0> *>(pa0), &a0);
+ EXPECT_EQ(dyn_cast<Align4<1> *>(pa0), nullptr);
+ EXPECT_EQ(dyn_cast<Align8<0> *>(pa0), nullptr);
+
+ EXPECT_EQ(dyn_cast<Align8<0> *>(pb0), &b0);
+ EXPECT_EQ(dyn_cast<Align4<0> *>(pb0), nullptr);
+
+ // pb1 has the all-ones tag -- most likely to expose masking bugs.
+ EXPECT_EQ(dyn_cast<Align8<1> *>(pb1), &b1);
+ EXPECT_EQ(dyn_cast<Align4<0> *>(pb1), nullptr);
+ EXPECT_EQ(dyn_cast<Align4<1> *>(pb1), nullptr);
+ EXPECT_EQ(dyn_cast<Align4<2> *>(pb1), nullptr);
+ EXPECT_EQ(dyn_cast<Align8<0> *>(pb1), nullptr);
+
+ EXPECT_EQ(dyn_cast_if_present<Align4<0> *>(na0), nullptr);
+ EXPECT_EQ(dyn_cast_if_present<Align8<0> *>(na0), nullptr);
+ EXPECT_EQ(dyn_cast_if_present<Align8<0> *>(nb0), nullptr);
+}
+
+TEST_F(VariablePointerUnion2TierTest, Null) {
+ EXPECT_FALSE(pa0.isNull());
+ EXPECT_FALSE(pb0.isNull());
+ EXPECT_TRUE(null.isNull());
+ EXPECT_TRUE(!null);
+ EXPECT_TRUE((bool)pa0);
+
+ EXPECT_TRUE(na0.isNull());
+ EXPECT_TRUE(na1.isNull());
+ EXPECT_TRUE(na2.isNull());
+ EXPECT_TRUE(nb0.isNull());
+ EXPECT_TRUE(nb1.isNull());
+}
+
+TEST_F(VariablePointerUnion2TierTest, NullDiscrimination) {
+ // Null pointers of different types have different opaque values.
+ EXPECT_NE(na0, na1);
+ EXPECT_NE(na0, na2);
+ EXPECT_NE(na0, nb0);
+ EXPECT_NE(na1, nb0);
+ EXPECT_NE(nb0, nb1);
+
+ // Default-constructed is null of first type.
+ EXPECT_EQ(null, na0);
+}
+
+TEST_F(VariablePointerUnion2TierTest, Comparison) {
+ EXPECT_EQ(pa0, pa0);
+ EXPECT_NE(pa0, pa1);
+ EXPECT_NE(pa0, pb0);
+
+ VPU2 other(&a0);
+ EXPECT_EQ(pa0, other);
+}
+
+TEST_F(VariablePointerUnion2TierTest, Assignment) {
+ VPU2 u;
+ EXPECT_TRUE(u.isNull());
+
+ u = &a0;
+ EXPECT_TRUE(isa<Align4<0> *>(u));
+ EXPECT_EQ(cast<Align4<0> *>(u), &a0);
+
+ u = &b0;
+ EXPECT_TRUE(isa<Align8<0> *>(u));
+ EXPECT_EQ(cast<Align8<0> *>(u), &b0);
+
+ u = &a2;
+ EXPECT_TRUE(isa<Align4<2> *>(u));
+
+ u = nullptr;
+ EXPECT_TRUE(u.isNull());
+}
+
+TEST_F(VariablePointerUnion2TierTest, GetAddrOfPtr1) {
+ EXPECT_TRUE((void *)pa0.getAddrOfPtr1() == (void *)&pa0);
+ EXPECT_TRUE((void *)null.getAddrOfPtr1() == (void *)&null);
+}
+
+TEST_F(VariablePointerUnion2TierTest, OpaqueValueRoundTrip) {
+ void *opaque = pa0.getOpaqueValue();
+ VPU2 restored = VPU2::getFromOpaqueValue(opaque);
+ EXPECT_EQ(pa0, restored);
+ EXPECT_EQ(cast<Align4<0> *>(restored), &a0);
+
+ opaque = pb0.getOpaqueValue();
+ restored = VPU2::getFromOpaqueValue(opaque);
+ EXPECT_EQ(pb0, restored);
+ EXPECT_EQ(cast<Align8<0> *>(restored), &b0);
+
+ opaque = pb1.getOpaqueValue();
+ restored = VPU2::getFromOpaqueValue(opaque);
+ EXPECT_EQ(pb1, restored);
+ EXPECT_EQ(cast<Align8<1> *>(restored), &b1);
+}
+
+// 3-tier tests
+
+struct VariablePointerUnion3TierTest : public testing::Test {
+ Align4<0> a0;
+ Align4<1> a1;
+ Align4<2> a2;
+ Align8<0> b0;
+ Align16<0> c0;
+ Align16<1> c1;
+
+ VPU3 pa0, pa1, pa2, pb0, pc0, pc1, null;
+
+ VariablePointerUnion3TierTest()
+ : pa0(&a0), pa1(&a1), pa2(&a2), pb0(&b0), pc0(&c0), pc1(&c1), null() {}
+};
+
+TEST_F(VariablePointerUnion3TierTest, Isa) {
+ EXPECT_TRUE(isa<Align4<0> *>(pa0));
+ EXPECT_FALSE(isa<Align8<0> *>(pa0));
+ EXPECT_FALSE(isa<Align16<0> *>(pa0));
+
+ EXPECT_TRUE(isa<Align8<0> *>(pb0));
+ EXPECT_FALSE(isa<Align4<0> *>(pb0));
+ EXPECT_FALSE(isa<Align16<0> *>(pb0));
+
+ EXPECT_TRUE(isa<Align16<0> *>(pc0));
+ EXPECT_FALSE(isa<Align4<0> *>(pc0));
+ EXPECT_FALSE(isa<Align8<0> *>(pc0));
+ EXPECT_FALSE(isa<Align16<1> *>(pc0));
+
+ EXPECT_TRUE(isa<Align16<1> *>(pc1));
+ EXPECT_FALSE(isa<Align16<0> *>(pc1));
+}
+
+TEST_F(VariablePointerUnion3TierTest, Cast) {
+ EXPECT_EQ(cast<Align4<0> *>(pa0), &a0);
+ EXPECT_EQ(cast<Align4<1> *>(pa1), &a1);
+ EXPECT_EQ(cast<Align4<2> *>(pa2), &a2);
+ EXPECT_EQ(cast<Align8<0> *>(pb0), &b0);
+ EXPECT_EQ(cast<Align16<0> *>(pc0), &c0);
+ EXPECT_EQ(cast<Align16<1> *>(pc1), &c1);
+}
+
+TEST_F(VariablePointerUnion3TierTest, DynCast) {
+ EXPECT_EQ(dyn_cast<Align4<0> *>(pa0), &a0);
+ EXPECT_EQ(dyn_cast<Align8<0> *>(pa0), nullptr);
+ EXPECT_EQ(dyn_cast<Align16<0> *>(pa0), nullptr);
+
+ EXPECT_EQ(dyn_cast<Align8<0> *>(pb0), &b0);
+ EXPECT_EQ(dyn_cast<Align4<0> *>(pb0), nullptr);
+ EXPECT_EQ(dyn_cast<Align16<0> *>(pb0), nullptr);
+
+ EXPECT_EQ(dyn_cast<Align16<0> *>(pc0), &c0);
+ EXPECT_EQ(dyn_cast<Align16<1> *>(pc0), nullptr);
+ EXPECT_EQ(dyn_cast<Align4<0> *>(pc0), nullptr);
+
+ EXPECT_EQ(dyn_cast<Align16<1> *>(pc1), &c1);
+ EXPECT_EQ(dyn_cast<Align16<0> *>(pc1), nullptr);
+}
+
+TEST_F(VariablePointerUnion3TierTest, Null) {
+ EXPECT_TRUE(null.isNull());
+ EXPECT_FALSE(pa0.isNull());
+ EXPECT_FALSE(pb0.isNull());
+ EXPECT_FALSE(pc0.isNull());
+ EXPECT_FALSE(pc1.isNull());
+
+ VPU3 na0((Align4<0> *)nullptr);
+ VPU3 nb0((Align8<0> *)nullptr);
+ VPU3 nc0((Align16<0> *)nullptr);
+ VPU3 nc1((Align16<1> *)nullptr);
+ EXPECT_TRUE(na0.isNull());
+ EXPECT_TRUE(nb0.isNull());
+ EXPECT_TRUE(nc0.isNull());
+ EXPECT_TRUE(nc1.isNull());
+
+ // Null discrimination across all three tiers.
+ EXPECT_NE(na0, nb0);
+ EXPECT_NE(nb0, nc0);
+ EXPECT_NE(nc0, nc1);
+ EXPECT_NE(na0, nc0);
+}
+
+TEST_F(VariablePointerUnion3TierTest, Assignment) {
+ VPU3 u;
+ EXPECT_TRUE(u.isNull());
+
+ u = &a0;
+ EXPECT_TRUE(isa<Align4<0> *>(u));
+ EXPECT_EQ(cast<Align4<0> *>(u), &a0);
+
+ u = &b0;
+ EXPECT_TRUE(isa<Align8<0> *>(u));
+ EXPECT_EQ(cast<Align8<0> *>(u), &b0);
+
+ u = &c1;
+ EXPECT_TRUE(isa<Align16<1> *>(u));
+ EXPECT_EQ(cast<Align16<1> *>(u), &c1);
+
+ u = nullptr;
+ EXPECT_TRUE(u.isNull());
+}
+
+TEST_F(VariablePointerUnion3TierTest, OpaqueValueRoundTrip) {
+ // pb0 has tag 0x3 which doubles as the escape prefix for tier-2.
+ void *opaque = pb0.getOpaqueValue();
+ VPU3 restored = VPU3::getFromOpaqueValue(opaque);
+ EXPECT_EQ(pb0, restored);
+ EXPECT_EQ(cast<Align8<0> *>(restored), &b0);
+
+ opaque = pc0.getOpaqueValue();
+ restored = VPU3::getFromOpaqueValue(opaque);
+ EXPECT_EQ(pc0, restored);
+ EXPECT_EQ(cast<Align16<0> *>(restored), &c0);
+
+ opaque = pc1.getOpaqueValue();
+ restored = VPU3::getFromOpaqueValue(opaque);
+ EXPECT_EQ(pc1, restored);
+ EXPECT_EQ(cast<Align16<1> *>(restored), &c1);
+}
+
+TEST_F(VariablePointerUnion3TierTest, ConstCast) {
+ const VPU3 cpc0(&c0);
+ EXPECT_TRUE(isa<Align16<0> *>(cpc0));
+ EXPECT_FALSE(isa<Align4<0> *>(cpc0));
+ EXPECT_EQ(cast<Align16<0> *>(cpc0), &c0);
+ EXPECT_EQ(dyn_cast<Align8<0> *>(cpc0), nullptr);
+}
+
+TEST(VariablePointerUnionDenseMapTest, BasicOperations) {
+ Align4<0> a0;
+ Align8<0> b0;
+ Align8<1> b1;
+
+ DenseMap<VPU2, int> map;
+ VPU2 ka(&a0), kb(&b0), kb1(&b1);
+
+ map[ka] = 1;
+ map[kb] = 2;
+ map[kb1] = 3;
+
+ EXPECT_EQ(map[ka], 1);
+ EXPECT_EQ(map[kb], 2);
+ EXPECT_EQ(map[kb1], 3);
+
+ EXPECT_EQ(map.count(ka), 1u);
+ map.erase(ka);
+ EXPECT_EQ(map.count(ka), 0u);
+ EXPECT_EQ(map.count(kb), 1u);
+}
+
} // end anonymous namespace
diff --git a/llvm/unittests/ADT/RepeatedTest.cpp b/llvm/unittests/ADT/RepeatedTest.cpp
index f55be4b22ee5f..673f67db7addc 100644
--- a/llvm/unittests/ADT/RepeatedTest.cpp
+++ b/llvm/unittests/ADT/RepeatedTest.cpp
@@ -96,5 +96,16 @@ TEST(RepeatedTest, IteratorTraits) {
SUCCEED();
}
+TEST(RepeatedTest, Alignment) {
+ // Repeated<T> must be at least 16-byte aligned (for VariablePointerUnion
+ // tag bits), but must also respect T's natural alignment when it exceeds 16.
+ static_assert(alignof(Repeated<char>) == 16);
+
+ struct alignas(32) Over { int x; };
+ static_assert(alignof(Repeated<Over>) == 32);
+
+ SUCCEED();
+}
+
} // anonymous namespace
} // namespace llvm
More information about the Mlir-commits
mailing list