[llvm] r235837 - [X86][SSE] Add v16i8/v32i8 multiplication support

Simon Pilgrim llvm-dev at redking.me.uk
Mon Apr 27 00:55:47 PDT 2015


Author: rksimon
Date: Mon Apr 27 02:55:46 2015
New Revision: 235837

URL: http://llvm.org/viewvc/llvm-project?rev=235837&view=rev
Log:
[X86][SSE] Add v16i8/v32i8 multiplication support

Patch to allow int8 vectors to be multiplied on the SSE unit instead of being scalarized.

The patch sign extends the i8 lanes to i16, uses the SSE2 pmullw multiplication instruction, then packs the lower byte from each result.

Differential Revision: http://reviews.llvm.org/D9115

Modified:
    llvm/trunk/lib/Target/X86/X86ISelLowering.cpp
    llvm/trunk/test/CodeGen/X86/avx2-arith.ll
    llvm/trunk/test/CodeGen/X86/pmul.ll

Modified: llvm/trunk/lib/Target/X86/X86ISelLowering.cpp
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/lib/Target/X86/X86ISelLowering.cpp?rev=235837&r1=235836&r2=235837&view=diff
==============================================================================
--- llvm/trunk/lib/Target/X86/X86ISelLowering.cpp (original)
+++ llvm/trunk/lib/Target/X86/X86ISelLowering.cpp Mon Apr 27 02:55:46 2015
@@ -802,6 +802,7 @@ X86TargetLowering::X86TargetLowering(con
     setOperationAction(ISD::ADD,                MVT::v8i16, Legal);
     setOperationAction(ISD::ADD,                MVT::v4i32, Legal);
     setOperationAction(ISD::ADD,                MVT::v2i64, Legal);
+    setOperationAction(ISD::MUL,                MVT::v16i8, Custom);
     setOperationAction(ISD::MUL,                MVT::v4i32, Custom);
     setOperationAction(ISD::MUL,                MVT::v2i64, Custom);
     setOperationAction(ISD::UMUL_LOHI,          MVT::v4i32, Custom);
@@ -1122,7 +1123,7 @@ X86TargetLowering::X86TargetLowering(con
       setOperationAction(ISD::MUL,             MVT::v4i64, Custom);
       setOperationAction(ISD::MUL,             MVT::v8i32, Legal);
       setOperationAction(ISD::MUL,             MVT::v16i16, Legal);
-      // Don't lower v32i8 because there is no 128-bit byte mul
+      setOperationAction(ISD::MUL,             MVT::v32i8, Custom);
 
       setOperationAction(ISD::UMUL_LOHI,       MVT::v8i32, Custom);
       setOperationAction(ISD::SMUL_LOHI,       MVT::v8i32, Custom);
@@ -1171,7 +1172,7 @@ X86TargetLowering::X86TargetLowering(con
       setOperationAction(ISD::MUL,             MVT::v4i64, Custom);
       setOperationAction(ISD::MUL,             MVT::v8i32, Custom);
       setOperationAction(ISD::MUL,             MVT::v16i16, Custom);
-      // Don't lower v32i8 because there is no 128-bit byte mul
+      setOperationAction(ISD::MUL,             MVT::v32i8, Custom);
     }
 
     // In the customized shift lowering, the legal cases in AVX2 will be
@@ -9894,7 +9895,7 @@ static SDValue lower256BitVectorShuffle(
   int NumV2Elements = std::count_if(Mask.begin(), Mask.end(), [NumElts](int M) {
     return M >= NumElts;
   });
-  
+
   if (NumV2Elements == 1 && Mask[0] >= NumElts)
     if (SDValue Insertion = lowerVectorShuffleAsElementInsertion(
                               DL, VT, V1, V2, Mask, Subtarget, DAG))
@@ -10646,7 +10647,7 @@ SDValue X86TargetLowering::LowerINSERT_V
         return DAG.getNode(X86ISD::BLENDI, dl, VT, N0, N1Vec, N2);
       }
     }
-    
+
     // Get the desired 128-bit vector chunk.
     SDValue V = Extract128BitVector(N0, IdxVal, DAG, dl);
 
@@ -15908,6 +15909,79 @@ static SDValue LowerMUL(SDValue Op, cons
   SDValue A = Op.getOperand(0);
   SDValue B = Op.getOperand(1);
 
+  // Lower v16i8/v32i8 mul as promotion to v8i16/v16i16 vector
+  // pairs, multiply and truncate.
+  if (VT == MVT::v16i8 || VT == MVT::v32i8) {
+    if (Subtarget->hasInt256()) {
+      if (VT == MVT::v32i8) {
+        MVT SubVT = MVT::getVectorVT(MVT::i8, VT.getVectorNumElements() / 2);
+        SDValue Lo = DAG.getIntPtrConstant(0);
+        SDValue Hi = DAG.getIntPtrConstant(VT.getVectorNumElements() / 2);
+        SDValue ALo = DAG.getNode(ISD::EXTRACT_SUBVECTOR, dl, SubVT, A, Lo);
+        SDValue BLo = DAG.getNode(ISD::EXTRACT_SUBVECTOR, dl, SubVT, B, Lo);
+        SDValue AHi = DAG.getNode(ISD::EXTRACT_SUBVECTOR, dl, SubVT, A, Hi);
+        SDValue BHi = DAG.getNode(ISD::EXTRACT_SUBVECTOR, dl, SubVT, B, Hi);
+        return DAG.getNode(ISD::CONCAT_VECTORS, dl, VT,
+                           DAG.getNode(ISD::MUL, dl, SubVT, ALo, BLo),
+                           DAG.getNode(ISD::MUL, dl, SubVT, AHi, BHi));
+      }
+
+      MVT ExVT = MVT::getVectorVT(MVT::i16, VT.getVectorNumElements());
+      return DAG.getNode(
+          ISD::TRUNCATE, dl, VT,
+          DAG.getNode(ISD::MUL, dl, ExVT,
+                      DAG.getNode(ISD::SIGN_EXTEND, dl, ExVT, A),
+                      DAG.getNode(ISD::SIGN_EXTEND, dl, ExVT, B)));
+    }
+
+    assert(VT == MVT::v16i8 &&
+           "Pre-AVX2 support only supports v16i8 multiplication");
+    MVT ExVT = MVT::v8i16;
+
+    // Extract the lo parts and sign extend to i16
+    SDValue ALo, BLo;
+    if (Subtarget->hasSSE41()) {
+      ALo = DAG.getNode(X86ISD::VSEXT, dl, ExVT, A);
+      BLo = DAG.getNode(X86ISD::VSEXT, dl, ExVT, B);
+    } else {
+      const int ShufMask[] = {0, -1, 1, -1, 2, -1, 3, -1,
+                              4, -1, 5, -1, 6, -1, 7, -1};
+      ALo = DAG.getVectorShuffle(VT, dl, A, A, ShufMask);
+      BLo = DAG.getVectorShuffle(VT, dl, B, B, ShufMask);
+      ALo = DAG.getNode(ISD::BITCAST, dl, ExVT, ALo);
+      BLo = DAG.getNode(ISD::BITCAST, dl, ExVT, BLo);
+      ALo = DAG.getNode(ISD::SRA, dl, ExVT, ALo, DAG.getConstant(8, ExVT));
+      BLo = DAG.getNode(ISD::SRA, dl, ExVT, BLo, DAG.getConstant(8, ExVT));
+    }
+
+    // Extract the hi parts and sign extend to i16
+    SDValue AHi, BHi;
+    if (Subtarget->hasSSE41()) {
+      const int ShufMask[] = {8,  9,  10, 11, 12, 13, 14, 15,
+                              -1, -1, -1, -1, -1, -1, -1, -1};
+      AHi = DAG.getVectorShuffle(VT, dl, A, A, ShufMask);
+      BHi = DAG.getVectorShuffle(VT, dl, B, B, ShufMask);
+      AHi = DAG.getNode(X86ISD::VSEXT, dl, ExVT, AHi);
+      BHi = DAG.getNode(X86ISD::VSEXT, dl, ExVT, BHi);
+    } else {
+      const int ShufMask[] = {8,  -1, 9,  -1, 10, -1, 11, -1,
+                              12, -1, 13, -1, 14, -1, 15, -1};
+      AHi = DAG.getVectorShuffle(VT, dl, A, A, ShufMask);
+      BHi = DAG.getVectorShuffle(VT, dl, B, B, ShufMask);
+      AHi = DAG.getNode(ISD::BITCAST, dl, ExVT, AHi);
+      BHi = DAG.getNode(ISD::BITCAST, dl, ExVT, BHi);
+      AHi = DAG.getNode(ISD::SRA, dl, ExVT, AHi, DAG.getConstant(8, ExVT));
+      BHi = DAG.getNode(ISD::SRA, dl, ExVT, BHi, DAG.getConstant(8, ExVT));
+    }
+
+    // Multiply, mask the lower 8bits of the lo/hi results and pack
+    SDValue RLo = DAG.getNode(ISD::MUL, dl, ExVT, ALo, BLo);
+    SDValue RHi = DAG.getNode(ISD::MUL, dl, ExVT, AHi, BHi);
+    RLo = DAG.getNode(ISD::AND, dl, ExVT, RLo, DAG.getConstant(255, ExVT));
+    RHi = DAG.getNode(ISD::AND, dl, ExVT, RHi, DAG.getConstant(255, ExVT));
+    return DAG.getNode(X86ISD::PACKUS, dl, VT, RLo, RHi);
+  }
+
   // Lower v4i32 mul as 2x shuffle, 2x pmuludq, 2x shuffle.
   if (VT == MVT::v4i32) {
     assert(Subtarget->hasSSE2() && !Subtarget->hasSSE41() &&

Modified: llvm/trunk/test/CodeGen/X86/avx2-arith.ll
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/test/CodeGen/X86/avx2-arith.ll?rev=235837&r1=235836&r2=235837&view=diff
==============================================================================
--- llvm/trunk/test/CodeGen/X86/avx2-arith.ll (original)
+++ llvm/trunk/test/CodeGen/X86/avx2-arith.ll Mon Apr 27 02:55:46 2015
@@ -60,6 +60,49 @@ define <16 x i16> @test_vpmullw(<16 x i1
   ret <16 x i16> %x
 }
 
+; CHECK: mul-v16i8
+; CHECK:       # BB#0:
+; CHECK-NEXT:  vpmovsxbw %xmm1, %ymm1
+; CHECK-NEXT:  vpmovsxbw %xmm0, %ymm0
+; CHECK-NEXT:  vpmullw %ymm1, %ymm0, %ymm0
+; CHECK-NEXT:  vextracti128 $1, %ymm0, %xmm1
+; CHECK-NEXT:  vmovdqa {{.*#+}} xmm2 = <0,2,4,6,8,10,12,14,u,u,u,u,u,u,u,u>
+; CHECK-NEXT:  vpshufb %xmm2, %xmm1, %xmm1
+; CHECK-NEXT:  vpshufb %xmm2, %xmm0, %xmm0
+; CHECK-NEXT:  vpunpcklqdq {{.*#+}} xmm0 = xmm0[0],xmm1[0]
+; CHECK-NEXT:  vzeroupper
+; CHECK-NEXT:  retq
+define <16 x i8> @mul-v16i8(<16 x i8> %i, <16 x i8> %j) nounwind readnone {
+  %x = mul <16 x i8> %i, %j
+  ret <16 x i8> %x
+}
+
+; CHECK: mul-v32i8
+; CHECK:       # BB#0:
+; CHECK-NEXT:  vextracti128 $1, %ymm1, %xmm2
+; CHECK-NEXT:  vpmovsxbw %xmm2, %ymm2
+; CHECK-NEXT:  vextracti128 $1, %ymm0, %xmm3
+; CHECK-NEXT:  vpmovsxbw %xmm3, %ymm3
+; CHECK-NEXT:  vpmullw %ymm2, %ymm3, %ymm2
+; CHECK-NEXT:  vextracti128 $1, %ymm2, %xmm3
+; CHECK-NEXT:  vmovdqa {{.*#+}} xmm4 = <0,2,4,6,8,10,12,14,u,u,u,u,u,u,u,u>
+; CHECK-NEXT:  vpshufb %xmm4, %xmm3, %xmm3
+; CHECK-NEXT:  vpshufb %xmm4, %xmm2, %xmm2
+; CHECK-NEXT:  vpunpcklqdq {{.*#+}} xmm2 = xmm2[0],xmm3[0]
+; CHECK-NEXT:  vpmovsxbw %xmm1, %ymm1
+; CHECK-NEXT:  vpmovsxbw %xmm0, %ymm0
+; CHECK-NEXT:  vpmullw %ymm1, %ymm0, %ymm0
+; CHECK-NEXT:  vextracti128 $1, %ymm0, %xmm1
+; CHECK-NEXT:  vpshufb %xmm4, %xmm1, %xmm1
+; CHECK-NEXT:  vpshufb %xmm4, %xmm0, %xmm0
+; CHECK-NEXT:  vpunpcklqdq {{.*#+}} xmm0 = xmm0[0],xmm1[0]
+; CHECK-NEXT:  vinserti128 $1, %xmm2, %ymm0, %ymm0
+; CHECK-NEXT:  retq
+define <32 x i8> @mul-v32i8(<32 x i8> %i, <32 x i8> %j) nounwind readnone {
+  %x = mul <32 x i8> %i, %j
+  ret <32 x i8> %x
+}
+
 ; CHECK: mul-v4i64
 ; CHECK: vpmuludq %ymm
 ; CHECK-NEXT: vpsrlq $32, %ymm

Modified: llvm/trunk/test/CodeGen/X86/pmul.ll
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/test/CodeGen/X86/pmul.ll?rev=235837&r1=235836&r2=235837&view=diff
==============================================================================
--- llvm/trunk/test/CodeGen/X86/pmul.ll (original)
+++ llvm/trunk/test/CodeGen/X86/pmul.ll Mon Apr 27 02:55:46 2015
@@ -1,6 +1,53 @@
 ; RUN: llc < %s -mtriple=x86_64-unknown-unknown | FileCheck %s --check-prefix=ALL --check-prefix=SSE2
 ; RUN: llc < %s -mtriple=x86_64-unknown-unknown -mattr=sse4.1 | FileCheck %s --check-prefix=ALL --check-prefix=SSE41
 
+define <16 x i8> @mul8c(<16 x i8> %i) nounwind  {
+; SSE2-LABEL: mul8c:
+; SSE2:       # BB#0: # %entry
+; SSE2-NEXT:    movdqa {{.*#+}} xmm1 = [117,117,117,117,117,117,117,117,117,117,117,117,117,117,117,117]
+; SSE2-NEXT:    psraw $8, %xmm1
+; SSE2-NEXT:    movdqa %xmm0, %xmm2
+; SSE2-NEXT:    punpckhbw {{.*#+}} xmm2 = xmm2[8,8,9,9,10,10,11,11,12,12,13,13,14,14,15,15]
+; SSE2-NEXT:    psraw $8, %xmm2
+; SSE2-NEXT:    pmullw %xmm1, %xmm2
+; SSE2-NEXT:    movdqa {{.*#+}} xmm3 = [255,255,255,255,255,255,255,255]
+; SSE2-NEXT:    pand %xmm3, %xmm2
+; SSE2-NEXT:    punpcklbw {{.*#+}} xmm0 = xmm0[0,0,1,1,2,2,3,3,4,4,5,5,6,6,7,7]
+; SSE2-NEXT:    psraw $8, %xmm0
+; SSE2-NEXT:    pmullw %xmm1, %xmm0
+; SSE2-NEXT:    pand %xmm3, %xmm0
+; SSE2-NEXT:    packuswb %xmm2, %xmm0
+; SSE2-NEXT:    retq
+;
+; SSE41-LABEL: mul8c:
+; SSE41:       # BB#0: # %entry
+; SSE41-NEXT:    pmovsxbw %xmm0, %xmm1
+; SSE41-NEXT:    pmovsxbw {{.*}}(%rip), %xmm2
+; SSE41-NEXT:    pmullw %xmm2, %xmm1
+; SSE41-NEXT:    movdqa {{.*#+}} xmm3 = [255,255,255,255,255,255,255,255]
+; SSE41-NEXT:    pand %xmm3, %xmm1
+; SSE41-NEXT:    pshufd {{.*#+}} xmm0 = xmm0[2,3,0,1]
+; SSE41-NEXT:    pmovsxbw %xmm0, %xmm0
+; SSE41-NEXT:    pmullw %xmm2, %xmm0
+; SSE41-NEXT:    pand %xmm3, %xmm0
+; SSE41-NEXT:    packuswb %xmm0, %xmm1
+; SSE41-NEXT:    movdqa %xmm1, %xmm0
+; SSE41-NEXT:    retq
+entry:
+  %A = mul <16 x i8> %i, < i8 117, i8 117, i8 117, i8 117, i8 117, i8 117, i8 117, i8 117, i8 117, i8 117, i8 117, i8 117, i8 117, i8 117, i8 117, i8 117 >
+  ret <16 x i8> %A
+}
+
+define <8 x i16> @mul16c(<8 x i16> %i) nounwind  {
+; ALL-LABEL: mul16c:
+; ALL:       # BB#0: # %entry
+; ALL-NEXT:    pmullw {{.*}}(%rip), %xmm0
+; ALL-NEXT:    retq
+entry:
+  %A = mul <8 x i16> %i, < i16 117, i16 117, i16 117, i16 117, i16 117, i16 117, i16 117, i16 117 >
+  ret <8 x i16> %A
+}
+
 define <4 x i32> @a(<4 x i32> %i) nounwind  {
 ; SSE2-LABEL: a:
 ; SSE2:       # BB#0: # %entry
@@ -42,6 +89,59 @@ entry:
   ret <2 x i64> %A
 }
 
+define <16 x i8> @mul8(<16 x i8> %i, <16 x i8> %j) nounwind  {
+; SSE2-LABEL: mul8:
+; SSE2:       # BB#0: # %entry
+; SSE2-NEXT:    movdqa %xmm1, %xmm3
+; SSE2-NEXT:    punpcklbw {{.*#+}} xmm3 = xmm3[0],xmm0[0],xmm3[1],xmm0[1],xmm3[2],xmm0[2],xmm3[3],xmm0[3],xmm3[4],xmm0[4],xmm3[5],xmm0[5],xmm3[6],xmm0[6],xmm3[7],xmm0[7]
+; SSE2-NEXT:    psraw $8, %xmm3
+; SSE2-NEXT:    movdqa %xmm0, %xmm2
+; SSE2-NEXT:    punpcklbw {{.*#+}} xmm2 = xmm2[0],xmm0[0],xmm2[1],xmm0[1],xmm2[2],xmm0[2],xmm2[3],xmm0[3],xmm2[4],xmm0[4],xmm2[5],xmm0[5],xmm2[6],xmm0[6],xmm2[7],xmm0[7]
+; SSE2-NEXT:    psraw $8, %xmm2
+; SSE2-NEXT:    pmullw %xmm3, %xmm2
+; SSE2-NEXT:    movdqa {{.*#+}} xmm3 = [255,255,255,255,255,255,255,255]
+; SSE2-NEXT:    pand %xmm3, %xmm2
+; SSE2-NEXT:    punpckhbw {{.*#+}} xmm1 = xmm1[8,8,9,9,10,10,11,11,12,12,13,13,14,14,15,15]
+; SSE2-NEXT:    psraw $8, %xmm1
+; SSE2-NEXT:    punpckhbw {{.*#+}} xmm0 = xmm0[8,8,9,9,10,10,11,11,12,12,13,13,14,14,15,15]
+; SSE2-NEXT:    psraw $8, %xmm0
+; SSE2-NEXT:    pmullw %xmm1, %xmm0
+; SSE2-NEXT:    pand %xmm3, %xmm0
+; SSE2-NEXT:    packuswb %xmm0, %xmm2
+; SSE2-NEXT:    movdqa %xmm2, %xmm0
+; SSE2-NEXT:    retq
+;
+; SSE41-LABEL: mul8:
+; SSE41:       # BB#0: # %entry
+; SSE41-NEXT:    pmovsxbw %xmm1, %xmm3
+; SSE41-NEXT:    pmovsxbw %xmm0, %xmm2
+; SSE41-NEXT:    pmullw %xmm3, %xmm2
+; SSE41-NEXT:    movdqa {{.*#+}} xmm3 = [255,255,255,255,255,255,255,255]
+; SSE41-NEXT:    pand %xmm3, %xmm2
+; SSE41-NEXT:    pshufd {{.*#+}} xmm1 = xmm1[2,3,0,1]
+; SSE41-NEXT:    pmovsxbw %xmm1, %xmm1
+; SSE41-NEXT:    pshufd {{.*#+}} xmm0 = xmm0[2,3,0,1]
+; SSE41-NEXT:    pmovsxbw %xmm0, %xmm0
+; SSE41-NEXT:    pmullw %xmm1, %xmm0
+; SSE41-NEXT:    pand %xmm3, %xmm0
+; SSE41-NEXT:    packuswb %xmm0, %xmm2
+; SSE41-NEXT:    movdqa %xmm2, %xmm0
+; SSE41-NEXT:    retq
+entry:
+  %A = mul <16 x i8> %i, %j
+  ret <16 x i8> %A
+}
+
+define <8 x i16> @mul16(<8 x i16> %i, <8 x i16> %j) nounwind  {
+; ALL-LABEL: mul16:
+; ALL:       # BB#0: # %entry
+; ALL-NEXT:    pmullw %xmm1, %xmm0
+; ALL-NEXT:    retq
+entry:
+  %A = mul <8 x i16> %i, %j
+  ret <8 x i16> %A
+}
+
 define <4 x i32> @c(<4 x i32> %i, <4 x i32> %j) nounwind  {
 ; SSE2-LABEL: c:
 ; SSE2:       # BB#0: # %entry





More information about the llvm-commits mailing list