[llvm] [RISCV] Migrate zvqdotq reduce matching to use partial_reduce infrastructure (PR #142212)
Philip Reames via llvm-commits
llvm-commits at lists.llvm.org
Fri May 30 13:59:12 PDT 2025
https://github.com/preames updated https://github.com/llvm/llvm-project/pull/142212
>From f6d8378dad6f7faf4d88853f02aa19bbd60579f4 Mon Sep 17 00:00:00 2001
From: Philip Reames <preames at rivosinc.com>
Date: Fri, 30 May 2025 12:28:02 -0700
Subject: [PATCH 1/2] [RISCV] Migrate zvqdotq reduce matching to use
partial_reduce infrastructure
This involves a slight codegen regression at the moment due to the issue
described in 443cdd0b, but this aligns the lowering paths for this case
and makes it less likely future bugs go undetected.
---
llvm/lib/Target/RISCV/RISCVISelLowering.cpp | 56 ++++++++-------
.../RISCV/rvv/fixed-vectors-zvqdotq.ll | 72 +++++++------------
llvm/test/CodeGen/RISCV/rvv/zvqdotq-sdnode.ll | 72 +++++++------------
3 files changed, 78 insertions(+), 122 deletions(-)
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index b7fd0c93fa93f..ff2500b8adbc2 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -18338,17 +18338,15 @@ static SDValue foldReduceOperandViaVQDOT(SDValue InVec, const SDLoc &DL,
if (InVec.getOpcode() == ISD::ZERO_EXTEND ||
InVec.getOpcode() == ISD::SIGN_EXTEND) {
SDValue A = InVec.getOperand(0);
- if (A.getValueType().getVectorElementType() != MVT::i8 ||
- !TLI.isTypeLegal(A.getValueType()))
+ EVT OpVT = A.getValueType();
+ if (OpVT.getVectorElementType() != MVT::i8 || !TLI.isTypeLegal(OpVT))
return SDValue();
MVT ResVT = getQDOTXResultType(A.getSimpleValueType());
- A = DAG.getBitcast(ResVT, A);
- SDValue B = DAG.getConstant(0x01010101, DL, ResVT);
-
+ SDValue B = DAG.getConstant(0x1, DL, OpVT);
bool IsSigned = InVec.getOpcode() == ISD::SIGN_EXTEND;
- unsigned Opc = IsSigned ? RISCVISD::VQDOT_VL : RISCVISD::VQDOTU_VL;
- return lowerVQDOT(Opc, A, B, DL, DAG, Subtarget);
+ unsigned Opc = IsSigned ? ISD::PARTIAL_REDUCE_SMLA : ISD::PARTIAL_REDUCE_UMLA;
+ return DAG.getNode(Opc, DL, ResVT, {DAG.getConstant(0, DL, ResVT), A, B});
}
// mul (sext, sext) -> vqdot
@@ -18362,32 +18360,38 @@ static SDValue foldReduceOperandViaVQDOT(SDValue InVec, const SDLoc &DL,
SDValue A = InVec.getOperand(0);
SDValue B = InVec.getOperand(1);
- unsigned Opc = 0;
+
+ if (!ISD::isExtOpcode(A.getOpcode()))
+ return SDValue();
+
+ EVT OpVT = A.getOperand(0).getValueType();
+ if (OpVT.getVectorElementType() != MVT::i8 ||
+ OpVT != B.getOperand(0).getValueType() ||
+ !TLI.isTypeLegal(A.getValueType()))
+ return SDValue();
+
+ MVT ResVT = getQDOTXResultType(OpVT.getSimpleVT());
+ // Use the partial_reduce_*mla path if possible
if (A.getOpcode() == B.getOpcode()) {
- if (A.getOpcode() == ISD::SIGN_EXTEND)
- Opc = RISCVISD::VQDOT_VL;
- else if (A.getOpcode() == ISD::ZERO_EXTEND)
- Opc = RISCVISD::VQDOTU_VL;
- else
- return SDValue();
- } else {
- if (B.getOpcode() != ISD::ZERO_EXTEND)
- std::swap(A, B);
- if (A.getOpcode() != ISD::SIGN_EXTEND || B.getOpcode() != ISD::ZERO_EXTEND)
+ // TODO: handle ANY_EXTEND and zext nonneg here
+ if (A.getOpcode() != ISD::SIGN_EXTEND &&
+ A.getOpcode() != ISD::ZERO_EXTEND)
return SDValue();
- Opc = RISCVISD::VQDOTSU_VL;
- }
- assert(Opc);
- if (A.getOperand(0).getValueType().getVectorElementType() != MVT::i8 ||
- A.getOperand(0).getValueType() != B.getOperand(0).getValueType() ||
- !TLI.isTypeLegal(A.getValueType()))
+ bool IsSigned = A.getOpcode() == ISD::SIGN_EXTEND;
+ unsigned Opc = IsSigned ? ISD::PARTIAL_REDUCE_SMLA : ISD::PARTIAL_REDUCE_UMLA;
+ return DAG.getNode(Opc, DL, ResVT, {DAG.getConstant(0, DL, ResVT), A.getOperand(0), B.getOperand(0)});
+ }
+ // We don't yet have a partial_reduce_sumla node, so directly lower to the
+ // target node instead.
+ if (B.getOpcode() != ISD::ZERO_EXTEND)
+ std::swap(A, B);
+ if (A.getOpcode() != ISD::SIGN_EXTEND || B.getOpcode() != ISD::ZERO_EXTEND)
return SDValue();
- MVT ResVT = getQDOTXResultType(A.getOperand(0).getSimpleValueType());
A = DAG.getBitcast(ResVT, A.getOperand(0));
B = DAG.getBitcast(ResVT, B.getOperand(0));
- return lowerVQDOT(Opc, A, B, DL, DAG, Subtarget);
+ return lowerVQDOT(RISCVISD::VQDOTSU_VL, A, B, DL, DAG, Subtarget);
}
static SDValue performVECREDUCECombine(SDNode *N, SelectionDAG &DAG,
diff --git a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-zvqdotq.ll b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-zvqdotq.ll
index 0237faea9efb7..8ef691622415c 100644
--- a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-zvqdotq.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-zvqdotq.ll
@@ -1,8 +1,8 @@
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py
; RUN: llc -mtriple=riscv32 -mattr=+v -verify-machineinstrs < %s | FileCheck %s --check-prefixes=CHECK,NODOT
; RUN: llc -mtriple=riscv64 -mattr=+v -verify-machineinstrs < %s | FileCheck %s --check-prefixes=CHECK,NODOT
-; RUN: llc -mtriple=riscv32 -mattr=+v,+experimental-zvqdotq -verify-machineinstrs < %s | FileCheck %s --check-prefixes=CHECK,DOT,DOT32
-; RUN: llc -mtriple=riscv64 -mattr=+v,+experimental-zvqdotq -verify-machineinstrs < %s | FileCheck %s --check-prefixes=CHECK,DOT,DOT64
+; RUN: llc -mtriple=riscv32 -mattr=+v,+experimental-zvqdotq -verify-machineinstrs < %s | FileCheck %s --check-prefixes=CHECK,DOT
+; RUN: llc -mtriple=riscv64 -mattr=+v,+experimental-zvqdotq -verify-machineinstrs < %s | FileCheck %s --check-prefixes=CHECK,DOT
define i32 @vqdot_vv(<16 x i8> %a, <16 x i8> %b) {
; NODOT-LABEL: vqdot_vv:
@@ -230,29 +230,17 @@ define i32 @reduce_of_sext(<16 x i8> %a) {
; NODOT-NEXT: vmv.x.s a0, v8
; NODOT-NEXT: ret
;
-; DOT32-LABEL: reduce_of_sext:
-; DOT32: # %bb.0: # %entry
-; DOT32-NEXT: vsetivli zero, 4, e32, m1, ta, ma
-; DOT32-NEXT: vmv.v.i v9, 0
-; DOT32-NEXT: lui a0, 4112
-; DOT32-NEXT: addi a0, a0, 257
-; DOT32-NEXT: vqdot.vx v9, v8, a0
-; DOT32-NEXT: vmv.s.x v8, zero
-; DOT32-NEXT: vredsum.vs v8, v9, v8
-; DOT32-NEXT: vmv.x.s a0, v8
-; DOT32-NEXT: ret
-;
-; DOT64-LABEL: reduce_of_sext:
-; DOT64: # %bb.0: # %entry
-; DOT64-NEXT: vsetivli zero, 4, e32, m1, ta, ma
-; DOT64-NEXT: vmv.v.i v9, 0
-; DOT64-NEXT: lui a0, 4112
-; DOT64-NEXT: addiw a0, a0, 257
-; DOT64-NEXT: vqdot.vx v9, v8, a0
-; DOT64-NEXT: vmv.s.x v8, zero
-; DOT64-NEXT: vredsum.vs v8, v9, v8
-; DOT64-NEXT: vmv.x.s a0, v8
-; DOT64-NEXT: ret
+; DOT-LABEL: reduce_of_sext:
+; DOT: # %bb.0: # %entry
+; DOT-NEXT: vsetivli zero, 16, e8, m1, ta, ma
+; DOT-NEXT: vmv.v.i v9, 1
+; DOT-NEXT: vsetivli zero, 4, e32, m1, ta, ma
+; DOT-NEXT: vmv.v.i v10, 0
+; DOT-NEXT: vqdot.vv v10, v8, v9
+; DOT-NEXT: vmv.s.x v8, zero
+; DOT-NEXT: vredsum.vs v8, v10, v8
+; DOT-NEXT: vmv.x.s a0, v8
+; DOT-NEXT: ret
entry:
%a.ext = sext <16 x i8> %a to <16 x i32>
%res = tail call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> %a.ext)
@@ -269,29 +257,17 @@ define i32 @reduce_of_zext(<16 x i8> %a) {
; NODOT-NEXT: vmv.x.s a0, v8
; NODOT-NEXT: ret
;
-; DOT32-LABEL: reduce_of_zext:
-; DOT32: # %bb.0: # %entry
-; DOT32-NEXT: vsetivli zero, 4, e32, m1, ta, ma
-; DOT32-NEXT: vmv.v.i v9, 0
-; DOT32-NEXT: lui a0, 4112
-; DOT32-NEXT: addi a0, a0, 257
-; DOT32-NEXT: vqdotu.vx v9, v8, a0
-; DOT32-NEXT: vmv.s.x v8, zero
-; DOT32-NEXT: vredsum.vs v8, v9, v8
-; DOT32-NEXT: vmv.x.s a0, v8
-; DOT32-NEXT: ret
-;
-; DOT64-LABEL: reduce_of_zext:
-; DOT64: # %bb.0: # %entry
-; DOT64-NEXT: vsetivli zero, 4, e32, m1, ta, ma
-; DOT64-NEXT: vmv.v.i v9, 0
-; DOT64-NEXT: lui a0, 4112
-; DOT64-NEXT: addiw a0, a0, 257
-; DOT64-NEXT: vqdotu.vx v9, v8, a0
-; DOT64-NEXT: vmv.s.x v8, zero
-; DOT64-NEXT: vredsum.vs v8, v9, v8
-; DOT64-NEXT: vmv.x.s a0, v8
-; DOT64-NEXT: ret
+; DOT-LABEL: reduce_of_zext:
+; DOT: # %bb.0: # %entry
+; DOT-NEXT: vsetivli zero, 16, e8, m1, ta, ma
+; DOT-NEXT: vmv.v.i v9, 1
+; DOT-NEXT: vsetivli zero, 4, e32, m1, ta, ma
+; DOT-NEXT: vmv.v.i v10, 0
+; DOT-NEXT: vqdotu.vv v10, v8, v9
+; DOT-NEXT: vmv.s.x v8, zero
+; DOT-NEXT: vredsum.vs v8, v10, v8
+; DOT-NEXT: vmv.x.s a0, v8
+; DOT-NEXT: ret
entry:
%a.ext = zext <16 x i8> %a to <16 x i32>
%res = tail call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> %a.ext)
diff --git a/llvm/test/CodeGen/RISCV/rvv/zvqdotq-sdnode.ll b/llvm/test/CodeGen/RISCV/rvv/zvqdotq-sdnode.ll
index d0fc915a0d07e..1948904493e8f 100644
--- a/llvm/test/CodeGen/RISCV/rvv/zvqdotq-sdnode.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/zvqdotq-sdnode.ll
@@ -1,8 +1,8 @@
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py
; RUN: llc -mtriple=riscv32 -mattr=+v -verify-machineinstrs < %s | FileCheck %s --check-prefixes=CHECK,NODOT
; RUN: llc -mtriple=riscv64 -mattr=+v -verify-machineinstrs < %s | FileCheck %s --check-prefixes=CHECK,NODOT
-; RUN: llc -mtriple=riscv32 -mattr=+v,+experimental-zvqdotq -verify-machineinstrs < %s | FileCheck %s --check-prefixes=CHECK,DOT,DOT32
-; RUN: llc -mtriple=riscv64 -mattr=+v,+experimental-zvqdotq -verify-machineinstrs < %s | FileCheck %s --check-prefixes=CHECK,DOT,DOT64
+; RUN: llc -mtriple=riscv32 -mattr=+v,+experimental-zvqdotq -verify-machineinstrs < %s | FileCheck %s --check-prefixes=CHECK,DOT
+; RUN: llc -mtriple=riscv64 -mattr=+v,+experimental-zvqdotq -verify-machineinstrs < %s | FileCheck %s --check-prefixes=CHECK,DOT
define i32 @vqdot_vv(<vscale x 16 x i8> %a, <vscale x 16 x i8> %b) {
; NODOT-LABEL: vqdot_vv:
@@ -230,29 +230,17 @@ define i32 @reduce_of_sext(<vscale x 16 x i8> %a) {
; NODOT-NEXT: vmv.x.s a0, v8
; NODOT-NEXT: ret
;
-; DOT32-LABEL: reduce_of_sext:
-; DOT32: # %bb.0: # %entry
-; DOT32-NEXT: vsetvli a0, zero, e32, m2, ta, ma
-; DOT32-NEXT: vmv.v.i v10, 0
-; DOT32-NEXT: lui a0, 4112
-; DOT32-NEXT: addi a0, a0, 257
-; DOT32-NEXT: vqdot.vx v10, v8, a0
-; DOT32-NEXT: vmv.s.x v8, zero
-; DOT32-NEXT: vredsum.vs v8, v10, v8
-; DOT32-NEXT: vmv.x.s a0, v8
-; DOT32-NEXT: ret
-;
-; DOT64-LABEL: reduce_of_sext:
-; DOT64: # %bb.0: # %entry
-; DOT64-NEXT: vsetvli a0, zero, e32, m2, ta, ma
-; DOT64-NEXT: vmv.v.i v10, 0
-; DOT64-NEXT: lui a0, 4112
-; DOT64-NEXT: addiw a0, a0, 257
-; DOT64-NEXT: vqdot.vx v10, v8, a0
-; DOT64-NEXT: vmv.s.x v8, zero
-; DOT64-NEXT: vredsum.vs v8, v10, v8
-; DOT64-NEXT: vmv.x.s a0, v8
-; DOT64-NEXT: ret
+; DOT-LABEL: reduce_of_sext:
+; DOT: # %bb.0: # %entry
+; DOT-NEXT: vsetvli a0, zero, e8, m2, ta, ma
+; DOT-NEXT: vmv.v.i v10, 1
+; DOT-NEXT: vsetvli a0, zero, e32, m2, ta, ma
+; DOT-NEXT: vmv.v.i v12, 0
+; DOT-NEXT: vqdot.vv v12, v8, v10
+; DOT-NEXT: vmv.s.x v8, zero
+; DOT-NEXT: vredsum.vs v8, v12, v8
+; DOT-NEXT: vmv.x.s a0, v8
+; DOT-NEXT: ret
entry:
%a.ext = sext <vscale x 16 x i8> %a to <vscale x 16 x i32>
%res = tail call i32 @llvm.vector.reduce.add.v16i32(<vscale x 16 x i32> %a.ext)
@@ -269,29 +257,17 @@ define i32 @reduce_of_zext(<vscale x 16 x i8> %a) {
; NODOT-NEXT: vmv.x.s a0, v8
; NODOT-NEXT: ret
;
-; DOT32-LABEL: reduce_of_zext:
-; DOT32: # %bb.0: # %entry
-; DOT32-NEXT: vsetvli a0, zero, e32, m2, ta, ma
-; DOT32-NEXT: vmv.v.i v10, 0
-; DOT32-NEXT: lui a0, 4112
-; DOT32-NEXT: addi a0, a0, 257
-; DOT32-NEXT: vqdotu.vx v10, v8, a0
-; DOT32-NEXT: vmv.s.x v8, zero
-; DOT32-NEXT: vredsum.vs v8, v10, v8
-; DOT32-NEXT: vmv.x.s a0, v8
-; DOT32-NEXT: ret
-;
-; DOT64-LABEL: reduce_of_zext:
-; DOT64: # %bb.0: # %entry
-; DOT64-NEXT: vsetvli a0, zero, e32, m2, ta, ma
-; DOT64-NEXT: vmv.v.i v10, 0
-; DOT64-NEXT: lui a0, 4112
-; DOT64-NEXT: addiw a0, a0, 257
-; DOT64-NEXT: vqdotu.vx v10, v8, a0
-; DOT64-NEXT: vmv.s.x v8, zero
-; DOT64-NEXT: vredsum.vs v8, v10, v8
-; DOT64-NEXT: vmv.x.s a0, v8
-; DOT64-NEXT: ret
+; DOT-LABEL: reduce_of_zext:
+; DOT: # %bb.0: # %entry
+; DOT-NEXT: vsetvli a0, zero, e8, m2, ta, ma
+; DOT-NEXT: vmv.v.i v10, 1
+; DOT-NEXT: vsetvli a0, zero, e32, m2, ta, ma
+; DOT-NEXT: vmv.v.i v12, 0
+; DOT-NEXT: vqdotu.vv v12, v8, v10
+; DOT-NEXT: vmv.s.x v8, zero
+; DOT-NEXT: vredsum.vs v8, v12, v8
+; DOT-NEXT: vmv.x.s a0, v8
+; DOT-NEXT: ret
entry:
%a.ext = zext <vscale x 16 x i8> %a to <vscale x 16 x i32>
%res = tail call i32 @llvm.vector.reduce.add.v16i32(<vscale x 16 x i32> %a.ext)
>From 5f5472b93663775cef47f9e1a9076746496668b6 Mon Sep 17 00:00:00 2001
From: Philip Reames <preames at rivosinc.com>
Date: Fri, 30 May 2025 13:55:23 -0700
Subject: [PATCH 2/2] clang-format
---
llvm/lib/Target/RISCV/RISCVISelLowering.cpp | 13 ++++++++-----
1 file changed, 8 insertions(+), 5 deletions(-)
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index ff2500b8adbc2..d452279a671d5 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -18345,7 +18345,8 @@ static SDValue foldReduceOperandViaVQDOT(SDValue InVec, const SDLoc &DL,
MVT ResVT = getQDOTXResultType(A.getSimpleValueType());
SDValue B = DAG.getConstant(0x1, DL, OpVT);
bool IsSigned = InVec.getOpcode() == ISD::SIGN_EXTEND;
- unsigned Opc = IsSigned ? ISD::PARTIAL_REDUCE_SMLA : ISD::PARTIAL_REDUCE_UMLA;
+ unsigned Opc =
+ IsSigned ? ISD::PARTIAL_REDUCE_SMLA : ISD::PARTIAL_REDUCE_UMLA;
return DAG.getNode(Opc, DL, ResVT, {DAG.getConstant(0, DL, ResVT), A, B});
}
@@ -18374,13 +18375,15 @@ static SDValue foldReduceOperandViaVQDOT(SDValue InVec, const SDLoc &DL,
// Use the partial_reduce_*mla path if possible
if (A.getOpcode() == B.getOpcode()) {
// TODO: handle ANY_EXTEND and zext nonneg here
- if (A.getOpcode() != ISD::SIGN_EXTEND &&
- A.getOpcode() != ISD::ZERO_EXTEND)
+ if (A.getOpcode() != ISD::SIGN_EXTEND && A.getOpcode() != ISD::ZERO_EXTEND)
return SDValue();
bool IsSigned = A.getOpcode() == ISD::SIGN_EXTEND;
- unsigned Opc = IsSigned ? ISD::PARTIAL_REDUCE_SMLA : ISD::PARTIAL_REDUCE_UMLA;
- return DAG.getNode(Opc, DL, ResVT, {DAG.getConstant(0, DL, ResVT), A.getOperand(0), B.getOperand(0)});
+ unsigned Opc =
+ IsSigned ? ISD::PARTIAL_REDUCE_SMLA : ISD::PARTIAL_REDUCE_UMLA;
+ return DAG.getNode(
+ Opc, DL, ResVT,
+ {DAG.getConstant(0, DL, ResVT), A.getOperand(0), B.getOperand(0)});
}
// We don't yet have a partial_reduce_sumla node, so directly lower to the
// target node instead.
More information about the llvm-commits
mailing list