[llvm] [AArch64][SME2] Add FORM_STRIDED_TUPLE pseudo nodes (PR #116399)
Matthew Devereau via llvm-commits
llvm-commits at lists.llvm.org
Fri Nov 29 07:07:02 PST 2024
================
@@ -1107,6 +1107,81 @@ unsigned AArch64RegisterInfo::getRegPressureLimit(const TargetRegisterClass *RC,
}
}
+// FORM_STRIDED_TUPLE nodes are created to improve register allocation where
+// a consecutive multi-vector tuple is constructed from the same indices of
+// multiple strided loads. This may still result in unnecessary copies between
+// the loads and the tuple. Here we try to return a hint to assign the
+// contiguous ZPRMulReg starting at the same register as the first operand of
+// the pseudo, which should be a subregister of the first strided load.
+//
+// For example, if the first strided load has been assigned $z16_z20_z24_z28
+// and the operands of the pseudo are each accessing subregister zsub2, we
+// should look through through Order to find a contiguous register which
+// begins with $z24 (i.e. $z24_z25_z26_z27).
+//
+bool AArch64RegisterInfo::getRegAllocationHints(
+ Register VirtReg, ArrayRef<MCPhysReg> Order,
+ SmallVectorImpl<MCPhysReg> &Hints, const MachineFunction &MF,
+ const VirtRegMap *VRM, const LiveRegMatrix *Matrix) const {
+ const AArch64Subtarget &STI = MF.getSubtarget<AArch64Subtarget>();
+ const TargetRegisterInfo *TRI = STI.getRegisterInfo();
+ const MachineRegisterInfo &MRI = MF.getRegInfo();
+ bool DefaultHints =
+ TargetRegisterInfo::getRegAllocationHints(VirtReg, Order, Hints, MF, VRM);
+
+ unsigned RegID = MRI.getRegClass(VirtReg)->getID();
+ if (RegID != AArch64::ZPR2Mul2RegClassID &&
+ RegID != AArch64::ZPR4Mul4RegClassID)
+ return DefaultHints;
+
+ for (MachineInstr &MI : MRI.def_instructions(VirtReg)) {
+ if (MI.getOpcode() != AArch64::FORM_STRIDED_TUPLE_X2_PSEUDO &&
+ MI.getOpcode() != AArch64::FORM_STRIDED_TUPLE_X4_PSEUDO)
+ continue;
+
+ // Look up the physical register mapped to the first load of the pseudo.
+ Register FirstLoadVirtReg = MI.getOperand(1).getReg();
+ if (!VRM->hasPhys(FirstLoadVirtReg))
+ continue;
+
+ unsigned SubRegIdx = 0;
+ MCRegister FirstLoadPhysReg = VRM->getPhys(FirstLoadVirtReg);
+
+ // The subreg number is used to access the correct unit of the
+ // strided register found in the map above.
+ switch (MI.getOperand(1).getSubReg()) {
+ case AArch64::zsub0:
+ break;
+ case AArch64::zsub1:
+ SubRegIdx = 1;
+ break;
+ case AArch64::zsub2:
+ SubRegIdx = 2;
+ break;
+ case AArch64::zsub3:
+ SubRegIdx = 3;
+ break;
+ default:
+ continue;
+ }
----------------
MDevereau wrote:
Is it possible to shorten this to
```c++
unsigned SubRegIdx = AArch64::zsub0 - MI.getOperand(1).getSubReg();
if (SubRegIdx > 3)
continue;
```
?
https://github.com/llvm/llvm-project/pull/116399
More information about the llvm-commits
mailing list