[llvm] [ADT] Specialize ValueIsPresent for PointerUnion (PR #121847)
Sergei Barannikov via llvm-commits
llvm-commits at lists.llvm.org
Wed Jan 8 04:59:37 PST 2025
https://github.com/s-barannikov updated https://github.com/llvm/llvm-project/pull/121847
>From 107a35db5dfb64e7067ff12e7932b358796d3864 Mon Sep 17 00:00:00 2001
From: Sergei Barannikov <barannikov88 at gmail.com>
Date: Tue, 7 Jan 2025 00:49:48 +0300
Subject: [PATCH 1/3] [ADT] Make null PointerUnion with different active
members compare equal
Currently, two instances of `PointerUnion` with different active members
and null value compare unequal. In some cases, this results in
counterintuitive behavior when using functions from `Casting.h`, e.g.:
```
PointerUnion<int *, float *> U;
// U = (int *)nullptr;
dyn_cast<int *>(U); // Aborts
dyn_cast<float *>(U); // Aborts
U = (float *)nullptr;
dyn_cast<int *>(U); // OK
dyn_cast<float *>(U); // OK
```
`dyn_cast` should abort in all cases because the argument is null.
Currently, it aborts only if the first member is active. This happens
because the partial template specialization of `ValueIsPresent` for
nullable types compares the union with a union constructed from nullptr,
and the two unions compare equal only if their active members are the
same.
This patch makes two instances of a union compare equal if they are both
null regardless of their active members, and fixes two places where the
old behavior was exploited.
---
llvm/include/llvm/ADT/PointerUnion.h | 12 ++++++------
llvm/lib/CodeGen/RegisterBankInfo.cpp | 4 ++--
llvm/lib/Target/AMDGPU/SIRegisterInfo.cpp | 4 ++--
llvm/unittests/ADT/PointerUnionTest.cpp | 13 ++++++++++---
4 files changed, 20 insertions(+), 13 deletions(-)
diff --git a/llvm/include/llvm/ADT/PointerUnion.h b/llvm/include/llvm/ADT/PointerUnion.h
index 7d4ed02b622626..bea9d6a4569b0b 100644
--- a/llvm/include/llvm/ADT/PointerUnion.h
+++ b/llvm/include/llvm/ADT/PointerUnion.h
@@ -198,14 +198,14 @@ class PointerUnion
}
};
-template <typename ...PTs>
-bool operator==(PointerUnion<PTs...> lhs, PointerUnion<PTs...> rhs) {
- return lhs.getOpaqueValue() == rhs.getOpaqueValue();
+template <typename... PTs>
+bool operator==(PointerUnion<PTs...> LHS, PointerUnion<PTs...> RHS) {
+ return (!LHS && !RHS) || LHS.getOpaqueValue() == RHS.getOpaqueValue();
}
-template <typename ...PTs>
-bool operator!=(PointerUnion<PTs...> lhs, PointerUnion<PTs...> rhs) {
- return lhs.getOpaqueValue() != rhs.getOpaqueValue();
+template <typename... PTs>
+bool operator!=(PointerUnion<PTs...> LHS, PointerUnion<PTs...> RHS) {
+ return !operator==(LHS, RHS);
}
template <typename ...PTs>
diff --git a/llvm/lib/CodeGen/RegisterBankInfo.cpp b/llvm/lib/CodeGen/RegisterBankInfo.cpp
index e1720b038e2361..5a8cf13ad11fd5 100644
--- a/llvm/lib/CodeGen/RegisterBankInfo.cpp
+++ b/llvm/lib/CodeGen/RegisterBankInfo.cpp
@@ -134,10 +134,10 @@ const TargetRegisterClass *RegisterBankInfo::constrainGenericRegister(
// If the register already has a class, fallback to MRI::constrainRegClass.
auto &RegClassOrBank = MRI.getRegClassOrRegBank(Reg);
- if (isa<const TargetRegisterClass *>(RegClassOrBank))
+ if (isa_and_present<const TargetRegisterClass *>(RegClassOrBank))
return MRI.constrainRegClass(Reg, &RC);
- const RegisterBank *RB = cast<const RegisterBank *>(RegClassOrBank);
+ const auto *RB = dyn_cast_if_present<const RegisterBank *>(RegClassOrBank);
// Otherwise, all we can do is ensure the bank covers the class, and set it.
if (RB && !RB->covers(RC))
return nullptr;
diff --git a/llvm/lib/Target/AMDGPU/SIRegisterInfo.cpp b/llvm/lib/Target/AMDGPU/SIRegisterInfo.cpp
index 704435dad65d7b..8fa656c77e90ed 100644
--- a/llvm/lib/Target/AMDGPU/SIRegisterInfo.cpp
+++ b/llvm/lib/Target/AMDGPU/SIRegisterInfo.cpp
@@ -3708,10 +3708,10 @@ const TargetRegisterClass *
SIRegisterInfo::getConstrainedRegClassForOperand(const MachineOperand &MO,
const MachineRegisterInfo &MRI) const {
const RegClassOrRegBank &RCOrRB = MRI.getRegClassOrRegBank(MO.getReg());
- if (const RegisterBank *RB = dyn_cast<const RegisterBank *>(RCOrRB))
+ if (const auto *RB = dyn_cast_if_present<const RegisterBank *>(RCOrRB))
return getRegClassForTypeOnBank(MRI.getType(MO.getReg()), *RB);
- if (const auto *RC = dyn_cast<const TargetRegisterClass *>(RCOrRB))
+ if (const auto *RC = dyn_cast_if_present<const TargetRegisterClass *>(RCOrRB))
return getAllocatableClass(RC);
return nullptr;
diff --git a/llvm/unittests/ADT/PointerUnionTest.cpp b/llvm/unittests/ADT/PointerUnionTest.cpp
index acddb789601494..d870339b01397d 100644
--- a/llvm/unittests/ADT/PointerUnionTest.cpp
+++ b/llvm/unittests/ADT/PointerUnionTest.cpp
@@ -53,9 +53,16 @@ TEST_F(PointerUnionTest, Comparison) {
EXPECT_TRUE(i4 != l4);
EXPECT_TRUE(f4 != l4);
EXPECT_TRUE(l4 != d4);
- EXPECT_TRUE(i4null != f4null);
- EXPECT_TRUE(i4null != l4null);
- EXPECT_TRUE(i4null != d4null);
+ EXPECT_TRUE(i4null == f4null);
+ EXPECT_FALSE(i4null != f4null);
+ EXPECT_TRUE(i4null == l4null);
+ EXPECT_FALSE(i4null != l4null);
+ EXPECT_TRUE(i4null == d4null);
+ EXPECT_FALSE(i4null != d4null);
+ EXPECT_FALSE(i4null == i4);
+ EXPECT_TRUE(i4null != i4);
+ EXPECT_FALSE(i4null == f4);
+ EXPECT_TRUE(i4null != f4);
}
TEST_F(PointerUnionTest, Null) {
>From 13a909854341c6bdea5095a88b8b3a7c256d51c4 Mon Sep 17 00:00:00 2001
From: Sergei Barannikov <barannikov88 at gmail.com>
Date: Wed, 8 Jan 2025 00:11:24 +0300
Subject: [PATCH 2/3] Specialize ValueIsPresent for PointerUnion
---
llvm/include/llvm/ADT/PointerUnion.h | 23 +++++++++++++++++------
llvm/unittests/ADT/PointerUnionTest.cpp | 18 ++++++++----------
2 files changed, 25 insertions(+), 16 deletions(-)
diff --git a/llvm/include/llvm/ADT/PointerUnion.h b/llvm/include/llvm/ADT/PointerUnion.h
index bea9d6a4569b0b..59bfaf6d3a92d4 100644
--- a/llvm/include/llvm/ADT/PointerUnion.h
+++ b/llvm/include/llvm/ADT/PointerUnion.h
@@ -198,14 +198,14 @@ class PointerUnion
}
};
-template <typename... PTs>
-bool operator==(PointerUnion<PTs...> LHS, PointerUnion<PTs...> RHS) {
- return (!LHS && !RHS) || LHS.getOpaqueValue() == RHS.getOpaqueValue();
+template <typename ...PTs>
+bool operator==(PointerUnion<PTs...> lhs, PointerUnion<PTs...> rhs) {
+ return lhs.getOpaqueValue() == rhs.getOpaqueValue();
}
-template <typename... PTs>
-bool operator!=(PointerUnion<PTs...> LHS, PointerUnion<PTs...> RHS) {
- return !operator==(LHS, RHS);
+template <typename ...PTs>
+bool operator!=(PointerUnion<PTs...> lhs, PointerUnion<PTs...> rhs) {
+ return lhs.getOpaqueValue() != rhs.getOpaqueValue();
}
template <typename ...PTs>
@@ -259,6 +259,17 @@ struct CastInfo<To, const PointerUnion<PTs...>>
CastInfo<To, PointerUnion<PTs...>>> {
};
+// The default implementation of isPresent() for nullable types returns true
+// if the active member is not the first one, even if its value is nullptr.
+// Override the default behavior to return false for all possible null values.
+template <typename... PTs>
+struct ValueIsPresent<PointerUnion<PTs...>,
+ std::enable_if_t<IsNullable<PointerUnion<PTs...>>>> {
+ using Union = PointerUnion<PTs...>;
+ static bool isPresent(const Union &V) { return static_cast<bool>(V); }
+ static decltype(auto) unwrapValue(Union &V) { return V; }
+};
+
// Teach SmallPtrSet that PointerUnion is "basically a pointer", that has
// # low bits available = min(PT1bits,PT2bits)-1.
template <typename ...PTs>
diff --git a/llvm/unittests/ADT/PointerUnionTest.cpp b/llvm/unittests/ADT/PointerUnionTest.cpp
index d870339b01397d..a28d532865cbc1 100644
--- a/llvm/unittests/ADT/PointerUnionTest.cpp
+++ b/llvm/unittests/ADT/PointerUnionTest.cpp
@@ -53,16 +53,9 @@ TEST_F(PointerUnionTest, Comparison) {
EXPECT_TRUE(i4 != l4);
EXPECT_TRUE(f4 != l4);
EXPECT_TRUE(l4 != d4);
- EXPECT_TRUE(i4null == f4null);
- EXPECT_FALSE(i4null != f4null);
- EXPECT_TRUE(i4null == l4null);
- EXPECT_FALSE(i4null != l4null);
- EXPECT_TRUE(i4null == d4null);
- EXPECT_FALSE(i4null != d4null);
- EXPECT_FALSE(i4null == i4);
- EXPECT_TRUE(i4null != i4);
- EXPECT_FALSE(i4null == f4);
- EXPECT_TRUE(i4null != f4);
+ EXPECT_TRUE(i4null != f4null);
+ EXPECT_TRUE(i4null != l4null);
+ EXPECT_TRUE(i4null != d4null);
}
TEST_F(PointerUnionTest, Null) {
@@ -215,6 +208,11 @@ TEST_F(PointerUnionTest, NewCastInfra) {
EXPECT_FALSE(isa<float *>(d4null));
EXPECT_FALSE(isa<long long *>(d4null));
+ EXPECT_FALSE(isa_and_present<int *>(i4null));
+ EXPECT_FALSE(isa_and_present<float *>(f4null));
+ EXPECT_FALSE(isa_and_present<long long *>(l4null));
+ EXPECT_FALSE(isa_and_present<double *>(d4null));
+
// test cast<>
EXPECT_EQ(cast<float *>(a), &f);
EXPECT_EQ(cast<int *>(b), &i);
>From 8abd5cd49031c0b6b956d6b15d957bdf663a93f4 Mon Sep 17 00:00:00 2001
From: Sergei Barannikov <barannikov88 at gmail.com>
Date: Wed, 8 Jan 2025 13:53:13 +0300
Subject: [PATCH 3/3] Update llvm/include/llvm/ADT/PointerUnion.h
Co-authored-by: Jakub Kuderski <kubakuderski at gmail.com>
---
llvm/include/llvm/ADT/PointerUnion.h | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/llvm/include/llvm/ADT/PointerUnion.h b/llvm/include/llvm/ADT/PointerUnion.h
index 59bfaf6d3a92d4..fb854665322c26 100644
--- a/llvm/include/llvm/ADT/PointerUnion.h
+++ b/llvm/include/llvm/ADT/PointerUnion.h
@@ -267,7 +267,7 @@ struct ValueIsPresent<PointerUnion<PTs...>,
std::enable_if_t<IsNullable<PointerUnion<PTs...>>>> {
using Union = PointerUnion<PTs...>;
static bool isPresent(const Union &V) { return static_cast<bool>(V); }
- static decltype(auto) unwrapValue(Union &V) { return V; }
+ static Union &unwrapValue(Union &V) { return V; }
};
// Teach SmallPtrSet that PointerUnion is "basically a pointer", that has
More information about the llvm-commits
mailing list