[llvm] [RISCV][ISel] Combine scalable vector fadd/fsub/fmul with fp extend. (PR #88615)

via llvm-commits llvm-commits at lists.llvm.org
Sat Apr 13 04:43:30 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-backend-risc-v

Author: Chia (sun-jacobi)

<details>
<summary>Changes</summary>

Extend D133739, #<!-- -->76785 and ##<!-- -->81248 to support combining scalable vector fadd/fsub/fmul with fp extend.

Specifically, this patch works for the below optimization case:

### Source code
```
define void @<!-- -->vfwmul_v2f32_multiple_users(ptr %x, ptr %y, ptr %z, <vscale x 2 x float> %a, <vscale x 2 x float> %b, <vscale x 2 x float> %b2) {
  %c = fpext <vscale x 2 x float> %a to <vscale x 2 x double>
  %d = fpext <vscale x 2 x float> %b to <vscale x 2 x double>
  %d2 = fpext <vscale x 2 x float> %b2 to <vscale x 2 x double>
  %e = fmul <vscale x 2 x double> %c, %d
  %f = fadd <vscale x 2 x double> %c, %d2
  %g = fsub <vscale x 2 x double> %d, %d2
  store <vscale x 2 x double> %e, ptr %x
  store <vscale x 2 x double> %f, ptr %y
  store <vscale x 2 x double> %g, ptr %z
  ret void
}
```

### Before this patch
[Compiler Explorer](https://godbolt.org/z/jMeYsjoGq)
```
vfwmul_v2f32_multiple_users: 
        vsetvli a3, zero, e32, m1, ta, ma
        vfwcvt.f.f.v    v12, v8
        vfwcvt.f.f.v    v14, v9
        vfwcvt.f.f.v    v8, v10
        vsetvli zero, zero, e64, m2, ta, ma
        vfmul.vv        v10, v12, v14
        vfadd.vv        v12, v12, v8
        vfsub.vv        v8, v14, v8
        vs2r.v  v10, (a0)
        vs2r.v  v12, (a1)
        vs2r.v  v8, (a2)
        ret
```

### After this patch
```
vfwmul_v2f32_multiple_users:
        vsetvli a3, zero, e32, m1, ta, ma
        vfwmul.vv v12, v8, v9
        vfwadd.vv v14, v8, v10
        vfwsub.vv v16, v9, v10
        vs2r.v v12, (a0)
        vs2r.v v14, (a1)
        vs2r.v v16, (a2)
        ret
```

---
Full diff: https://github.com/llvm/llvm-project/pull/88615.diff


2 Files Affected:

- (modified) llvm/lib/Target/RISCV/RISCVISelLowering.cpp (+60-10) 
- (added) llvm/test/CodeGen/RISCV/rvv/vscale-vfw-web-simplification.ll (+99) 


``````````diff
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index 5a572002091ff3..b8b926a54ea908 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -1430,6 +1430,8 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
                          ISD::EXPERIMENTAL_VP_REVERSE, ISD::MUL,
                          ISD::SDIV, ISD::UDIV, ISD::SREM, ISD::UREM,
                          ISD::INSERT_VECTOR_ELT, ISD::ABS});
+  if (Subtarget.hasVInstructionsAnyF())
+    setTargetDAGCombine({ISD::FADD, ISD::FSUB, ISD::FMUL});
   if (Subtarget.hasVendorXTHeadMemPair())
     setTargetDAGCombine({ISD::LOAD, ISD::STORE});
   if (Subtarget.useRVVForFixedLengthVectors())
@@ -13597,6 +13599,13 @@ struct NodeExtensionHelper {
     case RISCVISD::VZEXT_VL:
     case RISCVISD::FP_EXTEND_VL:
       return OrigOperand.getOperand(0);
+    case ISD::SPLAT_VECTOR: {
+      SDValue Op = OrigOperand.getOperand(0);
+      if (Op.getOpcode() == ISD::FP_EXTEND)
+        return Op;
+      return OrigOperand;
+    }
+
     default:
       return OrigOperand;
     }
@@ -13735,12 +13744,15 @@ struct NodeExtensionHelper {
   /// Opcode(fpext(a), fpext(b)) -> newOpcode(a, b)
   static unsigned getFPExtOpcode(unsigned Opcode) {
     switch (Opcode) {
+    case ISD::FADD:
     case RISCVISD::FADD_VL:
     case RISCVISD::VFWADD_W_VL:
       return RISCVISD::VFWADD_VL;
+    case ISD::FSUB:
     case RISCVISD::FSUB_VL:
     case RISCVISD::VFWSUB_W_VL:
       return RISCVISD::VFWSUB_VL;
+    case ISD::FMUL:
     case RISCVISD::FMUL_VL:
       return RISCVISD::VFWMUL_VL;
     default:
@@ -13769,8 +13781,10 @@ struct NodeExtensionHelper {
     case RISCVISD::SUB_VL:
       return SupportsExt == ExtKind::SExt ? RISCVISD::VWSUB_W_VL
                                           : RISCVISD::VWSUBU_W_VL;
+    case ISD::FADD:
     case RISCVISD::FADD_VL:
       return RISCVISD::VFWADD_W_VL;
+    case ISD::FSUB:
     case RISCVISD::FSUB_VL:
       return RISCVISD::VFWSUB_W_VL;
     default:
@@ -13824,6 +13838,10 @@ struct NodeExtensionHelper {
                               APInt::getBitsSetFrom(ScalarBits, NarrowSize)))
       SupportsZExt = true;
 
+    if (Op.getOpcode() == ISD::FP_EXTEND &&
+        NarrowSize >= (Subtarget.hasVInstructionsF16() ? 16 : 32))
+      SupportsFPExt = true;
+
     EnforceOneUse = false;
   }
 
@@ -13854,6 +13872,7 @@ struct NodeExtensionHelper {
 
       SupportsZExt = Opc == ISD::ZERO_EXTEND;
       SupportsSExt = Opc == ISD::SIGN_EXTEND;
+      SupportsFPExt = Opc == ISD::FP_EXTEND;
       break;
     }
     case RISCVISD::VZEXT_VL:
@@ -13862,9 +13881,18 @@ struct NodeExtensionHelper {
     case RISCVISD::VSEXT_VL:
       SupportsSExt = true;
       break;
-    case RISCVISD::FP_EXTEND_VL:
+    case RISCVISD::FP_EXTEND_VL: {
+      SDValue NarrowElt = OrigOperand.getOperand(0);
+      MVT NarrowVT = NarrowElt.getSimpleValueType();
+
+      if (!Subtarget.hasVInstructionsF16() &&
+          NarrowVT.getVectorElementType() == MVT::f16)
+        break;
+
       SupportsFPExt = true;
       break;
+    }
+
     case ISD::SPLAT_VECTOR:
     case RISCVISD::VMV_V_X_VL:
       fillUpExtensionSupportForSplat(Root, DAG, Subtarget);
@@ -13880,13 +13908,16 @@ struct NodeExtensionHelper {
     switch (Root->getOpcode()) {
     case ISD::ADD:
     case ISD::SUB:
-    case ISD::MUL: {
+    case ISD::MUL:
       return Root->getValueType(0).isScalableVector();
-    }
-    case ISD::OR: {
+    case ISD::OR:
       return Root->getValueType(0).isScalableVector() &&
              Root->getFlags().hasDisjoint();
-    }
+    case ISD::FADD:
+    case ISD::FSUB:
+    case ISD::FMUL:
+      return Root->getValueType(0).isScalableVector() &&
+             Subtarget.hasVInstructionsAnyF();
     // Vector Widening Integer Add/Sub/Mul Instructions
     case RISCVISD::ADD_VL:
     case RISCVISD::MUL_VL:
@@ -13963,7 +13994,10 @@ struct NodeExtensionHelper {
     case ISD::SUB:
     case ISD::MUL:
     case ISD::OR:
-    case ISD::SHL: {
+    case ISD::SHL:
+    case ISD::FADD:
+    case ISD::FSUB:
+    case ISD::FMUL: {
       SDLoc DL(Root);
       MVT VT = Root->getSimpleValueType(0);
       return getDefaultScalableVLOps(VT, DL, DAG, Subtarget);
@@ -13980,6 +14014,8 @@ struct NodeExtensionHelper {
     case ISD::ADD:
     case ISD::MUL:
     case ISD::OR:
+    case ISD::FADD:
+    case ISD::FMUL:
     case RISCVISD::ADD_VL:
     case RISCVISD::MUL_VL:
     case RISCVISD::VWADD_W_VL:
@@ -13989,6 +14025,7 @@ struct NodeExtensionHelper {
     case RISCVISD::VFWADD_W_VL:
       return true;
     case ISD::SUB:
+    case ISD::FSUB:
     case RISCVISD::SUB_VL:
     case RISCVISD::VWSUB_W_VL:
     case RISCVISD::VWSUBU_W_VL:
@@ -14050,6 +14087,9 @@ struct CombineResult {
     case ISD::MUL:
     case ISD::OR:
     case ISD::SHL:
+    case ISD::FADD:
+    case ISD::FSUB:
+    case ISD::FMUL:
       Merge = DAG.getUNDEF(Root->getValueType(0));
       break;
     }
@@ -14192,6 +14232,8 @@ NodeExtensionHelper::getSupportedFoldings(const SDNode *Root) {
   case ISD::ADD:
   case ISD::SUB:
   case ISD::OR:
+  case ISD::FADD:
+  case ISD::FSUB:
   case RISCVISD::ADD_VL:
   case RISCVISD::SUB_VL:
   case RISCVISD::FADD_VL:
@@ -14201,6 +14243,7 @@ NodeExtensionHelper::getSupportedFoldings(const SDNode *Root) {
     // add|sub|fadd|fsub -> vwadd(u)_w|vwsub(u)_w}|vfwadd_w|vfwsub_w
     Strategies.push_back(canFoldToVW_W);
     break;
+  case ISD::FMUL:
   case RISCVISD::FMUL_VL:
     Strategies.push_back(canFoldToVWWithSameExtension);
     break;
@@ -14244,9 +14287,9 @@ NodeExtensionHelper::getSupportedFoldings(const SDNode *Root) {
 /// sub | sub_vl -> vwsub(u) | vwsub(u)_w
 /// mul | mul_vl -> vwmul(u) | vwmul_su
 /// shl | shl_vl -> vwsll
-/// fadd_vl ->  vfwadd | vfwadd_w
-/// fsub_vl ->  vfwsub | vfwsub_w
-/// fmul_vl ->  vfwmul
+/// fadd | fadd_vl ->  vfwadd | vfwadd_w
+/// fsub | fsub_vl ->  vfwsub | vfwsub_w
+/// fmul | fmul_vl ->  vfwmul
 /// vwadd_w(u) -> vwadd(u)
 /// vwsub_w(u) -> vwsub(u)
 /// vfwadd_w -> vfwadd
@@ -15921,7 +15964,14 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N,
     if (SDValue V = combineBinOpOfZExt(N, DAG))
       return V;
     break;
-  case ISD::FADD:
+  case ISD::FSUB:
+  case ISD::FMUL:
+    return combineBinOp_VLToVWBinOp_VL(N, DCI, Subtarget);
+  case ISD::FADD: {
+    if (SDValue V = combineBinOp_VLToVWBinOp_VL(N, DCI, Subtarget))
+      return V;
+    [[fallthrough]];
+  }
   case ISD::UMAX:
   case ISD::UMIN:
   case ISD::SMAX:
diff --git a/llvm/test/CodeGen/RISCV/rvv/vscale-vfw-web-simplification.ll b/llvm/test/CodeGen/RISCV/rvv/vscale-vfw-web-simplification.ll
new file mode 100644
index 00000000000000..0d1713acfc0cd0
--- /dev/null
+++ b/llvm/test/CodeGen/RISCV/rvv/vscale-vfw-web-simplification.ll
@@ -0,0 +1,99 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py
+; RUN: llc -mtriple=riscv64 -mattr=+v,+zfh,+zvfh,+f,+d -verify-machineinstrs %s -o - --riscv-lower-ext-max-web-size=1 | FileCheck %s --check-prefixes=NO_FOLDING
+; RUN: llc -mtriple=riscv64 -mattr=+v,+zfh,+zvfh,+f,+d -verify-machineinstrs %s -o - --riscv-lower-ext-max-web-size=2 | FileCheck %s --check-prefixes=NO_FOLDING
+; RUN: llc -mtriple=riscv64 -mattr=+v,+zfh,+zvfh,+f,+d -verify-machineinstrs %s -o - --riscv-lower-ext-max-web-size=3 | FileCheck %s --check-prefixes=FOLDING,ZVFH
+; RUN: llc -mtriple=riscv64 -mattr=+v,+zfh,+zvfhmin,+f,+d -verify-machineinstrs %s -o - --riscv-lower-ext-max-web-size=3 | FileCheck %s --check-prefixes=FOLDING,ZVFHMIN
+; Check that the default value enables the web folding and
+; that it is bigger than 3.
+; RUN: llc -mtriple=riscv64 -mattr=+v,+zfh,+zvfh,+f,+d -verify-machineinstrs %s -o - | FileCheck %s --check-prefixes=FOLDING
+
+define void @vfwmul_v2f116_multiple_users(ptr %x, ptr %y, ptr %z, <vscale x 2 x half> %a, <vscale x 2 x half> %b, <vscale x 2 x half> %b2) {
+; NO_FOLDING-LABEL: vfwmul_v2f116_multiple_users:
+; NO_FOLDING:       # %bb.0:
+; NO_FOLDING-NEXT:    vsetvli a3, zero, e16, mf2, ta, ma
+; NO_FOLDING-NEXT:    vfwcvt.f.f.v v11, v8
+; NO_FOLDING-NEXT:    vfwcvt.f.f.v v8, v9
+; NO_FOLDING-NEXT:    vfwcvt.f.f.v v9, v10
+; NO_FOLDING-NEXT:    vsetvli zero, zero, e32, m1, ta, ma
+; NO_FOLDING-NEXT:    vfmul.vv v10, v11, v8
+; NO_FOLDING-NEXT:    vfadd.vv v11, v11, v9
+; NO_FOLDING-NEXT:    vfsub.vv v8, v8, v9
+; NO_FOLDING-NEXT:    vs1r.v v10, (a0)
+; NO_FOLDING-NEXT:    vs1r.v v11, (a1)
+; NO_FOLDING-NEXT:    vs1r.v v8, (a2)
+; NO_FOLDING-NEXT:    ret
+;
+; ZVFH-LABEL: vfwmul_v2f116_multiple_users:
+; ZVFH:       # %bb.0:
+; ZVFH-NEXT:    vsetvli a3, zero, e16, mf2, ta, ma
+; ZVFH-NEXT:    vfwmul.vv v11, v8, v9
+; ZVFH-NEXT:    vfwadd.vv v12, v8, v10
+; ZVFH-NEXT:    vfwsub.vv v8, v9, v10
+; ZVFH-NEXT:    vs1r.v v11, (a0)
+; ZVFH-NEXT:    vs1r.v v12, (a1)
+; ZVFH-NEXT:    vs1r.v v8, (a2)
+; ZVFH-NEXT:    ret
+;
+; ZVFHMIN-LABEL: vfwmul_v2f116_multiple_users:
+; ZVFHMIN:       # %bb.0:
+; ZVFHMIN-NEXT:    vsetvli a3, zero, e16, mf2, ta, ma
+; ZVFHMIN-NEXT:    vfwcvt.f.f.v v11, v8
+; ZVFHMIN-NEXT:    vfwcvt.f.f.v v8, v9
+; ZVFHMIN-NEXT:    vfwcvt.f.f.v v9, v10
+; ZVFHMIN-NEXT:    vsetvli zero, zero, e32, m1, ta, ma
+; ZVFHMIN-NEXT:    vfmul.vv v10, v11, v8
+; ZVFHMIN-NEXT:    vfadd.vv v11, v11, v9
+; ZVFHMIN-NEXT:    vfsub.vv v8, v8, v9
+; ZVFHMIN-NEXT:    vs1r.v v10, (a0)
+; ZVFHMIN-NEXT:    vs1r.v v11, (a1)
+; ZVFHMIN-NEXT:    vs1r.v v8, (a2)
+; ZVFHMIN-NEXT:    ret
+  %c = fpext <vscale x 2 x half> %a to <vscale x 2 x float>
+  %d = fpext <vscale x 2 x half> %b to <vscale x 2 x float>
+  %d2 = fpext <vscale x 2 x half> %b2 to <vscale x 2 x float>
+  %e = fmul <vscale x 2 x float> %c, %d
+  %f = fadd <vscale x 2 x float> %c, %d2
+  %g = fsub <vscale x 2 x float> %d, %d2
+  store <vscale x 2 x float> %e, ptr %x
+  store <vscale x 2 x float> %f, ptr %y
+  store <vscale x 2 x float> %g, ptr %z
+  ret void
+}
+
+define void @vfwmul_v2f32_multiple_users(ptr %x, ptr %y, ptr %z, <vscale x 2 x float> %a, <vscale x 2 x float> %b, <vscale x 2 x float> %b2) {
+; NO_FOLDING-LABEL: vfwmul_v2f32_multiple_users:
+; NO_FOLDING:       # %bb.0:
+; NO_FOLDING-NEXT:    vsetvli a3, zero, e32, m1, ta, ma
+; NO_FOLDING-NEXT:    vfwcvt.f.f.v v12, v8
+; NO_FOLDING-NEXT:    vfwcvt.f.f.v v14, v9
+; NO_FOLDING-NEXT:    vfwcvt.f.f.v v8, v10
+; NO_FOLDING-NEXT:    vsetvli zero, zero, e64, m2, ta, ma
+; NO_FOLDING-NEXT:    vfmul.vv v10, v12, v14
+; NO_FOLDING-NEXT:    vfadd.vv v12, v12, v8
+; NO_FOLDING-NEXT:    vfsub.vv v8, v14, v8
+; NO_FOLDING-NEXT:    vs2r.v v10, (a0)
+; NO_FOLDING-NEXT:    vs2r.v v12, (a1)
+; NO_FOLDING-NEXT:    vs2r.v v8, (a2)
+; NO_FOLDING-NEXT:    ret
+;
+; FOLDING-LABEL: vfwmul_v2f32_multiple_users:
+; FOLDING:       # %bb.0:
+; FOLDING-NEXT:    vsetvli a3, zero, e32, m1, ta, ma
+; FOLDING-NEXT:    vfwmul.vv v12, v8, v9
+; FOLDING-NEXT:    vfwadd.vv v14, v8, v10
+; FOLDING-NEXT:    vfwsub.vv v16, v9, v10
+; FOLDING-NEXT:    vs2r.v v12, (a0)
+; FOLDING-NEXT:    vs2r.v v14, (a1)
+; FOLDING-NEXT:    vs2r.v v16, (a2)
+; FOLDING-NEXT:    ret
+  %c = fpext <vscale x 2 x float> %a to <vscale x 2 x double>
+  %d = fpext <vscale x 2 x float> %b to <vscale x 2 x double>
+  %d2 = fpext <vscale x 2 x float> %b2 to <vscale x 2 x double>
+  %e = fmul <vscale x 2 x double> %c, %d
+  %f = fadd <vscale x 2 x double> %c, %d2
+  %g = fsub <vscale x 2 x double> %d, %d2
+  store <vscale x 2 x double> %e, ptr %x
+  store <vscale x 2 x double> %f, ptr %y
+  store <vscale x 2 x double> %g, ptr %z
+  ret void
+}

``````````

</details>


https://github.com/llvm/llvm-project/pull/88615


More information about the llvm-commits mailing list