[llvm] [AArch64][GlobalISel] Combine vecreduce(ext) to {U/S}ADDLV (PR #75832)
David Green via llvm-commits
llvm-commits at lists.llvm.org
Mon Jan 8 03:36:18 PST 2024
================
@@ -418,6 +411,157 @@ void applyExtAddvToUdotAddv(MachineInstr &MI, MachineRegisterInfo &MRI,
MI.eraseFromParent();
}
+// Matches {U/S}ADDV(ext(x)) => {U/S}ADDLV(x)
+// Ensure that the type coming from the extend instruction is the right size
+bool matchExtUaddvToUaddlv(MachineInstr &MI, MachineRegisterInfo &MRI,
+ std::pair<Register, bool> &MatchInfo) {
+ assert(MI.getOpcode() == TargetOpcode::G_VECREDUCE_ADD &&
+ "Expected G_VECREDUCE_ADD Opcode");
+
+ // Check if the last instruction is an extend
+ MachineInstr *ExtMI = getDefIgnoringCopies(MI.getOperand(1).getReg(), MRI);
+ auto ExtOpc = ExtMI->getOpcode();
+
+ if (ExtOpc == TargetOpcode::G_ZEXT)
+ std::get<1>(MatchInfo) = 0;
+ else if (ExtOpc == TargetOpcode::G_SEXT)
+ std::get<1>(MatchInfo) = 1;
+ else
+ return false;
+
+ // Check if the source register is a valid type
+ Register ExtSrcReg = ExtMI->getOperand(1).getReg();
+ LLT ExtSrcTy = MRI.getType(ExtSrcReg);
+ LLT DstTy = MRI.getType(MI.getOperand(0).getReg());
+ if ((DstTy.getScalarSizeInBits() == 16 &&
+ ExtSrcTy.getNumElements() % 8 == 0 && ExtSrcTy.getNumElements() < 256) ||
+ (DstTy.getScalarSizeInBits() == 32 &&
+ ExtSrcTy.getNumElements() % 4 == 0 &&
+ ExtSrcTy.getNumElements() < 65536) ||
+ (DstTy.getScalarSizeInBits() == 64 &&
+ ExtSrcTy.getNumElements() % 4 == 0 &&
+ ExtSrcTy.getNumElements() < 4294967296)) {
+ std::get<0>(MatchInfo) = ExtSrcReg;
+ return true;
+ }
+ return false;
+}
+
+void applyExtUaddvToUaddlv(MachineInstr &MI, MachineRegisterInfo &MRI,
+ MachineIRBuilder &B, GISelChangeObserver &Observer,
+ std::pair<Register, bool> &MatchInfo) {
+ assert(MI.getOpcode() == TargetOpcode::G_VECREDUCE_ADD &&
+ "Expected G_VECREDUCE_ADD Opcode");
+
+ unsigned Opc = std::get<1>(MatchInfo) ? AArch64::G_SADDLV : AArch64::G_UADDLV;
+ Register SrcReg = std::get<0>(MatchInfo);
+ Register DstReg = MI.getOperand(0).getReg();
+ LLT SrcTy = MRI.getType(SrcReg);
+ LLT DstTy = MRI.getType(DstReg);
+
+ // If SrcTy has more elements than expected, split them into multiple
+ // insructions and sum the results
+ LLT MainTy;
+ SmallVector<Register, 1> WorkingRegisters;
+ unsigned SrcScalSize = SrcTy.getScalarSizeInBits();
+ unsigned SrcNumElem = SrcTy.getNumElements();
+ if ((SrcScalSize == 8 && SrcNumElem > 16) ||
+ (SrcScalSize == 16 && SrcNumElem > 8) ||
+ (SrcScalSize == 32 && SrcNumElem > 4)) {
+
+ LLT LeftoverTy;
+ SmallVector<Register, 4> LeftoverRegs;
+ if (SrcScalSize == 8)
+ MainTy = LLT::fixed_vector(16, 8);
+ else if (SrcScalSize == 16)
+ MainTy = LLT::fixed_vector(8, 16);
+ else if (SrcScalSize == 32)
+ MainTy = LLT::fixed_vector(4, 32);
+ else
+ llvm_unreachable("Source's Scalar Size not supported");
+
+ // Extract the parts and put each extracted sources through U/SADDLV and put
+ // the values inside a small vec
+ extractParts(SrcReg, SrcTy, MainTy, LeftoverTy, WorkingRegisters,
+ LeftoverRegs, B, MRI);
+ for (unsigned I = 0; I < LeftoverRegs.size(); I++) {
+ WorkingRegisters.push_back(LeftoverRegs[I]);
+ }
+ } else {
+ WorkingRegisters.push_back(SrcReg);
+ MainTy = SrcTy;
+ }
+
+ unsigned MidScalarSize = MainTy.getScalarSizeInBits() * 2;
+ LLT MidScalarLLT = LLT::scalar(MidScalarSize);
+ Register zeroReg =
+ B.buildConstant(LLT::scalar(64), 0)->getOperand(0).getReg();
----------------
davemgreen wrote:
getReg(0)?
https://github.com/llvm/llvm-project/pull/75832
More information about the llvm-commits
mailing list