[llvm] [RISCV] Initial codegen support for zvqdotq extension (PR #137039)
Philip Reames via llvm-commits
llvm-commits at lists.llvm.org
Wed Apr 23 11:53:21 PDT 2025
https://github.com/preames created https://github.com/llvm/llvm-project/pull/137039
This patch adds pattern matching for the basic usages of the dot product instructions introduced by the experimental zvqdotq extension. It specifically only handles the case where the pattern is feeding a i32 sum reduction as we need reassociate the reduction tree to use these instructions.
The vecreduce_add (sext) and vecreduce_add (zext) cases are included mostly to exercise the VX matchers. For the generic matching, we fail to match due to an order of combine issue which results in the bitcast being separated from the splat.
I chose to do this lowering as an early combine so as to avoid having to integrate the entire logic into the reduction lowering flow. In particular, that would get a lot more complicated as we extend this to handle add-trees feeding the reductions.
>From 381c8ac3cfbbbade495e037a2480e2a8ab694312 Mon Sep 17 00:00:00 2001
From: Philip Reames <preames at rivosinc.com>
Date: Tue, 22 Apr 2025 13:23:57 -0700
Subject: [PATCH] [RISCV] Initial codegen support for zvqdotq extension
This patch adds pattern matching for the basic usages of the dot product
instructions introduced by the experimental zvqdotq extension. It
specifically only handles the case where the pattern is feeding a
i32 sum reduction as we need reassociate the reduction tree to use
these instructions.
The vecreduce_add (sext) and vecreduce_add (zext) cases are included
mostly to exercise the VX matchers. For the generic matching, we
fail to match due to an order of combine issue which results in the
bitcast being separated from the splat.
I chose to do this lowering as an early combine so as to avoid
having to integrate the entire logic into the reduction lowering
flow. In particular, that would get a lot more complicated as
we extend this to handle add-trees feeding the reductions.
---
llvm/lib/Target/RISCV/RISCVISelLowering.cpp | 124 +++++++++-
llvm/lib/Target/RISCV/RISCVISelLowering.h | 7 +-
.../lib/Target/RISCV/RISCVInstrInfoZvqdotq.td | 31 +++
.../RISCV/rvv/fixed-vectors-zvqdotq.ll | 216 ++++++++++++------
4 files changed, 310 insertions(+), 68 deletions(-)
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index 11f2095ac9bce..f0c80da123fb1 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -6956,7 +6956,7 @@ static bool hasPassthruOp(unsigned Opcode) {
Opcode <= RISCVISD::LAST_STRICTFP_OPCODE &&
"not a RISC-V target specific op");
static_assert(
- RISCVISD::LAST_VL_VECTOR_OP - RISCVISD::FIRST_VL_VECTOR_OP == 133 &&
+ RISCVISD::LAST_VL_VECTOR_OP - RISCVISD::FIRST_VL_VECTOR_OP == 136 &&
RISCVISD::LAST_STRICTFP_OPCODE - RISCVISD::FIRST_STRICTFP_OPCODE == 21 &&
"adding target specific op should update this function");
if (Opcode >= RISCVISD::ADD_VL && Opcode <= RISCVISD::VFMAX_VL)
@@ -6980,7 +6980,7 @@ static bool hasMaskOp(unsigned Opcode) {
Opcode <= RISCVISD::LAST_STRICTFP_OPCODE &&
"not a RISC-V target specific op");
static_assert(
- RISCVISD::LAST_VL_VECTOR_OP - RISCVISD::FIRST_VL_VECTOR_OP == 133 &&
+ RISCVISD::LAST_VL_VECTOR_OP - RISCVISD::FIRST_VL_VECTOR_OP == 136 &&
RISCVISD::LAST_STRICTFP_OPCODE - RISCVISD::FIRST_STRICTFP_OPCODE == 21 &&
"adding target specific op should update this function");
if (Opcode >= RISCVISD::TRUNCATE_VECTOR_VL && Opcode <= RISCVISD::SETCC_VL)
@@ -18003,6 +18003,118 @@ static SDValue performBUILD_VECTORCombine(SDNode *N, SelectionDAG &DAG,
DAG.getBuildVector(VT, DL, RHSOps));
}
+static SDValue lowerVQDOT(unsigned Opc, SDValue Op0, SDValue Op1,
+ const SDLoc &DL, SelectionDAG &DAG,
+ const RISCVSubtarget &Subtarget) {
+ assert(RISCVISD::VQDOT_VL == Opc || RISCVISD::VQDOTU_VL == Opc ||
+ RISCVISD::VQDOTSU_VL == Opc);
+ MVT VT = Op0.getSimpleValueType();
+ assert(VT == Op1.getSimpleValueType() &&
+ VT.getVectorElementType() == MVT::i32);
+
+ assert(VT.isFixedLengthVector());
+ MVT ContainerVT = getContainerForFixedLengthVector(DAG, VT, Subtarget);
+ SDValue Passthru = convertToScalableVector(
+ ContainerVT, DAG.getConstant(0, DL, VT), DAG, Subtarget);
+ Op0 = convertToScalableVector(ContainerVT, Op0, DAG, Subtarget);
+ Op1 = convertToScalableVector(ContainerVT, Op1, DAG, Subtarget);
+
+ auto [Mask, VL] = getDefaultVLOps(VT, ContainerVT, DL, DAG, Subtarget);
+ const unsigned Policy = RISCVVType::TAIL_AGNOSTIC | RISCVVType::MASK_AGNOSTIC;
+ SDValue PolicyOp = DAG.getTargetConstant(Policy, DL, Subtarget.getXLenVT());
+ SDValue LocalAccum = DAG.getNode(Opc, DL, ContainerVT,
+ {Op0, Op1, Passthru, Mask, VL, PolicyOp});
+ return convertFromScalableVector(VT, LocalAccum, DAG, Subtarget);
+}
+
+static MVT getQDOTXResultType(MVT OpVT) {
+ ElementCount OpEC = OpVT.getVectorElementCount();
+ assert(OpEC.isKnownMultipleOf(4) && OpVT.getVectorElementType() == MVT::i8);
+ return MVT::getVectorVT(MVT::i32, OpEC.divideCoefficientBy(4));
+}
+
+static SDValue foldReduceOperandViaVQDOT(SDValue InVec, const SDLoc &DL,
+ SelectionDAG &DAG,
+ const RISCVSubtarget &Subtarget,
+ const RISCVTargetLowering &TLI) {
+ // Note: We intentionally do not check the legality of the reduction type.
+ // We want to handle the m4/m8 *src* types, and thus need to let illegal
+ // intermediate types flow through here.
+ if (InVec.getValueType().getVectorElementType() != MVT::i32 ||
+ !InVec.getValueType().getVectorElementCount().isKnownMultipleOf(4))
+ return SDValue();
+
+ // reduce (sext a) <--> reduce (mul zext a. zext 1)
+ // reduce (zext a) <--> reduce (mul sext a. sext 1)
+ if (InVec.getOpcode() == ISD::ZERO_EXTEND ||
+ InVec.getOpcode() == ISD::SIGN_EXTEND) {
+ SDValue A = InVec.getOperand(0);
+ if (A.getValueType().getVectorElementType() != MVT::i8 ||
+ !TLI.isTypeLegal(A.getValueType()))
+ return SDValue();
+
+ MVT ResVT = getQDOTXResultType(A.getSimpleValueType());
+ A = DAG.getBitcast(ResVT, A);
+ SDValue B = DAG.getConstant(0x01010101, DL, ResVT);
+
+ bool IsSigned = InVec.getOpcode() == ISD::SIGN_EXTEND;
+ unsigned Opc = IsSigned ? RISCVISD::VQDOT_VL : RISCVISD::VQDOTU_VL;
+ return lowerVQDOT(Opc, A, B, DL, DAG, Subtarget);
+ }
+
+ // mul (sext, sext) -> vqdot
+ // mul (zext, zext) -> vqdotu
+ // mul (sext, zext) -> vqdotsu
+ // mul (zext, sext) -> vqdotsu (swapped)
+ // TODO: Improve .vx handling - we end up with a sub-vector insert
+ // which confuses the splat pattern matching. Also, match vqdotus.vx
+ if (InVec.getOpcode() != ISD::MUL)
+ return SDValue();
+
+ SDValue A = InVec.getOperand(0);
+ SDValue B = InVec.getOperand(1);
+ unsigned Opc = 0;
+ if (A.getOpcode() == B.getOpcode()) {
+ if (A.getOpcode() == ISD::SIGN_EXTEND)
+ Opc = RISCVISD::VQDOT_VL;
+ else if (A.getOpcode() == ISD::ZERO_EXTEND)
+ Opc = RISCVISD::VQDOTU_VL;
+ else
+ return SDValue();
+ } else {
+ if (B.getOpcode() != ISD::ZERO_EXTEND)
+ std::swap(A, B);
+ if (A.getOpcode() != ISD::SIGN_EXTEND || B.getOpcode() != ISD::ZERO_EXTEND)
+ return SDValue();
+ Opc = RISCVISD::VQDOTSU_VL;
+ }
+ assert(Opc);
+
+ if (A.getOperand(0).getValueType().getVectorElementType() != MVT::i8 ||
+ A.getOperand(0).getValueType() != B.getOperand(0).getValueType() ||
+ !TLI.isTypeLegal(A.getValueType()))
+ return SDValue();
+
+ MVT ResVT = getQDOTXResultType(A.getOperand(0).getSimpleValueType());
+ A = DAG.getBitcast(ResVT, A.getOperand(0));
+ B = DAG.getBitcast(ResVT, B.getOperand(0));
+ return lowerVQDOT(Opc, A, B, DL, DAG, Subtarget);
+}
+
+static SDValue performVECREDUCECombine(SDNode *N, SelectionDAG &DAG,
+ const RISCVSubtarget &Subtarget,
+ const RISCVTargetLowering &TLI) {
+ if (!Subtarget.hasStdExtZvqdotq())
+ return SDValue();
+
+ SDLoc DL(N);
+ MVT VT = N->getSimpleValueType(0);
+ SDValue InVec = N->getOperand(0);
+ if (SDValue V = foldReduceOperandViaVQDOT(InVec, DL, DAG, Subtarget, TLI))
+ return DAG.getNode(ISD::VECREDUCE_ADD, DL, VT, V);
+ return SDValue();
+}
+
static SDValue performINSERT_VECTOR_ELTCombine(SDNode *N, SelectionDAG &DAG,
const RISCVSubtarget &Subtarget,
const RISCVTargetLowering &TLI) {
@@ -19779,8 +19891,11 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N,
return SDValue();
}
- case ISD::CTPOP:
case ISD::VECREDUCE_ADD:
+ if (SDValue V = performVECREDUCECombine(N, DAG, Subtarget, *this))
+ return V;
+ [[fallthrough]];
+ case ISD::CTPOP:
if (SDValue V = combineToVCPOP(N, DAG, Subtarget))
return V;
break;
@@ -22306,6 +22421,9 @@ const char *RISCVTargetLowering::getTargetNodeName(unsigned Opcode) const {
NODE_NAME_CASE(RI_VZIP2B_VL)
NODE_NAME_CASE(RI_VUNZIP2A_VL)
NODE_NAME_CASE(RI_VUNZIP2B_VL)
+ NODE_NAME_CASE(VQDOT_VL)
+ NODE_NAME_CASE(VQDOTU_VL)
+ NODE_NAME_CASE(VQDOTSU_VL)
NODE_NAME_CASE(READ_CSR)
NODE_NAME_CASE(WRITE_CSR)
NODE_NAME_CASE(SWAP_CSR)
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.h b/llvm/lib/Target/RISCV/RISCVISelLowering.h
index 6e50ab8e1f296..c59baf1fd3e58 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.h
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.h
@@ -412,7 +412,12 @@ enum NodeType : unsigned {
RI_VUNZIP2A_VL,
RI_VUNZIP2B_VL,
- LAST_VL_VECTOR_OP = RI_VUNZIP2B_VL,
+ // zvqdot instructions with additional passthru, mask and VL operands
+ VQDOT_VL,
+ VQDOTU_VL,
+ VQDOTSU_VL,
+
+ LAST_VL_VECTOR_OP = VQDOTSU_VL,
// Read VLENB CSR
READ_VLENB,
diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfoZvqdotq.td b/llvm/lib/Target/RISCV/RISCVInstrInfoZvqdotq.td
index 205fffd5115ee..25192ec7e0db2 100644
--- a/llvm/lib/Target/RISCV/RISCVInstrInfoZvqdotq.td
+++ b/llvm/lib/Target/RISCV/RISCVInstrInfoZvqdotq.td
@@ -26,3 +26,34 @@ let Predicates = [HasStdExtZvqdotq] in {
def VQDOTSU_VX : VALUVX<0b101010, OPMVX, "vqdotsu.vx">;
def VQDOTUS_VX : VALUVX<0b101110, OPMVX, "vqdotus.vx">;
} // Predicates = [HasStdExtZvqdotq]
+
+
+def riscv_vqdot_vl : SDNode<"RISCVISD::VQDOT_VL", SDT_RISCVIntBinOp_VL>;
+def riscv_vqdotu_vl : SDNode<"RISCVISD::VQDOTU_VL", SDT_RISCVIntBinOp_VL>;
+def riscv_vqdotsu_vl : SDNode<"RISCVISD::VQDOTSU_VL", SDT_RISCVIntBinOp_VL>;
+
+multiclass VPseudoVQDOT_VV_VX {
+ foreach m = MxSet<32>.m in {
+ defm "" : VPseudoBinaryV_VV<m>,
+ SchedBinary<"WriteVIALUV", "ReadVIALUV", "ReadVIALUV", m.MX,
+ forcePassthruRead=true>;
+ defm "" : VPseudoBinaryV_VX<m>,
+ SchedBinary<"WriteVIALUX", "ReadVIALUV", "ReadVIALUX", m.MX,
+ forcePassthruRead=true>;
+ }
+}
+
+// TODO: Add pseudo and patterns for vqdotus.vx
+let Predicates = [HasStdExtZvqdotq], mayLoad = 0, mayStore = 0,
+ hasSideEffects = 0 in {
+ defm PseudoVQDOT : VPseudoVQDOT_VV_VX;
+ defm PseudoVQDOTU : VPseudoVQDOT_VV_VX;
+ defm PseudoVQDOTSU : VPseudoVQDOT_VV_VX;
+}
+
+
+defvar AllE32Vectors = [VI32MF2, VI32M1, VI32M2, VI32M4, VI32M8];
+defm : VPatBinaryVL_VV_VX<riscv_vqdot_vl, "PseudoVQDOT", AllE32Vectors>;
+defm : VPatBinaryVL_VV_VX<riscv_vqdotu_vl, "PseudoVQDOTU", AllE32Vectors>;
+defm : VPatBinaryVL_VV_VX<riscv_vqdotsu_vl, "PseudoVQDOTSU", AllE32Vectors>;
+
diff --git a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-zvqdotq.ll b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-zvqdotq.ll
index 25192ea19aab3..e48bc9cdfea4e 100644
--- a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-zvqdotq.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-zvqdotq.ll
@@ -1,21 +1,31 @@
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py
-; RUN: llc -mtriple=riscv32 -mattr=+v -verify-machineinstrs < %s | FileCheck %s
-; RUN: llc -mtriple=riscv64 -mattr=+v -verify-machineinstrs < %s | FileCheck %s
-; RUN: llc -mtriple=riscv32 -mattr=+v,+experimental-zvqdotq -verify-machineinstrs < %s | FileCheck %s
-; RUN: llc -mtriple=riscv64 -mattr=+v,+experimental-zvqdotq -verify-machineinstrs < %s | FileCheck %s
+; RUN: llc -mtriple=riscv32 -mattr=+v -verify-machineinstrs < %s | FileCheck %s --check-prefixes=CHECK,NODOT
+; RUN: llc -mtriple=riscv64 -mattr=+v -verify-machineinstrs < %s | FileCheck %s --check-prefixes=CHECK,NODOT
+; RUN: llc -mtriple=riscv32 -mattr=+v,+experimental-zvqdotq -verify-machineinstrs < %s | FileCheck %s --check-prefixes=CHECK,DOT,DOT32
+; RUN: llc -mtriple=riscv64 -mattr=+v,+experimental-zvqdotq -verify-machineinstrs < %s | FileCheck %s --check-prefixes=CHECK,DOT,DOT64
define i32 @vqdot_vv(<16 x i8> %a, <16 x i8> %b) {
-; CHECK-LABEL: vqdot_vv:
-; CHECK: # %bb.0: # %entry
-; CHECK-NEXT: vsetivli zero, 16, e16, m2, ta, ma
-; CHECK-NEXT: vsext.vf2 v12, v8
-; CHECK-NEXT: vsext.vf2 v14, v9
-; CHECK-NEXT: vwmul.vv v8, v12, v14
-; CHECK-NEXT: vsetvli zero, zero, e32, m4, ta, ma
-; CHECK-NEXT: vmv.s.x v12, zero
-; CHECK-NEXT: vredsum.vs v8, v8, v12
-; CHECK-NEXT: vmv.x.s a0, v8
-; CHECK-NEXT: ret
+; NODOT-LABEL: vqdot_vv:
+; NODOT: # %bb.0: # %entry
+; NODOT-NEXT: vsetivli zero, 16, e16, m2, ta, ma
+; NODOT-NEXT: vsext.vf2 v12, v8
+; NODOT-NEXT: vsext.vf2 v14, v9
+; NODOT-NEXT: vwmul.vv v8, v12, v14
+; NODOT-NEXT: vsetvli zero, zero, e32, m4, ta, ma
+; NODOT-NEXT: vmv.s.x v12, zero
+; NODOT-NEXT: vredsum.vs v8, v8, v12
+; NODOT-NEXT: vmv.x.s a0, v8
+; NODOT-NEXT: ret
+;
+; DOT-LABEL: vqdot_vv:
+; DOT: # %bb.0: # %entry
+; DOT-NEXT: vsetivli zero, 4, e32, m1, ta, ma
+; DOT-NEXT: vmv.v.i v10, 0
+; DOT-NEXT: vqdot.vv v10, v8, v9
+; DOT-NEXT: vmv.s.x v8, zero
+; DOT-NEXT: vredsum.vs v8, v10, v8
+; DOT-NEXT: vmv.x.s a0, v8
+; DOT-NEXT: ret
entry:
%a.sext = sext <16 x i8> %a to <16 x i32>
%b.sext = sext <16 x i8> %b to <16 x i32>
@@ -63,17 +73,27 @@ entry:
}
define i32 @vqdotu_vv(<16 x i8> %a, <16 x i8> %b) {
-; CHECK-LABEL: vqdotu_vv:
-; CHECK: # %bb.0: # %entry
-; CHECK-NEXT: vsetivli zero, 16, e8, m1, ta, ma
-; CHECK-NEXT: vwmulu.vv v10, v8, v9
-; CHECK-NEXT: vsetvli zero, zero, e32, m4, ta, ma
-; CHECK-NEXT: vmv.s.x v8, zero
-; CHECK-NEXT: vsetvli zero, zero, e16, m2, ta, ma
-; CHECK-NEXT: vwredsumu.vs v8, v10, v8
-; CHECK-NEXT: vsetvli zero, zero, e32, m4, ta, ma
-; CHECK-NEXT: vmv.x.s a0, v8
-; CHECK-NEXT: ret
+; NODOT-LABEL: vqdotu_vv:
+; NODOT: # %bb.0: # %entry
+; NODOT-NEXT: vsetivli zero, 16, e8, m1, ta, ma
+; NODOT-NEXT: vwmulu.vv v10, v8, v9
+; NODOT-NEXT: vsetvli zero, zero, e32, m4, ta, ma
+; NODOT-NEXT: vmv.s.x v8, zero
+; NODOT-NEXT: vsetvli zero, zero, e16, m2, ta, ma
+; NODOT-NEXT: vwredsumu.vs v8, v10, v8
+; NODOT-NEXT: vsetvli zero, zero, e32, m4, ta, ma
+; NODOT-NEXT: vmv.x.s a0, v8
+; NODOT-NEXT: ret
+;
+; DOT-LABEL: vqdotu_vv:
+; DOT: # %bb.0: # %entry
+; DOT-NEXT: vsetivli zero, 4, e32, m1, ta, ma
+; DOT-NEXT: vmv.v.i v10, 0
+; DOT-NEXT: vqdotu.vv v10, v8, v9
+; DOT-NEXT: vmv.s.x v8, zero
+; DOT-NEXT: vredsum.vs v8, v10, v8
+; DOT-NEXT: vmv.x.s a0, v8
+; DOT-NEXT: ret
entry:
%a.zext = zext <16 x i8> %a to <16 x i32>
%b.zext = zext <16 x i8> %b to <16 x i32>
@@ -102,17 +122,27 @@ entry:
}
define i32 @vqdotsu_vv(<16 x i8> %a, <16 x i8> %b) {
-; CHECK-LABEL: vqdotsu_vv:
-; CHECK: # %bb.0: # %entry
-; CHECK-NEXT: vsetivli zero, 16, e16, m2, ta, ma
-; CHECK-NEXT: vsext.vf2 v12, v8
-; CHECK-NEXT: vzext.vf2 v14, v9
-; CHECK-NEXT: vwmulsu.vv v8, v12, v14
-; CHECK-NEXT: vsetvli zero, zero, e32, m4, ta, ma
-; CHECK-NEXT: vmv.s.x v12, zero
-; CHECK-NEXT: vredsum.vs v8, v8, v12
-; CHECK-NEXT: vmv.x.s a0, v8
-; CHECK-NEXT: ret
+; NODOT-LABEL: vqdotsu_vv:
+; NODOT: # %bb.0: # %entry
+; NODOT-NEXT: vsetivli zero, 16, e16, m2, ta, ma
+; NODOT-NEXT: vsext.vf2 v12, v8
+; NODOT-NEXT: vzext.vf2 v14, v9
+; NODOT-NEXT: vwmulsu.vv v8, v12, v14
+; NODOT-NEXT: vsetvli zero, zero, e32, m4, ta, ma
+; NODOT-NEXT: vmv.s.x v12, zero
+; NODOT-NEXT: vredsum.vs v8, v8, v12
+; NODOT-NEXT: vmv.x.s a0, v8
+; NODOT-NEXT: ret
+;
+; DOT-LABEL: vqdotsu_vv:
+; DOT: # %bb.0: # %entry
+; DOT-NEXT: vsetivli zero, 4, e32, m1, ta, ma
+; DOT-NEXT: vmv.v.i v10, 0
+; DOT-NEXT: vqdotsu.vv v10, v8, v9
+; DOT-NEXT: vmv.s.x v8, zero
+; DOT-NEXT: vredsum.vs v8, v10, v8
+; DOT-NEXT: vmv.x.s a0, v8
+; DOT-NEXT: ret
entry:
%a.sext = sext <16 x i8> %a to <16 x i32>
%b.zext = zext <16 x i8> %b to <16 x i32>
@@ -122,17 +152,27 @@ entry:
}
define i32 @vqdotsu_vv_swapped(<16 x i8> %a, <16 x i8> %b) {
-; CHECK-LABEL: vqdotsu_vv_swapped:
-; CHECK: # %bb.0: # %entry
-; CHECK-NEXT: vsetivli zero, 16, e16, m2, ta, ma
-; CHECK-NEXT: vsext.vf2 v12, v8
-; CHECK-NEXT: vzext.vf2 v14, v9
-; CHECK-NEXT: vwmulsu.vv v8, v12, v14
-; CHECK-NEXT: vsetvli zero, zero, e32, m4, ta, ma
-; CHECK-NEXT: vmv.s.x v12, zero
-; CHECK-NEXT: vredsum.vs v8, v8, v12
-; CHECK-NEXT: vmv.x.s a0, v8
-; CHECK-NEXT: ret
+; NODOT-LABEL: vqdotsu_vv_swapped:
+; NODOT: # %bb.0: # %entry
+; NODOT-NEXT: vsetivli zero, 16, e16, m2, ta, ma
+; NODOT-NEXT: vsext.vf2 v12, v8
+; NODOT-NEXT: vzext.vf2 v14, v9
+; NODOT-NEXT: vwmulsu.vv v8, v12, v14
+; NODOT-NEXT: vsetvli zero, zero, e32, m4, ta, ma
+; NODOT-NEXT: vmv.s.x v12, zero
+; NODOT-NEXT: vredsum.vs v8, v8, v12
+; NODOT-NEXT: vmv.x.s a0, v8
+; NODOT-NEXT: ret
+;
+; DOT-LABEL: vqdotsu_vv_swapped:
+; DOT: # %bb.0: # %entry
+; DOT-NEXT: vsetivli zero, 4, e32, m1, ta, ma
+; DOT-NEXT: vmv.v.i v10, 0
+; DOT-NEXT: vqdotsu.vv v10, v8, v9
+; DOT-NEXT: vmv.s.x v8, zero
+; DOT-NEXT: vredsum.vs v8, v10, v8
+; DOT-NEXT: vmv.x.s a0, v8
+; DOT-NEXT: ret
entry:
%a.sext = sext <16 x i8> %a to <16 x i32>
%b.zext = zext <16 x i8> %b to <16 x i32>
@@ -181,14 +221,38 @@ entry:
}
define i32 @reduce_of_sext(<16 x i8> %a) {
-; CHECK-LABEL: reduce_of_sext:
-; CHECK: # %bb.0: # %entry
-; CHECK-NEXT: vsetivli zero, 16, e32, m4, ta, ma
-; CHECK-NEXT: vsext.vf4 v12, v8
-; CHECK-NEXT: vmv.s.x v8, zero
-; CHECK-NEXT: vredsum.vs v8, v12, v8
-; CHECK-NEXT: vmv.x.s a0, v8
-; CHECK-NEXT: ret
+; NODOT-LABEL: reduce_of_sext:
+; NODOT: # %bb.0: # %entry
+; NODOT-NEXT: vsetivli zero, 16, e32, m4, ta, ma
+; NODOT-NEXT: vsext.vf4 v12, v8
+; NODOT-NEXT: vmv.s.x v8, zero
+; NODOT-NEXT: vredsum.vs v8, v12, v8
+; NODOT-NEXT: vmv.x.s a0, v8
+; NODOT-NEXT: ret
+;
+; DOT32-LABEL: reduce_of_sext:
+; DOT32: # %bb.0: # %entry
+; DOT32-NEXT: vsetivli zero, 4, e32, m1, ta, ma
+; DOT32-NEXT: vmv.v.i v9, 0
+; DOT32-NEXT: lui a0, 4112
+; DOT32-NEXT: addi a0, a0, 257
+; DOT32-NEXT: vqdot.vx v9, v8, a0
+; DOT32-NEXT: vmv.s.x v8, zero
+; DOT32-NEXT: vredsum.vs v8, v9, v8
+; DOT32-NEXT: vmv.x.s a0, v8
+; DOT32-NEXT: ret
+;
+; DOT64-LABEL: reduce_of_sext:
+; DOT64: # %bb.0: # %entry
+; DOT64-NEXT: vsetivli zero, 4, e32, m1, ta, ma
+; DOT64-NEXT: vmv.v.i v9, 0
+; DOT64-NEXT: lui a0, 4112
+; DOT64-NEXT: addiw a0, a0, 257
+; DOT64-NEXT: vqdot.vx v9, v8, a0
+; DOT64-NEXT: vmv.s.x v8, zero
+; DOT64-NEXT: vredsum.vs v8, v9, v8
+; DOT64-NEXT: vmv.x.s a0, v8
+; DOT64-NEXT: ret
entry:
%a.ext = sext <16 x i8> %a to <16 x i32>
%res = tail call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> %a.ext)
@@ -196,14 +260,38 @@ entry:
}
define i32 @reduce_of_zext(<16 x i8> %a) {
-; CHECK-LABEL: reduce_of_zext:
-; CHECK: # %bb.0: # %entry
-; CHECK-NEXT: vsetivli zero, 16, e32, m4, ta, ma
-; CHECK-NEXT: vzext.vf4 v12, v8
-; CHECK-NEXT: vmv.s.x v8, zero
-; CHECK-NEXT: vredsum.vs v8, v12, v8
-; CHECK-NEXT: vmv.x.s a0, v8
-; CHECK-NEXT: ret
+; NODOT-LABEL: reduce_of_zext:
+; NODOT: # %bb.0: # %entry
+; NODOT-NEXT: vsetivli zero, 16, e32, m4, ta, ma
+; NODOT-NEXT: vzext.vf4 v12, v8
+; NODOT-NEXT: vmv.s.x v8, zero
+; NODOT-NEXT: vredsum.vs v8, v12, v8
+; NODOT-NEXT: vmv.x.s a0, v8
+; NODOT-NEXT: ret
+;
+; DOT32-LABEL: reduce_of_zext:
+; DOT32: # %bb.0: # %entry
+; DOT32-NEXT: vsetivli zero, 4, e32, m1, ta, ma
+; DOT32-NEXT: vmv.v.i v9, 0
+; DOT32-NEXT: lui a0, 4112
+; DOT32-NEXT: addi a0, a0, 257
+; DOT32-NEXT: vqdotu.vx v9, v8, a0
+; DOT32-NEXT: vmv.s.x v8, zero
+; DOT32-NEXT: vredsum.vs v8, v9, v8
+; DOT32-NEXT: vmv.x.s a0, v8
+; DOT32-NEXT: ret
+;
+; DOT64-LABEL: reduce_of_zext:
+; DOT64: # %bb.0: # %entry
+; DOT64-NEXT: vsetivli zero, 4, e32, m1, ta, ma
+; DOT64-NEXT: vmv.v.i v9, 0
+; DOT64-NEXT: lui a0, 4112
+; DOT64-NEXT: addiw a0, a0, 257
+; DOT64-NEXT: vqdotu.vx v9, v8, a0
+; DOT64-NEXT: vmv.s.x v8, zero
+; DOT64-NEXT: vredsum.vs v8, v9, v8
+; DOT64-NEXT: vmv.x.s a0, v8
+; DOT64-NEXT: ret
entry:
%a.ext = zext <16 x i8> %a to <16 x i32>
%res = tail call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> %a.ext)
More information about the llvm-commits
mailing list