[llvm] [RISCV] Migrate zvqdotq reduce matching to use partial_reduce infrastructure (PR #142212)
via llvm-commits
llvm-commits at lists.llvm.org
Fri May 30 13:51:34 PDT 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-backend-risc-v
Author: Philip Reames (preames)
<details>
<summary>Changes</summary>
This involves a codegen regression at the moment due to the issue described in 443cdd0b, but this aligns the lowering paths for this case and makes it less likely future bugs go undetected.
---
Full diff: https://github.com/llvm/llvm-project/pull/142212.diff
3 Files Affected:
- (modified) llvm/lib/Target/RISCV/RISCVISelLowering.cpp (+30-26)
- (modified) llvm/test/CodeGen/RISCV/rvv/fixed-vectors-zvqdotq.ll (+24-48)
- (modified) llvm/test/CodeGen/RISCV/rvv/zvqdotq-sdnode.ll (+24-48)
``````````diff
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index b7fd0c93fa93f..ff2500b8adbc2 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -18338,17 +18338,15 @@ static SDValue foldReduceOperandViaVQDOT(SDValue InVec, const SDLoc &DL,
if (InVec.getOpcode() == ISD::ZERO_EXTEND ||
InVec.getOpcode() == ISD::SIGN_EXTEND) {
SDValue A = InVec.getOperand(0);
- if (A.getValueType().getVectorElementType() != MVT::i8 ||
- !TLI.isTypeLegal(A.getValueType()))
+ EVT OpVT = A.getValueType();
+ if (OpVT.getVectorElementType() != MVT::i8 || !TLI.isTypeLegal(OpVT))
return SDValue();
MVT ResVT = getQDOTXResultType(A.getSimpleValueType());
- A = DAG.getBitcast(ResVT, A);
- SDValue B = DAG.getConstant(0x01010101, DL, ResVT);
-
+ SDValue B = DAG.getConstant(0x1, DL, OpVT);
bool IsSigned = InVec.getOpcode() == ISD::SIGN_EXTEND;
- unsigned Opc = IsSigned ? RISCVISD::VQDOT_VL : RISCVISD::VQDOTU_VL;
- return lowerVQDOT(Opc, A, B, DL, DAG, Subtarget);
+ unsigned Opc = IsSigned ? ISD::PARTIAL_REDUCE_SMLA : ISD::PARTIAL_REDUCE_UMLA;
+ return DAG.getNode(Opc, DL, ResVT, {DAG.getConstant(0, DL, ResVT), A, B});
}
// mul (sext, sext) -> vqdot
@@ -18362,32 +18360,38 @@ static SDValue foldReduceOperandViaVQDOT(SDValue InVec, const SDLoc &DL,
SDValue A = InVec.getOperand(0);
SDValue B = InVec.getOperand(1);
- unsigned Opc = 0;
+
+ if (!ISD::isExtOpcode(A.getOpcode()))
+ return SDValue();
+
+ EVT OpVT = A.getOperand(0).getValueType();
+ if (OpVT.getVectorElementType() != MVT::i8 ||
+ OpVT != B.getOperand(0).getValueType() ||
+ !TLI.isTypeLegal(A.getValueType()))
+ return SDValue();
+
+ MVT ResVT = getQDOTXResultType(OpVT.getSimpleVT());
+ // Use the partial_reduce_*mla path if possible
if (A.getOpcode() == B.getOpcode()) {
- if (A.getOpcode() == ISD::SIGN_EXTEND)
- Opc = RISCVISD::VQDOT_VL;
- else if (A.getOpcode() == ISD::ZERO_EXTEND)
- Opc = RISCVISD::VQDOTU_VL;
- else
- return SDValue();
- } else {
- if (B.getOpcode() != ISD::ZERO_EXTEND)
- std::swap(A, B);
- if (A.getOpcode() != ISD::SIGN_EXTEND || B.getOpcode() != ISD::ZERO_EXTEND)
+ // TODO: handle ANY_EXTEND and zext nonneg here
+ if (A.getOpcode() != ISD::SIGN_EXTEND &&
+ A.getOpcode() != ISD::ZERO_EXTEND)
return SDValue();
- Opc = RISCVISD::VQDOTSU_VL;
- }
- assert(Opc);
- if (A.getOperand(0).getValueType().getVectorElementType() != MVT::i8 ||
- A.getOperand(0).getValueType() != B.getOperand(0).getValueType() ||
- !TLI.isTypeLegal(A.getValueType()))
+ bool IsSigned = A.getOpcode() == ISD::SIGN_EXTEND;
+ unsigned Opc = IsSigned ? ISD::PARTIAL_REDUCE_SMLA : ISD::PARTIAL_REDUCE_UMLA;
+ return DAG.getNode(Opc, DL, ResVT, {DAG.getConstant(0, DL, ResVT), A.getOperand(0), B.getOperand(0)});
+ }
+ // We don't yet have a partial_reduce_sumla node, so directly lower to the
+ // target node instead.
+ if (B.getOpcode() != ISD::ZERO_EXTEND)
+ std::swap(A, B);
+ if (A.getOpcode() != ISD::SIGN_EXTEND || B.getOpcode() != ISD::ZERO_EXTEND)
return SDValue();
- MVT ResVT = getQDOTXResultType(A.getOperand(0).getSimpleValueType());
A = DAG.getBitcast(ResVT, A.getOperand(0));
B = DAG.getBitcast(ResVT, B.getOperand(0));
- return lowerVQDOT(Opc, A, B, DL, DAG, Subtarget);
+ return lowerVQDOT(RISCVISD::VQDOTSU_VL, A, B, DL, DAG, Subtarget);
}
static SDValue performVECREDUCECombine(SDNode *N, SelectionDAG &DAG,
diff --git a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-zvqdotq.ll b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-zvqdotq.ll
index 0237faea9efb7..8ef691622415c 100644
--- a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-zvqdotq.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-zvqdotq.ll
@@ -1,8 +1,8 @@
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py
; RUN: llc -mtriple=riscv32 -mattr=+v -verify-machineinstrs < %s | FileCheck %s --check-prefixes=CHECK,NODOT
; RUN: llc -mtriple=riscv64 -mattr=+v -verify-machineinstrs < %s | FileCheck %s --check-prefixes=CHECK,NODOT
-; RUN: llc -mtriple=riscv32 -mattr=+v,+experimental-zvqdotq -verify-machineinstrs < %s | FileCheck %s --check-prefixes=CHECK,DOT,DOT32
-; RUN: llc -mtriple=riscv64 -mattr=+v,+experimental-zvqdotq -verify-machineinstrs < %s | FileCheck %s --check-prefixes=CHECK,DOT,DOT64
+; RUN: llc -mtriple=riscv32 -mattr=+v,+experimental-zvqdotq -verify-machineinstrs < %s | FileCheck %s --check-prefixes=CHECK,DOT
+; RUN: llc -mtriple=riscv64 -mattr=+v,+experimental-zvqdotq -verify-machineinstrs < %s | FileCheck %s --check-prefixes=CHECK,DOT
define i32 @vqdot_vv(<16 x i8> %a, <16 x i8> %b) {
; NODOT-LABEL: vqdot_vv:
@@ -230,29 +230,17 @@ define i32 @reduce_of_sext(<16 x i8> %a) {
; NODOT-NEXT: vmv.x.s a0, v8
; NODOT-NEXT: ret
;
-; DOT32-LABEL: reduce_of_sext:
-; DOT32: # %bb.0: # %entry
-; DOT32-NEXT: vsetivli zero, 4, e32, m1, ta, ma
-; DOT32-NEXT: vmv.v.i v9, 0
-; DOT32-NEXT: lui a0, 4112
-; DOT32-NEXT: addi a0, a0, 257
-; DOT32-NEXT: vqdot.vx v9, v8, a0
-; DOT32-NEXT: vmv.s.x v8, zero
-; DOT32-NEXT: vredsum.vs v8, v9, v8
-; DOT32-NEXT: vmv.x.s a0, v8
-; DOT32-NEXT: ret
-;
-; DOT64-LABEL: reduce_of_sext:
-; DOT64: # %bb.0: # %entry
-; DOT64-NEXT: vsetivli zero, 4, e32, m1, ta, ma
-; DOT64-NEXT: vmv.v.i v9, 0
-; DOT64-NEXT: lui a0, 4112
-; DOT64-NEXT: addiw a0, a0, 257
-; DOT64-NEXT: vqdot.vx v9, v8, a0
-; DOT64-NEXT: vmv.s.x v8, zero
-; DOT64-NEXT: vredsum.vs v8, v9, v8
-; DOT64-NEXT: vmv.x.s a0, v8
-; DOT64-NEXT: ret
+; DOT-LABEL: reduce_of_sext:
+; DOT: # %bb.0: # %entry
+; DOT-NEXT: vsetivli zero, 16, e8, m1, ta, ma
+; DOT-NEXT: vmv.v.i v9, 1
+; DOT-NEXT: vsetivli zero, 4, e32, m1, ta, ma
+; DOT-NEXT: vmv.v.i v10, 0
+; DOT-NEXT: vqdot.vv v10, v8, v9
+; DOT-NEXT: vmv.s.x v8, zero
+; DOT-NEXT: vredsum.vs v8, v10, v8
+; DOT-NEXT: vmv.x.s a0, v8
+; DOT-NEXT: ret
entry:
%a.ext = sext <16 x i8> %a to <16 x i32>
%res = tail call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> %a.ext)
@@ -269,29 +257,17 @@ define i32 @reduce_of_zext(<16 x i8> %a) {
; NODOT-NEXT: vmv.x.s a0, v8
; NODOT-NEXT: ret
;
-; DOT32-LABEL: reduce_of_zext:
-; DOT32: # %bb.0: # %entry
-; DOT32-NEXT: vsetivli zero, 4, e32, m1, ta, ma
-; DOT32-NEXT: vmv.v.i v9, 0
-; DOT32-NEXT: lui a0, 4112
-; DOT32-NEXT: addi a0, a0, 257
-; DOT32-NEXT: vqdotu.vx v9, v8, a0
-; DOT32-NEXT: vmv.s.x v8, zero
-; DOT32-NEXT: vredsum.vs v8, v9, v8
-; DOT32-NEXT: vmv.x.s a0, v8
-; DOT32-NEXT: ret
-;
-; DOT64-LABEL: reduce_of_zext:
-; DOT64: # %bb.0: # %entry
-; DOT64-NEXT: vsetivli zero, 4, e32, m1, ta, ma
-; DOT64-NEXT: vmv.v.i v9, 0
-; DOT64-NEXT: lui a0, 4112
-; DOT64-NEXT: addiw a0, a0, 257
-; DOT64-NEXT: vqdotu.vx v9, v8, a0
-; DOT64-NEXT: vmv.s.x v8, zero
-; DOT64-NEXT: vredsum.vs v8, v9, v8
-; DOT64-NEXT: vmv.x.s a0, v8
-; DOT64-NEXT: ret
+; DOT-LABEL: reduce_of_zext:
+; DOT: # %bb.0: # %entry
+; DOT-NEXT: vsetivli zero, 16, e8, m1, ta, ma
+; DOT-NEXT: vmv.v.i v9, 1
+; DOT-NEXT: vsetivli zero, 4, e32, m1, ta, ma
+; DOT-NEXT: vmv.v.i v10, 0
+; DOT-NEXT: vqdotu.vv v10, v8, v9
+; DOT-NEXT: vmv.s.x v8, zero
+; DOT-NEXT: vredsum.vs v8, v10, v8
+; DOT-NEXT: vmv.x.s a0, v8
+; DOT-NEXT: ret
entry:
%a.ext = zext <16 x i8> %a to <16 x i32>
%res = tail call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> %a.ext)
diff --git a/llvm/test/CodeGen/RISCV/rvv/zvqdotq-sdnode.ll b/llvm/test/CodeGen/RISCV/rvv/zvqdotq-sdnode.ll
index d0fc915a0d07e..1948904493e8f 100644
--- a/llvm/test/CodeGen/RISCV/rvv/zvqdotq-sdnode.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/zvqdotq-sdnode.ll
@@ -1,8 +1,8 @@
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py
; RUN: llc -mtriple=riscv32 -mattr=+v -verify-machineinstrs < %s | FileCheck %s --check-prefixes=CHECK,NODOT
; RUN: llc -mtriple=riscv64 -mattr=+v -verify-machineinstrs < %s | FileCheck %s --check-prefixes=CHECK,NODOT
-; RUN: llc -mtriple=riscv32 -mattr=+v,+experimental-zvqdotq -verify-machineinstrs < %s | FileCheck %s --check-prefixes=CHECK,DOT,DOT32
-; RUN: llc -mtriple=riscv64 -mattr=+v,+experimental-zvqdotq -verify-machineinstrs < %s | FileCheck %s --check-prefixes=CHECK,DOT,DOT64
+; RUN: llc -mtriple=riscv32 -mattr=+v,+experimental-zvqdotq -verify-machineinstrs < %s | FileCheck %s --check-prefixes=CHECK,DOT
+; RUN: llc -mtriple=riscv64 -mattr=+v,+experimental-zvqdotq -verify-machineinstrs < %s | FileCheck %s --check-prefixes=CHECK,DOT
define i32 @vqdot_vv(<vscale x 16 x i8> %a, <vscale x 16 x i8> %b) {
; NODOT-LABEL: vqdot_vv:
@@ -230,29 +230,17 @@ define i32 @reduce_of_sext(<vscale x 16 x i8> %a) {
; NODOT-NEXT: vmv.x.s a0, v8
; NODOT-NEXT: ret
;
-; DOT32-LABEL: reduce_of_sext:
-; DOT32: # %bb.0: # %entry
-; DOT32-NEXT: vsetvli a0, zero, e32, m2, ta, ma
-; DOT32-NEXT: vmv.v.i v10, 0
-; DOT32-NEXT: lui a0, 4112
-; DOT32-NEXT: addi a0, a0, 257
-; DOT32-NEXT: vqdot.vx v10, v8, a0
-; DOT32-NEXT: vmv.s.x v8, zero
-; DOT32-NEXT: vredsum.vs v8, v10, v8
-; DOT32-NEXT: vmv.x.s a0, v8
-; DOT32-NEXT: ret
-;
-; DOT64-LABEL: reduce_of_sext:
-; DOT64: # %bb.0: # %entry
-; DOT64-NEXT: vsetvli a0, zero, e32, m2, ta, ma
-; DOT64-NEXT: vmv.v.i v10, 0
-; DOT64-NEXT: lui a0, 4112
-; DOT64-NEXT: addiw a0, a0, 257
-; DOT64-NEXT: vqdot.vx v10, v8, a0
-; DOT64-NEXT: vmv.s.x v8, zero
-; DOT64-NEXT: vredsum.vs v8, v10, v8
-; DOT64-NEXT: vmv.x.s a0, v8
-; DOT64-NEXT: ret
+; DOT-LABEL: reduce_of_sext:
+; DOT: # %bb.0: # %entry
+; DOT-NEXT: vsetvli a0, zero, e8, m2, ta, ma
+; DOT-NEXT: vmv.v.i v10, 1
+; DOT-NEXT: vsetvli a0, zero, e32, m2, ta, ma
+; DOT-NEXT: vmv.v.i v12, 0
+; DOT-NEXT: vqdot.vv v12, v8, v10
+; DOT-NEXT: vmv.s.x v8, zero
+; DOT-NEXT: vredsum.vs v8, v12, v8
+; DOT-NEXT: vmv.x.s a0, v8
+; DOT-NEXT: ret
entry:
%a.ext = sext <vscale x 16 x i8> %a to <vscale x 16 x i32>
%res = tail call i32 @llvm.vector.reduce.add.v16i32(<vscale x 16 x i32> %a.ext)
@@ -269,29 +257,17 @@ define i32 @reduce_of_zext(<vscale x 16 x i8> %a) {
; NODOT-NEXT: vmv.x.s a0, v8
; NODOT-NEXT: ret
;
-; DOT32-LABEL: reduce_of_zext:
-; DOT32: # %bb.0: # %entry
-; DOT32-NEXT: vsetvli a0, zero, e32, m2, ta, ma
-; DOT32-NEXT: vmv.v.i v10, 0
-; DOT32-NEXT: lui a0, 4112
-; DOT32-NEXT: addi a0, a0, 257
-; DOT32-NEXT: vqdotu.vx v10, v8, a0
-; DOT32-NEXT: vmv.s.x v8, zero
-; DOT32-NEXT: vredsum.vs v8, v10, v8
-; DOT32-NEXT: vmv.x.s a0, v8
-; DOT32-NEXT: ret
-;
-; DOT64-LABEL: reduce_of_zext:
-; DOT64: # %bb.0: # %entry
-; DOT64-NEXT: vsetvli a0, zero, e32, m2, ta, ma
-; DOT64-NEXT: vmv.v.i v10, 0
-; DOT64-NEXT: lui a0, 4112
-; DOT64-NEXT: addiw a0, a0, 257
-; DOT64-NEXT: vqdotu.vx v10, v8, a0
-; DOT64-NEXT: vmv.s.x v8, zero
-; DOT64-NEXT: vredsum.vs v8, v10, v8
-; DOT64-NEXT: vmv.x.s a0, v8
-; DOT64-NEXT: ret
+; DOT-LABEL: reduce_of_zext:
+; DOT: # %bb.0: # %entry
+; DOT-NEXT: vsetvli a0, zero, e8, m2, ta, ma
+; DOT-NEXT: vmv.v.i v10, 1
+; DOT-NEXT: vsetvli a0, zero, e32, m2, ta, ma
+; DOT-NEXT: vmv.v.i v12, 0
+; DOT-NEXT: vqdotu.vv v12, v8, v10
+; DOT-NEXT: vmv.s.x v8, zero
+; DOT-NEXT: vredsum.vs v8, v12, v8
+; DOT-NEXT: vmv.x.s a0, v8
+; DOT-NEXT: ret
entry:
%a.ext = zext <vscale x 16 x i8> %a to <vscale x 16 x i32>
%res = tail call i32 @llvm.vector.reduce.add.v16i32(<vscale x 16 x i32> %a.ext)
``````````
</details>
https://github.com/llvm/llvm-project/pull/142212
More information about the llvm-commits
mailing list