[llvm] [AArch64][NEON][SVE] Lower mixed sign/zero extended partial reductions to usdot (PR #107566)
Nicholas Guy via llvm-commits
llvm-commits at lists.llvm.org
Mon Sep 9 06:25:10 PDT 2024
================
@@ -21824,37 +21824,59 @@ SDValue tryLowerPartialReductionToDot(SDNode *N,
auto ExtA = MulOp->getOperand(0);
auto ExtB = MulOp->getOperand(1);
- bool IsSExt = ExtA->getOpcode() == ISD::SIGN_EXTEND;
- bool IsZExt = ExtA->getOpcode() == ISD::ZERO_EXTEND;
- if (ExtA->getOpcode() != ExtB->getOpcode() || (!IsSExt && !IsZExt))
- return SDValue();
-
auto A = ExtA->getOperand(0);
auto B = ExtB->getOperand(0);
if (A.getValueType() != B.getValueType())
return SDValue();
- unsigned Opcode = 0;
-
- if (IsSExt)
- Opcode = AArch64ISD::SDOT;
- else if (IsZExt)
- Opcode = AArch64ISD::UDOT;
-
- assert(Opcode != 0 && "Unexpected dot product case encountered.");
-
EVT ReducedType = N->getValueType(0);
EVT MulSrcType = A.getValueType();
// Dot products operate on chunks of four elements so there must be four times
// as many elements in the wide type
- if ((ReducedType == MVT::nxv4i32 && MulSrcType == MVT::nxv16i8) ||
- (ReducedType == MVT::nxv2i64 && MulSrcType == MVT::nxv8i16) ||
- (ReducedType == MVT::v4i32 && MulSrcType == MVT::v16i8) ||
- (ReducedType == MVT::v2i32 && MulSrcType == MVT::v8i8))
- return DAG.getNode(Opcode, DL, ReducedType, NarrowOp, A, B);
+ if (!(ReducedType == MVT::nxv4i32 && MulSrcType == MVT::nxv16i8) &&
+ !(ReducedType == MVT::nxv2i64 && MulSrcType == MVT::nxv8i16) &&
+ !(ReducedType == MVT::v4i32 && MulSrcType == MVT::v16i8) &&
+ !(ReducedType == MVT::v2i32 && MulSrcType == MVT::v8i8))
+ return SDValue();
- return SDValue();
+ bool AIsSExt = ExtA->getOpcode() == ISD::SIGN_EXTEND;
+ bool AIsZExt = ExtA->getOpcode() == ISD::ZERO_EXTEND;
+ bool BIsSExt = ExtB->getOpcode() == ISD::SIGN_EXTEND;
+ bool BIsZExt = ExtB->getOpcode() == ISD::ZERO_EXTEND;
+ if (!(AIsSExt || AIsZExt) || !(BIsSExt || BIsZExt))
----------------
NickGuy-Arm wrote:
Not sure if this check is necessary (though nor was it needed before this patch). The code that emits this intrinsic checks this condition (see https://github.com/llvm/llvm-project/pull/92418/files#diff-da321d454a7246f8ae276bf1db2782bf26b5210b8133cb59e4d7fd45d0905decR2156-R2158), so outside of hand-written IR the case of no extends is never taken.
That, and there doesn't seem to be any tests that check this condition either.
That said, I'm fine with this staying as a bit of defensive coding, as I can't say for sure whether all partial reduction cases in the future will match on extends, but figured it was worth bringing forward
https://github.com/llvm/llvm-project/pull/107566
More information about the llvm-commits
mailing list