[llvm] 7bc2cd6 - [VP][DAGCombiner] Introduce generalized pattern match for vp sdnodes.

Yeting Kuo via llvm-commits llvm-commits at lists.llvm.org
Tue Feb 7 21:45:42 PST 2023


Author: Yeting Kuo
Date: 2023-02-08T13:45:35+08:00
New Revision: 7bc2cd614ec479d5ffacba6e80ce849cd922ffdf

URL: https://github.com/llvm/llvm-project/commit/7bc2cd614ec479d5ffacba6e80ce849cd922ffdf
DIFF: https://github.com/llvm/llvm-project/commit/7bc2cd614ec479d5ffacba6e80ce849cd922ffdf.diff

LOG: [VP][DAGCombiner] Introduce generalized pattern match for vp sdnodes.

The patch tries to solve duplicated combine work for vp sdnodes. The idea is to
introduce MatchConext that verifies specific patterns and generate specific node
infromation. There is two MatchConext in DAGCombiner. EmptyMatcher is for
normal nodes and VPMatcher is for vp nodes.

The idea of this patch is come form Simon Moll's proposal [0]. I only fixed some
minor issues and added few new features in this patch.

[0]: https://github.com/sx-aurora-dev/llvm-project/commit/c38a14484aa2945f3b05369560b65916dd832f76

Reviewed By: craig.topper

Differential Revision: https://reviews.llvm.org/D141891

Added: 
    llvm/test/CodeGen/RISCV/rvv/fold-vp-fadd-and-vp-fmul.ll

Modified: 
    llvm/include/llvm/CodeGen/ISDOpcodes.h
    llvm/include/llvm/IR/VPIntrinsics.def
    llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
    llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/CodeGen/ISDOpcodes.h b/llvm/include/llvm/CodeGen/ISDOpcodes.h
index 157247dfba98d..9bf311f843149 100644
--- a/llvm/include/llvm/CodeGen/ISDOpcodes.h
+++ b/llvm/include/llvm/CodeGen/ISDOpcodes.h
@@ -1346,6 +1346,12 @@ std::optional<unsigned> getVPMaskIdx(unsigned Opcode);
 /// The operand position of the explicit vector length parameter.
 std::optional<unsigned> getVPExplicitVectorLengthIdx(unsigned Opcode);
 
+/// Translate this VP Opcode to its corresponding non-VP Opcode.
+std::optional<unsigned> getBaseOpcodeForVP(unsigned Opcode, bool hasFPExcept);
+
+/// Translate this non-VP Opcode to its corresponding VP Opcode.
+unsigned getVPForBaseOpcode(unsigned Opcode);
+
 //===--------------------------------------------------------------------===//
 /// MemIndexedMode enum - This enum defines the load / store indexed
 /// addressing modes.

diff  --git a/llvm/include/llvm/IR/VPIntrinsics.def b/llvm/include/llvm/IR/VPIntrinsics.def
index 2a1a34e33ea49..359f178ba5622 100644
--- a/llvm/include/llvm/IR/VPIntrinsics.def
+++ b/llvm/include/llvm/IR/VPIntrinsics.def
@@ -106,6 +106,12 @@
 #define VP_PROPERTY_CONSTRAINEDFP(HASROUND, HASEXCEPT, INTRINID)
 #endif
 
+// The intrinsic and/or SDNode has the same function as this ISD Opcode.
+// \p SDOPC      The opcode of the instruction with the same function.
+#ifndef VP_PROPERTY_FUNCTIONAL_SDOPC
+#define VP_PROPERTY_FUNCTIONAL_SDOPC(SDOPC)
+#endif
+
 // Map this VP intrinsic to its canonical functional intrinsic.
 // \p INTRIN     The non-VP intrinsics with the same function.
 #ifndef VP_PROPERTY_FUNCTIONAL_INTRINSIC
@@ -265,27 +271,28 @@ END_REGISTER_VP(vp_fshr, VP_FSHR)
 #error                                                                         \
     "The internal helper macro HELPER_REGISTER_BINARY_FP_VP is already defined!"
 #endif
-#define HELPER_REGISTER_BINARY_FP_VP(OPSUFFIX, VPSD, IROPC)                    \
+#define HELPER_REGISTER_BINARY_FP_VP(OPSUFFIX, VPSD, IROPC, SDOPC)             \
   BEGIN_REGISTER_VP(vp_##OPSUFFIX, 2, 3, VPSD, -1)                             \
   VP_PROPERTY_FUNCTIONAL_OPC(IROPC)                                            \
   VP_PROPERTY_CONSTRAINEDFP(1, 1, experimental_constrained_##OPSUFFIX)         \
+  VP_PROPERTY_FUNCTIONAL_SDOPC(SDOPC)                                          \
   VP_PROPERTY_BINARYOP                                                         \
   END_REGISTER_VP(vp_##OPSUFFIX, VPSD)
 
 // llvm.vp.fadd(x,y,mask,vlen)
-HELPER_REGISTER_BINARY_FP_VP(fadd, VP_FADD, FAdd)
+HELPER_REGISTER_BINARY_FP_VP(fadd, VP_FADD, FAdd, FADD)
 
 // llvm.vp.fsub(x,y,mask,vlen)
-HELPER_REGISTER_BINARY_FP_VP(fsub, VP_FSUB, FSub)
+HELPER_REGISTER_BINARY_FP_VP(fsub, VP_FSUB, FSub, FSUB)
 
 // llvm.vp.fmul(x,y,mask,vlen)
-HELPER_REGISTER_BINARY_FP_VP(fmul, VP_FMUL, FMul)
+HELPER_REGISTER_BINARY_FP_VP(fmul, VP_FMUL, FMul, FMUL)
 
 // llvm.vp.fdiv(x,y,mask,vlen)
-HELPER_REGISTER_BINARY_FP_VP(fdiv, VP_FDIV, FDiv)
+HELPER_REGISTER_BINARY_FP_VP(fdiv, VP_FDIV, FDiv, FDIV)
 
 // llvm.vp.frem(x,y,mask,vlen)
-HELPER_REGISTER_BINARY_FP_VP(frem, VP_FREM, FRem)
+HELPER_REGISTER_BINARY_FP_VP(frem, VP_FREM, FRem, FREM)
 
 #undef HELPER_REGISTER_BINARY_FP_VP
 
@@ -305,6 +312,7 @@ END_REGISTER_VP(vp_sqrt, VP_SQRT)
 // llvm.vp.fma(x,y,z,mask,vlen)
 BEGIN_REGISTER_VP(vp_fma, 3, 4, VP_FMA, -1)
 VP_PROPERTY_CONSTRAINEDFP(1, 1, experimental_constrained_fma)
+VP_PROPERTY_FUNCTIONAL_SDOPC(FMA)
 END_REGISTER_VP(vp_fma, VP_FMA)
 
 // llvm.vp.fmuladd(x,y,z,mask,vlen)
@@ -630,5 +638,6 @@ END_REGISTER_VP(experimental_vp_splice, EXPERIMENTAL_VP_SPLICE)
 #undef VP_PROPERTY_CONSTRAINEDFP
 #undef VP_PROPERTY_FUNCTIONAL_INTRINSIC
 #undef VP_PROPERTY_FUNCTIONAL_OPC
+#undef VP_PROPERTY_FUNCTIONAL_SDOPC
 #undef VP_PROPERTY_MEMOP
 #undef VP_PROPERTY_REDUCTION

diff  --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index 0b7912588f647..d1d6f4ab63874 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -482,6 +482,7 @@ namespace {
     SDValue visitFREEZE(SDNode *N);
     SDValue visitBUILD_PAIR(SDNode *N);
     SDValue visitFADD(SDNode *N);
+    SDValue visitVP_FADD(SDNode *N);
     SDValue visitSTRICT_FADD(SDNode *N);
     SDValue visitFSUB(SDNode *N);
     SDValue visitFMUL(SDNode *N);
@@ -534,6 +535,7 @@ namespace {
     SDValue visitVECREDUCE(SDNode *N);
     SDValue visitVPOp(SDNode *N);
 
+    template <class MatchContextClass>
     SDValue visitFADDForFMACombine(SDNode *N);
     SDValue visitFSUBForFMACombine(SDNode *N);
     SDValue visitFMULForFMADistributiveCombine(SDNode *N);
@@ -847,6 +849,125 @@ class WorklistInserter : public SelectionDAG::DAGUpdateListener {
   void NodeInserted(SDNode *N) override { DC.ConsiderForPruning(N); }
 };
 
+class EmptyMatchContext {
+  SelectionDAG &DAG;
+
+public:
+  EmptyMatchContext(SelectionDAG &DAG, SDNode *Root) : DAG(DAG) {}
+
+  bool match(SDValue OpN, unsigned Opcode) const {
+    return Opcode == OpN->getOpcode();
+  }
+
+  // Same as SelectionDAG::getNode().
+  template <typename... ArgT> SDValue getNode(ArgT &&...Args) {
+    return DAG.getNode(std::forward<ArgT>(Args)...);
+  }
+};
+
+class VPMatchContext {
+  SelectionDAG &DAG;
+  SDNode *Root;
+  SDValue RootMaskOp;
+  SDValue RootVectorLenOp;
+
+public:
+  VPMatchContext(SelectionDAG &DAG, SDNode *Root)
+      : DAG(DAG), Root(Root), RootMaskOp(), RootVectorLenOp() {
+    assert(Root->isVPOpcode());
+    if (auto RootMaskPos = ISD::getVPMaskIdx(Root->getOpcode()))
+      RootMaskOp = Root->getOperand(*RootMaskPos);
+
+    if (auto RootVLenPos =
+            ISD::getVPExplicitVectorLengthIdx(Root->getOpcode()))
+      RootVectorLenOp = Root->getOperand(*RootVLenPos);
+  }
+
+  /// whether \p OpVal is a node that is functionally compatible with the
+  /// NodeType \p Opc
+  bool match(SDValue OpVal, unsigned Opc) const {
+    if (!OpVal->isVPOpcode())
+      return OpVal->getOpcode() == Opc;
+
+    auto BaseOpc = ISD::getBaseOpcodeForVP(OpVal->getOpcode(),
+                                           !OpVal->getFlags().hasNoFPExcept());
+    if (BaseOpc != Opc)
+      return false;
+
+    // Make sure the mask of OpVal is true mask or is same as Root's.
+    unsigned VPOpcode = OpVal->getOpcode();
+    if (auto MaskPos = ISD::getVPMaskIdx(VPOpcode)) {
+      SDValue MaskOp = OpVal.getOperand(*MaskPos);
+      if (RootMaskOp != MaskOp &&
+          !ISD::isConstantSplatVectorAllOnes(MaskOp.getNode()))
+        return false;
+    }
+
+    // Make sure the EVL of OpVal is same as Root's.
+    if (auto VLenPos = ISD::getVPExplicitVectorLengthIdx(VPOpcode))
+      if (RootVectorLenOp != OpVal.getOperand(*VLenPos))
+        return false;
+    return true;
+  }
+
+  // Specialize based on number of operands.
+  // TODO emit VP intrinsics where MaskOp/VectorLenOp != null
+  // SDValue getNode(unsigned Opcode, const SDLoc &DL, EVT VT) { return
+  // DAG.getNode(Opcode, DL, VT); }
+  SDValue getNode(unsigned Opcode, const SDLoc &DL, EVT VT, SDValue Operand) {
+    unsigned VPOpcode = ISD::getVPForBaseOpcode(Opcode);
+    assert(ISD::getVPMaskIdx(VPOpcode) == 1 &&
+           ISD::getVPExplicitVectorLengthIdx(VPOpcode) == 2);
+    return DAG.getNode(VPOpcode, DL, VT,
+                       {Operand, RootMaskOp, RootVectorLenOp});
+  }
+
+  SDValue getNode(unsigned Opcode, const SDLoc &DL, EVT VT, SDValue N1,
+                  SDValue N2) {
+    unsigned VPOpcode = ISD::getVPForBaseOpcode(Opcode);
+    assert(ISD::getVPMaskIdx(VPOpcode) == 2 &&
+           ISD::getVPExplicitVectorLengthIdx(VPOpcode) == 3);
+    return DAG.getNode(VPOpcode, DL, VT,
+                       {N1, N2, RootMaskOp, RootVectorLenOp});
+  }
+
+  SDValue getNode(unsigned Opcode, const SDLoc &DL, EVT VT, SDValue N1,
+                  SDValue N2, SDValue N3) {
+    unsigned VPOpcode = ISD::getVPForBaseOpcode(Opcode);
+    assert(ISD::getVPMaskIdx(VPOpcode) == 3 &&
+           ISD::getVPExplicitVectorLengthIdx(VPOpcode) == 4);
+    return DAG.getNode(VPOpcode, DL, VT,
+                       {N1, N2, N3, RootMaskOp, RootVectorLenOp});
+  }
+
+  SDValue getNode(unsigned Opcode, const SDLoc &DL, EVT VT, SDValue Operand,
+                  SDNodeFlags Flags) {
+    unsigned VPOpcode = ISD::getVPForBaseOpcode(Opcode);
+    assert(ISD::getVPMaskIdx(VPOpcode) == 1 &&
+           ISD::getVPExplicitVectorLengthIdx(VPOpcode) == 2);
+    return DAG.getNode(VPOpcode, DL, VT, {Operand, RootMaskOp, RootVectorLenOp},
+                       Flags);
+  }
+
+  SDValue getNode(unsigned Opcode, const SDLoc &DL, EVT VT, SDValue N1,
+                  SDValue N2, SDNodeFlags Flags) {
+    unsigned VPOpcode = ISD::getVPForBaseOpcode(Opcode);
+    assert(ISD::getVPMaskIdx(VPOpcode) == 2 &&
+           ISD::getVPExplicitVectorLengthIdx(VPOpcode) == 3);
+    return DAG.getNode(VPOpcode, DL, VT, {N1, N2, RootMaskOp, RootVectorLenOp},
+                       Flags);
+  }
+
+  SDValue getNode(unsigned Opcode, const SDLoc &DL, EVT VT, SDValue N1,
+                  SDValue N2, SDValue N3, SDNodeFlags Flags) {
+    unsigned VPOpcode = ISD::getVPForBaseOpcode(Opcode);
+    assert(ISD::getVPMaskIdx(VPOpcode) == 3 &&
+           ISD::getVPExplicitVectorLengthIdx(VPOpcode) == 4);
+    return DAG.getNode(VPOpcode, DL, VT,
+                       {N1, N2, N3, RootMaskOp, RootVectorLenOp}, Flags);
+  }
+};
+
 } // end anonymous namespace
 
 //===----------------------------------------------------------------------===//
@@ -14676,21 +14797,26 @@ static bool hasNoInfs(const TargetOptions &Options, SDValue N) {
 }
 
 /// Try to perform FMA combining on a given FADD node.
+template <class MatchContextClass>
 SDValue DAGCombiner::visitFADDForFMACombine(SDNode *N) {
   SDValue N0 = N->getOperand(0);
   SDValue N1 = N->getOperand(1);
   EVT VT = N->getValueType(0);
   SDLoc SL(N);
-
+  MatchContextClass matcher(DAG, N);
   const TargetOptions &Options = DAG.getTarget().Options;
 
+  bool UseVP = std::is_same_v<MatchContextClass, VPMatchContext>;
+
   // Floating-point multiply-add with intermediate rounding.
-  bool HasFMAD = (LegalOperations && TLI.isFMADLegal(DAG, N));
+  // FIXME: Make isFMADLegal have specific behavior when using VPMatchContext.
+  // FIXME: Add VP_FMAD opcode.
+  bool HasFMAD = !UseVP && (LegalOperations && TLI.isFMADLegal(DAG, N));
 
   // Floating-point multiply-add without intermediate rounding.
-  bool HasFMA =
-      TLI.isFMAFasterThanFMulAndFAdd(DAG.getMachineFunction(), VT) &&
-      (!LegalOperations || TLI.isOperationLegalOrCustom(ISD::FMA, VT));
+  unsigned FMAOpc = UseVP ? ISD::VP_FMA : ISD::FMA;
+  bool HasFMA = TLI.isFMAFasterThanFMulAndFAdd(DAG.getMachineFunction(), VT) &&
+                (!LegalOperations || TLI.isOperationLegalOrCustom(FMAOpc, VT));
 
   // No valid opcode, do not combine.
   if (!HasFMAD && !HasFMA)
@@ -14712,14 +14838,13 @@ SDValue DAGCombiner::visitFADDForFMACombine(SDNode *N) {
   bool Aggressive = TLI.enableAggressiveFMAFusion(VT);
 
   auto isFusedOp = [&](SDValue N) {
-    unsigned Opcode = N.getOpcode();
-    return Opcode == ISD::FMA || Opcode == ISD::FMAD;
+    return matcher.match(N, ISD::FMA) || matcher.match(N, ISD::FMAD);
   };
 
   // Is the node an FMUL and contractable either due to global flags or
   // SDNodeFlags.
-  auto isContractableFMUL = [AllowFusionGlobally](SDValue N) {
-    if (N.getOpcode() != ISD::FMUL)
+  auto isContractableFMUL = [AllowFusionGlobally, &matcher](SDValue N) {
+    if (!matcher.match(N, ISD::FMUL))
       return false;
     return AllowFusionGlobally || N->getFlags().hasAllowContract();
   };
@@ -14732,15 +14857,15 @@ SDValue DAGCombiner::visitFADDForFMACombine(SDNode *N) {
 
   // fold (fadd (fmul x, y), z) -> (fma x, y, z)
   if (isContractableFMUL(N0) && (Aggressive || N0->hasOneUse())) {
-    return DAG.getNode(PreferredFusedOpcode, SL, VT, N0.getOperand(0),
-                       N0.getOperand(1), N1);
+    return matcher.getNode(PreferredFusedOpcode, SL, VT, N0.getOperand(0),
+                           N0.getOperand(1), N1);
   }
 
   // fold (fadd x, (fmul y, z)) -> (fma y, z, x)
   // Note: Commutes FADD operands.
   if (isContractableFMUL(N1) && (Aggressive || N1->hasOneUse())) {
-    return DAG.getNode(PreferredFusedOpcode, SL, VT, N1.getOperand(0),
-                       N1.getOperand(1), N0);
+    return matcher.getNode(PreferredFusedOpcode, SL, VT, N1.getOperand(0),
+                           N1.getOperand(1), N0);
   }
 
   // fadd (fma A, B, (fmul C, D)), E --> fma A, B, (fma C, D, E)
@@ -14781,29 +14906,29 @@ SDValue DAGCombiner::visitFADDForFMACombine(SDNode *N) {
   // Look through FP_EXTEND nodes to do more combining.
 
   // fold (fadd (fpext (fmul x, y)), z) -> (fma (fpext x), (fpext y), z)
-  if (N0.getOpcode() == ISD::FP_EXTEND) {
+  if (matcher.match(N0, ISD::FP_EXTEND)) {
     SDValue N00 = N0.getOperand(0);
     if (isContractableFMUL(N00) &&
         TLI.isFPExtFoldable(DAG, PreferredFusedOpcode, VT,
                             N00.getValueType())) {
-      return DAG.getNode(PreferredFusedOpcode, SL, VT,
-                         DAG.getNode(ISD::FP_EXTEND, SL, VT, N00.getOperand(0)),
-                         DAG.getNode(ISD::FP_EXTEND, SL, VT, N00.getOperand(1)),
-                         N1);
+      return matcher.getNode(
+          PreferredFusedOpcode, SL, VT,
+          matcher.getNode(ISD::FP_EXTEND, SL, VT, N00.getOperand(0)),
+          matcher.getNode(ISD::FP_EXTEND, SL, VT, N00.getOperand(1)), N1);
     }
   }
 
   // fold (fadd x, (fpext (fmul y, z))) -> (fma (fpext y), (fpext z), x)
   // Note: Commutes FADD operands.
-  if (N1.getOpcode() == ISD::FP_EXTEND) {
+  if (matcher.match(N1, ISD::FP_EXTEND)) {
     SDValue N10 = N1.getOperand(0);
     if (isContractableFMUL(N10) &&
         TLI.isFPExtFoldable(DAG, PreferredFusedOpcode, VT,
                             N10.getValueType())) {
-      return DAG.getNode(PreferredFusedOpcode, SL, VT,
-                         DAG.getNode(ISD::FP_EXTEND, SL, VT, N10.getOperand(0)),
-                         DAG.getNode(ISD::FP_EXTEND, SL, VT, N10.getOperand(1)),
-                         N0);
+      return matcher.getNode(
+          PreferredFusedOpcode, SL, VT,
+          matcher.getNode(ISD::FP_EXTEND, SL, VT, N10.getOperand(0)),
+          matcher.getNode(ISD::FP_EXTEND, SL, VT, N10.getOperand(1)), N0);
     }
   }
 
@@ -14813,15 +14938,15 @@ SDValue DAGCombiner::visitFADDForFMACombine(SDNode *N) {
     //   -> (fma x, y, (fma (fpext u), (fpext v), z))
     auto FoldFAddFMAFPExtFMul = [&](SDValue X, SDValue Y, SDValue U, SDValue V,
                                     SDValue Z) {
-      return DAG.getNode(PreferredFusedOpcode, SL, VT, X, Y,
-                         DAG.getNode(PreferredFusedOpcode, SL, VT,
-                                     DAG.getNode(ISD::FP_EXTEND, SL, VT, U),
-                                     DAG.getNode(ISD::FP_EXTEND, SL, VT, V),
-                                     Z));
+      return matcher.getNode(
+          PreferredFusedOpcode, SL, VT, X, Y,
+          matcher.getNode(PreferredFusedOpcode, SL, VT,
+                          matcher.getNode(ISD::FP_EXTEND, SL, VT, U),
+                          matcher.getNode(ISD::FP_EXTEND, SL, VT, V), Z));
     };
     if (isFusedOp(N0)) {
       SDValue N02 = N0.getOperand(2);
-      if (N02.getOpcode() == ISD::FP_EXTEND) {
+      if (matcher.match(N02, ISD::FP_EXTEND)) {
         SDValue N020 = N02.getOperand(0);
         if (isContractableFMUL(N020) &&
             TLI.isFPExtFoldable(DAG, PreferredFusedOpcode, VT,
@@ -14840,12 +14965,13 @@ SDValue DAGCombiner::visitFADDForFMACombine(SDNode *N) {
     // interesting for all targets, especially GPUs.
     auto FoldFAddFPExtFMAFMul = [&](SDValue X, SDValue Y, SDValue U, SDValue V,
                                     SDValue Z) {
-      return DAG.getNode(
-          PreferredFusedOpcode, SL, VT, DAG.getNode(ISD::FP_EXTEND, SL, VT, X),
-          DAG.getNode(ISD::FP_EXTEND, SL, VT, Y),
-          DAG.getNode(PreferredFusedOpcode, SL, VT,
-                      DAG.getNode(ISD::FP_EXTEND, SL, VT, U),
-                      DAG.getNode(ISD::FP_EXTEND, SL, VT, V), Z));
+      return matcher.getNode(
+          PreferredFusedOpcode, SL, VT,
+          matcher.getNode(ISD::FP_EXTEND, SL, VT, X),
+          matcher.getNode(ISD::FP_EXTEND, SL, VT, Y),
+          matcher.getNode(PreferredFusedOpcode, SL, VT,
+                          matcher.getNode(ISD::FP_EXTEND, SL, VT, U),
+                          matcher.getNode(ISD::FP_EXTEND, SL, VT, V), Z));
     };
     if (N0.getOpcode() == ISD::FP_EXTEND) {
       SDValue N00 = N0.getOperand(0);
@@ -15308,6 +15434,17 @@ SDValue DAGCombiner::visitFMULForFMADistributiveCombine(SDNode *N) {
   return SDValue();
 }
 
+SDValue DAGCombiner::visitVP_FADD(SDNode *N) {
+  SelectionDAG::FlagInserter FlagsInserter(DAG, N);
+
+  // FADD -> FMA combines:
+  if (SDValue Fused = visitFADDForFMACombine<VPMatchContext>(N)) {
+    AddToWorklist(Fused.getNode());
+    return Fused;
+  }
+  return SDValue();
+}
+
 SDValue DAGCombiner::visitFADD(SDNode *N) {
   SDValue N0 = N->getOperand(0);
   SDValue N1 = N->getOperand(1);
@@ -15488,7 +15625,7 @@ SDValue DAGCombiner::visitFADD(SDNode *N) {
   } // enable-unsafe-fp-math
 
   // FADD -> FMA combines:
-  if (SDValue Fused = visitFADDForFMACombine(N)) {
+  if (SDValue Fused = visitFADDForFMACombine<EmptyMatchContext>(N)) {
     AddToWorklist(Fused.getNode());
     return Fused;
   }
@@ -24929,8 +25066,13 @@ SDValue DAGCombiner::visitVPOp(SDNode *N) {
         ISD::isConstantSplatVectorAllZeros(N->getOperand(*MaskIdx).getNode());
 
   // This is the only generic VP combine we support for now.
-  if (!AreAllEltsDisabled)
+  if (!AreAllEltsDisabled) {
+    switch (N->getOpcode()) {
+    case ISD::VP_FADD:
+      return visitVP_FADD(N);
+    }
     return SDValue();
+  }
 
   // Binary operations can be replaced by UNDEF.
   if (ISD::isVPBinaryOp(N->getOpcode()))

diff  --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
index 0d2ae398815c1..8292f9c3de3ae 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
@@ -516,6 +516,31 @@ std::optional<unsigned> ISD::getVPExplicitVectorLengthIdx(unsigned Opcode) {
   }
 }
 
+std::optional<unsigned> ISD::getBaseOpcodeForVP(unsigned VPOpcode,
+                                                bool hasFPExcept) {
+  // FIXME: Return strict opcodes in case of fp exceptions.
+  switch (VPOpcode) {
+  default:
+    return std::nullopt;
+#define BEGIN_REGISTER_VP_SDNODE(VPOPC, ...) case ISD::VPOPC:
+#define VP_PROPERTY_FUNCTIONAL_SDOPC(SDOPC) return ISD::SDOPC;
+#define END_REGISTER_VP_SDNODE(VPOPC) break;
+#include "llvm/IR/VPIntrinsics.def"
+  }
+  return std::nullopt;
+}
+
+unsigned ISD::getVPForBaseOpcode(unsigned Opcode) {
+  switch (Opcode) {
+  default:
+    llvm_unreachable("can not translate this Opcode to VP.");
+#define BEGIN_REGISTER_VP_SDNODE(VPOPC, ...) break;
+#define VP_PROPERTY_FUNCTIONAL_SDOPC(SDOPC) case ISD::SDOPC:
+#define END_REGISTER_VP_SDNODE(VPOPC) return ISD::VPOPC;
+#include "llvm/IR/VPIntrinsics.def"
+  }
+}
+
 ISD::NodeType ISD::getExtForLoadExtType(bool IsFP, ISD::LoadExtType ExtType) {
   switch (ExtType) {
   case ISD::EXTLOAD:

diff  --git a/llvm/test/CodeGen/RISCV/rvv/fold-vp-fadd-and-vp-fmul.ll b/llvm/test/CodeGen/RISCV/rvv/fold-vp-fadd-and-vp-fmul.ll
new file mode 100644
index 0000000000000..1d30328e97815
--- /dev/null
+++ b/llvm/test/CodeGen/RISCV/rvv/fold-vp-fadd-and-vp-fmul.ll
@@ -0,0 +1,59 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py
+; RUN: llc -mtriple=riscv64 -mattr=+v -target-abi=lp64d -verify-machineinstrs < %s | FileCheck %s
+
+declare <vscale x 1 x double> @llvm.vp.fmul.nxv1f64(<vscale x 1 x double> %x, <vscale x 1 x double> %y, <vscale x 1 x i1> %m, i32 %vl)
+declare <vscale x 1 x double> @llvm.vp.fadd.nxv1f64(<vscale x 1 x double> %x, <vscale x 1 x double> %y, <vscale x 1 x i1> %m, i32 %vl)
+
+; (fadd (fmul x, y), z)) -> (fma x, y, z)
+define <vscale x 1 x double> @fma(<vscale x 1 x double> %x, <vscale x 1 x double> %y, <vscale x 1 x double> %z, <vscale x 1 x i1> %m, i32 zeroext %vl) {
+; CHECK-LABEL: fma:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    vsetvli zero, a0, e64, m1, ta, mu
+; CHECK-NEXT:    vfmadd.vv v9, v8, v10, v0.t
+; CHECK-NEXT:    vmv.v.v v8, v9
+; CHECK-NEXT:    ret
+  %1 = call fast <vscale x 1 x double> @llvm.vp.fmul.nxv1f64(<vscale x 1 x double> %x, <vscale x 1 x double> %y, <vscale x 1 x i1> %m, i32 %vl)
+  %2 = call fast <vscale x 1 x double> @llvm.vp.fadd.nxv1f64(<vscale x 1 x double> %1, <vscale x 1 x double> %z, <vscale x 1 x i1> %m, i32 %vl)
+  ret <vscale x 1 x double> %2
+}
+
+; (fadd z, (fmul x, y))) -> (fma x, y, z)
+define <vscale x 1 x double> @fma_commute(<vscale x 1 x double> %x, <vscale x 1 x double> %y, <vscale x 1 x double> %z, <vscale x 1 x i1> %m, i32 zeroext %vl) {
+; CHECK-LABEL: fma_commute:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    vsetvli zero, a0, e64, m1, ta, mu
+; CHECK-NEXT:    vfmadd.vv v9, v8, v10, v0.t
+; CHECK-NEXT:    vmv.v.v v8, v9
+; CHECK-NEXT:    ret
+  %1 = call fast <vscale x 1 x double> @llvm.vp.fmul.nxv1f64(<vscale x 1 x double> %x, <vscale x 1 x double> %y, <vscale x 1 x i1> %m, i32 %vl)
+  %2 = call fast <vscale x 1 x double> @llvm.vp.fadd.nxv1f64(<vscale x 1 x double> %z, <vscale x 1 x double> %1, <vscale x 1 x i1> %m, i32 %vl)
+  ret <vscale x 1 x double> %2
+}
+
+; Test operand with true mask
+define <vscale x 1 x double> @fma_true(<vscale x 1 x double> %x, <vscale x 1 x double> %y, <vscale x 1 x double> %z, <vscale x 1 x i1> %m, i32 zeroext %vl) {
+; CHECK-LABEL: fma_true:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    vsetvli zero, a0, e64, m1, ta, mu
+; CHECK-NEXT:    vfmadd.vv v9, v8, v10, v0.t
+; CHECK-NEXT:    vmv.v.v v8, v9
+; CHECK-NEXT:    ret
+  %head = insertelement <vscale x 1 x i1> poison, i1 true, i32 0
+  %true = shufflevector <vscale x 1 x i1> %head, <vscale x 1 x i1> poison, <vscale x 1 x i32> zeroinitializer
+  %1 = call fast <vscale x 1 x double> @llvm.vp.fmul.nxv1f64(<vscale x 1 x double> %x, <vscale x 1 x double> %y, <vscale x 1 x i1> %true, i32 %vl)
+  %2 = call fast <vscale x 1 x double> @llvm.vp.fadd.nxv1f64(<vscale x 1 x double> %1, <vscale x 1 x double> %z, <vscale x 1 x i1> %m, i32 %vl)
+  ret <vscale x 1 x double> %2
+}
+
+; Test operand with normal opcode.
+define <vscale x 1 x double> @fma_nonvp(<vscale x 1 x double> %x, <vscale x 1 x double> %y, <vscale x 1 x double> %z, <vscale x 1 x i1> %m, i32 zeroext %vl) {
+; CHECK-LABEL: fma_nonvp:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    vsetvli zero, a0, e64, m1, ta, mu
+; CHECK-NEXT:    vfmadd.vv v9, v8, v10, v0.t
+; CHECK-NEXT:    vmv.v.v v8, v9
+; CHECK-NEXT:    ret
+  %1 = fmul fast <vscale x 1 x double> %x, %y
+  %2 = call fast <vscale x 1 x double> @llvm.vp.fadd.nxv1f64(<vscale x 1 x double> %1, <vscale x 1 x double> %z, <vscale x 1 x i1> %m, i32 %vl)
+  ret <vscale x 1 x double> %2
+}


        


More information about the llvm-commits mailing list