[llvm] e42fdcb - [RISCV] Match widening fp instructions with same fpext used in multiple operands (#125803)

via llvm-commits llvm-commits at lists.llvm.org
Mon Feb 10 09:11:47 PST 2025


Author: Luke Lau
Date: 2025-02-11T01:11:44+08:00
New Revision: e42fdcb41fdcfe7bf302b40f20afb4e9cda5602d

URL: https://github.com/llvm/llvm-project/commit/e42fdcb41fdcfe7bf302b40f20afb4e9cda5602d
DIFF: https://github.com/llvm/llvm-project/commit/e42fdcb41fdcfe7bf302b40f20afb4e9cda5602d.diff

LOG: [RISCV] Match widening fp instructions with same fpext used in multiple operands (#125803)

Because the fpext has a single use constraint on it we can't match cases
where it's used for both operands.

Introduce a new PatFrag that allows multiple uses on a single user and
use it for the binary patterns, and some ternary patterns.

(For some of the ternary patterns there is a fneg that counts as a
separate user, we still need to handle these)

Added: 
    

Modified: 
    llvm/lib/Target/RISCV/RISCVInstrInfoVSDPatterns.td
    llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td
    llvm/test/CodeGen/RISCV/rvv/vfwadd-sdnode.ll
    llvm/test/CodeGen/RISCV/rvv/vfwmacc-sdnode.ll
    llvm/test/CodeGen/RISCV/rvv/vfwmul-sdnode.ll
    llvm/test/CodeGen/RISCV/rvv/vfwsub-sdnode.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Target/RISCV/RISCVInstrInfoVSDPatterns.td b/llvm/lib/Target/RISCV/RISCVInstrInfoVSDPatterns.td
index 8f77b2ce34d1f19..c588e047c2ac8fa 100644
--- a/llvm/lib/Target/RISCV/RISCVInstrInfoVSDPatterns.td
+++ b/llvm/lib/Target/RISCV/RISCVInstrInfoVSDPatterns.td
@@ -536,19 +536,19 @@ multiclass VPatWidenBinaryFPSDNode_VV_VF<SDNode op, string instruction_name> {
     defvar wti = vtiToWti.Wti;
     let Predicates = !listconcat(GetVTypePredicates<vti>.Predicates,
                                  GetVTypePredicates<wti>.Predicates) in {
-      def : Pat<(op (wti.Vector (riscv_fpextend_vl_oneuse
+      def : Pat<(op (wti.Vector (riscv_fpextend_vl_sameuser
                                      (vti.Vector vti.RegClass:$rs2),
                                      (vti.Mask true_mask), (XLenVT srcvalue))),
-                    (wti.Vector (riscv_fpextend_vl_oneuse
+                    (wti.Vector (riscv_fpextend_vl_sameuser
                                      (vti.Vector vti.RegClass:$rs1),
                                      (vti.Mask true_mask), (XLenVT srcvalue)))),
                 (!cast<Instruction>(instruction_name#"_VV_"#vti.LMul.MX)
                   (wti.Vector (IMPLICIT_DEF)), vti.RegClass:$rs2,
                   vti.RegClass:$rs1, vti.AVL, vti.Log2SEW, TA_MA)>;
-      def : Pat<(op (wti.Vector (riscv_fpextend_vl_oneuse
+      def : Pat<(op (wti.Vector (riscv_fpextend_vl_sameuser
                                      (vti.Vector vti.RegClass:$rs2),
                                      (vti.Mask true_mask), (XLenVT srcvalue))),
-                    (wti.Vector (riscv_fpextend_vl_oneuse
+                    (wti.Vector (riscv_fpextend_vl_sameuser
                                      (vti.Vector (SplatFPOp vti.ScalarRegClass:$rs1)),
                                      (vti.Mask true_mask), (XLenVT srcvalue)))),
                 (!cast<Instruction>(instruction_name#"_V"#vti.ScalarSuffix#"_"#vti.LMul.MX)
@@ -571,10 +571,10 @@ multiclass VPatWidenBinaryFPSDNode_VV_VF_RM<SDNode op, string instruction_name>
     defvar wti = vtiToWti.Wti;
     let Predicates = !listconcat(GetVTypePredicates<vti>.Predicates,
                                  GetVTypePredicates<wti>.Predicates) in {
-      def : Pat<(op (wti.Vector (riscv_fpextend_vl_oneuse
+      def : Pat<(op (wti.Vector (riscv_fpextend_vl_sameuser
                                      (vti.Vector vti.RegClass:$rs2),
                                      (vti.Mask true_mask), (XLenVT srcvalue))),
-                    (wti.Vector (riscv_fpextend_vl_oneuse
+                    (wti.Vector (riscv_fpextend_vl_sameuser
                                      (vti.Vector vti.RegClass:$rs1),
                                      (vti.Mask true_mask), (XLenVT srcvalue)))),
                 (!cast<Instruction>(instruction_name#"_VV_"#vti.LMul.MX#"_E"#vti.SEW)
@@ -584,10 +584,10 @@ multiclass VPatWidenBinaryFPSDNode_VV_VF_RM<SDNode op, string instruction_name>
                    // RISCVInsertReadWriteCSR
                    FRM_DYN,
                   vti.AVL, vti.Log2SEW, TA_MA)>;
-      def : Pat<(op (wti.Vector (riscv_fpextend_vl_oneuse
+      def : Pat<(op (wti.Vector (riscv_fpextend_vl_sameuser
                                      (vti.Vector vti.RegClass:$rs2),
                                      (vti.Mask true_mask), (XLenVT srcvalue))),
-                    (wti.Vector (riscv_fpextend_vl_oneuse
+                    (wti.Vector (riscv_fpextend_vl_sameuser
                                      (vti.Vector (SplatFPOp (vti.Scalar vti.ScalarRegClass:$rs1))),
                                      (vti.Mask true_mask), (XLenVT srcvalue)))),
                 (!cast<Instruction>(instruction_name#"_V"#vti.ScalarSuffix#"_"#vti.LMul.MX#"_E"#vti.SEW)
@@ -669,10 +669,10 @@ multiclass VPatWidenFPMulAccSDNode_VV_VF_RM<string instruction_name,
                                  !if(!eq(vti.Scalar, bf16),
                                      [HasStdExtZvfbfwma],
                                      [])) in {
-      def : Pat<(fma (wti.Vector (riscv_fpextend_vl_oneuse
+      def : Pat<(fma (wti.Vector (riscv_fpextend_vl_sameuser
                                       (vti.Vector vti.RegClass:$rs1),
                                       (vti.Mask true_mask), (XLenVT srcvalue))),
-                     (wti.Vector (riscv_fpextend_vl_oneuse
+                     (wti.Vector (riscv_fpextend_vl_sameuser
                                       (vti.Vector vti.RegClass:$rs2),
                                       (vti.Mask true_mask), (XLenVT srcvalue))),
                      (wti.Vector wti.RegClass:$rd)),
@@ -749,11 +749,11 @@ multiclass VPatWidenFPMulSacSDNode_VV_VF_RM<string instruction_name> {
     defvar suffix = vti.LMul.MX # "_E" # vti.SEW;
     let Predicates = !listconcat(GetVTypePredicates<vti>.Predicates,
                                  GetVTypePredicates<wti>.Predicates) in {
-      def : Pat<(fma (wti.Vector (riscv_fpextend_vl_oneuse
+      def : Pat<(fma (wti.Vector (riscv_fpextend_vl_sameuser
                                       (vti.Vector vti.RegClass:$rs1),
                                       (vti.Mask true_mask), (XLenVT srcvalue))),
-                     (riscv_fpextend_vl_oneuse (vti.Vector vti.RegClass:$rs2),
-                                               (vti.Mask true_mask), (XLenVT srcvalue)),
+                     (riscv_fpextend_vl_sameuser (vti.Vector vti.RegClass:$rs2),
+                                                 (vti.Mask true_mask), (XLenVT srcvalue)),
                      (fneg wti.RegClass:$rd)),
                 (!cast<Instruction>(instruction_name#"_VV_"#suffix)
                    wti.RegClass:$rd, vti.RegClass:$rs1, vti.RegClass:$rs2,

diff  --git a/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td b/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td
index 333ae52534681a1..f63c1560f625382 100644
--- a/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td
+++ b/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td
@@ -554,6 +554,11 @@ def riscv_fpextend_vl_oneuse : PatFrag<(ops node:$A, node:$B, node:$C),
   return N->hasOneUse();
 }]>;
 
+def riscv_fpextend_vl_sameuser : PatFrag<(ops node:$A, node:$B, node:$C),
+                                 (riscv_fpextend_vl node:$A, node:$B, node:$C), [{
+  return !N->use_empty() && all_equal(N->users());
+}]>;
+
 def riscv_vfmadd_vl_oneuse : PatFrag<(ops node:$A, node:$B, node:$C, node:$D,
                                           node:$E),
                                      (riscv_vfmadd_vl node:$A, node:$B,

diff  --git a/llvm/test/CodeGen/RISCV/rvv/vfwadd-sdnode.ll b/llvm/test/CodeGen/RISCV/rvv/vfwadd-sdnode.ll
index 68014ff4206f8ad..f7d287a088cc3ed 100644
--- a/llvm/test/CodeGen/RISCV/rvv/vfwadd-sdnode.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/vfwadd-sdnode.ll
@@ -323,3 +323,15 @@ define <vscale x 8 x double> @vfwadd_wf_nxv8f64_2(<vscale x 8 x double> %va, flo
   %vd = fadd <vscale x 8 x double> %va, %splat
   ret <vscale x 8 x double> %vd
 }
+
+define <vscale x 1 x double> @vfwadd_vv_nxv1f64_same_op(<vscale x 1 x float> %va) {
+; CHECK-LABEL: vfwadd_vv_nxv1f64_same_op:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    vsetvli a0, zero, e32, mf2, ta, ma
+; CHECK-NEXT:    vfwadd.vv v9, v8, v8
+; CHECK-NEXT:    vmv1r.v v8, v9
+; CHECK-NEXT:    ret
+  %vb = fpext <vscale x 1 x float> %va to <vscale x 1 x double>
+  %vc = fadd <vscale x 1 x double> %vb, %vb
+  ret <vscale x 1 x double> %vc
+}

diff  --git a/llvm/test/CodeGen/RISCV/rvv/vfwmacc-sdnode.ll b/llvm/test/CodeGen/RISCV/rvv/vfwmacc-sdnode.ll
index f69b2346226ee99..63113b878098943 100644
--- a/llvm/test/CodeGen/RISCV/rvv/vfwmacc-sdnode.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/vfwmacc-sdnode.ll
@@ -1764,3 +1764,28 @@ define <vscale x 8 x double> @vfwnmsac_fv_nxv8f64(<vscale x 8 x double> %va, <vs
   %vg = call <vscale x 8 x double> @llvm.fma.v8f64(<vscale x 8 x double> %vd, <vscale x 8 x double> %vf, <vscale x 8 x double> %va)
   ret <vscale x 8 x double> %vg
 }
+
+define <vscale x 1 x double> @vfwma_vv_nxv1f64_same_op(<vscale x 1 x float> %va, <vscale x 1 x double> %vb) {
+; CHECK-LABEL: vfwma_vv_nxv1f64_same_op:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    vsetvli a0, zero, e32, mf2, ta, ma
+; CHECK-NEXT:    vfwmacc.vv v9, v8, v8
+; CHECK-NEXT:    vmv1r.v v8, v9
+; CHECK-NEXT:    ret
+  %vc = fpext <vscale x 1 x float> %va to <vscale x 1 x double>
+  %vd = call <vscale x 1 x double> @llvm.fma(<vscale x 1 x double> %vc, <vscale x 1 x double> %vc, <vscale x 1 x double> %vb)
+  ret <vscale x 1 x double> %vd
+}
+
+define <vscale x 1 x double> @vfwmsac_vv_nxv1f64_same_op(<vscale x 1 x float> %va, <vscale x 1 x double> %vb) {
+; CHECK-LABEL: vfwmsac_vv_nxv1f64_same_op:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    vsetvli a0, zero, e32, mf2, ta, ma
+; CHECK-NEXT:    vfwmsac.vv v9, v8, v8
+; CHECK-NEXT:    vmv1r.v v8, v9
+; CHECK-NEXT:    ret
+  %vc = fpext <vscale x 1 x float> %va to <vscale x 1 x double>
+  %vd = fneg <vscale x 1 x double> %vb
+  %ve = call <vscale x 1 x double> @llvm.fma(<vscale x 1 x double> %vc, <vscale x 1 x double> %vc, <vscale x 1 x double> %vd)
+  ret <vscale x 1 x double> %ve
+}

diff  --git a/llvm/test/CodeGen/RISCV/rvv/vfwmul-sdnode.ll b/llvm/test/CodeGen/RISCV/rvv/vfwmul-sdnode.ll
index f00ff4b6d2cec2a..8cc8c5cffca6b27 100644
--- a/llvm/test/CodeGen/RISCV/rvv/vfwmul-sdnode.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/vfwmul-sdnode.ll
@@ -175,3 +175,15 @@ define <vscale x 8 x double> @vfwmul_vf_nxv8f64_2(<vscale x 8 x float> %va, floa
   %ve = fmul <vscale x 8 x double> %vc, %splat
   ret <vscale x 8 x double> %ve
 }
+
+define <vscale x 1 x double> @vfwmul_vv_nxv1f64_same_op(<vscale x 1 x float> %va) {
+; CHECK-LABEL: vfwmul_vv_nxv1f64_same_op:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    vsetvli a0, zero, e32, mf2, ta, ma
+; CHECK-NEXT:    vfwmul.vv v9, v8, v8
+; CHECK-NEXT:    vmv1r.v v8, v9
+; CHECK-NEXT:    ret
+  %vb = fpext <vscale x 1 x float> %va to <vscale x 1 x double>
+  %vc = fmul <vscale x 1 x double> %vb, %vb
+  ret <vscale x 1 x double> %vc
+}

diff  --git a/llvm/test/CodeGen/RISCV/rvv/vfwsub-sdnode.ll b/llvm/test/CodeGen/RISCV/rvv/vfwsub-sdnode.ll
index b9f66d5d30825d4..d0cb64d98666187 100644
--- a/llvm/test/CodeGen/RISCV/rvv/vfwsub-sdnode.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/vfwsub-sdnode.ll
@@ -323,3 +323,15 @@ define <vscale x 8 x double> @vfwsub_wf_nxv8f64_2(<vscale x 8 x double> %va, flo
   %vd = fsub <vscale x 8 x double> %va, %splat
   ret <vscale x 8 x double> %vd
 }
+
+define <vscale x 1 x double> @vfwsub_vv_nxv1f64_same_op(<vscale x 1 x float> %va) {
+; CHECK-LABEL: vfwsub_vv_nxv1f64_same_op:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    vsetvli a0, zero, e32, mf2, ta, ma
+; CHECK-NEXT:    vfwsub.vv v9, v8, v8
+; CHECK-NEXT:    vmv1r.v v8, v9
+; CHECK-NEXT:    ret
+  %vb = fpext <vscale x 1 x float> %va to <vscale x 1 x double>
+  %vc = fsub <vscale x 1 x double> %vb, %vb
+  ret <vscale x 1 x double> %vc
+}


        


More information about the llvm-commits mailing list