[llvm] [RISCV] Migrate zvqdotq reduce matching to use partial_reduce infrastructure (PR #142212)
Philip Reames via llvm-commits
llvm-commits at lists.llvm.org
Mon Jun 9 08:10:16 PDT 2025
https://github.com/preames updated https://github.com/llvm/llvm-project/pull/142212
>From f6d8378dad6f7faf4d88853f02aa19bbd60579f4 Mon Sep 17 00:00:00 2001
From: Philip Reames <preames at rivosinc.com>
Date: Fri, 30 May 2025 12:28:02 -0700
Subject: [PATCH 1/5] [RISCV] Migrate zvqdotq reduce matching to use
partial_reduce infrastructure
This involves a slight codegen regression at the moment due to the issue
described in 443cdd0b, but this aligns the lowering paths for this case
and makes it less likely future bugs go undetected.
---
llvm/lib/Target/RISCV/RISCVISelLowering.cpp | 56 ++++++++-------
.../RISCV/rvv/fixed-vectors-zvqdotq.ll | 72 +++++++------------
llvm/test/CodeGen/RISCV/rvv/zvqdotq-sdnode.ll | 72 +++++++------------
3 files changed, 78 insertions(+), 122 deletions(-)
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index b7fd0c93fa93f..ff2500b8adbc2 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -18338,17 +18338,15 @@ static SDValue foldReduceOperandViaVQDOT(SDValue InVec, const SDLoc &DL,
if (InVec.getOpcode() == ISD::ZERO_EXTEND ||
InVec.getOpcode() == ISD::SIGN_EXTEND) {
SDValue A = InVec.getOperand(0);
- if (A.getValueType().getVectorElementType() != MVT::i8 ||
- !TLI.isTypeLegal(A.getValueType()))
+ EVT OpVT = A.getValueType();
+ if (OpVT.getVectorElementType() != MVT::i8 || !TLI.isTypeLegal(OpVT))
return SDValue();
MVT ResVT = getQDOTXResultType(A.getSimpleValueType());
- A = DAG.getBitcast(ResVT, A);
- SDValue B = DAG.getConstant(0x01010101, DL, ResVT);
-
+ SDValue B = DAG.getConstant(0x1, DL, OpVT);
bool IsSigned = InVec.getOpcode() == ISD::SIGN_EXTEND;
- unsigned Opc = IsSigned ? RISCVISD::VQDOT_VL : RISCVISD::VQDOTU_VL;
- return lowerVQDOT(Opc, A, B, DL, DAG, Subtarget);
+ unsigned Opc = IsSigned ? ISD::PARTIAL_REDUCE_SMLA : ISD::PARTIAL_REDUCE_UMLA;
+ return DAG.getNode(Opc, DL, ResVT, {DAG.getConstant(0, DL, ResVT), A, B});
}
// mul (sext, sext) -> vqdot
@@ -18362,32 +18360,38 @@ static SDValue foldReduceOperandViaVQDOT(SDValue InVec, const SDLoc &DL,
SDValue A = InVec.getOperand(0);
SDValue B = InVec.getOperand(1);
- unsigned Opc = 0;
+
+ if (!ISD::isExtOpcode(A.getOpcode()))
+ return SDValue();
+
+ EVT OpVT = A.getOperand(0).getValueType();
+ if (OpVT.getVectorElementType() != MVT::i8 ||
+ OpVT != B.getOperand(0).getValueType() ||
+ !TLI.isTypeLegal(A.getValueType()))
+ return SDValue();
+
+ MVT ResVT = getQDOTXResultType(OpVT.getSimpleVT());
+ // Use the partial_reduce_*mla path if possible
if (A.getOpcode() == B.getOpcode()) {
- if (A.getOpcode() == ISD::SIGN_EXTEND)
- Opc = RISCVISD::VQDOT_VL;
- else if (A.getOpcode() == ISD::ZERO_EXTEND)
- Opc = RISCVISD::VQDOTU_VL;
- else
- return SDValue();
- } else {
- if (B.getOpcode() != ISD::ZERO_EXTEND)
- std::swap(A, B);
- if (A.getOpcode() != ISD::SIGN_EXTEND || B.getOpcode() != ISD::ZERO_EXTEND)
+ // TODO: handle ANY_EXTEND and zext nonneg here
+ if (A.getOpcode() != ISD::SIGN_EXTEND &&
+ A.getOpcode() != ISD::ZERO_EXTEND)
return SDValue();
- Opc = RISCVISD::VQDOTSU_VL;
- }
- assert(Opc);
- if (A.getOperand(0).getValueType().getVectorElementType() != MVT::i8 ||
- A.getOperand(0).getValueType() != B.getOperand(0).getValueType() ||
- !TLI.isTypeLegal(A.getValueType()))
+ bool IsSigned = A.getOpcode() == ISD::SIGN_EXTEND;
+ unsigned Opc = IsSigned ? ISD::PARTIAL_REDUCE_SMLA : ISD::PARTIAL_REDUCE_UMLA;
+ return DAG.getNode(Opc, DL, ResVT, {DAG.getConstant(0, DL, ResVT), A.getOperand(0), B.getOperand(0)});
+ }
+ // We don't yet have a partial_reduce_sumla node, so directly lower to the
+ // target node instead.
+ if (B.getOpcode() != ISD::ZERO_EXTEND)
+ std::swap(A, B);
+ if (A.getOpcode() != ISD::SIGN_EXTEND || B.getOpcode() != ISD::ZERO_EXTEND)
return SDValue();
- MVT ResVT = getQDOTXResultType(A.getOperand(0).getSimpleValueType());
A = DAG.getBitcast(ResVT, A.getOperand(0));
B = DAG.getBitcast(ResVT, B.getOperand(0));
- return lowerVQDOT(Opc, A, B, DL, DAG, Subtarget);
+ return lowerVQDOT(RISCVISD::VQDOTSU_VL, A, B, DL, DAG, Subtarget);
}
static SDValue performVECREDUCECombine(SDNode *N, SelectionDAG &DAG,
diff --git a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-zvqdotq.ll b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-zvqdotq.ll
index 0237faea9efb7..8ef691622415c 100644
--- a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-zvqdotq.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-zvqdotq.ll
@@ -1,8 +1,8 @@
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py
; RUN: llc -mtriple=riscv32 -mattr=+v -verify-machineinstrs < %s | FileCheck %s --check-prefixes=CHECK,NODOT
; RUN: llc -mtriple=riscv64 -mattr=+v -verify-machineinstrs < %s | FileCheck %s --check-prefixes=CHECK,NODOT
-; RUN: llc -mtriple=riscv32 -mattr=+v,+experimental-zvqdotq -verify-machineinstrs < %s | FileCheck %s --check-prefixes=CHECK,DOT,DOT32
-; RUN: llc -mtriple=riscv64 -mattr=+v,+experimental-zvqdotq -verify-machineinstrs < %s | FileCheck %s --check-prefixes=CHECK,DOT,DOT64
+; RUN: llc -mtriple=riscv32 -mattr=+v,+experimental-zvqdotq -verify-machineinstrs < %s | FileCheck %s --check-prefixes=CHECK,DOT
+; RUN: llc -mtriple=riscv64 -mattr=+v,+experimental-zvqdotq -verify-machineinstrs < %s | FileCheck %s --check-prefixes=CHECK,DOT
define i32 @vqdot_vv(<16 x i8> %a, <16 x i8> %b) {
; NODOT-LABEL: vqdot_vv:
@@ -230,29 +230,17 @@ define i32 @reduce_of_sext(<16 x i8> %a) {
; NODOT-NEXT: vmv.x.s a0, v8
; NODOT-NEXT: ret
;
-; DOT32-LABEL: reduce_of_sext:
-; DOT32: # %bb.0: # %entry
-; DOT32-NEXT: vsetivli zero, 4, e32, m1, ta, ma
-; DOT32-NEXT: vmv.v.i v9, 0
-; DOT32-NEXT: lui a0, 4112
-; DOT32-NEXT: addi a0, a0, 257
-; DOT32-NEXT: vqdot.vx v9, v8, a0
-; DOT32-NEXT: vmv.s.x v8, zero
-; DOT32-NEXT: vredsum.vs v8, v9, v8
-; DOT32-NEXT: vmv.x.s a0, v8
-; DOT32-NEXT: ret
-;
-; DOT64-LABEL: reduce_of_sext:
-; DOT64: # %bb.0: # %entry
-; DOT64-NEXT: vsetivli zero, 4, e32, m1, ta, ma
-; DOT64-NEXT: vmv.v.i v9, 0
-; DOT64-NEXT: lui a0, 4112
-; DOT64-NEXT: addiw a0, a0, 257
-; DOT64-NEXT: vqdot.vx v9, v8, a0
-; DOT64-NEXT: vmv.s.x v8, zero
-; DOT64-NEXT: vredsum.vs v8, v9, v8
-; DOT64-NEXT: vmv.x.s a0, v8
-; DOT64-NEXT: ret
+; DOT-LABEL: reduce_of_sext:
+; DOT: # %bb.0: # %entry
+; DOT-NEXT: vsetivli zero, 16, e8, m1, ta, ma
+; DOT-NEXT: vmv.v.i v9, 1
+; DOT-NEXT: vsetivli zero, 4, e32, m1, ta, ma
+; DOT-NEXT: vmv.v.i v10, 0
+; DOT-NEXT: vqdot.vv v10, v8, v9
+; DOT-NEXT: vmv.s.x v8, zero
+; DOT-NEXT: vredsum.vs v8, v10, v8
+; DOT-NEXT: vmv.x.s a0, v8
+; DOT-NEXT: ret
entry:
%a.ext = sext <16 x i8> %a to <16 x i32>
%res = tail call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> %a.ext)
@@ -269,29 +257,17 @@ define i32 @reduce_of_zext(<16 x i8> %a) {
; NODOT-NEXT: vmv.x.s a0, v8
; NODOT-NEXT: ret
;
-; DOT32-LABEL: reduce_of_zext:
-; DOT32: # %bb.0: # %entry
-; DOT32-NEXT: vsetivli zero, 4, e32, m1, ta, ma
-; DOT32-NEXT: vmv.v.i v9, 0
-; DOT32-NEXT: lui a0, 4112
-; DOT32-NEXT: addi a0, a0, 257
-; DOT32-NEXT: vqdotu.vx v9, v8, a0
-; DOT32-NEXT: vmv.s.x v8, zero
-; DOT32-NEXT: vredsum.vs v8, v9, v8
-; DOT32-NEXT: vmv.x.s a0, v8
-; DOT32-NEXT: ret
-;
-; DOT64-LABEL: reduce_of_zext:
-; DOT64: # %bb.0: # %entry
-; DOT64-NEXT: vsetivli zero, 4, e32, m1, ta, ma
-; DOT64-NEXT: vmv.v.i v9, 0
-; DOT64-NEXT: lui a0, 4112
-; DOT64-NEXT: addiw a0, a0, 257
-; DOT64-NEXT: vqdotu.vx v9, v8, a0
-; DOT64-NEXT: vmv.s.x v8, zero
-; DOT64-NEXT: vredsum.vs v8, v9, v8
-; DOT64-NEXT: vmv.x.s a0, v8
-; DOT64-NEXT: ret
+; DOT-LABEL: reduce_of_zext:
+; DOT: # %bb.0: # %entry
+; DOT-NEXT: vsetivli zero, 16, e8, m1, ta, ma
+; DOT-NEXT: vmv.v.i v9, 1
+; DOT-NEXT: vsetivli zero, 4, e32, m1, ta, ma
+; DOT-NEXT: vmv.v.i v10, 0
+; DOT-NEXT: vqdotu.vv v10, v8, v9
+; DOT-NEXT: vmv.s.x v8, zero
+; DOT-NEXT: vredsum.vs v8, v10, v8
+; DOT-NEXT: vmv.x.s a0, v8
+; DOT-NEXT: ret
entry:
%a.ext = zext <16 x i8> %a to <16 x i32>
%res = tail call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> %a.ext)
diff --git a/llvm/test/CodeGen/RISCV/rvv/zvqdotq-sdnode.ll b/llvm/test/CodeGen/RISCV/rvv/zvqdotq-sdnode.ll
index d0fc915a0d07e..1948904493e8f 100644
--- a/llvm/test/CodeGen/RISCV/rvv/zvqdotq-sdnode.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/zvqdotq-sdnode.ll
@@ -1,8 +1,8 @@
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py
; RUN: llc -mtriple=riscv32 -mattr=+v -verify-machineinstrs < %s | FileCheck %s --check-prefixes=CHECK,NODOT
; RUN: llc -mtriple=riscv64 -mattr=+v -verify-machineinstrs < %s | FileCheck %s --check-prefixes=CHECK,NODOT
-; RUN: llc -mtriple=riscv32 -mattr=+v,+experimental-zvqdotq -verify-machineinstrs < %s | FileCheck %s --check-prefixes=CHECK,DOT,DOT32
-; RUN: llc -mtriple=riscv64 -mattr=+v,+experimental-zvqdotq -verify-machineinstrs < %s | FileCheck %s --check-prefixes=CHECK,DOT,DOT64
+; RUN: llc -mtriple=riscv32 -mattr=+v,+experimental-zvqdotq -verify-machineinstrs < %s | FileCheck %s --check-prefixes=CHECK,DOT
+; RUN: llc -mtriple=riscv64 -mattr=+v,+experimental-zvqdotq -verify-machineinstrs < %s | FileCheck %s --check-prefixes=CHECK,DOT
define i32 @vqdot_vv(<vscale x 16 x i8> %a, <vscale x 16 x i8> %b) {
; NODOT-LABEL: vqdot_vv:
@@ -230,29 +230,17 @@ define i32 @reduce_of_sext(<vscale x 16 x i8> %a) {
; NODOT-NEXT: vmv.x.s a0, v8
; NODOT-NEXT: ret
;
-; DOT32-LABEL: reduce_of_sext:
-; DOT32: # %bb.0: # %entry
-; DOT32-NEXT: vsetvli a0, zero, e32, m2, ta, ma
-; DOT32-NEXT: vmv.v.i v10, 0
-; DOT32-NEXT: lui a0, 4112
-; DOT32-NEXT: addi a0, a0, 257
-; DOT32-NEXT: vqdot.vx v10, v8, a0
-; DOT32-NEXT: vmv.s.x v8, zero
-; DOT32-NEXT: vredsum.vs v8, v10, v8
-; DOT32-NEXT: vmv.x.s a0, v8
-; DOT32-NEXT: ret
-;
-; DOT64-LABEL: reduce_of_sext:
-; DOT64: # %bb.0: # %entry
-; DOT64-NEXT: vsetvli a0, zero, e32, m2, ta, ma
-; DOT64-NEXT: vmv.v.i v10, 0
-; DOT64-NEXT: lui a0, 4112
-; DOT64-NEXT: addiw a0, a0, 257
-; DOT64-NEXT: vqdot.vx v10, v8, a0
-; DOT64-NEXT: vmv.s.x v8, zero
-; DOT64-NEXT: vredsum.vs v8, v10, v8
-; DOT64-NEXT: vmv.x.s a0, v8
-; DOT64-NEXT: ret
+; DOT-LABEL: reduce_of_sext:
+; DOT: # %bb.0: # %entry
+; DOT-NEXT: vsetvli a0, zero, e8, m2, ta, ma
+; DOT-NEXT: vmv.v.i v10, 1
+; DOT-NEXT: vsetvli a0, zero, e32, m2, ta, ma
+; DOT-NEXT: vmv.v.i v12, 0
+; DOT-NEXT: vqdot.vv v12, v8, v10
+; DOT-NEXT: vmv.s.x v8, zero
+; DOT-NEXT: vredsum.vs v8, v12, v8
+; DOT-NEXT: vmv.x.s a0, v8
+; DOT-NEXT: ret
entry:
%a.ext = sext <vscale x 16 x i8> %a to <vscale x 16 x i32>
%res = tail call i32 @llvm.vector.reduce.add.v16i32(<vscale x 16 x i32> %a.ext)
@@ -269,29 +257,17 @@ define i32 @reduce_of_zext(<vscale x 16 x i8> %a) {
; NODOT-NEXT: vmv.x.s a0, v8
; NODOT-NEXT: ret
;
-; DOT32-LABEL: reduce_of_zext:
-; DOT32: # %bb.0: # %entry
-; DOT32-NEXT: vsetvli a0, zero, e32, m2, ta, ma
-; DOT32-NEXT: vmv.v.i v10, 0
-; DOT32-NEXT: lui a0, 4112
-; DOT32-NEXT: addi a0, a0, 257
-; DOT32-NEXT: vqdotu.vx v10, v8, a0
-; DOT32-NEXT: vmv.s.x v8, zero
-; DOT32-NEXT: vredsum.vs v8, v10, v8
-; DOT32-NEXT: vmv.x.s a0, v8
-; DOT32-NEXT: ret
-;
-; DOT64-LABEL: reduce_of_zext:
-; DOT64: # %bb.0: # %entry
-; DOT64-NEXT: vsetvli a0, zero, e32, m2, ta, ma
-; DOT64-NEXT: vmv.v.i v10, 0
-; DOT64-NEXT: lui a0, 4112
-; DOT64-NEXT: addiw a0, a0, 257
-; DOT64-NEXT: vqdotu.vx v10, v8, a0
-; DOT64-NEXT: vmv.s.x v8, zero
-; DOT64-NEXT: vredsum.vs v8, v10, v8
-; DOT64-NEXT: vmv.x.s a0, v8
-; DOT64-NEXT: ret
+; DOT-LABEL: reduce_of_zext:
+; DOT: # %bb.0: # %entry
+; DOT-NEXT: vsetvli a0, zero, e8, m2, ta, ma
+; DOT-NEXT: vmv.v.i v10, 1
+; DOT-NEXT: vsetvli a0, zero, e32, m2, ta, ma
+; DOT-NEXT: vmv.v.i v12, 0
+; DOT-NEXT: vqdotu.vv v12, v8, v10
+; DOT-NEXT: vmv.s.x v8, zero
+; DOT-NEXT: vredsum.vs v8, v12, v8
+; DOT-NEXT: vmv.x.s a0, v8
+; DOT-NEXT: ret
entry:
%a.ext = zext <vscale x 16 x i8> %a to <vscale x 16 x i32>
%res = tail call i32 @llvm.vector.reduce.add.v16i32(<vscale x 16 x i32> %a.ext)
>From 5f5472b93663775cef47f9e1a9076746496668b6 Mon Sep 17 00:00:00 2001
From: Philip Reames <preames at rivosinc.com>
Date: Fri, 30 May 2025 13:55:23 -0700
Subject: [PATCH 2/5] clang-format
---
llvm/lib/Target/RISCV/RISCVISelLowering.cpp | 13 ++++++++-----
1 file changed, 8 insertions(+), 5 deletions(-)
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index ff2500b8adbc2..d452279a671d5 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -18345,7 +18345,8 @@ static SDValue foldReduceOperandViaVQDOT(SDValue InVec, const SDLoc &DL,
MVT ResVT = getQDOTXResultType(A.getSimpleValueType());
SDValue B = DAG.getConstant(0x1, DL, OpVT);
bool IsSigned = InVec.getOpcode() == ISD::SIGN_EXTEND;
- unsigned Opc = IsSigned ? ISD::PARTIAL_REDUCE_SMLA : ISD::PARTIAL_REDUCE_UMLA;
+ unsigned Opc =
+ IsSigned ? ISD::PARTIAL_REDUCE_SMLA : ISD::PARTIAL_REDUCE_UMLA;
return DAG.getNode(Opc, DL, ResVT, {DAG.getConstant(0, DL, ResVT), A, B});
}
@@ -18374,13 +18375,15 @@ static SDValue foldReduceOperandViaVQDOT(SDValue InVec, const SDLoc &DL,
// Use the partial_reduce_*mla path if possible
if (A.getOpcode() == B.getOpcode()) {
// TODO: handle ANY_EXTEND and zext nonneg here
- if (A.getOpcode() != ISD::SIGN_EXTEND &&
- A.getOpcode() != ISD::ZERO_EXTEND)
+ if (A.getOpcode() != ISD::SIGN_EXTEND && A.getOpcode() != ISD::ZERO_EXTEND)
return SDValue();
bool IsSigned = A.getOpcode() == ISD::SIGN_EXTEND;
- unsigned Opc = IsSigned ? ISD::PARTIAL_REDUCE_SMLA : ISD::PARTIAL_REDUCE_UMLA;
- return DAG.getNode(Opc, DL, ResVT, {DAG.getConstant(0, DL, ResVT), A.getOperand(0), B.getOperand(0)});
+ unsigned Opc =
+ IsSigned ? ISD::PARTIAL_REDUCE_SMLA : ISD::PARTIAL_REDUCE_UMLA;
+ return DAG.getNode(
+ Opc, DL, ResVT,
+ {DAG.getConstant(0, DL, ResVT), A.getOperand(0), B.getOperand(0)});
}
// We don't yet have a partial_reduce_sumla node, so directly lower to the
// target node instead.
>From bde6cfd470086486c43d9467bde0fdf63bd4e233 Mon Sep 17 00:00:00 2001
From: Philip Reames <preames at rivosinc.com>
Date: Mon, 9 Jun 2025 07:41:57 -0700
Subject: [PATCH 3/5] Use parial_reduce_sumla
---
llvm/lib/Target/RISCV/RISCVISelLowering.cpp | 10 ++++------
1 file changed, 4 insertions(+), 6 deletions(-)
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index d7e5d6c5a1663..f1ec92d499916 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -18494,7 +18494,6 @@ static SDValue foldReduceOperandViaVQDOT(SDValue InVec, const SDLoc &DL,
return SDValue();
MVT ResVT = getQDOTXResultType(OpVT.getSimpleVT());
- // Use the partial_reduce_*mla path if possible
if (A.getOpcode() == B.getOpcode()) {
// TODO: handle ANY_EXTEND and zext nonneg here
if (A.getOpcode() != ISD::SIGN_EXTEND && A.getOpcode() != ISD::ZERO_EXTEND)
@@ -18507,16 +18506,15 @@ static SDValue foldReduceOperandViaVQDOT(SDValue InVec, const SDLoc &DL,
Opc, DL, ResVT,
{DAG.getConstant(0, DL, ResVT), A.getOperand(0), B.getOperand(0)});
}
- // We don't yet have a partial_reduce_sumla node, so directly lower to the
- // target node instead.
if (B.getOpcode() != ISD::ZERO_EXTEND)
std::swap(A, B);
if (A.getOpcode() != ISD::SIGN_EXTEND || B.getOpcode() != ISD::ZERO_EXTEND)
return SDValue();
- A = DAG.getBitcast(ResVT, A.getOperand(0));
- B = DAG.getBitcast(ResVT, B.getOperand(0));
- return lowerVQDOT(RISCVISD::VQDOTSU_VL, A, B, DL, DAG, Subtarget);
+ unsigned Opc = ISD::PARTIAL_REDUCE_SUMLA;
+ return DAG.getNode(
+ Opc, DL, ResVT,
+ {DAG.getConstant(0, DL, ResVT), A.getOperand(0), B.getOperand(0)});
}
static SDValue performVECREDUCECombine(SDNode *N, SelectionDAG &DAG,
>From e877de075304fab3be7ead112f16f476d32d155c Mon Sep 17 00:00:00 2001
From: Philip Reames <preames at rivosinc.com>
Date: Mon, 9 Jun 2025 07:51:12 -0700
Subject: [PATCH 4/5] Style and comment cleanup
---
llvm/lib/Target/RISCV/RISCVISelLowering.cpp | 44 ++++++++++-----------
1 file changed, 20 insertions(+), 24 deletions(-)
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index f1ec92d499916..718e0faea2a22 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -18455,8 +18455,8 @@ static SDValue foldReduceOperandViaVQDOT(SDValue InVec, const SDLoc &DL,
}
}
- // reduce (zext a) <--> reduce (mul zext a. zext 1)
- // reduce (sext a) <--> reduce (mul sext a. sext 1)
+ // reduce (zext a) <--> partial_reduce_umla 0, a, 1
+ // reduce (sext a) <--> partial_reduce_smla 0, a, 1
if (InVec.getOpcode() == ISD::ZERO_EXTEND ||
InVec.getOpcode() == ISD::SIGN_EXTEND) {
SDValue A = InVec.getOperand(0);
@@ -18472,12 +18472,10 @@ static SDValue foldReduceOperandViaVQDOT(SDValue InVec, const SDLoc &DL,
return DAG.getNode(Opc, DL, ResVT, {DAG.getConstant(0, DL, ResVT), A, B});
}
- // mul (sext, sext) -> vqdot
- // mul (zext, zext) -> vqdotu
- // mul (sext, zext) -> vqdotsu
- // mul (zext, sext) -> vqdotsu (swapped)
- // TODO: Improve .vx handling - we end up with a sub-vector insert
- // which confuses the splat pattern matching. Also, match vqdotus.vx
+ // mul (sext a, sext b) -> partial_reduce_smla 0, a, b
+ // mul (zext a, zext b) -> partial_reduce_umla 0, a, b
+ // mul (sext a, zext b) -> partial_reduce_ssmla 0, a, b
+ // mul (zext a, sext b) -> partial_reduce_smla 0, b, a (swapped)
if (InVec.getOpcode() != ISD::MUL)
return SDValue();
@@ -18493,25 +18491,23 @@ static SDValue foldReduceOperandViaVQDOT(SDValue InVec, const SDLoc &DL,
!TLI.isTypeLegal(A.getValueType()))
return SDValue();
- MVT ResVT = getQDOTXResultType(OpVT.getSimpleVT());
- if (A.getOpcode() == B.getOpcode()) {
- // TODO: handle ANY_EXTEND and zext nonneg here
- if (A.getOpcode() != ISD::SIGN_EXTEND && A.getOpcode() != ISD::ZERO_EXTEND)
- return SDValue();
-
- bool IsSigned = A.getOpcode() == ISD::SIGN_EXTEND;
- unsigned Opc =
- IsSigned ? ISD::PARTIAL_REDUCE_SMLA : ISD::PARTIAL_REDUCE_UMLA;
- return DAG.getNode(
- Opc, DL, ResVT,
- {DAG.getConstant(0, DL, ResVT), A.getOperand(0), B.getOperand(0)});
- }
- if (B.getOpcode() != ISD::ZERO_EXTEND)
+ unsigned Opc;
+ if (A.getOpcode() == ISD::SIGN_EXTEND && B.getOpcode() == ISD::SIGN_EXTEND)
+ Opc = ISD::PARTIAL_REDUCE_SMLA;
+ else if (A.getOpcode() == ISD::ZERO_EXTEND &&
+ B.getOpcode() == ISD::ZERO_EXTEND)
+ Opc = ISD::PARTIAL_REDUCE_UMLA;
+ else if (A.getOpcode() == ISD::SIGN_EXTEND &&
+ B.getOpcode() == ISD::ZERO_EXTEND)
+ Opc = ISD::PARTIAL_REDUCE_SUMLA;
+ else if (A.getOpcode() == ISD::ZERO_EXTEND &&
+ B.getOpcode() == ISD::SIGN_EXTEND) {
+ Opc = ISD::PARTIAL_REDUCE_SUMLA;
std::swap(A, B);
- if (A.getOpcode() != ISD::SIGN_EXTEND || B.getOpcode() != ISD::ZERO_EXTEND)
+ } else
return SDValue();
- unsigned Opc = ISD::PARTIAL_REDUCE_SUMLA;
+ MVT ResVT = getQDOTXResultType(OpVT.getSimpleVT());
return DAG.getNode(
Opc, DL, ResVT,
{DAG.getConstant(0, DL, ResVT), A.getOperand(0), B.getOperand(0)});
>From 91bb24e50771f6e2d720e789553a8d7ede5b405a Mon Sep 17 00:00:00 2001
From: Philip Reames <preames at rivosinc.com>
Date: Mon, 9 Jun 2025 08:06:21 -0700
Subject: [PATCH 5/5] Delete dead code
---
llvm/lib/Target/RISCV/RISCVISelLowering.cpp | 25 ---------------------
1 file changed, 25 deletions(-)
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index 718e0faea2a22..39d9bea063667 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -18372,31 +18372,6 @@ static SDValue performBUILD_VECTORCombine(SDNode *N, SelectionDAG &DAG,
DAG.getBuildVector(VT, DL, RHSOps));
}
-static SDValue lowerVQDOT(unsigned Opc, SDValue Op0, SDValue Op1,
- const SDLoc &DL, SelectionDAG &DAG,
- const RISCVSubtarget &Subtarget) {
- assert(RISCVISD::VQDOT_VL == Opc || RISCVISD::VQDOTU_VL == Opc ||
- RISCVISD::VQDOTSU_VL == Opc);
- MVT VT = Op0.getSimpleValueType();
- assert(VT == Op1.getSimpleValueType() &&
- VT.getVectorElementType() == MVT::i32);
-
- SDValue Passthru = DAG.getConstant(0, DL, VT);
- MVT ContainerVT = VT;
- if (VT.isFixedLengthVector()) {
- ContainerVT = getContainerForFixedLengthVector(DAG, VT, Subtarget);
- Passthru = convertToScalableVector(ContainerVT, Passthru, DAG, Subtarget);
- Op0 = convertToScalableVector(ContainerVT, Op0, DAG, Subtarget);
- Op1 = convertToScalableVector(ContainerVT, Op1, DAG, Subtarget);
- }
- auto [Mask, VL] = getDefaultVLOps(VT, ContainerVT, DL, DAG, Subtarget);
- SDValue LocalAccum = DAG.getNode(Opc, DL, ContainerVT,
- {Op0, Op1, Passthru, Mask, VL});
- if (VT.isFixedLengthVector())
- return convertFromScalableVector(VT, LocalAccum, DAG, Subtarget);
- return LocalAccum;
-}
-
static MVT getQDOTXResultType(MVT OpVT) {
ElementCount OpEC = OpVT.getVectorElementCount();
assert(OpEC.isKnownMultipleOf(4) && OpVT.getVectorElementType() == MVT::i8);
More information about the llvm-commits
mailing list