[llvm] [DAGCombiner][VP] add getNegative for VPMatchContext (PR #80635)
via llvm-commits
llvm-commits at lists.llvm.org
Sun Feb 4 22:10:30 PST 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-llvm-selectiondag
Author: Shao-Ce SUN (sunshaoce)
<details>
<summary>Changes</summary>
This is my attempt to reuse existing code as much as possible, in order to provide a helper function for https://github.com/llvm/llvm-project/pull/80105.
---
Full diff: https://github.com/llvm/llvm-project/pull/80635.diff
3 Files Affected:
- (modified) llvm/include/llvm/CodeGen/SelectionDAG.h (+33-1)
- (modified) llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp (+31-24)
- (modified) llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp (+25-2)
``````````diff
diff --git a/llvm/include/llvm/CodeGen/SelectionDAG.h b/llvm/include/llvm/CodeGen/SelectionDAG.h
index b9ec30754f0c3..22981a6284fac 100644
--- a/llvm/include/llvm/CodeGen/SelectionDAG.h
+++ b/llvm/include/llvm/CodeGen/SelectionDAG.h
@@ -88,6 +88,32 @@ class TargetMachine;
class TargetSubtargetInfo;
class Value;
+class VPMaskAndVL {
+ bool IsVP;
+ SDValue MaskOp;
+ SDValue VectorLenOp;
+
+public:
+ VPMaskAndVL(SDValue Mask, SDValue VectorLen)
+ : IsVP(true), MaskOp(Mask), VectorLenOp(VectorLen) {}
+ VPMaskAndVL() : IsVP(false), MaskOp(), VectorLenOp() {}
+
+ bool empty() const { return !IsVP; }
+ bool isMaskEqualsTo(const SDValue &Val) const { return MaskOp == Val; }
+ bool isVLEqualsTo(const SDValue &Val) const { return VectorLenOp == Val; }
+ SDValue getMask() const { return MaskOp; }
+ SDValue getVL() const { return VectorLenOp; }
+
+ SDValue setMask(SDValue Val) {
+ IsVP = true;
+ return MaskOp = Val;
+ }
+ SDValue setVL(SDValue Val) {
+ IsVP = true;
+ return VectorLenOp = Val;
+ }
+};
+
template <typename T> class GenericSSAContext;
using SSAContext = GenericSSAContext<Function>;
template <typename T> class GenericUniformityInfo;
@@ -1004,7 +1030,8 @@ class SelectionDAG {
SDValue getBoolExtOrTrunc(SDValue Op, const SDLoc &SL, EVT VT, EVT OpVT);
/// Create negative operation as (SUB 0, Val).
- SDValue getNegative(SDValue Val, const SDLoc &DL, EVT VT);
+ SDValue getNegative(SDValue Val, const SDLoc &DL, EVT VT,
+ VPMaskAndVL VPOp = VPMaskAndVL());
/// Create a bitwise NOT operation as (XOR Val, -1).
SDValue getNOT(const SDLoc &DL, SDValue Val, EVT VT);
@@ -1116,6 +1143,9 @@ class SelectionDAG {
ArrayRef<SDUse> Ops);
SDValue getNode(unsigned Opcode, const SDLoc &DL, EVT VT,
ArrayRef<SDValue> Ops, const SDNodeFlags Flags);
+ SDValue getNode(unsigned Opcode, const SDLoc &DL, EVT VT,
+ ArrayRef<SDValue> Ops, const SDNodeFlags Flags,
+ VPMaskAndVL VPOp);
SDValue getNode(unsigned Opcode, const SDLoc &DL, ArrayRef<EVT> ResultTys,
ArrayRef<SDValue> Ops);
SDValue getNode(unsigned Opcode, const SDLoc &DL, SDVTList VTList,
@@ -1124,6 +1154,8 @@ class SelectionDAG {
// Use flags from current flag inserter.
SDValue getNode(unsigned Opcode, const SDLoc &DL, EVT VT,
ArrayRef<SDValue> Ops);
+ SDValue getNode(unsigned Opcode, const SDLoc &DL, EVT VT,
+ ArrayRef<SDValue> Ops, VPMaskAndVL VPOp);
SDValue getNode(unsigned Opcode, const SDLoc &DL, SDVTList VTList,
ArrayRef<SDValue> Ops);
SDValue getNode(unsigned Opcode, const SDLoc &DL, EVT VT, SDValue Operand);
diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index 3ce45e0e43bf4..3ed1a7533dd43 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -905,11 +905,16 @@ class EmptyMatchContext {
return Opcode == OpN->getOpcode();
}
- // Same as SelectionDAG::getNode().
- template <typename... ArgT> SDValue getNode(ArgT &&...Args) {
- return DAG.getNode(std::forward<ArgT>(Args)...);
+ // Same as SelectionDAG::FUNCT_NAME(Args...).
+#define GET_SELECTION_DAG_FUNCT(FUNCT_NAME) \
+ template <typename... ArgT> SDValue FUNCT_NAME(ArgT &&...Args) { \
+ return DAG.FUNCT_NAME(std::forward<ArgT>(Args)...); \
}
+ GET_SELECTION_DAG_FUNCT(getNode)
+ GET_SELECTION_DAG_FUNCT(getNegative)
+#undef GET_SELECTION_DAG_FUNCT
+
bool isOperationLegalOrCustom(unsigned Op, EVT VT,
bool LegalOnly = false) const {
return TLI.isOperationLegalOrCustom(Op, VT, LegalOnly);
@@ -919,22 +924,21 @@ class EmptyMatchContext {
class VPMatchContext {
SelectionDAG &DAG;
const TargetLowering &TLI;
- SDValue RootMaskOp;
- SDValue RootVectorLenOp;
+ VPMaskAndVL VPOp;
public:
VPMatchContext(SelectionDAG &DAG, const TargetLowering &TLI, SDNode *Root)
- : DAG(DAG), TLI(TLI), RootMaskOp(), RootVectorLenOp() {
+ : DAG(DAG), TLI(TLI), VPOp() {
assert(Root->isVPOpcode());
if (auto RootMaskPos = ISD::getVPMaskIdx(Root->getOpcode()))
- RootMaskOp = Root->getOperand(*RootMaskPos);
+ VPOp.setMask(Root->getOperand(*RootMaskPos));
else if (Root->getOpcode() == ISD::VP_SELECT)
- RootMaskOp = DAG.getAllOnesConstant(SDLoc(Root),
- Root->getOperand(0).getValueType());
+ VPOp.setMask(DAG.getAllOnesConstant(SDLoc(Root),
+ Root->getOperand(0).getValueType()));
if (auto RootVLenPos =
ISD::getVPExplicitVectorLengthIdx(Root->getOpcode()))
- RootVectorLenOp = Root->getOperand(*RootVLenPos);
+ VPOp.setVL(Root->getOperand(*RootVLenPos));
}
/// whether \p OpVal is a node that is functionally compatible with the
@@ -952,14 +956,14 @@ class VPMatchContext {
unsigned VPOpcode = OpVal->getOpcode();
if (auto MaskPos = ISD::getVPMaskIdx(VPOpcode)) {
SDValue MaskOp = OpVal.getOperand(*MaskPos);
- if (RootMaskOp != MaskOp &&
+ if (!VPOp.isMaskEqualsTo(MaskOp) &&
!ISD::isConstantSplatVectorAllOnes(MaskOp.getNode()))
return false;
}
// Make sure the EVL of OpVal is same as Root's.
if (auto VLenPos = ISD::getVPExplicitVectorLengthIdx(VPOpcode))
- if (RootVectorLenOp != OpVal.getOperand(*VLenPos))
+ if (!VPOp.isVLEqualsTo(OpVal.getOperand(*VLenPos)))
return false;
return true;
}
@@ -972,8 +976,7 @@ class VPMatchContext {
unsigned VPOpcode = ISD::getVPForBaseOpcode(Opcode);
assert(ISD::getVPMaskIdx(VPOpcode) == 1 &&
ISD::getVPExplicitVectorLengthIdx(VPOpcode) == 2);
- return DAG.getNode(VPOpcode, DL, VT,
- {Operand, RootMaskOp, RootVectorLenOp});
+ return DAG.getNode(VPOpcode, DL, VT, {Operand}, VPOp);
}
SDValue getNode(unsigned Opcode, const SDLoc &DL, EVT VT, SDValue N1,
@@ -981,8 +984,7 @@ class VPMatchContext {
unsigned VPOpcode = ISD::getVPForBaseOpcode(Opcode);
assert(ISD::getVPMaskIdx(VPOpcode) == 2 &&
ISD::getVPExplicitVectorLengthIdx(VPOpcode) == 3);
- return DAG.getNode(VPOpcode, DL, VT,
- {N1, N2, RootMaskOp, RootVectorLenOp});
+ return DAG.getNode(VPOpcode, DL, VT, {N1, N2}, VPOp);
}
SDValue getNode(unsigned Opcode, const SDLoc &DL, EVT VT, SDValue N1,
@@ -990,8 +992,7 @@ class VPMatchContext {
unsigned VPOpcode = ISD::getVPForBaseOpcode(Opcode);
assert(ISD::getVPMaskIdx(VPOpcode) == 3 &&
ISD::getVPExplicitVectorLengthIdx(VPOpcode) == 4);
- return DAG.getNode(VPOpcode, DL, VT,
- {N1, N2, N3, RootMaskOp, RootVectorLenOp});
+ return DAG.getNode(VPOpcode, DL, VT, {N1, N2, N3}, VPOp);
}
SDValue getNode(unsigned Opcode, const SDLoc &DL, EVT VT, SDValue Operand,
@@ -999,8 +1000,7 @@ class VPMatchContext {
unsigned VPOpcode = ISD::getVPForBaseOpcode(Opcode);
assert(ISD::getVPMaskIdx(VPOpcode) == 1 &&
ISD::getVPExplicitVectorLengthIdx(VPOpcode) == 2);
- return DAG.getNode(VPOpcode, DL, VT, {Operand, RootMaskOp, RootVectorLenOp},
- Flags);
+ return DAG.getNode(VPOpcode, DL, VT, {Operand}, Flags, VPOp);
}
SDValue getNode(unsigned Opcode, const SDLoc &DL, EVT VT, SDValue N1,
@@ -1008,8 +1008,7 @@ class VPMatchContext {
unsigned VPOpcode = ISD::getVPForBaseOpcode(Opcode);
assert(ISD::getVPMaskIdx(VPOpcode) == 2 &&
ISD::getVPExplicitVectorLengthIdx(VPOpcode) == 3);
- return DAG.getNode(VPOpcode, DL, VT, {N1, N2, RootMaskOp, RootVectorLenOp},
- Flags);
+ return DAG.getNode(VPOpcode, DL, VT, {N1, N2}, Flags, VPOp);
}
SDValue getNode(unsigned Opcode, const SDLoc &DL, EVT VT, SDValue N1,
@@ -1017,10 +1016,18 @@ class VPMatchContext {
unsigned VPOpcode = ISD::getVPForBaseOpcode(Opcode);
assert(ISD::getVPMaskIdx(VPOpcode) == 3 &&
ISD::getVPExplicitVectorLengthIdx(VPOpcode) == 4);
- return DAG.getNode(VPOpcode, DL, VT,
- {N1, N2, N3, RootMaskOp, RootVectorLenOp}, Flags);
+ return DAG.getNode(VPOpcode, DL, VT, {N1, N2, N3}, Flags, VPOp);
+ }
+
+ // Same as SelectionDAG::FUNCT_NAME(Args, VPOp).
+#define GET_SELECTION_DAG_VP_FUNCT(FUNCT_NAME) \
+ template <typename... ArgT> SDValue FUNCT_NAME(ArgT &&...Args) { \
+ return DAG.FUNCT_NAME(std::forward<ArgT>(Args)..., VPOp); \
}
+ GET_SELECTION_DAG_VP_FUNCT(getNegative)
+#undef GET_SELECTION_DAG_VP_FUNCT
+
bool isOperationLegalOrCustom(unsigned Op, EVT VT,
bool LegalOnly = false) const {
unsigned VPOp = ISD::getVPForBaseOpcode(Op);
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
index 3c1343836187a..0c6e70a86caf8 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
@@ -1548,8 +1548,10 @@ SDValue SelectionDAG::getPtrExtendInReg(SDValue Op, const SDLoc &DL, EVT VT) {
return getZeroExtendInReg(Op, DL, VT);
}
-SDValue SelectionDAG::getNegative(SDValue Val, const SDLoc &DL, EVT VT) {
- return getNode(ISD::SUB, DL, VT, getConstant(0, DL, VT), Val);
+SDValue SelectionDAG::getNegative(SDValue Val, const SDLoc &DL, EVT VT,
+ VPMaskAndVL VPOp) {
+ auto Opcode = VPOp.empty() ? ISD::SUB : ISD::VP_SUB;
+ return getNode(Opcode, DL, VT, {getConstant(0, DL, VT), Val}, VPOp);
}
/// getNOT - Create a bitwise NOT operation as (XOR Val, -1).
@@ -9708,6 +9710,16 @@ SDValue SelectionDAG::getNode(unsigned Opcode, const SDLoc &DL, EVT VT,
return getNode(Opcode, DL, VT, Ops, Flags);
}
+SDValue SelectionDAG::getNode(unsigned Opcode, const SDLoc &DL, EVT VT,
+ ArrayRef<SDValue> Ops, VPMaskAndVL VPOp) {
+ SmallVector<SDValue, 8> OpsVec(Ops);
+ if (!VPOp.empty()) {
+ OpsVec.push_back(VPOp.getMask());
+ OpsVec.push_back(VPOp.getVL());
+ }
+ return getNode(Opcode, DL, VT, OpsVec);
+}
+
SDValue SelectionDAG::getNode(unsigned Opcode, const SDLoc &DL, EVT VT,
ArrayRef<SDValue> Ops, const SDNodeFlags Flags) {
unsigned NumOps = Ops.size();
@@ -9820,6 +9832,17 @@ SDValue SelectionDAG::getNode(unsigned Opcode, const SDLoc &DL, EVT VT,
return V;
}
+SDValue SelectionDAG::getNode(unsigned Opcode, const SDLoc &DL, EVT VT,
+ ArrayRef<SDValue> Ops, const SDNodeFlags Flags,
+ VPMaskAndVL VPOp) {
+ SmallVector<SDValue, 8> OpsVec(Ops);
+ if (!VPOp.empty()) {
+ OpsVec.push_back(VPOp.getMask());
+ OpsVec.push_back(VPOp.getVL());
+ }
+ return getNode(Opcode, DL, VT, OpsVec, Flags);
+}
+
SDValue SelectionDAG::getNode(unsigned Opcode, const SDLoc &DL,
ArrayRef<EVT> ResultTys, ArrayRef<SDValue> Ops) {
return getNode(Opcode, DL, getVTList(ResultTys), Ops);
``````````
</details>
https://github.com/llvm/llvm-project/pull/80635
More information about the llvm-commits
mailing list