[llvm] [AArch64][SVE] Add dot product codegen for partial reductions with no binary operation on input (PR #120207)
Benjamin Maxwell via llvm-commits
llvm-commits at lists.llvm.org
Fri Dec 20 07:59:45 PST 2024
================
@@ -21953,36 +21953,46 @@ SDValue tryLowerPartialReductionToDot(SDNode *N,
SDLoc DL(N);
SDValue Op2 = N->getOperand(2);
- if (Op2->getOpcode() != ISD::MUL ||
- !ISD::isExtOpcode(Op2->getOperand(0)->getOpcode()) ||
- !ISD::isExtOpcode(Op2->getOperand(1)->getOpcode()))
- return SDValue();
+ unsigned Op2Opcode = Op2->getOpcode();
+ SDValue MulOpLHS, MulOpRHS;
+ bool MulOpLHSIsSigned, MulOpRHSIsSigned;
+ if (ISD::isExtOpcode(Op2Opcode)) {
+ MulOpLHSIsSigned = MulOpRHSIsSigned = (Op2Opcode == ISD::SIGN_EXTEND);
+ MulOpLHS = Op2->getOperand(0);
+ MulOpRHS = DAG.getConstant(1, DL, MulOpLHS.getValueType());
+ } else if (Op2Opcode == ISD::MUL) {
+ SDValue ExtMulOpLHS = Op2->getOperand(0);
+ SDValue ExtMulOpRHS = Op2->getOperand(1);
+
+ unsigned ExtMulOpLHSOpcode = ExtMulOpLHS->getOpcode();
+ unsigned ExtMulOpRHSOpcode = ExtMulOpRHS->getOpcode();
+ if (!ISD::isExtOpcode(ExtMulOpLHSOpcode) ||
+ !ISD::isExtOpcode(ExtMulOpRHSOpcode))
+ return SDValue();
- SDValue Acc = N->getOperand(1);
- SDValue Mul = N->getOperand(2);
- SDValue ExtMulOpLHS = Mul->getOperand(0);
- SDValue ExtMulOpRHS = Mul->getOperand(1);
+ MulOpLHSIsSigned = ExtMulOpLHSOpcode == ISD::SIGN_EXTEND;
+ MulOpRHSIsSigned = ExtMulOpRHSOpcode == ISD::SIGN_EXTEND;
- SDValue MulOpLHS = ExtMulOpLHS->getOperand(0);
- SDValue MulOpRHS = ExtMulOpRHS->getOperand(0);
- if (MulOpLHS.getValueType() != MulOpRHS.getValueType())
+ MulOpLHS = ExtMulOpLHS->getOperand(0);
+ MulOpRHS = ExtMulOpRHS->getOperand(0);
+ } else
return SDValue();
+ SDValue Acc = N->getOperand(1);
EVT ReducedVT = N->getValueType(0);
EVT MulSrcVT = MulOpLHS.getValueType();
// Dot products operate on chunks of four elements so there must be four times
// as many elements in the wide type
- if (!(ReducedVT == MVT::nxv4i64 && MulSrcVT == MVT::nxv16i8) &&
- !(ReducedVT == MVT::nxv4i32 && MulSrcVT == MVT::nxv16i8) &&
- !(ReducedVT == MVT::nxv2i64 && MulSrcVT == MVT::nxv8i16) &&
- !(ReducedVT == MVT::v4i64 && MulSrcVT == MVT::v16i8) &&
- !(ReducedVT == MVT::v4i32 && MulSrcVT == MVT::v16i8) &&
- !(ReducedVT == MVT::v2i32 && MulSrcVT == MVT::v8i8))
+ if ((!(ReducedVT == MVT::nxv4i64 && MulSrcVT == MVT::nxv16i8) &&
+ !(ReducedVT == MVT::nxv4i32 && MulSrcVT == MVT::nxv16i8) &&
+ !(ReducedVT == MVT::nxv2i64 && MulSrcVT == MVT::nxv8i16) &&
+ !(ReducedVT == MVT::v4i64 && MulSrcVT == MVT::v16i8) &&
+ !(ReducedVT == MVT::v4i32 && MulSrcVT == MVT::v16i8) &&
+ !(ReducedVT == MVT::v2i32 && MulSrcVT == MVT::v8i8)) ||
+ (MulOpLHS.getValueType() != MulOpRHS.getValueType()))
----------------
MacDue wrote:
nit: Maybe keep this exit above all these type checks (like it was before), as the comment does not quite relate to this check.
```
if (MulOpLHS.getValueType() != MulOpRHS.getValueType())
return SDValue();
```
https://github.com/llvm/llvm-project/pull/120207
More information about the llvm-commits
mailing list