[llvm] r317985 - [X86] Attempt to match multiple binary reduction ops at once. NFCI
Simon Pilgrim via llvm-commits
llvm-commits at lists.llvm.org
Sat Nov 11 10:16:55 PST 2017
Author: rksimon
Date: Sat Nov 11 10:16:55 2017
New Revision: 317985
URL: http://llvm.org/viewvc/llvm-project?rev=317985&view=rev
Log:
[X86] Attempt to match multiple binary reduction ops at once. NFCI
matchBinOpReduction currently matches against a single opcode, but we already have a case where we repeat calls to try to match against AND/OR and I'll be shortly adding another case for SMAX/SMIN/UMAX/UMIN (D39729).
This NFCI patch alters matchBinOpReduction to try and pattern match against any of the provided list of candidate bin ops at once to save time.
Differential Revision: https://reviews.llvm.org/D39726
Modified:
llvm/trunk/lib/Target/X86/X86ISelLowering.cpp
Modified: llvm/trunk/lib/Target/X86/X86ISelLowering.cpp
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/lib/Target/X86/X86ISelLowering.cpp?rev=317985&r1=317984&r2=317985&view=diff
==============================================================================
--- llvm/trunk/lib/Target/X86/X86ISelLowering.cpp (original)
+++ llvm/trunk/lib/Target/X86/X86ISelLowering.cpp Sat Nov 11 10:16:55 2017
@@ -30093,16 +30093,22 @@ static SDValue combineBitcast(SDNode *N,
// the elements of a vector.
// Returns the vector that is being reduced on, or SDValue() if a reduction
// was not matched.
-static SDValue matchBinOpReduction(SDNode *Extract, ISD::NodeType BinOp) {
+static SDValue matchBinOpReduction(SDNode *Extract, unsigned &BinOp,
+ ArrayRef<ISD::NodeType> CandidateBinOps) {
// The pattern must end in an extract from index 0.
if ((Extract->getOpcode() != ISD::EXTRACT_VECTOR_ELT) ||
!isNullConstant(Extract->getOperand(1)))
return SDValue();
- unsigned Stages =
- Log2_32(Extract->getOperand(0).getValueType().getVectorNumElements());
-
SDValue Op = Extract->getOperand(0);
+ unsigned Stages = Log2_32(Op.getValueType().getVectorNumElements());
+
+ // Match against one of the candidate binary ops.
+ if (llvm::none_of(CandidateBinOps, [Op](ISD::NodeType BinOp) {
+ return Op.getOpcode() == BinOp;
+ }))
+ return SDValue();
+
// At each stage, we're looking for something that looks like:
// %s = shufflevector <8 x i32> %op, <8 x i32> undef,
// <8 x i32> <i32 2, i32 3, i32 undef, i32 undef,
@@ -30113,8 +30119,9 @@ static SDValue matchBinOpReduction(SDNod
// <4,5,6,7,u,u,u,u>
// <2,3,u,u,u,u,u,u>
// <1,u,u,u,u,u,u,u>
+ unsigned CandidateBinOp = Op.getOpcode();
for (unsigned i = 0; i < Stages; ++i) {
- if (Op.getOpcode() != BinOp)
+ if (Op.getOpcode() != CandidateBinOp)
return SDValue();
ShuffleVectorSDNode *Shuffle =
@@ -30127,8 +30134,8 @@ static SDValue matchBinOpReduction(SDNod
}
// The first operand of the shuffle should be the same as the other operand
- // of the add.
- if (!Shuffle || (Shuffle->getOperand(0) != Op))
+ // of the binop.
+ if (!Shuffle || Shuffle->getOperand(0) != Op)
return SDValue();
// Verify the shuffle has the expected (at this stage of the pyramid) mask.
@@ -30137,6 +30144,7 @@ static SDValue matchBinOpReduction(SDNod
return SDValue();
}
+ BinOp = CandidateBinOp;
return Op;
}
@@ -30250,66 +30258,63 @@ static SDValue combineHorizontalPredicat
return SDValue();
// Check for OR(any_of) and AND(all_of) horizontal reduction patterns.
- for (ISD::NodeType Op : {ISD::OR, ISD::AND}) {
- SDValue Match = matchBinOpReduction(Extract, Op);
- if (!Match)
- continue;
-
- // EXTRACT_VECTOR_ELT can require implicit extension of the vector element
- // which we can't support here for now.
- if (Match.getScalarValueSizeInBits() != BitWidth)
- continue;
-
- // We require AVX2 for PMOVMSKB for v16i16/v32i8;
- unsigned MatchSizeInBits = Match.getValueSizeInBits();
- if (!(MatchSizeInBits == 128 ||
- (MatchSizeInBits == 256 &&
- ((Subtarget.hasAVX() && BitWidth >= 32) || Subtarget.hasAVX2()))))
- return SDValue();
+ unsigned BinOp = 0;
+ SDValue Match = matchBinOpReduction(Extract, BinOp, {ISD::OR, ISD::AND});
+ if (!Match)
+ return SDValue();
- // Don't bother performing this for 2-element vectors.
- if (Match.getValueType().getVectorNumElements() <= 2)
- return SDValue();
+ // EXTRACT_VECTOR_ELT can require implicit extension of the vector element
+ // which we can't support here for now.
+ if (Match.getScalarValueSizeInBits() != BitWidth)
+ return SDValue();
- // Check that we are extracting a reduction of all sign bits.
- if (DAG.ComputeNumSignBits(Match) != BitWidth)
- return SDValue();
+ // We require AVX2 for PMOVMSKB for v16i16/v32i8;
+ unsigned MatchSizeInBits = Match.getValueSizeInBits();
+ if (!(MatchSizeInBits == 128 ||
+ (MatchSizeInBits == 256 &&
+ ((Subtarget.hasAVX() && BitWidth >= 32) || Subtarget.hasAVX2()))))
+ return SDValue();
- // For 32/64 bit comparisons use MOVMSKPS/MOVMSKPD, else PMOVMSKB.
- MVT MaskVT;
- if (64 == BitWidth || 32 == BitWidth)
- MaskVT = MVT::getVectorVT(MVT::getFloatingPointVT(BitWidth),
- MatchSizeInBits / BitWidth);
- else
- MaskVT = MVT::getVectorVT(MVT::i8, MatchSizeInBits / 8);
-
- APInt CompareBits;
- ISD::CondCode CondCode;
- if (Op == ISD::OR) {
- // any_of -> MOVMSK != 0
- CompareBits = APInt::getNullValue(32);
- CondCode = ISD::CondCode::SETNE;
- } else {
- // all_of -> MOVMSK == ((1 << NumElts) - 1)
- CompareBits = APInt::getLowBitsSet(32, MaskVT.getVectorNumElements());
- CondCode = ISD::CondCode::SETEQ;
- }
+ // Don't bother performing this for 2-element vectors.
+ if (Match.getValueType().getVectorNumElements() <= 2)
+ return SDValue();
+
+ // Check that we are extracting a reduction of all sign bits.
+ if (DAG.ComputeNumSignBits(Match) != BitWidth)
+ return SDValue();
- // Perform the select as i32/i64 and then truncate to avoid partial register
- // stalls.
- unsigned ResWidth = std::max(BitWidth, 32u);
- EVT ResVT = EVT::getIntegerVT(*DAG.getContext(), ResWidth);
- SDLoc DL(Extract);
- SDValue Zero = DAG.getConstant(0, DL, ResVT);
- SDValue Ones = DAG.getAllOnesConstant(DL, ResVT);
- SDValue Res = DAG.getBitcast(MaskVT, Match);
- Res = DAG.getNode(X86ISD::MOVMSK, DL, MVT::i32, Res);
- Res = DAG.getSelectCC(DL, Res, DAG.getConstant(CompareBits, DL, MVT::i32),
- Ones, Zero, CondCode);
- return DAG.getSExtOrTrunc(Res, DL, ExtractVT);
+ // For 32/64 bit comparisons use MOVMSKPS/MOVMSKPD, else PMOVMSKB.
+ MVT MaskVT;
+ if (64 == BitWidth || 32 == BitWidth)
+ MaskVT = MVT::getVectorVT(MVT::getFloatingPointVT(BitWidth),
+ MatchSizeInBits / BitWidth);
+ else
+ MaskVT = MVT::getVectorVT(MVT::i8, MatchSizeInBits / 8);
+
+ APInt CompareBits;
+ ISD::CondCode CondCode;
+ if (BinOp == ISD::OR) {
+ // any_of -> MOVMSK != 0
+ CompareBits = APInt::getNullValue(32);
+ CondCode = ISD::CondCode::SETNE;
+ } else {
+ // all_of -> MOVMSK == ((1 << NumElts) - 1)
+ CompareBits = APInt::getLowBitsSet(32, MaskVT.getVectorNumElements());
+ CondCode = ISD::CondCode::SETEQ;
}
- return SDValue();
+ // Perform the select as i32/i64 and then truncate to avoid partial register
+ // stalls.
+ unsigned ResWidth = std::max(BitWidth, 32u);
+ EVT ResVT = EVT::getIntegerVT(*DAG.getContext(), ResWidth);
+ SDLoc DL(Extract);
+ SDValue Zero = DAG.getConstant(0, DL, ResVT);
+ SDValue Ones = DAG.getAllOnesConstant(DL, ResVT);
+ SDValue Res = DAG.getBitcast(MaskVT, Match);
+ Res = DAG.getNode(X86ISD::MOVMSK, DL, MVT::i32, Res);
+ Res = DAG.getSelectCC(DL, Res, DAG.getConstant(CompareBits, DL, MVT::i32),
+ Ones, Zero, CondCode);
+ return DAG.getSExtOrTrunc(Res, DL, ExtractVT);
}
static SDValue combineBasicSADPattern(SDNode *Extract, SelectionDAG &DAG,
@@ -30336,7 +30341,8 @@ static SDValue combineBasicSADPattern(SD
return SDValue();
// Match shuffle + add pyramid.
- SDValue Root = matchBinOpReduction(Extract, ISD::ADD);
+ unsigned BinOp = 0;
+ SDValue Root = matchBinOpReduction(Extract, BinOp, {ISD::ADD});
// The operand is expected to be zero extended from i8
// (verified in detectZextAbsDiff).
More information about the llvm-commits
mailing list