[llvm] 5cca982 - [X86][AVX512] Avoid bitcasts between scalar and vXi1 bool vectors

Simon Pilgrim via llvm-commits llvm-commits at lists.llvm.org
Thu Jun 11 02:23:31 PDT 2020


Author: Simon Pilgrim
Date: 2020-06-11T10:22:55+01:00
New Revision: 5cca9828ff1ccdb0a7bf99633afc836032dcf393

URL: https://github.com/llvm/llvm-project/commit/5cca9828ff1ccdb0a7bf99633afc836032dcf393
DIFF: https://github.com/llvm/llvm-project/commit/5cca9828ff1ccdb0a7bf99633afc836032dcf393.diff

LOG: [X86][AVX512] Avoid bitcasts between scalar and vXi1 bool vectors

AVX512 mask types are often bitcasted to scalar integers for various ops before being bitcast back to be used as a predicate. In many cases we can avoid these KMASK<->GPR transfers and perform equivalent operations on the mask unit.

If the destination mask type is legal, and we can confirm that the scalar op originally came from a mask/vector/float/double type then we should try to avoid the scalar entirely.

This avoids some codegen issues noticed while working on PTEST/MOVMSK improvements.

Partially fixes PR32547 - we don't create a KUNPCK yet, but OR(X,KSHIFTL(Y)) can be handled in a separate patch.

Differential Revision: https://reviews.llvm.org/D81548

Added: 
    

Modified: 
    llvm/lib/Target/X86/X86ISelLowering.cpp
    llvm/test/CodeGen/X86/avx512-intrinsics.ll
    llvm/test/CodeGen/X86/pr41619.ll
    llvm/test/CodeGen/X86/vector-shuffle-v1.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp
index 4b1d1a20777a..ead00a9d2015 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.cpp
+++ b/llvm/lib/Target/X86/X86ISelLowering.cpp
@@ -37889,6 +37889,74 @@ static SDValue createMMXBuildVector(BuildVectorSDNode *BV, SelectionDAG &DAG,
   return Ops[0];
 }
 
+// Recursive function that attempts to find if a bool vector node was originally
+// a vector/float/double that got truncated/extended/bitcast to/from a scalar
+// integer. If so, replace the scalar ops with bool vector equivalents back down
+// the chain.
+static SDValue combineBitcastToBoolVector(EVT VT, SDValue V, SDLoc DL,
+                                          SelectionDAG &DAG,
+                                          const X86Subtarget &Subtarget) {
+  const TargetLowering &TLI = DAG.getTargetLoweringInfo();
+  unsigned Opc = V.getOpcode();
+  switch (Opc) {
+  case ISD::BITCAST: {
+    // Bitcast from a vector/float/double, we can cheaply bitcast to VT.
+    SDValue Src = V.getOperand(0);
+    EVT SrcVT = Src.getValueType();
+    if (SrcVT.isVector() || SrcVT.isFloatingPoint())
+      return DAG.getBitcast(VT, Src);
+    break;
+  }
+  case ISD::TRUNCATE: {
+    // If we find a suitable source, a truncated scalar becomes a subvector.
+    SDValue Src = V.getOperand(0);
+    EVT NewSrcVT =
+        EVT::getVectorVT(*DAG.getContext(), MVT::i1, Src.getValueSizeInBits());
+    if (TLI.isTypeLegal(NewSrcVT))
+      if (SDValue N0 =
+              combineBitcastToBoolVector(NewSrcVT, Src, DL, DAG, Subtarget))
+        return DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, VT, N0,
+                           DAG.getIntPtrConstant(0, DL));
+    break;
+  }
+  case ISD::ANY_EXTEND:
+  case ISD::ZERO_EXTEND: {
+    // If we find a suitable source, an extended scalar becomes a subvector.
+    SDValue Src = V.getOperand(0);
+    EVT NewSrcVT = EVT::getVectorVT(*DAG.getContext(), MVT::i1,
+                                    Src.getScalarValueSizeInBits());
+    if (TLI.isTypeLegal(NewSrcVT))
+      if (SDValue N0 =
+              combineBitcastToBoolVector(NewSrcVT, Src, DL, DAG, Subtarget))
+        return DAG.getNode(ISD::INSERT_SUBVECTOR, DL, VT,
+                           Opc == ISD::ANY_EXTEND ? DAG.getUNDEF(VT)
+                                                  : DAG.getConstant(0, DL, VT),
+                           N0, DAG.getIntPtrConstant(0, DL));
+    break;
+  }
+  case ISD::OR: {
+    // If we find suitable sources, we can just move an OR to the vector domain.
+    SDValue Src0 = V.getOperand(0);
+    SDValue Src1 = V.getOperand(1);
+    if (SDValue N0 = combineBitcastToBoolVector(VT, Src0, DL, DAG, Subtarget))
+      if (SDValue N1 = combineBitcastToBoolVector(VT, Src1, DL, DAG, Subtarget))
+        return DAG.getNode(Opc, DL, VT, N0, N1);
+    break;
+  }
+  case ISD::SHL: {
+    // If we find a suitable source, a SHL becomes a KSHIFTL.
+    SDValue Src0 = V.getOperand(0);
+    if (auto *Amt = dyn_cast<ConstantSDNode>(V.getOperand(1)))
+      if (SDValue N0 = combineBitcastToBoolVector(VT, Src0, DL, DAG, Subtarget))
+        return DAG.getNode(
+            X86ISD::KSHIFTL, DL, VT, N0,
+            DAG.getTargetConstant(Amt->getZExtValue(), DL, MVT::i8));
+    break;
+  }
+  }
+  return SDValue();
+}
+
 static SDValue combineBitcast(SDNode *N, SelectionDAG &DAG,
                               TargetLowering::DAGCombinerInfo &DCI,
                               const X86Subtarget &Subtarget) {
@@ -37948,6 +38016,16 @@ static SDValue combineBitcast(SDNode *N, SelectionDAG &DAG,
       N0 = DAG.getBitcast(MVT::i8, N0);
       return DAG.getNode(ISD::TRUNCATE, dl, VT, N0);
     }
+  } else {
+    // If we're bitcasting from iX to vXi1, see if the integer originally
+    // began as a vXi1 and whether we can remove the bitcast entirely.
+    if (VT.isVector() && VT.getScalarType() == MVT::i1 &&
+        SrcVT.isScalarInteger() &&
+        DAG.getTargetLoweringInfo().isTypeLegal(VT)) {
+      if (SDValue V =
+              combineBitcastToBoolVector(VT, N0, SDLoc(N), DAG, Subtarget))
+        return V;
+    }
   }
 
   // Look for (i8 (bitcast (v8i1 (extract_subvector (v16i1 X), 0)))) and

diff  --git a/llvm/test/CodeGen/X86/avx512-intrinsics.ll b/llvm/test/CodeGen/X86/avx512-intrinsics.ll
index 2e5dc1e69c8a..a8222b1edab0 100644
--- a/llvm/test/CodeGen/X86/avx512-intrinsics.ll
+++ b/llvm/test/CodeGen/X86/avx512-intrinsics.ll
@@ -7496,13 +7496,7 @@ define <16 x float> @bad_mask_transition(<8 x double> %a, <8 x double> %b, <8 x
 ; X64-LABEL: bad_mask_transition:
 ; X64:       # %bb.0: # %entry
 ; X64-NEXT:    vcmplt_oqpd %zmm1, %zmm0, %k0
-; X64-NEXT:    kmovw %k0, %eax
-; X64-NEXT:    vcmplt_oqpd %zmm3, %zmm2, %k0
-; X64-NEXT:    kmovw %k0, %ecx
-; X64-NEXT:    movzbl %al, %eax
-; X64-NEXT:    movzbl %cl, %ecx
-; X64-NEXT:    kmovw %eax, %k0
-; X64-NEXT:    kmovw %ecx, %k1
+; X64-NEXT:    vcmplt_oqpd %zmm3, %zmm2, %k1
 ; X64-NEXT:    kunpckbw %k0, %k1, %k1
 ; X64-NEXT:    vblendmps %zmm5, %zmm4, %zmm0 {%k1}
 ; X64-NEXT:    retq
@@ -7518,13 +7512,7 @@ define <16 x float> @bad_mask_transition(<8 x double> %a, <8 x double> %b, <8 x
 ; X86-NEXT:    subl $64, %esp
 ; X86-NEXT:    vmovaps 72(%ebp), %zmm3
 ; X86-NEXT:    vcmplt_oqpd %zmm1, %zmm0, %k0
-; X86-NEXT:    kmovw %k0, %eax
-; X86-NEXT:    vcmplt_oqpd 8(%ebp), %zmm2, %k0
-; X86-NEXT:    kmovw %k0, %ecx
-; X86-NEXT:    movzbl %al, %eax
-; X86-NEXT:    movzbl %cl, %ecx
-; X86-NEXT:    kmovw %eax, %k0
-; X86-NEXT:    kmovw %ecx, %k1
+; X86-NEXT:    vcmplt_oqpd 8(%ebp), %zmm2, %k1
 ; X86-NEXT:    kunpckbw %k0, %k1, %k1
 ; X86-NEXT:    vmovaps 136(%ebp), %zmm3 {%k1}
 ; X86-NEXT:    vmovaps %zmm3, %zmm0
@@ -7551,10 +7539,7 @@ entry:
 define <16 x float> @bad_mask_transition_2(<8 x double> %a, <8 x double> %b, <8 x double> %c, <8 x double> %d, <16 x float> %e, <16 x float> %f) {
 ; X64-LABEL: bad_mask_transition_2:
 ; X64:       # %bb.0: # %entry
-; X64-NEXT:    vcmplt_oqpd %zmm1, %zmm0, %k0
-; X64-NEXT:    kmovw %k0, %eax
-; X64-NEXT:    movzbl %al, %eax
-; X64-NEXT:    kmovw %eax, %k1
+; X64-NEXT:    vcmplt_oqpd %zmm1, %zmm0, %k1
 ; X64-NEXT:    vblendmps %zmm5, %zmm4, %zmm0 {%k1}
 ; X64-NEXT:    retq
 ;
@@ -7568,10 +7553,7 @@ define <16 x float> @bad_mask_transition_2(<8 x double> %a, <8 x double> %b, <8
 ; X86-NEXT:    andl $-64, %esp
 ; X86-NEXT:    subl $64, %esp
 ; X86-NEXT:    vmovaps 72(%ebp), %zmm2
-; X86-NEXT:    vcmplt_oqpd %zmm1, %zmm0, %k0
-; X86-NEXT:    kmovw %k0, %eax
-; X86-NEXT:    movzbl %al, %eax
-; X86-NEXT:    kmovw %eax, %k1
+; X86-NEXT:    vcmplt_oqpd %zmm1, %zmm0, %k1
 ; X86-NEXT:    vmovaps 136(%ebp), %zmm2 {%k1}
 ; X86-NEXT:    vmovaps %zmm2, %zmm0
 ; X86-NEXT:    movl %ebp, %esp

diff  --git a/llvm/test/CodeGen/X86/pr41619.ll b/llvm/test/CodeGen/X86/pr41619.ll
index 13bfd910587c..87c629270903 100644
--- a/llvm/test/CodeGen/X86/pr41619.ll
+++ b/llvm/test/CodeGen/X86/pr41619.ll
@@ -44,8 +44,6 @@ define i32 @bar(double %blah) nounwind {
 ; AVX512-LABEL: bar:
 ; AVX512:       ## %bb.0:
 ; AVX512-NEXT:    vmovq %xmm0, %rax
-; AVX512-NEXT:    kmovd %eax, %k0
-; AVX512-NEXT:    kmovq %k0, %rax
 ; AVX512-NEXT:    ## kill: def $eax killed $eax killed $rax
 ; AVX512-NEXT:    retq
   %z = bitcast double %blah to i64

diff  --git a/llvm/test/CodeGen/X86/vector-shuffle-v1.ll b/llvm/test/CodeGen/X86/vector-shuffle-v1.ll
index 782303e97b12..c2c5eafb9cf9 100644
--- a/llvm/test/CodeGen/X86/vector-shuffle-v1.ll
+++ b/llvm/test/CodeGen/X86/vector-shuffle-v1.ll
@@ -891,12 +891,10 @@ define void @PR32547(<8 x float> %a, <8 x float> %b, <8 x float> %c, <8 x float>
 ; AVX512F-NEXT:    # kill: def $ymm0 killed $ymm0 def $zmm0
 ; AVX512F-NEXT:    vcmpltps %zmm1, %zmm0, %k0
 ; AVX512F-NEXT:    vcmpltps %zmm3, %zmm2, %k1
-; AVX512F-NEXT:    kmovw %k1, %eax
-; AVX512F-NEXT:    kmovw %k0, %ecx
-; AVX512F-NEXT:    movzbl %al, %eax
-; AVX512F-NEXT:    shll $8, %ecx
-; AVX512F-NEXT:    orl %eax, %ecx
-; AVX512F-NEXT:    kmovw %ecx, %k1
+; AVX512F-NEXT:    kshiftlw $8, %k0, %k0
+; AVX512F-NEXT:    kshiftlw $8, %k1, %k1
+; AVX512F-NEXT:    kshiftrw $8, %k1, %k1
+; AVX512F-NEXT:    korw %k1, %k0, %k1
 ; AVX512F-NEXT:    vxorps %xmm0, %xmm0, %xmm0
 ; AVX512F-NEXT:    vmovaps %zmm0, (%rdi) {%k1}
 ; AVX512F-NEXT:    vzeroupper
@@ -906,12 +904,8 @@ define void @PR32547(<8 x float> %a, <8 x float> %b, <8 x float> %c, <8 x float>
 ; AVX512VL:       # %bb.0: # %entry
 ; AVX512VL-NEXT:    vcmpltps %ymm1, %ymm0, %k0
 ; AVX512VL-NEXT:    vcmpltps %ymm3, %ymm2, %k1
-; AVX512VL-NEXT:    kmovw %k1, %eax
-; AVX512VL-NEXT:    kmovw %k0, %ecx
-; AVX512VL-NEXT:    movzbl %al, %eax
-; AVX512VL-NEXT:    shll $8, %ecx
-; AVX512VL-NEXT:    orl %eax, %ecx
-; AVX512VL-NEXT:    kmovw %ecx, %k1
+; AVX512VL-NEXT:    kshiftlw $8, %k0, %k0
+; AVX512VL-NEXT:    korw %k1, %k0, %k1
 ; AVX512VL-NEXT:    vxorps %xmm0, %xmm0, %xmm0
 ; AVX512VL-NEXT:    vmovaps %zmm0, (%rdi) {%k1}
 ; AVX512VL-NEXT:    vzeroupper
@@ -921,11 +915,8 @@ define void @PR32547(<8 x float> %a, <8 x float> %b, <8 x float> %c, <8 x float>
 ; VL_BW_DQ:       # %bb.0: # %entry
 ; VL_BW_DQ-NEXT:    vcmpltps %ymm1, %ymm0, %k0
 ; VL_BW_DQ-NEXT:    vcmpltps %ymm3, %ymm2, %k1
-; VL_BW_DQ-NEXT:    kmovd %k0, %eax
-; VL_BW_DQ-NEXT:    kmovb %k1, %ecx
-; VL_BW_DQ-NEXT:    shll $8, %eax
-; VL_BW_DQ-NEXT:    orl %ecx, %eax
-; VL_BW_DQ-NEXT:    kmovd %eax, %k1
+; VL_BW_DQ-NEXT:    kshiftlw $8, %k0, %k0
+; VL_BW_DQ-NEXT:    korw %k1, %k0, %k1
 ; VL_BW_DQ-NEXT:    vxorps %xmm0, %xmm0, %xmm0
 ; VL_BW_DQ-NEXT:    vmovaps %zmm0, (%rdi) {%k1}
 ; VL_BW_DQ-NEXT:    vzeroupper


        


More information about the llvm-commits mailing list