[llvm] [clang-tools-extra] [clang] [RISCV][ISel] Combine scalable vector add/sub/mul with zero/sign extension (PR #72340)

via cfe-commits cfe-commits at lists.llvm.org
Mon Dec 11 00:36:28 PST 2023


https://github.com/sun-jacobi updated https://github.com/llvm/llvm-project/pull/72340

>From 6a4198b6120d8f25a4460622fb37a96bd4eb6304 Mon Sep 17 00:00:00 2001
From: sun-jacobi <sun1011jacobi at gmail.com>
Date: Wed, 15 Nov 2023 11:50:11 +0900
Subject: [PATCH 1/8] [RISCV] Combine non-fixed lenghth vector add/sub/mul with
 zero/sign extension

---
 llvm/lib/Target/RISCV/RISCVISelLowering.cpp | 232 +++++++++++++++-----
 1 file changed, 179 insertions(+), 53 deletions(-)

diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index a4cd8327f45f82..d777f9da1f9b81 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -1364,8 +1364,8 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
   setJumpIsExpensive();
 
   setTargetDAGCombine({ISD::INTRINSIC_VOID, ISD::INTRINSIC_W_CHAIN,
-                       ISD::INTRINSIC_WO_CHAIN, ISD::ADD, ISD::SUB, ISD::AND,
-                       ISD::OR, ISD::XOR, ISD::SETCC, ISD::SELECT});
+                       ISD::INTRINSIC_WO_CHAIN, ISD::ADD, ISD::SUB, ISD::MUL,
+                       ISD::AND, ISD::OR, ISD::XOR, ISD::SETCC, ISD::SELECT});
   if (Subtarget.is64Bit())
     setTargetDAGCombine(ISD::SRA);
 
@@ -12538,9 +12538,9 @@ struct CombineResult;
 
 /// Helper class for folding sign/zero extensions.
 /// In particular, this class is used for the following combines:
-/// add_vl -> vwadd(u) | vwadd(u)_w
-/// sub_vl -> vwsub(u) | vwsub(u)_w
-/// mul_vl -> vwmul(u) | vwmul_su
+/// add | add_vl -> vwadd(u) | vwadd(u)_w
+/// sub | sub_vl -> vwsub(u) | vwsub(u)_w
+/// mul | mul_vl -> vwmul(u) | vwmul_su
 ///
 /// An object of this class represents an operand of the operation we want to
 /// combine.
@@ -12585,6 +12585,8 @@ struct NodeExtensionHelper {
   /// E.g., for zext(a), this would return a.
   SDValue getSource() const {
     switch (OrigOperand.getOpcode()) {
+    case ISD::ZERO_EXTEND:
+    case ISD::SIGN_EXTEND:
     case RISCVISD::VSEXT_VL:
     case RISCVISD::VZEXT_VL:
       return OrigOperand.getOperand(0);
@@ -12601,7 +12603,8 @@ struct NodeExtensionHelper {
   /// Get or create a value that can feed \p Root with the given extension \p
   /// SExt. If \p SExt is std::nullopt, this returns the source of this operand.
   /// \see ::getSource().
-  SDValue getOrCreateExtendedOp(const SDNode *Root, SelectionDAG &DAG,
+  SDValue getOrCreateExtendedOp(SDNode *Root, SelectionDAG &DAG,
+                                const RISCVSubtarget &Subtarget,
                                 std::optional<bool> SExt) const {
     if (!SExt.has_value())
       return OrigOperand;
@@ -12616,8 +12619,10 @@ struct NodeExtensionHelper {
 
     // If we need an extension, we should be changing the type.
     SDLoc DL(Root);
-    auto [Mask, VL] = getMaskAndVL(Root);
+    auto [Mask, VL] = getMaskAndVL(Root, DAG, Subtarget);
     switch (OrigOperand.getOpcode()) {
+    case ISD::ZERO_EXTEND:
+    case ISD::SIGN_EXTEND:
     case RISCVISD::VSEXT_VL:
     case RISCVISD::VZEXT_VL:
       return DAG.getNode(ExtOpc, DL, NarrowVT, Source, Mask, VL);
@@ -12657,12 +12662,15 @@ struct NodeExtensionHelper {
   /// \pre \p Opcode represents a supported root (\see ::isSupportedRoot()).
   static unsigned getSameExtensionOpcode(unsigned Opcode, bool IsSExt) {
     switch (Opcode) {
+    case ISD::ADD:
     case RISCVISD::ADD_VL:
     case RISCVISD::VWADD_W_VL:
     case RISCVISD::VWADDU_W_VL:
       return IsSExt ? RISCVISD::VWADD_VL : RISCVISD::VWADDU_VL;
+    case ISD::MUL:
     case RISCVISD::MUL_VL:
       return IsSExt ? RISCVISD::VWMUL_VL : RISCVISD::VWMULU_VL;
+    case ISD::SUB:
     case RISCVISD::SUB_VL:
     case RISCVISD::VWSUB_W_VL:
     case RISCVISD::VWSUBU_W_VL:
@@ -12675,7 +12683,8 @@ struct NodeExtensionHelper {
   /// Get the opcode to materialize \p Opcode(sext(a), zext(b)) ->
   /// newOpcode(a, b).
   static unsigned getSUOpcode(unsigned Opcode) {
-    assert(Opcode == RISCVISD::MUL_VL && "SU is only supported for MUL");
+    assert((Opcode == RISCVISD::MUL_VL || Opcode == ISD::MUL) &&
+           "SU is only supported for MUL");
     return RISCVISD::VWMULSU_VL;
   }
 
@@ -12683,8 +12692,10 @@ struct NodeExtensionHelper {
   /// newOpcode(a, b).
   static unsigned getWOpcode(unsigned Opcode, bool IsSExt) {
     switch (Opcode) {
+    case ISD::ADD:
     case RISCVISD::ADD_VL:
       return IsSExt ? RISCVISD::VWADD_W_VL : RISCVISD::VWADDU_W_VL;
+    case ISD::SUB:
     case RISCVISD::SUB_VL:
       return IsSExt ? RISCVISD::VWSUB_W_VL : RISCVISD::VWSUBU_W_VL;
     default:
@@ -12694,19 +12705,45 @@ struct NodeExtensionHelper {
 
   using CombineToTry = std::function<std::optional<CombineResult>(
       SDNode * /*Root*/, const NodeExtensionHelper & /*LHS*/,
-      const NodeExtensionHelper & /*RHS*/)>;
+      const NodeExtensionHelper & /*RHS*/, SelectionDAG &,
+      const RISCVSubtarget &)>;
 
   /// Check if this node needs to be fully folded or extended for all users.
   bool needToPromoteOtherUsers() const { return EnforceOneUse; }
 
   /// Helper method to set the various fields of this struct based on the
   /// type of \p Root.
-  void fillUpExtensionSupport(SDNode *Root, SelectionDAG &DAG) {
+  void fillUpExtensionSupport(SDNode *Root, SelectionDAG &DAG,
+                              const RISCVSubtarget &Subtarget) {
     SupportsZExt = false;
     SupportsSExt = false;
     EnforceOneUse = true;
     CheckMask = true;
     switch (OrigOperand.getOpcode()) {
+    case ISD::ZERO_EXTEND: {
+      SupportsZExt = true;
+      SDLoc DL(Root);
+      MVT VT = Root->getSimpleValueType(0);
+      if (VT.isFixedLengthVector()) {
+        MVT ContainerVT = getContainerForFixedLengthVector(DAG, VT, Subtarget);
+        std::tie(Mask, VL) =
+            getDefaultVLOps(VT, ContainerVT, DL, DAG, Subtarget);
+      } else if (VT.isVector())
+        std::tie(Mask, VL) = getDefaultScalableVLOps(VT, DL, DAG, Subtarget);
+      break;
+    }
+    case ISD::SIGN_EXTEND: {
+      SupportsSExt = true;
+      SDLoc DL(Root);
+      MVT VT = Root->getSimpleValueType(0);
+      if (VT.isFixedLengthVector()) {
+        MVT ContainerVT = getContainerForFixedLengthVector(DAG, VT, Subtarget);
+        std::tie(Mask, VL) =
+            getDefaultVLOps(VT, ContainerVT, DL, DAG, Subtarget);
+      } else if (VT.isVector())
+        std::tie(Mask, VL) = getDefaultScalableVLOps(VT, DL, DAG, Subtarget);
+      break;
+    }
     case RISCVISD::VZEXT_VL:
       SupportsZExt = true;
       Mask = OrigOperand.getOperand(1);
@@ -12764,6 +12801,15 @@ struct NodeExtensionHelper {
   /// Check if \p Root supports any extension folding combines.
   static bool isSupportedRoot(const SDNode *Root) {
     switch (Root->getOpcode()) {
+    case ISD::ADD:
+    case ISD::SUB:
+    case ISD::MUL: {
+      EVT VT0 = Root->getOperand(0).getValueType();
+      EVT VT1 = Root->getOperand(1).getValueType();
+      if (VT0.isFixedLengthVector() || VT0.isFixedLengthVector())
+        return false;
+      return (VT0.isVector() || VT1.isVector());
+    }
     case RISCVISD::ADD_VL:
     case RISCVISD::MUL_VL:
     case RISCVISD::VWADD_W_VL:
@@ -12778,7 +12824,8 @@ struct NodeExtensionHelper {
   }
 
   /// Build a NodeExtensionHelper for \p Root.getOperand(\p OperandIdx).
-  NodeExtensionHelper(SDNode *Root, unsigned OperandIdx, SelectionDAG &DAG) {
+  NodeExtensionHelper(SDNode *Root, unsigned OperandIdx, SelectionDAG &DAG,
+                      const RISCVSubtarget &Subtarget) {
     assert(isSupportedRoot(Root) && "Trying to build an helper with an "
                                     "unsupported root");
     assert(OperandIdx < 2 && "Requesting something else than LHS or RHS");
@@ -12796,7 +12843,8 @@ struct NodeExtensionHelper {
         SupportsZExt =
             Opc == RISCVISD::VWADDU_W_VL || Opc == RISCVISD::VWSUBU_W_VL;
         SupportsSExt = !SupportsZExt;
-        std::tie(Mask, VL) = getMaskAndVL(Root);
+        Mask = Root->getOperand(3);
+        VL = Root->getOperand(4);
         CheckMask = true;
         // There's no existing extension here, so we don't have to worry about
         // making sure it gets removed.
@@ -12805,7 +12853,7 @@ struct NodeExtensionHelper {
       }
       [[fallthrough]];
     default:
-      fillUpExtensionSupport(Root, DAG);
+      fillUpExtensionSupport(Root, DAG, Subtarget);
       break;
     }
   }
@@ -12821,14 +12869,35 @@ struct NodeExtensionHelper {
   }
 
   /// Helper function to get the Mask and VL from \p Root.
-  static std::pair<SDValue, SDValue> getMaskAndVL(const SDNode *Root) {
+  static std::pair<SDValue, SDValue>
+  getMaskAndVL(const SDNode *Root, SelectionDAG &DAG,
+               const RISCVSubtarget &Subtarget) {
     assert(isSupportedRoot(Root) && "Unexpected root");
-    return std::make_pair(Root->getOperand(3), Root->getOperand(4));
+    switch (Root->getOpcode()) {
+    case ISD::ADD:
+    case ISD::SUB:
+    case ISD::MUL: {
+      SDLoc DL(Root);
+      MVT VT = Root->getSimpleValueType(0);
+      SDValue Mask, VL;
+      if (VT.isFixedLengthVector()) {
+        MVT ContainerVT = getContainerForFixedLengthVector(DAG, VT, Subtarget);
+        std::tie(Mask, VL) =
+            getDefaultVLOps(VT, ContainerVT, DL, DAG, Subtarget);
+      } else
+        std::tie(Mask, VL) = getDefaultScalableVLOps(VT, DL, DAG, Subtarget);
+      return std::make_pair(Mask, VL);
+    }
+
+    default:
+      return std::make_pair(Root->getOperand(3), Root->getOperand(4));
+    }
   }
 
   /// Check if the Mask and VL of this operand are compatible with \p Root.
-  bool areVLAndMaskCompatible(const SDNode *Root) const {
-    auto [Mask, VL] = getMaskAndVL(Root);
+  bool areVLAndMaskCompatible(SDNode *Root, SelectionDAG &DAG,
+                              const RISCVSubtarget &Subtarget) const {
+    auto [Mask, VL] = getMaskAndVL(Root, DAG, Subtarget);
     return isMaskCompatible(Mask) && isVLCompatible(VL);
   }
 
@@ -12836,11 +12905,14 @@ struct NodeExtensionHelper {
   /// foldings that are supported by this class.
   static bool isCommutative(const SDNode *N) {
     switch (N->getOpcode()) {
+    case ISD::ADD:
+    case ISD::MUL:
     case RISCVISD::ADD_VL:
     case RISCVISD::MUL_VL:
     case RISCVISD::VWADD_W_VL:
     case RISCVISD::VWADDU_W_VL:
       return true;
+    case ISD::SUB:
     case RISCVISD::SUB_VL:
     case RISCVISD::VWSUB_W_VL:
     case RISCVISD::VWSUBU_W_VL:
@@ -12885,14 +12957,25 @@ struct CombineResult {
   /// Return a value that uses TargetOpcode and that can be used to replace
   /// Root.
   /// The actual replacement is *not* done in that method.
-  SDValue materialize(SelectionDAG &DAG) const {
+  SDValue materialize(SelectionDAG &DAG,
+                      const RISCVSubtarget &Subtarget) const {
     SDValue Mask, VL, Merge;
-    std::tie(Mask, VL) = NodeExtensionHelper::getMaskAndVL(Root);
-    Merge = Root->getOperand(2);
+    std::tie(Mask, VL) =
+        NodeExtensionHelper::getMaskAndVL(Root, DAG, Subtarget);
+    switch (Root->getOpcode()) {
+    default:
+      Merge = Root->getOperand(2);
+      break;
+    case ISD::ADD:
+    case ISD::SUB:
+    case ISD::MUL:
+      Merge = DAG.getUNDEF(Root->getValueType(0));
+      break;
+    }
     return DAG.getNode(TargetOpcode, SDLoc(Root), Root->getValueType(0),
-                       LHS.getOrCreateExtendedOp(Root, DAG, SExtLHS),
-                       RHS.getOrCreateExtendedOp(Root, DAG, SExtRHS), Merge,
-                       Mask, VL);
+                       LHS.getOrCreateExtendedOp(Root, DAG, Subtarget, SExtLHS),
+                       RHS.getOrCreateExtendedOp(Root, DAG, Subtarget, SExtRHS),
+                       Merge, Mask, VL);
   }
 };
 
@@ -12909,15 +12992,16 @@ struct CombineResult {
 static std::optional<CombineResult>
 canFoldToVWWithSameExtensionImpl(SDNode *Root, const NodeExtensionHelper &LHS,
                                  const NodeExtensionHelper &RHS, bool AllowSExt,
-                                 bool AllowZExt) {
+                                 bool AllowZExt, SelectionDAG &DAG,
+                                 const RISCVSubtarget &Subtarget) {
   assert((AllowSExt || AllowZExt) && "Forgot to set what you want?");
-  if (!LHS.areVLAndMaskCompatible(Root) || !RHS.areVLAndMaskCompatible(Root))
+  if (!LHS.areVLAndMaskCompatible(Root, DAG, Subtarget) ||
+      !RHS.areVLAndMaskCompatible(Root, DAG, Subtarget))
     return std::nullopt;
   if (AllowZExt && LHS.SupportsZExt && RHS.SupportsZExt)
     return CombineResult(NodeExtensionHelper::getSameExtensionOpcode(
                              Root->getOpcode(), /*IsSExt=*/false),
-                         Root, LHS, /*SExtLHS=*/false, RHS,
-                         /*SExtRHS=*/false);
+                         Root, LHS, /*SExtLHS=*/false, RHS, /*SExtRHS=*/false);
   if (AllowSExt && LHS.SupportsSExt && RHS.SupportsSExt)
     return CombineResult(NodeExtensionHelper::getSameExtensionOpcode(
                              Root->getOpcode(), /*IsSExt=*/true),
@@ -12934,9 +13018,10 @@ canFoldToVWWithSameExtensionImpl(SDNode *Root, const NodeExtensionHelper &LHS,
 /// can be used to apply the pattern.
 static std::optional<CombineResult>
 canFoldToVWWithSameExtension(SDNode *Root, const NodeExtensionHelper &LHS,
-                             const NodeExtensionHelper &RHS) {
+                             const NodeExtensionHelper &RHS, SelectionDAG &DAG,
+                             const RISCVSubtarget &Subtarget) {
   return canFoldToVWWithSameExtensionImpl(Root, LHS, RHS, /*AllowSExt=*/true,
-                                          /*AllowZExt=*/true);
+                                          /*AllowZExt=*/true, DAG, Subtarget);
 }
 
 /// Check if \p Root follows a pattern Root(LHS, ext(RHS))
@@ -12945,8 +13030,9 @@ canFoldToVWWithSameExtension(SDNode *Root, const NodeExtensionHelper &LHS,
 /// can be used to apply the pattern.
 static std::optional<CombineResult>
 canFoldToVW_W(SDNode *Root, const NodeExtensionHelper &LHS,
-              const NodeExtensionHelper &RHS) {
-  if (!RHS.areVLAndMaskCompatible(Root))
+              const NodeExtensionHelper &RHS, SelectionDAG &DAG,
+              const RISCVSubtarget &Subtarget) {
+  if (!RHS.areVLAndMaskCompatible(Root, DAG, Subtarget))
     return std::nullopt;
 
   // FIXME: Is it useful to form a vwadd.wx or vwsub.wx if it removes a scalar
@@ -12970,9 +13056,10 @@ canFoldToVW_W(SDNode *Root, const NodeExtensionHelper &LHS,
 /// can be used to apply the pattern.
 static std::optional<CombineResult>
 canFoldToVWWithSEXT(SDNode *Root, const NodeExtensionHelper &LHS,
-                    const NodeExtensionHelper &RHS) {
+                    const NodeExtensionHelper &RHS, SelectionDAG &DAG,
+                    const RISCVSubtarget &Subtarget) {
   return canFoldToVWWithSameExtensionImpl(Root, LHS, RHS, /*AllowSExt=*/true,
-                                          /*AllowZExt=*/false);
+                                          /*AllowZExt=*/false, DAG, Subtarget);
 }
 
 /// Check if \p Root follows a pattern Root(zext(LHS), zext(RHS))
@@ -12981,9 +13068,10 @@ canFoldToVWWithSEXT(SDNode *Root, const NodeExtensionHelper &LHS,
 /// can be used to apply the pattern.
 static std::optional<CombineResult>
 canFoldToVWWithZEXT(SDNode *Root, const NodeExtensionHelper &LHS,
-                    const NodeExtensionHelper &RHS) {
+                    const NodeExtensionHelper &RHS, SelectionDAG &DAG,
+                    const RISCVSubtarget &Subtarget) {
   return canFoldToVWWithSameExtensionImpl(Root, LHS, RHS, /*AllowSExt=*/false,
-                                          /*AllowZExt=*/true);
+                                          /*AllowZExt=*/true, DAG, Subtarget);
 }
 
 /// Check if \p Root follows a pattern Root(sext(LHS), zext(RHS))
@@ -12992,10 +13080,13 @@ canFoldToVWWithZEXT(SDNode *Root, const NodeExtensionHelper &LHS,
 /// can be used to apply the pattern.
 static std::optional<CombineResult>
 canFoldToVW_SU(SDNode *Root, const NodeExtensionHelper &LHS,
-               const NodeExtensionHelper &RHS) {
+               const NodeExtensionHelper &RHS, SelectionDAG &DAG,
+               const RISCVSubtarget &Subtarget) {
+
   if (!LHS.SupportsSExt || !RHS.SupportsZExt)
     return std::nullopt;
-  if (!LHS.areVLAndMaskCompatible(Root) || !RHS.areVLAndMaskCompatible(Root))
+  if (!LHS.areVLAndMaskCompatible(Root, DAG, Subtarget) ||
+      !RHS.areVLAndMaskCompatible(Root, DAG, Subtarget))
     return std::nullopt;
   return CombineResult(NodeExtensionHelper::getSUOpcode(Root->getOpcode()),
                        Root, LHS, /*SExtLHS=*/true, RHS, /*SExtRHS=*/false);
@@ -13005,6 +13096,8 @@ SmallVector<NodeExtensionHelper::CombineToTry>
 NodeExtensionHelper::getSupportedFoldings(const SDNode *Root) {
   SmallVector<CombineToTry> Strategies;
   switch (Root->getOpcode()) {
+  case ISD::ADD:
+  case ISD::SUB:
   case RISCVISD::ADD_VL:
   case RISCVISD::SUB_VL:
     // add|sub -> vwadd(u)|vwsub(u)
@@ -13012,6 +13105,7 @@ NodeExtensionHelper::getSupportedFoldings(const SDNode *Root) {
     // add|sub -> vwadd(u)_w|vwsub(u)_w
     Strategies.push_back(canFoldToVW_W);
     break;
+  case ISD::MUL:
   case RISCVISD::MUL_VL:
     // mul -> vwmul(u)
     Strategies.push_back(canFoldToVWWithSameExtension);
@@ -13042,12 +13136,14 @@ NodeExtensionHelper::getSupportedFoldings(const SDNode *Root) {
 /// mul_vl -> vwmul(u) | vwmul_su
 /// vwadd_w(u) -> vwadd(u)
 /// vwub_w(u) -> vwadd(u)
-static SDValue
-combineBinOp_VLToVWBinOp_VL(SDNode *N, TargetLowering::DAGCombinerInfo &DCI) {
+static SDValue combineBinOp_VLToVWBinOp_VL(SDNode *N,
+                                           TargetLowering::DAGCombinerInfo &DCI,
+                                           const RISCVSubtarget &Subtarget) {
   SelectionDAG &DAG = DCI.DAG;
 
-  assert(NodeExtensionHelper::isSupportedRoot(N) &&
-         "Shouldn't have called this method");
+  if (!NodeExtensionHelper::isSupportedRoot(N))
+    return SDValue();
+
   SmallVector<SDNode *> Worklist;
   SmallSet<SDNode *, 8> Inserted;
   Worklist.push_back(N);
@@ -13059,8 +13155,8 @@ combineBinOp_VLToVWBinOp_VL(SDNode *N, TargetLowering::DAGCombinerInfo &DCI) {
     if (!NodeExtensionHelper::isSupportedRoot(Root))
       return SDValue();
 
-    NodeExtensionHelper LHS(N, 0, DAG);
-    NodeExtensionHelper RHS(N, 1, DAG);
+    NodeExtensionHelper LHS(N, 0, DAG, Subtarget);
+    NodeExtensionHelper RHS(N, 1, DAG, Subtarget);
     auto AppendUsersIfNeeded = [&Worklist,
                                 &Inserted](const NodeExtensionHelper &Op) {
       if (Op.needToPromoteOtherUsers()) {
@@ -13087,7 +13183,8 @@ combineBinOp_VLToVWBinOp_VL(SDNode *N, TargetLowering::DAGCombinerInfo &DCI) {
 
       for (NodeExtensionHelper::CombineToTry FoldingStrategy :
            FoldingStrategies) {
-        std::optional<CombineResult> Res = FoldingStrategy(N, LHS, RHS);
+        std::optional<CombineResult> Res =
+            FoldingStrategy(N, LHS, RHS, DAG, Subtarget);
         if (Res) {
           Matched = true;
           CombinesToApply.push_back(*Res);
@@ -13116,7 +13213,7 @@ combineBinOp_VLToVWBinOp_VL(SDNode *N, TargetLowering::DAGCombinerInfo &DCI) {
   SmallVector<std::pair<SDValue, SDValue>> ValuesToReplace;
   ValuesToReplace.reserve(CombinesToApply.size());
   for (CombineResult Res : CombinesToApply) {
-    SDValue NewValue = Res.materialize(DAG);
+    SDValue NewValue = Res.materialize(DAG, Subtarget);
     if (!InputRootReplacement) {
       assert(Res.Root == N &&
              "First element is expected to be the current node");
@@ -14282,10 +14379,20 @@ static SDValue performCONCAT_VECTORSCombine(SDNode *N, SelectionDAG &DAG,
 
 static SDValue combineToVWMACC(SDNode *N, SelectionDAG &DAG,
                                const RISCVSubtarget &Subtarget) {
-  assert(N->getOpcode() == RISCVISD::ADD_VL);
+
+  assert(N->getOpcode() == RISCVISD::ADD_VL || N->getOpcode() == ISD::ADD);
+
+  if (N->getValueType(0).isFixedLengthVector())
+    return SDValue();
+
   SDValue Addend = N->getOperand(0);
   SDValue MulOp = N->getOperand(1);
-  SDValue AddMergeOp = N->getOperand(2);
+  SDValue AddMergeOp = [](SDNode *N, SelectionDAG &DAG) {
+    if (N->getOpcode() == ISD::ADD)
+      return DAG.getUNDEF(N->getValueType(0));
+    else
+      return N->getOperand(2);
+  }(N, DAG);
 
   if (!AddMergeOp.isUndef())
     return SDValue();
@@ -14312,8 +14419,17 @@ static SDValue combineToVWMACC(SDNode *N, SelectionDAG &DAG,
   if (!MulMergeOp.isUndef())
     return SDValue();
 
-  SDValue AddMask = N->getOperand(3);
-  SDValue AddVL = N->getOperand(4);
+  auto [AddMask, AddVL] = [](SDNode *N, SelectionDAG &DAG,
+                             const RISCVSubtarget &Subtarget) {
+    if (N->getOpcode() == ISD::ADD) {
+      SDLoc DL(N);
+      return getDefaultScalableVLOps(N->getSimpleValueType(0), DL, DAG,
+                                     Subtarget);
+    } else {
+      return std::make_pair(N->getOperand(3), N->getOperand(4));
+    }
+  }(N, DAG, Subtarget);
+
   SDValue MulMask = MulOp.getOperand(3);
   SDValue MulVL = MulOp.getOperand(4);
 
@@ -14579,10 +14695,20 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N,
     return DAG.getNode(ISD::AND, DL, VT, NewFMV,
                        DAG.getConstant(~SignBit, DL, VT));
   }
-  case ISD::ADD:
+  case ISD::MUL:
+    return combineBinOp_VLToVWBinOp_VL(N, DCI, Subtarget);
+  case ISD::ADD: {
+    if (SDValue V = combineBinOp_VLToVWBinOp_VL(N, DCI, Subtarget))
+      return V;
+    if (SDValue V = combineToVWMACC(N, DAG, Subtarget))
+      return V;
     return performADDCombine(N, DAG, Subtarget);
-  case ISD::SUB:
+  }
+  case ISD::SUB: {
+    if (SDValue V = combineBinOp_VLToVWBinOp_VL(N, DCI, Subtarget))
+      return V;
     return performSUBCombine(N, DAG, Subtarget);
+  }
   case ISD::AND:
     return performANDCombine(N, DCI, Subtarget);
   case ISD::OR:
@@ -15064,7 +15190,7 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N,
     break;
   }
   case RISCVISD::ADD_VL:
-    if (SDValue V = combineBinOp_VLToVWBinOp_VL(N, DCI))
+    if (SDValue V = combineBinOp_VLToVWBinOp_VL(N, DCI, Subtarget))
       return V;
     return combineToVWMACC(N, DAG, Subtarget);
   case RISCVISD::SUB_VL:
@@ -15073,7 +15199,7 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N,
   case RISCVISD::VWSUB_W_VL:
   case RISCVISD::VWSUBU_W_VL:
   case RISCVISD::MUL_VL:
-    return combineBinOp_VLToVWBinOp_VL(N, DCI);
+    return combineBinOp_VLToVWBinOp_VL(N, DCI, Subtarget);
   case RISCVISD::VFMADD_VL:
   case RISCVISD::VFNMADD_VL:
   case RISCVISD::VFMSUB_VL:

>From cb29e4b3e5e0fc315d2efc276f4781b1f643015b Mon Sep 17 00:00:00 2001
From: sun-jacobi <sun1011jacobi at gmail.com>
Date: Wed, 15 Nov 2023 11:51:42 +0900
Subject: [PATCH 2/8] [RISCV] update ctlz-sdnode test for optimization

---
 llvm/test/CodeGen/RISCV/rvv/ctlz-sdnode.ll | 128 +++++++++++----------
 1 file changed, 68 insertions(+), 60 deletions(-)

diff --git a/llvm/test/CodeGen/RISCV/rvv/ctlz-sdnode.ll b/llvm/test/CodeGen/RISCV/rvv/ctlz-sdnode.ll
index d78d67d5e35987..763f72f6e89b9d 100644
--- a/llvm/test/CodeGen/RISCV/rvv/ctlz-sdnode.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/ctlz-sdnode.ll
@@ -1231,16 +1231,17 @@ define <vscale x 1 x i64> @ctlz_nxv1i64(<vscale x 1 x i64> %va) {
 ;
 ; CHECK-F-LABEL: ctlz_nxv1i64:
 ; CHECK-F:       # %bb.0:
+; CHECK-F-NEXT:    li a0, 190
+; CHECK-F-NEXT:    vsetvli a1, zero, e64, m1, ta, ma
+; CHECK-F-NEXT:    vmv.v.x v9, a0
 ; CHECK-F-NEXT:    fsrmi a0, 1
-; CHECK-F-NEXT:    vsetvli a1, zero, e32, mf2, ta, ma
-; CHECK-F-NEXT:    vfncvt.f.xu.w v9, v8
-; CHECK-F-NEXT:    vsrl.vi v8, v9, 23
-; CHECK-F-NEXT:    vsetvli zero, zero, e64, m1, ta, ma
-; CHECK-F-NEXT:    vzext.vf2 v9, v8
-; CHECK-F-NEXT:    li a1, 190
-; CHECK-F-NEXT:    vrsub.vx v8, v9, a1
+; CHECK-F-NEXT:    vsetvli zero, zero, e32, mf2, ta, ma
+; CHECK-F-NEXT:    vfncvt.f.xu.w v10, v8
+; CHECK-F-NEXT:    vsrl.vi v8, v10, 23
+; CHECK-F-NEXT:    vwsubu.wv v9, v9, v8
 ; CHECK-F-NEXT:    li a1, 64
-; CHECK-F-NEXT:    vminu.vx v8, v8, a1
+; CHECK-F-NEXT:    vsetvli zero, zero, e64, m1, ta, ma
+; CHECK-F-NEXT:    vminu.vx v8, v9, a1
 ; CHECK-F-NEXT:    fsrm a0
 ; CHECK-F-NEXT:    ret
 ;
@@ -1371,16 +1372,17 @@ define <vscale x 2 x i64> @ctlz_nxv2i64(<vscale x 2 x i64> %va) {
 ;
 ; CHECK-F-LABEL: ctlz_nxv2i64:
 ; CHECK-F:       # %bb.0:
+; CHECK-F-NEXT:    li a0, 190
+; CHECK-F-NEXT:    vsetvli a1, zero, e64, m2, ta, ma
+; CHECK-F-NEXT:    vmv.v.x v10, a0
 ; CHECK-F-NEXT:    fsrmi a0, 1
-; CHECK-F-NEXT:    vsetvli a1, zero, e32, m1, ta, ma
-; CHECK-F-NEXT:    vfncvt.f.xu.w v10, v8
-; CHECK-F-NEXT:    vsrl.vi v8, v10, 23
-; CHECK-F-NEXT:    vsetvli zero, zero, e64, m2, ta, ma
-; CHECK-F-NEXT:    vzext.vf2 v10, v8
-; CHECK-F-NEXT:    li a1, 190
-; CHECK-F-NEXT:    vrsub.vx v8, v10, a1
+; CHECK-F-NEXT:    vsetvli zero, zero, e32, m1, ta, ma
+; CHECK-F-NEXT:    vfncvt.f.xu.w v12, v8
+; CHECK-F-NEXT:    vsrl.vi v8, v12, 23
+; CHECK-F-NEXT:    vwsubu.wv v10, v10, v8
 ; CHECK-F-NEXT:    li a1, 64
-; CHECK-F-NEXT:    vminu.vx v8, v8, a1
+; CHECK-F-NEXT:    vsetvli zero, zero, e64, m2, ta, ma
+; CHECK-F-NEXT:    vminu.vx v8, v10, a1
 ; CHECK-F-NEXT:    fsrm a0
 ; CHECK-F-NEXT:    ret
 ;
@@ -1511,16 +1513,17 @@ define <vscale x 4 x i64> @ctlz_nxv4i64(<vscale x 4 x i64> %va) {
 ;
 ; CHECK-F-LABEL: ctlz_nxv4i64:
 ; CHECK-F:       # %bb.0:
+; CHECK-F-NEXT:    li a0, 190
+; CHECK-F-NEXT:    vsetvli a1, zero, e64, m4, ta, ma
+; CHECK-F-NEXT:    vmv.v.x v12, a0
 ; CHECK-F-NEXT:    fsrmi a0, 1
-; CHECK-F-NEXT:    vsetvli a1, zero, e32, m2, ta, ma
-; CHECK-F-NEXT:    vfncvt.f.xu.w v12, v8
-; CHECK-F-NEXT:    vsrl.vi v8, v12, 23
-; CHECK-F-NEXT:    vsetvli zero, zero, e64, m4, ta, ma
-; CHECK-F-NEXT:    vzext.vf2 v12, v8
-; CHECK-F-NEXT:    li a1, 190
-; CHECK-F-NEXT:    vrsub.vx v8, v12, a1
+; CHECK-F-NEXT:    vsetvli zero, zero, e32, m2, ta, ma
+; CHECK-F-NEXT:    vfncvt.f.xu.w v16, v8
+; CHECK-F-NEXT:    vsrl.vi v8, v16, 23
+; CHECK-F-NEXT:    vwsubu.wv v12, v12, v8
 ; CHECK-F-NEXT:    li a1, 64
-; CHECK-F-NEXT:    vminu.vx v8, v8, a1
+; CHECK-F-NEXT:    vsetvli zero, zero, e64, m4, ta, ma
+; CHECK-F-NEXT:    vminu.vx v8, v12, a1
 ; CHECK-F-NEXT:    fsrm a0
 ; CHECK-F-NEXT:    ret
 ;
@@ -1651,16 +1654,17 @@ define <vscale x 8 x i64> @ctlz_nxv8i64(<vscale x 8 x i64> %va) {
 ;
 ; CHECK-F-LABEL: ctlz_nxv8i64:
 ; CHECK-F:       # %bb.0:
+; CHECK-F-NEXT:    li a0, 190
+; CHECK-F-NEXT:    vsetvli a1, zero, e64, m8, ta, ma
+; CHECK-F-NEXT:    vmv.v.x v16, a0
 ; CHECK-F-NEXT:    fsrmi a0, 1
-; CHECK-F-NEXT:    vsetvli a1, zero, e32, m4, ta, ma
-; CHECK-F-NEXT:    vfncvt.f.xu.w v16, v8
-; CHECK-F-NEXT:    vsrl.vi v8, v16, 23
-; CHECK-F-NEXT:    vsetvli zero, zero, e64, m8, ta, ma
-; CHECK-F-NEXT:    vzext.vf2 v16, v8
-; CHECK-F-NEXT:    li a1, 190
-; CHECK-F-NEXT:    vrsub.vx v8, v16, a1
+; CHECK-F-NEXT:    vsetvli zero, zero, e32, m4, ta, ma
+; CHECK-F-NEXT:    vfncvt.f.xu.w v24, v8
+; CHECK-F-NEXT:    vsrl.vi v8, v24, 23
+; CHECK-F-NEXT:    vwsubu.wv v16, v16, v8
 ; CHECK-F-NEXT:    li a1, 64
-; CHECK-F-NEXT:    vminu.vx v8, v8, a1
+; CHECK-F-NEXT:    vsetvli zero, zero, e64, m8, ta, ma
+; CHECK-F-NEXT:    vminu.vx v8, v16, a1
 ; CHECK-F-NEXT:    fsrm a0
 ; CHECK-F-NEXT:    ret
 ;
@@ -2833,15 +2837,16 @@ define <vscale x 1 x i64> @ctlz_zero_undef_nxv1i64(<vscale x 1 x i64> %va) {
 ;
 ; CHECK-F-LABEL: ctlz_zero_undef_nxv1i64:
 ; CHECK-F:       # %bb.0:
+; CHECK-F-NEXT:    li a0, 190
+; CHECK-F-NEXT:    vsetvli a1, zero, e64, m1, ta, ma
+; CHECK-F-NEXT:    vmv.v.x v9, a0
 ; CHECK-F-NEXT:    fsrmi a0, 1
-; CHECK-F-NEXT:    vsetvli a1, zero, e32, mf2, ta, ma
-; CHECK-F-NEXT:    vfncvt.f.xu.w v9, v8
-; CHECK-F-NEXT:    vsrl.vi v8, v9, 23
-; CHECK-F-NEXT:    vsetvli zero, zero, e64, m1, ta, ma
-; CHECK-F-NEXT:    vzext.vf2 v9, v8
-; CHECK-F-NEXT:    li a1, 190
-; CHECK-F-NEXT:    vrsub.vx v8, v9, a1
+; CHECK-F-NEXT:    vsetvli zero, zero, e32, mf2, ta, ma
+; CHECK-F-NEXT:    vfncvt.f.xu.w v10, v8
+; CHECK-F-NEXT:    vsrl.vi v8, v10, 23
+; CHECK-F-NEXT:    vwsubu.wv v9, v9, v8
 ; CHECK-F-NEXT:    fsrm a0
+; CHECK-F-NEXT:    vmv1r.v v8, v9
 ; CHECK-F-NEXT:    ret
 ;
 ; CHECK-D-LABEL: ctlz_zero_undef_nxv1i64:
@@ -2968,15 +2973,16 @@ define <vscale x 2 x i64> @ctlz_zero_undef_nxv2i64(<vscale x 2 x i64> %va) {
 ;
 ; CHECK-F-LABEL: ctlz_zero_undef_nxv2i64:
 ; CHECK-F:       # %bb.0:
+; CHECK-F-NEXT:    li a0, 190
+; CHECK-F-NEXT:    vsetvli a1, zero, e64, m2, ta, ma
+; CHECK-F-NEXT:    vmv.v.x v10, a0
 ; CHECK-F-NEXT:    fsrmi a0, 1
-; CHECK-F-NEXT:    vsetvli a1, zero, e32, m1, ta, ma
-; CHECK-F-NEXT:    vfncvt.f.xu.w v10, v8
-; CHECK-F-NEXT:    vsrl.vi v8, v10, 23
-; CHECK-F-NEXT:    vsetvli zero, zero, e64, m2, ta, ma
-; CHECK-F-NEXT:    vzext.vf2 v10, v8
-; CHECK-F-NEXT:    li a1, 190
-; CHECK-F-NEXT:    vrsub.vx v8, v10, a1
+; CHECK-F-NEXT:    vsetvli zero, zero, e32, m1, ta, ma
+; CHECK-F-NEXT:    vfncvt.f.xu.w v12, v8
+; CHECK-F-NEXT:    vsrl.vi v8, v12, 23
+; CHECK-F-NEXT:    vwsubu.wv v10, v10, v8
 ; CHECK-F-NEXT:    fsrm a0
+; CHECK-F-NEXT:    vmv2r.v v8, v10
 ; CHECK-F-NEXT:    ret
 ;
 ; CHECK-D-LABEL: ctlz_zero_undef_nxv2i64:
@@ -3103,15 +3109,16 @@ define <vscale x 4 x i64> @ctlz_zero_undef_nxv4i64(<vscale x 4 x i64> %va) {
 ;
 ; CHECK-F-LABEL: ctlz_zero_undef_nxv4i64:
 ; CHECK-F:       # %bb.0:
+; CHECK-F-NEXT:    li a0, 190
+; CHECK-F-NEXT:    vsetvli a1, zero, e64, m4, ta, ma
+; CHECK-F-NEXT:    vmv.v.x v12, a0
 ; CHECK-F-NEXT:    fsrmi a0, 1
-; CHECK-F-NEXT:    vsetvli a1, zero, e32, m2, ta, ma
-; CHECK-F-NEXT:    vfncvt.f.xu.w v12, v8
-; CHECK-F-NEXT:    vsrl.vi v8, v12, 23
-; CHECK-F-NEXT:    vsetvli zero, zero, e64, m4, ta, ma
-; CHECK-F-NEXT:    vzext.vf2 v12, v8
-; CHECK-F-NEXT:    li a1, 190
-; CHECK-F-NEXT:    vrsub.vx v8, v12, a1
+; CHECK-F-NEXT:    vsetvli zero, zero, e32, m2, ta, ma
+; CHECK-F-NEXT:    vfncvt.f.xu.w v16, v8
+; CHECK-F-NEXT:    vsrl.vi v8, v16, 23
+; CHECK-F-NEXT:    vwsubu.wv v12, v12, v8
 ; CHECK-F-NEXT:    fsrm a0
+; CHECK-F-NEXT:    vmv4r.v v8, v12
 ; CHECK-F-NEXT:    ret
 ;
 ; CHECK-D-LABEL: ctlz_zero_undef_nxv4i64:
@@ -3238,14 +3245,15 @@ define <vscale x 8 x i64> @ctlz_zero_undef_nxv8i64(<vscale x 8 x i64> %va) {
 ;
 ; CHECK-F-LABEL: ctlz_zero_undef_nxv8i64:
 ; CHECK-F:       # %bb.0:
+; CHECK-F-NEXT:    vmv8r.v v16, v8
+; CHECK-F-NEXT:    li a0, 190
+; CHECK-F-NEXT:    vsetvli a1, zero, e64, m8, ta, ma
+; CHECK-F-NEXT:    vmv.v.x v8, a0
 ; CHECK-F-NEXT:    fsrmi a0, 1
-; CHECK-F-NEXT:    vsetvli a1, zero, e32, m4, ta, ma
-; CHECK-F-NEXT:    vfncvt.f.xu.w v16, v8
-; CHECK-F-NEXT:    vsrl.vi v8, v16, 23
-; CHECK-F-NEXT:    vsetvli zero, zero, e64, m8, ta, ma
-; CHECK-F-NEXT:    vzext.vf2 v16, v8
-; CHECK-F-NEXT:    li a1, 190
-; CHECK-F-NEXT:    vrsub.vx v8, v16, a1
+; CHECK-F-NEXT:    vsetvli zero, zero, e32, m4, ta, ma
+; CHECK-F-NEXT:    vfncvt.f.xu.w v24, v16
+; CHECK-F-NEXT:    vsrl.vi v16, v24, 23
+; CHECK-F-NEXT:    vwsubu.wv v8, v8, v16
 ; CHECK-F-NEXT:    fsrm a0
 ; CHECK-F-NEXT:    ret
 ;

>From f5098679c24c1a2da5f7fd2541a105d423eda3e2 Mon Sep 17 00:00:00 2001
From: sun-jacobi <sun1011jacobi at gmail.com>
Date: Wed, 15 Nov 2023 23:02:39 +0900
Subject: [PATCH 3/8] [RISCV] refactor and fixed typo

---
 llvm/lib/Target/RISCV/RISCVISelLowering.cpp | 50 ++++++---------------
 1 file changed, 13 insertions(+), 37 deletions(-)

diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index d777f9da1f9b81..a6a48b17bfd098 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -12719,29 +12719,16 @@ struct NodeExtensionHelper {
     SupportsSExt = false;
     EnforceOneUse = true;
     CheckMask = true;
-    switch (OrigOperand.getOpcode()) {
-    case ISD::ZERO_EXTEND: {
-      SupportsZExt = true;
-      SDLoc DL(Root);
-      MVT VT = Root->getSimpleValueType(0);
-      if (VT.isFixedLengthVector()) {
-        MVT ContainerVT = getContainerForFixedLengthVector(DAG, VT, Subtarget);
-        std::tie(Mask, VL) =
-            getDefaultVLOps(VT, ContainerVT, DL, DAG, Subtarget);
-      } else if (VT.isVector())
-        std::tie(Mask, VL) = getDefaultScalableVLOps(VT, DL, DAG, Subtarget);
-      break;
-    }
+    unsigned Opc = OrigOperand.getOpcode();
+    switch (Opc) {
+    case ISD::ZERO_EXTEND:
     case ISD::SIGN_EXTEND: {
-      SupportsSExt = true;
-      SDLoc DL(Root);
-      MVT VT = Root->getSimpleValueType(0);
-      if (VT.isFixedLengthVector()) {
-        MVT ContainerVT = getContainerForFixedLengthVector(DAG, VT, Subtarget);
-        std::tie(Mask, VL) =
-            getDefaultVLOps(VT, ContainerVT, DL, DAG, Subtarget);
-      } else if (VT.isVector())
+      if (OrigOperand.getValueType().isVector()) {
+        SupportsZExt = Opc == ISD::ZERO_EXTEND;
+        SDLoc DL(Root);
+        MVT VT = Root->getSimpleValueType(0);
         std::tie(Mask, VL) = getDefaultScalableVLOps(VT, DL, DAG, Subtarget);
+      }
       break;
     }
     case RISCVISD::VZEXT_VL:
@@ -12806,7 +12793,7 @@ struct NodeExtensionHelper {
     case ISD::MUL: {
       EVT VT0 = Root->getOperand(0).getValueType();
       EVT VT1 = Root->getOperand(1).getValueType();
-      if (VT0.isFixedLengthVector() || VT0.isFixedLengthVector())
+      if (VT0.isFixedLengthVector() || VT1.isFixedLengthVector())
         return false;
       return (VT0.isVector() || VT1.isVector());
     }
@@ -12843,8 +12830,7 @@ struct NodeExtensionHelper {
         SupportsZExt =
             Opc == RISCVISD::VWADDU_W_VL || Opc == RISCVISD::VWSUBU_W_VL;
         SupportsSExt = !SupportsZExt;
-        Mask = Root->getOperand(3);
-        VL = Root->getOperand(4);
+        std::tie(Mask, VL) = getMaskAndVL(Root, DAG, Subtarget);
         CheckMask = true;
         // There's no existing extension here, so we don't have to worry about
         // making sure it gets removed.
@@ -12879,16 +12865,8 @@ struct NodeExtensionHelper {
     case ISD::MUL: {
       SDLoc DL(Root);
       MVT VT = Root->getSimpleValueType(0);
-      SDValue Mask, VL;
-      if (VT.isFixedLengthVector()) {
-        MVT ContainerVT = getContainerForFixedLengthVector(DAG, VT, Subtarget);
-        std::tie(Mask, VL) =
-            getDefaultVLOps(VT, ContainerVT, DL, DAG, Subtarget);
-      } else
-        std::tie(Mask, VL) = getDefaultScalableVLOps(VT, DL, DAG, Subtarget);
-      return std::make_pair(Mask, VL);
+      return getDefaultScalableVLOps(VT, DL, DAG, Subtarget);
     }
-
     default:
       return std::make_pair(Root->getOperand(3), Root->getOperand(4));
     }
@@ -14390,8 +14368,7 @@ static SDValue combineToVWMACC(SDNode *N, SelectionDAG &DAG,
   SDValue AddMergeOp = [](SDNode *N, SelectionDAG &DAG) {
     if (N->getOpcode() == ISD::ADD)
       return DAG.getUNDEF(N->getValueType(0));
-    else
-      return N->getOperand(2);
+    return N->getOperand(2);
   }(N, DAG);
 
   if (!AddMergeOp.isUndef())
@@ -14425,9 +14402,8 @@ static SDValue combineToVWMACC(SDNode *N, SelectionDAG &DAG,
       SDLoc DL(N);
       return getDefaultScalableVLOps(N->getSimpleValueType(0), DL, DAG,
                                      Subtarget);
-    } else {
-      return std::make_pair(N->getOperand(3), N->getOperand(4));
     }
+    return std::make_pair(N->getOperand(3), N->getOperand(4));
   }(N, DAG, Subtarget);
 
   SDValue MulMask = MulOp.getOperand(3);

>From ff849d9295379785b7c2b024b33b23982e5abb92 Mon Sep 17 00:00:00 2001
From: sun-jacobi <sun1011jacobi at gmail.com>
Date: Thu, 16 Nov 2023 13:52:01 +0900
Subject: [PATCH 4/8] [RISCV] use isScalableVector and check value legality in
 isSupportedRoot

---
 llvm/lib/Target/RISCV/RISCVISelLowering.cpp | 19 +++++++++----------
 1 file changed, 9 insertions(+), 10 deletions(-)

diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index a6a48b17bfd098..2d203141d98bdd 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -12786,16 +12786,15 @@ struct NodeExtensionHelper {
   }
 
   /// Check if \p Root supports any extension folding combines.
-  static bool isSupportedRoot(const SDNode *Root) {
+  static bool isSupportedRoot(const SDNode *Root, const SelectionDAG &DAG) {
     switch (Root->getOpcode()) {
     case ISD::ADD:
     case ISD::SUB:
     case ISD::MUL: {
-      EVT VT0 = Root->getOperand(0).getValueType();
-      EVT VT1 = Root->getOperand(1).getValueType();
-      if (VT0.isFixedLengthVector() || VT1.isFixedLengthVector())
+      const TargetLowering &TLI = DAG.getTargetLoweringInfo();
+      if (!TLI.isTypeLegal(Root->getValueType(0)))
         return false;
-      return (VT0.isVector() || VT1.isVector());
+      return Root->getValueType(0).isScalableVector();
     }
     case RISCVISD::ADD_VL:
     case RISCVISD::MUL_VL:
@@ -12813,8 +12812,8 @@ struct NodeExtensionHelper {
   /// Build a NodeExtensionHelper for \p Root.getOperand(\p OperandIdx).
   NodeExtensionHelper(SDNode *Root, unsigned OperandIdx, SelectionDAG &DAG,
                       const RISCVSubtarget &Subtarget) {
-    assert(isSupportedRoot(Root) && "Trying to build an helper with an "
-                                    "unsupported root");
+    assert(isSupportedRoot(Root, DAG) && "Trying to build an helper with an "
+                                         "unsupported root");
     assert(OperandIdx < 2 && "Requesting something else than LHS or RHS");
     OrigOperand = Root->getOperand(OperandIdx);
 
@@ -12858,7 +12857,7 @@ struct NodeExtensionHelper {
   static std::pair<SDValue, SDValue>
   getMaskAndVL(const SDNode *Root, SelectionDAG &DAG,
                const RISCVSubtarget &Subtarget) {
-    assert(isSupportedRoot(Root) && "Unexpected root");
+    assert(isSupportedRoot(Root, DAG) && "Unexpected root");
     switch (Root->getOpcode()) {
     case ISD::ADD:
     case ISD::SUB:
@@ -13119,7 +13118,7 @@ static SDValue combineBinOp_VLToVWBinOp_VL(SDNode *N,
                                            const RISCVSubtarget &Subtarget) {
   SelectionDAG &DAG = DCI.DAG;
 
-  if (!NodeExtensionHelper::isSupportedRoot(N))
+  if (!NodeExtensionHelper::isSupportedRoot(N, DAG))
     return SDValue();
 
   SmallVector<SDNode *> Worklist;
@@ -13130,7 +13129,7 @@ static SDValue combineBinOp_VLToVWBinOp_VL(SDNode *N,
 
   while (!Worklist.empty()) {
     SDNode *Root = Worklist.pop_back_val();
-    if (!NodeExtensionHelper::isSupportedRoot(Root))
+    if (!NodeExtensionHelper::isSupportedRoot(Root, DAG))
       return SDValue();
 
     NodeExtensionHelper LHS(N, 0, DAG, Subtarget);

>From 8b4cddd90d6af075a5e3f2784b12c8e364c0e84f Mon Sep 17 00:00:00 2001
From: sun-jacobi <sun1011jacobi at gmail.com>
Date: Sat, 18 Nov 2023 13:00:26 +0900
Subject: [PATCH 5/8] [RISCV] fix SupportsSExt in fillUpExtensionSupport

---
 llvm/lib/Target/RISCV/RISCVISelLowering.cpp | 1 +
 1 file changed, 1 insertion(+)

diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index 2d203141d98bdd..6390e725e63485 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -12725,6 +12725,7 @@ struct NodeExtensionHelper {
     case ISD::SIGN_EXTEND: {
       if (OrigOperand.getValueType().isVector()) {
         SupportsZExt = Opc == ISD::ZERO_EXTEND;
+        SupportsSExt = Opc == ISD::SIGN_EXTEND;
         SDLoc DL(Root);
         MVT VT = Root->getSimpleValueType(0);
         std::tie(Mask, VL) = getDefaultScalableVLOps(VT, DL, DAG, Subtarget);

>From a144d503d5644fa851342c6066be4db8c7fafa6a Mon Sep 17 00:00:00 2001
From: sun-jacobi <sun1011jacobi at gmail.com>
Date: Sat, 18 Nov 2023 13:01:26 +0900
Subject: [PATCH 6/8] [RISCV] add test for vscale vwop folding

---
 .../RISCV/rvv/vscale-vw-web-simplification.ll | 107 ++++++++++++++++++
 1 file changed, 107 insertions(+)
 create mode 100644 llvm/test/CodeGen/RISCV/rvv/vscale-vw-web-simplification.ll

diff --git a/llvm/test/CodeGen/RISCV/rvv/vscale-vw-web-simplification.ll b/llvm/test/CodeGen/RISCV/rvv/vscale-vw-web-simplification.ll
new file mode 100644
index 00000000000000..fe605d5ca6f99b
--- /dev/null
+++ b/llvm/test/CodeGen/RISCV/rvv/vscale-vw-web-simplification.ll
@@ -0,0 +1,107 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py
+; RUN: llc -mtriple=riscv32 -mattr=+v -verify-machineinstrs %s -o - --riscv-lower-ext-max-web-size=1 | FileCheck %s --check-prefixes=NO_FOLDING
+; RUN: llc -mtriple=riscv64 -mattr=+v -verify-machineinstrs %s -o - --riscv-lower-ext-max-web-size=1 | FileCheck %s --check-prefixes=NO_FOLDING
+; RUN: llc -mtriple=riscv32 -mattr=+v -verify-machineinstrs %s -o - --riscv-lower-ext-max-web-size=2 | FileCheck %s --check-prefixes=NO_FOLDING
+; RUN: llc -mtriple=riscv64 -mattr=+v -verify-machineinstrs %s -o - --riscv-lower-ext-max-web-size=2 | FileCheck %s --check-prefixes=NO_FOLDING
+; RUN: llc -mtriple=riscv32 -mattr=+v -verify-machineinstrs %s -o - --riscv-lower-ext-max-web-size=3 | FileCheck %s --check-prefixes=FOLDING
+; RUN: llc -mtriple=riscv64 -mattr=+v -verify-machineinstrs %s -o - --riscv-lower-ext-max-web-size=3 | FileCheck %s --check-prefixes=FOLDING
+; Check that the default value enables the web folding and
+; that it is bigger than 3.
+; RUN: llc -mtriple=riscv32 -mattr=+v -verify-machineinstrs %s -o - | FileCheck %s --check-prefixes=FOLDING
+; RUN: llc -mtriple=riscv64 -mattr=+v -verify-machineinstrs %s -o - | FileCheck %s --check-prefixes=FOLDING
+
+
+; Check that the scalable vector add/sub/mul operations are all promoted into their
+; vw counterpart when the folding of the web size is increased to 3.
+; We need the web size to be at least 3 for the folding to happen, because
+; %c has 3 uses.
+; see https://github.com/llvm/llvm-project/pull/72340
+define <vscale x 2 x i16> @vwop_vscale_sext_multiple_users(ptr %x, ptr %y, ptr %z) {
+; NO_FOLDING-LABEL: vwop_vscale_sext_multiple_users:
+; NO_FOLDING:       # %bb.0:
+; NO_FOLDING-NEXT:    vsetvli a3, zero, e16, mf2, ta, ma
+; NO_FOLDING-NEXT:    vle8.v v8, (a0)
+; NO_FOLDING-NEXT:    vle8.v v9, (a1)
+; NO_FOLDING-NEXT:    vle8.v v10, (a2)
+; NO_FOLDING-NEXT:    vsext.vf2 v11, v8
+; NO_FOLDING-NEXT:    vsext.vf2 v8, v9
+; NO_FOLDING-NEXT:    vsext.vf2 v9, v10
+; NO_FOLDING-NEXT:    vmul.vv v8, v11, v8
+; NO_FOLDING-NEXT:    vadd.vv v10, v11, v9
+; NO_FOLDING-NEXT:    vsub.vv v9, v11, v9
+; NO_FOLDING-NEXT:    vor.vv v8, v8, v10
+; NO_FOLDING-NEXT:    vor.vv v8, v8, v9
+; NO_FOLDING-NEXT:    ret
+;
+; FOLDING-LABEL: vwop_vscale_sext_multiple_users:
+; FOLDING:       # %bb.0:
+; FOLDING-NEXT:    vsetvli a3, zero, e8, mf4, ta, ma
+; FOLDING-NEXT:    vle8.v v8, (a0)
+; FOLDING-NEXT:    vle8.v v9, (a1)
+; FOLDING-NEXT:    vle8.v v10, (a2)
+; FOLDING-NEXT:    vwmul.vv v11, v8, v9
+; FOLDING-NEXT:    vwadd.vv v9, v8, v10
+; FOLDING-NEXT:    vwsub.vv v12, v8, v10
+; FOLDING-NEXT:    vsetvli zero, zero, e16, mf2, ta, ma
+; FOLDING-NEXT:    vor.vv v8, v11, v9
+; FOLDING-NEXT:    vor.vv v8, v8, v12
+; FOLDING-NEXT:    ret
+  %a = load <vscale x 2 x i8>, ptr %x
+  %b = load <vscale x 2 x i8>, ptr %y
+  %b2 = load <vscale x 2 x i8>, ptr %z
+  %c = sext <vscale x 2 x i8> %a to <vscale x 2 x i16>
+  %d = sext <vscale x 2 x i8> %b to <vscale x 2 x i16>
+  %d2 = sext <vscale x 2 x i8> %b2 to <vscale x 2 x i16>
+  %e = mul <vscale x 2 x i16> %c, %d
+  %f = add <vscale x 2 x i16> %c, %d2
+  %g = sub <vscale x 2 x i16> %c, %d2
+  %h = or <vscale x 2 x i16> %e, %f
+  %i = or <vscale x 2 x i16> %h, %g
+  ret <vscale x 2 x i16> %i
+}
+
+
+
+define <vscale x 2 x i16> @vwop_vscale_zext_multiple_users(ptr %x, ptr %y, ptr %z) {
+; NO_FOLDING-LABEL: vwop_vscale_zext_multiple_users:
+; NO_FOLDING:       # %bb.0:
+; NO_FOLDING-NEXT:    vsetvli a3, zero, e16, mf2, ta, ma
+; NO_FOLDING-NEXT:    vle8.v v8, (a0)
+; NO_FOLDING-NEXT:    vle8.v v9, (a1)
+; NO_FOLDING-NEXT:    vle8.v v10, (a2)
+; NO_FOLDING-NEXT:    vzext.vf2 v11, v8
+; NO_FOLDING-NEXT:    vzext.vf2 v8, v9
+; NO_FOLDING-NEXT:    vzext.vf2 v9, v10
+; NO_FOLDING-NEXT:    vmul.vv v8, v11, v8
+; NO_FOLDING-NEXT:    vadd.vv v10, v11, v9
+; NO_FOLDING-NEXT:    vsub.vv v9, v11, v9
+; NO_FOLDING-NEXT:    vor.vv v8, v8, v10
+; NO_FOLDING-NEXT:    vor.vv v8, v8, v9
+; NO_FOLDING-NEXT:    ret
+;
+; FOLDING-LABEL: vwop_vscale_zext_multiple_users:
+; FOLDING:       # %bb.0:
+; FOLDING-NEXT:    vsetvli a3, zero, e8, mf4, ta, ma
+; FOLDING-NEXT:    vle8.v v8, (a0)
+; FOLDING-NEXT:    vle8.v v9, (a1)
+; FOLDING-NEXT:    vle8.v v10, (a2)
+; FOLDING-NEXT:    vwmulu.vv v11, v8, v9
+; FOLDING-NEXT:    vwaddu.vv v9, v8, v10
+; FOLDING-NEXT:    vwsubu.vv v12, v8, v10
+; FOLDING-NEXT:    vsetvli zero, zero, e16, mf2, ta, ma
+; FOLDING-NEXT:    vor.vv v8, v11, v9
+; FOLDING-NEXT:    vor.vv v8, v8, v12
+; FOLDING-NEXT:    ret
+  %a = load <vscale x 2 x i8>, ptr %x
+  %b = load <vscale x 2 x i8>, ptr %y
+  %b2 = load <vscale x 2 x i8>, ptr %z
+  %c = zext <vscale x 2 x i8> %a to <vscale x 2 x i16>
+  %d = zext <vscale x 2 x i8> %b to <vscale x 2 x i16>
+  %d2 = zext <vscale x 2 x i8> %b2 to <vscale x 2 x i16>
+  %e = mul <vscale x 2 x i16> %c, %d
+  %f = add <vscale x 2 x i16> %c, %d2
+  %g = sub <vscale x 2 x i16> %c, %d2
+  %h = or <vscale x 2 x i16> %e, %f
+  %i = or <vscale x 2 x i16> %h, %g
+  ret <vscale x 2 x i16> %i
+}

>From 9177e7d3bdab9d80bb0aa9ef6ff7f264f215955d Mon Sep 17 00:00:00 2001
From: sun-jacobi <sun1011jacobi at gmail.com>
Date: Thu, 7 Dec 2023 18:08:00 +0900
Subject: [PATCH 7/8] [RISCV] avoid lambda for AddMergeOp

---
 llvm/lib/Target/RISCV/RISCVISelLowering.cpp | 12 +++++-------
 1 file changed, 5 insertions(+), 7 deletions(-)

diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index 6390e725e63485..9523d70e5a4a09 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -14365,14 +14365,12 @@ static SDValue combineToVWMACC(SDNode *N, SelectionDAG &DAG,
 
   SDValue Addend = N->getOperand(0);
   SDValue MulOp = N->getOperand(1);
-  SDValue AddMergeOp = [](SDNode *N, SelectionDAG &DAG) {
-    if (N->getOpcode() == ISD::ADD)
-      return DAG.getUNDEF(N->getValueType(0));
-    return N->getOperand(2);
-  }(N, DAG);
 
-  if (!AddMergeOp.isUndef())
-    return SDValue();
+  if (N->getOpcode() == RISCVISD::ADD_VL) {
+    SDValue AddMergeOp = N->getOperand(2);
+    if (!AddMergeOp.isUndef())
+      return SDValue();
+  }
 
   auto IsVWMulOpc = [](unsigned Opc) {
     switch (Opc) {

>From 07c9be2fe2c1b860e5059bf604fca61c35a02dbb Mon Sep 17 00:00:00 2001
From: sun-jacobi <sun1011jacobi at gmail.com>
Date: Mon, 11 Dec 2023 17:36:08 +0900
Subject: [PATCH 8/8] [RISCV] remove duplicate ISD::MUL in PerformDAGCombine.

---
 llvm/lib/Target/RISCV/RISCVISelLowering.cpp | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index 201f77a49ec807..2eed1413fc156f 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -14935,8 +14935,6 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N,
     return DAG.getNode(ISD::AND, DL, VT, NewFMV,
                        DAG.getConstant(~SignBit, DL, VT));
   }
-  case ISD::MUL:
-    return combineBinOp_VLToVWBinOp_VL(N, DCI, Subtarget);
   case ISD::ADD: {
     if (SDValue V = combineBinOp_VLToVWBinOp_VL(N, DCI, Subtarget))
       return V;
@@ -14956,6 +14954,8 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N,
   case ISD::XOR:
     return performXORCombine(N, DAG, Subtarget);
   case ISD::MUL:
+    if (SDValue V = combineBinOp_VLToVWBinOp_VL(N, DCI, Subtarget))
+      return V;
     return performMULCombine(N, DAG);
   case ISD::FADD:
   case ISD::UMAX:



More information about the cfe-commits mailing list