[llvm] [RISCV][ISel] Combine vector fadd/fsub/fmul with fp extend. (PR #76695)
via llvm-commits
llvm-commits at lists.llvm.org
Tue Jan 2 04:04:39 PST 2024
https://github.com/sun-jacobi updated https://github.com/llvm/llvm-project/pull/76695
>From 5f9a44a5d2ea90b5d7f990cf1db306753ba77ecb Mon Sep 17 00:00:00 2001
From: sun-jacobi <sun1011jacobi at gmail.com>
Date: Sun, 31 Dec 2023 02:21:23 +0900
Subject: [PATCH 1/7] [RISCV][ISel] refactor combineBinOp_VLToVWBinOp_VL
---
llvm/lib/Target/RISCV/RISCVISelLowering.cpp | 165 ++++++++++++++------
1 file changed, 113 insertions(+), 52 deletions(-)
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index 51580d15451ca2..ca8b5d0d9422ef 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -12848,6 +12848,8 @@ namespace {
// apply a combine.
struct CombineResult;
+enum class SupportExt { ZExt, SExt };
+
/// Helper class for folding sign/zero extensions.
/// In particular, this class is used for the following combines:
/// add | add_vl -> vwadd(u) | vwadd(u)_w
@@ -12912,13 +12914,22 @@ struct NodeExtensionHelper {
return OrigOperand.getOpcode() == RISCVISD::VMV_V_X_VL;
}
+ unsigned getExtOpc(SupportExt Ext) const {
+ switch (Ext) {
+ case SupportExt::ZExt:
+ return RISCVISD::VZEXT_VL;
+ case SupportExt::SExt:
+ return RISCVISD::VSEXT_VL;
+ }
+ }
+
/// Get or create a value that can feed \p Root with the given extension \p
/// SExt. If \p SExt is std::nullopt, this returns the source of this operand.
/// \see ::getSource().
SDValue getOrCreateExtendedOp(SDNode *Root, SelectionDAG &DAG,
const RISCVSubtarget &Subtarget,
- std::optional<bool> SExt) const {
- if (!SExt.has_value())
+ std::optional<SupportExt> Ext) const {
+ if (!Ext.has_value())
return OrigOperand;
MVT NarrowVT = getNarrowType(Root);
@@ -12927,7 +12938,7 @@ struct NodeExtensionHelper {
if (Source.getValueType() == NarrowVT)
return Source;
- unsigned ExtOpc = *SExt ? RISCVISD::VSEXT_VL : RISCVISD::VZEXT_VL;
+ unsigned ExtOpc = getExtOpc(*Ext);
// If we need an extension, we should be changing the type.
SDLoc DL(Root);
@@ -12965,56 +12976,103 @@ struct NodeExtensionHelper {
return NarrowVT;
}
- /// Return the opcode required to materialize the folding of the sign
- /// extensions (\p IsSExt == true) or zero extensions (IsSExt == false) for
- /// both operands for \p Opcode.
- /// Put differently, get the opcode to materialize:
- /// - ISExt == true: \p Opcode(sext(a), sext(b)) -> newOpcode(a, b)
- /// - ISExt == false: \p Opcode(zext(a), zext(b)) -> newOpcode(a, b)
- /// \pre \p Opcode represents a supported root (\see ::isSupportedRoot()).
- static unsigned getSameExtensionOpcode(unsigned Opcode, bool IsSExt) {
+ static unsigned getSignedDoubleWidenOpcode(unsigned Opcode) {
switch (Opcode) {
case ISD::ADD:
case RISCVISD::ADD_VL:
case RISCVISD::VWADD_W_VL:
- case RISCVISD::VWADDU_W_VL:
- return IsSExt ? RISCVISD::VWADD_VL : RISCVISD::VWADDU_VL;
+ return RISCVISD::VWADD_VL;
case ISD::MUL:
case RISCVISD::MUL_VL:
- return IsSExt ? RISCVISD::VWMUL_VL : RISCVISD::VWMULU_VL;
+ return RISCVISD::VWMUL_VL;
case ISD::SUB:
case RISCVISD::SUB_VL:
case RISCVISD::VWSUB_W_VL:
+ return RISCVISD::VWSUB_VL;
+ default:
+ llvm_unreachable("Unexpected Opcode");
+ }
+ }
+
+ static unsigned getUnsignedDoubleWidenOpcode(unsigned Opcode) {
+ switch (Opcode) {
+ case ISD::ADD:
+ case RISCVISD::ADD_VL:
+ case RISCVISD::VWADDU_W_VL:
+ return RISCVISD::VWADDU_VL;
+ case ISD::MUL:
+ case RISCVISD::MUL_VL:
+ return RISCVISD::VWMULU_VL;
+ case ISD::SUB:
+ case RISCVISD::SUB_VL:
case RISCVISD::VWSUBU_W_VL:
- return IsSExt ? RISCVISD::VWSUB_VL : RISCVISD::VWSUBU_VL;
+ return RISCVISD::VWSUBU_VL;
default:
- llvm_unreachable("Unexpected opcode");
+ llvm_unreachable("Unexpected Opcode");
+ }
+ }
+
+ /// Return the opcode required to materialize the folding of the sign
+ /// extensions (\p IsSExt == true) or zero extensions (IsSExt == false) for
+ /// both operands for \p Opcode.
+ /// Put differently, get the opcode to materialize:
+ /// - ISExt == true: \p Opcode(sext(a), sext(b)) -> newOpcode(a, b)
+ /// - ISExt == false: \p Opcode(zext(a), zext(b)) -> newOpcode(a, b)
+ /// \pre \p Opcode represents a supported root (\see ::isSupportedRoot()).
+ static unsigned getDoubleWidenOpcode(unsigned OrigOpcode, SupportExt Ext) {
+ switch (Ext) {
+ case SupportExt::SExt:
+ return getSignedDoubleWidenOpcode(OrigOpcode);
+ case SupportExt::ZExt:
+ return getUnsignedDoubleWidenOpcode(OrigOpcode);
}
}
/// Get the opcode to materialize \p Opcode(sext(a), zext(b)) ->
/// newOpcode(a, b).
- static unsigned getSUOpcode(unsigned Opcode) {
+ static unsigned getSignedUnsignedWidenOpcode(unsigned Opcode) {
assert((Opcode == RISCVISD::MUL_VL || Opcode == ISD::MUL) &&
"SU is only supported for MUL");
return RISCVISD::VWMULSU_VL;
}
- /// Get the opcode to materialize \p Opcode(a, s|zext(b)) ->
- /// newOpcode(a, b).
- static unsigned getWOpcode(unsigned Opcode, bool IsSExt) {
+ static unsigned getSignedSingleWidenOpcode(unsigned Opcode) {
+ switch (Opcode) {
+ case ISD::ADD:
+ case RISCVISD::ADD_VL:
+ return RISCVISD::VWADD_W_VL;
+ case ISD::SUB:
+ case RISCVISD::SUB_VL:
+ return RISCVISD::VWSUB_W_VL;
+ default:
+ llvm_unreachable("Unexpected opcode");
+ }
+ }
+
+ static unsigned getUnsignedSingleWidenOpcode(unsigned Opcode) {
switch (Opcode) {
case ISD::ADD:
case RISCVISD::ADD_VL:
- return IsSExt ? RISCVISD::VWADD_W_VL : RISCVISD::VWADDU_W_VL;
+ return RISCVISD::VWADDU_W_VL;
case ISD::SUB:
case RISCVISD::SUB_VL:
- return IsSExt ? RISCVISD::VWSUB_W_VL : RISCVISD::VWSUBU_W_VL;
+ return RISCVISD::VWSUBU_W_VL;
default:
llvm_unreachable("Unexpected opcode");
}
}
+ /// Get the opcode to materialize \p Opcode(a, s|zext(b)) ->
+ /// newOpcode(a, b).
+ static unsigned getSingleWidenOpcode(unsigned Opcode, SupportExt Ext) {
+ switch (Ext) {
+ case SupportExt::SExt:
+ return getSignedSingleWidenOpcode(Opcode);
+ case SupportExt::ZExt:
+ return getUnsignedSingleWidenOpcode(Opcode);
+ }
+ }
+
using CombineToTry = std::function<std::optional<CombineResult>(
SDNode * /*Root*/, const NodeExtensionHelper & /*LHS*/,
const NodeExtensionHelper & /*RHS*/, SelectionDAG &,
@@ -13227,22 +13285,23 @@ struct NodeExtensionHelper {
struct CombineResult {
/// Opcode to be generated when materializing the combine.
unsigned TargetOpcode;
- // No value means no extension is needed. If extension is needed, the value
- // indicates if it needs to be sign extended.
- std::optional<bool> SExtLHS;
- std::optional<bool> SExtRHS;
- /// Root of the combine.
- SDNode *Root;
/// LHS of the TargetOpcode.
NodeExtensionHelper LHS;
+ /// Extension of the LHS
+ std::optional<SupportExt> ExtLHS;
/// RHS of the TargetOpcode.
NodeExtensionHelper RHS;
+ /// Extension of the RHS
+ std::optional<SupportExt> ExtRHS;
+ /// Root of the combine.
+ SDNode *Root;
- CombineResult(unsigned TargetOpcode, SDNode *Root,
- const NodeExtensionHelper &LHS, std::optional<bool> SExtLHS,
- const NodeExtensionHelper &RHS, std::optional<bool> SExtRHS)
- : TargetOpcode(TargetOpcode), SExtLHS(SExtLHS), SExtRHS(SExtRHS),
- Root(Root), LHS(LHS), RHS(RHS) {}
+ CombineResult(
+ unsigned TargetOpcode, SDNode *Root,
+ std::pair<const NodeExtensionHelper &, std::optional<SupportExt>> LHS,
+ std::pair<const NodeExtensionHelper &, std::optional<SupportExt>> RHS)
+ : TargetOpcode(TargetOpcode), LHS(LHS.first), ExtLHS(LHS.second),
+ RHS(RHS.first), ExtRHS(RHS.second), Root(Root) {}
/// Return a value that uses TargetOpcode and that can be used to replace
/// Root.
@@ -13263,8 +13322,8 @@ struct CombineResult {
break;
}
return DAG.getNode(TargetOpcode, SDLoc(Root), Root->getValueType(0),
- LHS.getOrCreateExtendedOp(Root, DAG, Subtarget, SExtLHS),
- RHS.getOrCreateExtendedOp(Root, DAG, Subtarget, SExtRHS),
+ LHS.getOrCreateExtendedOp(Root, DAG, Subtarget, ExtLHS),
+ RHS.getOrCreateExtendedOp(Root, DAG, Subtarget, ExtRHS),
Merge, Mask, VL);
}
};
@@ -13289,14 +13348,15 @@ canFoldToVWWithSameExtensionImpl(SDNode *Root, const NodeExtensionHelper &LHS,
!RHS.areVLAndMaskCompatible(Root, DAG, Subtarget))
return std::nullopt;
if (AllowZExt && LHS.SupportsZExt && RHS.SupportsZExt)
- return CombineResult(NodeExtensionHelper::getSameExtensionOpcode(
- Root->getOpcode(), /*IsSExt=*/false),
- Root, LHS, /*SExtLHS=*/false, RHS, /*SExtRHS=*/false);
+ return CombineResult(NodeExtensionHelper::getDoubleWidenOpcode(
+ Root->getOpcode(), SupportExt::ZExt),
+ Root, {LHS, SupportExt::ZExt},
+ {RHS, SupportExt::ZExt});
if (AllowSExt && LHS.SupportsSExt && RHS.SupportsSExt)
- return CombineResult(NodeExtensionHelper::getSameExtensionOpcode(
- Root->getOpcode(), /*IsSExt=*/true),
- Root, LHS, /*SExtLHS=*/true, RHS,
- /*SExtRHS=*/true);
+ return CombineResult(NodeExtensionHelper::getDoubleWidenOpcode(
+ Root->getOpcode(), SupportExt::SExt),
+ Root, {LHS, SupportExt::SExt},
+ {RHS, SupportExt::SExt});
return std::nullopt;
}
@@ -13330,13 +13390,13 @@ canFoldToVW_W(SDNode *Root, const NodeExtensionHelper &LHS,
// Control this behavior behind an option (AllowSplatInVW_W) for testing
// purposes.
if (RHS.SupportsZExt && (!RHS.isSplat() || AllowSplatInVW_W))
- return CombineResult(
- NodeExtensionHelper::getWOpcode(Root->getOpcode(), /*IsSExt=*/false),
- Root, LHS, /*SExtLHS=*/std::nullopt, RHS, /*SExtRHS=*/false);
+ return CombineResult(NodeExtensionHelper::getSingleWidenOpcode(
+ Root->getOpcode(), SupportExt::ZExt),
+ Root, {LHS, std::nullopt}, {RHS, SupportExt::ZExt});
if (RHS.SupportsSExt && (!RHS.isSplat() || AllowSplatInVW_W))
- return CombineResult(
- NodeExtensionHelper::getWOpcode(Root->getOpcode(), /*IsSExt=*/true),
- Root, LHS, /*SExtLHS=*/std::nullopt, RHS, /*SExtRHS=*/true);
+ return CombineResult(NodeExtensionHelper::getSingleWidenOpcode(
+ Root->getOpcode(), SupportExt::SExt),
+ Root, {LHS, std::nullopt}, {RHS, SupportExt::SExt});
return std::nullopt;
}
@@ -13378,8 +13438,9 @@ canFoldToVW_SU(SDNode *Root, const NodeExtensionHelper &LHS,
if (!LHS.areVLAndMaskCompatible(Root, DAG, Subtarget) ||
!RHS.areVLAndMaskCompatible(Root, DAG, Subtarget))
return std::nullopt;
- return CombineResult(NodeExtensionHelper::getSUOpcode(Root->getOpcode()),
- Root, LHS, /*SExtLHS=*/true, RHS, /*SExtRHS=*/false);
+ return CombineResult(
+ NodeExtensionHelper::getSignedUnsignedWidenOpcode(Root->getOpcode()),
+ Root, {LHS, SupportExt::SExt}, {RHS, SupportExt::ZExt});
}
SmallVector<NodeExtensionHelper::CombineToTry>
@@ -13481,9 +13542,9 @@ static SDValue combineBinOp_VLToVWBinOp_VL(SDNode *N,
// All the inputs that are extended need to be folded, otherwise
// we would be leaving the old input (since it is may still be used),
// and the new one.
- if (Res->SExtLHS.has_value())
+ if (Res->ExtLHS.has_value())
AppendUsersIfNeeded(LHS);
- if (Res->SExtRHS.has_value())
+ if (Res->ExtRHS.has_value())
AppendUsersIfNeeded(RHS);
break;
}
>From ffbec13e5ee10ee5d478152956d6f01b293445a6 Mon Sep 17 00:00:00 2001
From: sun-jacobi <sun1011jacobi at gmail.com>
Date: Tue, 2 Jan 2024 11:43:36 +0900
Subject: [PATCH 2/7] [RISCV][ISel] Combine vector fadd/fsub/fmul with fp
extend.
---
llvm/lib/Target/RISCV/RISCVISelLowering.cpp | 312 +++++++++++-------
.../CodeGen/RISCV/rvv/fixed-vectors-vfwmul.ll | 14 +-
2 files changed, 195 insertions(+), 131 deletions(-)
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index ca8b5d0d9422ef..cf48cd7c378a50 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -1374,13 +1374,14 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
setPrefLoopAlignment(Subtarget.getPrefLoopAlignment());
setTargetDAGCombine({ISD::INTRINSIC_VOID, ISD::INTRINSIC_W_CHAIN,
- ISD::INTRINSIC_WO_CHAIN, ISD::ADD, ISD::SUB, ISD::MUL,
- ISD::AND, ISD::OR, ISD::XOR, ISD::SETCC, ISD::SELECT});
+ ISD::INTRINSIC_WO_CHAIN, ISD::ADD, ISD::SUB, ISD::AND,
+ ISD::OR, ISD::XOR, ISD::SETCC, ISD::SELECT});
if (Subtarget.is64Bit())
setTargetDAGCombine(ISD::SRA);
if (Subtarget.hasStdExtFOrZfinx())
- setTargetDAGCombine({ISD::FADD, ISD::FMAXNUM, ISD::FMINNUM});
+ setTargetDAGCombine(
+ {ISD::FADD, ISD::FSUB, ISD::FMUL, ISD::FMAXNUM, ISD::FMINNUM});
if (Subtarget.hasStdExtZbb())
setTargetDAGCombine({ISD::UMAX, ISD::UMIN, ISD::SMAX, ISD::SMIN});
@@ -12848,7 +12849,8 @@ namespace {
// apply a combine.
struct CombineResult;
-enum class SupportExt { ZExt, SExt };
+// Supported extension kind to be folded.
+enum class SupportExt { ZExt, SExt, FPExt };
/// Helper class for folding sign/zero extensions.
/// In particular, this class is used for the following combines:
@@ -12880,6 +12882,8 @@ struct NodeExtensionHelper {
/// instance, a splat constant (e.g., 3), would support being both sign and
/// zero extended.
bool SupportsSExt;
+ /// Records if this operand is like being floating-point extended.
+ bool SupportsFPExt;
/// This boolean captures whether we care if this operand would still be
/// around after the folding happens.
bool EnforceOneUse;
@@ -12901,8 +12905,10 @@ struct NodeExtensionHelper {
switch (OrigOperand.getOpcode()) {
case ISD::ZERO_EXTEND:
case ISD::SIGN_EXTEND:
+ case ISD::FP_EXTEND:
case RISCVISD::VSEXT_VL:
case RISCVISD::VZEXT_VL:
+ case RISCVISD::FP_EXTEND_VL:
return OrigOperand.getOperand(0);
default:
return OrigOperand;
@@ -12911,15 +12917,19 @@ struct NodeExtensionHelper {
/// Check if this instance represents a splat.
bool isSplat() const {
- return OrigOperand.getOpcode() == RISCVISD::VMV_V_X_VL;
+ return OrigOperand.getOpcode() == RISCVISD::VMV_V_X_VL ||
+ OrigOperand.getOpcode() == RISCVISD::VFMV_V_F_VL;
}
+ /// Get the extended opcode.
unsigned getExtOpc(SupportExt Ext) const {
switch (Ext) {
case SupportExt::ZExt:
return RISCVISD::VZEXT_VL;
case SupportExt::SExt:
return RISCVISD::VSEXT_VL;
+ case SupportExt::FPExt:
+ return RISCVISD::FP_EXTEND_VL;
}
}
@@ -12938,20 +12948,24 @@ struct NodeExtensionHelper {
if (Source.getValueType() == NarrowVT)
return Source;
- unsigned ExtOpc = getExtOpc(*Ext);
-
+ unsigned OrigOpc = OrigOperand.getOpcode();
// If we need an extension, we should be changing the type.
SDLoc DL(Root);
auto [Mask, VL] = getMaskAndVL(Root, DAG, Subtarget);
- switch (OrigOperand.getOpcode()) {
+ switch (OrigOpc) {
case ISD::ZERO_EXTEND:
case ISD::SIGN_EXTEND:
+ case ISD::FP_EXTEND:
case RISCVISD::VSEXT_VL:
case RISCVISD::VZEXT_VL:
+ case RISCVISD::FP_EXTEND_VL: {
+ unsigned ExtOpc = getExtOpc(*Ext);
return DAG.getNode(ExtOpc, DL, NarrowVT, Source, Mask, VL);
+ }
+ case RISCVISD::VFMV_V_F_VL:
case RISCVISD::VMV_V_X_VL:
- return DAG.getNode(RISCVISD::VMV_V_X_VL, DL, NarrowVT,
- DAG.getUNDEF(NarrowVT), Source.getOperand(1), VL);
+ return DAG.getNode(OrigOpc, DL, NarrowVT, DAG.getUNDEF(NarrowVT),
+ Source.getOperand(1), VL);
default:
// Other opcodes can only come from the original LHS of VW(ADD|SUB)_W_VL
// and that operand should already have the right NarrowVT so no
@@ -12970,13 +12984,22 @@ struct NodeExtensionHelper {
// Determine the narrow size.
unsigned NarrowSize = VT.getScalarSizeInBits() / 2;
- assert(NarrowSize >= 8 && "Trying to extend something we can't represent");
- MVT NarrowVT = MVT::getVectorVT(MVT::getIntegerVT(NarrowSize),
- VT.getVectorElementCount());
- return NarrowVT;
+ // Determine the minimum narrow size.
+ unsigned MinSize = VT.isInteger() ? 8 : 32;
+
+ assert(NarrowSize >= MinSize &&
+ "Trying to extend something we can't represent");
+
+ MVT NarrowScalarVT = VT.isInteger() ? MVT::getIntegerVT(NarrowSize)
+ : MVT::getFloatingPointVT(NarrowSize);
+ MVT NarrowVectorVT =
+ MVT::getVectorVT(NarrowScalarVT, VT.getVectorElementCount());
+ return NarrowVectorVT;
}
- static unsigned getSignedDoubleWidenOpcode(unsigned Opcode) {
+ /// Get full widening (2*SEW = SEW +/-/* SEW) signed integer add/sub/mul
+ /// opcode.
+ static unsigned getSignedFullWidenOpcode(unsigned Opcode) {
switch (Opcode) {
case ISD::ADD:
case RISCVISD::ADD_VL:
@@ -12994,7 +13017,9 @@ struct NodeExtensionHelper {
}
}
- static unsigned getUnsignedDoubleWidenOpcode(unsigned Opcode) {
+ /// Get full widening (2*SEW = SEW +/-/* SEW) unsigned integer add/sub/mul
+ /// opcode.
+ static unsigned getUnsignedFullWidenOpcode(unsigned Opcode) {
switch (Opcode) {
case ISD::ADD:
case RISCVISD::ADD_VL:
@@ -13012,6 +13037,25 @@ struct NodeExtensionHelper {
}
}
+ /// Get full widening (2*SEW = SEW +/-/* SEW) FP add/sub/mul opcode.
+ static unsigned getFloatFullWidenOpcode(unsigned Opcode) {
+ switch (Opcode) {
+ case ISD::FADD:
+ case RISCVISD::FADD_VL:
+ case RISCVISD::VFWADD_W_VL:
+ return RISCVISD::VFWADD_VL;
+ case ISD::FSUB:
+ case RISCVISD::FSUB_VL:
+ case RISCVISD::VFWSUB_W_VL:
+ return RISCVISD::VFWSUB_VL;
+ case ISD::FMUL:
+ case RISCVISD::FMUL_VL:
+ return RISCVISD::VFWMUL_VL;
+ default:
+ llvm_unreachable("Unexpected Opcode");
+ }
+ }
+
/// Return the opcode required to materialize the folding of the sign
/// extensions (\p IsSExt == true) or zero extensions (IsSExt == false) for
/// both operands for \p Opcode.
@@ -13019,12 +13063,14 @@ struct NodeExtensionHelper {
/// - ISExt == true: \p Opcode(sext(a), sext(b)) -> newOpcode(a, b)
/// - ISExt == false: \p Opcode(zext(a), zext(b)) -> newOpcode(a, b)
/// \pre \p Opcode represents a supported root (\see ::isSupportedRoot()).
- static unsigned getDoubleWidenOpcode(unsigned OrigOpcode, SupportExt Ext) {
+ static unsigned getFullWidenOpcode(unsigned OrigOpcode, SupportExt Ext) {
switch (Ext) {
case SupportExt::SExt:
- return getSignedDoubleWidenOpcode(OrigOpcode);
+ return getSignedFullWidenOpcode(OrigOpcode);
case SupportExt::ZExt:
- return getUnsignedDoubleWidenOpcode(OrigOpcode);
+ return getUnsignedFullWidenOpcode(OrigOpcode);
+ case SupportExt::FPExt:
+ return getFloatFullWidenOpcode(OrigOpcode);
}
}
@@ -13036,7 +13082,8 @@ struct NodeExtensionHelper {
return RISCVISD::VWMULSU_VL;
}
- static unsigned getSignedSingleWidenOpcode(unsigned Opcode) {
+ /// Get half widening (2*SEW = 2*SEW +/- SEW) signed integer add/sub opcode.
+ static unsigned getSignedHalfWidenOpcode(unsigned Opcode) {
switch (Opcode) {
case ISD::ADD:
case RISCVISD::ADD_VL:
@@ -13049,7 +13096,8 @@ struct NodeExtensionHelper {
}
}
- static unsigned getUnsignedSingleWidenOpcode(unsigned Opcode) {
+ /// Get half widening (2*SEW = 2*SEW +/- SEW) unsigned integer add/sub opcode.
+ static unsigned getUnsignedHalfWidenOpcode(unsigned Opcode) {
switch (Opcode) {
case ISD::ADD:
case RISCVISD::ADD_VL:
@@ -13062,14 +13110,28 @@ struct NodeExtensionHelper {
}
}
+ /// Get half widening (2*SEW = 2*SEW +/- SEW) FP add/sub opcode.
+ static unsigned getFloatHalfWidenOpcode(unsigned Opcode) {
+ switch (Opcode) {
+ case RISCVISD::FADD_VL:
+ return RISCVISD::VFWADD_W_VL;
+ case RISCVISD::FSUB_VL:
+ return RISCVISD::VFWSUB_W_VL;
+ default:
+ llvm_unreachable("Unexpected opcode");
+ }
+ }
+
/// Get the opcode to materialize \p Opcode(a, s|zext(b)) ->
/// newOpcode(a, b).
- static unsigned getSingleWidenOpcode(unsigned Opcode, SupportExt Ext) {
+ static unsigned getHalfWidenOpcode(unsigned Opcode, SupportExt Ext) {
switch (Ext) {
case SupportExt::SExt:
- return getSignedSingleWidenOpcode(Opcode);
+ return getSignedHalfWidenOpcode(Opcode);
case SupportExt::ZExt:
- return getUnsignedSingleWidenOpcode(Opcode);
+ return getUnsignedHalfWidenOpcode(Opcode);
+ case SupportExt::FPExt:
+ return getFloatHalfWidenOpcode(Opcode);
}
}
@@ -13087,15 +13149,18 @@ struct NodeExtensionHelper {
const RISCVSubtarget &Subtarget) {
SupportsZExt = false;
SupportsSExt = false;
+ SupportsFPExt = false;
EnforceOneUse = true;
CheckMask = true;
unsigned Opc = OrigOperand.getOpcode();
switch (Opc) {
case ISD::ZERO_EXTEND:
- case ISD::SIGN_EXTEND: {
+ case ISD::SIGN_EXTEND:
+ case ISD::FP_EXTEND: {
if (OrigOperand.getValueType().isVector()) {
SupportsZExt = Opc == ISD::ZERO_EXTEND;
SupportsSExt = Opc == ISD::SIGN_EXTEND;
+ SupportsFPExt = Opc == ISD::FP_EXTEND;
SDLoc DL(Root);
MVT VT = Root->getSimpleValueType(0);
std::tie(Mask, VL) = getDefaultScalableVLOps(VT, DL, DAG, Subtarget);
@@ -13112,6 +13177,12 @@ struct NodeExtensionHelper {
Mask = OrigOperand.getOperand(1);
VL = OrigOperand.getOperand(2);
break;
+ case RISCVISD::FP_EXTEND_VL:
+ SupportsFPExt = true;
+ Mask = OrigOperand.getOperand(1);
+ VL = OrigOperand.getOperand(2);
+ break;
+ case RISCVISD::VFMV_V_F_VL:
case RISCVISD::VMV_V_X_VL: {
// Historically, we didn't care about splat values not disappearing during
// combines.
@@ -13138,6 +13209,11 @@ struct NodeExtensionHelper {
if (ScalarBits < EltBits)
break;
+ if (VT.isFloatingPoint()) {
+ SupportsFPExt = true;
+ break;
+ }
+
unsigned NarrowSize = VT.getScalarSizeInBits() / 2;
// If the narrow type cannot be expressed with a legal VMV,
// this is not a valid candidate.
@@ -13161,7 +13237,10 @@ struct NodeExtensionHelper {
switch (Root->getOpcode()) {
case ISD::ADD:
case ISD::SUB:
- case ISD::MUL: {
+ case ISD::MUL:
+ case ISD::FADD:
+ case ISD::FSUB:
+ case ISD::FMUL: {
const TargetLowering &TLI = DAG.getTargetLoweringInfo();
if (!TLI.isTypeLegal(Root->getValueType(0)))
return false;
@@ -13174,6 +13253,11 @@ struct NodeExtensionHelper {
case RISCVISD::SUB_VL:
case RISCVISD::VWSUB_W_VL:
case RISCVISD::VWSUBU_W_VL:
+ case RISCVISD::FADD_VL:
+ case RISCVISD::FSUB_VL:
+ case RISCVISD::FMUL_VL:
+ case RISCVISD::VFWADD_W_VL:
+ case RISCVISD::VFWSUB_W_VL:
return true;
default:
return false;
@@ -13196,10 +13280,15 @@ struct NodeExtensionHelper {
case RISCVISD::VWADDU_W_VL:
case RISCVISD::VWSUB_W_VL:
case RISCVISD::VWSUBU_W_VL:
+ case RISCVISD::VFWADD_W_VL:
+ case RISCVISD::VFWSUB_W_VL:
if (OperandIdx == 1) {
SupportsZExt =
Opc == RISCVISD::VWADDU_W_VL || Opc == RISCVISD::VWSUBU_W_VL;
- SupportsSExt = !SupportsZExt;
+ SupportsSExt =
+ Opc == RISCVISD::VWADD_W_VL || Opc == RISCVISD::VWSUB_W_VL;
+ SupportsFPExt =
+ Opc == RISCVISD::VFWADD_W_VL || Opc == RISCVISD::VFWSUB_W_VL;
std::tie(Mask, VL) = getMaskAndVL(Root, DAG, Subtarget);
CheckMask = true;
// There's no existing extension here, so we don't have to worry about
@@ -13232,7 +13321,10 @@ struct NodeExtensionHelper {
switch (Root->getOpcode()) {
case ISD::ADD:
case ISD::SUB:
- case ISD::MUL: {
+ case ISD::MUL:
+ case ISD::FADD:
+ case ISD::FSUB:
+ case ISD::FMUL: {
SDLoc DL(Root);
MVT VT = Root->getSimpleValueType(0);
return getDefaultScalableVLOps(VT, DL, DAG, Subtarget);
@@ -13255,15 +13347,23 @@ struct NodeExtensionHelper {
switch (N->getOpcode()) {
case ISD::ADD:
case ISD::MUL:
+ case ISD::FADD:
+ case ISD::FMUL:
case RISCVISD::ADD_VL:
case RISCVISD::MUL_VL:
case RISCVISD::VWADD_W_VL:
case RISCVISD::VWADDU_W_VL:
+ case RISCVISD::FADD_VL:
+ case RISCVISD::FMUL_VL:
+ case RISCVISD::VFWADD_W_VL:
return true;
case ISD::SUB:
+ case ISD::FSUB:
case RISCVISD::SUB_VL:
case RISCVISD::VWSUB_W_VL:
case RISCVISD::VWSUBU_W_VL:
+ case RISCVISD::FSUB_VL:
+ case RISCVISD::VFWSUB_W_VL:
return false;
default:
llvm_unreachable("Unexpected opcode");
@@ -13318,6 +13418,9 @@ struct CombineResult {
case ISD::ADD:
case ISD::SUB:
case ISD::MUL:
+ case ISD::FADD:
+ case ISD::FSUB:
+ case ISD::FMUL:
Merge = DAG.getUNDEF(Root->getValueType(0));
break;
}
@@ -13338,25 +13441,30 @@ struct CombineResult {
///
/// \returns std::nullopt if the pattern doesn't match or a CombineResult that
/// can be used to apply the pattern.
-static std::optional<CombineResult>
-canFoldToVWWithSameExtensionImpl(SDNode *Root, const NodeExtensionHelper &LHS,
- const NodeExtensionHelper &RHS, bool AllowSExt,
- bool AllowZExt, SelectionDAG &DAG,
- const RISCVSubtarget &Subtarget) {
- assert((AllowSExt || AllowZExt) && "Forgot to set what you want?");
+static std::optional<CombineResult> canFoldToVWWithSameExtensionImpl(
+ SDNode *Root, const NodeExtensionHelper &LHS,
+ const NodeExtensionHelper &RHS, bool AllowSExt, bool AllowZExt,
+ bool AllowFPExt, SelectionDAG &DAG, const RISCVSubtarget &Subtarget) {
+ assert((AllowSExt || AllowZExt || AllowFPExt) &&
+ "Forgot to set what you want?");
if (!LHS.areVLAndMaskCompatible(Root, DAG, Subtarget) ||
!RHS.areVLAndMaskCompatible(Root, DAG, Subtarget))
return std::nullopt;
if (AllowZExt && LHS.SupportsZExt && RHS.SupportsZExt)
- return CombineResult(NodeExtensionHelper::getDoubleWidenOpcode(
+ return CombineResult(NodeExtensionHelper::getFullWidenOpcode(
Root->getOpcode(), SupportExt::ZExt),
Root, {LHS, SupportExt::ZExt},
{RHS, SupportExt::ZExt});
if (AllowSExt && LHS.SupportsSExt && RHS.SupportsSExt)
- return CombineResult(NodeExtensionHelper::getDoubleWidenOpcode(
+ return CombineResult(NodeExtensionHelper::getFullWidenOpcode(
Root->getOpcode(), SupportExt::SExt),
Root, {LHS, SupportExt::SExt},
{RHS, SupportExt::SExt});
+ if (AllowFPExt && LHS.SupportsFPExt && RHS.SupportsFPExt)
+ return CombineResult(NodeExtensionHelper::getFullWidenOpcode(
+ Root->getOpcode(), SupportExt::FPExt),
+ Root, {LHS, SupportExt::FPExt},
+ {RHS, SupportExt::FPExt});
return std::nullopt;
}
@@ -13371,7 +13479,8 @@ canFoldToVWWithSameExtension(SDNode *Root, const NodeExtensionHelper &LHS,
const NodeExtensionHelper &RHS, SelectionDAG &DAG,
const RISCVSubtarget &Subtarget) {
return canFoldToVWWithSameExtensionImpl(Root, LHS, RHS, /*AllowSExt=*/true,
- /*AllowZExt=*/true, DAG, Subtarget);
+ /*AllowZExt=*/true, true, DAG,
+ Subtarget);
}
/// Check if \p Root follows a pattern Root(LHS, ext(RHS))
@@ -13390,13 +13499,17 @@ canFoldToVW_W(SDNode *Root, const NodeExtensionHelper &LHS,
// Control this behavior behind an option (AllowSplatInVW_W) for testing
// purposes.
if (RHS.SupportsZExt && (!RHS.isSplat() || AllowSplatInVW_W))
- return CombineResult(NodeExtensionHelper::getSingleWidenOpcode(
+ return CombineResult(NodeExtensionHelper::getHalfWidenOpcode(
Root->getOpcode(), SupportExt::ZExt),
Root, {LHS, std::nullopt}, {RHS, SupportExt::ZExt});
if (RHS.SupportsSExt && (!RHS.isSplat() || AllowSplatInVW_W))
- return CombineResult(NodeExtensionHelper::getSingleWidenOpcode(
+ return CombineResult(NodeExtensionHelper::getHalfWidenOpcode(
Root->getOpcode(), SupportExt::SExt),
Root, {LHS, std::nullopt}, {RHS, SupportExt::SExt});
+ if (RHS.SupportsFPExt && (!RHS.isSplat() || AllowSplatInVW_W))
+ return CombineResult(NodeExtensionHelper::getHalfWidenOpcode(
+ Root->getOpcode(), SupportExt::FPExt),
+ Root, {LHS, std::nullopt}, {RHS, SupportExt::FPExt});
return std::nullopt;
}
@@ -13409,7 +13522,8 @@ canFoldToVWWithSEXT(SDNode *Root, const NodeExtensionHelper &LHS,
const NodeExtensionHelper &RHS, SelectionDAG &DAG,
const RISCVSubtarget &Subtarget) {
return canFoldToVWWithSameExtensionImpl(Root, LHS, RHS, /*AllowSExt=*/true,
- /*AllowZExt=*/false, DAG, Subtarget);
+ /*AllowZExt=*/false, false, DAG,
+ Subtarget);
}
/// Check if \p Root follows a pattern Root(zext(LHS), zext(RHS))
@@ -13421,7 +13535,17 @@ canFoldToVWWithZEXT(SDNode *Root, const NodeExtensionHelper &LHS,
const NodeExtensionHelper &RHS, SelectionDAG &DAG,
const RISCVSubtarget &Subtarget) {
return canFoldToVWWithSameExtensionImpl(Root, LHS, RHS, /*AllowSExt=*/false,
- /*AllowZExt=*/true, DAG, Subtarget);
+ /*AllowZExt=*/true, false, DAG,
+ Subtarget);
+}
+
+static std::optional<CombineResult>
+canFoldToVFWWithFPEXT(SDNode *Root, const NodeExtensionHelper &LHS,
+ const NodeExtensionHelper &RHS, SelectionDAG &DAG,
+ const RISCVSubtarget &Subtarget) {
+ return canFoldToVWWithSameExtensionImpl(Root, LHS, RHS, /*AllowSExt=*/false,
+ /*AllowZExt=*/false, true, DAG,
+ Subtarget);
}
/// Check if \p Root follows a pattern Root(sext(LHS), zext(RHS))
@@ -13449,13 +13573,21 @@ NodeExtensionHelper::getSupportedFoldings(const SDNode *Root) {
switch (Root->getOpcode()) {
case ISD::ADD:
case ISD::SUB:
+ case ISD::FADD:
+ case ISD::FSUB:
case RISCVISD::ADD_VL:
case RISCVISD::SUB_VL:
+ case RISCVISD::FADD_VL:
+ case RISCVISD::FSUB_VL:
// add|sub -> vwadd(u)|vwsub(u)
Strategies.push_back(canFoldToVWWithSameExtension);
// add|sub -> vwadd(u)_w|vwsub(u)_w
Strategies.push_back(canFoldToVW_W);
break;
+ case ISD::FMUL:
+ case RISCVISD::FMUL_VL:
+ Strategies.push_back(canFoldToVWWithSameExtension);
+ break;
case ISD::MUL:
case RISCVISD::MUL_VL:
// mul -> vwmul(u)
@@ -13473,6 +13605,10 @@ NodeExtensionHelper::getSupportedFoldings(const SDNode *Root) {
// vwaddu_w|vwsubu_w -> vwaddu|vwsubu
Strategies.push_back(canFoldToVWWithZEXT);
break;
+ case RISCVISD::VFWADD_W_VL:
+ case RISCVISD::VFWSUB_W_VL:
+ Strategies.push_back(canFoldToVFWWithFPEXT);
+ break;
default:
llvm_unreachable("Unexpected opcode");
}
@@ -14051,7 +14187,8 @@ static SDValue performVFMADD_VLCombine(SDNode *N, SelectionDAG &DAG,
N->getOperand(2), Mask, VL);
}
-static SDValue performVFMUL_VLCombine(SDNode *N, SelectionDAG &DAG,
+static SDValue performVFMUL_VLCombine(SDNode *N,
+ TargetLowering::DAGCombinerInfo &DCI,
const RISCVSubtarget &Subtarget) {
if (N->getValueType(0).isScalableVector() &&
N->getValueType(0).getVectorElementType() == MVT::f32 &&
@@ -14063,36 +14200,11 @@ static SDValue performVFMUL_VLCombine(SDNode *N, SelectionDAG &DAG,
// FIXME: Ignore strict opcodes for now.
assert(!N->isTargetStrictFPOpcode() && "Unexpected opcode");
- // Try to form widening multiply.
- SDValue Op0 = N->getOperand(0);
- SDValue Op1 = N->getOperand(1);
- SDValue Merge = N->getOperand(2);
- SDValue Mask = N->getOperand(3);
- SDValue VL = N->getOperand(4);
-
- if (Op0.getOpcode() != RISCVISD::FP_EXTEND_VL ||
- Op1.getOpcode() != RISCVISD::FP_EXTEND_VL)
- return SDValue();
-
- // TODO: Refactor to handle more complex cases similar to
- // combineBinOp_VLToVWBinOp_VL.
- if ((!Op0.hasOneUse() || !Op1.hasOneUse()) &&
- (Op0 != Op1 || !Op0->hasNUsesOfValue(2, 0)))
- return SDValue();
-
- // Check the mask and VL are the same.
- if (Op0.getOperand(1) != Mask || Op0.getOperand(2) != VL ||
- Op1.getOperand(1) != Mask || Op1.getOperand(2) != VL)
- return SDValue();
-
- Op0 = Op0.getOperand(0);
- Op1 = Op1.getOperand(0);
-
- return DAG.getNode(RISCVISD::VFWMUL_VL, SDLoc(N), N->getValueType(0), Op0,
- Op1, Merge, Mask, VL);
+ return combineBinOp_VLToVWBinOp_VL(N, DCI, Subtarget);
}
-static SDValue performFADDSUB_VLCombine(SDNode *N, SelectionDAG &DAG,
+static SDValue performFADDSUB_VLCombine(SDNode *N,
+ TargetLowering::DAGCombinerInfo &DCI,
const RISCVSubtarget &Subtarget) {
if (N->getValueType(0).isScalableVector() &&
N->getValueType(0).getVectorElementType() == MVT::f32 &&
@@ -14101,55 +14213,7 @@ static SDValue performFADDSUB_VLCombine(SDNode *N, SelectionDAG &DAG,
return SDValue();
}
- SDValue Op0 = N->getOperand(0);
- SDValue Op1 = N->getOperand(1);
- SDValue Merge = N->getOperand(2);
- SDValue Mask = N->getOperand(3);
- SDValue VL = N->getOperand(4);
-
- bool IsAdd = N->getOpcode() == RISCVISD::FADD_VL;
-
- // Look for foldable FP_EXTENDS.
- bool Op0IsExtend =
- Op0.getOpcode() == RISCVISD::FP_EXTEND_VL &&
- (Op0.hasOneUse() || (Op0 == Op1 && Op0->hasNUsesOfValue(2, 0)));
- bool Op1IsExtend =
- (Op0 == Op1 && Op0IsExtend) ||
- (Op1.getOpcode() == RISCVISD::FP_EXTEND_VL && Op1.hasOneUse());
-
- // Check the mask and VL.
- if (Op0IsExtend && (Op0.getOperand(1) != Mask || Op0.getOperand(2) != VL))
- Op0IsExtend = false;
- if (Op1IsExtend && (Op1.getOperand(1) != Mask || Op1.getOperand(2) != VL))
- Op1IsExtend = false;
-
- // Canonicalize.
- if (!Op1IsExtend) {
- // Sub requires at least operand 1 to be an extend.
- if (!IsAdd)
- return SDValue();
-
- // Add is commutable, if the other operand is foldable, swap them.
- if (!Op0IsExtend)
- return SDValue();
-
- std::swap(Op0, Op1);
- std::swap(Op0IsExtend, Op1IsExtend);
- }
-
- // Op1 is a foldable extend. Op0 might be foldable.
- Op1 = Op1.getOperand(0);
- if (Op0IsExtend)
- Op0 = Op0.getOperand(0);
-
- unsigned Opc;
- if (IsAdd)
- Opc = Op0IsExtend ? RISCVISD::VFWADD_VL : RISCVISD::VFWADD_W_VL;
- else
- Opc = Op0IsExtend ? RISCVISD::VFWSUB_VL : RISCVISD::VFWSUB_W_VL;
-
- return DAG.getNode(Opc, SDLoc(N), N->getValueType(0), Op0, Op1, Merge, Mask,
- VL);
+ return combineBinOp_VLToVWBinOp_VL(N, DCI, Subtarget);
}
static SDValue performSRACombine(SDNode *N, SelectionDAG &DAG,
@@ -15170,7 +15234,13 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N,
if (SDValue V = combineBinOp_VLToVWBinOp_VL(N, DCI, Subtarget))
return V;
return performMULCombine(N, DAG);
+ case ISD::FSUB:
+ case ISD::FMUL:
+ return combineBinOp_VLToVWBinOp_VL(N, DCI, Subtarget);
case ISD::FADD:
+ if (SDValue V = combineBinOp_VLToVWBinOp_VL(N, DCI, Subtarget))
+ return V;
+ [[fallthrough]];
case ISD::UMAX:
case ISD::UMIN:
case ISD::SMAX:
@@ -15665,10 +15735,10 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N,
case RISCVISD::STRICT_VFNMSUB_VL:
return performVFMADD_VLCombine(N, DAG, Subtarget);
case RISCVISD::FMUL_VL:
- return performVFMUL_VLCombine(N, DAG, Subtarget);
+ return performVFMUL_VLCombine(N, DCI, Subtarget);
case RISCVISD::FADD_VL:
case RISCVISD::FSUB_VL:
- return performFADDSUB_VLCombine(N, DAG, Subtarget);
+ return performFADDSUB_VLCombine(N, DCI, Subtarget);
case ISD::LOAD:
case ISD::STORE: {
if (DCI.isAfterLegalizeDAG())
diff --git a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vfwmul.ll b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vfwmul.ll
index 8ad858d4c76598..7eaa1856ce2218 100644
--- a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vfwmul.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vfwmul.ll
@@ -394,18 +394,12 @@ define <32 x double> @vfwmul_vf_v32f32(ptr %x, float %y) {
; CHECK: # %bb.0:
; CHECK-NEXT: li a1, 32
; CHECK-NEXT: vsetvli zero, a1, e32, m8, ta, ma
-; CHECK-NEXT: vle32.v v16, (a0)
-; CHECK-NEXT: vsetivli zero, 16, e32, m4, ta, ma
-; CHECK-NEXT: vfwcvt.f.f.v v8, v16
+; CHECK-NEXT: vle32.v v24, (a0)
; CHECK-NEXT: vsetivli zero, 16, e32, m8, ta, ma
-; CHECK-NEXT: vslidedown.vi v16, v16, 16
+; CHECK-NEXT: vslidedown.vi v8, v24, 16
; CHECK-NEXT: vsetivli zero, 16, e32, m4, ta, ma
-; CHECK-NEXT: vfwcvt.f.f.v v24, v16
-; CHECK-NEXT: vfmv.v.f v16, fa0
-; CHECK-NEXT: vfwcvt.f.f.v v0, v16
-; CHECK-NEXT: vsetvli zero, zero, e64, m8, ta, ma
-; CHECK-NEXT: vfmul.vv v16, v24, v0
-; CHECK-NEXT: vfmul.vv v8, v8, v0
+; CHECK-NEXT: vfwmul.vf v16, v8, fa0
+; CHECK-NEXT: vfwmul.vf v8, v24, fa0
; CHECK-NEXT: ret
%a = load <32 x float>, ptr %x
%b = insertelement <32 x float> poison, float %y, i32 0
>From 654fbef11707dca4e10a1c2c136d19630bf61969 Mon Sep 17 00:00:00 2001
From: sun-jacobi <sun1011jacobi at gmail.com>
Date: Tue, 2 Jan 2024 13:42:43 +0900
Subject: [PATCH 3/7] [RISCV][ISel] fix minimum narrow size for floating point
in NodeExtensionHelper.
---
llvm/lib/Target/RISCV/RISCVISelLowering.cpp | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index cf48cd7c378a50..f8feab14fbc9db 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -12985,7 +12985,7 @@ struct NodeExtensionHelper {
// Determine the narrow size.
unsigned NarrowSize = VT.getScalarSizeInBits() / 2;
// Determine the minimum narrow size.
- unsigned MinSize = VT.isInteger() ? 8 : 32;
+ unsigned MinSize = VT.isInteger() ? 8 : 16;
assert(NarrowSize >= MinSize &&
"Trying to extend something we can't represent");
>From d3945c23727a778e3f458afbe94fade68f1adadb Mon Sep 17 00:00:00 2001
From: sun-jacobi <sun1011jacobi at gmail.com>
Date: Tue, 2 Jan 2024 14:30:47 +0900
Subject: [PATCH 4/7] [RISCV][ISel] use TLI info for
NodeExtensionHelper::getNarrowType
---
llvm/lib/Target/RISCV/RISCVISelLowering.cpp | 13 ++++++-------
.../test/CodeGen/RISCV/rvv/fixed-vectors-vfwadd.ll | 8 +++-----
.../test/CodeGen/RISCV/rvv/fixed-vectors-vfwsub.ll | 14 ++++----------
3 files changed, 13 insertions(+), 22 deletions(-)
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index f8feab14fbc9db..024b942c60d035 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -12942,7 +12942,7 @@ struct NodeExtensionHelper {
if (!Ext.has_value())
return OrigOperand;
- MVT NarrowVT = getNarrowType(Root);
+ MVT NarrowVT = getNarrowType(Root, Subtarget);
SDValue Source = getSource();
if (Source.getValueType() == NarrowVT)
@@ -12979,21 +12979,20 @@ struct NodeExtensionHelper {
/// element by 2. E.g., if Root's type <2xi16> -> narrow type <2xi8>.
/// \pre The size of the type of the elements of Root must be a multiple of 2
/// and be greater than 16.
- static MVT getNarrowType(const SDNode *Root) {
+ static MVT getNarrowType(const SDNode *Root,
+ const RISCVSubtarget &Subtarget) {
MVT VT = Root->getSimpleValueType(0);
// Determine the narrow size.
unsigned NarrowSize = VT.getScalarSizeInBits() / 2;
- // Determine the minimum narrow size.
- unsigned MinSize = VT.isInteger() ? 8 : 16;
-
- assert(NarrowSize >= MinSize &&
- "Trying to extend something we can't represent");
MVT NarrowScalarVT = VT.isInteger() ? MVT::getIntegerVT(NarrowSize)
: MVT::getFloatingPointVT(NarrowSize);
MVT NarrowVectorVT =
MVT::getVectorVT(NarrowScalarVT, VT.getVectorElementCount());
+
+ assert(Subtarget.getTargetLowering()->isTypeLegal(NarrowVectorVT) &&
+ "Trying to extend something we can't represent");
return NarrowVectorVT;
}
diff --git a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vfwadd.ll b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vfwadd.ll
index c9dc75e18774f8..dd3a50cfd77377 100644
--- a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vfwadd.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vfwadd.ll
@@ -396,12 +396,10 @@ define <32 x double> @vfwadd_vf_v32f32(ptr %x, float %y) {
; CHECK-NEXT: vsetvli zero, a1, e32, m8, ta, ma
; CHECK-NEXT: vle32.v v24, (a0)
; CHECK-NEXT: vsetivli zero, 16, e32, m8, ta, ma
-; CHECK-NEXT: vslidedown.vi v0, v24, 16
+; CHECK-NEXT: vslidedown.vi v8, v24, 16
; CHECK-NEXT: vsetivli zero, 16, e32, m4, ta, ma
-; CHECK-NEXT: vfmv.v.f v16, fa0
-; CHECK-NEXT: vfwcvt.f.f.v v8, v16
-; CHECK-NEXT: vfwadd.wv v16, v8, v0
-; CHECK-NEXT: vfwadd.wv v8, v8, v24
+; CHECK-NEXT: vfwadd.vf v16, v8, fa0
+; CHECK-NEXT: vfwadd.vf v8, v24, fa0
; CHECK-NEXT: ret
%a = load <32 x float>, ptr %x
%b = insertelement <32 x float> poison, float %y, i32 0
diff --git a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vfwsub.ll b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vfwsub.ll
index d22781d6a97ac2..8cf7c5f1758654 100644
--- a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vfwsub.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vfwsub.ll
@@ -394,18 +394,12 @@ define <32 x double> @vfwsub_vf_v32f32(ptr %x, float %y) {
; CHECK: # %bb.0:
; CHECK-NEXT: li a1, 32
; CHECK-NEXT: vsetvli zero, a1, e32, m8, ta, ma
-; CHECK-NEXT: vle32.v v16, (a0)
-; CHECK-NEXT: vsetivli zero, 16, e32, m4, ta, ma
-; CHECK-NEXT: vfwcvt.f.f.v v8, v16
+; CHECK-NEXT: vle32.v v24, (a0)
; CHECK-NEXT: vsetivli zero, 16, e32, m8, ta, ma
-; CHECK-NEXT: vslidedown.vi v16, v16, 16
+; CHECK-NEXT: vslidedown.vi v8, v24, 16
; CHECK-NEXT: vsetivli zero, 16, e32, m4, ta, ma
-; CHECK-NEXT: vfwcvt.f.f.v v24, v16
-; CHECK-NEXT: vfmv.v.f v16, fa0
-; CHECK-NEXT: vfwcvt.f.f.v v0, v16
-; CHECK-NEXT: vsetvli zero, zero, e64, m8, ta, ma
-; CHECK-NEXT: vfsub.vv v16, v24, v0
-; CHECK-NEXT: vfsub.vv v8, v8, v0
+; CHECK-NEXT: vfwsub.vf v16, v8, fa0
+; CHECK-NEXT: vfwsub.vf v8, v24, fa0
; CHECK-NEXT: ret
%a = load <32 x float>, ptr %x
%b = insertelement <32 x float> poison, float %y, i32 0
>From 5053cf9a75ee97dda73994bbe291d2c999a3b95e Mon Sep 17 00:00:00 2001
From: sun-jacobi <sun1011jacobi at gmail.com>
Date: Tue, 2 Jan 2024 14:34:39 +0900
Subject: [PATCH 5/7] [RISCV][ISel] refactor
NodeExtensionHelper::fillUpExtensionSupport
---
llvm/lib/Target/RISCV/RISCVISelLowering.cpp | 12 +++---------
1 file changed, 3 insertions(+), 9 deletions(-)
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index 024b942c60d035..52bba83c276b4d 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -13167,17 +13167,11 @@ struct NodeExtensionHelper {
break;
}
case RISCVISD::VZEXT_VL:
- SupportsZExt = true;
- Mask = OrigOperand.getOperand(1);
- VL = OrigOperand.getOperand(2);
- break;
case RISCVISD::VSEXT_VL:
- SupportsSExt = true;
- Mask = OrigOperand.getOperand(1);
- VL = OrigOperand.getOperand(2);
- break;
case RISCVISD::FP_EXTEND_VL:
- SupportsFPExt = true;
+ SupportsZExt = Opc == RISCVISD::VZEXT_VL;
+ SupportsSExt = Opc == RISCVISD::VSEXT_VL;
+ SupportsFPExt = Opc == RISCVISD::FP_EXTEND_VL;
Mask = OrigOperand.getOperand(1);
VL = OrigOperand.getOperand(2);
break;
>From 32d9ae167399a5ed9656dce0eabf4b0410ae2e33 Mon Sep 17 00:00:00 2001
From: sun-jacobi <sun1011jacobi at gmail.com>
Date: Tue, 2 Jan 2024 16:17:31 +0900
Subject: [PATCH 6/7] [RISCV][ISel] set minimum size for floating point in
NodeExtensionHelper::isSupportedRoot.
---
llvm/lib/Target/RISCV/RISCVISelLowering.cpp | 41 ++++++++++++++-------
1 file changed, 27 insertions(+), 14 deletions(-)
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index 52bba83c276b4d..ca0542ff175718 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -13112,8 +13112,10 @@ struct NodeExtensionHelper {
/// Get half widening (2*SEW = 2*SEW +/- SEW) FP add/sub opcode.
static unsigned getFloatHalfWidenOpcode(unsigned Opcode) {
switch (Opcode) {
+ case ISD::FADD:
case RISCVISD::FADD_VL:
return RISCVISD::VFWADD_W_VL;
+ case ISD::FSUB:
case RISCVISD::FSUB_VL:
return RISCVISD::VFWSUB_W_VL;
default:
@@ -13202,7 +13204,8 @@ struct NodeExtensionHelper {
if (ScalarBits < EltBits)
break;
- if (VT.isFloatingPoint()) {
+ if (VT.isFloatingPoint() &&
+ EltBits >= (Subtarget.hasStdExtZvfh() ? 32 : 64)) {
SupportsFPExt = true;
break;
}
@@ -13226,19 +13229,31 @@ struct NodeExtensionHelper {
}
/// Check if \p Root supports any extension folding combines.
- static bool isSupportedRoot(const SDNode *Root, const SelectionDAG &DAG) {
+ static bool isSupportedRoot(const SDNode *Root, const SelectionDAG &DAG,
+ const RISCVSubtarget &Subtarget) {
switch (Root->getOpcode()) {
- case ISD::ADD:
- case ISD::SUB:
- case ISD::MUL:
case ISD::FADD:
case ISD::FSUB:
- case ISD::FMUL: {
+ case ISD::FMUL:
+ if (Root->getValueType(0).getScalarSizeInBits() <
+ (Subtarget.hasStdExtZvfh() ? 32 : 64))
+ return false;
+ [[fallthrough]];
+ case ISD::ADD:
+ case ISD::SUB:
+ case ISD::MUL: {
const TargetLowering &TLI = DAG.getTargetLoweringInfo();
if (!TLI.isTypeLegal(Root->getValueType(0)))
return false;
return Root->getValueType(0).isScalableVector();
}
+ case RISCVISD::FADD_VL:
+ case RISCVISD::FSUB_VL:
+ case RISCVISD::FMUL_VL:
+ if (Root->getValueType(0).getScalarSizeInBits() <
+ (Subtarget.hasStdExtZvfh() ? 32 : 64))
+ return false;
+ [[fallthrough]];
case RISCVISD::ADD_VL:
case RISCVISD::MUL_VL:
case RISCVISD::VWADD_W_VL:
@@ -13246,9 +13261,6 @@ struct NodeExtensionHelper {
case RISCVISD::SUB_VL:
case RISCVISD::VWSUB_W_VL:
case RISCVISD::VWSUBU_W_VL:
- case RISCVISD::FADD_VL:
- case RISCVISD::FSUB_VL:
- case RISCVISD::FMUL_VL:
case RISCVISD::VFWADD_W_VL:
case RISCVISD::VFWSUB_W_VL:
return true;
@@ -13260,8 +13272,9 @@ struct NodeExtensionHelper {
/// Build a NodeExtensionHelper for \p Root.getOperand(\p OperandIdx).
NodeExtensionHelper(SDNode *Root, unsigned OperandIdx, SelectionDAG &DAG,
const RISCVSubtarget &Subtarget) {
- assert(isSupportedRoot(Root, DAG) && "Trying to build an helper with an "
- "unsupported root");
+ assert(isSupportedRoot(Root, DAG, Subtarget) &&
+ "Trying to build an helper with an "
+ "unsupported root");
assert(OperandIdx < 2 && "Requesting something else than LHS or RHS");
OrigOperand = Root->getOperand(OperandIdx);
@@ -13310,7 +13323,7 @@ struct NodeExtensionHelper {
static std::pair<SDValue, SDValue>
getMaskAndVL(const SDNode *Root, SelectionDAG &DAG,
const RISCVSubtarget &Subtarget) {
- assert(isSupportedRoot(Root, DAG) && "Unexpected root");
+ assert(isSupportedRoot(Root, DAG, Subtarget) && "Unexpected root");
switch (Root->getOpcode()) {
case ISD::ADD:
case ISD::SUB:
@@ -13621,7 +13634,7 @@ static SDValue combineBinOp_VLToVWBinOp_VL(SDNode *N,
const RISCVSubtarget &Subtarget) {
SelectionDAG &DAG = DCI.DAG;
- if (!NodeExtensionHelper::isSupportedRoot(N, DAG))
+ if (!NodeExtensionHelper::isSupportedRoot(N, DAG, Subtarget))
return SDValue();
SmallVector<SDNode *> Worklist;
@@ -13632,7 +13645,7 @@ static SDValue combineBinOp_VLToVWBinOp_VL(SDNode *N,
while (!Worklist.empty()) {
SDNode *Root = Worklist.pop_back_val();
- if (!NodeExtensionHelper::isSupportedRoot(Root, DAG))
+ if (!NodeExtensionHelper::isSupportedRoot(Root, DAG, Subtarget))
return SDValue();
NodeExtensionHelper LHS(N, 0, DAG, Subtarget);
>From 81ea4288c25e433e8b5f498c050c43cbf4855599 Mon Sep 17 00:00:00 2001
From: sun-jacobi <sun1011jacobi at gmail.com>
Date: Tue, 2 Jan 2024 21:04:21 +0900
Subject: [PATCH 7/7] [RISCV][Isel] remove VFMV_V_F_VL in NodeExtensionHelper.
---
llvm/lib/Target/RISCV/RISCVISelLowering.cpp | 29 ++++++++++++++-------
1 file changed, 20 insertions(+), 9 deletions(-)
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index ca0542ff175718..239223813220e7 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -12910,6 +12910,8 @@ struct NodeExtensionHelper {
case RISCVISD::VZEXT_VL:
case RISCVISD::FP_EXTEND_VL:
return OrigOperand.getOperand(0);
+ case ISD::SPLAT_VECTOR:
+ return OrigOperand.getOperand(0)->getOperand(0);
default:
return OrigOperand;
}
@@ -12918,7 +12920,8 @@ struct NodeExtensionHelper {
/// Check if this instance represents a splat.
bool isSplat() const {
return OrigOperand.getOpcode() == RISCVISD::VMV_V_X_VL ||
- OrigOperand.getOpcode() == RISCVISD::VFMV_V_F_VL;
+ (OrigOperand.getOpcode() == ISD::SPLAT_VECTOR &&
+ OrigOperand.getOperand(0).getOpcode() == ISD::FP_EXTEND);
}
/// Get the extended opcode.
@@ -12962,7 +12965,8 @@ struct NodeExtensionHelper {
unsigned ExtOpc = getExtOpc(*Ext);
return DAG.getNode(ExtOpc, DL, NarrowVT, Source, Mask, VL);
}
- case RISCVISD::VFMV_V_F_VL:
+ case ISD::SPLAT_VECTOR:
+ return DAG.getNode(ISD::SPLAT_VECTOR, DL, NarrowVT, Source);
case RISCVISD::VMV_V_X_VL:
return DAG.getNode(OrigOpc, DL, NarrowVT, DAG.getUNDEF(NarrowVT),
Source.getOperand(1), VL);
@@ -13177,7 +13181,20 @@ struct NodeExtensionHelper {
Mask = OrigOperand.getOperand(1);
VL = OrigOperand.getOperand(2);
break;
- case RISCVISD::VFMV_V_F_VL:
+ case ISD::SPLAT_VECTOR: {
+ if (OrigOperand.getOperand(0)->getOpcode() != ISD::FP_EXTEND)
+ break;
+
+ MVT VT = Root->getSimpleValueType(0);
+ if (VT.isFloatingPoint() &&
+ VT.getScalarSizeInBits() >= (Subtarget.hasStdExtZvfh() ? 32 : 64)) {
+ SupportsFPExt = true;
+ SDLoc DL(Root);
+ std::tie(Mask, VL) = getDefaultScalableVLOps(VT, DL, DAG, Subtarget);
+ break;
+ }
+ break;
+ }
case RISCVISD::VMV_V_X_VL: {
// Historically, we didn't care about splat values not disappearing during
// combines.
@@ -13204,12 +13221,6 @@ struct NodeExtensionHelper {
if (ScalarBits < EltBits)
break;
- if (VT.isFloatingPoint() &&
- EltBits >= (Subtarget.hasStdExtZvfh() ? 32 : 64)) {
- SupportsFPExt = true;
- break;
- }
-
unsigned NarrowSize = VT.getScalarSizeInBits() / 2;
// If the narrow type cannot be expressed with a legal VMV,
// this is not a valid candidate.
More information about the llvm-commits
mailing list