[llvm] [RISCV] Initial codegen support for zvqdotq extension (PR #137039)
Philip Reames via llvm-commits
llvm-commits at lists.llvm.org
Thu Apr 24 08:59:33 PDT 2025
https://github.com/preames updated https://github.com/llvm/llvm-project/pull/137039
>From 381c8ac3cfbbbade495e037a2480e2a8ab694312 Mon Sep 17 00:00:00 2001
From: Philip Reames <preames at rivosinc.com>
Date: Tue, 22 Apr 2025 13:23:57 -0700
Subject: [PATCH 1/3] [RISCV] Initial codegen support for zvqdotq extension
This patch adds pattern matching for the basic usages of the dot product
instructions introduced by the experimental zvqdotq extension. It
specifically only handles the case where the pattern is feeding a
i32 sum reduction as we need reassociate the reduction tree to use
these instructions.
The vecreduce_add (sext) and vecreduce_add (zext) cases are included
mostly to exercise the VX matchers. For the generic matching, we
fail to match due to an order of combine issue which results in the
bitcast being separated from the splat.
I chose to do this lowering as an early combine so as to avoid
having to integrate the entire logic into the reduction lowering
flow. In particular, that would get a lot more complicated as
we extend this to handle add-trees feeding the reductions.
---
llvm/lib/Target/RISCV/RISCVISelLowering.cpp | 124 +++++++++-
llvm/lib/Target/RISCV/RISCVISelLowering.h | 7 +-
.../lib/Target/RISCV/RISCVInstrInfoZvqdotq.td | 31 +++
.../RISCV/rvv/fixed-vectors-zvqdotq.ll | 216 ++++++++++++------
4 files changed, 310 insertions(+), 68 deletions(-)
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index 11f2095ac9bce..f0c80da123fb1 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -6956,7 +6956,7 @@ static bool hasPassthruOp(unsigned Opcode) {
Opcode <= RISCVISD::LAST_STRICTFP_OPCODE &&
"not a RISC-V target specific op");
static_assert(
- RISCVISD::LAST_VL_VECTOR_OP - RISCVISD::FIRST_VL_VECTOR_OP == 133 &&
+ RISCVISD::LAST_VL_VECTOR_OP - RISCVISD::FIRST_VL_VECTOR_OP == 136 &&
RISCVISD::LAST_STRICTFP_OPCODE - RISCVISD::FIRST_STRICTFP_OPCODE == 21 &&
"adding target specific op should update this function");
if (Opcode >= RISCVISD::ADD_VL && Opcode <= RISCVISD::VFMAX_VL)
@@ -6980,7 +6980,7 @@ static bool hasMaskOp(unsigned Opcode) {
Opcode <= RISCVISD::LAST_STRICTFP_OPCODE &&
"not a RISC-V target specific op");
static_assert(
- RISCVISD::LAST_VL_VECTOR_OP - RISCVISD::FIRST_VL_VECTOR_OP == 133 &&
+ RISCVISD::LAST_VL_VECTOR_OP - RISCVISD::FIRST_VL_VECTOR_OP == 136 &&
RISCVISD::LAST_STRICTFP_OPCODE - RISCVISD::FIRST_STRICTFP_OPCODE == 21 &&
"adding target specific op should update this function");
if (Opcode >= RISCVISD::TRUNCATE_VECTOR_VL && Opcode <= RISCVISD::SETCC_VL)
@@ -18003,6 +18003,118 @@ static SDValue performBUILD_VECTORCombine(SDNode *N, SelectionDAG &DAG,
DAG.getBuildVector(VT, DL, RHSOps));
}
+static SDValue lowerVQDOT(unsigned Opc, SDValue Op0, SDValue Op1,
+ const SDLoc &DL, SelectionDAG &DAG,
+ const RISCVSubtarget &Subtarget) {
+ assert(RISCVISD::VQDOT_VL == Opc || RISCVISD::VQDOTU_VL == Opc ||
+ RISCVISD::VQDOTSU_VL == Opc);
+ MVT VT = Op0.getSimpleValueType();
+ assert(VT == Op1.getSimpleValueType() &&
+ VT.getVectorElementType() == MVT::i32);
+
+ assert(VT.isFixedLengthVector());
+ MVT ContainerVT = getContainerForFixedLengthVector(DAG, VT, Subtarget);
+ SDValue Passthru = convertToScalableVector(
+ ContainerVT, DAG.getConstant(0, DL, VT), DAG, Subtarget);
+ Op0 = convertToScalableVector(ContainerVT, Op0, DAG, Subtarget);
+ Op1 = convertToScalableVector(ContainerVT, Op1, DAG, Subtarget);
+
+ auto [Mask, VL] = getDefaultVLOps(VT, ContainerVT, DL, DAG, Subtarget);
+ const unsigned Policy = RISCVVType::TAIL_AGNOSTIC | RISCVVType::MASK_AGNOSTIC;
+ SDValue PolicyOp = DAG.getTargetConstant(Policy, DL, Subtarget.getXLenVT());
+ SDValue LocalAccum = DAG.getNode(Opc, DL, ContainerVT,
+ {Op0, Op1, Passthru, Mask, VL, PolicyOp});
+ return convertFromScalableVector(VT, LocalAccum, DAG, Subtarget);
+}
+
+static MVT getQDOTXResultType(MVT OpVT) {
+ ElementCount OpEC = OpVT.getVectorElementCount();
+ assert(OpEC.isKnownMultipleOf(4) && OpVT.getVectorElementType() == MVT::i8);
+ return MVT::getVectorVT(MVT::i32, OpEC.divideCoefficientBy(4));
+}
+
+static SDValue foldReduceOperandViaVQDOT(SDValue InVec, const SDLoc &DL,
+ SelectionDAG &DAG,
+ const RISCVSubtarget &Subtarget,
+ const RISCVTargetLowering &TLI) {
+ // Note: We intentionally do not check the legality of the reduction type.
+ // We want to handle the m4/m8 *src* types, and thus need to let illegal
+ // intermediate types flow through here.
+ if (InVec.getValueType().getVectorElementType() != MVT::i32 ||
+ !InVec.getValueType().getVectorElementCount().isKnownMultipleOf(4))
+ return SDValue();
+
+ // reduce (sext a) <--> reduce (mul zext a. zext 1)
+ // reduce (zext a) <--> reduce (mul sext a. sext 1)
+ if (InVec.getOpcode() == ISD::ZERO_EXTEND ||
+ InVec.getOpcode() == ISD::SIGN_EXTEND) {
+ SDValue A = InVec.getOperand(0);
+ if (A.getValueType().getVectorElementType() != MVT::i8 ||
+ !TLI.isTypeLegal(A.getValueType()))
+ return SDValue();
+
+ MVT ResVT = getQDOTXResultType(A.getSimpleValueType());
+ A = DAG.getBitcast(ResVT, A);
+ SDValue B = DAG.getConstant(0x01010101, DL, ResVT);
+
+ bool IsSigned = InVec.getOpcode() == ISD::SIGN_EXTEND;
+ unsigned Opc = IsSigned ? RISCVISD::VQDOT_VL : RISCVISD::VQDOTU_VL;
+ return lowerVQDOT(Opc, A, B, DL, DAG, Subtarget);
+ }
+
+ // mul (sext, sext) -> vqdot
+ // mul (zext, zext) -> vqdotu
+ // mul (sext, zext) -> vqdotsu
+ // mul (zext, sext) -> vqdotsu (swapped)
+ // TODO: Improve .vx handling - we end up with a sub-vector insert
+ // which confuses the splat pattern matching. Also, match vqdotus.vx
+ if (InVec.getOpcode() != ISD::MUL)
+ return SDValue();
+
+ SDValue A = InVec.getOperand(0);
+ SDValue B = InVec.getOperand(1);
+ unsigned Opc = 0;
+ if (A.getOpcode() == B.getOpcode()) {
+ if (A.getOpcode() == ISD::SIGN_EXTEND)
+ Opc = RISCVISD::VQDOT_VL;
+ else if (A.getOpcode() == ISD::ZERO_EXTEND)
+ Opc = RISCVISD::VQDOTU_VL;
+ else
+ return SDValue();
+ } else {
+ if (B.getOpcode() != ISD::ZERO_EXTEND)
+ std::swap(A, B);
+ if (A.getOpcode() != ISD::SIGN_EXTEND || B.getOpcode() != ISD::ZERO_EXTEND)
+ return SDValue();
+ Opc = RISCVISD::VQDOTSU_VL;
+ }
+ assert(Opc);
+
+ if (A.getOperand(0).getValueType().getVectorElementType() != MVT::i8 ||
+ A.getOperand(0).getValueType() != B.getOperand(0).getValueType() ||
+ !TLI.isTypeLegal(A.getValueType()))
+ return SDValue();
+
+ MVT ResVT = getQDOTXResultType(A.getOperand(0).getSimpleValueType());
+ A = DAG.getBitcast(ResVT, A.getOperand(0));
+ B = DAG.getBitcast(ResVT, B.getOperand(0));
+ return lowerVQDOT(Opc, A, B, DL, DAG, Subtarget);
+}
+
+static SDValue performVECREDUCECombine(SDNode *N, SelectionDAG &DAG,
+ const RISCVSubtarget &Subtarget,
+ const RISCVTargetLowering &TLI) {
+ if (!Subtarget.hasStdExtZvqdotq())
+ return SDValue();
+
+ SDLoc DL(N);
+ MVT VT = N->getSimpleValueType(0);
+ SDValue InVec = N->getOperand(0);
+ if (SDValue V = foldReduceOperandViaVQDOT(InVec, DL, DAG, Subtarget, TLI))
+ return DAG.getNode(ISD::VECREDUCE_ADD, DL, VT, V);
+ return SDValue();
+}
+
static SDValue performINSERT_VECTOR_ELTCombine(SDNode *N, SelectionDAG &DAG,
const RISCVSubtarget &Subtarget,
const RISCVTargetLowering &TLI) {
@@ -19779,8 +19891,11 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N,
return SDValue();
}
- case ISD::CTPOP:
case ISD::VECREDUCE_ADD:
+ if (SDValue V = performVECREDUCECombine(N, DAG, Subtarget, *this))
+ return V;
+ [[fallthrough]];
+ case ISD::CTPOP:
if (SDValue V = combineToVCPOP(N, DAG, Subtarget))
return V;
break;
@@ -22306,6 +22421,9 @@ const char *RISCVTargetLowering::getTargetNodeName(unsigned Opcode) const {
NODE_NAME_CASE(RI_VZIP2B_VL)
NODE_NAME_CASE(RI_VUNZIP2A_VL)
NODE_NAME_CASE(RI_VUNZIP2B_VL)
+ NODE_NAME_CASE(VQDOT_VL)
+ NODE_NAME_CASE(VQDOTU_VL)
+ NODE_NAME_CASE(VQDOTSU_VL)
NODE_NAME_CASE(READ_CSR)
NODE_NAME_CASE(WRITE_CSR)
NODE_NAME_CASE(SWAP_CSR)
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.h b/llvm/lib/Target/RISCV/RISCVISelLowering.h
index 6e50ab8e1f296..c59baf1fd3e58 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.h
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.h
@@ -412,7 +412,12 @@ enum NodeType : unsigned {
RI_VUNZIP2A_VL,
RI_VUNZIP2B_VL,
- LAST_VL_VECTOR_OP = RI_VUNZIP2B_VL,
+ // zvqdot instructions with additional passthru, mask and VL operands
+ VQDOT_VL,
+ VQDOTU_VL,
+ VQDOTSU_VL,
+
+ LAST_VL_VECTOR_OP = VQDOTSU_VL,
// Read VLENB CSR
READ_VLENB,
diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfoZvqdotq.td b/llvm/lib/Target/RISCV/RISCVInstrInfoZvqdotq.td
index 205fffd5115ee..25192ec7e0db2 100644
--- a/llvm/lib/Target/RISCV/RISCVInstrInfoZvqdotq.td
+++ b/llvm/lib/Target/RISCV/RISCVInstrInfoZvqdotq.td
@@ -26,3 +26,34 @@ let Predicates = [HasStdExtZvqdotq] in {
def VQDOTSU_VX : VALUVX<0b101010, OPMVX, "vqdotsu.vx">;
def VQDOTUS_VX : VALUVX<0b101110, OPMVX, "vqdotus.vx">;
} // Predicates = [HasStdExtZvqdotq]
+
+
+def riscv_vqdot_vl : SDNode<"RISCVISD::VQDOT_VL", SDT_RISCVIntBinOp_VL>;
+def riscv_vqdotu_vl : SDNode<"RISCVISD::VQDOTU_VL", SDT_RISCVIntBinOp_VL>;
+def riscv_vqdotsu_vl : SDNode<"RISCVISD::VQDOTSU_VL", SDT_RISCVIntBinOp_VL>;
+
+multiclass VPseudoVQDOT_VV_VX {
+ foreach m = MxSet<32>.m in {
+ defm "" : VPseudoBinaryV_VV<m>,
+ SchedBinary<"WriteVIALUV", "ReadVIALUV", "ReadVIALUV", m.MX,
+ forcePassthruRead=true>;
+ defm "" : VPseudoBinaryV_VX<m>,
+ SchedBinary<"WriteVIALUX", "ReadVIALUV", "ReadVIALUX", m.MX,
+ forcePassthruRead=true>;
+ }
+}
+
+// TODO: Add pseudo and patterns for vqdotus.vx
+let Predicates = [HasStdExtZvqdotq], mayLoad = 0, mayStore = 0,
+ hasSideEffects = 0 in {
+ defm PseudoVQDOT : VPseudoVQDOT_VV_VX;
+ defm PseudoVQDOTU : VPseudoVQDOT_VV_VX;
+ defm PseudoVQDOTSU : VPseudoVQDOT_VV_VX;
+}
+
+
+defvar AllE32Vectors = [VI32MF2, VI32M1, VI32M2, VI32M4, VI32M8];
+defm : VPatBinaryVL_VV_VX<riscv_vqdot_vl, "PseudoVQDOT", AllE32Vectors>;
+defm : VPatBinaryVL_VV_VX<riscv_vqdotu_vl, "PseudoVQDOTU", AllE32Vectors>;
+defm : VPatBinaryVL_VV_VX<riscv_vqdotsu_vl, "PseudoVQDOTSU", AllE32Vectors>;
+
diff --git a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-zvqdotq.ll b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-zvqdotq.ll
index 25192ea19aab3..e48bc9cdfea4e 100644
--- a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-zvqdotq.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-zvqdotq.ll
@@ -1,21 +1,31 @@
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py
-; RUN: llc -mtriple=riscv32 -mattr=+v -verify-machineinstrs < %s | FileCheck %s
-; RUN: llc -mtriple=riscv64 -mattr=+v -verify-machineinstrs < %s | FileCheck %s
-; RUN: llc -mtriple=riscv32 -mattr=+v,+experimental-zvqdotq -verify-machineinstrs < %s | FileCheck %s
-; RUN: llc -mtriple=riscv64 -mattr=+v,+experimental-zvqdotq -verify-machineinstrs < %s | FileCheck %s
+; RUN: llc -mtriple=riscv32 -mattr=+v -verify-machineinstrs < %s | FileCheck %s --check-prefixes=CHECK,NODOT
+; RUN: llc -mtriple=riscv64 -mattr=+v -verify-machineinstrs < %s | FileCheck %s --check-prefixes=CHECK,NODOT
+; RUN: llc -mtriple=riscv32 -mattr=+v,+experimental-zvqdotq -verify-machineinstrs < %s | FileCheck %s --check-prefixes=CHECK,DOT,DOT32
+; RUN: llc -mtriple=riscv64 -mattr=+v,+experimental-zvqdotq -verify-machineinstrs < %s | FileCheck %s --check-prefixes=CHECK,DOT,DOT64
define i32 @vqdot_vv(<16 x i8> %a, <16 x i8> %b) {
-; CHECK-LABEL: vqdot_vv:
-; CHECK: # %bb.0: # %entry
-; CHECK-NEXT: vsetivli zero, 16, e16, m2, ta, ma
-; CHECK-NEXT: vsext.vf2 v12, v8
-; CHECK-NEXT: vsext.vf2 v14, v9
-; CHECK-NEXT: vwmul.vv v8, v12, v14
-; CHECK-NEXT: vsetvli zero, zero, e32, m4, ta, ma
-; CHECK-NEXT: vmv.s.x v12, zero
-; CHECK-NEXT: vredsum.vs v8, v8, v12
-; CHECK-NEXT: vmv.x.s a0, v8
-; CHECK-NEXT: ret
+; NODOT-LABEL: vqdot_vv:
+; NODOT: # %bb.0: # %entry
+; NODOT-NEXT: vsetivli zero, 16, e16, m2, ta, ma
+; NODOT-NEXT: vsext.vf2 v12, v8
+; NODOT-NEXT: vsext.vf2 v14, v9
+; NODOT-NEXT: vwmul.vv v8, v12, v14
+; NODOT-NEXT: vsetvli zero, zero, e32, m4, ta, ma
+; NODOT-NEXT: vmv.s.x v12, zero
+; NODOT-NEXT: vredsum.vs v8, v8, v12
+; NODOT-NEXT: vmv.x.s a0, v8
+; NODOT-NEXT: ret
+;
+; DOT-LABEL: vqdot_vv:
+; DOT: # %bb.0: # %entry
+; DOT-NEXT: vsetivli zero, 4, e32, m1, ta, ma
+; DOT-NEXT: vmv.v.i v10, 0
+; DOT-NEXT: vqdot.vv v10, v8, v9
+; DOT-NEXT: vmv.s.x v8, zero
+; DOT-NEXT: vredsum.vs v8, v10, v8
+; DOT-NEXT: vmv.x.s a0, v8
+; DOT-NEXT: ret
entry:
%a.sext = sext <16 x i8> %a to <16 x i32>
%b.sext = sext <16 x i8> %b to <16 x i32>
@@ -63,17 +73,27 @@ entry:
}
define i32 @vqdotu_vv(<16 x i8> %a, <16 x i8> %b) {
-; CHECK-LABEL: vqdotu_vv:
-; CHECK: # %bb.0: # %entry
-; CHECK-NEXT: vsetivli zero, 16, e8, m1, ta, ma
-; CHECK-NEXT: vwmulu.vv v10, v8, v9
-; CHECK-NEXT: vsetvli zero, zero, e32, m4, ta, ma
-; CHECK-NEXT: vmv.s.x v8, zero
-; CHECK-NEXT: vsetvli zero, zero, e16, m2, ta, ma
-; CHECK-NEXT: vwredsumu.vs v8, v10, v8
-; CHECK-NEXT: vsetvli zero, zero, e32, m4, ta, ma
-; CHECK-NEXT: vmv.x.s a0, v8
-; CHECK-NEXT: ret
+; NODOT-LABEL: vqdotu_vv:
+; NODOT: # %bb.0: # %entry
+; NODOT-NEXT: vsetivli zero, 16, e8, m1, ta, ma
+; NODOT-NEXT: vwmulu.vv v10, v8, v9
+; NODOT-NEXT: vsetvli zero, zero, e32, m4, ta, ma
+; NODOT-NEXT: vmv.s.x v8, zero
+; NODOT-NEXT: vsetvli zero, zero, e16, m2, ta, ma
+; NODOT-NEXT: vwredsumu.vs v8, v10, v8
+; NODOT-NEXT: vsetvli zero, zero, e32, m4, ta, ma
+; NODOT-NEXT: vmv.x.s a0, v8
+; NODOT-NEXT: ret
+;
+; DOT-LABEL: vqdotu_vv:
+; DOT: # %bb.0: # %entry
+; DOT-NEXT: vsetivli zero, 4, e32, m1, ta, ma
+; DOT-NEXT: vmv.v.i v10, 0
+; DOT-NEXT: vqdotu.vv v10, v8, v9
+; DOT-NEXT: vmv.s.x v8, zero
+; DOT-NEXT: vredsum.vs v8, v10, v8
+; DOT-NEXT: vmv.x.s a0, v8
+; DOT-NEXT: ret
entry:
%a.zext = zext <16 x i8> %a to <16 x i32>
%b.zext = zext <16 x i8> %b to <16 x i32>
@@ -102,17 +122,27 @@ entry:
}
define i32 @vqdotsu_vv(<16 x i8> %a, <16 x i8> %b) {
-; CHECK-LABEL: vqdotsu_vv:
-; CHECK: # %bb.0: # %entry
-; CHECK-NEXT: vsetivli zero, 16, e16, m2, ta, ma
-; CHECK-NEXT: vsext.vf2 v12, v8
-; CHECK-NEXT: vzext.vf2 v14, v9
-; CHECK-NEXT: vwmulsu.vv v8, v12, v14
-; CHECK-NEXT: vsetvli zero, zero, e32, m4, ta, ma
-; CHECK-NEXT: vmv.s.x v12, zero
-; CHECK-NEXT: vredsum.vs v8, v8, v12
-; CHECK-NEXT: vmv.x.s a0, v8
-; CHECK-NEXT: ret
+; NODOT-LABEL: vqdotsu_vv:
+; NODOT: # %bb.0: # %entry
+; NODOT-NEXT: vsetivli zero, 16, e16, m2, ta, ma
+; NODOT-NEXT: vsext.vf2 v12, v8
+; NODOT-NEXT: vzext.vf2 v14, v9
+; NODOT-NEXT: vwmulsu.vv v8, v12, v14
+; NODOT-NEXT: vsetvli zero, zero, e32, m4, ta, ma
+; NODOT-NEXT: vmv.s.x v12, zero
+; NODOT-NEXT: vredsum.vs v8, v8, v12
+; NODOT-NEXT: vmv.x.s a0, v8
+; NODOT-NEXT: ret
+;
+; DOT-LABEL: vqdotsu_vv:
+; DOT: # %bb.0: # %entry
+; DOT-NEXT: vsetivli zero, 4, e32, m1, ta, ma
+; DOT-NEXT: vmv.v.i v10, 0
+; DOT-NEXT: vqdotsu.vv v10, v8, v9
+; DOT-NEXT: vmv.s.x v8, zero
+; DOT-NEXT: vredsum.vs v8, v10, v8
+; DOT-NEXT: vmv.x.s a0, v8
+; DOT-NEXT: ret
entry:
%a.sext = sext <16 x i8> %a to <16 x i32>
%b.zext = zext <16 x i8> %b to <16 x i32>
@@ -122,17 +152,27 @@ entry:
}
define i32 @vqdotsu_vv_swapped(<16 x i8> %a, <16 x i8> %b) {
-; CHECK-LABEL: vqdotsu_vv_swapped:
-; CHECK: # %bb.0: # %entry
-; CHECK-NEXT: vsetivli zero, 16, e16, m2, ta, ma
-; CHECK-NEXT: vsext.vf2 v12, v8
-; CHECK-NEXT: vzext.vf2 v14, v9
-; CHECK-NEXT: vwmulsu.vv v8, v12, v14
-; CHECK-NEXT: vsetvli zero, zero, e32, m4, ta, ma
-; CHECK-NEXT: vmv.s.x v12, zero
-; CHECK-NEXT: vredsum.vs v8, v8, v12
-; CHECK-NEXT: vmv.x.s a0, v8
-; CHECK-NEXT: ret
+; NODOT-LABEL: vqdotsu_vv_swapped:
+; NODOT: # %bb.0: # %entry
+; NODOT-NEXT: vsetivli zero, 16, e16, m2, ta, ma
+; NODOT-NEXT: vsext.vf2 v12, v8
+; NODOT-NEXT: vzext.vf2 v14, v9
+; NODOT-NEXT: vwmulsu.vv v8, v12, v14
+; NODOT-NEXT: vsetvli zero, zero, e32, m4, ta, ma
+; NODOT-NEXT: vmv.s.x v12, zero
+; NODOT-NEXT: vredsum.vs v8, v8, v12
+; NODOT-NEXT: vmv.x.s a0, v8
+; NODOT-NEXT: ret
+;
+; DOT-LABEL: vqdotsu_vv_swapped:
+; DOT: # %bb.0: # %entry
+; DOT-NEXT: vsetivli zero, 4, e32, m1, ta, ma
+; DOT-NEXT: vmv.v.i v10, 0
+; DOT-NEXT: vqdotsu.vv v10, v8, v9
+; DOT-NEXT: vmv.s.x v8, zero
+; DOT-NEXT: vredsum.vs v8, v10, v8
+; DOT-NEXT: vmv.x.s a0, v8
+; DOT-NEXT: ret
entry:
%a.sext = sext <16 x i8> %a to <16 x i32>
%b.zext = zext <16 x i8> %b to <16 x i32>
@@ -181,14 +221,38 @@ entry:
}
define i32 @reduce_of_sext(<16 x i8> %a) {
-; CHECK-LABEL: reduce_of_sext:
-; CHECK: # %bb.0: # %entry
-; CHECK-NEXT: vsetivli zero, 16, e32, m4, ta, ma
-; CHECK-NEXT: vsext.vf4 v12, v8
-; CHECK-NEXT: vmv.s.x v8, zero
-; CHECK-NEXT: vredsum.vs v8, v12, v8
-; CHECK-NEXT: vmv.x.s a0, v8
-; CHECK-NEXT: ret
+; NODOT-LABEL: reduce_of_sext:
+; NODOT: # %bb.0: # %entry
+; NODOT-NEXT: vsetivli zero, 16, e32, m4, ta, ma
+; NODOT-NEXT: vsext.vf4 v12, v8
+; NODOT-NEXT: vmv.s.x v8, zero
+; NODOT-NEXT: vredsum.vs v8, v12, v8
+; NODOT-NEXT: vmv.x.s a0, v8
+; NODOT-NEXT: ret
+;
+; DOT32-LABEL: reduce_of_sext:
+; DOT32: # %bb.0: # %entry
+; DOT32-NEXT: vsetivli zero, 4, e32, m1, ta, ma
+; DOT32-NEXT: vmv.v.i v9, 0
+; DOT32-NEXT: lui a0, 4112
+; DOT32-NEXT: addi a0, a0, 257
+; DOT32-NEXT: vqdot.vx v9, v8, a0
+; DOT32-NEXT: vmv.s.x v8, zero
+; DOT32-NEXT: vredsum.vs v8, v9, v8
+; DOT32-NEXT: vmv.x.s a0, v8
+; DOT32-NEXT: ret
+;
+; DOT64-LABEL: reduce_of_sext:
+; DOT64: # %bb.0: # %entry
+; DOT64-NEXT: vsetivli zero, 4, e32, m1, ta, ma
+; DOT64-NEXT: vmv.v.i v9, 0
+; DOT64-NEXT: lui a0, 4112
+; DOT64-NEXT: addiw a0, a0, 257
+; DOT64-NEXT: vqdot.vx v9, v8, a0
+; DOT64-NEXT: vmv.s.x v8, zero
+; DOT64-NEXT: vredsum.vs v8, v9, v8
+; DOT64-NEXT: vmv.x.s a0, v8
+; DOT64-NEXT: ret
entry:
%a.ext = sext <16 x i8> %a to <16 x i32>
%res = tail call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> %a.ext)
@@ -196,14 +260,38 @@ entry:
}
define i32 @reduce_of_zext(<16 x i8> %a) {
-; CHECK-LABEL: reduce_of_zext:
-; CHECK: # %bb.0: # %entry
-; CHECK-NEXT: vsetivli zero, 16, e32, m4, ta, ma
-; CHECK-NEXT: vzext.vf4 v12, v8
-; CHECK-NEXT: vmv.s.x v8, zero
-; CHECK-NEXT: vredsum.vs v8, v12, v8
-; CHECK-NEXT: vmv.x.s a0, v8
-; CHECK-NEXT: ret
+; NODOT-LABEL: reduce_of_zext:
+; NODOT: # %bb.0: # %entry
+; NODOT-NEXT: vsetivli zero, 16, e32, m4, ta, ma
+; NODOT-NEXT: vzext.vf4 v12, v8
+; NODOT-NEXT: vmv.s.x v8, zero
+; NODOT-NEXT: vredsum.vs v8, v12, v8
+; NODOT-NEXT: vmv.x.s a0, v8
+; NODOT-NEXT: ret
+;
+; DOT32-LABEL: reduce_of_zext:
+; DOT32: # %bb.0: # %entry
+; DOT32-NEXT: vsetivli zero, 4, e32, m1, ta, ma
+; DOT32-NEXT: vmv.v.i v9, 0
+; DOT32-NEXT: lui a0, 4112
+; DOT32-NEXT: addi a0, a0, 257
+; DOT32-NEXT: vqdotu.vx v9, v8, a0
+; DOT32-NEXT: vmv.s.x v8, zero
+; DOT32-NEXT: vredsum.vs v8, v9, v8
+; DOT32-NEXT: vmv.x.s a0, v8
+; DOT32-NEXT: ret
+;
+; DOT64-LABEL: reduce_of_zext:
+; DOT64: # %bb.0: # %entry
+; DOT64-NEXT: vsetivli zero, 4, e32, m1, ta, ma
+; DOT64-NEXT: vmv.v.i v9, 0
+; DOT64-NEXT: lui a0, 4112
+; DOT64-NEXT: addiw a0, a0, 257
+; DOT64-NEXT: vqdotu.vx v9, v8, a0
+; DOT64-NEXT: vmv.s.x v8, zero
+; DOT64-NEXT: vredsum.vs v8, v9, v8
+; DOT64-NEXT: vmv.x.s a0, v8
+; DOT64-NEXT: ret
entry:
%a.ext = zext <16 x i8> %a to <16 x i32>
%res = tail call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> %a.ext)
>From f3e724db92dc8c337447455cf1dd03bd5cfe7749 Mon Sep 17 00:00:00 2001
From: Philip Reames <preames at rivosinc.com>
Date: Wed, 23 Apr 2025 14:40:24 -0700
Subject: [PATCH 2/3] Address review comment
---
llvm/lib/Target/RISCV/RISCVISelLowering.cpp | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index f0c80da123fb1..c087678ac4c21 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -18108,7 +18108,7 @@ static SDValue performVECREDUCECombine(SDNode *N, SelectionDAG &DAG,
return SDValue();
SDLoc DL(N);
- MVT VT = N->getSimpleValueType(0);
+ EVT VT = N->getValueType(0);
SDValue InVec = N->getOperand(0);
if (SDValue V = foldReduceOperandViaVQDOT(InVec, DL, DAG, Subtarget, TLI))
return DAG.getNode(ISD::VECREDUCE_ADD, DL, VT, V);
>From 231fed6bd639fa92e0c8c924d6c326020ae70af6 Mon Sep 17 00:00:00 2001
From: Philip Reames <preames at rivosinc.com>
Date: Thu, 24 Apr 2025 08:57:28 -0700
Subject: [PATCH 3/3] Address review comment
---
llvm/lib/Target/RISCV/RISCVISelLowering.cpp | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index c087678ac4c21..8cb3579730e7d 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -18044,8 +18044,8 @@ static SDValue foldReduceOperandViaVQDOT(SDValue InVec, const SDLoc &DL,
!InVec.getValueType().getVectorElementCount().isKnownMultipleOf(4))
return SDValue();
- // reduce (sext a) <--> reduce (mul zext a. zext 1)
- // reduce (zext a) <--> reduce (mul sext a. sext 1)
+ // reduce (zext a) <--> reduce (mul zext a. zext 1)
+ // reduce (sext a) <--> reduce (mul sext a. sext 1)
if (InVec.getOpcode() == ISD::ZERO_EXTEND ||
InVec.getOpcode() == ISD::SIGN_EXTEND) {
SDValue A = InVec.getOperand(0);
More information about the llvm-commits
mailing list