[llvm] [AArch64][GlobalISel] Support udot lowering for vecreduce add (PR #70784)
David Green via llvm-commits
llvm-commits at lists.llvm.org
Thu Nov 2 00:37:40 PDT 2023
================
@@ -228,6 +228,146 @@ void applyFoldGlobalOffset(MachineInstr &MI, MachineRegisterInfo &MRI,
B.buildConstant(LLT::scalar(64), -static_cast<int64_t>(MinOffset)));
}
+// Combines vecreduce_add(mul(ext, ext)) -> vecreduce_add(udot)
+// Or vecreduce_add(ext) -> vecreduce_add(ext)
+// Similar to performVecReduceAddCombine in SelectionDAG
+bool matchExtAddvToUdotAddv(MachineInstr &MI, MachineRegisterInfo &MRI) {
+ assert(MI.getOpcode() == TargetOpcode::G_VECREDUCE_ADD &&
+ "Expected a G_VECREDUCE_ADD instruction");
+
+ MachineInstr *I1 = getDefIgnoringCopies(MI.getOperand(1).getReg(), MRI);
+ Register DstReg = MI.getOperand(0).getReg();
+ Register MidReg = I1->getOperand(0).getReg();
+ LLT DstTy = MRI.getType(DstReg);
+ LLT MidTy = MRI.getType(MidReg);
+ if (DstTy.getScalarSizeInBits() != 32 || MidTy.getScalarSizeInBits() != 32)
+ return false;
+
+ LLT SrcTy;
+ auto I1Opc = I1->getOpcode();
+ if (I1Opc == TargetOpcode::G_MUL) {
+ MachineInstr *ExtMI1 =
+ getDefIgnoringCopies(I1->getOperand(1).getReg(), MRI);
+ MachineInstr *ExtMI2 =
+ getDefIgnoringCopies(I1->getOperand(2).getReg(), MRI);
+ LLT Ext1DstTy = MRI.getType(ExtMI1->getOperand(0).getReg());
+ LLT Ext2DstTy = MRI.getType(ExtMI2->getOperand(0).getReg());
+
+ if (ExtMI1->getOpcode() != ExtMI2->getOpcode() || Ext1DstTy != Ext2DstTy)
+ return false;
+ I1Opc = ExtMI1->getOpcode();
+ SrcTy = MRI.getType(ExtMI1->getOperand(1).getReg());
+ } else
+ SrcTy = MRI.getType(I1->getOperand(1).getReg());
+
+ if (I1Opc != TargetOpcode::G_ZEXT && I1Opc != TargetOpcode::G_SEXT)
+ return false;
+ if (SrcTy.getScalarSizeInBits() != 8 || SrcTy.getNumElements() % 8 != 0)
+ return false;
+
+ return true;
+}
+
+void applyExtAddvToUdotAddv(MachineInstr &MI, MachineRegisterInfo &MRI,
+ MachineIRBuilder &Builder,
+ GISelChangeObserver &Observer) {
+ assert(MI.getOpcode() == TargetOpcode::G_VECREDUCE_ADD &&
+ "Expected a G_VECREDUCE_ADD instruction");
+ MachineInstr *I1 = getDefIgnoringCopies(MI.getOperand(1).getReg(), MRI);
+ Register Ext1SrcReg, Ext2SrcReg;
+ unsigned DotOpcode;
+ if (I1->getOpcode() == TargetOpcode::G_MUL) {
+ auto Ext1MI = getDefIgnoringCopies(I1->getOperand(1).getReg(), MRI);
+ auto Ext2MI = getDefIgnoringCopies(I1->getOperand(2).getReg(), MRI);
+ Ext1SrcReg = Ext1MI->getOperand(1).getReg();
+ Ext2SrcReg = Ext2MI->getOperand(1).getReg();
+ DotOpcode = Ext1MI->getOpcode() == TargetOpcode::G_ZEXT ? AArch64::G_UDOT
+ : AArch64::G_SDOT;
+ } else if (I1->getOpcode() == TargetOpcode::G_ZEXT ||
+ I1->getOpcode() == TargetOpcode::G_SEXT) {
+ Ext1SrcReg = I1->getOperand(1).getReg();
+ Ext2SrcReg = Builder.buildConstant(MRI.getType(Ext1SrcReg), 1)
+ ->getOperand(0)
+ .getReg();
+ DotOpcode = I1->getOpcode() == TargetOpcode::G_ZEXT ? AArch64::G_UDOT
+ : AArch64::G_SDOT;
+ } else
+ return;
+
+ LLT SrcTy = MRI.getType(Ext1SrcReg);
+ LLT MidTy;
+ unsigned NumOfVecReduce;
+ if (SrcTy.getNumElements() % 16 == 0) {
+ NumOfVecReduce = SrcTy.getNumElements() / 16;
+ MidTy = LLT::fixed_vector(4, 32);
+ } else if (SrcTy.getNumElements() % 8 == 0) {
+ NumOfVecReduce = SrcTy.getNumElements() / 8;
+ MidTy = LLT::fixed_vector(2, 32);
+ } else
+ return;
+
+ // Handle case where one DOT instruction is needed
+ if (NumOfVecReduce == 1) {
+ auto Zeroes = Builder.buildConstant(MidTy, 0)->getOperand(0).getReg();
+ auto Dot = Builder.buildInstr(DotOpcode, {MidTy},
+ {Zeroes, Ext1SrcReg, Ext2SrcReg});
+ Builder.buildVecReduceAdd(MI.getOperand(0), Dot->getOperand(0));
+ } else {
+ // Get the number of output vectors needed
+ SmallVector<LLT, 4> DotVecLLT;
+ auto SrcVecNum = SrcTy.getNumElements();
+ while (SrcVecNum - 16 >= 16 || SrcVecNum - 16 == 0) {
+ DotVecLLT.push_back(LLT::fixed_vector(16, 8));
+ SrcVecNum = SrcVecNum - 16;
+ }
+ if (SrcVecNum == 8)
+ DotVecLLT.push_back(LLT::fixed_vector(8, 8));
+
+ // Unmerge the source vectors
+ auto Ext1Unmerge = Builder.buildUnmerge(DotVecLLT, Ext1SrcReg);
+ auto Ext2Unmerge = Builder.buildUnmerge(DotVecLLT, Ext2SrcReg);
+
+ // Build the UDOT instructions
+ SmallVector<Register, 2> DotReg;
+ unsigned NumElements = 0;
+ for (unsigned i = 0; i < DotVecLLT.size(); i++) {
+ LLT ZeroesLLT;
+ // Check if it is 16 or 8 elements. Set Zeroes to the accoridng size
+ if (MRI.getType(Ext1Unmerge.getReg(i)).getNumElements() == 16) {
+ ZeroesLLT = LLT::fixed_vector(4, 32);
+ NumElements += 4;
+ } else {
+ ZeroesLLT = LLT::fixed_vector(2, 32);
+ NumElements += 2;
+ }
+ auto Zeroes = Builder.buildConstant(ZeroesLLT, 0)->getOperand(0).getReg();
+ DotReg.push_back(Builder
+ .buildInstr(DotOpcode, {MRI.getType(Zeroes)},
+ {Zeroes, Ext1Unmerge.getReg(i),
+ Ext2Unmerge.getReg(i)})
+ ->getOperand(0)
+ .getReg());
+ }
+
+ // Merge the output
+ // auto a = MI.getOperand(1).getReg().changeNumElements(NumElements);
----------------
davemgreen wrote:
Can be removed now
https://github.com/llvm/llvm-project/pull/70784
More information about the llvm-commits
mailing list