[llvm] [RISCV][ISel] Combine scalable vector fadd/fsub/fmul with fp extend. (PR #88615)
via llvm-commits
llvm-commits at lists.llvm.org
Sat Apr 13 04:43:30 PDT 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-backend-risc-v
Author: Chia (sun-jacobi)
<details>
<summary>Changes</summary>
Extend D133739, #<!-- -->76785 and ##<!-- -->81248 to support combining scalable vector fadd/fsub/fmul with fp extend.
Specifically, this patch works for the below optimization case:
### Source code
```
define void @<!-- -->vfwmul_v2f32_multiple_users(ptr %x, ptr %y, ptr %z, <vscale x 2 x float> %a, <vscale x 2 x float> %b, <vscale x 2 x float> %b2) {
%c = fpext <vscale x 2 x float> %a to <vscale x 2 x double>
%d = fpext <vscale x 2 x float> %b to <vscale x 2 x double>
%d2 = fpext <vscale x 2 x float> %b2 to <vscale x 2 x double>
%e = fmul <vscale x 2 x double> %c, %d
%f = fadd <vscale x 2 x double> %c, %d2
%g = fsub <vscale x 2 x double> %d, %d2
store <vscale x 2 x double> %e, ptr %x
store <vscale x 2 x double> %f, ptr %y
store <vscale x 2 x double> %g, ptr %z
ret void
}
```
### Before this patch
[Compiler Explorer](https://godbolt.org/z/jMeYsjoGq)
```
vfwmul_v2f32_multiple_users:
vsetvli a3, zero, e32, m1, ta, ma
vfwcvt.f.f.v v12, v8
vfwcvt.f.f.v v14, v9
vfwcvt.f.f.v v8, v10
vsetvli zero, zero, e64, m2, ta, ma
vfmul.vv v10, v12, v14
vfadd.vv v12, v12, v8
vfsub.vv v8, v14, v8
vs2r.v v10, (a0)
vs2r.v v12, (a1)
vs2r.v v8, (a2)
ret
```
### After this patch
```
vfwmul_v2f32_multiple_users:
vsetvli a3, zero, e32, m1, ta, ma
vfwmul.vv v12, v8, v9
vfwadd.vv v14, v8, v10
vfwsub.vv v16, v9, v10
vs2r.v v12, (a0)
vs2r.v v14, (a1)
vs2r.v v16, (a2)
ret
```
---
Full diff: https://github.com/llvm/llvm-project/pull/88615.diff
2 Files Affected:
- (modified) llvm/lib/Target/RISCV/RISCVISelLowering.cpp (+60-10)
- (added) llvm/test/CodeGen/RISCV/rvv/vscale-vfw-web-simplification.ll (+99)
``````````diff
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index 5a572002091ff3..b8b926a54ea908 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -1430,6 +1430,8 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
ISD::EXPERIMENTAL_VP_REVERSE, ISD::MUL,
ISD::SDIV, ISD::UDIV, ISD::SREM, ISD::UREM,
ISD::INSERT_VECTOR_ELT, ISD::ABS});
+ if (Subtarget.hasVInstructionsAnyF())
+ setTargetDAGCombine({ISD::FADD, ISD::FSUB, ISD::FMUL});
if (Subtarget.hasVendorXTHeadMemPair())
setTargetDAGCombine({ISD::LOAD, ISD::STORE});
if (Subtarget.useRVVForFixedLengthVectors())
@@ -13597,6 +13599,13 @@ struct NodeExtensionHelper {
case RISCVISD::VZEXT_VL:
case RISCVISD::FP_EXTEND_VL:
return OrigOperand.getOperand(0);
+ case ISD::SPLAT_VECTOR: {
+ SDValue Op = OrigOperand.getOperand(0);
+ if (Op.getOpcode() == ISD::FP_EXTEND)
+ return Op;
+ return OrigOperand;
+ }
+
default:
return OrigOperand;
}
@@ -13735,12 +13744,15 @@ struct NodeExtensionHelper {
/// Opcode(fpext(a), fpext(b)) -> newOpcode(a, b)
static unsigned getFPExtOpcode(unsigned Opcode) {
switch (Opcode) {
+ case ISD::FADD:
case RISCVISD::FADD_VL:
case RISCVISD::VFWADD_W_VL:
return RISCVISD::VFWADD_VL;
+ case ISD::FSUB:
case RISCVISD::FSUB_VL:
case RISCVISD::VFWSUB_W_VL:
return RISCVISD::VFWSUB_VL;
+ case ISD::FMUL:
case RISCVISD::FMUL_VL:
return RISCVISD::VFWMUL_VL;
default:
@@ -13769,8 +13781,10 @@ struct NodeExtensionHelper {
case RISCVISD::SUB_VL:
return SupportsExt == ExtKind::SExt ? RISCVISD::VWSUB_W_VL
: RISCVISD::VWSUBU_W_VL;
+ case ISD::FADD:
case RISCVISD::FADD_VL:
return RISCVISD::VFWADD_W_VL;
+ case ISD::FSUB:
case RISCVISD::FSUB_VL:
return RISCVISD::VFWSUB_W_VL;
default:
@@ -13824,6 +13838,10 @@ struct NodeExtensionHelper {
APInt::getBitsSetFrom(ScalarBits, NarrowSize)))
SupportsZExt = true;
+ if (Op.getOpcode() == ISD::FP_EXTEND &&
+ NarrowSize >= (Subtarget.hasVInstructionsF16() ? 16 : 32))
+ SupportsFPExt = true;
+
EnforceOneUse = false;
}
@@ -13854,6 +13872,7 @@ struct NodeExtensionHelper {
SupportsZExt = Opc == ISD::ZERO_EXTEND;
SupportsSExt = Opc == ISD::SIGN_EXTEND;
+ SupportsFPExt = Opc == ISD::FP_EXTEND;
break;
}
case RISCVISD::VZEXT_VL:
@@ -13862,9 +13881,18 @@ struct NodeExtensionHelper {
case RISCVISD::VSEXT_VL:
SupportsSExt = true;
break;
- case RISCVISD::FP_EXTEND_VL:
+ case RISCVISD::FP_EXTEND_VL: {
+ SDValue NarrowElt = OrigOperand.getOperand(0);
+ MVT NarrowVT = NarrowElt.getSimpleValueType();
+
+ if (!Subtarget.hasVInstructionsF16() &&
+ NarrowVT.getVectorElementType() == MVT::f16)
+ break;
+
SupportsFPExt = true;
break;
+ }
+
case ISD::SPLAT_VECTOR:
case RISCVISD::VMV_V_X_VL:
fillUpExtensionSupportForSplat(Root, DAG, Subtarget);
@@ -13880,13 +13908,16 @@ struct NodeExtensionHelper {
switch (Root->getOpcode()) {
case ISD::ADD:
case ISD::SUB:
- case ISD::MUL: {
+ case ISD::MUL:
return Root->getValueType(0).isScalableVector();
- }
- case ISD::OR: {
+ case ISD::OR:
return Root->getValueType(0).isScalableVector() &&
Root->getFlags().hasDisjoint();
- }
+ case ISD::FADD:
+ case ISD::FSUB:
+ case ISD::FMUL:
+ return Root->getValueType(0).isScalableVector() &&
+ Subtarget.hasVInstructionsAnyF();
// Vector Widening Integer Add/Sub/Mul Instructions
case RISCVISD::ADD_VL:
case RISCVISD::MUL_VL:
@@ -13963,7 +13994,10 @@ struct NodeExtensionHelper {
case ISD::SUB:
case ISD::MUL:
case ISD::OR:
- case ISD::SHL: {
+ case ISD::SHL:
+ case ISD::FADD:
+ case ISD::FSUB:
+ case ISD::FMUL: {
SDLoc DL(Root);
MVT VT = Root->getSimpleValueType(0);
return getDefaultScalableVLOps(VT, DL, DAG, Subtarget);
@@ -13980,6 +14014,8 @@ struct NodeExtensionHelper {
case ISD::ADD:
case ISD::MUL:
case ISD::OR:
+ case ISD::FADD:
+ case ISD::FMUL:
case RISCVISD::ADD_VL:
case RISCVISD::MUL_VL:
case RISCVISD::VWADD_W_VL:
@@ -13989,6 +14025,7 @@ struct NodeExtensionHelper {
case RISCVISD::VFWADD_W_VL:
return true;
case ISD::SUB:
+ case ISD::FSUB:
case RISCVISD::SUB_VL:
case RISCVISD::VWSUB_W_VL:
case RISCVISD::VWSUBU_W_VL:
@@ -14050,6 +14087,9 @@ struct CombineResult {
case ISD::MUL:
case ISD::OR:
case ISD::SHL:
+ case ISD::FADD:
+ case ISD::FSUB:
+ case ISD::FMUL:
Merge = DAG.getUNDEF(Root->getValueType(0));
break;
}
@@ -14192,6 +14232,8 @@ NodeExtensionHelper::getSupportedFoldings(const SDNode *Root) {
case ISD::ADD:
case ISD::SUB:
case ISD::OR:
+ case ISD::FADD:
+ case ISD::FSUB:
case RISCVISD::ADD_VL:
case RISCVISD::SUB_VL:
case RISCVISD::FADD_VL:
@@ -14201,6 +14243,7 @@ NodeExtensionHelper::getSupportedFoldings(const SDNode *Root) {
// add|sub|fadd|fsub -> vwadd(u)_w|vwsub(u)_w}|vfwadd_w|vfwsub_w
Strategies.push_back(canFoldToVW_W);
break;
+ case ISD::FMUL:
case RISCVISD::FMUL_VL:
Strategies.push_back(canFoldToVWWithSameExtension);
break;
@@ -14244,9 +14287,9 @@ NodeExtensionHelper::getSupportedFoldings(const SDNode *Root) {
/// sub | sub_vl -> vwsub(u) | vwsub(u)_w
/// mul | mul_vl -> vwmul(u) | vwmul_su
/// shl | shl_vl -> vwsll
-/// fadd_vl -> vfwadd | vfwadd_w
-/// fsub_vl -> vfwsub | vfwsub_w
-/// fmul_vl -> vfwmul
+/// fadd | fadd_vl -> vfwadd | vfwadd_w
+/// fsub | fsub_vl -> vfwsub | vfwsub_w
+/// fmul | fmul_vl -> vfwmul
/// vwadd_w(u) -> vwadd(u)
/// vwsub_w(u) -> vwsub(u)
/// vfwadd_w -> vfwadd
@@ -15921,7 +15964,14 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N,
if (SDValue V = combineBinOpOfZExt(N, DAG))
return V;
break;
- case ISD::FADD:
+ case ISD::FSUB:
+ case ISD::FMUL:
+ return combineBinOp_VLToVWBinOp_VL(N, DCI, Subtarget);
+ case ISD::FADD: {
+ if (SDValue V = combineBinOp_VLToVWBinOp_VL(N, DCI, Subtarget))
+ return V;
+ [[fallthrough]];
+ }
case ISD::UMAX:
case ISD::UMIN:
case ISD::SMAX:
diff --git a/llvm/test/CodeGen/RISCV/rvv/vscale-vfw-web-simplification.ll b/llvm/test/CodeGen/RISCV/rvv/vscale-vfw-web-simplification.ll
new file mode 100644
index 00000000000000..0d1713acfc0cd0
--- /dev/null
+++ b/llvm/test/CodeGen/RISCV/rvv/vscale-vfw-web-simplification.ll
@@ -0,0 +1,99 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py
+; RUN: llc -mtriple=riscv64 -mattr=+v,+zfh,+zvfh,+f,+d -verify-machineinstrs %s -o - --riscv-lower-ext-max-web-size=1 | FileCheck %s --check-prefixes=NO_FOLDING
+; RUN: llc -mtriple=riscv64 -mattr=+v,+zfh,+zvfh,+f,+d -verify-machineinstrs %s -o - --riscv-lower-ext-max-web-size=2 | FileCheck %s --check-prefixes=NO_FOLDING
+; RUN: llc -mtriple=riscv64 -mattr=+v,+zfh,+zvfh,+f,+d -verify-machineinstrs %s -o - --riscv-lower-ext-max-web-size=3 | FileCheck %s --check-prefixes=FOLDING,ZVFH
+; RUN: llc -mtriple=riscv64 -mattr=+v,+zfh,+zvfhmin,+f,+d -verify-machineinstrs %s -o - --riscv-lower-ext-max-web-size=3 | FileCheck %s --check-prefixes=FOLDING,ZVFHMIN
+; Check that the default value enables the web folding and
+; that it is bigger than 3.
+; RUN: llc -mtriple=riscv64 -mattr=+v,+zfh,+zvfh,+f,+d -verify-machineinstrs %s -o - | FileCheck %s --check-prefixes=FOLDING
+
+define void @vfwmul_v2f116_multiple_users(ptr %x, ptr %y, ptr %z, <vscale x 2 x half> %a, <vscale x 2 x half> %b, <vscale x 2 x half> %b2) {
+; NO_FOLDING-LABEL: vfwmul_v2f116_multiple_users:
+; NO_FOLDING: # %bb.0:
+; NO_FOLDING-NEXT: vsetvli a3, zero, e16, mf2, ta, ma
+; NO_FOLDING-NEXT: vfwcvt.f.f.v v11, v8
+; NO_FOLDING-NEXT: vfwcvt.f.f.v v8, v9
+; NO_FOLDING-NEXT: vfwcvt.f.f.v v9, v10
+; NO_FOLDING-NEXT: vsetvli zero, zero, e32, m1, ta, ma
+; NO_FOLDING-NEXT: vfmul.vv v10, v11, v8
+; NO_FOLDING-NEXT: vfadd.vv v11, v11, v9
+; NO_FOLDING-NEXT: vfsub.vv v8, v8, v9
+; NO_FOLDING-NEXT: vs1r.v v10, (a0)
+; NO_FOLDING-NEXT: vs1r.v v11, (a1)
+; NO_FOLDING-NEXT: vs1r.v v8, (a2)
+; NO_FOLDING-NEXT: ret
+;
+; ZVFH-LABEL: vfwmul_v2f116_multiple_users:
+; ZVFH: # %bb.0:
+; ZVFH-NEXT: vsetvli a3, zero, e16, mf2, ta, ma
+; ZVFH-NEXT: vfwmul.vv v11, v8, v9
+; ZVFH-NEXT: vfwadd.vv v12, v8, v10
+; ZVFH-NEXT: vfwsub.vv v8, v9, v10
+; ZVFH-NEXT: vs1r.v v11, (a0)
+; ZVFH-NEXT: vs1r.v v12, (a1)
+; ZVFH-NEXT: vs1r.v v8, (a2)
+; ZVFH-NEXT: ret
+;
+; ZVFHMIN-LABEL: vfwmul_v2f116_multiple_users:
+; ZVFHMIN: # %bb.0:
+; ZVFHMIN-NEXT: vsetvli a3, zero, e16, mf2, ta, ma
+; ZVFHMIN-NEXT: vfwcvt.f.f.v v11, v8
+; ZVFHMIN-NEXT: vfwcvt.f.f.v v8, v9
+; ZVFHMIN-NEXT: vfwcvt.f.f.v v9, v10
+; ZVFHMIN-NEXT: vsetvli zero, zero, e32, m1, ta, ma
+; ZVFHMIN-NEXT: vfmul.vv v10, v11, v8
+; ZVFHMIN-NEXT: vfadd.vv v11, v11, v9
+; ZVFHMIN-NEXT: vfsub.vv v8, v8, v9
+; ZVFHMIN-NEXT: vs1r.v v10, (a0)
+; ZVFHMIN-NEXT: vs1r.v v11, (a1)
+; ZVFHMIN-NEXT: vs1r.v v8, (a2)
+; ZVFHMIN-NEXT: ret
+ %c = fpext <vscale x 2 x half> %a to <vscale x 2 x float>
+ %d = fpext <vscale x 2 x half> %b to <vscale x 2 x float>
+ %d2 = fpext <vscale x 2 x half> %b2 to <vscale x 2 x float>
+ %e = fmul <vscale x 2 x float> %c, %d
+ %f = fadd <vscale x 2 x float> %c, %d2
+ %g = fsub <vscale x 2 x float> %d, %d2
+ store <vscale x 2 x float> %e, ptr %x
+ store <vscale x 2 x float> %f, ptr %y
+ store <vscale x 2 x float> %g, ptr %z
+ ret void
+}
+
+define void @vfwmul_v2f32_multiple_users(ptr %x, ptr %y, ptr %z, <vscale x 2 x float> %a, <vscale x 2 x float> %b, <vscale x 2 x float> %b2) {
+; NO_FOLDING-LABEL: vfwmul_v2f32_multiple_users:
+; NO_FOLDING: # %bb.0:
+; NO_FOLDING-NEXT: vsetvli a3, zero, e32, m1, ta, ma
+; NO_FOLDING-NEXT: vfwcvt.f.f.v v12, v8
+; NO_FOLDING-NEXT: vfwcvt.f.f.v v14, v9
+; NO_FOLDING-NEXT: vfwcvt.f.f.v v8, v10
+; NO_FOLDING-NEXT: vsetvli zero, zero, e64, m2, ta, ma
+; NO_FOLDING-NEXT: vfmul.vv v10, v12, v14
+; NO_FOLDING-NEXT: vfadd.vv v12, v12, v8
+; NO_FOLDING-NEXT: vfsub.vv v8, v14, v8
+; NO_FOLDING-NEXT: vs2r.v v10, (a0)
+; NO_FOLDING-NEXT: vs2r.v v12, (a1)
+; NO_FOLDING-NEXT: vs2r.v v8, (a2)
+; NO_FOLDING-NEXT: ret
+;
+; FOLDING-LABEL: vfwmul_v2f32_multiple_users:
+; FOLDING: # %bb.0:
+; FOLDING-NEXT: vsetvli a3, zero, e32, m1, ta, ma
+; FOLDING-NEXT: vfwmul.vv v12, v8, v9
+; FOLDING-NEXT: vfwadd.vv v14, v8, v10
+; FOLDING-NEXT: vfwsub.vv v16, v9, v10
+; FOLDING-NEXT: vs2r.v v12, (a0)
+; FOLDING-NEXT: vs2r.v v14, (a1)
+; FOLDING-NEXT: vs2r.v v16, (a2)
+; FOLDING-NEXT: ret
+ %c = fpext <vscale x 2 x float> %a to <vscale x 2 x double>
+ %d = fpext <vscale x 2 x float> %b to <vscale x 2 x double>
+ %d2 = fpext <vscale x 2 x float> %b2 to <vscale x 2 x double>
+ %e = fmul <vscale x 2 x double> %c, %d
+ %f = fadd <vscale x 2 x double> %c, %d2
+ %g = fsub <vscale x 2 x double> %d, %d2
+ store <vscale x 2 x double> %e, ptr %x
+ store <vscale x 2 x double> %f, ptr %y
+ store <vscale x 2 x double> %g, ptr %z
+ ret void
+}
``````````
</details>
https://github.com/llvm/llvm-project/pull/88615
More information about the llvm-commits
mailing list