[llvm] r276918 - [X86] Factor out another piece of the SAD combine. NFCI.

Michael Kuperstein via llvm-commits llvm-commits at lists.llvm.org
Wed Jul 27 13:59:51 PDT 2016


Author: mkuper
Date: Wed Jul 27 15:59:51 2016
New Revision: 276918

URL: http://llvm.org/viewvc/llvm-project?rev=276918&view=rev
Log:
[X86] Factor out another piece of the SAD combine. NFCI.

Modified:
    llvm/trunk/lib/Target/X86/X86ISelLowering.cpp

Modified: llvm/trunk/lib/Target/X86/X86ISelLowering.cpp
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/lib/Target/X86/X86ISelLowering.cpp?rev=276918&r1=276917&r2=276918&view=diff
==============================================================================
--- llvm/trunk/lib/Target/X86/X86ISelLowering.cpp (original)
+++ llvm/trunk/lib/Target/X86/X86ISelLowering.cpp Wed Jul 27 15:59:51 2016
@@ -26358,6 +26358,86 @@ static SDValue combineBitcast(SDNode *N,
   return SDValue();
 }
 
+// Given a select, detect the following pattern:
+// 1:    %2 = zext <N x i8> %0 to <N x i32>
+// 2:    %3 = zext <N x i8> %1 to <N x i32>
+// 3:    %4 = sub nsw <N x i32> %2, %3
+// 4:    %5 = icmp sgt <N x i32> %4, [0 x N] or [-1 x N]
+// 5:    %6 = sub nsw <N x i32> zeroinitializer, %4
+// 6:    %7 = select <N x i1> %5, <N x i32> %4, <N x i32> %6
+// This is useful as it is the input into a SAD pattern.
+static bool detectZextAbsDiff(const SDValue &Select, SDValue &Op0,
+                              SDValue &Op1) {
+  // Check the condition of the select instruction is greater-than.
+  SDValue SetCC = Select->getOperand(0);
+  if (SetCC.getOpcode() != ISD::SETCC)
+    return false;
+  ISD::CondCode CC = cast<CondCodeSDNode>(SetCC.getOperand(2))->get();
+  if (CC != ISD::SETGT)
+    return false;
+
+  SDValue SelectOp1 = Select->getOperand(1);
+  SDValue SelectOp2 = Select->getOperand(2);
+
+  // The second operand of the select should be the negation of the first
+  // operand, which is implemented as 0 - SelectOp1.
+  if (!(SelectOp2.getOpcode() == ISD::SUB &&
+        ISD::isBuildVectorAllZeros(SelectOp2.getOperand(0).getNode()) &&
+        SelectOp2.getOperand(1) == SelectOp1))
+    return false;
+
+  // The first operand of SetCC is the first operand of the select, which is the
+  // difference between the two input vectors.
+  if (SetCC.getOperand(0) != SelectOp1)
+    return false;
+
+  // The second operand of the comparison can be either -1 or 0.
+  if (!(ISD::isBuildVectorAllZeros(SetCC.getOperand(1).getNode()) ||
+        ISD::isBuildVectorAllOnes(SetCC.getOperand(1).getNode())))
+    return false;
+
+  // The first operand of the select is the difference between the two input
+  // vectors.
+  if (SelectOp1.getOpcode() != ISD::SUB)
+    return false;
+
+  Op0 = SelectOp1.getOperand(0);
+  Op1 = SelectOp1.getOperand(1);
+
+  // Check if the operands of the sub are zero-extended from vectors of i8.
+  if (Op0.getOpcode() != ISD::ZERO_EXTEND ||
+      Op0.getOperand(0).getValueType().getVectorElementType() != MVT::i8 ||
+      Op1.getOpcode() != ISD::ZERO_EXTEND ||
+      Op1.getOperand(0).getValueType().getVectorElementType() != MVT::i8)
+    return false;
+
+  return true;
+}
+
+// Given two zexts of <k x i8> to <k x i32>, create a PSADBW of the inputs
+// to these zexts.
+static SDValue createPSADBW(SelectionDAG &DAG, const SDValue &Zext0,
+                            const SDValue &Zext1, const SDLoc &DL) {
+
+  // Find the appropriate width for the PSADBW.
+  EVT InVT = Zext0.getOperand(0).getValueType();
+  unsigned RegSize = std::max(128u, InVT.getSizeInBits());
+
+  // "Zero-extend" the i8 vectors. This is not a per-element zext, rather we
+  // fill in the missing vector elements with 0.
+  unsigned NumConcat = RegSize / InVT.getSizeInBits();
+  SmallVector<SDValue, 16> Ops(NumConcat, DAG.getConstant(0, DL, InVT));
+  Ops[0] = Zext0.getOperand(0);
+  MVT ExtendedVT = MVT::getVectorVT(MVT::i8, RegSize / 8);
+  SDValue SadOp0 = DAG.getNode(ISD::CONCAT_VECTORS, DL, ExtendedVT, Ops);
+  Ops[0] = Zext1.getOperand(0);
+  SDValue SadOp1 = DAG.getNode(ISD::CONCAT_VECTORS, DL, ExtendedVT, Ops);
+
+  // Actually build the SAD
+  MVT SadVT = MVT::getVectorVT(MVT::i64, RegSize / 64);
+  return DAG.getNode(X86ISD::PSADBW, DL, SadVT, SadOp0, SadOp1);
+}
+
 /// Detect vector gather/scatter index generation and convert it from being a
 /// bunch of shuffles and extracts into a somewhat faster sequence.
 /// For i686, the best sequence is apparently storing the value and loading
@@ -30680,62 +30760,6 @@ static SDValue OptimizeConditionalInDecr
                      DAG.getConstant(0, DL, OtherVal.getValueType()), NewCmp);
 }
 
-// Given a select, detect the following pattern:
-// 1:    %2 = zext <N x i8> %0 to <N x i32>
-// 2:    %3 = zext <N x i8> %1 to <N x i32>
-// 3:    %4 = sub nsw <N x i32> %2, %3
-// 4:    %5 = icmp sgt <N x i32> %4, [0 x N] or [-1 x N]
-// 5:    %6 = sub nsw <N x i32> zeroinitializer, %4
-// 6:    %7 = select <N x i1> %5, <N x i32> %4, <N x i32> %6
-// This is useful as it is the input into a SAD pattern.
-static bool detectZextAbsDiff(const SDValue &Select, SDValue &Op0,
-                              SDValue &Op1) {
-  // Check the condition of the select instruction is greater-than.
-  SDValue SetCC = Select->getOperand(0);
-  if (SetCC.getOpcode() != ISD::SETCC)
-    return false;
-  ISD::CondCode CC = cast<CondCodeSDNode>(SetCC.getOperand(2))->get();
-  if (CC != ISD::SETGT)
-    return false;
-
-  SDValue SelectOp1 = Select->getOperand(1);
-  SDValue SelectOp2 = Select->getOperand(2);
-
-  // The second operand of the select should be the negation of the first
-  // operand, which is implemented as 0 - SelectOp1.
-  if (!(SelectOp2.getOpcode() == ISD::SUB &&
-        ISD::isBuildVectorAllZeros(SelectOp2.getOperand(0).getNode()) &&
-        SelectOp2.getOperand(1) == SelectOp1))
-    return false;
-
-  // The first operand of SetCC is the first operand of the select, which is the
-  // difference between the two input vectors.
-  if (SetCC.getOperand(0) != SelectOp1)
-    return false;
-
-  // The second operand of the comparison can be either -1 or 0.
-  if (!(ISD::isBuildVectorAllZeros(SetCC.getOperand(1).getNode()) ||
-        ISD::isBuildVectorAllOnes(SetCC.getOperand(1).getNode())))
-    return false;
-
-  // The first operand of the select is the difference between the two input
-  // vectors.
-  if (SelectOp1.getOpcode() != ISD::SUB)
-    return false;
-
-  Op0 = SelectOp1.getOperand(0);
-  Op1 = SelectOp1.getOperand(1);
-
-  // Check if the operands of the sub are zero-extended from vectors of i8.
-  if (Op0.getOpcode() != ISD::ZERO_EXTEND ||
-      Op0.getOperand(0).getValueType().getVectorElementType() != MVT::i8 ||
-      Op1.getOpcode() != ISD::ZERO_EXTEND ||
-      Op1.getOperand(0).getValueType().getVectorElementType() != MVT::i8)
-    return false;
-
-  return true;
-}
-
 static SDValue combineLoopSADPattern(SDNode *N, SelectionDAG &DAG,
                                      const X86Subtarget &Subtarget) {
   SDLoc DL(N);
@@ -30777,31 +30801,15 @@ static SDValue combineLoopSADPattern(SDN
   // reduction. Note that the number of elements of the result of SAD is less
   // than the number of elements of its input. Therefore, we could only update
   // part of elements in the reduction vector.
-
-  // Legalize the type of the inputs of PSADBW.
-  EVT InVT = Op0.getOperand(0).getValueType();
-  if (InVT.getSizeInBits() <= 128)
-    RegSize = 128;
-  else if (InVT.getSizeInBits() <= 256)
-    RegSize = 256;
-
-  unsigned NumConcat = RegSize / InVT.getSizeInBits();
-  SmallVector<SDValue, 16> Ops(NumConcat, DAG.getConstant(0, DL, InVT));
-  Ops[0] = Op0.getOperand(0);
-  MVT ExtendedVT = MVT::getVectorVT(MVT::i8, RegSize / 8);
-  Op0 = DAG.getNode(ISD::CONCAT_VECTORS, DL, ExtendedVT, Ops);
-  Ops[0] = Op1.getOperand(0);
-  Op1 = DAG.getNode(ISD::CONCAT_VECTORS, DL, ExtendedVT, Ops);
+  SDValue Sad = createPSADBW(DAG, Op0, Op1, DL);
 
   // The output of PSADBW is a vector of i64.
-  MVT SadVT = MVT::getVectorVT(MVT::i64, RegSize / 64);
-  SDValue Sad = DAG.getNode(X86ISD::PSADBW, DL, SadVT, Op0, Op1);
-
   // We need to turn the vector of i64 into a vector of i32.
   // If the reduction vector is at least as wide as the psadbw result, just
   // bitcast. If it's narrower, truncate - the high i32 of each i64 is zero
   // anyway.
-  MVT ResVT = MVT::getVectorVT(MVT::i32, RegSize / 32);
+  MVT ResVT =
+      MVT::getVectorVT(MVT::i32, Sad.getValueType().getSizeInBits() / 32);
   if (VT.getSizeInBits() >= ResVT.getSizeInBits())
     Sad = DAG.getNode(ISD::BITCAST, DL, ResVT, Sad);
   else




More information about the llvm-commits mailing list