[llvm] [RISCV] Initial codegen support for zvqdotq extension (PR #137039)
Philip Reames via llvm-commits
llvm-commits at lists.llvm.org
Thu Apr 24 08:56:30 PDT 2025
================
@@ -18003,6 +18003,118 @@ static SDValue performBUILD_VECTORCombine(SDNode *N, SelectionDAG &DAG,
DAG.getBuildVector(VT, DL, RHSOps));
}
+static SDValue lowerVQDOT(unsigned Opc, SDValue Op0, SDValue Op1,
+ const SDLoc &DL, SelectionDAG &DAG,
+ const RISCVSubtarget &Subtarget) {
+ assert(RISCVISD::VQDOT_VL == Opc || RISCVISD::VQDOTU_VL == Opc ||
+ RISCVISD::VQDOTSU_VL == Opc);
+ MVT VT = Op0.getSimpleValueType();
+ assert(VT == Op1.getSimpleValueType() &&
+ VT.getVectorElementType() == MVT::i32);
+
+ assert(VT.isFixedLengthVector());
+ MVT ContainerVT = getContainerForFixedLengthVector(DAG, VT, Subtarget);
+ SDValue Passthru = convertToScalableVector(
+ ContainerVT, DAG.getConstant(0, DL, VT), DAG, Subtarget);
+ Op0 = convertToScalableVector(ContainerVT, Op0, DAG, Subtarget);
+ Op1 = convertToScalableVector(ContainerVT, Op1, DAG, Subtarget);
+
+ auto [Mask, VL] = getDefaultVLOps(VT, ContainerVT, DL, DAG, Subtarget);
+ const unsigned Policy = RISCVVType::TAIL_AGNOSTIC | RISCVVType::MASK_AGNOSTIC;
+ SDValue PolicyOp = DAG.getTargetConstant(Policy, DL, Subtarget.getXLenVT());
+ SDValue LocalAccum = DAG.getNode(Opc, DL, ContainerVT,
+ {Op0, Op1, Passthru, Mask, VL, PolicyOp});
+ return convertFromScalableVector(VT, LocalAccum, DAG, Subtarget);
+}
+
+static MVT getQDOTXResultType(MVT OpVT) {
+ ElementCount OpEC = OpVT.getVectorElementCount();
+ assert(OpEC.isKnownMultipleOf(4) && OpVT.getVectorElementType() == MVT::i8);
+ return MVT::getVectorVT(MVT::i32, OpEC.divideCoefficientBy(4));
+}
+
+static SDValue foldReduceOperandViaVQDOT(SDValue InVec, const SDLoc &DL,
+ SelectionDAG &DAG,
+ const RISCVSubtarget &Subtarget,
+ const RISCVTargetLowering &TLI) {
+ // Note: We intentionally do not check the legality of the reduction type.
+ // We want to handle the m4/m8 *src* types, and thus need to let illegal
+ // intermediate types flow through here.
+ if (InVec.getValueType().getVectorElementType() != MVT::i32 ||
+ !InVec.getValueType().getVectorElementCount().isKnownMultipleOf(4))
+ return SDValue();
+
+ // reduce (sext a) <--> reduce (mul zext a. zext 1)
+ // reduce (zext a) <--> reduce (mul sext a. sext 1)
+ if (InVec.getOpcode() == ISD::ZERO_EXTEND ||
+ InVec.getOpcode() == ISD::SIGN_EXTEND) {
+ SDValue A = InVec.getOperand(0);
+ if (A.getValueType().getVectorElementType() != MVT::i8 ||
+ !TLI.isTypeLegal(A.getValueType()))
+ return SDValue();
+
+ MVT ResVT = getQDOTXResultType(A.getSimpleValueType());
+ A = DAG.getBitcast(ResVT, A);
+ SDValue B = DAG.getConstant(0x01010101, DL, ResVT);
+
+ bool IsSigned = InVec.getOpcode() == ISD::SIGN_EXTEND;
+ unsigned Opc = IsSigned ? RISCVISD::VQDOT_VL : RISCVISD::VQDOTU_VL;
+ return lowerVQDOT(Opc, A, B, DL, DAG, Subtarget);
+ }
+
+ // mul (sext, sext) -> vqdot
+ // mul (zext, zext) -> vqdotu
+ // mul (sext, zext) -> vqdotsu
+ // mul (zext, sext) -> vqdotsu (swapped)
+ // TODO: Improve .vx handling - we end up with a sub-vector insert
+ // which confuses the splat pattern matching. Also, match vqdotus.vx
+ if (InVec.getOpcode() != ISD::MUL)
----------------
preames wrote:
Annoyingly complicated, possible future work.
The problem is that we have to expand the shift as a multiply by 2^N, and the range of shift amounts we can handle is very limited due to the input being an i8.
https://github.com/llvm/llvm-project/pull/137039
More information about the llvm-commits
mailing list