[llvm] [AArch64][GlobalISel] Support udot lowering for vecreduce add (PR #70784)

via llvm-commits llvm-commits at lists.llvm.org
Thu Nov 2 01:50:06 PDT 2023


================
@@ -228,6 +228,146 @@ void applyFoldGlobalOffset(MachineInstr &MI, MachineRegisterInfo &MRI,
       B.buildConstant(LLT::scalar(64), -static_cast<int64_t>(MinOffset)));
 }
 
+// Combines vecreduce_add(mul(ext, ext)) -> vecreduce_add(udot)
+// Or vecreduce_add(ext) -> vecreduce_add(ext)
+// Similar to performVecReduceAddCombine in SelectionDAG
+bool matchExtAddvToUdotAddv(MachineInstr &MI, MachineRegisterInfo &MRI) {
+  assert(MI.getOpcode() == TargetOpcode::G_VECREDUCE_ADD &&
+         "Expected a G_VECREDUCE_ADD instruction");
+
+  MachineInstr *I1 = getDefIgnoringCopies(MI.getOperand(1).getReg(), MRI);
+  Register DstReg = MI.getOperand(0).getReg();
+  Register MidReg = I1->getOperand(0).getReg();
+  LLT DstTy = MRI.getType(DstReg);
+  LLT MidTy = MRI.getType(MidReg);
+  if (DstTy.getScalarSizeInBits() != 32 || MidTy.getScalarSizeInBits() != 32)
+    return false;
+
+  LLT SrcTy;
+  auto I1Opc = I1->getOpcode();
+  if (I1Opc == TargetOpcode::G_MUL) {
+    MachineInstr *ExtMI1 =
+        getDefIgnoringCopies(I1->getOperand(1).getReg(), MRI);
+    MachineInstr *ExtMI2 =
+        getDefIgnoringCopies(I1->getOperand(2).getReg(), MRI);
+    LLT Ext1DstTy = MRI.getType(ExtMI1->getOperand(0).getReg());
+    LLT Ext2DstTy = MRI.getType(ExtMI2->getOperand(0).getReg());
+
+    if (ExtMI1->getOpcode() != ExtMI2->getOpcode() || Ext1DstTy != Ext2DstTy)
+      return false;
+    I1Opc = ExtMI1->getOpcode();
+    SrcTy = MRI.getType(ExtMI1->getOperand(1).getReg());
+  } else
+    SrcTy = MRI.getType(I1->getOperand(1).getReg());
----------------
tschuett wrote:

The braces are unbalanced.

https://github.com/llvm/llvm-project/pull/70784


More information about the llvm-commits mailing list