[llvm-commits] [llvm] r144987 - in /llvm/trunk: lib/Target/X86/X86ISelLowering.cpp lib/Target/X86/X86InstrFragmentsSIMD.td lib/Target/X86/X86InstrSSE.td test/CodeGen/X86/avx2-logic.ll

Craig Topper craig.topper at gmail.com
Fri Nov 18 23:07:26 PST 2011


Author: ctopper
Date: Sat Nov 19 01:07:26 2011
New Revision: 144987

URL: http://llvm.org/viewvc/llvm-project?rev=144987&view=rev
Log:
Extend VPBLENDVB and VPSIGN lowering to work for AVX2.

Modified:
    llvm/trunk/lib/Target/X86/X86ISelLowering.cpp
    llvm/trunk/lib/Target/X86/X86InstrFragmentsSIMD.td
    llvm/trunk/lib/Target/X86/X86InstrSSE.td
    llvm/trunk/test/CodeGen/X86/avx2-logic.ll

Modified: llvm/trunk/lib/Target/X86/X86ISelLowering.cpp
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/lib/Target/X86/X86ISelLowering.cpp?rev=144987&r1=144986&r2=144987&view=diff
==============================================================================
--- llvm/trunk/lib/Target/X86/X86ISelLowering.cpp (original)
+++ llvm/trunk/lib/Target/X86/X86ISelLowering.cpp Sat Nov 19 01:07:26 2011
@@ -13859,98 +13859,105 @@
     return R;
 
   EVT VT = N->getValueType(0);
-  if (VT != MVT::i16 && VT != MVT::i32 && VT != MVT::i64 && VT != MVT::v2i64)
-    return SDValue();
 
   SDValue N0 = N->getOperand(0);
   SDValue N1 = N->getOperand(1);
 
   // look for psign/blend
-  if (Subtarget->hasSSSE3() || Subtarget->hasAVX()) {
-    if (VT == MVT::v2i64) {
-      // Canonicalize pandn to RHS
-      if (N0.getOpcode() == X86ISD::ANDNP)
-        std::swap(N0, N1);
-      // or (and (m, x), (pandn m, y))
-      if (N0.getOpcode() == ISD::AND && N1.getOpcode() == X86ISD::ANDNP) {
-        SDValue Mask = N1.getOperand(0);
-        SDValue X    = N1.getOperand(1);
-        SDValue Y;
-        if (N0.getOperand(0) == Mask)
-          Y = N0.getOperand(1);
-        if (N0.getOperand(1) == Mask)
-          Y = N0.getOperand(0);
-
-        // Check to see if the mask appeared in both the AND and ANDNP and
-        if (!Y.getNode())
-          return SDValue();
-
-        // Validate that X, Y, and Mask are BIT_CONVERTS, and see through them.
-        if (Mask.getOpcode() != ISD::BITCAST ||
-            X.getOpcode() != ISD::BITCAST ||
-            Y.getOpcode() != ISD::BITCAST)
-          return SDValue();
-
-        // Look through mask bitcast.
-        Mask = Mask.getOperand(0);
-        EVT MaskVT = Mask.getValueType();
-
-        // Validate that the Mask operand is a vector sra node.  The sra node
-        // will be an intrinsic.
-        if (Mask.getOpcode() != ISD::INTRINSIC_WO_CHAIN)
-          return SDValue();
-
-        // FIXME: what to do for bytes, since there is a psignb/pblendvb, but
-        // there is no psrai.b
-        switch (cast<ConstantSDNode>(Mask.getOperand(0))->getZExtValue()) {
-        case Intrinsic::x86_sse2_psrai_w:
-        case Intrinsic::x86_sse2_psrai_d:
-          break;
-        default: return SDValue();
-        }
+  if (VT == MVT::v2i64 || VT == MVT::v4i64) {
+    if (!(Subtarget->hasSSSE3() || Subtarget->hasAVX()) ||
+        (VT == MVT::v4i64 && !Subtarget->hasAVX2()))
+      return SDValue();
 
-        // Check that the SRA is all signbits.
-        SDValue SraC = Mask.getOperand(2);
-        unsigned SraAmt  = cast<ConstantSDNode>(SraC)->getZExtValue();
-        unsigned EltBits = MaskVT.getVectorElementType().getSizeInBits();
-        if ((SraAmt + 1) != EltBits)
-          return SDValue();
-
-        DebugLoc DL = N->getDebugLoc();
-
-        // Now we know we at least have a plendvb with the mask val.  See if
-        // we can form a psignb/w/d.
-        // psign = x.type == y.type == mask.type && y = sub(0, x);
-        X = X.getOperand(0);
-        Y = Y.getOperand(0);
-        if (Y.getOpcode() == ISD::SUB && Y.getOperand(1) == X &&
-            ISD::isBuildVectorAllZeros(Y.getOperand(0).getNode()) &&
-            X.getValueType() == MaskVT && X.getValueType() == Y.getValueType()){
-          unsigned Opc = 0;
-          switch (EltBits) {
-          case 8: Opc = X86ISD::PSIGNB; break;
-          case 16: Opc = X86ISD::PSIGNW; break;
-          case 32: Opc = X86ISD::PSIGND; break;
-          default: break;
-          }
-          if (Opc) {
-            SDValue Sign = DAG.getNode(Opc, DL, MaskVT, X, Mask.getOperand(1));
-            return DAG.getNode(ISD::BITCAST, DL, MVT::v2i64, Sign);
-          }
+    // Canonicalize pandn to RHS
+    if (N0.getOpcode() == X86ISD::ANDNP)
+      std::swap(N0, N1);
+    // or (and (m, x), (pandn m, y))
+    if (N0.getOpcode() == ISD::AND && N1.getOpcode() == X86ISD::ANDNP) {
+      SDValue Mask = N1.getOperand(0);
+      SDValue X    = N1.getOperand(1);
+      SDValue Y;
+      if (N0.getOperand(0) == Mask)
+        Y = N0.getOperand(1);
+      if (N0.getOperand(1) == Mask)
+        Y = N0.getOperand(0);
+
+      // Check to see if the mask appeared in both the AND and ANDNP and
+      if (!Y.getNode())
+        return SDValue();
+
+      // Validate that X, Y, and Mask are BIT_CONVERTS, and see through them.
+      if (Mask.getOpcode() != ISD::BITCAST ||
+          X.getOpcode() != ISD::BITCAST ||
+          Y.getOpcode() != ISD::BITCAST)
+        return SDValue();
+
+      // Look through mask bitcast.
+      Mask = Mask.getOperand(0);
+      EVT MaskVT = Mask.getValueType();
+
+      // Validate that the Mask operand is a vector sra node.  The sra node
+      // will be an intrinsic.
+      if (Mask.getOpcode() != ISD::INTRINSIC_WO_CHAIN)
+        return SDValue();
+
+      // FIXME: what to do for bytes, since there is a psignb/pblendvb, but
+      // there is no psrai.b
+      switch (cast<ConstantSDNode>(Mask.getOperand(0))->getZExtValue()) {
+      case Intrinsic::x86_sse2_psrai_w:
+      case Intrinsic::x86_sse2_psrai_d:
+      case Intrinsic::x86_avx2_psrai_w:
+      case Intrinsic::x86_avx2_psrai_d:
+        break;
+      default: return SDValue();
+      }
+
+      // Check that the SRA is all signbits.
+      SDValue SraC = Mask.getOperand(2);
+      unsigned SraAmt  = cast<ConstantSDNode>(SraC)->getZExtValue();
+      unsigned EltBits = MaskVT.getVectorElementType().getSizeInBits();
+      if ((SraAmt + 1) != EltBits)
+        return SDValue();
+
+      DebugLoc DL = N->getDebugLoc();
+
+      // Now we know we at least have a plendvb with the mask val.  See if
+      // we can form a psignb/w/d.
+      // psign = x.type == y.type == mask.type && y = sub(0, x);
+      X = X.getOperand(0);
+      Y = Y.getOperand(0);
+      if (Y.getOpcode() == ISD::SUB && Y.getOperand(1) == X &&
+          ISD::isBuildVectorAllZeros(Y.getOperand(0).getNode()) &&
+          X.getValueType() == MaskVT && X.getValueType() == Y.getValueType()){
+        unsigned Opc = 0;
+        switch (EltBits) {
+        case 8: Opc = X86ISD::PSIGNB; break;
+        case 16: Opc = X86ISD::PSIGNW; break;
+        case 32: Opc = X86ISD::PSIGND; break;
+        default: break;
+        }
+        if (Opc) {
+          SDValue Sign = DAG.getNode(Opc, DL, MaskVT, X, Mask.getOperand(1));
+          return DAG.getNode(ISD::BITCAST, DL, VT, Sign);
         }
-        // PBLENDVB only available on SSE 4.1
-        if (!(Subtarget->hasSSE41() || Subtarget->hasAVX()))
-          return SDValue();
-
-        X = DAG.getNode(ISD::BITCAST, DL, MVT::v16i8, X);
-        Y = DAG.getNode(ISD::BITCAST, DL, MVT::v16i8, Y);
-        Mask = DAG.getNode(ISD::BITCAST, DL, MVT::v16i8, Mask);
-        Mask = DAG.getNode(ISD::VSELECT, DL, MVT::v16i8, Mask, X, Y);
-        return DAG.getNode(ISD::BITCAST, DL, MVT::v2i64, Mask);
       }
+      // PBLENDVB only available on SSE 4.1
+      if (!(Subtarget->hasSSE41() || Subtarget->hasAVX()))
+        return SDValue();
+
+      EVT BlendVT = (VT == MVT::v4i64) ? MVT::v32i8 : MVT::v16i8;
+
+      X = DAG.getNode(ISD::BITCAST, DL, BlendVT, X);
+      Y = DAG.getNode(ISD::BITCAST, DL, BlendVT, Y);
+      Mask = DAG.getNode(ISD::BITCAST, DL, BlendVT, Mask);
+      Mask = DAG.getNode(ISD::VSELECT, DL, BlendVT, Mask, X, Y);
+      return DAG.getNode(ISD::BITCAST, DL, VT, Mask);
     }
   }
 
+  if (VT != MVT::i16 && VT != MVT::i32 && VT != MVT::i64)
+    return SDValue();
+
   // fold (or (x << c) | (y >> (64 - c))) ==> (shld64 x, y, c)
   if (N0.getOpcode() == ISD::SRL && N1.getOpcode() == ISD::SHL)
     std::swap(N0, N1);

Modified: llvm/trunk/lib/Target/X86/X86InstrFragmentsSIMD.td
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/lib/Target/X86/X86InstrFragmentsSIMD.td?rev=144987&r1=144986&r2=144987&view=diff
==============================================================================
--- llvm/trunk/lib/Target/X86/X86InstrFragmentsSIMD.td (original)
+++ llvm/trunk/lib/Target/X86/X86InstrFragmentsSIMD.td Sat Nov 19 01:07:26 2011
@@ -52,13 +52,13 @@
                  SDTypeProfile<1, 2, [SDTCisVec<0>, SDTCisSameAs<0,1>,
                                       SDTCisSameAs<0,2>]>>;
 def X86psignb  : SDNode<"X86ISD::PSIGNB",
-                 SDTypeProfile<1, 2, [SDTCisVT<0, v16i8>, SDTCisSameAs<0,1>,
+                 SDTypeProfile<1, 2, [SDTCisVec<0>, SDTCisSameAs<0,1>,
                                       SDTCisSameAs<0,2>]>>;
 def X86psignw  : SDNode<"X86ISD::PSIGNW",
-                 SDTypeProfile<1, 2, [SDTCisVT<0, v8i16>, SDTCisSameAs<0,1>,
+                 SDTypeProfile<1, 2, [SDTCisVec<0>, SDTCisSameAs<0,1>,
                                       SDTCisSameAs<0,2>]>>;
 def X86psignd  : SDNode<"X86ISD::PSIGND",
-                 SDTypeProfile<1, 2, [SDTCisVT<0, v4i32>, SDTCisSameAs<0,1>,
+                 SDTypeProfile<1, 2, [SDTCisVec<0>, SDTCisSameAs<0,1>,
                                       SDTCisSameAs<0,2>]>>;
 def X86pextrb  : SDNode<"X86ISD::PEXTRB",
                  SDTypeProfile<1, 2, [SDTCisVT<0, i32>, SDTCisPtrTy<2>]>>;

Modified: llvm/trunk/lib/Target/X86/X86InstrSSE.td
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/lib/Target/X86/X86InstrSSE.td?rev=144987&r1=144986&r2=144987&view=diff
==============================================================================
--- llvm/trunk/lib/Target/X86/X86InstrSSE.td (original)
+++ llvm/trunk/lib/Target/X86/X86InstrSSE.td Sat Nov 19 01:07:26 2011
@@ -3824,51 +3824,51 @@
 
 let Predicates = [HasAVX] in {
   def : Pat<(int_x86_sse2_psll_dq VR128:$src1, imm:$src2),
-            (v2i64 (VPSLLDQri VR128:$src1, (BYTE_imm imm:$src2)))>;
+            (VPSLLDQri VR128:$src1, (BYTE_imm imm:$src2))>;
   def : Pat<(int_x86_sse2_psrl_dq VR128:$src1, imm:$src2),
-            (v2i64 (VPSRLDQri VR128:$src1, (BYTE_imm imm:$src2)))>;
+            (VPSRLDQri VR128:$src1, (BYTE_imm imm:$src2))>;
   def : Pat<(int_x86_sse2_psll_dq_bs VR128:$src1, imm:$src2),
-            (v2i64 (VPSLLDQri VR128:$src1, imm:$src2))>;
+            (VPSLLDQri VR128:$src1, imm:$src2)>;
   def : Pat<(int_x86_sse2_psrl_dq_bs VR128:$src1, imm:$src2),
-            (v2i64 (VPSRLDQri VR128:$src1, imm:$src2))>;
+            (VPSRLDQri VR128:$src1, imm:$src2)>;
   def : Pat<(v2f64 (X86fsrl VR128:$src1, i32immSExt8:$src2)),
-            (v2f64 (VPSRLDQri VR128:$src1, (BYTE_imm imm:$src2)))>;
+            (VPSRLDQri VR128:$src1, (BYTE_imm imm:$src2))>;
 
   // Shift up / down and insert zero's.
   def : Pat<(v2i64 (X86vshl  VR128:$src, (i8 imm:$amt))),
-            (v2i64 (VPSLLDQri VR128:$src, (BYTE_imm imm:$amt)))>;
+            (VPSLLDQri VR128:$src, (BYTE_imm imm:$amt))>;
   def : Pat<(v2i64 (X86vshr  VR128:$src, (i8 imm:$amt))),
-            (v2i64 (VPSRLDQri VR128:$src, (BYTE_imm imm:$amt)))>;
+            (VPSRLDQri VR128:$src, (BYTE_imm imm:$amt))>;
 }
 
 let Predicates = [HasAVX2] in {
   def : Pat<(int_x86_avx2_psll_dq VR256:$src1, imm:$src2),
-            (v4i64 (VPSLLDQYri VR256:$src1, (BYTE_imm imm:$src2)))>;
+            (VPSLLDQYri VR256:$src1, (BYTE_imm imm:$src2))>;
   def : Pat<(int_x86_avx2_psrl_dq VR256:$src1, imm:$src2),
-            (v4i64 (VPSRLDQYri VR256:$src1, (BYTE_imm imm:$src2)))>;
+            (VPSRLDQYri VR256:$src1, (BYTE_imm imm:$src2))>;
   def : Pat<(int_x86_avx2_psll_dq_bs VR256:$src1, imm:$src2),
-            (v4i64 (VPSLLDQYri VR256:$src1, imm:$src2))>;
+            (VPSLLDQYri VR256:$src1, imm:$src2)>;
   def : Pat<(int_x86_avx2_psrl_dq_bs VR256:$src1, imm:$src2),
-            (v4i64 (VPSRLDQYri VR256:$src1, imm:$src2))>;
+            (VPSRLDQYri VR256:$src1, imm:$src2)>;
 }
 
 let Predicates = [HasSSE2] in {
   def : Pat<(int_x86_sse2_psll_dq VR128:$src1, imm:$src2),
-            (v2i64 (PSLLDQri VR128:$src1, (BYTE_imm imm:$src2)))>;
+            (PSLLDQri VR128:$src1, (BYTE_imm imm:$src2))>;
   def : Pat<(int_x86_sse2_psrl_dq VR128:$src1, imm:$src2),
-            (v2i64 (PSRLDQri VR128:$src1, (BYTE_imm imm:$src2)))>;
+            (PSRLDQri VR128:$src1, (BYTE_imm imm:$src2))>;
   def : Pat<(int_x86_sse2_psll_dq_bs VR128:$src1, imm:$src2),
-            (v2i64 (PSLLDQri VR128:$src1, imm:$src2))>;
+            (PSLLDQri VR128:$src1, imm:$src2)>;
   def : Pat<(int_x86_sse2_psrl_dq_bs VR128:$src1, imm:$src2),
-            (v2i64 (PSRLDQri VR128:$src1, imm:$src2))>;
+            (PSRLDQri VR128:$src1, imm:$src2)>;
   def : Pat<(v2f64 (X86fsrl VR128:$src1, i32immSExt8:$src2)),
-            (v2f64 (PSRLDQri VR128:$src1, (BYTE_imm imm:$src2)))>;
+            (PSRLDQri VR128:$src1, (BYTE_imm imm:$src2))>;
 
   // Shift up / down and insert zero's.
   def : Pat<(v2i64 (X86vshl  VR128:$src, (i8 imm:$amt))),
-            (v2i64 (PSLLDQri VR128:$src, (BYTE_imm imm:$amt)))>;
+            (PSLLDQri VR128:$src, (BYTE_imm imm:$amt))>;
   def : Pat<(v2i64 (X86vshr  VR128:$src, (i8 imm:$amt))),
-            (v2i64 (PSRLDQri VR128:$src, (BYTE_imm imm:$amt)))>;
+            (PSRLDQri VR128:$src, (BYTE_imm imm:$amt))>;
 }
 
 //===---------------------------------------------------------------------===//
@@ -5316,11 +5316,11 @@
                                         int_x86_avx2_pmadd_ub_sw>, VEX_4V;
   defm VPSHUFB    : SS3I_binop_rm_int_y<0x00, "vpshufb", memopv32i8,
                                         int_x86_avx2_pshuf_b>, VEX_4V;
-  defm VPSIGNB    : SS3I_binop_rm_int_y<0x08, "vpsignb", memopv16i8,
+  defm VPSIGNB    : SS3I_binop_rm_int_y<0x08, "vpsignb", memopv32i8,
                                         int_x86_avx2_psign_b>, VEX_4V;
-  defm VPSIGNW    : SS3I_binop_rm_int_y<0x09, "vpsignw", memopv8i16,
+  defm VPSIGNW    : SS3I_binop_rm_int_y<0x09, "vpsignw", memopv16i16,
                                         int_x86_avx2_psign_w>, VEX_4V;
-  defm VPSIGND    : SS3I_binop_rm_int_y<0x0A, "vpsignd", memopv4i32,
+  defm VPSIGND    : SS3I_binop_rm_int_y<0x0A, "vpsignd", memopv8i32,
                                         int_x86_avx2_psign_d>, VEX_4V;
 }
 defm VPMULHRSW    : SS3I_binop_rm_int_y<0x0B, "vpmulhrsw", memopv16i16,
@@ -5363,11 +5363,11 @@
   def : Pat<(X86pshufb VR128:$src, (bc_v16i8 (memopv2i64 addr:$mask))),
             (PSHUFBrm128 VR128:$src, addr:$mask)>;
 
-  def : Pat<(X86psignb VR128:$src1, VR128:$src2),
+  def : Pat<(v16i8 (X86psignb VR128:$src1, VR128:$src2)),
             (PSIGNBrr128 VR128:$src1, VR128:$src2)>;
-  def : Pat<(X86psignw VR128:$src1, VR128:$src2),
+  def : Pat<(v8i16 (X86psignw VR128:$src1, VR128:$src2)),
             (PSIGNWrr128 VR128:$src1, VR128:$src2)>;
-  def : Pat<(X86psignd VR128:$src1, VR128:$src2),
+  def : Pat<(v4i32 (X86psignd VR128:$src1, VR128:$src2)),
             (PSIGNDrr128 VR128:$src1, VR128:$src2)>;
 }
 
@@ -5377,14 +5377,23 @@
   def : Pat<(X86pshufb VR128:$src, (bc_v16i8 (memopv2i64 addr:$mask))),
             (VPSHUFBrm128 VR128:$src, addr:$mask)>;
 
-  def : Pat<(X86psignb VR128:$src1, VR128:$src2),
+  def : Pat<(v16i8 (X86psignb VR128:$src1, VR128:$src2)),
             (VPSIGNBrr128 VR128:$src1, VR128:$src2)>;
-  def : Pat<(X86psignw VR128:$src1, VR128:$src2),
+  def : Pat<(v8i16 (X86psignw VR128:$src1, VR128:$src2)),
             (VPSIGNWrr128 VR128:$src1, VR128:$src2)>;
-  def : Pat<(X86psignd VR128:$src1, VR128:$src2),
+  def : Pat<(v4i32 (X86psignd VR128:$src1, VR128:$src2)),
             (VPSIGNDrr128 VR128:$src1, VR128:$src2)>;
 }
 
+let Predicates = [HasAVX2] in {
+  def : Pat<(v32i8 (X86psignb VR256:$src1, VR256:$src2)),
+            (VPSIGNBrr256 VR256:$src1, VR256:$src2)>;
+  def : Pat<(v16i16 (X86psignw VR256:$src1, VR256:$src2)),
+            (VPSIGNWrr256 VR256:$src1, VR256:$src2)>;
+  def : Pat<(v8i32 (X86psignd VR256:$src1, VR256:$src2)),
+            (VPSIGNDrr256 VR256:$src1, VR256:$src2)>;
+}
+
 //===---------------------------------------------------------------------===//
 // SSSE3 - Packed Align Instruction Patterns
 //===---------------------------------------------------------------------===//

Modified: llvm/trunk/test/CodeGen/X86/avx2-logic.ll
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/test/CodeGen/X86/avx2-logic.ll?rev=144987&r1=144986&r2=144987&view=diff
==============================================================================
--- llvm/trunk/test/CodeGen/X86/avx2-logic.ll (original)
+++ llvm/trunk/test/CodeGen/X86/avx2-logic.ll Sat Nov 19 01:07:26 2011
@@ -53,3 +53,32 @@
   %min = select <32 x i1> %min_is_x, <32 x i8> %x, <32 x i8> %y
   ret <32 x i8> %min
 }
+
+define <8 x i32> @signd(<8 x i32> %a, <8 x i32> %b) nounwind {
+entry:
+; CHECK: signd:
+; CHECK: psignd
+; CHECK-NOT: sub
+; CHECK: ret
+  %b.lobit = ashr <8 x i32> %b, <i32 31, i32 31, i32 31, i32 31, i32 31, i32 31, i32 31, i32 31>
+  %sub = sub nsw <8 x i32> zeroinitializer, %a
+  %0 = xor <8 x i32> %b.lobit, <i32 -1, i32 -1, i32 -1, i32 -1, i32 -1, i32 -1, i32 -1, i32 -1>
+  %1 = and <8 x i32> %a, %0
+  %2 = and <8 x i32> %b.lobit, %sub
+  %cond = or <8 x i32> %1, %2
+  ret <8 x i32> %cond
+}
+
+define <8 x i32> @blendvb(<8 x i32> %b, <8 x i32> %a, <8 x i32> %c) nounwind {
+entry:
+; CHECK: blendvb:
+; CHECK: pblendvb
+; CHECK: ret
+  %b.lobit = ashr <8 x i32> %b, <i32 31, i32 31, i32 31, i32 31, i32 31, i32 31, i32 31, i32 31>
+  %sub = sub nsw <8 x i32> zeroinitializer, %a
+  %0 = xor <8 x i32> %b.lobit, <i32 -1, i32 -1, i32 -1, i32 -1, i32 -1, i32 -1, i32 -1, i32 -1>
+  %1 = and <8 x i32> %c, %0
+  %2 = and <8 x i32> %a, %b.lobit
+  %cond = or <8 x i32> %1, %2
+  ret <8 x i32> %cond
+}





More information about the llvm-commits mailing list