[llvm] [AArch64] Lower partial add reduction to udot or svdot (PR #101010)
Paul Walker via llvm-commits
llvm-commits at lists.llvm.org
Mon Sep 2 02:27:12 PDT 2024
================
@@ -21757,6 +21768,70 @@ static SDValue tryCombineWhileLo(SDNode *N,
return SDValue(N, 0);
}
+SDValue tryLowerPartialReductionToDot(SDNode *N,
+ const AArch64Subtarget *Subtarget,
+ SelectionDAG &DAG) {
+
+ assert(N->getOpcode() == ISD::INTRINSIC_WO_CHAIN &&
+ getIntrinsicID(N) ==
+ Intrinsic::experimental_vector_partial_reduce_add &&
+ "Expected a partial reduction node");
+
+ if (!Subtarget->isSVEorStreamingSVEAvailable())
+ return SDValue();
+
+ SDLoc DL(N);
+
+ // The narrower of the two operands. Used as the accumulator
+ auto NarrowOp = N->getOperand(1);
+ auto MulOp = N->getOperand(2);
+ if (MulOp->getOpcode() != ISD::MUL)
+ return SDValue();
+
+ auto ExtA = MulOp->getOperand(0);
+ auto ExtB = MulOp->getOperand(1);
+ bool IsSExt = ExtA->getOpcode() == ISD::SIGN_EXTEND;
+ bool IsZExt = ExtA->getOpcode() == ISD::ZERO_EXTEND;
+ if (ExtA->getOpcode() != ExtB->getOpcode() || (!IsSExt && !IsZExt))
+ return SDValue();
+
+ auto A = ExtA->getOperand(0);
+ auto B = ExtB->getOperand(0);
+ if (A.getValueType() != B.getValueType())
+ return SDValue();
+
+ unsigned Opcode = 0;
+
+ if (IsSExt)
+ Opcode = AArch64ISD::SDOT;
+ else if (IsZExt)
+ Opcode = AArch64ISD::UDOT;
+
+ assert(Opcode != 0 && "Unexpected dot product case encountered.");
+
+ // The fully-reduced type. Should be a vector of i32 or i64
+ EVT ReducedType = N->getValueType(0);
+ // The type that is extended to the wide type. Should be an i8 or i16
+ EVT ExtendedType = A.getValueType();
+ // The wide type with four times as many elements as the reduced type. Should
+ // be a vector of i32 or i64, the same as the fully-reduced type
+ EVT WideType = MulOp.getValueType();
+
+ // Dot products operate on chunks of four elements so there must be four times
+ // as many elements in the wide type
+ if (WideType == MVT::nxv16i32 && ReducedType == MVT::nxv4i32 &&
+ ExtendedType == MVT::nxv16i8)
----------------
paulwalker-arm wrote:
DAG combines can and should assume the DAG is well formed.
In this instance you can be certain the element type of the mul will match the element type of the partial.add's result type (i.e. ReducedType ) and you can be certain the mul will have the same number of elements as its operands, which by extension means the same number of elements as its pre-extended operands (i.e ExtendedType).
https://github.com/llvm/llvm-project/pull/101010
More information about the llvm-commits
mailing list