[llvm] r350826 - [x86] fix horizontal binop matching for 256-bit vectors (PR40243)

Sanjay Patel via llvm-commits llvm-commits at lists.llvm.org
Thu Jan 10 07:04:52 PST 2019


Author: spatel
Date: Thu Jan 10 07:04:52 2019
New Revision: 350826

URL: http://llvm.org/viewvc/llvm-project?rev=350826&view=rev
Log:
[x86] fix horizontal binop matching for 256-bit vectors (PR40243)

This is a partial fix for:
https://bugs.llvm.org/show_bug.cgi?id=40243
...as seen in the integer test, we still need to correct the result when using the 
existing (old) horizontal op matching function because it does not model the way 
x86 256-bit horizontal ops return results (each 128-bit half is its own horizontal-op). 
A potential follow-up change for that is discussed in the bug report - see also D56490.

This generally duplicates a lot of the existing matching code, but we can't just remove 
that without introducing regressions, so the existing code is renamed and used less often. 
Follow-ups may try to reduce that overlap.

Differential Revision: https://reviews.llvm.org/D56450

Modified:
    llvm/trunk/lib/Target/X86/X86ISelLowering.cpp
    llvm/trunk/test/CodeGen/X86/haddsub-undef.ll
    llvm/trunk/test/CodeGen/X86/phaddsub-undef.ll

Modified: llvm/trunk/lib/Target/X86/X86ISelLowering.cpp
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/lib/Target/X86/X86ISelLowering.cpp?rev=350826&r1=350825&r2=350826&view=diff
==============================================================================
--- llvm/trunk/lib/Target/X86/X86ISelLowering.cpp (original)
+++ llvm/trunk/lib/Target/X86/X86ISelLowering.cpp Thu Jan 10 07:04:52 2019
@@ -7851,13 +7851,14 @@ static SDValue LowerBUILD_VECTORvXi1(SDV
   return DstVec;
 }
 
-/// Return true if \p N implements a horizontal binop and return the
-/// operands for the horizontal binop into V0 and V1.
-///
 /// This is a helper function of LowerToHorizontalOp().
 /// This function checks that the build_vector \p N in input implements a
-/// horizontal operation. Parameter \p Opcode defines the kind of horizontal
-/// operation to match.
+/// 128-bit partial horizontal operation on a 256-bit vector, but that operation
+/// may not match the layout of an x86 256-bit horizontal instruction.
+/// In other words, if this returns true, then some extraction/insertion will
+/// be required to produce a valid horizontal instruction.
+///
+/// Parameter \p Opcode defines the kind of horizontal operation to match.
 /// For example, if \p Opcode is equal to ISD::ADD, then this function
 /// checks if \p N implements a horizontal arithmetic add; if instead \p Opcode
 /// is equal to ISD::SUB, then this function checks if this is a horizontal
@@ -7865,12 +7866,17 @@ static SDValue LowerBUILD_VECTORvXi1(SDV
 ///
 /// This function only analyzes elements of \p N whose indices are
 /// in range [BaseIdx, LastIdx).
-static bool isHorizontalBinOp(const BuildVectorSDNode *N, unsigned Opcode,
-                              SelectionDAG &DAG,
-                              unsigned BaseIdx, unsigned LastIdx,
-                              SDValue &V0, SDValue &V1) {
+///
+/// TODO: This function was originally used to match both real and fake partial
+/// horizontal operations, but the index-matching logic is incorrect for that.
+/// See the corrected implementation in isHopBuildVector(). Can we reduce this
+/// code because it is only used for partial h-op matching now?
+static bool isHorizontalBinOpPart(const BuildVectorSDNode *N, unsigned Opcode,
+                                  SelectionDAG &DAG,
+                                  unsigned BaseIdx, unsigned LastIdx,
+                                  SDValue &V0, SDValue &V1) {
   EVT VT = N->getValueType(0);
-
+  assert(VT.is256BitVector() && "Only use for matching partial 256-bit h-ops");
   assert(BaseIdx * 2 <= LastIdx && "Invalid Indices in input!");
   assert(VT.isVector() && VT.getVectorNumElements() >= LastIdx &&
          "Invalid Vector in input!");
@@ -8211,17 +8217,148 @@ static SDValue lowerToAddSubOrFMAddSub(c
   return DAG.getNode(X86ISD::ADDSUB, DL, VT, Opnd0, Opnd1);
 }
 
+static bool isHopBuildVector(const BuildVectorSDNode *BV, SelectionDAG &DAG,
+                             unsigned &HOpcode, SDValue &V0, SDValue &V1) {
+  // Initialize outputs to known values.
+  MVT VT = BV->getSimpleValueType(0);
+  HOpcode = ISD::DELETED_NODE;
+  V0 = DAG.getUNDEF(VT);
+  V1 = DAG.getUNDEF(VT);
+
+  // x86 256-bit horizontal ops are defined in a non-obvious way. Each 128-bit
+  // half of the result is calculated independently from the 128-bit halves of
+  // the inputs, so that makes the index-checking logic below more complicated.
+  unsigned NumElts = VT.getVectorNumElements();
+  unsigned GenericOpcode = ISD::DELETED_NODE;
+  unsigned Num128BitChunks = VT.is256BitVector() ? 2 : 1;
+  unsigned NumEltsIn128Bits = NumElts / Num128BitChunks;
+  unsigned NumEltsIn64Bits = NumEltsIn128Bits / 2;
+  for (unsigned i = 0; i != Num128BitChunks; ++i) {
+    for (unsigned j = 0; j != NumEltsIn128Bits; ++j) {
+      // Ignore undef elements.
+      SDValue Op = BV->getOperand(i * NumEltsIn128Bits + j);
+      if (Op.isUndef())
+        continue;
+
+      // If there's an opcode mismatch, we're done.
+      if (HOpcode != ISD::DELETED_NODE && Op.getOpcode() != GenericOpcode)
+        return false;
+
+      // Initialize horizontal opcode.
+      if (HOpcode == ISD::DELETED_NODE) {
+        GenericOpcode = Op.getOpcode();
+        switch (GenericOpcode) {
+        case ISD::ADD: HOpcode = X86ISD::HADD; break;
+        case ISD::SUB: HOpcode = X86ISD::HSUB; break;
+        case ISD::FADD: HOpcode = X86ISD::FHADD; break;
+        case ISD::FSUB: HOpcode = X86ISD::FHSUB; break;
+        default: return false;
+        }
+      }
+
+      SDValue Op0 = Op.getOperand(0);
+      SDValue Op1 = Op.getOperand(1);
+      if (Op0.getOpcode() != ISD::EXTRACT_VECTOR_ELT ||
+          Op1.getOpcode() != ISD::EXTRACT_VECTOR_ELT ||
+          Op0.getOperand(0) != Op1.getOperand(0) ||
+          !isa<ConstantSDNode>(Op0.getOperand(1)) ||
+          !isa<ConstantSDNode>(Op1.getOperand(1)) || !Op.hasOneUse())
+        return false;
+
+      // The source vector is chosen based on which 64-bit half of the
+      // destination vector is being calculated.
+      if (j < NumEltsIn64Bits) {
+        if (V0.isUndef())
+          V0 = Op0.getOperand(0);
+      } else {
+        if (V1.isUndef())
+          V1 = Op0.getOperand(0);
+      }
+
+      SDValue SourceVec = (j < NumEltsIn64Bits) ? V0 : V1;
+      if (SourceVec != Op0.getOperand(0))
+        return false;
+
+      // op (extract_vector_elt A, I), (extract_vector_elt A, I+1)
+      unsigned ExtIndex0 = Op0.getConstantOperandVal(1);
+      unsigned ExtIndex1 = Op1.getConstantOperandVal(1);
+      unsigned ExpectedIndex = i * NumEltsIn128Bits +
+                               (j % NumEltsIn64Bits) * 2;
+      if (ExpectedIndex == ExtIndex0 && ExtIndex1 == ExtIndex0 + 1)
+        continue;
+
+      // If this is not a commutative op, this does not match.
+      if (GenericOpcode != ISD::ADD && GenericOpcode != ISD::FADD)
+        return false;
+
+      // Addition is commutative, so try swapping the extract indexes.
+      // op (extract_vector_elt A, I+1), (extract_vector_elt A, I)
+      if (ExpectedIndex == ExtIndex1 && ExtIndex0 == ExtIndex1 + 1)
+        continue;
+
+      // Extract indexes do not match horizontal requirement.
+      return false;
+    }
+  }
+  // We matched. Opcode and operands are returned by reference as arguments.
+  return true;
+}
+
+static SDValue getHopForBuildVector(const BuildVectorSDNode *BV,
+                                    SelectionDAG &DAG, unsigned HOpcode,
+                                    SDValue V0, SDValue V1) {
+  // TODO: We should extract/insert to match the size of the build vector.
+  MVT VT = BV->getSimpleValueType(0);
+  if (V0.getValueType() != VT || V1.getValueType() != VT)
+    return SDValue();
+
+  return DAG.getNode(HOpcode, SDLoc(BV), VT, V0, V1);
+}
+
 /// Lower BUILD_VECTOR to a horizontal add/sub operation if possible.
 static SDValue LowerToHorizontalOp(const BuildVectorSDNode *BV,
                                    const X86Subtarget &Subtarget,
                                    SelectionDAG &DAG) {
+  // We need at least 2 non-undef elements to make this worthwhile by default.
+  unsigned NumNonUndefs = 0;
+  for (const SDValue &V : BV->op_values())
+    if (!V.isUndef())
+      ++NumNonUndefs;
+
+  if (NumNonUndefs < 2)
+    return SDValue();
+
+  // There are 4 sets of horizontal math operations distinguished by type:
+  // int/FP at 128-bit/256-bit. Each type was introduced with a different
+  // subtarget feature. Try to match those "native" patterns first.
   MVT VT = BV->getSimpleValueType(0);
+  unsigned HOpcode;
+  SDValue V0, V1;
+  if ((VT == MVT::v4f32 || VT == MVT::v2f64) && Subtarget.hasSSE3())
+    if (isHopBuildVector(BV, DAG, HOpcode, V0, V1))
+      return getHopForBuildVector(BV, DAG, HOpcode, V0, V1);
+
+  if ((VT == MVT::v8i16 || VT == MVT::v4i32) && Subtarget.hasSSSE3())
+    if (isHopBuildVector(BV, DAG, HOpcode, V0, V1))
+      return getHopForBuildVector(BV, DAG, HOpcode, V0, V1);
+
+  if ((VT == MVT::v8f32 || VT == MVT::v4f64) && Subtarget.hasAVX())
+    if (isHopBuildVector(BV, DAG, HOpcode, V0, V1))
+      return getHopForBuildVector(BV, DAG, HOpcode, V0, V1);
+
+  if ((VT == MVT::v16i16 || VT == MVT::v8i32) && Subtarget.hasAVX2())
+    if (isHopBuildVector(BV, DAG, HOpcode, V0, V1))
+      return getHopForBuildVector(BV, DAG, HOpcode, V0, V1);
+
+  // Try harder to match 256-bit ops by using extract/concat.
+  if (!Subtarget.hasAVX() || !VT.is256BitVector())
+    return SDValue();
+
+  // Count the number of UNDEF operands in the build_vector in input.
   unsigned NumElts = VT.getVectorNumElements();
+  unsigned Half = NumElts / 2;
   unsigned NumUndefsLO = 0;
   unsigned NumUndefsHI = 0;
-  unsigned Half = NumElts/2;
-
-  // Count the number of UNDEF operands in the build_vector in input.
   for (unsigned i = 0, e = Half; i != e; ++i)
     if (BV->getOperand(i)->isUndef())
       NumUndefsLO++;
@@ -8230,72 +8367,31 @@ static SDValue LowerToHorizontalOp(const
     if (BV->getOperand(i)->isUndef())
       NumUndefsHI++;
 
-  // Early exit if this is either a build_vector of all UNDEFs or all the
-  // operands but one are UNDEF.
-  if (NumUndefsLO + NumUndefsHI + 1 >= NumElts)
-    return SDValue();
-
   SDLoc DL(BV);
   SDValue InVec0, InVec1;
-  if ((VT == MVT::v4f32 || VT == MVT::v2f64) && Subtarget.hasSSE3()) {
-    // Try to match an SSE3 float HADD/HSUB.
-    if (isHorizontalBinOp(BV, ISD::FADD, DAG, 0, NumElts, InVec0, InVec1))
-      return DAG.getNode(X86ISD::FHADD, DL, VT, InVec0, InVec1);
-
-    if (isHorizontalBinOp(BV, ISD::FSUB, DAG, 0, NumElts, InVec0, InVec1))
-      return DAG.getNode(X86ISD::FHSUB, DL, VT, InVec0, InVec1);
-  } else if ((VT == MVT::v4i32 || VT == MVT::v8i16) && Subtarget.hasSSSE3()) {
-    // Try to match an SSSE3 integer HADD/HSUB.
-    if (isHorizontalBinOp(BV, ISD::ADD, DAG, 0, NumElts, InVec0, InVec1))
-      return DAG.getNode(X86ISD::HADD, DL, VT, InVec0, InVec1);
-
-    if (isHorizontalBinOp(BV, ISD::SUB, DAG, 0, NumElts, InVec0, InVec1))
-      return DAG.getNode(X86ISD::HSUB, DL, VT, InVec0, InVec1);
-  }
-
-  if (!Subtarget.hasAVX())
-    return SDValue();
-
-  if ((VT == MVT::v8f32 || VT == MVT::v4f64)) {
-    // Try to match an AVX horizontal add/sub of packed single/double
-    // precision floating point values from 256-bit vectors.
-    SDValue InVec2, InVec3;
-    if (isHorizontalBinOp(BV, ISD::FADD, DAG, 0, Half, InVec0, InVec1) &&
-        isHorizontalBinOp(BV, ISD::FADD, DAG, Half, NumElts, InVec2, InVec3) &&
-        ((InVec0.isUndef() || InVec2.isUndef()) || InVec0 == InVec2) &&
-        ((InVec1.isUndef() || InVec3.isUndef()) || InVec1 == InVec3))
-      return DAG.getNode(X86ISD::FHADD, DL, VT, InVec0, InVec1);
-
-    if (isHorizontalBinOp(BV, ISD::FSUB, DAG, 0, Half, InVec0, InVec1) &&
-        isHorizontalBinOp(BV, ISD::FSUB, DAG, Half, NumElts, InVec2, InVec3) &&
-        ((InVec0.isUndef() || InVec2.isUndef()) || InVec0 == InVec2) &&
-        ((InVec1.isUndef() || InVec3.isUndef()) || InVec1 == InVec3))
-      return DAG.getNode(X86ISD::FHSUB, DL, VT, InVec0, InVec1);
-  } else if (VT == MVT::v8i32 || VT == MVT::v16i16) {
+  if (VT == MVT::v8i32 || VT == MVT::v16i16) {
     // Try to match an AVX2 horizontal add/sub of signed integers.
     SDValue InVec2, InVec3;
     unsigned X86Opcode;
     bool CanFold = true;
 
-    if (isHorizontalBinOp(BV, ISD::ADD, DAG, 0, Half, InVec0, InVec1) &&
-        isHorizontalBinOp(BV, ISD::ADD, DAG, Half, NumElts, InVec2, InVec3) &&
+    if (isHorizontalBinOpPart(BV, ISD::ADD, DAG, 0, Half, InVec0, InVec1) &&
+        isHorizontalBinOpPart(BV, ISD::ADD, DAG, Half, NumElts, InVec2,
+                              InVec3) &&
         ((InVec0.isUndef() || InVec2.isUndef()) || InVec0 == InVec2) &&
         ((InVec1.isUndef() || InVec3.isUndef()) || InVec1 == InVec3))
       X86Opcode = X86ISD::HADD;
-    else if (isHorizontalBinOp(BV, ISD::SUB, DAG, 0, Half, InVec0, InVec1) &&
-        isHorizontalBinOp(BV, ISD::SUB, DAG, Half, NumElts, InVec2, InVec3) &&
-        ((InVec0.isUndef() || InVec2.isUndef()) || InVec0 == InVec2) &&
-        ((InVec1.isUndef() || InVec3.isUndef()) || InVec1 == InVec3))
+    else if (isHorizontalBinOpPart(BV, ISD::SUB, DAG, 0, Half, InVec0,
+                                   InVec1) &&
+             isHorizontalBinOpPart(BV, ISD::SUB, DAG, Half, NumElts, InVec2,
+                                   InVec3) &&
+             ((InVec0.isUndef() || InVec2.isUndef()) || InVec0 == InVec2) &&
+             ((InVec1.isUndef() || InVec3.isUndef()) || InVec1 == InVec3))
       X86Opcode = X86ISD::HSUB;
     else
       CanFold = false;
 
     if (CanFold) {
-      // Fold this build_vector into a single horizontal add/sub.
-      // Do this only if the target has AVX2.
-      if (Subtarget.hasAVX2())
-        return DAG.getNode(X86Opcode, DL, VT, InVec0, InVec1);
-
       // Do not try to expand this build_vector into a pair of horizontal
       // add/sub if we can emit a pair of scalar add/sub.
       if (NumUndefsLO + 1 == Half || NumUndefsHI + 1 == Half)
@@ -8310,16 +8406,19 @@ static SDValue LowerToHorizontalOp(const
     }
   }
 
-  if ((VT == MVT::v8f32 || VT == MVT::v4f64 || VT == MVT::v8i32 ||
-       VT == MVT::v16i16) && Subtarget.hasAVX()) {
+  if (VT == MVT::v8f32 || VT == MVT::v4f64 || VT == MVT::v8i32 ||
+      VT == MVT::v16i16) {
     unsigned X86Opcode;
-    if (isHorizontalBinOp(BV, ISD::ADD, DAG, 0, NumElts, InVec0, InVec1))
+    if (isHorizontalBinOpPart(BV, ISD::ADD, DAG, 0, NumElts, InVec0, InVec1))
       X86Opcode = X86ISD::HADD;
-    else if (isHorizontalBinOp(BV, ISD::SUB, DAG, 0, NumElts, InVec0, InVec1))
+    else if (isHorizontalBinOpPart(BV, ISD::SUB, DAG, 0, NumElts, InVec0,
+                                   InVec1))
       X86Opcode = X86ISD::HSUB;
-    else if (isHorizontalBinOp(BV, ISD::FADD, DAG, 0, NumElts, InVec0, InVec1))
+    else if (isHorizontalBinOpPart(BV, ISD::FADD, DAG, 0, NumElts, InVec0,
+                                   InVec1))
       X86Opcode = X86ISD::FHADD;
-    else if (isHorizontalBinOp(BV, ISD::FSUB, DAG, 0, NumElts, InVec0, InVec1))
+    else if (isHorizontalBinOpPart(BV, ISD::FSUB, DAG, 0, NumElts, InVec0,
+                                   InVec1))
       X86Opcode = X86ISD::FHSUB;
     else
       return SDValue();

Modified: llvm/trunk/test/CodeGen/X86/haddsub-undef.ll
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/test/CodeGen/X86/haddsub-undef.ll?rev=350826&r1=350825&r2=350826&view=diff
==============================================================================
--- llvm/trunk/test/CodeGen/X86/haddsub-undef.ll (original)
+++ llvm/trunk/test/CodeGen/X86/haddsub-undef.ll Thu Jan 10 07:04:52 2019
@@ -300,7 +300,7 @@ define <8 x float> @test11_undef(<8 x fl
 ;
 ; AVX-LABEL: test11_undef:
 ; AVX:       # %bb.0:
-; AVX-NEXT:    vhaddps %ymm0, %ymm0, %ymm0
+; AVX-NEXT:    vhaddps %ymm1, %ymm0, %ymm0
 ; AVX-NEXT:    retq
   %vecext = extractelement <8 x float> %a, i32 0
   %vecext1 = extractelement <8 x float> %a, i32 1
@@ -934,12 +934,12 @@ define <8 x float> @v16f32_inputs_v8f32_
 ;
 ; AVX1-SLOW-LABEL: v16f32_inputs_v8f32_output_4567:
 ; AVX1-SLOW:       # %bb.0:
-; AVX1-SLOW-NEXT:    vhaddps %ymm0, %ymm0, %ymm0
+; AVX1-SLOW-NEXT:    vhaddps %ymm2, %ymm0, %ymm0
 ; AVX1-SLOW-NEXT:    retq
 ;
 ; AVX1-FAST-LABEL: v16f32_inputs_v8f32_output_4567:
 ; AVX1-FAST:       # %bb.0:
-; AVX1-FAST-NEXT:    vhaddps %ymm0, %ymm0, %ymm0
+; AVX1-FAST-NEXT:    vhaddps %ymm2, %ymm0, %ymm0
 ; AVX1-FAST-NEXT:    retq
 ;
 ; AVX512-LABEL: v16f32_inputs_v8f32_output_4567:
@@ -973,7 +973,7 @@ define <8 x float> @PR40243(<8 x float>
 ;
 ; AVX-LABEL: PR40243:
 ; AVX:       # %bb.0:
-; AVX-NEXT:    vhaddps %ymm0, %ymm0, %ymm0
+; AVX-NEXT:    vhaddps %ymm1, %ymm0, %ymm0
 ; AVX-NEXT:    retq
   %a4 = extractelement <8 x float> %a, i32 4
   %a5 = extractelement <8 x float> %a, i32 5

Modified: llvm/trunk/test/CodeGen/X86/phaddsub-undef.ll
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/test/CodeGen/X86/phaddsub-undef.ll?rev=350826&r1=350825&r2=350826&view=diff
==============================================================================
--- llvm/trunk/test/CodeGen/X86/phaddsub-undef.ll (original)
+++ llvm/trunk/test/CodeGen/X86/phaddsub-undef.ll Thu Jan 10 07:04:52 2019
@@ -75,12 +75,12 @@ define <8 x i32> @test15_undef(<8 x i32>
 ;
 ; AVX2-LABEL: test15_undef:
 ; AVX2:       # %bb.0:
-; AVX2-NEXT:    vphaddd %ymm0, %ymm0, %ymm0
+; AVX2-NEXT:    vphaddd %ymm1, %ymm0, %ymm0
 ; AVX2-NEXT:    retq
 ;
 ; AVX512-LABEL: test15_undef:
 ; AVX512:       # %bb.0:
-; AVX512-NEXT:    vphaddd %ymm0, %ymm0, %ymm0
+; AVX512-NEXT:    vphaddd %ymm1, %ymm0, %ymm0
 ; AVX512-NEXT:    retq
   %vecext = extractelement <8 x i32> %a, i32 0
   %vecext1 = extractelement <8 x i32> %a, i32 1
@@ -105,12 +105,12 @@ define <8 x i32> @PR40243_alt(<8 x i32>
 ;
 ; AVX2-LABEL: PR40243_alt:
 ; AVX2:       # %bb.0:
-; AVX2-NEXT:    vphaddd %ymm0, %ymm0, %ymm0
+; AVX2-NEXT:    vphaddd %ymm1, %ymm0, %ymm0
 ; AVX2-NEXT:    retq
 ;
 ; AVX512-LABEL: PR40243_alt:
 ; AVX512:       # %bb.0:
-; AVX512-NEXT:    vphaddd %ymm0, %ymm0, %ymm0
+; AVX512-NEXT:    vphaddd %ymm1, %ymm0, %ymm0
 ; AVX512-NEXT:    retq
   %a4 = extractelement <8 x i32> %a, i32 4
   %a5 = extractelement <8 x i32> %a, i32 5




More information about the llvm-commits mailing list