[llvm] [AArch64][SME] Make getRegAllocationHints stricter for multi-vector loads (PR #123081)
Kerry McLaughlin via llvm-commits
llvm-commits at lists.llvm.org
Tue Jan 28 07:02:09 PST 2025
================
@@ -1109,24 +1110,93 @@ bool AArch64RegisterInfo::getRegAllocationHints(
// so we add the strided registers as a hint.
unsigned RegID = MRI.getRegClass(VirtReg)->getID();
// Look through uses of the register for FORM_TRANSPOSED_REG_TUPLE.
- if ((RegID == AArch64::ZPR2StridedOrContiguousRegClassID ||
- RegID == AArch64::ZPR4StridedOrContiguousRegClassID) &&
- any_of(MRI.use_nodbg_instructions(VirtReg), [](const MachineInstr &Use) {
- return Use.getOpcode() ==
- AArch64::FORM_TRANSPOSED_REG_TUPLE_X2_PSEUDO ||
- Use.getOpcode() == AArch64::FORM_TRANSPOSED_REG_TUPLE_X4_PSEUDO;
- })) {
- const TargetRegisterClass *StridedRC =
- RegID == AArch64::ZPR2StridedOrContiguousRegClassID
- ? &AArch64::ZPR2StridedRegClass
- : &AArch64::ZPR4StridedRegClass;
+ for (const MachineInstr &Use : MRI.use_nodbg_instructions(VirtReg)) {
+ if ((RegID != AArch64::ZPR2StridedOrContiguousRegClassID &&
+ RegID != AArch64::ZPR4StridedOrContiguousRegClassID) ||
+ (Use.getOpcode() != AArch64::FORM_TRANSPOSED_REG_TUPLE_X2_PSEUDO &&
+ Use.getOpcode() != AArch64::FORM_TRANSPOSED_REG_TUPLE_X4_PSEUDO))
+ continue;
+
+ unsigned LdOps = Use.getNumOperands() - 1;
+ const TargetRegisterClass *StridedRC = LdOps == 2
+ ? &AArch64::ZPR2StridedRegClass
+ : &AArch64::ZPR4StridedRegClass;
+ SmallVector<MCPhysReg, 4> StridedOrder;
for (MCPhysReg Reg : Order)
if (StridedRC->contains(Reg))
- Hints.push_back(Reg);
+ StridedOrder.push_back(Reg);
+
+ auto GetRegStartingAt = [&](MCPhysReg FirstReg) -> MCPhysReg {
+ for (MCPhysReg Strided : StridedOrder)
+ if (getSubReg(Strided, AArch64::zsub0) == FirstReg)
+ return Strided;
+ return (MCPhysReg)AArch64::NoRegister;
+ };
+
+ int OpIdx = Use.findRegisterUseOperandIdx(VirtReg, this);
+ assert(OpIdx != -1 && "Expected operand index from register use.");
+
+ unsigned TupleID = MRI.getRegClass(Use.getOperand(0).getReg())->getID();
+ bool IsMulZPR = TupleID == AArch64::ZPR2Mul2RegClassID ||
+ TupleID == AArch64::ZPR4Mul4RegClassID;
+
+ unsigned AssignedOp = 0;
+ if (!any_of(make_range(Use.operands_begin() + 1, Use.operands_end()),
+ [&](const MachineOperand &Op) {
+ if (!VRM->hasPhys(Op.getReg()))
+ return false;
+ AssignedOp = Op.getOperandNo();
+ return true;
+ })) {
+ // There are no registers already assigned to any of the pseudo operands.
+ // Look for a valid starting register for the group.
+ for (unsigned I = 0; I < StridedOrder.size(); ++I) {
+ MCPhysReg Reg = StridedOrder[I];
+ unsigned FirstReg = getSubReg(Reg, AArch64::zsub0);
+
+ // If the FORM_TRANSPOSE nodes use the ZPRMul classes, the starting
+ // register of the first load should be a multiple of 2 or 4.
+ if (IsMulZPR && (FirstReg - AArch64::Z0) % LdOps != 0)
+ continue;
+ // Skip this register if it has any live intervals assigned.
+ if (Matrix->isPhysRegUsed(Reg))
+ continue;
+
+ // Look for registers in StridedOrder which start with sub-registers
+ // following sequentially from FirstReg. If all are found and none are
+ // already live, add Reg to Hints.
+ MCPhysReg RegToAssign = Reg;
+ for (unsigned Next = 1; Next < LdOps; ++Next) {
+ MCPhysReg Strided = GetRegStartingAt(FirstReg + Next);
+ if (Strided == AArch64::NoRegister ||
+ Matrix->isPhysRegUsed(Strided)) {
+ RegToAssign = AArch64::NoRegister;
+ break;
+ }
+ if (Next == (unsigned)OpIdx - 1)
+ RegToAssign = Strided;
+ }
+ if (RegToAssign != AArch64::NoRegister)
+ Hints.push_back(RegToAssign);
----------------
kmclaughlin-arm wrote:
I've rewritten this as suggested, however I included `!is_contained(StridedOrder, FirstReg + I)` in the first loop because I think it could be possible for this register to exist outside of the list. It's for this reason that I've also left in the filtering by start register here, as I need StridedOrder to contain every possible strided register for the `is_contained`.
I've also added a check to make sure that the starting registers are consecutive in the loop and created a test where this is necessary in sme2-multivec-regalloc.mir.
https://github.com/llvm/llvm-project/pull/123081
More information about the llvm-commits
mailing list