[llvm] [RISCV] Don't check extop VL in vfwred{u, o}sum patterns (PR #125799)

via llvm-commits llvm-commits at lists.llvm.org
Tue Feb 4 20:07:06 PST 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-backend-risc-v

Author: Luke Lau (lukel97)

<details>
<summary>Changes</summary>

Because riscv_fpextend_vl doesn't have a passthru operand the tail elements are undef, so we can treat them as if they were active.

Relaxing this allows us to match widening reductions where the fpextend isn't a VP intrinsic.

This same reasoning is already used for riscv_fpextend_vl in RISCVInstrInfoVSDPatterns.td


---
Full diff: https://github.com/llvm/llvm-project/pull/125799.diff


2 Files Affected:

- (modified) llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td (+2-2) 
- (modified) llvm/test/CodeGen/RISCV/rvv/vreductions-fp-vp.ll (+20-24) 


``````````diff
diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td b/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td
index f35dc6eb2cb8be..77a5f8b22fc26b 100644
--- a/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td
+++ b/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td
@@ -1443,7 +1443,7 @@ multiclass VPatWidenReductionVL_Ext_VL<SDNode vop, PatFrags extop, string instru
     let Predicates = !listconcat(GetVTypePredicates<vti>.Predicates,
                                  GetVTypePredicates<wti>.Predicates) in {
       def: Pat<(wti_m1.Vector (vop (wti_m1.Vector VR:$passthru),
-                                   (wti.Vector (extop (vti.Vector vti.RegClass:$rs1), (vti.Mask true_mask), VLOpFrag)),
+                                   (wti.Vector (extop (vti.Vector vti.RegClass:$rs1), (vti.Mask true_mask), (XLenVT srcvalue))),
                                    VR:$rs2, (vti.Mask V0), VLOpFrag,
                                    (XLenVT timm:$policy))),
                (!cast<Instruction>(instruction_name#"_VS_"#vti.LMul.MX#"_E"#vti.SEW#"_MASK")
@@ -1462,7 +1462,7 @@ multiclass VPatWidenReductionVL_Ext_VL_RM<SDNode vop, PatFrags extop, string ins
     let Predicates = !listconcat(GetVTypePredicates<vti>.Predicates,
                                  GetVTypePredicates<wti>.Predicates) in {
       def: Pat<(wti_m1.Vector (vop (wti_m1.Vector VR:$passthru),
-                                   (wti.Vector (extop (vti.Vector vti.RegClass:$rs1), (vti.Mask true_mask), VLOpFrag)),
+                                   (wti.Vector (extop (vti.Vector vti.RegClass:$rs1), (vti.Mask true_mask), (XLenVT srcvalue))),
                                    VR:$rs2, (vti.Mask V0), VLOpFrag,
                                    (XLenVT timm:$policy))),
                (!cast<Instruction>(instruction_name#"_VS_"#vti.LMul.MX#"_E"#vti.SEW#"_MASK")
diff --git a/llvm/test/CodeGen/RISCV/rvv/vreductions-fp-vp.ll b/llvm/test/CodeGen/RISCV/rvv/vreductions-fp-vp.ll
index b47edf94d3bf5c..ccea5b05fb03cc 100644
--- a/llvm/test/CodeGen/RISCV/rvv/vreductions-fp-vp.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/vreductions-fp-vp.ll
@@ -533,13 +533,12 @@ define double @vpreduce_ord_fadd_fpext_vp_fpext_nxv1f32_nxv1f64(double %s, <vsca
 define float @vpreduce_fadd_fpext_nxv1f16_nxv1f32(float %s, <vscale x 1 x half> %v, <vscale x 1 x i1> %m, i32 zeroext %evl) {
 ; CHECK-LABEL: vpreduce_fadd_fpext_nxv1f16_nxv1f32:
 ; CHECK:       # %bb.0:
+; CHECK-NEXT:    vsetivli zero, 1, e32, m1, ta, ma
+; CHECK-NEXT:    vfmv.s.f v9, fa0
 ; CHECK-NEXT:    vsetvli zero, a0, e16, mf4, ta, ma
-; CHECK-NEXT:    vfwcvt.f.f.v v9, v8
-; CHECK-NEXT:    vsetivli zero, 1, e32, mf2, ta, ma
-; CHECK-NEXT:    vfmv.s.f v8, fa0
-; CHECK-NEXT:    vsetvli zero, a0, e32, mf2, ta, ma
-; CHECK-NEXT:    vfredusum.vs v8, v9, v8, v0.t
-; CHECK-NEXT:    vfmv.f.s fa0, v8
+; CHECK-NEXT:    vfwredusum.vs v9, v8, v9, v0.t
+; CHECK-NEXT:    vsetvli zero, zero, e32, mf2, ta, ma
+; CHECK-NEXT:    vfmv.f.s fa0, v9
 ; CHECK-NEXT:    ret
   %w = fpext <vscale x 1 x half> %v to <vscale x 1 x float>
   %r = call reassoc float @llvm.vp.reduce.fadd(float %s, <vscale x 1 x float> %w, <vscale x 1 x i1> %m, i32 %evl)
@@ -549,13 +548,12 @@ define float @vpreduce_fadd_fpext_nxv1f16_nxv1f32(float %s, <vscale x 1 x half>
 define float @vpreduce_ord_fadd_fpext_nxv1f16_nxv1f32(float %s, <vscale x 1 x half> %v, <vscale x 1 x i1> %m, i32 zeroext %evl) {
 ; CHECK-LABEL: vpreduce_ord_fadd_fpext_nxv1f16_nxv1f32:
 ; CHECK:       # %bb.0:
+; CHECK-NEXT:    vsetivli zero, 1, e32, m1, ta, ma
+; CHECK-NEXT:    vfmv.s.f v9, fa0
 ; CHECK-NEXT:    vsetvli zero, a0, e16, mf4, ta, ma
-; CHECK-NEXT:    vfwcvt.f.f.v v9, v8
-; CHECK-NEXT:    vsetivli zero, 1, e32, mf2, ta, ma
-; CHECK-NEXT:    vfmv.s.f v8, fa0
-; CHECK-NEXT:    vsetvli zero, a0, e32, mf2, ta, ma
-; CHECK-NEXT:    vfredosum.vs v8, v9, v8, v0.t
-; CHECK-NEXT:    vfmv.f.s fa0, v8
+; CHECK-NEXT:    vfwredosum.vs v9, v8, v9, v0.t
+; CHECK-NEXT:    vsetvli zero, zero, e32, mf2, ta, ma
+; CHECK-NEXT:    vfmv.f.s fa0, v9
 ; CHECK-NEXT:    ret
   %w = fpext <vscale x 1 x half> %v to <vscale x 1 x float>
   %r = call float @llvm.vp.reduce.fadd(float %s, <vscale x 1 x float> %w, <vscale x 1 x i1> %m, i32 %evl)
@@ -565,13 +563,12 @@ define float @vpreduce_ord_fadd_fpext_nxv1f16_nxv1f32(float %s, <vscale x 1 x ha
 define double @vpreduce_fadd_fpext_nxv1f32_nxv1f64(double %s, <vscale x 1 x float> %v, <vscale x 1 x i1> %m, i32 zeroext %evl) {
 ; CHECK-LABEL: vpreduce_fadd_fpext_nxv1f32_nxv1f64:
 ; CHECK:       # %bb.0:
-; CHECK-NEXT:    vsetvli zero, a0, e32, mf2, ta, ma
-; CHECK-NEXT:    vfwcvt.f.f.v v9, v8
 ; CHECK-NEXT:    vsetivli zero, 1, e64, m1, ta, ma
-; CHECK-NEXT:    vfmv.s.f v8, fa0
-; CHECK-NEXT:    vsetvli zero, a0, e64, m1, ta, ma
-; CHECK-NEXT:    vfredusum.vs v8, v9, v8, v0.t
-; CHECK-NEXT:    vfmv.f.s fa0, v8
+; CHECK-NEXT:    vfmv.s.f v9, fa0
+; CHECK-NEXT:    vsetvli zero, a0, e32, mf2, ta, ma
+; CHECK-NEXT:    vfwredusum.vs v9, v8, v9, v0.t
+; CHECK-NEXT:    vsetvli zero, zero, e64, m1, ta, ma
+; CHECK-NEXT:    vfmv.f.s fa0, v9
 ; CHECK-NEXT:    ret
   %w = fpext <vscale x 1 x float> %v to <vscale x 1 x double>
   %r = call reassoc double @llvm.vp.reduce.fadd(double %s, <vscale x 1 x double> %w, <vscale x 1 x i1> %m, i32 %evl)
@@ -581,13 +578,12 @@ define double @vpreduce_fadd_fpext_nxv1f32_nxv1f64(double %s, <vscale x 1 x floa
 define double @vpreduce_ord_fadd_fpext_nxv1f32_nxv1f64(double %s, <vscale x 1 x float> %v, <vscale x 1 x i1> %m, i32 zeroext %evl) {
 ; CHECK-LABEL: vpreduce_ord_fadd_fpext_nxv1f32_nxv1f64:
 ; CHECK:       # %bb.0:
-; CHECK-NEXT:    vsetvli zero, a0, e32, mf2, ta, ma
-; CHECK-NEXT:    vfwcvt.f.f.v v9, v8
 ; CHECK-NEXT:    vsetivli zero, 1, e64, m1, ta, ma
-; CHECK-NEXT:    vfmv.s.f v8, fa0
-; CHECK-NEXT:    vsetvli zero, a0, e64, m1, ta, ma
-; CHECK-NEXT:    vfredosum.vs v8, v9, v8, v0.t
-; CHECK-NEXT:    vfmv.f.s fa0, v8
+; CHECK-NEXT:    vfmv.s.f v9, fa0
+; CHECK-NEXT:    vsetvli zero, a0, e32, mf2, ta, ma
+; CHECK-NEXT:    vfwredosum.vs v9, v8, v9, v0.t
+; CHECK-NEXT:    vsetvli zero, zero, e64, m1, ta, ma
+; CHECK-NEXT:    vfmv.f.s fa0, v9
 ; CHECK-NEXT:    ret
   %w = fpext <vscale x 1 x float> %v to <vscale x 1 x double>
   %r = call double @llvm.vp.reduce.fadd(double %s, <vscale x 1 x double> %w, <vscale x 1 x i1> %m, i32 %evl)

``````````

</details>


https://github.com/llvm/llvm-project/pull/125799


More information about the llvm-commits mailing list