[llvm] [AArch64][SME] Make getRegAllocationHints stricter for multi-vector loads (PR #123081)
Sander de Smalen via llvm-commits
llvm-commits at lists.llvm.org
Wed Jan 29 06:46:27 PST 2025
================
@@ -1108,25 +1114,82 @@ bool AArch64RegisterInfo::getRegAllocationHints(
// instructions over reducing the number of clobbered callee-save registers,
// so we add the strided registers as a hint.
unsigned RegID = MRI.getRegClass(VirtReg)->getID();
- // Look through uses of the register for FORM_TRANSPOSED_REG_TUPLE.
- if ((RegID == AArch64::ZPR2StridedOrContiguousRegClassID ||
- RegID == AArch64::ZPR4StridedOrContiguousRegClassID) &&
- any_of(MRI.use_nodbg_instructions(VirtReg), [](const MachineInstr &Use) {
- return Use.getOpcode() ==
- AArch64::FORM_TRANSPOSED_REG_TUPLE_X2_PSEUDO ||
- Use.getOpcode() == AArch64::FORM_TRANSPOSED_REG_TUPLE_X4_PSEUDO;
- })) {
- const TargetRegisterClass *StridedRC =
- RegID == AArch64::ZPR2StridedOrContiguousRegClassID
- ? &AArch64::ZPR2StridedRegClass
- : &AArch64::ZPR4StridedRegClass;
-
- for (MCPhysReg Reg : Order)
- if (StridedRC->contains(Reg))
- Hints.push_back(Reg);
+ if (RegID == AArch64::ZPR2StridedOrContiguousRegClassID ||
+ RegID == AArch64::ZPR4StridedOrContiguousRegClassID) {
+
+ // Look through uses of the register for FORM_TRANSPOSED_REG_TUPLE.
+ for (const MachineInstr &Use : MRI.use_nodbg_instructions(VirtReg)) {
+ if (Use.getOpcode() != AArch64::FORM_TRANSPOSED_REG_TUPLE_X2_PSEUDO &&
+ Use.getOpcode() != AArch64::FORM_TRANSPOSED_REG_TUPLE_X4_PSEUDO)
+ continue;
+
+ unsigned LdOps = Use.getNumOperands() - 1;
+ const TargetRegisterClass *StridedRC =
+ LdOps == 2 ? &AArch64::ZPR2StridedRegClass
+ : &AArch64::ZPR4StridedRegClass;
+
+ SmallVector<MCPhysReg, 4> StridedOrder;
+ for (MCPhysReg Reg : Order)
+ if (StridedRC->contains(Reg))
+ StridedOrder.push_back(Reg);
+
+ int OpIdx = Use.findRegisterUseOperandIdx(VirtReg, this);
+ assert(OpIdx != -1 && "Expected operand index from register use.");
+
+ unsigned TupleID = MRI.getRegClass(Use.getOperand(0).getReg())->getID();
+ bool IsMulZPR = TupleID == AArch64::ZPR2Mul2RegClassID ||
+ TupleID == AArch64::ZPR4Mul4RegClassID;
+
+ const MachineOperand *AssignedRegOp = llvm::find_if(
+ make_range(Use.operands_begin() + 1, Use.operands_end()),
+ [&VRM](const MachineOperand &Op) {
+ return VRM->hasPhys(Op.getReg());
+ });
+
+ if (AssignedRegOp == Use.operands_end()) {
+ // There are no registers already assigned to any of the pseudo
+ // operands. Look for a valid starting register for the group.
+ for (unsigned I = 0; I < StridedOrder.size(); ++I) {
+ MCPhysReg Reg = StridedOrder[I];
+ SmallVector<MCPhysReg> Regs;
+ unsigned FirstStridedReg = Reg - OpIdx + 1;
+
+ // If the FORM_TRANSPOSE nodes use the ZPRMul classes, the starting
+ // register of the first load should be a multiple of 2 or 4.
+ unsigned FirstSubReg = getSubReg(FirstStridedReg, AArch64::zsub0);
+ if (IsMulZPR && (FirstSubReg - AArch64::Z0) % LdOps != 0)
+ continue;
+
+ for (unsigned Op = 0; Op < LdOps; ++Op) {
+ if (!is_contained(StridedOrder, FirstStridedReg + Op) ||
+ getSubReg(FirstStridedReg + Op, AArch64::zsub0) !=
+ FirstSubReg + Op)
+ break;
+ Regs.push_back(FirstStridedReg + Op);
+ }
+
+ if (Regs.size() == LdOps && all_of(Regs, [&](MCPhysReg R) {
+ return !Matrix->isPhysRegUsed(R);
+ }))
----------------
sdesmalen-arm wrote:
nit: This could be rewritten in such a way that it doesn't need `SmallVector<MCPhysReg> Regs` as an intermediate step, e.g.
```
auto IsFreeConsecutiveRegs = [&](unsigned I) {
// conditions
};
if (all_of(iota_range<unsigned>(0U, UseOps, /*Inclusive=*/false),
IsFreeConsecutiveReg))
Hints.push_back(Reg);
```
https://github.com/llvm/llvm-project/pull/123081
More information about the llvm-commits
mailing list