[llvm] ebb181c - [X86] matchScalarReduction - add support for partial reductions

Simon Pilgrim via llvm-commits llvm-commits at lists.llvm.org
Mon Mar 16 11:01:24 PDT 2020


Author: Simon Pilgrim
Date: 2020-03-16T18:01:02Z
New Revision: ebb181cf4098e87c0ac357b09ad5fb2d24b1863d

URL: https://github.com/llvm/llvm-project/commit/ebb181cf4098e87c0ac357b09ad5fb2d24b1863d
DIFF: https://github.com/llvm/llvm-project/commit/ebb181cf4098e87c0ac357b09ad5fb2d24b1863d.diff

LOG: [X86] matchScalarReduction - add support for partial reductions

Add optional support for opt-in partial reduction cases by providing an optional partial mask to indicate which elements have been extracted for the scalar reduction.

Added: 
    

Modified: 
    llvm/lib/Target/X86/X86ISelLowering.cpp
    llvm/test/CodeGen/X86/movmsk-cmp.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp
index 339e3c37ee25..8e8a7cce9fb1 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.cpp
+++ b/llvm/lib/Target/X86/X86ISelLowering.cpp
@@ -20964,9 +20964,12 @@ static SDValue getSETCC(X86::CondCode Cond, SDValue EFLAGS, const SDLoc &dl,
 }
 
 /// Helper for matching OR(EXTRACTELT(X,0),OR(EXTRACTELT(X,1),...))
-/// style scalarized (associative) reduction patterns.
+/// style scalarized (associative) reduction patterns. Partial reductions
+/// are supported when the pointer SrcMask is non-null.
+/// TODO - move this to SelectionDAG?
 static bool matchScalarReduction(SDValue Op, ISD::NodeType BinOp,
-                                 SmallVectorImpl<SDValue> &SrcOps) {
+                                 SmallVectorImpl<SDValue> &SrcOps,
+                                 SmallVectorImpl<APInt> *SrcMask = nullptr) {
   SmallVector<SDValue, 8> Opnds;
   DenseMap<SDValue, APInt> SrcOpMap;
   EVT VT = MVT::Other;
@@ -21018,12 +21021,18 @@ static bool matchScalarReduction(SDValue Op, ISD::NodeType BinOp,
     M->second.setBit(CIdx);
   }
 
-  // Quit if not all elements are used.
-  for (DenseMap<SDValue, APInt>::const_iterator I = SrcOpMap.begin(),
-                                                E = SrcOpMap.end();
-       I != E; ++I) {
-    if (!I->second.isAllOnesValue())
-      return false;
+  if (SrcMask) {
+    // Collect the source partial masks.
+    for (SDValue &SrcOp : SrcOps)
+      SrcMask->push_back(SrcOpMap[SrcOp]);
+  } else {
+    // Quit if not all elements are used.
+    for (DenseMap<SDValue, APInt>::const_iterator I = SrcOpMap.begin(),
+                                                  E = SrcOpMap.end();
+         I != E; ++I) {
+      if (!I->second.isAllOnesValue())
+        return false;
+    }
   }
 
   return true;
@@ -41210,7 +41219,8 @@ static SDValue combineAnd(SDNode *N, SelectionDAG &DAG,
   // TODO: Support multiple SrcOps.
   if (VT == MVT::i1) {
     SmallVector<SDValue, 2> SrcOps;
-    if (matchScalarReduction(SDValue(N, 0), ISD::AND, SrcOps) &&
+    SmallVector<APInt, 2> SrcPartials;
+    if (matchScalarReduction(SDValue(N, 0), ISD::AND, SrcOps, &SrcPartials) &&
         SrcOps.size() == 1) {
       SDLoc dl(N);
       const TargetLowering &TLI = DAG.getTargetLoweringInfo();
@@ -41220,9 +41230,11 @@ static SDValue combineAnd(SDNode *N, SelectionDAG &DAG,
       if (!Mask && TLI.isTypeLegal(SrcOps[0].getValueType()))
         Mask = DAG.getBitcast(MaskVT, SrcOps[0]);
       if (Mask) {
-        APInt AllBits = APInt::getAllOnesValue(NumElts);
-        return DAG.getSetCC(dl, MVT::i1, Mask,
-                            DAG.getConstant(AllBits, dl, MaskVT), ISD::SETEQ);
+        assert(SrcPartials[0].getBitWidth() == NumElts &&
+               "Unexpected partial reduction mask");
+        SDValue PartialBits = DAG.getConstant(SrcPartials[0], dl, MaskVT);
+        Mask = DAG.getNode(ISD::AND, dl, MaskVT, Mask, PartialBits);
+        return DAG.getSetCC(dl, MVT::i1, Mask, PartialBits, ISD::SETEQ);
       }
     }
   }
@@ -41685,7 +41697,8 @@ static SDValue combineOr(SDNode *N, SelectionDAG &DAG,
   // TODO: Support multiple SrcOps.
   if (VT == MVT::i1) {
     SmallVector<SDValue, 2> SrcOps;
-    if (matchScalarReduction(SDValue(N, 0), ISD::OR, SrcOps) &&
+    SmallVector<APInt, 2> SrcPartials;
+    if (matchScalarReduction(SDValue(N, 0), ISD::OR, SrcOps, &SrcPartials) &&
         SrcOps.size() == 1) {
       SDLoc dl(N);
       const TargetLowering &TLI = DAG.getTargetLoweringInfo();
@@ -41695,9 +41708,12 @@ static SDValue combineOr(SDNode *N, SelectionDAG &DAG,
       if (!Mask && TLI.isTypeLegal(SrcOps[0].getValueType()))
         Mask = DAG.getBitcast(MaskVT, SrcOps[0]);
       if (Mask) {
-        APInt AllBits = APInt::getNullValue(NumElts);
-        return DAG.getSetCC(dl, MVT::i1, Mask,
-                            DAG.getConstant(AllBits, dl, MaskVT), ISD::SETNE);
+        assert(SrcPartials[0].getBitWidth() == NumElts &&
+               "Unexpected partial reduction mask");
+        SDValue ZeroBits = DAG.getConstant(0, dl, MaskVT);
+        SDValue PartialBits = DAG.getConstant(SrcPartials[0], dl, MaskVT);
+        Mask = DAG.getNode(ISD::AND, dl, MaskVT, Mask, PartialBits);
+        return DAG.getSetCC(dl, MVT::i1, Mask, ZeroBits, ISD::SETNE);
       }
     }
   }

diff  --git a/llvm/test/CodeGen/X86/movmsk-cmp.ll b/llvm/test/CodeGen/X86/movmsk-cmp.ll
index 7f0a1418a719..4fdde8c06641 100644
--- a/llvm/test/CodeGen/X86/movmsk-cmp.ll
+++ b/llvm/test/CodeGen/X86/movmsk-cmp.ll
@@ -4225,40 +4225,25 @@ define i1 @movmsk_v16i8(<16 x i8> %x, <16 x i8> %y) {
   ret i1 %u2
 }
 
-; TODO: Replace shift+mask chain with NOT+TEST+SETE
 define i1 @movmsk_v8i16(<8 x i16> %x, <8 x i16> %y) {
 ; SSE2-LABEL: movmsk_v8i16:
 ; SSE2:       # %bb.0:
 ; SSE2-NEXT:    pcmpgtw %xmm1, %xmm0
 ; SSE2-NEXT:    packsswb %xmm0, %xmm0
-; SSE2-NEXT:    pmovmskb %xmm0, %ecx
-; SSE2-NEXT:    movl %ecx, %eax
-; SSE2-NEXT:    shrb $7, %al
-; SSE2-NEXT:    movl %ecx, %edx
-; SSE2-NEXT:    andb $16, %dl
-; SSE2-NEXT:    shrb $4, %dl
-; SSE2-NEXT:    andb %al, %dl
-; SSE2-NEXT:    movl %ecx, %eax
-; SSE2-NEXT:    shrb %al
-; SSE2-NEXT:    andb %dl, %al
-; SSE2-NEXT:    andb %cl, %al
+; SSE2-NEXT:    pmovmskb %xmm0, %eax
+; SSE2-NEXT:    andb $-109, %al
+; SSE2-NEXT:    cmpb $-109, %al
+; SSE2-NEXT:    sete %al
 ; SSE2-NEXT:    retq
 ;
 ; AVX-LABEL: movmsk_v8i16:
 ; AVX:       # %bb.0:
 ; AVX-NEXT:    vpcmpgtw %xmm1, %xmm0, %xmm0
 ; AVX-NEXT:    vpacksswb %xmm0, %xmm0, %xmm0
-; AVX-NEXT:    vpmovmskb %xmm0, %ecx
-; AVX-NEXT:    movl %ecx, %eax
-; AVX-NEXT:    shrb $7, %al
-; AVX-NEXT:    movl %ecx, %edx
-; AVX-NEXT:    andb $16, %dl
-; AVX-NEXT:    shrb $4, %dl
-; AVX-NEXT:    andb %al, %dl
-; AVX-NEXT:    movl %ecx, %eax
-; AVX-NEXT:    shrb %al
-; AVX-NEXT:    andb %dl, %al
-; AVX-NEXT:    andb %cl, %al
+; AVX-NEXT:    vpmovmskb %xmm0, %eax
+; AVX-NEXT:    andb $-109, %al
+; AVX-NEXT:    cmpb $-109, %al
+; AVX-NEXT:    sete %al
 ; AVX-NEXT:    retq
 ;
 ; KNL-LABEL: movmsk_v8i16:
@@ -4266,34 +4251,20 @@ define i1 @movmsk_v8i16(<8 x i16> %x, <8 x i16> %y) {
 ; KNL-NEXT:    vpcmpgtw %xmm1, %xmm0, %xmm0
 ; KNL-NEXT:    vpmovsxwq %xmm0, %zmm0
 ; KNL-NEXT:    vptestmq %zmm0, %zmm0, %k0
-; KNL-NEXT:    kshiftrw $4, %k0, %k1
-; KNL-NEXT:    kmovw %k1, %ecx
-; KNL-NEXT:    kshiftrw $7, %k0, %k1
-; KNL-NEXT:    kmovw %k1, %eax
-; KNL-NEXT:    kshiftrw $1, %k0, %k1
-; KNL-NEXT:    kmovw %k1, %edx
-; KNL-NEXT:    kmovw %k0, %esi
-; KNL-NEXT:    andb %cl, %al
-; KNL-NEXT:    andb %dl, %al
-; KNL-NEXT:    andb %sil, %al
-; KNL-NEXT:    # kill: def $al killed $al killed $eax
+; KNL-NEXT:    kmovw %k0, %eax
+; KNL-NEXT:    andb $-109, %al
+; KNL-NEXT:    cmpb $-109, %al
+; KNL-NEXT:    sete %al
 ; KNL-NEXT:    vzeroupper
 ; KNL-NEXT:    retq
 ;
 ; SKX-LABEL: movmsk_v8i16:
 ; SKX:       # %bb.0:
 ; SKX-NEXT:    vpcmpgtw %xmm1, %xmm0, %k0
-; SKX-NEXT:    kshiftrb $4, %k0, %k1
-; SKX-NEXT:    kmovd %k1, %ecx
-; SKX-NEXT:    kshiftrb $7, %k0, %k1
-; SKX-NEXT:    kmovd %k1, %eax
-; SKX-NEXT:    kshiftrb $1, %k0, %k1
-; SKX-NEXT:    kmovd %k1, %edx
-; SKX-NEXT:    kmovd %k0, %esi
-; SKX-NEXT:    andb %cl, %al
-; SKX-NEXT:    andb %dl, %al
-; SKX-NEXT:    andb %sil, %al
-; SKX-NEXT:    # kill: def $al killed $al killed $eax
+; SKX-NEXT:    kmovd %k0, %eax
+; SKX-NEXT:    andb $-109, %al
+; SKX-NEXT:    cmpb $-109, %al
+; SKX-NEXT:    sete %al
 ; SKX-NEXT:    retq
   %cmp = icmp sgt <8 x i16> %x, %y
   %e1 = extractelement <8 x i1> %cmp, i32 0
@@ -4478,30 +4449,18 @@ define i1 @movmsk_v4f32(<4 x float> %x, <4 x float> %y) {
 ; KNL-NEXT:    # kill: def $xmm1 killed $xmm1 def $zmm1
 ; KNL-NEXT:    # kill: def $xmm0 killed $xmm0 def $zmm0
 ; KNL-NEXT:    vcmpeq_uqps %zmm1, %zmm0, %k0
-; KNL-NEXT:    kshiftrw $3, %k0, %k1
-; KNL-NEXT:    kmovw %k1, %ecx
-; KNL-NEXT:    kshiftrw $2, %k0, %k1
-; KNL-NEXT:    kmovw %k1, %eax
-; KNL-NEXT:    kshiftrw $1, %k0, %k0
-; KNL-NEXT:    kmovw %k0, %edx
-; KNL-NEXT:    orb %cl, %al
-; KNL-NEXT:    orb %dl, %al
-; KNL-NEXT:    # kill: def $al killed $al killed $eax
+; KNL-NEXT:    kmovw %k0, %eax
+; KNL-NEXT:    testb $14, %al
+; KNL-NEXT:    setne %al
 ; KNL-NEXT:    vzeroupper
 ; KNL-NEXT:    retq
 ;
 ; SKX-LABEL: movmsk_v4f32:
 ; SKX:       # %bb.0:
 ; SKX-NEXT:    vcmpeq_uqps %xmm1, %xmm0, %k0
-; SKX-NEXT:    kshiftrb $3, %k0, %k1
-; SKX-NEXT:    kmovd %k1, %ecx
-; SKX-NEXT:    kshiftrb $2, %k0, %k1
-; SKX-NEXT:    kmovd %k1, %eax
-; SKX-NEXT:    kshiftrb $1, %k0, %k0
-; SKX-NEXT:    kmovd %k0, %edx
-; SKX-NEXT:    orb %cl, %al
-; SKX-NEXT:    orb %dl, %al
-; SKX-NEXT:    # kill: def $al killed $al killed $eax
+; SKX-NEXT:    kmovd %k0, %eax
+; SKX-NEXT:    testb $14, %al
+; SKX-NEXT:    setne %al
 ; SKX-NEXT:    retq
   %cmp = fcmp ueq <4 x float> %x, %y
   %e1 = extractelement <4 x i1> %cmp, i32 1


        


More information about the llvm-commits mailing list