[llvm] [AArch64] Lower partial add reduction to udot or svdot (PR #101010)
Paul Walker via llvm-commits
llvm-commits at lists.llvm.org
Tue Aug 27 10:41:27 PDT 2024
================
@@ -21229,6 +21249,101 @@ static SDValue tryCombineWhileLo(SDNode *N,
return SDValue(N, 0);
}
+SDValue tryLowerPartialReductionToDot(SDNode *N,
+ const AArch64Subtarget *Subtarget,
+ SelectionDAG &DAG) {
+
+ if (!Subtarget->isSVEAvailable() && !Subtarget->isNeonAvailable())
+ return SDValue();
+
+ SDLoc DL(N);
+
+ // The narrower of the two operands. Used as the accumulator
+ auto NarrowOp = N->getOperand(1);
+ auto MulOp = N->getOperand(2);
+ if (MulOp->getOpcode() != ISD::MUL)
+ return SDValue();
+
+ auto ExtA = MulOp->getOperand(0);
+ auto ExtB = MulOp->getOperand(1);
+ bool IsSExt = ExtA->getOpcode() == ISD::SIGN_EXTEND;
+ bool IsZExt = ExtA->getOpcode() == ISD::ZERO_EXTEND;
+ if (ExtA->getOpcode() != ExtB->getOpcode() || (!IsSExt && !IsZExt))
+ return SDValue();
+
+ auto A = ExtA->getOperand(0);
+ auto B = ExtB->getOperand(0);
+ if (A.getValueType() != B.getValueType())
+ return SDValue();
+
+ // The fully-reduced type. Should be a vector of i32 or i64
+ EVT FullType = N->getValueType(0);
+ // The type that is extended to the wide type. Should be an i8 or i16
+ EVT ExtendedType = A.getValueType();
+ // The wide type with four times as many elements as the reduced type. Should
+ // be a vector of i32 or i64, the same as the fully-reduced type
+ EVT WideType = MulOp.getValueType();
+ if (WideType.getScalarSizeInBits() != FullType.getScalarSizeInBits())
+ return SDValue();
+ // Dot products operate on chunks of four elements so there must be four times
+ // as many elements in the wide type
+ if (WideType.getVectorMinNumElements() / FullType.getVectorMinNumElements() !=
+ 4)
+ return SDValue();
+ switch (FullType.getScalarSizeInBits()) {
+ case 32:
+ if (ExtendedType.getScalarSizeInBits() != 8)
+ return SDValue();
+ break;
+ case 64:
+ // i8 to i64 can be done with an extended i32 dot product
+ if (ExtendedType.getScalarSizeInBits() != 8 &&
+ ExtendedType.getScalarSizeInBits() != 16)
+ return SDValue();
+ break;
+ default:
+ return SDValue();
+ }
+
+ unsigned DotIntrinsicId = Intrinsic::not_intrinsic;
+
+ if (IsSExt)
+ DotIntrinsicId = Intrinsic::aarch64_sve_sdot;
+ else if (IsZExt)
+ DotIntrinsicId = Intrinsic::aarch64_sve_udot;
+
+ assert(DotIntrinsicId != Intrinsic::not_intrinsic &&
+ "Unexpected dot product case encountered.");
+
+ EVT Type = NarrowOp.getValueType();
+
+ // 8 bit input to 64 bit output can be done by doing a 32 bit dot product
+ // and extending the output
+ bool Extend = A->getValueType(0).getScalarSizeInBits() == 8 &&
+ Type.getScalarSizeInBits() == 64;
+ SDValue Accumulator = NarrowOp;
+ if (Extend) {
+ Type =
+ Type.changeVectorElementType(EVT::getIntegerVT(*DAG.getContext(), 32));
+ // The accumulator is of the wider type so we insert a 0 accumulator and
+ // add the proper one after extending
+ Accumulator = DAG.getNode(ISD::SPLAT_VECTOR, DL, MVT::nxv4i32,
+ DAG.getConstant(0, DL, MVT::i32));
----------------
paulwalker-arm wrote:
You should be able to use `DAG.getConstant(0, DL, MVT::nxv4i32);` here.
https://github.com/llvm/llvm-project/pull/101010
More information about the llvm-commits
mailing list