[llvm] [InstCombine] Fold `lshr -> zext -> shl` patterns (PR #147737)
via llvm-commits
llvm-commits at lists.llvm.org
Fri Jul 18 10:24:25 PDT 2025
https://github.com/zGoldthorpe updated https://github.com/llvm/llvm-project/pull/147737
>From 3048c6d54aaa3f410b45af4c0ae3ad26c9f9b114 Mon Sep 17 00:00:00 2001
From: Zach Goldthorpe <Zach.Goldthorpe at amd.com>
Date: Tue, 8 Jul 2025 17:52:34 -0500
Subject: [PATCH 1/3] Added patch to fold `lshr + zext + shl` patterns
---
.../InstCombine/InstCombineShifts.cpp | 49 ++++++++++++-
.../Analysis/ValueTracking/numsignbits-shl.ll | 6 +-
.../Transforms/InstCombine/iX-ext-split.ll | 6 +-
.../InstCombine/shifts-around-zext.ll | 69 +++++++++++++++++++
4 files changed, 122 insertions(+), 8 deletions(-)
create mode 100644 llvm/test/Transforms/InstCombine/shifts-around-zext.ll
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp b/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp
index 550f095b26ba4..b0b1301cd2580 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp
@@ -978,6 +978,47 @@ Instruction *InstCombinerImpl::foldLShrOverflowBit(BinaryOperator &I) {
return new ZExtInst(Overflow, Ty);
}
+/// If the operand of a zext-ed left shift \p V is a logically right-shifted
+/// value, try to fold the opposing shifts.
+static Instruction *foldShrThroughZExtedShl(Type *DestTy, Value *V,
+ unsigned ShlAmt,
+ InstCombinerImpl &IC,
+ const DataLayout &DL) {
+ auto *I = dyn_cast<Instruction>(V);
+ if (!I)
+ return nullptr;
+
+ // Dig through operations until the first shift.
+ while (!I->isShift())
+ if (!match(I, m_BinOp(m_OneUse(m_Instruction(I)), m_Constant())))
+ return nullptr;
+
+ // Fold only if the inner shift is a logical right-shift.
+ uint64_t InnerShrAmt;
+ if (!match(I, m_LShr(m_Value(), m_ConstantInt(InnerShrAmt))))
+ return nullptr;
+
+ if (InnerShrAmt >= ShlAmt) {
+ const uint64_t ReducedShrAmt = InnerShrAmt - ShlAmt;
+ if (!canEvaluateShifted(V, ReducedShrAmt, /*IsLeftShift=*/false, IC,
+ nullptr))
+ return nullptr;
+ Value *NewInner =
+ getShiftedValue(V, ReducedShrAmt, /*isLeftShift=*/false, IC, DL);
+ return new ZExtInst(NewInner, DestTy);
+ }
+
+ if (!canEvaluateShifted(V, InnerShrAmt, /*IsLeftShift=*/true, IC, nullptr))
+ return nullptr;
+
+ const uint64_t ReducedShlAmt = ShlAmt - InnerShrAmt;
+ Value *NewInner =
+ getShiftedValue(V, InnerShrAmt, /*isLeftShift=*/true, IC, DL);
+ Value *NewZExt = IC.Builder.CreateZExt(NewInner, DestTy);
+ return BinaryOperator::CreateShl(NewZExt,
+ ConstantInt::get(DestTy, ReducedShlAmt));
+}
+
// Try to set nuw/nsw flags on shl or exact flag on lshr/ashr using knownbits.
static bool setShiftFlags(BinaryOperator &I, const SimplifyQuery &Q) {
assert(I.isShift() && "Expected a shift as input");
@@ -1062,14 +1103,18 @@ Instruction *InstCombinerImpl::visitShl(BinaryOperator &I) {
if (match(Op1, m_APInt(C))) {
unsigned ShAmtC = C->getZExtValue();
- // shl (zext X), C --> zext (shl X, C)
- // This is only valid if X would have zeros shifted out.
Value *X;
if (match(Op0, m_OneUse(m_ZExt(m_Value(X))))) {
+ // shl (zext X), C --> zext (shl X, C)
+ // This is only valid if X would have zeros shifted out.
unsigned SrcWidth = X->getType()->getScalarSizeInBits();
if (ShAmtC < SrcWidth &&
MaskedValueIsZero(X, APInt::getHighBitsSet(SrcWidth, ShAmtC), &I))
return new ZExtInst(Builder.CreateShl(X, ShAmtC), Ty);
+
+ // Otherwise, try to cancel the outer shl with a lshr inside the zext.
+ if (Instruction *V = foldShrThroughZExtedShl(Ty, X, ShAmtC, *this, DL))
+ return V;
}
// (X >> C) << C --> X & (-1 << C)
diff --git a/llvm/test/Analysis/ValueTracking/numsignbits-shl.ll b/llvm/test/Analysis/ValueTracking/numsignbits-shl.ll
index 5224d75a157d5..8330fd09090c8 100644
--- a/llvm/test/Analysis/ValueTracking/numsignbits-shl.ll
+++ b/llvm/test/Analysis/ValueTracking/numsignbits-shl.ll
@@ -101,9 +101,9 @@ define void @numsignbits_shl_zext_extended_bits_remains(i8 %x) {
define void @numsignbits_shl_zext_all_bits_shifted_out(i8 %x) {
; CHECK-LABEL: define void @numsignbits_shl_zext_all_bits_shifted_out(
; CHECK-SAME: i8 [[X:%.*]]) {
-; CHECK-NEXT: [[ASHR:%.*]] = lshr i8 [[X]], 5
-; CHECK-NEXT: [[ZEXT:%.*]] = zext nneg i8 [[ASHR]] to i16
-; CHECK-NEXT: [[NSB1:%.*]] = shl i16 [[ZEXT]], 14
+; CHECK-NEXT: [[ASHR:%.*]] = and i8 [[X]], 96
+; CHECK-NEXT: [[TMP1:%.*]] = zext nneg i8 [[ASHR]] to i16
+; CHECK-NEXT: [[NSB1:%.*]] = shl nuw i16 [[TMP1]], 9
; CHECK-NEXT: [[AND14:%.*]] = and i16 [[NSB1]], 16384
; CHECK-NEXT: [[ADD14:%.*]] = add i16 [[AND14]], [[NSB1]]
; CHECK-NEXT: call void @escape(i16 [[ADD14]])
diff --git a/llvm/test/Transforms/InstCombine/iX-ext-split.ll b/llvm/test/Transforms/InstCombine/iX-ext-split.ll
index fc804df0e4bec..b8e056725f122 100644
--- a/llvm/test/Transforms/InstCombine/iX-ext-split.ll
+++ b/llvm/test/Transforms/InstCombine/iX-ext-split.ll
@@ -197,9 +197,9 @@ define i128 @i128_ext_split_neg4(i32 %x) {
; CHECK-NEXT: [[ENTRY:.*:]]
; CHECK-NEXT: [[LOWERSRC:%.*]] = sext i32 [[X]] to i64
; CHECK-NEXT: [[LO:%.*]] = zext i64 [[LOWERSRC]] to i128
-; CHECK-NEXT: [[SIGN:%.*]] = lshr i32 [[X]], 31
-; CHECK-NEXT: [[WIDEN:%.*]] = zext nneg i32 [[SIGN]] to i128
-; CHECK-NEXT: [[HI:%.*]] = shl nuw nsw i128 [[WIDEN]], 64
+; CHECK-NEXT: [[SIGN:%.*]] = and i32 [[X]], -2147483648
+; CHECK-NEXT: [[TMP0:%.*]] = zext i32 [[SIGN]] to i128
+; CHECK-NEXT: [[HI:%.*]] = shl nuw nsw i128 [[TMP0]], 33
; CHECK-NEXT: [[RES:%.*]] = or disjoint i128 [[HI]], [[LO]]
; CHECK-NEXT: ret i128 [[RES]]
;
diff --git a/llvm/test/Transforms/InstCombine/shifts-around-zext.ll b/llvm/test/Transforms/InstCombine/shifts-around-zext.ll
new file mode 100644
index 0000000000000..517783fcbcb5c
--- /dev/null
+++ b/llvm/test/Transforms/InstCombine/shifts-around-zext.ll
@@ -0,0 +1,69 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5
+; RUN: opt -S -passes=instcombine %s | FileCheck %s
+
+define i64 @simple(i32 %x) {
+; CHECK-LABEL: define i64 @simple(
+; CHECK-SAME: i32 [[X:%.*]]) {
+; CHECK-NEXT: [[LSHR:%.*]] = and i32 [[X]], -256
+; CHECK-NEXT: [[TMP1:%.*]] = zext i32 [[LSHR]] to i64
+; CHECK-NEXT: [[SHL:%.*]] = shl nuw nsw i64 [[TMP1]], 24
+; CHECK-NEXT: ret i64 [[SHL]]
+;
+ %lshr = lshr i32 %x, 8
+ %zext = zext i32 %lshr to i64
+ %shl = shl i64 %zext, 32
+ ret i64 %shl
+}
+
+;; u0xff0 = 4080
+define i64 @masked(i32 %x) {
+; CHECK-LABEL: define i64 @masked(
+; CHECK-SAME: i32 [[X:%.*]]) {
+; CHECK-NEXT: [[MASK:%.*]] = and i32 [[X]], 4080
+; CHECK-NEXT: [[TMP1:%.*]] = zext nneg i32 [[MASK]] to i64
+; CHECK-NEXT: [[SHL:%.*]] = shl nuw nsw i64 [[TMP1]], 44
+; CHECK-NEXT: ret i64 [[SHL]]
+;
+ %lshr = lshr i32 %x, 4
+ %mask = and i32 %lshr, u0xff
+ %zext = zext i32 %mask to i64
+ %shl = shl i64 %zext, 48
+ ret i64 %shl
+}
+
+define i64 @combine(i32 %lower, i32 %upper) {
+; CHECK-LABEL: define i64 @combine(
+; CHECK-SAME: i32 [[LOWER:%.*]], i32 [[UPPER:%.*]]) {
+; CHECK-NEXT: [[BASE:%.*]] = zext i32 [[LOWER]] to i64
+; CHECK-NEXT: [[TMP1:%.*]] = zext i32 [[UPPER]] to i64
+; CHECK-NEXT: [[TMP2:%.*]] = shl nuw i64 [[TMP1]], 32
+; CHECK-NEXT: [[O_3:%.*]] = or disjoint i64 [[TMP2]], [[BASE]]
+; CHECK-NEXT: ret i64 [[O_3]]
+;
+ %base = zext i32 %lower to i64
+
+ %u.0 = and i32 %upper, u0xff
+ %z.0 = zext i32 %u.0 to i64
+ %s.0 = shl i64 %z.0, 32
+ %o.0 = or i64 %base, %s.0
+
+ %r.1 = lshr i32 %upper, 8
+ %u.1 = and i32 %r.1, u0xff
+ %z.1 = zext i32 %u.1 to i64
+ %s.1 = shl i64 %z.1, 40
+ %o.1 = or i64 %o.0, %s.1
+
+ %r.2 = lshr i32 %upper, 16
+ %u.2 = and i32 %r.2, u0xff
+ %z.2 = zext i32 %u.2 to i64
+ %s.2 = shl i64 %z.2, 48
+ %o.2 = or i64 %o.1, %s.2
+
+ %r.3 = lshr i32 %upper, 24
+ %u.3 = and i32 %r.3, u0xff
+ %z.3 = zext i32 %u.3 to i64
+ %s.3 = shl i64 %z.3, 56
+ %o.3 = or i64 %o.2, %s.3
+
+ ret i64 %o.3
+}
>From 8ea66689bca1005b09f842500d7ece5ad4881386 Mon Sep 17 00:00:00 2001
From: Zach Goldthorpe <Zach.Goldthorpe at amd.com>
Date: Fri, 11 Jul 2025 15:33:46 -0500
Subject: [PATCH 2/3] Incorporated straightforward reviewer feedback.
---
.../InstCombine/InstCombineShifts.cpp | 48 +++++++++++--------
.../InstCombine/shifts-around-zext.ll | 22 +++++++--
2 files changed, 46 insertions(+), 24 deletions(-)
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp b/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp
index b0b1301cd2580..edcd963e3a7db 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp
@@ -978,45 +978,53 @@ Instruction *InstCombinerImpl::foldLShrOverflowBit(BinaryOperator &I) {
return new ZExtInst(Overflow, Ty);
}
-/// If the operand of a zext-ed left shift \p V is a logically right-shifted
-/// value, try to fold the opposing shifts.
-static Instruction *foldShrThroughZExtedShl(Type *DestTy, Value *V,
+/// If the operand \p Op of a zext-ed left shift \p I is a logically
+/// right-shifted value, try to fold the opposing shifts.
+static Instruction *foldShrThroughZExtedShl(BinaryOperator &I, Value *Op,
unsigned ShlAmt,
InstCombinerImpl &IC,
const DataLayout &DL) {
- auto *I = dyn_cast<Instruction>(V);
- if (!I)
+ Type *DestTy = I.getType();
+
+ auto *Inner = dyn_cast<Instruction>(Op);
+ if (!Inner)
return nullptr;
// Dig through operations until the first shift.
- while (!I->isShift())
- if (!match(I, m_BinOp(m_OneUse(m_Instruction(I)), m_Constant())))
+ while (!Inner->isShift())
+ if (!match(Inner, m_BinOp(m_OneUse(m_Instruction(Inner)), m_Constant())))
return nullptr;
// Fold only if the inner shift is a logical right-shift.
- uint64_t InnerShrAmt;
- if (!match(I, m_LShr(m_Value(), m_ConstantInt(InnerShrAmt))))
+ const APInt *InnerShrConst;
+ if (!match(Inner, m_LShr(m_Value(), m_APInt(InnerShrConst))))
return nullptr;
+ const uint64_t InnerShrAmt = InnerShrConst->getZExtValue();
if (InnerShrAmt >= ShlAmt) {
const uint64_t ReducedShrAmt = InnerShrAmt - ShlAmt;
- if (!canEvaluateShifted(V, ReducedShrAmt, /*IsLeftShift=*/false, IC,
+ if (!canEvaluateShifted(Op, ReducedShrAmt, /*IsLeftShift=*/false, IC,
nullptr))
return nullptr;
- Value *NewInner =
- getShiftedValue(V, ReducedShrAmt, /*isLeftShift=*/false, IC, DL);
- return new ZExtInst(NewInner, DestTy);
+ Value *NewOp =
+ getShiftedValue(Op, ReducedShrAmt, /*isLeftShift=*/false, IC, DL);
+ return new ZExtInst(NewOp, DestTy);
}
- if (!canEvaluateShifted(V, InnerShrAmt, /*IsLeftShift=*/true, IC, nullptr))
+ if (!canEvaluateShifted(Op, InnerShrAmt, /*IsLeftShift=*/true, IC, nullptr))
return nullptr;
const uint64_t ReducedShlAmt = ShlAmt - InnerShrAmt;
- Value *NewInner =
- getShiftedValue(V, InnerShrAmt, /*isLeftShift=*/true, IC, DL);
- Value *NewZExt = IC.Builder.CreateZExt(NewInner, DestTy);
- return BinaryOperator::CreateShl(NewZExt,
- ConstantInt::get(DestTy, ReducedShlAmt));
+ Value *NewOp = getShiftedValue(Op, InnerShrAmt, /*isLeftShift=*/true, IC, DL);
+ Value *NewZExt = IC.Builder.CreateZExt(NewOp, DestTy);
+ NewZExt->takeName(I.getOperand(0));
+ auto *NewShl = BinaryOperator::CreateShl(
+ NewZExt, ConstantInt::get(DestTy, ReducedShlAmt));
+
+ // New shl inherits all flags from the original shl instruction.
+ NewShl->setHasNoSignedWrap(I.hasNoSignedWrap());
+ NewShl->setHasNoUnsignedWrap(I.hasNoUnsignedWrap());
+ return NewShl;
}
// Try to set nuw/nsw flags on shl or exact flag on lshr/ashr using knownbits.
@@ -1113,7 +1121,7 @@ Instruction *InstCombinerImpl::visitShl(BinaryOperator &I) {
return new ZExtInst(Builder.CreateShl(X, ShAmtC), Ty);
// Otherwise, try to cancel the outer shl with a lshr inside the zext.
- if (Instruction *V = foldShrThroughZExtedShl(Ty, X, ShAmtC, *this, DL))
+ if (Instruction *V = foldShrThroughZExtedShl(I, X, ShAmtC, *this, DL))
return V;
}
diff --git a/llvm/test/Transforms/InstCombine/shifts-around-zext.ll b/llvm/test/Transforms/InstCombine/shifts-around-zext.ll
index 517783fcbcb5c..818e7b0fc735c 100644
--- a/llvm/test/Transforms/InstCombine/shifts-around-zext.ll
+++ b/llvm/test/Transforms/InstCombine/shifts-around-zext.ll
@@ -5,8 +5,8 @@ define i64 @simple(i32 %x) {
; CHECK-LABEL: define i64 @simple(
; CHECK-SAME: i32 [[X:%.*]]) {
; CHECK-NEXT: [[LSHR:%.*]] = and i32 [[X]], -256
-; CHECK-NEXT: [[TMP1:%.*]] = zext i32 [[LSHR]] to i64
-; CHECK-NEXT: [[SHL:%.*]] = shl nuw nsw i64 [[TMP1]], 24
+; CHECK-NEXT: [[ZEXT:%.*]] = zext i32 [[LSHR]] to i64
+; CHECK-NEXT: [[SHL:%.*]] = shl nuw nsw i64 [[ZEXT]], 24
; CHECK-NEXT: ret i64 [[SHL]]
;
%lshr = lshr i32 %x, 8
@@ -20,8 +20,8 @@ define i64 @masked(i32 %x) {
; CHECK-LABEL: define i64 @masked(
; CHECK-SAME: i32 [[X:%.*]]) {
; CHECK-NEXT: [[MASK:%.*]] = and i32 [[X]], 4080
-; CHECK-NEXT: [[TMP1:%.*]] = zext nneg i32 [[MASK]] to i64
-; CHECK-NEXT: [[SHL:%.*]] = shl nuw nsw i64 [[TMP1]], 44
+; CHECK-NEXT: [[ZEXT:%.*]] = zext nneg i32 [[MASK]] to i64
+; CHECK-NEXT: [[SHL:%.*]] = shl nuw nsw i64 [[ZEXT]], 44
; CHECK-NEXT: ret i64 [[SHL]]
;
%lshr = lshr i32 %x, 4
@@ -67,3 +67,17 @@ define i64 @combine(i32 %lower, i32 %upper) {
ret i64 %o.3
}
+
+define <2 x i64> @simple.vec(<2 x i32> %v) {
+; CHECK-LABEL: define <2 x i64> @simple.vec(
+; CHECK-SAME: <2 x i32> [[V:%.*]]) {
+; CHECK-NEXT: [[LSHR:%.*]] = and <2 x i32> [[V]], splat (i32 -256)
+; CHECK-NEXT: [[ZEXT:%.*]] = zext <2 x i32> [[LSHR]] to <2 x i64>
+; CHECK-NEXT: [[SHL:%.*]] = shl nuw nsw <2 x i64> [[ZEXT]], splat (i64 24)
+; CHECK-NEXT: ret <2 x i64> [[SHL]]
+;
+ %lshr = lshr <2 x i32> %v, splat(i32 8)
+ %zext = zext <2 x i32> %lshr to <2 x i64>
+ %shl = shl <2 x i64> %zext, splat(i64 32)
+ ret <2 x i64> %shl
+}
>From 730920c9517d77623ce53ae1d2edb7625b4ef2df Mon Sep 17 00:00:00 2001
From: Zach Goldthorpe <Zach.Goldthorpe at amd.com>
Date: Mon, 14 Jul 2025 17:19:54 -0500
Subject: [PATCH 3/3] Refactored `canEvaluateShifted` to identify candidates
for simplification.
---
.../InstCombine/InstCombineShifts.cpp | 186 +++++++++++-------
.../InstCombine/shifts-around-zext.ll | 107 ++++++++--
2 files changed, 207 insertions(+), 86 deletions(-)
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp b/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp
index edcd963e3a7db..1436e4bd5854f 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp
@@ -530,92 +530,116 @@ Instruction *InstCombinerImpl::commonShiftTransforms(BinaryOperator &I) {
return nullptr;
}
-/// Return true if we can simplify two logical (either left or right) shifts
-/// that have constant shift amounts: OuterShift (InnerShift X, C1), C2.
-static bool canEvaluateShiftedShift(unsigned OuterShAmt, bool IsOuterShl,
- Instruction *InnerShift,
- InstCombinerImpl &IC, Instruction *CxtI) {
+/// Return a bitmask of all constant outer shift amounts that can be simplified
+/// by foldShiftedShift().
+static APInt getEvaluableShiftedShiftMask(bool IsOuterShl,
+ Instruction *InnerShift,
+ InstCombinerImpl &IC,
+ Instruction *CxtI) {
assert(InnerShift->isLogicalShift() && "Unexpected instruction type");
+ const unsigned TypeWidth = InnerShift->getType()->getScalarSizeInBits();
+
// We need constant scalar or constant splat shifts.
const APInt *InnerShiftConst;
if (!match(InnerShift->getOperand(1), m_APInt(InnerShiftConst)))
- return false;
+ return APInt::getZero(TypeWidth);
- // Two logical shifts in the same direction:
+ if (InnerShiftConst->uge(TypeWidth))
+ return APInt::getZero(TypeWidth);
+
+ const unsigned InnerShAmt = InnerShiftConst->getZExtValue();
+
+ // Two logical shifts in the same direction can always be simplified, so long
+ // as the total shift amount is legal.
// shl (shl X, C1), C2 --> shl X, C1 + C2
// lshr (lshr X, C1), C2 --> lshr X, C1 + C2
bool IsInnerShl = InnerShift->getOpcode() == Instruction::Shl;
if (IsInnerShl == IsOuterShl)
- return true;
+ return APInt::getLowBitsSet(TypeWidth, TypeWidth - InnerShAmt);
+ APInt ShMask = APInt::getZero(TypeWidth);
// Equal shift amounts in opposite directions become bitwise 'and':
// lshr (shl X, C), C --> and X, C'
// shl (lshr X, C), C --> and X, C'
- if (*InnerShiftConst == OuterShAmt)
- return true;
+ ShMask.setBit(InnerShAmt);
- // If the 2nd shift is bigger than the 1st, we can fold:
+ // If the inner shift is bigger than the outer, we can fold:
// lshr (shl X, C1), C2 --> and (shl X, C1 - C2), C3
// shl (lshr X, C1), C2 --> and (lshr X, C1 - C2), C3
- // but it isn't profitable unless we know the and'd out bits are already zero.
- // Also, check that the inner shift is valid (less than the type width) or
- // we'll crash trying to produce the bit mask for the 'and'.
- unsigned TypeWidth = InnerShift->getType()->getScalarSizeInBits();
- if (InnerShiftConst->ugt(OuterShAmt) && InnerShiftConst->ult(TypeWidth)) {
- unsigned InnerShAmt = InnerShiftConst->getZExtValue();
- unsigned MaskShift =
- IsInnerShl ? TypeWidth - InnerShAmt : InnerShAmt - OuterShAmt;
- APInt Mask = APInt::getLowBitsSet(TypeWidth, OuterShAmt) << MaskShift;
- if (IC.MaskedValueIsZero(InnerShift->getOperand(0), Mask, CxtI))
- return true;
- }
-
- return false;
+ // but it isn't profitable unless we know the masked out bits are already
+ // zero.
+ KnownBits Known = IC.computeKnownBits(InnerShift->getOperand(0), CxtI);
+ // Isolate the bits that are annihilated by the inner shift.
+ APInt InnerShMask = IsInnerShl ? Known.Zero.lshr(TypeWidth - InnerShAmt)
+ : Known.Zero.trunc(InnerShAmt);
+ // Isolate the upper (resp. lower) InnerShAmt bits of the base operand of the
+ // inner shl (resp. lshr).
+ // Then:
+ // - lshr (shl X, C1), C2 == (shl X, C1 - C2) if the bottom C2 of the isolated
+ // bits are zero
+ // - shl (lshr X, C1), C2 == (lshr X, C1 - C2) if the top C2 of the isolated
+ // bits are zero
+ const unsigned MaxOuterShAmt =
+ IsInnerShl ? Known.Zero.lshr(TypeWidth - InnerShAmt).countr_one()
+ : Known.Zero.trunc(InnerShAmt).countl_one();
+ ShMask.setLowBits(MaxOuterShAmt);
+ return ShMask;
}
-/// See if we can compute the specified value, but shifted logically to the left
-/// or right by some number of bits. This should return true if the expression
-/// can be computed for the same cost as the current expression tree. This is
-/// used to eliminate extraneous shifting from things like:
-/// %C = shl i128 %A, 64
-/// %D = shl i128 %B, 96
-/// %E = or i128 %C, %D
-/// %F = lshr i128 %E, 64
-/// where the client will ask if E can be computed shifted right by 64-bits. If
-/// this succeeds, getShiftedValue() will be called to produce the value.
-static bool canEvaluateShifted(Value *V, unsigned NumBits, bool IsLeftShift,
- InstCombinerImpl &IC, Instruction *CxtI) {
+/// Given a bitmask \p ShiftMask of desired shift amounts, determine the submask
+/// of bits corresponding to shift amounts X for which the given expression \p V
+/// can be computed for at worst the same cost as the current expression tree
+/// when shifted by X. For each set bit in the \p ShiftMask afterward,
+/// getShiftedValue() can produce the corresponding value.
+///
+/// \returns true if and only if at least one bit of the \p ShiftMask is set
+/// after refinement.
+static bool refineEvaluableShiftMask(Value *V, APInt &ShiftMask,
+ bool IsLeftShift, InstCombinerImpl &IC,
+ Instruction *CxtI) {
// We can always evaluate immediate constants.
if (match(V, m_ImmConstant()))
return true;
Instruction *I = dyn_cast<Instruction>(V);
- if (!I) return false;
+ if (!I) {
+ ShiftMask.clearAllBits();
+ return false;
+ }
// We can't mutate something that has multiple uses: doing so would
// require duplicating the instruction in general, which isn't profitable.
- if (!I->hasOneUse()) return false;
+ if (!I->hasOneUse()) {
+ ShiftMask.clearAllBits();
+ return false;
+ }
switch (I->getOpcode()) {
- default: return false;
+ default: {
+ ShiftMask.clearAllBits();
+ return false;
+ }
case Instruction::And:
case Instruction::Or:
case Instruction::Xor:
- // Bitwise operators can all arbitrarily be arbitrarily evaluated shifted.
- return canEvaluateShifted(I->getOperand(0), NumBits, IsLeftShift, IC, I) &&
- canEvaluateShifted(I->getOperand(1), NumBits, IsLeftShift, IC, I);
+ return refineEvaluableShiftMask(I->getOperand(0), ShiftMask, IsLeftShift,
+ IC, I) &&
+ refineEvaluableShiftMask(I->getOperand(1), ShiftMask, IsLeftShift,
+ IC, I);
case Instruction::Shl:
- case Instruction::LShr:
- return canEvaluateShiftedShift(NumBits, IsLeftShift, I, IC, CxtI);
+ case Instruction::LShr: {
+ ShiftMask &= getEvaluableShiftedShiftMask(IsLeftShift, I, IC, CxtI);
+ return !ShiftMask.isZero();
+ }
case Instruction::Select: {
SelectInst *SI = cast<SelectInst>(I);
Value *TrueVal = SI->getTrueValue();
Value *FalseVal = SI->getFalseValue();
- return canEvaluateShifted(TrueVal, NumBits, IsLeftShift, IC, SI) &&
- canEvaluateShifted(FalseVal, NumBits, IsLeftShift, IC, SI);
+ return refineEvaluableShiftMask(TrueVal, ShiftMask, IsLeftShift, IC, SI) &&
+ refineEvaluableShiftMask(FalseVal, ShiftMask, IsLeftShift, IC, SI);
}
case Instruction::PHI: {
// We can change a phi if we can change all operands. Note that we never
@@ -623,19 +647,42 @@ static bool canEvaluateShifted(Value *V, unsigned NumBits, bool IsLeftShift,
// instructions with a single use.
PHINode *PN = cast<PHINode>(I);
for (Value *IncValue : PN->incoming_values())
- if (!canEvaluateShifted(IncValue, NumBits, IsLeftShift, IC, PN))
+ if (!refineEvaluableShiftMask(IncValue, ShiftMask, IsLeftShift, IC, PN))
return false;
return true;
}
case Instruction::Mul: {
const APInt *MulConst;
// We can fold (shr (mul X, -(1 << C)), C) -> (and (neg X), C`)
- return !IsLeftShift && match(I->getOperand(1), m_APInt(MulConst)) &&
- MulConst->isNegatedPowerOf2() && MulConst->countr_zero() == NumBits;
+ if (IsLeftShift || !match(I->getOperand(1), m_APInt(MulConst)) ||
+ !MulConst->isNegatedPowerOf2()) {
+ ShiftMask.clearAllBits();
+ return false;
+ }
+ ShiftMask &=
+ APInt::getOneBitSet(ShiftMask.getBitWidth(), MulConst->countr_zero());
+ return !ShiftMask.isZero();
}
}
}
+/// See if we can compute the specified value, but shifted logically to the left
+/// or right by some number of bits. This should return true if the expression
+/// can be computed for the same cost as the current expression tree. This is
+/// used to eliminate extraneous shifting from things like:
+/// %C = shl i128 %A, 64
+/// %D = shl i128 %B, 96
+/// %E = or i128 %C, %D
+/// %F = lshr i128 %E, 64
+/// where the client will ask if E can be computed shifted right by 64-bits. If
+/// this succeeds, getShiftedValue() will be called to produce the value.
+static bool canEvaluateShifted(Value *V, unsigned ShAmt, bool IsLeftShift,
+ InstCombinerImpl &IC, Instruction *CxtI) {
+ APInt ShiftMask =
+ APInt::getOneBitSet(V->getType()->getScalarSizeInBits(), ShAmt);
+ return refineEvaluableShiftMask(V, ShiftMask, IsLeftShift, IC, CxtI);
+}
+
/// Fold OuterShift (InnerShift X, C1), C2.
/// See canEvaluateShiftedShift() for the constraints on these instructions.
static Value *foldShiftedShift(BinaryOperator *InnerShift, unsigned OuterShAmt,
@@ -985,37 +1032,32 @@ static Instruction *foldShrThroughZExtedShl(BinaryOperator &I, Value *Op,
InstCombinerImpl &IC,
const DataLayout &DL) {
Type *DestTy = I.getType();
+ const unsigned InnerBitWidth = Op->getType()->getScalarSizeInBits();
- auto *Inner = dyn_cast<Instruction>(Op);
- if (!Inner)
+ // Determine if the operand is effectively right-shifted by counting the
+ // known leading zero bits.
+ KnownBits Known = IC.computeKnownBits(Op, nullptr);
+ const unsigned MaxInnerShrAmt = Known.countMinLeadingZeros();
+ if (MaxInnerShrAmt == 0)
return nullptr;
+ APInt ShrMask =
+ APInt::getLowBitsSet(InnerBitWidth, std::min(MaxInnerShrAmt, ShlAmt) + 1);
- // Dig through operations until the first shift.
- while (!Inner->isShift())
- if (!match(Inner, m_BinOp(m_OneUse(m_Instruction(Inner)), m_Constant())))
- return nullptr;
-
- // Fold only if the inner shift is a logical right-shift.
- const APInt *InnerShrConst;
- if (!match(Inner, m_LShr(m_Value(), m_APInt(InnerShrConst))))
+ // Undo the maximal inner right shift amount that simplifies the overall
+ // computation.
+ if (!refineEvaluableShiftMask(Op, ShrMask, /*IsLeftShift=*/true, IC, nullptr))
return nullptr;
- const uint64_t InnerShrAmt = InnerShrConst->getZExtValue();
- if (InnerShrAmt >= ShlAmt) {
- const uint64_t ReducedShrAmt = InnerShrAmt - ShlAmt;
- if (!canEvaluateShifted(Op, ReducedShrAmt, /*IsLeftShift=*/false, IC,
- nullptr))
- return nullptr;
- Value *NewOp =
- getShiftedValue(Op, ReducedShrAmt, /*isLeftShift=*/false, IC, DL);
- return new ZExtInst(NewOp, DestTy);
- }
-
- if (!canEvaluateShifted(Op, InnerShrAmt, /*IsLeftShift=*/true, IC, nullptr))
+ const unsigned InnerShrAmt = ShrMask.getActiveBits() - 1;
+ if (InnerShrAmt == 0)
return nullptr;
+ assert(InnerShrAmt <= ShlAmt);
const uint64_t ReducedShlAmt = ShlAmt - InnerShrAmt;
Value *NewOp = getShiftedValue(Op, InnerShrAmt, /*isLeftShift=*/true, IC, DL);
+ if (ReducedShlAmt == 0)
+ return new ZExtInst(NewOp, DestTy);
+
Value *NewZExt = IC.Builder.CreateZExt(NewOp, DestTy);
NewZExt->takeName(I.getOperand(0));
auto *NewShl = BinaryOperator::CreateShl(
diff --git a/llvm/test/Transforms/InstCombine/shifts-around-zext.ll b/llvm/test/Transforms/InstCombine/shifts-around-zext.ll
index 818e7b0fc735c..82ed2985b3b16 100644
--- a/llvm/test/Transforms/InstCombine/shifts-around-zext.ll
+++ b/llvm/test/Transforms/InstCombine/shifts-around-zext.ll
@@ -1,6 +1,8 @@
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5
; RUN: opt -S -passes=instcombine %s | FileCheck %s
+declare void @clobber.i32(i32)
+
define i64 @simple(i32 %x) {
; CHECK-LABEL: define i64 @simple(
; CHECK-SAME: i32 [[X:%.*]]) {
@@ -15,6 +17,20 @@ define i64 @simple(i32 %x) {
ret i64 %shl
}
+define <2 x i64> @simple.vec(<2 x i32> %v) {
+; CHECK-LABEL: define <2 x i64> @simple.vec(
+; CHECK-SAME: <2 x i32> [[V:%.*]]) {
+; CHECK-NEXT: [[LSHR:%.*]] = and <2 x i32> [[V]], splat (i32 -256)
+; CHECK-NEXT: [[ZEXT:%.*]] = zext <2 x i32> [[LSHR]] to <2 x i64>
+; CHECK-NEXT: [[SHL:%.*]] = shl nuw nsw <2 x i64> [[ZEXT]], splat (i64 24)
+; CHECK-NEXT: ret <2 x i64> [[SHL]]
+;
+ %lshr = lshr <2 x i32> %v, splat(i32 8)
+ %zext = zext <2 x i32> %lshr to <2 x i64>
+ %shl = shl <2 x i64> %zext, splat(i64 32)
+ ret <2 x i64> %shl
+}
+
;; u0xff0 = 4080
define i64 @masked(i32 %x) {
; CHECK-LABEL: define i64 @masked(
@@ -31,6 +47,83 @@ define i64 @masked(i32 %x) {
ret i64 %shl
}
+define i64 @masked.multi_use.0(i32 %x) {
+; CHECK-LABEL: define i64 @masked.multi_use.0(
+; CHECK-SAME: i32 [[X:%.*]]) {
+; CHECK-NEXT: [[LSHR:%.*]] = lshr i32 [[X]], 4
+; CHECK-NEXT: call void @clobber.i32(i32 [[LSHR]])
+; CHECK-NEXT: [[MASK:%.*]] = and i32 [[LSHR]], 255
+; CHECK-NEXT: [[ZEXT:%.*]] = zext nneg i32 [[MASK]] to i64
+; CHECK-NEXT: [[SHL:%.*]] = shl nuw nsw i64 [[ZEXT]], 48
+; CHECK-NEXT: ret i64 [[SHL]]
+;
+ %lshr = lshr i32 %x, 4
+ call void @clobber.i32(i32 %lshr)
+ %mask = and i32 %lshr, u0xff
+ %zext = zext i32 %mask to i64
+ %shl = shl i64 %zext, 48
+ ret i64 %shl
+}
+
+define i64 @masked.multi_use.1(i32 %x) {
+; CHECK-LABEL: define i64 @masked.multi_use.1(
+; CHECK-SAME: i32 [[X:%.*]]) {
+; CHECK-NEXT: [[LSHR:%.*]] = lshr i32 [[X]], 4
+; CHECK-NEXT: [[MASK:%.*]] = and i32 [[LSHR]], 255
+; CHECK-NEXT: call void @clobber.i32(i32 [[MASK]])
+; CHECK-NEXT: [[ZEXT:%.*]] = zext nneg i32 [[MASK]] to i64
+; CHECK-NEXT: [[SHL:%.*]] = shl nuw nsw i64 [[ZEXT]], 48
+; CHECK-NEXT: ret i64 [[SHL]]
+;
+ %lshr = lshr i32 %x, 4
+ %mask = and i32 %lshr, u0xff
+ call void @clobber.i32(i32 %mask)
+ %zext = zext i32 %mask to i64
+ %shl = shl i64 %zext, 48
+ ret i64 %shl
+}
+
+define <2 x i64> @masked.multi_use.2(i32 %x) {
+; CHECK-LABEL: define <2 x i64> @masked.multi_use.2(
+; CHECK-SAME: i32 [[X:%.*]]) {
+; CHECK-NEXT: [[LSHR:%.*]] = lshr i32 [[X]], 4
+; CHECK-NEXT: [[MASK:%.*]] = and i32 [[LSHR]], 255
+; CHECK-NEXT: [[ZEXT:%.*]] = zext nneg i32 [[MASK]] to i64
+; CHECK-NEXT: [[SHL:%.*]] = shl nuw nsw i64 [[ZEXT]], 48
+; CHECK-NEXT: [[CLOBBER:%.*]] = xor i32 [[MASK]], 255
+; CHECK-NEXT: [[CLOBBER_Z:%.*]] = zext nneg i32 [[CLOBBER]] to i64
+; CHECK-NEXT: [[V_0:%.*]] = insertelement <2 x i64> poison, i64 [[SHL]], i64 0
+; CHECK-NEXT: [[V_1:%.*]] = insertelement <2 x i64> [[V_0]], i64 [[CLOBBER_Z]], i64 1
+; CHECK-NEXT: ret <2 x i64> [[V_1]]
+;
+ %lshr = lshr i32 %x, 4
+ %mask = and i32 %lshr, u0xff
+ %zext = zext i32 %mask to i64
+ %shl = shl i64 %zext, 48
+
+ %clobber = xor i32 %mask, u0xff
+ %clobber.z = zext i32 %clobber to i64
+ %v.0 = insertelement <2 x i64> poison, i64 %shl, i32 0
+ %v.1 = insertelement <2 x i64> %v.0, i64 %clobber.z, i32 1
+ ret <2 x i64> %v.1
+}
+
+;; u0xff0 = 4080
+define <2 x i64> @masked.vec(<2 x i32> %v) {
+; CHECK-LABEL: define <2 x i64> @masked.vec(
+; CHECK-SAME: <2 x i32> [[V:%.*]]) {
+; CHECK-NEXT: [[MASK:%.*]] = and <2 x i32> [[V]], splat (i32 4080)
+; CHECK-NEXT: [[ZEXT:%.*]] = zext nneg <2 x i32> [[MASK]] to <2 x i64>
+; CHECK-NEXT: [[SHL:%.*]] = shl nuw nsw <2 x i64> [[ZEXT]], splat (i64 44)
+; CHECK-NEXT: ret <2 x i64> [[SHL]]
+;
+ %lshr = lshr <2 x i32> %v, splat(i32 4)
+ %mask = and <2 x i32> %lshr, splat(i32 u0xff)
+ %zext = zext <2 x i32> %mask to <2 x i64>
+ %shl = shl <2 x i64> %zext, splat(i64 48)
+ ret <2 x i64> %shl
+}
+
define i64 @combine(i32 %lower, i32 %upper) {
; CHECK-LABEL: define i64 @combine(
; CHECK-SAME: i32 [[LOWER:%.*]], i32 [[UPPER:%.*]]) {
@@ -67,17 +160,3 @@ define i64 @combine(i32 %lower, i32 %upper) {
ret i64 %o.3
}
-
-define <2 x i64> @simple.vec(<2 x i32> %v) {
-; CHECK-LABEL: define <2 x i64> @simple.vec(
-; CHECK-SAME: <2 x i32> [[V:%.*]]) {
-; CHECK-NEXT: [[LSHR:%.*]] = and <2 x i32> [[V]], splat (i32 -256)
-; CHECK-NEXT: [[ZEXT:%.*]] = zext <2 x i32> [[LSHR]] to <2 x i64>
-; CHECK-NEXT: [[SHL:%.*]] = shl nuw nsw <2 x i64> [[ZEXT]], splat (i64 24)
-; CHECK-NEXT: ret <2 x i64> [[SHL]]
-;
- %lshr = lshr <2 x i32> %v, splat(i32 8)
- %zext = zext <2 x i32> %lshr to <2 x i64>
- %shl = shl <2 x i64> %zext, splat(i64 32)
- ret <2 x i64> %shl
-}
More information about the llvm-commits
mailing list