[llvm] 4e347b4 - Revert "[RISCV][ISel] Combine scalable vector add/sub/mul with zero/sign extension (#72340)"

Craig Topper via llvm-commits llvm-commits at lists.llvm.org
Tue Jan 2 19:50:10 PST 2024


Author: Craig Topper
Date: 2024-01-02T19:49:42-08:00
New Revision: 4e347b4e38b95bc455d0e620e11ac58fc0172a94

URL: https://github.com/llvm/llvm-project/commit/4e347b4e38b95bc455d0e620e11ac58fc0172a94
DIFF: https://github.com/llvm/llvm-project/commit/4e347b4e38b95bc455d0e620e11ac58fc0172a94.diff

LOG: Revert "[RISCV][ISel] Combine scalable vector add/sub/mul with zero/sign extension (#72340)"

This reverts most of commit 5b155aea0e529b7b5c807e189fef6ea5cd5faec9.
I have left the new test file, but regenerated the checks.

This causes failures in our downstream testing. The input types
to the extends need to be checked so we don't create RISCVISD::VZEXT_VL
with illegal or unsupported input type.

Added: 
    

Modified: 
    llvm/lib/Target/RISCV/RISCVISelLowering.cpp
    llvm/test/CodeGen/RISCV/rvv/ctlz-sdnode.ll
    llvm/test/CodeGen/RISCV/rvv/vscale-vw-web-simplification.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index 03a59f8a8b57ca..27bb69dc9868c8 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -1374,8 +1374,8 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
   setPrefLoopAlignment(Subtarget.getPrefLoopAlignment());
 
   setTargetDAGCombine({ISD::INTRINSIC_VOID, ISD::INTRINSIC_W_CHAIN,
-                       ISD::INTRINSIC_WO_CHAIN, ISD::ADD, ISD::SUB, ISD::MUL,
-                       ISD::AND, ISD::OR, ISD::XOR, ISD::SETCC, ISD::SELECT});
+                       ISD::INTRINSIC_WO_CHAIN, ISD::ADD, ISD::SUB, ISD::AND,
+                       ISD::OR, ISD::XOR, ISD::SETCC, ISD::SELECT});
   if (Subtarget.is64Bit())
     setTargetDAGCombine(ISD::SRA);
 
@@ -12850,9 +12850,9 @@ struct CombineResult;
 
 /// Helper class for folding sign/zero extensions.
 /// In particular, this class is used for the following combines:
-/// add | add_vl -> vwadd(u) | vwadd(u)_w
-/// sub | sub_vl -> vwsub(u) | vwsub(u)_w
-/// mul | mul_vl -> vwmul(u) | vwmul_su
+/// add_vl -> vwadd(u) | vwadd(u)_w
+/// sub_vl -> vwsub(u) | vwsub(u)_w
+/// mul_vl -> vwmul(u) | vwmul_su
 ///
 /// An object of this class represents an operand of the operation we want to
 /// combine.
@@ -12897,8 +12897,6 @@ struct NodeExtensionHelper {
   /// E.g., for zext(a), this would return a.
   SDValue getSource() const {
     switch (OrigOperand.getOpcode()) {
-    case ISD::ZERO_EXTEND:
-    case ISD::SIGN_EXTEND:
     case RISCVISD::VSEXT_VL:
     case RISCVISD::VZEXT_VL:
       return OrigOperand.getOperand(0);
@@ -12915,8 +12913,7 @@ struct NodeExtensionHelper {
   /// Get or create a value that can feed \p Root with the given extension \p
   /// SExt. If \p SExt is std::nullopt, this returns the source of this operand.
   /// \see ::getSource().
-  SDValue getOrCreateExtendedOp(SDNode *Root, SelectionDAG &DAG,
-                                const RISCVSubtarget &Subtarget,
+  SDValue getOrCreateExtendedOp(const SDNode *Root, SelectionDAG &DAG,
                                 std::optional<bool> SExt) const {
     if (!SExt.has_value())
       return OrigOperand;
@@ -12931,10 +12928,8 @@ struct NodeExtensionHelper {
 
     // If we need an extension, we should be changing the type.
     SDLoc DL(Root);
-    auto [Mask, VL] = getMaskAndVL(Root, DAG, Subtarget);
+    auto [Mask, VL] = getMaskAndVL(Root);
     switch (OrigOperand.getOpcode()) {
-    case ISD::ZERO_EXTEND:
-    case ISD::SIGN_EXTEND:
     case RISCVISD::VSEXT_VL:
     case RISCVISD::VZEXT_VL:
       return DAG.getNode(ExtOpc, DL, NarrowVT, Source, Mask, VL);
@@ -12974,15 +12969,12 @@ struct NodeExtensionHelper {
   /// \pre \p Opcode represents a supported root (\see ::isSupportedRoot()).
   static unsigned getSameExtensionOpcode(unsigned Opcode, bool IsSExt) {
     switch (Opcode) {
-    case ISD::ADD:
     case RISCVISD::ADD_VL:
     case RISCVISD::VWADD_W_VL:
     case RISCVISD::VWADDU_W_VL:
       return IsSExt ? RISCVISD::VWADD_VL : RISCVISD::VWADDU_VL;
-    case ISD::MUL:
     case RISCVISD::MUL_VL:
       return IsSExt ? RISCVISD::VWMUL_VL : RISCVISD::VWMULU_VL;
-    case ISD::SUB:
     case RISCVISD::SUB_VL:
     case RISCVISD::VWSUB_W_VL:
     case RISCVISD::VWSUBU_W_VL:
@@ -12995,8 +12987,7 @@ struct NodeExtensionHelper {
   /// Get the opcode to materialize \p Opcode(sext(a), zext(b)) ->
   /// newOpcode(a, b).
   static unsigned getSUOpcode(unsigned Opcode) {
-    assert((Opcode == RISCVISD::MUL_VL || Opcode == ISD::MUL) &&
-           "SU is only supported for MUL");
+    assert(Opcode == RISCVISD::MUL_VL && "SU is only supported for MUL");
     return RISCVISD::VWMULSU_VL;
   }
 
@@ -13004,10 +12995,8 @@ struct NodeExtensionHelper {
   /// newOpcode(a, b).
   static unsigned getWOpcode(unsigned Opcode, bool IsSExt) {
     switch (Opcode) {
-    case ISD::ADD:
     case RISCVISD::ADD_VL:
       return IsSExt ? RISCVISD::VWADD_W_VL : RISCVISD::VWADDU_W_VL;
-    case ISD::SUB:
     case RISCVISD::SUB_VL:
       return IsSExt ? RISCVISD::VWSUB_W_VL : RISCVISD::VWSUBU_W_VL;
     default:
@@ -13017,33 +13006,19 @@ struct NodeExtensionHelper {
 
   using CombineToTry = std::function<std::optional<CombineResult>(
       SDNode * /*Root*/, const NodeExtensionHelper & /*LHS*/,
-      const NodeExtensionHelper & /*RHS*/, SelectionDAG &,
-      const RISCVSubtarget &)>;
+      const NodeExtensionHelper & /*RHS*/)>;
 
   /// Check if this node needs to be fully folded or extended for all users.
   bool needToPromoteOtherUsers() const { return EnforceOneUse; }
 
   /// Helper method to set the various fields of this struct based on the
   /// type of \p Root.
-  void fillUpExtensionSupport(SDNode *Root, SelectionDAG &DAG,
-                              const RISCVSubtarget &Subtarget) {
+  void fillUpExtensionSupport(SDNode *Root, SelectionDAG &DAG) {
     SupportsZExt = false;
     SupportsSExt = false;
     EnforceOneUse = true;
     CheckMask = true;
-    unsigned Opc = OrigOperand.getOpcode();
-    switch (Opc) {
-    case ISD::ZERO_EXTEND:
-    case ISD::SIGN_EXTEND: {
-      if (OrigOperand.getValueType().isVector()) {
-        SupportsZExt = Opc == ISD::ZERO_EXTEND;
-        SupportsSExt = Opc == ISD::SIGN_EXTEND;
-        SDLoc DL(Root);
-        MVT VT = Root->getSimpleValueType(0);
-        std::tie(Mask, VL) = getDefaultScalableVLOps(VT, DL, DAG, Subtarget);
-      }
-      break;
-    }
+    switch (OrigOperand.getOpcode()) {
     case RISCVISD::VZEXT_VL:
       SupportsZExt = true;
       Mask = OrigOperand.getOperand(1);
@@ -13099,16 +13074,8 @@ struct NodeExtensionHelper {
   }
 
   /// Check if \p Root supports any extension folding combines.
-  static bool isSupportedRoot(const SDNode *Root, const SelectionDAG &DAG) {
+  static bool isSupportedRoot(const SDNode *Root) {
     switch (Root->getOpcode()) {
-    case ISD::ADD:
-    case ISD::SUB:
-    case ISD::MUL: {
-      const TargetLowering &TLI = DAG.getTargetLoweringInfo();
-      if (!TLI.isTypeLegal(Root->getValueType(0)))
-        return false;
-      return Root->getValueType(0).isScalableVector();
-    }
     case RISCVISD::ADD_VL:
     case RISCVISD::MUL_VL:
     case RISCVISD::VWADD_W_VL:
@@ -13123,10 +13090,9 @@ struct NodeExtensionHelper {
   }
 
   /// Build a NodeExtensionHelper for \p Root.getOperand(\p OperandIdx).
-  NodeExtensionHelper(SDNode *Root, unsigned OperandIdx, SelectionDAG &DAG,
-                      const RISCVSubtarget &Subtarget) {
-    assert(isSupportedRoot(Root, DAG) && "Trying to build an helper with an "
-                                         "unsupported root");
+  NodeExtensionHelper(SDNode *Root, unsigned OperandIdx, SelectionDAG &DAG) {
+    assert(isSupportedRoot(Root) && "Trying to build an helper with an "
+                                    "unsupported root");
     assert(OperandIdx < 2 && "Requesting something else than LHS or RHS");
     OrigOperand = Root->getOperand(OperandIdx);
 
@@ -13142,7 +13108,7 @@ struct NodeExtensionHelper {
         SupportsZExt =
             Opc == RISCVISD::VWADDU_W_VL || Opc == RISCVISD::VWSUBU_W_VL;
         SupportsSExt = !SupportsZExt;
-        std::tie(Mask, VL) = getMaskAndVL(Root, DAG, Subtarget);
+        std::tie(Mask, VL) = getMaskAndVL(Root);
         CheckMask = true;
         // There's no existing extension here, so we don't have to worry about
         // making sure it gets removed.
@@ -13151,7 +13117,7 @@ struct NodeExtensionHelper {
       }
       [[fallthrough]];
     default:
-      fillUpExtensionSupport(Root, DAG, Subtarget);
+      fillUpExtensionSupport(Root, DAG);
       break;
     }
   }
@@ -13167,27 +13133,14 @@ struct NodeExtensionHelper {
   }
 
   /// Helper function to get the Mask and VL from \p Root.
-  static std::pair<SDValue, SDValue>
-  getMaskAndVL(const SDNode *Root, SelectionDAG &DAG,
-               const RISCVSubtarget &Subtarget) {
-    assert(isSupportedRoot(Root, DAG) && "Unexpected root");
-    switch (Root->getOpcode()) {
-    case ISD::ADD:
-    case ISD::SUB:
-    case ISD::MUL: {
-      SDLoc DL(Root);
-      MVT VT = Root->getSimpleValueType(0);
-      return getDefaultScalableVLOps(VT, DL, DAG, Subtarget);
-    }
-    default:
-      return std::make_pair(Root->getOperand(3), Root->getOperand(4));
-    }
+  static std::pair<SDValue, SDValue> getMaskAndVL(const SDNode *Root) {
+    assert(isSupportedRoot(Root) && "Unexpected root");
+    return std::make_pair(Root->getOperand(3), Root->getOperand(4));
   }
 
   /// Check if the Mask and VL of this operand are compatible with \p Root.
-  bool areVLAndMaskCompatible(SDNode *Root, SelectionDAG &DAG,
-                              const RISCVSubtarget &Subtarget) const {
-    auto [Mask, VL] = getMaskAndVL(Root, DAG, Subtarget);
+  bool areVLAndMaskCompatible(const SDNode *Root) const {
+    auto [Mask, VL] = getMaskAndVL(Root);
     return isMaskCompatible(Mask) && isVLCompatible(VL);
   }
 
@@ -13195,14 +13148,11 @@ struct NodeExtensionHelper {
   /// foldings that are supported by this class.
   static bool isCommutative(const SDNode *N) {
     switch (N->getOpcode()) {
-    case ISD::ADD:
-    case ISD::MUL:
     case RISCVISD::ADD_VL:
     case RISCVISD::MUL_VL:
     case RISCVISD::VWADD_W_VL:
     case RISCVISD::VWADDU_W_VL:
       return true;
-    case ISD::SUB:
     case RISCVISD::SUB_VL:
     case RISCVISD::VWSUB_W_VL:
     case RISCVISD::VWSUBU_W_VL:
@@ -13247,25 +13197,14 @@ struct CombineResult {
   /// Return a value that uses TargetOpcode and that can be used to replace
   /// Root.
   /// The actual replacement is *not* done in that method.
-  SDValue materialize(SelectionDAG &DAG,
-                      const RISCVSubtarget &Subtarget) const {
+  SDValue materialize(SelectionDAG &DAG) const {
     SDValue Mask, VL, Merge;
-    std::tie(Mask, VL) =
-        NodeExtensionHelper::getMaskAndVL(Root, DAG, Subtarget);
-    switch (Root->getOpcode()) {
-    default:
-      Merge = Root->getOperand(2);
-      break;
-    case ISD::ADD:
-    case ISD::SUB:
-    case ISD::MUL:
-      Merge = DAG.getUNDEF(Root->getValueType(0));
-      break;
-    }
+    std::tie(Mask, VL) = NodeExtensionHelper::getMaskAndVL(Root);
+    Merge = Root->getOperand(2);
     return DAG.getNode(TargetOpcode, SDLoc(Root), Root->getValueType(0),
-                       LHS.getOrCreateExtendedOp(Root, DAG, Subtarget, SExtLHS),
-                       RHS.getOrCreateExtendedOp(Root, DAG, Subtarget, SExtRHS),
-                       Merge, Mask, VL);
+                       LHS.getOrCreateExtendedOp(Root, DAG, SExtLHS),
+                       RHS.getOrCreateExtendedOp(Root, DAG, SExtRHS), Merge,
+                       Mask, VL);
   }
 };
 
@@ -13282,16 +13221,15 @@ struct CombineResult {
 static std::optional<CombineResult>
 canFoldToVWWithSameExtensionImpl(SDNode *Root, const NodeExtensionHelper &LHS,
                                  const NodeExtensionHelper &RHS, bool AllowSExt,
-                                 bool AllowZExt, SelectionDAG &DAG,
-                                 const RISCVSubtarget &Subtarget) {
+                                 bool AllowZExt) {
   assert((AllowSExt || AllowZExt) && "Forgot to set what you want?");
-  if (!LHS.areVLAndMaskCompatible(Root, DAG, Subtarget) ||
-      !RHS.areVLAndMaskCompatible(Root, DAG, Subtarget))
+  if (!LHS.areVLAndMaskCompatible(Root) || !RHS.areVLAndMaskCompatible(Root))
     return std::nullopt;
   if (AllowZExt && LHS.SupportsZExt && RHS.SupportsZExt)
     return CombineResult(NodeExtensionHelper::getSameExtensionOpcode(
                              Root->getOpcode(), /*IsSExt=*/false),
-                         Root, LHS, /*SExtLHS=*/false, RHS, /*SExtRHS=*/false);
+                         Root, LHS, /*SExtLHS=*/false, RHS,
+                         /*SExtRHS=*/false);
   if (AllowSExt && LHS.SupportsSExt && RHS.SupportsSExt)
     return CombineResult(NodeExtensionHelper::getSameExtensionOpcode(
                              Root->getOpcode(), /*IsSExt=*/true),
@@ -13308,10 +13246,9 @@ canFoldToVWWithSameExtensionImpl(SDNode *Root, const NodeExtensionHelper &LHS,
 /// can be used to apply the pattern.
 static std::optional<CombineResult>
 canFoldToVWWithSameExtension(SDNode *Root, const NodeExtensionHelper &LHS,
-                             const NodeExtensionHelper &RHS, SelectionDAG &DAG,
-                             const RISCVSubtarget &Subtarget) {
+                             const NodeExtensionHelper &RHS) {
   return canFoldToVWWithSameExtensionImpl(Root, LHS, RHS, /*AllowSExt=*/true,
-                                          /*AllowZExt=*/true, DAG, Subtarget);
+                                          /*AllowZExt=*/true);
 }
 
 /// Check if \p Root follows a pattern Root(LHS, ext(RHS))
@@ -13320,9 +13257,8 @@ canFoldToVWWithSameExtension(SDNode *Root, const NodeExtensionHelper &LHS,
 /// can be used to apply the pattern.
 static std::optional<CombineResult>
 canFoldToVW_W(SDNode *Root, const NodeExtensionHelper &LHS,
-              const NodeExtensionHelper &RHS, SelectionDAG &DAG,
-              const RISCVSubtarget &Subtarget) {
-  if (!RHS.areVLAndMaskCompatible(Root, DAG, Subtarget))
+              const NodeExtensionHelper &RHS) {
+  if (!RHS.areVLAndMaskCompatible(Root))
     return std::nullopt;
 
   // FIXME: Is it useful to form a vwadd.wx or vwsub.wx if it removes a scalar
@@ -13346,10 +13282,9 @@ canFoldToVW_W(SDNode *Root, const NodeExtensionHelper &LHS,
 /// can be used to apply the pattern.
 static std::optional<CombineResult>
 canFoldToVWWithSEXT(SDNode *Root, const NodeExtensionHelper &LHS,
-                    const NodeExtensionHelper &RHS, SelectionDAG &DAG,
-                    const RISCVSubtarget &Subtarget) {
+                    const NodeExtensionHelper &RHS) {
   return canFoldToVWWithSameExtensionImpl(Root, LHS, RHS, /*AllowSExt=*/true,
-                                          /*AllowZExt=*/false, DAG, Subtarget);
+                                          /*AllowZExt=*/false);
 }
 
 /// Check if \p Root follows a pattern Root(zext(LHS), zext(RHS))
@@ -13358,10 +13293,9 @@ canFoldToVWWithSEXT(SDNode *Root, const NodeExtensionHelper &LHS,
 /// can be used to apply the pattern.
 static std::optional<CombineResult>
 canFoldToVWWithZEXT(SDNode *Root, const NodeExtensionHelper &LHS,
-                    const NodeExtensionHelper &RHS, SelectionDAG &DAG,
-                    const RISCVSubtarget &Subtarget) {
+                    const NodeExtensionHelper &RHS) {
   return canFoldToVWWithSameExtensionImpl(Root, LHS, RHS, /*AllowSExt=*/false,
-                                          /*AllowZExt=*/true, DAG, Subtarget);
+                                          /*AllowZExt=*/true);
 }
 
 /// Check if \p Root follows a pattern Root(sext(LHS), zext(RHS))
@@ -13370,13 +13304,10 @@ canFoldToVWWithZEXT(SDNode *Root, const NodeExtensionHelper &LHS,
 /// can be used to apply the pattern.
 static std::optional<CombineResult>
 canFoldToVW_SU(SDNode *Root, const NodeExtensionHelper &LHS,
-               const NodeExtensionHelper &RHS, SelectionDAG &DAG,
-               const RISCVSubtarget &Subtarget) {
-
+               const NodeExtensionHelper &RHS) {
   if (!LHS.SupportsSExt || !RHS.SupportsZExt)
     return std::nullopt;
-  if (!LHS.areVLAndMaskCompatible(Root, DAG, Subtarget) ||
-      !RHS.areVLAndMaskCompatible(Root, DAG, Subtarget))
+  if (!LHS.areVLAndMaskCompatible(Root) || !RHS.areVLAndMaskCompatible(Root))
     return std::nullopt;
   return CombineResult(NodeExtensionHelper::getSUOpcode(Root->getOpcode()),
                        Root, LHS, /*SExtLHS=*/true, RHS, /*SExtRHS=*/false);
@@ -13386,8 +13317,6 @@ SmallVector<NodeExtensionHelper::CombineToTry>
 NodeExtensionHelper::getSupportedFoldings(const SDNode *Root) {
   SmallVector<CombineToTry> Strategies;
   switch (Root->getOpcode()) {
-  case ISD::ADD:
-  case ISD::SUB:
   case RISCVISD::ADD_VL:
   case RISCVISD::SUB_VL:
     // add|sub -> vwadd(u)|vwsub(u)
@@ -13395,7 +13324,6 @@ NodeExtensionHelper::getSupportedFoldings(const SDNode *Root) {
     // add|sub -> vwadd(u)_w|vwsub(u)_w
     Strategies.push_back(canFoldToVW_W);
     break;
-  case ISD::MUL:
   case RISCVISD::MUL_VL:
     // mul -> vwmul(u)
     Strategies.push_back(canFoldToVWWithSameExtension);
@@ -13426,14 +13354,12 @@ NodeExtensionHelper::getSupportedFoldings(const SDNode *Root) {
 /// mul_vl -> vwmul(u) | vwmul_su
 /// vwadd_w(u) -> vwadd(u)
 /// vwub_w(u) -> vwadd(u)
-static SDValue combineBinOp_VLToVWBinOp_VL(SDNode *N,
-                                           TargetLowering::DAGCombinerInfo &DCI,
-                                           const RISCVSubtarget &Subtarget) {
+static SDValue
+combineBinOp_VLToVWBinOp_VL(SDNode *N, TargetLowering::DAGCombinerInfo &DCI) {
   SelectionDAG &DAG = DCI.DAG;
 
-  if (!NodeExtensionHelper::isSupportedRoot(N, DAG))
-    return SDValue();
-
+  assert(NodeExtensionHelper::isSupportedRoot(N) &&
+         "Shouldn't have called this method");
   SmallVector<SDNode *> Worklist;
   SmallSet<SDNode *, 8> Inserted;
   Worklist.push_back(N);
@@ -13442,11 +13368,11 @@ static SDValue combineBinOp_VLToVWBinOp_VL(SDNode *N,
 
   while (!Worklist.empty()) {
     SDNode *Root = Worklist.pop_back_val();
-    if (!NodeExtensionHelper::isSupportedRoot(Root, DAG))
+    if (!NodeExtensionHelper::isSupportedRoot(Root))
       return SDValue();
 
-    NodeExtensionHelper LHS(N, 0, DAG, Subtarget);
-    NodeExtensionHelper RHS(N, 1, DAG, Subtarget);
+    NodeExtensionHelper LHS(N, 0, DAG);
+    NodeExtensionHelper RHS(N, 1, DAG);
     auto AppendUsersIfNeeded = [&Worklist,
                                 &Inserted](const NodeExtensionHelper &Op) {
       if (Op.needToPromoteOtherUsers()) {
@@ -13473,8 +13399,7 @@ static SDValue combineBinOp_VLToVWBinOp_VL(SDNode *N,
 
       for (NodeExtensionHelper::CombineToTry FoldingStrategy :
            FoldingStrategies) {
-        std::optional<CombineResult> Res =
-            FoldingStrategy(N, LHS, RHS, DAG, Subtarget);
+        std::optional<CombineResult> Res = FoldingStrategy(N, LHS, RHS);
         if (Res) {
           Matched = true;
           CombinesToApply.push_back(*Res);
@@ -13503,7 +13428,7 @@ static SDValue combineBinOp_VLToVWBinOp_VL(SDNode *N,
   SmallVector<std::pair<SDValue, SDValue>> ValuesToReplace;
   ValuesToReplace.reserve(CombinesToApply.size());
   for (CombineResult Res : CombinesToApply) {
-    SDValue NewValue = Res.materialize(DAG, Subtarget);
+    SDValue NewValue = Res.materialize(DAG);
     if (!InputRootReplacement) {
       assert(Res.Root == N &&
              "First element is expected to be the current node");
@@ -14775,20 +14700,13 @@ static SDValue performCONCAT_VECTORSCombine(SDNode *N, SelectionDAG &DAG,
 
 static SDValue combineToVWMACC(SDNode *N, SelectionDAG &DAG,
                                const RISCVSubtarget &Subtarget) {
-
-  assert(N->getOpcode() == RISCVISD::ADD_VL || N->getOpcode() == ISD::ADD);
-
-  if (N->getValueType(0).isFixedLengthVector())
-    return SDValue();
-
+  assert(N->getOpcode() == RISCVISD::ADD_VL);
   SDValue Addend = N->getOperand(0);
   SDValue MulOp = N->getOperand(1);
+  SDValue AddMergeOp = N->getOperand(2);
 
-  if (N->getOpcode() == RISCVISD::ADD_VL) {
-    SDValue AddMergeOp = N->getOperand(2);
-    if (!AddMergeOp.isUndef())
-      return SDValue();
-  }
+  if (!AddMergeOp.isUndef())
+    return SDValue();
 
   auto IsVWMulOpc = [](unsigned Opc) {
     switch (Opc) {
@@ -14812,16 +14730,8 @@ static SDValue combineToVWMACC(SDNode *N, SelectionDAG &DAG,
   if (!MulMergeOp.isUndef())
     return SDValue();
 
-  auto [AddMask, AddVL] = [](SDNode *N, SelectionDAG &DAG,
-                             const RISCVSubtarget &Subtarget) {
-    if (N->getOpcode() == ISD::ADD) {
-      SDLoc DL(N);
-      return getDefaultScalableVLOps(N->getSimpleValueType(0), DL, DAG,
-                                     Subtarget);
-    }
-    return std::make_pair(N->getOperand(3), N->getOperand(4));
-  }(N, DAG, Subtarget);
-
+  SDValue AddMask = N->getOperand(3);
+  SDValue AddVL = N->getOperand(4);
   SDValue MulMask = MulOp.getOperand(3);
   SDValue MulVL = MulOp.getOperand(4);
 
@@ -15087,18 +14997,10 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N,
     return DAG.getNode(ISD::AND, DL, VT, NewFMV,
                        DAG.getConstant(~SignBit, DL, VT));
   }
-  case ISD::ADD: {
-    if (SDValue V = combineBinOp_VLToVWBinOp_VL(N, DCI, Subtarget))
-      return V;
-    if (SDValue V = combineToVWMACC(N, DAG, Subtarget))
-      return V;
+  case ISD::ADD:
     return performADDCombine(N, DAG, Subtarget);
-  }
-  case ISD::SUB: {
-    if (SDValue V = combineBinOp_VLToVWBinOp_VL(N, DCI, Subtarget))
-      return V;
+  case ISD::SUB:
     return performSUBCombine(N, DAG, Subtarget);
-  }
   case ISD::AND:
     return performANDCombine(N, DCI, Subtarget);
   case ISD::OR:
@@ -15106,8 +15008,6 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N,
   case ISD::XOR:
     return performXORCombine(N, DAG, Subtarget);
   case ISD::MUL:
-    if (SDValue V = combineBinOp_VLToVWBinOp_VL(N, DCI, Subtarget))
-      return V;
     return performMULCombine(N, DAG);
   case ISD::FADD:
   case ISD::UMAX:
@@ -15584,7 +15484,7 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N,
     break;
   }
   case RISCVISD::ADD_VL:
-    if (SDValue V = combineBinOp_VLToVWBinOp_VL(N, DCI, Subtarget))
+    if (SDValue V = combineBinOp_VLToVWBinOp_VL(N, DCI))
       return V;
     return combineToVWMACC(N, DAG, Subtarget);
   case RISCVISD::SUB_VL:
@@ -15593,7 +15493,7 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N,
   case RISCVISD::VWSUB_W_VL:
   case RISCVISD::VWSUBU_W_VL:
   case RISCVISD::MUL_VL:
-    return combineBinOp_VLToVWBinOp_VL(N, DCI, Subtarget);
+    return combineBinOp_VLToVWBinOp_VL(N, DCI);
   case RISCVISD::VFMADD_VL:
   case RISCVISD::VFNMADD_VL:
   case RISCVISD::VFMSUB_VL:

diff  --git a/llvm/test/CodeGen/RISCV/rvv/ctlz-sdnode.ll b/llvm/test/CodeGen/RISCV/rvv/ctlz-sdnode.ll
index fc94f8c2a52797..47d65c2593a4cc 100644
--- a/llvm/test/CodeGen/RISCV/rvv/ctlz-sdnode.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/ctlz-sdnode.ll
@@ -1231,17 +1231,16 @@ define <vscale x 1 x i64> @ctlz_nxv1i64(<vscale x 1 x i64> %va) {
 ;
 ; CHECK-F-LABEL: ctlz_nxv1i64:
 ; CHECK-F:       # %bb.0:
-; CHECK-F-NEXT:    li a0, 190
-; CHECK-F-NEXT:    vsetvli a1, zero, e64, m1, ta, ma
-; CHECK-F-NEXT:    vmv.v.x v9, a0
-; CHECK-F-NEXT:    vsetvli zero, zero, e32, mf2, ta, ma
+; CHECK-F-NEXT:    vsetvli a0, zero, e32, mf2, ta, ma
 ; CHECK-F-NEXT:    fsrmi a0, 1
-; CHECK-F-NEXT:    vfncvt.f.xu.w v10, v8
-; CHECK-F-NEXT:    vsrl.vi v8, v10, 23
-; CHECK-F-NEXT:    vwsubu.wv v9, v9, v8
-; CHECK-F-NEXT:    li a1, 64
+; CHECK-F-NEXT:    vfncvt.f.xu.w v9, v8
+; CHECK-F-NEXT:    vsrl.vi v8, v9, 23
 ; CHECK-F-NEXT:    vsetvli zero, zero, e64, m1, ta, ma
-; CHECK-F-NEXT:    vminu.vx v8, v9, a1
+; CHECK-F-NEXT:    vzext.vf2 v9, v8
+; CHECK-F-NEXT:    li a1, 190
+; CHECK-F-NEXT:    vrsub.vx v8, v9, a1
+; CHECK-F-NEXT:    li a1, 64
+; CHECK-F-NEXT:    vminu.vx v8, v8, a1
 ; CHECK-F-NEXT:    fsrm a0
 ; CHECK-F-NEXT:    ret
 ;
@@ -1372,17 +1371,16 @@ define <vscale x 2 x i64> @ctlz_nxv2i64(<vscale x 2 x i64> %va) {
 ;
 ; CHECK-F-LABEL: ctlz_nxv2i64:
 ; CHECK-F:       # %bb.0:
-; CHECK-F-NEXT:    li a0, 190
-; CHECK-F-NEXT:    vsetvli a1, zero, e64, m2, ta, ma
-; CHECK-F-NEXT:    vmv.v.x v10, a0
-; CHECK-F-NEXT:    vsetvli zero, zero, e32, m1, ta, ma
+; CHECK-F-NEXT:    vsetvli a0, zero, e32, m1, ta, ma
 ; CHECK-F-NEXT:    fsrmi a0, 1
-; CHECK-F-NEXT:    vfncvt.f.xu.w v12, v8
-; CHECK-F-NEXT:    vsrl.vi v8, v12, 23
-; CHECK-F-NEXT:    vwsubu.wv v10, v10, v8
-; CHECK-F-NEXT:    li a1, 64
+; CHECK-F-NEXT:    vfncvt.f.xu.w v10, v8
+; CHECK-F-NEXT:    vsrl.vi v8, v10, 23
 ; CHECK-F-NEXT:    vsetvli zero, zero, e64, m2, ta, ma
-; CHECK-F-NEXT:    vminu.vx v8, v10, a1
+; CHECK-F-NEXT:    vzext.vf2 v10, v8
+; CHECK-F-NEXT:    li a1, 190
+; CHECK-F-NEXT:    vrsub.vx v8, v10, a1
+; CHECK-F-NEXT:    li a1, 64
+; CHECK-F-NEXT:    vminu.vx v8, v8, a1
 ; CHECK-F-NEXT:    fsrm a0
 ; CHECK-F-NEXT:    ret
 ;
@@ -1513,17 +1511,16 @@ define <vscale x 4 x i64> @ctlz_nxv4i64(<vscale x 4 x i64> %va) {
 ;
 ; CHECK-F-LABEL: ctlz_nxv4i64:
 ; CHECK-F:       # %bb.0:
-; CHECK-F-NEXT:    li a0, 190
-; CHECK-F-NEXT:    vsetvli a1, zero, e64, m4, ta, ma
-; CHECK-F-NEXT:    vmv.v.x v12, a0
-; CHECK-F-NEXT:    vsetvli zero, zero, e32, m2, ta, ma
+; CHECK-F-NEXT:    vsetvli a0, zero, e32, m2, ta, ma
 ; CHECK-F-NEXT:    fsrmi a0, 1
-; CHECK-F-NEXT:    vfncvt.f.xu.w v16, v8
-; CHECK-F-NEXT:    vsrl.vi v8, v16, 23
-; CHECK-F-NEXT:    vwsubu.wv v12, v12, v8
-; CHECK-F-NEXT:    li a1, 64
+; CHECK-F-NEXT:    vfncvt.f.xu.w v12, v8
+; CHECK-F-NEXT:    vsrl.vi v8, v12, 23
 ; CHECK-F-NEXT:    vsetvli zero, zero, e64, m4, ta, ma
-; CHECK-F-NEXT:    vminu.vx v8, v12, a1
+; CHECK-F-NEXT:    vzext.vf2 v12, v8
+; CHECK-F-NEXT:    li a1, 190
+; CHECK-F-NEXT:    vrsub.vx v8, v12, a1
+; CHECK-F-NEXT:    li a1, 64
+; CHECK-F-NEXT:    vminu.vx v8, v8, a1
 ; CHECK-F-NEXT:    fsrm a0
 ; CHECK-F-NEXT:    ret
 ;
@@ -1654,17 +1651,16 @@ define <vscale x 8 x i64> @ctlz_nxv8i64(<vscale x 8 x i64> %va) {
 ;
 ; CHECK-F-LABEL: ctlz_nxv8i64:
 ; CHECK-F:       # %bb.0:
-; CHECK-F-NEXT:    li a0, 190
-; CHECK-F-NEXT:    vsetvli a1, zero, e64, m8, ta, ma
-; CHECK-F-NEXT:    vmv.v.x v16, a0
-; CHECK-F-NEXT:    vsetvli zero, zero, e32, m4, ta, ma
+; CHECK-F-NEXT:    vsetvli a0, zero, e32, m4, ta, ma
 ; CHECK-F-NEXT:    fsrmi a0, 1
-; CHECK-F-NEXT:    vfncvt.f.xu.w v24, v8
-; CHECK-F-NEXT:    vsrl.vi v8, v24, 23
-; CHECK-F-NEXT:    vwsubu.wv v16, v16, v8
-; CHECK-F-NEXT:    li a1, 64
+; CHECK-F-NEXT:    vfncvt.f.xu.w v16, v8
+; CHECK-F-NEXT:    vsrl.vi v8, v16, 23
 ; CHECK-F-NEXT:    vsetvli zero, zero, e64, m8, ta, ma
-; CHECK-F-NEXT:    vminu.vx v8, v16, a1
+; CHECK-F-NEXT:    vzext.vf2 v16, v8
+; CHECK-F-NEXT:    li a1, 190
+; CHECK-F-NEXT:    vrsub.vx v8, v16, a1
+; CHECK-F-NEXT:    li a1, 64
+; CHECK-F-NEXT:    vminu.vx v8, v8, a1
 ; CHECK-F-NEXT:    fsrm a0
 ; CHECK-F-NEXT:    ret
 ;
@@ -2837,16 +2833,15 @@ define <vscale x 1 x i64> @ctlz_zero_undef_nxv1i64(<vscale x 1 x i64> %va) {
 ;
 ; CHECK-F-LABEL: ctlz_zero_undef_nxv1i64:
 ; CHECK-F:       # %bb.0:
-; CHECK-F-NEXT:    li a0, 190
-; CHECK-F-NEXT:    vsetvli a1, zero, e64, m1, ta, ma
-; CHECK-F-NEXT:    vmv.v.x v9, a0
-; CHECK-F-NEXT:    vsetvli zero, zero, e32, mf2, ta, ma
+; CHECK-F-NEXT:    vsetvli a0, zero, e32, mf2, ta, ma
 ; CHECK-F-NEXT:    fsrmi a0, 1
-; CHECK-F-NEXT:    vfncvt.f.xu.w v10, v8
-; CHECK-F-NEXT:    vsrl.vi v8, v10, 23
-; CHECK-F-NEXT:    vwsubu.wv v9, v9, v8
+; CHECK-F-NEXT:    vfncvt.f.xu.w v9, v8
+; CHECK-F-NEXT:    vsrl.vi v8, v9, 23
+; CHECK-F-NEXT:    vsetvli zero, zero, e64, m1, ta, ma
+; CHECK-F-NEXT:    vzext.vf2 v9, v8
+; CHECK-F-NEXT:    li a1, 190
+; CHECK-F-NEXT:    vrsub.vx v8, v9, a1
 ; CHECK-F-NEXT:    fsrm a0
-; CHECK-F-NEXT:    vmv1r.v v8, v9
 ; CHECK-F-NEXT:    ret
 ;
 ; CHECK-D-LABEL: ctlz_zero_undef_nxv1i64:
@@ -2973,16 +2968,15 @@ define <vscale x 2 x i64> @ctlz_zero_undef_nxv2i64(<vscale x 2 x i64> %va) {
 ;
 ; CHECK-F-LABEL: ctlz_zero_undef_nxv2i64:
 ; CHECK-F:       # %bb.0:
-; CHECK-F-NEXT:    li a0, 190
-; CHECK-F-NEXT:    vsetvli a1, zero, e64, m2, ta, ma
-; CHECK-F-NEXT:    vmv.v.x v10, a0
-; CHECK-F-NEXT:    vsetvli zero, zero, e32, m1, ta, ma
+; CHECK-F-NEXT:    vsetvli a0, zero, e32, m1, ta, ma
 ; CHECK-F-NEXT:    fsrmi a0, 1
-; CHECK-F-NEXT:    vfncvt.f.xu.w v12, v8
-; CHECK-F-NEXT:    vsrl.vi v8, v12, 23
-; CHECK-F-NEXT:    vwsubu.wv v10, v10, v8
+; CHECK-F-NEXT:    vfncvt.f.xu.w v10, v8
+; CHECK-F-NEXT:    vsrl.vi v8, v10, 23
+; CHECK-F-NEXT:    vsetvli zero, zero, e64, m2, ta, ma
+; CHECK-F-NEXT:    vzext.vf2 v10, v8
+; CHECK-F-NEXT:    li a1, 190
+; CHECK-F-NEXT:    vrsub.vx v8, v10, a1
 ; CHECK-F-NEXT:    fsrm a0
-; CHECK-F-NEXT:    vmv2r.v v8, v10
 ; CHECK-F-NEXT:    ret
 ;
 ; CHECK-D-LABEL: ctlz_zero_undef_nxv2i64:
@@ -3109,16 +3103,15 @@ define <vscale x 4 x i64> @ctlz_zero_undef_nxv4i64(<vscale x 4 x i64> %va) {
 ;
 ; CHECK-F-LABEL: ctlz_zero_undef_nxv4i64:
 ; CHECK-F:       # %bb.0:
-; CHECK-F-NEXT:    li a0, 190
-; CHECK-F-NEXT:    vsetvli a1, zero, e64, m4, ta, ma
-; CHECK-F-NEXT:    vmv.v.x v12, a0
-; CHECK-F-NEXT:    vsetvli zero, zero, e32, m2, ta, ma
+; CHECK-F-NEXT:    vsetvli a0, zero, e32, m2, ta, ma
 ; CHECK-F-NEXT:    fsrmi a0, 1
-; CHECK-F-NEXT:    vfncvt.f.xu.w v16, v8
-; CHECK-F-NEXT:    vsrl.vi v8, v16, 23
-; CHECK-F-NEXT:    vwsubu.wv v12, v12, v8
+; CHECK-F-NEXT:    vfncvt.f.xu.w v12, v8
+; CHECK-F-NEXT:    vsrl.vi v8, v12, 23
+; CHECK-F-NEXT:    vsetvli zero, zero, e64, m4, ta, ma
+; CHECK-F-NEXT:    vzext.vf2 v12, v8
+; CHECK-F-NEXT:    li a1, 190
+; CHECK-F-NEXT:    vrsub.vx v8, v12, a1
 ; CHECK-F-NEXT:    fsrm a0
-; CHECK-F-NEXT:    vmv4r.v v8, v12
 ; CHECK-F-NEXT:    ret
 ;
 ; CHECK-D-LABEL: ctlz_zero_undef_nxv4i64:
@@ -3245,15 +3238,14 @@ define <vscale x 8 x i64> @ctlz_zero_undef_nxv8i64(<vscale x 8 x i64> %va) {
 ;
 ; CHECK-F-LABEL: ctlz_zero_undef_nxv8i64:
 ; CHECK-F:       # %bb.0:
-; CHECK-F-NEXT:    vmv8r.v v16, v8
-; CHECK-F-NEXT:    li a0, 190
-; CHECK-F-NEXT:    vsetvli a1, zero, e64, m8, ta, ma
-; CHECK-F-NEXT:    vmv.v.x v8, a0
-; CHECK-F-NEXT:    vsetvli zero, zero, e32, m4, ta, ma
+; CHECK-F-NEXT:    vsetvli a0, zero, e32, m4, ta, ma
 ; CHECK-F-NEXT:    fsrmi a0, 1
-; CHECK-F-NEXT:    vfncvt.f.xu.w v24, v16
-; CHECK-F-NEXT:    vsrl.vi v16, v24, 23
-; CHECK-F-NEXT:    vwsubu.wv v8, v8, v16
+; CHECK-F-NEXT:    vfncvt.f.xu.w v16, v8
+; CHECK-F-NEXT:    vsrl.vi v8, v16, 23
+; CHECK-F-NEXT:    vsetvli zero, zero, e64, m8, ta, ma
+; CHECK-F-NEXT:    vzext.vf2 v16, v8
+; CHECK-F-NEXT:    li a1, 190
+; CHECK-F-NEXT:    vrsub.vx v8, v16, a1
 ; CHECK-F-NEXT:    fsrm a0
 ; CHECK-F-NEXT:    ret
 ;

diff  --git a/llvm/test/CodeGen/RISCV/rvv/vscale-vw-web-simplification.ll b/llvm/test/CodeGen/RISCV/rvv/vscale-vw-web-simplification.ll
index fe605d5ca6f99b..d99e3a7fe690aa 100644
--- a/llvm/test/CodeGen/RISCV/rvv/vscale-vw-web-simplification.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/vscale-vw-web-simplification.ll
@@ -10,12 +10,14 @@
 ; RUN: llc -mtriple=riscv32 -mattr=+v -verify-machineinstrs %s -o - | FileCheck %s --check-prefixes=FOLDING
 ; RUN: llc -mtriple=riscv64 -mattr=+v -verify-machineinstrs %s -o - | FileCheck %s --check-prefixes=FOLDING
 
+; FIXME: We should use vwadd/vwsub/vwmul instructions.
 
 ; Check that the scalable vector add/sub/mul operations are all promoted into their
 ; vw counterpart when the folding of the web size is increased to 3.
 ; We need the web size to be at least 3 for the folding to happen, because
 ; %c has 3 uses.
 ; see https://github.com/llvm/llvm-project/pull/72340
+; FIXME: We don't currently use widening instructions.
 define <vscale x 2 x i16> @vwop_vscale_sext_multiple_users(ptr %x, ptr %y, ptr %z) {
 ; NO_FOLDING-LABEL: vwop_vscale_sext_multiple_users:
 ; NO_FOLDING:       # %bb.0:
@@ -35,16 +37,18 @@ define <vscale x 2 x i16> @vwop_vscale_sext_multiple_users(ptr %x, ptr %y, ptr %
 ;
 ; FOLDING-LABEL: vwop_vscale_sext_multiple_users:
 ; FOLDING:       # %bb.0:
-; FOLDING-NEXT:    vsetvli a3, zero, e8, mf4, ta, ma
+; FOLDING-NEXT:    vsetvli a3, zero, e16, mf2, ta, ma
 ; FOLDING-NEXT:    vle8.v v8, (a0)
 ; FOLDING-NEXT:    vle8.v v9, (a1)
 ; FOLDING-NEXT:    vle8.v v10, (a2)
-; FOLDING-NEXT:    vwmul.vv v11, v8, v9
-; FOLDING-NEXT:    vwadd.vv v9, v8, v10
-; FOLDING-NEXT:    vwsub.vv v12, v8, v10
-; FOLDING-NEXT:    vsetvli zero, zero, e16, mf2, ta, ma
-; FOLDING-NEXT:    vor.vv v8, v11, v9
-; FOLDING-NEXT:    vor.vv v8, v8, v12
+; FOLDING-NEXT:    vsext.vf2 v11, v8
+; FOLDING-NEXT:    vsext.vf2 v8, v9
+; FOLDING-NEXT:    vsext.vf2 v9, v10
+; FOLDING-NEXT:    vmul.vv v8, v11, v8
+; FOLDING-NEXT:    vadd.vv v10, v11, v9
+; FOLDING-NEXT:    vsub.vv v9, v11, v9
+; FOLDING-NEXT:    vor.vv v8, v8, v10
+; FOLDING-NEXT:    vor.vv v8, v8, v9
 ; FOLDING-NEXT:    ret
   %a = load <vscale x 2 x i8>, ptr %x
   %b = load <vscale x 2 x i8>, ptr %y
@@ -81,16 +85,18 @@ define <vscale x 2 x i16> @vwop_vscale_zext_multiple_users(ptr %x, ptr %y, ptr %
 ;
 ; FOLDING-LABEL: vwop_vscale_zext_multiple_users:
 ; FOLDING:       # %bb.0:
-; FOLDING-NEXT:    vsetvli a3, zero, e8, mf4, ta, ma
+; FOLDING-NEXT:    vsetvli a3, zero, e16, mf2, ta, ma
 ; FOLDING-NEXT:    vle8.v v8, (a0)
 ; FOLDING-NEXT:    vle8.v v9, (a1)
 ; FOLDING-NEXT:    vle8.v v10, (a2)
-; FOLDING-NEXT:    vwmulu.vv v11, v8, v9
-; FOLDING-NEXT:    vwaddu.vv v9, v8, v10
-; FOLDING-NEXT:    vwsubu.vv v12, v8, v10
-; FOLDING-NEXT:    vsetvli zero, zero, e16, mf2, ta, ma
-; FOLDING-NEXT:    vor.vv v8, v11, v9
-; FOLDING-NEXT:    vor.vv v8, v8, v12
+; FOLDING-NEXT:    vzext.vf2 v11, v8
+; FOLDING-NEXT:    vzext.vf2 v8, v9
+; FOLDING-NEXT:    vzext.vf2 v9, v10
+; FOLDING-NEXT:    vmul.vv v8, v11, v8
+; FOLDING-NEXT:    vadd.vv v10, v11, v9
+; FOLDING-NEXT:    vsub.vv v9, v11, v9
+; FOLDING-NEXT:    vor.vv v8, v8, v10
+; FOLDING-NEXT:    vor.vv v8, v8, v9
 ; FOLDING-NEXT:    ret
   %a = load <vscale x 2 x i8>, ptr %x
   %b = load <vscale x 2 x i8>, ptr %y


        


More information about the llvm-commits mailing list