[llvm] [AArch64][SME] Add common helper for expanding conditional pseudos (NFC) (PR #155398)

via llvm-commits llvm-commits at lists.llvm.org
Tue Aug 26 04:52:21 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-backend-aarch64

Author: Benjamin Maxwell (MacDue)

<details>
<summary>Changes</summary>

Factors out some common slightly tricky code. Hopefully makes adding new pseudos simpler.

---
Full diff: https://github.com/llvm/llvm-project/pull/155398.diff


1 Files Affected:

- (modified) llvm/lib/Target/AArch64/AArch64ExpandPseudoInsts.cpp (+106-85) 


``````````diff
diff --git a/llvm/lib/Target/AArch64/AArch64ExpandPseudoInsts.cpp b/llvm/lib/Target/AArch64/AArch64ExpandPseudoInsts.cpp
index 57dcd68595ff1..9e83515aa536e 100644
--- a/llvm/lib/Target/AArch64/AArch64ExpandPseudoInsts.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ExpandPseudoInsts.cpp
@@ -92,9 +92,17 @@ class AArch64ExpandPseudo : public MachineFunctionPass {
   bool expandCALL_BTI(MachineBasicBlock &MBB, MachineBasicBlock::iterator MBBI);
   bool expandStoreSwiftAsyncContext(MachineBasicBlock &MBB,
                                     MachineBasicBlock::iterator MBBI);
+
   MachineBasicBlock *
-  expandCommitOrRestoreZASave(MachineBasicBlock &MBB,
-                              MachineBasicBlock::iterator MBBI);
+  expandConditionalPseudo(MachineBasicBlock &MBB,
+                          MachineBasicBlock::iterator MBBI, DebugLoc DL,
+                          MachineInstrBuilder &Branch,
+                          function_ref<void(MachineBasicBlock &)> InsertBody);
+
+  MachineBasicBlock *expandRestoreZASave(MachineBasicBlock &MBB,
+                                         MachineBasicBlock::iterator MBBI);
+  MachineBasicBlock *expandCommitZASave(MachineBasicBlock &MBB,
+                                        MachineBasicBlock::iterator MBBI);
   MachineBasicBlock *expandCondSMToggle(MachineBasicBlock &MBB,
                                         MachineBasicBlock::iterator MBBI);
 };
@@ -991,74 +999,99 @@ bool AArch64ExpandPseudo::expandStoreSwiftAsyncContext(
   return true;
 }
 
-static constexpr unsigned ZERO_ALL_ZA_MASK = 0b11111111;
-
-MachineBasicBlock *AArch64ExpandPseudo::expandCommitOrRestoreZASave(
-    MachineBasicBlock &MBB, MachineBasicBlock::iterator MBBI) {
+MachineBasicBlock *AArch64ExpandPseudo::expandConditionalPseudo(
+    MachineBasicBlock &MBB, MachineBasicBlock::iterator MBBI, DebugLoc DL,
+    MachineInstrBuilder &Branch,
+    function_ref<void(MachineBasicBlock &)> InsertBody) {
   MachineInstr &MI = *MBBI;
-  bool IsRestoreZA = MI.getOpcode() == AArch64::RestoreZAPseudo;
-  assert((MI.getOpcode() == AArch64::RestoreZAPseudo ||
-          MI.getOpcode() == AArch64::CommitZASavePseudo) &&
-         "Expected ZA commit or restore");
   assert((std::next(MBBI) != MBB.end() ||
           MI.getParent()->successors().begin() !=
               MI.getParent()->successors().end()) &&
-         "Unexpected unreachable in block that restores ZA");
-
-  // Compare TPIDR2_EL0 value against 0.
-  DebugLoc DL = MI.getDebugLoc();
-  MachineInstrBuilder Branch =
-      BuildMI(MBB, MBBI, DL,
-              TII->get(IsRestoreZA ? AArch64::CBZX : AArch64::CBNZX))
-          .add(MI.getOperand(0));
+         "Unexpected unreachable in block");
 
   // Split MBB and create two new blocks:
-  //  - MBB now contains all instructions before RestoreZAPseudo.
-  //  - SMBB contains the [Commit|RestoreZA]Pseudo instruction only.
-  //  - EndBB contains all instructions after [Commit|RestoreZA]Pseudo.
+  //  - MBB now contains all instructions before the conditional pseudo.
+  //  - SMBB contains the conditional pseudo instruction only.
+  //  - EndBB contains all instructions after the conditional pseudo.
   MachineInstr &PrevMI = *std::prev(MBBI);
   MachineBasicBlock *SMBB = MBB.splitAt(PrevMI, /*UpdateLiveIns*/ true);
   MachineBasicBlock *EndBB = std::next(MI.getIterator()) == SMBB->end()
                                  ? *SMBB->successors().begin()
                                  : SMBB->splitAt(MI, /*UpdateLiveIns*/ true);
 
-  // Add the SMBB label to the CB[N]Z instruction & create a branch to EndBB.
+  // Add the SMBB label to the branch instruction & create a branch to EndBB.
   Branch.addMBB(SMBB);
   BuildMI(&MBB, DL, TII->get(AArch64::B))
       .addMBB(EndBB);
   MBB.addSuccessor(EndBB);
 
-  // Replace the pseudo with a call (BL).
-  MachineInstrBuilder MIB =
-      BuildMI(*SMBB, SMBB->end(), DL, TII->get(AArch64::BL));
-  // Copy operands (mainly the regmask) from the pseudo.
-  for (unsigned I = 2; I < MI.getNumOperands(); ++I)
-    MIB.add(MI.getOperand(I));
-
-  if (IsRestoreZA) {
-    // Mark the TPIDR2 block pointer (X0) as an implicit use.
-    MIB.addReg(MI.getOperand(1).getReg(), RegState::Implicit);
-  } else /*CommitZA*/ {
-    [[maybe_unused]] auto *TRI =
-        MBB.getParent()->getSubtarget().getRegisterInfo();
-    // Clear TPIDR2_EL0.
-    BuildMI(*SMBB, SMBB->end(), DL, TII->get(AArch64::MSR))
-        .addImm(AArch64SysReg::TPIDR2_EL0)
-        .addReg(AArch64::XZR);
-    bool ZeroZA = MI.getOperand(1).getImm() != 0;
-    if (ZeroZA) {
-      assert(MI.definesRegister(AArch64::ZAB0, TRI) && "should define ZA!");
-      BuildMI(*SMBB, SMBB->end(), DL, TII->get(AArch64::ZERO_M))
-          .addImm(ZERO_ALL_ZA_MASK)
-          .addDef(AArch64::ZAB0, RegState::ImplicitDefine);
-    }
-  }
+  // Insert the conditional pseudo expansion.
+  InsertBody(*SMBB);
 
   BuildMI(SMBB, DL, TII->get(AArch64::B)).addMBB(EndBB);
   MI.eraseFromParent();
   return EndBB;
 }
 
+MachineBasicBlock *
+AArch64ExpandPseudo::expandRestoreZASave(MachineBasicBlock &MBB,
+                                         MachineBasicBlock::iterator MBBI) {
+  MachineInstr &MI = *MBBI;
+  DebugLoc DL = MI.getDebugLoc();
+
+  // Compare TPIDR2_EL0 against 0. Restore ZA if TPIDR2_EL0 is zero.
+  MachineInstrBuilder Branch =
+      BuildMI(MBB, MBBI, DL, TII->get(AArch64::CBZX)).add(MI.getOperand(0));
+
+  return expandConditionalPseudo(
+      MBB, MBBI, DL, Branch, [&](MachineBasicBlock &SMBB) {
+        // Replace the pseudo with a call (BL).
+        MachineInstrBuilder MIB =
+            BuildMI(SMBB, SMBB.end(), DL, TII->get(AArch64::BL));
+        // Copy operands (mainly the regmask) from the pseudo.
+        for (unsigned I = 2; I < MI.getNumOperands(); ++I)
+          MIB.add(MI.getOperand(I));
+        // Mark the TPIDR2 block pointer (X0) as an implicit use.
+        MIB.addReg(MI.getOperand(1).getReg(), RegState::Implicit);
+      });
+}
+
+static constexpr unsigned ZERO_ALL_ZA_MASK = 0b11111111;
+
+MachineBasicBlock *
+AArch64ExpandPseudo::expandCommitZASave(MachineBasicBlock &MBB,
+                                        MachineBasicBlock::iterator MBBI) {
+  MachineInstr &MI = *MBBI;
+  DebugLoc DL = MI.getDebugLoc();
+  [[maybe_unused]] auto *TRI =
+      MBB.getParent()->getSubtarget().getRegisterInfo();
+
+  // Compare TPIDR2_EL0 against 0. Commit ZA if TPIDR2_EL0 is non-zero.
+  MachineInstrBuilder Branch =
+      BuildMI(MBB, MBBI, DL, TII->get(AArch64::CBNZX)).add(MI.getOperand(0));
+
+  return expandConditionalPseudo(
+      MBB, MBBI, DL, Branch, [&](MachineBasicBlock &SMBB) {
+        // Replace the pseudo with a call (BL).
+        MachineInstrBuilder MIB =
+            BuildMI(SMBB, SMBB.end(), DL, TII->get(AArch64::BL));
+        // Copy operands (mainly the regmask) from the pseudo.
+        for (unsigned I = 2; I < MI.getNumOperands(); ++I)
+          MIB.add(MI.getOperand(I));
+        // Clear TPIDR2_EL0.
+        BuildMI(SMBB, SMBB.end(), DL, TII->get(AArch64::MSR))
+            .addImm(AArch64SysReg::TPIDR2_EL0)
+            .addReg(AArch64::XZR);
+        bool ZeroZA = MI.getOperand(1).getImm() != 0;
+        if (ZeroZA) {
+          assert(MI.definesRegister(AArch64::ZAB0, TRI) && "should define ZA!");
+          BuildMI(SMBB, SMBB.end(), DL, TII->get(AArch64::ZERO_M))
+              .addImm(ZERO_ALL_ZA_MASK)
+              .addDef(AArch64::ZAB0, RegState::ImplicitDefine);
+        }
+      });
+}
+
 MachineBasicBlock *
 AArch64ExpandPseudo::expandCondSMToggle(MachineBasicBlock &MBB,
                                         MachineBasicBlock::iterator MBBI) {
@@ -1130,37 +1163,19 @@ AArch64ExpandPseudo::expandCondSMToggle(MachineBasicBlock &MBB,
   MachineInstrBuilder Tbx =
       BuildMI(MBB, MBBI, DL, TII->get(Opc)).addReg(SMReg32).addImm(0);
 
-  // Split MBB and create two new blocks:
-  //  - MBB now contains all instructions before MSRcond_pstatesvcrImm1.
-  //  - SMBB contains the MSRcond_pstatesvcrImm1 instruction only.
-  //  - EndBB contains all instructions after MSRcond_pstatesvcrImm1.
-  MachineInstr &PrevMI = *std::prev(MBBI);
-  MachineBasicBlock *SMBB = MBB.splitAt(PrevMI, /*UpdateLiveIns*/ true);
-  MachineBasicBlock *EndBB = std::next(MI.getIterator()) == SMBB->end()
-                                 ? *SMBB->successors().begin()
-                                 : SMBB->splitAt(MI, /*UpdateLiveIns*/ true);
-
-  // Add the SMBB label to the TB[N]Z instruction & create a branch to EndBB.
-  Tbx.addMBB(SMBB);
-  BuildMI(&MBB, DL, TII->get(AArch64::B))
-      .addMBB(EndBB);
-  MBB.addSuccessor(EndBB);
-
-  // Create the SMSTART/SMSTOP (MSRpstatesvcrImm1) instruction in SMBB.
-  MachineInstrBuilder MIB = BuildMI(*SMBB, SMBB->begin(), MI.getDebugLoc(),
-                                    TII->get(AArch64::MSRpstatesvcrImm1));
-  // Copy all but the second and third operands of MSRcond_pstatesvcrImm1 (as
-  // these contain the CopyFromReg for the first argument and the flag to
-  // indicate whether the callee is streaming or normal).
-  MIB.add(MI.getOperand(0));
-  MIB.add(MI.getOperand(1));
-  for (unsigned i = 4; i < MI.getNumOperands(); ++i)
-    MIB.add(MI.getOperand(i));
-
-  BuildMI(SMBB, DL, TII->get(AArch64::B)).addMBB(EndBB);
-
-  MI.eraseFromParent();
-  return EndBB;
+  return expandConditionalPseudo(
+      MBB, MBBI, DL, Tbx, [&](MachineBasicBlock &SMBB) {
+        // Create the SMSTART/SMSTOP (MSRpstatesvcrImm1) instruction in SMBB.
+        MachineInstrBuilder MIB = BuildMI(SMBB, SMBB.begin(), MI.getDebugLoc(),
+                                          TII->get(AArch64::MSRpstatesvcrImm1));
+        // Copy all but the second and third operands of MSRcond_pstatesvcrImm1
+        // (as these contain the CopyFromReg for the first argument and the flag
+        // to indicate whether the callee is streaming or normal).
+        MIB.add(MI.getOperand(0));
+        MIB.add(MI.getOperand(1));
+        for (unsigned i = 4; i < MI.getNumOperands(); ++i)
+          MIB.add(MI.getOperand(i));
+      });
 }
 
 bool AArch64ExpandPseudo::expandMultiVecPseudo(
@@ -1673,15 +1688,21 @@ bool AArch64ExpandPseudo::expandMI(MachineBasicBlock &MBB,
      return expandCALL_BTI(MBB, MBBI);
    case AArch64::StoreSwiftAsyncContext:
      return expandStoreSwiftAsyncContext(MBB, MBBI);
+   case AArch64::RestoreZAPseudo:
    case AArch64::CommitZASavePseudo:
-   case AArch64::RestoreZAPseudo: {
-     auto *NewMBB = expandCommitOrRestoreZASave(MBB, MBBI);
-     if (NewMBB != &MBB)
-       NextMBBI = MBB.end(); // The NextMBBI iterator is invalidated.
-     return true;
-   }
    case AArch64::MSRpstatePseudo: {
-     auto *NewMBB = expandCondSMToggle(MBB, MBBI);
+     auto *NewMBB = [&] {
+       switch (Opcode) {
+       case AArch64::RestoreZAPseudo:
+         return expandRestoreZASave(MBB, MBBI);
+       case AArch64::CommitZASavePseudo:
+         return expandCommitZASave(MBB, MBBI);
+       case AArch64::MSRpstatePseudo:
+         return expandCondSMToggle(MBB, MBBI);
+       default:
+         llvm_unreachable("Unexpected conditional pseudo!");
+       }
+     }();
      if (NewMBB != &MBB)
        NextMBBI = MBB.end(); // The NextMBBI iterator is invalidated.
      return true;

``````````

</details>


https://github.com/llvm/llvm-project/pull/155398


More information about the llvm-commits mailing list