[llvm] [MachineCombiner] Reassociate long chains of accumulation instructions into a tree to increase ILP (PR #126060)

via llvm-commits llvm-commits at lists.llvm.org
Thu Feb 6 05:11:36 PST 2025


github-actions[bot] wrote:

<!--LLVM CODE FORMAT COMMENT: {clang-format}-->


:warning: C/C++ code formatter, clang-format found issues in your code. :warning:

<details>
<summary>
You can test this locally with the following command:
</summary>

``````````bash
git-clang-format --diff 7ebacf3a999fc9766c3f0ec4979e3ed08344c348 830c373b3d03c52f75da18d5ff40fa47efff38f7 --extensions cpp,h -- llvm/lib/Target/AArch64/AArch64InstrInfo.cpp llvm/lib/Target/AArch64/AArch64InstrInfo.h
``````````

</details>

<details>
<summary>
View the diff from clang-format here.
</summary>

``````````diff
diff --git a/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp b/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp
index af8d6a8031..fe8cedf912 100644
--- a/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp
+++ b/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp
@@ -79,17 +79,18 @@ static cl::opt<unsigned>
     BDisplacementBits("aarch64-b-offset-bits", cl::Hidden, cl::init(26),
                       cl::desc("Restrict range of B instructions (DEBUG)"));
 
-static cl::opt<bool>
-    EnableAccReassociation("aarch64-acc-reassoc", cl::Hidden, cl::init(true), 
-                      cl::desc("Enable reassociation of accumulation chains"));
+static cl::opt<bool> EnableAccReassociation(
+    "aarch64-acc-reassoc", cl::Hidden, cl::init(true),
+    cl::desc("Enable reassociation of accumulation chains"));
 
 static cl::opt<unsigned int>
-    MinAccumulatorDepth("aarch64-acc-min-depth", cl::Hidden, cl::init(8), 
-                      cl::desc("Minimum length of accumulator chains required for the optimization to kick in"));
-
-static cl::opt<unsigned int>
-    MaxAccumulatorWidth("aarch64-acc-max-width", cl::Hidden, cl::init(3), cl::desc("Maximum number of branches in the accumulator tree"));
+    MinAccumulatorDepth("aarch64-acc-min-depth", cl::Hidden, cl::init(8),
+                        cl::desc("Minimum length of accumulator chains "
+                                 "required for the optimization to kick in"));
 
+static cl::opt<unsigned int> MaxAccumulatorWidth(
+    "aarch64-acc-max-width", cl::Hidden, cl::init(3),
+    cl::desc("Maximum number of branches in the accumulator tree"));
 
 AArch64InstrInfo::AArch64InstrInfo(const AArch64Subtarget &STI)
     : AArch64GenInstrInfo(AArch64::ADJCALLSTACKDOWN, AArch64::ADJCALLSTACKUP,
@@ -6689,63 +6690,69 @@ static bool getMaddPatterns(MachineInstr &Root,
 }
 
 static bool isAccumulationOpcode(unsigned Opcode) {
-  switch(Opcode) {
-    default:
-      break;
-    case AArch64::UABALB_ZZZ_D:
-    case AArch64::UABALB_ZZZ_H:
-    case AArch64::UABALB_ZZZ_S:
-    case AArch64::UABALT_ZZZ_D:
-    case AArch64::UABALT_ZZZ_H:
-    case AArch64::UABALT_ZZZ_S:
-    case AArch64::UABALv16i8_v8i16:
-    case AArch64::UABALv2i32_v2i64:
-    case AArch64::UABALv4i16_v4i32:
-    case AArch64::UABALv4i32_v2i64:
-    case AArch64::UABALv8i16_v4i32:
-    case AArch64::UABALv8i8_v8i16:
-      return true;
+  switch (Opcode) {
+  default:
+    break;
+  case AArch64::UABALB_ZZZ_D:
+  case AArch64::UABALB_ZZZ_H:
+  case AArch64::UABALB_ZZZ_S:
+  case AArch64::UABALT_ZZZ_D:
+  case AArch64::UABALT_ZZZ_H:
+  case AArch64::UABALT_ZZZ_S:
+  case AArch64::UABALv16i8_v8i16:
+  case AArch64::UABALv2i32_v2i64:
+  case AArch64::UABALv4i16_v4i32:
+  case AArch64::UABALv4i32_v2i64:
+  case AArch64::UABALv8i16_v4i32:
+  case AArch64::UABALv8i8_v8i16:
+    return true;
   }
-  
+
   return false;
 }
 
 static unsigned getAccumulationStartOpCode(unsigned AccumulationOpcode) {
-  switch(AccumulationOpcode) {
-    default:
-      llvm_unreachable("Unknown accumulator opcode");
-    case AArch64::UABALB_ZZZ_D:
-      return AArch64::UABDLB_ZZZ_D;
-    case AArch64::UABALB_ZZZ_H:
-      return AArch64::UABDLB_ZZZ_H;
-    case AArch64::UABALB_ZZZ_S:
-      return AArch64::UABDLB_ZZZ_S;
-    case AArch64::UABALT_ZZZ_D:
-      return AArch64::UABDLT_ZZZ_D;
-    case AArch64::UABALT_ZZZ_H:
-        return AArch64::UABDLT_ZZZ_H;
-    case AArch64::UABALT_ZZZ_S:
-        return AArch64::UABDLT_ZZZ_S;
-    case AArch64::UABALv16i8_v8i16:
-        return AArch64::UABDLv16i8_v8i16;
-    case AArch64::UABALv2i32_v2i64:
-        return AArch64::UABDLv2i32_v2i64;
-    case AArch64::UABALv4i16_v4i32:
-        return AArch64::UABDLv4i16_v4i32;
-    case AArch64::UABALv4i32_v2i64:
-        return AArch64::UABDLv4i32_v2i64;
-    case AArch64::UABALv8i16_v4i32:
-        return AArch64::UABDLv8i16_v4i32;
-    case AArch64::UABALv8i8_v8i16:
-        return AArch64::UABDLv8i8_v8i16;
-  }
-}
-
-static void getAccumulatorChain(MachineInstr *CurrentInstr, MachineBasicBlock &MBB, MachineRegisterInfo &MRI, SmallVectorImpl<Register> &Chain) {
-  // Walk up the chain of accumulation instructions and collect them in the vector.
+  switch (AccumulationOpcode) {
+  default:
+    llvm_unreachable("Unknown accumulator opcode");
+  case AArch64::UABALB_ZZZ_D:
+    return AArch64::UABDLB_ZZZ_D;
+  case AArch64::UABALB_ZZZ_H:
+    return AArch64::UABDLB_ZZZ_H;
+  case AArch64::UABALB_ZZZ_S:
+    return AArch64::UABDLB_ZZZ_S;
+  case AArch64::UABALT_ZZZ_D:
+    return AArch64::UABDLT_ZZZ_D;
+  case AArch64::UABALT_ZZZ_H:
+    return AArch64::UABDLT_ZZZ_H;
+  case AArch64::UABALT_ZZZ_S:
+    return AArch64::UABDLT_ZZZ_S;
+  case AArch64::UABALv16i8_v8i16:
+    return AArch64::UABDLv16i8_v8i16;
+  case AArch64::UABALv2i32_v2i64:
+    return AArch64::UABDLv2i32_v2i64;
+  case AArch64::UABALv4i16_v4i32:
+    return AArch64::UABDLv4i16_v4i32;
+  case AArch64::UABALv4i32_v2i64:
+    return AArch64::UABDLv4i32_v2i64;
+  case AArch64::UABALv8i16_v4i32:
+    return AArch64::UABDLv8i16_v4i32;
+  case AArch64::UABALv8i8_v8i16:
+    return AArch64::UABDLv8i8_v8i16;
+  }
+}
+
+static void getAccumulatorChain(MachineInstr *CurrentInstr,
+                                MachineBasicBlock &MBB,
+                                MachineRegisterInfo &MRI,
+                                SmallVectorImpl<Register> &Chain) {
+  // Walk up the chain of accumulation instructions and collect them in the
+  // vector.
   unsigned AccumulatorOpcode = CurrentInstr->getOpcode();
   unsigned ChainStartOpCode = getAccumulationStartOpCode(AccumulatorOpcode);
-  while(CurrentInstr && (canCombine(MBB, CurrentInstr->getOperand(1), AccumulatorOpcode) || canCombine(MBB, CurrentInstr->getOperand(1), ChainStartOpCode))) {
+  while (CurrentInstr &&
+         (canCombine(MBB, CurrentInstr->getOperand(1), AccumulatorOpcode) ||
+          canCombine(MBB, CurrentInstr->getOperand(1), ChainStartOpCode))) {
     Chain.push_back(CurrentInstr->getOperand(0).getReg());
     CurrentInstr = MRI.getUniqueVRegDef(CurrentInstr->getOperand(1).getReg());
   }
@@ -6755,16 +6762,17 @@ static void getAccumulatorChain(MachineInstr *CurrentInstr, MachineBasicBlock &M
     Chain.push_back(CurrentInstr->getOperand(0).getReg());
 }
 
-/// Find chains of accumulations, likely linearized by reassocation pass, 
+/// Find chains of accumulations, likely linearized by reassocation pass,
 /// that can be rewritten as a tree for increased ILP.
-static bool getAccumulatorReassociationPatterns(MachineInstr &Root,
-                            SmallVectorImpl<unsigned> &Patterns) {
+static bool
+getAccumulatorReassociationPatterns(MachineInstr &Root,
+                                    SmallVectorImpl<unsigned> &Patterns) {
   // find a chain of depth 4, which would make it profitable to rewrite
   // as a tree. This pattern should be applied recursively in case we
   // have a longer chain.
   if (!EnableAccReassociation)
     return false;
-  
+
   unsigned Opc = Root.getOpcode();
   if (!isAccumulationOpcode(Opc))
     return false;
@@ -6778,7 +6786,7 @@ static bool getAccumulatorReassociationPatterns(MachineInstr &Root,
   auto User = MRI.use_instr_begin(Root.getOperand(0).getReg());
   if (User->getOpcode() == Opc)
     return false;
-  
+
   // Walk up the use chain and collect the reduction chain.
   SmallVector<Register, 32> Chain;
   getAccumulatorChain(&Root, MBB, MRI, Chain);
@@ -6787,10 +6795,12 @@ static bool getAccumulatorReassociationPatterns(MachineInstr &Root,
   if (Chain.size() < MinAccumulatorDepth)
     return false;
 
-  // Check if the MBB this instruction is a part of contains any other chains. If so, don't apply it.
+  // Check if the MBB this instruction is a part of contains any other chains.
+  // If so, don't apply it.
   SmallSet<Register, 32> ReductionChain(Chain.begin(), Chain.end());
   for (const auto &I : MBB) {
-    if (I.getOpcode() == Opc && !ReductionChain.contains(I.getOperand(0).getReg()))
+    if (I.getOpcode() == Opc &&
+        !ReductionChain.contains(I.getOperand(0).getReg()))
       return false;
   }
 
@@ -7564,44 +7574,47 @@ genSubAdd2SubSub(MachineFunction &MF, MachineRegisterInfo &MRI,
   DelInstrs.push_back(&Root);
 }
 
-static unsigned int getReduceOpCodeForAccumulator(unsigned int AccumulatorOpCode) {
+static unsigned int
+getReduceOpCodeForAccumulator(unsigned int AccumulatorOpCode) {
   switch (AccumulatorOpCode) {
-    case AArch64::UABALB_ZZZ_D:
-      return AArch64::ADD_ZZZ_D;
-    case AArch64::UABALB_ZZZ_H:
-      return AArch64::ADD_ZZZ_H;
-    case AArch64::UABALB_ZZZ_S:
-      return AArch64::ADD_ZZZ_S;
-    case AArch64::UABALT_ZZZ_D:
-      return AArch64::ADD_ZZZ_D;
-    case AArch64::UABALT_ZZZ_H:
-      return AArch64::ADD_ZZZ_H;
-    case AArch64::UABALT_ZZZ_S:
-      return AArch64::ADD_ZZZ_S;
-    case AArch64::UABALv16i8_v8i16:
-      return AArch64::ADDv8i16;
-    case AArch64::UABALv2i32_v2i64:
-      return AArch64::ADDv2i64;
-    case AArch64::UABALv4i16_v4i32:
-      return AArch64::ADDv4i32;
-    case AArch64::UABALv4i32_v2i64:
-      return AArch64::ADDv2i64;
-    case AArch64::UABALv8i16_v4i32:
-      return AArch64::ADDv4i32;
-    case AArch64::UABALv8i8_v8i16:
-      return AArch64::ADDv8i16;
-    default:
-      llvm_unreachable("Unknown accumulator opcode");
+  case AArch64::UABALB_ZZZ_D:
+    return AArch64::ADD_ZZZ_D;
+  case AArch64::UABALB_ZZZ_H:
+    return AArch64::ADD_ZZZ_H;
+  case AArch64::UABALB_ZZZ_S:
+    return AArch64::ADD_ZZZ_S;
+  case AArch64::UABALT_ZZZ_D:
+    return AArch64::ADD_ZZZ_D;
+  case AArch64::UABALT_ZZZ_H:
+    return AArch64::ADD_ZZZ_H;
+  case AArch64::UABALT_ZZZ_S:
+    return AArch64::ADD_ZZZ_S;
+  case AArch64::UABALv16i8_v8i16:
+    return AArch64::ADDv8i16;
+  case AArch64::UABALv2i32_v2i64:
+    return AArch64::ADDv2i64;
+  case AArch64::UABALv4i16_v4i32:
+    return AArch64::ADDv4i32;
+  case AArch64::UABALv4i32_v2i64:
+    return AArch64::ADDv2i64;
+  case AArch64::UABALv8i16_v4i32:
+    return AArch64::ADDv4i32;
+  case AArch64::UABALv8i8_v8i16:
+    return AArch64::ADDv8i16;
+  default:
+    llvm_unreachable("Unknown accumulator opcode");
   }
 }
 
 // Reduce branches of the accumulator tree by adding them together.
-static void reduceAccumulatorTree(SmallVectorImpl<Register> &RegistersToReduce, SmallVectorImpl<MachineInstr *> &InsInstrs, 
-                   MachineFunction &MF, MachineInstr &Root, MachineRegisterInfo &MRI, 
-                   DenseMap<unsigned, unsigned> &InstrIdxForVirtReg, Register ResultReg) {
+static void reduceAccumulatorTree(
+    SmallVectorImpl<Register> &RegistersToReduce,
+    SmallVectorImpl<MachineInstr *> &InsInstrs, MachineFunction &MF,
+    MachineInstr &Root, MachineRegisterInfo &MRI,
+    DenseMap<unsigned, unsigned> &InstrIdxForVirtReg, Register ResultReg) {
   const TargetInstrInfo *TII = MF.getSubtarget().getInstrInfo();
   SmallVector<Register, 8> NewRegs;
-  for (unsigned int i = 1; i <= (RegistersToReduce.size() / 2); i+=2) {
+  for (unsigned int i = 1; i <= (RegistersToReduce.size() / 2); i += 2) {
     auto RHS = RegistersToReduce[i - 1];
     auto LHS = RegistersToReduce[i];
     Register Dest;
@@ -7610,20 +7623,26 @@ static void reduceAccumulatorTree(SmallVectorImpl<Register> &RegistersToReduce,
       Dest = ResultReg;
     // Otherwise, create a new virtual register to hold the partial sum.
     else {
-      auto NewVR = MRI.createVirtualRegister(MRI.getRegClass(Root.getOperand(0).getReg()));
+      auto NewVR = MRI.createVirtualRegister(
+          MRI.getRegClass(Root.getOperand(0).getReg()));
       Dest = NewVR;
       NewRegs.push_back(Dest);
       InstrIdxForVirtReg.insert(std::make_pair(Dest, InsInstrs.size()));
     }
-    
+
     // Create the new add instruction.
-    MachineInstrBuilder MIB = BuildMI(MF, MIMetadata(Root), TII->get(getReduceOpCodeForAccumulator(Root.getOpcode())), Dest).addReg(RHS, getKillRegState(true)).addReg(LHS, getKillRegState(true));
+    MachineInstrBuilder MIB =
+        BuildMI(MF, MIMetadata(Root),
+                TII->get(getReduceOpCodeForAccumulator(Root.getOpcode())), Dest)
+            .addReg(RHS, getKillRegState(true))
+            .addReg(LHS, getKillRegState(true));
     // Copy any flags needed from the original instruction.
     MIB->setFlags(Root.getFlags());
     InsInstrs.push_back(MIB);
   }
 
-  // If the number of registers to reduce is odd, add the reminaing register to the vector of registers to reduce.
+  // If the number of registers to reduce is odd, add the reminaing register to
+  // the vector of registers to reduce.
   if (RegistersToReduce.size() % 2 != 0)
     NewRegs.push_back(RegistersToReduce[RegistersToReduce.size() - 1]);
 
@@ -7870,9 +7889,12 @@ void AArch64InstrInfo::genAlternativeCodeSequence(
     getAccumulatorChain(&Root, MBB, MRI, ChainRegs);
 
     unsigned int Depth = ChainRegs.size();
-    assert(MaxAccumulatorWidth > 1 && "Max accumulator width set to illegal value");
-    unsigned int MaxWidth = Log2_32(Depth) < MaxAccumulatorWidth ? Log2_32(Depth) : MaxAccumulatorWidth;
-    
+    assert(MaxAccumulatorWidth > 1 &&
+           "Max accumulator width set to illegal value");
+    unsigned int MaxWidth = Log2_32(Depth) < MaxAccumulatorWidth
+                                ? Log2_32(Depth)
+                                : MaxAccumulatorWidth;
+
     // Walk down the chain and rewrite it as a tree.
     for (auto IndexedReg : llvm::enumerate(llvm::reverse(ChainRegs))) {
       // No need to rewrite the first node, it is already perfect as it is.
@@ -7885,13 +7907,31 @@ void AArch64InstrInfo::genAlternativeCodeSequence(
       if (IndexedReg.index() < MaxWidth) {
         // Now we need to create new instructions for the first row.
         AccReg = Instr->getOperand(0).getReg();
-        MIB = BuildMI(MF, MIMetadata(*Instr), TII->get(MRI.getUniqueVRegDef(ChainRegs.back())->getOpcode()), AccReg).addReg(Instr->getOperand(2).getReg(), getKillRegState(Instr->getOperand(2).isKill())).addReg(Instr->getOperand(3).getReg(), getKillRegState(Instr->getOperand(3).isKill()));       
+        MIB = BuildMI(
+                  MF, MIMetadata(*Instr),
+                  TII->get(MRI.getUniqueVRegDef(ChainRegs.back())->getOpcode()),
+                  AccReg)
+                  .addReg(Instr->getOperand(2).getReg(),
+                          getKillRegState(Instr->getOperand(2).isKill()))
+                  .addReg(Instr->getOperand(3).getReg(),
+                          getKillRegState(Instr->getOperand(3).isKill()));
       } else {
-        // For the remaining cases, we need ot use an output register of one of the newly inserted instuctions as operand 1
-        AccReg = Instr->getOperand(0).getReg() == Root.getOperand(0).getReg() ? MRI.createVirtualRegister(MRI.getRegClass(Root.getOperand(0).getReg())) : Instr->getOperand(0).getReg();
+        // For the remaining cases, we need ot use an output register of one of
+        // the newly inserted instuctions as operand 1
+        AccReg = Instr->getOperand(0).getReg() == Root.getOperand(0).getReg()
+                     ? MRI.createVirtualRegister(
+                           MRI.getRegClass(Root.getOperand(0).getReg()))
+                     : Instr->getOperand(0).getReg();
         assert(IndexedReg.index() - MaxWidth >= 0);
-        auto AccumulatorInput = ChainRegs[Depth - (IndexedReg.index() - MaxWidth) - 1];
-        MIB = BuildMI(MF, MIMetadata(*Instr), TII->get(Instr->getOpcode()), AccReg).addReg(AccumulatorInput, getKillRegState(true)).addReg(Instr->getOperand(2).getReg(), getKillRegState(Instr->getOperand(2).isKill())).addReg(Instr->getOperand(3).getReg(), getKillRegState(Instr->getOperand(3).isKill()));
+        auto AccumulatorInput =
+            ChainRegs[Depth - (IndexedReg.index() - MaxWidth) - 1];
+        MIB = BuildMI(MF, MIMetadata(*Instr), TII->get(Instr->getOpcode()),
+                      AccReg)
+                  .addReg(AccumulatorInput, getKillRegState(true))
+                  .addReg(Instr->getOperand(2).getReg(),
+                          getKillRegState(Instr->getOperand(2).isKill()))
+                  .addReg(Instr->getOperand(3).getReg(),
+                          getKillRegState(Instr->getOperand(3).isKill()));
       }
 
       MIB->setFlags(Instr->getFlags());
@@ -7899,7 +7939,7 @@ void AArch64InstrInfo::genAlternativeCodeSequence(
       InsInstrs.push_back(MIB);
       DelInstrs.push_back(Instr);
     }
-    
+
     SmallVector<Register, 8> RegistersToReduce;
     for (int i = (InsInstrs.size() - MaxWidth); i < InsInstrs.size(); ++i) {
       auto Reg = InsInstrs[i]->getOperand(0).getReg();
@@ -7907,9 +7947,11 @@ void AArch64InstrInfo::genAlternativeCodeSequence(
     }
 
     while (RegistersToReduce.size() > 1)
-      reduceAccumulatorTree(RegistersToReduce, InsInstrs, MF, Root, MRI, InstrIdxForVirtReg, Root.getOperand(0).getReg());
-    
-    // We don't want to break, we handle setting flags and adding Root to DelInstrs from here.
+      reduceAccumulatorTree(RegistersToReduce, InsInstrs, MF, Root, MRI,
+                            InstrIdxForVirtReg, Root.getOperand(0).getReg());
+
+    // We don't want to break, we handle setting flags and adding Root to
+    // DelInstrs from here.
     return;
   }
   case AArch64MachineCombinerPattern::MULADDv8i8_OP1:

``````````

</details>


https://github.com/llvm/llvm-project/pull/126060


More information about the llvm-commits mailing list