[llvm] [DAGCombiner][VP] add getNegative for VPMatchContext (PR #80635)

via llvm-commits llvm-commits at lists.llvm.org
Sun Feb 4 22:10:30 PST 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-llvm-selectiondag

Author: Shao-Ce SUN (sunshaoce)

<details>
<summary>Changes</summary>

This is my attempt to reuse existing code as much as possible, in order to provide a helper function for https://github.com/llvm/llvm-project/pull/80105.

---
Full diff: https://github.com/llvm/llvm-project/pull/80635.diff


3 Files Affected:

- (modified) llvm/include/llvm/CodeGen/SelectionDAG.h (+33-1) 
- (modified) llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp (+31-24) 
- (modified) llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp (+25-2) 


``````````diff
diff --git a/llvm/include/llvm/CodeGen/SelectionDAG.h b/llvm/include/llvm/CodeGen/SelectionDAG.h
index b9ec30754f0c3..22981a6284fac 100644
--- a/llvm/include/llvm/CodeGen/SelectionDAG.h
+++ b/llvm/include/llvm/CodeGen/SelectionDAG.h
@@ -88,6 +88,32 @@ class TargetMachine;
 class TargetSubtargetInfo;
 class Value;
 
+class VPMaskAndVL {
+  bool IsVP;
+  SDValue MaskOp;
+  SDValue VectorLenOp;
+
+public:
+  VPMaskAndVL(SDValue Mask, SDValue VectorLen)
+      : IsVP(true), MaskOp(Mask), VectorLenOp(VectorLen) {}
+  VPMaskAndVL() : IsVP(false), MaskOp(), VectorLenOp() {}
+
+  bool empty() const { return !IsVP; }
+  bool isMaskEqualsTo(const SDValue &Val) const { return MaskOp == Val; }
+  bool isVLEqualsTo(const SDValue &Val) const { return VectorLenOp == Val; }
+  SDValue getMask() const { return MaskOp; }
+  SDValue getVL() const { return VectorLenOp; }
+
+  SDValue setMask(SDValue Val) {
+    IsVP = true;
+    return MaskOp = Val;
+  }
+  SDValue setVL(SDValue Val) {
+    IsVP = true;
+    return VectorLenOp = Val;
+  }
+};
+
 template <typename T> class GenericSSAContext;
 using SSAContext = GenericSSAContext<Function>;
 template <typename T> class GenericUniformityInfo;
@@ -1004,7 +1030,8 @@ class SelectionDAG {
   SDValue getBoolExtOrTrunc(SDValue Op, const SDLoc &SL, EVT VT, EVT OpVT);
 
   /// Create negative operation as (SUB 0, Val).
-  SDValue getNegative(SDValue Val, const SDLoc &DL, EVT VT);
+  SDValue getNegative(SDValue Val, const SDLoc &DL, EVT VT,
+                      VPMaskAndVL VPOp = VPMaskAndVL());
 
   /// Create a bitwise NOT operation as (XOR Val, -1).
   SDValue getNOT(const SDLoc &DL, SDValue Val, EVT VT);
@@ -1116,6 +1143,9 @@ class SelectionDAG {
                   ArrayRef<SDUse> Ops);
   SDValue getNode(unsigned Opcode, const SDLoc &DL, EVT VT,
                   ArrayRef<SDValue> Ops, const SDNodeFlags Flags);
+  SDValue getNode(unsigned Opcode, const SDLoc &DL, EVT VT,
+                  ArrayRef<SDValue> Ops, const SDNodeFlags Flags,
+                  VPMaskAndVL VPOp);
   SDValue getNode(unsigned Opcode, const SDLoc &DL, ArrayRef<EVT> ResultTys,
                   ArrayRef<SDValue> Ops);
   SDValue getNode(unsigned Opcode, const SDLoc &DL, SDVTList VTList,
@@ -1124,6 +1154,8 @@ class SelectionDAG {
   // Use flags from current flag inserter.
   SDValue getNode(unsigned Opcode, const SDLoc &DL, EVT VT,
                   ArrayRef<SDValue> Ops);
+  SDValue getNode(unsigned Opcode, const SDLoc &DL, EVT VT,
+                  ArrayRef<SDValue> Ops, VPMaskAndVL VPOp);
   SDValue getNode(unsigned Opcode, const SDLoc &DL, SDVTList VTList,
                   ArrayRef<SDValue> Ops);
   SDValue getNode(unsigned Opcode, const SDLoc &DL, EVT VT, SDValue Operand);
diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index 3ce45e0e43bf4..3ed1a7533dd43 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -905,11 +905,16 @@ class EmptyMatchContext {
     return Opcode == OpN->getOpcode();
   }
 
-  // Same as SelectionDAG::getNode().
-  template <typename... ArgT> SDValue getNode(ArgT &&...Args) {
-    return DAG.getNode(std::forward<ArgT>(Args)...);
+  // Same as SelectionDAG::FUNCT_NAME(Args...).
+#define GET_SELECTION_DAG_FUNCT(FUNCT_NAME)                                    \
+  template <typename... ArgT> SDValue FUNCT_NAME(ArgT &&...Args) {             \
+    return DAG.FUNCT_NAME(std::forward<ArgT>(Args)...);                        \
   }
 
+  GET_SELECTION_DAG_FUNCT(getNode)
+  GET_SELECTION_DAG_FUNCT(getNegative)
+#undef GET_SELECTION_DAG_FUNCT
+
   bool isOperationLegalOrCustom(unsigned Op, EVT VT,
                                 bool LegalOnly = false) const {
     return TLI.isOperationLegalOrCustom(Op, VT, LegalOnly);
@@ -919,22 +924,21 @@ class EmptyMatchContext {
 class VPMatchContext {
   SelectionDAG &DAG;
   const TargetLowering &TLI;
-  SDValue RootMaskOp;
-  SDValue RootVectorLenOp;
+  VPMaskAndVL VPOp;
 
 public:
   VPMatchContext(SelectionDAG &DAG, const TargetLowering &TLI, SDNode *Root)
-      : DAG(DAG), TLI(TLI), RootMaskOp(), RootVectorLenOp() {
+      : DAG(DAG), TLI(TLI), VPOp() {
     assert(Root->isVPOpcode());
     if (auto RootMaskPos = ISD::getVPMaskIdx(Root->getOpcode()))
-      RootMaskOp = Root->getOperand(*RootMaskPos);
+      VPOp.setMask(Root->getOperand(*RootMaskPos));
     else if (Root->getOpcode() == ISD::VP_SELECT)
-      RootMaskOp = DAG.getAllOnesConstant(SDLoc(Root),
-                                          Root->getOperand(0).getValueType());
+      VPOp.setMask(DAG.getAllOnesConstant(SDLoc(Root),
+                                          Root->getOperand(0).getValueType()));
 
     if (auto RootVLenPos =
             ISD::getVPExplicitVectorLengthIdx(Root->getOpcode()))
-      RootVectorLenOp = Root->getOperand(*RootVLenPos);
+      VPOp.setVL(Root->getOperand(*RootVLenPos));
   }
 
   /// whether \p OpVal is a node that is functionally compatible with the
@@ -952,14 +956,14 @@ class VPMatchContext {
     unsigned VPOpcode = OpVal->getOpcode();
     if (auto MaskPos = ISD::getVPMaskIdx(VPOpcode)) {
       SDValue MaskOp = OpVal.getOperand(*MaskPos);
-      if (RootMaskOp != MaskOp &&
+      if (!VPOp.isMaskEqualsTo(MaskOp) &&
           !ISD::isConstantSplatVectorAllOnes(MaskOp.getNode()))
         return false;
     }
 
     // Make sure the EVL of OpVal is same as Root's.
     if (auto VLenPos = ISD::getVPExplicitVectorLengthIdx(VPOpcode))
-      if (RootVectorLenOp != OpVal.getOperand(*VLenPos))
+      if (!VPOp.isVLEqualsTo(OpVal.getOperand(*VLenPos)))
         return false;
     return true;
   }
@@ -972,8 +976,7 @@ class VPMatchContext {
     unsigned VPOpcode = ISD::getVPForBaseOpcode(Opcode);
     assert(ISD::getVPMaskIdx(VPOpcode) == 1 &&
            ISD::getVPExplicitVectorLengthIdx(VPOpcode) == 2);
-    return DAG.getNode(VPOpcode, DL, VT,
-                       {Operand, RootMaskOp, RootVectorLenOp});
+    return DAG.getNode(VPOpcode, DL, VT, {Operand}, VPOp);
   }
 
   SDValue getNode(unsigned Opcode, const SDLoc &DL, EVT VT, SDValue N1,
@@ -981,8 +984,7 @@ class VPMatchContext {
     unsigned VPOpcode = ISD::getVPForBaseOpcode(Opcode);
     assert(ISD::getVPMaskIdx(VPOpcode) == 2 &&
            ISD::getVPExplicitVectorLengthIdx(VPOpcode) == 3);
-    return DAG.getNode(VPOpcode, DL, VT,
-                       {N1, N2, RootMaskOp, RootVectorLenOp});
+    return DAG.getNode(VPOpcode, DL, VT, {N1, N2}, VPOp);
   }
 
   SDValue getNode(unsigned Opcode, const SDLoc &DL, EVT VT, SDValue N1,
@@ -990,8 +992,7 @@ class VPMatchContext {
     unsigned VPOpcode = ISD::getVPForBaseOpcode(Opcode);
     assert(ISD::getVPMaskIdx(VPOpcode) == 3 &&
            ISD::getVPExplicitVectorLengthIdx(VPOpcode) == 4);
-    return DAG.getNode(VPOpcode, DL, VT,
-                       {N1, N2, N3, RootMaskOp, RootVectorLenOp});
+    return DAG.getNode(VPOpcode, DL, VT, {N1, N2, N3}, VPOp);
   }
 
   SDValue getNode(unsigned Opcode, const SDLoc &DL, EVT VT, SDValue Operand,
@@ -999,8 +1000,7 @@ class VPMatchContext {
     unsigned VPOpcode = ISD::getVPForBaseOpcode(Opcode);
     assert(ISD::getVPMaskIdx(VPOpcode) == 1 &&
            ISD::getVPExplicitVectorLengthIdx(VPOpcode) == 2);
-    return DAG.getNode(VPOpcode, DL, VT, {Operand, RootMaskOp, RootVectorLenOp},
-                       Flags);
+    return DAG.getNode(VPOpcode, DL, VT, {Operand}, Flags, VPOp);
   }
 
   SDValue getNode(unsigned Opcode, const SDLoc &DL, EVT VT, SDValue N1,
@@ -1008,8 +1008,7 @@ class VPMatchContext {
     unsigned VPOpcode = ISD::getVPForBaseOpcode(Opcode);
     assert(ISD::getVPMaskIdx(VPOpcode) == 2 &&
            ISD::getVPExplicitVectorLengthIdx(VPOpcode) == 3);
-    return DAG.getNode(VPOpcode, DL, VT, {N1, N2, RootMaskOp, RootVectorLenOp},
-                       Flags);
+    return DAG.getNode(VPOpcode, DL, VT, {N1, N2}, Flags, VPOp);
   }
 
   SDValue getNode(unsigned Opcode, const SDLoc &DL, EVT VT, SDValue N1,
@@ -1017,10 +1016,18 @@ class VPMatchContext {
     unsigned VPOpcode = ISD::getVPForBaseOpcode(Opcode);
     assert(ISD::getVPMaskIdx(VPOpcode) == 3 &&
            ISD::getVPExplicitVectorLengthIdx(VPOpcode) == 4);
-    return DAG.getNode(VPOpcode, DL, VT,
-                       {N1, N2, N3, RootMaskOp, RootVectorLenOp}, Flags);
+    return DAG.getNode(VPOpcode, DL, VT, {N1, N2, N3}, Flags, VPOp);
+  }
+
+  // Same as SelectionDAG::FUNCT_NAME(Args, VPOp).
+#define GET_SELECTION_DAG_VP_FUNCT(FUNCT_NAME)                                 \
+  template <typename... ArgT> SDValue FUNCT_NAME(ArgT &&...Args) {             \
+    return DAG.FUNCT_NAME(std::forward<ArgT>(Args)..., VPOp);                  \
   }
 
+  GET_SELECTION_DAG_VP_FUNCT(getNegative)
+#undef GET_SELECTION_DAG_VP_FUNCT
+
   bool isOperationLegalOrCustom(unsigned Op, EVT VT,
                                 bool LegalOnly = false) const {
     unsigned VPOp = ISD::getVPForBaseOpcode(Op);
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
index 3c1343836187a..0c6e70a86caf8 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
@@ -1548,8 +1548,10 @@ SDValue SelectionDAG::getPtrExtendInReg(SDValue Op, const SDLoc &DL, EVT VT) {
   return getZeroExtendInReg(Op, DL, VT);
 }
 
-SDValue SelectionDAG::getNegative(SDValue Val, const SDLoc &DL, EVT VT) {
-  return getNode(ISD::SUB, DL, VT, getConstant(0, DL, VT), Val);
+SDValue SelectionDAG::getNegative(SDValue Val, const SDLoc &DL, EVT VT,
+                                  VPMaskAndVL VPOp) {
+  auto Opcode = VPOp.empty() ? ISD::SUB : ISD::VP_SUB;
+  return getNode(Opcode, DL, VT, {getConstant(0, DL, VT), Val}, VPOp);
 }
 
 /// getNOT - Create a bitwise NOT operation as (XOR Val, -1).
@@ -9708,6 +9710,16 @@ SDValue SelectionDAG::getNode(unsigned Opcode, const SDLoc &DL, EVT VT,
   return getNode(Opcode, DL, VT, Ops, Flags);
 }
 
+SDValue SelectionDAG::getNode(unsigned Opcode, const SDLoc &DL, EVT VT,
+                              ArrayRef<SDValue> Ops, VPMaskAndVL VPOp) {
+  SmallVector<SDValue, 8> OpsVec(Ops);
+  if (!VPOp.empty()) {
+    OpsVec.push_back(VPOp.getMask());
+    OpsVec.push_back(VPOp.getVL());
+  }
+  return getNode(Opcode, DL, VT, OpsVec);
+}
+
 SDValue SelectionDAG::getNode(unsigned Opcode, const SDLoc &DL, EVT VT,
                               ArrayRef<SDValue> Ops, const SDNodeFlags Flags) {
   unsigned NumOps = Ops.size();
@@ -9820,6 +9832,17 @@ SDValue SelectionDAG::getNode(unsigned Opcode, const SDLoc &DL, EVT VT,
   return V;
 }
 
+SDValue SelectionDAG::getNode(unsigned Opcode, const SDLoc &DL, EVT VT,
+                              ArrayRef<SDValue> Ops, const SDNodeFlags Flags,
+                              VPMaskAndVL VPOp) {
+  SmallVector<SDValue, 8> OpsVec(Ops);
+  if (!VPOp.empty()) {
+    OpsVec.push_back(VPOp.getMask());
+    OpsVec.push_back(VPOp.getVL());
+  }
+  return getNode(Opcode, DL, VT, OpsVec, Flags);
+}
+
 SDValue SelectionDAG::getNode(unsigned Opcode, const SDLoc &DL,
                               ArrayRef<EVT> ResultTys, ArrayRef<SDValue> Ops) {
   return getNode(Opcode, DL, getVTList(ResultTys), Ops);

``````````

</details>


https://github.com/llvm/llvm-project/pull/80635


More information about the llvm-commits mailing list