[llvm] Recreate "[RISCV][ISel] Combine scalable vector add/sub/mul with zero/sign extension. (#72340)" (PR #76785)

via llvm-commits llvm-commits at lists.llvm.org
Tue Jan 2 23:15:03 PST 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-backend-risc-v

Author: Chia (sun-jacobi)

<details>
<summary>Changes</summary>

This recreates the #<!-- -->72340 reverted by 4e347b4.

---

Patch is 50.94 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/76785.diff


3 Files Affected:

- (modified) llvm/lib/Target/RISCV/RISCVISelLowering.cpp (+172-61) 
- (modified) llvm/test/CodeGen/RISCV/rvv/ctlz-sdnode.ll (+68-60) 
- (modified) llvm/test/CodeGen/RISCV/rvv/vscale-vw-web-simplification.ll (+402-34) 


``````````diff
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index 27bb69dc9868c8..2fb79c81b7f169 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -1374,8 +1374,8 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
   setPrefLoopAlignment(Subtarget.getPrefLoopAlignment());
 
   setTargetDAGCombine({ISD::INTRINSIC_VOID, ISD::INTRINSIC_W_CHAIN,
-                       ISD::INTRINSIC_WO_CHAIN, ISD::ADD, ISD::SUB, ISD::AND,
-                       ISD::OR, ISD::XOR, ISD::SETCC, ISD::SELECT});
+                       ISD::INTRINSIC_WO_CHAIN, ISD::ADD, ISD::SUB, ISD::MUL,
+                       ISD::AND, ISD::OR, ISD::XOR, ISD::SETCC, ISD::SELECT});
   if (Subtarget.is64Bit())
     setTargetDAGCombine(ISD::SRA);
 
@@ -12850,9 +12850,9 @@ struct CombineResult;
 
 /// Helper class for folding sign/zero extensions.
 /// In particular, this class is used for the following combines:
-/// add_vl -> vwadd(u) | vwadd(u)_w
-/// sub_vl -> vwsub(u) | vwsub(u)_w
-/// mul_vl -> vwmul(u) | vwmul_su
+/// add | add_vl -> vwadd(u) | vwadd(u)_w
+/// sub | sub_vl -> vwsub(u) | vwsub(u)_w
+/// mul | mul_vl -> vwmul(u) | vwmul_su
 ///
 /// An object of this class represents an operand of the operation we want to
 /// combine.
@@ -12897,6 +12897,8 @@ struct NodeExtensionHelper {
   /// E.g., for zext(a), this would return a.
   SDValue getSource() const {
     switch (OrigOperand.getOpcode()) {
+    case ISD::ZERO_EXTEND:
+    case ISD::SIGN_EXTEND:
     case RISCVISD::VSEXT_VL:
     case RISCVISD::VZEXT_VL:
       return OrigOperand.getOperand(0);
@@ -12913,7 +12915,8 @@ struct NodeExtensionHelper {
   /// Get or create a value that can feed \p Root with the given extension \p
   /// SExt. If \p SExt is std::nullopt, this returns the source of this operand.
   /// \see ::getSource().
-  SDValue getOrCreateExtendedOp(const SDNode *Root, SelectionDAG &DAG,
+  SDValue getOrCreateExtendedOp(SDNode *Root, SelectionDAG &DAG,
+                                const RISCVSubtarget &Subtarget,
                                 std::optional<bool> SExt) const {
     if (!SExt.has_value())
       return OrigOperand;
@@ -12928,8 +12931,10 @@ struct NodeExtensionHelper {
 
     // If we need an extension, we should be changing the type.
     SDLoc DL(Root);
-    auto [Mask, VL] = getMaskAndVL(Root);
+    auto [Mask, VL] = getMaskAndVL(Root, DAG, Subtarget);
     switch (OrigOperand.getOpcode()) {
+    case ISD::ZERO_EXTEND:
+    case ISD::SIGN_EXTEND:
     case RISCVISD::VSEXT_VL:
     case RISCVISD::VZEXT_VL:
       return DAG.getNode(ExtOpc, DL, NarrowVT, Source, Mask, VL);
@@ -12969,12 +12974,15 @@ struct NodeExtensionHelper {
   /// \pre \p Opcode represents a supported root (\see ::isSupportedRoot()).
   static unsigned getSameExtensionOpcode(unsigned Opcode, bool IsSExt) {
     switch (Opcode) {
+    case ISD::ADD:
     case RISCVISD::ADD_VL:
     case RISCVISD::VWADD_W_VL:
     case RISCVISD::VWADDU_W_VL:
       return IsSExt ? RISCVISD::VWADD_VL : RISCVISD::VWADDU_VL;
+    case ISD::MUL:
     case RISCVISD::MUL_VL:
       return IsSExt ? RISCVISD::VWMUL_VL : RISCVISD::VWMULU_VL;
+    case ISD::SUB:
     case RISCVISD::SUB_VL:
     case RISCVISD::VWSUB_W_VL:
     case RISCVISD::VWSUBU_W_VL:
@@ -12987,7 +12995,8 @@ struct NodeExtensionHelper {
   /// Get the opcode to materialize \p Opcode(sext(a), zext(b)) ->
   /// newOpcode(a, b).
   static unsigned getSUOpcode(unsigned Opcode) {
-    assert(Opcode == RISCVISD::MUL_VL && "SU is only supported for MUL");
+    assert((Opcode == RISCVISD::MUL_VL || Opcode == ISD::MUL) &&
+           "SU is only supported for MUL");
     return RISCVISD::VWMULSU_VL;
   }
 
@@ -12995,8 +13004,10 @@ struct NodeExtensionHelper {
   /// newOpcode(a, b).
   static unsigned getWOpcode(unsigned Opcode, bool IsSExt) {
     switch (Opcode) {
+    case ISD::ADD:
     case RISCVISD::ADD_VL:
       return IsSExt ? RISCVISD::VWADD_W_VL : RISCVISD::VWADDU_W_VL;
+    case ISD::SUB:
     case RISCVISD::SUB_VL:
       return IsSExt ? RISCVISD::VWSUB_W_VL : RISCVISD::VWSUBU_W_VL;
     default:
@@ -13006,19 +13017,44 @@ struct NodeExtensionHelper {
 
   using CombineToTry = std::function<std::optional<CombineResult>(
       SDNode * /*Root*/, const NodeExtensionHelper & /*LHS*/,
-      const NodeExtensionHelper & /*RHS*/)>;
+      const NodeExtensionHelper & /*RHS*/, SelectionDAG &,
+      const RISCVSubtarget &)>;
 
   /// Check if this node needs to be fully folded or extended for all users.
   bool needToPromoteOtherUsers() const { return EnforceOneUse; }
 
   /// Helper method to set the various fields of this struct based on the
   /// type of \p Root.
-  void fillUpExtensionSupport(SDNode *Root, SelectionDAG &DAG) {
+  void fillUpExtensionSupport(SDNode *Root, SelectionDAG &DAG,
+                              const RISCVSubtarget &Subtarget) {
     SupportsZExt = false;
     SupportsSExt = false;
     EnforceOneUse = true;
     CheckMask = true;
-    switch (OrigOperand.getOpcode()) {
+    unsigned Opc = OrigOperand.getOpcode();
+    switch (Opc) {
+    case ISD::ZERO_EXTEND:
+    case ISD::SIGN_EXTEND: {
+      MVT VT = OrigOperand.getSimpleValueType();
+      if (!VT.isVector())
+        break;
+
+      MVT NarrowVT = OrigOperand.getOperand(0)->getSimpleValueType(0);
+
+      unsigned ScalarBits = VT.getScalarSizeInBits();
+      unsigned NarrowScalarBits = NarrowVT.getScalarSizeInBits();
+
+      // Ensure the extension's semantic is equivalent to rvv vzext or vsext.
+      if (ScalarBits != NarrowScalarBits * 2)
+        break;
+
+      SupportsZExt = Opc == ISD::ZERO_EXTEND;
+      SupportsSExt = Opc == ISD::SIGN_EXTEND;
+
+      SDLoc DL(Root);
+      std::tie(Mask, VL) = getDefaultScalableVLOps(VT, DL, DAG, Subtarget);
+      break;
+    }
     case RISCVISD::VZEXT_VL:
       SupportsZExt = true;
       Mask = OrigOperand.getOperand(1);
@@ -13074,8 +13110,16 @@ struct NodeExtensionHelper {
   }
 
   /// Check if \p Root supports any extension folding combines.
-  static bool isSupportedRoot(const SDNode *Root) {
+  static bool isSupportedRoot(const SDNode *Root, const SelectionDAG &DAG) {
     switch (Root->getOpcode()) {
+    case ISD::ADD:
+    case ISD::SUB:
+    case ISD::MUL: {
+      const TargetLowering &TLI = DAG.getTargetLoweringInfo();
+      if (!TLI.isTypeLegal(Root->getValueType(0)))
+        return false;
+      return Root->getValueType(0).isScalableVector();
+    }
     case RISCVISD::ADD_VL:
     case RISCVISD::MUL_VL:
     case RISCVISD::VWADD_W_VL:
@@ -13090,9 +13134,10 @@ struct NodeExtensionHelper {
   }
 
   /// Build a NodeExtensionHelper for \p Root.getOperand(\p OperandIdx).
-  NodeExtensionHelper(SDNode *Root, unsigned OperandIdx, SelectionDAG &DAG) {
-    assert(isSupportedRoot(Root) && "Trying to build an helper with an "
-                                    "unsupported root");
+  NodeExtensionHelper(SDNode *Root, unsigned OperandIdx, SelectionDAG &DAG,
+                      const RISCVSubtarget &Subtarget) {
+    assert(isSupportedRoot(Root, DAG) && "Trying to build an helper with an "
+                                         "unsupported root");
     assert(OperandIdx < 2 && "Requesting something else than LHS or RHS");
     OrigOperand = Root->getOperand(OperandIdx);
 
@@ -13108,7 +13153,7 @@ struct NodeExtensionHelper {
         SupportsZExt =
             Opc == RISCVISD::VWADDU_W_VL || Opc == RISCVISD::VWSUBU_W_VL;
         SupportsSExt = !SupportsZExt;
-        std::tie(Mask, VL) = getMaskAndVL(Root);
+        std::tie(Mask, VL) = getMaskAndVL(Root, DAG, Subtarget);
         CheckMask = true;
         // There's no existing extension here, so we don't have to worry about
         // making sure it gets removed.
@@ -13117,7 +13162,7 @@ struct NodeExtensionHelper {
       }
       [[fallthrough]];
     default:
-      fillUpExtensionSupport(Root, DAG);
+      fillUpExtensionSupport(Root, DAG, Subtarget);
       break;
     }
   }
@@ -13133,14 +13178,27 @@ struct NodeExtensionHelper {
   }
 
   /// Helper function to get the Mask and VL from \p Root.
-  static std::pair<SDValue, SDValue> getMaskAndVL(const SDNode *Root) {
-    assert(isSupportedRoot(Root) && "Unexpected root");
-    return std::make_pair(Root->getOperand(3), Root->getOperand(4));
+  static std::pair<SDValue, SDValue>
+  getMaskAndVL(const SDNode *Root, SelectionDAG &DAG,
+               const RISCVSubtarget &Subtarget) {
+    assert(isSupportedRoot(Root, DAG) && "Unexpected root");
+    switch (Root->getOpcode()) {
+    case ISD::ADD:
+    case ISD::SUB:
+    case ISD::MUL: {
+      SDLoc DL(Root);
+      MVT VT = Root->getSimpleValueType(0);
+      return getDefaultScalableVLOps(VT, DL, DAG, Subtarget);
+    }
+    default:
+      return std::make_pair(Root->getOperand(3), Root->getOperand(4));
+    }
   }
 
   /// Check if the Mask and VL of this operand are compatible with \p Root.
-  bool areVLAndMaskCompatible(const SDNode *Root) const {
-    auto [Mask, VL] = getMaskAndVL(Root);
+  bool areVLAndMaskCompatible(SDNode *Root, SelectionDAG &DAG,
+                              const RISCVSubtarget &Subtarget) const {
+    auto [Mask, VL] = getMaskAndVL(Root, DAG, Subtarget);
     return isMaskCompatible(Mask) && isVLCompatible(VL);
   }
 
@@ -13148,11 +13206,14 @@ struct NodeExtensionHelper {
   /// foldings that are supported by this class.
   static bool isCommutative(const SDNode *N) {
     switch (N->getOpcode()) {
+    case ISD::ADD:
+    case ISD::MUL:
     case RISCVISD::ADD_VL:
     case RISCVISD::MUL_VL:
     case RISCVISD::VWADD_W_VL:
     case RISCVISD::VWADDU_W_VL:
       return true;
+    case ISD::SUB:
     case RISCVISD::SUB_VL:
     case RISCVISD::VWSUB_W_VL:
     case RISCVISD::VWSUBU_W_VL:
@@ -13197,14 +13258,25 @@ struct CombineResult {
   /// Return a value that uses TargetOpcode and that can be used to replace
   /// Root.
   /// The actual replacement is *not* done in that method.
-  SDValue materialize(SelectionDAG &DAG) const {
+  SDValue materialize(SelectionDAG &DAG,
+                      const RISCVSubtarget &Subtarget) const {
     SDValue Mask, VL, Merge;
-    std::tie(Mask, VL) = NodeExtensionHelper::getMaskAndVL(Root);
-    Merge = Root->getOperand(2);
+    std::tie(Mask, VL) =
+        NodeExtensionHelper::getMaskAndVL(Root, DAG, Subtarget);
+    switch (Root->getOpcode()) {
+    default:
+      Merge = Root->getOperand(2);
+      break;
+    case ISD::ADD:
+    case ISD::SUB:
+    case ISD::MUL:
+      Merge = DAG.getUNDEF(Root->getValueType(0));
+      break;
+    }
     return DAG.getNode(TargetOpcode, SDLoc(Root), Root->getValueType(0),
-                       LHS.getOrCreateExtendedOp(Root, DAG, SExtLHS),
-                       RHS.getOrCreateExtendedOp(Root, DAG, SExtRHS), Merge,
-                       Mask, VL);
+                       LHS.getOrCreateExtendedOp(Root, DAG, Subtarget, SExtLHS),
+                       RHS.getOrCreateExtendedOp(Root, DAG, Subtarget, SExtRHS),
+                       Merge, Mask, VL);
   }
 };
 
@@ -13221,15 +13293,16 @@ struct CombineResult {
 static std::optional<CombineResult>
 canFoldToVWWithSameExtensionImpl(SDNode *Root, const NodeExtensionHelper &LHS,
                                  const NodeExtensionHelper &RHS, bool AllowSExt,
-                                 bool AllowZExt) {
+                                 bool AllowZExt, SelectionDAG &DAG,
+                                 const RISCVSubtarget &Subtarget) {
   assert((AllowSExt || AllowZExt) && "Forgot to set what you want?");
-  if (!LHS.areVLAndMaskCompatible(Root) || !RHS.areVLAndMaskCompatible(Root))
+  if (!LHS.areVLAndMaskCompatible(Root, DAG, Subtarget) ||
+      !RHS.areVLAndMaskCompatible(Root, DAG, Subtarget))
     return std::nullopt;
   if (AllowZExt && LHS.SupportsZExt && RHS.SupportsZExt)
     return CombineResult(NodeExtensionHelper::getSameExtensionOpcode(
                              Root->getOpcode(), /*IsSExt=*/false),
-                         Root, LHS, /*SExtLHS=*/false, RHS,
-                         /*SExtRHS=*/false);
+                         Root, LHS, /*SExtLHS=*/false, RHS, /*SExtRHS=*/false);
   if (AllowSExt && LHS.SupportsSExt && RHS.SupportsSExt)
     return CombineResult(NodeExtensionHelper::getSameExtensionOpcode(
                              Root->getOpcode(), /*IsSExt=*/true),
@@ -13246,9 +13319,10 @@ canFoldToVWWithSameExtensionImpl(SDNode *Root, const NodeExtensionHelper &LHS,
 /// can be used to apply the pattern.
 static std::optional<CombineResult>
 canFoldToVWWithSameExtension(SDNode *Root, const NodeExtensionHelper &LHS,
-                             const NodeExtensionHelper &RHS) {
+                             const NodeExtensionHelper &RHS, SelectionDAG &DAG,
+                             const RISCVSubtarget &Subtarget) {
   return canFoldToVWWithSameExtensionImpl(Root, LHS, RHS, /*AllowSExt=*/true,
-                                          /*AllowZExt=*/true);
+                                          /*AllowZExt=*/true, DAG, Subtarget);
 }
 
 /// Check if \p Root follows a pattern Root(LHS, ext(RHS))
@@ -13257,8 +13331,9 @@ canFoldToVWWithSameExtension(SDNode *Root, const NodeExtensionHelper &LHS,
 /// can be used to apply the pattern.
 static std::optional<CombineResult>
 canFoldToVW_W(SDNode *Root, const NodeExtensionHelper &LHS,
-              const NodeExtensionHelper &RHS) {
-  if (!RHS.areVLAndMaskCompatible(Root))
+              const NodeExtensionHelper &RHS, SelectionDAG &DAG,
+              const RISCVSubtarget &Subtarget) {
+  if (!RHS.areVLAndMaskCompatible(Root, DAG, Subtarget))
     return std::nullopt;
 
   // FIXME: Is it useful to form a vwadd.wx or vwsub.wx if it removes a scalar
@@ -13282,9 +13357,10 @@ canFoldToVW_W(SDNode *Root, const NodeExtensionHelper &LHS,
 /// can be used to apply the pattern.
 static std::optional<CombineResult>
 canFoldToVWWithSEXT(SDNode *Root, const NodeExtensionHelper &LHS,
-                    const NodeExtensionHelper &RHS) {
+                    const NodeExtensionHelper &RHS, SelectionDAG &DAG,
+                    const RISCVSubtarget &Subtarget) {
   return canFoldToVWWithSameExtensionImpl(Root, LHS, RHS, /*AllowSExt=*/true,
-                                          /*AllowZExt=*/false);
+                                          /*AllowZExt=*/false, DAG, Subtarget);
 }
 
 /// Check if \p Root follows a pattern Root(zext(LHS), zext(RHS))
@@ -13293,9 +13369,10 @@ canFoldToVWWithSEXT(SDNode *Root, const NodeExtensionHelper &LHS,
 /// can be used to apply the pattern.
 static std::optional<CombineResult>
 canFoldToVWWithZEXT(SDNode *Root, const NodeExtensionHelper &LHS,
-                    const NodeExtensionHelper &RHS) {
+                    const NodeExtensionHelper &RHS, SelectionDAG &DAG,
+                    const RISCVSubtarget &Subtarget) {
   return canFoldToVWWithSameExtensionImpl(Root, LHS, RHS, /*AllowSExt=*/false,
-                                          /*AllowZExt=*/true);
+                                          /*AllowZExt=*/true, DAG, Subtarget);
 }
 
 /// Check if \p Root follows a pattern Root(sext(LHS), zext(RHS))
@@ -13304,10 +13381,13 @@ canFoldToVWWithZEXT(SDNode *Root, const NodeExtensionHelper &LHS,
 /// can be used to apply the pattern.
 static std::optional<CombineResult>
 canFoldToVW_SU(SDNode *Root, const NodeExtensionHelper &LHS,
-               const NodeExtensionHelper &RHS) {
+               const NodeExtensionHelper &RHS, SelectionDAG &DAG,
+               const RISCVSubtarget &Subtarget) {
+
   if (!LHS.SupportsSExt || !RHS.SupportsZExt)
     return std::nullopt;
-  if (!LHS.areVLAndMaskCompatible(Root) || !RHS.areVLAndMaskCompatible(Root))
+  if (!LHS.areVLAndMaskCompatible(Root, DAG, Subtarget) ||
+      !RHS.areVLAndMaskCompatible(Root, DAG, Subtarget))
     return std::nullopt;
   return CombineResult(NodeExtensionHelper::getSUOpcode(Root->getOpcode()),
                        Root, LHS, /*SExtLHS=*/true, RHS, /*SExtRHS=*/false);
@@ -13317,6 +13397,8 @@ SmallVector<NodeExtensionHelper::CombineToTry>
 NodeExtensionHelper::getSupportedFoldings(const SDNode *Root) {
   SmallVector<CombineToTry> Strategies;
   switch (Root->getOpcode()) {
+  case ISD::ADD:
+  case ISD::SUB:
   case RISCVISD::ADD_VL:
   case RISCVISD::SUB_VL:
     // add|sub -> vwadd(u)|vwsub(u)
@@ -13324,6 +13406,7 @@ NodeExtensionHelper::getSupportedFoldings(const SDNode *Root) {
     // add|sub -> vwadd(u)_w|vwsub(u)_w
     Strategies.push_back(canFoldToVW_W);
     break;
+  case ISD::MUL:
   case RISCVISD::MUL_VL:
     // mul -> vwmul(u)
     Strategies.push_back(canFoldToVWWithSameExtension);
@@ -13354,12 +13437,14 @@ NodeExtensionHelper::getSupportedFoldings(const SDNode *Root) {
 /// mul_vl -> vwmul(u) | vwmul_su
 /// vwadd_w(u) -> vwadd(u)
 /// vwub_w(u) -> vwadd(u)
-static SDValue
-combineBinOp_VLToVWBinOp_VL(SDNode *N, TargetLowering::DAGCombinerInfo &DCI) {
+static SDValue combineBinOp_VLToVWBinOp_VL(SDNode *N,
+                                           TargetLowering::DAGCombinerInfo &DCI,
+                                           const RISCVSubtarget &Subtarget) {
   SelectionDAG &DAG = DCI.DAG;
 
-  assert(NodeExtensionHelper::isSupportedRoot(N) &&
-         "Shouldn't have called this method");
+  if (!NodeExtensionHelper::isSupportedRoot(N, DAG))
+    return SDValue();
+
   SmallVector<SDNode *> Worklist;
   SmallSet<SDNode *, 8> Inserted;
   Worklist.push_back(N);
@@ -13368,11 +13453,11 @@ combineBinOp_VLToVWBinOp_VL(SDNode *N, TargetLowering::DAGCombinerInfo &DCI) {
 
   while (!Worklist.empty()) {
     SDNode *Root = Worklist.pop_back_val();
-    if (!NodeExtensionHelper::isSupportedRoot(Root))
+    if (!NodeExtensionHelper::isSupportedRoot(Root, DAG))
       return SDValue();
 
-    NodeExtensionHelper LHS(N, 0, DAG);
-    NodeExtensionHelper RHS(N, 1, DAG);
+    NodeExtensionHelper LHS(N, 0, DAG, Subtarget);
+    NodeExtensionHelper RHS(N, 1, DAG, Subtarget);
     auto AppendUsersIfNeeded = [&Worklist,
                                 &Inserted](const NodeExtensionHelper &Op) {
       if (Op.needToPromoteOtherUsers()) {
@@ -13399,7 +13484,8 @@ combineBinOp_VLToVWBinOp_VL(SDNode *N, TargetLowering::DAGCombinerInfo &DCI) {
 
       for (NodeExtensionHelper::CombineToTry FoldingStrategy :
            FoldingStrategies) {
-        std::optional<CombineResult> Res = FoldingStrategy(N, LHS, RHS);
+        std::optional<CombineResult> Res =
+            FoldingStrategy(N, LHS, RHS, DAG, Subtarget);
         if (Res) {
           Matched = true;
           CombinesToApply.push_back(*Res);
@@ -13428,7 +13514,7 @@ combineBinOp_VLToVWBinOp_VL(SDNode *N, TargetLowering::DAGCombinerInfo &DCI) {
   SmallVector<std::pair<SDValue, SDValue>> ValuesToReplace;
   ValuesToReplace.reserve(CombinesToApply.size());
   for (CombineResult Res : CombinesToApply) {
-    SDValue NewValue = Res.materialize(DAG);
+    SDValue NewValue = Res.materialize(DAG, Subtarget);
     if (!InputRootReplacement) {
       assert(Res.Root == N &&
              "First element is expected to be the current node");
@@ -14700,13 +14786,20 @@ static SDValue performCONCAT_VECTORSCombine(SDNode *N, SelectionDAG &DAG,
 
 static SDValue combineToVWMACC(SDNode *N, SelectionDAG &DAG,
                                const RISCVSubtarget &Subtarget) {
-  assert(N->getOpcode() == RISCVISD::ADD_VL);
+
+  assert(N->getOpcode() == RISCVISD::ADD_VL || N->getOpcode() == ISD::ADD);
+
+  if (N->getValueType(0).isFixedLengthVector())
+    return SDValue();
+
   SDValue Addend = N->getOperand(0);
   SDValue MulOp = N->getOperand(1);
-  SDValue AddMergeOp = N->getOperand(2);
 
-  if (!AddMergeOp.isUndef())
-    return SDValue();
+  if (N->getOpcode() == RISCVISD::ADD_VL) {
+    SDValue AddMergeOp = N->getOperand(2);
+    if (!AddMergeOp.isUndef())
+      return SDValue();
+  }
 
   auto IsVWMulOpc = [](unsigned Opc) {
     switch (Opc) {
@@ -14730,8 +14823,16 @@ static SDValue combineToVWMACC(SDNode *N, SelectionDAG &DAG,
   if (!MulMergeOp.isUndef())
     return SDValue();
 
-  SDValue AddMask = N->getOperand(3);
-  SDValue AddVL = N->getOperand(4);
+  auto [AddMask, AddVL] = [](SDNode *N, SelectionDAG &DAG,
+                             const RISCVSubtarget &Subtarget) {
+    if (N->getOpcode() == ISD::ADD) {
+      SDLoc DL(N);
+      return getDefaultScalableVLOps(N->getSimpleValueType(0), DL, DAG,
+                                     Subtarget);
+    }
+    re...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/76785


More information about the llvm-commits mailing list