[llvm] [RISCV] Lower PARTIAL_REDUCE_[S/U]MLA via zvqdotq (PR #140950)
Philip Reames via llvm-commits
llvm-commits at lists.llvm.org
Thu May 22 08:02:34 PDT 2025
https://github.com/preames updated https://github.com/llvm/llvm-project/pull/140950
>From 07590242f4b13be0aa7f76e3055490c9a209ac6e Mon Sep 17 00:00:00 2001
From: Philip Reames <preames at rivosinc.com>
Date: Tue, 20 May 2025 13:06:46 -0700
Subject: [PATCH 1/3] [RISCV] Lower PARTIAL_REDUCE_[S/U]MLA via zvqdotq
The semantics of the PARTIAL_REDUCE_SMLA with i32 result element, and
i8 sources corresponds to vqdot. Analogously PARTIAL_REDUCE_UMLA
corresponds to vqdotu. There is currently no vqdotsu equivalent.
This patch is a starting place. We can extend this quite a bit more,
and I plan to take a look at the fixed vector lowering, the TTI hook
to drive loop vectorizer, and to try to integrate the reduction based
lowering I'd added for zvqdotq into this flow.
---
llvm/lib/Target/RISCV/RISCVISelLowering.cpp | 32 ++
llvm/lib/Target/RISCV/RISCVISelLowering.h | 1 +
llvm/test/CodeGen/RISCV/rvv/zvqdotq-sdnode.ll | 531 +++++++++++-------
3 files changed, 355 insertions(+), 209 deletions(-)
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index d69e04a9912a2..59f43761b4105 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -1571,6 +1571,14 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
setIndexedStoreAction(ISD::POST_INC, MVT::i32, Legal);
}
+ if (Subtarget.hasStdExtZvqdotq()) {
+ setPartialReduceMLAAction(MVT::nxv1i32, MVT::nxv4i8, Custom);
+ setPartialReduceMLAAction(MVT::nxv2i32, MVT::nxv8i8, Custom);
+ setPartialReduceMLAAction(MVT::nxv4i32, MVT::nxv16i8, Custom);
+ setPartialReduceMLAAction(MVT::nxv8i32, MVT::nxv32i8, Custom);
+ setPartialReduceMLAAction(MVT::nxv16i32, MVT::nxv64i8, Custom);
+ }
+
// Function alignments.
const Align FunctionAlignment(Subtarget.hasStdExtCOrZca() ? 2 : 4);
setMinFunctionAlignment(FunctionAlignment);
@@ -8229,6 +8237,9 @@ SDValue RISCVTargetLowering::LowerOperation(SDValue Op,
return lowerINIT_TRAMPOLINE(Op, DAG);
case ISD::ADJUST_TRAMPOLINE:
return lowerADJUST_TRAMPOLINE(Op, DAG);
+ case ISD::PARTIAL_REDUCE_UMLA:
+ case ISD::PARTIAL_REDUCE_SMLA:
+ return lowerPARTIAL_REDUCE_MLA(Op, DAG);
}
}
@@ -8364,6 +8375,27 @@ SDValue RISCVTargetLowering::lowerADJUST_TRAMPOLINE(SDValue Op,
return Op.getOperand(0);
}
+SDValue RISCVTargetLowering::lowerPARTIAL_REDUCE_MLA(SDValue Op,
+ SelectionDAG &DAG) const {
+ // Currently, only the vqdot and vqdotu case (from zvqdotq) hould be legal.
+ // TODO: There are many other sub-cases we could potentially lower, are
+ // any of them worthwhile? Ex: via vredsum, vwredsum, vwwmaccu, etc..
+ // TODO: PARTIAL_REDUCE_*MLA can't represent a vqdotsu currently.
+ SDLoc DL(Op);
+ MVT VT = Op.getSimpleValueType();
+ SDValue Accum = Op.getOperand(0);
+ assert(Accum.getSimpleValueType() == VT &&
+ VT.getVectorElementType() == MVT::i32);
+ SDValue A = Op.getOperand(1);
+ SDValue B = Op.getOperand(2);
+ assert(A.getSimpleValueType() == B.getSimpleValueType() &&
+ A.getSimpleValueType().getVectorElementType() == MVT::i8);
+ bool IsSigned = Op.getOpcode() == ISD::PARTIAL_REDUCE_SMLA;
+ unsigned Opc = IsSigned ? RISCVISD::VQDOT_VL : RISCVISD::VQDOTU_VL;
+ auto [Mask, VL] = getDefaultScalableVLOps(VT, DL, DAG, Subtarget);
+ return DAG.getNode(Opc, DL, VT, {A, B, Accum, Mask, VL});
+}
+
static SDValue getTargetNode(GlobalAddressSDNode *N, const SDLoc &DL, EVT Ty,
SelectionDAG &DAG, unsigned Flags) {
return DAG.getTargetGlobalAddress(N->getGlobal(), DL, Ty, 0, Flags);
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.h b/llvm/lib/Target/RISCV/RISCVISelLowering.h
index fc8d8b8ce1b56..78f2044ba83a7 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.h
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.h
@@ -552,6 +552,7 @@ class RISCVTargetLowering : public TargetLowering {
SDValue lowerINIT_TRAMPOLINE(SDValue Op, SelectionDAG &DAG) const;
SDValue lowerADJUST_TRAMPOLINE(SDValue Op, SelectionDAG &DAG) const;
+ SDValue lowerPARTIAL_REDUCE_MLA(SDValue Op, SelectionDAG &DAG) const;
bool isEligibleForTailCallOptimization(
CCState &CCInfo, CallLoweringInfo &CLI, MachineFunction &MF,
diff --git a/llvm/test/CodeGen/RISCV/rvv/zvqdotq-sdnode.ll b/llvm/test/CodeGen/RISCV/rvv/zvqdotq-sdnode.ll
index 6df628e3bd812..2bd2ef2878fd5 100644
--- a/llvm/test/CodeGen/RISCV/rvv/zvqdotq-sdnode.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/zvqdotq-sdnode.ll
@@ -524,22 +524,30 @@ entry:
define <vscale x 1 x i32> @partial_reduce_nf2(<vscale x 4 x i8> %a, <vscale x 4 x i8> %b) {
-; CHECK-LABEL: partial_reduce_nf2:
-; CHECK: # %bb.0: # %entry
-; CHECK-NEXT: vsetvli a0, zero, e16, m1, ta, ma
-; CHECK-NEXT: vsext.vf2 v10, v8
-; CHECK-NEXT: vsext.vf2 v11, v9
-; CHECK-NEXT: csrr a0, vlenb
-; CHECK-NEXT: vwmul.vv v8, v10, v11
-; CHECK-NEXT: srli a0, a0, 3
-; CHECK-NEXT: vsetvli a1, zero, e32, m1, ta, ma
-; CHECK-NEXT: vslidedown.vx v10, v9, a0
-; CHECK-NEXT: vslidedown.vx v11, v8, a0
-; CHECK-NEXT: vsetvli a0, zero, e32, mf2, ta, ma
-; CHECK-NEXT: vadd.vv v8, v10, v8
-; CHECK-NEXT: vadd.vv v9, v11, v9
-; CHECK-NEXT: vadd.vv v8, v9, v8
-; CHECK-NEXT: ret
+; NODOT-LABEL: partial_reduce_nf2:
+; NODOT: # %bb.0: # %entry
+; NODOT-NEXT: vsetvli a0, zero, e16, m1, ta, ma
+; NODOT-NEXT: vsext.vf2 v10, v8
+; NODOT-NEXT: vsext.vf2 v11, v9
+; NODOT-NEXT: csrr a0, vlenb
+; NODOT-NEXT: vwmul.vv v8, v10, v11
+; NODOT-NEXT: srli a0, a0, 3
+; NODOT-NEXT: vsetvli a1, zero, e32, m1, ta, ma
+; NODOT-NEXT: vslidedown.vx v10, v9, a0
+; NODOT-NEXT: vslidedown.vx v11, v8, a0
+; NODOT-NEXT: vsetvli a0, zero, e32, mf2, ta, ma
+; NODOT-NEXT: vadd.vv v8, v10, v8
+; NODOT-NEXT: vadd.vv v9, v11, v9
+; NODOT-NEXT: vadd.vv v8, v9, v8
+; NODOT-NEXT: ret
+;
+; DOT-LABEL: partial_reduce_nf2:
+; DOT: # %bb.0: # %entry
+; DOT-NEXT: vsetvli a0, zero, e32, mf2, ta, ma
+; DOT-NEXT: vmv.v.i v10, 0
+; DOT-NEXT: vqdot.vv v10, v8, v9
+; DOT-NEXT: vmv1r.v v8, v10
+; DOT-NEXT: ret
entry:
%a.sext = sext <vscale x 4 x i8> %a to <vscale x 4 x i32>
%b.sext = sext <vscale x 4 x i8> %b to <vscale x 4 x i32>
@@ -549,17 +557,25 @@ entry:
}
define <vscale x 2 x i32> @partial_reduce_m1(<vscale x 8 x i8> %a, <vscale x 8 x i8> %b) {
-; CHECK-LABEL: partial_reduce_m1:
-; CHECK: # %bb.0: # %entry
-; CHECK-NEXT: vsetvli a0, zero, e16, m2, ta, ma
-; CHECK-NEXT: vsext.vf2 v12, v8
-; CHECK-NEXT: vsext.vf2 v14, v9
-; CHECK-NEXT: vwmul.vv v8, v12, v14
-; CHECK-NEXT: vsetvli a0, zero, e32, m1, ta, ma
-; CHECK-NEXT: vadd.vv v8, v11, v8
-; CHECK-NEXT: vadd.vv v9, v9, v10
-; CHECK-NEXT: vadd.vv v8, v9, v8
-; CHECK-NEXT: ret
+; NODOT-LABEL: partial_reduce_m1:
+; NODOT: # %bb.0: # %entry
+; NODOT-NEXT: vsetvli a0, zero, e16, m2, ta, ma
+; NODOT-NEXT: vsext.vf2 v12, v8
+; NODOT-NEXT: vsext.vf2 v14, v9
+; NODOT-NEXT: vwmul.vv v8, v12, v14
+; NODOT-NEXT: vsetvli a0, zero, e32, m1, ta, ma
+; NODOT-NEXT: vadd.vv v8, v11, v8
+; NODOT-NEXT: vadd.vv v9, v9, v10
+; NODOT-NEXT: vadd.vv v8, v9, v8
+; NODOT-NEXT: ret
+;
+; DOT-LABEL: partial_reduce_m1:
+; DOT: # %bb.0: # %entry
+; DOT-NEXT: vsetvli a0, zero, e32, m1, ta, ma
+; DOT-NEXT: vmv.v.i v10, 0
+; DOT-NEXT: vqdot.vv v10, v8, v9
+; DOT-NEXT: vmv.v.v v8, v10
+; DOT-NEXT: ret
entry:
%a.sext = sext <vscale x 8 x i8> %a to <vscale x 8 x i32>
%b.sext = sext <vscale x 8 x i8> %b to <vscale x 8 x i32>
@@ -569,17 +585,25 @@ entry:
}
define <vscale x 4 x i32> @partial_reduce_m2(<vscale x 16 x i8> %a, <vscale x 16 x i8> %b) {
-; CHECK-LABEL: partial_reduce_m2:
-; CHECK: # %bb.0: # %entry
-; CHECK-NEXT: vsetvli a0, zero, e16, m4, ta, ma
-; CHECK-NEXT: vsext.vf2 v16, v8
-; CHECK-NEXT: vsext.vf2 v20, v10
-; CHECK-NEXT: vwmul.vv v8, v16, v20
-; CHECK-NEXT: vsetvli a0, zero, e32, m2, ta, ma
-; CHECK-NEXT: vadd.vv v8, v14, v8
-; CHECK-NEXT: vadd.vv v10, v10, v12
-; CHECK-NEXT: vadd.vv v8, v10, v8
-; CHECK-NEXT: ret
+; NODOT-LABEL: partial_reduce_m2:
+; NODOT: # %bb.0: # %entry
+; NODOT-NEXT: vsetvli a0, zero, e16, m4, ta, ma
+; NODOT-NEXT: vsext.vf2 v16, v8
+; NODOT-NEXT: vsext.vf2 v20, v10
+; NODOT-NEXT: vwmul.vv v8, v16, v20
+; NODOT-NEXT: vsetvli a0, zero, e32, m2, ta, ma
+; NODOT-NEXT: vadd.vv v8, v14, v8
+; NODOT-NEXT: vadd.vv v10, v10, v12
+; NODOT-NEXT: vadd.vv v8, v10, v8
+; NODOT-NEXT: ret
+;
+; DOT-LABEL: partial_reduce_m2:
+; DOT: # %bb.0: # %entry
+; DOT-NEXT: vsetvli a0, zero, e32, m2, ta, ma
+; DOT-NEXT: vmv.v.i v12, 0
+; DOT-NEXT: vqdot.vv v12, v8, v10
+; DOT-NEXT: vmv.v.v v8, v12
+; DOT-NEXT: ret
entry:
%a.sext = sext <vscale x 16 x i8> %a to <vscale x 16 x i32>
%b.sext = sext <vscale x 16 x i8> %b to <vscale x 16 x i32>
@@ -589,20 +613,28 @@ entry:
}
define <vscale x 8 x i32> @partial_reduce_m4(<vscale x 32 x i8> %a, <vscale x 32 x i8> %b) {
-; CHECK-LABEL: partial_reduce_m4:
-; CHECK: # %bb.0: # %entry
-; CHECK-NEXT: vsetvli a0, zero, e16, m4, ta, ma
-; CHECK-NEXT: vsext.vf2 v24, v8
-; CHECK-NEXT: vsext.vf2 v16, v10
-; CHECK-NEXT: vsext.vf2 v28, v12
-; CHECK-NEXT: vsext.vf2 v20, v14
-; CHECK-NEXT: vwmul.vv v8, v16, v20
-; CHECK-NEXT: vwmul.vv v16, v24, v28
-; CHECK-NEXT: vsetvli a0, zero, e32, m4, ta, ma
-; CHECK-NEXT: vadd.vv v16, v20, v16
-; CHECK-NEXT: vadd.vv v8, v12, v8
-; CHECK-NEXT: vadd.vv v8, v8, v16
-; CHECK-NEXT: ret
+; NODOT-LABEL: partial_reduce_m4:
+; NODOT: # %bb.0: # %entry
+; NODOT-NEXT: vsetvli a0, zero, e16, m4, ta, ma
+; NODOT-NEXT: vsext.vf2 v24, v8
+; NODOT-NEXT: vsext.vf2 v16, v10
+; NODOT-NEXT: vsext.vf2 v28, v12
+; NODOT-NEXT: vsext.vf2 v20, v14
+; NODOT-NEXT: vwmul.vv v8, v16, v20
+; NODOT-NEXT: vwmul.vv v16, v24, v28
+; NODOT-NEXT: vsetvli a0, zero, e32, m4, ta, ma
+; NODOT-NEXT: vadd.vv v16, v20, v16
+; NODOT-NEXT: vadd.vv v8, v12, v8
+; NODOT-NEXT: vadd.vv v8, v8, v16
+; NODOT-NEXT: ret
+;
+; DOT-LABEL: partial_reduce_m4:
+; DOT: # %bb.0: # %entry
+; DOT-NEXT: vsetvli a0, zero, e32, m4, ta, ma
+; DOT-NEXT: vmv.v.i v16, 0
+; DOT-NEXT: vqdot.vv v16, v8, v12
+; DOT-NEXT: vmv.v.v v8, v16
+; DOT-NEXT: ret
entry:
%a.sext = sext <vscale x 32 x i8> %a to <vscale x 32 x i32>
%b.sext = sext <vscale x 32 x i8> %b to <vscale x 32 x i32>
@@ -612,38 +644,46 @@ entry:
}
define <vscale x 16 x i32> @partial_reduce_m8(<vscale x 64 x i8> %a, <vscale x 64 x i8> %b) {
-; CHECK-LABEL: partial_reduce_m8:
-; CHECK: # %bb.0: # %entry
-; CHECK-NEXT: addi sp, sp, -16
-; CHECK-NEXT: .cfi_def_cfa_offset 16
-; CHECK-NEXT: csrr a0, vlenb
-; CHECK-NEXT: slli a0, a0, 2
-; CHECK-NEXT: sub sp, sp, a0
-; CHECK-NEXT: .cfi_escape 0x0f, 0x0d, 0x72, 0x00, 0x11, 0x10, 0x22, 0x11, 0x04, 0x92, 0xa2, 0x38, 0x00, 0x1e, 0x22 # sp + 16 + 4 * vlenb
-; CHECK-NEXT: vsetvli a0, zero, e16, m4, ta, ma
-; CHECK-NEXT: vsext.vf2 v24, v10
-; CHECK-NEXT: addi a0, sp, 16
-; CHECK-NEXT: vs4r.v v24, (a0) # vscale x 32-byte Folded Spill
-; CHECK-NEXT: vsext.vf2 v0, v8
-; CHECK-NEXT: vsext.vf2 v8, v18
-; CHECK-NEXT: vsext.vf2 v4, v16
-; CHECK-NEXT: vwmul.vv v24, v0, v4
-; CHECK-NEXT: vl4r.v v16, (a0) # vscale x 32-byte Folded Reload
-; CHECK-NEXT: vwmacc.vv v24, v16, v8
-; CHECK-NEXT: vsext.vf2 v8, v12
-; CHECK-NEXT: vsext.vf2 v16, v20
-; CHECK-NEXT: vwmacc.vv v24, v8, v16
-; CHECK-NEXT: vsext.vf2 v8, v14
-; CHECK-NEXT: vsext.vf2 v12, v22
-; CHECK-NEXT: vwmacc.vv v24, v8, v12
-; CHECK-NEXT: vmv8r.v v8, v24
-; CHECK-NEXT: csrr a0, vlenb
-; CHECK-NEXT: slli a0, a0, 2
-; CHECK-NEXT: add sp, sp, a0
-; CHECK-NEXT: .cfi_def_cfa sp, 16
-; CHECK-NEXT: addi sp, sp, 16
-; CHECK-NEXT: .cfi_def_cfa_offset 0
-; CHECK-NEXT: ret
+; NODOT-LABEL: partial_reduce_m8:
+; NODOT: # %bb.0: # %entry
+; NODOT-NEXT: addi sp, sp, -16
+; NODOT-NEXT: .cfi_def_cfa_offset 16
+; NODOT-NEXT: csrr a0, vlenb
+; NODOT-NEXT: slli a0, a0, 2
+; NODOT-NEXT: sub sp, sp, a0
+; NODOT-NEXT: .cfi_escape 0x0f, 0x0d, 0x72, 0x00, 0x11, 0x10, 0x22, 0x11, 0x04, 0x92, 0xa2, 0x38, 0x00, 0x1e, 0x22 # sp + 16 + 4 * vlenb
+; NODOT-NEXT: vsetvli a0, zero, e16, m4, ta, ma
+; NODOT-NEXT: vsext.vf2 v24, v10
+; NODOT-NEXT: addi a0, sp, 16
+; NODOT-NEXT: vs4r.v v24, (a0) # vscale x 32-byte Folded Spill
+; NODOT-NEXT: vsext.vf2 v0, v8
+; NODOT-NEXT: vsext.vf2 v8, v18
+; NODOT-NEXT: vsext.vf2 v4, v16
+; NODOT-NEXT: vwmul.vv v24, v0, v4
+; NODOT-NEXT: vl4r.v v16, (a0) # vscale x 32-byte Folded Reload
+; NODOT-NEXT: vwmacc.vv v24, v16, v8
+; NODOT-NEXT: vsext.vf2 v8, v12
+; NODOT-NEXT: vsext.vf2 v16, v20
+; NODOT-NEXT: vwmacc.vv v24, v8, v16
+; NODOT-NEXT: vsext.vf2 v8, v14
+; NODOT-NEXT: vsext.vf2 v12, v22
+; NODOT-NEXT: vwmacc.vv v24, v8, v12
+; NODOT-NEXT: vmv8r.v v8, v24
+; NODOT-NEXT: csrr a0, vlenb
+; NODOT-NEXT: slli a0, a0, 2
+; NODOT-NEXT: add sp, sp, a0
+; NODOT-NEXT: .cfi_def_cfa sp, 16
+; NODOT-NEXT: addi sp, sp, 16
+; NODOT-NEXT: .cfi_def_cfa_offset 0
+; NODOT-NEXT: ret
+;
+; DOT-LABEL: partial_reduce_m8:
+; DOT: # %bb.0: # %entry
+; DOT-NEXT: vsetvli a0, zero, e32, m8, ta, ma
+; DOT-NEXT: vmv.v.i v24, 0
+; DOT-NEXT: vqdot.vv v24, v8, v16
+; DOT-NEXT: vmv.v.v v8, v24
+; DOT-NEXT: ret
entry:
%a.sext = sext <vscale x 64 x i8> %a to <vscale x 64 x i32>
%b.sext = sext <vscale x 64 x i8> %b to <vscale x 64 x i32>
@@ -653,103 +693,161 @@ entry:
}
define <vscale x 32 x i32> @partial_reduce_m16(<vscale x 128 x i8> %a, <vscale x 128 x i8> %b) {
-; CHECK-LABEL: partial_reduce_m16:
-; CHECK: # %bb.0: # %entry
-; CHECK-NEXT: addi sp, sp, -16
-; CHECK-NEXT: .cfi_def_cfa_offset 16
-; CHECK-NEXT: csrr a1, vlenb
-; CHECK-NEXT: slli a1, a1, 3
-; CHECK-NEXT: mv a2, a1
-; CHECK-NEXT: slli a1, a1, 1
-; CHECK-NEXT: add a1, a1, a2
-; CHECK-NEXT: sub sp, sp, a1
-; CHECK-NEXT: .cfi_escape 0x0f, 0x0d, 0x72, 0x00, 0x11, 0x10, 0x22, 0x11, 0x18, 0x92, 0xa2, 0x38, 0x00, 0x1e, 0x22 # sp + 16 + 24 * vlenb
-; CHECK-NEXT: csrr a1, vlenb
-; CHECK-NEXT: slli a1, a1, 4
-; CHECK-NEXT: add a1, sp, a1
-; CHECK-NEXT: addi a1, a1, 16
-; CHECK-NEXT: vs8r.v v16, (a1) # vscale x 64-byte Folded Spill
-; CHECK-NEXT: addi a1, sp, 16
-; CHECK-NEXT: vs8r.v v8, (a1) # vscale x 64-byte Folded Spill
-; CHECK-NEXT: vl8r.v v16, (a0)
-; CHECK-NEXT: csrr a1, vlenb
-; CHECK-NEXT: slli a1, a1, 3
-; CHECK-NEXT: add a1, sp, a1
-; CHECK-NEXT: addi a1, a1, 16
-; CHECK-NEXT: vs8r.v v16, (a1) # vscale x 64-byte Folded Spill
-; CHECK-NEXT: vsetvli a1, zero, e16, m4, ta, ma
-; CHECK-NEXT: vsext.vf2 v4, v8
-; CHECK-NEXT: vsext.vf2 v0, v16
-; CHECK-NEXT: vwmul.vv v24, v4, v0
-; CHECK-NEXT: vsext.vf2 v4, v10
-; CHECK-NEXT: vsext.vf2 v8, v18
-; CHECK-NEXT: vwmacc.vv v24, v4, v8
-; CHECK-NEXT: csrr a1, vlenb
-; CHECK-NEXT: slli a1, a1, 3
-; CHECK-NEXT: add a0, a0, a1
-; CHECK-NEXT: vsext.vf2 v0, v12
-; CHECK-NEXT: vl8r.v v8, (a0)
-; CHECK-NEXT: csrr a0, vlenb
-; CHECK-NEXT: slli a0, a0, 3
-; CHECK-NEXT: add a0, sp, a0
-; CHECK-NEXT: addi a0, a0, 16
-; CHECK-NEXT: vl8r.v v16, (a0) # vscale x 64-byte Folded Reload
-; CHECK-NEXT: vsext.vf2 v4, v20
-; CHECK-NEXT: vwmacc.vv v24, v0, v4
-; CHECK-NEXT: csrr a0, vlenb
-; CHECK-NEXT: slli a0, a0, 4
-; CHECK-NEXT: add a0, sp, a0
-; CHECK-NEXT: addi a0, a0, 16
-; CHECK-NEXT: vl8r.v v0, (a0) # vscale x 64-byte Folded Reload
-; CHECK-NEXT: vsext.vf2 v20, v0
-; CHECK-NEXT: vsext.vf2 v16, v8
-; CHECK-NEXT: vwmul.vv v0, v20, v16
-; CHECK-NEXT: csrr a0, vlenb
-; CHECK-NEXT: slli a0, a0, 4
-; CHECK-NEXT: add a0, sp, a0
-; CHECK-NEXT: addi a0, a0, 16
-; CHECK-NEXT: vl8r.v v16, (a0) # vscale x 64-byte Folded Reload
-; CHECK-NEXT: vsext.vf2 v20, v18
-; CHECK-NEXT: vsext.vf2 v16, v10
-; CHECK-NEXT: vwmacc.vv v0, v20, v16
-; CHECK-NEXT: csrr a0, vlenb
-; CHECK-NEXT: slli a0, a0, 4
-; CHECK-NEXT: add a0, sp, a0
-; CHECK-NEXT: addi a0, a0, 16
-; CHECK-NEXT: vl8r.v v16, (a0) # vscale x 64-byte Folded Reload
-; CHECK-NEXT: vsext.vf2 v8, v20
-; CHECK-NEXT: vsext.vf2 v16, v12
-; CHECK-NEXT: vwmacc.vv v0, v8, v16
-; CHECK-NEXT: csrr a0, vlenb
-; CHECK-NEXT: slli a0, a0, 4
-; CHECK-NEXT: add a0, sp, a0
-; CHECK-NEXT: addi a0, a0, 16
-; CHECK-NEXT: vl8r.v v16, (a0) # vscale x 64-byte Folded Reload
-; CHECK-NEXT: vsext.vf2 v8, v22
-; CHECK-NEXT: vsext.vf2 v16, v14
-; CHECK-NEXT: vwmacc.vv v0, v8, v16
-; CHECK-NEXT: addi a0, sp, 16
-; CHECK-NEXT: vl8r.v v8, (a0) # vscale x 64-byte Folded Reload
-; CHECK-NEXT: vsext.vf2 v8, v14
-; CHECK-NEXT: csrr a0, vlenb
-; CHECK-NEXT: slli a0, a0, 3
-; CHECK-NEXT: add a0, sp, a0
-; CHECK-NEXT: addi a0, a0, 16
-; CHECK-NEXT: vl8r.v v16, (a0) # vscale x 64-byte Folded Reload
-; CHECK-NEXT: vsext.vf2 v12, v22
-; CHECK-NEXT: vwmacc.vv v24, v8, v12
-; CHECK-NEXT: vmv8r.v v8, v24
-; CHECK-NEXT: vmv8r.v v16, v0
-; CHECK-NEXT: csrr a0, vlenb
-; CHECK-NEXT: slli a0, a0, 3
-; CHECK-NEXT: mv a1, a0
-; CHECK-NEXT: slli a0, a0, 1
-; CHECK-NEXT: add a0, a0, a1
-; CHECK-NEXT: add sp, sp, a0
-; CHECK-NEXT: .cfi_def_cfa sp, 16
-; CHECK-NEXT: addi sp, sp, 16
-; CHECK-NEXT: .cfi_def_cfa_offset 0
-; CHECK-NEXT: ret
+; NODOT-LABEL: partial_reduce_m16:
+; NODOT: # %bb.0: # %entry
+; NODOT-NEXT: addi sp, sp, -16
+; NODOT-NEXT: .cfi_def_cfa_offset 16
+; NODOT-NEXT: csrr a1, vlenb
+; NODOT-NEXT: slli a1, a1, 3
+; NODOT-NEXT: mv a2, a1
+; NODOT-NEXT: slli a1, a1, 1
+; NODOT-NEXT: add a1, a1, a2
+; NODOT-NEXT: sub sp, sp, a1
+; NODOT-NEXT: .cfi_escape 0x0f, 0x0d, 0x72, 0x00, 0x11, 0x10, 0x22, 0x11, 0x18, 0x92, 0xa2, 0x38, 0x00, 0x1e, 0x22 # sp + 16 + 24 * vlenb
+; NODOT-NEXT: csrr a1, vlenb
+; NODOT-NEXT: slli a1, a1, 4
+; NODOT-NEXT: add a1, sp, a1
+; NODOT-NEXT: addi a1, a1, 16
+; NODOT-NEXT: vs8r.v v16, (a1) # vscale x 64-byte Folded Spill
+; NODOT-NEXT: addi a1, sp, 16
+; NODOT-NEXT: vs8r.v v8, (a1) # vscale x 64-byte Folded Spill
+; NODOT-NEXT: vl8r.v v16, (a0)
+; NODOT-NEXT: csrr a1, vlenb
+; NODOT-NEXT: slli a1, a1, 3
+; NODOT-NEXT: add a1, sp, a1
+; NODOT-NEXT: addi a1, a1, 16
+; NODOT-NEXT: vs8r.v v16, (a1) # vscale x 64-byte Folded Spill
+; NODOT-NEXT: vsetvli a1, zero, e16, m4, ta, ma
+; NODOT-NEXT: vsext.vf2 v4, v8
+; NODOT-NEXT: vsext.vf2 v0, v16
+; NODOT-NEXT: vwmul.vv v24, v4, v0
+; NODOT-NEXT: vsext.vf2 v4, v10
+; NODOT-NEXT: vsext.vf2 v8, v18
+; NODOT-NEXT: vwmacc.vv v24, v4, v8
+; NODOT-NEXT: csrr a1, vlenb
+; NODOT-NEXT: slli a1, a1, 3
+; NODOT-NEXT: add a0, a0, a1
+; NODOT-NEXT: vsext.vf2 v0, v12
+; NODOT-NEXT: vl8r.v v8, (a0)
+; NODOT-NEXT: csrr a0, vlenb
+; NODOT-NEXT: slli a0, a0, 3
+; NODOT-NEXT: add a0, sp, a0
+; NODOT-NEXT: addi a0, a0, 16
+; NODOT-NEXT: vl8r.v v16, (a0) # vscale x 64-byte Folded Reload
+; NODOT-NEXT: vsext.vf2 v4, v20
+; NODOT-NEXT: vwmacc.vv v24, v0, v4
+; NODOT-NEXT: csrr a0, vlenb
+; NODOT-NEXT: slli a0, a0, 4
+; NODOT-NEXT: add a0, sp, a0
+; NODOT-NEXT: addi a0, a0, 16
+; NODOT-NEXT: vl8r.v v0, (a0) # vscale x 64-byte Folded Reload
+; NODOT-NEXT: vsext.vf2 v20, v0
+; NODOT-NEXT: vsext.vf2 v16, v8
+; NODOT-NEXT: vwmul.vv v0, v20, v16
+; NODOT-NEXT: csrr a0, vlenb
+; NODOT-NEXT: slli a0, a0, 4
+; NODOT-NEXT: add a0, sp, a0
+; NODOT-NEXT: addi a0, a0, 16
+; NODOT-NEXT: vl8r.v v16, (a0) # vscale x 64-byte Folded Reload
+; NODOT-NEXT: vsext.vf2 v20, v18
+; NODOT-NEXT: vsext.vf2 v16, v10
+; NODOT-NEXT: vwmacc.vv v0, v20, v16
+; NODOT-NEXT: csrr a0, vlenb
+; NODOT-NEXT: slli a0, a0, 4
+; NODOT-NEXT: add a0, sp, a0
+; NODOT-NEXT: addi a0, a0, 16
+; NODOT-NEXT: vl8r.v v16, (a0) # vscale x 64-byte Folded Reload
+; NODOT-NEXT: vsext.vf2 v8, v20
+; NODOT-NEXT: vsext.vf2 v16, v12
+; NODOT-NEXT: vwmacc.vv v0, v8, v16
+; NODOT-NEXT: csrr a0, vlenb
+; NODOT-NEXT: slli a0, a0, 4
+; NODOT-NEXT: add a0, sp, a0
+; NODOT-NEXT: addi a0, a0, 16
+; NODOT-NEXT: vl8r.v v16, (a0) # vscale x 64-byte Folded Reload
+; NODOT-NEXT: vsext.vf2 v8, v22
+; NODOT-NEXT: vsext.vf2 v16, v14
+; NODOT-NEXT: vwmacc.vv v0, v8, v16
+; NODOT-NEXT: addi a0, sp, 16
+; NODOT-NEXT: vl8r.v v8, (a0) # vscale x 64-byte Folded Reload
+; NODOT-NEXT: vsext.vf2 v8, v14
+; NODOT-NEXT: csrr a0, vlenb
+; NODOT-NEXT: slli a0, a0, 3
+; NODOT-NEXT: add a0, sp, a0
+; NODOT-NEXT: addi a0, a0, 16
+; NODOT-NEXT: vl8r.v v16, (a0) # vscale x 64-byte Folded Reload
+; NODOT-NEXT: vsext.vf2 v12, v22
+; NODOT-NEXT: vwmacc.vv v24, v8, v12
+; NODOT-NEXT: vmv8r.v v8, v24
+; NODOT-NEXT: vmv8r.v v16, v0
+; NODOT-NEXT: csrr a0, vlenb
+; NODOT-NEXT: slli a0, a0, 3
+; NODOT-NEXT: mv a1, a0
+; NODOT-NEXT: slli a0, a0, 1
+; NODOT-NEXT: add a0, a0, a1
+; NODOT-NEXT: add sp, sp, a0
+; NODOT-NEXT: .cfi_def_cfa sp, 16
+; NODOT-NEXT: addi sp, sp, 16
+; NODOT-NEXT: .cfi_def_cfa_offset 0
+; NODOT-NEXT: ret
+;
+; DOT-LABEL: partial_reduce_m16:
+; DOT: # %bb.0: # %entry
+; DOT-NEXT: addi sp, sp, -16
+; DOT-NEXT: .cfi_def_cfa_offset 16
+; DOT-NEXT: csrr a1, vlenb
+; DOT-NEXT: slli a1, a1, 3
+; DOT-NEXT: mv a2, a1
+; DOT-NEXT: slli a1, a1, 1
+; DOT-NEXT: add a1, a1, a2
+; DOT-NEXT: sub sp, sp, a1
+; DOT-NEXT: .cfi_escape 0x0f, 0x0d, 0x72, 0x00, 0x11, 0x10, 0x22, 0x11, 0x18, 0x92, 0xa2, 0x38, 0x00, 0x1e, 0x22 # sp + 16 + 24 * vlenb
+; DOT-NEXT: csrr a1, vlenb
+; DOT-NEXT: slli a1, a1, 4
+; DOT-NEXT: add a1, sp, a1
+; DOT-NEXT: addi a1, a1, 16
+; DOT-NEXT: vs8r.v v16, (a1) # vscale x 64-byte Folded Spill
+; DOT-NEXT: csrr a1, vlenb
+; DOT-NEXT: slli a1, a1, 3
+; DOT-NEXT: add a1, sp, a1
+; DOT-NEXT: addi a1, a1, 16
+; DOT-NEXT: vs8r.v v8, (a1) # vscale x 64-byte Folded Spill
+; DOT-NEXT: csrr a1, vlenb
+; DOT-NEXT: slli a1, a1, 3
+; DOT-NEXT: add a1, a0, a1
+; DOT-NEXT: vl8r.v v8, (a0)
+; DOT-NEXT: vl8r.v v16, (a1)
+; DOT-NEXT: addi a0, sp, 16
+; DOT-NEXT: vs8r.v v16, (a0) # vscale x 64-byte Folded Spill
+; DOT-NEXT: vsetvli a0, zero, e32, m8, ta, ma
+; DOT-NEXT: vmv.v.i v24, 0
+; DOT-NEXT: vmv.v.i v0, 0
+; DOT-NEXT: csrr a0, vlenb
+; DOT-NEXT: slli a0, a0, 3
+; DOT-NEXT: add a0, sp, a0
+; DOT-NEXT: addi a0, a0, 16
+; DOT-NEXT: vl8r.v v16, (a0) # vscale x 64-byte Folded Reload
+; DOT-NEXT: vqdot.vv v0, v16, v8
+; DOT-NEXT: csrr a0, vlenb
+; DOT-NEXT: slli a0, a0, 4
+; DOT-NEXT: add a0, sp, a0
+; DOT-NEXT: addi a0, a0, 16
+; DOT-NEXT: vl8r.v v8, (a0) # vscale x 64-byte Folded Reload
+; DOT-NEXT: addi a0, sp, 16
+; DOT-NEXT: vl8r.v v16, (a0) # vscale x 64-byte Folded Reload
+; DOT-NEXT: vqdot.vv v24, v8, v16
+; DOT-NEXT: vmv.v.v v8, v0
+; DOT-NEXT: vmv.v.v v16, v24
+; DOT-NEXT: csrr a0, vlenb
+; DOT-NEXT: slli a0, a0, 3
+; DOT-NEXT: mv a1, a0
+; DOT-NEXT: slli a0, a0, 1
+; DOT-NEXT: add a0, a0, a1
+; DOT-NEXT: add sp, sp, a0
+; DOT-NEXT: .cfi_def_cfa sp, 16
+; DOT-NEXT: addi sp, sp, 16
+; DOT-NEXT: .cfi_def_cfa_offset 0
+; DOT-NEXT: ret
entry:
%a.sext = sext <vscale x 128 x i8> %a to <vscale x 128 x i32>
%b.sext = sext <vscale x 128 x i8> %b to <vscale x 128 x i32>
@@ -759,18 +857,25 @@ entry:
}
define <vscale x 4 x i32> @partial_reduce_accum(<vscale x 16 x i8> %a, <vscale x 16 x i8> %b, <vscale x 4 x i32> %accum) {
-; CHECK-LABEL: partial_reduce_accum:
-; CHECK: # %bb.0: # %entry
-; CHECK-NEXT: vsetvli a0, zero, e16, m4, ta, ma
-; CHECK-NEXT: vsext.vf2 v24, v8
-; CHECK-NEXT: vsext.vf2 v28, v10
-; CHECK-NEXT: vwmul.vv v16, v24, v28
-; CHECK-NEXT: vsetvli a0, zero, e32, m2, ta, ma
-; CHECK-NEXT: vadd.vv v8, v18, v20
-; CHECK-NEXT: vadd.vv v10, v12, v16
-; CHECK-NEXT: vadd.vv v10, v22, v10
-; CHECK-NEXT: vadd.vv v8, v8, v10
-; CHECK-NEXT: ret
+; NODOT-LABEL: partial_reduce_accum:
+; NODOT: # %bb.0: # %entry
+; NODOT-NEXT: vsetvli a0, zero, e16, m4, ta, ma
+; NODOT-NEXT: vsext.vf2 v24, v8
+; NODOT-NEXT: vsext.vf2 v28, v10
+; NODOT-NEXT: vwmul.vv v16, v24, v28
+; NODOT-NEXT: vsetvli a0, zero, e32, m2, ta, ma
+; NODOT-NEXT: vadd.vv v8, v18, v20
+; NODOT-NEXT: vadd.vv v10, v12, v16
+; NODOT-NEXT: vadd.vv v10, v22, v10
+; NODOT-NEXT: vadd.vv v8, v8, v10
+; NODOT-NEXT: ret
+;
+; DOT-LABEL: partial_reduce_accum:
+; DOT: # %bb.0: # %entry
+; DOT-NEXT: vsetvli a0, zero, e32, m2, ta, ma
+; DOT-NEXT: vqdot.vv v12, v8, v10
+; DOT-NEXT: vmv.v.v v8, v12
+; DOT-NEXT: ret
entry:
%a.sext = sext <vscale x 16 x i8> %a to <vscale x 16 x i32>
%b.sext = sext <vscale x 16 x i8> %b to <vscale x 16 x i32>
@@ -796,22 +901,30 @@ entry:
}
define <vscale x 1 x i32> @partial_reduce_vqdotu(<vscale x 4 x i8> %a, <vscale x 4 x i8> %b) {
-; CHECK-LABEL: partial_reduce_vqdotu:
-; CHECK: # %bb.0: # %entry
-; CHECK-NEXT: vsetvli a0, zero, e8, mf2, ta, ma
-; CHECK-NEXT: vwmulu.vv v10, v8, v9
-; CHECK-NEXT: csrr a0, vlenb
-; CHECK-NEXT: vsetvli zero, zero, e32, m2, ta, ma
-; CHECK-NEXT: vzext.vf2 v8, v10
-; CHECK-NEXT: srli a0, a0, 3
-; CHECK-NEXT: vsetvli a1, zero, e32, m1, ta, ma
-; CHECK-NEXT: vslidedown.vx v10, v9, a0
-; CHECK-NEXT: vslidedown.vx v11, v8, a0
-; CHECK-NEXT: vsetvli a0, zero, e32, mf2, ta, ma
-; CHECK-NEXT: vadd.vv v8, v10, v8
-; CHECK-NEXT: vadd.vv v9, v11, v9
-; CHECK-NEXT: vadd.vv v8, v9, v8
-; CHECK-NEXT: ret
+; NODOT-LABEL: partial_reduce_vqdotu:
+; NODOT: # %bb.0: # %entry
+; NODOT-NEXT: vsetvli a0, zero, e8, mf2, ta, ma
+; NODOT-NEXT: vwmulu.vv v10, v8, v9
+; NODOT-NEXT: csrr a0, vlenb
+; NODOT-NEXT: vsetvli zero, zero, e32, m2, ta, ma
+; NODOT-NEXT: vzext.vf2 v8, v10
+; NODOT-NEXT: srli a0, a0, 3
+; NODOT-NEXT: vsetvli a1, zero, e32, m1, ta, ma
+; NODOT-NEXT: vslidedown.vx v10, v9, a0
+; NODOT-NEXT: vslidedown.vx v11, v8, a0
+; NODOT-NEXT: vsetvli a0, zero, e32, mf2, ta, ma
+; NODOT-NEXT: vadd.vv v8, v10, v8
+; NODOT-NEXT: vadd.vv v9, v11, v9
+; NODOT-NEXT: vadd.vv v8, v9, v8
+; NODOT-NEXT: ret
+;
+; DOT-LABEL: partial_reduce_vqdotu:
+; DOT: # %bb.0: # %entry
+; DOT-NEXT: vsetvli a0, zero, e32, mf2, ta, ma
+; DOT-NEXT: vmv.v.i v10, 0
+; DOT-NEXT: vqdotu.vv v10, v8, v9
+; DOT-NEXT: vmv1r.v v8, v10
+; DOT-NEXT: ret
entry:
%a.sext = zext <vscale x 4 x i8> %a to <vscale x 4 x i32>
%b.sext = zext <vscale x 4 x i8> %b to <vscale x 4 x i32>
>From 63f1df69706fea84596cfaaa7f35b9f449088a98 Mon Sep 17 00:00:00 2001
From: Philip Reames <preames at rivosinc.com>
Date: Wed, 21 May 2025 14:10:19 -0700
Subject: [PATCH 2/3] Address review comment
---
llvm/lib/Target/RISCV/RISCVISelLowering.cpp | 5 +++--
1 file changed, 3 insertions(+), 2 deletions(-)
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index 59f43761b4105..6d99481c4c9bc 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -1571,7 +1571,8 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
setIndexedStoreAction(ISD::POST_INC, MVT::i32, Legal);
}
- if (Subtarget.hasStdExtZvqdotq()) {
+ // zve32x is broken for partial_reduce_umla, but let's not make it worse.
+ if (Subtarget.hasStdExtZvqdotq() && Subtarget.getRealMinVLen() >= 64) {
setPartialReduceMLAAction(MVT::nxv1i32, MVT::nxv4i8, Custom);
setPartialReduceMLAAction(MVT::nxv2i32, MVT::nxv8i8, Custom);
setPartialReduceMLAAction(MVT::nxv4i32, MVT::nxv16i8, Custom);
@@ -8377,7 +8378,7 @@ SDValue RISCVTargetLowering::lowerADJUST_TRAMPOLINE(SDValue Op,
SDValue RISCVTargetLowering::lowerPARTIAL_REDUCE_MLA(SDValue Op,
SelectionDAG &DAG) const {
- // Currently, only the vqdot and vqdotu case (from zvqdotq) hould be legal.
+ // Currently, only the vqdot and vqdotu case (from zvqdotq) should be legal.
// TODO: There are many other sub-cases we could potentially lower, are
// any of them worthwhile? Ex: via vredsum, vwredsum, vwwmaccu, etc..
// TODO: PARTIAL_REDUCE_*MLA can't represent a vqdotsu currently.
>From 7ce44f4198eb50744325349bbe169a5769b7d876 Mon Sep 17 00:00:00 2001
From: Philip Reames <preames at rivosinc.com>
Date: Thu, 22 May 2025 08:01:53 -0700
Subject: [PATCH 3/3] Try two at review comment
---
llvm/lib/Target/RISCV/RISCVISelLowering.cpp | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index 6d99481c4c9bc..286e4e1eea40f 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -1572,7 +1572,7 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
}
// zve32x is broken for partial_reduce_umla, but let's not make it worse.
- if (Subtarget.hasStdExtZvqdotq() && Subtarget.getRealMinVLen() >= 64) {
+ if (Subtarget.hasStdExtZvqdotq() && Subtarget.getELen() >= 64) {
setPartialReduceMLAAction(MVT::nxv1i32, MVT::nxv4i8, Custom);
setPartialReduceMLAAction(MVT::nxv2i32, MVT::nxv8i8, Custom);
setPartialReduceMLAAction(MVT::nxv4i32, MVT::nxv16i8, Custom);
More information about the llvm-commits
mailing list