[llvm] [AArch64][Machine-Combiner] Split loads into lanes of neon vectors into multiple vectors when possible (PR #142941)
via llvm-commits
llvm-commits at lists.llvm.org
Tue Jul 8 12:04:45 PDT 2025
github-actions[bot] wrote:
<!--LLVM CODE FORMAT COMMENT: {clang-format}-->
:warning: C/C++ code formatter, clang-format found issues in your code. :warning:
<details>
<summary>
You can test this locally with the following command:
</summary>
``````````bash
git-clang-format --diff HEAD~1 HEAD --extensions cpp,h -- llvm/lib/Target/AArch64/AArch64InstrInfo.cpp llvm/lib/Target/AArch64/AArch64InstrInfo.h
``````````
</details>
<details>
<summary>
View the diff from clang-format here.
</summary>
``````````diff
diff --git a/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp b/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp
index 520b11bd5..bf8e78c71 100644
--- a/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp
+++ b/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp
@@ -7373,9 +7373,8 @@ static bool getMiscPatterns(MachineInstr &Root,
}
static bool getGatherPattern(MachineInstr &Root,
- SmallVectorImpl<unsigned> &Patterns,
- unsigned LoadLaneOpCode,
- unsigned NumLanes) {
+ SmallVectorImpl<unsigned> &Patterns,
+ unsigned LoadLaneOpCode, unsigned NumLanes) {
const MachineRegisterInfo &MRI = Root.getMF()->getRegInfo();
const TargetRegisterInfo *TRI =
Root.getMF()->getSubtarget().getRegisterInfo();
@@ -7414,17 +7413,17 @@ static bool getGatherPattern(MachineInstr &Root,
return false;
switch (NumLanes) {
- case 4:
- Patterns.push_back(AArch64MachineCombinerPattern::GATHER_i32);
- break;
- case 8:
- Patterns.push_back(AArch64MachineCombinerPattern::GATHER_i16);
- break;
- case 16:
- Patterns.push_back(AArch64MachineCombinerPattern::GATHER_i8);
- break;
- default:
- llvm_unreachable("Got bad number of lanes for gather pattern.");
+ case 4:
+ Patterns.push_back(AArch64MachineCombinerPattern::GATHER_i32);
+ break;
+ case 8:
+ Patterns.push_back(AArch64MachineCombinerPattern::GATHER_i16);
+ break;
+ case 16:
+ Patterns.push_back(AArch64MachineCombinerPattern::GATHER_i8);
+ break;
+ default:
+ llvm_unreachable("Got bad number of lanes for gather pattern.");
}
return true;
@@ -7442,23 +7441,24 @@ static bool getLoadPatterns(MachineInstr &Root,
// The pattern searches for loads into single lanes.
switch (Root.getOpcode()) {
- case AArch64::LD1i32:
- return getGatherPattern(Root, Patterns, Root.getOpcode(), 4);
- case AArch64::LD1i16:
- return getGatherPattern(Root, Patterns, Root.getOpcode(), 8);
- case AArch64::LD1i8:
- return getGatherPattern(Root, Patterns, Root.getOpcode(), 16);
- default:
- return false;
+ case AArch64::LD1i32:
+ return getGatherPattern(Root, Patterns, Root.getOpcode(), 4);
+ case AArch64::LD1i16:
+ return getGatherPattern(Root, Patterns, Root.getOpcode(), 8);
+ case AArch64::LD1i8:
+ return getGatherPattern(Root, Patterns, Root.getOpcode(), 16);
+ default:
+ return false;
}
}
-static void generateGatherPattern(
- MachineInstr &Root, SmallVectorImpl<MachineInstr *> &InsInstrs,
- SmallVectorImpl<MachineInstr *> &DelInstrs,
- DenseMap<Register, unsigned> &InstrIdxForVirtReg, unsigned Pattern,
- unsigned NumLanes) {
-
+static void
+generateGatherPattern(MachineInstr &Root,
+ SmallVectorImpl<MachineInstr *> &InsInstrs,
+ SmallVectorImpl<MachineInstr *> &DelInstrs,
+ DenseMap<Register, unsigned> &InstrIdxForVirtReg,
+ unsigned Pattern, unsigned NumLanes) {
+
MachineFunction &MF = *Root.getParent()->getParent();
MachineRegisterInfo &MRI = MF.getRegInfo();
const TargetInstrInfo *TII = MF.getSubtarget().getInstrInfo();
@@ -7470,11 +7470,12 @@ static void generateGatherPattern(
LoadToLaneInstrs.push_back(CurrInstr);
CurrInstr = MRI.getUniqueVRegDef(CurrInstr->getOperand(1).getReg());
}
-
+
MachineInstr *SubregToReg = CurrInstr;
- LoadToLaneInstrs.push_back(MRI.getUniqueVRegDef(SubregToReg->getOperand(2).getReg()));
+ LoadToLaneInstrs.push_back(
+ MRI.getUniqueVRegDef(SubregToReg->getOperand(2).getReg()));
auto OriginalLoadInstrs = llvm::reverse(LoadToLaneInstrs);
-
+
const TargetRegisterClass *FPR128RegClass =
MRI.getRegClass(Root.getOperand(0).getReg());
@@ -7494,33 +7495,37 @@ static void generateGatherPattern(
};
// Helper to create load instruction based on opcode
- auto CreateLoadInstruction = [&](unsigned NumLanes, Register DestReg,
- Register OffsetReg) -> MachineInstrBuilder {
- unsigned Opcode;
- switch (NumLanes) {
- case 4:
- Opcode = AArch64::LDRSui;
- break;
- case 8:
- Opcode = AArch64::LDRHui;
- break;
- case 16:
- Opcode = AArch64::LDRBui;
- break;
- default:
- llvm_unreachable("Got unsupported number of lanes in machine-combiner gather pattern");
- }
- // Immediate offset load
- return BuildMI(MF, MIMetadata(Root), TII->get(Opcode), DestReg)
- .addReg(OffsetReg)
- .addImm(0); // immediate offset
+ auto CreateLoadInstruction = [&](unsigned NumLanes, Register DestReg,
+ Register OffsetReg) -> MachineInstrBuilder {
+ unsigned Opcode;
+ switch (NumLanes) {
+ case 4:
+ Opcode = AArch64::LDRSui;
+ break;
+ case 8:
+ Opcode = AArch64::LDRHui;
+ break;
+ case 16:
+ Opcode = AArch64::LDRBui;
+ break;
+ default:
+ llvm_unreachable(
+ "Got unsupported number of lanes in machine-combiner gather pattern");
+ }
+ // Immediate offset load
+ return BuildMI(MF, MIMetadata(Root), TII->get(Opcode), DestReg)
+ .addReg(OffsetReg)
+ .addImm(0); // immediate offset
};
// Load index 1 into register 0 lane 1
- auto LanesToLoadToReg0 = llvm::make_range(OriginalLoadInstrs.begin() + 1, OriginalLoadInstrs.begin() + NumLanes / 2);
+ auto LanesToLoadToReg0 =
+ llvm::make_range(OriginalLoadInstrs.begin() + 1,
+ OriginalLoadInstrs.begin() + NumLanes / 2);
auto PrevReg = SubregToReg->getOperand(0).getReg();
for (auto [Index, LoadInstr] : llvm::enumerate(LanesToLoadToReg0)) {
- PrevReg = LoadLaneToRegister(LoadInstr, PrevReg, Index + 1, LoadInstr->getOperand(3).getReg());
+ PrevReg = LoadLaneToRegister(LoadInstr, PrevReg, Index + 1,
+ LoadInstr->getOperand(3).getReg());
DelInstrs.push_back(LoadInstr);
}
auto LastLoadReg0 = PrevReg;
@@ -7530,12 +7535,13 @@ static void generateGatherPattern(
auto OriginalSplitLoad = *std::next(OriginalLoadInstrs.begin(), NumLanes / 2);
auto DestRegForMiddleIndex = MRI.createVirtualRegister(
MRI.getRegClass(Lane0Load->getOperand(0).getReg()));
-
- MachineInstrBuilder MiddleIndexLoadInstr = CreateLoadInstruction(
- NumLanes, DestRegForMiddleIndex,
- OriginalSplitLoad->getOperand(3).getReg());
-
- InstrIdxForVirtReg.insert(std::make_pair(DestRegForMiddleIndex, InsInstrs.size()));
+
+ MachineInstrBuilder MiddleIndexLoadInstr =
+ CreateLoadInstruction(NumLanes, DestRegForMiddleIndex,
+ OriginalSplitLoad->getOperand(3).getReg());
+
+ InstrIdxForVirtReg.insert(
+ std::make_pair(DestRegForMiddleIndex, InsInstrs.size()));
InsInstrs.push_back(MiddleIndexLoadInstr);
DelInstrs.push_back(OriginalSplitLoad);
@@ -7543,32 +7549,36 @@ static void generateGatherPattern(
auto DestRegForSubregToReg = MRI.createVirtualRegister(FPR128RegClass);
unsigned SubregType;
switch (NumLanes) {
- case 4:
- SubregType = AArch64::ssub;
- break;
- case 8:
- SubregType = AArch64::hsub;
- break;
- case 16:
- SubregType = AArch64::bsub;
- break;
- default:
- llvm_unreachable("Got invalid NumLanes for machine-combiner gather pattern");
+ case 4:
+ SubregType = AArch64::ssub;
+ break;
+ case 8:
+ SubregType = AArch64::hsub;
+ break;
+ case 16:
+ SubregType = AArch64::bsub;
+ break;
+ default:
+ llvm_unreachable(
+ "Got invalid NumLanes for machine-combiner gather pattern");
}
- auto SubRegToRegInstr = BuildMI(MF, MIMetadata(Root),
- TII->get(SubregToReg->getOpcode()),
- DestRegForSubregToReg)
- .addImm(0)
- .addReg(DestRegForMiddleIndex, getKillRegState(true))
- .addImm(SubregType);
- InstrIdxForVirtReg.insert(std::make_pair(DestRegForSubregToReg, InsInstrs.size()));
+ auto SubRegToRegInstr =
+ BuildMI(MF, MIMetadata(Root), TII->get(SubregToReg->getOpcode()),
+ DestRegForSubregToReg)
+ .addImm(0)
+ .addReg(DestRegForMiddleIndex, getKillRegState(true))
+ .addImm(SubregType);
+ InstrIdxForVirtReg.insert(
+ std::make_pair(DestRegForSubregToReg, InsInstrs.size()));
InsInstrs.push_back(SubRegToRegInstr);
// Load index 3 into register 1 lane 1
- auto LanesToLoadToReg1 = llvm::make_range(OriginalLoadInstrs.begin() + NumLanes / 2 + 1, OriginalLoadInstrs.end());
+ auto LanesToLoadToReg1 = llvm::make_range(
+ OriginalLoadInstrs.begin() + NumLanes / 2 + 1, OriginalLoadInstrs.end());
PrevReg = SubRegToRegInstr->getOperand(0).getReg();
for (auto [Index, LoadInstr] : llvm::enumerate(LanesToLoadToReg1)) {
- PrevReg = LoadLaneToRegister(LoadInstr, PrevReg, Index + 1, LoadInstr->getOperand(3).getReg());
+ PrevReg = LoadLaneToRegister(LoadInstr, PrevReg, Index + 1,
+ LoadInstr->getOperand(3).getReg());
if (Index == NumLanes / 2 - 2) {
break;
}
@@ -8957,18 +8967,21 @@ void AArch64InstrInfo::genAlternativeCodeSequence(
break;
}
case AArch64MachineCombinerPattern::GATHER_i32: {
- generateGatherPattern(Root, InsInstrs, DelInstrs, InstrIdxForVirtReg, Pattern, 4);
+ generateGatherPattern(Root, InsInstrs, DelInstrs, InstrIdxForVirtReg,
+ Pattern, 4);
for (const auto Instr : DelInstrs) {
Instr->print(llvm::errs());
}
break;
}
case AArch64MachineCombinerPattern::GATHER_i16: {
- generateGatherPattern(Root, InsInstrs, DelInstrs, InstrIdxForVirtReg, Pattern, 8);
+ generateGatherPattern(Root, InsInstrs, DelInstrs, InstrIdxForVirtReg,
+ Pattern, 8);
break;
}
case AArch64MachineCombinerPattern::GATHER_i8: {
- generateGatherPattern(Root, InsInstrs, DelInstrs, InstrIdxForVirtReg, Pattern, 16);
+ generateGatherPattern(Root, InsInstrs, DelInstrs, InstrIdxForVirtReg,
+ Pattern, 16);
break;
}
``````````
</details>
https://github.com/llvm/llvm-project/pull/142941
More information about the llvm-commits
mailing list