[llvm-branch-commits] [llvm] [AArch64] Generalize the instruction size checking in AsmPrinter (PR #110108)

Anatoly Trosinenko via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Thu Sep 26 04:23:52 PDT 2024


https://github.com/atrosinenko created https://github.com/llvm/llvm-project/pull/110108

Most of PAuth-related code counts the instructions being inserted and asserts that no more bytes are emitted than the size returned by the getInstSizeInBytes(MI) method. This check seems useful not only for PAuth-related instructions. Also, reimplementing it globally in AArch64AsmPrinter makes it more robust and simplifies further refactoring of PAuth-related code.

>From 4dfd901151b9ecde9e3795a6d4dba932d60859ee Mon Sep 17 00:00:00 2001
From: Anatoly Trosinenko <atrosinenko at accesssoftek.com>
Date: Wed, 25 Sep 2024 16:16:29 +0300
Subject: [PATCH] [AArch64] Generalize the instruction size checking in
 AsmPrinter

Most of PAuth-related code counts the instructions being inserted and
asserts that no more bytes are emitted than the size returned by the
getInstSizeInBytes(MI) method. This check seems useful not only for
PAuth-related instructions. Also, reimplementing it globally in
AArch64AsmPrinter makes it more robust and simplifies further
refactoring of PAuth-related code.
---
 llvm/lib/Target/AArch64/AArch64AsmPrinter.cpp | 121 +++++++-----------
 llvm/lib/Target/AArch64/AArch64InstrInfo.td   |   1 +
 2 files changed, 44 insertions(+), 78 deletions(-)

diff --git a/llvm/lib/Target/AArch64/AArch64AsmPrinter.cpp b/llvm/lib/Target/AArch64/AArch64AsmPrinter.cpp
index 47dd32ad2adc2f..c6ee8d43bd8f2d 100644
--- a/llvm/lib/Target/AArch64/AArch64AsmPrinter.cpp
+++ b/llvm/lib/Target/AArch64/AArch64AsmPrinter.cpp
@@ -24,6 +24,7 @@
 #include "MCTargetDesc/AArch64TargetStreamer.h"
 #include "TargetInfo/AArch64TargetInfo.h"
 #include "Utils/AArch64BaseInfo.h"
+#include "llvm/ADT/ScopeExit.h"
 #include "llvm/ADT/SmallString.h"
 #include "llvm/ADT/SmallVector.h"
 #include "llvm/ADT/StringRef.h"
@@ -86,6 +87,9 @@ class AArch64AsmPrinter : public AsmPrinter {
   FaultMaps FM;
   const AArch64Subtarget *STI;
   bool ShouldEmitWeakSwiftAsyncExtendedFramePointerFlags = false;
+#ifndef NDEBUG
+  unsigned InstsEmitted;
+#endif
 
 public:
   AArch64AsmPrinter(TargetMachine &TM, std::unique_ptr<MCStreamer> Streamer)
@@ -150,8 +154,7 @@ class AArch64AsmPrinter : public AsmPrinter {
   void emitPtrauthAuthResign(const MachineInstr *MI);
 
   // Emit the sequence to compute a discriminator into x17, or reuse AddrDisc.
-  unsigned emitPtrauthDiscriminator(uint16_t Disc, unsigned AddrDisc,
-                                    unsigned &InstsEmitted);
+  unsigned emitPtrauthDiscriminator(uint16_t Disc, unsigned AddrDisc);
 
   // Emit the sequence for LOADauthptrstatic
   void LowerLOADauthptrstatic(const MachineInstr &MI);
@@ -1338,8 +1341,6 @@ void AArch64AsmPrinter::LowerJumpTableDest(llvm::MCStreamer &OutStreamer,
 }
 
 void AArch64AsmPrinter::LowerHardenedBRJumpTable(const MachineInstr &MI) {
-  unsigned InstsEmitted = 0;
-
   const MachineJumpTableInfo *MJTI = MF->getJumpTableInfo();
   assert(MJTI && "Can't lower jump-table dispatch without JTI");
 
@@ -1377,10 +1378,8 @@ void AArch64AsmPrinter::LowerHardenedBRJumpTable(const MachineInstr &MI) {
                                      .addReg(AArch64::X16)
                                      .addImm(MaxTableEntry)
                                      .addImm(0));
-    ++InstsEmitted;
   } else {
     emitMOVZ(AArch64::X17, static_cast<uint16_t>(MaxTableEntry), 0);
-    ++InstsEmitted;
     // It's sad that we have to manually materialize instructions, but we can't
     // trivially reuse the main pseudo expansion logic.
     // A MOVK sequence is easy enough to generate and handles the general case.
@@ -1389,14 +1388,12 @@ void AArch64AsmPrinter::LowerHardenedBRJumpTable(const MachineInstr &MI) {
         break;
       emitMOVK(AArch64::X17, static_cast<uint16_t>(MaxTableEntry >> Offset),
                Offset);
-      ++InstsEmitted;
     }
     EmitToStreamer(*OutStreamer, MCInstBuilder(AArch64::SUBSXrs)
                                      .addReg(AArch64::XZR)
                                      .addReg(AArch64::X16)
                                      .addReg(AArch64::X17)
                                      .addImm(0));
-    ++InstsEmitted;
   }
 
   // This picks entry #0 on failure.
@@ -1406,7 +1403,6 @@ void AArch64AsmPrinter::LowerHardenedBRJumpTable(const MachineInstr &MI) {
                                    .addReg(AArch64::X16)
                                    .addReg(AArch64::XZR)
                                    .addImm(AArch64CC::LS));
-  ++InstsEmitted;
 
   // Prepare the @PAGE/@PAGEOFF low/high operands.
   MachineOperand JTMOHi(JTOp), JTMOLo(JTOp);
@@ -1421,14 +1417,12 @@ void AArch64AsmPrinter::LowerHardenedBRJumpTable(const MachineInstr &MI) {
   EmitToStreamer(
       *OutStreamer,
       MCInstBuilder(AArch64::ADRP).addReg(AArch64::X17).addOperand(JTMCHi));
-  ++InstsEmitted;
 
   EmitToStreamer(*OutStreamer, MCInstBuilder(AArch64::ADDXri)
                                    .addReg(AArch64::X17)
                                    .addReg(AArch64::X17)
                                    .addOperand(JTMCLo)
                                    .addImm(0));
-  ++InstsEmitted;
 
   EmitToStreamer(*OutStreamer, MCInstBuilder(AArch64::LDRSWroX)
                                    .addReg(AArch64::X16)
@@ -1436,7 +1430,6 @@ void AArch64AsmPrinter::LowerHardenedBRJumpTable(const MachineInstr &MI) {
                                    .addReg(AArch64::X16)
                                    .addImm(0)
                                    .addImm(1));
-  ++InstsEmitted;
 
   MCSymbol *AdrLabel = MF->getContext().createTempSymbol();
   const auto *AdrLabelE = MCSymbolRefExpr::create(AdrLabel, MF->getContext());
@@ -1446,20 +1439,14 @@ void AArch64AsmPrinter::LowerHardenedBRJumpTable(const MachineInstr &MI) {
   EmitToStreamer(
       *OutStreamer,
       MCInstBuilder(AArch64::ADR).addReg(AArch64::X17).addExpr(AdrLabelE));
-  ++InstsEmitted;
 
   EmitToStreamer(*OutStreamer, MCInstBuilder(AArch64::ADDXrs)
                                    .addReg(AArch64::X16)
                                    .addReg(AArch64::X17)
                                    .addReg(AArch64::X16)
                                    .addImm(0));
-  ++InstsEmitted;
 
   EmitToStreamer(*OutStreamer, MCInstBuilder(AArch64::BR).addReg(AArch64::X16));
-  ++InstsEmitted;
-
-  (void)InstsEmitted;
-  assert(STI->getInstrInfo()->getInstSizeInBytes(MI) >= InstsEmitted * 4);
 }
 
 void AArch64AsmPrinter::LowerMOPS(llvm::MCStreamer &OutStreamer,
@@ -1710,8 +1697,7 @@ void AArch64AsmPrinter::emitFMov0(const MachineInstr &MI) {
 }
 
 unsigned AArch64AsmPrinter::emitPtrauthDiscriminator(uint16_t Disc,
-                                                     unsigned AddrDisc,
-                                                     unsigned &InstsEmitted) {
+                                                     unsigned AddrDisc) {
   // So far we've used NoRegister in pseudos.  Now we need real encodings.
   if (AddrDisc == AArch64::NoRegister)
     AddrDisc = AArch64::XZR;
@@ -1724,20 +1710,16 @@ unsigned AArch64AsmPrinter::emitPtrauthDiscriminator(uint16_t Disc,
   // If there's only a constant discriminator, MOV it into x17.
   if (AddrDisc == AArch64::XZR) {
     emitMOVZ(AArch64::X17, Disc, 0);
-    ++InstsEmitted;
     return AArch64::X17;
   }
 
   // If there are both, emit a blend into x17.
   emitMovXReg(AArch64::X17, AddrDisc);
-  ++InstsEmitted;
   emitMOVK(AArch64::X17, Disc, 48);
-  ++InstsEmitted;
   return AArch64::X17;
 }
 
 void AArch64AsmPrinter::emitPtrauthAuthResign(const MachineInstr *MI) {
-  unsigned InstsEmitted = 0;
   const bool IsAUTPAC = MI->getOpcode() == AArch64::AUTPAC;
 
   // We can expand AUT/AUTPAC into 3 possible sequences:
@@ -1822,8 +1804,7 @@ void AArch64AsmPrinter::emitPtrauthAuthResign(const MachineInstr *MI) {
 
   // Compute aut discriminator into x17
   assert(isUInt<16>(AUTDisc));
-  unsigned AUTDiscReg =
-      emitPtrauthDiscriminator(AUTDisc, AUTAddrDisc, InstsEmitted);
+  unsigned AUTDiscReg = emitPtrauthDiscriminator(AUTDisc, AUTAddrDisc);
   bool AUTZero = AUTDiscReg == AArch64::XZR;
   unsigned AUTOpc = getAUTOpcodeForKey(AUTKey, AUTZero);
 
@@ -1836,13 +1817,10 @@ void AArch64AsmPrinter::emitPtrauthAuthResign(const MachineInstr *MI) {
   if (!AUTZero)
     AUTInst.addOperand(MCOperand::createReg(AUTDiscReg));
   EmitToStreamer(*OutStreamer, AUTInst);
-  ++InstsEmitted;
 
   // Unchecked or checked-but-non-trapping AUT is just an "AUT": we're done.
-  if (!IsAUTPAC && (!ShouldCheck || !ShouldTrap)) {
-    assert(STI->getInstrInfo()->getInstSizeInBytes(*MI) >= InstsEmitted * 4);
+  if (!IsAUTPAC && (!ShouldCheck || !ShouldTrap))
     return;
-  }
 
   MCSymbol *EndSym = nullptr;
 
@@ -1853,13 +1831,11 @@ void AArch64AsmPrinter::emitPtrauthAuthResign(const MachineInstr *MI) {
     // XPAC has tied src/dst: use x17 as a temporary copy.
     //  mov x17, x16
     emitMovXReg(AArch64::X17, AArch64::X16);
-    ++InstsEmitted;
 
     //  xpaci x17
     EmitToStreamer(
         *OutStreamer,
         MCInstBuilder(XPACOpc).addReg(AArch64::X17).addReg(AArch64::X17));
-    ++InstsEmitted;
 
     //  cmp x16, x17
     EmitToStreamer(*OutStreamer, MCInstBuilder(AArch64::SUBSXrs)
@@ -1867,21 +1843,18 @@ void AArch64AsmPrinter::emitPtrauthAuthResign(const MachineInstr *MI) {
                                      .addReg(AArch64::X16)
                                      .addReg(AArch64::X17)
                                      .addImm(0));
-    ++InstsEmitted;
 
     //  b.eq Lsuccess
     EmitToStreamer(*OutStreamer, MCInstBuilder(AArch64::Bcc)
                                      .addImm(AArch64CC::EQ)
                                      .addExpr(MCSymbolRefExpr::create(
                                          SuccessSym, OutContext)));
-    ++InstsEmitted;
 
     if (ShouldTrap) {
       // Trapping sequences do a 'brk'.
       //  brk #<0xc470 + aut key>
       EmitToStreamer(*OutStreamer,
                      MCInstBuilder(AArch64::BRK).addImm(0xc470 | AUTKey));
-      ++InstsEmitted;
     } else {
       // Non-trapping checked sequences return the stripped result in x16,
       // skipping over the PAC if there is one.
@@ -1890,7 +1863,6 @@ void AArch64AsmPrinter::emitPtrauthAuthResign(const MachineInstr *MI) {
       //        ..traps this is usable as an oracle anyway, based on high bits
       //  mov x17, x16
       emitMovXReg(AArch64::X16, AArch64::X17);
-      ++InstsEmitted;
 
       if (IsAUTPAC) {
         EndSym = createTempSymbol("resign_end_");
@@ -1899,7 +1871,6 @@ void AArch64AsmPrinter::emitPtrauthAuthResign(const MachineInstr *MI) {
         EmitToStreamer(*OutStreamer, MCInstBuilder(AArch64::B)
                                          .addExpr(MCSymbolRefExpr::create(
                                              EndSym, OutContext)));
-        ++InstsEmitted;
       }
     }
 
@@ -1911,10 +1882,8 @@ void AArch64AsmPrinter::emitPtrauthAuthResign(const MachineInstr *MI) {
   // We already emitted unchecked and checked-but-non-trapping AUTs.
   // That left us with trapping AUTs, and AUTPACs.
   // Trapping AUTs don't need PAC: we're done.
-  if (!IsAUTPAC) {
-    assert(STI->getInstrInfo()->getInstSizeInBytes(*MI) >= InstsEmitted * 4);
+  if (!IsAUTPAC)
     return;
-  }
 
   auto PACKey = (AArch64PACKey::ID)MI->getOperand(3).getImm();
   uint64_t PACDisc = MI->getOperand(4).getImm();
@@ -1922,8 +1891,7 @@ void AArch64AsmPrinter::emitPtrauthAuthResign(const MachineInstr *MI) {
 
   // Compute pac discriminator into x17
   assert(isUInt<16>(PACDisc));
-  unsigned PACDiscReg =
-      emitPtrauthDiscriminator(PACDisc, PACAddrDisc, InstsEmitted);
+  unsigned PACDiscReg = emitPtrauthDiscriminator(PACDisc, PACAddrDisc);
   bool PACZero = PACDiscReg == AArch64::XZR;
   unsigned PACOpc = getPACOpcodeForKey(PACKey, PACZero);
 
@@ -1936,16 +1904,13 @@ void AArch64AsmPrinter::emitPtrauthAuthResign(const MachineInstr *MI) {
   if (!PACZero)
     PACInst.addOperand(MCOperand::createReg(PACDiscReg));
   EmitToStreamer(*OutStreamer, PACInst);
-  ++InstsEmitted;
 
-  assert(STI->getInstrInfo()->getInstSizeInBytes(*MI) >= InstsEmitted * 4);
   //  Lend:
   if (EndSym)
     OutStreamer->emitLabel(EndSym);
 }
 
 void AArch64AsmPrinter::emitPtrauthBranch(const MachineInstr *MI) {
-  unsigned InstsEmitted = 0;
   bool IsCall = MI->getOpcode() == AArch64::BLRA;
   unsigned BrTarget = MI->getOperand(0).getReg();
 
@@ -1959,7 +1924,7 @@ void AArch64AsmPrinter::emitPtrauthBranch(const MachineInstr *MI) {
   unsigned AddrDisc = MI->getOperand(3).getReg();
 
   // Compute discriminator into x17
-  unsigned DiscReg = emitPtrauthDiscriminator(Disc, AddrDisc, InstsEmitted);
+  unsigned DiscReg = emitPtrauthDiscriminator(Disc, AddrDisc);
   bool IsZeroDisc = DiscReg == AArch64::XZR;
 
   unsigned Opc;
@@ -1981,9 +1946,6 @@ void AArch64AsmPrinter::emitPtrauthBranch(const MachineInstr *MI) {
   if (!IsZeroDisc)
     BRInst.addOperand(MCOperand::createReg(DiscReg));
   EmitToStreamer(*OutStreamer, BRInst);
-  ++InstsEmitted;
-
-  assert(STI->getInstrInfo()->getInstSizeInBytes(*MI) >= InstsEmitted * 4);
 }
 
 const MCExpr *
@@ -2091,12 +2053,6 @@ void AArch64AsmPrinter::LowerLOADauthptrstatic(const MachineInstr &MI) {
 }
 
 void AArch64AsmPrinter::LowerMOVaddrPAC(const MachineInstr &MI) {
-  unsigned InstsEmitted = 0;
-  auto EmitAndIncrement = [this, &InstsEmitted](const MCInst &Inst) {
-    EmitToStreamer(*OutStreamer, Inst);
-    ++InstsEmitted;
-  };
-
   const bool IsGOTLoad = MI.getOpcode() == AArch64::LOADgotPAC;
   MachineOperand GAOp = MI.getOperand(0);
   const uint64_t KeyC = MI.getOperand(1).getImm();
@@ -2158,20 +2114,20 @@ void AArch64AsmPrinter::LowerMOVaddrPAC(const MachineInstr &MI) {
   MCInstLowering.lowerOperand(GAMOHi, GAMCHi);
   MCInstLowering.lowerOperand(GAMOLo, GAMCLo);
 
-  EmitAndIncrement(
+  EmitToStreamer(
       MCInstBuilder(AArch64::ADRP).addReg(AArch64::X16).addOperand(GAMCHi));
 
   if (IsGOTLoad) {
-    EmitAndIncrement(MCInstBuilder(AArch64::LDRXui)
-                         .addReg(AArch64::X16)
-                         .addReg(AArch64::X16)
-                         .addOperand(GAMCLo));
+    EmitToStreamer(MCInstBuilder(AArch64::LDRXui)
+                       .addReg(AArch64::X16)
+                       .addReg(AArch64::X16)
+                       .addOperand(GAMCLo));
   } else {
-    EmitAndIncrement(MCInstBuilder(AArch64::ADDXri)
-                         .addReg(AArch64::X16)
-                         .addReg(AArch64::X16)
-                         .addOperand(GAMCLo)
-                         .addImm(0));
+    EmitToStreamer(MCInstBuilder(AArch64::ADDXri)
+                       .addReg(AArch64::X16)
+                       .addReg(AArch64::X16)
+                       .addOperand(GAMCLo)
+                       .addImm(0));
   }
 
   if (Offset != 0) {
@@ -2180,7 +2136,7 @@ void AArch64AsmPrinter::LowerMOVaddrPAC(const MachineInstr &MI) {
     if (isUInt<24>(AbsOffset)) {
       for (int BitPos = 0; BitPos != 24 && (AbsOffset >> BitPos);
            BitPos += 12) {
-        EmitAndIncrement(
+        EmitToStreamer(
             MCInstBuilder(IsNeg ? AArch64::SUBXri : AArch64::ADDXri)
                 .addReg(AArch64::X16)
                 .addReg(AArch64::X16)
@@ -2189,10 +2145,10 @@ void AArch64AsmPrinter::LowerMOVaddrPAC(const MachineInstr &MI) {
       }
     } else {
       const uint64_t UOffset = Offset;
-      EmitAndIncrement(MCInstBuilder(IsNeg ? AArch64::MOVNXi : AArch64::MOVZXi)
-                           .addReg(AArch64::X17)
-                           .addImm((IsNeg ? ~UOffset : UOffset) & 0xffff)
-                           .addImm(/*shift=*/0));
+      EmitToStreamer(MCInstBuilder(IsNeg ? AArch64::MOVNXi : AArch64::MOVZXi)
+                         .addReg(AArch64::X17)
+                         .addImm((IsNeg ? ~UOffset : UOffset) & 0xffff)
+                         .addImm(/*shift=*/0));
       auto NeedMovk = [IsNeg, UOffset](int BitPos) -> bool {
         assert(BitPos == 16 || BitPos == 32 || BitPos == 48);
         uint64_t Shifted = UOffset >> BitPos;
@@ -2206,11 +2162,11 @@ void AArch64AsmPrinter::LowerMOVaddrPAC(const MachineInstr &MI) {
       for (int BitPos = 16; BitPos != 64 && NeedMovk(BitPos); BitPos += 16)
         emitMOVK(AArch64::X17, (UOffset >> BitPos) & 0xffff, BitPos);
 
-      EmitAndIncrement(MCInstBuilder(AArch64::ADDXrs)
-                           .addReg(AArch64::X16)
-                           .addReg(AArch64::X16)
-                           .addReg(AArch64::X17)
-                           .addImm(/*shift=*/0));
+      EmitToStreamer(MCInstBuilder(AArch64::ADDXrs)
+                         .addReg(AArch64::X16)
+                         .addReg(AArch64::X16)
+                         .addReg(AArch64::X17)
+                         .addImm(/*shift=*/0));
     }
   }
 
@@ -2230,9 +2186,7 @@ void AArch64AsmPrinter::LowerMOVaddrPAC(const MachineInstr &MI) {
                  .addReg(AArch64::X16);
   if (DiscReg != AArch64::XZR)
     MIB.addReg(DiscReg);
-  EmitAndIncrement(MIB);
-
-  assert(STI->getInstrInfo()->getInstSizeInBytes(MI) >= InstsEmitted * 4);
+  EmitToStreamer(MIB);
 }
 
 const MCExpr *
@@ -2254,11 +2208,21 @@ AArch64AsmPrinter::lowerBlockAddressConstant(const BlockAddress &BA) {
 
 void AArch64AsmPrinter::EmitToStreamer(MCStreamer &S, const MCInst &Inst) {
   S.emitInstruction(Inst, *STI);
+#ifndef NDEBUG
+  ++InstsEmitted;
+#endif
 }
 
 void AArch64AsmPrinter::emitInstruction(const MachineInstr *MI) {
   AArch64_MC::verifyInstructionPredicates(MI->getOpcode(), STI->getFeatureBits());
 
+#ifndef NDEBUG
+  InstsEmitted = 0;
+  auto CheckMISize = make_scope_exit([&]() {
+    assert(STI->getInstrInfo()->getInstSizeInBytes(*MI) >= InstsEmitted * 4);
+  });
+#endif
+
   // Do any auto-generated pseudo lowerings.
   if (MCInst OutInst; lowerPseudoInstExpansion(MI, OutInst)) {
     EmitToStreamer(*OutStreamer, OutInst);
@@ -2546,6 +2510,7 @@ void AArch64AsmPrinter::emitInstruction(const MachineInstr *MI) {
     TLSDescCall.setOpcode(AArch64::TLSDESCCALL);
     TLSDescCall.addOperand(Sym);
     EmitToStreamer(*OutStreamer, TLSDescCall);
+    --InstsEmitted; // no code emitted
 
     MCInst Blr;
     Blr.setOpcode(AArch64::BLR);
diff --git a/llvm/lib/Target/AArch64/AArch64InstrInfo.td b/llvm/lib/Target/AArch64/AArch64InstrInfo.td
index c70e835d1619ff..b674f595761cfe 100644
--- a/llvm/lib/Target/AArch64/AArch64InstrInfo.td
+++ b/llvm/lib/Target/AArch64/AArch64InstrInfo.td
@@ -9697,6 +9697,7 @@ def : Pat<(AArch64tcret tglobaladdr:$dst, (i32 timm:$FPDiff)),
 def : Pat<(AArch64tcret texternalsym:$dst, (i32 timm:$FPDiff)),
           (TCRETURNdi texternalsym:$dst, imm:$FPDiff)>;
 
+let Size = 8 in
 def MOVMCSym : Pseudo<(outs GPR64:$dst), (ins i64imm:$sym), []>, Sched<[]>;
 def : Pat<(i64 (AArch64LocalRecover mcsym:$sym)), (MOVMCSym mcsym:$sym)>;
 



More information about the llvm-branch-commits mailing list