[llvm] 33c9236 - [RISCV] Extend zvqdot matching to handle disjoint or (#157901)
via llvm-commits
llvm-commits at lists.llvm.org
Wed Sep 10 10:27:55 PDT 2025
Author: Hongyu Chen
Date: 2025-09-10T17:27:51Z
New Revision: 33c9236bf870bc732a48a0256e90b907d1c21a49
URL: https://github.com/llvm/llvm-project/commit/33c9236bf870bc732a48a0256e90b907d1c21a49
DIFF: https://github.com/llvm/llvm-project/commit/33c9236bf870bc732a48a0256e90b907d1c21a49.diff
LOG: [RISCV] Extend zvqdot matching to handle disjoint or (#157901)
This patch makes use of pattern matching to handle disjoint or. Also, it
simplifies the multiplication matching.
Added:
Modified:
llvm/lib/Target/RISCV/RISCVISelLowering.cpp
llvm/test/CodeGen/RISCV/rvv/fixed-vectors-zvqdotq.ll
Removed:
################################################################################
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index 21ef630e0f692..409f98b348903 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -19123,6 +19123,7 @@ static SDValue foldReduceOperandViaVQDOT(SDValue InVec, const SDLoc &DL,
SelectionDAG &DAG,
const RISCVSubtarget &Subtarget,
const RISCVTargetLowering &TLI) {
+ using namespace SDPatternMatch;
// Note: We intentionally do not check the legality of the reduction type.
// We want to handle the m4/m8 *src* types, and thus need to let illegal
// intermediate types flow through here.
@@ -19130,11 +19131,10 @@ static SDValue foldReduceOperandViaVQDOT(SDValue InVec, const SDLoc &DL,
!InVec.getValueType().getVectorElementCount().isKnownMultipleOf(4))
return SDValue();
- // Recurse through adds (since generic dag canonicalizes to that
- // form). TODO: Handle disjoint or here.
- if (InVec->getOpcode() == ISD::ADD) {
- SDValue A = InVec.getOperand(0);
- SDValue B = InVec.getOperand(1);
+ // Recurse through adds/disjoint ors (since generic dag canonicalizes to that
+ // form).
+ SDValue A, B;
+ if (sd_match(InVec, m_AddLike(m_Value(A), m_Value(B)))) {
SDValue AOpt = foldReduceOperandViaVQDOT(A, DL, DAG, Subtarget, TLI);
SDValue BOpt = foldReduceOperandViaVQDOT(B, DL, DAG, Subtarget, TLI);
if (AOpt || BOpt) {
@@ -19171,12 +19171,9 @@ static SDValue foldReduceOperandViaVQDOT(SDValue InVec, const SDLoc &DL,
// mul (zext a, zext b) -> partial_reduce_umla 0, a, b
// mul (sext a, zext b) -> partial_reduce_ssmla 0, a, b
// mul (zext a, sext b) -> partial_reduce_smla 0, b, a (swapped)
- if (InVec.getOpcode() != ISD::MUL)
+ if (!sd_match(InVec, m_Mul(m_Value(A), m_Value(B))))
return SDValue();
- SDValue A = InVec.getOperand(0);
- SDValue B = InVec.getOperand(1);
-
if (!ISD::isExtOpcode(A.getOpcode()))
return SDValue();
diff --git a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-zvqdotq.ll b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-zvqdotq.ll
index a189711d11471..684eb609635ef 100644
--- a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-zvqdotq.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-zvqdotq.ll
@@ -1552,6 +1552,133 @@ entry:
%res = call <4 x i32> @llvm.experimental.vector.partial.reduce.add(<4 x i32> zeroinitializer, <16 x i32> %a.ext)
ret <4 x i32> %res
}
+
+define i32 @vqdot_vv_accum_disjoint_or(<16 x i8> %a, <16 x i8> %b, <16 x i32> %x) {
+; NODOT-LABEL: vqdot_vv_accum_disjoint_or:
+; NODOT: # %bb.0: # %entry
+; NODOT-NEXT: vsetivli zero, 16, e16, m2, ta, ma
+; NODOT-NEXT: vsext.vf2 v16, v8
+; NODOT-NEXT: vsext.vf2 v18, v9
+; NODOT-NEXT: vwmul.vv v8, v16, v18
+; NODOT-NEXT: vsetvli zero, zero, e32, m4, ta, ma
+; NODOT-NEXT: vor.vv v8, v8, v12
+; NODOT-NEXT: vmv.s.x v12, zero
+; NODOT-NEXT: vredsum.vs v8, v8, v12
+; NODOT-NEXT: vmv.x.s a0, v8
+; NODOT-NEXT: ret
+;
+; DOT-LABEL: vqdot_vv_accum_disjoint_or:
+; DOT: # %bb.0: # %entry
+; DOT-NEXT: vsetivli zero, 4, e32, m1, ta, ma
+; DOT-NEXT: vmv1r.v v16, v12
+; DOT-NEXT: vqdot.vv v16, v8, v9
+; DOT-NEXT: vsetivli zero, 4, e32, m4, tu, ma
+; DOT-NEXT: vmv.v.v v12, v16
+; DOT-NEXT: vmv.s.x v8, zero
+; DOT-NEXT: vsetivli zero, 16, e32, m4, ta, ma
+; DOT-NEXT: vredsum.vs v8, v12, v8
+; DOT-NEXT: vmv.x.s a0, v8
+; DOT-NEXT: ret
+entry:
+ %a.sext = sext <16 x i8> %a to <16 x i32>
+ %b.sext = sext <16 x i8> %b to <16 x i32>
+ %mul = mul <16 x i32> %a.sext, %b.sext
+ %add = or disjoint <16 x i32> %mul, %x
+ %sum = tail call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> %add)
+ ret i32 %sum
+}
+
+define i32 @vqdot_vv_accum_or(<16 x i8> %a, <16 x i8> %b, <16 x i32> %x) {
+; CHECK-LABEL: vqdot_vv_accum_or:
+; CHECK: # %bb.0: # %entry
+; CHECK-NEXT: vsetivli zero, 16, e16, m2, ta, ma
+; CHECK-NEXT: vsext.vf2 v16, v8
+; CHECK-NEXT: vsext.vf2 v18, v9
+; CHECK-NEXT: vwmul.vv v8, v16, v18
+; CHECK-NEXT: vsetvli zero, zero, e32, m4, ta, ma
+; CHECK-NEXT: vor.vv v8, v8, v12
+; CHECK-NEXT: vmv.s.x v12, zero
+; CHECK-NEXT: vredsum.vs v8, v8, v12
+; CHECK-NEXT: vmv.x.s a0, v8
+; CHECK-NEXT: ret
+entry:
+ %a.sext = sext <16 x i8> %a to <16 x i32>
+ %b.sext = sext <16 x i8> %b to <16 x i32>
+ %mul = mul <16 x i32> %a.sext, %b.sext
+ %add = or <16 x i32> %mul, %x
+ %sum = tail call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> %add)
+ ret i32 %sum
+}
+
+define i32 @vqdotu_vv_accum_disjoint_or(<16 x i8> %a, <16 x i8> %b, <16 x i32> %x) {
+; NODOT-LABEL: vqdotu_vv_accum_disjoint_or:
+; NODOT: # %bb.0: # %entry
+; NODOT-NEXT: vsetivli zero, 16, e8, m1, ta, ma
+; NODOT-NEXT: vwmulu.vv v10, v8, v9
+; NODOT-NEXT: vsetvli zero, zero, e16, m2, ta, ma
+; NODOT-NEXT: vwaddu.wv v12, v12, v10
+; NODOT-NEXT: vsetvli zero, zero, e32, m4, ta, ma
+; NODOT-NEXT: vmv.s.x v8, zero
+; NODOT-NEXT: vredsum.vs v8, v12, v8
+; NODOT-NEXT: vmv.x.s a0, v8
+; NODOT-NEXT: ret
+;
+; DOT-LABEL: vqdotu_vv_accum_disjoint_or:
+; DOT: # %bb.0: # %entry
+; DOT-NEXT: vsetivli zero, 4, e32, m1, ta, ma
+; DOT-NEXT: vmv1r.v v16, v12
+; DOT-NEXT: vqdotu.vv v16, v8, v9
+; DOT-NEXT: vsetivli zero, 4, e32, m4, tu, ma
+; DOT-NEXT: vmv.v.v v12, v16
+; DOT-NEXT: vmv.s.x v8, zero
+; DOT-NEXT: vsetivli zero, 16, e32, m4, ta, ma
+; DOT-NEXT: vredsum.vs v8, v12, v8
+; DOT-NEXT: vmv.x.s a0, v8
+; DOT-NEXT: ret
+entry:
+ %a.zext = zext <16 x i8> %a to <16 x i32>
+ %b.zext = zext <16 x i8> %b to <16 x i32>
+ %mul = mul <16 x i32> %a.zext, %b.zext
+ %add = or disjoint <16 x i32> %mul, %x
+ %sum = tail call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> %add)
+ ret i32 %sum
+}
+
+define i32 @vqdotsu_vv_accum_disjoint_or(<16 x i8> %a, <16 x i8> %b, <16 x i32> %x) {
+; NODOT-LABEL: vqdotsu_vv_accum_disjoint_or:
+; NODOT: # %bb.0: # %entry
+; NODOT-NEXT: vsetivli zero, 16, e16, m2, ta, ma
+; NODOT-NEXT: vsext.vf2 v16, v8
+; NODOT-NEXT: vzext.vf2 v18, v9
+; NODOT-NEXT: vwmulsu.vv v8, v16, v18
+; NODOT-NEXT: vsetvli zero, zero, e32, m4, ta, ma
+; NODOT-NEXT: vor.vv v8, v8, v12
+; NODOT-NEXT: vmv.s.x v12, zero
+; NODOT-NEXT: vredsum.vs v8, v8, v12
+; NODOT-NEXT: vmv.x.s a0, v8
+; NODOT-NEXT: ret
+;
+; DOT-LABEL: vqdotsu_vv_accum_disjoint_or:
+; DOT: # %bb.0: # %entry
+; DOT-NEXT: vsetivli zero, 4, e32, m1, ta, ma
+; DOT-NEXT: vmv1r.v v16, v12
+; DOT-NEXT: vqdotsu.vv v16, v8, v9
+; DOT-NEXT: vsetivli zero, 4, e32, m4, tu, ma
+; DOT-NEXT: vmv.v.v v12, v16
+; DOT-NEXT: vmv.s.x v8, zero
+; DOT-NEXT: vsetivli zero, 16, e32, m4, ta, ma
+; DOT-NEXT: vredsum.vs v8, v12, v8
+; DOT-NEXT: vmv.x.s a0, v8
+; DOT-NEXT: ret
+entry:
+ %a.sext = sext <16 x i8> %a to <16 x i32>
+ %b.zext = zext <16 x i8> %b to <16 x i32>
+ %mul = mul <16 x i32> %a.sext, %b.zext
+ %add = or disjoint <16 x i32> %mul, %x
+ %sum = tail call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> %add)
+ ret i32 %sum
+}
+
;; NOTE: These prefixes are unused and the list is autogenerated. Do not add tests below this line:
; DOT32: {{.*}}
; DOT64: {{.*}}
More information about the llvm-commits
mailing list