[llvm] [RISCV][ISel] Combine vector fadd/fsub/fmul with fp extend. (PR #81248)
via llvm-commits
llvm-commits at lists.llvm.org
Fri Feb 9 04:58:30 PST 2024
https://github.com/sun-jacobi created https://github.com/llvm/llvm-project/pull/81248
Extend D133739 and #76785 to support vector widening floating-point add/sub/mul instructions.#80477
Specifically, this patch works for the below optimization case:
### Source code
```
define void @vfwmul_v2f32_multiple_users(ptr %x, ptr %y, ptr %z, <2 x float> %a, <2 x float> %b, <2 x float> %b2) {
%c = fpext <2 x float> %a to <2 x double>
%d = fpext <2 x float> %b to <2 x double>
%d2 = fpext <2 x float> %b2 to <2 x double>
%e = fmul <2 x double> %c, %d
%f = fadd <2 x double> %c, %d2
%g = fsub <2 x double> %d, %d2
store <2 x double> %e, ptr %x
store <2 x double> %f, ptr %y
store <2 x double> %g, ptr %z
ret void
}
```
### Before this patch
[Compiler Explorer](https://godbolt.org/z/aaEMs5s9h
```
vfwmul_v2f32_multiple_users:
vsetivli zero, 2, e32, mf2, ta, ma
vfwcvt.f.f.v v11, v8
vfwcvt.f.f.v v8, v9
vfwcvt.f.f.v v9, v10
vsetvli zero, zero, e64, m1, ta, ma
vfmul.vv v10, v11, v8
vfadd.vv v11, v11, v9
vfsub.vv v8, v8, v9
vse64.v v10, (a0)
vse64.v v11, (a1)
vse64.v v8, (a2)
ret
```
### After this patch
```
vfwmul_v2f32_multiple_users:
vsetivli zero, 2, e32, mf2, ta, ma
vfwmul.vv v11, v8, v9
vfwadd.vv v12, v8, v10
vfwsub.vv v8, v9, v10
vse64.v v11, (a0)
vse64.v v12, (a1)
vse64.v v8, (a2)
```
> [!NOTE]
> Scalable version will be introduced in another patch.
>From a71d733726ea449b87841615c8f3350b8423fe14 Mon Sep 17 00:00:00 2001
From: sun-jacobi <sun1011jacobi at gmail.com>
Date: Fri, 9 Feb 2024 20:47:50 +0800
Subject: [PATCH] [RISCV][ISel] Combine vector fadd/fsub/fmul with fp extend.
---
llvm/lib/Target/RISCV/RISCVISelLowering.cpp | 356 +++++++++---------
.../fixed-vectors-vfw-web-simplification.ll | 88 +++++
.../CodeGen/RISCV/rvv/fixed-vectors-vfwadd.ll | 8 +-
.../CodeGen/RISCV/rvv/fixed-vectors-vfwmul.ll | 14 +-
.../CodeGen/RISCV/rvv/fixed-vectors-vfwsub.ll | 14 +-
5 files changed, 285 insertions(+), 195 deletions(-)
create mode 100644 llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vfw-web-simplification.ll
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index 12c0cd53514dae..4ca834b8cf0f07 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -13209,12 +13209,16 @@ namespace {
// apply a combine.
struct CombineResult;
+enum class ExtKind { ZExt, SExt, FPExt };
+
/// Helper class for folding sign/zero extensions.
/// In particular, this class is used for the following combines:
/// add | add_vl -> vwadd(u) | vwadd(u)_w
/// sub | sub_vl -> vwsub(u) | vwsub(u)_w
/// mul | mul_vl -> vwmul(u) | vwmul_su
-///
+/// fadd -> vfwadd | vfwadd_w
+/// fsub -> vfwsub | vfwsub_w
+/// fmul -> vfwmul
/// An object of this class represents an operand of the operation we want to
/// combine.
/// E.g., when trying to combine `mul_vl a, b`, we will have one instance of
@@ -13228,7 +13232,8 @@ struct CombineResult;
/// - VWADDU_W == add(op0, zext(op1))
/// - VWSUB_W == sub(op0, sext(op1))
/// - VWSUBU_W == sub(op0, zext(op1))
-///
+/// - VFWADD_W == fadd(op0, fpext(op1))
+/// - VFWSUB_W == fsub(op0, fpext(op1))
/// And VMV_V_X_VL, depending on the value, is conceptually equivalent to
/// zext|sext(smaller_value).
struct NodeExtensionHelper {
@@ -13239,6 +13244,8 @@ struct NodeExtensionHelper {
/// instance, a splat constant (e.g., 3), would support being both sign and
/// zero extended.
bool SupportsSExt;
+ /// Records if this operand is like being floating-Point extended.
+ bool SupportsFPExt;
/// This boolean captures whether we care if this operand would still be
/// around after the folding happens.
bool EnforceOneUse;
@@ -13262,6 +13269,7 @@ struct NodeExtensionHelper {
case ISD::SIGN_EXTEND:
case RISCVISD::VSEXT_VL:
case RISCVISD::VZEXT_VL:
+ case RISCVISD::FP_EXTEND_VL:
return OrigOperand.getOperand(0);
default:
return OrigOperand;
@@ -13273,22 +13281,34 @@ struct NodeExtensionHelper {
return OrigOperand.getOpcode() == RISCVISD::VMV_V_X_VL;
}
+ /// Get the extended opcode.
+ unsigned getExtOpc(ExtKind SupportsExt) const {
+ switch (SupportsExt) {
+ case ExtKind::SExt:
+ return RISCVISD::VSEXT_VL;
+ case ExtKind::ZExt:
+ return RISCVISD::VZEXT_VL;
+ case ExtKind::FPExt:
+ return RISCVISD::FP_EXTEND_VL;
+ }
+ }
+
/// Get or create a value that can feed \p Root with the given extension \p
- /// SExt. If \p SExt is std::nullopt, this returns the source of this operand.
- /// \see ::getSource().
+ /// SupportsExt. If \p SExt is std::nullopt, this returns the source of this
+ /// operand. \see ::getSource().
SDValue getOrCreateExtendedOp(SDNode *Root, SelectionDAG &DAG,
const RISCVSubtarget &Subtarget,
- std::optional<bool> SExt) const {
- if (!SExt.has_value())
+ std::optional<ExtKind> SupportsExt) const {
+ if (!SupportsExt.has_value())
return OrigOperand;
- MVT NarrowVT = getNarrowType(Root);
+ MVT NarrowVT = getNarrowType(Root, *SupportsExt);
SDValue Source = getSource();
if (Source.getValueType() == NarrowVT)
return Source;
- unsigned ExtOpc = *SExt ? RISCVISD::VSEXT_VL : RISCVISD::VZEXT_VL;
+ unsigned ExtOpc = getExtOpc(*SupportsExt);
// If we need an extension, we should be changing the type.
SDLoc DL(Root);
@@ -13298,6 +13318,7 @@ struct NodeExtensionHelper {
case ISD::SIGN_EXTEND:
case RISCVISD::VSEXT_VL:
case RISCVISD::VZEXT_VL:
+ case RISCVISD::FP_EXTEND_VL:
return DAG.getNode(ExtOpc, DL, NarrowVT, Source, Mask, VL);
case RISCVISD::VMV_V_X_VL:
return DAG.getNode(RISCVISD::VMV_V_X_VL, DL, NarrowVT,
@@ -13313,41 +13334,57 @@ struct NodeExtensionHelper {
/// Helper function to get the narrow type for \p Root.
/// The narrow type is the type of \p Root where we divided the size of each
/// element by 2. E.g., if Root's type <2xi16> -> narrow type <2xi8>.
- /// \pre The size of the type of the elements of Root must be a multiple of 2
- /// and be greater than 16.
- static MVT getNarrowType(const SDNode *Root) {
+ /// \pre Both the narrow type and the original type should be legal.
+ static MVT getNarrowType(const SDNode *Root, ExtKind SupportsExt) {
MVT VT = Root->getSimpleValueType(0);
// Determine the narrow size.
unsigned NarrowSize = VT.getScalarSizeInBits() / 2;
- assert(NarrowSize >= 8 && "Trying to extend something we can't represent");
- MVT NarrowVT = MVT::getVectorVT(MVT::getIntegerVT(NarrowSize),
- VT.getVectorElementCount());
+
+ unsigned NarrowMinSize = SupportsExt == ExtKind::FPExt ? 16 : 8;
+
+ MVT EltVT = SupportsExt == ExtKind::FPExt
+ ? MVT::getFloatingPointVT(NarrowSize)
+ : MVT::getIntegerVT(NarrowSize);
+
+ assert(NarrowSize >= NarrowMinSize &&
+ "Trying to extend something we can't represent");
+ MVT NarrowVT = MVT::getVectorVT(EltVT, VT.getVectorElementCount());
return NarrowVT;
}
- /// Return the opcode required to materialize the folding of the sign
- /// extensions (\p IsSExt == true) or zero extensions (IsSExt == false) for
+ /// Return the opcode required to materialize the folding for
/// both operands for \p Opcode.
/// Put differently, get the opcode to materialize:
- /// - ISExt == true: \p Opcode(sext(a), sext(b)) -> newOpcode(a, b)
- /// - ISExt == false: \p Opcode(zext(a), zext(b)) -> newOpcode(a, b)
+ /// - ExtKind::SExt: \p Opcode(sext(a), sext(b)) -> newOpcode(a, b)
+ /// - ExtKind::ZExt: \p Opcode(zext(a), zext(b)) -> newOpcode(a, b)
/// \pre \p Opcode represents a supported root (\see ::isSupportedRoot()).
- static unsigned getSameExtensionOpcode(unsigned Opcode, bool IsSExt) {
+ static unsigned getSameExtensionOpcode(unsigned Opcode, ExtKind SupportsExt) {
switch (Opcode) {
case ISD::ADD:
case RISCVISD::ADD_VL:
case RISCVISD::VWADD_W_VL:
case RISCVISD::VWADDU_W_VL:
- return IsSExt ? RISCVISD::VWADD_VL : RISCVISD::VWADDU_VL;
+ return SupportsExt == ExtKind::SExt ? RISCVISD::VWADD_VL
+ : RISCVISD::VWADDU_VL;
case ISD::MUL:
case RISCVISD::MUL_VL:
- return IsSExt ? RISCVISD::VWMUL_VL : RISCVISD::VWMULU_VL;
+ return SupportsExt == ExtKind::SExt ? RISCVISD::VWMUL_VL
+ : RISCVISD::VWMULU_VL;
case ISD::SUB:
case RISCVISD::SUB_VL:
case RISCVISD::VWSUB_W_VL:
case RISCVISD::VWSUBU_W_VL:
- return IsSExt ? RISCVISD::VWSUB_VL : RISCVISD::VWSUBU_VL;
+ return SupportsExt == ExtKind::SExt ? RISCVISD::VWSUB_VL
+ : RISCVISD::VWSUBU_VL;
+ case RISCVISD::FADD_VL:
+ case RISCVISD::VFWADD_W_VL:
+ return RISCVISD::VFWADD_VL;
+ case RISCVISD::FSUB_VL:
+ case RISCVISD::VFWSUB_W_VL:
+ return RISCVISD::VFWSUB_VL;
+ case RISCVISD::FMUL_VL:
+ return RISCVISD::VFWMUL_VL;
default:
llvm_unreachable("Unexpected opcode");
}
@@ -13361,16 +13398,22 @@ struct NodeExtensionHelper {
return RISCVISD::VWMULSU_VL;
}
- /// Get the opcode to materialize \p Opcode(a, s|zext(b)) ->
- /// newOpcode(a, b).
- static unsigned getWOpcode(unsigned Opcode, bool IsSExt) {
+ /// Get the opcode to materialize
+ /// \p Opcode(a, s|z|fpext(b)) -> newOpcode(a, b).
+ static unsigned getWOpcode(unsigned Opcode, ExtKind SupportsExt) {
switch (Opcode) {
case ISD::ADD:
case RISCVISD::ADD_VL:
- return IsSExt ? RISCVISD::VWADD_W_VL : RISCVISD::VWADDU_W_VL;
+ return SupportsExt == ExtKind::SExt ? RISCVISD::VWADD_W_VL
+ : RISCVISD::VWADDU_W_VL;
case ISD::SUB:
case RISCVISD::SUB_VL:
- return IsSExt ? RISCVISD::VWSUB_W_VL : RISCVISD::VWSUBU_W_VL;
+ return SupportsExt == ExtKind::SExt ? RISCVISD::VWSUB_W_VL
+ : RISCVISD::VWSUBU_W_VL;
+ case RISCVISD::FADD_VL:
+ return RISCVISD::VFWADD_W_VL;
+ case RISCVISD::FSUB_VL:
+ return RISCVISD::VFWSUB_W_VL;
default:
llvm_unreachable("Unexpected opcode");
}
@@ -13390,6 +13433,7 @@ struct NodeExtensionHelper {
const RISCVSubtarget &Subtarget) {
SupportsZExt = false;
SupportsSExt = false;
+ SupportsFPExt = false;
EnforceOneUse = true;
CheckMask = true;
unsigned Opc = OrigOperand.getOpcode();
@@ -13431,6 +13475,11 @@ struct NodeExtensionHelper {
Mask = OrigOperand.getOperand(1);
VL = OrigOperand.getOperand(2);
break;
+ case RISCVISD::FP_EXTEND_VL:
+ SupportsFPExt = true;
+ Mask = OrigOperand.getOperand(1);
+ VL = OrigOperand.getOperand(2);
+ break;
case RISCVISD::VMV_V_X_VL: {
// Historically, we didn't care about splat values not disappearing during
// combines.
@@ -13477,15 +13526,16 @@ struct NodeExtensionHelper {
/// Check if \p Root supports any extension folding combines.
static bool isSupportedRoot(const SDNode *Root, const SelectionDAG &DAG) {
+ const TargetLowering &TLI = DAG.getTargetLoweringInfo();
switch (Root->getOpcode()) {
case ISD::ADD:
case ISD::SUB:
case ISD::MUL: {
- const TargetLowering &TLI = DAG.getTargetLoweringInfo();
if (!TLI.isTypeLegal(Root->getValueType(0)))
return false;
return Root->getValueType(0).isScalableVector();
}
+ // Vector Widening Integer Add/Sub/Mul Instructions
case RISCVISD::ADD_VL:
case RISCVISD::MUL_VL:
case RISCVISD::VWADD_W_VL:
@@ -13493,7 +13543,13 @@ struct NodeExtensionHelper {
case RISCVISD::SUB_VL:
case RISCVISD::VWSUB_W_VL:
case RISCVISD::VWSUBU_W_VL:
- return true;
+ // Vector Widening Floating-Point Add/Sub/Mul Instructions
+ case RISCVISD::FADD_VL:
+ case RISCVISD::FSUB_VL:
+ case RISCVISD::FMUL_VL:
+ case RISCVISD::VFWADD_W_VL:
+ case RISCVISD::VFWSUB_W_VL:
+ return TLI.isTypeLegal(Root->getValueType(0));
default:
return false;
}
@@ -13509,16 +13565,23 @@ struct NodeExtensionHelper {
unsigned Opc = Root->getOpcode();
switch (Opc) {
- // We consider VW<ADD|SUB>(U)_W(LHS, RHS) as if they were
- // <ADD|SUB>(LHS, S|ZEXT(RHS))
+ // We consider
+ // VW<ADD|SUB>_W(LHS, RHS) -> <ADD|SUB>(LHS, SEXT(RHS))
+ // VW<ADD|SUB>U_W(LHS, RHS) -> <ADD|SUB>(LHS, ZEXT(RHS))
+ // VFW<ADD|SUB>_W(LHS, RHS) -> F<ADD|SUB>(LHS, FPEXT(RHS))
case RISCVISD::VWADD_W_VL:
case RISCVISD::VWADDU_W_VL:
case RISCVISD::VWSUB_W_VL:
case RISCVISD::VWSUBU_W_VL:
+ case RISCVISD::VFWADD_W_VL:
+ case RISCVISD::VFWSUB_W_VL:
if (OperandIdx == 1) {
SupportsZExt =
Opc == RISCVISD::VWADDU_W_VL || Opc == RISCVISD::VWSUBU_W_VL;
- SupportsSExt = !SupportsZExt;
+ SupportsSExt =
+ Opc == RISCVISD::VWADD_W_VL || Opc == RISCVISD::VWSUB_W_VL;
+ SupportsFPExt =
+ Opc == RISCVISD::VFWADD_W_VL || Opc == RISCVISD::VFWSUB_W_VL;
std::tie(Mask, VL) = getMaskAndVL(Root, DAG, Subtarget);
CheckMask = true;
// There's no existing extension here, so we don't have to worry about
@@ -13578,11 +13641,16 @@ struct NodeExtensionHelper {
case RISCVISD::MUL_VL:
case RISCVISD::VWADD_W_VL:
case RISCVISD::VWADDU_W_VL:
+ case RISCVISD::FADD_VL:
+ case RISCVISD::FMUL_VL:
+ case RISCVISD::VFWADD_W_VL:
return true;
case ISD::SUB:
case RISCVISD::SUB_VL:
case RISCVISD::VWSUB_W_VL:
case RISCVISD::VWSUBU_W_VL:
+ case RISCVISD::FSUB_VL:
+ case RISCVISD::VFWSUB_W_VL:
return false;
default:
llvm_unreachable("Unexpected opcode");
@@ -13604,10 +13672,9 @@ struct NodeExtensionHelper {
struct CombineResult {
/// Opcode to be generated when materializing the combine.
unsigned TargetOpcode;
- // No value means no extension is needed. If extension is needed, the value
- // indicates if it needs to be sign extended.
- std::optional<bool> SExtLHS;
- std::optional<bool> SExtRHS;
+ // No value means no extension is needed.
+ std::optional<ExtKind> LHSExt;
+ std::optional<ExtKind> RHSExt;
/// Root of the combine.
SDNode *Root;
/// LHS of the TargetOpcode.
@@ -13616,10 +13683,10 @@ struct CombineResult {
NodeExtensionHelper RHS;
CombineResult(unsigned TargetOpcode, SDNode *Root,
- const NodeExtensionHelper &LHS, std::optional<bool> SExtLHS,
- const NodeExtensionHelper &RHS, std::optional<bool> SExtRHS)
- : TargetOpcode(TargetOpcode), SExtLHS(SExtLHS), SExtRHS(SExtRHS),
- Root(Root), LHS(LHS), RHS(RHS) {}
+ const NodeExtensionHelper &LHS, std::optional<ExtKind> LHSExt,
+ const NodeExtensionHelper &RHS, std::optional<ExtKind> RHSExt)
+ : TargetOpcode(TargetOpcode), LHSExt(LHSExt), RHSExt(RHSExt), Root(Root),
+ LHS(LHS), RHS(RHS) {}
/// Return a value that uses TargetOpcode and that can be used to replace
/// Root.
@@ -13640,8 +13707,8 @@ struct CombineResult {
break;
}
return DAG.getNode(TargetOpcode, SDLoc(Root), Root->getValueType(0),
- LHS.getOrCreateExtendedOp(Root, DAG, Subtarget, SExtLHS),
- RHS.getOrCreateExtendedOp(Root, DAG, Subtarget, SExtRHS),
+ LHS.getOrCreateExtendedOp(Root, DAG, Subtarget, LHSExt),
+ RHS.getOrCreateExtendedOp(Root, DAG, Subtarget, RHSExt),
Merge, Mask, VL);
}
};
@@ -13656,24 +13723,30 @@ struct CombineResult {
///
/// \returns std::nullopt if the pattern doesn't match or a CombineResult that
/// can be used to apply the pattern.
-static std::optional<CombineResult>
-canFoldToVWWithSameExtensionImpl(SDNode *Root, const NodeExtensionHelper &LHS,
- const NodeExtensionHelper &RHS, bool AllowSExt,
- bool AllowZExt, SelectionDAG &DAG,
- const RISCVSubtarget &Subtarget) {
- assert((AllowSExt || AllowZExt) && "Forgot to set what you want?");
+static std::optional<CombineResult> canFoldToVWWithSameExtensionImpl(
+ SDNode *Root, const NodeExtensionHelper &LHS,
+ const NodeExtensionHelper &RHS, bool AllowSExt, bool AllowZExt,
+ bool AllowFPExt, SelectionDAG &DAG, const RISCVSubtarget &Subtarget) {
+ assert((AllowSExt || AllowZExt || AllowFPExt) &&
+ "Forgot to set what you want?");
if (!LHS.areVLAndMaskCompatible(Root, DAG, Subtarget) ||
!RHS.areVLAndMaskCompatible(Root, DAG, Subtarget))
return std::nullopt;
if (AllowZExt && LHS.SupportsZExt && RHS.SupportsZExt)
return CombineResult(NodeExtensionHelper::getSameExtensionOpcode(
- Root->getOpcode(), /*IsSExt=*/false),
- Root, LHS, /*SExtLHS=*/false, RHS, /*SExtRHS=*/false);
+ Root->getOpcode(), ExtKind::ZExt),
+ Root, LHS, /*LHSExt=*/{ExtKind::ZExt}, RHS,
+ /*RHSExt=*/{ExtKind::ZExt});
if (AllowSExt && LHS.SupportsSExt && RHS.SupportsSExt)
return CombineResult(NodeExtensionHelper::getSameExtensionOpcode(
- Root->getOpcode(), /*IsSExt=*/true),
- Root, LHS, /*SExtLHS=*/true, RHS,
- /*SExtRHS=*/true);
+ Root->getOpcode(), ExtKind::SExt),
+ Root, LHS, /*LHSExt=*/{ExtKind::SExt}, RHS,
+ /*RHSExt=*/{ExtKind::SExt});
+ if (AllowFPExt && LHS.SupportsFPExt && RHS.SupportsFPExt)
+ return CombineResult(NodeExtensionHelper::getSameExtensionOpcode(
+ Root->getOpcode(), ExtKind::FPExt),
+ Root, LHS, /*LHSExt=*/{ExtKind::FPExt}, RHS,
+ /*RHSExt=*/{ExtKind::FPExt});
return std::nullopt;
}
@@ -13688,7 +13761,8 @@ canFoldToVWWithSameExtension(SDNode *Root, const NodeExtensionHelper &LHS,
const NodeExtensionHelper &RHS, SelectionDAG &DAG,
const RISCVSubtarget &Subtarget) {
return canFoldToVWWithSameExtensionImpl(Root, LHS, RHS, /*AllowSExt=*/true,
- /*AllowZExt=*/true, DAG, Subtarget);
+ /*AllowZExt=*/true,
+ /*AllowFPExt=*/true, DAG, Subtarget);
}
/// Check if \p Root follows a pattern Root(LHS, ext(RHS))
@@ -13702,18 +13776,23 @@ canFoldToVW_W(SDNode *Root, const NodeExtensionHelper &LHS,
if (!RHS.areVLAndMaskCompatible(Root, DAG, Subtarget))
return std::nullopt;
+ if (RHS.SupportsFPExt)
+ return CombineResult(
+ NodeExtensionHelper::getWOpcode(Root->getOpcode(), ExtKind::FPExt),
+ Root, LHS, /*LHSExt=*/std::nullopt, RHS, /*RHSExt=*/{ExtKind::FPExt});
+
// FIXME: Is it useful to form a vwadd.wx or vwsub.wx if it removes a scalar
// sext/zext?
// Control this behavior behind an option (AllowSplatInVW_W) for testing
// purposes.
if (RHS.SupportsZExt && (!RHS.isSplat() || AllowSplatInVW_W))
return CombineResult(
- NodeExtensionHelper::getWOpcode(Root->getOpcode(), /*IsSExt=*/false),
- Root, LHS, /*SExtLHS=*/std::nullopt, RHS, /*SExtRHS=*/false);
+ NodeExtensionHelper::getWOpcode(Root->getOpcode(), ExtKind::ZExt), Root,
+ LHS, /*LHSExt=*/std::nullopt, RHS, /*RHSExt=*/{ExtKind::ZExt});
if (RHS.SupportsSExt && (!RHS.isSplat() || AllowSplatInVW_W))
return CombineResult(
- NodeExtensionHelper::getWOpcode(Root->getOpcode(), /*IsSExt=*/true),
- Root, LHS, /*SExtLHS=*/std::nullopt, RHS, /*SExtRHS=*/true);
+ NodeExtensionHelper::getWOpcode(Root->getOpcode(), ExtKind::SExt), Root,
+ LHS, /*LHSExt=*/std::nullopt, RHS, /*RHSExt=*/{ExtKind::SExt});
return std::nullopt;
}
@@ -13726,7 +13805,8 @@ canFoldToVWWithSEXT(SDNode *Root, const NodeExtensionHelper &LHS,
const NodeExtensionHelper &RHS, SelectionDAG &DAG,
const RISCVSubtarget &Subtarget) {
return canFoldToVWWithSameExtensionImpl(Root, LHS, RHS, /*AllowSExt=*/true,
- /*AllowZExt=*/false, DAG, Subtarget);
+ /*AllowZExt=*/false,
+ /*AllowFPExt=*/false, DAG, Subtarget);
}
/// Check if \p Root follows a pattern Root(zext(LHS), zext(RHS))
@@ -13738,7 +13818,21 @@ canFoldToVWWithZEXT(SDNode *Root, const NodeExtensionHelper &LHS,
const NodeExtensionHelper &RHS, SelectionDAG &DAG,
const RISCVSubtarget &Subtarget) {
return canFoldToVWWithSameExtensionImpl(Root, LHS, RHS, /*AllowSExt=*/false,
- /*AllowZExt=*/true, DAG, Subtarget);
+ /*AllowZExt=*/true,
+ /*AllowFPExt=*/false, DAG, Subtarget);
+}
+
+/// Check if \p Root follows a pattern Root(fpext(LHS), fpext(RHS))
+///
+/// \returns std::nullopt if the pattern doesn't match or a CombineResult that
+/// can be used to apply the pattern.
+static std::optional<CombineResult>
+canFoldToVWWithFPEXT(SDNode *Root, const NodeExtensionHelper &LHS,
+ const NodeExtensionHelper &RHS, SelectionDAG &DAG,
+ const RISCVSubtarget &Subtarget) {
+ return canFoldToVWWithSameExtensionImpl(Root, LHS, RHS, /*AllowSExt=*/false,
+ /*AllowZExt=*/false,
+ /*AllowFPExt=*/true, DAG, Subtarget);
}
/// Check if \p Root follows a pattern Root(sext(LHS), zext(RHS))
@@ -13756,7 +13850,8 @@ canFoldToVW_SU(SDNode *Root, const NodeExtensionHelper &LHS,
!RHS.areVLAndMaskCompatible(Root, DAG, Subtarget))
return std::nullopt;
return CombineResult(NodeExtensionHelper::getSUOpcode(Root->getOpcode()),
- Root, LHS, /*SExtLHS=*/true, RHS, /*SExtRHS=*/false);
+ Root, LHS, /*LHSExt=*/{ExtKind::SExt}, RHS,
+ /*RHSExt=*/{ExtKind::ZExt});
}
SmallVector<NodeExtensionHelper::CombineToTry>
@@ -13767,11 +13862,16 @@ NodeExtensionHelper::getSupportedFoldings(const SDNode *Root) {
case ISD::SUB:
case RISCVISD::ADD_VL:
case RISCVISD::SUB_VL:
- // add|sub -> vwadd(u)|vwsub(u)
+ case RISCVISD::FADD_VL:
+ case RISCVISD::FSUB_VL:
+ // add|sub|fadd|fsub-> vwadd(u)|vwsub(u)|vfwadd|vfwsub
Strategies.push_back(canFoldToVWWithSameExtension);
- // add|sub -> vwadd(u)_w|vwsub(u)_w
+ // add|sub|fadd|fsub -> vwadd(u)_w|vwsub(u)_w}|vfwadd_w|vfwsub_w
Strategies.push_back(canFoldToVW_W);
break;
+ case RISCVISD::FMUL_VL:
+ Strategies.push_back(canFoldToVWWithSameExtension);
+ break;
case ISD::MUL:
case RISCVISD::MUL_VL:
// mul -> vwmul(u)
@@ -13789,6 +13889,11 @@ NodeExtensionHelper::getSupportedFoldings(const SDNode *Root) {
// vwaddu_w|vwsubu_w -> vwaddu|vwsubu
Strategies.push_back(canFoldToVWWithZEXT);
break;
+ case RISCVISD::VFWADD_W_VL:
+ case RISCVISD::VFWSUB_W_VL:
+ // vfwadd_w|vfwsub_w -> vfwadd|vfwsub
+ Strategies.push_back(canFoldToVWWithFPEXT);
+ break;
default:
llvm_unreachable("Unexpected opcode");
}
@@ -13801,8 +13906,13 @@ NodeExtensionHelper::getSupportedFoldings(const SDNode *Root) {
/// add_vl -> vwadd(u) | vwadd(u)_w
/// sub_vl -> vwsub(u) | vwsub(u)_w
/// mul_vl -> vwmul(u) | vwmul_su
+/// fadd_vl -> vfwadd | vfwadd_w
+/// fsub_vl -> vfwsub | vfwsub_w
+/// fmul_vl -> vfwmul
/// vwadd_w(u) -> vwadd(u)
-/// vwub_w(u) -> vwadd(u)
+/// vwsub_w(u) -> vwsub(u)
+/// vfwadd_w -> vfwadd
+/// vfwsub_w -> vfwsub
static SDValue combineBinOp_VLToVWBinOp_VL(SDNode *N,
TargetLowering::DAGCombinerInfo &DCI,
const RISCVSubtarget &Subtarget) {
@@ -13858,9 +13968,9 @@ static SDValue combineBinOp_VLToVWBinOp_VL(SDNode *N,
// All the inputs that are extended need to be folded, otherwise
// we would be leaving the old input (since it is may still be used),
// and the new one.
- if (Res->SExtLHS.has_value())
+ if (Res->LHSExt.has_value())
AppendUsersIfNeeded(LHS);
- if (Res->SExtRHS.has_value())
+ if (Res->RHSExt.has_value())
AppendUsersIfNeeded(RHS);
break;
}
@@ -14425,107 +14535,6 @@ static SDValue performVFMADD_VLCombine(SDNode *N, SelectionDAG &DAG,
N->getOperand(2), Mask, VL);
}
-static SDValue performVFMUL_VLCombine(SDNode *N, SelectionDAG &DAG,
- const RISCVSubtarget &Subtarget) {
- if (N->getValueType(0).isScalableVector() &&
- N->getValueType(0).getVectorElementType() == MVT::f32 &&
- (Subtarget.hasVInstructionsF16Minimal() &&
- !Subtarget.hasVInstructionsF16())) {
- return SDValue();
- }
-
- // FIXME: Ignore strict opcodes for now.
- assert(!N->isTargetStrictFPOpcode() && "Unexpected opcode");
-
- // Try to form widening multiply.
- SDValue Op0 = N->getOperand(0);
- SDValue Op1 = N->getOperand(1);
- SDValue Merge = N->getOperand(2);
- SDValue Mask = N->getOperand(3);
- SDValue VL = N->getOperand(4);
-
- if (Op0.getOpcode() != RISCVISD::FP_EXTEND_VL ||
- Op1.getOpcode() != RISCVISD::FP_EXTEND_VL)
- return SDValue();
-
- // TODO: Refactor to handle more complex cases similar to
- // combineBinOp_VLToVWBinOp_VL.
- if ((!Op0.hasOneUse() || !Op1.hasOneUse()) &&
- (Op0 != Op1 || !Op0->hasNUsesOfValue(2, 0)))
- return SDValue();
-
- // Check the mask and VL are the same.
- if (Op0.getOperand(1) != Mask || Op0.getOperand(2) != VL ||
- Op1.getOperand(1) != Mask || Op1.getOperand(2) != VL)
- return SDValue();
-
- Op0 = Op0.getOperand(0);
- Op1 = Op1.getOperand(0);
-
- return DAG.getNode(RISCVISD::VFWMUL_VL, SDLoc(N), N->getValueType(0), Op0,
- Op1, Merge, Mask, VL);
-}
-
-static SDValue performFADDSUB_VLCombine(SDNode *N, SelectionDAG &DAG,
- const RISCVSubtarget &Subtarget) {
- if (N->getValueType(0).isScalableVector() &&
- N->getValueType(0).getVectorElementType() == MVT::f32 &&
- (Subtarget.hasVInstructionsF16Minimal() &&
- !Subtarget.hasVInstructionsF16())) {
- return SDValue();
- }
-
- SDValue Op0 = N->getOperand(0);
- SDValue Op1 = N->getOperand(1);
- SDValue Merge = N->getOperand(2);
- SDValue Mask = N->getOperand(3);
- SDValue VL = N->getOperand(4);
-
- bool IsAdd = N->getOpcode() == RISCVISD::FADD_VL;
-
- // Look for foldable FP_EXTENDS.
- bool Op0IsExtend =
- Op0.getOpcode() == RISCVISD::FP_EXTEND_VL &&
- (Op0.hasOneUse() || (Op0 == Op1 && Op0->hasNUsesOfValue(2, 0)));
- bool Op1IsExtend =
- (Op0 == Op1 && Op0IsExtend) ||
- (Op1.getOpcode() == RISCVISD::FP_EXTEND_VL && Op1.hasOneUse());
-
- // Check the mask and VL.
- if (Op0IsExtend && (Op0.getOperand(1) != Mask || Op0.getOperand(2) != VL))
- Op0IsExtend = false;
- if (Op1IsExtend && (Op1.getOperand(1) != Mask || Op1.getOperand(2) != VL))
- Op1IsExtend = false;
-
- // Canonicalize.
- if (!Op1IsExtend) {
- // Sub requires at least operand 1 to be an extend.
- if (!IsAdd)
- return SDValue();
-
- // Add is commutable, if the other operand is foldable, swap them.
- if (!Op0IsExtend)
- return SDValue();
-
- std::swap(Op0, Op1);
- std::swap(Op0IsExtend, Op1IsExtend);
- }
-
- // Op1 is a foldable extend. Op0 might be foldable.
- Op1 = Op1.getOperand(0);
- if (Op0IsExtend)
- Op0 = Op0.getOperand(0);
-
- unsigned Opc;
- if (IsAdd)
- Opc = Op0IsExtend ? RISCVISD::VFWADD_VL : RISCVISD::VFWADD_W_VL;
- else
- Opc = Op0IsExtend ? RISCVISD::VFWSUB_VL : RISCVISD::VFWSUB_W_VL;
-
- return DAG.getNode(Opc, SDLoc(N), N->getValueType(0), Op0, Op1, Merge, Mask,
- VL);
-}
-
static SDValue performSRACombine(SDNode *N, SelectionDAG &DAG,
const RISCVSubtarget &Subtarget) {
assert(N->getOpcode() == ISD::SRA && "Unexpected opcode");
@@ -16045,11 +16054,18 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N,
case RISCVISD::STRICT_VFMSUB_VL:
case RISCVISD::STRICT_VFNMSUB_VL:
return performVFMADD_VLCombine(N, DAG, Subtarget);
- case RISCVISD::FMUL_VL:
- return performVFMUL_VLCombine(N, DAG, Subtarget);
case RISCVISD::FADD_VL:
case RISCVISD::FSUB_VL:
- return performFADDSUB_VLCombine(N, DAG, Subtarget);
+ case RISCVISD::FMUL_VL:
+ case RISCVISD::VFWADD_W_VL:
+ case RISCVISD::VFWSUB_W_VL: {
+ if (N->getValueType(0).isScalableVector() &&
+ N->getValueType(0).getVectorElementType() == MVT::f32 &&
+ (Subtarget.hasVInstructionsF16Minimal() &&
+ !Subtarget.hasVInstructionsF16()))
+ return SDValue();
+ return combineBinOp_VLToVWBinOp_VL(N, DCI, Subtarget);
+ }
case ISD::LOAD:
case ISD::STORE: {
if (DCI.isAfterLegalizeDAG())
diff --git a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vfw-web-simplification.ll b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vfw-web-simplification.ll
new file mode 100644
index 00000000000000..26f77225dbb0e1
--- /dev/null
+++ b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vfw-web-simplification.ll
@@ -0,0 +1,88 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py
+; RUN: llc -mtriple=riscv64 -mattr=+v,+zfh,+zvfh,+f,+d -verify-machineinstrs %s -o - --riscv-lower-ext-max-web-size=1 | FileCheck %s --check-prefixes=NO_FOLDING
+; RUN: llc -mtriple=riscv64 -mattr=+v,+zfh,+zvfh,+f,+d -verify-machineinstrs %s -o - --riscv-lower-ext-max-web-size=2 | FileCheck %s --check-prefixes=NO_FOLDING
+; RUN: llc -mtriple=riscv64 -mattr=+v,+zfh,+zvfh,+f,+d -verify-machineinstrs %s -o - --riscv-lower-ext-max-web-size=3 | FileCheck %s --check-prefixes=FOLDING
+; RUN: llc -mtriple=riscv64 -mattr=+v,+zfh,+zvfhmin,+f,+d -verify-machineinstrs %s -o - --riscv-lower-ext-max-web-size=3 | FileCheck %s --check-prefixes=FOLDING,ZVFHMIN
+; Check that the default value enables the web folding and
+; that it is bigger than 3.
+; RUN: llc -mtriple=riscv64 -mattr=+v,+zfh,+zvfh,+f,+d -verify-machineinstrs %s -o - | FileCheck %s --check-prefixes=FOLDING
+
+define void @vfwmul_v2f116_multiple_users(ptr %x, ptr %y, ptr %z, <2 x half> %a, <2 x half> %b, <2 x half> %b2) {
+; NO_FOLDING-LABEL: vfwmul_v2f116_multiple_users:
+; NO_FOLDING: # %bb.0:
+; NO_FOLDING-NEXT: vsetivli zero, 2, e16, mf4, ta, ma
+; NO_FOLDING-NEXT: vfwcvt.f.f.v v11, v8
+; NO_FOLDING-NEXT: vfwcvt.f.f.v v8, v9
+; NO_FOLDING-NEXT: vfwcvt.f.f.v v9, v10
+; NO_FOLDING-NEXT: vsetvli zero, zero, e32, mf2, ta, ma
+; NO_FOLDING-NEXT: vfmul.vv v10, v11, v8
+; NO_FOLDING-NEXT: vfadd.vv v11, v11, v9
+; NO_FOLDING-NEXT: vfsub.vv v8, v8, v9
+; NO_FOLDING-NEXT: vse32.v v10, (a0)
+; NO_FOLDING-NEXT: vse32.v v11, (a1)
+; NO_FOLDING-NEXT: vse32.v v8, (a2)
+; NO_FOLDING-NEXT: ret
+;
+; ZVFHMIN-LABEL: vfwmul_v2f116_multiple_users:
+; ZVFHMIN: # %bb.0:
+; ZVFHMIN-NEXT: vsetivli zero, 2, e16, mf4, ta, ma
+; ZVFHMIN-NEXT: vfwcvt.f.f.v v11, v8
+; ZVFHMIN-NEXT: vfwcvt.f.f.v v8, v9
+; ZVFHMIN-NEXT: vfwcvt.f.f.v v9, v10
+; ZVFHMIN-NEXT: vsetvli zero, zero, e32, mf2, ta, ma
+; ZVFHMIN-NEXT: vfmul.vv v10, v11, v8
+; ZVFHMIN-NEXT: vfadd.vv v11, v11, v9
+; ZVFHMIN-NEXT: vfsub.vv v8, v8, v9
+; ZVFHMIN-NEXT: vse32.v v10, (a0)
+; ZVFHMIN-NEXT: vse32.v v11, (a1)
+; ZVFHMIN-NEXT: vse32.v v8, (a2)
+; ZVFHMIN-NEXT: ret
+ %c = fpext <2 x half> %a to <2 x float>
+ %d = fpext <2 x half> %b to <2 x float>
+ %d2 = fpext <2 x half> %b2 to <2 x float>
+ %e = fmul <2 x float> %c, %d
+ %f = fadd <2 x float> %c, %d2
+ %g = fsub <2 x float> %d, %d2
+ store <2 x float> %e, ptr %x
+ store <2 x float> %f, ptr %y
+ store <2 x float> %g, ptr %z
+ ret void
+}
+
+define void @vfwmul_v2f32_multiple_users(ptr %x, ptr %y, ptr %z, <2 x float> %a, <2 x float> %b, <2 x float> %b2) {
+; NO_FOLDING-LABEL: vfwmul_v2f32_multiple_users:
+; NO_FOLDING: # %bb.0:
+; NO_FOLDING-NEXT: vsetivli zero, 2, e32, mf2, ta, ma
+; NO_FOLDING-NEXT: vfwcvt.f.f.v v11, v8
+; NO_FOLDING-NEXT: vfwcvt.f.f.v v8, v9
+; NO_FOLDING-NEXT: vfwcvt.f.f.v v9, v10
+; NO_FOLDING-NEXT: vsetvli zero, zero, e64, m1, ta, ma
+; NO_FOLDING-NEXT: vfmul.vv v10, v11, v8
+; NO_FOLDING-NEXT: vfadd.vv v11, v11, v9
+; NO_FOLDING-NEXT: vfsub.vv v8, v8, v9
+; NO_FOLDING-NEXT: vse64.v v10, (a0)
+; NO_FOLDING-NEXT: vse64.v v11, (a1)
+; NO_FOLDING-NEXT: vse64.v v8, (a2)
+; NO_FOLDING-NEXT: ret
+;
+; FOLDING-LABEL: vfwmul_v2f32_multiple_users:
+; FOLDING: # %bb.0:
+; FOLDING-NEXT: vsetivli zero, 2, e32, mf2, ta, ma
+; FOLDING-NEXT: vfwmul.vv v11, v8, v9
+; FOLDING-NEXT: vfwadd.vv v12, v8, v10
+; FOLDING-NEXT: vfwsub.vv v8, v9, v10
+; FOLDING-NEXT: vse64.v v11, (a0)
+; FOLDING-NEXT: vse64.v v12, (a1)
+; FOLDING-NEXT: vse64.v v8, (a2)
+; FOLDING-NEXT: ret
+ %c = fpext <2 x float> %a to <2 x double>
+ %d = fpext <2 x float> %b to <2 x double>
+ %d2 = fpext <2 x float> %b2 to <2 x double>
+ %e = fmul <2 x double> %c, %d
+ %f = fadd <2 x double> %c, %d2
+ %g = fsub <2 x double> %d, %d2
+ store <2 x double> %e, ptr %x
+ store <2 x double> %f, ptr %y
+ store <2 x double> %g, ptr %z
+ ret void
+}
diff --git a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vfwadd.ll b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vfwadd.ll
index c9dc75e18774f8..dd3a50cfd77377 100644
--- a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vfwadd.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vfwadd.ll
@@ -396,12 +396,10 @@ define <32 x double> @vfwadd_vf_v32f32(ptr %x, float %y) {
; CHECK-NEXT: vsetvli zero, a1, e32, m8, ta, ma
; CHECK-NEXT: vle32.v v24, (a0)
; CHECK-NEXT: vsetivli zero, 16, e32, m8, ta, ma
-; CHECK-NEXT: vslidedown.vi v0, v24, 16
+; CHECK-NEXT: vslidedown.vi v8, v24, 16
; CHECK-NEXT: vsetivli zero, 16, e32, m4, ta, ma
-; CHECK-NEXT: vfmv.v.f v16, fa0
-; CHECK-NEXT: vfwcvt.f.f.v v8, v16
-; CHECK-NEXT: vfwadd.wv v16, v8, v0
-; CHECK-NEXT: vfwadd.wv v8, v8, v24
+; CHECK-NEXT: vfwadd.vf v16, v8, fa0
+; CHECK-NEXT: vfwadd.vf v8, v24, fa0
; CHECK-NEXT: ret
%a = load <32 x float>, ptr %x
%b = insertelement <32 x float> poison, float %y, i32 0
diff --git a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vfwmul.ll b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vfwmul.ll
index 8ad858d4c76598..7eaa1856ce2218 100644
--- a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vfwmul.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vfwmul.ll
@@ -394,18 +394,12 @@ define <32 x double> @vfwmul_vf_v32f32(ptr %x, float %y) {
; CHECK: # %bb.0:
; CHECK-NEXT: li a1, 32
; CHECK-NEXT: vsetvli zero, a1, e32, m8, ta, ma
-; CHECK-NEXT: vle32.v v16, (a0)
-; CHECK-NEXT: vsetivli zero, 16, e32, m4, ta, ma
-; CHECK-NEXT: vfwcvt.f.f.v v8, v16
+; CHECK-NEXT: vle32.v v24, (a0)
; CHECK-NEXT: vsetivli zero, 16, e32, m8, ta, ma
-; CHECK-NEXT: vslidedown.vi v16, v16, 16
+; CHECK-NEXT: vslidedown.vi v8, v24, 16
; CHECK-NEXT: vsetivli zero, 16, e32, m4, ta, ma
-; CHECK-NEXT: vfwcvt.f.f.v v24, v16
-; CHECK-NEXT: vfmv.v.f v16, fa0
-; CHECK-NEXT: vfwcvt.f.f.v v0, v16
-; CHECK-NEXT: vsetvli zero, zero, e64, m8, ta, ma
-; CHECK-NEXT: vfmul.vv v16, v24, v0
-; CHECK-NEXT: vfmul.vv v8, v8, v0
+; CHECK-NEXT: vfwmul.vf v16, v8, fa0
+; CHECK-NEXT: vfwmul.vf v8, v24, fa0
; CHECK-NEXT: ret
%a = load <32 x float>, ptr %x
%b = insertelement <32 x float> poison, float %y, i32 0
diff --git a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vfwsub.ll b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vfwsub.ll
index d22781d6a97ac2..8cf7c5f1758654 100644
--- a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vfwsub.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vfwsub.ll
@@ -394,18 +394,12 @@ define <32 x double> @vfwsub_vf_v32f32(ptr %x, float %y) {
; CHECK: # %bb.0:
; CHECK-NEXT: li a1, 32
; CHECK-NEXT: vsetvli zero, a1, e32, m8, ta, ma
-; CHECK-NEXT: vle32.v v16, (a0)
-; CHECK-NEXT: vsetivli zero, 16, e32, m4, ta, ma
-; CHECK-NEXT: vfwcvt.f.f.v v8, v16
+; CHECK-NEXT: vle32.v v24, (a0)
; CHECK-NEXT: vsetivli zero, 16, e32, m8, ta, ma
-; CHECK-NEXT: vslidedown.vi v16, v16, 16
+; CHECK-NEXT: vslidedown.vi v8, v24, 16
; CHECK-NEXT: vsetivli zero, 16, e32, m4, ta, ma
-; CHECK-NEXT: vfwcvt.f.f.v v24, v16
-; CHECK-NEXT: vfmv.v.f v16, fa0
-; CHECK-NEXT: vfwcvt.f.f.v v0, v16
-; CHECK-NEXT: vsetvli zero, zero, e64, m8, ta, ma
-; CHECK-NEXT: vfsub.vv v16, v24, v0
-; CHECK-NEXT: vfsub.vv v8, v8, v0
+; CHECK-NEXT: vfwsub.vf v16, v8, fa0
+; CHECK-NEXT: vfwsub.vf v8, v24, fa0
; CHECK-NEXT: ret
%a = load <32 x float>, ptr %x
%b = insertelement <32 x float> poison, float %y, i32 0
More information about the llvm-commits
mailing list