[llvm] [AArch64][MachineCombiner] Reassociate long chains of accumulation instructions into a tree to increase ILP (PR #126060)

Jonathan Cohen via llvm-commits llvm-commits at lists.llvm.org
Thu Mar 20 01:54:09 PDT 2025


================
@@ -899,6 +913,158 @@ bool TargetInstrInfo::isReassociationCandidate(const MachineInstr &Inst,
          hasReassociableSibling(Inst, Commuted);
 }
 
+// Utility routine that checks if \param MO is defined by an
+// \param CombineOpc instruction in the basic block \param MBB.
+// If \param CombineOpc is not provided, the OpCode check will
+// be skipped.
+static bool canCombine(MachineBasicBlock &MBB, MachineOperand &MO,
+                       unsigned CombineOpc = 0) {
+  MachineRegisterInfo &MRI = MBB.getParent()->getRegInfo();
+  MachineInstr *MI = nullptr;
+
+  if (MO.isReg() && MO.getReg().isVirtual())
+    MI = MRI.getUniqueVRegDef(MO.getReg());
+  // And it needs to be in the trace (otherwise, it won't have a depth).
+  if (!MI || MI->getParent() != &MBB ||
+      ((unsigned)MI->getOpcode() != CombineOpc && CombineOpc != 0))
+    return false;
+  // Must only used by the user we combine with.
+  if (!MRI.hasOneNonDBGUse(MI->getOperand(0).getReg()))
+    return false;
+
+  return true;
+}
+
+// A chain of accumulation instructions will be selected IFF:
+//    1. All the accumulation instructions in the chain have the same opcode,
+//       besides the first that has a slightly different opcode because it does
+//       not perform the accumulation, just defines it.
+//    2. All the instructions in the chain are combinable (have a single use
+//       which itself is part of the chain).
+//    3. Meets the required minimum length.
+void TargetInstrInfo::getAccumulatorChain(
+    MachineInstr *CurrentInstr, SmallVectorImpl<Register> &Chain) const {
+  // Walk up the chain of accumulation instructions and collect them in the
+  // vector.
+  MachineBasicBlock &MBB = *CurrentInstr->getParent();
+  const MachineRegisterInfo &MRI = MBB.getParent()->getRegInfo();
+  unsigned AccumulatorOpcode = CurrentInstr->getOpcode();
+  std::optional<unsigned> ChainStartOpCode =
+      getAccumulationStartOpcode(AccumulatorOpcode);
+
+  if (!ChainStartOpCode.has_value())
+    return;
+
+  // Push the first accumulator result to the start of the chain.
+  Chain.push_back(CurrentInstr->getOperand(0).getReg());
+
+  // Collect the accumulator input register from all instructions in the chain.
+  while (CurrentInstr &&
+         canCombine(MBB, CurrentInstr->getOperand(1), AccumulatorOpcode)) {
+    Chain.push_back(CurrentInstr->getOperand(1).getReg());
+    CurrentInstr = MRI.getUniqueVRegDef(CurrentInstr->getOperand(1).getReg());
+  }
+
+  // Add the instruction at the top of the chain.
+  if (CurrentInstr->getOpcode() == AccumulatorOpcode &&
+      canCombine(MBB, CurrentInstr->getOperand(1)))
+    Chain.push_back(CurrentInstr->getOperand(1).getReg());
+}
+
+/// Find chains of accumulations that can be rewritten as a tree for increased
+/// ILP.
+bool TargetInstrInfo::getAccumulatorReassociationPatterns(
+    MachineInstr &Root, SmallVectorImpl<unsigned> &Patterns) const {
+  if (!EnableAccReassociation)
+    return false;
+
+  unsigned Opc = Root.getOpcode();
+  if (!isAccumulationOpcode(Opc))
+    return false;
+
+  // Verify that this is the end of the chain.
+  MachineBasicBlock &MBB = *Root.getParent();
+  MachineRegisterInfo &MRI = MBB.getParent()->getRegInfo();
+  if (!MRI.hasOneNonDBGUser(Root.getOperand(0).getReg()))
+    return false;
+
+  auto User = MRI.use_instr_begin(Root.getOperand(0).getReg());
+  if (User->getOpcode() == Opc)
+    return false;
+
+  // Walk up the use chain and collect the reduction chain.
+  SmallVector<Register, 32> Chain;
+  getAccumulatorChain(&Root, Chain);
+
+  // Reject chains which are too short to be worth modifying.
+  if (Chain.size() < MinAccumulatorDepth)
+    return false;
+
+  // Check if the MBB this instruction is a part of contains any other chains.
+  // If so, don't apply it.
+  SmallSet<Register, 32> ReductionChain(Chain.begin(), Chain.end());
+  for (const auto &I : MBB) {
+    if (I.getOpcode() == Opc &&
+        !ReductionChain.contains(I.getOperand(0).getReg()))
+      return false;
+  }
+
+  Patterns.push_back(MachineCombinerPattern::ACC_CHAIN);
+  return true;
+}
+
+// Reduce branches of the accumulator tree by adding them together.
+void TargetInstrInfo::reduceAccumulatorTree(
+    SmallVectorImpl<Register> &RegistersToReduce,
+    SmallVectorImpl<MachineInstr *> &InsInstrs, MachineFunction &MF,
+    MachineInstr &Root, MachineRegisterInfo &MRI,
+    DenseMap<unsigned, unsigned> &InstrIdxForVirtReg,
+    Register ResultReg) const {
+  const TargetInstrInfo *TII = MF.getSubtarget().getInstrInfo();
+  SmallVector<Register, 8> NewRegs;
+
+  // Get the opcode for the reduction instruction we will need to build.
+  // If for some reason it is not defined, early exit and don't apply this.
+  std::optional<unsigned> ReduceOpCode =
+      getReduceOpcodeForAccumulator(Root.getOpcode());
+
+  if (!ReduceOpCode.value())
----------------
jcohen-apple wrote:

I assumed that the code alteration will only be applied if the code pattern is applied successfully, but I agree it makes more sense to bail if this is used incorrectly. I changed them back to use only `llvm_unreachable`

https://github.com/llvm/llvm-project/pull/126060


More information about the llvm-commits mailing list