[llvm] [AArch64][SME] Make getRegAllocationHints stricter for multi-vector loads (PR #123081)

Sander de Smalen via llvm-commits llvm-commits at lists.llvm.org
Wed Jan 29 06:46:26 PST 2025


================
@@ -1108,25 +1114,82 @@ bool AArch64RegisterInfo::getRegAllocationHints(
   // instructions over reducing the number of clobbered callee-save registers,
   // so we add the strided registers as a hint.
   unsigned RegID = MRI.getRegClass(VirtReg)->getID();
-  // Look through uses of the register for FORM_TRANSPOSED_REG_TUPLE.
-  if ((RegID == AArch64::ZPR2StridedOrContiguousRegClassID ||
-       RegID == AArch64::ZPR4StridedOrContiguousRegClassID) &&
-      any_of(MRI.use_nodbg_instructions(VirtReg), [](const MachineInstr &Use) {
-        return Use.getOpcode() ==
-                   AArch64::FORM_TRANSPOSED_REG_TUPLE_X2_PSEUDO ||
-               Use.getOpcode() == AArch64::FORM_TRANSPOSED_REG_TUPLE_X4_PSEUDO;
-      })) {
-    const TargetRegisterClass *StridedRC =
-        RegID == AArch64::ZPR2StridedOrContiguousRegClassID
-            ? &AArch64::ZPR2StridedRegClass
-            : &AArch64::ZPR4StridedRegClass;
-
-    for (MCPhysReg Reg : Order)
-      if (StridedRC->contains(Reg))
-        Hints.push_back(Reg);
+  if (RegID == AArch64::ZPR2StridedOrContiguousRegClassID ||
+      RegID == AArch64::ZPR4StridedOrContiguousRegClassID) {
+
+    // Look through uses of the register for FORM_TRANSPOSED_REG_TUPLE.
+    for (const MachineInstr &Use : MRI.use_nodbg_instructions(VirtReg)) {
+      if (Use.getOpcode() != AArch64::FORM_TRANSPOSED_REG_TUPLE_X2_PSEUDO &&
+          Use.getOpcode() != AArch64::FORM_TRANSPOSED_REG_TUPLE_X4_PSEUDO)
+        continue;
+
+      unsigned LdOps = Use.getNumOperands() - 1;
+      const TargetRegisterClass *StridedRC =
+          LdOps == 2 ? &AArch64::ZPR2StridedRegClass
+                     : &AArch64::ZPR4StridedRegClass;
+
+      SmallVector<MCPhysReg, 4> StridedOrder;
+      for (MCPhysReg Reg : Order)
+        if (StridedRC->contains(Reg))
+          StridedOrder.push_back(Reg);
+
+      int OpIdx = Use.findRegisterUseOperandIdx(VirtReg, this);
+      assert(OpIdx != -1 && "Expected operand index from register use.");
+
+      unsigned TupleID = MRI.getRegClass(Use.getOperand(0).getReg())->getID();
+      bool IsMulZPR = TupleID == AArch64::ZPR2Mul2RegClassID ||
+                      TupleID == AArch64::ZPR4Mul4RegClassID;
+
+      const MachineOperand *AssignedRegOp = llvm::find_if(
+          make_range(Use.operands_begin() + 1, Use.operands_end()),
+          [&VRM](const MachineOperand &Op) {
+            return VRM->hasPhys(Op.getReg());
+          });
+
+      if (AssignedRegOp == Use.operands_end()) {
+        // There are no registers already assigned to any of the pseudo
+        // operands. Look for a valid starting register for the group.
+        for (unsigned I = 0; I < StridedOrder.size(); ++I) {
+          MCPhysReg Reg = StridedOrder[I];
+          SmallVector<MCPhysReg> Regs;
+          unsigned FirstStridedReg = Reg - OpIdx + 1;
----------------
sdesmalen-arm wrote:

I would avoid doing this, because `Reg - OpIdx + 1` may not be an SVE tuple register, which means that `getSubReg(FirstStridedReg, AArch64::zsub0)` might fail.

Example, if the first tuple register in the list would be `Z0_Z1` and we're looking at the second operand in the tuple form_*tuple pseudo, i.e. OpIdx = 2, then `FirstStridedReg` would be `X26_X27`.

You can instead write this as:
```
unsigned SubRegIdx = Use.getOperand(OpIdx).getSubReg();
if (IsMulZPR && (getSubReg(Reg, SubRegIdx) - AArch64::Z0) % UseOps !=
                    ((unsigned)OpIdx - 1))
  continue;
```

https://github.com/llvm/llvm-project/pull/123081


More information about the llvm-commits mailing list