[llvm] [ADT] Make null PointerUnion with different active members compare equal (PR #121847)

via llvm-commits llvm-commits at lists.llvm.org
Mon Jan 6 14:03:57 PST 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-backend-amdgpu

Author: Sergei Barannikov (s-barannikov)

<details>
<summary>Changes</summary>

Currently, two instances of `PointerUnion` with different active members and null value compare unequal. This may result in counterintuitive behavior when using functions from `Casting.h`, e.g.:

```C++
  PointerUnion<int *, float *> U;
  // U = (int *)nullptr;
  dyn_cast<int *>(U); // Aborts
  dyn_cast<float *>(U); // Aborts
  U = (float *)nullptr;
  dyn_cast<int *>(U); // OK
  dyn_cast<float *>(U); // OK
```

`dyn_cast` should abort in all cases because the argument is null. Currently, it aborts only if the first member is active. This happens because the partial template specialization of `ValueIsPresent` for nullable types compares the union with a union constructed from nullptr, and the two unions compare equal only if their active members are the same.

This patch makes two instances of a union compare equal if they are both null regardless of their active members, and fixes two places where the old behavior was exploited.

---
Full diff: https://github.com/llvm/llvm-project/pull/121847.diff


4 Files Affected:

- (modified) llvm/include/llvm/ADT/PointerUnion.h (+6-6) 
- (modified) llvm/lib/CodeGen/RegisterBankInfo.cpp (+2-2) 
- (modified) llvm/lib/Target/AMDGPU/SIRegisterInfo.cpp (+2-2) 
- (modified) llvm/unittests/ADT/PointerUnionTest.cpp (+10-3) 


``````````diff
diff --git a/llvm/include/llvm/ADT/PointerUnion.h b/llvm/include/llvm/ADT/PointerUnion.h
index 7d4ed02b622626..bea9d6a4569b0b 100644
--- a/llvm/include/llvm/ADT/PointerUnion.h
+++ b/llvm/include/llvm/ADT/PointerUnion.h
@@ -198,14 +198,14 @@ class PointerUnion
   }
 };
 
-template <typename ...PTs>
-bool operator==(PointerUnion<PTs...> lhs, PointerUnion<PTs...> rhs) {
-  return lhs.getOpaqueValue() == rhs.getOpaqueValue();
+template <typename... PTs>
+bool operator==(PointerUnion<PTs...> LHS, PointerUnion<PTs...> RHS) {
+  return (!LHS && !RHS) || LHS.getOpaqueValue() == RHS.getOpaqueValue();
 }
 
-template <typename ...PTs>
-bool operator!=(PointerUnion<PTs...> lhs, PointerUnion<PTs...> rhs) {
-  return lhs.getOpaqueValue() != rhs.getOpaqueValue();
+template <typename... PTs>
+bool operator!=(PointerUnion<PTs...> LHS, PointerUnion<PTs...> RHS) {
+  return !operator==(LHS, RHS);
 }
 
 template <typename ...PTs>
diff --git a/llvm/lib/CodeGen/RegisterBankInfo.cpp b/llvm/lib/CodeGen/RegisterBankInfo.cpp
index e1720b038e2361..5a8cf13ad11fd5 100644
--- a/llvm/lib/CodeGen/RegisterBankInfo.cpp
+++ b/llvm/lib/CodeGen/RegisterBankInfo.cpp
@@ -134,10 +134,10 @@ const TargetRegisterClass *RegisterBankInfo::constrainGenericRegister(
 
   // If the register already has a class, fallback to MRI::constrainRegClass.
   auto &RegClassOrBank = MRI.getRegClassOrRegBank(Reg);
-  if (isa<const TargetRegisterClass *>(RegClassOrBank))
+  if (isa_and_present<const TargetRegisterClass *>(RegClassOrBank))
     return MRI.constrainRegClass(Reg, &RC);
 
-  const RegisterBank *RB = cast<const RegisterBank *>(RegClassOrBank);
+  const auto *RB = dyn_cast_if_present<const RegisterBank *>(RegClassOrBank);
   // Otherwise, all we can do is ensure the bank covers the class, and set it.
   if (RB && !RB->covers(RC))
     return nullptr;
diff --git a/llvm/lib/Target/AMDGPU/SIRegisterInfo.cpp b/llvm/lib/Target/AMDGPU/SIRegisterInfo.cpp
index 704435dad65d7b..8fa656c77e90ed 100644
--- a/llvm/lib/Target/AMDGPU/SIRegisterInfo.cpp
+++ b/llvm/lib/Target/AMDGPU/SIRegisterInfo.cpp
@@ -3708,10 +3708,10 @@ const TargetRegisterClass *
 SIRegisterInfo::getConstrainedRegClassForOperand(const MachineOperand &MO,
                                          const MachineRegisterInfo &MRI) const {
   const RegClassOrRegBank &RCOrRB = MRI.getRegClassOrRegBank(MO.getReg());
-  if (const RegisterBank *RB = dyn_cast<const RegisterBank *>(RCOrRB))
+  if (const auto *RB = dyn_cast_if_present<const RegisterBank *>(RCOrRB))
     return getRegClassForTypeOnBank(MRI.getType(MO.getReg()), *RB);
 
-  if (const auto *RC = dyn_cast<const TargetRegisterClass *>(RCOrRB))
+  if (const auto *RC = dyn_cast_if_present<const TargetRegisterClass *>(RCOrRB))
     return getAllocatableClass(RC);
 
   return nullptr;
diff --git a/llvm/unittests/ADT/PointerUnionTest.cpp b/llvm/unittests/ADT/PointerUnionTest.cpp
index acddb789601494..d870339b01397d 100644
--- a/llvm/unittests/ADT/PointerUnionTest.cpp
+++ b/llvm/unittests/ADT/PointerUnionTest.cpp
@@ -53,9 +53,16 @@ TEST_F(PointerUnionTest, Comparison) {
   EXPECT_TRUE(i4 != l4);
   EXPECT_TRUE(f4 != l4);
   EXPECT_TRUE(l4 != d4);
-  EXPECT_TRUE(i4null != f4null);
-  EXPECT_TRUE(i4null != l4null);
-  EXPECT_TRUE(i4null != d4null);
+  EXPECT_TRUE(i4null == f4null);
+  EXPECT_FALSE(i4null != f4null);
+  EXPECT_TRUE(i4null == l4null);
+  EXPECT_FALSE(i4null != l4null);
+  EXPECT_TRUE(i4null == d4null);
+  EXPECT_FALSE(i4null != d4null);
+  EXPECT_FALSE(i4null == i4);
+  EXPECT_TRUE(i4null != i4);
+  EXPECT_FALSE(i4null == f4);
+  EXPECT_TRUE(i4null != f4);
 }
 
 TEST_F(PointerUnionTest, Null) {

``````````

</details>


https://github.com/llvm/llvm-project/pull/121847


More information about the llvm-commits mailing list