[llvm] [InstCombine] Fold `lshr -> zext -> shl` patterns (PR #147737)
via llvm-commits
llvm-commits at lists.llvm.org
Wed Jul 9 07:17:24 PDT 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-llvm-analysis
Author: None (zGoldthorpe)
<details>
<summary>Changes</summary>
This patch to the InstCombiner eliminates one of the shift operations from `lshr -> zext -> shl` patterns such as
```llvm
define i64 @<!-- -->masked(i32 %x) {
%lshr = lshr i32 %x, 4
%mask = and i32 %lshr, u0xff
%zext = zext i32 %mask to i64
%shl = shl i64 %zext, 48
ret i64 %shl
}
```
In so doing, this enables the InstCombiner to fully simplify certain integer unpack/repack patterns such as
```llvm
define i64 @<!-- -->combine(i32 %lower, i32 %upper) {
%base = zext i32 %lower to i64
%u.0 = and i32 %upper, u0xff
%z.0 = zext i32 %u.0 to i64
%s.0 = shl i64 %z.0, 32
%o.0 = or i64 %base, %s.0
%r.1 = lshr i32 %upper, 8
%u.1 = and i32 %r.1, u0xff
%z.1 = zext i32 %u.1 to i64
%s.1 = shl i64 %z.1, 40
%o.1 = or i64 %o.0, %s.1
%r.2 = lshr i32 %upper, 16
%u.2 = and i32 %r.2, u0xff
%z.2 = zext i32 %u.2 to i64
%s.2 = shl i64 %z.2, 48
%o.2 = or i64 %o.1, %s.2
%r.3 = lshr i32 %upper, 24
%u.3 = and i32 %r.3, u0xff
%z.3 = zext i32 %u.3 to i64
%s.3 = shl i64 %z.3, 56
%o.3 = or i64 %o.2, %s.3
ret i64 %o.3
}
```
---
Full diff: https://github.com/llvm/llvm-project/pull/147737.diff
4 Files Affected:
- (modified) llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp (+47-2)
- (modified) llvm/test/Analysis/ValueTracking/numsignbits-shl.ll (+3-3)
- (modified) llvm/test/Transforms/InstCombine/iX-ext-split.ll (+3-3)
- (added) llvm/test/Transforms/InstCombine/shifts-around-zext.ll (+69)
``````````diff
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp b/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp
index 550f095b26ba4..b0b1301cd2580 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp
@@ -978,6 +978,47 @@ Instruction *InstCombinerImpl::foldLShrOverflowBit(BinaryOperator &I) {
return new ZExtInst(Overflow, Ty);
}
+/// If the operand of a zext-ed left shift \p V is a logically right-shifted
+/// value, try to fold the opposing shifts.
+static Instruction *foldShrThroughZExtedShl(Type *DestTy, Value *V,
+ unsigned ShlAmt,
+ InstCombinerImpl &IC,
+ const DataLayout &DL) {
+ auto *I = dyn_cast<Instruction>(V);
+ if (!I)
+ return nullptr;
+
+ // Dig through operations until the first shift.
+ while (!I->isShift())
+ if (!match(I, m_BinOp(m_OneUse(m_Instruction(I)), m_Constant())))
+ return nullptr;
+
+ // Fold only if the inner shift is a logical right-shift.
+ uint64_t InnerShrAmt;
+ if (!match(I, m_LShr(m_Value(), m_ConstantInt(InnerShrAmt))))
+ return nullptr;
+
+ if (InnerShrAmt >= ShlAmt) {
+ const uint64_t ReducedShrAmt = InnerShrAmt - ShlAmt;
+ if (!canEvaluateShifted(V, ReducedShrAmt, /*IsLeftShift=*/false, IC,
+ nullptr))
+ return nullptr;
+ Value *NewInner =
+ getShiftedValue(V, ReducedShrAmt, /*isLeftShift=*/false, IC, DL);
+ return new ZExtInst(NewInner, DestTy);
+ }
+
+ if (!canEvaluateShifted(V, InnerShrAmt, /*IsLeftShift=*/true, IC, nullptr))
+ return nullptr;
+
+ const uint64_t ReducedShlAmt = ShlAmt - InnerShrAmt;
+ Value *NewInner =
+ getShiftedValue(V, InnerShrAmt, /*isLeftShift=*/true, IC, DL);
+ Value *NewZExt = IC.Builder.CreateZExt(NewInner, DestTy);
+ return BinaryOperator::CreateShl(NewZExt,
+ ConstantInt::get(DestTy, ReducedShlAmt));
+}
+
// Try to set nuw/nsw flags on shl or exact flag on lshr/ashr using knownbits.
static bool setShiftFlags(BinaryOperator &I, const SimplifyQuery &Q) {
assert(I.isShift() && "Expected a shift as input");
@@ -1062,14 +1103,18 @@ Instruction *InstCombinerImpl::visitShl(BinaryOperator &I) {
if (match(Op1, m_APInt(C))) {
unsigned ShAmtC = C->getZExtValue();
- // shl (zext X), C --> zext (shl X, C)
- // This is only valid if X would have zeros shifted out.
Value *X;
if (match(Op0, m_OneUse(m_ZExt(m_Value(X))))) {
+ // shl (zext X), C --> zext (shl X, C)
+ // This is only valid if X would have zeros shifted out.
unsigned SrcWidth = X->getType()->getScalarSizeInBits();
if (ShAmtC < SrcWidth &&
MaskedValueIsZero(X, APInt::getHighBitsSet(SrcWidth, ShAmtC), &I))
return new ZExtInst(Builder.CreateShl(X, ShAmtC), Ty);
+
+ // Otherwise, try to cancel the outer shl with a lshr inside the zext.
+ if (Instruction *V = foldShrThroughZExtedShl(Ty, X, ShAmtC, *this, DL))
+ return V;
}
// (X >> C) << C --> X & (-1 << C)
diff --git a/llvm/test/Analysis/ValueTracking/numsignbits-shl.ll b/llvm/test/Analysis/ValueTracking/numsignbits-shl.ll
index 5224d75a157d5..8330fd09090c8 100644
--- a/llvm/test/Analysis/ValueTracking/numsignbits-shl.ll
+++ b/llvm/test/Analysis/ValueTracking/numsignbits-shl.ll
@@ -101,9 +101,9 @@ define void @numsignbits_shl_zext_extended_bits_remains(i8 %x) {
define void @numsignbits_shl_zext_all_bits_shifted_out(i8 %x) {
; CHECK-LABEL: define void @numsignbits_shl_zext_all_bits_shifted_out(
; CHECK-SAME: i8 [[X:%.*]]) {
-; CHECK-NEXT: [[ASHR:%.*]] = lshr i8 [[X]], 5
-; CHECK-NEXT: [[ZEXT:%.*]] = zext nneg i8 [[ASHR]] to i16
-; CHECK-NEXT: [[NSB1:%.*]] = shl i16 [[ZEXT]], 14
+; CHECK-NEXT: [[ASHR:%.*]] = and i8 [[X]], 96
+; CHECK-NEXT: [[TMP1:%.*]] = zext nneg i8 [[ASHR]] to i16
+; CHECK-NEXT: [[NSB1:%.*]] = shl nuw i16 [[TMP1]], 9
; CHECK-NEXT: [[AND14:%.*]] = and i16 [[NSB1]], 16384
; CHECK-NEXT: [[ADD14:%.*]] = add i16 [[AND14]], [[NSB1]]
; CHECK-NEXT: call void @escape(i16 [[ADD14]])
diff --git a/llvm/test/Transforms/InstCombine/iX-ext-split.ll b/llvm/test/Transforms/InstCombine/iX-ext-split.ll
index fc804df0e4bec..b8e056725f122 100644
--- a/llvm/test/Transforms/InstCombine/iX-ext-split.ll
+++ b/llvm/test/Transforms/InstCombine/iX-ext-split.ll
@@ -197,9 +197,9 @@ define i128 @i128_ext_split_neg4(i32 %x) {
; CHECK-NEXT: [[ENTRY:.*:]]
; CHECK-NEXT: [[LOWERSRC:%.*]] = sext i32 [[X]] to i64
; CHECK-NEXT: [[LO:%.*]] = zext i64 [[LOWERSRC]] to i128
-; CHECK-NEXT: [[SIGN:%.*]] = lshr i32 [[X]], 31
-; CHECK-NEXT: [[WIDEN:%.*]] = zext nneg i32 [[SIGN]] to i128
-; CHECK-NEXT: [[HI:%.*]] = shl nuw nsw i128 [[WIDEN]], 64
+; CHECK-NEXT: [[SIGN:%.*]] = and i32 [[X]], -2147483648
+; CHECK-NEXT: [[TMP0:%.*]] = zext i32 [[SIGN]] to i128
+; CHECK-NEXT: [[HI:%.*]] = shl nuw nsw i128 [[TMP0]], 33
; CHECK-NEXT: [[RES:%.*]] = or disjoint i128 [[HI]], [[LO]]
; CHECK-NEXT: ret i128 [[RES]]
;
diff --git a/llvm/test/Transforms/InstCombine/shifts-around-zext.ll b/llvm/test/Transforms/InstCombine/shifts-around-zext.ll
new file mode 100644
index 0000000000000..517783fcbcb5c
--- /dev/null
+++ b/llvm/test/Transforms/InstCombine/shifts-around-zext.ll
@@ -0,0 +1,69 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5
+; RUN: opt -S -passes=instcombine %s | FileCheck %s
+
+define i64 @simple(i32 %x) {
+; CHECK-LABEL: define i64 @simple(
+; CHECK-SAME: i32 [[X:%.*]]) {
+; CHECK-NEXT: [[LSHR:%.*]] = and i32 [[X]], -256
+; CHECK-NEXT: [[TMP1:%.*]] = zext i32 [[LSHR]] to i64
+; CHECK-NEXT: [[SHL:%.*]] = shl nuw nsw i64 [[TMP1]], 24
+; CHECK-NEXT: ret i64 [[SHL]]
+;
+ %lshr = lshr i32 %x, 8
+ %zext = zext i32 %lshr to i64
+ %shl = shl i64 %zext, 32
+ ret i64 %shl
+}
+
+;; u0xff0 = 4080
+define i64 @masked(i32 %x) {
+; CHECK-LABEL: define i64 @masked(
+; CHECK-SAME: i32 [[X:%.*]]) {
+; CHECK-NEXT: [[MASK:%.*]] = and i32 [[X]], 4080
+; CHECK-NEXT: [[TMP1:%.*]] = zext nneg i32 [[MASK]] to i64
+; CHECK-NEXT: [[SHL:%.*]] = shl nuw nsw i64 [[TMP1]], 44
+; CHECK-NEXT: ret i64 [[SHL]]
+;
+ %lshr = lshr i32 %x, 4
+ %mask = and i32 %lshr, u0xff
+ %zext = zext i32 %mask to i64
+ %shl = shl i64 %zext, 48
+ ret i64 %shl
+}
+
+define i64 @combine(i32 %lower, i32 %upper) {
+; CHECK-LABEL: define i64 @combine(
+; CHECK-SAME: i32 [[LOWER:%.*]], i32 [[UPPER:%.*]]) {
+; CHECK-NEXT: [[BASE:%.*]] = zext i32 [[LOWER]] to i64
+; CHECK-NEXT: [[TMP1:%.*]] = zext i32 [[UPPER]] to i64
+; CHECK-NEXT: [[TMP2:%.*]] = shl nuw i64 [[TMP1]], 32
+; CHECK-NEXT: [[O_3:%.*]] = or disjoint i64 [[TMP2]], [[BASE]]
+; CHECK-NEXT: ret i64 [[O_3]]
+;
+ %base = zext i32 %lower to i64
+
+ %u.0 = and i32 %upper, u0xff
+ %z.0 = zext i32 %u.0 to i64
+ %s.0 = shl i64 %z.0, 32
+ %o.0 = or i64 %base, %s.0
+
+ %r.1 = lshr i32 %upper, 8
+ %u.1 = and i32 %r.1, u0xff
+ %z.1 = zext i32 %u.1 to i64
+ %s.1 = shl i64 %z.1, 40
+ %o.1 = or i64 %o.0, %s.1
+
+ %r.2 = lshr i32 %upper, 16
+ %u.2 = and i32 %r.2, u0xff
+ %z.2 = zext i32 %u.2 to i64
+ %s.2 = shl i64 %z.2, 48
+ %o.2 = or i64 %o.1, %s.2
+
+ %r.3 = lshr i32 %upper, 24
+ %u.3 = and i32 %r.3, u0xff
+ %z.3 = zext i32 %u.3 to i64
+ %s.3 = shl i64 %z.3, 56
+ %o.3 = or i64 %o.2, %s.3
+
+ ret i64 %o.3
+}
``````````
</details>
https://github.com/llvm/llvm-project/pull/147737
More information about the llvm-commits
mailing list