[llvm] [AArch64][SME] Make getRegAllocationHints stricter for multi-vector loads (PR #123081)

Sander de Smalen via llvm-commits llvm-commits at lists.llvm.org
Fri Jan 24 09:27:22 PST 2025


================
@@ -1109,24 +1110,93 @@ bool AArch64RegisterInfo::getRegAllocationHints(
   // so we add the strided registers as a hint.
   unsigned RegID = MRI.getRegClass(VirtReg)->getID();
   // Look through uses of the register for FORM_TRANSPOSED_REG_TUPLE.
-  if ((RegID == AArch64::ZPR2StridedOrContiguousRegClassID ||
-       RegID == AArch64::ZPR4StridedOrContiguousRegClassID) &&
-      any_of(MRI.use_nodbg_instructions(VirtReg), [](const MachineInstr &Use) {
-        return Use.getOpcode() ==
-                   AArch64::FORM_TRANSPOSED_REG_TUPLE_X2_PSEUDO ||
-               Use.getOpcode() == AArch64::FORM_TRANSPOSED_REG_TUPLE_X4_PSEUDO;
-      })) {
-    const TargetRegisterClass *StridedRC =
-        RegID == AArch64::ZPR2StridedOrContiguousRegClassID
-            ? &AArch64::ZPR2StridedRegClass
-            : &AArch64::ZPR4StridedRegClass;
+  for (const MachineInstr &Use : MRI.use_nodbg_instructions(VirtReg)) {
+    if ((RegID != AArch64::ZPR2StridedOrContiguousRegClassID &&
+         RegID != AArch64::ZPR4StridedOrContiguousRegClassID) ||
+        (Use.getOpcode() != AArch64::FORM_TRANSPOSED_REG_TUPLE_X2_PSEUDO &&
+         Use.getOpcode() != AArch64::FORM_TRANSPOSED_REG_TUPLE_X4_PSEUDO))
+      continue;
+
+    unsigned LdOps = Use.getNumOperands() - 1;
+    const TargetRegisterClass *StridedRC = LdOps == 2
+                                               ? &AArch64::ZPR2StridedRegClass
+                                               : &AArch64::ZPR4StridedRegClass;
 
+    SmallVector<MCPhysReg, 4> StridedOrder;
     for (MCPhysReg Reg : Order)
       if (StridedRC->contains(Reg))
-        Hints.push_back(Reg);
+        StridedOrder.push_back(Reg);
+
+    auto GetRegStartingAt = [&](MCPhysReg FirstReg) -> MCPhysReg {
+      for (MCPhysReg Strided : StridedOrder)
+        if (getSubReg(Strided, AArch64::zsub0) == FirstReg)
+          return Strided;
+      return (MCPhysReg)AArch64::NoRegister;
+    };
+
+    int OpIdx = Use.findRegisterUseOperandIdx(VirtReg, this);
+    assert(OpIdx != -1 && "Expected operand index from register use.");
+
+    unsigned TupleID = MRI.getRegClass(Use.getOperand(0).getReg())->getID();
+    bool IsMulZPR = TupleID == AArch64::ZPR2Mul2RegClassID ||
+                    TupleID == AArch64::ZPR4Mul4RegClassID;
+
+    unsigned AssignedOp = 0;
+    if (!any_of(make_range(Use.operands_begin() + 1, Use.operands_end()),
+                [&](const MachineOperand &Op) {
+                  if (!VRM->hasPhys(Op.getReg()))
+                    return false;
+                  AssignedOp = Op.getOperandNo();
+                  return true;
+                })) {
+      // There are no registers already assigned to any of the pseudo operands.
+      // Look for a valid starting register for the group.
+      for (unsigned I = 0; I < StridedOrder.size(); ++I) {
+        MCPhysReg Reg = StridedOrder[I];
+        unsigned FirstReg = getSubReg(Reg, AArch64::zsub0);
+
+        // If the FORM_TRANSPOSE nodes use the ZPRMul classes, the starting
+        // register of the first load should be a multiple of 2 or 4.
+        if (IsMulZPR && (FirstReg - AArch64::Z0) % LdOps != 0)
+          continue;
+        // Skip this register if it has any live intervals assigned.
+        if (Matrix->isPhysRegUsed(Reg))
+          continue;
+
+        // Look for registers in StridedOrder which start with sub-registers
+        // following sequentially from FirstReg. If all are found and none are
+        // already live, add Reg to Hints.
+        MCPhysReg RegToAssign = Reg;
+        for (unsigned Next = 1; Next < LdOps; ++Next) {
+          MCPhysReg Strided = GetRegStartingAt(FirstReg + Next);
+          if (Strided == AArch64::NoRegister ||
+              Matrix->isPhysRegUsed(Strided)) {
+            RegToAssign = AArch64::NoRegister;
+            break;
+          }
+          if (Next == (unsigned)OpIdx - 1)
+            RegToAssign = Strided;
+        }
+        if (RegToAssign != AArch64::NoRegister)
+          Hints.push_back(RegToAssign);
----------------
sdesmalen-arm wrote:

This code is a little convoluted. I think you could also avoid the extra nested loop in `GetRegStartingAt`, by doing the following:

If Reg is e.g. `Z1_Z5_z9_z13`, then loop from `Z0_Z4_Z8_Z12`..`Z3_Z7_Z11_Z15` and check if any of them has allocated a phys reg. If not, then you can add `Z1_Z5_Z9_Z13`.

```
SmallVector<MCPhysReg> Regs;
unsigned FirstReg = Reg - OpIdx + 1;
for (unsigned I = 0; I < LdOps; ++I)
  Regs.push_back(FirstReg + I);

if (all_of(Regs,
           [&](MCPhysReg R) { return !Matrix->isPhysRegUsed(R); }))
  Hints.push_back(FirstReg + OpIdx - 1);
```

https://github.com/llvm/llvm-project/pull/123081


More information about the llvm-commits mailing list