[llvm] [RISCV][ISel] Combine vector fadd/fsub/fmul with fp extend. (PR #81248)

via llvm-commits llvm-commits at lists.llvm.org
Wed Feb 21 04:07:10 PST 2024


https://github.com/sun-jacobi updated https://github.com/llvm/llvm-project/pull/81248

>From b7ebaeb98e77ae6f20b7a9c9531dd4f7b45a557a Mon Sep 17 00:00:00 2001
From: sun-jacobi <sun1011jacobi at gmail.com>
Date: Fri, 9 Feb 2024 20:47:50 +0800
Subject: [PATCH 1/3] [RISCV][ISel] Combine vector fadd/fsub/fmul with fp
 extend.

---
 llvm/lib/Target/RISCV/RISCVISelLowering.cpp   | 356 +++++++++---------
 .../fixed-vectors-vfw-web-simplification.ll   |  88 +++++
 .../CodeGen/RISCV/rvv/fixed-vectors-vfwadd.ll |   8 +-
 .../CodeGen/RISCV/rvv/fixed-vectors-vfwmul.ll |  14 +-
 .../CodeGen/RISCV/rvv/fixed-vectors-vfwsub.ll |  14 +-
 5 files changed, 285 insertions(+), 195 deletions(-)
 create mode 100644 llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vfw-web-simplification.ll

diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index c2fef4993f6ec8..2b302e94e1ed0c 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -13316,12 +13316,16 @@ namespace {
 // apply a combine.
 struct CombineResult;
 
+enum class ExtKind { ZExt, SExt, FPExt };
+
 /// Helper class for folding sign/zero extensions.
 /// In particular, this class is used for the following combines:
 /// add | add_vl -> vwadd(u) | vwadd(u)_w
 /// sub | sub_vl -> vwsub(u) | vwsub(u)_w
 /// mul | mul_vl -> vwmul(u) | vwmul_su
-///
+/// fadd -> vfwadd | vfwadd_w
+/// fsub -> vfwsub | vfwsub_w
+/// fmul -> vfwmul
 /// An object of this class represents an operand of the operation we want to
 /// combine.
 /// E.g., when trying to combine `mul_vl a, b`, we will have one instance of
@@ -13335,7 +13339,8 @@ struct CombineResult;
 /// - VWADDU_W == add(op0, zext(op1))
 /// - VWSUB_W == sub(op0, sext(op1))
 /// - VWSUBU_W == sub(op0, zext(op1))
-///
+/// - VFWADD_W == fadd(op0, fpext(op1))
+/// - VFWSUB_W == fsub(op0, fpext(op1))
 /// And VMV_V_X_VL, depending on the value, is conceptually equivalent to
 /// zext|sext(smaller_value).
 struct NodeExtensionHelper {
@@ -13346,6 +13351,8 @@ struct NodeExtensionHelper {
   /// instance, a splat constant (e.g., 3), would support being both sign and
   /// zero extended.
   bool SupportsSExt;
+  /// Records if this operand is like being floating-Point extended.
+  bool SupportsFPExt;
   /// This boolean captures whether we care if this operand would still be
   /// around after the folding happens.
   bool EnforceOneUse;
@@ -13369,6 +13376,7 @@ struct NodeExtensionHelper {
     case ISD::SIGN_EXTEND:
     case RISCVISD::VSEXT_VL:
     case RISCVISD::VZEXT_VL:
+    case RISCVISD::FP_EXTEND_VL:
       return OrigOperand.getOperand(0);
     default:
       return OrigOperand;
@@ -13380,22 +13388,34 @@ struct NodeExtensionHelper {
     return OrigOperand.getOpcode() == RISCVISD::VMV_V_X_VL;
   }
 
+  /// Get the extended opcode.
+  unsigned getExtOpc(ExtKind SupportsExt) const {
+    switch (SupportsExt) {
+    case ExtKind::SExt:
+      return RISCVISD::VSEXT_VL;
+    case ExtKind::ZExt:
+      return RISCVISD::VZEXT_VL;
+    case ExtKind::FPExt:
+      return RISCVISD::FP_EXTEND_VL;
+    }
+  }
+
   /// Get or create a value that can feed \p Root with the given extension \p
-  /// SExt. If \p SExt is std::nullopt, this returns the source of this operand.
-  /// \see ::getSource().
+  /// SupportsExt. If \p SExt is std::nullopt, this returns the source of this
+  /// operand. \see ::getSource().
   SDValue getOrCreateExtendedOp(SDNode *Root, SelectionDAG &DAG,
                                 const RISCVSubtarget &Subtarget,
-                                std::optional<bool> SExt) const {
-    if (!SExt.has_value())
+                                std::optional<ExtKind> SupportsExt) const {
+    if (!SupportsExt.has_value())
       return OrigOperand;
 
-    MVT NarrowVT = getNarrowType(Root);
+    MVT NarrowVT = getNarrowType(Root, *SupportsExt);
 
     SDValue Source = getSource();
     if (Source.getValueType() == NarrowVT)
       return Source;
 
-    unsigned ExtOpc = *SExt ? RISCVISD::VSEXT_VL : RISCVISD::VZEXT_VL;
+    unsigned ExtOpc = getExtOpc(*SupportsExt);
 
     // If we need an extension, we should be changing the type.
     SDLoc DL(OrigOperand);
@@ -13405,6 +13425,7 @@ struct NodeExtensionHelper {
     case ISD::SIGN_EXTEND:
     case RISCVISD::VSEXT_VL:
     case RISCVISD::VZEXT_VL:
+    case RISCVISD::FP_EXTEND_VL:
       return DAG.getNode(ExtOpc, DL, NarrowVT, Source, Mask, VL);
     case RISCVISD::VMV_V_X_VL:
       return DAG.getNode(RISCVISD::VMV_V_X_VL, DL, NarrowVT,
@@ -13420,41 +13441,57 @@ struct NodeExtensionHelper {
   /// Helper function to get the narrow type for \p Root.
   /// The narrow type is the type of \p Root where we divided the size of each
   /// element by 2. E.g., if Root's type <2xi16> -> narrow type <2xi8>.
-  /// \pre The size of the type of the elements of Root must be a multiple of 2
-  /// and be greater than 16.
-  static MVT getNarrowType(const SDNode *Root) {
+  /// \pre Both the narrow type and the original type should be legal.
+  static MVT getNarrowType(const SDNode *Root, ExtKind SupportsExt) {
     MVT VT = Root->getSimpleValueType(0);
 
     // Determine the narrow size.
     unsigned NarrowSize = VT.getScalarSizeInBits() / 2;
-    assert(NarrowSize >= 8 && "Trying to extend something we can't represent");
-    MVT NarrowVT = MVT::getVectorVT(MVT::getIntegerVT(NarrowSize),
-                                    VT.getVectorElementCount());
+
+    unsigned NarrowMinSize = SupportsExt == ExtKind::FPExt ? 16 : 8;
+
+    MVT EltVT = SupportsExt == ExtKind::FPExt
+                    ? MVT::getFloatingPointVT(NarrowSize)
+                    : MVT::getIntegerVT(NarrowSize);
+
+    assert(NarrowSize >= NarrowMinSize &&
+           "Trying to extend something we can't represent");
+    MVT NarrowVT = MVT::getVectorVT(EltVT, VT.getVectorElementCount());
     return NarrowVT;
   }
 
-  /// Return the opcode required to materialize the folding of the sign
-  /// extensions (\p IsSExt == true) or zero extensions (IsSExt == false) for
+  /// Return the opcode required to materialize the folding for
   /// both operands for \p Opcode.
   /// Put differently, get the opcode to materialize:
-  /// - ISExt == true: \p Opcode(sext(a), sext(b)) -> newOpcode(a, b)
-  /// - ISExt == false: \p Opcode(zext(a), zext(b)) -> newOpcode(a, b)
+  /// - ExtKind::SExt: \p Opcode(sext(a), sext(b)) -> newOpcode(a, b)
+  /// - ExtKind::ZExt: \p Opcode(zext(a), zext(b)) -> newOpcode(a, b)
   /// \pre \p Opcode represents a supported root (\see ::isSupportedRoot()).
-  static unsigned getSameExtensionOpcode(unsigned Opcode, bool IsSExt) {
+  static unsigned getSameExtensionOpcode(unsigned Opcode, ExtKind SupportsExt) {
     switch (Opcode) {
     case ISD::ADD:
     case RISCVISD::ADD_VL:
     case RISCVISD::VWADD_W_VL:
     case RISCVISD::VWADDU_W_VL:
-      return IsSExt ? RISCVISD::VWADD_VL : RISCVISD::VWADDU_VL;
+      return SupportsExt == ExtKind::SExt ? RISCVISD::VWADD_VL
+                                          : RISCVISD::VWADDU_VL;
     case ISD::MUL:
     case RISCVISD::MUL_VL:
-      return IsSExt ? RISCVISD::VWMUL_VL : RISCVISD::VWMULU_VL;
+      return SupportsExt == ExtKind::SExt ? RISCVISD::VWMUL_VL
+                                          : RISCVISD::VWMULU_VL;
     case ISD::SUB:
     case RISCVISD::SUB_VL:
     case RISCVISD::VWSUB_W_VL:
     case RISCVISD::VWSUBU_W_VL:
-      return IsSExt ? RISCVISD::VWSUB_VL : RISCVISD::VWSUBU_VL;
+      return SupportsExt == ExtKind::SExt ? RISCVISD::VWSUB_VL
+                                          : RISCVISD::VWSUBU_VL;
+    case RISCVISD::FADD_VL:
+    case RISCVISD::VFWADD_W_VL:
+      return RISCVISD::VFWADD_VL;
+    case RISCVISD::FSUB_VL:
+    case RISCVISD::VFWSUB_W_VL:
+      return RISCVISD::VFWSUB_VL;
+    case RISCVISD::FMUL_VL:
+      return RISCVISD::VFWMUL_VL;
     default:
       llvm_unreachable("Unexpected opcode");
     }
@@ -13468,16 +13505,22 @@ struct NodeExtensionHelper {
     return RISCVISD::VWMULSU_VL;
   }
 
-  /// Get the opcode to materialize \p Opcode(a, s|zext(b)) ->
-  /// newOpcode(a, b).
-  static unsigned getWOpcode(unsigned Opcode, bool IsSExt) {
+  /// Get the opcode to materialize
+  /// \p Opcode(a, s|z|fpext(b)) -> newOpcode(a, b).
+  static unsigned getWOpcode(unsigned Opcode, ExtKind SupportsExt) {
     switch (Opcode) {
     case ISD::ADD:
     case RISCVISD::ADD_VL:
-      return IsSExt ? RISCVISD::VWADD_W_VL : RISCVISD::VWADDU_W_VL;
+      return SupportsExt == ExtKind::SExt ? RISCVISD::VWADD_W_VL
+                                          : RISCVISD::VWADDU_W_VL;
     case ISD::SUB:
     case RISCVISD::SUB_VL:
-      return IsSExt ? RISCVISD::VWSUB_W_VL : RISCVISD::VWSUBU_W_VL;
+      return SupportsExt == ExtKind::SExt ? RISCVISD::VWSUB_W_VL
+                                          : RISCVISD::VWSUBU_W_VL;
+    case RISCVISD::FADD_VL:
+      return RISCVISD::VFWADD_W_VL;
+    case RISCVISD::FSUB_VL:
+      return RISCVISD::VFWSUB_W_VL;
     default:
       llvm_unreachable("Unexpected opcode");
     }
@@ -13497,6 +13540,7 @@ struct NodeExtensionHelper {
                               const RISCVSubtarget &Subtarget) {
     SupportsZExt = false;
     SupportsSExt = false;
+    SupportsFPExt = false;
     EnforceOneUse = true;
     CheckMask = true;
     unsigned Opc = OrigOperand.getOpcode();
@@ -13538,6 +13582,11 @@ struct NodeExtensionHelper {
       Mask = OrigOperand.getOperand(1);
       VL = OrigOperand.getOperand(2);
       break;
+    case RISCVISD::FP_EXTEND_VL:
+      SupportsFPExt = true;
+      Mask = OrigOperand.getOperand(1);
+      VL = OrigOperand.getOperand(2);
+      break;
     case RISCVISD::VMV_V_X_VL: {
       // Historically, we didn't care about splat values not disappearing during
       // combines.
@@ -13584,15 +13633,16 @@ struct NodeExtensionHelper {
 
   /// Check if \p Root supports any extension folding combines.
   static bool isSupportedRoot(const SDNode *Root, const SelectionDAG &DAG) {
+    const TargetLowering &TLI = DAG.getTargetLoweringInfo();
     switch (Root->getOpcode()) {
     case ISD::ADD:
     case ISD::SUB:
     case ISD::MUL: {
-      const TargetLowering &TLI = DAG.getTargetLoweringInfo();
       if (!TLI.isTypeLegal(Root->getValueType(0)))
         return false;
       return Root->getValueType(0).isScalableVector();
     }
+    // Vector Widening Integer Add/Sub/Mul Instructions
     case RISCVISD::ADD_VL:
     case RISCVISD::MUL_VL:
     case RISCVISD::VWADD_W_VL:
@@ -13600,7 +13650,13 @@ struct NodeExtensionHelper {
     case RISCVISD::SUB_VL:
     case RISCVISD::VWSUB_W_VL:
     case RISCVISD::VWSUBU_W_VL:
-      return true;
+    // Vector Widening Floating-Point Add/Sub/Mul Instructions
+    case RISCVISD::FADD_VL:
+    case RISCVISD::FSUB_VL:
+    case RISCVISD::FMUL_VL:
+    case RISCVISD::VFWADD_W_VL:
+    case RISCVISD::VFWSUB_W_VL:
+      return TLI.isTypeLegal(Root->getValueType(0));
     default:
       return false;
     }
@@ -13616,16 +13672,23 @@ struct NodeExtensionHelper {
 
     unsigned Opc = Root->getOpcode();
     switch (Opc) {
-    // We consider VW<ADD|SUB>(U)_W(LHS, RHS) as if they were
-    // <ADD|SUB>(LHS, S|ZEXT(RHS))
+    // We consider
+    // VW<ADD|SUB>_W(LHS, RHS) -> <ADD|SUB>(LHS, SEXT(RHS))
+    // VW<ADD|SUB>U_W(LHS, RHS) -> <ADD|SUB>(LHS, ZEXT(RHS))
+    // VFW<ADD|SUB>_W(LHS, RHS) -> F<ADD|SUB>(LHS, FPEXT(RHS))
     case RISCVISD::VWADD_W_VL:
     case RISCVISD::VWADDU_W_VL:
     case RISCVISD::VWSUB_W_VL:
     case RISCVISD::VWSUBU_W_VL:
+    case RISCVISD::VFWADD_W_VL:
+    case RISCVISD::VFWSUB_W_VL:
       if (OperandIdx == 1) {
         SupportsZExt =
             Opc == RISCVISD::VWADDU_W_VL || Opc == RISCVISD::VWSUBU_W_VL;
-        SupportsSExt = !SupportsZExt;
+        SupportsSExt =
+            Opc == RISCVISD::VWADD_W_VL || Opc == RISCVISD::VWSUB_W_VL;
+        SupportsFPExt =
+            Opc == RISCVISD::VFWADD_W_VL || Opc == RISCVISD::VFWSUB_W_VL;
         std::tie(Mask, VL) = getMaskAndVL(Root, DAG, Subtarget);
         CheckMask = true;
         // There's no existing extension here, so we don't have to worry about
@@ -13685,11 +13748,16 @@ struct NodeExtensionHelper {
     case RISCVISD::MUL_VL:
     case RISCVISD::VWADD_W_VL:
     case RISCVISD::VWADDU_W_VL:
+    case RISCVISD::FADD_VL:
+    case RISCVISD::FMUL_VL:
+    case RISCVISD::VFWADD_W_VL:
       return true;
     case ISD::SUB:
     case RISCVISD::SUB_VL:
     case RISCVISD::VWSUB_W_VL:
     case RISCVISD::VWSUBU_W_VL:
+    case RISCVISD::FSUB_VL:
+    case RISCVISD::VFWSUB_W_VL:
       return false;
     default:
       llvm_unreachable("Unexpected opcode");
@@ -13711,10 +13779,9 @@ struct NodeExtensionHelper {
 struct CombineResult {
   /// Opcode to be generated when materializing the combine.
   unsigned TargetOpcode;
-  // No value means no extension is needed. If extension is needed, the value
-  // indicates if it needs to be sign extended.
-  std::optional<bool> SExtLHS;
-  std::optional<bool> SExtRHS;
+  // No value means no extension is needed.
+  std::optional<ExtKind> LHSExt;
+  std::optional<ExtKind> RHSExt;
   /// Root of the combine.
   SDNode *Root;
   /// LHS of the TargetOpcode.
@@ -13723,10 +13790,10 @@ struct CombineResult {
   NodeExtensionHelper RHS;
 
   CombineResult(unsigned TargetOpcode, SDNode *Root,
-                const NodeExtensionHelper &LHS, std::optional<bool> SExtLHS,
-                const NodeExtensionHelper &RHS, std::optional<bool> SExtRHS)
-      : TargetOpcode(TargetOpcode), SExtLHS(SExtLHS), SExtRHS(SExtRHS),
-        Root(Root), LHS(LHS), RHS(RHS) {}
+                const NodeExtensionHelper &LHS, std::optional<ExtKind> LHSExt,
+                const NodeExtensionHelper &RHS, std::optional<ExtKind> RHSExt)
+      : TargetOpcode(TargetOpcode), LHSExt(LHSExt), RHSExt(RHSExt), Root(Root),
+        LHS(LHS), RHS(RHS) {}
 
   /// Return a value that uses TargetOpcode and that can be used to replace
   /// Root.
@@ -13747,8 +13814,8 @@ struct CombineResult {
       break;
     }
     return DAG.getNode(TargetOpcode, SDLoc(Root), Root->getValueType(0),
-                       LHS.getOrCreateExtendedOp(Root, DAG, Subtarget, SExtLHS),
-                       RHS.getOrCreateExtendedOp(Root, DAG, Subtarget, SExtRHS),
+                       LHS.getOrCreateExtendedOp(Root, DAG, Subtarget, LHSExt),
+                       RHS.getOrCreateExtendedOp(Root, DAG, Subtarget, RHSExt),
                        Merge, Mask, VL);
   }
 };
@@ -13763,24 +13830,30 @@ struct CombineResult {
 ///
 /// \returns std::nullopt if the pattern doesn't match or a CombineResult that
 /// can be used to apply the pattern.
-static std::optional<CombineResult>
-canFoldToVWWithSameExtensionImpl(SDNode *Root, const NodeExtensionHelper &LHS,
-                                 const NodeExtensionHelper &RHS, bool AllowSExt,
-                                 bool AllowZExt, SelectionDAG &DAG,
-                                 const RISCVSubtarget &Subtarget) {
-  assert((AllowSExt || AllowZExt) && "Forgot to set what you want?");
+static std::optional<CombineResult> canFoldToVWWithSameExtensionImpl(
+    SDNode *Root, const NodeExtensionHelper &LHS,
+    const NodeExtensionHelper &RHS, bool AllowSExt, bool AllowZExt,
+    bool AllowFPExt, SelectionDAG &DAG, const RISCVSubtarget &Subtarget) {
+  assert((AllowSExt || AllowZExt || AllowFPExt) &&
+         "Forgot to set what you want?");
   if (!LHS.areVLAndMaskCompatible(Root, DAG, Subtarget) ||
       !RHS.areVLAndMaskCompatible(Root, DAG, Subtarget))
     return std::nullopt;
   if (AllowZExt && LHS.SupportsZExt && RHS.SupportsZExt)
     return CombineResult(NodeExtensionHelper::getSameExtensionOpcode(
-                             Root->getOpcode(), /*IsSExt=*/false),
-                         Root, LHS, /*SExtLHS=*/false, RHS, /*SExtRHS=*/false);
+                             Root->getOpcode(), ExtKind::ZExt),
+                         Root, LHS, /*LHSExt=*/{ExtKind::ZExt}, RHS,
+                         /*RHSExt=*/{ExtKind::ZExt});
   if (AllowSExt && LHS.SupportsSExt && RHS.SupportsSExt)
     return CombineResult(NodeExtensionHelper::getSameExtensionOpcode(
-                             Root->getOpcode(), /*IsSExt=*/true),
-                         Root, LHS, /*SExtLHS=*/true, RHS,
-                         /*SExtRHS=*/true);
+                             Root->getOpcode(), ExtKind::SExt),
+                         Root, LHS, /*LHSExt=*/{ExtKind::SExt}, RHS,
+                         /*RHSExt=*/{ExtKind::SExt});
+  if (AllowFPExt && LHS.SupportsFPExt && RHS.SupportsFPExt)
+    return CombineResult(NodeExtensionHelper::getSameExtensionOpcode(
+                             Root->getOpcode(), ExtKind::FPExt),
+                         Root, LHS, /*LHSExt=*/{ExtKind::FPExt}, RHS,
+                         /*RHSExt=*/{ExtKind::FPExt});
   return std::nullopt;
 }
 
@@ -13795,7 +13868,8 @@ canFoldToVWWithSameExtension(SDNode *Root, const NodeExtensionHelper &LHS,
                              const NodeExtensionHelper &RHS, SelectionDAG &DAG,
                              const RISCVSubtarget &Subtarget) {
   return canFoldToVWWithSameExtensionImpl(Root, LHS, RHS, /*AllowSExt=*/true,
-                                          /*AllowZExt=*/true, DAG, Subtarget);
+                                          /*AllowZExt=*/true,
+                                          /*AllowFPExt=*/true, DAG, Subtarget);
 }
 
 /// Check if \p Root follows a pattern Root(LHS, ext(RHS))
@@ -13809,18 +13883,23 @@ canFoldToVW_W(SDNode *Root, const NodeExtensionHelper &LHS,
   if (!RHS.areVLAndMaskCompatible(Root, DAG, Subtarget))
     return std::nullopt;
 
+  if (RHS.SupportsFPExt)
+    return CombineResult(
+        NodeExtensionHelper::getWOpcode(Root->getOpcode(), ExtKind::FPExt),
+        Root, LHS, /*LHSExt=*/std::nullopt, RHS, /*RHSExt=*/{ExtKind::FPExt});
+
   // FIXME: Is it useful to form a vwadd.wx or vwsub.wx if it removes a scalar
   // sext/zext?
   // Control this behavior behind an option (AllowSplatInVW_W) for testing
   // purposes.
   if (RHS.SupportsZExt && (!RHS.isSplat() || AllowSplatInVW_W))
     return CombineResult(
-        NodeExtensionHelper::getWOpcode(Root->getOpcode(), /*IsSExt=*/false),
-        Root, LHS, /*SExtLHS=*/std::nullopt, RHS, /*SExtRHS=*/false);
+        NodeExtensionHelper::getWOpcode(Root->getOpcode(), ExtKind::ZExt), Root,
+        LHS, /*LHSExt=*/std::nullopt, RHS, /*RHSExt=*/{ExtKind::ZExt});
   if (RHS.SupportsSExt && (!RHS.isSplat() || AllowSplatInVW_W))
     return CombineResult(
-        NodeExtensionHelper::getWOpcode(Root->getOpcode(), /*IsSExt=*/true),
-        Root, LHS, /*SExtLHS=*/std::nullopt, RHS, /*SExtRHS=*/true);
+        NodeExtensionHelper::getWOpcode(Root->getOpcode(), ExtKind::SExt), Root,
+        LHS, /*LHSExt=*/std::nullopt, RHS, /*RHSExt=*/{ExtKind::SExt});
   return std::nullopt;
 }
 
@@ -13833,7 +13912,8 @@ canFoldToVWWithSEXT(SDNode *Root, const NodeExtensionHelper &LHS,
                     const NodeExtensionHelper &RHS, SelectionDAG &DAG,
                     const RISCVSubtarget &Subtarget) {
   return canFoldToVWWithSameExtensionImpl(Root, LHS, RHS, /*AllowSExt=*/true,
-                                          /*AllowZExt=*/false, DAG, Subtarget);
+                                          /*AllowZExt=*/false,
+                                          /*AllowFPExt=*/false, DAG, Subtarget);
 }
 
 /// Check if \p Root follows a pattern Root(zext(LHS), zext(RHS))
@@ -13845,7 +13925,21 @@ canFoldToVWWithZEXT(SDNode *Root, const NodeExtensionHelper &LHS,
                     const NodeExtensionHelper &RHS, SelectionDAG &DAG,
                     const RISCVSubtarget &Subtarget) {
   return canFoldToVWWithSameExtensionImpl(Root, LHS, RHS, /*AllowSExt=*/false,
-                                          /*AllowZExt=*/true, DAG, Subtarget);
+                                          /*AllowZExt=*/true,
+                                          /*AllowFPExt=*/false, DAG, Subtarget);
+}
+
+/// Check if \p Root follows a pattern Root(fpext(LHS), fpext(RHS))
+///
+/// \returns std::nullopt if the pattern doesn't match or a CombineResult that
+/// can be used to apply the pattern.
+static std::optional<CombineResult>
+canFoldToVWWithFPEXT(SDNode *Root, const NodeExtensionHelper &LHS,
+                     const NodeExtensionHelper &RHS, SelectionDAG &DAG,
+                     const RISCVSubtarget &Subtarget) {
+  return canFoldToVWWithSameExtensionImpl(Root, LHS, RHS, /*AllowSExt=*/false,
+                                          /*AllowZExt=*/false,
+                                          /*AllowFPExt=*/true, DAG, Subtarget);
 }
 
 /// Check if \p Root follows a pattern Root(sext(LHS), zext(RHS))
@@ -13863,7 +13957,8 @@ canFoldToVW_SU(SDNode *Root, const NodeExtensionHelper &LHS,
       !RHS.areVLAndMaskCompatible(Root, DAG, Subtarget))
     return std::nullopt;
   return CombineResult(NodeExtensionHelper::getSUOpcode(Root->getOpcode()),
-                       Root, LHS, /*SExtLHS=*/true, RHS, /*SExtRHS=*/false);
+                       Root, LHS, /*LHSExt=*/{ExtKind::SExt}, RHS,
+                       /*RHSExt=*/{ExtKind::ZExt});
 }
 
 SmallVector<NodeExtensionHelper::CombineToTry>
@@ -13874,11 +13969,16 @@ NodeExtensionHelper::getSupportedFoldings(const SDNode *Root) {
   case ISD::SUB:
   case RISCVISD::ADD_VL:
   case RISCVISD::SUB_VL:
-    // add|sub -> vwadd(u)|vwsub(u)
+  case RISCVISD::FADD_VL:
+  case RISCVISD::FSUB_VL:
+    // add|sub|fadd|fsub-> vwadd(u)|vwsub(u)|vfwadd|vfwsub
     Strategies.push_back(canFoldToVWWithSameExtension);
-    // add|sub -> vwadd(u)_w|vwsub(u)_w
+    // add|sub|fadd|fsub -> vwadd(u)_w|vwsub(u)_w}|vfwadd_w|vfwsub_w
     Strategies.push_back(canFoldToVW_W);
     break;
+  case RISCVISD::FMUL_VL:
+    Strategies.push_back(canFoldToVWWithSameExtension);
+    break;
   case ISD::MUL:
   case RISCVISD::MUL_VL:
     // mul -> vwmul(u)
@@ -13896,6 +13996,11 @@ NodeExtensionHelper::getSupportedFoldings(const SDNode *Root) {
     // vwaddu_w|vwsubu_w -> vwaddu|vwsubu
     Strategies.push_back(canFoldToVWWithZEXT);
     break;
+  case RISCVISD::VFWADD_W_VL:
+  case RISCVISD::VFWSUB_W_VL:
+    // vfwadd_w|vfwsub_w -> vfwadd|vfwsub
+    Strategies.push_back(canFoldToVWWithFPEXT);
+    break;
   default:
     llvm_unreachable("Unexpected opcode");
   }
@@ -13908,8 +14013,13 @@ NodeExtensionHelper::getSupportedFoldings(const SDNode *Root) {
 /// add_vl -> vwadd(u) | vwadd(u)_w
 /// sub_vl -> vwsub(u) | vwsub(u)_w
 /// mul_vl -> vwmul(u) | vwmul_su
+/// fadd_vl ->  vfwadd | vfwadd_w
+/// fsub_vl ->  vfwsub | vfwsub_w
+/// fmul_vl ->  vfwmul
 /// vwadd_w(u) -> vwadd(u)
-/// vwub_w(u) -> vwadd(u)
+/// vwsub_w(u) -> vwsub(u)
+/// vfwadd_w -> vfwadd
+/// vfwsub_w -> vfwsub
 static SDValue combineBinOp_VLToVWBinOp_VL(SDNode *N,
                                            TargetLowering::DAGCombinerInfo &DCI,
                                            const RISCVSubtarget &Subtarget) {
@@ -13965,9 +14075,9 @@ static SDValue combineBinOp_VLToVWBinOp_VL(SDNode *N,
           // All the inputs that are extended need to be folded, otherwise
           // we would be leaving the old input (since it is may still be used),
           // and the new one.
-          if (Res->SExtLHS.has_value())
+          if (Res->LHSExt.has_value())
             AppendUsersIfNeeded(LHS);
-          if (Res->SExtRHS.has_value())
+          if (Res->RHSExt.has_value())
             AppendUsersIfNeeded(RHS);
           break;
         }
@@ -14532,107 +14642,6 @@ static SDValue performVFMADD_VLCombine(SDNode *N, SelectionDAG &DAG,
                      N->getOperand(2), Mask, VL);
 }
 
-static SDValue performVFMUL_VLCombine(SDNode *N, SelectionDAG &DAG,
-                                      const RISCVSubtarget &Subtarget) {
-  if (N->getValueType(0).isScalableVector() &&
-      N->getValueType(0).getVectorElementType() == MVT::f32 &&
-      (Subtarget.hasVInstructionsF16Minimal() &&
-       !Subtarget.hasVInstructionsF16())) {
-    return SDValue();
-  }
-
-  // FIXME: Ignore strict opcodes for now.
-  assert(!N->isTargetStrictFPOpcode() && "Unexpected opcode");
-
-  // Try to form widening multiply.
-  SDValue Op0 = N->getOperand(0);
-  SDValue Op1 = N->getOperand(1);
-  SDValue Merge = N->getOperand(2);
-  SDValue Mask = N->getOperand(3);
-  SDValue VL = N->getOperand(4);
-
-  if (Op0.getOpcode() != RISCVISD::FP_EXTEND_VL ||
-      Op1.getOpcode() != RISCVISD::FP_EXTEND_VL)
-    return SDValue();
-
-  // TODO: Refactor to handle more complex cases similar to
-  // combineBinOp_VLToVWBinOp_VL.
-  if ((!Op0.hasOneUse() || !Op1.hasOneUse()) &&
-      (Op0 != Op1 || !Op0->hasNUsesOfValue(2, 0)))
-    return SDValue();
-
-  // Check the mask and VL are the same.
-  if (Op0.getOperand(1) != Mask || Op0.getOperand(2) != VL ||
-      Op1.getOperand(1) != Mask || Op1.getOperand(2) != VL)
-    return SDValue();
-
-  Op0 = Op0.getOperand(0);
-  Op1 = Op1.getOperand(0);
-
-  return DAG.getNode(RISCVISD::VFWMUL_VL, SDLoc(N), N->getValueType(0), Op0,
-                     Op1, Merge, Mask, VL);
-}
-
-static SDValue performFADDSUB_VLCombine(SDNode *N, SelectionDAG &DAG,
-                                        const RISCVSubtarget &Subtarget) {
-  if (N->getValueType(0).isScalableVector() &&
-      N->getValueType(0).getVectorElementType() == MVT::f32 &&
-      (Subtarget.hasVInstructionsF16Minimal() &&
-       !Subtarget.hasVInstructionsF16())) {
-    return SDValue();
-  }
-
-  SDValue Op0 = N->getOperand(0);
-  SDValue Op1 = N->getOperand(1);
-  SDValue Merge = N->getOperand(2);
-  SDValue Mask = N->getOperand(3);
-  SDValue VL = N->getOperand(4);
-
-  bool IsAdd = N->getOpcode() == RISCVISD::FADD_VL;
-
-  // Look for foldable FP_EXTENDS.
-  bool Op0IsExtend =
-      Op0.getOpcode() == RISCVISD::FP_EXTEND_VL &&
-      (Op0.hasOneUse() || (Op0 == Op1 && Op0->hasNUsesOfValue(2, 0)));
-  bool Op1IsExtend =
-      (Op0 == Op1 && Op0IsExtend) ||
-      (Op1.getOpcode() == RISCVISD::FP_EXTEND_VL && Op1.hasOneUse());
-
-  // Check the mask and VL.
-  if (Op0IsExtend && (Op0.getOperand(1) != Mask || Op0.getOperand(2) != VL))
-    Op0IsExtend = false;
-  if (Op1IsExtend && (Op1.getOperand(1) != Mask || Op1.getOperand(2) != VL))
-    Op1IsExtend = false;
-
-  // Canonicalize.
-  if (!Op1IsExtend) {
-    // Sub requires at least operand 1 to be an extend.
-    if (!IsAdd)
-      return SDValue();
-
-    // Add is commutable, if the other operand is foldable, swap them.
-    if (!Op0IsExtend)
-      return SDValue();
-
-    std::swap(Op0, Op1);
-    std::swap(Op0IsExtend, Op1IsExtend);
-  }
-
-  // Op1 is a foldable extend. Op0 might be foldable.
-  Op1 = Op1.getOperand(0);
-  if (Op0IsExtend)
-    Op0 = Op0.getOperand(0);
-
-  unsigned Opc;
-  if (IsAdd)
-    Opc = Op0IsExtend ? RISCVISD::VFWADD_VL : RISCVISD::VFWADD_W_VL;
-  else
-    Opc = Op0IsExtend ? RISCVISD::VFWSUB_VL : RISCVISD::VFWSUB_W_VL;
-
-  return DAG.getNode(Opc, SDLoc(N), N->getValueType(0), Op0, Op1, Merge, Mask,
-                     VL);
-}
-
 static SDValue performSRACombine(SDNode *N, SelectionDAG &DAG,
                                  const RISCVSubtarget &Subtarget) {
   assert(N->getOpcode() == ISD::SRA && "Unexpected opcode");
@@ -16165,11 +16174,18 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N,
   case RISCVISD::STRICT_VFMSUB_VL:
   case RISCVISD::STRICT_VFNMSUB_VL:
     return performVFMADD_VLCombine(N, DAG, Subtarget);
-  case RISCVISD::FMUL_VL:
-    return performVFMUL_VLCombine(N, DAG, Subtarget);
   case RISCVISD::FADD_VL:
   case RISCVISD::FSUB_VL:
-    return performFADDSUB_VLCombine(N, DAG, Subtarget);
+  case RISCVISD::FMUL_VL:
+  case RISCVISD::VFWADD_W_VL:
+  case RISCVISD::VFWSUB_W_VL: {
+    if (N->getValueType(0).isScalableVector() &&
+        N->getValueType(0).getVectorElementType() == MVT::f32 &&
+        (Subtarget.hasVInstructionsF16Minimal() &&
+         !Subtarget.hasVInstructionsF16()))
+      return SDValue();
+    return combineBinOp_VLToVWBinOp_VL(N, DCI, Subtarget);
+  }
   case ISD::LOAD:
   case ISD::STORE: {
     if (DCI.isAfterLegalizeDAG())
diff --git a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vfw-web-simplification.ll b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vfw-web-simplification.ll
new file mode 100644
index 00000000000000..26f77225dbb0e1
--- /dev/null
+++ b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vfw-web-simplification.ll
@@ -0,0 +1,88 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py
+; RUN: llc -mtriple=riscv64 -mattr=+v,+zfh,+zvfh,+f,+d -verify-machineinstrs %s -o - --riscv-lower-ext-max-web-size=1 | FileCheck %s --check-prefixes=NO_FOLDING
+; RUN: llc -mtriple=riscv64 -mattr=+v,+zfh,+zvfh,+f,+d -verify-machineinstrs %s -o - --riscv-lower-ext-max-web-size=2 | FileCheck %s --check-prefixes=NO_FOLDING
+; RUN: llc -mtriple=riscv64 -mattr=+v,+zfh,+zvfh,+f,+d -verify-machineinstrs %s -o - --riscv-lower-ext-max-web-size=3 | FileCheck %s --check-prefixes=FOLDING
+; RUN: llc -mtriple=riscv64 -mattr=+v,+zfh,+zvfhmin,+f,+d -verify-machineinstrs %s -o - --riscv-lower-ext-max-web-size=3 | FileCheck %s --check-prefixes=FOLDING,ZVFHMIN
+; Check that the default value enables the web folding and
+; that it is bigger than 3.
+; RUN: llc -mtriple=riscv64 -mattr=+v,+zfh,+zvfh,+f,+d -verify-machineinstrs %s -o - | FileCheck %s --check-prefixes=FOLDING
+
+define void @vfwmul_v2f116_multiple_users(ptr %x, ptr %y, ptr %z, <2 x half> %a, <2 x half> %b, <2 x half> %b2) {
+; NO_FOLDING-LABEL: vfwmul_v2f116_multiple_users:
+; NO_FOLDING:       # %bb.0:
+; NO_FOLDING-NEXT:    vsetivli zero, 2, e16, mf4, ta, ma
+; NO_FOLDING-NEXT:    vfwcvt.f.f.v v11, v8
+; NO_FOLDING-NEXT:    vfwcvt.f.f.v v8, v9
+; NO_FOLDING-NEXT:    vfwcvt.f.f.v v9, v10
+; NO_FOLDING-NEXT:    vsetvli zero, zero, e32, mf2, ta, ma
+; NO_FOLDING-NEXT:    vfmul.vv v10, v11, v8
+; NO_FOLDING-NEXT:    vfadd.vv v11, v11, v9
+; NO_FOLDING-NEXT:    vfsub.vv v8, v8, v9
+; NO_FOLDING-NEXT:    vse32.v v10, (a0)
+; NO_FOLDING-NEXT:    vse32.v v11, (a1)
+; NO_FOLDING-NEXT:    vse32.v v8, (a2)
+; NO_FOLDING-NEXT:    ret
+;
+; ZVFHMIN-LABEL: vfwmul_v2f116_multiple_users:
+; ZVFHMIN:       # %bb.0:
+; ZVFHMIN-NEXT:    vsetivli zero, 2, e16, mf4, ta, ma
+; ZVFHMIN-NEXT:    vfwcvt.f.f.v v11, v8
+; ZVFHMIN-NEXT:    vfwcvt.f.f.v v8, v9
+; ZVFHMIN-NEXT:    vfwcvt.f.f.v v9, v10
+; ZVFHMIN-NEXT:    vsetvli zero, zero, e32, mf2, ta, ma
+; ZVFHMIN-NEXT:    vfmul.vv v10, v11, v8
+; ZVFHMIN-NEXT:    vfadd.vv v11, v11, v9
+; ZVFHMIN-NEXT:    vfsub.vv v8, v8, v9
+; ZVFHMIN-NEXT:    vse32.v v10, (a0)
+; ZVFHMIN-NEXT:    vse32.v v11, (a1)
+; ZVFHMIN-NEXT:    vse32.v v8, (a2)
+; ZVFHMIN-NEXT:    ret
+  %c = fpext <2 x half> %a to <2 x float>
+  %d = fpext <2 x half> %b to <2 x float>
+  %d2 = fpext <2 x half> %b2 to <2 x float>
+  %e = fmul <2 x float> %c, %d
+  %f = fadd <2 x float> %c, %d2
+  %g = fsub <2 x float> %d, %d2
+  store <2 x float> %e, ptr %x
+  store <2 x float> %f, ptr %y
+  store <2 x float> %g, ptr %z
+  ret void
+}
+
+define void @vfwmul_v2f32_multiple_users(ptr %x, ptr %y, ptr %z, <2 x float> %a, <2 x float> %b, <2 x float> %b2) {
+; NO_FOLDING-LABEL: vfwmul_v2f32_multiple_users:
+; NO_FOLDING:       # %bb.0:
+; NO_FOLDING-NEXT:    vsetivli zero, 2, e32, mf2, ta, ma
+; NO_FOLDING-NEXT:    vfwcvt.f.f.v v11, v8
+; NO_FOLDING-NEXT:    vfwcvt.f.f.v v8, v9
+; NO_FOLDING-NEXT:    vfwcvt.f.f.v v9, v10
+; NO_FOLDING-NEXT:    vsetvli zero, zero, e64, m1, ta, ma
+; NO_FOLDING-NEXT:    vfmul.vv v10, v11, v8
+; NO_FOLDING-NEXT:    vfadd.vv v11, v11, v9
+; NO_FOLDING-NEXT:    vfsub.vv v8, v8, v9
+; NO_FOLDING-NEXT:    vse64.v v10, (a0)
+; NO_FOLDING-NEXT:    vse64.v v11, (a1)
+; NO_FOLDING-NEXT:    vse64.v v8, (a2)
+; NO_FOLDING-NEXT:    ret
+;
+; FOLDING-LABEL: vfwmul_v2f32_multiple_users:
+; FOLDING:       # %bb.0:
+; FOLDING-NEXT:    vsetivli zero, 2, e32, mf2, ta, ma
+; FOLDING-NEXT:    vfwmul.vv v11, v8, v9
+; FOLDING-NEXT:    vfwadd.vv v12, v8, v10
+; FOLDING-NEXT:    vfwsub.vv v8, v9, v10
+; FOLDING-NEXT:    vse64.v v11, (a0)
+; FOLDING-NEXT:    vse64.v v12, (a1)
+; FOLDING-NEXT:    vse64.v v8, (a2)
+; FOLDING-NEXT:    ret
+  %c = fpext <2 x float> %a to <2 x double>
+  %d = fpext <2 x float> %b to <2 x double>
+  %d2 = fpext <2 x float> %b2 to <2 x double>
+  %e = fmul <2 x double> %c, %d
+  %f = fadd <2 x double> %c, %d2
+  %g = fsub <2 x double> %d, %d2
+  store <2 x double> %e, ptr %x
+  store <2 x double> %f, ptr %y
+  store <2 x double> %g, ptr %z
+  ret void
+}
diff --git a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vfwadd.ll b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vfwadd.ll
index c9dc75e18774f8..dd3a50cfd77377 100644
--- a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vfwadd.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vfwadd.ll
@@ -396,12 +396,10 @@ define <32 x double> @vfwadd_vf_v32f32(ptr %x, float %y) {
 ; CHECK-NEXT:    vsetvli zero, a1, e32, m8, ta, ma
 ; CHECK-NEXT:    vle32.v v24, (a0)
 ; CHECK-NEXT:    vsetivli zero, 16, e32, m8, ta, ma
-; CHECK-NEXT:    vslidedown.vi v0, v24, 16
+; CHECK-NEXT:    vslidedown.vi v8, v24, 16
 ; CHECK-NEXT:    vsetivli zero, 16, e32, m4, ta, ma
-; CHECK-NEXT:    vfmv.v.f v16, fa0
-; CHECK-NEXT:    vfwcvt.f.f.v v8, v16
-; CHECK-NEXT:    vfwadd.wv v16, v8, v0
-; CHECK-NEXT:    vfwadd.wv v8, v8, v24
+; CHECK-NEXT:    vfwadd.vf v16, v8, fa0
+; CHECK-NEXT:    vfwadd.vf v8, v24, fa0
 ; CHECK-NEXT:    ret
   %a = load <32 x float>, ptr %x
   %b = insertelement <32 x float> poison, float %y, i32 0
diff --git a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vfwmul.ll b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vfwmul.ll
index 8ad858d4c76598..7eaa1856ce2218 100644
--- a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vfwmul.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vfwmul.ll
@@ -394,18 +394,12 @@ define <32 x double> @vfwmul_vf_v32f32(ptr %x, float %y) {
 ; CHECK:       # %bb.0:
 ; CHECK-NEXT:    li a1, 32
 ; CHECK-NEXT:    vsetvli zero, a1, e32, m8, ta, ma
-; CHECK-NEXT:    vle32.v v16, (a0)
-; CHECK-NEXT:    vsetivli zero, 16, e32, m4, ta, ma
-; CHECK-NEXT:    vfwcvt.f.f.v v8, v16
+; CHECK-NEXT:    vle32.v v24, (a0)
 ; CHECK-NEXT:    vsetivli zero, 16, e32, m8, ta, ma
-; CHECK-NEXT:    vslidedown.vi v16, v16, 16
+; CHECK-NEXT:    vslidedown.vi v8, v24, 16
 ; CHECK-NEXT:    vsetivli zero, 16, e32, m4, ta, ma
-; CHECK-NEXT:    vfwcvt.f.f.v v24, v16
-; CHECK-NEXT:    vfmv.v.f v16, fa0
-; CHECK-NEXT:    vfwcvt.f.f.v v0, v16
-; CHECK-NEXT:    vsetvli zero, zero, e64, m8, ta, ma
-; CHECK-NEXT:    vfmul.vv v16, v24, v0
-; CHECK-NEXT:    vfmul.vv v8, v8, v0
+; CHECK-NEXT:    vfwmul.vf v16, v8, fa0
+; CHECK-NEXT:    vfwmul.vf v8, v24, fa0
 ; CHECK-NEXT:    ret
   %a = load <32 x float>, ptr %x
   %b = insertelement <32 x float> poison, float %y, i32 0
diff --git a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vfwsub.ll b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vfwsub.ll
index d22781d6a97ac2..8cf7c5f1758654 100644
--- a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vfwsub.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vfwsub.ll
@@ -394,18 +394,12 @@ define <32 x double> @vfwsub_vf_v32f32(ptr %x, float %y) {
 ; CHECK:       # %bb.0:
 ; CHECK-NEXT:    li a1, 32
 ; CHECK-NEXT:    vsetvli zero, a1, e32, m8, ta, ma
-; CHECK-NEXT:    vle32.v v16, (a0)
-; CHECK-NEXT:    vsetivli zero, 16, e32, m4, ta, ma
-; CHECK-NEXT:    vfwcvt.f.f.v v8, v16
+; CHECK-NEXT:    vle32.v v24, (a0)
 ; CHECK-NEXT:    vsetivli zero, 16, e32, m8, ta, ma
-; CHECK-NEXT:    vslidedown.vi v16, v16, 16
+; CHECK-NEXT:    vslidedown.vi v8, v24, 16
 ; CHECK-NEXT:    vsetivli zero, 16, e32, m4, ta, ma
-; CHECK-NEXT:    vfwcvt.f.f.v v24, v16
-; CHECK-NEXT:    vfmv.v.f v16, fa0
-; CHECK-NEXT:    vfwcvt.f.f.v v0, v16
-; CHECK-NEXT:    vsetvli zero, zero, e64, m8, ta, ma
-; CHECK-NEXT:    vfsub.vv v16, v24, v0
-; CHECK-NEXT:    vfsub.vv v8, v8, v0
+; CHECK-NEXT:    vfwsub.vf v16, v8, fa0
+; CHECK-NEXT:    vfwsub.vf v8, v24, fa0
 ; CHECK-NEXT:    ret
   %a = load <32 x float>, ptr %x
   %b = insertelement <32 x float> poison, float %y, i32 0

>From f0e6c8b2f91c2476d8ddc2d44f611881db142791 Mon Sep 17 00:00:00 2001
From: sun-jacobi <sun1011jacobi at gmail.com>
Date: Sun, 18 Feb 2024 21:45:10 +0900
Subject: [PATCH 2/3] add AllowExtMask

---
 llvm/lib/Target/RISCV/RISCVISelLowering.cpp | 47 +++++++++------------
 1 file changed, 20 insertions(+), 27 deletions(-)

diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index 2b302e94e1ed0c..55135adad8b5f3 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -13316,8 +13316,7 @@ namespace {
 // apply a combine.
 struct CombineResult;
 
-enum class ExtKind { ZExt, SExt, FPExt };
-
+enum ExtKind : uint8_t { ZExt = 1 << 0, SExt = 1 << 1, FPExt = 1 << 2 };
 /// Helper class for folding sign/zero extensions.
 /// In particular, this class is used for the following combines:
 /// add | add_vl -> vwadd(u) | vwadd(u)_w
@@ -13448,13 +13447,11 @@ struct NodeExtensionHelper {
     // Determine the narrow size.
     unsigned NarrowSize = VT.getScalarSizeInBits() / 2;
 
-    unsigned NarrowMinSize = SupportsExt == ExtKind::FPExt ? 16 : 8;
-
     MVT EltVT = SupportsExt == ExtKind::FPExt
                     ? MVT::getFloatingPointVT(NarrowSize)
                     : MVT::getIntegerVT(NarrowSize);
 
-    assert(NarrowSize >= NarrowMinSize &&
+    assert(NarrowSize >= (SupportsExt == ExtKind::FPExt ? 16 : 8) &&
            "Trying to extend something we can't represent");
     MVT NarrowVT = MVT::getVectorVT(EltVT, VT.getVectorElementCount());
     return NarrowVT;
@@ -13823,33 +13820,32 @@ struct CombineResult {
 /// Check if \p Root follows a pattern Root(ext(LHS), ext(RHS))
 /// where `ext` is the same for both LHS and RHS (i.e., both are sext or both
 /// are zext) and LHS and RHS can be folded into Root.
-/// AllowSExt and AllozZExt define which form `ext` can take in this pattern.
+/// AllowExtMask define which form `ext` can take in this pattern.
 ///
 /// \note If the pattern can match with both zext and sext, the returned
 /// CombineResult will feature the zext result.
 ///
 /// \returns std::nullopt if the pattern doesn't match or a CombineResult that
 /// can be used to apply the pattern.
-static std::optional<CombineResult> canFoldToVWWithSameExtensionImpl(
-    SDNode *Root, const NodeExtensionHelper &LHS,
-    const NodeExtensionHelper &RHS, bool AllowSExt, bool AllowZExt,
-    bool AllowFPExt, SelectionDAG &DAG, const RISCVSubtarget &Subtarget) {
-  assert((AllowSExt || AllowZExt || AllowFPExt) &&
-         "Forgot to set what you want?");
+static std::optional<CombineResult>
+canFoldToVWWithSameExtensionImpl(SDNode *Root, const NodeExtensionHelper &LHS,
+                                 const NodeExtensionHelper &RHS,
+                                 uint8_t AllowExtMask, SelectionDAG &DAG,
+                                 const RISCVSubtarget &Subtarget) {
   if (!LHS.areVLAndMaskCompatible(Root, DAG, Subtarget) ||
       !RHS.areVLAndMaskCompatible(Root, DAG, Subtarget))
     return std::nullopt;
-  if (AllowZExt && LHS.SupportsZExt && RHS.SupportsZExt)
+  if (AllowExtMask & ExtKind::ZExt && LHS.SupportsZExt && RHS.SupportsZExt)
     return CombineResult(NodeExtensionHelper::getSameExtensionOpcode(
                              Root->getOpcode(), ExtKind::ZExt),
                          Root, LHS, /*LHSExt=*/{ExtKind::ZExt}, RHS,
                          /*RHSExt=*/{ExtKind::ZExt});
-  if (AllowSExt && LHS.SupportsSExt && RHS.SupportsSExt)
+  if (AllowExtMask & ExtKind::SExt && LHS.SupportsSExt && RHS.SupportsSExt)
     return CombineResult(NodeExtensionHelper::getSameExtensionOpcode(
                              Root->getOpcode(), ExtKind::SExt),
                          Root, LHS, /*LHSExt=*/{ExtKind::SExt}, RHS,
                          /*RHSExt=*/{ExtKind::SExt});
-  if (AllowFPExt && LHS.SupportsFPExt && RHS.SupportsFPExt)
+  if (AllowExtMask & ExtKind::FPExt && RHS.SupportsFPExt)
     return CombineResult(NodeExtensionHelper::getSameExtensionOpcode(
                              Root->getOpcode(), ExtKind::FPExt),
                          Root, LHS, /*LHSExt=*/{ExtKind::FPExt}, RHS,
@@ -13867,9 +13863,9 @@ static std::optional<CombineResult>
 canFoldToVWWithSameExtension(SDNode *Root, const NodeExtensionHelper &LHS,
                              const NodeExtensionHelper &RHS, SelectionDAG &DAG,
                              const RISCVSubtarget &Subtarget) {
-  return canFoldToVWWithSameExtensionImpl(Root, LHS, RHS, /*AllowSExt=*/true,
-                                          /*AllowZExt=*/true,
-                                          /*AllowFPExt=*/true, DAG, Subtarget);
+  return canFoldToVWWithSameExtensionImpl(
+      Root, LHS, RHS, ExtKind::ZExt | ExtKind::SExt | ExtKind::FPExt, DAG,
+      Subtarget);
 }
 
 /// Check if \p Root follows a pattern Root(LHS, ext(RHS))
@@ -13911,9 +13907,8 @@ static std::optional<CombineResult>
 canFoldToVWWithSEXT(SDNode *Root, const NodeExtensionHelper &LHS,
                     const NodeExtensionHelper &RHS, SelectionDAG &DAG,
                     const RISCVSubtarget &Subtarget) {
-  return canFoldToVWWithSameExtensionImpl(Root, LHS, RHS, /*AllowSExt=*/true,
-                                          /*AllowZExt=*/false,
-                                          /*AllowFPExt=*/false, DAG, Subtarget);
+  return canFoldToVWWithSameExtensionImpl(Root, LHS, RHS, ExtKind::SExt, DAG,
+                                          Subtarget);
 }
 
 /// Check if \p Root follows a pattern Root(zext(LHS), zext(RHS))
@@ -13924,9 +13919,8 @@ static std::optional<CombineResult>
 canFoldToVWWithZEXT(SDNode *Root, const NodeExtensionHelper &LHS,
                     const NodeExtensionHelper &RHS, SelectionDAG &DAG,
                     const RISCVSubtarget &Subtarget) {
-  return canFoldToVWWithSameExtensionImpl(Root, LHS, RHS, /*AllowSExt=*/false,
-                                          /*AllowZExt=*/true,
-                                          /*AllowFPExt=*/false, DAG, Subtarget);
+  return canFoldToVWWithSameExtensionImpl(Root, LHS, RHS, ExtKind::ZExt, DAG,
+                                          Subtarget);
 }
 
 /// Check if \p Root follows a pattern Root(fpext(LHS), fpext(RHS))
@@ -13937,9 +13931,8 @@ static std::optional<CombineResult>
 canFoldToVWWithFPEXT(SDNode *Root, const NodeExtensionHelper &LHS,
                      const NodeExtensionHelper &RHS, SelectionDAG &DAG,
                      const RISCVSubtarget &Subtarget) {
-  return canFoldToVWWithSameExtensionImpl(Root, LHS, RHS, /*AllowSExt=*/false,
-                                          /*AllowZExt=*/false,
-                                          /*AllowFPExt=*/true, DAG, Subtarget);
+  return canFoldToVWWithSameExtensionImpl(Root, LHS, RHS, ExtKind::FPExt, DAG,
+                                          Subtarget);
 }
 
 /// Check if \p Root follows a pattern Root(sext(LHS), zext(RHS))

>From d94f74f574244fbf3654351309c6f62ebc7ef25e Mon Sep 17 00:00:00 2001
From: sun-jacobi <sun1011jacobi at gmail.com>
Date: Mon, 19 Feb 2024 15:28:53 +0900
Subject: [PATCH 3/3] split getSameExtensionOpcode

---
 llvm/lib/Target/RISCV/RISCVISelLowering.cpp | 65 ++++++++++++++-------
 1 file changed, 43 insertions(+), 22 deletions(-)

diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index 55135adad8b5f3..812bb26f201a00 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -13457,30 +13457,54 @@ struct NodeExtensionHelper {
     return NarrowVT;
   }
 
-  /// Return the opcode required to materialize the folding for
-  /// both operands for \p Opcode.
-  /// Put differently, get the opcode to materialize:
-  /// - ExtKind::SExt: \p Opcode(sext(a), sext(b)) -> newOpcode(a, b)
-  /// - ExtKind::ZExt: \p Opcode(zext(a), zext(b)) -> newOpcode(a, b)
-  /// \pre \p Opcode represents a supported root (\see ::isSupportedRoot()).
-  static unsigned getSameExtensionOpcode(unsigned Opcode, ExtKind SupportsExt) {
+  /// Get the opcode to materialize:
+  /// Opcode(sext(a), sext(b)) -> newOpcode(a, b)
+  static unsigned getSExtOpcode(unsigned Opcode) {
     switch (Opcode) {
     case ISD::ADD:
     case RISCVISD::ADD_VL:
     case RISCVISD::VWADD_W_VL:
     case RISCVISD::VWADDU_W_VL:
-      return SupportsExt == ExtKind::SExt ? RISCVISD::VWADD_VL
-                                          : RISCVISD::VWADDU_VL;
+      return RISCVISD::VWADD_VL;
+    case ISD::SUB:
+    case RISCVISD::SUB_VL:
+    case RISCVISD::VWSUB_W_VL:
+    case RISCVISD::VWSUBU_W_VL:
+      return RISCVISD::VWSUB_VL;
     case ISD::MUL:
     case RISCVISD::MUL_VL:
-      return SupportsExt == ExtKind::SExt ? RISCVISD::VWMUL_VL
-                                          : RISCVISD::VWMULU_VL;
+      return RISCVISD::VWMUL_VL;
+    default:
+      llvm_unreachable("Unexpected opcode");
+    }
+  }
+
+  /// Get the opcode to materialize:
+  /// Opcode(zext(a), zext(b)) -> newOpcode(a, b)
+  static unsigned getZExtOpcode(unsigned Opcode) {
+    switch (Opcode) {
+    case ISD::ADD:
+    case RISCVISD::ADD_VL:
+    case RISCVISD::VWADD_W_VL:
+    case RISCVISD::VWADDU_W_VL:
+      return RISCVISD::VWADDU_VL;
     case ISD::SUB:
     case RISCVISD::SUB_VL:
     case RISCVISD::VWSUB_W_VL:
     case RISCVISD::VWSUBU_W_VL:
-      return SupportsExt == ExtKind::SExt ? RISCVISD::VWSUB_VL
-                                          : RISCVISD::VWSUBU_VL;
+      return RISCVISD::VWSUBU_VL;
+    case ISD::MUL:
+    case RISCVISD::MUL_VL:
+      return RISCVISD::VWMULU_VL;
+    default:
+      llvm_unreachable("Unexpected opcode");
+    }
+  }
+
+  /// Get the opcode to materialize:
+  /// Opcode(fpext(a), fpext(b)) -> newOpcode(a, b)
+  static unsigned getFPExtOpcode(unsigned Opcode) {
+    switch (Opcode) {
     case RISCVISD::FADD_VL:
     case RISCVISD::VFWADD_W_VL:
       return RISCVISD::VFWADD_VL;
@@ -13835,19 +13859,16 @@ canFoldToVWWithSameExtensionImpl(SDNode *Root, const NodeExtensionHelper &LHS,
   if (!LHS.areVLAndMaskCompatible(Root, DAG, Subtarget) ||
       !RHS.areVLAndMaskCompatible(Root, DAG, Subtarget))
     return std::nullopt;
-  if (AllowExtMask & ExtKind::ZExt && LHS.SupportsZExt && RHS.SupportsZExt)
-    return CombineResult(NodeExtensionHelper::getSameExtensionOpcode(
-                             Root->getOpcode(), ExtKind::ZExt),
+  if ((AllowExtMask & ExtKind::ZExt) && LHS.SupportsZExt && RHS.SupportsZExt)
+    return CombineResult(NodeExtensionHelper::getZExtOpcode(Root->getOpcode()),
                          Root, LHS, /*LHSExt=*/{ExtKind::ZExt}, RHS,
                          /*RHSExt=*/{ExtKind::ZExt});
-  if (AllowExtMask & ExtKind::SExt && LHS.SupportsSExt && RHS.SupportsSExt)
-    return CombineResult(NodeExtensionHelper::getSameExtensionOpcode(
-                             Root->getOpcode(), ExtKind::SExt),
+  if ((AllowExtMask & ExtKind::SExt) && LHS.SupportsSExt && RHS.SupportsSExt)
+    return CombineResult(NodeExtensionHelper::getSExtOpcode(Root->getOpcode()),
                          Root, LHS, /*LHSExt=*/{ExtKind::SExt}, RHS,
                          /*RHSExt=*/{ExtKind::SExt});
-  if (AllowExtMask & ExtKind::FPExt && RHS.SupportsFPExt)
-    return CombineResult(NodeExtensionHelper::getSameExtensionOpcode(
-                             Root->getOpcode(), ExtKind::FPExt),
+  if ((AllowExtMask & ExtKind::FPExt) && RHS.SupportsFPExt)
+    return CombineResult(NodeExtensionHelper::getFPExtOpcode(Root->getOpcode()),
                          Root, LHS, /*LHSExt=*/{ExtKind::FPExt}, RHS,
                          /*RHSExt=*/{ExtKind::FPExt});
   return std::nullopt;



More information about the llvm-commits mailing list