[llvm] [InstCombine] Fold integer unpack/repack patterns through ZExt (PR #153583)

via llvm-commits llvm-commits at lists.llvm.org
Thu Aug 14 15:24:25 PDT 2025


https://github.com/zGoldthorpe updated https://github.com/llvm/llvm-project/pull/153583

>From 738f098703e6d53e9317325431595cc36cfe4fb2 Mon Sep 17 00:00:00 2001
From: Zach Goldthorpe <Zach.Goldthorpe at amd.com>
Date: Wed, 13 Aug 2025 17:58:42 -0500
Subject: [PATCH 1/2] Implemented simpler pass for folding integer
 unpack/repack patterns.

---
 .../InstCombine/InstCombineAndOrXor.cpp       |  87 +++++++
 .../InstCombine/repack-ints-thru-zext.ll      | 242 ++++++++++++++++++
 2 files changed, 329 insertions(+)
 create mode 100644 llvm/test/Transforms/InstCombine/repack-ints-thru-zext.ll

diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
index d7971e8e3caea..637bf1bed605f 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
@@ -3740,6 +3740,87 @@ static Instruction *foldIntegerPackFromVector(Instruction &I,
   return CastInst::Create(Instruction::BitCast, MaskedVec, I.getType());
 }
 
+/// Match \p V as "lshr -> mask -> zext -> shl".
+///
+/// \p Int is the underlying integer being extracted from.
+/// \p Mask is a bitmask identifying which bits of the integer are being
+/// extracted. \p Offset identifies which bit of the result \p V corresponds to
+/// the least significant bit of \p Int
+static bool matchZExtedSubInteger(Value *V, Value *&Int, APInt &Mask,
+                                  uint64_t &Offset, bool &IsShlNUW,
+                                  bool &IsShlNSW) {
+  Value *ShlOp0;
+  const APInt *ShlConst;
+  if (!match(V, m_OneUse(m_Shl(m_Value(ShlOp0), m_APInt(ShlConst)))))
+    return false;
+
+  IsShlNUW = cast<BinaryOperator>(V)->hasNoUnsignedWrap();
+  IsShlNSW = cast<BinaryOperator>(V)->hasNoSignedWrap();
+
+  Value *ZExtOp0;
+  if (!match(ShlOp0, m_OneUse(m_ZExt(m_Value(ZExtOp0)))))
+    return false;
+
+  Value *MaskedOp0;
+  const APInt *ShiftedMaskConst = nullptr;
+  if (!match(ZExtOp0, m_CombineOr(m_OneUse(m_And(m_Value(MaskedOp0),
+                                                 m_APInt(ShiftedMaskConst))),
+                                  m_Value(MaskedOp0))))
+    return false;
+
+  const APInt *LShrConst = nullptr;
+  if (!match(MaskedOp0,
+             m_CombineOr(m_OneUse(m_LShr(m_Value(Int), m_APInt(LShrConst))),
+                         m_Value(Int))))
+    return false;
+
+  assert(ShlConst != nullptr);
+  const uint64_t ShlAmt = ShlConst->getZExtValue();
+  const uint64_t LShrAmt = LShrConst ? LShrConst->getZExtValue() : 0;
+  if (LShrAmt > ShlAmt)
+    return false;
+
+  Mask = (ShiftedMaskConst
+              ? *ShiftedMaskConst
+              : APInt::getAllOnes(Int->getType()->getScalarSizeInBits()))
+             .shl(LShrAmt);
+  assert(LShrAmt < INT64_MAX && ShlAmt < INT64_MAX);
+  Offset = static_cast<int64_t>(ShlAmt) - static_cast<int64_t>(LShrAmt);
+
+  return true;
+}
+
+/// Try to fold the join of two scalar integers whose bits are unpacked and
+/// zexted from the same source integer.
+static Value *foldIntegerRepackThroughZExt(Value *Lhs, Value *Rhs,
+                                           InstCombiner::BuilderTy &Builder) {
+
+  Value *LhsInt, *RhsInt;
+  APInt LhsMask, RhsMask;
+  uint64_t LhsOffset, RhsOffset;
+  bool IsLhsShlNUW, IsLhsShlNSW, IsRhsShlNUW, IsRhsShlNSW;
+  if (!matchZExtedSubInteger(Lhs, LhsInt, LhsMask, LhsOffset, IsLhsShlNUW,
+                             IsLhsShlNSW))
+    return nullptr;
+  if (!matchZExtedSubInteger(Rhs, RhsInt, RhsMask, RhsOffset, IsRhsShlNUW,
+                             IsRhsShlNSW))
+    return nullptr;
+  if (LhsInt != RhsInt || LhsOffset != RhsOffset)
+    return nullptr;
+
+  APInt Mask = LhsMask | RhsMask;
+
+  Type *DestTy = Lhs->getType();
+  Value *Res = Builder.CreateShl(
+      Builder.CreateZExt(
+          Builder.CreateAnd(LhsInt, Mask, LhsInt->getName() + ".mask"), DestTy,
+          LhsInt->getName() + ".zext"),
+      ConstantInt::get(DestTy, LhsOffset), "", IsLhsShlNUW && IsRhsShlNUW,
+      IsLhsShlNSW && IsRhsShlNSW);
+  Res->takeName(Lhs);
+  return Res;
+}
+
 // A decomposition of ((X & Mask) * Factor). The NUW / NSW bools
 // track these properities for preservation. Note that we can decompose
 // equivalent select form of this expression (e.g. (!(X & Mask) ? 0 : Mask *
@@ -3841,6 +3922,8 @@ static Value *foldBitmaskMul(Value *Op0, Value *Op1,
 Value *InstCombinerImpl::foldDisjointOr(Value *LHS, Value *RHS) {
   if (Value *Res = foldBitmaskMul(LHS, RHS, Builder))
     return Res;
+  if (Value *Res = foldIntegerRepackThroughZExt(LHS, RHS, Builder))
+    return Res;
 
   return nullptr;
 }
@@ -3976,6 +4059,10 @@ Instruction *InstCombinerImpl::visitOr(BinaryOperator &I) {
     if (Value *Res = foldBitmaskMul(I.getOperand(0), I.getOperand(1), Builder))
       return replaceInstUsesWith(I, Res);
 
+    if (Value *Res = foldIntegerRepackThroughZExt(I.getOperand(0),
+                                                  I.getOperand(1), Builder))
+      return replaceInstUsesWith(I, Res);
+
     if (Value *Res = reassociateDisjointOr(I.getOperand(0), I.getOperand(1)))
       return replaceInstUsesWith(I, Res);
   }
diff --git a/llvm/test/Transforms/InstCombine/repack-ints-thru-zext.ll b/llvm/test/Transforms/InstCombine/repack-ints-thru-zext.ll
new file mode 100644
index 0000000000000..c90f08b7322ac
--- /dev/null
+++ b/llvm/test/Transforms/InstCombine/repack-ints-thru-zext.ll
@@ -0,0 +1,242 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5
+; RUN: opt -passes=instcombine %s -S | FileCheck %s
+
+declare void @use.i32(i32)
+declare void @use.i64(i64)
+
+define i64 @full_shl(i32 %x) {
+; CHECK-LABEL: define i64 @full_shl(
+; CHECK-SAME: i32 [[X:%.*]]) {
+; CHECK-NEXT:    [[X_ZEXT:%.*]] = zext i32 [[X]] to i64
+; CHECK-NEXT:    [[LO_SHL:%.*]] = shl nuw nsw i64 [[X_ZEXT]], 24
+; CHECK-NEXT:    ret i64 [[LO_SHL]]
+;
+  %lo = and i32 %x, u0xffff
+  %lo.zext = zext nneg i32 %lo to i64
+  %lo.shl = shl nuw nsw i64 %lo.zext, 24
+
+  %hi = lshr i32 %x, 16
+  %hi.zext = zext nneg i32 %hi to i64
+  %hi.shl = shl nuw nsw i64 %hi.zext, 40
+
+  %res = or disjoint i64 %lo.shl, %hi.shl
+  ret i64 %res
+}
+
+define <2 x i64> @full_shl_vec(<2 x i32> %v) {
+; CHECK-LABEL: define <2 x i64> @full_shl_vec(
+; CHECK-SAME: <2 x i32> [[V:%.*]]) {
+; CHECK-NEXT:    [[V_ZEXT:%.*]] = zext <2 x i32> [[V]] to <2 x i64>
+; CHECK-NEXT:    [[LO_SHL:%.*]] = shl nuw nsw <2 x i64> [[V_ZEXT]], splat (i64 24)
+; CHECK-NEXT:    ret <2 x i64> [[LO_SHL]]
+;
+  %lo = and <2 x i32> %v, splat(i32 u0xffff)
+  %lo.zext = zext nneg <2 x i32> %lo to <2 x i64>
+  %lo.shl = shl nuw nsw <2 x i64> %lo.zext, splat(i64 24)
+
+  %hi = lshr <2 x i32> %v, splat(i32 16)
+  %hi.zext = zext nneg <2 x i32> %hi to <2 x i64>
+  %hi.shl = shl nuw nsw <2 x i64> %hi.zext, splat(i64 40)
+
+  %res = or disjoint <2 x i64> %lo.shl, %hi.shl
+  ret <2 x i64> %res
+}
+
+; u0xaabbccdd = -1430532899
+define i64 @partial_shl(i32 %x) {
+; CHECK-LABEL: define i64 @partial_shl(
+; CHECK-SAME: i32 [[X:%.*]]) {
+; CHECK-NEXT:    [[X_MASK:%.*]] = and i32 [[X]], -1430532899
+; CHECK-NEXT:    [[X_ZEXT:%.*]] = zext i32 [[X_MASK]] to i64
+; CHECK-NEXT:    [[LO_SHL:%.*]] = shl nuw nsw i64 [[X_ZEXT]], 24
+; CHECK-NEXT:    ret i64 [[LO_SHL]]
+;
+  %lo = and i32 %x, u0xccdd
+  %lo.zext = zext nneg i32 %lo to i64
+  %lo.shl = shl nuw nsw i64 %lo.zext, 24
+
+  %hi = lshr i32 %x, 16
+  %hi.mask = and i32 %hi, u0xaabb
+  %hi.zext = zext nneg i32 %hi.mask to i64
+  %hi.shl = shl nuw nsw i64 %hi.zext, 40
+
+  %res = or disjoint i64 %lo.shl, %hi.shl
+  ret i64 %res
+}
+
+define i64 @shl_multi_use_shl(i32 %x) {
+; CHECK-LABEL: define i64 @shl_multi_use_shl(
+; CHECK-SAME: i32 [[X:%.*]]) {
+; CHECK-NEXT:    [[TMP1:%.*]] = shl i32 [[X]], 24
+; CHECK-NEXT:    [[LO_SHL:%.*]] = zext i32 [[TMP1]] to i64
+; CHECK-NEXT:    call void @use.i64(i64 [[LO_SHL]])
+; CHECK-NEXT:    [[HI:%.*]] = lshr i32 [[X]], 16
+; CHECK-NEXT:    [[HI_ZEXT:%.*]] = zext nneg i32 [[HI]] to i64
+; CHECK-NEXT:    [[HI_SHL:%.*]] = shl nuw nsw i64 [[HI_ZEXT]], 40
+; CHECK-NEXT:    [[RES:%.*]] = or disjoint i64 [[HI_SHL]], [[LO_SHL]]
+; CHECK-NEXT:    ret i64 [[RES]]
+;
+  %lo = and i32 %x, u0x00ff
+  %lo.zext = zext nneg i32 %lo to i64
+  %lo.shl = shl nuw nsw i64 %lo.zext, 24
+  call void @use.i64(i64 %lo.shl)
+
+  %hi = lshr i32 %x, 16
+  %hi.zext = zext nneg i32 %hi to i64
+  %hi.shl = shl nuw nsw i64 %hi.zext, 40
+
+  %res = or disjoint i64 %lo.shl, %hi.shl
+  ret i64 %res
+}
+
+define i64 @shl_multi_use_zext(i32 %x) {
+; CHECK-LABEL: define i64 @shl_multi_use_zext(
+; CHECK-SAME: i32 [[X:%.*]]) {
+; CHECK-NEXT:    [[LO:%.*]] = and i32 [[X]], 255
+; CHECK-NEXT:    [[LO_ZEXT:%.*]] = zext nneg i32 [[LO]] to i64
+; CHECK-NEXT:    call void @use.i64(i64 [[LO_ZEXT]])
+; CHECK-NEXT:    [[LO_SHL:%.*]] = shl nuw nsw i64 [[LO_ZEXT]], 24
+; CHECK-NEXT:    [[HI:%.*]] = lshr i32 [[X]], 16
+; CHECK-NEXT:    [[HI_ZEXT:%.*]] = zext nneg i32 [[HI]] to i64
+; CHECK-NEXT:    [[HI_SHL:%.*]] = shl nuw nsw i64 [[HI_ZEXT]], 40
+; CHECK-NEXT:    [[RES:%.*]] = or disjoint i64 [[LO_SHL]], [[HI_SHL]]
+; CHECK-NEXT:    ret i64 [[RES]]
+;
+  %lo = and i32 %x, u0x00ff
+  %lo.zext = zext nneg i32 %lo to i64
+  call void @use.i64(i64 %lo.zext)
+  %lo.shl = shl nuw nsw i64 %lo.zext, 24
+
+  %hi = lshr i32 %x, 16
+  %hi.zext = zext nneg i32 %hi to i64
+  %hi.shl = shl nuw nsw i64 %hi.zext, 40
+
+  %res = or disjoint i64 %lo.shl, %hi.shl
+  ret i64 %res
+}
+
+define i64 @shl_multi_use_lshr(i32 %x) {
+; CHECK-LABEL: define i64 @shl_multi_use_lshr(
+; CHECK-SAME: i32 [[X:%.*]]) {
+; CHECK-NEXT:    [[TMP1:%.*]] = shl i32 [[X]], 24
+; CHECK-NEXT:    [[LO_SHL:%.*]] = zext i32 [[TMP1]] to i64
+; CHECK-NEXT:    [[HI:%.*]] = lshr i32 [[X]], 16
+; CHECK-NEXT:    call void @use.i32(i32 [[HI]])
+; CHECK-NEXT:    [[HI_ZEXT:%.*]] = zext nneg i32 [[HI]] to i64
+; CHECK-NEXT:    [[HI_SHL:%.*]] = shl nuw nsw i64 [[HI_ZEXT]], 40
+; CHECK-NEXT:    [[RES:%.*]] = or disjoint i64 [[HI_SHL]], [[LO_SHL]]
+; CHECK-NEXT:    ret i64 [[RES]]
+;
+  %lo = and i32 %x, u0x00ff
+  %lo.zext = zext nneg i32 %lo to i64
+  %lo.shl = shl nuw nsw i64 %lo.zext, 24
+
+  %hi = lshr i32 %x, 16
+  call void @use.i32(i32 %hi)
+  %hi.zext = zext nneg i32 %hi to i64
+  %hi.shl = shl nuw nsw i64 %hi.zext, 40
+
+  %res = or disjoint i64 %lo.shl, %hi.shl
+  ret i64 %res
+}
+
+define i64 @shl_non_disjoint(i32 %x) {
+; CHECK-LABEL: define i64 @shl_non_disjoint(
+; CHECK-SAME: i32 [[X:%.*]]) {
+; CHECK-NEXT:    [[LO:%.*]] = and i32 [[X]], 16711680
+; CHECK-NEXT:    [[LO_ZEXT:%.*]] = zext nneg i32 [[LO]] to i64
+; CHECK-NEXT:    [[LO_SHL:%.*]] = shl nuw nsw i64 [[LO_ZEXT]], 24
+; CHECK-NEXT:    [[HI:%.*]] = lshr i32 [[X]], 16
+; CHECK-NEXT:    call void @use.i32(i32 [[HI]])
+; CHECK-NEXT:    [[HI_ZEXT:%.*]] = zext nneg i32 [[HI]] to i64
+; CHECK-NEXT:    [[HI_SHL:%.*]] = shl nuw nsw i64 [[HI_ZEXT]], 40
+; CHECK-NEXT:    [[RES:%.*]] = or i64 [[LO_SHL]], [[HI_SHL]]
+; CHECK-NEXT:    ret i64 [[RES]]
+;
+  %lo = and i32 %x, u0x00ff0000
+  %lo.zext = zext nneg i32 %lo to i64
+  %lo.shl = shl nuw nsw i64 %lo.zext, 24
+
+  %hi = lshr i32 %x, 16
+  call void @use.i32(i32 %hi)
+  %hi.zext = zext nneg i32 %hi to i64
+  %hi.shl = shl nuw nsw i64 %hi.zext, 40
+
+  %res = or i64 %lo.shl, %hi.shl
+  ret i64 %res
+}
+
+define i64 @combine(i32 %lower, i32 %upper) {
+; CHECK-LABEL: define i64 @combine(
+; CHECK-SAME: i32 [[LOWER:%.*]], i32 [[UPPER:%.*]]) {
+; CHECK-NEXT:    [[BASE:%.*]] = zext i32 [[LOWER]] to i64
+; CHECK-NEXT:    [[UPPER_ZEXT:%.*]] = zext i32 [[UPPER]] to i64
+; CHECK-NEXT:    [[S_0:%.*]] = shl nuw i64 [[UPPER_ZEXT]], 32
+; CHECK-NEXT:    [[O_3:%.*]] = or disjoint i64 [[S_0]], [[BASE]]
+; CHECK-NEXT:    ret i64 [[O_3]]
+;
+  %base = zext i32 %lower to i64
+
+  %u.0 = and i32 %upper, u0xff
+  %z.0 = zext i32 %u.0 to i64
+  %s.0 = shl i64 %z.0, 32
+  %o.0 = or i64 %base, %s.0
+
+  %r.1 = lshr i32 %upper, 8
+  %u.1 = and i32 %r.1, u0xff
+  %z.1 = zext i32 %u.1 to i64
+  %s.1 = shl i64 %z.1, 40
+  %o.1 = or i64 %o.0, %s.1
+
+  %r.2 = lshr i32 %upper, 16
+  %u.2 = and i32 %r.2, u0xff
+  %z.2 = zext i32 %u.2 to i64
+  %s.2 = shl i64 %z.2, 48
+  %o.2 = or i64 %o.1, %s.2
+
+  %r.3 = lshr i32 %upper, 24
+  %u.3 = and i32 %r.3, u0xff
+  %z.3 = zext i32 %u.3 to i64
+  %s.3 = shl i64 %z.3, 56
+  %o.3 = or i64 %o.2, %s.3
+
+  ret i64 %o.3
+}
+
+define i64 @combine_2(i32 %lower, i32 %upper) {
+; CHECK-LABEL: define i64 @combine_2(
+; CHECK-SAME: i32 [[LOWER:%.*]], i32 [[UPPER:%.*]]) {
+; CHECK-NEXT:    [[BASE:%.*]] = zext i32 [[LOWER]] to i64
+; CHECK-NEXT:    [[S_03:%.*]] = zext i32 [[UPPER]] to i64
+; CHECK-NEXT:    [[O:%.*]] = shl nuw i64 [[S_03]], 32
+; CHECK-NEXT:    [[RES:%.*]] = or disjoint i64 [[O]], [[BASE]]
+; CHECK-NEXT:    ret i64 [[RES]]
+;
+  %base = zext i32 %lower to i64
+
+  %u.0 = and i32 %upper, u0xff
+  %z.0 = zext i32 %u.0 to i64
+  %s.0 = shl i64 %z.0, 32
+
+  %r.1 = lshr i32 %upper, 8
+  %u.1 = and i32 %r.1, u0xff
+  %z.1 = zext i32 %u.1 to i64
+  %s.1 = shl i64 %z.1, 40
+  %o.1 = or i64 %s.0, %s.1
+
+  %r.2 = lshr i32 %upper, 16
+  %u.2 = and i32 %r.2, u0xff
+  %z.2 = zext i32 %u.2 to i64
+  %s.2 = shl i64 %z.2, 48
+
+  %r.3 = lshr i32 %upper, 24
+  %u.3 = and i32 %r.3, u0xff
+  %z.3 = zext i32 %u.3 to i64
+  %s.3 = shl i64 %z.3, 56
+  %o.3 = or i64 %s.2, %s.3
+
+  %o = or i64 %o.1, %o.3
+  %res = or i64 %o, %base
+
+  ret i64 %res
+}

>From a5974960ccdc9e15960b139f10e140b4ad8ba825 Mon Sep 17 00:00:00 2001
From: Zach Goldthorpe <Zach.Goldthorpe at amd.com>
Date: Thu, 14 Aug 2025 14:41:24 -0500
Subject: [PATCH 2/2] Addressed reviewer feedback.

---
 .../InstCombine/InstCombineAndOrXor.cpp         | 17 +++++------------
 1 file changed, 5 insertions(+), 12 deletions(-)

diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
index 637bf1bed605f..f8764864398b7 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
@@ -3774,18 +3774,15 @@ static bool matchZExtedSubInteger(Value *V, Value *&Int, APInt &Mask,
                          m_Value(Int))))
     return false;
 
-  assert(ShlConst != nullptr);
   const uint64_t ShlAmt = ShlConst->getZExtValue();
   const uint64_t LShrAmt = LShrConst ? LShrConst->getZExtValue() : 0;
   if (LShrAmt > ShlAmt)
     return false;
+  Offset = ShlAmt - LShrAmt;
 
-  Mask = (ShiftedMaskConst
-              ? *ShiftedMaskConst
-              : APInt::getAllOnes(Int->getType()->getScalarSizeInBits()))
-             .shl(LShrAmt);
-  assert(LShrAmt < INT64_MAX && ShlAmt < INT64_MAX);
-  Offset = static_cast<int64_t>(ShlAmt) - static_cast<int64_t>(LShrAmt);
+  Mask = ShiftedMaskConst ? ShiftedMaskConst->shl(LShrAmt)
+                          : APInt::getBitsSetFrom(
+                                Int->getType()->getScalarSizeInBits(), LShrAmt);
 
   return true;
 }
@@ -4056,11 +4053,7 @@ Instruction *InstCombinerImpl::visitOr(BinaryOperator &I) {
                                    /*NSW=*/true, /*NUW=*/true))
       return R;
 
-    if (Value *Res = foldBitmaskMul(I.getOperand(0), I.getOperand(1), Builder))
-      return replaceInstUsesWith(I, Res);
-
-    if (Value *Res = foldIntegerRepackThroughZExt(I.getOperand(0),
-                                                  I.getOperand(1), Builder))
+    if (Value *Res = foldDisjointOr(I.getOperand(0), I.getOperand(1)))
       return replaceInstUsesWith(I, Res);
 
     if (Value *Res = reassociateDisjointOr(I.getOperand(0), I.getOperand(1)))



More information about the llvm-commits mailing list