[llvm] r317985 - [X86] Attempt to match multiple binary reduction ops at once. NFCI

Simon Pilgrim via llvm-commits llvm-commits at lists.llvm.org
Sat Nov 11 10:16:55 PST 2017


Author: rksimon
Date: Sat Nov 11 10:16:55 2017
New Revision: 317985

URL: http://llvm.org/viewvc/llvm-project?rev=317985&view=rev
Log:
[X86] Attempt to match multiple binary reduction ops at once. NFCI

matchBinOpReduction currently matches against a single opcode, but we already have a case where we repeat calls to try to match against AND/OR and I'll be shortly adding another case for SMAX/SMIN/UMAX/UMIN (D39729).

This NFCI patch alters matchBinOpReduction to try and pattern match against any of the provided list of candidate bin ops at once to save time.

Differential Revision: https://reviews.llvm.org/D39726

Modified:
    llvm/trunk/lib/Target/X86/X86ISelLowering.cpp

Modified: llvm/trunk/lib/Target/X86/X86ISelLowering.cpp
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/lib/Target/X86/X86ISelLowering.cpp?rev=317985&r1=317984&r2=317985&view=diff
==============================================================================
--- llvm/trunk/lib/Target/X86/X86ISelLowering.cpp (original)
+++ llvm/trunk/lib/Target/X86/X86ISelLowering.cpp Sat Nov 11 10:16:55 2017
@@ -30093,16 +30093,22 @@ static SDValue combineBitcast(SDNode *N,
 // the elements of a vector.
 // Returns the vector that is being reduced on, or SDValue() if a reduction
 // was not matched.
-static SDValue matchBinOpReduction(SDNode *Extract, ISD::NodeType BinOp) {
+static SDValue matchBinOpReduction(SDNode *Extract, unsigned &BinOp,
+                                   ArrayRef<ISD::NodeType> CandidateBinOps) {
   // The pattern must end in an extract from index 0.
   if ((Extract->getOpcode() != ISD::EXTRACT_VECTOR_ELT) ||
       !isNullConstant(Extract->getOperand(1)))
     return SDValue();
 
-  unsigned Stages =
-      Log2_32(Extract->getOperand(0).getValueType().getVectorNumElements());
-
   SDValue Op = Extract->getOperand(0);
+  unsigned Stages = Log2_32(Op.getValueType().getVectorNumElements());
+
+  // Match against one of the candidate binary ops.
+  if (llvm::none_of(CandidateBinOps, [Op](ISD::NodeType BinOp) {
+        return Op.getOpcode() == BinOp;
+      }))
+    return SDValue();
+
   // At each stage, we're looking for something that looks like:
   // %s = shufflevector <8 x i32> %op, <8 x i32> undef,
   //                    <8 x i32> <i32 2, i32 3, i32 undef, i32 undef,
@@ -30113,8 +30119,9 @@ static SDValue matchBinOpReduction(SDNod
   // <4,5,6,7,u,u,u,u>
   // <2,3,u,u,u,u,u,u>
   // <1,u,u,u,u,u,u,u>
+  unsigned CandidateBinOp = Op.getOpcode();
   for (unsigned i = 0; i < Stages; ++i) {
-    if (Op.getOpcode() != BinOp)
+    if (Op.getOpcode() != CandidateBinOp)
       return SDValue();
 
     ShuffleVectorSDNode *Shuffle =
@@ -30127,8 +30134,8 @@ static SDValue matchBinOpReduction(SDNod
     }
 
     // The first operand of the shuffle should be the same as the other operand
-    // of the add.
-    if (!Shuffle || (Shuffle->getOperand(0) != Op))
+    // of the binop.
+    if (!Shuffle || Shuffle->getOperand(0) != Op)
       return SDValue();
 
     // Verify the shuffle has the expected (at this stage of the pyramid) mask.
@@ -30137,6 +30144,7 @@ static SDValue matchBinOpReduction(SDNod
         return SDValue();
   }
 
+  BinOp = CandidateBinOp;
   return Op;
 }
 
@@ -30250,66 +30258,63 @@ static SDValue combineHorizontalPredicat
     return SDValue();
 
   // Check for OR(any_of) and AND(all_of) horizontal reduction patterns.
-  for (ISD::NodeType Op : {ISD::OR, ISD::AND}) {
-    SDValue Match = matchBinOpReduction(Extract, Op);
-    if (!Match)
-      continue;
-
-    // EXTRACT_VECTOR_ELT can require implicit extension of the vector element
-    // which we can't support here for now.
-    if (Match.getScalarValueSizeInBits() != BitWidth)
-      continue;
-
-    // We require AVX2 for PMOVMSKB for v16i16/v32i8;
-    unsigned MatchSizeInBits = Match.getValueSizeInBits();
-    if (!(MatchSizeInBits == 128 ||
-          (MatchSizeInBits == 256 &&
-           ((Subtarget.hasAVX() && BitWidth >= 32) || Subtarget.hasAVX2()))))
-      return SDValue();
+  unsigned BinOp = 0;
+  SDValue Match = matchBinOpReduction(Extract, BinOp, {ISD::OR, ISD::AND});
+  if (!Match)
+    return SDValue();
 
-    // Don't bother performing this for 2-element vectors.
-    if (Match.getValueType().getVectorNumElements() <= 2)
-      return SDValue();
+  // EXTRACT_VECTOR_ELT can require implicit extension of the vector element
+  // which we can't support here for now.
+  if (Match.getScalarValueSizeInBits() != BitWidth)
+    return SDValue();
 
-    // Check that we are extracting a reduction of all sign bits.
-    if (DAG.ComputeNumSignBits(Match) != BitWidth)
-      return SDValue();
+  // We require AVX2 for PMOVMSKB for v16i16/v32i8;
+  unsigned MatchSizeInBits = Match.getValueSizeInBits();
+  if (!(MatchSizeInBits == 128 ||
+        (MatchSizeInBits == 256 &&
+         ((Subtarget.hasAVX() && BitWidth >= 32) || Subtarget.hasAVX2()))))
+    return SDValue();
 
-    // For 32/64 bit comparisons use MOVMSKPS/MOVMSKPD, else PMOVMSKB.
-    MVT MaskVT;
-    if (64 == BitWidth || 32 == BitWidth)
-      MaskVT = MVT::getVectorVT(MVT::getFloatingPointVT(BitWidth),
-                                MatchSizeInBits / BitWidth);
-    else
-      MaskVT = MVT::getVectorVT(MVT::i8, MatchSizeInBits / 8);
-
-    APInt CompareBits;
-    ISD::CondCode CondCode;
-    if (Op == ISD::OR) {
-      // any_of -> MOVMSK != 0
-      CompareBits = APInt::getNullValue(32);
-      CondCode = ISD::CondCode::SETNE;
-    } else {
-      // all_of -> MOVMSK == ((1 << NumElts) - 1)
-      CompareBits = APInt::getLowBitsSet(32, MaskVT.getVectorNumElements());
-      CondCode = ISD::CondCode::SETEQ;
-    }
+  // Don't bother performing this for 2-element vectors.
+  if (Match.getValueType().getVectorNumElements() <= 2)
+    return SDValue();
+
+  // Check that we are extracting a reduction of all sign bits.
+  if (DAG.ComputeNumSignBits(Match) != BitWidth)
+    return SDValue();
 
-    // Perform the select as i32/i64 and then truncate to avoid partial register
-    // stalls.
-    unsigned ResWidth = std::max(BitWidth, 32u);
-    EVT ResVT = EVT::getIntegerVT(*DAG.getContext(), ResWidth);
-    SDLoc DL(Extract);
-    SDValue Zero = DAG.getConstant(0, DL, ResVT);
-    SDValue Ones = DAG.getAllOnesConstant(DL, ResVT);
-    SDValue Res = DAG.getBitcast(MaskVT, Match);
-    Res = DAG.getNode(X86ISD::MOVMSK, DL, MVT::i32, Res);
-    Res = DAG.getSelectCC(DL, Res, DAG.getConstant(CompareBits, DL, MVT::i32),
-                          Ones, Zero, CondCode);
-    return DAG.getSExtOrTrunc(Res, DL, ExtractVT);
+  // For 32/64 bit comparisons use MOVMSKPS/MOVMSKPD, else PMOVMSKB.
+  MVT MaskVT;
+  if (64 == BitWidth || 32 == BitWidth)
+    MaskVT = MVT::getVectorVT(MVT::getFloatingPointVT(BitWidth),
+                              MatchSizeInBits / BitWidth);
+  else
+    MaskVT = MVT::getVectorVT(MVT::i8, MatchSizeInBits / 8);
+
+  APInt CompareBits;
+  ISD::CondCode CondCode;
+  if (BinOp == ISD::OR) {
+    // any_of -> MOVMSK != 0
+    CompareBits = APInt::getNullValue(32);
+    CondCode = ISD::CondCode::SETNE;
+  } else {
+    // all_of -> MOVMSK == ((1 << NumElts) - 1)
+    CompareBits = APInt::getLowBitsSet(32, MaskVT.getVectorNumElements());
+    CondCode = ISD::CondCode::SETEQ;
   }
 
-  return SDValue();
+  // Perform the select as i32/i64 and then truncate to avoid partial register
+  // stalls.
+  unsigned ResWidth = std::max(BitWidth, 32u);
+  EVT ResVT = EVT::getIntegerVT(*DAG.getContext(), ResWidth);
+  SDLoc DL(Extract);
+  SDValue Zero = DAG.getConstant(0, DL, ResVT);
+  SDValue Ones = DAG.getAllOnesConstant(DL, ResVT);
+  SDValue Res = DAG.getBitcast(MaskVT, Match);
+  Res = DAG.getNode(X86ISD::MOVMSK, DL, MVT::i32, Res);
+  Res = DAG.getSelectCC(DL, Res, DAG.getConstant(CompareBits, DL, MVT::i32),
+                        Ones, Zero, CondCode);
+  return DAG.getSExtOrTrunc(Res, DL, ExtractVT);
 }
 
 static SDValue combineBasicSADPattern(SDNode *Extract, SelectionDAG &DAG,
@@ -30336,7 +30341,8 @@ static SDValue combineBasicSADPattern(SD
     return SDValue();
 
   // Match shuffle + add pyramid.
-  SDValue Root = matchBinOpReduction(Extract, ISD::ADD);
+  unsigned BinOp = 0;
+  SDValue Root = matchBinOpReduction(Extract, BinOp, {ISD::ADD});
 
   // The operand is expected to be zero extended from i8
   // (verified in detectZextAbsDiff).




More information about the llvm-commits mailing list