[llvm] [AArch64][GlobalISel] Support udot lowering for vecreduce add (PR #70784)
David Green via llvm-commits
llvm-commits at lists.llvm.org
Tue Nov 14 01:43:09 PST 2023
================
@@ -228,6 +228,199 @@ void applyFoldGlobalOffset(MachineInstr &MI, MachineRegisterInfo &MRI,
B.buildConstant(LLT::scalar(64), -static_cast<int64_t>(MinOffset)));
}
+// Combines vecreduce_add(mul(ext(x), ext(y))) -> vecreduce_add(udot(x, y))
+// Or vecreduce_add(ext(x)) -> vecreduce_add(udot(x, 1))
+// Similar to performVecReduceAddCombine in SelectionDAG
+bool matchExtAddvToUdotAddv(MachineInstr &MI, MachineRegisterInfo &MRI,
+ const AArch64Subtarget &STI,
+ std::tuple<Register, Register, bool> &MatchInfo) {
+ assert(MI.getOpcode() == TargetOpcode::G_VECREDUCE_ADD &&
+ "Expected a G_VECREDUCE_ADD instruction");
+ assert(STI.hasDotProd() && "Target should have Dot Product feature");
+
+ MachineInstr *I1 = getDefIgnoringCopies(MI.getOperand(1).getReg(), MRI);
+ Register DstReg = MI.getOperand(0).getReg();
+ Register MidReg = I1->getOperand(0).getReg();
+ LLT DstTy = MRI.getType(DstReg);
+ LLT MidTy = MRI.getType(MidReg);
+ if (DstTy.getScalarSizeInBits() != 32 || MidTy.getScalarSizeInBits() != 32)
+ return false;
+
+ LLT SrcTy;
+ auto I1Opc = I1->getOpcode();
+ if (I1Opc == TargetOpcode::G_MUL) {
+ // If result of this has more than 1 use, then there is no point in creating
+ // udot instruction
+ if (!MRI.hasOneNonDBGUse(MidReg))
+ return false;
+
+ MachineInstr *ExtMI1 =
+ getDefIgnoringCopies(I1->getOperand(1).getReg(), MRI);
+ MachineInstr *ExtMI2 =
+ getDefIgnoringCopies(I1->getOperand(2).getReg(), MRI);
+ LLT Ext1DstTy = MRI.getType(ExtMI1->getOperand(0).getReg());
+ LLT Ext2DstTy = MRI.getType(ExtMI2->getOperand(0).getReg());
+
+ if (ExtMI1->getOpcode() != ExtMI2->getOpcode() || Ext1DstTy != Ext2DstTy)
+ return false;
+ I1Opc = ExtMI1->getOpcode();
+ SrcTy = MRI.getType(ExtMI1->getOperand(1).getReg());
+ std::get<0>(MatchInfo) = ExtMI1->getOperand(1).getReg();
+ std::get<1>(MatchInfo) = ExtMI2->getOperand(1).getReg();
+ } else {
+ SrcTy = MRI.getType(I1->getOperand(1).getReg());
+ std::get<0>(MatchInfo) = I1->getOperand(1).getReg();
+ std::get<1>(MatchInfo) = 0;
+ }
+
+ if (I1Opc == TargetOpcode::G_ZEXT)
+ std::get<2>(MatchInfo) = 0;
+ else if (I1Opc == TargetOpcode::G_SEXT)
+ std::get<2>(MatchInfo) = 1;
+ else
+ return false;
+
+ if (SrcTy.getScalarSizeInBits() != 8 || SrcTy.getNumElements() % 8 != 0)
+ return false;
+
+ return true;
+}
+
+void applyExtAddvToUdotAddv(MachineInstr &MI, MachineRegisterInfo &MRI,
+ MachineIRBuilder &Builder,
+ GISelChangeObserver &Observer,
+ const AArch64Subtarget &STI,
+ std::tuple<Register, Register, bool> &MatchInfo) {
+ assert(MI.getOpcode() == TargetOpcode::G_VECREDUCE_ADD &&
+ "Expected a G_VECREDUCE_ADD instruction");
+ assert(STI.hasDotProd() && "Target should have Dot Product feature");
+
+ // Initialise the variables
+ unsigned DotOpcode =
+ std::get<2>(MatchInfo) ? AArch64::G_SDOT : AArch64::G_UDOT;
+ Register Ext1SrcReg = std::get<0>(MatchInfo);
+
+ // If there is one source register, create a vector of 0s as the second
+ // source register
+ Register Ext2SrcReg;
+ if (std::get<1>(MatchInfo) == 0)
+ Ext2SrcReg = Builder.buildConstant(MRI.getType(Ext1SrcReg), 1)
+ ->getOperand(0)
+ .getReg();
+ else
+ Ext2SrcReg = std::get<1>(MatchInfo);
+
+ // Find out how many DOT instructions are needed
+ LLT SrcTy = MRI.getType(Ext1SrcReg);
+ LLT MidTy;
+ unsigned NumOfDotMI;
+ if (SrcTy.getNumElements() % 16 == 0) {
+ NumOfDotMI = SrcTy.getNumElements() / 16;
+ MidTy = LLT::fixed_vector(4, 32);
+ } else if (SrcTy.getNumElements() % 8 == 0) {
+ NumOfDotMI = SrcTy.getNumElements() / 8;
+ MidTy = LLT::fixed_vector(2, 32);
+ } else {
+ llvm_unreachable("Source type number of elements is not multiple of 8");
+ }
+
+ // Handle case where one DOT instruction is needed
+ if (NumOfDotMI == 1) {
+ auto Zeroes = Builder.buildConstant(MidTy, 0)->getOperand(0).getReg();
+ auto Dot = Builder.buildInstr(DotOpcode, {MidTy},
+ {Zeroes, Ext1SrcReg, Ext2SrcReg});
+ Builder.buildVecReduceAdd(MI.getOperand(0), Dot->getOperand(0));
+ } else {
+ // If not pad the last v8 element with 0s to a v16
+ SmallVector<Register, 4> Ext1UnmergeReg;
+ SmallVector<Register, 4> Ext2UnmergeReg;
+ if (SrcTy.getNumElements() % 16 != 0) {
+ // Unmerge source to v8i8, append a new v8i8 of 0s and the merge to v16s
+ SmallVector<Register, 4> PadUnmergeDstReg1;
+ SmallVector<Register, 4> PadUnmergeDstReg2;
+ unsigned NumOfVec = SrcTy.getNumElements() / 8;
+
+ // Unmerge the source to v8i8
+ MachineInstr *PadUnmerge1 =
+ Builder.buildUnmerge(LLT::fixed_vector(8, 8), Ext1SrcReg);
+ MachineInstr *PadUnmerge2 =
+ Builder.buildUnmerge(LLT::fixed_vector(8, 8), Ext2SrcReg);
+ for (unsigned i = 0; i < NumOfVec; i++) {
+ PadUnmergeDstReg1.push_back(PadUnmerge1->getOperand(i).getReg());
+ PadUnmergeDstReg2.push_back(PadUnmerge2->getOperand(i).getReg());
+ }
+
+ // Pad the vectors with a v8i8 constant of 0s
+ MachineInstr *v8Zeroes =
+ Builder.buildConstant(LLT::fixed_vector(8, 8), 0);
+ PadUnmergeDstReg1.push_back(v8Zeroes->getOperand(0).getReg());
+ PadUnmergeDstReg2.push_back(v8Zeroes->getOperand(0).getReg());
+
+ // Merge them all back to v16i8
+ NumOfVec = (NumOfVec + 1) / 2;
+ for (unsigned i = 0; i < NumOfVec; i++) {
+ Ext1UnmergeReg.push_back(
+ Builder
+ .buildMergeLikeInstr(
+ LLT::fixed_vector(16, 8),
+ {PadUnmergeDstReg1[i * 2], PadUnmergeDstReg1[(i * 2) + 1]})
+ ->getOperand(0)
+ .getReg());
----------------
davemgreen wrote:
I think that ->getOperand(0).getReg() can be .getReg(0) from a MIBuilder. Here and 2 other places below.
https://github.com/llvm/llvm-project/pull/70784
More information about the llvm-commits
mailing list