[llvm] 2680afb - [RISCV] Migrate zvqdotq reduce matching to use partial_reduce infrastructure (#142212)
via llvm-commits
llvm-commits at lists.llvm.org
Mon Jun 9 17:47:11 PDT 2025
Author: Philip Reames
Date: 2025-06-09T17:47:08-07:00
New Revision: 2680afb76bea48b9635cfe1e2aee81b421b71cde
URL: https://github.com/llvm/llvm-project/commit/2680afb76bea48b9635cfe1e2aee81b421b71cde
DIFF: https://github.com/llvm/llvm-project/commit/2680afb76bea48b9635cfe1e2aee81b421b71cde.diff
LOG: [RISCV] Migrate zvqdotq reduce matching to use partial_reduce infrastructure (#142212)
This involves a codegen regression at the moment due to the issue
described in 443cdd0b, but this aligns the lowering paths for this case
and makes it less likely future bugs go undetected.
Added:
Modified:
llvm/lib/Target/RISCV/RISCVISelLowering.cpp
llvm/test/CodeGen/RISCV/rvv/fixed-vectors-zvqdotq.ll
llvm/test/CodeGen/RISCV/rvv/zvqdotq-sdnode.ll
Removed:
################################################################################
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index ab8b36df44d3f..cefcd914c5a8f 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -18372,31 +18372,6 @@ static SDValue performBUILD_VECTORCombine(SDNode *N, SelectionDAG &DAG,
DAG.getBuildVector(VT, DL, RHSOps));
}
-static SDValue lowerVQDOT(unsigned Opc, SDValue Op0, SDValue Op1,
- const SDLoc &DL, SelectionDAG &DAG,
- const RISCVSubtarget &Subtarget) {
- assert(RISCVISD::VQDOT_VL == Opc || RISCVISD::VQDOTU_VL == Opc ||
- RISCVISD::VQDOTSU_VL == Opc);
- MVT VT = Op0.getSimpleValueType();
- assert(VT == Op1.getSimpleValueType() &&
- VT.getVectorElementType() == MVT::i32);
-
- SDValue Passthru = DAG.getConstant(0, DL, VT);
- MVT ContainerVT = VT;
- if (VT.isFixedLengthVector()) {
- ContainerVT = getContainerForFixedLengthVector(DAG, VT, Subtarget);
- Passthru = convertToScalableVector(ContainerVT, Passthru, DAG, Subtarget);
- Op0 = convertToScalableVector(ContainerVT, Op0, DAG, Subtarget);
- Op1 = convertToScalableVector(ContainerVT, Op1, DAG, Subtarget);
- }
- auto [Mask, VL] = getDefaultVLOps(VT, ContainerVT, DL, DAG, Subtarget);
- SDValue LocalAccum = DAG.getNode(Opc, DL, ContainerVT,
- {Op0, Op1, Passthru, Mask, VL});
- if (VT.isFixedLengthVector())
- return convertFromScalableVector(VT, LocalAccum, DAG, Subtarget);
- return LocalAccum;
-}
-
static MVT getQDOTXResultType(MVT OpVT) {
ElementCount OpEC = OpVT.getVectorElementCount();
assert(OpEC.isKnownMultipleOf(4) && OpVT.getVectorElementType() == MVT::i8);
@@ -18455,61 +18430,62 @@ static SDValue foldReduceOperandViaVQDOT(SDValue InVec, const SDLoc &DL,
}
}
- // reduce (zext a) <--> reduce (mul zext a. zext 1)
- // reduce (sext a) <--> reduce (mul sext a. sext 1)
+ // zext a <--> partial_reduce_umla 0, a, 1
+ // sext a <--> partial_reduce_smla 0, a, 1
if (InVec.getOpcode() == ISD::ZERO_EXTEND ||
InVec.getOpcode() == ISD::SIGN_EXTEND) {
SDValue A = InVec.getOperand(0);
- if (A.getValueType().getVectorElementType() != MVT::i8 ||
- !TLI.isTypeLegal(A.getValueType()))
+ EVT OpVT = A.getValueType();
+ if (OpVT.getVectorElementType() != MVT::i8 || !TLI.isTypeLegal(OpVT))
return SDValue();
MVT ResVT = getQDOTXResultType(A.getSimpleValueType());
- A = DAG.getBitcast(ResVT, A);
- SDValue B = DAG.getConstant(0x01010101, DL, ResVT);
-
+ SDValue B = DAG.getConstant(0x1, DL, OpVT);
bool IsSigned = InVec.getOpcode() == ISD::SIGN_EXTEND;
- unsigned Opc = IsSigned ? RISCVISD::VQDOT_VL : RISCVISD::VQDOTU_VL;
- return lowerVQDOT(Opc, A, B, DL, DAG, Subtarget);
+ unsigned Opc =
+ IsSigned ? ISD::PARTIAL_REDUCE_SMLA : ISD::PARTIAL_REDUCE_UMLA;
+ return DAG.getNode(Opc, DL, ResVT, {DAG.getConstant(0, DL, ResVT), A, B});
}
- // mul (sext, sext) -> vqdot
- // mul (zext, zext) -> vqdotu
- // mul (sext, zext) -> vqdotsu
- // mul (zext, sext) -> vqdotsu (swapped)
- // TODO: Improve .vx handling - we end up with a sub-vector insert
- // which confuses the splat pattern matching. Also, match vqdotus.vx
+ // mul (sext a, sext b) -> partial_reduce_smla 0, a, b
+ // mul (zext a, zext b) -> partial_reduce_umla 0, a, b
+ // mul (sext a, zext b) -> partial_reduce_ssmla 0, a, b
+ // mul (zext a, sext b) -> partial_reduce_smla 0, b, a (swapped)
if (InVec.getOpcode() != ISD::MUL)
return SDValue();
SDValue A = InVec.getOperand(0);
SDValue B = InVec.getOperand(1);
- unsigned Opc = 0;
- if (A.getOpcode() == B.getOpcode()) {
- if (A.getOpcode() == ISD::SIGN_EXTEND)
- Opc = RISCVISD::VQDOT_VL;
- else if (A.getOpcode() == ISD::ZERO_EXTEND)
- Opc = RISCVISD::VQDOTU_VL;
- else
- return SDValue();
- } else {
- if (B.getOpcode() != ISD::ZERO_EXTEND)
- std::swap(A, B);
- if (A.getOpcode() != ISD::SIGN_EXTEND || B.getOpcode() != ISD::ZERO_EXTEND)
- return SDValue();
- Opc = RISCVISD::VQDOTSU_VL;
- }
- assert(Opc);
- if (A.getOperand(0).getValueType().getVectorElementType() != MVT::i8 ||
- A.getOperand(0).getValueType() != B.getOperand(0).getValueType() ||
+ if (!ISD::isExtOpcode(A.getOpcode()))
+ return SDValue();
+
+ EVT OpVT = A.getOperand(0).getValueType();
+ if (OpVT.getVectorElementType() != MVT::i8 ||
+ OpVT != B.getOperand(0).getValueType() ||
!TLI.isTypeLegal(A.getValueType()))
return SDValue();
- MVT ResVT = getQDOTXResultType(A.getOperand(0).getSimpleValueType());
- A = DAG.getBitcast(ResVT, A.getOperand(0));
- B = DAG.getBitcast(ResVT, B.getOperand(0));
- return lowerVQDOT(Opc, A, B, DL, DAG, Subtarget);
+ unsigned Opc;
+ if (A.getOpcode() == ISD::SIGN_EXTEND && B.getOpcode() == ISD::SIGN_EXTEND)
+ Opc = ISD::PARTIAL_REDUCE_SMLA;
+ else if (A.getOpcode() == ISD::ZERO_EXTEND &&
+ B.getOpcode() == ISD::ZERO_EXTEND)
+ Opc = ISD::PARTIAL_REDUCE_UMLA;
+ else if (A.getOpcode() == ISD::SIGN_EXTEND &&
+ B.getOpcode() == ISD::ZERO_EXTEND)
+ Opc = ISD::PARTIAL_REDUCE_SUMLA;
+ else if (A.getOpcode() == ISD::ZERO_EXTEND &&
+ B.getOpcode() == ISD::SIGN_EXTEND) {
+ Opc = ISD::PARTIAL_REDUCE_SUMLA;
+ std::swap(A, B);
+ } else
+ return SDValue();
+
+ MVT ResVT = getQDOTXResultType(OpVT.getSimpleVT());
+ return DAG.getNode(
+ Opc, DL, ResVT,
+ {DAG.getConstant(0, DL, ResVT), A.getOperand(0), B.getOperand(0)});
}
static SDValue performVECREDUCECombine(SDNode *N, SelectionDAG &DAG,
diff --git a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-zvqdotq.ll b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-zvqdotq.ll
index 6e8eaa2ab6f74..a189711d11471 100644
--- a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-zvqdotq.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-zvqdotq.ll
@@ -232,13 +232,13 @@ define i32 @reduce_of_sext(<16 x i8> %a) {
;
; DOT-LABEL: reduce_of_sext:
; DOT: # %bb.0: # %entry
+; DOT-NEXT: vsetivli zero, 16, e8, m1, ta, ma
+; DOT-NEXT: vmv.v.i v9, 1
; DOT-NEXT: vsetivli zero, 4, e32, m1, ta, ma
-; DOT-NEXT: vmv.v.i v9, 0
-; DOT-NEXT: lui a0, 4112
-; DOT-NEXT: addi a0, a0, 257
-; DOT-NEXT: vqdot.vx v9, v8, a0
+; DOT-NEXT: vmv.v.i v10, 0
+; DOT-NEXT: vqdot.vv v10, v8, v9
; DOT-NEXT: vmv.s.x v8, zero
-; DOT-NEXT: vredsum.vs v8, v9, v8
+; DOT-NEXT: vredsum.vs v8, v10, v8
; DOT-NEXT: vmv.x.s a0, v8
; DOT-NEXT: ret
entry:
@@ -259,13 +259,13 @@ define i32 @reduce_of_zext(<16 x i8> %a) {
;
; DOT-LABEL: reduce_of_zext:
; DOT: # %bb.0: # %entry
+; DOT-NEXT: vsetivli zero, 16, e8, m1, ta, ma
+; DOT-NEXT: vmv.v.i v9, 1
; DOT-NEXT: vsetivli zero, 4, e32, m1, ta, ma
-; DOT-NEXT: vmv.v.i v9, 0
-; DOT-NEXT: lui a0, 4112
-; DOT-NEXT: addi a0, a0, 257
-; DOT-NEXT: vqdotu.vx v9, v8, a0
+; DOT-NEXT: vmv.v.i v10, 0
+; DOT-NEXT: vqdotu.vv v10, v8, v9
; DOT-NEXT: vmv.s.x v8, zero
-; DOT-NEXT: vredsum.vs v8, v9, v8
+; DOT-NEXT: vredsum.vs v8, v10, v8
; DOT-NEXT: vmv.x.s a0, v8
; DOT-NEXT: ret
entry:
diff --git a/llvm/test/CodeGen/RISCV/rvv/zvqdotq-sdnode.ll b/llvm/test/CodeGen/RISCV/rvv/zvqdotq-sdnode.ll
index 0b6f8a7a838bc..87a984bda1fee 100644
--- a/llvm/test/CodeGen/RISCV/rvv/zvqdotq-sdnode.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/zvqdotq-sdnode.ll
@@ -232,13 +232,13 @@ define i32 @reduce_of_sext(<vscale x 16 x i8> %a) {
;
; DOT-LABEL: reduce_of_sext:
; DOT: # %bb.0: # %entry
+; DOT-NEXT: vsetvli a0, zero, e8, m2, ta, ma
+; DOT-NEXT: vmv.v.i v10, 1
; DOT-NEXT: vsetvli a0, zero, e32, m2, ta, ma
-; DOT-NEXT: vmv.v.i v10, 0
-; DOT-NEXT: lui a0, 4112
-; DOT-NEXT: addi a0, a0, 257
-; DOT-NEXT: vqdot.vx v10, v8, a0
+; DOT-NEXT: vmv.v.i v12, 0
+; DOT-NEXT: vqdot.vv v12, v8, v10
; DOT-NEXT: vmv.s.x v8, zero
-; DOT-NEXT: vredsum.vs v8, v10, v8
+; DOT-NEXT: vredsum.vs v8, v12, v8
; DOT-NEXT: vmv.x.s a0, v8
; DOT-NEXT: ret
entry:
@@ -259,13 +259,13 @@ define i32 @reduce_of_zext(<vscale x 16 x i8> %a) {
;
; DOT-LABEL: reduce_of_zext:
; DOT: # %bb.0: # %entry
+; DOT-NEXT: vsetvli a0, zero, e8, m2, ta, ma
+; DOT-NEXT: vmv.v.i v10, 1
; DOT-NEXT: vsetvli a0, zero, e32, m2, ta, ma
-; DOT-NEXT: vmv.v.i v10, 0
-; DOT-NEXT: lui a0, 4112
-; DOT-NEXT: addi a0, a0, 257
-; DOT-NEXT: vqdotu.vx v10, v8, a0
+; DOT-NEXT: vmv.v.i v12, 0
+; DOT-NEXT: vqdotu.vv v12, v8, v10
; DOT-NEXT: vmv.s.x v8, zero
-; DOT-NEXT: vredsum.vs v8, v10, v8
+; DOT-NEXT: vredsum.vs v8, v12, v8
; DOT-NEXT: vmv.x.s a0, v8
; DOT-NEXT: ret
entry:
More information about the llvm-commits
mailing list