[llvm] [AArch64][SME] Make getRegAllocationHints stricter for multi-vector loads (PR #123081)

Sander de Smalen via llvm-commits llvm-commits at lists.llvm.org
Fri Jan 17 01:16:12 PST 2025


================
@@ -1107,23 +1108,83 @@ bool AArch64RegisterInfo::getRegAllocationHints(
   // FORM_TRANSPOSED_REG_TUPLE pseudo, we want to favour reducing copy
   // instructions over reducing the number of clobbered callee-save registers,
   // so we add the strided registers as a hint.
+  const MachineInstr *TupleInst = nullptr;
   unsigned RegID = MRI.getRegClass(VirtReg)->getID();
   // Look through uses of the register for FORM_TRANSPOSED_REG_TUPLE.
   if ((RegID == AArch64::ZPR2StridedOrContiguousRegClassID ||
        RegID == AArch64::ZPR4StridedOrContiguousRegClassID) &&
-      any_of(MRI.use_nodbg_instructions(VirtReg), [](const MachineInstr &Use) {
-        return Use.getOpcode() ==
-                   AArch64::FORM_TRANSPOSED_REG_TUPLE_X2_PSEUDO ||
-               Use.getOpcode() == AArch64::FORM_TRANSPOSED_REG_TUPLE_X4_PSEUDO;
+      any_of(MRI.use_nodbg_instructions(VirtReg), [&TupleInst](
+                                                      const MachineInstr &Use) {
+        bool IsTuple =
+            Use.getOpcode() == AArch64::FORM_TRANSPOSED_REG_TUPLE_X2_PSEUDO ||
+            Use.getOpcode() == AArch64::FORM_TRANSPOSED_REG_TUPLE_X4_PSEUDO;
+        TupleInst = &Use;
+        return IsTuple;
       })) {
-    const TargetRegisterClass *StridedRC =
-        RegID == AArch64::ZPR2StridedOrContiguousRegClassID
-            ? &AArch64::ZPR2StridedRegClass
-            : &AArch64::ZPR4StridedRegClass;
+    unsigned LdOps = TupleInst->getNumOperands() - 1;
+    const TargetRegisterClass *StridedRC = LdOps == 2
+                                               ? &AArch64::ZPR2StridedRegClass
+                                               : &AArch64::ZPR4StridedRegClass;
 
+    SmallVector<MCPhysReg, 4> StridedOrder;
     for (MCPhysReg Reg : Order)
       if (StridedRC->contains(Reg))
-        Hints.push_back(Reg);
+        StridedOrder.push_back(Reg);
+
+    int OpIdx = TupleInst->findRegisterUseOperandIdx(VirtReg, this);
+    if (OpIdx == -1)
+      return TargetRegisterInfo::getRegAllocationHints(VirtReg, Order, Hints,
+                                                       MF, VRM);
+
+    unsigned TupleID =
+        MRI.getRegClass(TupleInst->getOperand(0).getReg())->getID();
+    bool IsMulZPR = TupleID == AArch64::ZPR2Mul2RegClassID ||
+                    TupleID == AArch64::ZPR4Mul4RegClassID;
+
+    if (OpIdx == 1) {
+      for (unsigned I = 0; I < StridedOrder.size(); ++I) {
+        MCPhysReg Reg = StridedOrder[I];
+        unsigned FirstReg = getSubReg(Reg, AArch64::zsub0);
+
+        // If the FORM_TRANSPOSE nodes use the ZPRMul classes, the starting
+        // register of the first load should be a multiple of 2 or 4.
+        if (IsMulZPR &&
+            (getSubReg(Reg, AArch64::zsub0) - AArch64::Z0) % LdOps != 0)
+          continue;
+        // Skip this register if it has any live intervals assigned.
+        if (Matrix->isPhysRegUsed(Reg))
+          continue;
+
+        bool CanAssign = true;
+        for (unsigned Next = 1; Next < LdOps; ++Next) {
+          // Ensure we can assign enough registers from the list for all loads.
+          if (I + Next >= StridedOrder.size()) {
+            CanAssign = false;
+            break;
+          }
+          // Ensure the subsequent registers are not live and that the starting
+          // sub-registers are sequential.
+          MCPhysReg NextReg = StridedOrder[I + Next];
+          if (Matrix->isPhysRegUsed(NextReg) ||
+              (getSubReg(NextReg, AArch64::zsub0) != FirstReg + Next)) {
+            CanAssign = false;
+            break;
+          }
+        }
+        if (CanAssign)
+          Hints.push_back(Reg);
----------------
sdesmalen-arm wrote:

This loop can be written more compactly doing something along the lines of:
```
for (unsigned Next = 1; Next < LdOps; ++Next) {
  if (!is_contained(StridedOrder, Reg + Next))
    // cannot assign
}
```

https://github.com/llvm/llvm-project/pull/123081


More information about the llvm-commits mailing list