[llvm] 34659de - [InstCombine][X86] simplifyX86immShift - convert variable in-range vector shift by scalar amounts to generic shifts (PR40391)

Simon Pilgrim via llvm-commits llvm-commits at lists.llvm.org
Fri Mar 20 08:48:28 PDT 2020


Author: Simon Pilgrim
Date: 2020-03-20T15:48:06Z
New Revision: 34659de5fdd18468f062f301e64d7c883c9a5f14

URL: https://github.com/llvm/llvm-project/commit/34659de5fdd18468f062f301e64d7c883c9a5f14
DIFF: https://github.com/llvm/llvm-project/commit/34659de5fdd18468f062f301e64d7c883c9a5f14.diff

LOG: [InstCombine][X86] simplifyX86immShift - convert variable in-range vector shift by scalar amounts to generic shifts (PR40391)

The sll/srl/sra scalar vector shifts can be replaced with generic shifts if the shift amount is known to be in range.

This also required public DemandedElts variants of llvm::computeKnownBits to be exposed (PR36319).

Added: 
    

Modified: 
    llvm/include/llvm/Analysis/ValueTracking.h
    llvm/lib/Analysis/ValueTracking.cpp
    llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
    llvm/test/Transforms/InstCombine/X86/x86-vector-shifts.ll

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/Analysis/ValueTracking.h b/llvm/include/llvm/Analysis/ValueTracking.h
index dc28d2375a78..577c03bb1dd2 100644
--- a/llvm/include/llvm/Analysis/ValueTracking.h
+++ b/llvm/include/llvm/Analysis/ValueTracking.h
@@ -59,6 +59,22 @@ class Value;
                         OptimizationRemarkEmitter *ORE = nullptr,
                         bool UseInstrInfo = true);
 
+  /// Determine which bits of V are known to be either zero or one and return
+  /// them in the KnownZero/KnownOne bit sets.
+  ///
+  /// This function is defined on values with integer type, values with pointer
+  /// type, and vectors of integers.  In the case
+  /// where V is a vector, the known zero and known one values are the
+  /// same width as the vector element, and the bit is set only if it is true
+  /// for all of the demanded elements in the vector.
+  void computeKnownBits(const Value *V, const APInt &DemandedElts,
+                        KnownBits &Known, const DataLayout &DL,
+                        unsigned Depth = 0, AssumptionCache *AC = nullptr,
+                        const Instruction *CxtI = nullptr,
+                        const DominatorTree *DT = nullptr,
+                        OptimizationRemarkEmitter *ORE = nullptr,
+                        bool UseInstrInfo = true);
+
   /// Returns the known bits rather than passing by reference.
   KnownBits computeKnownBits(const Value *V, const DataLayout &DL,
                              unsigned Depth = 0, AssumptionCache *AC = nullptr,
@@ -67,6 +83,15 @@ class Value;
                              OptimizationRemarkEmitter *ORE = nullptr,
                              bool UseInstrInfo = true);
 
+  /// Returns the known bits rather than passing by reference.
+  KnownBits computeKnownBits(const Value *V, const APInt &DemandedElts,
+                             const DataLayout &DL, unsigned Depth = 0,
+                             AssumptionCache *AC = nullptr,
+                             const Instruction *CxtI = nullptr,
+                             const DominatorTree *DT = nullptr,
+                             OptimizationRemarkEmitter *ORE = nullptr,
+                             bool UseInstrInfo = true);
+
   /// Compute known bits from the range metadata.
   /// \p KnownZero the set of bits that are known to be zero
   /// \p KnownOne the set of bits that are known to be one

diff  --git a/llvm/lib/Analysis/ValueTracking.cpp b/llvm/lib/Analysis/ValueTracking.cpp
index b7af0d2b0ce1..773f79c3c4f4 100644
--- a/llvm/lib/Analysis/ValueTracking.cpp
+++ b/llvm/lib/Analysis/ValueTracking.cpp
@@ -215,6 +215,15 @@ void llvm::computeKnownBits(const Value *V, KnownBits &Known,
                      Query(DL, AC, safeCxtI(V, CxtI), DT, UseInstrInfo, ORE));
 }
 
+void llvm::computeKnownBits(const Value *V, const APInt &DemandedElts,
+                            KnownBits &Known, const DataLayout &DL,
+                            unsigned Depth, AssumptionCache *AC,
+                            const Instruction *CxtI, const DominatorTree *DT,
+                            OptimizationRemarkEmitter *ORE, bool UseInstrInfo) {
+  ::computeKnownBits(V, DemandedElts, Known, Depth,
+                     Query(DL, AC, safeCxtI(V, CxtI), DT, UseInstrInfo, ORE));
+}
+
 static KnownBits computeKnownBits(const Value *V, const APInt &DemandedElts,
                                   unsigned Depth, const Query &Q);
 
@@ -231,6 +240,17 @@ KnownBits llvm::computeKnownBits(const Value *V, const DataLayout &DL,
       V, Depth, Query(DL, AC, safeCxtI(V, CxtI), DT, UseInstrInfo, ORE));
 }
 
+KnownBits llvm::computeKnownBits(const Value *V, const APInt &DemandedElts,
+                                 const DataLayout &DL, unsigned Depth,
+                                 AssumptionCache *AC, const Instruction *CxtI,
+                                 const DominatorTree *DT,
+                                 OptimizationRemarkEmitter *ORE,
+                                 bool UseInstrInfo) {
+  return ::computeKnownBits(
+      V, DemandedElts, Depth,
+      Query(DL, AC, safeCxtI(V, CxtI), DT, UseInstrInfo, ORE));
+}
+
 bool llvm::haveNoCommonBitsSet(const Value *LHS, const Value *RHS,
                                const DataLayout &DL, AssumptionCache *AC,
                                const Instruction *CxtI, const DominatorTree *DT,

diff  --git a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
index e0a35e74adbd..25872e85616a 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
@@ -376,6 +376,7 @@ static Value *simplifyX86immShift(const IntrinsicInst &II,
   auto Amt = II.getArgOperand(1);
   auto VT = cast<VectorType>(Vec->getType());
   auto SVT = VT->getElementType();
+  auto AmtVT = Amt->getType();
   unsigned VWidth = VT->getNumElements();
   unsigned BitWidth = SVT->getPrimitiveSizeInBits();
 
@@ -383,7 +384,7 @@ static Value *simplifyX86immShift(const IntrinsicInst &II,
   // generic shift. If its guaranteed to be out of range, logical shifts combine to
   // zero and arithmetic shifts are clamped to (BitWidth - 1).
   if (IsImm) {
-    assert(Amt->getType()->isIntegerTy(32) &&
+    assert(AmtVT ->isIntegerTy(32) &&
            "Unexpected shift-by-immediate type");
     KnownBits KnownAmtBits =
         llvm::computeKnownBits(Amt, II.getModule()->getDataLayout());
@@ -400,6 +401,27 @@ static Value *simplifyX86immShift(const IntrinsicInst &II,
       Amt = ConstantInt::get(SVT, BitWidth - 1);
       return Builder.CreateAShr(Vec, Builder.CreateVectorSplat(VWidth, Amt));
     }
+  } else {
+    // Ensure the first element has an in-range value and the rest of the
+    // elements in the bottom 64 bits are zero.
+    assert(AmtVT->isVectorTy() && AmtVT->getPrimitiveSizeInBits() == 128 &&
+           cast<VectorType>(AmtVT)->getElementType() == SVT &&
+           "Unexpected shift-by-scalar type");
+    unsigned NumAmtElts = cast<VectorType>(AmtVT)->getNumElements();
+    APInt DemandedLower = APInt::getOneBitSet(NumAmtElts, 0);
+    APInt DemandedUpper = APInt::getBitsSet(NumAmtElts, 1, NumAmtElts / 2);
+    KnownBits KnownLowerBits = llvm::computeKnownBits(
+        Amt, DemandedLower, II.getModule()->getDataLayout());
+    KnownBits KnownUpperBits = llvm::computeKnownBits(
+        Amt, DemandedUpper, II.getModule()->getDataLayout());
+    if (KnownLowerBits.getMaxValue().ult(BitWidth) &&
+        (DemandedUpper.isNullValue() || KnownUpperBits.isZero())) {
+      SmallVector<uint32_t, 16> ZeroSplat(VWidth, 0);
+      Amt = Builder.CreateShuffleVector(Amt, Amt, ZeroSplat);
+      return (LogicalShift ? (ShiftLeft ? Builder.CreateShl(Vec, Amt)
+                                        : Builder.CreateLShr(Vec, Amt))
+                           : Builder.CreateAShr(Vec, Amt));
+    }
   }
 
   // Simplify if count is constant vector.

diff  --git a/llvm/test/Transforms/InstCombine/X86/x86-vector-shifts.ll b/llvm/test/Transforms/InstCombine/X86/x86-vector-shifts.ll
index b8fa5c7f4999..8d75cc34be38 100644
--- a/llvm/test/Transforms/InstCombine/X86/x86-vector-shifts.ll
+++ b/llvm/test/Transforms/InstCombine/X86/x86-vector-shifts.ll
@@ -2680,9 +2680,10 @@ define <32 x i16> @avx512_psllv_w_512_undef(<32 x i16> %v) {
 
 define <8 x i16> @sse2_psra_w_128_masked(<8 x i16> %v, <8 x i16> %a) {
 ; CHECK-LABEL: @sse2_psra_w_128_masked(
-; CHECK-NEXT:    [[TMP1:%.*]] = and <8 x i16> [[A:%.*]], <i16 15, i16 0, i16 0, i16 0, i16 undef, i16 undef, i16 undef, i16 undef>
-; CHECK-NEXT:    [[TMP2:%.*]] = tail call <8 x i16> @llvm.x86.sse2.psra.w(<8 x i16> [[V:%.*]], <8 x i16> [[TMP1]])
-; CHECK-NEXT:    ret <8 x i16> [[TMP2]]
+; CHECK-NEXT:    [[TMP1:%.*]] = and <8 x i16> [[A:%.*]], <i16 15, i16 undef, i16 undef, i16 undef, i16 undef, i16 undef, i16 undef, i16 undef>
+; CHECK-NEXT:    [[TMP2:%.*]] = shufflevector <8 x i16> [[TMP1]], <8 x i16> undef, <8 x i32> zeroinitializer
+; CHECK-NEXT:    [[TMP3:%.*]] = ashr <8 x i16> [[V:%.*]], [[TMP2]]
+; CHECK-NEXT:    ret <8 x i16> [[TMP3]]
 ;
   %1 = and <8 x i16> %a, <i16 15, i16 0, i16 0, i16 0, i16 undef, i16 undef, i16 undef, i16 undef>
   %2 = tail call <8 x i16> @llvm.x86.sse2.psra.w(<8 x i16> %v, <8 x i16> %1)
@@ -2691,9 +2692,10 @@ define <8 x i16> @sse2_psra_w_128_masked(<8 x i16> %v, <8 x i16> %a) {
 
 define <8 x i32> @avx2_psra_d_256_masked(<8 x i32> %v, <4 x i32> %a) {
 ; CHECK-LABEL: @avx2_psra_d_256_masked(
-; CHECK-NEXT:    [[TMP1:%.*]] = and <4 x i32> [[A:%.*]], <i32 31, i32 0, i32 undef, i32 undef>
-; CHECK-NEXT:    [[TMP2:%.*]] = tail call <8 x i32> @llvm.x86.avx2.psra.d(<8 x i32> [[V:%.*]], <4 x i32> [[TMP1]])
-; CHECK-NEXT:    ret <8 x i32> [[TMP2]]
+; CHECK-NEXT:    [[TMP1:%.*]] = and <4 x i32> [[A:%.*]], <i32 31, i32 undef, i32 undef, i32 undef>
+; CHECK-NEXT:    [[TMP2:%.*]] = shufflevector <4 x i32> [[TMP1]], <4 x i32> undef, <8 x i32> zeroinitializer
+; CHECK-NEXT:    [[TMP3:%.*]] = ashr <8 x i32> [[V:%.*]], [[TMP2]]
+; CHECK-NEXT:    ret <8 x i32> [[TMP3]]
 ;
   %1 = and <4 x i32> %a, <i32 31, i32 0, i32 undef, i32 undef>
   %2 = tail call <8 x i32> @llvm.x86.avx2.psra.d(<8 x i32> %v, <4 x i32> %1)
@@ -2703,8 +2705,9 @@ define <8 x i32> @avx2_psra_d_256_masked(<8 x i32> %v, <4 x i32> %a) {
 define <8 x i64> @avx512_psra_q_512_masked(<8 x i64> %v, <2 x i64> %a) {
 ; CHECK-LABEL: @avx512_psra_q_512_masked(
 ; CHECK-NEXT:    [[TMP1:%.*]] = and <2 x i64> [[A:%.*]], <i64 63, i64 undef>
-; CHECK-NEXT:    [[TMP2:%.*]] = tail call <8 x i64> @llvm.x86.avx512.psra.q.512(<8 x i64> [[V:%.*]], <2 x i64> [[TMP1]])
-; CHECK-NEXT:    ret <8 x i64> [[TMP2]]
+; CHECK-NEXT:    [[TMP2:%.*]] = shufflevector <2 x i64> [[TMP1]], <2 x i64> undef, <8 x i32> zeroinitializer
+; CHECK-NEXT:    [[TMP3:%.*]] = ashr <8 x i64> [[V:%.*]], [[TMP2]]
+; CHECK-NEXT:    ret <8 x i64> [[TMP3]]
 ;
   %1 = and <2 x i64> %a, <i64 63, i64 undef>
   %2 = tail call <8 x i64> @llvm.x86.avx512.psra.q.512(<8 x i64> %v, <2 x i64> %1)
@@ -2713,9 +2716,10 @@ define <8 x i64> @avx512_psra_q_512_masked(<8 x i64> %v, <2 x i64> %a) {
 
 define <4 x i32> @sse2_psrl_d_128_masked(<4 x i32> %v, <4 x i32> %a) {
 ; CHECK-LABEL: @sse2_psrl_d_128_masked(
-; CHECK-NEXT:    [[TMP1:%.*]] = and <4 x i32> [[A:%.*]], <i32 31, i32 0, i32 undef, i32 undef>
-; CHECK-NEXT:    [[TMP2:%.*]] = tail call <4 x i32> @llvm.x86.sse2.psrl.d(<4 x i32> [[V:%.*]], <4 x i32> [[TMP1]])
-; CHECK-NEXT:    ret <4 x i32> [[TMP2]]
+; CHECK-NEXT:    [[TMP1:%.*]] = and <4 x i32> [[A:%.*]], <i32 31, i32 undef, i32 undef, i32 undef>
+; CHECK-NEXT:    [[TMP2:%.*]] = shufflevector <4 x i32> [[TMP1]], <4 x i32> undef, <4 x i32> zeroinitializer
+; CHECK-NEXT:    [[TMP3:%.*]] = lshr <4 x i32> [[V:%.*]], [[TMP2]]
+; CHECK-NEXT:    ret <4 x i32> [[TMP3]]
 ;
   %1 = and <4 x i32> %a, <i32 31, i32 0, i32 undef, i32 undef>
   %2 = tail call <4 x i32> @llvm.x86.sse2.psrl.d(<4 x i32> %v, <4 x i32> %1)
@@ -2725,8 +2729,9 @@ define <4 x i32> @sse2_psrl_d_128_masked(<4 x i32> %v, <4 x i32> %a) {
 define <4 x i64> @avx2_psrl_q_256_masked(<4 x i64> %v, <2 x i64> %a) {
 ; CHECK-LABEL: @avx2_psrl_q_256_masked(
 ; CHECK-NEXT:    [[TMP1:%.*]] = and <2 x i64> [[A:%.*]], <i64 63, i64 undef>
-; CHECK-NEXT:    [[TMP2:%.*]] = tail call <4 x i64> @llvm.x86.avx2.psrl.q(<4 x i64> [[V:%.*]], <2 x i64> [[TMP1]])
-; CHECK-NEXT:    ret <4 x i64> [[TMP2]]
+; CHECK-NEXT:    [[TMP2:%.*]] = shufflevector <2 x i64> [[TMP1]], <2 x i64> undef, <4 x i32> zeroinitializer
+; CHECK-NEXT:    [[TMP3:%.*]] = lshr <4 x i64> [[V:%.*]], [[TMP2]]
+; CHECK-NEXT:    ret <4 x i64> [[TMP3]]
 ;
   %1 = and <2 x i64> %a, <i64 63, i64 undef>
   %2 = tail call <4 x i64> @llvm.x86.avx2.psrl.q(<4 x i64> %v, <2 x i64> %1)
@@ -2735,9 +2740,10 @@ define <4 x i64> @avx2_psrl_q_256_masked(<4 x i64> %v, <2 x i64> %a) {
 
 define <32 x i16> @avx512_psrl_w_512_masked(<32 x i16> %v, <8 x i16> %a) {
 ; CHECK-LABEL: @avx512_psrl_w_512_masked(
-; CHECK-NEXT:    [[TMP1:%.*]] = and <8 x i16> [[A:%.*]], <i16 15, i16 0, i16 0, i16 0, i16 undef, i16 undef, i16 undef, i16 undef>
-; CHECK-NEXT:    [[TMP2:%.*]] = tail call <32 x i16> @llvm.x86.avx512.psrl.w.512(<32 x i16> [[V:%.*]], <8 x i16> [[TMP1]])
-; CHECK-NEXT:    ret <32 x i16> [[TMP2]]
+; CHECK-NEXT:    [[TMP1:%.*]] = and <8 x i16> [[A:%.*]], <i16 15, i16 undef, i16 undef, i16 undef, i16 undef, i16 undef, i16 undef, i16 undef>
+; CHECK-NEXT:    [[TMP2:%.*]] = shufflevector <8 x i16> [[TMP1]], <8 x i16> undef, <32 x i32> zeroinitializer
+; CHECK-NEXT:    [[TMP3:%.*]] = lshr <32 x i16> [[V:%.*]], [[TMP2]]
+; CHECK-NEXT:    ret <32 x i16> [[TMP3]]
 ;
   %1 = and <8 x i16> %a, <i16 15, i16 0, i16 0, i16 0, i16 undef, i16 undef, i16 undef, i16 undef>
   %2 = tail call <32 x i16> @llvm.x86.avx512.psrl.w.512(<32 x i16> %v, <8 x i16> %1)
@@ -2747,8 +2753,9 @@ define <32 x i16> @avx512_psrl_w_512_masked(<32 x i16> %v, <8 x i16> %a) {
 define <2 x i64> @sse2_psll_q_128_masked(<2 x i64> %v, <2 x i64> %a) {
 ; CHECK-LABEL: @sse2_psll_q_128_masked(
 ; CHECK-NEXT:    [[TMP1:%.*]] = and <2 x i64> [[A:%.*]], <i64 63, i64 undef>
-; CHECK-NEXT:    [[TMP2:%.*]] = tail call <2 x i64> @llvm.x86.sse2.psll.q(<2 x i64> [[V:%.*]], <2 x i64> [[TMP1]])
-; CHECK-NEXT:    ret <2 x i64> [[TMP2]]
+; CHECK-NEXT:    [[TMP2:%.*]] = shufflevector <2 x i64> [[TMP1]], <2 x i64> undef, <2 x i32> zeroinitializer
+; CHECK-NEXT:    [[TMP3:%.*]] = shl <2 x i64> [[V:%.*]], [[TMP2]]
+; CHECK-NEXT:    ret <2 x i64> [[TMP3]]
 ;
   %1 = and <2 x i64> %a, <i64 63, i64 undef>
   %2 = tail call <2 x i64> @llvm.x86.sse2.psll.q(<2 x i64> %v, <2 x i64> %1)
@@ -2757,9 +2764,10 @@ define <2 x i64> @sse2_psll_q_128_masked(<2 x i64> %v, <2 x i64> %a) {
 
 define <16 x i16> @avx2_psll_w_256_masked(<16 x i16> %v, <8 x i16> %a) {
 ; CHECK-LABEL: @avx2_psll_w_256_masked(
-; CHECK-NEXT:    [[TMP1:%.*]] = and <8 x i16> [[A:%.*]], <i16 15, i16 0, i16 0, i16 0, i16 undef, i16 undef, i16 undef, i16 undef>
-; CHECK-NEXT:    [[TMP2:%.*]] = tail call <16 x i16> @llvm.x86.avx2.psll.w(<16 x i16> [[V:%.*]], <8 x i16> [[TMP1]])
-; CHECK-NEXT:    ret <16 x i16> [[TMP2]]
+; CHECK-NEXT:    [[TMP1:%.*]] = and <8 x i16> [[A:%.*]], <i16 15, i16 undef, i16 undef, i16 undef, i16 undef, i16 undef, i16 undef, i16 undef>
+; CHECK-NEXT:    [[TMP2:%.*]] = shufflevector <8 x i16> [[TMP1]], <8 x i16> undef, <16 x i32> zeroinitializer
+; CHECK-NEXT:    [[TMP3:%.*]] = shl <16 x i16> [[V:%.*]], [[TMP2]]
+; CHECK-NEXT:    ret <16 x i16> [[TMP3]]
 ;
   %1 = and <8 x i16> %a, <i16 15, i16 0, i16 0, i16 0, i16 undef, i16 undef, i16 undef, i16 undef>
   %2 = tail call <16 x i16> @llvm.x86.avx2.psll.w(<16 x i16> %v, <8 x i16> %1)
@@ -2768,9 +2776,10 @@ define <16 x i16> @avx2_psll_w_256_masked(<16 x i16> %v, <8 x i16> %a) {
 
 define <16 x i32> @avx512_psll_d_512_masked(<16 x i32> %v, <4 x i32> %a) {
 ; CHECK-LABEL: @avx512_psll_d_512_masked(
-; CHECK-NEXT:    [[TMP1:%.*]] = and <4 x i32> [[A:%.*]], <i32 31, i32 0, i32 undef, i32 undef>
-; CHECK-NEXT:    [[TMP2:%.*]] = tail call <16 x i32> @llvm.x86.avx512.psll.d.512(<16 x i32> [[V:%.*]], <4 x i32> [[TMP1]])
-; CHECK-NEXT:    ret <16 x i32> [[TMP2]]
+; CHECK-NEXT:    [[TMP1:%.*]] = and <4 x i32> [[A:%.*]], <i32 31, i32 undef, i32 undef, i32 undef>
+; CHECK-NEXT:    [[TMP2:%.*]] = shufflevector <4 x i32> [[TMP1]], <4 x i32> undef, <16 x i32> zeroinitializer
+; CHECK-NEXT:    [[TMP3:%.*]] = shl <16 x i32> [[V:%.*]], [[TMP2]]
+; CHECK-NEXT:    ret <16 x i32> [[TMP3]]
 ;
   %1 = and <4 x i32> %a, <i32 31, i32 0, i32 undef, i32 undef>
   %2 = tail call <16 x i32> @llvm.x86.avx512.psll.d.512(<16 x i32> %v, <4 x i32> %1)


        


More information about the llvm-commits mailing list