[llvm] [AArch64][SVE] Add dot product codegen for partial reductions with no binary operation on input (PR #120207)

James Chesterman via llvm-commits llvm-commits at lists.llvm.org
Fri Dec 20 07:10:07 PST 2024


================
@@ -21741,45 +21741,63 @@ SDValue tryLowerPartialReductionToDot(SDNode *N,
   // The narrower of the two operands. Used as the accumulator
   auto NarrowOp = N->getOperand(1);
   auto MulOp = N->getOperand(2);
-  if (MulOp->getOpcode() != ISD::MUL)
-    return SDValue();
 
-  auto ExtA = MulOp->getOperand(0);
-  auto ExtB = MulOp->getOperand(1);
+  unsigned MulOpcode = MulOp->getOpcode();
+  EVT ReducedVT = N->getValueType(0);
+  EVT MulOpVT = MulOp->getValueType(0);
+  unsigned Opcode = 0;
+  bool AIsSigned, BIsSigned;
+  SDValue A, B;
+  if (MulOpcode != ISD::MUL && ReducedVT.getVectorElementCount() * 4 ==
+                                   MulOpVT.getVectorElementCount()) {
+    if (!ISD::isExtOpcode(MulOpcode))
+      return SDValue();
+    AIsSigned = MulOpcode == ISD::SIGN_EXTEND;
+    BIsSigned = AIsSigned;
+    SDValue NewMulOp = MulOp->getOperand(0);
+    Opcode = AIsSigned ? AArch64ISD::SDOT : AArch64ISD::UDOT;
+    A = NewMulOp;
+    B = DAG.getConstant(1, DL, NewMulOp.getValueType());
 
-  if (!ISD::isExtOpcode(ExtA->getOpcode()) ||
-      !ISD::isExtOpcode(ExtB->getOpcode()))
-    return SDValue();
-  bool AIsSigned = ExtA->getOpcode() == ISD::SIGN_EXTEND;
-  bool BIsSigned = ExtB->getOpcode() == ISD::SIGN_EXTEND;
+  } else {
+    if (MulOp->getOpcode() != ISD::MUL)
+      return SDValue();
 
-  auto A = ExtA->getOperand(0);
-  auto B = ExtB->getOperand(0);
-  if (A.getValueType() != B.getValueType())
-    return SDValue();
+    auto ExtA = MulOp->getOperand(0);
+    auto ExtB = MulOp->getOperand(1);
 
-  EVT ReducedType = N->getValueType(0);
-  EVT MulSrcType = A.getValueType();
+    if (!ISD::isExtOpcode(ExtA->getOpcode()) ||
+        !ISD::isExtOpcode(ExtB->getOpcode()))
+      return SDValue();
+    AIsSigned = ExtA->getOpcode() == ISD::SIGN_EXTEND;
+    BIsSigned = ExtB->getOpcode() == ISD::SIGN_EXTEND;
+
+    A = ExtA->getOperand(0);
+    B = ExtB->getOperand(0);
+    if (A.getValueType() != B.getValueType())
+      return SDValue();
----------------
JamesChesterman wrote:

Done, there was missing test coverage, added tests for when the original, unextended types are different in the test files: `sve-partial-reduce-dot-product.ll` and `neon-partial-reduce-dot-product.ll`. Also moved this into the large if statement below it, doing an OR operation with the rest of the conditions.

https://github.com/llvm/llvm-project/pull/120207


More information about the llvm-commits mailing list