[llvm] [RISCV][ISel] Combine vector fadd/fsub/fmul with fp extend. (PR #81248)

via llvm-commits llvm-commits at lists.llvm.org
Fri Feb 9 04:59:01 PST 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-backend-risc-v

Author: Chia (sun-jacobi)

<details>
<summary>Changes</summary>

Extend D133739 and #<!-- -->76785 to support vector widening floating-point add/sub/mul instructions.#<!-- -->80477 

Specifically, this patch works for the below optimization case:

### Source code
```
define void @<!-- -->vfwmul_v2f32_multiple_users(ptr %x, ptr %y, ptr %z, <2 x float> %a, <2 x float> %b, <2 x float> %b2) {
  %c = fpext <2 x float> %a to <2 x double>
  %d = fpext <2 x float> %b to <2 x double>
  %d2 = fpext <2 x float> %b2 to <2 x double>
  %e = fmul <2 x double> %c, %d
  %f = fadd <2 x double> %c, %d2
  %g = fsub <2 x double> %d, %d2
  store <2 x double> %e, ptr %x
  store <2 x double> %f, ptr %y
  store <2 x double> %g, ptr %z
  ret void
}
```

### Before this patch
[Compiler Explorer](https://godbolt.org/z/aaEMs5s9h)
```
vfwmul_v2f32_multiple_users:
        vsetivli        zero, 2, e32, mf2, ta, ma
        vfwcvt.f.f.v    v11, v8
        vfwcvt.f.f.v    v8, v9
        vfwcvt.f.f.v    v9, v10
        vsetvli zero, zero, e64, m1, ta, ma
        vfmul.vv        v10, v11, v8
        vfadd.vv        v11, v11, v9
        vfsub.vv        v8, v8, v9
        vse64.v v10, (a0)
        vse64.v v11, (a1)
        vse64.v v8, (a2)
        ret
```

### After this patch
```
vfwmul_v2f32_multiple_users:
        vsetivli zero, 2, e32, mf2, ta, ma
        vfwmul.vv v11, v8, v9
        vfwadd.vv v12, v8, v10
        vfwsub.vv v8, v9, v10
        vse64.v v11, (a0)
        vse64.v v12, (a1)
        vse64.v v8, (a2)
```

> [!NOTE]
> Scalable version will be introduced in another patch.

---

Patch is 34.45 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/81248.diff


5 Files Affected:

- (modified) llvm/lib/Target/RISCV/RISCVISelLowering.cpp (+186-170) 
- (added) llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vfw-web-simplification.ll (+88) 
- (modified) llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vfwadd.ll (+3-5) 
- (modified) llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vfwmul.ll (+4-10) 
- (modified) llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vfwsub.ll (+4-10) 


``````````diff
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index 12c0cd53514da..4ca834b8cf0f0 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -13209,12 +13209,16 @@ namespace {
 // apply a combine.
 struct CombineResult;
 
+enum class ExtKind { ZExt, SExt, FPExt };
+
 /// Helper class for folding sign/zero extensions.
 /// In particular, this class is used for the following combines:
 /// add | add_vl -> vwadd(u) | vwadd(u)_w
 /// sub | sub_vl -> vwsub(u) | vwsub(u)_w
 /// mul | mul_vl -> vwmul(u) | vwmul_su
-///
+/// fadd -> vfwadd | vfwadd_w
+/// fsub -> vfwsub | vfwsub_w
+/// fmul -> vfwmul
 /// An object of this class represents an operand of the operation we want to
 /// combine.
 /// E.g., when trying to combine `mul_vl a, b`, we will have one instance of
@@ -13228,7 +13232,8 @@ struct CombineResult;
 /// - VWADDU_W == add(op0, zext(op1))
 /// - VWSUB_W == sub(op0, sext(op1))
 /// - VWSUBU_W == sub(op0, zext(op1))
-///
+/// - VFWADD_W == fadd(op0, fpext(op1))
+/// - VFWSUB_W == fsub(op0, fpext(op1))
 /// And VMV_V_X_VL, depending on the value, is conceptually equivalent to
 /// zext|sext(smaller_value).
 struct NodeExtensionHelper {
@@ -13239,6 +13244,8 @@ struct NodeExtensionHelper {
   /// instance, a splat constant (e.g., 3), would support being both sign and
   /// zero extended.
   bool SupportsSExt;
+  /// Records if this operand is like being floating-Point extended.
+  bool SupportsFPExt;
   /// This boolean captures whether we care if this operand would still be
   /// around after the folding happens.
   bool EnforceOneUse;
@@ -13262,6 +13269,7 @@ struct NodeExtensionHelper {
     case ISD::SIGN_EXTEND:
     case RISCVISD::VSEXT_VL:
     case RISCVISD::VZEXT_VL:
+    case RISCVISD::FP_EXTEND_VL:
       return OrigOperand.getOperand(0);
     default:
       return OrigOperand;
@@ -13273,22 +13281,34 @@ struct NodeExtensionHelper {
     return OrigOperand.getOpcode() == RISCVISD::VMV_V_X_VL;
   }
 
+  /// Get the extended opcode.
+  unsigned getExtOpc(ExtKind SupportsExt) const {
+    switch (SupportsExt) {
+    case ExtKind::SExt:
+      return RISCVISD::VSEXT_VL;
+    case ExtKind::ZExt:
+      return RISCVISD::VZEXT_VL;
+    case ExtKind::FPExt:
+      return RISCVISD::FP_EXTEND_VL;
+    }
+  }
+
   /// Get or create a value that can feed \p Root with the given extension \p
-  /// SExt. If \p SExt is std::nullopt, this returns the source of this operand.
-  /// \see ::getSource().
+  /// SupportsExt. If \p SExt is std::nullopt, this returns the source of this
+  /// operand. \see ::getSource().
   SDValue getOrCreateExtendedOp(SDNode *Root, SelectionDAG &DAG,
                                 const RISCVSubtarget &Subtarget,
-                                std::optional<bool> SExt) const {
-    if (!SExt.has_value())
+                                std::optional<ExtKind> SupportsExt) const {
+    if (!SupportsExt.has_value())
       return OrigOperand;
 
-    MVT NarrowVT = getNarrowType(Root);
+    MVT NarrowVT = getNarrowType(Root, *SupportsExt);
 
     SDValue Source = getSource();
     if (Source.getValueType() == NarrowVT)
       return Source;
 
-    unsigned ExtOpc = *SExt ? RISCVISD::VSEXT_VL : RISCVISD::VZEXT_VL;
+    unsigned ExtOpc = getExtOpc(*SupportsExt);
 
     // If we need an extension, we should be changing the type.
     SDLoc DL(Root);
@@ -13298,6 +13318,7 @@ struct NodeExtensionHelper {
     case ISD::SIGN_EXTEND:
     case RISCVISD::VSEXT_VL:
     case RISCVISD::VZEXT_VL:
+    case RISCVISD::FP_EXTEND_VL:
       return DAG.getNode(ExtOpc, DL, NarrowVT, Source, Mask, VL);
     case RISCVISD::VMV_V_X_VL:
       return DAG.getNode(RISCVISD::VMV_V_X_VL, DL, NarrowVT,
@@ -13313,41 +13334,57 @@ struct NodeExtensionHelper {
   /// Helper function to get the narrow type for \p Root.
   /// The narrow type is the type of \p Root where we divided the size of each
   /// element by 2. E.g., if Root's type <2xi16> -> narrow type <2xi8>.
-  /// \pre The size of the type of the elements of Root must be a multiple of 2
-  /// and be greater than 16.
-  static MVT getNarrowType(const SDNode *Root) {
+  /// \pre Both the narrow type and the original type should be legal.
+  static MVT getNarrowType(const SDNode *Root, ExtKind SupportsExt) {
     MVT VT = Root->getSimpleValueType(0);
 
     // Determine the narrow size.
     unsigned NarrowSize = VT.getScalarSizeInBits() / 2;
-    assert(NarrowSize >= 8 && "Trying to extend something we can't represent");
-    MVT NarrowVT = MVT::getVectorVT(MVT::getIntegerVT(NarrowSize),
-                                    VT.getVectorElementCount());
+
+    unsigned NarrowMinSize = SupportsExt == ExtKind::FPExt ? 16 : 8;
+
+    MVT EltVT = SupportsExt == ExtKind::FPExt
+                    ? MVT::getFloatingPointVT(NarrowSize)
+                    : MVT::getIntegerVT(NarrowSize);
+
+    assert(NarrowSize >= NarrowMinSize &&
+           "Trying to extend something we can't represent");
+    MVT NarrowVT = MVT::getVectorVT(EltVT, VT.getVectorElementCount());
     return NarrowVT;
   }
 
-  /// Return the opcode required to materialize the folding of the sign
-  /// extensions (\p IsSExt == true) or zero extensions (IsSExt == false) for
+  /// Return the opcode required to materialize the folding for
   /// both operands for \p Opcode.
   /// Put differently, get the opcode to materialize:
-  /// - ISExt == true: \p Opcode(sext(a), sext(b)) -> newOpcode(a, b)
-  /// - ISExt == false: \p Opcode(zext(a), zext(b)) -> newOpcode(a, b)
+  /// - ExtKind::SExt: \p Opcode(sext(a), sext(b)) -> newOpcode(a, b)
+  /// - ExtKind::ZExt: \p Opcode(zext(a), zext(b)) -> newOpcode(a, b)
   /// \pre \p Opcode represents a supported root (\see ::isSupportedRoot()).
-  static unsigned getSameExtensionOpcode(unsigned Opcode, bool IsSExt) {
+  static unsigned getSameExtensionOpcode(unsigned Opcode, ExtKind SupportsExt) {
     switch (Opcode) {
     case ISD::ADD:
     case RISCVISD::ADD_VL:
     case RISCVISD::VWADD_W_VL:
     case RISCVISD::VWADDU_W_VL:
-      return IsSExt ? RISCVISD::VWADD_VL : RISCVISD::VWADDU_VL;
+      return SupportsExt == ExtKind::SExt ? RISCVISD::VWADD_VL
+                                          : RISCVISD::VWADDU_VL;
     case ISD::MUL:
     case RISCVISD::MUL_VL:
-      return IsSExt ? RISCVISD::VWMUL_VL : RISCVISD::VWMULU_VL;
+      return SupportsExt == ExtKind::SExt ? RISCVISD::VWMUL_VL
+                                          : RISCVISD::VWMULU_VL;
     case ISD::SUB:
     case RISCVISD::SUB_VL:
     case RISCVISD::VWSUB_W_VL:
     case RISCVISD::VWSUBU_W_VL:
-      return IsSExt ? RISCVISD::VWSUB_VL : RISCVISD::VWSUBU_VL;
+      return SupportsExt == ExtKind::SExt ? RISCVISD::VWSUB_VL
+                                          : RISCVISD::VWSUBU_VL;
+    case RISCVISD::FADD_VL:
+    case RISCVISD::VFWADD_W_VL:
+      return RISCVISD::VFWADD_VL;
+    case RISCVISD::FSUB_VL:
+    case RISCVISD::VFWSUB_W_VL:
+      return RISCVISD::VFWSUB_VL;
+    case RISCVISD::FMUL_VL:
+      return RISCVISD::VFWMUL_VL;
     default:
       llvm_unreachable("Unexpected opcode");
     }
@@ -13361,16 +13398,22 @@ struct NodeExtensionHelper {
     return RISCVISD::VWMULSU_VL;
   }
 
-  /// Get the opcode to materialize \p Opcode(a, s|zext(b)) ->
-  /// newOpcode(a, b).
-  static unsigned getWOpcode(unsigned Opcode, bool IsSExt) {
+  /// Get the opcode to materialize
+  /// \p Opcode(a, s|z|fpext(b)) -> newOpcode(a, b).
+  static unsigned getWOpcode(unsigned Opcode, ExtKind SupportsExt) {
     switch (Opcode) {
     case ISD::ADD:
     case RISCVISD::ADD_VL:
-      return IsSExt ? RISCVISD::VWADD_W_VL : RISCVISD::VWADDU_W_VL;
+      return SupportsExt == ExtKind::SExt ? RISCVISD::VWADD_W_VL
+                                          : RISCVISD::VWADDU_W_VL;
     case ISD::SUB:
     case RISCVISD::SUB_VL:
-      return IsSExt ? RISCVISD::VWSUB_W_VL : RISCVISD::VWSUBU_W_VL;
+      return SupportsExt == ExtKind::SExt ? RISCVISD::VWSUB_W_VL
+                                          : RISCVISD::VWSUBU_W_VL;
+    case RISCVISD::FADD_VL:
+      return RISCVISD::VFWADD_W_VL;
+    case RISCVISD::FSUB_VL:
+      return RISCVISD::VFWSUB_W_VL;
     default:
       llvm_unreachable("Unexpected opcode");
     }
@@ -13390,6 +13433,7 @@ struct NodeExtensionHelper {
                               const RISCVSubtarget &Subtarget) {
     SupportsZExt = false;
     SupportsSExt = false;
+    SupportsFPExt = false;
     EnforceOneUse = true;
     CheckMask = true;
     unsigned Opc = OrigOperand.getOpcode();
@@ -13431,6 +13475,11 @@ struct NodeExtensionHelper {
       Mask = OrigOperand.getOperand(1);
       VL = OrigOperand.getOperand(2);
       break;
+    case RISCVISD::FP_EXTEND_VL:
+      SupportsFPExt = true;
+      Mask = OrigOperand.getOperand(1);
+      VL = OrigOperand.getOperand(2);
+      break;
     case RISCVISD::VMV_V_X_VL: {
       // Historically, we didn't care about splat values not disappearing during
       // combines.
@@ -13477,15 +13526,16 @@ struct NodeExtensionHelper {
 
   /// Check if \p Root supports any extension folding combines.
   static bool isSupportedRoot(const SDNode *Root, const SelectionDAG &DAG) {
+    const TargetLowering &TLI = DAG.getTargetLoweringInfo();
     switch (Root->getOpcode()) {
     case ISD::ADD:
     case ISD::SUB:
     case ISD::MUL: {
-      const TargetLowering &TLI = DAG.getTargetLoweringInfo();
       if (!TLI.isTypeLegal(Root->getValueType(0)))
         return false;
       return Root->getValueType(0).isScalableVector();
     }
+    // Vector Widening Integer Add/Sub/Mul Instructions
     case RISCVISD::ADD_VL:
     case RISCVISD::MUL_VL:
     case RISCVISD::VWADD_W_VL:
@@ -13493,7 +13543,13 @@ struct NodeExtensionHelper {
     case RISCVISD::SUB_VL:
     case RISCVISD::VWSUB_W_VL:
     case RISCVISD::VWSUBU_W_VL:
-      return true;
+    // Vector Widening Floating-Point Add/Sub/Mul Instructions
+    case RISCVISD::FADD_VL:
+    case RISCVISD::FSUB_VL:
+    case RISCVISD::FMUL_VL:
+    case RISCVISD::VFWADD_W_VL:
+    case RISCVISD::VFWSUB_W_VL:
+      return TLI.isTypeLegal(Root->getValueType(0));
     default:
       return false;
     }
@@ -13509,16 +13565,23 @@ struct NodeExtensionHelper {
 
     unsigned Opc = Root->getOpcode();
     switch (Opc) {
-    // We consider VW<ADD|SUB>(U)_W(LHS, RHS) as if they were
-    // <ADD|SUB>(LHS, S|ZEXT(RHS))
+    // We consider
+    // VW<ADD|SUB>_W(LHS, RHS) -> <ADD|SUB>(LHS, SEXT(RHS))
+    // VW<ADD|SUB>U_W(LHS, RHS) -> <ADD|SUB>(LHS, ZEXT(RHS))
+    // VFW<ADD|SUB>_W(LHS, RHS) -> F<ADD|SUB>(LHS, FPEXT(RHS))
     case RISCVISD::VWADD_W_VL:
     case RISCVISD::VWADDU_W_VL:
     case RISCVISD::VWSUB_W_VL:
     case RISCVISD::VWSUBU_W_VL:
+    case RISCVISD::VFWADD_W_VL:
+    case RISCVISD::VFWSUB_W_VL:
       if (OperandIdx == 1) {
         SupportsZExt =
             Opc == RISCVISD::VWADDU_W_VL || Opc == RISCVISD::VWSUBU_W_VL;
-        SupportsSExt = !SupportsZExt;
+        SupportsSExt =
+            Opc == RISCVISD::VWADD_W_VL || Opc == RISCVISD::VWSUB_W_VL;
+        SupportsFPExt =
+            Opc == RISCVISD::VFWADD_W_VL || Opc == RISCVISD::VFWSUB_W_VL;
         std::tie(Mask, VL) = getMaskAndVL(Root, DAG, Subtarget);
         CheckMask = true;
         // There's no existing extension here, so we don't have to worry about
@@ -13578,11 +13641,16 @@ struct NodeExtensionHelper {
     case RISCVISD::MUL_VL:
     case RISCVISD::VWADD_W_VL:
     case RISCVISD::VWADDU_W_VL:
+    case RISCVISD::FADD_VL:
+    case RISCVISD::FMUL_VL:
+    case RISCVISD::VFWADD_W_VL:
       return true;
     case ISD::SUB:
     case RISCVISD::SUB_VL:
     case RISCVISD::VWSUB_W_VL:
     case RISCVISD::VWSUBU_W_VL:
+    case RISCVISD::FSUB_VL:
+    case RISCVISD::VFWSUB_W_VL:
       return false;
     default:
       llvm_unreachable("Unexpected opcode");
@@ -13604,10 +13672,9 @@ struct NodeExtensionHelper {
 struct CombineResult {
   /// Opcode to be generated when materializing the combine.
   unsigned TargetOpcode;
-  // No value means no extension is needed. If extension is needed, the value
-  // indicates if it needs to be sign extended.
-  std::optional<bool> SExtLHS;
-  std::optional<bool> SExtRHS;
+  // No value means no extension is needed.
+  std::optional<ExtKind> LHSExt;
+  std::optional<ExtKind> RHSExt;
   /// Root of the combine.
   SDNode *Root;
   /// LHS of the TargetOpcode.
@@ -13616,10 +13683,10 @@ struct CombineResult {
   NodeExtensionHelper RHS;
 
   CombineResult(unsigned TargetOpcode, SDNode *Root,
-                const NodeExtensionHelper &LHS, std::optional<bool> SExtLHS,
-                const NodeExtensionHelper &RHS, std::optional<bool> SExtRHS)
-      : TargetOpcode(TargetOpcode), SExtLHS(SExtLHS), SExtRHS(SExtRHS),
-        Root(Root), LHS(LHS), RHS(RHS) {}
+                const NodeExtensionHelper &LHS, std::optional<ExtKind> LHSExt,
+                const NodeExtensionHelper &RHS, std::optional<ExtKind> RHSExt)
+      : TargetOpcode(TargetOpcode), LHSExt(LHSExt), RHSExt(RHSExt), Root(Root),
+        LHS(LHS), RHS(RHS) {}
 
   /// Return a value that uses TargetOpcode and that can be used to replace
   /// Root.
@@ -13640,8 +13707,8 @@ struct CombineResult {
       break;
     }
     return DAG.getNode(TargetOpcode, SDLoc(Root), Root->getValueType(0),
-                       LHS.getOrCreateExtendedOp(Root, DAG, Subtarget, SExtLHS),
-                       RHS.getOrCreateExtendedOp(Root, DAG, Subtarget, SExtRHS),
+                       LHS.getOrCreateExtendedOp(Root, DAG, Subtarget, LHSExt),
+                       RHS.getOrCreateExtendedOp(Root, DAG, Subtarget, RHSExt),
                        Merge, Mask, VL);
   }
 };
@@ -13656,24 +13723,30 @@ struct CombineResult {
 ///
 /// \returns std::nullopt if the pattern doesn't match or a CombineResult that
 /// can be used to apply the pattern.
-static std::optional<CombineResult>
-canFoldToVWWithSameExtensionImpl(SDNode *Root, const NodeExtensionHelper &LHS,
-                                 const NodeExtensionHelper &RHS, bool AllowSExt,
-                                 bool AllowZExt, SelectionDAG &DAG,
-                                 const RISCVSubtarget &Subtarget) {
-  assert((AllowSExt || AllowZExt) && "Forgot to set what you want?");
+static std::optional<CombineResult> canFoldToVWWithSameExtensionImpl(
+    SDNode *Root, const NodeExtensionHelper &LHS,
+    const NodeExtensionHelper &RHS, bool AllowSExt, bool AllowZExt,
+    bool AllowFPExt, SelectionDAG &DAG, const RISCVSubtarget &Subtarget) {
+  assert((AllowSExt || AllowZExt || AllowFPExt) &&
+         "Forgot to set what you want?");
   if (!LHS.areVLAndMaskCompatible(Root, DAG, Subtarget) ||
       !RHS.areVLAndMaskCompatible(Root, DAG, Subtarget))
     return std::nullopt;
   if (AllowZExt && LHS.SupportsZExt && RHS.SupportsZExt)
     return CombineResult(NodeExtensionHelper::getSameExtensionOpcode(
-                             Root->getOpcode(), /*IsSExt=*/false),
-                         Root, LHS, /*SExtLHS=*/false, RHS, /*SExtRHS=*/false);
+                             Root->getOpcode(), ExtKind::ZExt),
+                         Root, LHS, /*LHSExt=*/{ExtKind::ZExt}, RHS,
+                         /*RHSExt=*/{ExtKind::ZExt});
   if (AllowSExt && LHS.SupportsSExt && RHS.SupportsSExt)
     return CombineResult(NodeExtensionHelper::getSameExtensionOpcode(
-                             Root->getOpcode(), /*IsSExt=*/true),
-                         Root, LHS, /*SExtLHS=*/true, RHS,
-                         /*SExtRHS=*/true);
+                             Root->getOpcode(), ExtKind::SExt),
+                         Root, LHS, /*LHSExt=*/{ExtKind::SExt}, RHS,
+                         /*RHSExt=*/{ExtKind::SExt});
+  if (AllowFPExt && LHS.SupportsFPExt && RHS.SupportsFPExt)
+    return CombineResult(NodeExtensionHelper::getSameExtensionOpcode(
+                             Root->getOpcode(), ExtKind::FPExt),
+                         Root, LHS, /*LHSExt=*/{ExtKind::FPExt}, RHS,
+                         /*RHSExt=*/{ExtKind::FPExt});
   return std::nullopt;
 }
 
@@ -13688,7 +13761,8 @@ canFoldToVWWithSameExtension(SDNode *Root, const NodeExtensionHelper &LHS,
                              const NodeExtensionHelper &RHS, SelectionDAG &DAG,
                              const RISCVSubtarget &Subtarget) {
   return canFoldToVWWithSameExtensionImpl(Root, LHS, RHS, /*AllowSExt=*/true,
-                                          /*AllowZExt=*/true, DAG, Subtarget);
+                                          /*AllowZExt=*/true,
+                                          /*AllowFPExt=*/true, DAG, Subtarget);
 }
 
 /// Check if \p Root follows a pattern Root(LHS, ext(RHS))
@@ -13702,18 +13776,23 @@ canFoldToVW_W(SDNode *Root, const NodeExtensionHelper &LHS,
   if (!RHS.areVLAndMaskCompatible(Root, DAG, Subtarget))
     return std::nullopt;
 
+  if (RHS.SupportsFPExt)
+    return CombineResult(
+        NodeExtensionHelper::getWOpcode(Root->getOpcode(), ExtKind::FPExt),
+        Root, LHS, /*LHSExt=*/std::nullopt, RHS, /*RHSExt=*/{ExtKind::FPExt});
+
   // FIXME: Is it useful to form a vwadd.wx or vwsub.wx if it removes a scalar
   // sext/zext?
   // Control this behavior behind an option (AllowSplatInVW_W) for testing
   // purposes.
   if (RHS.SupportsZExt && (!RHS.isSplat() || AllowSplatInVW_W))
     return CombineResult(
-        NodeExtensionHelper::getWOpcode(Root->getOpcode(), /*IsSExt=*/false),
-        Root, LHS, /*SExtLHS=*/std::nullopt, RHS, /*SExtRHS=*/false);
+        NodeExtensionHelper::getWOpcode(Root->getOpcode(), ExtKind::ZExt), Root,
+        LHS, /*LHSExt=*/std::nullopt, RHS, /*RHSExt=*/{ExtKind::ZExt});
   if (RHS.SupportsSExt && (!RHS.isSplat() || AllowSplatInVW_W))
     return CombineResult(
-        NodeExtensionHelper::getWOpcode(Root->getOpcode(), /*IsSExt=*/true),
-        Root, LHS, /*SExtLHS=*/std::nullopt, RHS, /*SExtRHS=*/true);
+        NodeExtensionHelper::getWOpcode(Root->getOpcode(), ExtKind::SExt), Root,
+        LHS, /*LHSExt=*/std::nullopt, RHS, /*RHSExt=*/{ExtKind::SExt});
   return std::nullopt;
 }
 
@@ -13726,7 +13805,8 @@ canFoldToVWWithSEXT(SDNode *Root, const NodeExtensionHelper &LHS,
                     const NodeExtensionHelper &RHS, SelectionDAG &DAG,
                     const RISCVSubtarget &Subtarget) {
   return canFoldToVWWithSameExtensionImpl(Root, LHS, RHS, /*AllowSExt=*/true,
-                                          /*AllowZExt=*/false, DAG, Subtarget);
+                                          /*AllowZExt=*/false,
+                                          /*AllowFPExt=*/false, DAG, Subtarget);
 }
 
 /// Check if \p Root follows a pattern Root(zext(LHS), zext(RHS))
@@ -13738,7 +13818,21 @@ canFoldToVWWithZEXT(SDNode *Root, const NodeExtensionHelper &LHS,
                     const NodeExtensionHelper &RHS, SelectionDAG &DAG,
                     const RISCVSubtarget &Subtarget) {
   return canFoldToVWWithSameExtensionImpl(Root, LHS, RHS, /*AllowSExt=*/false,
-                                          /*AllowZExt=*/true, DAG, Subtarget);
+                                          /*AllowZExt=*/true,
+                                          /*AllowFPExt=*/false, DAG, Subtarget);
+}
+
+/// Check if \p Root follows a pattern Root(fpext(LHS), fpext(RHS))
+///
+/// \returns std::nullopt if the pattern doesn't match or a CombineResult that
+/// can be used to apply the pattern.
+static std::optional<CombineResult>
+canFoldToVWWithFPEXT(SDNode *Root, const NodeExtensionHelper &LHS,
+                     const NodeExtensionHelper &RHS, SelectionDAG &DAG,
+                     const RISCVSubtarget &Subtarget) {
+  return canFoldToVWWithSameExtensionImpl(Root, LHS, RHS, /*AllowSExt=*/false,
+                                          /*AllowZExt=*/false,
+                                          /*AllowFPExt=*/true, DAG, Subtarget);
 }
 
 /// Check if \p Root follows a pattern Root(sext(LHS), zext(RHS))
@@ -13756,7 +13850,8 @@ canFoldToVW_SU(SDNode *Root, const NodeExtensionHelper &LHS,
       !RHS.areVLAndMaskCompatible(Root, DAG, Subtarget))
     return std::nullopt;
   return CombineResult(NodeExtensionHelper::getSUOpcode(Root->getOpcode()),
-                       Root, LHS, /*SExtLHS=*/true, RHS, /*SExtRHS=*/false);
+                       Root, LHS, /*LHSExt=*/{ExtKind::SExt}, RHS,
+                       /*RHSExt=*/{ExtKind::ZExt});
 }
 
 SmallVector<NodeExtensionHelper::Co...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/81248


More information about the llvm-commits mailing list