[llvm] [RISCV][ISel] Combine vector fadd/fsub/fmul with fp extend. (PR #81248)
via llvm-commits
llvm-commits at lists.llvm.org
Sun Feb 18 22:18:47 PST 2024
================
@@ -13396,41 +13416,55 @@ struct NodeExtensionHelper {
/// Helper function to get the narrow type for \p Root.
/// The narrow type is the type of \p Root where we divided the size of each
/// element by 2. E.g., if Root's type <2xi16> -> narrow type <2xi8>.
- /// \pre The size of the type of the elements of Root must be a multiple of 2
- /// and be greater than 16.
- static MVT getNarrowType(const SDNode *Root) {
+ /// \pre Both the narrow type and the original type should be legal.
+ static MVT getNarrowType(const SDNode *Root, ExtKind SupportsExt) {
MVT VT = Root->getSimpleValueType(0);
// Determine the narrow size.
unsigned NarrowSize = VT.getScalarSizeInBits() / 2;
- assert(NarrowSize >= 8 && "Trying to extend something we can't represent");
- MVT NarrowVT = MVT::getVectorVT(MVT::getIntegerVT(NarrowSize),
- VT.getVectorElementCount());
+
+ MVT EltVT = SupportsExt == ExtKind::FPExt
+ ? MVT::getFloatingPointVT(NarrowSize)
+ : MVT::getIntegerVT(NarrowSize);
+
+ assert(NarrowSize >= (SupportsExt == ExtKind::FPExt ? 16 : 8) &&
+ "Trying to extend something we can't represent");
+ MVT NarrowVT = MVT::getVectorVT(EltVT, VT.getVectorElementCount());
return NarrowVT;
}
- /// Return the opcode required to materialize the folding of the sign
- /// extensions (\p IsSExt == true) or zero extensions (IsSExt == false) for
+ /// Return the opcode required to materialize the folding for
/// both operands for \p Opcode.
/// Put differently, get the opcode to materialize:
- /// - ISExt == true: \p Opcode(sext(a), sext(b)) -> newOpcode(a, b)
- /// - ISExt == false: \p Opcode(zext(a), zext(b)) -> newOpcode(a, b)
+ /// - ExtKind::SExt: \p Opcode(sext(a), sext(b)) -> newOpcode(a, b)
+ /// - ExtKind::ZExt: \p Opcode(zext(a), zext(b)) -> newOpcode(a, b)
/// \pre \p Opcode represents a supported root (\see ::isSupportedRoot()).
- static unsigned getSameExtensionOpcode(unsigned Opcode, bool IsSExt) {
+ static unsigned getSameExtensionOpcode(unsigned Opcode, ExtKind SupportsExt) {
switch (Opcode) {
case ISD::ADD:
case RISCVISD::ADD_VL:
case RISCVISD::VWADD_W_VL:
case RISCVISD::VWADDU_W_VL:
- return IsSExt ? RISCVISD::VWADD_VL : RISCVISD::VWADDU_VL;
+ return SupportsExt == ExtKind::SExt ? RISCVISD::VWADD_VL
+ : RISCVISD::VWADDU_VL;
case ISD::MUL:
case RISCVISD::MUL_VL:
- return IsSExt ? RISCVISD::VWMUL_VL : RISCVISD::VWMULU_VL;
+ return SupportsExt == ExtKind::SExt ? RISCVISD::VWMUL_VL
+ : RISCVISD::VWMULU_VL;
case ISD::SUB:
case RISCVISD::SUB_VL:
case RISCVISD::VWSUB_W_VL:
case RISCVISD::VWSUBU_W_VL:
- return IsSExt ? RISCVISD::VWSUB_VL : RISCVISD::VWSUBU_VL;
+ return SupportsExt == ExtKind::SExt ? RISCVISD::VWSUB_VL
+ : RISCVISD::VWSUBU_VL;
+ case RISCVISD::FADD_VL:
+ case RISCVISD::VFWADD_W_VL:
+ return RISCVISD::VFWADD_VL;
----------------
sun-jacobi wrote:
I think we could also split the `getSameExtensionOpcode` into `getSExtOpcode`, `getZExtOpcode` and `getFPExtOpcode` ? It might be cleaner.
https://github.com/llvm/llvm-project/pull/81248
More information about the llvm-commits
mailing list