[llvm] [AArch64][GlobalISel] Support udot lowering for vecreduce add (PR #70784)
via llvm-commits
llvm-commits at lists.llvm.org
Thu Nov 2 01:50:06 PDT 2023
================
@@ -228,6 +228,146 @@ void applyFoldGlobalOffset(MachineInstr &MI, MachineRegisterInfo &MRI,
B.buildConstant(LLT::scalar(64), -static_cast<int64_t>(MinOffset)));
}
+// Combines vecreduce_add(mul(ext, ext)) -> vecreduce_add(udot)
+// Or vecreduce_add(ext) -> vecreduce_add(ext)
+// Similar to performVecReduceAddCombine in SelectionDAG
+bool matchExtAddvToUdotAddv(MachineInstr &MI, MachineRegisterInfo &MRI) {
+ assert(MI.getOpcode() == TargetOpcode::G_VECREDUCE_ADD &&
+ "Expected a G_VECREDUCE_ADD instruction");
+
+ MachineInstr *I1 = getDefIgnoringCopies(MI.getOperand(1).getReg(), MRI);
+ Register DstReg = MI.getOperand(0).getReg();
+ Register MidReg = I1->getOperand(0).getReg();
+ LLT DstTy = MRI.getType(DstReg);
+ LLT MidTy = MRI.getType(MidReg);
+ if (DstTy.getScalarSizeInBits() != 32 || MidTy.getScalarSizeInBits() != 32)
+ return false;
+
+ LLT SrcTy;
+ auto I1Opc = I1->getOpcode();
+ if (I1Opc == TargetOpcode::G_MUL) {
+ MachineInstr *ExtMI1 =
+ getDefIgnoringCopies(I1->getOperand(1).getReg(), MRI);
+ MachineInstr *ExtMI2 =
+ getDefIgnoringCopies(I1->getOperand(2).getReg(), MRI);
+ LLT Ext1DstTy = MRI.getType(ExtMI1->getOperand(0).getReg());
+ LLT Ext2DstTy = MRI.getType(ExtMI2->getOperand(0).getReg());
+
+ if (ExtMI1->getOpcode() != ExtMI2->getOpcode() || Ext1DstTy != Ext2DstTy)
+ return false;
+ I1Opc = ExtMI1->getOpcode();
+ SrcTy = MRI.getType(ExtMI1->getOperand(1).getReg());
+ } else
+ SrcTy = MRI.getType(I1->getOperand(1).getReg());
----------------
tschuett wrote:
The braces are unbalanced.
https://github.com/llvm/llvm-project/pull/70784
More information about the llvm-commits
mailing list