[llvm] [AArch64] Lower partial add reduction to udot or svdot (PR #101010)
Paul Walker via llvm-commits
llvm-commits at lists.llvm.org
Fri Aug 30 07:37:35 PDT 2024
================
@@ -21757,6 +21768,70 @@ static SDValue tryCombineWhileLo(SDNode *N,
return SDValue(N, 0);
}
+SDValue tryLowerPartialReductionToDot(SDNode *N,
+ const AArch64Subtarget *Subtarget,
+ SelectionDAG &DAG) {
+
+ assert(N->getOpcode() == ISD::INTRINSIC_WO_CHAIN &&
+ getIntrinsicID(N) ==
+ Intrinsic::experimental_vector_partial_reduce_add &&
+ "Expected a partial reduction node");
+
+ if (!Subtarget->isSVEorStreamingSVEAvailable())
+ return SDValue();
+
+ SDLoc DL(N);
+
+ // The narrower of the two operands. Used as the accumulator
+ auto NarrowOp = N->getOperand(1);
+ auto MulOp = N->getOperand(2);
+ if (MulOp->getOpcode() != ISD::MUL)
+ return SDValue();
+
+ auto ExtA = MulOp->getOperand(0);
+ auto ExtB = MulOp->getOperand(1);
+ bool IsSExt = ExtA->getOpcode() == ISD::SIGN_EXTEND;
+ bool IsZExt = ExtA->getOpcode() == ISD::ZERO_EXTEND;
+ if (ExtA->getOpcode() != ExtB->getOpcode() || (!IsSExt && !IsZExt))
+ return SDValue();
+
+ auto A = ExtA->getOperand(0);
+ auto B = ExtB->getOperand(0);
+ if (A.getValueType() != B.getValueType())
+ return SDValue();
+
+ unsigned Opcode = 0;
+
+ if (IsSExt)
+ Opcode = AArch64ISD::SDOT;
+ else if (IsZExt)
+ Opcode = AArch64ISD::UDOT;
----------------
paulwalker-arm wrote:
This seems overly complicated, but I guess it'll help when we want to add support for USDOT.
https://github.com/llvm/llvm-project/pull/101010
More information about the llvm-commits
mailing list