[llvm] [CodeGen] Use 128bits for LaneBitmask. (PR #111157)

Sander de Smalen via llvm-commits llvm-commits at lists.llvm.org
Fri Oct 4 06:51:19 PDT 2024


https://github.com/sdesmalen-arm created https://github.com/llvm/llvm-project/pull/111157

This follows on from the conversation on #109797.

FWIW, I've considered several alternatives to this patch;

(1) Using APInt as storage type rather than 'uint64_t Mask[2]'.

(1) makes the code a bit simpler to read, but APInt by default only allocates space for a 64-bit value and otherwise dynamically allocates a larger buffer to represent the larger value. Because we know the value is always 128bit, this extra dynamic allocation is undesirable. We also rarely need the full power of APInt since most of the tests are bitwise operations, so I made the choice to represent it as a uint64_t array instead, and only moving to/from APInt when this is necessary.

Because it is inconvenient that increasing the BitWidth applies to all targets, I tried to see if there is a way to make the bitwidth dynamic, by:

(2) Making the BitWidth dynamic per target by passing it to
    constructors.
(3) Modification of (2) that describes 'none', 'all' and 'lane' with
    an an enum which doesn't require a BitWidth, until doing
    arithmetic with an explicit bit mask (which does have a BitWidth).

Unfortunately both these approaches don't seem feasible. For (2) that is because it would require passing the TargetRegisterInfo/MCRegisterInfo to many places where this info is not available, where it needs to instantiates a LaneBitmask value.

Approach (3) leads to other issues such as questions like 'what is the meaning of 'operator==' when one value is a mask and the other is a 'all' enum?' If we let 'operator==' discard the bitwidth such that a 64-bit all-true bitmask == LaneBitmask::all() (using 'all' enum), then we could end up in a situation where:

  X == LaneBitmask::all() && Y == LaneBitmask::all()

but `X != Y`.

I considered replacing the equality operators by methods that take a RegisterInfo pointer, but the LaneBitmask struct is used in STL containers which require a plain 'operator==' or 'operator<'. We could work around that by providing custom lambdas (that call the method with the TargetInfo pointer), but this just gets increasingly more hacky.

Perhaps just using more bits isn't actually that bad in practice.

>From a782fcf51af8d832c372396ad52c24029723028a Mon Sep 17 00:00:00 2001
From: Sander de Smalen <sander.desmalen at arm.com>
Date: Wed, 21 Aug 2024 14:56:10 +0100
Subject: [PATCH] [CodeGen] Use 128bits for LaneBitmask.

This follows on from the conversation on #109797.

FWIW, I've considered several alternatives to this patch;

(1) Using APInt as storage type rather than 'uint64_t Mask[2]'.

(1) makes the code a bit simpler to read, but APInt by default
only allocates space for a 64-bit value and otherwise dynamically
allocates a larger buffer to represent the larger value. Because we
know the value is always 128bit, this extra dynamic allocation is
undesirable. We also rarely need the full power of APInt since most of
the tests are bitwise operations, so I made the choice to represent it
as a uint64_t array instead, and only moving to/from APInt when this
is necessary.

Because it is inconvenient that increasing the BitWidth applies to all
targets, I tried to see if there is a way to make the bitwidth dynamic,
by:

(2) Making the BitWidth dynamic per target by passing it to
    constructors.
(3) Modification of (2) that describes 'none', 'all' and 'lane' with
    an an enum which doesn't require a BitWidth, until doing
    arithmetic with an explicit bit mask (which does have a BitWidth).

Unfortunately both these approaches don't seem feasible. For (2) that is
because it would require passing the TargetRegisterInfo/MCRegisterInfo
to many places where this info is not available, where it needs to
instantiates a LaneBitmask value.

Approach (3) leads to other issues such as questions like 'what is the
meaning of 'operator==' when one value is a mask and the other is a
'all' enum?' If we let 'operator==' discard the bitwidth such that a
64-bit all-true bitmask == LaneBitmask::all() (using 'all' enum), then
we could end up in a situation where:

  X == LaneBitmask::all() && Y == LaneBitmask::all()

but `X != Y`.

I considered replacing the equality operators by methods that take
a RegisterInfo pointer, but the LaneBitmask struct is used in STL
containers which require a plain 'operator==' or 'operator<'. We
could work around that by providing custom lambdas (that call the
method with the TargetInfo pointer), but this just gets increasingly
more hacky.

Perhaps just using more bits isn't actually that bad in practice.
---
 llvm/include/llvm/CodeGen/RDFLiveness.h     |   4 +-
 llvm/include/llvm/CodeGen/RDFRegisters.h    |   4 +-
 llvm/include/llvm/MC/LaneBitmask.h          | 146 +++++++++++++-------
 llvm/lib/CodeGen/MIRParser/MIParser.cpp     |  45 ++++--
 llvm/lib/CodeGen/MIRPrinter.cpp             |  10 +-
 llvm/lib/CodeGen/RDFRegisters.cpp           |  10 +-
 llvm/lib/Target/AMDGPU/SIRegisterInfo.h     |   5 +-
 llvm/test/CodeGen/AArch64/lanebitmask.mir   |  18 +++
 llvm/unittests/CodeGen/MFCommon.inc         |   2 +-
 llvm/unittests/MC/CMakeLists.txt            |   2 +-
 llvm/unittests/MC/LaneBitmaskTest.cpp       |  68 +++++++++
 llvm/utils/TableGen/RegisterInfoEmitter.cpp |  23 +--
 12 files changed, 253 insertions(+), 84 deletions(-)
 create mode 100644 llvm/test/CodeGen/AArch64/lanebitmask.mir
 create mode 100644 llvm/unittests/MC/LaneBitmaskTest.cpp

diff --git a/llvm/include/llvm/CodeGen/RDFLiveness.h b/llvm/include/llvm/CodeGen/RDFLiveness.h
index fe1034f9b6f8ef..46ec82e77ba492 100644
--- a/llvm/include/llvm/CodeGen/RDFLiveness.h
+++ b/llvm/include/llvm/CodeGen/RDFLiveness.h
@@ -43,8 +43,8 @@ namespace std {
 
 template <> struct hash<llvm::rdf::detail::NodeRef> {
   std::size_t operator()(llvm::rdf::detail::NodeRef R) const {
-    return std::hash<llvm::rdf::NodeId>{}(R.first) ^
-           std::hash<llvm::LaneBitmask::Type>{}(R.second.getAsInteger());
+    return llvm::hash_value<llvm::rdf::NodeId>(R.first) ^
+           llvm::hash_value(R.second.getAsPair());
   }
 };
 
diff --git a/llvm/include/llvm/CodeGen/RDFRegisters.h b/llvm/include/llvm/CodeGen/RDFRegisters.h
index 7eed0b4e1e7b8f..87c38a215e006e 100644
--- a/llvm/include/llvm/CodeGen/RDFRegisters.h
+++ b/llvm/include/llvm/CodeGen/RDFRegisters.h
@@ -106,8 +106,8 @@ struct RegisterRef {
   }
 
   size_t hash() const {
-    return std::hash<RegisterId>{}(Reg) ^
-           std::hash<LaneBitmask::Type>{}(Mask.getAsInteger());
+    return llvm::hash_value<RegisterId>(Reg) ^
+           llvm::hash_value(Mask.getAsPair());
   }
 
   static constexpr bool isRegId(unsigned Id) {
diff --git a/llvm/include/llvm/MC/LaneBitmask.h b/llvm/include/llvm/MC/LaneBitmask.h
index c06ca7dd5b8fcd..ba22257c650a70 100644
--- a/llvm/include/llvm/MC/LaneBitmask.h
+++ b/llvm/include/llvm/MC/LaneBitmask.h
@@ -29,72 +29,120 @@
 #ifndef LLVM_MC_LANEBITMASK_H
 #define LLVM_MC_LANEBITMASK_H
 
+#include "llvm/ADT/APInt.h"
+#include "llvm/ADT/SmallString.h"
 #include "llvm/Support/Compiler.h"
 #include "llvm/Support/Format.h"
 #include "llvm/Support/MathExtras.h"
 #include "llvm/Support/Printable.h"
 #include "llvm/Support/raw_ostream.h"
+#include <utility>
 
 namespace llvm {
 
-  struct LaneBitmask {
-    // When changing the underlying type, change the format string as well.
-    using Type = uint64_t;
-    enum : unsigned { BitWidth = 8*sizeof(Type) };
-    constexpr static const char *const FormatStr = "%016llX";
+struct LaneBitmask {
+  static constexpr unsigned int BitWidth = 128;
 
-    constexpr LaneBitmask() = default;
-    explicit constexpr LaneBitmask(Type V) : Mask(V) {}
-
-    constexpr bool operator== (LaneBitmask M) const { return Mask == M.Mask; }
-    constexpr bool operator!= (LaneBitmask M) const { return Mask != M.Mask; }
-    constexpr bool operator< (LaneBitmask M)  const { return Mask < M.Mask; }
-    constexpr bool none() const { return Mask == 0; }
-    constexpr bool any()  const { return Mask != 0; }
-    constexpr bool all()  const { return ~Mask == 0; }
-
-    constexpr LaneBitmask operator~() const {
-      return LaneBitmask(~Mask);
-    }
-    constexpr LaneBitmask operator|(LaneBitmask M) const {
-      return LaneBitmask(Mask | M.Mask);
-    }
-    constexpr LaneBitmask operator&(LaneBitmask M) const {
-      return LaneBitmask(Mask & M.Mask);
-    }
-    LaneBitmask &operator|=(LaneBitmask M) {
-      Mask |= M.Mask;
-      return *this;
-    }
-    LaneBitmask &operator&=(LaneBitmask M) {
-      Mask &= M.Mask;
-      return *this;
+  explicit LaneBitmask(APInt V) {
+    switch (V.getBitWidth()) {
+    case BitWidth:
+      Mask[0] = V.getRawData()[0];
+      Mask[1] = V.getRawData()[1];
+      break;
+    default:
+      llvm_unreachable("Unsupported bitwidth");
     }
+  }
+  constexpr explicit LaneBitmask(uint64_t Lo = 0, uint64_t Hi = 0) : Mask{Lo, Hi} {}
 
-    constexpr Type getAsInteger() const { return Mask; }
+  constexpr bool operator==(LaneBitmask M) const {
+    return Mask[0] == M.Mask[0] && Mask[1] == M.Mask[1];
+  }
+  constexpr bool operator!=(LaneBitmask M) const {
+    return Mask[0] != M.Mask[0] || Mask[1] != M.Mask[1];
+  }
+  constexpr bool operator<(LaneBitmask M) const {
+    return Mask[1] < M.Mask[1] || Mask[0] < M.Mask[0];
+  }
+  constexpr bool none() const { return Mask[0] == 0 && Mask[1] == 0; }
+  constexpr bool any() const { return Mask[0] != 0 || Mask[1] != 0; }
+  constexpr bool all() const { return ~Mask[0] == 0 && ~Mask[1] == 0; }
 
-    unsigned getNumLanes() const { return llvm::popcount(Mask); }
-    unsigned getHighestLane() const {
-      return Log2_64(Mask);
-    }
+  constexpr LaneBitmask operator~() const { return LaneBitmask(~Mask[0], ~Mask[1]); }
+  constexpr LaneBitmask operator|(LaneBitmask M) const {
+    return LaneBitmask(Mask[0] | M.Mask[0], Mask[1] | M.Mask[1]);
+  }
+  constexpr LaneBitmask operator&(LaneBitmask M) const {
+    return LaneBitmask(Mask[0] & M.Mask[0], Mask[1] & M.Mask[1]);
+  }
+  LaneBitmask &operator|=(LaneBitmask M) {
+    Mask[0] |= M.Mask[0];
+    Mask[1] |= M.Mask[1];
+    return *this;
+  }
+  LaneBitmask &operator&=(LaneBitmask M) {
+    Mask[0] &= M.Mask[0];
+    Mask[1] &= M.Mask[1];
+    return *this;
+  }
 
-    static constexpr LaneBitmask getNone() { return LaneBitmask(0); }
-    static constexpr LaneBitmask getAll() { return ~LaneBitmask(0); }
-    static constexpr LaneBitmask getLane(unsigned Lane) {
-      return LaneBitmask(Type(1) << Lane);
-    }
+  APInt getAsAPInt() const { return APInt(BitWidth, {Mask[0], Mask[1]}); }
+  constexpr std::pair<uint64_t, uint64_t> getAsPair() const { return {Mask[0], Mask[1]}; }
 
-  private:
-    Type Mask = 0;
-  };
+  unsigned getNumLanes() const {
+    return Mask[1] ? llvm::popcount(Mask[1]) + llvm::popcount(Mask[0])
+                   : llvm::popcount(Mask[0]);
+  }
+  unsigned getHighestLane() const {
+    return Mask[1] ? Log2_64(Mask[1]) + 64 : Log2_64(Mask[0]);
+  }
 
-  /// Create Printable object to print LaneBitmasks on a \ref raw_ostream.
-  inline Printable PrintLaneMask(LaneBitmask LaneMask) {
-    return Printable([LaneMask](raw_ostream &OS) {
-      OS << format(LaneBitmask::FormatStr, LaneMask.getAsInteger());
-    });
+  static constexpr LaneBitmask getNone() { return LaneBitmask(0, 0); }
+  static constexpr LaneBitmask getAll() { return ~LaneBitmask(0, 0); }
+  static constexpr LaneBitmask getLane(unsigned Lane) {
+    return Lane >= 64 ? LaneBitmask(0, 1ULL << (Lane % 64))
+                      : LaneBitmask(1ULL << Lane, 0);
   }
 
+private:
+  uint64_t Mask[2];
+};
+
+/// Create Printable object to print LaneBitmasks on a \ref raw_ostream.
+/// If \p FormatAsCLiterals is true, it will print the bitmask as
+/// a hexadecimal C literal with zero padding, or a list of such C literals if
+/// the value cannot be represented in 64 bits.
+/// For example (FormatAsCliterals == true)
+///   bitmask '1'       => "0x0000000000000001"
+///   bitmask '1 << 64' => "0x0000000000000000,0x0000000000000001"
+/// (FormatAsCLiterals == false)
+///   bitmask '1'       => "00000000000000000000000000000001"
+///   bitmask '1 << 64' => "00000000000000010000000000000000"
+inline Printable PrintLaneMask(LaneBitmask LaneMask,
+                               bool FormatAsCLiterals = false) {
+  return Printable([LaneMask, FormatAsCLiterals](raw_ostream &OS) {
+    SmallString<64> Buffer;
+    APInt V = LaneMask.getAsAPInt();
+    while (true) {
+      unsigned Bitwidth = FormatAsCLiterals ? 64 : LaneBitmask::BitWidth;
+      APInt VToPrint = V.trunc(Bitwidth);
+
+      Buffer.clear();
+      VToPrint.toString(Buffer, 16, /*Signed=*/false,
+                        /*formatAsCLiteral=*/false);
+      unsigned NumZeroesToPad =
+          (VToPrint.countLeadingZeros() / 4) - VToPrint.isZero();
+      OS << (FormatAsCLiterals ? "0x" : "") << std::string(NumZeroesToPad, '0')
+         << Buffer.str();
+      V = V.lshr(Bitwidth);
+      if (V.getActiveBits())
+        OS << ",";
+      else
+        break;
+    }
+  });
+}
+
 } // end namespace llvm
 
 #endif // LLVM_MC_LANEBITMASK_H
diff --git a/llvm/lib/CodeGen/MIRParser/MIParser.cpp b/llvm/lib/CodeGen/MIRParser/MIParser.cpp
index 27f0a9331a3e3e..6b6b5be910fdc4 100644
--- a/llvm/lib/CodeGen/MIRParser/MIParser.cpp
+++ b/llvm/lib/CodeGen/MIRParser/MIParser.cpp
@@ -870,17 +870,40 @@ bool MIParser::parseBasicBlockLiveins(MachineBasicBlock &MBB) {
     lex();
     LaneBitmask Mask = LaneBitmask::getAll();
     if (consumeIfPresent(MIToken::colon)) {
-      // Parse lane mask.
-      if (Token.isNot(MIToken::IntegerLiteral) &&
-          Token.isNot(MIToken::HexLiteral))
-        return error("expected a lane mask");
-      static_assert(sizeof(LaneBitmask::Type) == sizeof(uint64_t),
-                    "Use correct get-function for lane mask");
-      LaneBitmask::Type V;
-      if (getUint64(V))
-        return error("invalid lane mask value");
-      Mask = LaneBitmask(V);
-      lex();
+      if (consumeIfPresent(MIToken::lparen)) {
+        // We need to parse a list of literals
+        SmallVector<uint64_t, 2> Literals;
+        while (true) {
+          if (Token.isNot(MIToken::HexLiteral))
+            return error("expected a lane mask");
+          APInt V;
+          getHexUint(V);
+          Literals.push_back(V.getZExtValue());
+          // Lex past literal
+          lex();
+          if (Token.is(MIToken::rparen))
+            break;
+          else if (Token.isNot(MIToken::comma))
+            return error("expected a comma");
+          // Lex past comma
+          lex();
+        }
+        // Lex past rparen
+        lex();
+        Mask = LaneBitmask(APInt(LaneBitmask::BitWidth, Literals));
+      } else {
+        // Parse lane mask.
+        APInt V;
+        if (Token.is(MIToken::IntegerLiteral)) {
+          uint64_t UV;
+          if (getUint64(UV))
+            return error("invalid lane mask value");
+          V = APInt(LaneBitmask::BitWidth, UV);
+        } else if (getHexUint(V))
+          return error("expected a lane mask");
+        Mask = LaneBitmask(APInt(LaneBitmask::BitWidth, V.getZExtValue()));
+        lex();
+      }
     }
     MBB.addLiveIn(Reg, Mask);
   } while (consumeIfPresent(MIToken::comma));
diff --git a/llvm/lib/CodeGen/MIRPrinter.cpp b/llvm/lib/CodeGen/MIRPrinter.cpp
index cf6122bce22364..349e3aaa87f72a 100644
--- a/llvm/lib/CodeGen/MIRPrinter.cpp
+++ b/llvm/lib/CodeGen/MIRPrinter.cpp
@@ -732,8 +732,14 @@ void MIPrinter::print(const MachineBasicBlock &MBB) {
         OS << ", ";
       First = false;
       OS << printReg(LI.PhysReg, &TRI);
-      if (!LI.LaneMask.all())
-        OS << ":0x" << PrintLaneMask(LI.LaneMask);
+      if (!LI.LaneMask.all()) {
+        OS << ":";
+        if (LI.LaneMask.getAsAPInt().getActiveBits() <= 64)
+          OS << PrintLaneMask(LI.LaneMask, /*FormatAsCLiterals=*/true);
+        else
+          OS << '(' << PrintLaneMask(LI.LaneMask, /*FormatAsCLiterals=*/true)
+             << ')';
+      }
     }
     OS << "\n";
     HasLineAttributes = true;
diff --git a/llvm/lib/CodeGen/RDFRegisters.cpp b/llvm/lib/CodeGen/RDFRegisters.cpp
index 7ce00a66b3ae6c..3229647379207f 100644
--- a/llvm/lib/CodeGen/RDFRegisters.cpp
+++ b/llvm/lib/CodeGen/RDFRegisters.cpp
@@ -412,11 +412,11 @@ raw_ostream &operator<<(raw_ostream &OS, const PrintLaneMaskShort &P) {
   if (P.Mask.none())
     return OS << ":*none*";
 
-  LaneBitmask::Type Val = P.Mask.getAsInteger();
-  if ((Val & 0xffff) == Val)
-    return OS << ':' << format("%04llX", Val);
-  if ((Val & 0xffffffff) == Val)
-    return OS << ':' << format("%08llX", Val);
+  APInt Val = P.Mask.getAsAPInt();
+  if (Val.getActiveBits() <= 16)
+    return OS << ':' << format("%04llX", Val.getZExtValue());
+  if (Val.getActiveBits() <= 32)
+    return OS << ':' << format("%08llX", Val.getZExtValue());
   return OS << ':' << PrintLaneMask(P.Mask);
 }
 
diff --git a/llvm/lib/Target/AMDGPU/SIRegisterInfo.h b/llvm/lib/Target/AMDGPU/SIRegisterInfo.h
index 88d5686720985e..a7796906779594 100644
--- a/llvm/lib/Target/AMDGPU/SIRegisterInfo.h
+++ b/llvm/lib/Target/AMDGPU/SIRegisterInfo.h
@@ -377,7 +377,10 @@ class SIRegisterInfo final : public AMDGPUGenRegisterInfo {
   static unsigned getNumCoveredRegs(LaneBitmask LM) {
     // The assumption is that every lo16 subreg is an even bit and every hi16
     // is an adjacent odd bit or vice versa.
-    uint64_t Mask = LM.getAsInteger();
+    APInt MaskV = LM.getAsAPInt();
+    assert(MaskV.getActiveBits() <= 64 &&
+           "uint64_t is insufficient to represent lane bitmask operation");
+    uint64_t Mask = MaskV.getZExtValue();
     uint64_t Even = Mask & 0xAAAAAAAAAAAAAAAAULL;
     Mask = (Even >> 1) | Mask;
     uint64_t Odd = Mask & 0x5555555555555555ULL;
diff --git a/llvm/test/CodeGen/AArch64/lanebitmask.mir b/llvm/test/CodeGen/AArch64/lanebitmask.mir
new file mode 100644
index 00000000000000..13178ea5cff4e3
--- /dev/null
+++ b/llvm/test/CodeGen/AArch64/lanebitmask.mir
@@ -0,0 +1,18 @@
+# NOTE: Assertions have been autogenerated by utils/update_mir_test_checks.py UTC_ARGS: --version 5
+# RUN: llc -o - %s -mtriple=aarch64 -stop-before=greedy | FileCheck %s
+---
+name:            test_parse_lanebitmask
+tracksRegLiveness: true
+liveins:
+  - { reg: '$h0' }
+  - { reg: '$s1' }
+body:             |
+  bb.0:
+    liveins: $h0:0x0000000000000001, $s1:(0x0000000000000001,0x0000000000000000)
+    ; CHECK-LABEL: name: test_parse_lanebitmask
+    ; CHECK: liveins: $h0:0x0000000000000001, $s1:0x0000000000000001, $h0, $s1
+    ; CHECK-NEXT: {{  $}}
+    ; CHECK-NEXT: RET_ReallyLR
+    RET_ReallyLR
+...
+
diff --git a/llvm/unittests/CodeGen/MFCommon.inc b/llvm/unittests/CodeGen/MFCommon.inc
index 5d5720c3162da9..720c5e43586d7e 100644
--- a/llvm/unittests/CodeGen/MFCommon.inc
+++ b/llvm/unittests/CodeGen/MFCommon.inc
@@ -23,7 +23,7 @@ class BogusRegisterInfo : public TargetRegisterInfo {
 public:
   BogusRegisterInfo()
       : TargetRegisterInfo(nullptr, BogusRegisterClasses, BogusRegisterClasses,
-                           nullptr, nullptr, nullptr, LaneBitmask(~0u), nullptr,
+                           nullptr, nullptr, nullptr, LaneBitmask::getAll(), nullptr,
                            nullptr) {
     InitMCRegisterInfo(nullptr, 0, 0, 0, nullptr, 0, nullptr, 0, nullptr,
                        nullptr, nullptr, nullptr, nullptr, 0, nullptr);
diff --git a/llvm/unittests/MC/CMakeLists.txt b/llvm/unittests/MC/CMakeLists.txt
index da8e219113f465..fabf0f8512786a 100644
--- a/llvm/unittests/MC/CMakeLists.txt
+++ b/llvm/unittests/MC/CMakeLists.txt
@@ -17,9 +17,9 @@ add_llvm_unittest(MCTests
   Disassembler.cpp
   DwarfLineTables.cpp
   DwarfLineTableHeaders.cpp
+  LaneBitmaskTest.cpp
   MCInstPrinter.cpp
   StringTableBuilderTest.cpp
   TargetRegistry.cpp
   MCDisassemblerTest.cpp
   )
-
diff --git a/llvm/unittests/MC/LaneBitmaskTest.cpp b/llvm/unittests/MC/LaneBitmaskTest.cpp
new file mode 100644
index 00000000000000..d7762122e03220
--- /dev/null
+++ b/llvm/unittests/MC/LaneBitmaskTest.cpp
@@ -0,0 +1,68 @@
+//===------------------ LaneBitmaskTest.cpp -------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "gtest/gtest.h"
+#include "llvm/MC/LaneBitmask.h"
+#include "llvm/Support/raw_ostream.h"
+#include <string>
+
+using namespace llvm;
+
+TEST(LaneBitmaskTest, Basic) {
+  EXPECT_EQ(LaneBitmask::getAll(), ~LaneBitmask::getNone());
+  EXPECT_EQ(LaneBitmask::getNone(), ~LaneBitmask::getAll());
+  EXPECT_EQ(LaneBitmask::getLane(0) | LaneBitmask::getLane(1), LaneBitmask(3));
+  EXPECT_EQ(LaneBitmask(3) & LaneBitmask::getLane(1), LaneBitmask::getLane(1));
+
+  EXPECT_EQ(LaneBitmask(APInt(128, 42)).getAsAPInt(), APInt(128, 42));
+  EXPECT_EQ(LaneBitmask(3).getNumLanes(), 2);
+  EXPECT_EQ(LaneBitmask::getLane(0).getHighestLane(), 0);
+  EXPECT_EQ(LaneBitmask::getLane(64).getHighestLane(), 64);
+  EXPECT_EQ(LaneBitmask::getLane(127).getHighestLane(), 127);
+
+  EXPECT_LT(LaneBitmask::getLane(64), LaneBitmask::getLane(65));
+  EXPECT_LT(LaneBitmask::getLane(63), LaneBitmask::getLane(64));
+  EXPECT_LT(LaneBitmask::getLane(62), LaneBitmask::getLane(63));
+
+  LaneBitmask X(1);
+  X |= LaneBitmask(2);
+  EXPECT_EQ(X, LaneBitmask(3));
+
+  LaneBitmask Y(3);
+  Y &= LaneBitmask(1);
+  EXPECT_EQ(Y, LaneBitmask(1));
+}
+
+TEST(LaneBitmaskTest, Print) {
+  std::string buffer;
+  raw_string_ostream OS(buffer);
+
+  buffer = "";
+  OS << PrintLaneMask(LaneBitmask::getAll(), /*FormatAsCLiterals=*/true);
+  EXPECT_STREQ(OS.str().data(), "0xFFFFFFFFFFFFFFFF,0xFFFFFFFFFFFFFFFF");
+
+  buffer = "";
+  OS << PrintLaneMask(LaneBitmask::getAll(), /*FormatAsCLiterals=*/false);
+  EXPECT_STREQ(OS.str().data(), "FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF");
+
+  buffer = "";
+  OS << PrintLaneMask(LaneBitmask::getLane(0), /*FormatAsCLiterals=*/true);
+  EXPECT_STREQ(OS.str().data(), "0x0000000000000001");
+
+  buffer = "";
+  OS << PrintLaneMask(LaneBitmask::getLane(63), /*FormatAsCLiterals=*/true);
+  EXPECT_STREQ(OS.str().data(), "0x8000000000000000");
+
+  buffer = "";
+  OS << PrintLaneMask(LaneBitmask::getNone(), /*FormatAsCLiterals=*/true);
+  EXPECT_STREQ(OS.str().data(), "0x0000000000000000");
+
+  buffer = "";
+  OS << PrintLaneMask(LaneBitmask::getLane(64), /*FormatAsCLiterals=*/true);
+  EXPECT_STREQ(OS.str().data(), "0x0000000000000000,0x0000000000000001");
+}
diff --git a/llvm/utils/TableGen/RegisterInfoEmitter.cpp b/llvm/utils/TableGen/RegisterInfoEmitter.cpp
index 7d81a83ef2b0a6..bf8aea096c1cf1 100644
--- a/llvm/utils/TableGen/RegisterInfoEmitter.cpp
+++ b/llvm/utils/TableGen/RegisterInfoEmitter.cpp
@@ -653,7 +653,8 @@ static DiffVec &diffEncode(DiffVec &V, unsigned InitVal, Iter Begin, Iter End) {
 static void printDiff16(raw_ostream &OS, int16_t Val) { OS << Val; }
 
 static void printMask(raw_ostream &OS, LaneBitmask Val) {
-  OS << "LaneBitmask(0x" << PrintLaneMask(Val) << ')';
+  OS << "LaneBitmask("
+     << PrintLaneMask(Val, /*FormatAsCLiteral=*/true) << ")";
 }
 
 // Try to combine Idx's compose map into Vec if it is compatible.
@@ -818,11 +819,11 @@ void RegisterInfoEmitter::emitComposeSubRegIndexLaneMask(raw_ostream &OS,
         "  for (const MaskRolOp *Ops =\n"
         "       &LaneMaskComposeSequences[CompositeSequences[IdxA]];\n"
         "       Ops->Mask.any(); ++Ops) {\n"
-        "    LaneBitmask::Type M = LaneMask.getAsInteger() & "
-        "Ops->Mask.getAsInteger();\n"
+        "    APInt M = LaneMask.getAsAPInt() & "
+        "Ops->Mask.getAsAPInt();\n"
         "    if (unsigned S = Ops->RotateLeft)\n"
-        "      Result |= LaneBitmask((M << S) | (M >> (LaneBitmask::BitWidth - "
-        "S)));\n"
+        "      Result |= LaneBitmask(M.shl(S) | M.lshr(LaneBitmask::BitWidth - "
+        "S));\n"
         "    else\n"
         "      Result |= LaneBitmask(M);\n"
         "  }\n"
@@ -840,10 +841,10 @@ void RegisterInfoEmitter::emitComposeSubRegIndexLaneMask(raw_ostream &OS,
         "  for (const MaskRolOp *Ops =\n"
         "       &LaneMaskComposeSequences[CompositeSequences[IdxA]];\n"
         "       Ops->Mask.any(); ++Ops) {\n"
-        "    LaneBitmask::Type M = LaneMask.getAsInteger();\n"
+        "    APInt M = LaneMask.getAsAPInt();\n"
         "    if (unsigned S = Ops->RotateLeft)\n"
-        "      Result |= LaneBitmask((M >> S) | (M << (LaneBitmask::BitWidth - "
-        "S)));\n"
+        "      Result |= LaneBitmask(M.lshr(S) | M.shl(LaneBitmask::BitWidth - "
+        "S));\n"
         "    else\n"
         "      Result |= LaneBitmask(M);\n"
         "  }\n"
@@ -1836,7 +1837,8 @@ void RegisterInfoEmitter::debugDump(raw_ostream &OS) {
     for (unsigned M = 0; M != NumModes; ++M)
       OS << ' ' << getModeName(M) << ':' << RC.RSI.get(M).SpillAlignment;
     OS << " }\n\tNumRegs: " << RC.getMembers().size() << '\n';
-    OS << "\tLaneMask: " << PrintLaneMask(RC.LaneMask) << '\n';
+    OS << "\tLaneMask: {"
+       << PrintLaneMask(RC.LaneMask, /*FormatAsCLiteral=*/true) << "}\n";
     OS << "\tHasDisjunctSubRegs: " << RC.HasDisjunctSubRegs << '\n';
     OS << "\tCoveredBySubRegs: " << RC.CoveredBySubRegs << '\n';
     OS << "\tAllocatable: " << RC.Allocatable << '\n';
@@ -1864,7 +1866,8 @@ void RegisterInfoEmitter::debugDump(raw_ostream &OS) {
 
   for (const CodeGenSubRegIndex &SRI : RegBank.getSubRegIndices()) {
     OS << "SubRegIndex " << SRI.getName() << ":\n";
-    OS << "\tLaneMask: " << PrintLaneMask(SRI.LaneMask) << '\n';
+    OS << "\tLaneMask: {"
+       << PrintLaneMask(SRI.LaneMask, /*FormatAsCLiteral=*/true) << "}\n";
     OS << "\tAllSuperRegsCovered: " << SRI.AllSuperRegsCovered << '\n';
     OS << "\tOffset: {";
     for (unsigned M = 0; M != NumModes; ++M)



More information about the llvm-commits mailing list