[llvm] [RISCV] Initial codegen support for zvqdotq extension (PR #137039)

Philip Reames via llvm-commits llvm-commits at lists.llvm.org
Thu Apr 24 08:56:30 PDT 2025


================
@@ -18003,6 +18003,118 @@ static SDValue performBUILD_VECTORCombine(SDNode *N, SelectionDAG &DAG,
                      DAG.getBuildVector(VT, DL, RHSOps));
 }
 
+static SDValue lowerVQDOT(unsigned Opc, SDValue Op0, SDValue Op1,
+                          const SDLoc &DL, SelectionDAG &DAG,
+                          const RISCVSubtarget &Subtarget) {
+  assert(RISCVISD::VQDOT_VL == Opc || RISCVISD::VQDOTU_VL == Opc ||
+         RISCVISD::VQDOTSU_VL == Opc);
+  MVT VT = Op0.getSimpleValueType();
+  assert(VT == Op1.getSimpleValueType() &&
+         VT.getVectorElementType() == MVT::i32);
+
+  assert(VT.isFixedLengthVector());
+  MVT ContainerVT = getContainerForFixedLengthVector(DAG, VT, Subtarget);
+  SDValue Passthru = convertToScalableVector(
+      ContainerVT, DAG.getConstant(0, DL, VT), DAG, Subtarget);
+  Op0 = convertToScalableVector(ContainerVT, Op0, DAG, Subtarget);
+  Op1 = convertToScalableVector(ContainerVT, Op1, DAG, Subtarget);
+
+  auto [Mask, VL] = getDefaultVLOps(VT, ContainerVT, DL, DAG, Subtarget);
+  const unsigned Policy = RISCVVType::TAIL_AGNOSTIC | RISCVVType::MASK_AGNOSTIC;
+  SDValue PolicyOp = DAG.getTargetConstant(Policy, DL, Subtarget.getXLenVT());
+  SDValue LocalAccum = DAG.getNode(Opc, DL, ContainerVT,
+                                   {Op0, Op1, Passthru, Mask, VL, PolicyOp});
+  return convertFromScalableVector(VT, LocalAccum, DAG, Subtarget);
+}
+
+static MVT getQDOTXResultType(MVT OpVT) {
+  ElementCount OpEC = OpVT.getVectorElementCount();
+  assert(OpEC.isKnownMultipleOf(4) && OpVT.getVectorElementType() == MVT::i8);
+  return MVT::getVectorVT(MVT::i32, OpEC.divideCoefficientBy(4));
+}
+
+static SDValue foldReduceOperandViaVQDOT(SDValue InVec, const SDLoc &DL,
+                                         SelectionDAG &DAG,
+                                         const RISCVSubtarget &Subtarget,
+                                         const RISCVTargetLowering &TLI) {
+  // Note: We intentionally do not check the legality of the reduction type.
+  // We want to handle the m4/m8 *src*  types, and thus need to let illegal
+  // intermediate types flow through here.
+  if (InVec.getValueType().getVectorElementType() != MVT::i32 ||
+      !InVec.getValueType().getVectorElementCount().isKnownMultipleOf(4))
+    return SDValue();
+
+  // reduce (sext a) <--> reduce (mul zext a. zext 1)
+  // reduce (zext a) <--> reduce (mul sext a. sext 1)
+  if (InVec.getOpcode() == ISD::ZERO_EXTEND ||
+      InVec.getOpcode() == ISD::SIGN_EXTEND) {
+    SDValue A = InVec.getOperand(0);
+    if (A.getValueType().getVectorElementType() != MVT::i8 ||
+        !TLI.isTypeLegal(A.getValueType()))
+      return SDValue();
+
+    MVT ResVT = getQDOTXResultType(A.getSimpleValueType());
+    A = DAG.getBitcast(ResVT, A);
+    SDValue B = DAG.getConstant(0x01010101, DL, ResVT);
+
+    bool IsSigned = InVec.getOpcode() == ISD::SIGN_EXTEND;
+    unsigned Opc = IsSigned ? RISCVISD::VQDOT_VL : RISCVISD::VQDOTU_VL;
+    return lowerVQDOT(Opc, A, B, DL, DAG, Subtarget);
+  }
+
+  // mul (sext, sext) -> vqdot
+  // mul (zext, zext) -> vqdotu
+  // mul (sext, zext) -> vqdotsu
+  // mul (zext, sext) -> vqdotsu (swapped)
+  // TODO: Improve .vx handling - we end up with a sub-vector insert
+  // which confuses the splat pattern matching.  Also, match vqdotus.vx
+  if (InVec.getOpcode() != ISD::MUL)
----------------
preames wrote:

Annoyingly complicated, possible future work.

The problem is that we have to expand the shift as a multiply by 2^N, and the range of shift amounts we can handle is very limited due to the input being an i8.  

https://github.com/llvm/llvm-project/pull/137039


More information about the llvm-commits mailing list