[llvm] [ADT] Specialize ValueIsPresent for PointerUnion (PR #121847)

Sergei Barannikov via llvm-commits llvm-commits at lists.llvm.org
Wed Jan 8 04:59:37 PST 2025


https://github.com/s-barannikov updated https://github.com/llvm/llvm-project/pull/121847

>From 107a35db5dfb64e7067ff12e7932b358796d3864 Mon Sep 17 00:00:00 2001
From: Sergei Barannikov <barannikov88 at gmail.com>
Date: Tue, 7 Jan 2025 00:49:48 +0300
Subject: [PATCH 1/3] [ADT] Make null PointerUnion with different active
 members compare equal

Currently, two instances of `PointerUnion` with different active members
and null value compare unequal. In some cases, this results in
counterintuitive behavior when using functions from `Casting.h`, e.g.:

```
  PointerUnion<int *, float *> U;
  // U = (int *)nullptr;
  dyn_cast<int *>(U); // Aborts
  dyn_cast<float *>(U); // Aborts
  U = (float *)nullptr;
  dyn_cast<int *>(U); // OK
  dyn_cast<float *>(U); // OK
```

`dyn_cast` should abort in all cases because the argument is null.
Currently, it aborts only if the first member is active. This happens
because the partial template specialization of `ValueIsPresent` for
nullable types compares the union with a union constructed from nullptr,
and the two unions compare equal only if their active members are the
same.

This patch makes two instances of a union compare equal if they are both
null regardless of their active members, and fixes two places where the
old behavior was exploited.
---
 llvm/include/llvm/ADT/PointerUnion.h      | 12 ++++++------
 llvm/lib/CodeGen/RegisterBankInfo.cpp     |  4 ++--
 llvm/lib/Target/AMDGPU/SIRegisterInfo.cpp |  4 ++--
 llvm/unittests/ADT/PointerUnionTest.cpp   | 13 ++++++++++---
 4 files changed, 20 insertions(+), 13 deletions(-)

diff --git a/llvm/include/llvm/ADT/PointerUnion.h b/llvm/include/llvm/ADT/PointerUnion.h
index 7d4ed02b622626..bea9d6a4569b0b 100644
--- a/llvm/include/llvm/ADT/PointerUnion.h
+++ b/llvm/include/llvm/ADT/PointerUnion.h
@@ -198,14 +198,14 @@ class PointerUnion
   }
 };
 
-template <typename ...PTs>
-bool operator==(PointerUnion<PTs...> lhs, PointerUnion<PTs...> rhs) {
-  return lhs.getOpaqueValue() == rhs.getOpaqueValue();
+template <typename... PTs>
+bool operator==(PointerUnion<PTs...> LHS, PointerUnion<PTs...> RHS) {
+  return (!LHS && !RHS) || LHS.getOpaqueValue() == RHS.getOpaqueValue();
 }
 
-template <typename ...PTs>
-bool operator!=(PointerUnion<PTs...> lhs, PointerUnion<PTs...> rhs) {
-  return lhs.getOpaqueValue() != rhs.getOpaqueValue();
+template <typename... PTs>
+bool operator!=(PointerUnion<PTs...> LHS, PointerUnion<PTs...> RHS) {
+  return !operator==(LHS, RHS);
 }
 
 template <typename ...PTs>
diff --git a/llvm/lib/CodeGen/RegisterBankInfo.cpp b/llvm/lib/CodeGen/RegisterBankInfo.cpp
index e1720b038e2361..5a8cf13ad11fd5 100644
--- a/llvm/lib/CodeGen/RegisterBankInfo.cpp
+++ b/llvm/lib/CodeGen/RegisterBankInfo.cpp
@@ -134,10 +134,10 @@ const TargetRegisterClass *RegisterBankInfo::constrainGenericRegister(
 
   // If the register already has a class, fallback to MRI::constrainRegClass.
   auto &RegClassOrBank = MRI.getRegClassOrRegBank(Reg);
-  if (isa<const TargetRegisterClass *>(RegClassOrBank))
+  if (isa_and_present<const TargetRegisterClass *>(RegClassOrBank))
     return MRI.constrainRegClass(Reg, &RC);
 
-  const RegisterBank *RB = cast<const RegisterBank *>(RegClassOrBank);
+  const auto *RB = dyn_cast_if_present<const RegisterBank *>(RegClassOrBank);
   // Otherwise, all we can do is ensure the bank covers the class, and set it.
   if (RB && !RB->covers(RC))
     return nullptr;
diff --git a/llvm/lib/Target/AMDGPU/SIRegisterInfo.cpp b/llvm/lib/Target/AMDGPU/SIRegisterInfo.cpp
index 704435dad65d7b..8fa656c77e90ed 100644
--- a/llvm/lib/Target/AMDGPU/SIRegisterInfo.cpp
+++ b/llvm/lib/Target/AMDGPU/SIRegisterInfo.cpp
@@ -3708,10 +3708,10 @@ const TargetRegisterClass *
 SIRegisterInfo::getConstrainedRegClassForOperand(const MachineOperand &MO,
                                          const MachineRegisterInfo &MRI) const {
   const RegClassOrRegBank &RCOrRB = MRI.getRegClassOrRegBank(MO.getReg());
-  if (const RegisterBank *RB = dyn_cast<const RegisterBank *>(RCOrRB))
+  if (const auto *RB = dyn_cast_if_present<const RegisterBank *>(RCOrRB))
     return getRegClassForTypeOnBank(MRI.getType(MO.getReg()), *RB);
 
-  if (const auto *RC = dyn_cast<const TargetRegisterClass *>(RCOrRB))
+  if (const auto *RC = dyn_cast_if_present<const TargetRegisterClass *>(RCOrRB))
     return getAllocatableClass(RC);
 
   return nullptr;
diff --git a/llvm/unittests/ADT/PointerUnionTest.cpp b/llvm/unittests/ADT/PointerUnionTest.cpp
index acddb789601494..d870339b01397d 100644
--- a/llvm/unittests/ADT/PointerUnionTest.cpp
+++ b/llvm/unittests/ADT/PointerUnionTest.cpp
@@ -53,9 +53,16 @@ TEST_F(PointerUnionTest, Comparison) {
   EXPECT_TRUE(i4 != l4);
   EXPECT_TRUE(f4 != l4);
   EXPECT_TRUE(l4 != d4);
-  EXPECT_TRUE(i4null != f4null);
-  EXPECT_TRUE(i4null != l4null);
-  EXPECT_TRUE(i4null != d4null);
+  EXPECT_TRUE(i4null == f4null);
+  EXPECT_FALSE(i4null != f4null);
+  EXPECT_TRUE(i4null == l4null);
+  EXPECT_FALSE(i4null != l4null);
+  EXPECT_TRUE(i4null == d4null);
+  EXPECT_FALSE(i4null != d4null);
+  EXPECT_FALSE(i4null == i4);
+  EXPECT_TRUE(i4null != i4);
+  EXPECT_FALSE(i4null == f4);
+  EXPECT_TRUE(i4null != f4);
 }
 
 TEST_F(PointerUnionTest, Null) {

>From 13a909854341c6bdea5095a88b8b3a7c256d51c4 Mon Sep 17 00:00:00 2001
From: Sergei Barannikov <barannikov88 at gmail.com>
Date: Wed, 8 Jan 2025 00:11:24 +0300
Subject: [PATCH 2/3] Specialize ValueIsPresent for PointerUnion

---
 llvm/include/llvm/ADT/PointerUnion.h    | 23 +++++++++++++++++------
 llvm/unittests/ADT/PointerUnionTest.cpp | 18 ++++++++----------
 2 files changed, 25 insertions(+), 16 deletions(-)

diff --git a/llvm/include/llvm/ADT/PointerUnion.h b/llvm/include/llvm/ADT/PointerUnion.h
index bea9d6a4569b0b..59bfaf6d3a92d4 100644
--- a/llvm/include/llvm/ADT/PointerUnion.h
+++ b/llvm/include/llvm/ADT/PointerUnion.h
@@ -198,14 +198,14 @@ class PointerUnion
   }
 };
 
-template <typename... PTs>
-bool operator==(PointerUnion<PTs...> LHS, PointerUnion<PTs...> RHS) {
-  return (!LHS && !RHS) || LHS.getOpaqueValue() == RHS.getOpaqueValue();
+template <typename ...PTs>
+bool operator==(PointerUnion<PTs...> lhs, PointerUnion<PTs...> rhs) {
+  return lhs.getOpaqueValue() == rhs.getOpaqueValue();
 }
 
-template <typename... PTs>
-bool operator!=(PointerUnion<PTs...> LHS, PointerUnion<PTs...> RHS) {
-  return !operator==(LHS, RHS);
+template <typename ...PTs>
+bool operator!=(PointerUnion<PTs...> lhs, PointerUnion<PTs...> rhs) {
+  return lhs.getOpaqueValue() != rhs.getOpaqueValue();
 }
 
 template <typename ...PTs>
@@ -259,6 +259,17 @@ struct CastInfo<To, const PointerUnion<PTs...>>
                                           CastInfo<To, PointerUnion<PTs...>>> {
 };
 
+// The default implementation of isPresent() for nullable types returns true
+// if the active member is not the first one, even if its value is nullptr.
+// Override the default behavior to return false for all possible null values.
+template <typename... PTs>
+struct ValueIsPresent<PointerUnion<PTs...>,
+                      std::enable_if_t<IsNullable<PointerUnion<PTs...>>>> {
+  using Union = PointerUnion<PTs...>;
+  static bool isPresent(const Union &V) { return static_cast<bool>(V); }
+  static decltype(auto) unwrapValue(Union &V) { return V; }
+};
+
 // Teach SmallPtrSet that PointerUnion is "basically a pointer", that has
 // # low bits available = min(PT1bits,PT2bits)-1.
 template <typename ...PTs>
diff --git a/llvm/unittests/ADT/PointerUnionTest.cpp b/llvm/unittests/ADT/PointerUnionTest.cpp
index d870339b01397d..a28d532865cbc1 100644
--- a/llvm/unittests/ADT/PointerUnionTest.cpp
+++ b/llvm/unittests/ADT/PointerUnionTest.cpp
@@ -53,16 +53,9 @@ TEST_F(PointerUnionTest, Comparison) {
   EXPECT_TRUE(i4 != l4);
   EXPECT_TRUE(f4 != l4);
   EXPECT_TRUE(l4 != d4);
-  EXPECT_TRUE(i4null == f4null);
-  EXPECT_FALSE(i4null != f4null);
-  EXPECT_TRUE(i4null == l4null);
-  EXPECT_FALSE(i4null != l4null);
-  EXPECT_TRUE(i4null == d4null);
-  EXPECT_FALSE(i4null != d4null);
-  EXPECT_FALSE(i4null == i4);
-  EXPECT_TRUE(i4null != i4);
-  EXPECT_FALSE(i4null == f4);
-  EXPECT_TRUE(i4null != f4);
+  EXPECT_TRUE(i4null != f4null);
+  EXPECT_TRUE(i4null != l4null);
+  EXPECT_TRUE(i4null != d4null);
 }
 
 TEST_F(PointerUnionTest, Null) {
@@ -215,6 +208,11 @@ TEST_F(PointerUnionTest, NewCastInfra) {
   EXPECT_FALSE(isa<float *>(d4null));
   EXPECT_FALSE(isa<long long *>(d4null));
 
+  EXPECT_FALSE(isa_and_present<int *>(i4null));
+  EXPECT_FALSE(isa_and_present<float *>(f4null));
+  EXPECT_FALSE(isa_and_present<long long *>(l4null));
+  EXPECT_FALSE(isa_and_present<double *>(d4null));
+
   // test cast<>
   EXPECT_EQ(cast<float *>(a), &f);
   EXPECT_EQ(cast<int *>(b), &i);

>From 8abd5cd49031c0b6b956d6b15d957bdf663a93f4 Mon Sep 17 00:00:00 2001
From: Sergei Barannikov <barannikov88 at gmail.com>
Date: Wed, 8 Jan 2025 13:53:13 +0300
Subject: [PATCH 3/3] Update llvm/include/llvm/ADT/PointerUnion.h

Co-authored-by: Jakub Kuderski <kubakuderski at gmail.com>
---
 llvm/include/llvm/ADT/PointerUnion.h | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/llvm/include/llvm/ADT/PointerUnion.h b/llvm/include/llvm/ADT/PointerUnion.h
index 59bfaf6d3a92d4..fb854665322c26 100644
--- a/llvm/include/llvm/ADT/PointerUnion.h
+++ b/llvm/include/llvm/ADT/PointerUnion.h
@@ -267,7 +267,7 @@ struct ValueIsPresent<PointerUnion<PTs...>,
                       std::enable_if_t<IsNullable<PointerUnion<PTs...>>>> {
   using Union = PointerUnion<PTs...>;
   static bool isPresent(const Union &V) { return static_cast<bool>(V); }
-  static decltype(auto) unwrapValue(Union &V) { return V; }
+  static Union &unwrapValue(Union &V) { return V; }
 };
 
 // Teach SmallPtrSet that PointerUnion is "basically a pointer", that has



More information about the llvm-commits mailing list