[llvm] f4e495a - [InstCombine][X86] simplifyX86varShift - convert variable in-range per-element shift amounts to generic shifts (PR40391)

Simon Pilgrim via llvm-commits llvm-commits at lists.llvm.org
Wed Mar 18 04:32:39 PDT 2020


Author: Simon Pilgrim
Date: 2020-03-18T11:26:54Z
New Revision: f4e495a18e80e3d8e9324a5a54966cf3b9780a61

URL: https://github.com/llvm/llvm-project/commit/f4e495a18e80e3d8e9324a5a54966cf3b9780a61
DIFF: https://github.com/llvm/llvm-project/commit/f4e495a18e80e3d8e9324a5a54966cf3b9780a61.diff

LOG: [InstCombine][X86] simplifyX86varShift - convert variable in-range per-element shift amounts to generic shifts (PR40391)

AVX2/AVX512 per-element shifts can be replaced with generic shifts if the shift amounts are guaranteed to be in-range (upper bits are known zero).

Added: 
    

Modified: 
    llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
    llvm/test/Transforms/InstCombine/X86/x86-vector-shifts.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
index 953e5ef66b18..2138017606b7 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
@@ -472,17 +472,29 @@ static Value *simplifyX86varShift(const IntrinsicInst &II,
   }
   assert((LogicalShift || !ShiftLeft) && "Only logical shifts can shift left");
 
-  // Simplify if all shift amounts are constant/undef.
-  auto *CShift = dyn_cast<Constant>(II.getArgOperand(1));
-  if (!CShift)
-    return nullptr;
-
   auto Vec = II.getArgOperand(0);
+  auto Amt = II.getArgOperand(1);
   auto VT = cast<VectorType>(II.getType());
   auto SVT = VT->getVectorElementType();
   int NumElts = VT->getNumElements();
   int BitWidth = SVT->getIntegerBitWidth();
 
+  // If the shift amount is guaranteed to be in-range we can replace it with a
+  // generic shift.
+  APInt UpperBits =
+      APInt::getHighBitsSet(BitWidth, BitWidth - Log2_32(BitWidth));
+  if (llvm::MaskedValueIsZero(Amt, UpperBits,
+                              II.getModule()->getDataLayout())) {
+    return (LogicalShift ? (ShiftLeft ? Builder.CreateShl(Vec, Amt)
+                                      : Builder.CreateLShr(Vec, Amt))
+                         : Builder.CreateAShr(Vec, Amt));
+  }
+
+  // Simplify if all shift amounts are constant/undef.
+  auto *CShift = dyn_cast<Constant>(Amt);
+  if (!CShift)
+    return nullptr;
+
   // Collect each element's shift amount.
   // We also collect special cases: UNDEF = -1, OUT-OF-RANGE = BitWidth.
   bool AnyOutOfRange = false;

diff  --git a/llvm/test/Transforms/InstCombine/X86/x86-vector-shifts.ll b/llvm/test/Transforms/InstCombine/X86/x86-vector-shifts.ll
index f14bcdf4f478..9794573fdd9e 100644
--- a/llvm/test/Transforms/InstCombine/X86/x86-vector-shifts.ll
+++ b/llvm/test/Transforms/InstCombine/X86/x86-vector-shifts.ll
@@ -2681,7 +2681,7 @@ define <32 x i16> @avx512_psllv_w_512_undef(<32 x i16> %v) {
 define <4 x i32> @avx2_psrav_d_128_masked(<4 x i32> %v, <4 x i32> %a) {
 ; CHECK-LABEL: @avx2_psrav_d_128_masked(
 ; CHECK-NEXT:    [[TMP1:%.*]] = and <4 x i32> [[A:%.*]], <i32 31, i32 31, i32 31, i32 31>
-; CHECK-NEXT:    [[TMP2:%.*]] = tail call <4 x i32> @llvm.x86.avx2.psrav.d(<4 x i32> [[V:%.*]], <4 x i32> [[TMP1]])
+; CHECK-NEXT:    [[TMP2:%.*]] = ashr <4 x i32> [[V:%.*]], [[TMP1]]
 ; CHECK-NEXT:    ret <4 x i32> [[TMP2]]
 ;
   %1 = and <4 x i32> %a, <i32 31, i32 31, i32 31, i32 31>
@@ -2692,7 +2692,7 @@ define <4 x i32> @avx2_psrav_d_128_masked(<4 x i32> %v, <4 x i32> %a) {
 define <8 x i32> @avx2_psrav_d_256_masked(<8 x i32> %v, <8 x i32> %a) {
 ; CHECK-LABEL: @avx2_psrav_d_256_masked(
 ; CHECK-NEXT:    [[TMP1:%.*]] = and <8 x i32> [[A:%.*]], <i32 0, i32 1, i32 7, i32 15, i32 16, i32 30, i32 31, i32 31>
-; CHECK-NEXT:    [[TMP2:%.*]] = tail call <8 x i32> @llvm.x86.avx2.psrav.d.256(<8 x i32> [[V:%.*]], <8 x i32> [[TMP1]])
+; CHECK-NEXT:    [[TMP2:%.*]] = ashr <8 x i32> [[V:%.*]], [[TMP1]]
 ; CHECK-NEXT:    ret <8 x i32> [[TMP2]]
 ;
   %1 = and <8 x i32> %a, <i32 0, i32 1, i32 7, i32 15, i32 16, i32 30, i32 31, i32 31>
@@ -2703,7 +2703,7 @@ define <8 x i32> @avx2_psrav_d_256_masked(<8 x i32> %v, <8 x i32> %a) {
 define <32 x i16> @avx512_psrav_w_512_masked(<32 x i16> %v, <32 x i16> %a) {
 ; CHECK-LABEL: @avx512_psrav_w_512_masked(
 ; CHECK-NEXT:    [[TMP1:%.*]] = and <32 x i16> [[A:%.*]], <i16 0, i16 1, i16 2, i16 3, i16 4, i16 5, i16 6, i16 7, i16 0, i16 1, i16 2, i16 3, i16 4, i16 5, i16 6, i16 7, i16 0, i16 1, i16 2, i16 3, i16 4, i16 5, i16 6, i16 7, i16 0, i16 1, i16 2, i16 3, i16 4, i16 5, i16 6, i16 7>
-; CHECK-NEXT:    [[TMP2:%.*]] = tail call <32 x i16> @llvm.x86.avx512.psrav.w.512(<32 x i16> [[V:%.*]], <32 x i16> [[TMP1]])
+; CHECK-NEXT:    [[TMP2:%.*]] = ashr <32 x i16> [[V:%.*]], [[TMP1]]
 ; CHECK-NEXT:    ret <32 x i16> [[TMP2]]
 ;
   %1 = and <32 x i16> %a, <i16 0, i16 1, i16 2, i16 3, i16 4, i16 5, i16 6, i16 7, i16 0, i16 1, i16 2, i16 3, i16 4, i16 5, i16 6, i16 7, i16 0, i16 1, i16 2, i16 3, i16 4, i16 5, i16 6, i16 7, i16 0, i16 1, i16 2, i16 3, i16 4, i16 5, i16 6, i16 7>
@@ -2714,7 +2714,7 @@ define <32 x i16> @avx512_psrav_w_512_masked(<32 x i16> %v, <32 x i16> %a) {
 define <2 x i64> @avx2_psrlv_q_128_masked(<2 x i64> %v, <2 x i64> %a) {
 ; CHECK-LABEL: @avx2_psrlv_q_128_masked(
 ; CHECK-NEXT:    [[TMP1:%.*]] = and <2 x i64> [[A:%.*]], <i64 32, i64 63>
-; CHECK-NEXT:    [[TMP2:%.*]] = tail call <2 x i64> @llvm.x86.avx2.psrlv.q(<2 x i64> [[V:%.*]], <2 x i64> [[TMP1]])
+; CHECK-NEXT:    [[TMP2:%.*]] = lshr <2 x i64> [[V:%.*]], [[TMP1]]
 ; CHECK-NEXT:    ret <2 x i64> [[TMP2]]
 ;
   %1 = and <2 x i64> %a, <i64 32, i64 63>
@@ -2725,7 +2725,7 @@ define <2 x i64> @avx2_psrlv_q_128_masked(<2 x i64> %v, <2 x i64> %a) {
 define <8 x i32> @avx2_psrlv_d_256_masked(<8 x i32> %v, <8 x i32> %a) {
 ; CHECK-LABEL: @avx2_psrlv_d_256_masked(
 ; CHECK-NEXT:    [[TMP1:%.*]] = and <8 x i32> [[A:%.*]], <i32 0, i32 1, i32 7, i32 15, i32 16, i32 30, i32 31, i32 31>
-; CHECK-NEXT:    [[TMP2:%.*]] = tail call <8 x i32> @llvm.x86.avx2.psrlv.d.256(<8 x i32> [[V:%.*]], <8 x i32> [[TMP1]])
+; CHECK-NEXT:    [[TMP2:%.*]] = lshr <8 x i32> [[V:%.*]], [[TMP1]]
 ; CHECK-NEXT:    ret <8 x i32> [[TMP2]]
 ;
   %1 = and <8 x i32> %a, <i32 0, i32 1, i32 7, i32 15, i32 16, i32 30, i32 31, i32 31>
@@ -2736,7 +2736,7 @@ define <8 x i32> @avx2_psrlv_d_256_masked(<8 x i32> %v, <8 x i32> %a) {
 define <8 x i64> @avx512_psrlv_q_512_masked(<8 x i64> %v, <8 x i64> %a) {
 ; CHECK-LABEL: @avx512_psrlv_q_512_masked(
 ; CHECK-NEXT:    [[TMP1:%.*]] = and <8 x i64> [[A:%.*]], <i64 0, i64 1, i64 4, i64 16, i64 32, i64 47, i64 62, i64 63>
-; CHECK-NEXT:    [[TMP2:%.*]] = tail call <8 x i64> @llvm.x86.avx512.psrlv.q.512(<8 x i64> [[V:%.*]], <8 x i64> [[TMP1]])
+; CHECK-NEXT:    [[TMP2:%.*]] = lshr <8 x i64> [[V:%.*]], [[TMP1]]
 ; CHECK-NEXT:    ret <8 x i64> [[TMP2]]
 ;
   %1 = and <8 x i64> %a, <i64 0, i64 1, i64 4, i64 16, i64 32, i64 47, i64 62, i64 63>
@@ -2747,7 +2747,7 @@ define <8 x i64> @avx512_psrlv_q_512_masked(<8 x i64> %v, <8 x i64> %a) {
 define <4 x i32> @avx2_psllv_d_128_masked(<4 x i32> %v, <4 x i32> %a) {
 ; CHECK-LABEL: @avx2_psllv_d_128_masked(
 ; CHECK-NEXT:    [[TMP1:%.*]] = and <4 x i32> [[A:%.*]], <i32 0, i32 15, i32 16, i32 31>
-; CHECK-NEXT:    [[TMP2:%.*]] = tail call <4 x i32> @llvm.x86.avx2.psllv.d(<4 x i32> [[V:%.*]], <4 x i32> [[TMP1]])
+; CHECK-NEXT:    [[TMP2:%.*]] = shl <4 x i32> [[V:%.*]], [[TMP1]]
 ; CHECK-NEXT:    ret <4 x i32> [[TMP2]]
 ;
   %1 = and <4 x i32> %a, <i32 0, i32 15, i32 16, i32 31>
@@ -2758,7 +2758,7 @@ define <4 x i32> @avx2_psllv_d_128_masked(<4 x i32> %v, <4 x i32> %a) {
 define <4 x i64> @avx2_psllv_q_256_masked(<4 x i64> %v, <4 x i64> %a) {
 ; CHECK-LABEL: @avx2_psllv_q_256_masked(
 ; CHECK-NEXT:    [[TMP1:%.*]] = and <4 x i64> [[A:%.*]], <i64 0, i64 16, i64 32, i64 63>
-; CHECK-NEXT:    [[TMP2:%.*]] = tail call <4 x i64> @llvm.x86.avx2.psllv.q.256(<4 x i64> [[V:%.*]], <4 x i64> [[TMP1]])
+; CHECK-NEXT:    [[TMP2:%.*]] = shl <4 x i64> [[V:%.*]], [[TMP1]]
 ; CHECK-NEXT:    ret <4 x i64> [[TMP2]]
 ;
   %1 = and <4 x i64> %a, <i64 0, i64 16, i64 32, i64 63>
@@ -2769,7 +2769,7 @@ define <4 x i64> @avx2_psllv_q_256_masked(<4 x i64> %v, <4 x i64> %a) {
 define <32 x i16> @avx512_psllv_w_512_masked(<32 x i16> %v, <32 x i16> %a) {
 ; CHECK-LABEL: @avx512_psllv_w_512_masked(
 ; CHECK-NEXT:    [[TMP1:%.*]] = and <32 x i16> [[A:%.*]], <i16 0, i16 1, i16 2, i16 3, i16 4, i16 5, i16 6, i16 7, i16 0, i16 1, i16 2, i16 3, i16 4, i16 5, i16 6, i16 7, i16 0, i16 1, i16 2, i16 3, i16 4, i16 5, i16 6, i16 7, i16 0, i16 1, i16 2, i16 3, i16 4, i16 5, i16 6, i16 7>
-; CHECK-NEXT:    [[TMP2:%.*]] = tail call <32 x i16> @llvm.x86.avx512.psllv.w.512(<32 x i16> [[V:%.*]], <32 x i16> [[TMP1]])
+; CHECK-NEXT:    [[TMP2:%.*]] = shl <32 x i16> [[V:%.*]], [[TMP1]]
 ; CHECK-NEXT:    ret <32 x i16> [[TMP2]]
 ;
   %1 = and <32 x i16> %a, <i16 0, i16 1, i16 2, i16 3, i16 4, i16 5, i16 6, i16 7, i16 0, i16 1, i16 2, i16 3, i16 4, i16 5, i16 6, i16 7, i16 0, i16 1, i16 2, i16 3, i16 4, i16 5, i16 6, i16 7, i16 0, i16 1, i16 2, i16 3, i16 4, i16 5, i16 6, i16 7>


        


More information about the llvm-commits mailing list