[llvm] [AArch64][SME] Make getRegAllocationHints stricter for multi-vector loads (PR #123081)
Sander de Smalen via llvm-commits
llvm-commits at lists.llvm.org
Fri Jan 17 01:16:12 PST 2025
================
@@ -1107,23 +1108,83 @@ bool AArch64RegisterInfo::getRegAllocationHints(
// FORM_TRANSPOSED_REG_TUPLE pseudo, we want to favour reducing copy
// instructions over reducing the number of clobbered callee-save registers,
// so we add the strided registers as a hint.
+ const MachineInstr *TupleInst = nullptr;
unsigned RegID = MRI.getRegClass(VirtReg)->getID();
// Look through uses of the register for FORM_TRANSPOSED_REG_TUPLE.
if ((RegID == AArch64::ZPR2StridedOrContiguousRegClassID ||
RegID == AArch64::ZPR4StridedOrContiguousRegClassID) &&
- any_of(MRI.use_nodbg_instructions(VirtReg), [](const MachineInstr &Use) {
- return Use.getOpcode() ==
- AArch64::FORM_TRANSPOSED_REG_TUPLE_X2_PSEUDO ||
- Use.getOpcode() == AArch64::FORM_TRANSPOSED_REG_TUPLE_X4_PSEUDO;
+ any_of(MRI.use_nodbg_instructions(VirtReg), [&TupleInst](
+ const MachineInstr &Use) {
+ bool IsTuple =
+ Use.getOpcode() == AArch64::FORM_TRANSPOSED_REG_TUPLE_X2_PSEUDO ||
+ Use.getOpcode() == AArch64::FORM_TRANSPOSED_REG_TUPLE_X4_PSEUDO;
+ TupleInst = &Use;
+ return IsTuple;
})) {
- const TargetRegisterClass *StridedRC =
- RegID == AArch64::ZPR2StridedOrContiguousRegClassID
- ? &AArch64::ZPR2StridedRegClass
- : &AArch64::ZPR4StridedRegClass;
+ unsigned LdOps = TupleInst->getNumOperands() - 1;
+ const TargetRegisterClass *StridedRC = LdOps == 2
+ ? &AArch64::ZPR2StridedRegClass
+ : &AArch64::ZPR4StridedRegClass;
+ SmallVector<MCPhysReg, 4> StridedOrder;
for (MCPhysReg Reg : Order)
if (StridedRC->contains(Reg))
- Hints.push_back(Reg);
+ StridedOrder.push_back(Reg);
+
+ int OpIdx = TupleInst->findRegisterUseOperandIdx(VirtReg, this);
+ if (OpIdx == -1)
+ return TargetRegisterInfo::getRegAllocationHints(VirtReg, Order, Hints,
+ MF, VRM);
+
+ unsigned TupleID =
+ MRI.getRegClass(TupleInst->getOperand(0).getReg())->getID();
+ bool IsMulZPR = TupleID == AArch64::ZPR2Mul2RegClassID ||
+ TupleID == AArch64::ZPR4Mul4RegClassID;
+
+ if (OpIdx == 1) {
+ for (unsigned I = 0; I < StridedOrder.size(); ++I) {
+ MCPhysReg Reg = StridedOrder[I];
+ unsigned FirstReg = getSubReg(Reg, AArch64::zsub0);
+
+ // If the FORM_TRANSPOSE nodes use the ZPRMul classes, the starting
+ // register of the first load should be a multiple of 2 or 4.
+ if (IsMulZPR &&
+ (getSubReg(Reg, AArch64::zsub0) - AArch64::Z0) % LdOps != 0)
+ continue;
+ // Skip this register if it has any live intervals assigned.
+ if (Matrix->isPhysRegUsed(Reg))
+ continue;
+
+ bool CanAssign = true;
+ for (unsigned Next = 1; Next < LdOps; ++Next) {
+ // Ensure we can assign enough registers from the list for all loads.
+ if (I + Next >= StridedOrder.size()) {
+ CanAssign = false;
+ break;
+ }
+ // Ensure the subsequent registers are not live and that the starting
+ // sub-registers are sequential.
+ MCPhysReg NextReg = StridedOrder[I + Next];
+ if (Matrix->isPhysRegUsed(NextReg) ||
+ (getSubReg(NextReg, AArch64::zsub0) != FirstReg + Next)) {
+ CanAssign = false;
+ break;
+ }
+ }
+ if (CanAssign)
+ Hints.push_back(Reg);
----------------
sdesmalen-arm wrote:
This loop can be written more compactly doing something along the lines of:
```
for (unsigned Next = 1; Next < LdOps; ++Next) {
if (!is_contained(StridedOrder, Reg + Next))
// cannot assign
}
```
https://github.com/llvm/llvm-project/pull/123081
More information about the llvm-commits
mailing list