[llvm] r238637 - [x86] Restore the bitcasts I removed when refactoring this to avoid
Chandler Carruth
chandlerc at gmail.com
Fri May 29 21:05:12 PDT 2015
Author: chandlerc
Date: Fri May 29 23:05:11 2015
New Revision: 238637
URL: http://llvm.org/viewvc/llvm-project?rev=238637&view=rev
Log:
[x86] Restore the bitcasts I removed when refactoring this to avoid
shifting vectors of bytes as x86 doesn't have direct support for that.
This removes a bunch of redundant masking in the generated code for SSE2
and SSE3.
In order to avoid the really significant code size growth this would
have triggered, I also factored the completely repeatative logic for
shifting and masking into two lambdas which in turn makes all of this
much easier to read IMO.
Modified:
llvm/trunk/lib/Target/X86/X86ISelLowering.cpp
llvm/trunk/test/CodeGen/X86/vector-popcnt-128.ll
Modified: llvm/trunk/lib/Target/X86/X86ISelLowering.cpp
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/lib/Target/X86/X86ISelLowering.cpp?rev=238637&r1=238636&r2=238637&view=diff
==============================================================================
--- llvm/trunk/lib/Target/X86/X86ISelLowering.cpp (original)
+++ llvm/trunk/lib/Target/X86/X86ISelLowering.cpp Fri May 29 23:05:11 2015
@@ -17479,7 +17479,6 @@ static SDValue LowerVectorCTPOPBitmath(S
"Only 128-bit vector bitmath lowering supported.");
int VecSize = VT.getSizeInBits();
- int NumElts = VT.getVectorNumElements();
MVT EltVT = VT.getVectorElementType();
int Len = EltVT.getSizeInBits();
@@ -17490,48 +17489,52 @@ static SDValue LowerVectorCTPOPBitmath(S
// this when we don't have SSSE3 which allows a LUT-based lowering that is
// much faster, even faster than using native popcnt instructions.
- SDValue Cst55 = DAG.getConstant(APInt::getSplat(Len, APInt(8, 0x55)), DL,
- EltVT);
- SDValue Cst33 = DAG.getConstant(APInt::getSplat(Len, APInt(8, 0x33)), DL,
- EltVT);
- SDValue Cst0F = DAG.getConstant(APInt::getSplat(Len, APInt(8, 0x0F)), DL,
- EltVT);
+ auto GetShift = [&](unsigned OpCode, SDValue V, int Shifter) {
+ MVT VT = V.getSimpleValueType();
+ SmallVector<SDValue, 32> Shifters(
+ VT.getVectorNumElements(),
+ DAG.getConstant(Shifter, DL, VT.getVectorElementType()));
+ return DAG.getNode(OpCode, DL, VT, V,
+ DAG.getNode(ISD::BUILD_VECTOR, DL, VT, Shifters));
+ };
+ auto GetMask = [&](SDValue V, APInt Mask) {
+ MVT VT = V.getSimpleValueType();
+ SmallVector<SDValue, 32> Masks(
+ VT.getVectorNumElements(),
+ DAG.getConstant(Mask, DL, VT.getVectorElementType()));
+ return DAG.getNode(ISD::AND, DL, VT, V,
+ DAG.getNode(ISD::BUILD_VECTOR, DL, VT, Masks));
+ };
+
+ // We don't want to incur the implicit masks required to SRL vNi8 vectors on
+ // x86, so set the SRL type to have elements at least i16 wide. This is
+ // correct because all of our SRLs are followed immediately by a mask anyways
+ // that handles any bits that sneak into the high bits of the byte elements.
+ MVT SrlVT = Len > 8 ? VT : MVT::getVectorVT(MVT::i16, VecSize / 16);
SDValue V = Op;
// v = v - ((v >> 1) & 0x55555555...)
- SmallVector<SDValue, 8> Ones(NumElts, DAG.getConstant(1, DL, EltVT));
- SDValue OnesV = DAG.getNode(ISD::BUILD_VECTOR, DL, VT, Ones);
- SDValue Srl = DAG.getNode(ISD::SRL, DL, VT, V, OnesV);
-
- SmallVector<SDValue, 8> Mask55(NumElts, Cst55);
- SDValue M55 = DAG.getNode(ISD::BUILD_VECTOR, DL, VT, Mask55);
- SDValue And = DAG.getNode(ISD::AND, DL, Srl.getValueType(), Srl, M55);
-
+ SDValue Srl = DAG.getNode(
+ ISD::BITCAST, DL, VT,
+ GetShift(ISD::SRL, DAG.getNode(ISD::BITCAST, DL, SrlVT, V), 1));
+ SDValue And = GetMask(Srl, APInt::getSplat(Len, APInt(8, 0x55)));
V = DAG.getNode(ISD::SUB, DL, VT, V, And);
// v = (v & 0x33333333...) + ((v >> 2) & 0x33333333...)
- SmallVector<SDValue, 8> Mask33(NumElts, Cst33);
- SDValue M33 = DAG.getNode(ISD::BUILD_VECTOR, DL, VT, Mask33);
- SDValue AndLHS = DAG.getNode(ISD::AND, DL, M33.getValueType(), V, M33);
-
- SmallVector<SDValue, 8> Twos(NumElts, DAG.getConstant(2, DL, EltVT));
- SDValue TwosV = DAG.getNode(ISD::BUILD_VECTOR, DL, VT, Twos);
- Srl = DAG.getNode(ISD::SRL, DL, VT, V, TwosV);
- SDValue AndRHS = DAG.getNode(ISD::AND, DL, M33.getValueType(), Srl, M33);
-
+ SDValue AndLHS = GetMask(V, APInt::getSplat(Len, APInt(8, 0x33)));
+ Srl = DAG.getNode(
+ ISD::BITCAST, DL, VT,
+ GetShift(ISD::SRL, DAG.getNode(ISD::BITCAST, DL, SrlVT, V), 2));
+ SDValue AndRHS = GetMask(Srl, APInt::getSplat(Len, APInt(8, 0x33)));
V = DAG.getNode(ISD::ADD, DL, VT, AndLHS, AndRHS);
// v = (v + (v >> 4)) & 0x0F0F0F0F...
- SmallVector<SDValue, 8> Fours(NumElts, DAG.getConstant(4, DL, EltVT));
- SDValue FoursV = DAG.getNode(ISD::BUILD_VECTOR, DL, VT, Fours);
- Srl = DAG.getNode(ISD::SRL, DL, VT, V, FoursV);
+ Srl = DAG.getNode(
+ ISD::BITCAST, DL, VT,
+ GetShift(ISD::SRL, DAG.getNode(ISD::BITCAST, DL, SrlVT, V), 4));
SDValue Add = DAG.getNode(ISD::ADD, DL, VT, V, Srl);
-
- SmallVector<SDValue, 8> Mask0F(NumElts, Cst0F);
- SDValue M0F = DAG.getNode(ISD::BUILD_VECTOR, DL, VT, Mask0F);
-
- V = DAG.getNode(ISD::AND, DL, M0F.getValueType(), Add, M0F);
+ V = GetMask(Add, APInt::getSplat(Len, APInt(8, 0x0F)));
// At this point, V contains the byte-wise population count, and we are
// merely doing a horizontal sum if necessary to get the wider element
@@ -17543,26 +17546,21 @@ static SDValue LowerVectorCTPOPBitmath(S
MVT ByteVT = MVT::getVectorVT(MVT::i8, VecSize / 8);
MVT ShiftVT = MVT::getVectorVT(MVT::i64, VecSize / 64);
V = DAG.getNode(ISD::BITCAST, DL, ByteVT, V);
- SmallVector<SDValue, 8> Csts;
assert(Len <= 64 && "We don't support element sizes of more than 64 bits!");
assert(isPowerOf2_32(Len) && "Only power of two element sizes supported!");
for (int i = Len; i > 8; i /= 2) {
- Csts.assign(VecSize / 64, DAG.getConstant(i / 2, DL, MVT::i64));
SDValue Shl = DAG.getNode(
- ISD::SHL, DL, ShiftVT, DAG.getNode(ISD::BITCAST, DL, ShiftVT, V),
- DAG.getNode(ISD::BUILD_VECTOR, DL, ShiftVT, Csts));
- V = DAG.getNode(ISD::ADD, DL, ByteVT, V,
- DAG.getNode(ISD::BITCAST, DL, ByteVT, Shl));
+ ISD::BITCAST, DL, ByteVT,
+ GetShift(ISD::SHL, DAG.getNode(ISD::BITCAST, DL, ShiftVT, V), i / 2));
+ V = DAG.getNode(ISD::ADD, DL, ByteVT, V, Shl);
}
// The high byte now contains the sum of the element bytes. Shift it right
// (if needed) to make it the low byte.
V = DAG.getNode(ISD::BITCAST, DL, VT, V);
- if (Len > 8) {
- Csts.assign(NumElts, DAG.getConstant(Len - 8, DL, EltVT));
- V = DAG.getNode(ISD::SRL, DL, VT, V,
- DAG.getNode(ISD::BUILD_VECTOR, DL, VT, Csts));
- }
+ if (Len > 8)
+ V = GetShift(ISD::SRL, V, Len - 8);
+
return V;
}
Modified: llvm/trunk/test/CodeGen/X86/vector-popcnt-128.ll
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/test/CodeGen/X86/vector-popcnt-128.ll?rev=238637&r1=238636&r2=238637&view=diff
==============================================================================
--- llvm/trunk/test/CodeGen/X86/vector-popcnt-128.ll (original)
+++ llvm/trunk/test/CodeGen/X86/vector-popcnt-128.ll Fri May 29 23:05:11 2015
@@ -339,21 +339,17 @@ define <16 x i8> @testv16i8(<16 x i8> %i
; SSE2-NEXT: movdqa %xmm0, %xmm1
; SSE2-NEXT: psrlw $1, %xmm1
; SSE2-NEXT: pand {{.*}}(%rip), %xmm1
-; SSE2-NEXT: pand {{.*}}(%rip), %xmm1
; SSE2-NEXT: psubb %xmm1, %xmm0
; SSE2-NEXT: movdqa {{.*#+}} xmm1 = [51,51,51,51,51,51,51,51,51,51,51,51,51,51,51,51]
; SSE2-NEXT: movdqa %xmm0, %xmm2
; SSE2-NEXT: pand %xmm1, %xmm2
; SSE2-NEXT: psrlw $2, %xmm0
-; SSE2-NEXT: pand {{.*}}(%rip), %xmm0
; SSE2-NEXT: pand %xmm1, %xmm0
; SSE2-NEXT: paddb %xmm2, %xmm0
; SSE2-NEXT: movdqa %xmm0, %xmm1
; SSE2-NEXT: psrlw $4, %xmm1
-; SSE2-NEXT: movdqa {{.*#+}} xmm2 = [15,15,15,15,15,15,15,15,15,15,15,15,15,15,15,15]
-; SSE2-NEXT: pand %xmm2, %xmm1
; SSE2-NEXT: paddb %xmm0, %xmm1
-; SSE2-NEXT: pand %xmm2, %xmm1
+; SSE2-NEXT: pand {{.*}}(%rip), %xmm1
; SSE2-NEXT: movdqa %xmm1, %xmm0
; SSE2-NEXT: retq
;
@@ -362,21 +358,17 @@ define <16 x i8> @testv16i8(<16 x i8> %i
; SSE3-NEXT: movdqa %xmm0, %xmm1
; SSE3-NEXT: psrlw $1, %xmm1
; SSE3-NEXT: pand {{.*}}(%rip), %xmm1
-; SSE3-NEXT: pand {{.*}}(%rip), %xmm1
; SSE3-NEXT: psubb %xmm1, %xmm0
; SSE3-NEXT: movdqa {{.*#+}} xmm1 = [51,51,51,51,51,51,51,51,51,51,51,51,51,51,51,51]
; SSE3-NEXT: movdqa %xmm0, %xmm2
; SSE3-NEXT: pand %xmm1, %xmm2
; SSE3-NEXT: psrlw $2, %xmm0
-; SSE3-NEXT: pand {{.*}}(%rip), %xmm0
; SSE3-NEXT: pand %xmm1, %xmm0
; SSE3-NEXT: paddb %xmm2, %xmm0
; SSE3-NEXT: movdqa %xmm0, %xmm1
; SSE3-NEXT: psrlw $4, %xmm1
-; SSE3-NEXT: movdqa {{.*#+}} xmm2 = [15,15,15,15,15,15,15,15,15,15,15,15,15,15,15,15]
-; SSE3-NEXT: pand %xmm2, %xmm1
; SSE3-NEXT: paddb %xmm0, %xmm1
-; SSE3-NEXT: pand %xmm2, %xmm1
+; SSE3-NEXT: pand {{.*}}(%rip), %xmm1
; SSE3-NEXT: movdqa %xmm1, %xmm0
; SSE3-NEXT: retq
;
More information about the llvm-commits
mailing list