[llvm] [AArch64][SVE] Add dot product codegen for partial reductions with no binary operation on input (PR #120207)
James Chesterman via llvm-commits
llvm-commits at lists.llvm.org
Fri Dec 20 07:10:07 PST 2024
================
@@ -21741,45 +21741,63 @@ SDValue tryLowerPartialReductionToDot(SDNode *N,
// The narrower of the two operands. Used as the accumulator
auto NarrowOp = N->getOperand(1);
auto MulOp = N->getOperand(2);
- if (MulOp->getOpcode() != ISD::MUL)
- return SDValue();
- auto ExtA = MulOp->getOperand(0);
- auto ExtB = MulOp->getOperand(1);
+ unsigned MulOpcode = MulOp->getOpcode();
+ EVT ReducedVT = N->getValueType(0);
+ EVT MulOpVT = MulOp->getValueType(0);
+ unsigned Opcode = 0;
+ bool AIsSigned, BIsSigned;
+ SDValue A, B;
+ if (MulOpcode != ISD::MUL && ReducedVT.getVectorElementCount() * 4 ==
+ MulOpVT.getVectorElementCount()) {
+ if (!ISD::isExtOpcode(MulOpcode))
+ return SDValue();
+ AIsSigned = MulOpcode == ISD::SIGN_EXTEND;
+ BIsSigned = AIsSigned;
+ SDValue NewMulOp = MulOp->getOperand(0);
+ Opcode = AIsSigned ? AArch64ISD::SDOT : AArch64ISD::UDOT;
+ A = NewMulOp;
+ B = DAG.getConstant(1, DL, NewMulOp.getValueType());
- if (!ISD::isExtOpcode(ExtA->getOpcode()) ||
- !ISD::isExtOpcode(ExtB->getOpcode()))
- return SDValue();
- bool AIsSigned = ExtA->getOpcode() == ISD::SIGN_EXTEND;
- bool BIsSigned = ExtB->getOpcode() == ISD::SIGN_EXTEND;
+ } else {
+ if (MulOp->getOpcode() != ISD::MUL)
+ return SDValue();
- auto A = ExtA->getOperand(0);
- auto B = ExtB->getOperand(0);
- if (A.getValueType() != B.getValueType())
- return SDValue();
+ auto ExtA = MulOp->getOperand(0);
+ auto ExtB = MulOp->getOperand(1);
- EVT ReducedType = N->getValueType(0);
- EVT MulSrcType = A.getValueType();
+ if (!ISD::isExtOpcode(ExtA->getOpcode()) ||
+ !ISD::isExtOpcode(ExtB->getOpcode()))
+ return SDValue();
+ AIsSigned = ExtA->getOpcode() == ISD::SIGN_EXTEND;
+ BIsSigned = ExtB->getOpcode() == ISD::SIGN_EXTEND;
+
+ A = ExtA->getOperand(0);
+ B = ExtB->getOperand(0);
+ if (A.getValueType() != B.getValueType())
+ return SDValue();
----------------
JamesChesterman wrote:
Done, there was missing test coverage, added tests for when the original, unextended types are different in the test files: `sve-partial-reduce-dot-product.ll` and `neon-partial-reduce-dot-product.ll`. Also moved this into the large if statement below it, doing an OR operation with the rest of the conditions.
https://github.com/llvm/llvm-project/pull/120207
More information about the llvm-commits
mailing list