[llvm] [RISCV][ISel] Combine vector fadd/fsub/fmul with fp extend. (PR #81248)
via llvm-commits
llvm-commits at lists.llvm.org
Sun Feb 18 04:45:31 PST 2024
https://github.com/sun-jacobi updated https://github.com/llvm/llvm-project/pull/81248
>From 22ecc6a32efa41ae4de5139641ea018292c1ff58 Mon Sep 17 00:00:00 2001
From: sun-jacobi <sun1011jacobi at gmail.com>
Date: Fri, 9 Feb 2024 20:47:50 +0800
Subject: [PATCH 1/2] [RISCV][ISel] Combine vector fadd/fsub/fmul with fp
extend.
---
llvm/lib/Target/RISCV/RISCVISelLowering.cpp | 356 +++++++++---------
.../fixed-vectors-vfw-web-simplification.ll | 88 +++++
.../CodeGen/RISCV/rvv/fixed-vectors-vfwadd.ll | 8 +-
.../CodeGen/RISCV/rvv/fixed-vectors-vfwmul.ll | 14 +-
.../CodeGen/RISCV/rvv/fixed-vectors-vfwsub.ll | 14 +-
5 files changed, 285 insertions(+), 195 deletions(-)
create mode 100644 llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vfw-web-simplification.ll
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index 0e1ea8ac75fe0e..2c71538f1b7795 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -13292,12 +13292,16 @@ namespace {
// apply a combine.
struct CombineResult;
+enum class ExtKind { ZExt, SExt, FPExt };
+
/// Helper class for folding sign/zero extensions.
/// In particular, this class is used for the following combines:
/// add | add_vl -> vwadd(u) | vwadd(u)_w
/// sub | sub_vl -> vwsub(u) | vwsub(u)_w
/// mul | mul_vl -> vwmul(u) | vwmul_su
-///
+/// fadd -> vfwadd | vfwadd_w
+/// fsub -> vfwsub | vfwsub_w
+/// fmul -> vfwmul
/// An object of this class represents an operand of the operation we want to
/// combine.
/// E.g., when trying to combine `mul_vl a, b`, we will have one instance of
@@ -13311,7 +13315,8 @@ struct CombineResult;
/// - VWADDU_W == add(op0, zext(op1))
/// - VWSUB_W == sub(op0, sext(op1))
/// - VWSUBU_W == sub(op0, zext(op1))
-///
+/// - VFWADD_W == fadd(op0, fpext(op1))
+/// - VFWSUB_W == fsub(op0, fpext(op1))
/// And VMV_V_X_VL, depending on the value, is conceptually equivalent to
/// zext|sext(smaller_value).
struct NodeExtensionHelper {
@@ -13322,6 +13327,8 @@ struct NodeExtensionHelper {
/// instance, a splat constant (e.g., 3), would support being both sign and
/// zero extended.
bool SupportsSExt;
+ /// Records if this operand is like being floating-Point extended.
+ bool SupportsFPExt;
/// This boolean captures whether we care if this operand would still be
/// around after the folding happens.
bool EnforceOneUse;
@@ -13345,6 +13352,7 @@ struct NodeExtensionHelper {
case ISD::SIGN_EXTEND:
case RISCVISD::VSEXT_VL:
case RISCVISD::VZEXT_VL:
+ case RISCVISD::FP_EXTEND_VL:
return OrigOperand.getOperand(0);
default:
return OrigOperand;
@@ -13356,22 +13364,34 @@ struct NodeExtensionHelper {
return OrigOperand.getOpcode() == RISCVISD::VMV_V_X_VL;
}
+ /// Get the extended opcode.
+ unsigned getExtOpc(ExtKind SupportsExt) const {
+ switch (SupportsExt) {
+ case ExtKind::SExt:
+ return RISCVISD::VSEXT_VL;
+ case ExtKind::ZExt:
+ return RISCVISD::VZEXT_VL;
+ case ExtKind::FPExt:
+ return RISCVISD::FP_EXTEND_VL;
+ }
+ }
+
/// Get or create a value that can feed \p Root with the given extension \p
- /// SExt. If \p SExt is std::nullopt, this returns the source of this operand.
- /// \see ::getSource().
+ /// SupportsExt. If \p SExt is std::nullopt, this returns the source of this
+ /// operand. \see ::getSource().
SDValue getOrCreateExtendedOp(SDNode *Root, SelectionDAG &DAG,
const RISCVSubtarget &Subtarget,
- std::optional<bool> SExt) const {
- if (!SExt.has_value())
+ std::optional<ExtKind> SupportsExt) const {
+ if (!SupportsExt.has_value())
return OrigOperand;
- MVT NarrowVT = getNarrowType(Root);
+ MVT NarrowVT = getNarrowType(Root, *SupportsExt);
SDValue Source = getSource();
if (Source.getValueType() == NarrowVT)
return Source;
- unsigned ExtOpc = *SExt ? RISCVISD::VSEXT_VL : RISCVISD::VZEXT_VL;
+ unsigned ExtOpc = getExtOpc(*SupportsExt);
// If we need an extension, we should be changing the type.
SDLoc DL(OrigOperand);
@@ -13381,6 +13401,7 @@ struct NodeExtensionHelper {
case ISD::SIGN_EXTEND:
case RISCVISD::VSEXT_VL:
case RISCVISD::VZEXT_VL:
+ case RISCVISD::FP_EXTEND_VL:
return DAG.getNode(ExtOpc, DL, NarrowVT, Source, Mask, VL);
case RISCVISD::VMV_V_X_VL:
return DAG.getNode(RISCVISD::VMV_V_X_VL, DL, NarrowVT,
@@ -13396,41 +13417,57 @@ struct NodeExtensionHelper {
/// Helper function to get the narrow type for \p Root.
/// The narrow type is the type of \p Root where we divided the size of each
/// element by 2. E.g., if Root's type <2xi16> -> narrow type <2xi8>.
- /// \pre The size of the type of the elements of Root must be a multiple of 2
- /// and be greater than 16.
- static MVT getNarrowType(const SDNode *Root) {
+ /// \pre Both the narrow type and the original type should be legal.
+ static MVT getNarrowType(const SDNode *Root, ExtKind SupportsExt) {
MVT VT = Root->getSimpleValueType(0);
// Determine the narrow size.
unsigned NarrowSize = VT.getScalarSizeInBits() / 2;
- assert(NarrowSize >= 8 && "Trying to extend something we can't represent");
- MVT NarrowVT = MVT::getVectorVT(MVT::getIntegerVT(NarrowSize),
- VT.getVectorElementCount());
+
+ unsigned NarrowMinSize = SupportsExt == ExtKind::FPExt ? 16 : 8;
+
+ MVT EltVT = SupportsExt == ExtKind::FPExt
+ ? MVT::getFloatingPointVT(NarrowSize)
+ : MVT::getIntegerVT(NarrowSize);
+
+ assert(NarrowSize >= NarrowMinSize &&
+ "Trying to extend something we can't represent");
+ MVT NarrowVT = MVT::getVectorVT(EltVT, VT.getVectorElementCount());
return NarrowVT;
}
- /// Return the opcode required to materialize the folding of the sign
- /// extensions (\p IsSExt == true) or zero extensions (IsSExt == false) for
+ /// Return the opcode required to materialize the folding for
/// both operands for \p Opcode.
/// Put differently, get the opcode to materialize:
- /// - ISExt == true: \p Opcode(sext(a), sext(b)) -> newOpcode(a, b)
- /// - ISExt == false: \p Opcode(zext(a), zext(b)) -> newOpcode(a, b)
+ /// - ExtKind::SExt: \p Opcode(sext(a), sext(b)) -> newOpcode(a, b)
+ /// - ExtKind::ZExt: \p Opcode(zext(a), zext(b)) -> newOpcode(a, b)
/// \pre \p Opcode represents a supported root (\see ::isSupportedRoot()).
- static unsigned getSameExtensionOpcode(unsigned Opcode, bool IsSExt) {
+ static unsigned getSameExtensionOpcode(unsigned Opcode, ExtKind SupportsExt) {
switch (Opcode) {
case ISD::ADD:
case RISCVISD::ADD_VL:
case RISCVISD::VWADD_W_VL:
case RISCVISD::VWADDU_W_VL:
- return IsSExt ? RISCVISD::VWADD_VL : RISCVISD::VWADDU_VL;
+ return SupportsExt == ExtKind::SExt ? RISCVISD::VWADD_VL
+ : RISCVISD::VWADDU_VL;
case ISD::MUL:
case RISCVISD::MUL_VL:
- return IsSExt ? RISCVISD::VWMUL_VL : RISCVISD::VWMULU_VL;
+ return SupportsExt == ExtKind::SExt ? RISCVISD::VWMUL_VL
+ : RISCVISD::VWMULU_VL;
case ISD::SUB:
case RISCVISD::SUB_VL:
case RISCVISD::VWSUB_W_VL:
case RISCVISD::VWSUBU_W_VL:
- return IsSExt ? RISCVISD::VWSUB_VL : RISCVISD::VWSUBU_VL;
+ return SupportsExt == ExtKind::SExt ? RISCVISD::VWSUB_VL
+ : RISCVISD::VWSUBU_VL;
+ case RISCVISD::FADD_VL:
+ case RISCVISD::VFWADD_W_VL:
+ return RISCVISD::VFWADD_VL;
+ case RISCVISD::FSUB_VL:
+ case RISCVISD::VFWSUB_W_VL:
+ return RISCVISD::VFWSUB_VL;
+ case RISCVISD::FMUL_VL:
+ return RISCVISD::VFWMUL_VL;
default:
llvm_unreachable("Unexpected opcode");
}
@@ -13444,16 +13481,22 @@ struct NodeExtensionHelper {
return RISCVISD::VWMULSU_VL;
}
- /// Get the opcode to materialize \p Opcode(a, s|zext(b)) ->
- /// newOpcode(a, b).
- static unsigned getWOpcode(unsigned Opcode, bool IsSExt) {
+ /// Get the opcode to materialize
+ /// \p Opcode(a, s|z|fpext(b)) -> newOpcode(a, b).
+ static unsigned getWOpcode(unsigned Opcode, ExtKind SupportsExt) {
switch (Opcode) {
case ISD::ADD:
case RISCVISD::ADD_VL:
- return IsSExt ? RISCVISD::VWADD_W_VL : RISCVISD::VWADDU_W_VL;
+ return SupportsExt == ExtKind::SExt ? RISCVISD::VWADD_W_VL
+ : RISCVISD::VWADDU_W_VL;
case ISD::SUB:
case RISCVISD::SUB_VL:
- return IsSExt ? RISCVISD::VWSUB_W_VL : RISCVISD::VWSUBU_W_VL;
+ return SupportsExt == ExtKind::SExt ? RISCVISD::VWSUB_W_VL
+ : RISCVISD::VWSUBU_W_VL;
+ case RISCVISD::FADD_VL:
+ return RISCVISD::VFWADD_W_VL;
+ case RISCVISD::FSUB_VL:
+ return RISCVISD::VFWSUB_W_VL;
default:
llvm_unreachable("Unexpected opcode");
}
@@ -13473,6 +13516,7 @@ struct NodeExtensionHelper {
const RISCVSubtarget &Subtarget) {
SupportsZExt = false;
SupportsSExt = false;
+ SupportsFPExt = false;
EnforceOneUse = true;
CheckMask = true;
unsigned Opc = OrigOperand.getOpcode();
@@ -13514,6 +13558,11 @@ struct NodeExtensionHelper {
Mask = OrigOperand.getOperand(1);
VL = OrigOperand.getOperand(2);
break;
+ case RISCVISD::FP_EXTEND_VL:
+ SupportsFPExt = true;
+ Mask = OrigOperand.getOperand(1);
+ VL = OrigOperand.getOperand(2);
+ break;
case RISCVISD::VMV_V_X_VL: {
// Historically, we didn't care about splat values not disappearing during
// combines.
@@ -13560,15 +13609,16 @@ struct NodeExtensionHelper {
/// Check if \p Root supports any extension folding combines.
static bool isSupportedRoot(const SDNode *Root, const SelectionDAG &DAG) {
+ const TargetLowering &TLI = DAG.getTargetLoweringInfo();
switch (Root->getOpcode()) {
case ISD::ADD:
case ISD::SUB:
case ISD::MUL: {
- const TargetLowering &TLI = DAG.getTargetLoweringInfo();
if (!TLI.isTypeLegal(Root->getValueType(0)))
return false;
return Root->getValueType(0).isScalableVector();
}
+ // Vector Widening Integer Add/Sub/Mul Instructions
case RISCVISD::ADD_VL:
case RISCVISD::MUL_VL:
case RISCVISD::VWADD_W_VL:
@@ -13576,7 +13626,13 @@ struct NodeExtensionHelper {
case RISCVISD::SUB_VL:
case RISCVISD::VWSUB_W_VL:
case RISCVISD::VWSUBU_W_VL:
- return true;
+ // Vector Widening Floating-Point Add/Sub/Mul Instructions
+ case RISCVISD::FADD_VL:
+ case RISCVISD::FSUB_VL:
+ case RISCVISD::FMUL_VL:
+ case RISCVISD::VFWADD_W_VL:
+ case RISCVISD::VFWSUB_W_VL:
+ return TLI.isTypeLegal(Root->getValueType(0));
default:
return false;
}
@@ -13592,16 +13648,23 @@ struct NodeExtensionHelper {
unsigned Opc = Root->getOpcode();
switch (Opc) {
- // We consider VW<ADD|SUB>(U)_W(LHS, RHS) as if they were
- // <ADD|SUB>(LHS, S|ZEXT(RHS))
+ // We consider
+ // VW<ADD|SUB>_W(LHS, RHS) -> <ADD|SUB>(LHS, SEXT(RHS))
+ // VW<ADD|SUB>U_W(LHS, RHS) -> <ADD|SUB>(LHS, ZEXT(RHS))
+ // VFW<ADD|SUB>_W(LHS, RHS) -> F<ADD|SUB>(LHS, FPEXT(RHS))
case RISCVISD::VWADD_W_VL:
case RISCVISD::VWADDU_W_VL:
case RISCVISD::VWSUB_W_VL:
case RISCVISD::VWSUBU_W_VL:
+ case RISCVISD::VFWADD_W_VL:
+ case RISCVISD::VFWSUB_W_VL:
if (OperandIdx == 1) {
SupportsZExt =
Opc == RISCVISD::VWADDU_W_VL || Opc == RISCVISD::VWSUBU_W_VL;
- SupportsSExt = !SupportsZExt;
+ SupportsSExt =
+ Opc == RISCVISD::VWADD_W_VL || Opc == RISCVISD::VWSUB_W_VL;
+ SupportsFPExt =
+ Opc == RISCVISD::VFWADD_W_VL || Opc == RISCVISD::VFWSUB_W_VL;
std::tie(Mask, VL) = getMaskAndVL(Root, DAG, Subtarget);
CheckMask = true;
// There's no existing extension here, so we don't have to worry about
@@ -13661,11 +13724,16 @@ struct NodeExtensionHelper {
case RISCVISD::MUL_VL:
case RISCVISD::VWADD_W_VL:
case RISCVISD::VWADDU_W_VL:
+ case RISCVISD::FADD_VL:
+ case RISCVISD::FMUL_VL:
+ case RISCVISD::VFWADD_W_VL:
return true;
case ISD::SUB:
case RISCVISD::SUB_VL:
case RISCVISD::VWSUB_W_VL:
case RISCVISD::VWSUBU_W_VL:
+ case RISCVISD::FSUB_VL:
+ case RISCVISD::VFWSUB_W_VL:
return false;
default:
llvm_unreachable("Unexpected opcode");
@@ -13687,10 +13755,9 @@ struct NodeExtensionHelper {
struct CombineResult {
/// Opcode to be generated when materializing the combine.
unsigned TargetOpcode;
- // No value means no extension is needed. If extension is needed, the value
- // indicates if it needs to be sign extended.
- std::optional<bool> SExtLHS;
- std::optional<bool> SExtRHS;
+ // No value means no extension is needed.
+ std::optional<ExtKind> LHSExt;
+ std::optional<ExtKind> RHSExt;
/// Root of the combine.
SDNode *Root;
/// LHS of the TargetOpcode.
@@ -13699,10 +13766,10 @@ struct CombineResult {
NodeExtensionHelper RHS;
CombineResult(unsigned TargetOpcode, SDNode *Root,
- const NodeExtensionHelper &LHS, std::optional<bool> SExtLHS,
- const NodeExtensionHelper &RHS, std::optional<bool> SExtRHS)
- : TargetOpcode(TargetOpcode), SExtLHS(SExtLHS), SExtRHS(SExtRHS),
- Root(Root), LHS(LHS), RHS(RHS) {}
+ const NodeExtensionHelper &LHS, std::optional<ExtKind> LHSExt,
+ const NodeExtensionHelper &RHS, std::optional<ExtKind> RHSExt)
+ : TargetOpcode(TargetOpcode), LHSExt(LHSExt), RHSExt(RHSExt), Root(Root),
+ LHS(LHS), RHS(RHS) {}
/// Return a value that uses TargetOpcode and that can be used to replace
/// Root.
@@ -13723,8 +13790,8 @@ struct CombineResult {
break;
}
return DAG.getNode(TargetOpcode, SDLoc(Root), Root->getValueType(0),
- LHS.getOrCreateExtendedOp(Root, DAG, Subtarget, SExtLHS),
- RHS.getOrCreateExtendedOp(Root, DAG, Subtarget, SExtRHS),
+ LHS.getOrCreateExtendedOp(Root, DAG, Subtarget, LHSExt),
+ RHS.getOrCreateExtendedOp(Root, DAG, Subtarget, RHSExt),
Merge, Mask, VL);
}
};
@@ -13739,24 +13806,30 @@ struct CombineResult {
///
/// \returns std::nullopt if the pattern doesn't match or a CombineResult that
/// can be used to apply the pattern.
-static std::optional<CombineResult>
-canFoldToVWWithSameExtensionImpl(SDNode *Root, const NodeExtensionHelper &LHS,
- const NodeExtensionHelper &RHS, bool AllowSExt,
- bool AllowZExt, SelectionDAG &DAG,
- const RISCVSubtarget &Subtarget) {
- assert((AllowSExt || AllowZExt) && "Forgot to set what you want?");
+static std::optional<CombineResult> canFoldToVWWithSameExtensionImpl(
+ SDNode *Root, const NodeExtensionHelper &LHS,
+ const NodeExtensionHelper &RHS, bool AllowSExt, bool AllowZExt,
+ bool AllowFPExt, SelectionDAG &DAG, const RISCVSubtarget &Subtarget) {
+ assert((AllowSExt || AllowZExt || AllowFPExt) &&
+ "Forgot to set what you want?");
if (!LHS.areVLAndMaskCompatible(Root, DAG, Subtarget) ||
!RHS.areVLAndMaskCompatible(Root, DAG, Subtarget))
return std::nullopt;
if (AllowZExt && LHS.SupportsZExt && RHS.SupportsZExt)
return CombineResult(NodeExtensionHelper::getSameExtensionOpcode(
- Root->getOpcode(), /*IsSExt=*/false),
- Root, LHS, /*SExtLHS=*/false, RHS, /*SExtRHS=*/false);
+ Root->getOpcode(), ExtKind::ZExt),
+ Root, LHS, /*LHSExt=*/{ExtKind::ZExt}, RHS,
+ /*RHSExt=*/{ExtKind::ZExt});
if (AllowSExt && LHS.SupportsSExt && RHS.SupportsSExt)
return CombineResult(NodeExtensionHelper::getSameExtensionOpcode(
- Root->getOpcode(), /*IsSExt=*/true),
- Root, LHS, /*SExtLHS=*/true, RHS,
- /*SExtRHS=*/true);
+ Root->getOpcode(), ExtKind::SExt),
+ Root, LHS, /*LHSExt=*/{ExtKind::SExt}, RHS,
+ /*RHSExt=*/{ExtKind::SExt});
+ if (AllowFPExt && LHS.SupportsFPExt && RHS.SupportsFPExt)
+ return CombineResult(NodeExtensionHelper::getSameExtensionOpcode(
+ Root->getOpcode(), ExtKind::FPExt),
+ Root, LHS, /*LHSExt=*/{ExtKind::FPExt}, RHS,
+ /*RHSExt=*/{ExtKind::FPExt});
return std::nullopt;
}
@@ -13771,7 +13844,8 @@ canFoldToVWWithSameExtension(SDNode *Root, const NodeExtensionHelper &LHS,
const NodeExtensionHelper &RHS, SelectionDAG &DAG,
const RISCVSubtarget &Subtarget) {
return canFoldToVWWithSameExtensionImpl(Root, LHS, RHS, /*AllowSExt=*/true,
- /*AllowZExt=*/true, DAG, Subtarget);
+ /*AllowZExt=*/true,
+ /*AllowFPExt=*/true, DAG, Subtarget);
}
/// Check if \p Root follows a pattern Root(LHS, ext(RHS))
@@ -13785,18 +13859,23 @@ canFoldToVW_W(SDNode *Root, const NodeExtensionHelper &LHS,
if (!RHS.areVLAndMaskCompatible(Root, DAG, Subtarget))
return std::nullopt;
+ if (RHS.SupportsFPExt)
+ return CombineResult(
+ NodeExtensionHelper::getWOpcode(Root->getOpcode(), ExtKind::FPExt),
+ Root, LHS, /*LHSExt=*/std::nullopt, RHS, /*RHSExt=*/{ExtKind::FPExt});
+
// FIXME: Is it useful to form a vwadd.wx or vwsub.wx if it removes a scalar
// sext/zext?
// Control this behavior behind an option (AllowSplatInVW_W) for testing
// purposes.
if (RHS.SupportsZExt && (!RHS.isSplat() || AllowSplatInVW_W))
return CombineResult(
- NodeExtensionHelper::getWOpcode(Root->getOpcode(), /*IsSExt=*/false),
- Root, LHS, /*SExtLHS=*/std::nullopt, RHS, /*SExtRHS=*/false);
+ NodeExtensionHelper::getWOpcode(Root->getOpcode(), ExtKind::ZExt), Root,
+ LHS, /*LHSExt=*/std::nullopt, RHS, /*RHSExt=*/{ExtKind::ZExt});
if (RHS.SupportsSExt && (!RHS.isSplat() || AllowSplatInVW_W))
return CombineResult(
- NodeExtensionHelper::getWOpcode(Root->getOpcode(), /*IsSExt=*/true),
- Root, LHS, /*SExtLHS=*/std::nullopt, RHS, /*SExtRHS=*/true);
+ NodeExtensionHelper::getWOpcode(Root->getOpcode(), ExtKind::SExt), Root,
+ LHS, /*LHSExt=*/std::nullopt, RHS, /*RHSExt=*/{ExtKind::SExt});
return std::nullopt;
}
@@ -13809,7 +13888,8 @@ canFoldToVWWithSEXT(SDNode *Root, const NodeExtensionHelper &LHS,
const NodeExtensionHelper &RHS, SelectionDAG &DAG,
const RISCVSubtarget &Subtarget) {
return canFoldToVWWithSameExtensionImpl(Root, LHS, RHS, /*AllowSExt=*/true,
- /*AllowZExt=*/false, DAG, Subtarget);
+ /*AllowZExt=*/false,
+ /*AllowFPExt=*/false, DAG, Subtarget);
}
/// Check if \p Root follows a pattern Root(zext(LHS), zext(RHS))
@@ -13821,7 +13901,21 @@ canFoldToVWWithZEXT(SDNode *Root, const NodeExtensionHelper &LHS,
const NodeExtensionHelper &RHS, SelectionDAG &DAG,
const RISCVSubtarget &Subtarget) {
return canFoldToVWWithSameExtensionImpl(Root, LHS, RHS, /*AllowSExt=*/false,
- /*AllowZExt=*/true, DAG, Subtarget);
+ /*AllowZExt=*/true,
+ /*AllowFPExt=*/false, DAG, Subtarget);
+}
+
+/// Check if \p Root follows a pattern Root(fpext(LHS), fpext(RHS))
+///
+/// \returns std::nullopt if the pattern doesn't match or a CombineResult that
+/// can be used to apply the pattern.
+static std::optional<CombineResult>
+canFoldToVWWithFPEXT(SDNode *Root, const NodeExtensionHelper &LHS,
+ const NodeExtensionHelper &RHS, SelectionDAG &DAG,
+ const RISCVSubtarget &Subtarget) {
+ return canFoldToVWWithSameExtensionImpl(Root, LHS, RHS, /*AllowSExt=*/false,
+ /*AllowZExt=*/false,
+ /*AllowFPExt=*/true, DAG, Subtarget);
}
/// Check if \p Root follows a pattern Root(sext(LHS), zext(RHS))
@@ -13839,7 +13933,8 @@ canFoldToVW_SU(SDNode *Root, const NodeExtensionHelper &LHS,
!RHS.areVLAndMaskCompatible(Root, DAG, Subtarget))
return std::nullopt;
return CombineResult(NodeExtensionHelper::getSUOpcode(Root->getOpcode()),
- Root, LHS, /*SExtLHS=*/true, RHS, /*SExtRHS=*/false);
+ Root, LHS, /*LHSExt=*/{ExtKind::SExt}, RHS,
+ /*RHSExt=*/{ExtKind::ZExt});
}
SmallVector<NodeExtensionHelper::CombineToTry>
@@ -13850,11 +13945,16 @@ NodeExtensionHelper::getSupportedFoldings(const SDNode *Root) {
case ISD::SUB:
case RISCVISD::ADD_VL:
case RISCVISD::SUB_VL:
- // add|sub -> vwadd(u)|vwsub(u)
+ case RISCVISD::FADD_VL:
+ case RISCVISD::FSUB_VL:
+ // add|sub|fadd|fsub-> vwadd(u)|vwsub(u)|vfwadd|vfwsub
Strategies.push_back(canFoldToVWWithSameExtension);
- // add|sub -> vwadd(u)_w|vwsub(u)_w
+ // add|sub|fadd|fsub -> vwadd(u)_w|vwsub(u)_w}|vfwadd_w|vfwsub_w
Strategies.push_back(canFoldToVW_W);
break;
+ case RISCVISD::FMUL_VL:
+ Strategies.push_back(canFoldToVWWithSameExtension);
+ break;
case ISD::MUL:
case RISCVISD::MUL_VL:
// mul -> vwmul(u)
@@ -13872,6 +13972,11 @@ NodeExtensionHelper::getSupportedFoldings(const SDNode *Root) {
// vwaddu_w|vwsubu_w -> vwaddu|vwsubu
Strategies.push_back(canFoldToVWWithZEXT);
break;
+ case RISCVISD::VFWADD_W_VL:
+ case RISCVISD::VFWSUB_W_VL:
+ // vfwadd_w|vfwsub_w -> vfwadd|vfwsub
+ Strategies.push_back(canFoldToVWWithFPEXT);
+ break;
default:
llvm_unreachable("Unexpected opcode");
}
@@ -13884,8 +13989,13 @@ NodeExtensionHelper::getSupportedFoldings(const SDNode *Root) {
/// add_vl -> vwadd(u) | vwadd(u)_w
/// sub_vl -> vwsub(u) | vwsub(u)_w
/// mul_vl -> vwmul(u) | vwmul_su
+/// fadd_vl -> vfwadd | vfwadd_w
+/// fsub_vl -> vfwsub | vfwsub_w
+/// fmul_vl -> vfwmul
/// vwadd_w(u) -> vwadd(u)
-/// vwub_w(u) -> vwadd(u)
+/// vwsub_w(u) -> vwsub(u)
+/// vfwadd_w -> vfwadd
+/// vfwsub_w -> vfwsub
static SDValue combineBinOp_VLToVWBinOp_VL(SDNode *N,
TargetLowering::DAGCombinerInfo &DCI,
const RISCVSubtarget &Subtarget) {
@@ -13941,9 +14051,9 @@ static SDValue combineBinOp_VLToVWBinOp_VL(SDNode *N,
// All the inputs that are extended need to be folded, otherwise
// we would be leaving the old input (since it is may still be used),
// and the new one.
- if (Res->SExtLHS.has_value())
+ if (Res->LHSExt.has_value())
AppendUsersIfNeeded(LHS);
- if (Res->SExtRHS.has_value())
+ if (Res->RHSExt.has_value())
AppendUsersIfNeeded(RHS);
break;
}
@@ -14508,107 +14618,6 @@ static SDValue performVFMADD_VLCombine(SDNode *N, SelectionDAG &DAG,
N->getOperand(2), Mask, VL);
}
-static SDValue performVFMUL_VLCombine(SDNode *N, SelectionDAG &DAG,
- const RISCVSubtarget &Subtarget) {
- if (N->getValueType(0).isScalableVector() &&
- N->getValueType(0).getVectorElementType() == MVT::f32 &&
- (Subtarget.hasVInstructionsF16Minimal() &&
- !Subtarget.hasVInstructionsF16())) {
- return SDValue();
- }
-
- // FIXME: Ignore strict opcodes for now.
- assert(!N->isTargetStrictFPOpcode() && "Unexpected opcode");
-
- // Try to form widening multiply.
- SDValue Op0 = N->getOperand(0);
- SDValue Op1 = N->getOperand(1);
- SDValue Merge = N->getOperand(2);
- SDValue Mask = N->getOperand(3);
- SDValue VL = N->getOperand(4);
-
- if (Op0.getOpcode() != RISCVISD::FP_EXTEND_VL ||
- Op1.getOpcode() != RISCVISD::FP_EXTEND_VL)
- return SDValue();
-
- // TODO: Refactor to handle more complex cases similar to
- // combineBinOp_VLToVWBinOp_VL.
- if ((!Op0.hasOneUse() || !Op1.hasOneUse()) &&
- (Op0 != Op1 || !Op0->hasNUsesOfValue(2, 0)))
- return SDValue();
-
- // Check the mask and VL are the same.
- if (Op0.getOperand(1) != Mask || Op0.getOperand(2) != VL ||
- Op1.getOperand(1) != Mask || Op1.getOperand(2) != VL)
- return SDValue();
-
- Op0 = Op0.getOperand(0);
- Op1 = Op1.getOperand(0);
-
- return DAG.getNode(RISCVISD::VFWMUL_VL, SDLoc(N), N->getValueType(0), Op0,
- Op1, Merge, Mask, VL);
-}
-
-static SDValue performFADDSUB_VLCombine(SDNode *N, SelectionDAG &DAG,
- const RISCVSubtarget &Subtarget) {
- if (N->getValueType(0).isScalableVector() &&
- N->getValueType(0).getVectorElementType() == MVT::f32 &&
- (Subtarget.hasVInstructionsF16Minimal() &&
- !Subtarget.hasVInstructionsF16())) {
- return SDValue();
- }
-
- SDValue Op0 = N->getOperand(0);
- SDValue Op1 = N->getOperand(1);
- SDValue Merge = N->getOperand(2);
- SDValue Mask = N->getOperand(3);
- SDValue VL = N->getOperand(4);
-
- bool IsAdd = N->getOpcode() == RISCVISD::FADD_VL;
-
- // Look for foldable FP_EXTENDS.
- bool Op0IsExtend =
- Op0.getOpcode() == RISCVISD::FP_EXTEND_VL &&
- (Op0.hasOneUse() || (Op0 == Op1 && Op0->hasNUsesOfValue(2, 0)));
- bool Op1IsExtend =
- (Op0 == Op1 && Op0IsExtend) ||
- (Op1.getOpcode() == RISCVISD::FP_EXTEND_VL && Op1.hasOneUse());
-
- // Check the mask and VL.
- if (Op0IsExtend && (Op0.getOperand(1) != Mask || Op0.getOperand(2) != VL))
- Op0IsExtend = false;
- if (Op1IsExtend && (Op1.getOperand(1) != Mask || Op1.getOperand(2) != VL))
- Op1IsExtend = false;
-
- // Canonicalize.
- if (!Op1IsExtend) {
- // Sub requires at least operand 1 to be an extend.
- if (!IsAdd)
- return SDValue();
-
- // Add is commutable, if the other operand is foldable, swap them.
- if (!Op0IsExtend)
- return SDValue();
-
- std::swap(Op0, Op1);
- std::swap(Op0IsExtend, Op1IsExtend);
- }
-
- // Op1 is a foldable extend. Op0 might be foldable.
- Op1 = Op1.getOperand(0);
- if (Op0IsExtend)
- Op0 = Op0.getOperand(0);
-
- unsigned Opc;
- if (IsAdd)
- Opc = Op0IsExtend ? RISCVISD::VFWADD_VL : RISCVISD::VFWADD_W_VL;
- else
- Opc = Op0IsExtend ? RISCVISD::VFWSUB_VL : RISCVISD::VFWSUB_W_VL;
-
- return DAG.getNode(Opc, SDLoc(N), N->getValueType(0), Op0, Op1, Merge, Mask,
- VL);
-}
-
static SDValue performSRACombine(SDNode *N, SelectionDAG &DAG,
const RISCVSubtarget &Subtarget) {
assert(N->getOpcode() == ISD::SRA && "Unexpected opcode");
@@ -16141,11 +16150,18 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N,
case RISCVISD::STRICT_VFMSUB_VL:
case RISCVISD::STRICT_VFNMSUB_VL:
return performVFMADD_VLCombine(N, DAG, Subtarget);
- case RISCVISD::FMUL_VL:
- return performVFMUL_VLCombine(N, DAG, Subtarget);
case RISCVISD::FADD_VL:
case RISCVISD::FSUB_VL:
- return performFADDSUB_VLCombine(N, DAG, Subtarget);
+ case RISCVISD::FMUL_VL:
+ case RISCVISD::VFWADD_W_VL:
+ case RISCVISD::VFWSUB_W_VL: {
+ if (N->getValueType(0).isScalableVector() &&
+ N->getValueType(0).getVectorElementType() == MVT::f32 &&
+ (Subtarget.hasVInstructionsF16Minimal() &&
+ !Subtarget.hasVInstructionsF16()))
+ return SDValue();
+ return combineBinOp_VLToVWBinOp_VL(N, DCI, Subtarget);
+ }
case ISD::LOAD:
case ISD::STORE: {
if (DCI.isAfterLegalizeDAG())
diff --git a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vfw-web-simplification.ll b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vfw-web-simplification.ll
new file mode 100644
index 00000000000000..26f77225dbb0e1
--- /dev/null
+++ b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vfw-web-simplification.ll
@@ -0,0 +1,88 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py
+; RUN: llc -mtriple=riscv64 -mattr=+v,+zfh,+zvfh,+f,+d -verify-machineinstrs %s -o - --riscv-lower-ext-max-web-size=1 | FileCheck %s --check-prefixes=NO_FOLDING
+; RUN: llc -mtriple=riscv64 -mattr=+v,+zfh,+zvfh,+f,+d -verify-machineinstrs %s -o - --riscv-lower-ext-max-web-size=2 | FileCheck %s --check-prefixes=NO_FOLDING
+; RUN: llc -mtriple=riscv64 -mattr=+v,+zfh,+zvfh,+f,+d -verify-machineinstrs %s -o - --riscv-lower-ext-max-web-size=3 | FileCheck %s --check-prefixes=FOLDING
+; RUN: llc -mtriple=riscv64 -mattr=+v,+zfh,+zvfhmin,+f,+d -verify-machineinstrs %s -o - --riscv-lower-ext-max-web-size=3 | FileCheck %s --check-prefixes=FOLDING,ZVFHMIN
+; Check that the default value enables the web folding and
+; that it is bigger than 3.
+; RUN: llc -mtriple=riscv64 -mattr=+v,+zfh,+zvfh,+f,+d -verify-machineinstrs %s -o - | FileCheck %s --check-prefixes=FOLDING
+
+define void @vfwmul_v2f116_multiple_users(ptr %x, ptr %y, ptr %z, <2 x half> %a, <2 x half> %b, <2 x half> %b2) {
+; NO_FOLDING-LABEL: vfwmul_v2f116_multiple_users:
+; NO_FOLDING: # %bb.0:
+; NO_FOLDING-NEXT: vsetivli zero, 2, e16, mf4, ta, ma
+; NO_FOLDING-NEXT: vfwcvt.f.f.v v11, v8
+; NO_FOLDING-NEXT: vfwcvt.f.f.v v8, v9
+; NO_FOLDING-NEXT: vfwcvt.f.f.v v9, v10
+; NO_FOLDING-NEXT: vsetvli zero, zero, e32, mf2, ta, ma
+; NO_FOLDING-NEXT: vfmul.vv v10, v11, v8
+; NO_FOLDING-NEXT: vfadd.vv v11, v11, v9
+; NO_FOLDING-NEXT: vfsub.vv v8, v8, v9
+; NO_FOLDING-NEXT: vse32.v v10, (a0)
+; NO_FOLDING-NEXT: vse32.v v11, (a1)
+; NO_FOLDING-NEXT: vse32.v v8, (a2)
+; NO_FOLDING-NEXT: ret
+;
+; ZVFHMIN-LABEL: vfwmul_v2f116_multiple_users:
+; ZVFHMIN: # %bb.0:
+; ZVFHMIN-NEXT: vsetivli zero, 2, e16, mf4, ta, ma
+; ZVFHMIN-NEXT: vfwcvt.f.f.v v11, v8
+; ZVFHMIN-NEXT: vfwcvt.f.f.v v8, v9
+; ZVFHMIN-NEXT: vfwcvt.f.f.v v9, v10
+; ZVFHMIN-NEXT: vsetvli zero, zero, e32, mf2, ta, ma
+; ZVFHMIN-NEXT: vfmul.vv v10, v11, v8
+; ZVFHMIN-NEXT: vfadd.vv v11, v11, v9
+; ZVFHMIN-NEXT: vfsub.vv v8, v8, v9
+; ZVFHMIN-NEXT: vse32.v v10, (a0)
+; ZVFHMIN-NEXT: vse32.v v11, (a1)
+; ZVFHMIN-NEXT: vse32.v v8, (a2)
+; ZVFHMIN-NEXT: ret
+ %c = fpext <2 x half> %a to <2 x float>
+ %d = fpext <2 x half> %b to <2 x float>
+ %d2 = fpext <2 x half> %b2 to <2 x float>
+ %e = fmul <2 x float> %c, %d
+ %f = fadd <2 x float> %c, %d2
+ %g = fsub <2 x float> %d, %d2
+ store <2 x float> %e, ptr %x
+ store <2 x float> %f, ptr %y
+ store <2 x float> %g, ptr %z
+ ret void
+}
+
+define void @vfwmul_v2f32_multiple_users(ptr %x, ptr %y, ptr %z, <2 x float> %a, <2 x float> %b, <2 x float> %b2) {
+; NO_FOLDING-LABEL: vfwmul_v2f32_multiple_users:
+; NO_FOLDING: # %bb.0:
+; NO_FOLDING-NEXT: vsetivli zero, 2, e32, mf2, ta, ma
+; NO_FOLDING-NEXT: vfwcvt.f.f.v v11, v8
+; NO_FOLDING-NEXT: vfwcvt.f.f.v v8, v9
+; NO_FOLDING-NEXT: vfwcvt.f.f.v v9, v10
+; NO_FOLDING-NEXT: vsetvli zero, zero, e64, m1, ta, ma
+; NO_FOLDING-NEXT: vfmul.vv v10, v11, v8
+; NO_FOLDING-NEXT: vfadd.vv v11, v11, v9
+; NO_FOLDING-NEXT: vfsub.vv v8, v8, v9
+; NO_FOLDING-NEXT: vse64.v v10, (a0)
+; NO_FOLDING-NEXT: vse64.v v11, (a1)
+; NO_FOLDING-NEXT: vse64.v v8, (a2)
+; NO_FOLDING-NEXT: ret
+;
+; FOLDING-LABEL: vfwmul_v2f32_multiple_users:
+; FOLDING: # %bb.0:
+; FOLDING-NEXT: vsetivli zero, 2, e32, mf2, ta, ma
+; FOLDING-NEXT: vfwmul.vv v11, v8, v9
+; FOLDING-NEXT: vfwadd.vv v12, v8, v10
+; FOLDING-NEXT: vfwsub.vv v8, v9, v10
+; FOLDING-NEXT: vse64.v v11, (a0)
+; FOLDING-NEXT: vse64.v v12, (a1)
+; FOLDING-NEXT: vse64.v v8, (a2)
+; FOLDING-NEXT: ret
+ %c = fpext <2 x float> %a to <2 x double>
+ %d = fpext <2 x float> %b to <2 x double>
+ %d2 = fpext <2 x float> %b2 to <2 x double>
+ %e = fmul <2 x double> %c, %d
+ %f = fadd <2 x double> %c, %d2
+ %g = fsub <2 x double> %d, %d2
+ store <2 x double> %e, ptr %x
+ store <2 x double> %f, ptr %y
+ store <2 x double> %g, ptr %z
+ ret void
+}
diff --git a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vfwadd.ll b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vfwadd.ll
index c9dc75e18774f8..dd3a50cfd77377 100644
--- a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vfwadd.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vfwadd.ll
@@ -396,12 +396,10 @@ define <32 x double> @vfwadd_vf_v32f32(ptr %x, float %y) {
; CHECK-NEXT: vsetvli zero, a1, e32, m8, ta, ma
; CHECK-NEXT: vle32.v v24, (a0)
; CHECK-NEXT: vsetivli zero, 16, e32, m8, ta, ma
-; CHECK-NEXT: vslidedown.vi v0, v24, 16
+; CHECK-NEXT: vslidedown.vi v8, v24, 16
; CHECK-NEXT: vsetivli zero, 16, e32, m4, ta, ma
-; CHECK-NEXT: vfmv.v.f v16, fa0
-; CHECK-NEXT: vfwcvt.f.f.v v8, v16
-; CHECK-NEXT: vfwadd.wv v16, v8, v0
-; CHECK-NEXT: vfwadd.wv v8, v8, v24
+; CHECK-NEXT: vfwadd.vf v16, v8, fa0
+; CHECK-NEXT: vfwadd.vf v8, v24, fa0
; CHECK-NEXT: ret
%a = load <32 x float>, ptr %x
%b = insertelement <32 x float> poison, float %y, i32 0
diff --git a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vfwmul.ll b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vfwmul.ll
index 8ad858d4c76598..7eaa1856ce2218 100644
--- a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vfwmul.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vfwmul.ll
@@ -394,18 +394,12 @@ define <32 x double> @vfwmul_vf_v32f32(ptr %x, float %y) {
; CHECK: # %bb.0:
; CHECK-NEXT: li a1, 32
; CHECK-NEXT: vsetvli zero, a1, e32, m8, ta, ma
-; CHECK-NEXT: vle32.v v16, (a0)
-; CHECK-NEXT: vsetivli zero, 16, e32, m4, ta, ma
-; CHECK-NEXT: vfwcvt.f.f.v v8, v16
+; CHECK-NEXT: vle32.v v24, (a0)
; CHECK-NEXT: vsetivli zero, 16, e32, m8, ta, ma
-; CHECK-NEXT: vslidedown.vi v16, v16, 16
+; CHECK-NEXT: vslidedown.vi v8, v24, 16
; CHECK-NEXT: vsetivli zero, 16, e32, m4, ta, ma
-; CHECK-NEXT: vfwcvt.f.f.v v24, v16
-; CHECK-NEXT: vfmv.v.f v16, fa0
-; CHECK-NEXT: vfwcvt.f.f.v v0, v16
-; CHECK-NEXT: vsetvli zero, zero, e64, m8, ta, ma
-; CHECK-NEXT: vfmul.vv v16, v24, v0
-; CHECK-NEXT: vfmul.vv v8, v8, v0
+; CHECK-NEXT: vfwmul.vf v16, v8, fa0
+; CHECK-NEXT: vfwmul.vf v8, v24, fa0
; CHECK-NEXT: ret
%a = load <32 x float>, ptr %x
%b = insertelement <32 x float> poison, float %y, i32 0
diff --git a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vfwsub.ll b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vfwsub.ll
index d22781d6a97ac2..8cf7c5f1758654 100644
--- a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vfwsub.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vfwsub.ll
@@ -394,18 +394,12 @@ define <32 x double> @vfwsub_vf_v32f32(ptr %x, float %y) {
; CHECK: # %bb.0:
; CHECK-NEXT: li a1, 32
; CHECK-NEXT: vsetvli zero, a1, e32, m8, ta, ma
-; CHECK-NEXT: vle32.v v16, (a0)
-; CHECK-NEXT: vsetivli zero, 16, e32, m4, ta, ma
-; CHECK-NEXT: vfwcvt.f.f.v v8, v16
+; CHECK-NEXT: vle32.v v24, (a0)
; CHECK-NEXT: vsetivli zero, 16, e32, m8, ta, ma
-; CHECK-NEXT: vslidedown.vi v16, v16, 16
+; CHECK-NEXT: vslidedown.vi v8, v24, 16
; CHECK-NEXT: vsetivli zero, 16, e32, m4, ta, ma
-; CHECK-NEXT: vfwcvt.f.f.v v24, v16
-; CHECK-NEXT: vfmv.v.f v16, fa0
-; CHECK-NEXT: vfwcvt.f.f.v v0, v16
-; CHECK-NEXT: vsetvli zero, zero, e64, m8, ta, ma
-; CHECK-NEXT: vfsub.vv v16, v24, v0
-; CHECK-NEXT: vfsub.vv v8, v8, v0
+; CHECK-NEXT: vfwsub.vf v16, v8, fa0
+; CHECK-NEXT: vfwsub.vf v8, v24, fa0
; CHECK-NEXT: ret
%a = load <32 x float>, ptr %x
%b = insertelement <32 x float> poison, float %y, i32 0
>From 93ded43e072cad86c77e8fd73e0c65aee2787307 Mon Sep 17 00:00:00 2001
From: sun-jacobi <sun1011jacobi at gmail.com>
Date: Sun, 18 Feb 2024 21:45:10 +0900
Subject: [PATCH 2/2] add AllowExtMask
---
llvm/lib/Target/RISCV/RISCVISelLowering.cpp | 47 +++++++++------------
1 file changed, 20 insertions(+), 27 deletions(-)
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index 2c71538f1b7795..caf7325d766c9e 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -13292,8 +13292,7 @@ namespace {
// apply a combine.
struct CombineResult;
-enum class ExtKind { ZExt, SExt, FPExt };
-
+enum ExtKind : uint8_t { ZExt = 1 << 0, SExt = 1 << 1, FPExt = 1 << 2 };
/// Helper class for folding sign/zero extensions.
/// In particular, this class is used for the following combines:
/// add | add_vl -> vwadd(u) | vwadd(u)_w
@@ -13424,13 +13423,11 @@ struct NodeExtensionHelper {
// Determine the narrow size.
unsigned NarrowSize = VT.getScalarSizeInBits() / 2;
- unsigned NarrowMinSize = SupportsExt == ExtKind::FPExt ? 16 : 8;
-
MVT EltVT = SupportsExt == ExtKind::FPExt
? MVT::getFloatingPointVT(NarrowSize)
: MVT::getIntegerVT(NarrowSize);
- assert(NarrowSize >= NarrowMinSize &&
+ assert(NarrowSize >= (SupportsExt == ExtKind::FPExt ? 16 : 8) &&
"Trying to extend something we can't represent");
MVT NarrowVT = MVT::getVectorVT(EltVT, VT.getVectorElementCount());
return NarrowVT;
@@ -13799,33 +13796,32 @@ struct CombineResult {
/// Check if \p Root follows a pattern Root(ext(LHS), ext(RHS))
/// where `ext` is the same for both LHS and RHS (i.e., both are sext or both
/// are zext) and LHS and RHS can be folded into Root.
-/// AllowSExt and AllozZExt define which form `ext` can take in this pattern.
+/// AllowExtMask define which form `ext` can take in this pattern.
///
/// \note If the pattern can match with both zext and sext, the returned
/// CombineResult will feature the zext result.
///
/// \returns std::nullopt if the pattern doesn't match or a CombineResult that
/// can be used to apply the pattern.
-static std::optional<CombineResult> canFoldToVWWithSameExtensionImpl(
- SDNode *Root, const NodeExtensionHelper &LHS,
- const NodeExtensionHelper &RHS, bool AllowSExt, bool AllowZExt,
- bool AllowFPExt, SelectionDAG &DAG, const RISCVSubtarget &Subtarget) {
- assert((AllowSExt || AllowZExt || AllowFPExt) &&
- "Forgot to set what you want?");
+static std::optional<CombineResult>
+canFoldToVWWithSameExtensionImpl(SDNode *Root, const NodeExtensionHelper &LHS,
+ const NodeExtensionHelper &RHS,
+ uint8_t AllowExtMask, SelectionDAG &DAG,
+ const RISCVSubtarget &Subtarget) {
if (!LHS.areVLAndMaskCompatible(Root, DAG, Subtarget) ||
!RHS.areVLAndMaskCompatible(Root, DAG, Subtarget))
return std::nullopt;
- if (AllowZExt && LHS.SupportsZExt && RHS.SupportsZExt)
+ if (AllowExtMask & ExtKind::ZExt && LHS.SupportsZExt && RHS.SupportsZExt)
return CombineResult(NodeExtensionHelper::getSameExtensionOpcode(
Root->getOpcode(), ExtKind::ZExt),
Root, LHS, /*LHSExt=*/{ExtKind::ZExt}, RHS,
/*RHSExt=*/{ExtKind::ZExt});
- if (AllowSExt && LHS.SupportsSExt && RHS.SupportsSExt)
+ if (AllowExtMask & ExtKind::SExt && LHS.SupportsSExt && RHS.SupportsSExt)
return CombineResult(NodeExtensionHelper::getSameExtensionOpcode(
Root->getOpcode(), ExtKind::SExt),
Root, LHS, /*LHSExt=*/{ExtKind::SExt}, RHS,
/*RHSExt=*/{ExtKind::SExt});
- if (AllowFPExt && LHS.SupportsFPExt && RHS.SupportsFPExt)
+ if (AllowExtMask & ExtKind::FPExt && RHS.SupportsFPExt)
return CombineResult(NodeExtensionHelper::getSameExtensionOpcode(
Root->getOpcode(), ExtKind::FPExt),
Root, LHS, /*LHSExt=*/{ExtKind::FPExt}, RHS,
@@ -13843,9 +13839,9 @@ static std::optional<CombineResult>
canFoldToVWWithSameExtension(SDNode *Root, const NodeExtensionHelper &LHS,
const NodeExtensionHelper &RHS, SelectionDAG &DAG,
const RISCVSubtarget &Subtarget) {
- return canFoldToVWWithSameExtensionImpl(Root, LHS, RHS, /*AllowSExt=*/true,
- /*AllowZExt=*/true,
- /*AllowFPExt=*/true, DAG, Subtarget);
+ return canFoldToVWWithSameExtensionImpl(
+ Root, LHS, RHS, ExtKind::ZExt | ExtKind::SExt | ExtKind::FPExt, DAG,
+ Subtarget);
}
/// Check if \p Root follows a pattern Root(LHS, ext(RHS))
@@ -13887,9 +13883,8 @@ static std::optional<CombineResult>
canFoldToVWWithSEXT(SDNode *Root, const NodeExtensionHelper &LHS,
const NodeExtensionHelper &RHS, SelectionDAG &DAG,
const RISCVSubtarget &Subtarget) {
- return canFoldToVWWithSameExtensionImpl(Root, LHS, RHS, /*AllowSExt=*/true,
- /*AllowZExt=*/false,
- /*AllowFPExt=*/false, DAG, Subtarget);
+ return canFoldToVWWithSameExtensionImpl(Root, LHS, RHS, ExtKind::SExt, DAG,
+ Subtarget);
}
/// Check if \p Root follows a pattern Root(zext(LHS), zext(RHS))
@@ -13900,9 +13895,8 @@ static std::optional<CombineResult>
canFoldToVWWithZEXT(SDNode *Root, const NodeExtensionHelper &LHS,
const NodeExtensionHelper &RHS, SelectionDAG &DAG,
const RISCVSubtarget &Subtarget) {
- return canFoldToVWWithSameExtensionImpl(Root, LHS, RHS, /*AllowSExt=*/false,
- /*AllowZExt=*/true,
- /*AllowFPExt=*/false, DAG, Subtarget);
+ return canFoldToVWWithSameExtensionImpl(Root, LHS, RHS, ExtKind::ZExt, DAG,
+ Subtarget);
}
/// Check if \p Root follows a pattern Root(fpext(LHS), fpext(RHS))
@@ -13913,9 +13907,8 @@ static std::optional<CombineResult>
canFoldToVWWithFPEXT(SDNode *Root, const NodeExtensionHelper &LHS,
const NodeExtensionHelper &RHS, SelectionDAG &DAG,
const RISCVSubtarget &Subtarget) {
- return canFoldToVWWithSameExtensionImpl(Root, LHS, RHS, /*AllowSExt=*/false,
- /*AllowZExt=*/false,
- /*AllowFPExt=*/true, DAG, Subtarget);
+ return canFoldToVWWithSameExtensionImpl(Root, LHS, RHS, ExtKind::FPExt, DAG,
+ Subtarget);
}
/// Check if \p Root follows a pattern Root(sext(LHS), zext(RHS))
More information about the llvm-commits
mailing list