[llvm] r259346 - AVX512: fix mask handling for gather/scatter/prefetch intrinsics.
Igor Breger via llvm-commits
llvm-commits at lists.llvm.org
Mon Feb 1 01:57:15 PST 2016
Author: ibreger
Date: Mon Feb 1 03:57:15 2016
New Revision: 259346
URL: http://llvm.org/viewvc/llvm-project?rev=259346&view=rev
Log:
AVX512: fix mask handling for gather/scatter/prefetch intrinsics.
Differential Revision: http://reviews.llvm.org/D16755
Modified:
llvm/trunk/lib/Target/X86/X86ISelLowering.cpp
llvm/trunk/test/CodeGen/X86/avx512-gather-scatter-intrin.ll
Modified: llvm/trunk/lib/Target/X86/X86ISelLowering.cpp
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/lib/Target/X86/X86ISelLowering.cpp?rev=259346&r1=259345&r2=259346&view=diff
==============================================================================
--- llvm/trunk/lib/Target/X86/X86ISelLowering.cpp (original)
+++ llvm/trunk/lib/Target/X86/X86ISelLowering.cpp Mon Feb 1 03:57:15 2016
@@ -16492,6 +16492,11 @@ static SDValue getMaskNode(SDValue Mask,
const X86Subtarget &Subtarget,
SelectionDAG &DAG, SDLoc dl) {
+ if (isAllOnesConstant(Mask))
+ return DAG.getTargetConstant(1, dl, MaskVT);
+ if (X86::isZeroNode(Mask))
+ return DAG.getTargetConstant(0, dl, MaskVT);
+
if (MaskVT.bitsGT(Mask.getSimpleValueType())) {
// Mask should be extended
Mask = DAG.getNode(ISD::ANY_EXTEND, dl,
@@ -17409,26 +17414,14 @@ static SDValue getGatherNode(unsigned Op
SDValue Scale = DAG.getTargetConstant(C->getZExtValue(), dl, MVT::i8);
MVT MaskVT = MVT::getVectorVT(MVT::i1,
Index.getSimpleValueType().getVectorNumElements());
- SDValue MaskInReg;
- ConstantSDNode *MaskC = dyn_cast<ConstantSDNode>(Mask);
- if (MaskC)
- MaskInReg = DAG.getTargetConstant(MaskC->getSExtValue(), dl, MaskVT);
- else {
- MVT BitcastVT = MVT::getVectorVT(MVT::i1,
- Mask.getSimpleValueType().getSizeInBits());
- // In case when MaskVT equals v2i1 or v4i1, low 2 or 4 elements
- // are extracted by EXTRACT_SUBVECTOR.
- MaskInReg = DAG.getNode(ISD::EXTRACT_SUBVECTOR, dl, MaskVT,
- DAG.getBitcast(BitcastVT, Mask),
- DAG.getIntPtrConstant(0, dl));
- }
+ SDValue VMask = getMaskNode(Mask, MaskVT, Subtarget, DAG, dl);
SDVTList VTs = DAG.getVTList(Op.getValueType(), MaskVT, MVT::Other);
SDValue Disp = DAG.getTargetConstant(0, dl, MVT::i32);
SDValue Segment = DAG.getRegister(0, MVT::i32);
if (Src.getOpcode() == ISD::UNDEF)
Src = getZeroVector(Op.getSimpleValueType(), Subtarget, DAG, dl);
- SDValue Ops[] = {Src, MaskInReg, Base, Scale, Index, Disp, Segment, Chain};
+ SDValue Ops[] = {Src, VMask, Base, Scale, Index, Disp, Segment, Chain};
SDNode *Res = DAG.getMachineNode(Opc, dl, VTs, Ops);
SDValue RetOps[] = { SDValue(Res, 0), SDValue(Res, 2) };
return DAG.getMergeValues(RetOps, dl);
@@ -17436,7 +17429,8 @@ static SDValue getGatherNode(unsigned Op
static SDValue getScatterNode(unsigned Opc, SDValue Op, SelectionDAG &DAG,
SDValue Src, SDValue Mask, SDValue Base,
- SDValue Index, SDValue ScaleOp, SDValue Chain) {
+ SDValue Index, SDValue ScaleOp, SDValue Chain,
+ const X86Subtarget &Subtarget) {
SDLoc dl(Op);
auto *C = cast<ConstantSDNode>(ScaleOp);
SDValue Scale = DAG.getTargetConstant(C->getZExtValue(), dl, MVT::i8);
@@ -17444,29 +17438,18 @@ static SDValue getScatterNode(unsigned O
SDValue Segment = DAG.getRegister(0, MVT::i32);
MVT MaskVT = MVT::getVectorVT(MVT::i1,
Index.getSimpleValueType().getVectorNumElements());
- SDValue MaskInReg;
- ConstantSDNode *MaskC = dyn_cast<ConstantSDNode>(Mask);
- if (MaskC)
- MaskInReg = DAG.getTargetConstant(MaskC->getSExtValue(), dl, MaskVT);
- else {
- MVT BitcastVT = MVT::getVectorVT(MVT::i1,
- Mask.getSimpleValueType().getSizeInBits());
- // In case when MaskVT equals v2i1 or v4i1, low 2 or 4 elements
- // are extracted by EXTRACT_SUBVECTOR.
- MaskInReg = DAG.getNode(ISD::EXTRACT_SUBVECTOR, dl, MaskVT,
- DAG.getBitcast(BitcastVT, Mask),
- DAG.getIntPtrConstant(0, dl));
- }
+ SDValue VMask = getMaskNode(Mask, MaskVT, Subtarget, DAG, dl);
SDVTList VTs = DAG.getVTList(MaskVT, MVT::Other);
- SDValue Ops[] = {Base, Scale, Index, Disp, Segment, MaskInReg, Src, Chain};
+ SDValue Ops[] = {Base, Scale, Index, Disp, Segment, VMask, Src, Chain};
SDNode *Res = DAG.getMachineNode(Opc, dl, VTs, Ops);
return SDValue(Res, 1);
}
static SDValue getPrefetchNode(unsigned Opc, SDValue Op, SelectionDAG &DAG,
SDValue Mask, SDValue Base, SDValue Index,
- SDValue ScaleOp, SDValue Chain) {
+ SDValue ScaleOp, SDValue Chain,
+ const X86Subtarget &Subtarget) {
SDLoc dl(Op);
auto *C = cast<ConstantSDNode>(ScaleOp);
SDValue Scale = DAG.getTargetConstant(C->getZExtValue(), dl, MVT::i8);
@@ -17474,14 +17457,9 @@ static SDValue getPrefetchNode(unsigned
SDValue Segment = DAG.getRegister(0, MVT::i32);
MVT MaskVT =
MVT::getVectorVT(MVT::i1, Index.getSimpleValueType().getVectorNumElements());
- SDValue MaskInReg;
- ConstantSDNode *MaskC = dyn_cast<ConstantSDNode>(Mask);
- if (MaskC)
- MaskInReg = DAG.getTargetConstant(MaskC->getSExtValue(), dl, MaskVT);
- else
- MaskInReg = DAG.getBitcast(MaskVT, Mask);
+ SDValue VMask = getMaskNode(Mask, MaskVT, Subtarget, DAG, dl);
//SDVTList VTs = DAG.getVTList(MVT::Other);
- SDValue Ops[] = {MaskInReg, Base, Scale, Index, Disp, Segment, Chain};
+ SDValue Ops[] = {VMask, Base, Scale, Index, Disp, Segment, Chain};
SDNode *Res = DAG.getMachineNode(Opc, dl, MVT::Other, Ops);
return SDValue(Res, 0);
}
@@ -17678,7 +17656,7 @@ static SDValue LowerINTRINSIC_W_CHAIN(SD
SDValue Src = Op.getOperand(5);
SDValue Scale = Op.getOperand(6);
return getScatterNode(IntrData->Opc0, Op, DAG, Src, Mask, Base, Index,
- Scale, Chain);
+ Scale, Chain, Subtarget);
}
case PREFETCH: {
SDValue Hint = Op.getOperand(6);
@@ -17690,7 +17668,8 @@ static SDValue LowerINTRINSIC_W_CHAIN(SD
SDValue Index = Op.getOperand(3);
SDValue Base = Op.getOperand(4);
SDValue Scale = Op.getOperand(5);
- return getPrefetchNode(Opcode, Op, DAG, Mask, Base, Index, Scale, Chain);
+ return getPrefetchNode(Opcode, Op, DAG, Mask, Base, Index, Scale, Chain,
+ Subtarget);
}
// Read Time Stamp Counter (RDTSC) and Processor ID (RDTSCP).
case RDTSC: {
Modified: llvm/trunk/test/CodeGen/X86/avx512-gather-scatter-intrin.ll
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/test/CodeGen/X86/avx512-gather-scatter-intrin.ll?rev=259346&r1=259345&r2=259346&view=diff
==============================================================================
--- llvm/trunk/test/CodeGen/X86/avx512-gather-scatter-intrin.ll (original)
+++ llvm/trunk/test/CodeGen/X86/avx512-gather-scatter-intrin.ll Mon Feb 1 03:57:15 2016
@@ -259,18 +259,22 @@ define void @prefetch(<8 x i64> %ind, i8
; CHECK: ## BB#0:
; CHECK-NEXT: kxnorw %k0, %k0, %k1
; CHECK-NEXT: vgatherpf0qps (%rdi,%zmm0,4) {%k1}
+; CHECK-NEXT: kxorw %k0, %k0, %k1
; CHECK-NEXT: vgatherpf1qps (%rdi,%zmm0,4) {%k1}
+; CHECK-NEXT: movb $1, %al
+; CHECK-NEXT: kmovb %eax, %k1
; CHECK-NEXT: vscatterpf0qps (%rdi,%zmm0,2) {%k1}
+; CHECK-NEXT: movb $120, %al
+; CHECK-NEXT: kmovb %eax, %k1
; CHECK-NEXT: vscatterpf1qps (%rdi,%zmm0,2) {%k1}
; CHECK-NEXT: retq
call void @llvm.x86.avx512.gatherpf.qps.512(i8 -1, <8 x i64> %ind, i8* %base, i32 4, i32 0)
- call void @llvm.x86.avx512.gatherpf.qps.512(i8 -1, <8 x i64> %ind, i8* %base, i32 4, i32 1)
- call void @llvm.x86.avx512.scatterpf.qps.512(i8 -1, <8 x i64> %ind, i8* %base, i32 2, i32 0)
- call void @llvm.x86.avx512.scatterpf.qps.512(i8 -1, <8 x i64> %ind, i8* %base, i32 2, i32 1)
+ call void @llvm.x86.avx512.gatherpf.qps.512(i8 0, <8 x i64> %ind, i8* %base, i32 4, i32 1)
+ call void @llvm.x86.avx512.scatterpf.qps.512(i8 1, <8 x i64> %ind, i8* %base, i32 2, i32 0)
+ call void @llvm.x86.avx512.scatterpf.qps.512(i8 120, <8 x i64> %ind, i8* %base, i32 2, i32 1)
ret void
}
-
declare <2 x double> @llvm.x86.avx512.gather3div2.df(<2 x double>, i8*, <2 x i64>, i8, i32)
define <2 x double>@test_int_x86_avx512_gather3div2_df(<2 x double> %x0, i8* %x1, <2 x i64> %x2, i8 %x3) {
@@ -790,3 +794,54 @@ define void at test_int_x86_avx512_scatters
ret void
}
+define void @scatter_mask_test(i8* %x0, <8 x i32> %x2, <8 x i32> %x3) {
+; CHECK-LABEL: scatter_mask_test:
+; CHECK: ## BB#0:
+; CHECK-NEXT: kxnorw %k0, %k0, %k1
+; CHECK-NEXT: vpscatterdd %ymm1, (%rdi,%ymm0,2) {%k1}
+; CHECK-NEXT: kxorw %k0, %k0, %k1
+; CHECK-NEXT: vpscatterdd %ymm1, (%rdi,%ymm0,4) {%k1}
+; CHECK-NEXT: movb $1, %al
+; CHECK-NEXT: kmovb %eax, %k1
+; CHECK-NEXT: vpscatterdd %ymm1, (%rdi,%ymm0,2) {%k1}
+; CHECK-NEXT: movb $96, %al
+; CHECK-NEXT: kmovb %eax, %k1
+; CHECK-NEXT: vpscatterdd %ymm1, (%rdi,%ymm0,4) {%k1}
+; CHECK-NEXT: retq
+ call void @llvm.x86.avx512.scattersiv8.si(i8* %x0, i8 -1, <8 x i32> %x2, <8 x i32> %x3, i32 2)
+ call void @llvm.x86.avx512.scattersiv8.si(i8* %x0, i8 0, <8 x i32> %x2, <8 x i32> %x3, i32 4)
+ call void @llvm.x86.avx512.scattersiv8.si(i8* %x0, i8 1, <8 x i32> %x2, <8 x i32> %x3, i32 2)
+ call void @llvm.x86.avx512.scattersiv8.si(i8* %x0, i8 96, <8 x i32> %x2, <8 x i32> %x3, i32 4)
+ ret void
+}
+
+define <16 x float> @gather_mask_test(<16 x i32> %ind, <16 x float> %src, i8* %base) {
+; CHECK-LABEL: gather_mask_test:
+; CHECK: ## BB#0:
+; CHECK-NEXT: kxnorw %k0, %k0, %k1
+; CHECK-NEXT: vmovaps %zmm1, %zmm2
+; CHECK-NEXT: vgatherdps (%rdi,%zmm0,4), %zmm2 {%k1}
+; CHECK-NEXT: kxorw %k0, %k0, %k1
+; CHECK-NEXT: vmovaps %zmm1, %zmm3
+; CHECK-NEXT: vgatherdps (%rdi,%zmm0,4), %zmm3 {%k1}
+; CHECK-NEXT: movw $1, %ax
+; CHECK-NEXT: kmovw %eax, %k1
+; CHECK-NEXT: vmovaps %zmm1, %zmm4
+; CHECK-NEXT: vgatherdps (%rdi,%zmm0,4), %zmm4 {%k1}
+; CHECK-NEXT: movw $220, %ax
+; CHECK-NEXT: kmovw %eax, %k1
+; CHECK-NEXT: vgatherdps (%rdi,%zmm0,4), %zmm1 {%k1}
+; CHECK-NEXT: vaddps %zmm3, %zmm2, %zmm0
+; CHECK-NEXT: vaddps %zmm4, %zmm1, %zmm1
+; CHECK-NEXT: vaddps %zmm0, %zmm1, %zmm0
+; CHECK-NEXT: retq
+ %res = call <16 x float> @llvm.x86.avx512.gather.dps.512 (<16 x float> %src, i8* %base, <16 x i32>%ind, i16 -1, i32 4)
+ %res1 = call <16 x float> @llvm.x86.avx512.gather.dps.512 (<16 x float> %src, i8* %base, <16 x i32>%ind, i16 0, i32 4)
+ %res2 = call <16 x float> @llvm.x86.avx512.gather.dps.512 (<16 x float> %src, i8* %base, <16 x i32>%ind, i16 1, i32 4)
+ %res3 = call <16 x float> @llvm.x86.avx512.gather.dps.512 (<16 x float> %src, i8* %base, <16 x i32>%ind, i16 220, i32 4)
+
+ %res4 = fadd <16 x float> %res, %res1
+ %res5 = fadd <16 x float> %res3, %res2
+ %res6 = fadd <16 x float> %res5, %res4
+ ret <16 x float> %res6
+}
More information about the llvm-commits
mailing list