[llvm] [AArch64][NEON][SVE] Lower mixed sign/zero extended partial reductions to usdot (PR #107566)
Sam Tebbs via llvm-commits
llvm-commits at lists.llvm.org
Thu Sep 19 03:05:50 PDT 2024
================
@@ -21824,37 +21830,52 @@ SDValue tryLowerPartialReductionToDot(SDNode *N,
auto ExtA = MulOp->getOperand(0);
auto ExtB = MulOp->getOperand(1);
- bool IsSExt = ExtA->getOpcode() == ISD::SIGN_EXTEND;
- bool IsZExt = ExtA->getOpcode() == ISD::ZERO_EXTEND;
- if (ExtA->getOpcode() != ExtB->getOpcode() || (!IsSExt && !IsZExt))
+
+ if (!ISD::isExtOpcode(ExtA->getOpcode()) ||
+ !ISD::isExtOpcode(ExtB->getOpcode()))
return SDValue();
+ bool AIsSigned = ExtA->getOpcode() == ISD::SIGN_EXTEND;
+ bool BIsSigned = ExtB->getOpcode() == ISD::SIGN_EXTEND;
auto A = ExtA->getOperand(0);
auto B = ExtB->getOperand(0);
if (A.getValueType() != B.getValueType())
return SDValue();
+ EVT ReducedType = N->getValueType(0);
+ EVT MulSrcType = A.getValueType();
+
+ // Dot products operate on chunks of four elements so there must be four times
+ // as many elements in the wide type
+ if (!(ReducedType == MVT::nxv4i32 && MulSrcType == MVT::nxv16i8) &&
+ !(ReducedType == MVT::nxv2i64 && MulSrcType == MVT::nxv8i16) &&
+ !(ReducedType == MVT::v4i32 && MulSrcType == MVT::v16i8) &&
+ !(ReducedType == MVT::v2i32 && MulSrcType == MVT::v8i8))
+ return SDValue();
+
+ // If the extensions are mixed, we should lower it to a usdot instead
unsigned Opcode = 0;
+ if (AIsSigned != BIsSigned) {
+ if (!Subtarget->hasMatMulInt8())
+ return SDValue();
+
+ bool Scalable = N->getValueType(0).isScalableVT();
+ // There's no nxv2i64 version of usdot
+ if (Scalable && ReducedType != MVT::nxv4i32)
+ return SDValue();
- if (IsSExt)
+ Opcode = AArch64ISD::USDOT;
+ // USDOT expects the signed operand to be last
+ if (!BIsSigned)
+ std::swap(A, B);
+ } else if (AIsSigned)
Opcode = AArch64ISD::SDOT;
- else if (IsZExt)
+ else if (!AIsSigned)
----------------
SamTebbs33 wrote:
That's very true, it also let's us get rid of the assertion. Done.
https://github.com/llvm/llvm-project/pull/107566
More information about the llvm-commits
mailing list