[llvm] [AArch64][GlobalISel] Combine vecreduce(ext) to {U/S}ADDLV (PR #75832)

David Green via llvm-commits llvm-commits at lists.llvm.org
Tue Dec 19 07:22:27 PST 2023


================
@@ -418,6 +411,158 @@ void applyExtAddvToUdotAddv(MachineInstr &MI, MachineRegisterInfo &MRI,
   MI.eraseFromParent();
 }
 
+// Matches {U/S}ADDV(ext(x)) => {U/S}ADDLV(x)
+// Ensure that the type coming from the extend instruction is the right size
+bool matchExtUaddvToUaddlv(MachineInstr &MI, MachineRegisterInfo &MRI,
+                           std::pair<Register, bool> &MatchInfo) {
+  assert(MI.getOpcode() == TargetOpcode::G_VECREDUCE_ADD &&
+         "Expected G_VECREDUCE_ADD Opcode");
+
+  // Check if the last instruction is an extend
+  MachineInstr *ExtMI = getDefIgnoringCopies(MI.getOperand(1).getReg(), MRI);
+  auto ExtOpc = ExtMI->getOpcode();
+
+  if (ExtOpc == TargetOpcode::G_ZEXT)
+    std::get<1>(MatchInfo) = 0;
+  else if (ExtOpc == TargetOpcode::G_SEXT)
+    std::get<1>(MatchInfo) = 1;
+  else
+    return false;
+
+  // Check if the source register is a valid type
+  Register ExtSrcReg = ExtMI->getOperand(1).getReg();
+  LLT ExtSrcTy = MRI.getType(ExtSrcReg);
+  LLT DstTy = MRI.getType(MI.getOperand(0).getReg());
+  if ((DstTy.getScalarSizeInBits() == 16 &&
+       ExtSrcTy.getNumElements() % 8 == 0) ||
+      (DstTy.getScalarSizeInBits() == 32 &&
+       ExtSrcTy.getNumElements() % 4 == 0) ||
+      (DstTy.getScalarSizeInBits() == 64 &&
+       ExtSrcTy.getNumElements() % 4 == 0)) {
+    std::get<0>(MatchInfo) = ExtSrcReg;
+    return true;
+  }
+  return false;
+}
+
+void applyExtUaddvToUaddlv(MachineInstr &MI, MachineRegisterInfo &MRI,
+                           MachineIRBuilder &B, GISelChangeObserver &Observer,
+                           std::pair<Register, bool> &MatchInfo) {
+  assert(MI.getOpcode() == TargetOpcode::G_VECREDUCE_ADD &&
+         "Expected G_VECREDUCE_ADD Opcode");
+
+  unsigned Opc = std::get<1>(MatchInfo) ? AArch64::G_SADDLV : AArch64::G_UADDLV;
+  Register SrcReg = std::get<0>(MatchInfo);
+  Register DstReg = MI.getOperand(0).getReg();
+  LLT SrcTy = MRI.getType(SrcReg);
+  LLT DstTy = MRI.getType(DstReg);
+
+  // If SrcTy has more elements than expected, split them into multiple
+  // insructions and sum the results
+  LLT MainTy;
+  SmallVector<Register, 1> WorkingRegisters;
+  unsigned SrcScalSize = SrcTy.getScalarSizeInBits();
+  unsigned SrcNumElem = SrcTy.getNumElements();
+  if ((SrcScalSize == 8 && SrcNumElem > 16) ||
+      (SrcScalSize == 16 && SrcNumElem > 8) ||
+      (SrcScalSize == 32 && SrcNumElem > 4)) {
+
+    LLT LeftoverTy;
+    SmallVector<Register, 4> LeftoverRegs;
+    if (SrcScalSize == 8)
+      MainTy = LLT::fixed_vector(16, 8);
+    else if (SrcScalSize == 16)
+      MainTy = LLT::fixed_vector(8, 16);
+    else if (SrcScalSize == 32)
+      MainTy = LLT::fixed_vector(4, 32);
+    else
+      llvm_unreachable("Source's Scalar Size not supported");
+
+    // Extract the parts and put each extracted sources through U/SADDLV and put
+    // the values inside a small vec
+    extractParts(SrcReg, SrcTy, MainTy, LeftoverTy, WorkingRegisters,
+                 LeftoverRegs, B, MRI);
+    for (unsigned I = 0; I < LeftoverRegs.size(); I++) {
+      WorkingRegisters.push_back(LeftoverRegs[I]);
+    }
+  } else {
+    WorkingRegisters.push_back(SrcReg);
+    MainTy = SrcTy;
+  }
+
+  unsigned MidScalarSize = MainTy.getScalarSizeInBits() * 2;
+  LLT MidScalarLLT = LLT::scalar(MidScalarSize);
+  Register zeroReg =
+      B.buildConstant(LLT::scalar(64), 0)->getOperand(0).getReg();
+  for (unsigned I = 0; I < WorkingRegisters.size(); I++) {
+    // If the number of elements is too small to build an instruction, extend
+    // its size before applying addlv
+    LLT WorkingRegTy = MRI.getType(WorkingRegisters[I]);
+    if ((WorkingRegTy.getScalarSizeInBits() == 8) &&
+        (WorkingRegTy.getNumElements() == 4)) {
+      WorkingRegisters[I] =
+          B.buildInstr(std::get<1>(MatchInfo) ? TargetOpcode::G_SEXT
+                                              : TargetOpcode::G_ZEXT,
+                       {LLT::fixed_vector(4, 16)}, {WorkingRegisters[I]})
+              ->getOperand(0)
+              .getReg();
+    }
+
+    // Generate the {U/S}ADDLV instruction, whose output is always double of the
+    // Src's Scalar size
+    LLT addlvTy = MidScalarSize <= 32 ? LLT::fixed_vector(4, 32)
+                                      : LLT::fixed_vector(2, 64);
+    Register addlvReg = B.buildInstr(Opc, {addlvTy}, {WorkingRegisters[I]})
+                            ->getOperand(0)
+                            .getReg();
+
+    // The output from {U/S}ADDLV gets placed in the lowest lane of a v4i32 or
+    // v2i64 register.
+    //     i16, i32 results uses v4i32 registers
+    //     i64      results uses v2i64 registers
+    // Therefore we have to extract/truncate the the value to the right type
+    if (MidScalarSize == 32 || MidScalarSize == 64) {
+      WorkingRegisters[I] = B.buildInstr(AArch64::G_EXTRACT_VECTOR_ELT,
+                                         {MidScalarLLT}, {addlvReg, zeroReg})
+                                ->getOperand(0)
+                                .getReg();
+    } else {
+      Register extractReg = B.buildInstr(AArch64::G_EXTRACT_VECTOR_ELT,
+                                         {LLT::scalar(32)}, {addlvReg, zeroReg})
+                                ->getOperand(0)
+                                .getReg();
+      WorkingRegisters[I] =
+          B.buildTrunc({MidScalarLLT}, {extractReg})->getOperand(0).getReg();
+    }
+  }
+
+  Register outReg;
+  if (WorkingRegisters.size() > 1) {
+    outReg = B.buildAdd(MidScalarLLT, WorkingRegisters[0], WorkingRegisters[1])
+                 ->getOperand(0)
+                 .getReg();
+    for (unsigned I = 2; I < WorkingRegisters.size(); I++) {
+      outReg = B.buildAdd(MidScalarLLT, outReg, WorkingRegisters[I])
+                   ->getOperand(0)
+                   .getReg();
----------------
davemgreen wrote:

`->getOperand(0).getReg()` -> `.getReg(0)`? Same above.

https://github.com/llvm/llvm-project/pull/75832


More information about the llvm-commits mailing list