[llvm] r311837 - [AVX512] Add patterns to match masked extract_subvector with bitcasts between the vselect and the extract_subvector. Remove the late DAG combine.

Craig Topper via llvm-commits llvm-commits at lists.llvm.org
Sat Aug 26 15:24:58 PDT 2017


Author: ctopper
Date: Sat Aug 26 15:24:57 2017
New Revision: 311837

URL: http://llvm.org/viewvc/llvm-project?rev=311837&view=rev
Log:
[AVX512] Add patterns to match masked extract_subvector with bitcasts between the vselect and the extract_subvector. Remove the late DAG combine.

We used to do a late DAG combine to move the bitcasts out of the way, but I'm starting to think that it's better to canonicalize extract_subvector's type to match the type of its input. I've seen some cases where we've formed two different extract_subvector from the same node where one had a bitcast and the other didn't.

Add some more test cases to ensure we've also got most of the zero masking covered too.

Modified:
    llvm/trunk/lib/Target/X86/X86ISelLowering.cpp
    llvm/trunk/lib/Target/X86/X86InstrAVX512.td
    llvm/trunk/test/CodeGen/X86/vector-shuffle-masked.ll

Modified: llvm/trunk/lib/Target/X86/X86ISelLowering.cpp
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/lib/Target/X86/X86ISelLowering.cpp?rev=311837&r1=311836&r2=311837&view=diff
==============================================================================
--- llvm/trunk/lib/Target/X86/X86ISelLowering.cpp (original)
+++ llvm/trunk/lib/Target/X86/X86ISelLowering.cpp Sat Aug 26 15:24:57 2017
@@ -30276,27 +30276,6 @@ static bool combineBitcastForMaskedOp(SD
                               DAG.getIntPtrConstant(Imm, DL)));
     return true;
   }
-  case ISD::EXTRACT_SUBVECTOR: {
-    unsigned EltSize = EltVT.getSizeInBits();
-    if (EltSize != 32 && EltSize != 64)
-      return false;
-    MVT OpEltVT = Op.getSimpleValueType().getVectorElementType();
-    // Only change element size, not type.
-    if (EltVT.isInteger() != OpEltVT.isInteger())
-      return false;
-    uint64_t Imm = cast<ConstantSDNode>(Op.getOperand(1))->getZExtValue();
-    Imm = (Imm * OpEltVT.getSizeInBits()) / EltSize;
-    // Op0 needs to be bitcasted to a larger vector with the same element type.
-    SDValue Op0 = Op.getOperand(0);
-    MVT Op0VT = MVT::getVectorVT(EltVT,
-                            Op0.getSimpleValueType().getSizeInBits() / EltSize);
-    Op0 = DAG.getBitcast(Op0VT, Op0);
-    DCI.AddToWorklist(Op0.getNode());
-    DCI.CombineTo(OrigOp.getNode(),
-                  DAG.getNode(Opcode, DL, VT, Op0,
-                              DAG.getIntPtrConstant(Imm, DL)));
-    return true;
-  }
   case X86ISD::SUBV_BROADCAST: {
     unsigned EltSize = EltVT.getSizeInBits();
     if (EltSize != 32 && EltSize != 64)

Modified: llvm/trunk/lib/Target/X86/X86InstrAVX512.td
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/lib/Target/X86/X86InstrAVX512.td?rev=311837&r1=311836&r2=311837&view=diff
==============================================================================
--- llvm/trunk/lib/Target/X86/X86InstrAVX512.td (original)
+++ llvm/trunk/lib/Target/X86/X86InstrAVX512.td Sat Aug 26 15:24:57 2017
@@ -887,6 +887,112 @@ def : Pat<(v64i8 (insert_subvector undef
           (INSERT_SUBREG (v64i8 (IMPLICIT_DEF)), VR256X:$src, sub_ymm)>;
 }
 
+// Additional patterns for handling a bitcast between the vselect and the
+// extract_subvector.
+multiclass vextract_for_mask_cast<string InstrStr, X86VectorVTInfo From,
+                                  X86VectorVTInfo To, X86VectorVTInfo Cast,
+                                  PatFrag vextract_extract,
+                                  SDNodeXForm EXTRACT_get_vextract_imm,
+                                  list<Predicate> p> {
+let Predicates = p in {
+  def : Pat<(Cast.VT (vselect Cast.KRCWM:$mask,
+                              (bitconvert
+                               (To.VT (vextract_extract:$ext
+                                       (From.VT From.RC:$src), (iPTR imm)))),
+                              To.RC:$src0)),
+            (Cast.VT (!cast<Instruction>(InstrStr#"rrk")
+                      Cast.RC:$src0, Cast.KRCWM:$mask, From.RC:$src,
+                      (EXTRACT_get_vextract_imm To.RC:$ext)))>;
+
+  def : Pat<(Cast.VT (vselect Cast.KRCWM:$mask,
+                              (bitconvert
+                               (To.VT (vextract_extract:$ext
+                                       (From.VT From.RC:$src), (iPTR imm)))),
+                              Cast.ImmAllZerosV)),
+            (Cast.VT (!cast<Instruction>(InstrStr#"rrkz")
+                      Cast.KRCWM:$mask, From.RC:$src,
+                      (EXTRACT_get_vextract_imm To.RC:$ext)))>;
+}
+}
+
+defm : vextract_for_mask_cast<"VEXTRACTF32x4Z256", v4f64x_info, v2f64x_info,
+                              v4f32x_info, vextract128_extract,
+                              EXTRACT_get_vextract128_imm, [HasVLX]>;
+defm : vextract_for_mask_cast<"VEXTRACTF64x2Z256", v8f32x_info, v4f32x_info,
+                              v2f64x_info, vextract128_extract,
+                              EXTRACT_get_vextract128_imm, [HasDQI, HasVLX]>;
+
+defm : vextract_for_mask_cast<"VEXTRACTI32x4Z256", v4i64x_info, v2i64x_info,
+                              v4i32x_info, vextract128_extract,
+                              EXTRACT_get_vextract128_imm, [HasVLX]>;
+defm : vextract_for_mask_cast<"VEXTRACTI32x4Z256", v16i16x_info, v8i16x_info,
+                              v4i32x_info, vextract128_extract,
+                              EXTRACT_get_vextract128_imm, [HasVLX]>;
+defm : vextract_for_mask_cast<"VEXTRACTI32x4Z256", v32i8x_info, v16i8x_info,
+                              v4i32x_info, vextract128_extract,
+                              EXTRACT_get_vextract128_imm, [HasVLX]>;
+defm : vextract_for_mask_cast<"VEXTRACTI64x2Z256", v8i32x_info, v4i32x_info,
+                              v2i64x_info, vextract128_extract,
+                              EXTRACT_get_vextract128_imm, [HasDQI, HasVLX]>;
+defm : vextract_for_mask_cast<"VEXTRACTI64x2Z256", v16i16x_info, v8i16x_info,
+                              v2i64x_info, vextract128_extract,
+                              EXTRACT_get_vextract128_imm, [HasDQI, HasVLX]>;
+defm : vextract_for_mask_cast<"VEXTRACTI64x2Z256", v32i8x_info, v16i8x_info,
+                              v2i64x_info, vextract128_extract,
+                              EXTRACT_get_vextract128_imm, [HasDQI, HasVLX]>;
+
+defm : vextract_for_mask_cast<"VEXTRACTF32x4Z", v8f64_info, v2f64x_info,
+                              v4f32x_info, vextract128_extract,
+                              EXTRACT_get_vextract128_imm, [HasAVX512]>;
+defm : vextract_for_mask_cast<"VEXTRACTF64x2Z", v16f32_info, v4f32x_info,
+                              v2f64x_info, vextract128_extract,
+                              EXTRACT_get_vextract128_imm, [HasDQI]>;
+
+defm : vextract_for_mask_cast<"VEXTRACTI32x4Z", v8i64_info, v2i64x_info,
+                              v4i32x_info, vextract128_extract,
+                              EXTRACT_get_vextract128_imm, [HasAVX512]>;
+defm : vextract_for_mask_cast<"VEXTRACTI32x4Z", v32i16_info, v8i16x_info,
+                              v4i32x_info, vextract128_extract,
+                              EXTRACT_get_vextract128_imm, [HasAVX512]>;
+defm : vextract_for_mask_cast<"VEXTRACTI32x4Z", v64i8_info, v16i8x_info,
+                              v4i32x_info, vextract128_extract,
+                              EXTRACT_get_vextract128_imm, [HasAVX512]>;
+defm : vextract_for_mask_cast<"VEXTRACTI64x2Z", v16i32_info, v4i32x_info,
+                              v2i64x_info, vextract128_extract,
+                              EXTRACT_get_vextract128_imm, [HasDQI]>;
+defm : vextract_for_mask_cast<"VEXTRACTI64x2Z", v32i16_info, v8i16x_info,
+                              v2i64x_info, vextract128_extract,
+                              EXTRACT_get_vextract128_imm, [HasDQI]>;
+defm : vextract_for_mask_cast<"VEXTRACTI64x2Z", v64i8_info, v16i8x_info,
+                              v2i64x_info, vextract128_extract,
+                              EXTRACT_get_vextract128_imm, [HasDQI]>;
+
+defm : vextract_for_mask_cast<"VEXTRACTF32x8Z", v8f64_info, v4f64x_info,
+                              v8f32x_info, vextract256_extract,
+                              EXTRACT_get_vextract256_imm, [HasDQI]>;
+defm : vextract_for_mask_cast<"VEXTRACTF64x4Z", v16f32_info, v8f32x_info,
+                              v4f64x_info, vextract256_extract,
+                              EXTRACT_get_vextract256_imm, [HasAVX512]>;
+
+defm : vextract_for_mask_cast<"VEXTRACTI32x8Z", v8i64_info, v4i64x_info,
+                              v8i32x_info, vextract256_extract,
+                              EXTRACT_get_vextract256_imm, [HasDQI]>;
+defm : vextract_for_mask_cast<"VEXTRACTI32x8Z", v32i16_info, v16i16x_info,
+                              v8i32x_info, vextract256_extract,
+                              EXTRACT_get_vextract256_imm, [HasDQI]>;
+defm : vextract_for_mask_cast<"VEXTRACTI32x8Z", v64i8_info, v32i8x_info,
+                              v8i32x_info, vextract256_extract,
+                              EXTRACT_get_vextract256_imm, [HasDQI]>;
+defm : vextract_for_mask_cast<"VEXTRACTI64x4Z", v16i32_info, v8i32x_info,
+                              v4i64x_info, vextract256_extract,
+                              EXTRACT_get_vextract256_imm, [HasAVX512]>;
+defm : vextract_for_mask_cast<"VEXTRACTI64x4Z", v32i16_info, v16i16x_info,
+                              v4i64x_info, vextract256_extract,
+                              EXTRACT_get_vextract256_imm, [HasAVX512]>;
+defm : vextract_for_mask_cast<"VEXTRACTI64x4Z", v64i8_info, v32i8x_info,
+                              v4i64x_info, vextract256_extract,
+                              EXTRACT_get_vextract256_imm, [HasAVX512]>;
+
 // vextractps - extract 32 bits from XMM
 def VEXTRACTPSZrr : AVX512AIi8<0x17, MRMDestReg, (outs GR32:$dst),
       (ins VR128X:$src1, u8imm:$src2),

Modified: llvm/trunk/test/CodeGen/X86/vector-shuffle-masked.ll
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/test/CodeGen/X86/vector-shuffle-masked.ll?rev=311837&r1=311836&r2=311837&view=diff
==============================================================================
--- llvm/trunk/test/CodeGen/X86/vector-shuffle-masked.ll (original)
+++ llvm/trunk/test/CodeGen/X86/vector-shuffle-masked.ll Sat Aug 26 15:24:57 2017
@@ -1031,6 +1031,19 @@ define <8 x i32> @mask_cast_extract_v8i6
   ret <8 x i32> %res
 }
 
+define <8 x i32> @mask_cast_extract_v8i64_v8i32_1_z(<8 x i64> %a, i8 %mask) {
+; CHECK-LABEL: mask_cast_extract_v8i64_v8i32_1_z:
+; CHECK:       # BB#0:
+; CHECK-NEXT:    kmovw %edi, %k1
+; CHECK-NEXT:    vextracti32x8 $1, %zmm0, %ymm0 {%k1} {z}
+; CHECK-NEXT:    retq
+  %shuffle = shufflevector <8 x i64> %a, <8 x i64> undef, <4 x i32> <i32 4, i32 5, i32 6, i32 7>
+  %shuffle.cast = bitcast <4 x i64> %shuffle to <8 x i32>
+  %mask.cast = bitcast i8 %mask to <8 x i1>
+  %res = select <8 x i1> %mask.cast, <8 x i32> %shuffle.cast, <8 x i32> zeroinitializer
+  ret <8 x i32> %res
+}
+
 define <8 x float> @mask_cast_extract_v8f64_v8f32_1(<8 x double> %a, <8 x float> %passthru, i8 %mask) {
 ; CHECK-LABEL: mask_cast_extract_v8f64_v8f32_1:
 ; CHECK:       # BB#0:
@@ -1045,6 +1058,19 @@ define <8 x float> @mask_cast_extract_v8
   ret <8 x float> %res
 }
 
+define <8 x float> @mask_cast_extract_v8f64_v8f32_1_z(<8 x double> %a, i8 %mask) {
+; CHECK-LABEL: mask_cast_extract_v8f64_v8f32_1_z:
+; CHECK:       # BB#0:
+; CHECK-NEXT:    kmovw %edi, %k1
+; CHECK-NEXT:    vextractf32x8 $1, %zmm0, %ymm0 {%k1} {z}
+; CHECK-NEXT:    retq
+  %shuffle = shufflevector <8 x double> %a, <8 x double> undef, <4 x i32> <i32 4, i32 5, i32 6, i32 7>
+  %shuffle.cast = bitcast <4 x double> %shuffle to <8 x float>
+  %mask.cast = bitcast i8 %mask to <8 x i1>
+  %res = select <8 x i1> %mask.cast, <8 x float> %shuffle.cast, <8 x float> zeroinitializer
+  ret <8 x float> %res
+}
+
 define <4 x i32> @mask_cast_extract_v8i64_v4i32_1(<8 x i64> %a, <4 x i32> %passthru, i8 %mask) {
 ; CHECK-LABEL: mask_cast_extract_v8i64_v4i32_1:
 ; CHECK:       # BB#0:
@@ -1061,6 +1087,21 @@ define <4 x i32> @mask_cast_extract_v8i6
   ret <4 x i32> %res
 }
 
+define <4 x i32> @mask_cast_extract_v8i64_v4i32_1_z(<8 x i64> %a, i8 %mask) {
+; CHECK-LABEL: mask_cast_extract_v8i64_v4i32_1_z:
+; CHECK:       # BB#0:
+; CHECK-NEXT:    kmovw %edi, %k1
+; CHECK-NEXT:    vextracti32x4 $1, %zmm0, %xmm0 {%k1} {z}
+; CHECK-NEXT:    vzeroupper
+; CHECK-NEXT:    retq
+  %shuffle = shufflevector <8 x i64> %a, <8 x i64> undef, <2 x i32> <i32 2, i32 3>
+  %shuffle.cast = bitcast <2 x i64> %shuffle to <4 x i32>
+  %mask.cast = bitcast i8 %mask to <8 x i1>
+  %mask.extract = shufflevector <8 x i1> %mask.cast, <8 x i1> undef, <4 x i32> <i32 0, i32 1, i32 2, i32 3>
+  %res = select <4 x i1> %mask.extract, <4 x i32> %shuffle.cast, <4 x i32> zeroinitializer
+  ret <4 x i32> %res
+}
+
 define <4 x float> @mask_cast_extract_v8f64_v4f32_1(<8 x double> %a, <4 x float> %passthru, i8 %mask) {
 ; CHECK-LABEL: mask_cast_extract_v8f64_v4f32_1:
 ; CHECK:       # BB#0:
@@ -1077,6 +1118,21 @@ define <4 x float> @mask_cast_extract_v8
   ret <4 x float> %res
 }
 
+define <4 x float> @mask_cast_extract_v8f64_v4f32_1_z(<8 x double> %a, i8 %mask) {
+; CHECK-LABEL: mask_cast_extract_v8f64_v4f32_1_z:
+; CHECK:       # BB#0:
+; CHECK-NEXT:    kmovw %edi, %k1
+; CHECK-NEXT:    vextractf32x4 $1, %zmm0, %xmm0 {%k1} {z}
+; CHECK-NEXT:    vzeroupper
+; CHECK-NEXT:    retq
+  %shuffle = shufflevector <8 x double> %a, <8 x double> undef, <2 x i32> <i32 2, i32 3>
+  %shuffle.cast = bitcast <2 x double> %shuffle to <4 x float>
+  %mask.cast = bitcast i8 %mask to <8 x i1>
+  %mask.extract = shufflevector <8 x i1> %mask.cast, <8 x i1> undef, <4 x i32> <i32 0, i32 1, i32 2, i32 3>
+  %res = select <4 x i1> %mask.extract, <4 x float> %shuffle.cast, <4 x float> zeroinitializer
+  ret <4 x float> %res
+}
+
 define <4 x i64> @mask_cast_extract_v16i32_v4i64_1(<16 x i32> %a, <4 x i64> %passthru, i8 %mask) {
 ; CHECK-LABEL: mask_cast_extract_v16i32_v4i64_1:
 ; CHECK:       # BB#0:
@@ -1092,6 +1148,20 @@ define <4 x i64> @mask_cast_extract_v16i
   ret <4 x i64> %res
 }
 
+define <4 x i64> @mask_cast_extract_v16i32_v4i64_1_z(<16 x i32> %a, i8 %mask) {
+; CHECK-LABEL: mask_cast_extract_v16i32_v4i64_1_z:
+; CHECK:       # BB#0:
+; CHECK-NEXT:    kmovw %edi, %k1
+; CHECK-NEXT:    vextracti64x4 $1, %zmm0, %ymm0 {%k1} {z}
+; CHECK-NEXT:    retq
+  %shuffle = shufflevector <16 x i32> %a, <16 x i32> undef, <8 x i32> <i32 8, i32 9, i32 10, i32 11, i32 12, i32 13, i32 14, i32 15>
+  %shuffle.cast = bitcast <8 x i32> %shuffle to <4 x i64>
+  %mask.cast = bitcast i8 %mask to <8 x i1>
+  %mask.extract = shufflevector <8 x i1> %mask.cast, <8 x i1> undef, <4 x i32> <i32 0, i32 1, i32 2, i32 3>
+  %res = select <4 x i1> %mask.extract, <4 x i64> %shuffle.cast, <4 x i64> zeroinitializer
+  ret <4 x i64> %res
+}
+
 define <4 x double> @mask_cast_extract_v16f32_v4f64_1(<16 x float> %a, <4 x double> %passthru, i8 %mask) {
 ; CHECK-LABEL: mask_cast_extract_v16f32_v4f64_1:
 ; CHECK:       # BB#0:
@@ -1107,6 +1177,20 @@ define <4 x double> @mask_cast_extract_v
   ret <4 x double> %res
 }
 
+define <4 x double> @mask_cast_extract_v16f32_v4f64_1_z(<16 x float> %a, i8 %mask) {
+; CHECK-LABEL: mask_cast_extract_v16f32_v4f64_1_z:
+; CHECK:       # BB#0:
+; CHECK-NEXT:    kmovw %edi, %k1
+; CHECK-NEXT:    vextractf64x4 $1, %zmm0, %ymm0 {%k1} {z}
+; CHECK-NEXT:    retq
+  %shuffle = shufflevector <16 x float> %a, <16 x float> undef, <8 x i32> <i32 8, i32 9, i32 10, i32 11, i32 12, i32 13, i32 14, i32 15>
+  %shuffle.cast = bitcast <8 x float> %shuffle to <4 x double>
+  %mask.cast = bitcast i8 %mask to <8 x i1>
+  %mask.extract = shufflevector <8 x i1> %mask.cast, <8 x i1> undef, <4 x i32> <i32 0, i32 1, i32 2, i32 3>
+  %res = select <4 x i1> %mask.extract, <4 x double> %shuffle.cast, <4 x double> zeroinitializer
+  ret <4 x double> %res
+}
+
 define <2 x i64> @mask_cast_extract_v16i32_v2i64_1(<16 x i32> %a, <2 x i64> %passthru, i8 %mask) {
 ; CHECK-LABEL: mask_cast_extract_v16i32_v2i64_1:
 ; CHECK:       # BB#0:
@@ -1123,6 +1207,21 @@ define <2 x i64> @mask_cast_extract_v16i
   ret <2 x i64> %res
 }
 
+define <2 x i64> @mask_cast_extract_v16i32_v2i64_1_z(<16 x i32> %a, i8 %mask) {
+; CHECK-LABEL: mask_cast_extract_v16i32_v2i64_1_z:
+; CHECK:       # BB#0:
+; CHECK-NEXT:    kmovw %edi, %k1
+; CHECK-NEXT:    vextracti64x2 $1, %zmm0, %xmm0 {%k1} {z}
+; CHECK-NEXT:    vzeroupper
+; CHECK-NEXT:    retq
+  %shuffle = shufflevector <16 x i32> %a, <16 x i32> undef, <4 x i32> <i32 4, i32 5, i32 6, i32 7>
+  %shuffle.cast = bitcast <4 x i32> %shuffle to <2 x i64>
+  %mask.cast = bitcast i8 %mask to <8 x i1>
+  %mask.extract = shufflevector <8 x i1> %mask.cast, <8 x i1> undef, <2 x i32> <i32 0, i32 1>
+  %res = select <2 x i1> %mask.extract, <2 x i64> %shuffle.cast, <2 x i64> zeroinitializer
+  ret <2 x i64> %res
+}
+
 define <2 x double> @mask_cast_extract_v16f32_v2f64_1(<16 x float> %a, <2 x double> %passthru, i8 %mask) {
 ; CHECK-LABEL: mask_cast_extract_v16f32_v2f64_1:
 ; CHECK:       # BB#0:
@@ -1139,6 +1238,21 @@ define <2 x double> @mask_cast_extract_v
   ret <2 x double> %res
 }
 
+define <2 x double> @mask_cast_extract_v16f32_v2f64_1_z(<16 x float> %a, i8 %mask) {
+; CHECK-LABEL: mask_cast_extract_v16f32_v2f64_1_z:
+; CHECK:       # BB#0:
+; CHECK-NEXT:    kmovw %edi, %k1
+; CHECK-NEXT:    vextractf64x2 $1, %zmm0, %xmm0 {%k1} {z}
+; CHECK-NEXT:    vzeroupper
+; CHECK-NEXT:    retq
+  %shuffle = shufflevector <16 x float> %a, <16 x float> undef, <4 x i32> <i32 4, i32 5, i32 6, i32 7>
+  %shuffle.cast = bitcast <4 x float> %shuffle to <2 x double>
+  %mask.cast = bitcast i8 %mask to <8 x i1>
+  %mask.extract = shufflevector <8 x i1> %mask.cast, <8 x i1> undef, <2 x i32> <i32 0, i32 1>
+  %res = select <2 x i1> %mask.extract, <2 x double> %shuffle.cast, <2 x double> zeroinitializer
+  ret <2 x double> %res
+}
+
 define <2 x double> @broadcast_v4f32_0101_from_v2f32_mask(double* %x, <2 x double> %passthru, i8 %mask) {
 ; CHECK-LABEL: broadcast_v4f32_0101_from_v2f32_mask:
 ; CHECK:       # BB#0:




More information about the llvm-commits mailing list