[llvm] [BitmaskEnum] Add support for shift operators. (PR #118007)

Sander de Smalen via llvm-commits llvm-commits at lists.llvm.org
Thu Nov 28 07:51:52 PST 2024


https://github.com/sdesmalen-arm updated https://github.com/llvm/llvm-project/pull/118007

>From 6c851dcd793536360cc5c48454b559dd45f3ce9e Mon Sep 17 00:00:00 2001
From: Sander de Smalen <sander.desmalen at arm.com>
Date: Thu, 28 Nov 2024 13:30:16 +0000
Subject: [PATCH] [BitmaskEnum] Add support for shift operators.

For enums that describe a bitmask where successive bits within that mask
describe some enum value (as described in the same enum), it is useful
to support operator<< and operator>> as well.

For example:

  enum class E : unsigned {
    // 2 bits per option
    OptionA = 0,
    OptionB = 1,
    OptionC = 2,
    OptionD = 3,
    OptionMask = 3,

    // Given 3 values in the bitmask X, Y and Z, each is 2 bits in size
    // and represents a choice of OptionA..OptionD.
    ShiftX = 0,
    ShiftY = 2,
    ShiftZ = 4,
  };

  // The mask can be encoded with:
  E mask;
  mask |= getOptionFor(X) << E::ShiftX;
  mask |= getOptionFor(Y) << E::ShiftY;
  mask |= getOptionFor(Z) << E::ShiftZ;

  // And to extract a value:
  E OptionForX =  (mask >> E::ShiftX) & E::OptionMask;
  E OptionForY =  (mask >> E::ShiftY) & E::OptionMask;
  E OptionForZ =  (mask >> E::ShiftZ) & E::OptionMask;
---
 llvm/include/llvm/ADT/BitmaskEnum.h    | 26 ++++++++++++++++++++++++++
 llvm/unittests/ADT/BitmaskEnumTest.cpp | 16 ++++++++++++++++
 2 files changed, 42 insertions(+)

diff --git a/llvm/include/llvm/ADT/BitmaskEnum.h b/llvm/include/llvm/ADT/BitmaskEnum.h
index c87e7cac65a5b1..dcb13bd8ba51a5 100644
--- a/llvm/include/llvm/ADT/BitmaskEnum.h
+++ b/llvm/include/llvm/ADT/BitmaskEnum.h
@@ -85,9 +85,13 @@
   using ::llvm::BitmaskEnumDetail::operator|;                                  \
   using ::llvm::BitmaskEnumDetail::operator&;                                  \
   using ::llvm::BitmaskEnumDetail::operator^;                                  \
+  using ::llvm::BitmaskEnumDetail::operator<<;                                 \
+  using ::llvm::BitmaskEnumDetail::operator>>;                                 \
   using ::llvm::BitmaskEnumDetail::operator|=;                                 \
   using ::llvm::BitmaskEnumDetail::operator&=;                                 \
   using ::llvm::BitmaskEnumDetail::operator^=;                                 \
+  using ::llvm::BitmaskEnumDetail::operator<<=;                                \
+  using ::llvm::BitmaskEnumDetail::operator>>=;                                \
   /* Force a semicolon at the end of this macro. */                            \
   using ::llvm::BitmaskEnumDetail::any
 
@@ -162,6 +166,16 @@ constexpr E operator^(E LHS, E RHS) {
   return static_cast<E>(Underlying(LHS) ^ Underlying(RHS));
 }
 
+template <typename E, typename = std::enable_if_t<is_bitmask_enum<E>::value>>
+constexpr E operator<<(E LHS, E RHS) {
+  return static_cast<E>(Underlying(LHS) << Underlying(RHS));
+}
+
+template <typename E, typename = std::enable_if_t<is_bitmask_enum<E>::value>>
+constexpr E operator>>(E LHS, E RHS) {
+  return static_cast<E>(Underlying(LHS) >> Underlying(RHS));
+}
+
 // |=, &=, and ^= return a reference to LHS, to match the behavior of the
 // operators on builtin types.
 
@@ -183,6 +197,18 @@ E &operator^=(E &LHS, E RHS) {
   return LHS;
 }
 
+template <typename e, typename = std::enable_if_t<is_bitmask_enum<e>::value>>
+e &operator<<=(e &lhs, e rhs) {
+  lhs = lhs << rhs;
+  return lhs;
+}
+
+template <typename e, typename = std::enable_if_t<is_bitmask_enum<e>::value>>
+e &operator>>=(e &lhs, e rhs) {
+  lhs = lhs >> rhs;
+  return lhs;
+}
+
 } // namespace BitmaskEnumDetail
 
 // Enable bitmask enums in namespace ::llvm and all nested namespaces.
diff --git a/llvm/unittests/ADT/BitmaskEnumTest.cpp b/llvm/unittests/ADT/BitmaskEnumTest.cpp
index c78937c3571fd1..2c0a80342a54c3 100644
--- a/llvm/unittests/ADT/BitmaskEnumTest.cpp
+++ b/llvm/unittests/ADT/BitmaskEnumTest.cpp
@@ -130,6 +130,22 @@ TEST(BitmaskEnumTest, BitwiseXorEquals) {
   EXPECT_EQ(V3, f2);
 }
 
+TEST(BitmaskEnumTest, BitwiseShift) {
+  Flags f = (F1 << F1);
+  EXPECT_EQ(f, F2);
+
+  Flags f2 = F1;
+  f2 <<= F1;
+  EXPECT_EQ(f2, F2);
+
+  Flags f3 = (F1 >> F1);
+  EXPECT_EQ(f3, F0);
+
+  Flags f4 = F1;
+  f4 >>= F1;
+  EXPECT_EQ(f4, F0);
+}
+
 TEST(BitmaskEnumTest, ConstantExpression) {
   constexpr Flags f1 = ~F1;
   constexpr Flags f2 = F1 | F2;



More information about the llvm-commits mailing list