[llvm] [AArch64][GlobalISel] Support udot lowering for vecreduce add (PR #70784)

David Green via llvm-commits llvm-commits at lists.llvm.org
Tue Nov 14 01:43:09 PST 2023


================
@@ -228,6 +228,199 @@ void applyFoldGlobalOffset(MachineInstr &MI, MachineRegisterInfo &MRI,
       B.buildConstant(LLT::scalar(64), -static_cast<int64_t>(MinOffset)));
 }
 
+// Combines vecreduce_add(mul(ext(x), ext(y))) -> vecreduce_add(udot(x, y))
+// Or vecreduce_add(ext(x)) -> vecreduce_add(udot(x, 1))
+// Similar to performVecReduceAddCombine in SelectionDAG
+bool matchExtAddvToUdotAddv(MachineInstr &MI, MachineRegisterInfo &MRI,
+                            const AArch64Subtarget &STI,
+                            std::tuple<Register, Register, bool> &MatchInfo) {
+  assert(MI.getOpcode() == TargetOpcode::G_VECREDUCE_ADD &&
+         "Expected a G_VECREDUCE_ADD instruction");
+  assert(STI.hasDotProd() && "Target should have Dot Product feature");
+
+  MachineInstr *I1 = getDefIgnoringCopies(MI.getOperand(1).getReg(), MRI);
+  Register DstReg = MI.getOperand(0).getReg();
+  Register MidReg = I1->getOperand(0).getReg();
+  LLT DstTy = MRI.getType(DstReg);
+  LLT MidTy = MRI.getType(MidReg);
+  if (DstTy.getScalarSizeInBits() != 32 || MidTy.getScalarSizeInBits() != 32)
+    return false;
+
+  LLT SrcTy;
+  auto I1Opc = I1->getOpcode();
+  if (I1Opc == TargetOpcode::G_MUL) {
+    // If result of this has more than 1 use, then there is no point in creating
+    // udot instruction
+    if (!MRI.hasOneNonDBGUse(MidReg))
+      return false;
+
+    MachineInstr *ExtMI1 =
+        getDefIgnoringCopies(I1->getOperand(1).getReg(), MRI);
+    MachineInstr *ExtMI2 =
+        getDefIgnoringCopies(I1->getOperand(2).getReg(), MRI);
+    LLT Ext1DstTy = MRI.getType(ExtMI1->getOperand(0).getReg());
+    LLT Ext2DstTy = MRI.getType(ExtMI2->getOperand(0).getReg());
+
+    if (ExtMI1->getOpcode() != ExtMI2->getOpcode() || Ext1DstTy != Ext2DstTy)
+      return false;
+    I1Opc = ExtMI1->getOpcode();
+    SrcTy = MRI.getType(ExtMI1->getOperand(1).getReg());
+    std::get<0>(MatchInfo) = ExtMI1->getOperand(1).getReg();
+    std::get<1>(MatchInfo) = ExtMI2->getOperand(1).getReg();
+  } else {
+    SrcTy = MRI.getType(I1->getOperand(1).getReg());
+    std::get<0>(MatchInfo) = I1->getOperand(1).getReg();
+    std::get<1>(MatchInfo) = 0;
+  }
+
+  if (I1Opc == TargetOpcode::G_ZEXT)
+    std::get<2>(MatchInfo) = 0;
+  else if (I1Opc == TargetOpcode::G_SEXT)
+    std::get<2>(MatchInfo) = 1;
+  else
+    return false;
+
+  if (SrcTy.getScalarSizeInBits() != 8 || SrcTy.getNumElements() % 8 != 0)
+    return false;
+
+  return true;
+}
+
+void applyExtAddvToUdotAddv(MachineInstr &MI, MachineRegisterInfo &MRI,
+                            MachineIRBuilder &Builder,
+                            GISelChangeObserver &Observer,
+                            const AArch64Subtarget &STI,
+                            std::tuple<Register, Register, bool> &MatchInfo) {
+  assert(MI.getOpcode() == TargetOpcode::G_VECREDUCE_ADD &&
+         "Expected a G_VECREDUCE_ADD instruction");
+  assert(STI.hasDotProd() && "Target should have Dot Product feature");
+
+  // Initialise the variables
+  unsigned DotOpcode =
+      std::get<2>(MatchInfo) ? AArch64::G_SDOT : AArch64::G_UDOT;
+  Register Ext1SrcReg = std::get<0>(MatchInfo);
+
+  // If there is one source register, create a vector of 0s as the second
+  // source register
+  Register Ext2SrcReg;
+  if (std::get<1>(MatchInfo) == 0)
+    Ext2SrcReg = Builder.buildConstant(MRI.getType(Ext1SrcReg), 1)
+                     ->getOperand(0)
+                     .getReg();
+  else
+    Ext2SrcReg = std::get<1>(MatchInfo);
+
+  // Find out how many DOT instructions are needed
+  LLT SrcTy = MRI.getType(Ext1SrcReg);
+  LLT MidTy;
+  unsigned NumOfDotMI;
+  if (SrcTy.getNumElements() % 16 == 0) {
+    NumOfDotMI = SrcTy.getNumElements() / 16;
+    MidTy = LLT::fixed_vector(4, 32);
+  } else if (SrcTy.getNumElements() % 8 == 0) {
+    NumOfDotMI = SrcTy.getNumElements() / 8;
+    MidTy = LLT::fixed_vector(2, 32);
+  } else {
+    llvm_unreachable("Source type number of elements is not multiple of 8");
+  }
+
+  // Handle case where one DOT instruction is needed
+  if (NumOfDotMI == 1) {
+    auto Zeroes = Builder.buildConstant(MidTy, 0)->getOperand(0).getReg();
+    auto Dot = Builder.buildInstr(DotOpcode, {MidTy},
+                                  {Zeroes, Ext1SrcReg, Ext2SrcReg});
+    Builder.buildVecReduceAdd(MI.getOperand(0), Dot->getOperand(0));
+  } else {
+    // If not pad the last v8 element with 0s to a v16
+    SmallVector<Register, 4> Ext1UnmergeReg;
+    SmallVector<Register, 4> Ext2UnmergeReg;
+    if (SrcTy.getNumElements() % 16 != 0) {
+      // Unmerge source to v8i8, append a new v8i8 of 0s and the merge to v16s
+      SmallVector<Register, 4> PadUnmergeDstReg1;
+      SmallVector<Register, 4> PadUnmergeDstReg2;
+      unsigned NumOfVec = SrcTy.getNumElements() / 8;
+
+      // Unmerge the source to v8i8
+      MachineInstr *PadUnmerge1 =
+          Builder.buildUnmerge(LLT::fixed_vector(8, 8), Ext1SrcReg);
+      MachineInstr *PadUnmerge2 =
+          Builder.buildUnmerge(LLT::fixed_vector(8, 8), Ext2SrcReg);
+      for (unsigned i = 0; i < NumOfVec; i++) {
+        PadUnmergeDstReg1.push_back(PadUnmerge1->getOperand(i).getReg());
+        PadUnmergeDstReg2.push_back(PadUnmerge2->getOperand(i).getReg());
+      }
+
+      // Pad the vectors with a v8i8 constant of 0s
+      MachineInstr *v8Zeroes =
+          Builder.buildConstant(LLT::fixed_vector(8, 8), 0);
+      PadUnmergeDstReg1.push_back(v8Zeroes->getOperand(0).getReg());
+      PadUnmergeDstReg2.push_back(v8Zeroes->getOperand(0).getReg());
+
+      // Merge them all back to v16i8
+      NumOfVec = (NumOfVec + 1) / 2;
+      for (unsigned i = 0; i < NumOfVec; i++) {
+        Ext1UnmergeReg.push_back(
+            Builder
+                .buildMergeLikeInstr(
+                    LLT::fixed_vector(16, 8),
+                    {PadUnmergeDstReg1[i * 2], PadUnmergeDstReg1[(i * 2) + 1]})
+                ->getOperand(0)
+                .getReg());
----------------
davemgreen wrote:

I think that ->getOperand(0).getReg() can be .getReg(0) from a MIBuilder. Here and 2 other places below.

https://github.com/llvm/llvm-project/pull/70784


More information about the llvm-commits mailing list