[llvm] [AArch64][GlobalISel] Combine vecreduce(ext) to {U/S}ADDLV (PR #75832)

David Green via llvm-commits llvm-commits at lists.llvm.org
Mon Jan 8 03:36:18 PST 2024


================
@@ -418,6 +411,157 @@ void applyExtAddvToUdotAddv(MachineInstr &MI, MachineRegisterInfo &MRI,
   MI.eraseFromParent();
 }
 
+// Matches {U/S}ADDV(ext(x)) => {U/S}ADDLV(x)
+// Ensure that the type coming from the extend instruction is the right size
+bool matchExtUaddvToUaddlv(MachineInstr &MI, MachineRegisterInfo &MRI,
+                           std::pair<Register, bool> &MatchInfo) {
+  assert(MI.getOpcode() == TargetOpcode::G_VECREDUCE_ADD &&
+         "Expected G_VECREDUCE_ADD Opcode");
+
+  // Check if the last instruction is an extend
+  MachineInstr *ExtMI = getDefIgnoringCopies(MI.getOperand(1).getReg(), MRI);
+  auto ExtOpc = ExtMI->getOpcode();
+
+  if (ExtOpc == TargetOpcode::G_ZEXT)
+    std::get<1>(MatchInfo) = 0;
+  else if (ExtOpc == TargetOpcode::G_SEXT)
+    std::get<1>(MatchInfo) = 1;
+  else
+    return false;
+
+  // Check if the source register is a valid type
+  Register ExtSrcReg = ExtMI->getOperand(1).getReg();
+  LLT ExtSrcTy = MRI.getType(ExtSrcReg);
+  LLT DstTy = MRI.getType(MI.getOperand(0).getReg());
+  if ((DstTy.getScalarSizeInBits() == 16 &&
+       ExtSrcTy.getNumElements() % 8 == 0 && ExtSrcTy.getNumElements() < 256) ||
+      (DstTy.getScalarSizeInBits() == 32 &&
+       ExtSrcTy.getNumElements() % 4 == 0 &&
+       ExtSrcTy.getNumElements() < 65536) ||
+      (DstTy.getScalarSizeInBits() == 64 &&
+       ExtSrcTy.getNumElements() % 4 == 0 &&
+       ExtSrcTy.getNumElements() < 4294967296)) {
+    std::get<0>(MatchInfo) = ExtSrcReg;
+    return true;
+  }
+  return false;
+}
+
+void applyExtUaddvToUaddlv(MachineInstr &MI, MachineRegisterInfo &MRI,
+                           MachineIRBuilder &B, GISelChangeObserver &Observer,
+                           std::pair<Register, bool> &MatchInfo) {
+  assert(MI.getOpcode() == TargetOpcode::G_VECREDUCE_ADD &&
+         "Expected G_VECREDUCE_ADD Opcode");
+
+  unsigned Opc = std::get<1>(MatchInfo) ? AArch64::G_SADDLV : AArch64::G_UADDLV;
+  Register SrcReg = std::get<0>(MatchInfo);
+  Register DstReg = MI.getOperand(0).getReg();
+  LLT SrcTy = MRI.getType(SrcReg);
+  LLT DstTy = MRI.getType(DstReg);
+
+  // If SrcTy has more elements than expected, split them into multiple
+  // insructions and sum the results
+  LLT MainTy;
+  SmallVector<Register, 1> WorkingRegisters;
+  unsigned SrcScalSize = SrcTy.getScalarSizeInBits();
+  unsigned SrcNumElem = SrcTy.getNumElements();
+  if ((SrcScalSize == 8 && SrcNumElem > 16) ||
+      (SrcScalSize == 16 && SrcNumElem > 8) ||
+      (SrcScalSize == 32 && SrcNumElem > 4)) {
+
+    LLT LeftoverTy;
+    SmallVector<Register, 4> LeftoverRegs;
+    if (SrcScalSize == 8)
+      MainTy = LLT::fixed_vector(16, 8);
+    else if (SrcScalSize == 16)
+      MainTy = LLT::fixed_vector(8, 16);
+    else if (SrcScalSize == 32)
+      MainTy = LLT::fixed_vector(4, 32);
+    else
+      llvm_unreachable("Source's Scalar Size not supported");
+
+    // Extract the parts and put each extracted sources through U/SADDLV and put
+    // the values inside a small vec
+    extractParts(SrcReg, SrcTy, MainTy, LeftoverTy, WorkingRegisters,
+                 LeftoverRegs, B, MRI);
+    for (unsigned I = 0; I < LeftoverRegs.size(); I++) {
+      WorkingRegisters.push_back(LeftoverRegs[I]);
+    }
+  } else {
+    WorkingRegisters.push_back(SrcReg);
+    MainTy = SrcTy;
+  }
+
+  unsigned MidScalarSize = MainTy.getScalarSizeInBits() * 2;
+  LLT MidScalarLLT = LLT::scalar(MidScalarSize);
+  Register zeroReg =
+      B.buildConstant(LLT::scalar(64), 0)->getOperand(0).getReg();
----------------
davemgreen wrote:

getReg(0)?

https://github.com/llvm/llvm-project/pull/75832


More information about the llvm-commits mailing list