[llvm] [AArch64][SME2] Add FORM_STRIDED_TUPLE pseudo nodes (PR #116399)

Sander de Smalen via llvm-commits llvm-commits at lists.llvm.org
Tue Dec 10 01:56:22 PST 2024


================
@@ -1107,6 +1107,69 @@ unsigned AArch64RegisterInfo::getRegPressureLimit(const TargetRegisterClass *RC,
   }
 }
 
+// FORM_STRIDED_TUPLE nodes are created to improve register allocation where
+// a consecutive multi-vector tuple is constructed from the same indices of
+// multiple strided loads. This may still result in unnecessary copies between
+// the loads and the tuple. Here we try to return a hint to assign the
+// contiguous ZPRMulReg starting at the same register as the first operand of
+// the pseudo, which should be a subregister of the first strided load.
+//
+// For example, if the first strided load has been assigned $z16_z20_z24_z28
+// and the operands of the pseudo are each accessing subregister zsub2, we
+// should look through through Order to find a contiguous register which
+// begins with $z24 (i.e. $z24_z25_z26_z27).
+//
+bool AArch64RegisterInfo::getRegAllocationHints(
+    Register VirtReg, ArrayRef<MCPhysReg> Order,
+    SmallVectorImpl<MCPhysReg> &Hints, const MachineFunction &MF,
+    const VirtRegMap *VRM, const LiveRegMatrix *Matrix) const {
+  const AArch64Subtarget &STI = MF.getSubtarget<AArch64Subtarget>();
+  const TargetRegisterInfo *TRI = STI.getRegisterInfo();
+  const MachineRegisterInfo &MRI = MF.getRegInfo();
+  bool DefaultHints =
+      TargetRegisterInfo::getRegAllocationHints(VirtReg, Order, Hints, MF, VRM);
+
+  unsigned RegID = MRI.getRegClass(VirtReg)->getID();
+  if (RegID != AArch64::ZPR2Mul2RegClassID &&
+      RegID != AArch64::ZPR4Mul4RegClassID)
+    return DefaultHints;
+
+  for (MachineInstr &MI : MRI.def_instructions(VirtReg)) {
+    if (MI.getOpcode() != AArch64::FORM_STRIDED_TUPLE_X2_PSEUDO &&
+        MI.getOpcode() != AArch64::FORM_STRIDED_TUPLE_X4_PSEUDO)
+      continue;
+
+    // Look up the physical register mapped to the first load of the pseudo.
+    Register FirstLoadVirtReg = MI.getOperand(1).getReg();
+    if (!VRM->hasPhys(FirstLoadVirtReg))
+      continue;
+
+    int64_t SubRegIdx = -1;
+    MCRegister FirstLoadPhysReg = VRM->getPhys(FirstLoadVirtReg);
+
+    // The subreg number is used to access the correct unit of the
+    // strided register found in the map above.
+    SubRegIdx = MI.getOperand(1).getSubReg() - AArch64::zsub0;
+    if (SubRegIdx < 0 || SubRegIdx > 3)
+      continue;
+
+    SmallVector<Register, 4> RegUnits;
+    for (MCRegUnit Unit : TRI->regunits(FirstLoadPhysReg))
+      RegUnits.push_back(Unit);
+
+    // Find the contiguous ZPRMul register which starts with the
+    // same register unit as the strided register and add to Hints.
+    Register StartReg = RegUnits[SubRegIdx];
+    for (unsigned I = 0; I < Order.size(); ++I) {
+      Register Reg = *TRI->regunits(Order[I]).begin();
+      if (Reg == StartReg)
+        Hints.push_back(Order[I]);
+    }
----------------
sdesmalen-arm wrote:

There is no need to iterate through all MCRegUnits for this register. It also feels rather fiddly to index into `RegUnits[SubRegIdx], because it makes assumptions on the order of register units in `RegUnits`.

You can do this instead using `getSubReg`, e.g.
```
MCRegister TupleStartReg = getSubReg(VRM->getPhys(FirstLoadVirtReg), MI.getOperand(1).getSubReg());
for (unsigned I = 0; I < Order.size(); ++I) 
  if (MCRegister R = getSubReg(Order[I], AArch64::zsub0))
    if (R == TupleStartReg)
      ....
}
```

https://github.com/llvm/llvm-project/pull/116399


More information about the llvm-commits mailing list