[llvm] 82caa25 - [InstCombine] Fold integer unpack/repack patterns through ZExt (#153583)
via llvm-commits
llvm-commits at lists.llvm.org
Fri Aug 15 11:48:36 PDT 2025
Author: zGoldthorpe
Date: 2025-08-15T12:48:32-06:00
New Revision: 82caa251d4e145b54ea76236213617076f254c2b
URL: https://github.com/llvm/llvm-project/commit/82caa251d4e145b54ea76236213617076f254c2b
DIFF: https://github.com/llvm/llvm-project/commit/82caa251d4e145b54ea76236213617076f254c2b.diff
LOG: [InstCombine] Fold integer unpack/repack patterns through ZExt (#153583)
This patch explicitly enables the InstCombiner to fold integer
unpack/repack patterns such as
```llvm
define i64 @src_combine(i32 %lower, i32 %upper) {
%base = zext i32 %lower to i64
%u.0 = and i32 %upper, u0xff
%z.0 = zext i32 %u.0 to i64
%s.0 = shl i64 %z.0, 32
%o.0 = or i64 %base, %s.0
%r.1 = lshr i32 %upper, 8
%u.1 = and i32 %r.1, u0xff
%z.1 = zext i32 %u.1 to i64
%s.1 = shl i64 %z.1, 40
%o.1 = or i64 %o.0, %s.1
%r.2 = lshr i32 %upper, 16
%u.2 = and i32 %r.2, u0xff
%z.2 = zext i32 %u.2 to i64
%s.2 = shl i64 %z.2, 48
%o.2 = or i64 %o.1, %s.2
%r.3 = lshr i32 %upper, 24
%u.3 = and i32 %r.3, u0xff
%z.3 = zext i32 %u.3 to i64
%s.3 = shl i64 %z.3, 56
%o.3 = or i64 %o.2, %s.3
ret i64 %o.3
}
; =>
define i64 @tgt_combine(i32 %lower, i32 %upper) {
%base = zext i32 %lower to i64
%upper.zext = zext i32 %upper to i64
%s.0 = shl nuw i64 %upper.zext, 32
%o.3 = or disjoint i64 %s.0, %base
ret i64 %o.3
}
```
Alive2 proofs: [YAy7ny](https://alive2.llvm.org/ce/z/YAy7ny)
Added:
llvm/test/Transforms/InstCombine/repack-ints-thru-zext.ll
Modified:
llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
Removed:
################################################################################
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
index d7971e8e3caea..6e46898634070 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
@@ -3740,6 +3740,82 @@ static Instruction *foldIntegerPackFromVector(Instruction &I,
return CastInst::Create(Instruction::BitCast, MaskedVec, I.getType());
}
+/// Match \p V as "lshr -> mask -> zext -> shl".
+///
+/// \p Int is the underlying integer being extracted from.
+/// \p Mask is a bitmask identifying which bits of the integer are being
+/// extracted. \p Offset identifies which bit of the result \p V corresponds to
+/// the least significant bit of \p Int
+static bool matchZExtedSubInteger(Value *V, Value *&Int, APInt &Mask,
+ uint64_t &Offset, bool &IsShlNUW,
+ bool &IsShlNSW) {
+ Value *ShlOp0;
+ uint64_t ShlAmt = 0;
+ if (!match(V, m_OneUse(m_Shl(m_Value(ShlOp0), m_ConstantInt(ShlAmt)))))
+ return false;
+
+ IsShlNUW = cast<BinaryOperator>(V)->hasNoUnsignedWrap();
+ IsShlNSW = cast<BinaryOperator>(V)->hasNoSignedWrap();
+
+ Value *ZExtOp0;
+ if (!match(ShlOp0, m_OneUse(m_ZExt(m_Value(ZExtOp0)))))
+ return false;
+
+ Value *MaskedOp0;
+ const APInt *ShiftedMaskConst = nullptr;
+ if (!match(ZExtOp0, m_CombineOr(m_OneUse(m_And(m_Value(MaskedOp0),
+ m_APInt(ShiftedMaskConst))),
+ m_Value(MaskedOp0))))
+ return false;
+
+ uint64_t LShrAmt = 0;
+ if (!match(MaskedOp0,
+ m_CombineOr(m_OneUse(m_LShr(m_Value(Int), m_ConstantInt(LShrAmt))),
+ m_Value(Int))))
+ return false;
+
+ if (LShrAmt > ShlAmt)
+ return false;
+ Offset = ShlAmt - LShrAmt;
+
+ Mask = ShiftedMaskConst ? ShiftedMaskConst->shl(LShrAmt)
+ : APInt::getBitsSetFrom(
+ Int->getType()->getScalarSizeInBits(), LShrAmt);
+
+ return true;
+}
+
+/// Try to fold the join of two scalar integers whose bits are unpacked and
+/// zexted from the same source integer.
+static Value *foldIntegerRepackThroughZExt(Value *Lhs, Value *Rhs,
+ InstCombiner::BuilderTy &Builder) {
+
+ Value *LhsInt, *RhsInt;
+ APInt LhsMask, RhsMask;
+ uint64_t LhsOffset, RhsOffset;
+ bool IsLhsShlNUW, IsLhsShlNSW, IsRhsShlNUW, IsRhsShlNSW;
+ if (!matchZExtedSubInteger(Lhs, LhsInt, LhsMask, LhsOffset, IsLhsShlNUW,
+ IsLhsShlNSW))
+ return nullptr;
+ if (!matchZExtedSubInteger(Rhs, RhsInt, RhsMask, RhsOffset, IsRhsShlNUW,
+ IsRhsShlNSW))
+ return nullptr;
+ if (LhsInt != RhsInt || LhsOffset != RhsOffset)
+ return nullptr;
+
+ APInt Mask = LhsMask | RhsMask;
+
+ Type *DestTy = Lhs->getType();
+ Value *Res = Builder.CreateShl(
+ Builder.CreateZExt(
+ Builder.CreateAnd(LhsInt, Mask, LhsInt->getName() + ".mask"), DestTy,
+ LhsInt->getName() + ".zext"),
+ ConstantInt::get(DestTy, LhsOffset), "", IsLhsShlNUW && IsRhsShlNUW,
+ IsLhsShlNSW && IsRhsShlNSW);
+ Res->takeName(Lhs);
+ return Res;
+}
+
// A decomposition of ((X & Mask) * Factor). The NUW / NSW bools
// track these properities for preservation. Note that we can decompose
// equivalent select form of this expression (e.g. (!(X & Mask) ? 0 : Mask *
@@ -3841,6 +3917,8 @@ static Value *foldBitmaskMul(Value *Op0, Value *Op1,
Value *InstCombinerImpl::foldDisjointOr(Value *LHS, Value *RHS) {
if (Value *Res = foldBitmaskMul(LHS, RHS, Builder))
return Res;
+ if (Value *Res = foldIntegerRepackThroughZExt(LHS, RHS, Builder))
+ return Res;
return nullptr;
}
@@ -3973,7 +4051,7 @@ Instruction *InstCombinerImpl::visitOr(BinaryOperator &I) {
/*NSW=*/true, /*NUW=*/true))
return R;
- if (Value *Res = foldBitmaskMul(I.getOperand(0), I.getOperand(1), Builder))
+ if (Value *Res = foldDisjointOr(I.getOperand(0), I.getOperand(1)))
return replaceInstUsesWith(I, Res);
if (Value *Res = reassociateDisjointOr(I.getOperand(0), I.getOperand(1)))
diff --git a/llvm/test/Transforms/InstCombine/repack-ints-thru-zext.ll b/llvm/test/Transforms/InstCombine/repack-ints-thru-zext.ll
new file mode 100644
index 0000000000000..c90f08b7322ac
--- /dev/null
+++ b/llvm/test/Transforms/InstCombine/repack-ints-thru-zext.ll
@@ -0,0 +1,242 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5
+; RUN: opt -passes=instcombine %s -S | FileCheck %s
+
+declare void @use.i32(i32)
+declare void @use.i64(i64)
+
+define i64 @full_shl(i32 %x) {
+; CHECK-LABEL: define i64 @full_shl(
+; CHECK-SAME: i32 [[X:%.*]]) {
+; CHECK-NEXT: [[X_ZEXT:%.*]] = zext i32 [[X]] to i64
+; CHECK-NEXT: [[LO_SHL:%.*]] = shl nuw nsw i64 [[X_ZEXT]], 24
+; CHECK-NEXT: ret i64 [[LO_SHL]]
+;
+ %lo = and i32 %x, u0xffff
+ %lo.zext = zext nneg i32 %lo to i64
+ %lo.shl = shl nuw nsw i64 %lo.zext, 24
+
+ %hi = lshr i32 %x, 16
+ %hi.zext = zext nneg i32 %hi to i64
+ %hi.shl = shl nuw nsw i64 %hi.zext, 40
+
+ %res = or disjoint i64 %lo.shl, %hi.shl
+ ret i64 %res
+}
+
+define <2 x i64> @full_shl_vec(<2 x i32> %v) {
+; CHECK-LABEL: define <2 x i64> @full_shl_vec(
+; CHECK-SAME: <2 x i32> [[V:%.*]]) {
+; CHECK-NEXT: [[V_ZEXT:%.*]] = zext <2 x i32> [[V]] to <2 x i64>
+; CHECK-NEXT: [[LO_SHL:%.*]] = shl nuw nsw <2 x i64> [[V_ZEXT]], splat (i64 24)
+; CHECK-NEXT: ret <2 x i64> [[LO_SHL]]
+;
+ %lo = and <2 x i32> %v, splat(i32 u0xffff)
+ %lo.zext = zext nneg <2 x i32> %lo to <2 x i64>
+ %lo.shl = shl nuw nsw <2 x i64> %lo.zext, splat(i64 24)
+
+ %hi = lshr <2 x i32> %v, splat(i32 16)
+ %hi.zext = zext nneg <2 x i32> %hi to <2 x i64>
+ %hi.shl = shl nuw nsw <2 x i64> %hi.zext, splat(i64 40)
+
+ %res = or disjoint <2 x i64> %lo.shl, %hi.shl
+ ret <2 x i64> %res
+}
+
+; u0xaabbccdd = -1430532899
+define i64 @partial_shl(i32 %x) {
+; CHECK-LABEL: define i64 @partial_shl(
+; CHECK-SAME: i32 [[X:%.*]]) {
+; CHECK-NEXT: [[X_MASK:%.*]] = and i32 [[X]], -1430532899
+; CHECK-NEXT: [[X_ZEXT:%.*]] = zext i32 [[X_MASK]] to i64
+; CHECK-NEXT: [[LO_SHL:%.*]] = shl nuw nsw i64 [[X_ZEXT]], 24
+; CHECK-NEXT: ret i64 [[LO_SHL]]
+;
+ %lo = and i32 %x, u0xccdd
+ %lo.zext = zext nneg i32 %lo to i64
+ %lo.shl = shl nuw nsw i64 %lo.zext, 24
+
+ %hi = lshr i32 %x, 16
+ %hi.mask = and i32 %hi, u0xaabb
+ %hi.zext = zext nneg i32 %hi.mask to i64
+ %hi.shl = shl nuw nsw i64 %hi.zext, 40
+
+ %res = or disjoint i64 %lo.shl, %hi.shl
+ ret i64 %res
+}
+
+define i64 @shl_multi_use_shl(i32 %x) {
+; CHECK-LABEL: define i64 @shl_multi_use_shl(
+; CHECK-SAME: i32 [[X:%.*]]) {
+; CHECK-NEXT: [[TMP1:%.*]] = shl i32 [[X]], 24
+; CHECK-NEXT: [[LO_SHL:%.*]] = zext i32 [[TMP1]] to i64
+; CHECK-NEXT: call void @use.i64(i64 [[LO_SHL]])
+; CHECK-NEXT: [[HI:%.*]] = lshr i32 [[X]], 16
+; CHECK-NEXT: [[HI_ZEXT:%.*]] = zext nneg i32 [[HI]] to i64
+; CHECK-NEXT: [[HI_SHL:%.*]] = shl nuw nsw i64 [[HI_ZEXT]], 40
+; CHECK-NEXT: [[RES:%.*]] = or disjoint i64 [[HI_SHL]], [[LO_SHL]]
+; CHECK-NEXT: ret i64 [[RES]]
+;
+ %lo = and i32 %x, u0x00ff
+ %lo.zext = zext nneg i32 %lo to i64
+ %lo.shl = shl nuw nsw i64 %lo.zext, 24
+ call void @use.i64(i64 %lo.shl)
+
+ %hi = lshr i32 %x, 16
+ %hi.zext = zext nneg i32 %hi to i64
+ %hi.shl = shl nuw nsw i64 %hi.zext, 40
+
+ %res = or disjoint i64 %lo.shl, %hi.shl
+ ret i64 %res
+}
+
+define i64 @shl_multi_use_zext(i32 %x) {
+; CHECK-LABEL: define i64 @shl_multi_use_zext(
+; CHECK-SAME: i32 [[X:%.*]]) {
+; CHECK-NEXT: [[LO:%.*]] = and i32 [[X]], 255
+; CHECK-NEXT: [[LO_ZEXT:%.*]] = zext nneg i32 [[LO]] to i64
+; CHECK-NEXT: call void @use.i64(i64 [[LO_ZEXT]])
+; CHECK-NEXT: [[LO_SHL:%.*]] = shl nuw nsw i64 [[LO_ZEXT]], 24
+; CHECK-NEXT: [[HI:%.*]] = lshr i32 [[X]], 16
+; CHECK-NEXT: [[HI_ZEXT:%.*]] = zext nneg i32 [[HI]] to i64
+; CHECK-NEXT: [[HI_SHL:%.*]] = shl nuw nsw i64 [[HI_ZEXT]], 40
+; CHECK-NEXT: [[RES:%.*]] = or disjoint i64 [[LO_SHL]], [[HI_SHL]]
+; CHECK-NEXT: ret i64 [[RES]]
+;
+ %lo = and i32 %x, u0x00ff
+ %lo.zext = zext nneg i32 %lo to i64
+ call void @use.i64(i64 %lo.zext)
+ %lo.shl = shl nuw nsw i64 %lo.zext, 24
+
+ %hi = lshr i32 %x, 16
+ %hi.zext = zext nneg i32 %hi to i64
+ %hi.shl = shl nuw nsw i64 %hi.zext, 40
+
+ %res = or disjoint i64 %lo.shl, %hi.shl
+ ret i64 %res
+}
+
+define i64 @shl_multi_use_lshr(i32 %x) {
+; CHECK-LABEL: define i64 @shl_multi_use_lshr(
+; CHECK-SAME: i32 [[X:%.*]]) {
+; CHECK-NEXT: [[TMP1:%.*]] = shl i32 [[X]], 24
+; CHECK-NEXT: [[LO_SHL:%.*]] = zext i32 [[TMP1]] to i64
+; CHECK-NEXT: [[HI:%.*]] = lshr i32 [[X]], 16
+; CHECK-NEXT: call void @use.i32(i32 [[HI]])
+; CHECK-NEXT: [[HI_ZEXT:%.*]] = zext nneg i32 [[HI]] to i64
+; CHECK-NEXT: [[HI_SHL:%.*]] = shl nuw nsw i64 [[HI_ZEXT]], 40
+; CHECK-NEXT: [[RES:%.*]] = or disjoint i64 [[HI_SHL]], [[LO_SHL]]
+; CHECK-NEXT: ret i64 [[RES]]
+;
+ %lo = and i32 %x, u0x00ff
+ %lo.zext = zext nneg i32 %lo to i64
+ %lo.shl = shl nuw nsw i64 %lo.zext, 24
+
+ %hi = lshr i32 %x, 16
+ call void @use.i32(i32 %hi)
+ %hi.zext = zext nneg i32 %hi to i64
+ %hi.shl = shl nuw nsw i64 %hi.zext, 40
+
+ %res = or disjoint i64 %lo.shl, %hi.shl
+ ret i64 %res
+}
+
+define i64 @shl_non_disjoint(i32 %x) {
+; CHECK-LABEL: define i64 @shl_non_disjoint(
+; CHECK-SAME: i32 [[X:%.*]]) {
+; CHECK-NEXT: [[LO:%.*]] = and i32 [[X]], 16711680
+; CHECK-NEXT: [[LO_ZEXT:%.*]] = zext nneg i32 [[LO]] to i64
+; CHECK-NEXT: [[LO_SHL:%.*]] = shl nuw nsw i64 [[LO_ZEXT]], 24
+; CHECK-NEXT: [[HI:%.*]] = lshr i32 [[X]], 16
+; CHECK-NEXT: call void @use.i32(i32 [[HI]])
+; CHECK-NEXT: [[HI_ZEXT:%.*]] = zext nneg i32 [[HI]] to i64
+; CHECK-NEXT: [[HI_SHL:%.*]] = shl nuw nsw i64 [[HI_ZEXT]], 40
+; CHECK-NEXT: [[RES:%.*]] = or i64 [[LO_SHL]], [[HI_SHL]]
+; CHECK-NEXT: ret i64 [[RES]]
+;
+ %lo = and i32 %x, u0x00ff0000
+ %lo.zext = zext nneg i32 %lo to i64
+ %lo.shl = shl nuw nsw i64 %lo.zext, 24
+
+ %hi = lshr i32 %x, 16
+ call void @use.i32(i32 %hi)
+ %hi.zext = zext nneg i32 %hi to i64
+ %hi.shl = shl nuw nsw i64 %hi.zext, 40
+
+ %res = or i64 %lo.shl, %hi.shl
+ ret i64 %res
+}
+
+define i64 @combine(i32 %lower, i32 %upper) {
+; CHECK-LABEL: define i64 @combine(
+; CHECK-SAME: i32 [[LOWER:%.*]], i32 [[UPPER:%.*]]) {
+; CHECK-NEXT: [[BASE:%.*]] = zext i32 [[LOWER]] to i64
+; CHECK-NEXT: [[UPPER_ZEXT:%.*]] = zext i32 [[UPPER]] to i64
+; CHECK-NEXT: [[S_0:%.*]] = shl nuw i64 [[UPPER_ZEXT]], 32
+; CHECK-NEXT: [[O_3:%.*]] = or disjoint i64 [[S_0]], [[BASE]]
+; CHECK-NEXT: ret i64 [[O_3]]
+;
+ %base = zext i32 %lower to i64
+
+ %u.0 = and i32 %upper, u0xff
+ %z.0 = zext i32 %u.0 to i64
+ %s.0 = shl i64 %z.0, 32
+ %o.0 = or i64 %base, %s.0
+
+ %r.1 = lshr i32 %upper, 8
+ %u.1 = and i32 %r.1, u0xff
+ %z.1 = zext i32 %u.1 to i64
+ %s.1 = shl i64 %z.1, 40
+ %o.1 = or i64 %o.0, %s.1
+
+ %r.2 = lshr i32 %upper, 16
+ %u.2 = and i32 %r.2, u0xff
+ %z.2 = zext i32 %u.2 to i64
+ %s.2 = shl i64 %z.2, 48
+ %o.2 = or i64 %o.1, %s.2
+
+ %r.3 = lshr i32 %upper, 24
+ %u.3 = and i32 %r.3, u0xff
+ %z.3 = zext i32 %u.3 to i64
+ %s.3 = shl i64 %z.3, 56
+ %o.3 = or i64 %o.2, %s.3
+
+ ret i64 %o.3
+}
+
+define i64 @combine_2(i32 %lower, i32 %upper) {
+; CHECK-LABEL: define i64 @combine_2(
+; CHECK-SAME: i32 [[LOWER:%.*]], i32 [[UPPER:%.*]]) {
+; CHECK-NEXT: [[BASE:%.*]] = zext i32 [[LOWER]] to i64
+; CHECK-NEXT: [[S_03:%.*]] = zext i32 [[UPPER]] to i64
+; CHECK-NEXT: [[O:%.*]] = shl nuw i64 [[S_03]], 32
+; CHECK-NEXT: [[RES:%.*]] = or disjoint i64 [[O]], [[BASE]]
+; CHECK-NEXT: ret i64 [[RES]]
+;
+ %base = zext i32 %lower to i64
+
+ %u.0 = and i32 %upper, u0xff
+ %z.0 = zext i32 %u.0 to i64
+ %s.0 = shl i64 %z.0, 32
+
+ %r.1 = lshr i32 %upper, 8
+ %u.1 = and i32 %r.1, u0xff
+ %z.1 = zext i32 %u.1 to i64
+ %s.1 = shl i64 %z.1, 40
+ %o.1 = or i64 %s.0, %s.1
+
+ %r.2 = lshr i32 %upper, 16
+ %u.2 = and i32 %r.2, u0xff
+ %z.2 = zext i32 %u.2 to i64
+ %s.2 = shl i64 %z.2, 48
+
+ %r.3 = lshr i32 %upper, 24
+ %u.3 = and i32 %r.3, u0xff
+ %z.3 = zext i32 %u.3 to i64
+ %s.3 = shl i64 %z.3, 56
+ %o.3 = or i64 %s.2, %s.3
+
+ %o = or i64 %o.1, %o.3
+ %res = or i64 %o, %base
+
+ ret i64 %res
+}
More information about the llvm-commits
mailing list