[llvm] [AArch64][Machine-Combiner] Split loads into lanes of neon vectors into multiple vectors when possible (PR #142941)

via llvm-commits llvm-commits at lists.llvm.org
Tue Jul 8 12:04:45 PDT 2025


github-actions[bot] wrote:

<!--LLVM CODE FORMAT COMMENT: {clang-format}-->


:warning: C/C++ code formatter, clang-format found issues in your code. :warning:

<details>
<summary>
You can test this locally with the following command:
</summary>

``````````bash
git-clang-format --diff HEAD~1 HEAD --extensions cpp,h -- llvm/lib/Target/AArch64/AArch64InstrInfo.cpp llvm/lib/Target/AArch64/AArch64InstrInfo.h
``````````

</details>

<details>
<summary>
View the diff from clang-format here.
</summary>

``````````diff
diff --git a/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp b/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp
index 520b11bd5..bf8e78c71 100644
--- a/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp
+++ b/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp
@@ -7373,9 +7373,8 @@ static bool getMiscPatterns(MachineInstr &Root,
 }
 
 static bool getGatherPattern(MachineInstr &Root,
-                                SmallVectorImpl<unsigned> &Patterns,
-                                unsigned LoadLaneOpCode,
-                                unsigned NumLanes) {
+                             SmallVectorImpl<unsigned> &Patterns,
+                             unsigned LoadLaneOpCode, unsigned NumLanes) {
   const MachineRegisterInfo &MRI = Root.getMF()->getRegInfo();
   const TargetRegisterInfo *TRI =
       Root.getMF()->getSubtarget().getRegisterInfo();
@@ -7414,17 +7413,17 @@ static bool getGatherPattern(MachineInstr &Root,
     return false;
 
   switch (NumLanes) {
-    case 4:
-      Patterns.push_back(AArch64MachineCombinerPattern::GATHER_i32);
-      break;
-    case 8:
-      Patterns.push_back(AArch64MachineCombinerPattern::GATHER_i16);
-      break;
-    case 16:
-      Patterns.push_back(AArch64MachineCombinerPattern::GATHER_i8);
-      break;
-    default:
-      llvm_unreachable("Got bad number of lanes for gather pattern.");
+  case 4:
+    Patterns.push_back(AArch64MachineCombinerPattern::GATHER_i32);
+    break;
+  case 8:
+    Patterns.push_back(AArch64MachineCombinerPattern::GATHER_i16);
+    break;
+  case 16:
+    Patterns.push_back(AArch64MachineCombinerPattern::GATHER_i8);
+    break;
+  default:
+    llvm_unreachable("Got bad number of lanes for gather pattern.");
   }
 
   return true;
@@ -7442,23 +7441,24 @@ static bool getLoadPatterns(MachineInstr &Root,
 
   // The pattern searches for loads into single lanes.
   switch (Root.getOpcode()) {
-    case AArch64::LD1i32:
-      return getGatherPattern(Root, Patterns, Root.getOpcode(), 4);
-    case AArch64::LD1i16:
-      return getGatherPattern(Root, Patterns, Root.getOpcode(), 8);
-    case AArch64::LD1i8:
-      return getGatherPattern(Root, Patterns, Root.getOpcode(), 16);
-    default:
-      return false;
+  case AArch64::LD1i32:
+    return getGatherPattern(Root, Patterns, Root.getOpcode(), 4);
+  case AArch64::LD1i16:
+    return getGatherPattern(Root, Patterns, Root.getOpcode(), 8);
+  case AArch64::LD1i8:
+    return getGatherPattern(Root, Patterns, Root.getOpcode(), 16);
+  default:
+    return false;
   }
 }
 
-static void generateGatherPattern(
-    MachineInstr &Root, SmallVectorImpl<MachineInstr *> &InsInstrs,
-    SmallVectorImpl<MachineInstr *> &DelInstrs,
-    DenseMap<Register, unsigned> &InstrIdxForVirtReg, unsigned Pattern,
-    unsigned NumLanes) {
-  
+static void
+generateGatherPattern(MachineInstr &Root,
+                      SmallVectorImpl<MachineInstr *> &InsInstrs,
+                      SmallVectorImpl<MachineInstr *> &DelInstrs,
+                      DenseMap<Register, unsigned> &InstrIdxForVirtReg,
+                      unsigned Pattern, unsigned NumLanes) {
+
   MachineFunction &MF = *Root.getParent()->getParent();
   MachineRegisterInfo &MRI = MF.getRegInfo();
   const TargetInstrInfo *TII = MF.getSubtarget().getInstrInfo();
@@ -7470,11 +7470,12 @@ static void generateGatherPattern(
     LoadToLaneInstrs.push_back(CurrInstr);
     CurrInstr = MRI.getUniqueVRegDef(CurrInstr->getOperand(1).getReg());
   }
-  
+
   MachineInstr *SubregToReg = CurrInstr;
-  LoadToLaneInstrs.push_back(MRI.getUniqueVRegDef(SubregToReg->getOperand(2).getReg()));
+  LoadToLaneInstrs.push_back(
+      MRI.getUniqueVRegDef(SubregToReg->getOperand(2).getReg()));
   auto OriginalLoadInstrs = llvm::reverse(LoadToLaneInstrs);
-  
+
   const TargetRegisterClass *FPR128RegClass =
       MRI.getRegClass(Root.getOperand(0).getReg());
 
@@ -7494,33 +7495,37 @@ static void generateGatherPattern(
   };
 
   // Helper to create load instruction based on opcode
-  auto CreateLoadInstruction = [&](unsigned NumLanes, Register DestReg, 
-                                  Register OffsetReg) -> MachineInstrBuilder {
-      unsigned Opcode;
-      switch (NumLanes) {
-        case 4:
-          Opcode = AArch64::LDRSui;
-          break;
-        case 8:
-          Opcode = AArch64::LDRHui;
-          break;
-        case 16:
-          Opcode = AArch64::LDRBui;
-          break;
-        default:
-          llvm_unreachable("Got unsupported number of lanes in machine-combiner gather pattern");
-      }
-      // Immediate offset load
-      return BuildMI(MF, MIMetadata(Root), TII->get(Opcode), DestReg)
-            .addReg(OffsetReg)
-            .addImm(0); // immediate offset
+  auto CreateLoadInstruction = [&](unsigned NumLanes, Register DestReg,
+                                   Register OffsetReg) -> MachineInstrBuilder {
+    unsigned Opcode;
+    switch (NumLanes) {
+    case 4:
+      Opcode = AArch64::LDRSui;
+      break;
+    case 8:
+      Opcode = AArch64::LDRHui;
+      break;
+    case 16:
+      Opcode = AArch64::LDRBui;
+      break;
+    default:
+      llvm_unreachable(
+          "Got unsupported number of lanes in machine-combiner gather pattern");
+    }
+    // Immediate offset load
+    return BuildMI(MF, MIMetadata(Root), TII->get(Opcode), DestReg)
+        .addReg(OffsetReg)
+        .addImm(0); // immediate offset
   };
 
   // Load index 1 into register 0 lane 1
-  auto LanesToLoadToReg0 = llvm::make_range(OriginalLoadInstrs.begin() + 1, OriginalLoadInstrs.begin() + NumLanes / 2);
+  auto LanesToLoadToReg0 =
+      llvm::make_range(OriginalLoadInstrs.begin() + 1,
+                       OriginalLoadInstrs.begin() + NumLanes / 2);
   auto PrevReg = SubregToReg->getOperand(0).getReg();
   for (auto [Index, LoadInstr] : llvm::enumerate(LanesToLoadToReg0)) {
-    PrevReg = LoadLaneToRegister(LoadInstr, PrevReg, Index + 1, LoadInstr->getOperand(3).getReg());
+    PrevReg = LoadLaneToRegister(LoadInstr, PrevReg, Index + 1,
+                                 LoadInstr->getOperand(3).getReg());
     DelInstrs.push_back(LoadInstr);
   }
   auto LastLoadReg0 = PrevReg;
@@ -7530,12 +7535,13 @@ static void generateGatherPattern(
   auto OriginalSplitLoad = *std::next(OriginalLoadInstrs.begin(), NumLanes / 2);
   auto DestRegForMiddleIndex = MRI.createVirtualRegister(
       MRI.getRegClass(Lane0Load->getOperand(0).getReg()));
-  
-  MachineInstrBuilder MiddleIndexLoadInstr = CreateLoadInstruction(
-      NumLanes, DestRegForMiddleIndex, 
-      OriginalSplitLoad->getOperand(3).getReg());
-  
-  InstrIdxForVirtReg.insert(std::make_pair(DestRegForMiddleIndex, InsInstrs.size()));
+
+  MachineInstrBuilder MiddleIndexLoadInstr =
+      CreateLoadInstruction(NumLanes, DestRegForMiddleIndex,
+                            OriginalSplitLoad->getOperand(3).getReg());
+
+  InstrIdxForVirtReg.insert(
+      std::make_pair(DestRegForMiddleIndex, InsInstrs.size()));
   InsInstrs.push_back(MiddleIndexLoadInstr);
   DelInstrs.push_back(OriginalSplitLoad);
 
@@ -7543,32 +7549,36 @@ static void generateGatherPattern(
   auto DestRegForSubregToReg = MRI.createVirtualRegister(FPR128RegClass);
   unsigned SubregType;
   switch (NumLanes) {
-    case 4:
-      SubregType = AArch64::ssub;
-      break;
-    case 8:
-      SubregType = AArch64::hsub;
-      break;
-    case 16:
-      SubregType = AArch64::bsub;
-      break;
-    default:
-      llvm_unreachable("Got invalid NumLanes for machine-combiner gather pattern");
+  case 4:
+    SubregType = AArch64::ssub;
+    break;
+  case 8:
+    SubregType = AArch64::hsub;
+    break;
+  case 16:
+    SubregType = AArch64::bsub;
+    break;
+  default:
+    llvm_unreachable(
+        "Got invalid NumLanes for machine-combiner gather pattern");
   }
-  auto SubRegToRegInstr = BuildMI(MF, MIMetadata(Root), 
-                                TII->get(SubregToReg->getOpcode()), 
-                                DestRegForSubregToReg)
-      .addImm(0)
-      .addReg(DestRegForMiddleIndex, getKillRegState(true))
-      .addImm(SubregType);
-  InstrIdxForVirtReg.insert(std::make_pair(DestRegForSubregToReg, InsInstrs.size()));
+  auto SubRegToRegInstr =
+      BuildMI(MF, MIMetadata(Root), TII->get(SubregToReg->getOpcode()),
+              DestRegForSubregToReg)
+          .addImm(0)
+          .addReg(DestRegForMiddleIndex, getKillRegState(true))
+          .addImm(SubregType);
+  InstrIdxForVirtReg.insert(
+      std::make_pair(DestRegForSubregToReg, InsInstrs.size()));
   InsInstrs.push_back(SubRegToRegInstr);
 
   // Load index 3 into register 1 lane 1
-  auto LanesToLoadToReg1 = llvm::make_range(OriginalLoadInstrs.begin() + NumLanes / 2 + 1, OriginalLoadInstrs.end());
+  auto LanesToLoadToReg1 = llvm::make_range(
+      OriginalLoadInstrs.begin() + NumLanes / 2 + 1, OriginalLoadInstrs.end());
   PrevReg = SubRegToRegInstr->getOperand(0).getReg();
   for (auto [Index, LoadInstr] : llvm::enumerate(LanesToLoadToReg1)) {
-    PrevReg = LoadLaneToRegister(LoadInstr, PrevReg, Index + 1, LoadInstr->getOperand(3).getReg());
+    PrevReg = LoadLaneToRegister(LoadInstr, PrevReg, Index + 1,
+                                 LoadInstr->getOperand(3).getReg());
     if (Index == NumLanes / 2 - 2) {
       break;
     }
@@ -8957,18 +8967,21 @@ void AArch64InstrInfo::genAlternativeCodeSequence(
     break;
   }
   case AArch64MachineCombinerPattern::GATHER_i32: {
-    generateGatherPattern(Root, InsInstrs, DelInstrs, InstrIdxForVirtReg, Pattern, 4);
+    generateGatherPattern(Root, InsInstrs, DelInstrs, InstrIdxForVirtReg,
+                          Pattern, 4);
     for (const auto Instr : DelInstrs) {
       Instr->print(llvm::errs());
     }
     break;
   }
   case AArch64MachineCombinerPattern::GATHER_i16: {
-    generateGatherPattern(Root, InsInstrs, DelInstrs, InstrIdxForVirtReg, Pattern, 8);
+    generateGatherPattern(Root, InsInstrs, DelInstrs, InstrIdxForVirtReg,
+                          Pattern, 8);
     break;
   }
   case AArch64MachineCombinerPattern::GATHER_i8: {
-    generateGatherPattern(Root, InsInstrs, DelInstrs, InstrIdxForVirtReg, Pattern, 16);
+    generateGatherPattern(Root, InsInstrs, DelInstrs, InstrIdxForVirtReg,
+                          Pattern, 16);
     break;
   }
 

``````````

</details>


https://github.com/llvm/llvm-project/pull/142941


More information about the llvm-commits mailing list