[llvm] ebb181c - [X86] matchScalarReduction - add support for partial reductions
Simon Pilgrim via llvm-commits
llvm-commits at lists.llvm.org
Mon Mar 16 11:01:24 PDT 2020
Author: Simon Pilgrim
Date: 2020-03-16T18:01:02Z
New Revision: ebb181cf4098e87c0ac357b09ad5fb2d24b1863d
URL: https://github.com/llvm/llvm-project/commit/ebb181cf4098e87c0ac357b09ad5fb2d24b1863d
DIFF: https://github.com/llvm/llvm-project/commit/ebb181cf4098e87c0ac357b09ad5fb2d24b1863d.diff
LOG: [X86] matchScalarReduction - add support for partial reductions
Add optional support for opt-in partial reduction cases by providing an optional partial mask to indicate which elements have been extracted for the scalar reduction.
Added:
Modified:
llvm/lib/Target/X86/X86ISelLowering.cpp
llvm/test/CodeGen/X86/movmsk-cmp.ll
Removed:
################################################################################
diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp
index 339e3c37ee25..8e8a7cce9fb1 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.cpp
+++ b/llvm/lib/Target/X86/X86ISelLowering.cpp
@@ -20964,9 +20964,12 @@ static SDValue getSETCC(X86::CondCode Cond, SDValue EFLAGS, const SDLoc &dl,
}
/// Helper for matching OR(EXTRACTELT(X,0),OR(EXTRACTELT(X,1),...))
-/// style scalarized (associative) reduction patterns.
+/// style scalarized (associative) reduction patterns. Partial reductions
+/// are supported when the pointer SrcMask is non-null.
+/// TODO - move this to SelectionDAG?
static bool matchScalarReduction(SDValue Op, ISD::NodeType BinOp,
- SmallVectorImpl<SDValue> &SrcOps) {
+ SmallVectorImpl<SDValue> &SrcOps,
+ SmallVectorImpl<APInt> *SrcMask = nullptr) {
SmallVector<SDValue, 8> Opnds;
DenseMap<SDValue, APInt> SrcOpMap;
EVT VT = MVT::Other;
@@ -21018,12 +21021,18 @@ static bool matchScalarReduction(SDValue Op, ISD::NodeType BinOp,
M->second.setBit(CIdx);
}
- // Quit if not all elements are used.
- for (DenseMap<SDValue, APInt>::const_iterator I = SrcOpMap.begin(),
- E = SrcOpMap.end();
- I != E; ++I) {
- if (!I->second.isAllOnesValue())
- return false;
+ if (SrcMask) {
+ // Collect the source partial masks.
+ for (SDValue &SrcOp : SrcOps)
+ SrcMask->push_back(SrcOpMap[SrcOp]);
+ } else {
+ // Quit if not all elements are used.
+ for (DenseMap<SDValue, APInt>::const_iterator I = SrcOpMap.begin(),
+ E = SrcOpMap.end();
+ I != E; ++I) {
+ if (!I->second.isAllOnesValue())
+ return false;
+ }
}
return true;
@@ -41210,7 +41219,8 @@ static SDValue combineAnd(SDNode *N, SelectionDAG &DAG,
// TODO: Support multiple SrcOps.
if (VT == MVT::i1) {
SmallVector<SDValue, 2> SrcOps;
- if (matchScalarReduction(SDValue(N, 0), ISD::AND, SrcOps) &&
+ SmallVector<APInt, 2> SrcPartials;
+ if (matchScalarReduction(SDValue(N, 0), ISD::AND, SrcOps, &SrcPartials) &&
SrcOps.size() == 1) {
SDLoc dl(N);
const TargetLowering &TLI = DAG.getTargetLoweringInfo();
@@ -41220,9 +41230,11 @@ static SDValue combineAnd(SDNode *N, SelectionDAG &DAG,
if (!Mask && TLI.isTypeLegal(SrcOps[0].getValueType()))
Mask = DAG.getBitcast(MaskVT, SrcOps[0]);
if (Mask) {
- APInt AllBits = APInt::getAllOnesValue(NumElts);
- return DAG.getSetCC(dl, MVT::i1, Mask,
- DAG.getConstant(AllBits, dl, MaskVT), ISD::SETEQ);
+ assert(SrcPartials[0].getBitWidth() == NumElts &&
+ "Unexpected partial reduction mask");
+ SDValue PartialBits = DAG.getConstant(SrcPartials[0], dl, MaskVT);
+ Mask = DAG.getNode(ISD::AND, dl, MaskVT, Mask, PartialBits);
+ return DAG.getSetCC(dl, MVT::i1, Mask, PartialBits, ISD::SETEQ);
}
}
}
@@ -41685,7 +41697,8 @@ static SDValue combineOr(SDNode *N, SelectionDAG &DAG,
// TODO: Support multiple SrcOps.
if (VT == MVT::i1) {
SmallVector<SDValue, 2> SrcOps;
- if (matchScalarReduction(SDValue(N, 0), ISD::OR, SrcOps) &&
+ SmallVector<APInt, 2> SrcPartials;
+ if (matchScalarReduction(SDValue(N, 0), ISD::OR, SrcOps, &SrcPartials) &&
SrcOps.size() == 1) {
SDLoc dl(N);
const TargetLowering &TLI = DAG.getTargetLoweringInfo();
@@ -41695,9 +41708,12 @@ static SDValue combineOr(SDNode *N, SelectionDAG &DAG,
if (!Mask && TLI.isTypeLegal(SrcOps[0].getValueType()))
Mask = DAG.getBitcast(MaskVT, SrcOps[0]);
if (Mask) {
- APInt AllBits = APInt::getNullValue(NumElts);
- return DAG.getSetCC(dl, MVT::i1, Mask,
- DAG.getConstant(AllBits, dl, MaskVT), ISD::SETNE);
+ assert(SrcPartials[0].getBitWidth() == NumElts &&
+ "Unexpected partial reduction mask");
+ SDValue ZeroBits = DAG.getConstant(0, dl, MaskVT);
+ SDValue PartialBits = DAG.getConstant(SrcPartials[0], dl, MaskVT);
+ Mask = DAG.getNode(ISD::AND, dl, MaskVT, Mask, PartialBits);
+ return DAG.getSetCC(dl, MVT::i1, Mask, ZeroBits, ISD::SETNE);
}
}
}
diff --git a/llvm/test/CodeGen/X86/movmsk-cmp.ll b/llvm/test/CodeGen/X86/movmsk-cmp.ll
index 7f0a1418a719..4fdde8c06641 100644
--- a/llvm/test/CodeGen/X86/movmsk-cmp.ll
+++ b/llvm/test/CodeGen/X86/movmsk-cmp.ll
@@ -4225,40 +4225,25 @@ define i1 @movmsk_v16i8(<16 x i8> %x, <16 x i8> %y) {
ret i1 %u2
}
-; TODO: Replace shift+mask chain with NOT+TEST+SETE
define i1 @movmsk_v8i16(<8 x i16> %x, <8 x i16> %y) {
; SSE2-LABEL: movmsk_v8i16:
; SSE2: # %bb.0:
; SSE2-NEXT: pcmpgtw %xmm1, %xmm0
; SSE2-NEXT: packsswb %xmm0, %xmm0
-; SSE2-NEXT: pmovmskb %xmm0, %ecx
-; SSE2-NEXT: movl %ecx, %eax
-; SSE2-NEXT: shrb $7, %al
-; SSE2-NEXT: movl %ecx, %edx
-; SSE2-NEXT: andb $16, %dl
-; SSE2-NEXT: shrb $4, %dl
-; SSE2-NEXT: andb %al, %dl
-; SSE2-NEXT: movl %ecx, %eax
-; SSE2-NEXT: shrb %al
-; SSE2-NEXT: andb %dl, %al
-; SSE2-NEXT: andb %cl, %al
+; SSE2-NEXT: pmovmskb %xmm0, %eax
+; SSE2-NEXT: andb $-109, %al
+; SSE2-NEXT: cmpb $-109, %al
+; SSE2-NEXT: sete %al
; SSE2-NEXT: retq
;
; AVX-LABEL: movmsk_v8i16:
; AVX: # %bb.0:
; AVX-NEXT: vpcmpgtw %xmm1, %xmm0, %xmm0
; AVX-NEXT: vpacksswb %xmm0, %xmm0, %xmm0
-; AVX-NEXT: vpmovmskb %xmm0, %ecx
-; AVX-NEXT: movl %ecx, %eax
-; AVX-NEXT: shrb $7, %al
-; AVX-NEXT: movl %ecx, %edx
-; AVX-NEXT: andb $16, %dl
-; AVX-NEXT: shrb $4, %dl
-; AVX-NEXT: andb %al, %dl
-; AVX-NEXT: movl %ecx, %eax
-; AVX-NEXT: shrb %al
-; AVX-NEXT: andb %dl, %al
-; AVX-NEXT: andb %cl, %al
+; AVX-NEXT: vpmovmskb %xmm0, %eax
+; AVX-NEXT: andb $-109, %al
+; AVX-NEXT: cmpb $-109, %al
+; AVX-NEXT: sete %al
; AVX-NEXT: retq
;
; KNL-LABEL: movmsk_v8i16:
@@ -4266,34 +4251,20 @@ define i1 @movmsk_v8i16(<8 x i16> %x, <8 x i16> %y) {
; KNL-NEXT: vpcmpgtw %xmm1, %xmm0, %xmm0
; KNL-NEXT: vpmovsxwq %xmm0, %zmm0
; KNL-NEXT: vptestmq %zmm0, %zmm0, %k0
-; KNL-NEXT: kshiftrw $4, %k0, %k1
-; KNL-NEXT: kmovw %k1, %ecx
-; KNL-NEXT: kshiftrw $7, %k0, %k1
-; KNL-NEXT: kmovw %k1, %eax
-; KNL-NEXT: kshiftrw $1, %k0, %k1
-; KNL-NEXT: kmovw %k1, %edx
-; KNL-NEXT: kmovw %k0, %esi
-; KNL-NEXT: andb %cl, %al
-; KNL-NEXT: andb %dl, %al
-; KNL-NEXT: andb %sil, %al
-; KNL-NEXT: # kill: def $al killed $al killed $eax
+; KNL-NEXT: kmovw %k0, %eax
+; KNL-NEXT: andb $-109, %al
+; KNL-NEXT: cmpb $-109, %al
+; KNL-NEXT: sete %al
; KNL-NEXT: vzeroupper
; KNL-NEXT: retq
;
; SKX-LABEL: movmsk_v8i16:
; SKX: # %bb.0:
; SKX-NEXT: vpcmpgtw %xmm1, %xmm0, %k0
-; SKX-NEXT: kshiftrb $4, %k0, %k1
-; SKX-NEXT: kmovd %k1, %ecx
-; SKX-NEXT: kshiftrb $7, %k0, %k1
-; SKX-NEXT: kmovd %k1, %eax
-; SKX-NEXT: kshiftrb $1, %k0, %k1
-; SKX-NEXT: kmovd %k1, %edx
-; SKX-NEXT: kmovd %k0, %esi
-; SKX-NEXT: andb %cl, %al
-; SKX-NEXT: andb %dl, %al
-; SKX-NEXT: andb %sil, %al
-; SKX-NEXT: # kill: def $al killed $al killed $eax
+; SKX-NEXT: kmovd %k0, %eax
+; SKX-NEXT: andb $-109, %al
+; SKX-NEXT: cmpb $-109, %al
+; SKX-NEXT: sete %al
; SKX-NEXT: retq
%cmp = icmp sgt <8 x i16> %x, %y
%e1 = extractelement <8 x i1> %cmp, i32 0
@@ -4478,30 +4449,18 @@ define i1 @movmsk_v4f32(<4 x float> %x, <4 x float> %y) {
; KNL-NEXT: # kill: def $xmm1 killed $xmm1 def $zmm1
; KNL-NEXT: # kill: def $xmm0 killed $xmm0 def $zmm0
; KNL-NEXT: vcmpeq_uqps %zmm1, %zmm0, %k0
-; KNL-NEXT: kshiftrw $3, %k0, %k1
-; KNL-NEXT: kmovw %k1, %ecx
-; KNL-NEXT: kshiftrw $2, %k0, %k1
-; KNL-NEXT: kmovw %k1, %eax
-; KNL-NEXT: kshiftrw $1, %k0, %k0
-; KNL-NEXT: kmovw %k0, %edx
-; KNL-NEXT: orb %cl, %al
-; KNL-NEXT: orb %dl, %al
-; KNL-NEXT: # kill: def $al killed $al killed $eax
+; KNL-NEXT: kmovw %k0, %eax
+; KNL-NEXT: testb $14, %al
+; KNL-NEXT: setne %al
; KNL-NEXT: vzeroupper
; KNL-NEXT: retq
;
; SKX-LABEL: movmsk_v4f32:
; SKX: # %bb.0:
; SKX-NEXT: vcmpeq_uqps %xmm1, %xmm0, %k0
-; SKX-NEXT: kshiftrb $3, %k0, %k1
-; SKX-NEXT: kmovd %k1, %ecx
-; SKX-NEXT: kshiftrb $2, %k0, %k1
-; SKX-NEXT: kmovd %k1, %eax
-; SKX-NEXT: kshiftrb $1, %k0, %k0
-; SKX-NEXT: kmovd %k0, %edx
-; SKX-NEXT: orb %cl, %al
-; SKX-NEXT: orb %dl, %al
-; SKX-NEXT: # kill: def $al killed $al killed $eax
+; SKX-NEXT: kmovd %k0, %eax
+; SKX-NEXT: testb $14, %al
+; SKX-NEXT: setne %al
; SKX-NEXT: retq
%cmp = fcmp ueq <4 x float> %x, %y
%e1 = extractelement <4 x i1> %cmp, i32 1
More information about the llvm-commits
mailing list