[llvm] [RISCV] Migrate zvqdotq reduce matching to use partial_reduce infrastructure (PR #142212)

via llvm-commits llvm-commits at lists.llvm.org
Fri May 30 13:51:34 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-backend-risc-v

Author: Philip Reames (preames)

<details>
<summary>Changes</summary>

This involves a codegen regression at the moment due to the issue described in 443cdd0b, but this aligns the lowering paths for this case and makes it less likely future bugs go undetected.

---
Full diff: https://github.com/llvm/llvm-project/pull/142212.diff


3 Files Affected:

- (modified) llvm/lib/Target/RISCV/RISCVISelLowering.cpp (+30-26) 
- (modified) llvm/test/CodeGen/RISCV/rvv/fixed-vectors-zvqdotq.ll (+24-48) 
- (modified) llvm/test/CodeGen/RISCV/rvv/zvqdotq-sdnode.ll (+24-48) 


``````````diff
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index b7fd0c93fa93f..ff2500b8adbc2 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -18338,17 +18338,15 @@ static SDValue foldReduceOperandViaVQDOT(SDValue InVec, const SDLoc &DL,
   if (InVec.getOpcode() == ISD::ZERO_EXTEND ||
       InVec.getOpcode() == ISD::SIGN_EXTEND) {
     SDValue A = InVec.getOperand(0);
-    if (A.getValueType().getVectorElementType() != MVT::i8 ||
-        !TLI.isTypeLegal(A.getValueType()))
+    EVT OpVT = A.getValueType();
+    if (OpVT.getVectorElementType() != MVT::i8 || !TLI.isTypeLegal(OpVT))
       return SDValue();
 
     MVT ResVT = getQDOTXResultType(A.getSimpleValueType());
-    A = DAG.getBitcast(ResVT, A);
-    SDValue B = DAG.getConstant(0x01010101, DL, ResVT);
-
+    SDValue B = DAG.getConstant(0x1, DL, OpVT);
     bool IsSigned = InVec.getOpcode() == ISD::SIGN_EXTEND;
-    unsigned Opc = IsSigned ? RISCVISD::VQDOT_VL : RISCVISD::VQDOTU_VL;
-    return lowerVQDOT(Opc, A, B, DL, DAG, Subtarget);
+    unsigned Opc = IsSigned ? ISD::PARTIAL_REDUCE_SMLA : ISD::PARTIAL_REDUCE_UMLA;
+    return DAG.getNode(Opc, DL, ResVT, {DAG.getConstant(0, DL, ResVT), A, B});
   }
 
   // mul (sext, sext) -> vqdot
@@ -18362,32 +18360,38 @@ static SDValue foldReduceOperandViaVQDOT(SDValue InVec, const SDLoc &DL,
 
   SDValue A = InVec.getOperand(0);
   SDValue B = InVec.getOperand(1);
-  unsigned Opc = 0;
+
+  if (!ISD::isExtOpcode(A.getOpcode()))
+    return SDValue();
+
+  EVT OpVT = A.getOperand(0).getValueType();
+  if (OpVT.getVectorElementType() != MVT::i8 ||
+      OpVT != B.getOperand(0).getValueType() ||
+      !TLI.isTypeLegal(A.getValueType()))
+    return SDValue();
+
+  MVT ResVT = getQDOTXResultType(OpVT.getSimpleVT());
+  // Use the partial_reduce_*mla path if possible
   if (A.getOpcode() == B.getOpcode()) {
-    if (A.getOpcode() == ISD::SIGN_EXTEND)
-      Opc = RISCVISD::VQDOT_VL;
-    else if (A.getOpcode() == ISD::ZERO_EXTEND)
-      Opc = RISCVISD::VQDOTU_VL;
-    else
-      return SDValue();
-  } else {
-    if (B.getOpcode() != ISD::ZERO_EXTEND)
-      std::swap(A, B);
-    if (A.getOpcode() != ISD::SIGN_EXTEND || B.getOpcode() != ISD::ZERO_EXTEND)
+    // TODO: handle ANY_EXTEND and zext nonneg here
+    if (A.getOpcode() != ISD::SIGN_EXTEND &&
+        A.getOpcode() != ISD::ZERO_EXTEND)
       return SDValue();
-    Opc = RISCVISD::VQDOTSU_VL;
-  }
-  assert(Opc);
 
-  if (A.getOperand(0).getValueType().getVectorElementType() != MVT::i8 ||
-      A.getOperand(0).getValueType() != B.getOperand(0).getValueType() ||
-      !TLI.isTypeLegal(A.getValueType()))
+    bool IsSigned = A.getOpcode() == ISD::SIGN_EXTEND;
+    unsigned Opc = IsSigned ? ISD::PARTIAL_REDUCE_SMLA : ISD::PARTIAL_REDUCE_UMLA;
+    return DAG.getNode(Opc, DL, ResVT, {DAG.getConstant(0, DL, ResVT), A.getOperand(0), B.getOperand(0)});
+  }
+  // We don't yet have a partial_reduce_sumla node, so directly lower to the
+  // target node instead.
+  if (B.getOpcode() != ISD::ZERO_EXTEND)
+    std::swap(A, B);
+  if (A.getOpcode() != ISD::SIGN_EXTEND || B.getOpcode() != ISD::ZERO_EXTEND)
     return SDValue();
 
-  MVT ResVT = getQDOTXResultType(A.getOperand(0).getSimpleValueType());
   A = DAG.getBitcast(ResVT, A.getOperand(0));
   B = DAG.getBitcast(ResVT, B.getOperand(0));
-  return lowerVQDOT(Opc, A, B, DL, DAG, Subtarget);
+  return lowerVQDOT(RISCVISD::VQDOTSU_VL, A, B, DL, DAG, Subtarget);
 }
 
 static SDValue performVECREDUCECombine(SDNode *N, SelectionDAG &DAG,
diff --git a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-zvqdotq.ll b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-zvqdotq.ll
index 0237faea9efb7..8ef691622415c 100644
--- a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-zvqdotq.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-zvqdotq.ll
@@ -1,8 +1,8 @@
 ; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py
 ; RUN: llc -mtriple=riscv32 -mattr=+v -verify-machineinstrs < %s | FileCheck %s --check-prefixes=CHECK,NODOT
 ; RUN: llc -mtriple=riscv64 -mattr=+v -verify-machineinstrs < %s | FileCheck %s --check-prefixes=CHECK,NODOT
-; RUN: llc -mtriple=riscv32 -mattr=+v,+experimental-zvqdotq -verify-machineinstrs < %s | FileCheck %s --check-prefixes=CHECK,DOT,DOT32
-; RUN: llc -mtriple=riscv64 -mattr=+v,+experimental-zvqdotq -verify-machineinstrs < %s | FileCheck %s --check-prefixes=CHECK,DOT,DOT64
+; RUN: llc -mtriple=riscv32 -mattr=+v,+experimental-zvqdotq -verify-machineinstrs < %s | FileCheck %s --check-prefixes=CHECK,DOT
+; RUN: llc -mtriple=riscv64 -mattr=+v,+experimental-zvqdotq -verify-machineinstrs < %s | FileCheck %s --check-prefixes=CHECK,DOT
 
 define i32 @vqdot_vv(<16 x i8> %a, <16 x i8> %b) {
 ; NODOT-LABEL: vqdot_vv:
@@ -230,29 +230,17 @@ define i32 @reduce_of_sext(<16 x i8> %a) {
 ; NODOT-NEXT:    vmv.x.s a0, v8
 ; NODOT-NEXT:    ret
 ;
-; DOT32-LABEL: reduce_of_sext:
-; DOT32:       # %bb.0: # %entry
-; DOT32-NEXT:    vsetivli zero, 4, e32, m1, ta, ma
-; DOT32-NEXT:    vmv.v.i v9, 0
-; DOT32-NEXT:    lui a0, 4112
-; DOT32-NEXT:    addi a0, a0, 257
-; DOT32-NEXT:    vqdot.vx v9, v8, a0
-; DOT32-NEXT:    vmv.s.x v8, zero
-; DOT32-NEXT:    vredsum.vs v8, v9, v8
-; DOT32-NEXT:    vmv.x.s a0, v8
-; DOT32-NEXT:    ret
-;
-; DOT64-LABEL: reduce_of_sext:
-; DOT64:       # %bb.0: # %entry
-; DOT64-NEXT:    vsetivli zero, 4, e32, m1, ta, ma
-; DOT64-NEXT:    vmv.v.i v9, 0
-; DOT64-NEXT:    lui a0, 4112
-; DOT64-NEXT:    addiw a0, a0, 257
-; DOT64-NEXT:    vqdot.vx v9, v8, a0
-; DOT64-NEXT:    vmv.s.x v8, zero
-; DOT64-NEXT:    vredsum.vs v8, v9, v8
-; DOT64-NEXT:    vmv.x.s a0, v8
-; DOT64-NEXT:    ret
+; DOT-LABEL: reduce_of_sext:
+; DOT:       # %bb.0: # %entry
+; DOT-NEXT:    vsetivli zero, 16, e8, m1, ta, ma
+; DOT-NEXT:    vmv.v.i v9, 1
+; DOT-NEXT:    vsetivli zero, 4, e32, m1, ta, ma
+; DOT-NEXT:    vmv.v.i v10, 0
+; DOT-NEXT:    vqdot.vv v10, v8, v9
+; DOT-NEXT:    vmv.s.x v8, zero
+; DOT-NEXT:    vredsum.vs v8, v10, v8
+; DOT-NEXT:    vmv.x.s a0, v8
+; DOT-NEXT:    ret
 entry:
   %a.ext = sext <16 x i8> %a to <16 x i32>
   %res = tail call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> %a.ext)
@@ -269,29 +257,17 @@ define i32 @reduce_of_zext(<16 x i8> %a) {
 ; NODOT-NEXT:    vmv.x.s a0, v8
 ; NODOT-NEXT:    ret
 ;
-; DOT32-LABEL: reduce_of_zext:
-; DOT32:       # %bb.0: # %entry
-; DOT32-NEXT:    vsetivli zero, 4, e32, m1, ta, ma
-; DOT32-NEXT:    vmv.v.i v9, 0
-; DOT32-NEXT:    lui a0, 4112
-; DOT32-NEXT:    addi a0, a0, 257
-; DOT32-NEXT:    vqdotu.vx v9, v8, a0
-; DOT32-NEXT:    vmv.s.x v8, zero
-; DOT32-NEXT:    vredsum.vs v8, v9, v8
-; DOT32-NEXT:    vmv.x.s a0, v8
-; DOT32-NEXT:    ret
-;
-; DOT64-LABEL: reduce_of_zext:
-; DOT64:       # %bb.0: # %entry
-; DOT64-NEXT:    vsetivli zero, 4, e32, m1, ta, ma
-; DOT64-NEXT:    vmv.v.i v9, 0
-; DOT64-NEXT:    lui a0, 4112
-; DOT64-NEXT:    addiw a0, a0, 257
-; DOT64-NEXT:    vqdotu.vx v9, v8, a0
-; DOT64-NEXT:    vmv.s.x v8, zero
-; DOT64-NEXT:    vredsum.vs v8, v9, v8
-; DOT64-NEXT:    vmv.x.s a0, v8
-; DOT64-NEXT:    ret
+; DOT-LABEL: reduce_of_zext:
+; DOT:       # %bb.0: # %entry
+; DOT-NEXT:    vsetivli zero, 16, e8, m1, ta, ma
+; DOT-NEXT:    vmv.v.i v9, 1
+; DOT-NEXT:    vsetivli zero, 4, e32, m1, ta, ma
+; DOT-NEXT:    vmv.v.i v10, 0
+; DOT-NEXT:    vqdotu.vv v10, v8, v9
+; DOT-NEXT:    vmv.s.x v8, zero
+; DOT-NEXT:    vredsum.vs v8, v10, v8
+; DOT-NEXT:    vmv.x.s a0, v8
+; DOT-NEXT:    ret
 entry:
   %a.ext = zext <16 x i8> %a to <16 x i32>
   %res = tail call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> %a.ext)
diff --git a/llvm/test/CodeGen/RISCV/rvv/zvqdotq-sdnode.ll b/llvm/test/CodeGen/RISCV/rvv/zvqdotq-sdnode.ll
index d0fc915a0d07e..1948904493e8f 100644
--- a/llvm/test/CodeGen/RISCV/rvv/zvqdotq-sdnode.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/zvqdotq-sdnode.ll
@@ -1,8 +1,8 @@
 ; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py
 ; RUN: llc -mtriple=riscv32 -mattr=+v -verify-machineinstrs < %s | FileCheck %s --check-prefixes=CHECK,NODOT
 ; RUN: llc -mtriple=riscv64 -mattr=+v -verify-machineinstrs < %s | FileCheck %s --check-prefixes=CHECK,NODOT
-; RUN: llc -mtriple=riscv32 -mattr=+v,+experimental-zvqdotq -verify-machineinstrs < %s | FileCheck %s --check-prefixes=CHECK,DOT,DOT32
-; RUN: llc -mtriple=riscv64 -mattr=+v,+experimental-zvqdotq -verify-machineinstrs < %s | FileCheck %s --check-prefixes=CHECK,DOT,DOT64
+; RUN: llc -mtriple=riscv32 -mattr=+v,+experimental-zvqdotq -verify-machineinstrs < %s | FileCheck %s --check-prefixes=CHECK,DOT
+; RUN: llc -mtriple=riscv64 -mattr=+v,+experimental-zvqdotq -verify-machineinstrs < %s | FileCheck %s --check-prefixes=CHECK,DOT
 
 define i32 @vqdot_vv(<vscale x 16 x i8> %a, <vscale x 16 x i8> %b) {
 ; NODOT-LABEL: vqdot_vv:
@@ -230,29 +230,17 @@ define i32 @reduce_of_sext(<vscale x 16 x i8> %a) {
 ; NODOT-NEXT:    vmv.x.s a0, v8
 ; NODOT-NEXT:    ret
 ;
-; DOT32-LABEL: reduce_of_sext:
-; DOT32:       # %bb.0: # %entry
-; DOT32-NEXT:    vsetvli a0, zero, e32, m2, ta, ma
-; DOT32-NEXT:    vmv.v.i v10, 0
-; DOT32-NEXT:    lui a0, 4112
-; DOT32-NEXT:    addi a0, a0, 257
-; DOT32-NEXT:    vqdot.vx v10, v8, a0
-; DOT32-NEXT:    vmv.s.x v8, zero
-; DOT32-NEXT:    vredsum.vs v8, v10, v8
-; DOT32-NEXT:    vmv.x.s a0, v8
-; DOT32-NEXT:    ret
-;
-; DOT64-LABEL: reduce_of_sext:
-; DOT64:       # %bb.0: # %entry
-; DOT64-NEXT:    vsetvli a0, zero, e32, m2, ta, ma
-; DOT64-NEXT:    vmv.v.i v10, 0
-; DOT64-NEXT:    lui a0, 4112
-; DOT64-NEXT:    addiw a0, a0, 257
-; DOT64-NEXT:    vqdot.vx v10, v8, a0
-; DOT64-NEXT:    vmv.s.x v8, zero
-; DOT64-NEXT:    vredsum.vs v8, v10, v8
-; DOT64-NEXT:    vmv.x.s a0, v8
-; DOT64-NEXT:    ret
+; DOT-LABEL: reduce_of_sext:
+; DOT:       # %bb.0: # %entry
+; DOT-NEXT:    vsetvli a0, zero, e8, m2, ta, ma
+; DOT-NEXT:    vmv.v.i v10, 1
+; DOT-NEXT:    vsetvli a0, zero, e32, m2, ta, ma
+; DOT-NEXT:    vmv.v.i v12, 0
+; DOT-NEXT:    vqdot.vv v12, v8, v10
+; DOT-NEXT:    vmv.s.x v8, zero
+; DOT-NEXT:    vredsum.vs v8, v12, v8
+; DOT-NEXT:    vmv.x.s a0, v8
+; DOT-NEXT:    ret
 entry:
   %a.ext = sext <vscale x 16 x i8> %a to <vscale x 16 x i32>
   %res = tail call i32 @llvm.vector.reduce.add.v16i32(<vscale x 16 x i32> %a.ext)
@@ -269,29 +257,17 @@ define i32 @reduce_of_zext(<vscale x 16 x i8> %a) {
 ; NODOT-NEXT:    vmv.x.s a0, v8
 ; NODOT-NEXT:    ret
 ;
-; DOT32-LABEL: reduce_of_zext:
-; DOT32:       # %bb.0: # %entry
-; DOT32-NEXT:    vsetvli a0, zero, e32, m2, ta, ma
-; DOT32-NEXT:    vmv.v.i v10, 0
-; DOT32-NEXT:    lui a0, 4112
-; DOT32-NEXT:    addi a0, a0, 257
-; DOT32-NEXT:    vqdotu.vx v10, v8, a0
-; DOT32-NEXT:    vmv.s.x v8, zero
-; DOT32-NEXT:    vredsum.vs v8, v10, v8
-; DOT32-NEXT:    vmv.x.s a0, v8
-; DOT32-NEXT:    ret
-;
-; DOT64-LABEL: reduce_of_zext:
-; DOT64:       # %bb.0: # %entry
-; DOT64-NEXT:    vsetvli a0, zero, e32, m2, ta, ma
-; DOT64-NEXT:    vmv.v.i v10, 0
-; DOT64-NEXT:    lui a0, 4112
-; DOT64-NEXT:    addiw a0, a0, 257
-; DOT64-NEXT:    vqdotu.vx v10, v8, a0
-; DOT64-NEXT:    vmv.s.x v8, zero
-; DOT64-NEXT:    vredsum.vs v8, v10, v8
-; DOT64-NEXT:    vmv.x.s a0, v8
-; DOT64-NEXT:    ret
+; DOT-LABEL: reduce_of_zext:
+; DOT:       # %bb.0: # %entry
+; DOT-NEXT:    vsetvli a0, zero, e8, m2, ta, ma
+; DOT-NEXT:    vmv.v.i v10, 1
+; DOT-NEXT:    vsetvli a0, zero, e32, m2, ta, ma
+; DOT-NEXT:    vmv.v.i v12, 0
+; DOT-NEXT:    vqdotu.vv v12, v8, v10
+; DOT-NEXT:    vmv.s.x v8, zero
+; DOT-NEXT:    vredsum.vs v8, v12, v8
+; DOT-NEXT:    vmv.x.s a0, v8
+; DOT-NEXT:    ret
 entry:
   %a.ext = zext <vscale x 16 x i8> %a to <vscale x 16 x i32>
   %res = tail call i32 @llvm.vector.reduce.add.v16i32(<vscale x 16 x i32> %a.ext)

``````````

</details>


https://github.com/llvm/llvm-project/pull/142212


More information about the llvm-commits mailing list