[llvm] [ADT] Make null PointerUnion with different active members compare equal (PR #121847)
Sergei Barannikov via llvm-commits
llvm-commits at lists.llvm.org
Mon Jan 6 14:03:25 PST 2025
https://github.com/s-barannikov created https://github.com/llvm/llvm-project/pull/121847
Currently, two instances of `PointerUnion` with different active members and null value compare unequal. This may result in counterintuitive behavior when using functions from `Casting.h`, e.g.:
```C++
PointerUnion<int *, float *> U;
// U = (int *)nullptr;
dyn_cast<int *>(U); // Aborts
dyn_cast<float *>(U); // Aborts
U = (float *)nullptr;
dyn_cast<int *>(U); // OK
dyn_cast<float *>(U); // OK
```
`dyn_cast` should abort in all cases because the argument is null. Currently, it aborts only if the first member is active. This happens because the partial template specialization of `ValueIsPresent` for nullable types compares the union with a union constructed from nullptr, and the two unions compare equal only if their active members are the same.
This patch makes two instances of a union compare equal if they are both null regardless of their active members, and fixes two places where the old behavior was exploited.
>From 107a35db5dfb64e7067ff12e7932b358796d3864 Mon Sep 17 00:00:00 2001
From: Sergei Barannikov <barannikov88 at gmail.com>
Date: Tue, 7 Jan 2025 00:49:48 +0300
Subject: [PATCH] [ADT] Make null PointerUnion with different active members
compare equal
Currently, two instances of `PointerUnion` with different active members
and null value compare unequal. In some cases, this results in
counterintuitive behavior when using functions from `Casting.h`, e.g.:
```
PointerUnion<int *, float *> U;
// U = (int *)nullptr;
dyn_cast<int *>(U); // Aborts
dyn_cast<float *>(U); // Aborts
U = (float *)nullptr;
dyn_cast<int *>(U); // OK
dyn_cast<float *>(U); // OK
```
`dyn_cast` should abort in all cases because the argument is null.
Currently, it aborts only if the first member is active. This happens
because the partial template specialization of `ValueIsPresent` for
nullable types compares the union with a union constructed from nullptr,
and the two unions compare equal only if their active members are the
same.
This patch makes two instances of a union compare equal if they are both
null regardless of their active members, and fixes two places where the
old behavior was exploited.
---
llvm/include/llvm/ADT/PointerUnion.h | 12 ++++++------
llvm/lib/CodeGen/RegisterBankInfo.cpp | 4 ++--
llvm/lib/Target/AMDGPU/SIRegisterInfo.cpp | 4 ++--
llvm/unittests/ADT/PointerUnionTest.cpp | 13 ++++++++++---
4 files changed, 20 insertions(+), 13 deletions(-)
diff --git a/llvm/include/llvm/ADT/PointerUnion.h b/llvm/include/llvm/ADT/PointerUnion.h
index 7d4ed02b622626..bea9d6a4569b0b 100644
--- a/llvm/include/llvm/ADT/PointerUnion.h
+++ b/llvm/include/llvm/ADT/PointerUnion.h
@@ -198,14 +198,14 @@ class PointerUnion
}
};
-template <typename ...PTs>
-bool operator==(PointerUnion<PTs...> lhs, PointerUnion<PTs...> rhs) {
- return lhs.getOpaqueValue() == rhs.getOpaqueValue();
+template <typename... PTs>
+bool operator==(PointerUnion<PTs...> LHS, PointerUnion<PTs...> RHS) {
+ return (!LHS && !RHS) || LHS.getOpaqueValue() == RHS.getOpaqueValue();
}
-template <typename ...PTs>
-bool operator!=(PointerUnion<PTs...> lhs, PointerUnion<PTs...> rhs) {
- return lhs.getOpaqueValue() != rhs.getOpaqueValue();
+template <typename... PTs>
+bool operator!=(PointerUnion<PTs...> LHS, PointerUnion<PTs...> RHS) {
+ return !operator==(LHS, RHS);
}
template <typename ...PTs>
diff --git a/llvm/lib/CodeGen/RegisterBankInfo.cpp b/llvm/lib/CodeGen/RegisterBankInfo.cpp
index e1720b038e2361..5a8cf13ad11fd5 100644
--- a/llvm/lib/CodeGen/RegisterBankInfo.cpp
+++ b/llvm/lib/CodeGen/RegisterBankInfo.cpp
@@ -134,10 +134,10 @@ const TargetRegisterClass *RegisterBankInfo::constrainGenericRegister(
// If the register already has a class, fallback to MRI::constrainRegClass.
auto &RegClassOrBank = MRI.getRegClassOrRegBank(Reg);
- if (isa<const TargetRegisterClass *>(RegClassOrBank))
+ if (isa_and_present<const TargetRegisterClass *>(RegClassOrBank))
return MRI.constrainRegClass(Reg, &RC);
- const RegisterBank *RB = cast<const RegisterBank *>(RegClassOrBank);
+ const auto *RB = dyn_cast_if_present<const RegisterBank *>(RegClassOrBank);
// Otherwise, all we can do is ensure the bank covers the class, and set it.
if (RB && !RB->covers(RC))
return nullptr;
diff --git a/llvm/lib/Target/AMDGPU/SIRegisterInfo.cpp b/llvm/lib/Target/AMDGPU/SIRegisterInfo.cpp
index 704435dad65d7b..8fa656c77e90ed 100644
--- a/llvm/lib/Target/AMDGPU/SIRegisterInfo.cpp
+++ b/llvm/lib/Target/AMDGPU/SIRegisterInfo.cpp
@@ -3708,10 +3708,10 @@ const TargetRegisterClass *
SIRegisterInfo::getConstrainedRegClassForOperand(const MachineOperand &MO,
const MachineRegisterInfo &MRI) const {
const RegClassOrRegBank &RCOrRB = MRI.getRegClassOrRegBank(MO.getReg());
- if (const RegisterBank *RB = dyn_cast<const RegisterBank *>(RCOrRB))
+ if (const auto *RB = dyn_cast_if_present<const RegisterBank *>(RCOrRB))
return getRegClassForTypeOnBank(MRI.getType(MO.getReg()), *RB);
- if (const auto *RC = dyn_cast<const TargetRegisterClass *>(RCOrRB))
+ if (const auto *RC = dyn_cast_if_present<const TargetRegisterClass *>(RCOrRB))
return getAllocatableClass(RC);
return nullptr;
diff --git a/llvm/unittests/ADT/PointerUnionTest.cpp b/llvm/unittests/ADT/PointerUnionTest.cpp
index acddb789601494..d870339b01397d 100644
--- a/llvm/unittests/ADT/PointerUnionTest.cpp
+++ b/llvm/unittests/ADT/PointerUnionTest.cpp
@@ -53,9 +53,16 @@ TEST_F(PointerUnionTest, Comparison) {
EXPECT_TRUE(i4 != l4);
EXPECT_TRUE(f4 != l4);
EXPECT_TRUE(l4 != d4);
- EXPECT_TRUE(i4null != f4null);
- EXPECT_TRUE(i4null != l4null);
- EXPECT_TRUE(i4null != d4null);
+ EXPECT_TRUE(i4null == f4null);
+ EXPECT_FALSE(i4null != f4null);
+ EXPECT_TRUE(i4null == l4null);
+ EXPECT_FALSE(i4null != l4null);
+ EXPECT_TRUE(i4null == d4null);
+ EXPECT_FALSE(i4null != d4null);
+ EXPECT_FALSE(i4null == i4);
+ EXPECT_TRUE(i4null != i4);
+ EXPECT_FALSE(i4null == f4);
+ EXPECT_TRUE(i4null != f4);
}
TEST_F(PointerUnionTest, Null) {
More information about the llvm-commits
mailing list