[llvm] 34659de - [InstCombine][X86] simplifyX86immShift - convert variable in-range vector shift by scalar amounts to generic shifts (PR40391)
Simon Pilgrim via llvm-commits
llvm-commits at lists.llvm.org
Fri Mar 20 08:48:28 PDT 2020
Author: Simon Pilgrim
Date: 2020-03-20T15:48:06Z
New Revision: 34659de5fdd18468f062f301e64d7c883c9a5f14
URL: https://github.com/llvm/llvm-project/commit/34659de5fdd18468f062f301e64d7c883c9a5f14
DIFF: https://github.com/llvm/llvm-project/commit/34659de5fdd18468f062f301e64d7c883c9a5f14.diff
LOG: [InstCombine][X86] simplifyX86immShift - convert variable in-range vector shift by scalar amounts to generic shifts (PR40391)
The sll/srl/sra scalar vector shifts can be replaced with generic shifts if the shift amount is known to be in range.
This also required public DemandedElts variants of llvm::computeKnownBits to be exposed (PR36319).
Added:
Modified:
llvm/include/llvm/Analysis/ValueTracking.h
llvm/lib/Analysis/ValueTracking.cpp
llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
llvm/test/Transforms/InstCombine/X86/x86-vector-shifts.ll
Removed:
################################################################################
diff --git a/llvm/include/llvm/Analysis/ValueTracking.h b/llvm/include/llvm/Analysis/ValueTracking.h
index dc28d2375a78..577c03bb1dd2 100644
--- a/llvm/include/llvm/Analysis/ValueTracking.h
+++ b/llvm/include/llvm/Analysis/ValueTracking.h
@@ -59,6 +59,22 @@ class Value;
OptimizationRemarkEmitter *ORE = nullptr,
bool UseInstrInfo = true);
+ /// Determine which bits of V are known to be either zero or one and return
+ /// them in the KnownZero/KnownOne bit sets.
+ ///
+ /// This function is defined on values with integer type, values with pointer
+ /// type, and vectors of integers. In the case
+ /// where V is a vector, the known zero and known one values are the
+ /// same width as the vector element, and the bit is set only if it is true
+ /// for all of the demanded elements in the vector.
+ void computeKnownBits(const Value *V, const APInt &DemandedElts,
+ KnownBits &Known, const DataLayout &DL,
+ unsigned Depth = 0, AssumptionCache *AC = nullptr,
+ const Instruction *CxtI = nullptr,
+ const DominatorTree *DT = nullptr,
+ OptimizationRemarkEmitter *ORE = nullptr,
+ bool UseInstrInfo = true);
+
/// Returns the known bits rather than passing by reference.
KnownBits computeKnownBits(const Value *V, const DataLayout &DL,
unsigned Depth = 0, AssumptionCache *AC = nullptr,
@@ -67,6 +83,15 @@ class Value;
OptimizationRemarkEmitter *ORE = nullptr,
bool UseInstrInfo = true);
+ /// Returns the known bits rather than passing by reference.
+ KnownBits computeKnownBits(const Value *V, const APInt &DemandedElts,
+ const DataLayout &DL, unsigned Depth = 0,
+ AssumptionCache *AC = nullptr,
+ const Instruction *CxtI = nullptr,
+ const DominatorTree *DT = nullptr,
+ OptimizationRemarkEmitter *ORE = nullptr,
+ bool UseInstrInfo = true);
+
/// Compute known bits from the range metadata.
/// \p KnownZero the set of bits that are known to be zero
/// \p KnownOne the set of bits that are known to be one
diff --git a/llvm/lib/Analysis/ValueTracking.cpp b/llvm/lib/Analysis/ValueTracking.cpp
index b7af0d2b0ce1..773f79c3c4f4 100644
--- a/llvm/lib/Analysis/ValueTracking.cpp
+++ b/llvm/lib/Analysis/ValueTracking.cpp
@@ -215,6 +215,15 @@ void llvm::computeKnownBits(const Value *V, KnownBits &Known,
Query(DL, AC, safeCxtI(V, CxtI), DT, UseInstrInfo, ORE));
}
+void llvm::computeKnownBits(const Value *V, const APInt &DemandedElts,
+ KnownBits &Known, const DataLayout &DL,
+ unsigned Depth, AssumptionCache *AC,
+ const Instruction *CxtI, const DominatorTree *DT,
+ OptimizationRemarkEmitter *ORE, bool UseInstrInfo) {
+ ::computeKnownBits(V, DemandedElts, Known, Depth,
+ Query(DL, AC, safeCxtI(V, CxtI), DT, UseInstrInfo, ORE));
+}
+
static KnownBits computeKnownBits(const Value *V, const APInt &DemandedElts,
unsigned Depth, const Query &Q);
@@ -231,6 +240,17 @@ KnownBits llvm::computeKnownBits(const Value *V, const DataLayout &DL,
V, Depth, Query(DL, AC, safeCxtI(V, CxtI), DT, UseInstrInfo, ORE));
}
+KnownBits llvm::computeKnownBits(const Value *V, const APInt &DemandedElts,
+ const DataLayout &DL, unsigned Depth,
+ AssumptionCache *AC, const Instruction *CxtI,
+ const DominatorTree *DT,
+ OptimizationRemarkEmitter *ORE,
+ bool UseInstrInfo) {
+ return ::computeKnownBits(
+ V, DemandedElts, Depth,
+ Query(DL, AC, safeCxtI(V, CxtI), DT, UseInstrInfo, ORE));
+}
+
bool llvm::haveNoCommonBitsSet(const Value *LHS, const Value *RHS,
const DataLayout &DL, AssumptionCache *AC,
const Instruction *CxtI, const DominatorTree *DT,
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
index e0a35e74adbd..25872e85616a 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
@@ -376,6 +376,7 @@ static Value *simplifyX86immShift(const IntrinsicInst &II,
auto Amt = II.getArgOperand(1);
auto VT = cast<VectorType>(Vec->getType());
auto SVT = VT->getElementType();
+ auto AmtVT = Amt->getType();
unsigned VWidth = VT->getNumElements();
unsigned BitWidth = SVT->getPrimitiveSizeInBits();
@@ -383,7 +384,7 @@ static Value *simplifyX86immShift(const IntrinsicInst &II,
// generic shift. If its guaranteed to be out of range, logical shifts combine to
// zero and arithmetic shifts are clamped to (BitWidth - 1).
if (IsImm) {
- assert(Amt->getType()->isIntegerTy(32) &&
+ assert(AmtVT ->isIntegerTy(32) &&
"Unexpected shift-by-immediate type");
KnownBits KnownAmtBits =
llvm::computeKnownBits(Amt, II.getModule()->getDataLayout());
@@ -400,6 +401,27 @@ static Value *simplifyX86immShift(const IntrinsicInst &II,
Amt = ConstantInt::get(SVT, BitWidth - 1);
return Builder.CreateAShr(Vec, Builder.CreateVectorSplat(VWidth, Amt));
}
+ } else {
+ // Ensure the first element has an in-range value and the rest of the
+ // elements in the bottom 64 bits are zero.
+ assert(AmtVT->isVectorTy() && AmtVT->getPrimitiveSizeInBits() == 128 &&
+ cast<VectorType>(AmtVT)->getElementType() == SVT &&
+ "Unexpected shift-by-scalar type");
+ unsigned NumAmtElts = cast<VectorType>(AmtVT)->getNumElements();
+ APInt DemandedLower = APInt::getOneBitSet(NumAmtElts, 0);
+ APInt DemandedUpper = APInt::getBitsSet(NumAmtElts, 1, NumAmtElts / 2);
+ KnownBits KnownLowerBits = llvm::computeKnownBits(
+ Amt, DemandedLower, II.getModule()->getDataLayout());
+ KnownBits KnownUpperBits = llvm::computeKnownBits(
+ Amt, DemandedUpper, II.getModule()->getDataLayout());
+ if (KnownLowerBits.getMaxValue().ult(BitWidth) &&
+ (DemandedUpper.isNullValue() || KnownUpperBits.isZero())) {
+ SmallVector<uint32_t, 16> ZeroSplat(VWidth, 0);
+ Amt = Builder.CreateShuffleVector(Amt, Amt, ZeroSplat);
+ return (LogicalShift ? (ShiftLeft ? Builder.CreateShl(Vec, Amt)
+ : Builder.CreateLShr(Vec, Amt))
+ : Builder.CreateAShr(Vec, Amt));
+ }
}
// Simplify if count is constant vector.
diff --git a/llvm/test/Transforms/InstCombine/X86/x86-vector-shifts.ll b/llvm/test/Transforms/InstCombine/X86/x86-vector-shifts.ll
index b8fa5c7f4999..8d75cc34be38 100644
--- a/llvm/test/Transforms/InstCombine/X86/x86-vector-shifts.ll
+++ b/llvm/test/Transforms/InstCombine/X86/x86-vector-shifts.ll
@@ -2680,9 +2680,10 @@ define <32 x i16> @avx512_psllv_w_512_undef(<32 x i16> %v) {
define <8 x i16> @sse2_psra_w_128_masked(<8 x i16> %v, <8 x i16> %a) {
; CHECK-LABEL: @sse2_psra_w_128_masked(
-; CHECK-NEXT: [[TMP1:%.*]] = and <8 x i16> [[A:%.*]], <i16 15, i16 0, i16 0, i16 0, i16 undef, i16 undef, i16 undef, i16 undef>
-; CHECK-NEXT: [[TMP2:%.*]] = tail call <8 x i16> @llvm.x86.sse2.psra.w(<8 x i16> [[V:%.*]], <8 x i16> [[TMP1]])
-; CHECK-NEXT: ret <8 x i16> [[TMP2]]
+; CHECK-NEXT: [[TMP1:%.*]] = and <8 x i16> [[A:%.*]], <i16 15, i16 undef, i16 undef, i16 undef, i16 undef, i16 undef, i16 undef, i16 undef>
+; CHECK-NEXT: [[TMP2:%.*]] = shufflevector <8 x i16> [[TMP1]], <8 x i16> undef, <8 x i32> zeroinitializer
+; CHECK-NEXT: [[TMP3:%.*]] = ashr <8 x i16> [[V:%.*]], [[TMP2]]
+; CHECK-NEXT: ret <8 x i16> [[TMP3]]
;
%1 = and <8 x i16> %a, <i16 15, i16 0, i16 0, i16 0, i16 undef, i16 undef, i16 undef, i16 undef>
%2 = tail call <8 x i16> @llvm.x86.sse2.psra.w(<8 x i16> %v, <8 x i16> %1)
@@ -2691,9 +2692,10 @@ define <8 x i16> @sse2_psra_w_128_masked(<8 x i16> %v, <8 x i16> %a) {
define <8 x i32> @avx2_psra_d_256_masked(<8 x i32> %v, <4 x i32> %a) {
; CHECK-LABEL: @avx2_psra_d_256_masked(
-; CHECK-NEXT: [[TMP1:%.*]] = and <4 x i32> [[A:%.*]], <i32 31, i32 0, i32 undef, i32 undef>
-; CHECK-NEXT: [[TMP2:%.*]] = tail call <8 x i32> @llvm.x86.avx2.psra.d(<8 x i32> [[V:%.*]], <4 x i32> [[TMP1]])
-; CHECK-NEXT: ret <8 x i32> [[TMP2]]
+; CHECK-NEXT: [[TMP1:%.*]] = and <4 x i32> [[A:%.*]], <i32 31, i32 undef, i32 undef, i32 undef>
+; CHECK-NEXT: [[TMP2:%.*]] = shufflevector <4 x i32> [[TMP1]], <4 x i32> undef, <8 x i32> zeroinitializer
+; CHECK-NEXT: [[TMP3:%.*]] = ashr <8 x i32> [[V:%.*]], [[TMP2]]
+; CHECK-NEXT: ret <8 x i32> [[TMP3]]
;
%1 = and <4 x i32> %a, <i32 31, i32 0, i32 undef, i32 undef>
%2 = tail call <8 x i32> @llvm.x86.avx2.psra.d(<8 x i32> %v, <4 x i32> %1)
@@ -2703,8 +2705,9 @@ define <8 x i32> @avx2_psra_d_256_masked(<8 x i32> %v, <4 x i32> %a) {
define <8 x i64> @avx512_psra_q_512_masked(<8 x i64> %v, <2 x i64> %a) {
; CHECK-LABEL: @avx512_psra_q_512_masked(
; CHECK-NEXT: [[TMP1:%.*]] = and <2 x i64> [[A:%.*]], <i64 63, i64 undef>
-; CHECK-NEXT: [[TMP2:%.*]] = tail call <8 x i64> @llvm.x86.avx512.psra.q.512(<8 x i64> [[V:%.*]], <2 x i64> [[TMP1]])
-; CHECK-NEXT: ret <8 x i64> [[TMP2]]
+; CHECK-NEXT: [[TMP2:%.*]] = shufflevector <2 x i64> [[TMP1]], <2 x i64> undef, <8 x i32> zeroinitializer
+; CHECK-NEXT: [[TMP3:%.*]] = ashr <8 x i64> [[V:%.*]], [[TMP2]]
+; CHECK-NEXT: ret <8 x i64> [[TMP3]]
;
%1 = and <2 x i64> %a, <i64 63, i64 undef>
%2 = tail call <8 x i64> @llvm.x86.avx512.psra.q.512(<8 x i64> %v, <2 x i64> %1)
@@ -2713,9 +2716,10 @@ define <8 x i64> @avx512_psra_q_512_masked(<8 x i64> %v, <2 x i64> %a) {
define <4 x i32> @sse2_psrl_d_128_masked(<4 x i32> %v, <4 x i32> %a) {
; CHECK-LABEL: @sse2_psrl_d_128_masked(
-; CHECK-NEXT: [[TMP1:%.*]] = and <4 x i32> [[A:%.*]], <i32 31, i32 0, i32 undef, i32 undef>
-; CHECK-NEXT: [[TMP2:%.*]] = tail call <4 x i32> @llvm.x86.sse2.psrl.d(<4 x i32> [[V:%.*]], <4 x i32> [[TMP1]])
-; CHECK-NEXT: ret <4 x i32> [[TMP2]]
+; CHECK-NEXT: [[TMP1:%.*]] = and <4 x i32> [[A:%.*]], <i32 31, i32 undef, i32 undef, i32 undef>
+; CHECK-NEXT: [[TMP2:%.*]] = shufflevector <4 x i32> [[TMP1]], <4 x i32> undef, <4 x i32> zeroinitializer
+; CHECK-NEXT: [[TMP3:%.*]] = lshr <4 x i32> [[V:%.*]], [[TMP2]]
+; CHECK-NEXT: ret <4 x i32> [[TMP3]]
;
%1 = and <4 x i32> %a, <i32 31, i32 0, i32 undef, i32 undef>
%2 = tail call <4 x i32> @llvm.x86.sse2.psrl.d(<4 x i32> %v, <4 x i32> %1)
@@ -2725,8 +2729,9 @@ define <4 x i32> @sse2_psrl_d_128_masked(<4 x i32> %v, <4 x i32> %a) {
define <4 x i64> @avx2_psrl_q_256_masked(<4 x i64> %v, <2 x i64> %a) {
; CHECK-LABEL: @avx2_psrl_q_256_masked(
; CHECK-NEXT: [[TMP1:%.*]] = and <2 x i64> [[A:%.*]], <i64 63, i64 undef>
-; CHECK-NEXT: [[TMP2:%.*]] = tail call <4 x i64> @llvm.x86.avx2.psrl.q(<4 x i64> [[V:%.*]], <2 x i64> [[TMP1]])
-; CHECK-NEXT: ret <4 x i64> [[TMP2]]
+; CHECK-NEXT: [[TMP2:%.*]] = shufflevector <2 x i64> [[TMP1]], <2 x i64> undef, <4 x i32> zeroinitializer
+; CHECK-NEXT: [[TMP3:%.*]] = lshr <4 x i64> [[V:%.*]], [[TMP2]]
+; CHECK-NEXT: ret <4 x i64> [[TMP3]]
;
%1 = and <2 x i64> %a, <i64 63, i64 undef>
%2 = tail call <4 x i64> @llvm.x86.avx2.psrl.q(<4 x i64> %v, <2 x i64> %1)
@@ -2735,9 +2740,10 @@ define <4 x i64> @avx2_psrl_q_256_masked(<4 x i64> %v, <2 x i64> %a) {
define <32 x i16> @avx512_psrl_w_512_masked(<32 x i16> %v, <8 x i16> %a) {
; CHECK-LABEL: @avx512_psrl_w_512_masked(
-; CHECK-NEXT: [[TMP1:%.*]] = and <8 x i16> [[A:%.*]], <i16 15, i16 0, i16 0, i16 0, i16 undef, i16 undef, i16 undef, i16 undef>
-; CHECK-NEXT: [[TMP2:%.*]] = tail call <32 x i16> @llvm.x86.avx512.psrl.w.512(<32 x i16> [[V:%.*]], <8 x i16> [[TMP1]])
-; CHECK-NEXT: ret <32 x i16> [[TMP2]]
+; CHECK-NEXT: [[TMP1:%.*]] = and <8 x i16> [[A:%.*]], <i16 15, i16 undef, i16 undef, i16 undef, i16 undef, i16 undef, i16 undef, i16 undef>
+; CHECK-NEXT: [[TMP2:%.*]] = shufflevector <8 x i16> [[TMP1]], <8 x i16> undef, <32 x i32> zeroinitializer
+; CHECK-NEXT: [[TMP3:%.*]] = lshr <32 x i16> [[V:%.*]], [[TMP2]]
+; CHECK-NEXT: ret <32 x i16> [[TMP3]]
;
%1 = and <8 x i16> %a, <i16 15, i16 0, i16 0, i16 0, i16 undef, i16 undef, i16 undef, i16 undef>
%2 = tail call <32 x i16> @llvm.x86.avx512.psrl.w.512(<32 x i16> %v, <8 x i16> %1)
@@ -2747,8 +2753,9 @@ define <32 x i16> @avx512_psrl_w_512_masked(<32 x i16> %v, <8 x i16> %a) {
define <2 x i64> @sse2_psll_q_128_masked(<2 x i64> %v, <2 x i64> %a) {
; CHECK-LABEL: @sse2_psll_q_128_masked(
; CHECK-NEXT: [[TMP1:%.*]] = and <2 x i64> [[A:%.*]], <i64 63, i64 undef>
-; CHECK-NEXT: [[TMP2:%.*]] = tail call <2 x i64> @llvm.x86.sse2.psll.q(<2 x i64> [[V:%.*]], <2 x i64> [[TMP1]])
-; CHECK-NEXT: ret <2 x i64> [[TMP2]]
+; CHECK-NEXT: [[TMP2:%.*]] = shufflevector <2 x i64> [[TMP1]], <2 x i64> undef, <2 x i32> zeroinitializer
+; CHECK-NEXT: [[TMP3:%.*]] = shl <2 x i64> [[V:%.*]], [[TMP2]]
+; CHECK-NEXT: ret <2 x i64> [[TMP3]]
;
%1 = and <2 x i64> %a, <i64 63, i64 undef>
%2 = tail call <2 x i64> @llvm.x86.sse2.psll.q(<2 x i64> %v, <2 x i64> %1)
@@ -2757,9 +2764,10 @@ define <2 x i64> @sse2_psll_q_128_masked(<2 x i64> %v, <2 x i64> %a) {
define <16 x i16> @avx2_psll_w_256_masked(<16 x i16> %v, <8 x i16> %a) {
; CHECK-LABEL: @avx2_psll_w_256_masked(
-; CHECK-NEXT: [[TMP1:%.*]] = and <8 x i16> [[A:%.*]], <i16 15, i16 0, i16 0, i16 0, i16 undef, i16 undef, i16 undef, i16 undef>
-; CHECK-NEXT: [[TMP2:%.*]] = tail call <16 x i16> @llvm.x86.avx2.psll.w(<16 x i16> [[V:%.*]], <8 x i16> [[TMP1]])
-; CHECK-NEXT: ret <16 x i16> [[TMP2]]
+; CHECK-NEXT: [[TMP1:%.*]] = and <8 x i16> [[A:%.*]], <i16 15, i16 undef, i16 undef, i16 undef, i16 undef, i16 undef, i16 undef, i16 undef>
+; CHECK-NEXT: [[TMP2:%.*]] = shufflevector <8 x i16> [[TMP1]], <8 x i16> undef, <16 x i32> zeroinitializer
+; CHECK-NEXT: [[TMP3:%.*]] = shl <16 x i16> [[V:%.*]], [[TMP2]]
+; CHECK-NEXT: ret <16 x i16> [[TMP3]]
;
%1 = and <8 x i16> %a, <i16 15, i16 0, i16 0, i16 0, i16 undef, i16 undef, i16 undef, i16 undef>
%2 = tail call <16 x i16> @llvm.x86.avx2.psll.w(<16 x i16> %v, <8 x i16> %1)
@@ -2768,9 +2776,10 @@ define <16 x i16> @avx2_psll_w_256_masked(<16 x i16> %v, <8 x i16> %a) {
define <16 x i32> @avx512_psll_d_512_masked(<16 x i32> %v, <4 x i32> %a) {
; CHECK-LABEL: @avx512_psll_d_512_masked(
-; CHECK-NEXT: [[TMP1:%.*]] = and <4 x i32> [[A:%.*]], <i32 31, i32 0, i32 undef, i32 undef>
-; CHECK-NEXT: [[TMP2:%.*]] = tail call <16 x i32> @llvm.x86.avx512.psll.d.512(<16 x i32> [[V:%.*]], <4 x i32> [[TMP1]])
-; CHECK-NEXT: ret <16 x i32> [[TMP2]]
+; CHECK-NEXT: [[TMP1:%.*]] = and <4 x i32> [[A:%.*]], <i32 31, i32 undef, i32 undef, i32 undef>
+; CHECK-NEXT: [[TMP2:%.*]] = shufflevector <4 x i32> [[TMP1]], <4 x i32> undef, <16 x i32> zeroinitializer
+; CHECK-NEXT: [[TMP3:%.*]] = shl <16 x i32> [[V:%.*]], [[TMP2]]
+; CHECK-NEXT: ret <16 x i32> [[TMP3]]
;
%1 = and <4 x i32> %a, <i32 31, i32 0, i32 undef, i32 undef>
%2 = tail call <16 x i32> @llvm.x86.avx512.psll.d.512(<16 x i32> %v, <4 x i32> %1)
More information about the llvm-commits
mailing list