[llvm] [RISCV] Extend zvqdot matching to handle reduction trees (PR #138965)

Philip Reames via llvm-commits llvm-commits at lists.llvm.org
Fri May 9 07:44:14 PDT 2025


https://github.com/preames updated https://github.com/llvm/llvm-project/pull/138965

>From 970983fe52fab8113ac3dc277185067105ee538a Mon Sep 17 00:00:00 2001
From: Philip Reames <preames at rivosinc.com>
Date: Wed, 7 May 2025 08:45:26 -0700
Subject: [PATCH 1/2] [RISCV] Extend zvqdot matching to handle reduction trees

Now that we have matching for vqdot in it's basic variants, we can
extend the matcher to handle reduction trees instead of individual
reductions.  This is important as we canonicalize reductions
by performing a tree in the vector domain before the root reduction
instruction.

The particular approach taken here has the unfortunate implication
that non-matches visit the entire reduction tree once for each
time the reduction root is visited in DAG.  While conceptually
problematic for compile time, this is probably fine in practice
as we should only visit the root once per pass of DAGCombine.
I don't really see a better solution - suggestions welcome.
---
 llvm/lib/Target/RISCV/RISCVISelLowering.cpp   |  43 +++++
 .../RISCV/rvv/fixed-vectors-zvqdotq.ll        | 149 ++++++++++++------
 2 files changed, 145 insertions(+), 47 deletions(-)

diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index 756c563f0194d..a4d778c054ecf 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -18131,6 +18131,29 @@ static MVT getQDOTXResultType(MVT OpVT) {
   return MVT::getVectorVT(MVT::i32, OpEC.divideCoefficientBy(4));
 }
 
+/// Given fixed length vectors A and B with equal element types, but possibly
+/// different number of elements, return A + B where either A or B is zero
+/// padded to the larger number of elements.
+static SDValue getZeroPaddedAdd(const SDLoc &DL, SDValue A, SDValue B,
+                                SelectionDAG &DAG) {
+  // NOTE: Manually doing the extract/add/insert scheme produces
+  // significantly better coegen than the naive pad with zeros
+  // and add scheme.
+  EVT AVT = A.getValueType();
+  EVT BVT = B.getValueType();
+  assert(AVT.getVectorElementType() == BVT.getVectorElementType());
+  if (AVT.getVectorNumElements() > BVT.getVectorNumElements()) {
+    std::swap(A, B);
+    std::swap(AVT, BVT);
+  }
+
+  SDValue BPart = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, AVT, B,
+                              DAG.getVectorIdxConstant(0, DL));
+  SDValue Res = DAG.getNode(ISD::ADD, DL, AVT, A, BPart);
+  return DAG.getNode(ISD::INSERT_SUBVECTOR, DL, BVT, B, Res,
+                     DAG.getVectorIdxConstant(0, DL));
+}
+
 static SDValue foldReduceOperandViaVQDOT(SDValue InVec, const SDLoc &DL,
                                          SelectionDAG &DAG,
                                          const RISCVSubtarget &Subtarget,
@@ -18142,6 +18165,26 @@ static SDValue foldReduceOperandViaVQDOT(SDValue InVec, const SDLoc &DL,
       !InVec.getValueType().getVectorElementCount().isKnownMultipleOf(4))
     return SDValue();
 
+  // Recurse through adds (since generic dag canonicalizes to that
+  // form).
+  if (InVec->getOpcode() == ISD::ADD) {
+    SDValue A = InVec.getOperand(0);
+    SDValue B = InVec.getOperand(1);
+    SDValue AOpt = foldReduceOperandViaVQDOT(A, DL, DAG, Subtarget, TLI);
+    SDValue BOpt = foldReduceOperandViaVQDOT(B, DL, DAG, Subtarget, TLI);
+    if (AOpt || BOpt) {
+      if (AOpt)
+        A = AOpt;
+      if (BOpt)
+        B = BOpt;
+      // From here, we're doing A + B with mixed types, implicitly zero
+      // padded to the wider type.  Note that we *don't* need the result
+      // type to be the original VT, and in fact prefer narrower ones
+      // if possible.
+      return getZeroPaddedAdd(DL, A, B, DAG);
+    }
+  }
+
   // reduce (zext a) <--> reduce (mul zext a. zext 1)
   // reduce (sext a) <--> reduce (mul sext a. sext 1)
   if (InVec.getOpcode() == ISD::ZERO_EXTEND ||
diff --git a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-zvqdotq.ll b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-zvqdotq.ll
index edc9886abc3b9..e5546ad404c1b 100644
--- a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-zvqdotq.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-zvqdotq.ll
@@ -299,17 +299,31 @@ entry:
 }
 
 define i32 @vqdot_vv_accum(<16 x i8> %a, <16 x i8> %b, <16 x i32> %x) {
-; CHECK-LABEL: vqdot_vv_accum:
-; CHECK:       # %bb.0: # %entry
-; CHECK-NEXT:    vsetivli zero, 16, e16, m2, ta, ma
-; CHECK-NEXT:    vsext.vf2 v10, v8
-; CHECK-NEXT:    vsext.vf2 v16, v9
-; CHECK-NEXT:    vwmacc.vv v12, v10, v16
-; CHECK-NEXT:    vsetvli zero, zero, e32, m4, ta, ma
-; CHECK-NEXT:    vmv.s.x v8, zero
-; CHECK-NEXT:    vredsum.vs v8, v12, v8
-; CHECK-NEXT:    vmv.x.s a0, v8
-; CHECK-NEXT:    ret
+; NODOT-LABEL: vqdot_vv_accum:
+; NODOT:       # %bb.0: # %entry
+; NODOT-NEXT:    vsetivli zero, 16, e16, m2, ta, ma
+; NODOT-NEXT:    vsext.vf2 v10, v8
+; NODOT-NEXT:    vsext.vf2 v16, v9
+; NODOT-NEXT:    vwmacc.vv v12, v10, v16
+; NODOT-NEXT:    vsetvli zero, zero, e32, m4, ta, ma
+; NODOT-NEXT:    vmv.s.x v8, zero
+; NODOT-NEXT:    vredsum.vs v8, v12, v8
+; NODOT-NEXT:    vmv.x.s a0, v8
+; NODOT-NEXT:    ret
+;
+; DOT-LABEL: vqdot_vv_accum:
+; DOT:       # %bb.0: # %entry
+; DOT-NEXT:    vsetivli zero, 4, e32, m1, ta, ma
+; DOT-NEXT:    vmv.v.i v10, 0
+; DOT-NEXT:    vqdot.vv v10, v8, v9
+; DOT-NEXT:    vadd.vv v8, v10, v12
+; DOT-NEXT:    vsetivli zero, 4, e32, m4, tu, ma
+; DOT-NEXT:    vmv.v.v v12, v8
+; DOT-NEXT:    vmv.s.x v8, zero
+; DOT-NEXT:    vsetivli zero, 16, e32, m4, ta, ma
+; DOT-NEXT:    vredsum.vs v8, v12, v8
+; DOT-NEXT:    vmv.x.s a0, v8
+; DOT-NEXT:    ret
 entry:
   %a.sext = sext <16 x i8> %a to <16 x i32>
   %b.sext = sext <16 x i8> %b to <16 x i32>
@@ -320,17 +334,31 @@ entry:
 }
 
 define i32 @vqdotu_vv_accum(<16 x i8> %a, <16 x i8> %b, <16 x i32> %x) {
-; CHECK-LABEL: vqdotu_vv_accum:
-; CHECK:       # %bb.0: # %entry
-; CHECK-NEXT:    vsetivli zero, 16, e8, m1, ta, ma
-; CHECK-NEXT:    vwmulu.vv v10, v8, v9
-; CHECK-NEXT:    vsetvli zero, zero, e16, m2, ta, ma
-; CHECK-NEXT:    vwaddu.wv v12, v12, v10
-; CHECK-NEXT:    vsetvli zero, zero, e32, m4, ta, ma
-; CHECK-NEXT:    vmv.s.x v8, zero
-; CHECK-NEXT:    vredsum.vs v8, v12, v8
-; CHECK-NEXT:    vmv.x.s a0, v8
-; CHECK-NEXT:    ret
+; NODOT-LABEL: vqdotu_vv_accum:
+; NODOT:       # %bb.0: # %entry
+; NODOT-NEXT:    vsetivli zero, 16, e8, m1, ta, ma
+; NODOT-NEXT:    vwmulu.vv v10, v8, v9
+; NODOT-NEXT:    vsetvli zero, zero, e16, m2, ta, ma
+; NODOT-NEXT:    vwaddu.wv v12, v12, v10
+; NODOT-NEXT:    vsetvli zero, zero, e32, m4, ta, ma
+; NODOT-NEXT:    vmv.s.x v8, zero
+; NODOT-NEXT:    vredsum.vs v8, v12, v8
+; NODOT-NEXT:    vmv.x.s a0, v8
+; NODOT-NEXT:    ret
+;
+; DOT-LABEL: vqdotu_vv_accum:
+; DOT:       # %bb.0: # %entry
+; DOT-NEXT:    vsetivli zero, 4, e32, m1, ta, ma
+; DOT-NEXT:    vmv.v.i v10, 0
+; DOT-NEXT:    vqdotu.vv v10, v8, v9
+; DOT-NEXT:    vadd.vv v8, v10, v12
+; DOT-NEXT:    vsetivli zero, 4, e32, m4, tu, ma
+; DOT-NEXT:    vmv.v.v v12, v8
+; DOT-NEXT:    vmv.s.x v8, zero
+; DOT-NEXT:    vsetivli zero, 16, e32, m4, ta, ma
+; DOT-NEXT:    vredsum.vs v8, v12, v8
+; DOT-NEXT:    vmv.x.s a0, v8
+; DOT-NEXT:    ret
 entry:
   %a.zext = zext <16 x i8> %a to <16 x i32>
   %b.zext = zext <16 x i8> %b to <16 x i32>
@@ -341,17 +369,31 @@ entry:
 }
 
 define i32 @vqdotsu_vv_accum(<16 x i8> %a, <16 x i8> %b, <16 x i32> %x) {
-; CHECK-LABEL: vqdotsu_vv_accum:
-; CHECK:       # %bb.0: # %entry
-; CHECK-NEXT:    vsetivli zero, 16, e16, m2, ta, ma
-; CHECK-NEXT:    vsext.vf2 v10, v8
-; CHECK-NEXT:    vzext.vf2 v16, v9
-; CHECK-NEXT:    vwmaccsu.vv v12, v10, v16
-; CHECK-NEXT:    vsetvli zero, zero, e32, m4, ta, ma
-; CHECK-NEXT:    vmv.s.x v8, zero
-; CHECK-NEXT:    vredsum.vs v8, v12, v8
-; CHECK-NEXT:    vmv.x.s a0, v8
-; CHECK-NEXT:    ret
+; NODOT-LABEL: vqdotsu_vv_accum:
+; NODOT:       # %bb.0: # %entry
+; NODOT-NEXT:    vsetivli zero, 16, e16, m2, ta, ma
+; NODOT-NEXT:    vsext.vf2 v10, v8
+; NODOT-NEXT:    vzext.vf2 v16, v9
+; NODOT-NEXT:    vwmaccsu.vv v12, v10, v16
+; NODOT-NEXT:    vsetvli zero, zero, e32, m4, ta, ma
+; NODOT-NEXT:    vmv.s.x v8, zero
+; NODOT-NEXT:    vredsum.vs v8, v12, v8
+; NODOT-NEXT:    vmv.x.s a0, v8
+; NODOT-NEXT:    ret
+;
+; DOT-LABEL: vqdotsu_vv_accum:
+; DOT:       # %bb.0: # %entry
+; DOT-NEXT:    vsetivli zero, 4, e32, m1, ta, ma
+; DOT-NEXT:    vmv.v.i v10, 0
+; DOT-NEXT:    vqdotsu.vv v10, v8, v9
+; DOT-NEXT:    vadd.vv v8, v10, v12
+; DOT-NEXT:    vsetivli zero, 4, e32, m4, tu, ma
+; DOT-NEXT:    vmv.v.v v12, v8
+; DOT-NEXT:    vmv.s.x v8, zero
+; DOT-NEXT:    vsetivli zero, 16, e32, m4, ta, ma
+; DOT-NEXT:    vredsum.vs v8, v12, v8
+; DOT-NEXT:    vmv.x.s a0, v8
+; DOT-NEXT:    ret
 entry:
   %a.sext = sext <16 x i8> %a to <16 x i32>
   %b.zext = zext <16 x i8> %b to <16 x i32>
@@ -455,20 +497,33 @@ entry:
 }
 
 define i32 @vqdot_vv_split(<16 x i8> %a, <16 x i8> %b, <16 x i8> %c, <16 x i8> %d) {
-; CHECK-LABEL: vqdot_vv_split:
-; CHECK:       # %bb.0: # %entry
-; CHECK-NEXT:    vsetivli zero, 16, e16, m2, ta, ma
-; CHECK-NEXT:    vsext.vf2 v12, v8
-; CHECK-NEXT:    vsext.vf2 v14, v9
-; CHECK-NEXT:    vsext.vf2 v16, v10
-; CHECK-NEXT:    vsext.vf2 v18, v11
-; CHECK-NEXT:    vwmul.vv v8, v12, v14
-; CHECK-NEXT:    vwmacc.vv v8, v16, v18
-; CHECK-NEXT:    vsetvli zero, zero, e32, m4, ta, ma
-; CHECK-NEXT:    vmv.s.x v12, zero
-; CHECK-NEXT:    vredsum.vs v8, v8, v12
-; CHECK-NEXT:    vmv.x.s a0, v8
-; CHECK-NEXT:    ret
+; NODOT-LABEL: vqdot_vv_split:
+; NODOT:       # %bb.0: # %entry
+; NODOT-NEXT:    vsetivli zero, 16, e16, m2, ta, ma
+; NODOT-NEXT:    vsext.vf2 v12, v8
+; NODOT-NEXT:    vsext.vf2 v14, v9
+; NODOT-NEXT:    vsext.vf2 v16, v10
+; NODOT-NEXT:    vsext.vf2 v18, v11
+; NODOT-NEXT:    vwmul.vv v8, v12, v14
+; NODOT-NEXT:    vwmacc.vv v8, v16, v18
+; NODOT-NEXT:    vsetvli zero, zero, e32, m4, ta, ma
+; NODOT-NEXT:    vmv.s.x v12, zero
+; NODOT-NEXT:    vredsum.vs v8, v8, v12
+; NODOT-NEXT:    vmv.x.s a0, v8
+; NODOT-NEXT:    ret
+;
+; DOT-LABEL: vqdot_vv_split:
+; DOT:       # %bb.0: # %entry
+; DOT-NEXT:    vsetivli zero, 4, e32, m1, ta, ma
+; DOT-NEXT:    vmv.v.i v12, 0
+; DOT-NEXT:    vmv.v.i v13, 0
+; DOT-NEXT:    vqdot.vv v12, v8, v9
+; DOT-NEXT:    vqdot.vv v13, v10, v11
+; DOT-NEXT:    vadd.vv v8, v12, v13
+; DOT-NEXT:    vmv.s.x v9, zero
+; DOT-NEXT:    vredsum.vs v8, v8, v9
+; DOT-NEXT:    vmv.x.s a0, v8
+; DOT-NEXT:    ret
 entry:
   %a.sext = sext <16 x i8> %a to <16 x i32>
   %b.sext = sext <16 x i8> %b to <16 x i32>

>From a4b8d2a6d4a24c1d541ef2668cf75b53b458262c Mon Sep 17 00:00:00 2001
From: Philip Reames <preames at rivosinc.com>
Date: Fri, 9 May 2025 07:44:06 -0700
Subject: [PATCH 2/2] Update llvm/lib/Target/RISCV/RISCVISelLowering.cpp

Co-authored-by: Luke Lau <luke_lau at icloud.com>
---
 llvm/lib/Target/RISCV/RISCVISelLowering.cpp | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index a4d778c054ecf..7c2fbb2a7b43f 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -18137,7 +18137,7 @@ static MVT getQDOTXResultType(MVT OpVT) {
 static SDValue getZeroPaddedAdd(const SDLoc &DL, SDValue A, SDValue B,
                                 SelectionDAG &DAG) {
   // NOTE: Manually doing the extract/add/insert scheme produces
-  // significantly better coegen than the naive pad with zeros
+  // significantly better codegen than the naive pad with zeros
   // and add scheme.
   EVT AVT = A.getValueType();
   EVT BVT = B.getValueType();



More information about the llvm-commits mailing list