[llvm] [RISCV] Extend zvqdot matching to handle disjoint or (PR #157901)

Hongyu Chen via llvm-commits llvm-commits at lists.llvm.org
Wed Sep 10 09:59:57 PDT 2025


https://github.com/XChy created https://github.com/llvm/llvm-project/pull/157901

This patch makes use of pattern matching to handle disjoint or. Also, it simplifies the multiplication matching.

>From faabcc56865d8b8c4e325b972da6f685f88dd0ff Mon Sep 17 00:00:00 2001
From: XChy <xxs_chy at outlook.com>
Date: Thu, 11 Sep 2025 00:51:05 +0800
Subject: [PATCH 1/2] Precommit tests

---
 .../RISCV/rvv/fixed-vectors-zvqdotq.ll        | 88 +++++++++++++++++++
 1 file changed, 88 insertions(+)

diff --git a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-zvqdotq.ll b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-zvqdotq.ll
index a189711d11471..2f22d6519a85a 100644
--- a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-zvqdotq.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-zvqdotq.ll
@@ -1552,6 +1552,94 @@ entry:
   %res = call <4 x i32> @llvm.experimental.vector.partial.reduce.add(<4 x i32> zeroinitializer, <16 x i32> %a.ext)
   ret <4 x i32> %res
 }
+
+define i32 @vqdot_vv_accum_disjoint_or(<16 x i8> %a, <16 x i8> %b, <16 x i32> %x) {
+; CHECK-LABEL: vqdot_vv_accum_disjoint_or:
+; CHECK:       # %bb.0: # %entry
+; CHECK-NEXT:    vsetivli zero, 16, e16, m2, ta, ma
+; CHECK-NEXT:    vsext.vf2 v16, v8
+; CHECK-NEXT:    vsext.vf2 v18, v9
+; CHECK-NEXT:    vwmul.vv v8, v16, v18
+; CHECK-NEXT:    vsetvli zero, zero, e32, m4, ta, ma
+; CHECK-NEXT:    vor.vv v8, v8, v12
+; CHECK-NEXT:    vmv.s.x v12, zero
+; CHECK-NEXT:    vredsum.vs v8, v8, v12
+; CHECK-NEXT:    vmv.x.s a0, v8
+; CHECK-NEXT:    ret
+entry:
+  %a.sext = sext <16 x i8> %a to <16 x i32>
+  %b.sext = sext <16 x i8> %b to <16 x i32>
+  %mul = mul <16 x i32> %a.sext, %b.sext
+  %add = or disjoint <16 x i32> %mul, %x
+  %sum = tail call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> %add)
+  ret i32 %sum
+}
+
+define i32 @vqdot_vv_accum_or(<16 x i8> %a, <16 x i8> %b, <16 x i32> %x) {
+; CHECK-LABEL: vqdot_vv_accum_or:
+; CHECK:       # %bb.0: # %entry
+; CHECK-NEXT:    vsetivli zero, 16, e16, m2, ta, ma
+; CHECK-NEXT:    vsext.vf2 v16, v8
+; CHECK-NEXT:    vsext.vf2 v18, v9
+; CHECK-NEXT:    vwmul.vv v8, v16, v18
+; CHECK-NEXT:    vsetvli zero, zero, e32, m4, ta, ma
+; CHECK-NEXT:    vor.vv v8, v8, v12
+; CHECK-NEXT:    vmv.s.x v12, zero
+; CHECK-NEXT:    vredsum.vs v8, v8, v12
+; CHECK-NEXT:    vmv.x.s a0, v8
+; CHECK-NEXT:    ret
+entry:
+  %a.sext = sext <16 x i8> %a to <16 x i32>
+  %b.sext = sext <16 x i8> %b to <16 x i32>
+  %mul = mul <16 x i32> %a.sext, %b.sext
+  %add = or <16 x i32> %mul, %x
+  %sum = tail call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> %add)
+  ret i32 %sum
+}
+
+define i32 @vqdotu_vv_accum_disjoint_or(<16 x i8> %a, <16 x i8> %b, <16 x i32> %x) {
+; CHECK-LABEL: vqdotu_vv_accum_disjoint_or:
+; CHECK:       # %bb.0: # %entry
+; CHECK-NEXT:    vsetivli zero, 16, e8, m1, ta, ma
+; CHECK-NEXT:    vwmulu.vv v10, v8, v9
+; CHECK-NEXT:    vsetvli zero, zero, e16, m2, ta, ma
+; CHECK-NEXT:    vwaddu.wv v12, v12, v10
+; CHECK-NEXT:    vsetvli zero, zero, e32, m4, ta, ma
+; CHECK-NEXT:    vmv.s.x v8, zero
+; CHECK-NEXT:    vredsum.vs v8, v12, v8
+; CHECK-NEXT:    vmv.x.s a0, v8
+; CHECK-NEXT:    ret
+entry:
+  %a.zext = zext <16 x i8> %a to <16 x i32>
+  %b.zext = zext <16 x i8> %b to <16 x i32>
+  %mul = mul <16 x i32> %a.zext, %b.zext
+  %add = or disjoint <16 x i32> %mul, %x
+  %sum = tail call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> %add)
+  ret i32 %sum
+}
+
+define i32 @vqdotsu_vv_accum_disjoint_or(<16 x i8> %a, <16 x i8> %b, <16 x i32> %x) {
+; CHECK-LABEL: vqdotsu_vv_accum_disjoint_or:
+; CHECK:       # %bb.0: # %entry
+; CHECK-NEXT:    vsetivli zero, 16, e16, m2, ta, ma
+; CHECK-NEXT:    vsext.vf2 v16, v8
+; CHECK-NEXT:    vzext.vf2 v18, v9
+; CHECK-NEXT:    vwmulsu.vv v8, v16, v18
+; CHECK-NEXT:    vsetvli zero, zero, e32, m4, ta, ma
+; CHECK-NEXT:    vor.vv v8, v8, v12
+; CHECK-NEXT:    vmv.s.x v12, zero
+; CHECK-NEXT:    vredsum.vs v8, v8, v12
+; CHECK-NEXT:    vmv.x.s a0, v8
+; CHECK-NEXT:    ret
+entry:
+  %a.sext = sext <16 x i8> %a to <16 x i32>
+  %b.zext = zext <16 x i8> %b to <16 x i32>
+  %mul = mul <16 x i32> %a.sext, %b.zext
+  %add = or disjoint <16 x i32> %mul, %x
+  %sum = tail call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> %add)
+  ret i32 %sum
+}
+
 ;; NOTE: These prefixes are unused and the list is autogenerated. Do not add tests below this line:
 ; DOT32: {{.*}}
 ; DOT64: {{.*}}

>From 1ca1d7adf3221de0270b40e33d324716cf90e4fd Mon Sep 17 00:00:00 2001
From: XChy <xxs_chy at outlook.com>
Date: Thu, 11 Sep 2025 00:53:51 +0800
Subject: [PATCH 2/2] [RISCV] Extend zvqdot matching to handle disjoint or

---
 llvm/lib/Target/RISCV/RISCVISelLowering.cpp   |  15 +--
 .../RISCV/rvv/fixed-vectors-zvqdotq.ll        | 109 ++++++++++++------
 2 files changed, 80 insertions(+), 44 deletions(-)

diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index af9430a23e2c9..616c6fc73f65c 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -19121,6 +19121,7 @@ static SDValue foldReduceOperandViaVQDOT(SDValue InVec, const SDLoc &DL,
                                          SelectionDAG &DAG,
                                          const RISCVSubtarget &Subtarget,
                                          const RISCVTargetLowering &TLI) {
+  using namespace SDPatternMatch;
   // Note: We intentionally do not check the legality of the reduction type.
   // We want to handle the m4/m8 *src*  types, and thus need to let illegal
   // intermediate types flow through here.
@@ -19128,11 +19129,10 @@ static SDValue foldReduceOperandViaVQDOT(SDValue InVec, const SDLoc &DL,
       !InVec.getValueType().getVectorElementCount().isKnownMultipleOf(4))
     return SDValue();
 
-  // Recurse through adds (since generic dag canonicalizes to that
-  // form). TODO: Handle disjoint or here.
-  if (InVec->getOpcode() == ISD::ADD) {
-    SDValue A = InVec.getOperand(0);
-    SDValue B = InVec.getOperand(1);
+  // Recurse through adds/disjoint ors (since generic dag canonicalizes to that
+  // form).
+  SDValue A, B;
+  if (sd_match(InVec, m_AddLike(m_Value(A), m_Value(B)))) {
     SDValue AOpt = foldReduceOperandViaVQDOT(A, DL, DAG, Subtarget, TLI);
     SDValue BOpt = foldReduceOperandViaVQDOT(B, DL, DAG, Subtarget, TLI);
     if (AOpt || BOpt) {
@@ -19169,12 +19169,9 @@ static SDValue foldReduceOperandViaVQDOT(SDValue InVec, const SDLoc &DL,
   // mul (zext a, zext b) -> partial_reduce_umla 0, a, b
   // mul (sext a, zext b) -> partial_reduce_ssmla 0, a, b
   // mul (zext a, sext b) -> partial_reduce_smla 0, b, a (swapped)
-  if (InVec.getOpcode() != ISD::MUL)
+  if (!sd_match(InVec, m_Mul(m_Value(A), m_Value(B))))
     return SDValue();
 
-  SDValue A = InVec.getOperand(0);
-  SDValue B = InVec.getOperand(1);
-
   if (!ISD::isExtOpcode(A.getOpcode()))
     return SDValue();
 
diff --git a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-zvqdotq.ll b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-zvqdotq.ll
index 2f22d6519a85a..684eb609635ef 100644
--- a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-zvqdotq.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-zvqdotq.ll
@@ -1554,18 +1554,31 @@ entry:
 }
 
 define i32 @vqdot_vv_accum_disjoint_or(<16 x i8> %a, <16 x i8> %b, <16 x i32> %x) {
-; CHECK-LABEL: vqdot_vv_accum_disjoint_or:
-; CHECK:       # %bb.0: # %entry
-; CHECK-NEXT:    vsetivli zero, 16, e16, m2, ta, ma
-; CHECK-NEXT:    vsext.vf2 v16, v8
-; CHECK-NEXT:    vsext.vf2 v18, v9
-; CHECK-NEXT:    vwmul.vv v8, v16, v18
-; CHECK-NEXT:    vsetvli zero, zero, e32, m4, ta, ma
-; CHECK-NEXT:    vor.vv v8, v8, v12
-; CHECK-NEXT:    vmv.s.x v12, zero
-; CHECK-NEXT:    vredsum.vs v8, v8, v12
-; CHECK-NEXT:    vmv.x.s a0, v8
-; CHECK-NEXT:    ret
+; NODOT-LABEL: vqdot_vv_accum_disjoint_or:
+; NODOT:       # %bb.0: # %entry
+; NODOT-NEXT:    vsetivli zero, 16, e16, m2, ta, ma
+; NODOT-NEXT:    vsext.vf2 v16, v8
+; NODOT-NEXT:    vsext.vf2 v18, v9
+; NODOT-NEXT:    vwmul.vv v8, v16, v18
+; NODOT-NEXT:    vsetvli zero, zero, e32, m4, ta, ma
+; NODOT-NEXT:    vor.vv v8, v8, v12
+; NODOT-NEXT:    vmv.s.x v12, zero
+; NODOT-NEXT:    vredsum.vs v8, v8, v12
+; NODOT-NEXT:    vmv.x.s a0, v8
+; NODOT-NEXT:    ret
+;
+; DOT-LABEL: vqdot_vv_accum_disjoint_or:
+; DOT:       # %bb.0: # %entry
+; DOT-NEXT:    vsetivli zero, 4, e32, m1, ta, ma
+; DOT-NEXT:    vmv1r.v v16, v12
+; DOT-NEXT:    vqdot.vv v16, v8, v9
+; DOT-NEXT:    vsetivli zero, 4, e32, m4, tu, ma
+; DOT-NEXT:    vmv.v.v v12, v16
+; DOT-NEXT:    vmv.s.x v8, zero
+; DOT-NEXT:    vsetivli zero, 16, e32, m4, ta, ma
+; DOT-NEXT:    vredsum.vs v8, v12, v8
+; DOT-NEXT:    vmv.x.s a0, v8
+; DOT-NEXT:    ret
 entry:
   %a.sext = sext <16 x i8> %a to <16 x i32>
   %b.sext = sext <16 x i8> %b to <16 x i32>
@@ -1598,17 +1611,30 @@ entry:
 }
 
 define i32 @vqdotu_vv_accum_disjoint_or(<16 x i8> %a, <16 x i8> %b, <16 x i32> %x) {
-; CHECK-LABEL: vqdotu_vv_accum_disjoint_or:
-; CHECK:       # %bb.0: # %entry
-; CHECK-NEXT:    vsetivli zero, 16, e8, m1, ta, ma
-; CHECK-NEXT:    vwmulu.vv v10, v8, v9
-; CHECK-NEXT:    vsetvli zero, zero, e16, m2, ta, ma
-; CHECK-NEXT:    vwaddu.wv v12, v12, v10
-; CHECK-NEXT:    vsetvli zero, zero, e32, m4, ta, ma
-; CHECK-NEXT:    vmv.s.x v8, zero
-; CHECK-NEXT:    vredsum.vs v8, v12, v8
-; CHECK-NEXT:    vmv.x.s a0, v8
-; CHECK-NEXT:    ret
+; NODOT-LABEL: vqdotu_vv_accum_disjoint_or:
+; NODOT:       # %bb.0: # %entry
+; NODOT-NEXT:    vsetivli zero, 16, e8, m1, ta, ma
+; NODOT-NEXT:    vwmulu.vv v10, v8, v9
+; NODOT-NEXT:    vsetvli zero, zero, e16, m2, ta, ma
+; NODOT-NEXT:    vwaddu.wv v12, v12, v10
+; NODOT-NEXT:    vsetvli zero, zero, e32, m4, ta, ma
+; NODOT-NEXT:    vmv.s.x v8, zero
+; NODOT-NEXT:    vredsum.vs v8, v12, v8
+; NODOT-NEXT:    vmv.x.s a0, v8
+; NODOT-NEXT:    ret
+;
+; DOT-LABEL: vqdotu_vv_accum_disjoint_or:
+; DOT:       # %bb.0: # %entry
+; DOT-NEXT:    vsetivli zero, 4, e32, m1, ta, ma
+; DOT-NEXT:    vmv1r.v v16, v12
+; DOT-NEXT:    vqdotu.vv v16, v8, v9
+; DOT-NEXT:    vsetivli zero, 4, e32, m4, tu, ma
+; DOT-NEXT:    vmv.v.v v12, v16
+; DOT-NEXT:    vmv.s.x v8, zero
+; DOT-NEXT:    vsetivli zero, 16, e32, m4, ta, ma
+; DOT-NEXT:    vredsum.vs v8, v12, v8
+; DOT-NEXT:    vmv.x.s a0, v8
+; DOT-NEXT:    ret
 entry:
   %a.zext = zext <16 x i8> %a to <16 x i32>
   %b.zext = zext <16 x i8> %b to <16 x i32>
@@ -1619,18 +1645,31 @@ entry:
 }
 
 define i32 @vqdotsu_vv_accum_disjoint_or(<16 x i8> %a, <16 x i8> %b, <16 x i32> %x) {
-; CHECK-LABEL: vqdotsu_vv_accum_disjoint_or:
-; CHECK:       # %bb.0: # %entry
-; CHECK-NEXT:    vsetivli zero, 16, e16, m2, ta, ma
-; CHECK-NEXT:    vsext.vf2 v16, v8
-; CHECK-NEXT:    vzext.vf2 v18, v9
-; CHECK-NEXT:    vwmulsu.vv v8, v16, v18
-; CHECK-NEXT:    vsetvli zero, zero, e32, m4, ta, ma
-; CHECK-NEXT:    vor.vv v8, v8, v12
-; CHECK-NEXT:    vmv.s.x v12, zero
-; CHECK-NEXT:    vredsum.vs v8, v8, v12
-; CHECK-NEXT:    vmv.x.s a0, v8
-; CHECK-NEXT:    ret
+; NODOT-LABEL: vqdotsu_vv_accum_disjoint_or:
+; NODOT:       # %bb.0: # %entry
+; NODOT-NEXT:    vsetivli zero, 16, e16, m2, ta, ma
+; NODOT-NEXT:    vsext.vf2 v16, v8
+; NODOT-NEXT:    vzext.vf2 v18, v9
+; NODOT-NEXT:    vwmulsu.vv v8, v16, v18
+; NODOT-NEXT:    vsetvli zero, zero, e32, m4, ta, ma
+; NODOT-NEXT:    vor.vv v8, v8, v12
+; NODOT-NEXT:    vmv.s.x v12, zero
+; NODOT-NEXT:    vredsum.vs v8, v8, v12
+; NODOT-NEXT:    vmv.x.s a0, v8
+; NODOT-NEXT:    ret
+;
+; DOT-LABEL: vqdotsu_vv_accum_disjoint_or:
+; DOT:       # %bb.0: # %entry
+; DOT-NEXT:    vsetivli zero, 4, e32, m1, ta, ma
+; DOT-NEXT:    vmv1r.v v16, v12
+; DOT-NEXT:    vqdotsu.vv v16, v8, v9
+; DOT-NEXT:    vsetivli zero, 4, e32, m4, tu, ma
+; DOT-NEXT:    vmv.v.v v12, v16
+; DOT-NEXT:    vmv.s.x v8, zero
+; DOT-NEXT:    vsetivli zero, 16, e32, m4, ta, ma
+; DOT-NEXT:    vredsum.vs v8, v12, v8
+; DOT-NEXT:    vmv.x.s a0, v8
+; DOT-NEXT:    ret
 entry:
   %a.sext = sext <16 x i8> %a to <16 x i32>
   %b.zext = zext <16 x i8> %b to <16 x i32>



More information about the llvm-commits mailing list