[llvm] [AArch64] Lower partial add reduction to udot or svdot (PR #101010)
Paul Walker via llvm-commits
llvm-commits at lists.llvm.org
Fri Aug 30 07:37:35 PDT 2024
================
@@ -21757,6 +21768,70 @@ static SDValue tryCombineWhileLo(SDNode *N,
return SDValue(N, 0);
}
+SDValue tryLowerPartialReductionToDot(SDNode *N,
+ const AArch64Subtarget *Subtarget,
+ SelectionDAG &DAG) {
+
+ assert(N->getOpcode() == ISD::INTRINSIC_WO_CHAIN &&
+ getIntrinsicID(N) ==
+ Intrinsic::experimental_vector_partial_reduce_add &&
+ "Expected a partial reduction node");
+
+ if (!Subtarget->isSVEorStreamingSVEAvailable())
+ return SDValue();
+
+ SDLoc DL(N);
+
+ // The narrower of the two operands. Used as the accumulator
+ auto NarrowOp = N->getOperand(1);
+ auto MulOp = N->getOperand(2);
+ if (MulOp->getOpcode() != ISD::MUL)
+ return SDValue();
+
+ auto ExtA = MulOp->getOperand(0);
+ auto ExtB = MulOp->getOperand(1);
+ bool IsSExt = ExtA->getOpcode() == ISD::SIGN_EXTEND;
+ bool IsZExt = ExtA->getOpcode() == ISD::ZERO_EXTEND;
+ if (ExtA->getOpcode() != ExtB->getOpcode() || (!IsSExt && !IsZExt))
+ return SDValue();
+
+ auto A = ExtA->getOperand(0);
+ auto B = ExtB->getOperand(0);
+ if (A.getValueType() != B.getValueType())
+ return SDValue();
+
+ unsigned Opcode = 0;
+
+ if (IsSExt)
+ Opcode = AArch64ISD::SDOT;
+ else if (IsZExt)
+ Opcode = AArch64ISD::UDOT;
+
+ assert(Opcode != 0 && "Unexpected dot product case encountered.");
+
+ // The fully-reduced type. Should be a vector of i32 or i64
+ EVT ReducedType = N->getValueType(0);
+ // The type that is extended to the wide type. Should be an i8 or i16
+ EVT ExtendedType = A.getValueType();
+ // The wide type with four times as many elements as the reduced type. Should
+ // be a vector of i32 or i64, the same as the fully-reduced type
+ EVT WideType = MulOp.getValueType();
+
+ // Dot products operate on chunks of four elements so there must be four times
+ // as many elements in the wide type
+ if (WideType == MVT::nxv16i32 && ReducedType == MVT::nxv4i32 &&
+ ExtendedType == MVT::nxv16i8)
+ return DAG.getNode(Opcode, DL, MVT::nxv4i32,
+ NarrowOp, A, B);
+
+ if (WideType == MVT::nxv8i64 && ReducedType == MVT::nxv2i64 &&
+ ExtendedType == MVT::nxv8i16)
+ return DAG.getNode(Opcode, DL, MVT::nxv2i64,
+ NarrowOp, A, B);
----------------
paulwalker-arm wrote:
As above, one line?
https://github.com/llvm/llvm-project/pull/101010
More information about the llvm-commits
mailing list