[llvm] [AArch64] Optimize when storing symmetry constants (PR #93717)

David Green via llvm-commits llvm-commits at lists.llvm.org
Thu Jun 6 01:31:54 PDT 2024


================
@@ -2252,6 +2260,159 @@ MachineBasicBlock::iterator AArch64LoadStoreOpt::findMatchingUpdateInsnBackward(
   return E;
 }
 
+static bool isSymmetric(MachineInstr &MI, Register BaseReg) {
+  auto MatchBaseReg = [&](unsigned Count) {
+    for (unsigned I = 0; I < Count; I++) {
+      auto OpI = MI.getOperand(I);
+      if (OpI.isReg() && OpI.getReg() != BaseReg)
+        return false;
+    }
+    return true;
+  };
+
+  unsigned Opc = MI.getOpcode();
+  switch (Opc) {
+  default:
+    return false;
+  case AArch64::MOVZXi:
+    return MatchBaseReg(1);
+  case AArch64::MOVKXi:
+    return MatchBaseReg(2);
+  case AArch64::ORRXrs:
+    MachineOperand &Imm = MI.getOperand(3);
+    // Fourth operand of ORR must be 32 which mean
+    // 32bit symmetric constant load.
+    // ex) renamable $x8 = ORRXrs $x8, $x8, 32
+    if (MatchBaseReg(3) && Imm.isImm() && Imm.getImm() == 32)
+      return true;
+  }
+
+  return false;
+}
+
+MachineBasicBlock::iterator AArch64LoadStoreOpt::doFoldSymmetryConstantLoad(
+    MachineInstr &MI, SmallVectorImpl<MachineBasicBlock::iterator> &MIs,
+    int SuccIndex, int Accumulated) {
+  MachineBasicBlock::iterator I = MI.getIterator();
+  MachineBasicBlock::iterator E = I->getParent()->end();
+  MachineBasicBlock::iterator NextI = next_nodbg(I, E);
+  MachineBasicBlock::iterator FirstMovI;
+  MachineBasicBlock *MBB = MI.getParent();
+  uint64_t Mask = 0xFFFFUL;
+  int Index = 0;
+
+  for (auto MI = MIs.begin(), E = MIs.end(); MI != E; ++MI, Index++) {
+    if (Index == SuccIndex - 1) {
+      FirstMovI = *MI;
+      break;
+    }
+    (*MI)->eraseFromParent();
+  }
+
+  Register DstRegW =
+      TRI->getSubReg(FirstMovI->getOperand(0).getReg(), AArch64::sub_32);
+  int Lower = Accumulated & Mask;
+  if (Lower) {
+    BuildMI(*MBB, FirstMovI, FirstMovI->getDebugLoc(),
+            TII->get(AArch64::MOVZWi), DstRegW)
+        .addImm(Lower)
+        .addImm(0);
+    Lower = Accumulated >> 16 & Mask;
+    if (Lower) {
+      BuildMI(*MBB, FirstMovI, FirstMovI->getDebugLoc(),
+              TII->get(AArch64::MOVKWi), DstRegW)
+          .addUse(DstRegW)
+          .addImm(Lower)
+          .addImm(AArch64_AM::getShifterImm(AArch64_AM::LSL, 16));
+    }
+  } else {
+    Lower = Accumulated >> 16 & Mask;
+    BuildMI(*MBB, FirstMovI, FirstMovI->getDebugLoc(),
+            TII->get(AArch64::MOVZWi), DstRegW)
+        .addImm(Lower)
+        .addImm(AArch64_AM::getShifterImm(AArch64_AM::LSL, 16));
+  }
+  FirstMovI->eraseFromParent();
+  Register BaseReg = getLdStRegOp(MI).getReg();
+  const MachineOperand MO = AArch64InstrInfo::getLdStBaseOp(MI);
+  DstRegW = TRI->getSubReg(BaseReg, AArch64::sub_32);
+  unsigned DstRegState = getRegState(MI.getOperand(0));
+  BuildMI(*MBB, MI, MI.getDebugLoc(), TII->get(AArch64::STPWi))
+      .addReg(DstRegW, DstRegState)
+      .addReg(DstRegW, DstRegState)
+      .addReg(MO.getReg(), getRegState(MO))
+      .add(AArch64InstrInfo::getLdStOffsetOp(MI))
+      .setMemRefs(MI.memoperands())
+      .setMIFlags(MI.getFlags());
+  I->eraseFromParent();
+
+  return NextI;
+}
+
+bool AArch64LoadStoreOpt::tryFoldSymmetryConstantLoad(
+    MachineBasicBlock::iterator &I, unsigned Limit) {
+  MachineInstr &MI = *I;
+  if (MI.getOpcode() != AArch64::STRXui)
+    return false;
+
+  MachineBasicBlock::iterator MBBI = I;
+  MachineBasicBlock::iterator B = I->getParent()->begin();
+  if (MBBI == B)
+    return false;
+
+  Register BaseReg = getLdStRegOp(MI).getReg();
+  unsigned Count = 0, SuccIndex = 0;
+  bool hasORR = false;
+  SmallVector<MachineBasicBlock::iterator> MIs;
+  ModifiedRegUnits.clear();
+  UsedRegUnits.clear();
+
+  uint64_t IValue, IShift, Accumulated = 0, Mask = 0xFFFFUL;
+  do {
+    MBBI = prev_nodbg(MBBI, B);
+    MachineInstr &MI = *MBBI;
+    if (!MI.isTransient())
+      ++Count;
+    if (!isSymmetric(MI, BaseReg)) {
+      LiveRegUnits::accumulateUsedDefed(MI, ModifiedRegUnits, UsedRegUnits,
+                                        TRI);
+      if (!ModifiedRegUnits.available(BaseReg) ||
+          !UsedRegUnits.available(BaseReg))
+        break;
----------------
davemgreen wrote:

Should this return false? So when we break out of the loop we know we are in a valid situation.

https://github.com/llvm/llvm-project/pull/93717


More information about the llvm-commits mailing list