[llvm] [AArch64][SVE] Add dot product codegen for partial reductions with no binary operation on input (PR #120207)

James Chesterman via llvm-commits llvm-commits at lists.llvm.org
Mon Jan 6 02:03:08 PST 2025


================
@@ -21953,36 +21953,46 @@ SDValue tryLowerPartialReductionToDot(SDNode *N,
   SDLoc DL(N);
 
   SDValue Op2 = N->getOperand(2);
-  if (Op2->getOpcode() != ISD::MUL ||
-      !ISD::isExtOpcode(Op2->getOperand(0)->getOpcode()) ||
-      !ISD::isExtOpcode(Op2->getOperand(1)->getOpcode()))
-    return SDValue();
+  unsigned Op2Opcode = Op2->getOpcode();
+  SDValue MulOpLHS, MulOpRHS;
+  bool MulOpLHSIsSigned, MulOpRHSIsSigned;
+  if (ISD::isExtOpcode(Op2Opcode)) {
+    MulOpLHSIsSigned = MulOpRHSIsSigned = (Op2Opcode == ISD::SIGN_EXTEND);
+    MulOpLHS = Op2->getOperand(0);
+    MulOpRHS = DAG.getConstant(1, DL, MulOpLHS.getValueType());
+  } else if (Op2Opcode == ISD::MUL) {
+    SDValue ExtMulOpLHS = Op2->getOperand(0);
+    SDValue ExtMulOpRHS = Op2->getOperand(1);
+
+    unsigned ExtMulOpLHSOpcode = ExtMulOpLHS->getOpcode();
+    unsigned ExtMulOpRHSOpcode = ExtMulOpRHS->getOpcode();
+    if (!ISD::isExtOpcode(ExtMulOpLHSOpcode) ||
+        !ISD::isExtOpcode(ExtMulOpRHSOpcode))
+      return SDValue();
 
-  SDValue Acc = N->getOperand(1);
-  SDValue Mul = N->getOperand(2);
-  SDValue ExtMulOpLHS = Mul->getOperand(0);
-  SDValue ExtMulOpRHS = Mul->getOperand(1);
+    MulOpLHSIsSigned = ExtMulOpLHSOpcode == ISD::SIGN_EXTEND;
+    MulOpRHSIsSigned = ExtMulOpRHSOpcode == ISD::SIGN_EXTEND;
 
-  SDValue MulOpLHS = ExtMulOpLHS->getOperand(0);
-  SDValue MulOpRHS = ExtMulOpRHS->getOperand(0);
-  if (MulOpLHS.getValueType() != MulOpRHS.getValueType())
+    MulOpLHS = ExtMulOpLHS->getOperand(0);
+    MulOpRHS = ExtMulOpRHS->getOperand(0);
+  } else
     return SDValue();
 
+  SDValue Acc = N->getOperand(1);
   EVT ReducedVT = N->getValueType(0);
   EVT MulSrcVT = MulOpLHS.getValueType();
 
   // Dot products operate on chunks of four elements so there must be four times
   // as many elements in the wide type
-  if (!(ReducedVT == MVT::nxv4i64 && MulSrcVT == MVT::nxv16i8) &&
-      !(ReducedVT == MVT::nxv4i32 && MulSrcVT == MVT::nxv16i8) &&
-      !(ReducedVT == MVT::nxv2i64 && MulSrcVT == MVT::nxv8i16) &&
-      !(ReducedVT == MVT::v4i64 && MulSrcVT == MVT::v16i8) &&
-      !(ReducedVT == MVT::v4i32 && MulSrcVT == MVT::v16i8) &&
-      !(ReducedVT == MVT::v2i32 && MulSrcVT == MVT::v8i8))
+  if ((!(ReducedVT == MVT::nxv4i64 && MulSrcVT == MVT::nxv16i8) &&
+       !(ReducedVT == MVT::nxv4i32 && MulSrcVT == MVT::nxv16i8) &&
+       !(ReducedVT == MVT::nxv2i64 && MulSrcVT == MVT::nxv8i16) &&
+       !(ReducedVT == MVT::v4i64 && MulSrcVT == MVT::v16i8) &&
+       !(ReducedVT == MVT::v4i32 && MulSrcVT == MVT::v16i8) &&
+       !(ReducedVT == MVT::v2i32 && MulSrcVT == MVT::v8i8)) ||
+      (MulOpLHS.getValueType() != MulOpRHS.getValueType()))
----------------
JamesChesterman wrote:

Done, moved the condition `if (MulOpLHS.getValueType() != MulOpRHS.getValueType())` back into the if statement `else if (Op2Opcode != ISD::MUL)`

https://github.com/llvm/llvm-project/pull/120207


More information about the llvm-commits mailing list