[llvm] [RISCV] Extend zvqdot matching to handle disjoint or (PR #157901)
Hongyu Chen via llvm-commits
llvm-commits at lists.llvm.org
Wed Sep 10 09:59:57 PDT 2025
https://github.com/XChy created https://github.com/llvm/llvm-project/pull/157901
This patch makes use of pattern matching to handle disjoint or. Also, it simplifies the multiplication matching.
>From faabcc56865d8b8c4e325b972da6f685f88dd0ff Mon Sep 17 00:00:00 2001
From: XChy <xxs_chy at outlook.com>
Date: Thu, 11 Sep 2025 00:51:05 +0800
Subject: [PATCH 1/2] Precommit tests
---
.../RISCV/rvv/fixed-vectors-zvqdotq.ll | 88 +++++++++++++++++++
1 file changed, 88 insertions(+)
diff --git a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-zvqdotq.ll b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-zvqdotq.ll
index a189711d11471..2f22d6519a85a 100644
--- a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-zvqdotq.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-zvqdotq.ll
@@ -1552,6 +1552,94 @@ entry:
%res = call <4 x i32> @llvm.experimental.vector.partial.reduce.add(<4 x i32> zeroinitializer, <16 x i32> %a.ext)
ret <4 x i32> %res
}
+
+define i32 @vqdot_vv_accum_disjoint_or(<16 x i8> %a, <16 x i8> %b, <16 x i32> %x) {
+; CHECK-LABEL: vqdot_vv_accum_disjoint_or:
+; CHECK: # %bb.0: # %entry
+; CHECK-NEXT: vsetivli zero, 16, e16, m2, ta, ma
+; CHECK-NEXT: vsext.vf2 v16, v8
+; CHECK-NEXT: vsext.vf2 v18, v9
+; CHECK-NEXT: vwmul.vv v8, v16, v18
+; CHECK-NEXT: vsetvli zero, zero, e32, m4, ta, ma
+; CHECK-NEXT: vor.vv v8, v8, v12
+; CHECK-NEXT: vmv.s.x v12, zero
+; CHECK-NEXT: vredsum.vs v8, v8, v12
+; CHECK-NEXT: vmv.x.s a0, v8
+; CHECK-NEXT: ret
+entry:
+ %a.sext = sext <16 x i8> %a to <16 x i32>
+ %b.sext = sext <16 x i8> %b to <16 x i32>
+ %mul = mul <16 x i32> %a.sext, %b.sext
+ %add = or disjoint <16 x i32> %mul, %x
+ %sum = tail call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> %add)
+ ret i32 %sum
+}
+
+define i32 @vqdot_vv_accum_or(<16 x i8> %a, <16 x i8> %b, <16 x i32> %x) {
+; CHECK-LABEL: vqdot_vv_accum_or:
+; CHECK: # %bb.0: # %entry
+; CHECK-NEXT: vsetivli zero, 16, e16, m2, ta, ma
+; CHECK-NEXT: vsext.vf2 v16, v8
+; CHECK-NEXT: vsext.vf2 v18, v9
+; CHECK-NEXT: vwmul.vv v8, v16, v18
+; CHECK-NEXT: vsetvli zero, zero, e32, m4, ta, ma
+; CHECK-NEXT: vor.vv v8, v8, v12
+; CHECK-NEXT: vmv.s.x v12, zero
+; CHECK-NEXT: vredsum.vs v8, v8, v12
+; CHECK-NEXT: vmv.x.s a0, v8
+; CHECK-NEXT: ret
+entry:
+ %a.sext = sext <16 x i8> %a to <16 x i32>
+ %b.sext = sext <16 x i8> %b to <16 x i32>
+ %mul = mul <16 x i32> %a.sext, %b.sext
+ %add = or <16 x i32> %mul, %x
+ %sum = tail call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> %add)
+ ret i32 %sum
+}
+
+define i32 @vqdotu_vv_accum_disjoint_or(<16 x i8> %a, <16 x i8> %b, <16 x i32> %x) {
+; CHECK-LABEL: vqdotu_vv_accum_disjoint_or:
+; CHECK: # %bb.0: # %entry
+; CHECK-NEXT: vsetivli zero, 16, e8, m1, ta, ma
+; CHECK-NEXT: vwmulu.vv v10, v8, v9
+; CHECK-NEXT: vsetvli zero, zero, e16, m2, ta, ma
+; CHECK-NEXT: vwaddu.wv v12, v12, v10
+; CHECK-NEXT: vsetvli zero, zero, e32, m4, ta, ma
+; CHECK-NEXT: vmv.s.x v8, zero
+; CHECK-NEXT: vredsum.vs v8, v12, v8
+; CHECK-NEXT: vmv.x.s a0, v8
+; CHECK-NEXT: ret
+entry:
+ %a.zext = zext <16 x i8> %a to <16 x i32>
+ %b.zext = zext <16 x i8> %b to <16 x i32>
+ %mul = mul <16 x i32> %a.zext, %b.zext
+ %add = or disjoint <16 x i32> %mul, %x
+ %sum = tail call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> %add)
+ ret i32 %sum
+}
+
+define i32 @vqdotsu_vv_accum_disjoint_or(<16 x i8> %a, <16 x i8> %b, <16 x i32> %x) {
+; CHECK-LABEL: vqdotsu_vv_accum_disjoint_or:
+; CHECK: # %bb.0: # %entry
+; CHECK-NEXT: vsetivli zero, 16, e16, m2, ta, ma
+; CHECK-NEXT: vsext.vf2 v16, v8
+; CHECK-NEXT: vzext.vf2 v18, v9
+; CHECK-NEXT: vwmulsu.vv v8, v16, v18
+; CHECK-NEXT: vsetvli zero, zero, e32, m4, ta, ma
+; CHECK-NEXT: vor.vv v8, v8, v12
+; CHECK-NEXT: vmv.s.x v12, zero
+; CHECK-NEXT: vredsum.vs v8, v8, v12
+; CHECK-NEXT: vmv.x.s a0, v8
+; CHECK-NEXT: ret
+entry:
+ %a.sext = sext <16 x i8> %a to <16 x i32>
+ %b.zext = zext <16 x i8> %b to <16 x i32>
+ %mul = mul <16 x i32> %a.sext, %b.zext
+ %add = or disjoint <16 x i32> %mul, %x
+ %sum = tail call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> %add)
+ ret i32 %sum
+}
+
;; NOTE: These prefixes are unused and the list is autogenerated. Do not add tests below this line:
; DOT32: {{.*}}
; DOT64: {{.*}}
>From 1ca1d7adf3221de0270b40e33d324716cf90e4fd Mon Sep 17 00:00:00 2001
From: XChy <xxs_chy at outlook.com>
Date: Thu, 11 Sep 2025 00:53:51 +0800
Subject: [PATCH 2/2] [RISCV] Extend zvqdot matching to handle disjoint or
---
llvm/lib/Target/RISCV/RISCVISelLowering.cpp | 15 +--
.../RISCV/rvv/fixed-vectors-zvqdotq.ll | 109 ++++++++++++------
2 files changed, 80 insertions(+), 44 deletions(-)
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index af9430a23e2c9..616c6fc73f65c 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -19121,6 +19121,7 @@ static SDValue foldReduceOperandViaVQDOT(SDValue InVec, const SDLoc &DL,
SelectionDAG &DAG,
const RISCVSubtarget &Subtarget,
const RISCVTargetLowering &TLI) {
+ using namespace SDPatternMatch;
// Note: We intentionally do not check the legality of the reduction type.
// We want to handle the m4/m8 *src* types, and thus need to let illegal
// intermediate types flow through here.
@@ -19128,11 +19129,10 @@ static SDValue foldReduceOperandViaVQDOT(SDValue InVec, const SDLoc &DL,
!InVec.getValueType().getVectorElementCount().isKnownMultipleOf(4))
return SDValue();
- // Recurse through adds (since generic dag canonicalizes to that
- // form). TODO: Handle disjoint or here.
- if (InVec->getOpcode() == ISD::ADD) {
- SDValue A = InVec.getOperand(0);
- SDValue B = InVec.getOperand(1);
+ // Recurse through adds/disjoint ors (since generic dag canonicalizes to that
+ // form).
+ SDValue A, B;
+ if (sd_match(InVec, m_AddLike(m_Value(A), m_Value(B)))) {
SDValue AOpt = foldReduceOperandViaVQDOT(A, DL, DAG, Subtarget, TLI);
SDValue BOpt = foldReduceOperandViaVQDOT(B, DL, DAG, Subtarget, TLI);
if (AOpt || BOpt) {
@@ -19169,12 +19169,9 @@ static SDValue foldReduceOperandViaVQDOT(SDValue InVec, const SDLoc &DL,
// mul (zext a, zext b) -> partial_reduce_umla 0, a, b
// mul (sext a, zext b) -> partial_reduce_ssmla 0, a, b
// mul (zext a, sext b) -> partial_reduce_smla 0, b, a (swapped)
- if (InVec.getOpcode() != ISD::MUL)
+ if (!sd_match(InVec, m_Mul(m_Value(A), m_Value(B))))
return SDValue();
- SDValue A = InVec.getOperand(0);
- SDValue B = InVec.getOperand(1);
-
if (!ISD::isExtOpcode(A.getOpcode()))
return SDValue();
diff --git a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-zvqdotq.ll b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-zvqdotq.ll
index 2f22d6519a85a..684eb609635ef 100644
--- a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-zvqdotq.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-zvqdotq.ll
@@ -1554,18 +1554,31 @@ entry:
}
define i32 @vqdot_vv_accum_disjoint_or(<16 x i8> %a, <16 x i8> %b, <16 x i32> %x) {
-; CHECK-LABEL: vqdot_vv_accum_disjoint_or:
-; CHECK: # %bb.0: # %entry
-; CHECK-NEXT: vsetivli zero, 16, e16, m2, ta, ma
-; CHECK-NEXT: vsext.vf2 v16, v8
-; CHECK-NEXT: vsext.vf2 v18, v9
-; CHECK-NEXT: vwmul.vv v8, v16, v18
-; CHECK-NEXT: vsetvli zero, zero, e32, m4, ta, ma
-; CHECK-NEXT: vor.vv v8, v8, v12
-; CHECK-NEXT: vmv.s.x v12, zero
-; CHECK-NEXT: vredsum.vs v8, v8, v12
-; CHECK-NEXT: vmv.x.s a0, v8
-; CHECK-NEXT: ret
+; NODOT-LABEL: vqdot_vv_accum_disjoint_or:
+; NODOT: # %bb.0: # %entry
+; NODOT-NEXT: vsetivli zero, 16, e16, m2, ta, ma
+; NODOT-NEXT: vsext.vf2 v16, v8
+; NODOT-NEXT: vsext.vf2 v18, v9
+; NODOT-NEXT: vwmul.vv v8, v16, v18
+; NODOT-NEXT: vsetvli zero, zero, e32, m4, ta, ma
+; NODOT-NEXT: vor.vv v8, v8, v12
+; NODOT-NEXT: vmv.s.x v12, zero
+; NODOT-NEXT: vredsum.vs v8, v8, v12
+; NODOT-NEXT: vmv.x.s a0, v8
+; NODOT-NEXT: ret
+;
+; DOT-LABEL: vqdot_vv_accum_disjoint_or:
+; DOT: # %bb.0: # %entry
+; DOT-NEXT: vsetivli zero, 4, e32, m1, ta, ma
+; DOT-NEXT: vmv1r.v v16, v12
+; DOT-NEXT: vqdot.vv v16, v8, v9
+; DOT-NEXT: vsetivli zero, 4, e32, m4, tu, ma
+; DOT-NEXT: vmv.v.v v12, v16
+; DOT-NEXT: vmv.s.x v8, zero
+; DOT-NEXT: vsetivli zero, 16, e32, m4, ta, ma
+; DOT-NEXT: vredsum.vs v8, v12, v8
+; DOT-NEXT: vmv.x.s a0, v8
+; DOT-NEXT: ret
entry:
%a.sext = sext <16 x i8> %a to <16 x i32>
%b.sext = sext <16 x i8> %b to <16 x i32>
@@ -1598,17 +1611,30 @@ entry:
}
define i32 @vqdotu_vv_accum_disjoint_or(<16 x i8> %a, <16 x i8> %b, <16 x i32> %x) {
-; CHECK-LABEL: vqdotu_vv_accum_disjoint_or:
-; CHECK: # %bb.0: # %entry
-; CHECK-NEXT: vsetivli zero, 16, e8, m1, ta, ma
-; CHECK-NEXT: vwmulu.vv v10, v8, v9
-; CHECK-NEXT: vsetvli zero, zero, e16, m2, ta, ma
-; CHECK-NEXT: vwaddu.wv v12, v12, v10
-; CHECK-NEXT: vsetvli zero, zero, e32, m4, ta, ma
-; CHECK-NEXT: vmv.s.x v8, zero
-; CHECK-NEXT: vredsum.vs v8, v12, v8
-; CHECK-NEXT: vmv.x.s a0, v8
-; CHECK-NEXT: ret
+; NODOT-LABEL: vqdotu_vv_accum_disjoint_or:
+; NODOT: # %bb.0: # %entry
+; NODOT-NEXT: vsetivli zero, 16, e8, m1, ta, ma
+; NODOT-NEXT: vwmulu.vv v10, v8, v9
+; NODOT-NEXT: vsetvli zero, zero, e16, m2, ta, ma
+; NODOT-NEXT: vwaddu.wv v12, v12, v10
+; NODOT-NEXT: vsetvli zero, zero, e32, m4, ta, ma
+; NODOT-NEXT: vmv.s.x v8, zero
+; NODOT-NEXT: vredsum.vs v8, v12, v8
+; NODOT-NEXT: vmv.x.s a0, v8
+; NODOT-NEXT: ret
+;
+; DOT-LABEL: vqdotu_vv_accum_disjoint_or:
+; DOT: # %bb.0: # %entry
+; DOT-NEXT: vsetivli zero, 4, e32, m1, ta, ma
+; DOT-NEXT: vmv1r.v v16, v12
+; DOT-NEXT: vqdotu.vv v16, v8, v9
+; DOT-NEXT: vsetivli zero, 4, e32, m4, tu, ma
+; DOT-NEXT: vmv.v.v v12, v16
+; DOT-NEXT: vmv.s.x v8, zero
+; DOT-NEXT: vsetivli zero, 16, e32, m4, ta, ma
+; DOT-NEXT: vredsum.vs v8, v12, v8
+; DOT-NEXT: vmv.x.s a0, v8
+; DOT-NEXT: ret
entry:
%a.zext = zext <16 x i8> %a to <16 x i32>
%b.zext = zext <16 x i8> %b to <16 x i32>
@@ -1619,18 +1645,31 @@ entry:
}
define i32 @vqdotsu_vv_accum_disjoint_or(<16 x i8> %a, <16 x i8> %b, <16 x i32> %x) {
-; CHECK-LABEL: vqdotsu_vv_accum_disjoint_or:
-; CHECK: # %bb.0: # %entry
-; CHECK-NEXT: vsetivli zero, 16, e16, m2, ta, ma
-; CHECK-NEXT: vsext.vf2 v16, v8
-; CHECK-NEXT: vzext.vf2 v18, v9
-; CHECK-NEXT: vwmulsu.vv v8, v16, v18
-; CHECK-NEXT: vsetvli zero, zero, e32, m4, ta, ma
-; CHECK-NEXT: vor.vv v8, v8, v12
-; CHECK-NEXT: vmv.s.x v12, zero
-; CHECK-NEXT: vredsum.vs v8, v8, v12
-; CHECK-NEXT: vmv.x.s a0, v8
-; CHECK-NEXT: ret
+; NODOT-LABEL: vqdotsu_vv_accum_disjoint_or:
+; NODOT: # %bb.0: # %entry
+; NODOT-NEXT: vsetivli zero, 16, e16, m2, ta, ma
+; NODOT-NEXT: vsext.vf2 v16, v8
+; NODOT-NEXT: vzext.vf2 v18, v9
+; NODOT-NEXT: vwmulsu.vv v8, v16, v18
+; NODOT-NEXT: vsetvli zero, zero, e32, m4, ta, ma
+; NODOT-NEXT: vor.vv v8, v8, v12
+; NODOT-NEXT: vmv.s.x v12, zero
+; NODOT-NEXT: vredsum.vs v8, v8, v12
+; NODOT-NEXT: vmv.x.s a0, v8
+; NODOT-NEXT: ret
+;
+; DOT-LABEL: vqdotsu_vv_accum_disjoint_or:
+; DOT: # %bb.0: # %entry
+; DOT-NEXT: vsetivli zero, 4, e32, m1, ta, ma
+; DOT-NEXT: vmv1r.v v16, v12
+; DOT-NEXT: vqdotsu.vv v16, v8, v9
+; DOT-NEXT: vsetivli zero, 4, e32, m4, tu, ma
+; DOT-NEXT: vmv.v.v v12, v16
+; DOT-NEXT: vmv.s.x v8, zero
+; DOT-NEXT: vsetivli zero, 16, e32, m4, ta, ma
+; DOT-NEXT: vredsum.vs v8, v12, v8
+; DOT-NEXT: vmv.x.s a0, v8
+; DOT-NEXT: ret
entry:
%a.sext = sext <16 x i8> %a to <16 x i32>
%b.zext = zext <16 x i8> %b to <16 x i32>
More information about the llvm-commits
mailing list