[llvm] 33c9236 - [RISCV] Extend zvqdot matching to handle disjoint or (#157901)

via llvm-commits llvm-commits at lists.llvm.org
Wed Sep 10 10:27:55 PDT 2025


Author: Hongyu Chen
Date: 2025-09-10T17:27:51Z
New Revision: 33c9236bf870bc732a48a0256e90b907d1c21a49

URL: https://github.com/llvm/llvm-project/commit/33c9236bf870bc732a48a0256e90b907d1c21a49
DIFF: https://github.com/llvm/llvm-project/commit/33c9236bf870bc732a48a0256e90b907d1c21a49.diff

LOG: [RISCV] Extend zvqdot matching to handle disjoint or (#157901)

This patch makes use of pattern matching to handle disjoint or. Also, it
simplifies the multiplication matching.

Added: 
    

Modified: 
    llvm/lib/Target/RISCV/RISCVISelLowering.cpp
    llvm/test/CodeGen/RISCV/rvv/fixed-vectors-zvqdotq.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index 21ef630e0f692..409f98b348903 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -19123,6 +19123,7 @@ static SDValue foldReduceOperandViaVQDOT(SDValue InVec, const SDLoc &DL,
                                          SelectionDAG &DAG,
                                          const RISCVSubtarget &Subtarget,
                                          const RISCVTargetLowering &TLI) {
+  using namespace SDPatternMatch;
   // Note: We intentionally do not check the legality of the reduction type.
   // We want to handle the m4/m8 *src*  types, and thus need to let illegal
   // intermediate types flow through here.
@@ -19130,11 +19131,10 @@ static SDValue foldReduceOperandViaVQDOT(SDValue InVec, const SDLoc &DL,
       !InVec.getValueType().getVectorElementCount().isKnownMultipleOf(4))
     return SDValue();
 
-  // Recurse through adds (since generic dag canonicalizes to that
-  // form). TODO: Handle disjoint or here.
-  if (InVec->getOpcode() == ISD::ADD) {
-    SDValue A = InVec.getOperand(0);
-    SDValue B = InVec.getOperand(1);
+  // Recurse through adds/disjoint ors (since generic dag canonicalizes to that
+  // form).
+  SDValue A, B;
+  if (sd_match(InVec, m_AddLike(m_Value(A), m_Value(B)))) {
     SDValue AOpt = foldReduceOperandViaVQDOT(A, DL, DAG, Subtarget, TLI);
     SDValue BOpt = foldReduceOperandViaVQDOT(B, DL, DAG, Subtarget, TLI);
     if (AOpt || BOpt) {
@@ -19171,12 +19171,9 @@ static SDValue foldReduceOperandViaVQDOT(SDValue InVec, const SDLoc &DL,
   // mul (zext a, zext b) -> partial_reduce_umla 0, a, b
   // mul (sext a, zext b) -> partial_reduce_ssmla 0, a, b
   // mul (zext a, sext b) -> partial_reduce_smla 0, b, a (swapped)
-  if (InVec.getOpcode() != ISD::MUL)
+  if (!sd_match(InVec, m_Mul(m_Value(A), m_Value(B))))
     return SDValue();
 
-  SDValue A = InVec.getOperand(0);
-  SDValue B = InVec.getOperand(1);
-
   if (!ISD::isExtOpcode(A.getOpcode()))
     return SDValue();
 

diff  --git a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-zvqdotq.ll b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-zvqdotq.ll
index a189711d11471..684eb609635ef 100644
--- a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-zvqdotq.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-zvqdotq.ll
@@ -1552,6 +1552,133 @@ entry:
   %res = call <4 x i32> @llvm.experimental.vector.partial.reduce.add(<4 x i32> zeroinitializer, <16 x i32> %a.ext)
   ret <4 x i32> %res
 }
+
+define i32 @vqdot_vv_accum_disjoint_or(<16 x i8> %a, <16 x i8> %b, <16 x i32> %x) {
+; NODOT-LABEL: vqdot_vv_accum_disjoint_or:
+; NODOT:       # %bb.0: # %entry
+; NODOT-NEXT:    vsetivli zero, 16, e16, m2, ta, ma
+; NODOT-NEXT:    vsext.vf2 v16, v8
+; NODOT-NEXT:    vsext.vf2 v18, v9
+; NODOT-NEXT:    vwmul.vv v8, v16, v18
+; NODOT-NEXT:    vsetvli zero, zero, e32, m4, ta, ma
+; NODOT-NEXT:    vor.vv v8, v8, v12
+; NODOT-NEXT:    vmv.s.x v12, zero
+; NODOT-NEXT:    vredsum.vs v8, v8, v12
+; NODOT-NEXT:    vmv.x.s a0, v8
+; NODOT-NEXT:    ret
+;
+; DOT-LABEL: vqdot_vv_accum_disjoint_or:
+; DOT:       # %bb.0: # %entry
+; DOT-NEXT:    vsetivli zero, 4, e32, m1, ta, ma
+; DOT-NEXT:    vmv1r.v v16, v12
+; DOT-NEXT:    vqdot.vv v16, v8, v9
+; DOT-NEXT:    vsetivli zero, 4, e32, m4, tu, ma
+; DOT-NEXT:    vmv.v.v v12, v16
+; DOT-NEXT:    vmv.s.x v8, zero
+; DOT-NEXT:    vsetivli zero, 16, e32, m4, ta, ma
+; DOT-NEXT:    vredsum.vs v8, v12, v8
+; DOT-NEXT:    vmv.x.s a0, v8
+; DOT-NEXT:    ret
+entry:
+  %a.sext = sext <16 x i8> %a to <16 x i32>
+  %b.sext = sext <16 x i8> %b to <16 x i32>
+  %mul = mul <16 x i32> %a.sext, %b.sext
+  %add = or disjoint <16 x i32> %mul, %x
+  %sum = tail call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> %add)
+  ret i32 %sum
+}
+
+define i32 @vqdot_vv_accum_or(<16 x i8> %a, <16 x i8> %b, <16 x i32> %x) {
+; CHECK-LABEL: vqdot_vv_accum_or:
+; CHECK:       # %bb.0: # %entry
+; CHECK-NEXT:    vsetivli zero, 16, e16, m2, ta, ma
+; CHECK-NEXT:    vsext.vf2 v16, v8
+; CHECK-NEXT:    vsext.vf2 v18, v9
+; CHECK-NEXT:    vwmul.vv v8, v16, v18
+; CHECK-NEXT:    vsetvli zero, zero, e32, m4, ta, ma
+; CHECK-NEXT:    vor.vv v8, v8, v12
+; CHECK-NEXT:    vmv.s.x v12, zero
+; CHECK-NEXT:    vredsum.vs v8, v8, v12
+; CHECK-NEXT:    vmv.x.s a0, v8
+; CHECK-NEXT:    ret
+entry:
+  %a.sext = sext <16 x i8> %a to <16 x i32>
+  %b.sext = sext <16 x i8> %b to <16 x i32>
+  %mul = mul <16 x i32> %a.sext, %b.sext
+  %add = or <16 x i32> %mul, %x
+  %sum = tail call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> %add)
+  ret i32 %sum
+}
+
+define i32 @vqdotu_vv_accum_disjoint_or(<16 x i8> %a, <16 x i8> %b, <16 x i32> %x) {
+; NODOT-LABEL: vqdotu_vv_accum_disjoint_or:
+; NODOT:       # %bb.0: # %entry
+; NODOT-NEXT:    vsetivli zero, 16, e8, m1, ta, ma
+; NODOT-NEXT:    vwmulu.vv v10, v8, v9
+; NODOT-NEXT:    vsetvli zero, zero, e16, m2, ta, ma
+; NODOT-NEXT:    vwaddu.wv v12, v12, v10
+; NODOT-NEXT:    vsetvli zero, zero, e32, m4, ta, ma
+; NODOT-NEXT:    vmv.s.x v8, zero
+; NODOT-NEXT:    vredsum.vs v8, v12, v8
+; NODOT-NEXT:    vmv.x.s a0, v8
+; NODOT-NEXT:    ret
+;
+; DOT-LABEL: vqdotu_vv_accum_disjoint_or:
+; DOT:       # %bb.0: # %entry
+; DOT-NEXT:    vsetivli zero, 4, e32, m1, ta, ma
+; DOT-NEXT:    vmv1r.v v16, v12
+; DOT-NEXT:    vqdotu.vv v16, v8, v9
+; DOT-NEXT:    vsetivli zero, 4, e32, m4, tu, ma
+; DOT-NEXT:    vmv.v.v v12, v16
+; DOT-NEXT:    vmv.s.x v8, zero
+; DOT-NEXT:    vsetivli zero, 16, e32, m4, ta, ma
+; DOT-NEXT:    vredsum.vs v8, v12, v8
+; DOT-NEXT:    vmv.x.s a0, v8
+; DOT-NEXT:    ret
+entry:
+  %a.zext = zext <16 x i8> %a to <16 x i32>
+  %b.zext = zext <16 x i8> %b to <16 x i32>
+  %mul = mul <16 x i32> %a.zext, %b.zext
+  %add = or disjoint <16 x i32> %mul, %x
+  %sum = tail call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> %add)
+  ret i32 %sum
+}
+
+define i32 @vqdotsu_vv_accum_disjoint_or(<16 x i8> %a, <16 x i8> %b, <16 x i32> %x) {
+; NODOT-LABEL: vqdotsu_vv_accum_disjoint_or:
+; NODOT:       # %bb.0: # %entry
+; NODOT-NEXT:    vsetivli zero, 16, e16, m2, ta, ma
+; NODOT-NEXT:    vsext.vf2 v16, v8
+; NODOT-NEXT:    vzext.vf2 v18, v9
+; NODOT-NEXT:    vwmulsu.vv v8, v16, v18
+; NODOT-NEXT:    vsetvli zero, zero, e32, m4, ta, ma
+; NODOT-NEXT:    vor.vv v8, v8, v12
+; NODOT-NEXT:    vmv.s.x v12, zero
+; NODOT-NEXT:    vredsum.vs v8, v8, v12
+; NODOT-NEXT:    vmv.x.s a0, v8
+; NODOT-NEXT:    ret
+;
+; DOT-LABEL: vqdotsu_vv_accum_disjoint_or:
+; DOT:       # %bb.0: # %entry
+; DOT-NEXT:    vsetivli zero, 4, e32, m1, ta, ma
+; DOT-NEXT:    vmv1r.v v16, v12
+; DOT-NEXT:    vqdotsu.vv v16, v8, v9
+; DOT-NEXT:    vsetivli zero, 4, e32, m4, tu, ma
+; DOT-NEXT:    vmv.v.v v12, v16
+; DOT-NEXT:    vmv.s.x v8, zero
+; DOT-NEXT:    vsetivli zero, 16, e32, m4, ta, ma
+; DOT-NEXT:    vredsum.vs v8, v12, v8
+; DOT-NEXT:    vmv.x.s a0, v8
+; DOT-NEXT:    ret
+entry:
+  %a.sext = sext <16 x i8> %a to <16 x i32>
+  %b.zext = zext <16 x i8> %b to <16 x i32>
+  %mul = mul <16 x i32> %a.sext, %b.zext
+  %add = or disjoint <16 x i32> %mul, %x
+  %sum = tail call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> %add)
+  ret i32 %sum
+}
+
 ;; NOTE: These prefixes are unused and the list is autogenerated. Do not add tests below this line:
 ; DOT32: {{.*}}
 ; DOT64: {{.*}}


        


More information about the llvm-commits mailing list