[llvm] [RISCV][ISel] Combine vector fadd/fsub/fmul with fp extend. (PR #76695)

via llvm-commits llvm-commits at lists.llvm.org
Mon Jan 1 18:49:44 PST 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-backend-risc-v

Author: Chia (sun-jacobi)

<details>
<summary>Changes</summary>

This patch is an extension of #<!-- -->72340 and [D133739](https://reviews.llvm.org/D133739), supporting floating-point extension.

---

Patch is 31.40 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/76695.diff


2 Files Affected:

- (modified) llvm/lib/Target/RISCV/RISCVISelLowering.cpp (+288-157) 
- (modified) llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vfwmul.ll (+4-10) 


``````````diff
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index 51580d15451ca2..cf48cd7c378a50 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -1374,13 +1374,14 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
   setPrefLoopAlignment(Subtarget.getPrefLoopAlignment());
 
   setTargetDAGCombine({ISD::INTRINSIC_VOID, ISD::INTRINSIC_W_CHAIN,
-                       ISD::INTRINSIC_WO_CHAIN, ISD::ADD, ISD::SUB, ISD::MUL,
-                       ISD::AND, ISD::OR, ISD::XOR, ISD::SETCC, ISD::SELECT});
+                       ISD::INTRINSIC_WO_CHAIN, ISD::ADD, ISD::SUB, ISD::AND,
+                       ISD::OR, ISD::XOR, ISD::SETCC, ISD::SELECT});
   if (Subtarget.is64Bit())
     setTargetDAGCombine(ISD::SRA);
 
   if (Subtarget.hasStdExtFOrZfinx())
-    setTargetDAGCombine({ISD::FADD, ISD::FMAXNUM, ISD::FMINNUM});
+    setTargetDAGCombine(
+        {ISD::FADD, ISD::FSUB, ISD::FMUL, ISD::FMAXNUM, ISD::FMINNUM});
 
   if (Subtarget.hasStdExtZbb())
     setTargetDAGCombine({ISD::UMAX, ISD::UMIN, ISD::SMAX, ISD::SMIN});
@@ -12848,6 +12849,9 @@ namespace {
 // apply a combine.
 struct CombineResult;
 
+// Supported extension kind to be folded.
+enum class SupportExt { ZExt, SExt, FPExt };
+
 /// Helper class for folding sign/zero extensions.
 /// In particular, this class is used for the following combines:
 /// add | add_vl -> vwadd(u) | vwadd(u)_w
@@ -12878,6 +12882,8 @@ struct NodeExtensionHelper {
   /// instance, a splat constant (e.g., 3), would support being both sign and
   /// zero extended.
   bool SupportsSExt;
+  /// Records if this operand is like being floating-point extended.
+  bool SupportsFPExt;
   /// This boolean captures whether we care if this operand would still be
   /// around after the folding happens.
   bool EnforceOneUse;
@@ -12899,8 +12905,10 @@ struct NodeExtensionHelper {
     switch (OrigOperand.getOpcode()) {
     case ISD::ZERO_EXTEND:
     case ISD::SIGN_EXTEND:
+    case ISD::FP_EXTEND:
     case RISCVISD::VSEXT_VL:
     case RISCVISD::VZEXT_VL:
+    case RISCVISD::FP_EXTEND_VL:
       return OrigOperand.getOperand(0);
     default:
       return OrigOperand;
@@ -12909,7 +12917,20 @@ struct NodeExtensionHelper {
 
   /// Check if this instance represents a splat.
   bool isSplat() const {
-    return OrigOperand.getOpcode() == RISCVISD::VMV_V_X_VL;
+    return OrigOperand.getOpcode() == RISCVISD::VMV_V_X_VL ||
+           OrigOperand.getOpcode() == RISCVISD::VFMV_V_F_VL;
+  }
+
+  /// Get the extended opcode.
+  unsigned getExtOpc(SupportExt Ext) const {
+    switch (Ext) {
+    case SupportExt::ZExt:
+      return RISCVISD::VZEXT_VL;
+    case SupportExt::SExt:
+      return RISCVISD::VSEXT_VL;
+    case SupportExt::FPExt:
+      return RISCVISD::FP_EXTEND_VL;
+    }
   }
 
   /// Get or create a value that can feed \p Root with the given extension \p
@@ -12917,8 +12938,8 @@ struct NodeExtensionHelper {
   /// \see ::getSource().
   SDValue getOrCreateExtendedOp(SDNode *Root, SelectionDAG &DAG,
                                 const RISCVSubtarget &Subtarget,
-                                std::optional<bool> SExt) const {
-    if (!SExt.has_value())
+                                std::optional<SupportExt> Ext) const {
+    if (!Ext.has_value())
       return OrigOperand;
 
     MVT NarrowVT = getNarrowType(Root);
@@ -12927,20 +12948,24 @@ struct NodeExtensionHelper {
     if (Source.getValueType() == NarrowVT)
       return Source;
 
-    unsigned ExtOpc = *SExt ? RISCVISD::VSEXT_VL : RISCVISD::VZEXT_VL;
-
+    unsigned OrigOpc = OrigOperand.getOpcode();
     // If we need an extension, we should be changing the type.
     SDLoc DL(Root);
     auto [Mask, VL] = getMaskAndVL(Root, DAG, Subtarget);
-    switch (OrigOperand.getOpcode()) {
+    switch (OrigOpc) {
     case ISD::ZERO_EXTEND:
     case ISD::SIGN_EXTEND:
+    case ISD::FP_EXTEND:
     case RISCVISD::VSEXT_VL:
     case RISCVISD::VZEXT_VL:
+    case RISCVISD::FP_EXTEND_VL: {
+      unsigned ExtOpc = getExtOpc(*Ext);
       return DAG.getNode(ExtOpc, DL, NarrowVT, Source, Mask, VL);
+    }
+    case RISCVISD::VFMV_V_F_VL:
     case RISCVISD::VMV_V_X_VL:
-      return DAG.getNode(RISCVISD::VMV_V_X_VL, DL, NarrowVT,
-                         DAG.getUNDEF(NarrowVT), Source.getOperand(1), VL);
+      return DAG.getNode(OrigOpc, DL, NarrowVT, DAG.getUNDEF(NarrowVT),
+                         Source.getOperand(1), VL);
     default:
       // Other opcodes can only come from the original LHS of VW(ADD|SUB)_W_VL
       // and that operand should already have the right NarrowVT so no
@@ -12959,62 +12984,157 @@ struct NodeExtensionHelper {
 
     // Determine the narrow size.
     unsigned NarrowSize = VT.getScalarSizeInBits() / 2;
-    assert(NarrowSize >= 8 && "Trying to extend something we can't represent");
-    MVT NarrowVT = MVT::getVectorVT(MVT::getIntegerVT(NarrowSize),
-                                    VT.getVectorElementCount());
-    return NarrowVT;
+    // Determine the minimum narrow size.
+    unsigned MinSize = VT.isInteger() ? 8 : 32;
+
+    assert(NarrowSize >= MinSize &&
+           "Trying to extend something we can't represent");
+
+    MVT NarrowScalarVT = VT.isInteger() ? MVT::getIntegerVT(NarrowSize)
+                                        : MVT::getFloatingPointVT(NarrowSize);
+    MVT NarrowVectorVT =
+        MVT::getVectorVT(NarrowScalarVT, VT.getVectorElementCount());
+    return NarrowVectorVT;
   }
 
-  /// Return the opcode required to materialize the folding of the sign
-  /// extensions (\p IsSExt == true) or zero extensions (IsSExt == false) for
-  /// both operands for \p Opcode.
-  /// Put differently, get the opcode to materialize:
-  /// - ISExt == true: \p Opcode(sext(a), sext(b)) -> newOpcode(a, b)
-  /// - ISExt == false: \p Opcode(zext(a), zext(b)) -> newOpcode(a, b)
-  /// \pre \p Opcode represents a supported root (\see ::isSupportedRoot()).
-  static unsigned getSameExtensionOpcode(unsigned Opcode, bool IsSExt) {
+  /// Get full widening (2*SEW = SEW +/-/* SEW) signed integer add/sub/mul
+  /// opcode.
+  static unsigned getSignedFullWidenOpcode(unsigned Opcode) {
     switch (Opcode) {
     case ISD::ADD:
     case RISCVISD::ADD_VL:
     case RISCVISD::VWADD_W_VL:
-    case RISCVISD::VWADDU_W_VL:
-      return IsSExt ? RISCVISD::VWADD_VL : RISCVISD::VWADDU_VL;
+      return RISCVISD::VWADD_VL;
     case ISD::MUL:
     case RISCVISD::MUL_VL:
-      return IsSExt ? RISCVISD::VWMUL_VL : RISCVISD::VWMULU_VL;
+      return RISCVISD::VWMUL_VL;
     case ISD::SUB:
     case RISCVISD::SUB_VL:
     case RISCVISD::VWSUB_W_VL:
+      return RISCVISD::VWSUB_VL;
+    default:
+      llvm_unreachable("Unexpected Opcode");
+    }
+  }
+
+  /// Get full widening (2*SEW = SEW +/-/* SEW) unsigned integer add/sub/mul
+  /// opcode.
+  static unsigned getUnsignedFullWidenOpcode(unsigned Opcode) {
+    switch (Opcode) {
+    case ISD::ADD:
+    case RISCVISD::ADD_VL:
+    case RISCVISD::VWADDU_W_VL:
+      return RISCVISD::VWADDU_VL;
+    case ISD::MUL:
+    case RISCVISD::MUL_VL:
+      return RISCVISD::VWMULU_VL;
+    case ISD::SUB:
+    case RISCVISD::SUB_VL:
     case RISCVISD::VWSUBU_W_VL:
-      return IsSExt ? RISCVISD::VWSUB_VL : RISCVISD::VWSUBU_VL;
+      return RISCVISD::VWSUBU_VL;
     default:
-      llvm_unreachable("Unexpected opcode");
+      llvm_unreachable("Unexpected Opcode");
+    }
+  }
+
+  /// Get full widening (2*SEW = SEW +/-/* SEW) FP add/sub/mul opcode.
+  static unsigned getFloatFullWidenOpcode(unsigned Opcode) {
+    switch (Opcode) {
+    case ISD::FADD:
+    case RISCVISD::FADD_VL:
+    case RISCVISD::VFWADD_W_VL:
+      return RISCVISD::VFWADD_VL;
+    case ISD::FSUB:
+    case RISCVISD::FSUB_VL:
+    case RISCVISD::VFWSUB_W_VL:
+      return RISCVISD::VFWSUB_VL;
+    case ISD::FMUL:
+    case RISCVISD::FMUL_VL:
+      return RISCVISD::VFWMUL_VL;
+    default:
+      llvm_unreachable("Unexpected Opcode");
+    }
+  }
+
+  /// Return the opcode required to materialize the folding of the sign
+  /// extensions (\p IsSExt == true) or zero extensions (IsSExt == false) for
+  /// both operands for \p Opcode.
+  /// Put differently, get the opcode to materialize:
+  /// - ISExt == true: \p Opcode(sext(a), sext(b)) -> newOpcode(a, b)
+  /// - ISExt == false: \p Opcode(zext(a), zext(b)) -> newOpcode(a, b)
+  /// \pre \p Opcode represents a supported root (\see ::isSupportedRoot()).
+  static unsigned getFullWidenOpcode(unsigned OrigOpcode, SupportExt Ext) {
+    switch (Ext) {
+    case SupportExt::SExt:
+      return getSignedFullWidenOpcode(OrigOpcode);
+    case SupportExt::ZExt:
+      return getUnsignedFullWidenOpcode(OrigOpcode);
+    case SupportExt::FPExt:
+      return getFloatFullWidenOpcode(OrigOpcode);
     }
   }
 
   /// Get the opcode to materialize \p Opcode(sext(a), zext(b)) ->
   /// newOpcode(a, b).
-  static unsigned getSUOpcode(unsigned Opcode) {
+  static unsigned getSignedUnsignedWidenOpcode(unsigned Opcode) {
     assert((Opcode == RISCVISD::MUL_VL || Opcode == ISD::MUL) &&
            "SU is only supported for MUL");
     return RISCVISD::VWMULSU_VL;
   }
 
-  /// Get the opcode to materialize \p Opcode(a, s|zext(b)) ->
-  /// newOpcode(a, b).
-  static unsigned getWOpcode(unsigned Opcode, bool IsSExt) {
+  /// Get half widening (2*SEW = 2*SEW +/- SEW) signed integer add/sub opcode.
+  static unsigned getSignedHalfWidenOpcode(unsigned Opcode) {
+    switch (Opcode) {
+    case ISD::ADD:
+    case RISCVISD::ADD_VL:
+      return RISCVISD::VWADD_W_VL;
+    case ISD::SUB:
+    case RISCVISD::SUB_VL:
+      return RISCVISD::VWSUB_W_VL;
+    default:
+      llvm_unreachable("Unexpected opcode");
+    }
+  }
+
+  /// Get half widening (2*SEW = 2*SEW +/- SEW) unsigned integer add/sub opcode.
+  static unsigned getUnsignedHalfWidenOpcode(unsigned Opcode) {
     switch (Opcode) {
     case ISD::ADD:
     case RISCVISD::ADD_VL:
-      return IsSExt ? RISCVISD::VWADD_W_VL : RISCVISD::VWADDU_W_VL;
+      return RISCVISD::VWADDU_W_VL;
     case ISD::SUB:
     case RISCVISD::SUB_VL:
-      return IsSExt ? RISCVISD::VWSUB_W_VL : RISCVISD::VWSUBU_W_VL;
+      return RISCVISD::VWSUBU_W_VL;
+    default:
+      llvm_unreachable("Unexpected opcode");
+    }
+  }
+
+  /// Get half widening (2*SEW = 2*SEW +/- SEW) FP add/sub opcode.
+  static unsigned getFloatHalfWidenOpcode(unsigned Opcode) {
+    switch (Opcode) {
+    case RISCVISD::FADD_VL:
+      return RISCVISD::VFWADD_W_VL;
+    case RISCVISD::FSUB_VL:
+      return RISCVISD::VFWSUB_W_VL;
     default:
       llvm_unreachable("Unexpected opcode");
     }
   }
 
+  /// Get the opcode to materialize \p Opcode(a, s|zext(b)) ->
+  /// newOpcode(a, b).
+  static unsigned getHalfWidenOpcode(unsigned Opcode, SupportExt Ext) {
+    switch (Ext) {
+    case SupportExt::SExt:
+      return getSignedHalfWidenOpcode(Opcode);
+    case SupportExt::ZExt:
+      return getUnsignedHalfWidenOpcode(Opcode);
+    case SupportExt::FPExt:
+      return getFloatHalfWidenOpcode(Opcode);
+    }
+  }
+
   using CombineToTry = std::function<std::optional<CombineResult>(
       SDNode * /*Root*/, const NodeExtensionHelper & /*LHS*/,
       const NodeExtensionHelper & /*RHS*/, SelectionDAG &,
@@ -13029,15 +13149,18 @@ struct NodeExtensionHelper {
                               const RISCVSubtarget &Subtarget) {
     SupportsZExt = false;
     SupportsSExt = false;
+    SupportsFPExt = false;
     EnforceOneUse = true;
     CheckMask = true;
     unsigned Opc = OrigOperand.getOpcode();
     switch (Opc) {
     case ISD::ZERO_EXTEND:
-    case ISD::SIGN_EXTEND: {
+    case ISD::SIGN_EXTEND:
+    case ISD::FP_EXTEND: {
       if (OrigOperand.getValueType().isVector()) {
         SupportsZExt = Opc == ISD::ZERO_EXTEND;
         SupportsSExt = Opc == ISD::SIGN_EXTEND;
+        SupportsFPExt = Opc == ISD::FP_EXTEND;
         SDLoc DL(Root);
         MVT VT = Root->getSimpleValueType(0);
         std::tie(Mask, VL) = getDefaultScalableVLOps(VT, DL, DAG, Subtarget);
@@ -13054,6 +13177,12 @@ struct NodeExtensionHelper {
       Mask = OrigOperand.getOperand(1);
       VL = OrigOperand.getOperand(2);
       break;
+    case RISCVISD::FP_EXTEND_VL:
+      SupportsFPExt = true;
+      Mask = OrigOperand.getOperand(1);
+      VL = OrigOperand.getOperand(2);
+      break;
+    case RISCVISD::VFMV_V_F_VL:
     case RISCVISD::VMV_V_X_VL: {
       // Historically, we didn't care about splat values not disappearing during
       // combines.
@@ -13080,6 +13209,11 @@ struct NodeExtensionHelper {
       if (ScalarBits < EltBits)
         break;
 
+      if (VT.isFloatingPoint()) {
+        SupportsFPExt = true;
+        break;
+      }
+
       unsigned NarrowSize = VT.getScalarSizeInBits() / 2;
       // If the narrow type cannot be expressed with a legal VMV,
       // this is not a valid candidate.
@@ -13103,7 +13237,10 @@ struct NodeExtensionHelper {
     switch (Root->getOpcode()) {
     case ISD::ADD:
     case ISD::SUB:
-    case ISD::MUL: {
+    case ISD::MUL:
+    case ISD::FADD:
+    case ISD::FSUB:
+    case ISD::FMUL: {
       const TargetLowering &TLI = DAG.getTargetLoweringInfo();
       if (!TLI.isTypeLegal(Root->getValueType(0)))
         return false;
@@ -13116,6 +13253,11 @@ struct NodeExtensionHelper {
     case RISCVISD::SUB_VL:
     case RISCVISD::VWSUB_W_VL:
     case RISCVISD::VWSUBU_W_VL:
+    case RISCVISD::FADD_VL:
+    case RISCVISD::FSUB_VL:
+    case RISCVISD::FMUL_VL:
+    case RISCVISD::VFWADD_W_VL:
+    case RISCVISD::VFWSUB_W_VL:
       return true;
     default:
       return false;
@@ -13138,10 +13280,15 @@ struct NodeExtensionHelper {
     case RISCVISD::VWADDU_W_VL:
     case RISCVISD::VWSUB_W_VL:
     case RISCVISD::VWSUBU_W_VL:
+    case RISCVISD::VFWADD_W_VL:
+    case RISCVISD::VFWSUB_W_VL:
       if (OperandIdx == 1) {
         SupportsZExt =
             Opc == RISCVISD::VWADDU_W_VL || Opc == RISCVISD::VWSUBU_W_VL;
-        SupportsSExt = !SupportsZExt;
+        SupportsSExt =
+            Opc == RISCVISD::VWADD_W_VL || Opc == RISCVISD::VWSUB_W_VL;
+        SupportsFPExt =
+            Opc == RISCVISD::VFWADD_W_VL || Opc == RISCVISD::VFWSUB_W_VL;
         std::tie(Mask, VL) = getMaskAndVL(Root, DAG, Subtarget);
         CheckMask = true;
         // There's no existing extension here, so we don't have to worry about
@@ -13174,7 +13321,10 @@ struct NodeExtensionHelper {
     switch (Root->getOpcode()) {
     case ISD::ADD:
     case ISD::SUB:
-    case ISD::MUL: {
+    case ISD::MUL:
+    case ISD::FADD:
+    case ISD::FSUB:
+    case ISD::FMUL: {
       SDLoc DL(Root);
       MVT VT = Root->getSimpleValueType(0);
       return getDefaultScalableVLOps(VT, DL, DAG, Subtarget);
@@ -13197,15 +13347,23 @@ struct NodeExtensionHelper {
     switch (N->getOpcode()) {
     case ISD::ADD:
     case ISD::MUL:
+    case ISD::FADD:
+    case ISD::FMUL:
     case RISCVISD::ADD_VL:
     case RISCVISD::MUL_VL:
     case RISCVISD::VWADD_W_VL:
     case RISCVISD::VWADDU_W_VL:
+    case RISCVISD::FADD_VL:
+    case RISCVISD::FMUL_VL:
+    case RISCVISD::VFWADD_W_VL:
       return true;
     case ISD::SUB:
+    case ISD::FSUB:
     case RISCVISD::SUB_VL:
     case RISCVISD::VWSUB_W_VL:
     case RISCVISD::VWSUBU_W_VL:
+    case RISCVISD::FSUB_VL:
+    case RISCVISD::VFWSUB_W_VL:
       return false;
     default:
       llvm_unreachable("Unexpected opcode");
@@ -13227,22 +13385,23 @@ struct NodeExtensionHelper {
 struct CombineResult {
   /// Opcode to be generated when materializing the combine.
   unsigned TargetOpcode;
-  // No value means no extension is needed. If extension is needed, the value
-  // indicates if it needs to be sign extended.
-  std::optional<bool> SExtLHS;
-  std::optional<bool> SExtRHS;
-  /// Root of the combine.
-  SDNode *Root;
   /// LHS of the TargetOpcode.
   NodeExtensionHelper LHS;
+  /// Extension of the LHS
+  std::optional<SupportExt> ExtLHS;
   /// RHS of the TargetOpcode.
   NodeExtensionHelper RHS;
+  /// Extension of the RHS
+  std::optional<SupportExt> ExtRHS;
+  /// Root of the combine.
+  SDNode *Root;
 
-  CombineResult(unsigned TargetOpcode, SDNode *Root,
-                const NodeExtensionHelper &LHS, std::optional<bool> SExtLHS,
-                const NodeExtensionHelper &RHS, std::optional<bool> SExtRHS)
-      : TargetOpcode(TargetOpcode), SExtLHS(SExtLHS), SExtRHS(SExtRHS),
-        Root(Root), LHS(LHS), RHS(RHS) {}
+  CombineResult(
+      unsigned TargetOpcode, SDNode *Root,
+      std::pair<const NodeExtensionHelper &, std::optional<SupportExt>> LHS,
+      std::pair<const NodeExtensionHelper &, std::optional<SupportExt>> RHS)
+      : TargetOpcode(TargetOpcode), LHS(LHS.first), ExtLHS(LHS.second),
+        RHS(RHS.first), ExtRHS(RHS.second), Root(Root) {}
 
   /// Return a value that uses TargetOpcode and that can be used to replace
   /// Root.
@@ -13259,12 +13418,15 @@ struct CombineResult {
     case ISD::ADD:
     case ISD::SUB:
     case ISD::MUL:
+    case ISD::FADD:
+    case ISD::FSUB:
+    case ISD::FMUL:
       Merge = DAG.getUNDEF(Root->getValueType(0));
       break;
     }
     return DAG.getNode(TargetOpcode, SDLoc(Root), Root->getValueType(0),
-                       LHS.getOrCreateExtendedOp(Root, DAG, Subtarget, SExtLHS),
-                       RHS.getOrCreateExtendedOp(Root, DAG, Subtarget, SExtRHS),
+                       LHS.getOrCreateExtendedOp(Root, DAG, Subtarget, ExtLHS),
+                       RHS.getOrCreateExtendedOp(Root, DAG, Subtarget, ExtRHS),
                        Merge, Mask, VL);
   }
 };
@@ -13279,24 +13441,30 @@ struct CombineResult {
 ///
 /// \returns std::nullopt if the pattern doesn't match or a CombineResult that
 /// can be used to apply the pattern.
-static std::optional<CombineResult>
-canFoldToVWWithSameExtensionImpl(SDNode *Root, const NodeExtensionHelper &LHS,
-                                 const NodeExtensionHelper &RHS, bool AllowSExt,
-                                 bool AllowZExt, SelectionDAG &DAG,
-                                 const RISCVSubtarget &Subtarget) {
-  assert((AllowSExt || AllowZExt) && "Forgot to set what you want?");
+static std::optional<CombineResult> canFoldToVWWithSameExtensionImpl(
+    SDNode *Root, const NodeExtensionHelper &LHS,
+    const NodeExtensionHelper &RHS, bool AllowSExt, bool AllowZExt,
+    bool AllowFPExt, SelectionDAG &DAG, const RISCVSubtarget &Subtarget) {
+  assert((AllowSExt || AllowZExt || AllowFPExt) &&
+         "Forgot to set what you want?");
   if (!LHS.areVLAndMaskCompatible(Root, DAG, Subtarget) ||
       !RHS.areVLAndMaskCompatible(Root, DAG, Subtarget))
     return std::nullopt;
   if (AllowZExt && LHS.SupportsZExt && RHS.SupportsZExt)
-    return CombineResult(NodeExtensionHelper::getSameExtensionOpcode(
-                             Root->getOpcode(), /*IsSExt=*/false),
-                         Root, LHS, /*SExtLHS=*/false, RHS, /*SExtRHS=*/false);
+    return CombineResult(NodeExtensionHelper::getFullWidenOpcode(
+                             Root->getOpcode(), SupportExt::ZExt),
+                         Root, {LHS, SupportExt::ZExt},
+                         {RHS, SupportExt::ZExt});
   if (AllowSExt && LHS.SupportsSExt && RHS.SupportsSExt)
-    return CombineResult(NodeExtensionHelper::getSameExtensionOpcode(
-                             Root->getOpcode(), /*IsSExt=*/true),
-                         Root, LHS, /*SExtLHS=*/true, RHS,
-                         /*SExtRHS=*/true);
+    return CombineResult(NodeExtensionHelper::getFullWidenOpcode(
+                             Root->getOpcode(), SupportExt::SExt),
+                         Root, {LHS, SupportExt::SExt},
+                         {RHS, SupportExt::SExt});
+  if (AllowFPExt && LHS.SupportsFPExt && RHS.SupportsFPExt)
+    return CombineResult(NodeExtensionHelper::getFullWidenOpcode(
+                             Root->getOpcode(), SupportExt::FPExt),
+                         Root, {LHS, SupportExt::FPExt},
+                         {RHS, SupportExt::FPExt});
   return std::nullopt;
 }
 
@@ -13311,7 +13479,8 @@ canFoldToVWWithSameExtension(SDNode *Root, const NodeExtensionHelper &LHS,
                              const NodeExtensionHelper &RHS, SelectionDAG &DAG,
                              const RISCVSubtarget &Subtarget) {
   return canFoldToVWWithSameExtensionImpl(Root, LHS, RHS, /*AllowSExt=*/true,
-                                          /*Allow...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/76695


More information about the llvm-commits mailing list