[llvm] [AArch64][GlobalISel] Support udot lowering for vecreduce add (PR #70784)

David Green via llvm-commits llvm-commits at lists.llvm.org
Thu Nov 2 00:37:39 PDT 2023


================
@@ -228,6 +228,146 @@ void applyFoldGlobalOffset(MachineInstr &MI, MachineRegisterInfo &MRI,
       B.buildConstant(LLT::scalar(64), -static_cast<int64_t>(MinOffset)));
 }
 
+// Combines vecreduce_add(mul(ext, ext)) -> vecreduce_add(udot)
+// Or vecreduce_add(ext) -> vecreduce_add(ext)
+// Similar to performVecReduceAddCombine in SelectionDAG
+bool matchExtAddvToUdotAddv(MachineInstr &MI, MachineRegisterInfo &MRI) {
+  assert(MI.getOpcode() == TargetOpcode::G_VECREDUCE_ADD &&
+         "Expected a G_VECREDUCE_ADD instruction");
+
+  MachineInstr *I1 = getDefIgnoringCopies(MI.getOperand(1).getReg(), MRI);
+  Register DstReg = MI.getOperand(0).getReg();
+  Register MidReg = I1->getOperand(0).getReg();
+  LLT DstTy = MRI.getType(DstReg);
+  LLT MidTy = MRI.getType(MidReg);
+  if (DstTy.getScalarSizeInBits() != 32 || MidTy.getScalarSizeInBits() != 32)
+    return false;
+
+  LLT SrcTy;
+  auto I1Opc = I1->getOpcode();
+  if (I1Opc == TargetOpcode::G_MUL) {
+    MachineInstr *ExtMI1 =
+        getDefIgnoringCopies(I1->getOperand(1).getReg(), MRI);
+    MachineInstr *ExtMI2 =
+        getDefIgnoringCopies(I1->getOperand(2).getReg(), MRI);
+    LLT Ext1DstTy = MRI.getType(ExtMI1->getOperand(0).getReg());
+    LLT Ext2DstTy = MRI.getType(ExtMI2->getOperand(0).getReg());
+
+    if (ExtMI1->getOpcode() != ExtMI2->getOpcode() || Ext1DstTy != Ext2DstTy)
+      return false;
+    I1Opc = ExtMI1->getOpcode();
+    SrcTy = MRI.getType(ExtMI1->getOperand(1).getReg());
+  } else
+    SrcTy = MRI.getType(I1->getOperand(1).getReg());
+
+  if (I1Opc != TargetOpcode::G_ZEXT && I1Opc != TargetOpcode::G_SEXT)
+    return false;
+  if (SrcTy.getScalarSizeInBits() != 8 || SrcTy.getNumElements() % 8 != 0)
+    return false;
+
+  return true;
+}
+
+void applyExtAddvToUdotAddv(MachineInstr &MI, MachineRegisterInfo &MRI,
+                            MachineIRBuilder &Builder,
+                            GISelChangeObserver &Observer) {
+  assert(MI.getOpcode() == TargetOpcode::G_VECREDUCE_ADD &&
+         "Expected a G_VECREDUCE_ADD instruction");
+  MachineInstr *I1 = getDefIgnoringCopies(MI.getOperand(1).getReg(), MRI);
+  Register Ext1SrcReg, Ext2SrcReg;
+  unsigned DotOpcode;
+  if (I1->getOpcode() == TargetOpcode::G_MUL) {
+    auto Ext1MI = getDefIgnoringCopies(I1->getOperand(1).getReg(), MRI);
+    auto Ext2MI = getDefIgnoringCopies(I1->getOperand(2).getReg(), MRI);
+    Ext1SrcReg = Ext1MI->getOperand(1).getReg();
+    Ext2SrcReg = Ext2MI->getOperand(1).getReg();
+    DotOpcode = Ext1MI->getOpcode() == TargetOpcode::G_ZEXT ? AArch64::G_UDOT
+                                                            : AArch64::G_SDOT;
+  } else if (I1->getOpcode() == TargetOpcode::G_ZEXT ||
+             I1->getOpcode() == TargetOpcode::G_SEXT) {
+    Ext1SrcReg = I1->getOperand(1).getReg();
+    Ext2SrcReg = Builder.buildConstant(MRI.getType(Ext1SrcReg), 1)
+                     ->getOperand(0)
+                     .getReg();
+    DotOpcode = I1->getOpcode() == TargetOpcode::G_ZEXT ? AArch64::G_UDOT
+                                                        : AArch64::G_SDOT;
+  } else
+    return;
+
+  LLT SrcTy = MRI.getType(Ext1SrcReg);
+  LLT MidTy;
+  unsigned NumOfVecReduce;
+  if (SrcTy.getNumElements() % 16 == 0) {
+    NumOfVecReduce = SrcTy.getNumElements() / 16;
+    MidTy = LLT::fixed_vector(4, 32);
+  } else if (SrcTy.getNumElements() % 8 == 0) {
+    NumOfVecReduce = SrcTy.getNumElements() / 8;
+    MidTy = LLT::fixed_vector(2, 32);
+  } else
+    return;
+
+  // Handle case where one DOT instruction is needed
+  if (NumOfVecReduce == 1) {
+    auto Zeroes = Builder.buildConstant(MidTy, 0)->getOperand(0).getReg();
+    auto Dot = Builder.buildInstr(DotOpcode, {MidTy},
+                                  {Zeroes, Ext1SrcReg, Ext2SrcReg});
+    Builder.buildVecReduceAdd(MI.getOperand(0), Dot->getOperand(0));
+  } else {
+    // Get the number of output vectors needed
+    SmallVector<LLT, 4> DotVecLLT;
+    auto SrcVecNum = SrcTy.getNumElements();
+    while (SrcVecNum - 16 >= 16 || SrcVecNum - 16 == 0) {
----------------
davemgreen wrote:

Do you have tests for 24x vector types?

https://github.com/llvm/llvm-project/pull/70784


More information about the llvm-commits mailing list