[llvm] [InstCombine] Fold `lshr -> zext -> shl` patterns (PR #147737)

via llvm-commits llvm-commits at lists.llvm.org
Fri Jul 18 10:24:25 PDT 2025


https://github.com/zGoldthorpe updated https://github.com/llvm/llvm-project/pull/147737

>From 3048c6d54aaa3f410b45af4c0ae3ad26c9f9b114 Mon Sep 17 00:00:00 2001
From: Zach Goldthorpe <Zach.Goldthorpe at amd.com>
Date: Tue, 8 Jul 2025 17:52:34 -0500
Subject: [PATCH 1/3] Added patch to fold `lshr + zext + shl` patterns

---
 .../InstCombine/InstCombineShifts.cpp         | 49 ++++++++++++-
 .../Analysis/ValueTracking/numsignbits-shl.ll |  6 +-
 .../Transforms/InstCombine/iX-ext-split.ll    |  6 +-
 .../InstCombine/shifts-around-zext.ll         | 69 +++++++++++++++++++
 4 files changed, 122 insertions(+), 8 deletions(-)
 create mode 100644 llvm/test/Transforms/InstCombine/shifts-around-zext.ll

diff --git a/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp b/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp
index 550f095b26ba4..b0b1301cd2580 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp
@@ -978,6 +978,47 @@ Instruction *InstCombinerImpl::foldLShrOverflowBit(BinaryOperator &I) {
   return new ZExtInst(Overflow, Ty);
 }
 
+/// If the operand of a zext-ed left shift \p V is a logically right-shifted
+/// value, try to fold the opposing shifts.
+static Instruction *foldShrThroughZExtedShl(Type *DestTy, Value *V,
+                                            unsigned ShlAmt,
+                                            InstCombinerImpl &IC,
+                                            const DataLayout &DL) {
+  auto *I = dyn_cast<Instruction>(V);
+  if (!I)
+    return nullptr;
+
+  // Dig through operations until the first shift.
+  while (!I->isShift())
+    if (!match(I, m_BinOp(m_OneUse(m_Instruction(I)), m_Constant())))
+      return nullptr;
+
+  // Fold only if the inner shift is a logical right-shift.
+  uint64_t InnerShrAmt;
+  if (!match(I, m_LShr(m_Value(), m_ConstantInt(InnerShrAmt))))
+    return nullptr;
+
+  if (InnerShrAmt >= ShlAmt) {
+    const uint64_t ReducedShrAmt = InnerShrAmt - ShlAmt;
+    if (!canEvaluateShifted(V, ReducedShrAmt, /*IsLeftShift=*/false, IC,
+                            nullptr))
+      return nullptr;
+    Value *NewInner =
+        getShiftedValue(V, ReducedShrAmt, /*isLeftShift=*/false, IC, DL);
+    return new ZExtInst(NewInner, DestTy);
+  }
+
+  if (!canEvaluateShifted(V, InnerShrAmt, /*IsLeftShift=*/true, IC, nullptr))
+    return nullptr;
+
+  const uint64_t ReducedShlAmt = ShlAmt - InnerShrAmt;
+  Value *NewInner =
+      getShiftedValue(V, InnerShrAmt, /*isLeftShift=*/true, IC, DL);
+  Value *NewZExt = IC.Builder.CreateZExt(NewInner, DestTy);
+  return BinaryOperator::CreateShl(NewZExt,
+                                   ConstantInt::get(DestTy, ReducedShlAmt));
+}
+
 // Try to set nuw/nsw flags on shl or exact flag on lshr/ashr using knownbits.
 static bool setShiftFlags(BinaryOperator &I, const SimplifyQuery &Q) {
   assert(I.isShift() && "Expected a shift as input");
@@ -1062,14 +1103,18 @@ Instruction *InstCombinerImpl::visitShl(BinaryOperator &I) {
   if (match(Op1, m_APInt(C))) {
     unsigned ShAmtC = C->getZExtValue();
 
-    // shl (zext X), C --> zext (shl X, C)
-    // This is only valid if X would have zeros shifted out.
     Value *X;
     if (match(Op0, m_OneUse(m_ZExt(m_Value(X))))) {
+      // shl (zext X), C --> zext (shl X, C)
+      // This is only valid if X would have zeros shifted out.
       unsigned SrcWidth = X->getType()->getScalarSizeInBits();
       if (ShAmtC < SrcWidth &&
           MaskedValueIsZero(X, APInt::getHighBitsSet(SrcWidth, ShAmtC), &I))
         return new ZExtInst(Builder.CreateShl(X, ShAmtC), Ty);
+
+      // Otherwise, try to cancel the outer shl with a lshr inside the zext.
+      if (Instruction *V = foldShrThroughZExtedShl(Ty, X, ShAmtC, *this, DL))
+        return V;
     }
 
     // (X >> C) << C --> X & (-1 << C)
diff --git a/llvm/test/Analysis/ValueTracking/numsignbits-shl.ll b/llvm/test/Analysis/ValueTracking/numsignbits-shl.ll
index 5224d75a157d5..8330fd09090c8 100644
--- a/llvm/test/Analysis/ValueTracking/numsignbits-shl.ll
+++ b/llvm/test/Analysis/ValueTracking/numsignbits-shl.ll
@@ -101,9 +101,9 @@ define void @numsignbits_shl_zext_extended_bits_remains(i8 %x) {
 define void @numsignbits_shl_zext_all_bits_shifted_out(i8 %x) {
 ; CHECK-LABEL: define void @numsignbits_shl_zext_all_bits_shifted_out(
 ; CHECK-SAME: i8 [[X:%.*]]) {
-; CHECK-NEXT:    [[ASHR:%.*]] = lshr i8 [[X]], 5
-; CHECK-NEXT:    [[ZEXT:%.*]] = zext nneg i8 [[ASHR]] to i16
-; CHECK-NEXT:    [[NSB1:%.*]] = shl i16 [[ZEXT]], 14
+; CHECK-NEXT:    [[ASHR:%.*]] = and i8 [[X]], 96
+; CHECK-NEXT:    [[TMP1:%.*]] = zext nneg i8 [[ASHR]] to i16
+; CHECK-NEXT:    [[NSB1:%.*]] = shl nuw i16 [[TMP1]], 9
 ; CHECK-NEXT:    [[AND14:%.*]] = and i16 [[NSB1]], 16384
 ; CHECK-NEXT:    [[ADD14:%.*]] = add i16 [[AND14]], [[NSB1]]
 ; CHECK-NEXT:    call void @escape(i16 [[ADD14]])
diff --git a/llvm/test/Transforms/InstCombine/iX-ext-split.ll b/llvm/test/Transforms/InstCombine/iX-ext-split.ll
index fc804df0e4bec..b8e056725f122 100644
--- a/llvm/test/Transforms/InstCombine/iX-ext-split.ll
+++ b/llvm/test/Transforms/InstCombine/iX-ext-split.ll
@@ -197,9 +197,9 @@ define i128 @i128_ext_split_neg4(i32 %x) {
 ; CHECK-NEXT:  [[ENTRY:.*:]]
 ; CHECK-NEXT:    [[LOWERSRC:%.*]] = sext i32 [[X]] to i64
 ; CHECK-NEXT:    [[LO:%.*]] = zext i64 [[LOWERSRC]] to i128
-; CHECK-NEXT:    [[SIGN:%.*]] = lshr i32 [[X]], 31
-; CHECK-NEXT:    [[WIDEN:%.*]] = zext nneg i32 [[SIGN]] to i128
-; CHECK-NEXT:    [[HI:%.*]] = shl nuw nsw i128 [[WIDEN]], 64
+; CHECK-NEXT:    [[SIGN:%.*]] = and i32 [[X]], -2147483648
+; CHECK-NEXT:    [[TMP0:%.*]] = zext i32 [[SIGN]] to i128
+; CHECK-NEXT:    [[HI:%.*]] = shl nuw nsw i128 [[TMP0]], 33
 ; CHECK-NEXT:    [[RES:%.*]] = or disjoint i128 [[HI]], [[LO]]
 ; CHECK-NEXT:    ret i128 [[RES]]
 ;
diff --git a/llvm/test/Transforms/InstCombine/shifts-around-zext.ll b/llvm/test/Transforms/InstCombine/shifts-around-zext.ll
new file mode 100644
index 0000000000000..517783fcbcb5c
--- /dev/null
+++ b/llvm/test/Transforms/InstCombine/shifts-around-zext.ll
@@ -0,0 +1,69 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5
+; RUN: opt -S -passes=instcombine %s | FileCheck %s
+
+define i64 @simple(i32 %x) {
+; CHECK-LABEL: define i64 @simple(
+; CHECK-SAME: i32 [[X:%.*]]) {
+; CHECK-NEXT:    [[LSHR:%.*]] = and i32 [[X]], -256
+; CHECK-NEXT:    [[TMP1:%.*]] = zext i32 [[LSHR]] to i64
+; CHECK-NEXT:    [[SHL:%.*]] = shl nuw nsw i64 [[TMP1]], 24
+; CHECK-NEXT:    ret i64 [[SHL]]
+;
+  %lshr = lshr i32 %x, 8
+  %zext = zext i32 %lshr to i64
+  %shl = shl i64 %zext, 32
+  ret i64 %shl
+}
+
+;; u0xff0 = 4080
+define i64 @masked(i32 %x) {
+; CHECK-LABEL: define i64 @masked(
+; CHECK-SAME: i32 [[X:%.*]]) {
+; CHECK-NEXT:    [[MASK:%.*]] = and i32 [[X]], 4080
+; CHECK-NEXT:    [[TMP1:%.*]] = zext nneg i32 [[MASK]] to i64
+; CHECK-NEXT:    [[SHL:%.*]] = shl nuw nsw i64 [[TMP1]], 44
+; CHECK-NEXT:    ret i64 [[SHL]]
+;
+  %lshr = lshr i32 %x, 4
+  %mask = and i32 %lshr, u0xff
+  %zext = zext i32 %mask to i64
+  %shl = shl i64 %zext, 48
+  ret i64 %shl
+}
+
+define i64 @combine(i32 %lower, i32 %upper) {
+; CHECK-LABEL: define i64 @combine(
+; CHECK-SAME: i32 [[LOWER:%.*]], i32 [[UPPER:%.*]]) {
+; CHECK-NEXT:    [[BASE:%.*]] = zext i32 [[LOWER]] to i64
+; CHECK-NEXT:    [[TMP1:%.*]] = zext i32 [[UPPER]] to i64
+; CHECK-NEXT:    [[TMP2:%.*]] = shl nuw i64 [[TMP1]], 32
+; CHECK-NEXT:    [[O_3:%.*]] = or disjoint i64 [[TMP2]], [[BASE]]
+; CHECK-NEXT:    ret i64 [[O_3]]
+;
+  %base = zext i32 %lower to i64
+
+  %u.0 = and i32 %upper, u0xff
+  %z.0 = zext i32 %u.0 to i64
+  %s.0 = shl i64 %z.0, 32
+  %o.0 = or i64 %base, %s.0
+
+  %r.1 = lshr i32 %upper, 8
+  %u.1 = and i32 %r.1, u0xff
+  %z.1 = zext i32 %u.1 to i64
+  %s.1 = shl i64 %z.1, 40
+  %o.1 = or i64 %o.0, %s.1
+
+  %r.2 = lshr i32 %upper, 16
+  %u.2 = and i32 %r.2, u0xff
+  %z.2 = zext i32 %u.2 to i64
+  %s.2 = shl i64 %z.2, 48
+  %o.2 = or i64 %o.1, %s.2
+
+  %r.3 = lshr i32 %upper, 24
+  %u.3 = and i32 %r.3, u0xff
+  %z.3 = zext i32 %u.3 to i64
+  %s.3 = shl i64 %z.3, 56
+  %o.3 = or i64 %o.2, %s.3
+
+  ret i64 %o.3
+}

>From 8ea66689bca1005b09f842500d7ece5ad4881386 Mon Sep 17 00:00:00 2001
From: Zach Goldthorpe <Zach.Goldthorpe at amd.com>
Date: Fri, 11 Jul 2025 15:33:46 -0500
Subject: [PATCH 2/3] Incorporated straightforward reviewer feedback.

---
 .../InstCombine/InstCombineShifts.cpp         | 48 +++++++++++--------
 .../InstCombine/shifts-around-zext.ll         | 22 +++++++--
 2 files changed, 46 insertions(+), 24 deletions(-)

diff --git a/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp b/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp
index b0b1301cd2580..edcd963e3a7db 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp
@@ -978,45 +978,53 @@ Instruction *InstCombinerImpl::foldLShrOverflowBit(BinaryOperator &I) {
   return new ZExtInst(Overflow, Ty);
 }
 
-/// If the operand of a zext-ed left shift \p V is a logically right-shifted
-/// value, try to fold the opposing shifts.
-static Instruction *foldShrThroughZExtedShl(Type *DestTy, Value *V,
+/// If the operand \p Op of a zext-ed left shift \p I is a logically
+/// right-shifted value, try to fold the opposing shifts.
+static Instruction *foldShrThroughZExtedShl(BinaryOperator &I, Value *Op,
                                             unsigned ShlAmt,
                                             InstCombinerImpl &IC,
                                             const DataLayout &DL) {
-  auto *I = dyn_cast<Instruction>(V);
-  if (!I)
+  Type *DestTy = I.getType();
+
+  auto *Inner = dyn_cast<Instruction>(Op);
+  if (!Inner)
     return nullptr;
 
   // Dig through operations until the first shift.
-  while (!I->isShift())
-    if (!match(I, m_BinOp(m_OneUse(m_Instruction(I)), m_Constant())))
+  while (!Inner->isShift())
+    if (!match(Inner, m_BinOp(m_OneUse(m_Instruction(Inner)), m_Constant())))
       return nullptr;
 
   // Fold only if the inner shift is a logical right-shift.
-  uint64_t InnerShrAmt;
-  if (!match(I, m_LShr(m_Value(), m_ConstantInt(InnerShrAmt))))
+  const APInt *InnerShrConst;
+  if (!match(Inner, m_LShr(m_Value(), m_APInt(InnerShrConst))))
     return nullptr;
 
+  const uint64_t InnerShrAmt = InnerShrConst->getZExtValue();
   if (InnerShrAmt >= ShlAmt) {
     const uint64_t ReducedShrAmt = InnerShrAmt - ShlAmt;
-    if (!canEvaluateShifted(V, ReducedShrAmt, /*IsLeftShift=*/false, IC,
+    if (!canEvaluateShifted(Op, ReducedShrAmt, /*IsLeftShift=*/false, IC,
                             nullptr))
       return nullptr;
-    Value *NewInner =
-        getShiftedValue(V, ReducedShrAmt, /*isLeftShift=*/false, IC, DL);
-    return new ZExtInst(NewInner, DestTy);
+    Value *NewOp =
+        getShiftedValue(Op, ReducedShrAmt, /*isLeftShift=*/false, IC, DL);
+    return new ZExtInst(NewOp, DestTy);
   }
 
-  if (!canEvaluateShifted(V, InnerShrAmt, /*IsLeftShift=*/true, IC, nullptr))
+  if (!canEvaluateShifted(Op, InnerShrAmt, /*IsLeftShift=*/true, IC, nullptr))
     return nullptr;
 
   const uint64_t ReducedShlAmt = ShlAmt - InnerShrAmt;
-  Value *NewInner =
-      getShiftedValue(V, InnerShrAmt, /*isLeftShift=*/true, IC, DL);
-  Value *NewZExt = IC.Builder.CreateZExt(NewInner, DestTy);
-  return BinaryOperator::CreateShl(NewZExt,
-                                   ConstantInt::get(DestTy, ReducedShlAmt));
+  Value *NewOp = getShiftedValue(Op, InnerShrAmt, /*isLeftShift=*/true, IC, DL);
+  Value *NewZExt = IC.Builder.CreateZExt(NewOp, DestTy);
+  NewZExt->takeName(I.getOperand(0));
+  auto *NewShl = BinaryOperator::CreateShl(
+      NewZExt, ConstantInt::get(DestTy, ReducedShlAmt));
+
+  // New shl inherits all flags from the original shl instruction.
+  NewShl->setHasNoSignedWrap(I.hasNoSignedWrap());
+  NewShl->setHasNoUnsignedWrap(I.hasNoUnsignedWrap());
+  return NewShl;
 }
 
 // Try to set nuw/nsw flags on shl or exact flag on lshr/ashr using knownbits.
@@ -1113,7 +1121,7 @@ Instruction *InstCombinerImpl::visitShl(BinaryOperator &I) {
         return new ZExtInst(Builder.CreateShl(X, ShAmtC), Ty);
 
       // Otherwise, try to cancel the outer shl with a lshr inside the zext.
-      if (Instruction *V = foldShrThroughZExtedShl(Ty, X, ShAmtC, *this, DL))
+      if (Instruction *V = foldShrThroughZExtedShl(I, X, ShAmtC, *this, DL))
         return V;
     }
 
diff --git a/llvm/test/Transforms/InstCombine/shifts-around-zext.ll b/llvm/test/Transforms/InstCombine/shifts-around-zext.ll
index 517783fcbcb5c..818e7b0fc735c 100644
--- a/llvm/test/Transforms/InstCombine/shifts-around-zext.ll
+++ b/llvm/test/Transforms/InstCombine/shifts-around-zext.ll
@@ -5,8 +5,8 @@ define i64 @simple(i32 %x) {
 ; CHECK-LABEL: define i64 @simple(
 ; CHECK-SAME: i32 [[X:%.*]]) {
 ; CHECK-NEXT:    [[LSHR:%.*]] = and i32 [[X]], -256
-; CHECK-NEXT:    [[TMP1:%.*]] = zext i32 [[LSHR]] to i64
-; CHECK-NEXT:    [[SHL:%.*]] = shl nuw nsw i64 [[TMP1]], 24
+; CHECK-NEXT:    [[ZEXT:%.*]] = zext i32 [[LSHR]] to i64
+; CHECK-NEXT:    [[SHL:%.*]] = shl nuw nsw i64 [[ZEXT]], 24
 ; CHECK-NEXT:    ret i64 [[SHL]]
 ;
   %lshr = lshr i32 %x, 8
@@ -20,8 +20,8 @@ define i64 @masked(i32 %x) {
 ; CHECK-LABEL: define i64 @masked(
 ; CHECK-SAME: i32 [[X:%.*]]) {
 ; CHECK-NEXT:    [[MASK:%.*]] = and i32 [[X]], 4080
-; CHECK-NEXT:    [[TMP1:%.*]] = zext nneg i32 [[MASK]] to i64
-; CHECK-NEXT:    [[SHL:%.*]] = shl nuw nsw i64 [[TMP1]], 44
+; CHECK-NEXT:    [[ZEXT:%.*]] = zext nneg i32 [[MASK]] to i64
+; CHECK-NEXT:    [[SHL:%.*]] = shl nuw nsw i64 [[ZEXT]], 44
 ; CHECK-NEXT:    ret i64 [[SHL]]
 ;
   %lshr = lshr i32 %x, 4
@@ -67,3 +67,17 @@ define i64 @combine(i32 %lower, i32 %upper) {
 
   ret i64 %o.3
 }
+
+define <2 x i64> @simple.vec(<2 x i32> %v) {
+; CHECK-LABEL: define <2 x i64> @simple.vec(
+; CHECK-SAME: <2 x i32> [[V:%.*]]) {
+; CHECK-NEXT:    [[LSHR:%.*]] = and <2 x i32> [[V]], splat (i32 -256)
+; CHECK-NEXT:    [[ZEXT:%.*]] = zext <2 x i32> [[LSHR]] to <2 x i64>
+; CHECK-NEXT:    [[SHL:%.*]] = shl nuw nsw <2 x i64> [[ZEXT]], splat (i64 24)
+; CHECK-NEXT:    ret <2 x i64> [[SHL]]
+;
+  %lshr = lshr <2 x i32> %v, splat(i32 8)
+  %zext = zext <2 x i32> %lshr to <2 x i64>
+  %shl = shl <2 x i64> %zext, splat(i64 32)
+  ret <2 x i64> %shl
+}

>From 730920c9517d77623ce53ae1d2edb7625b4ef2df Mon Sep 17 00:00:00 2001
From: Zach Goldthorpe <Zach.Goldthorpe at amd.com>
Date: Mon, 14 Jul 2025 17:19:54 -0500
Subject: [PATCH 3/3] Refactored `canEvaluateShifted` to identify candidates
 for simplification.

---
 .../InstCombine/InstCombineShifts.cpp         | 186 +++++++++++-------
 .../InstCombine/shifts-around-zext.ll         | 107 ++++++++--
 2 files changed, 207 insertions(+), 86 deletions(-)

diff --git a/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp b/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp
index edcd963e3a7db..1436e4bd5854f 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp
@@ -530,92 +530,116 @@ Instruction *InstCombinerImpl::commonShiftTransforms(BinaryOperator &I) {
   return nullptr;
 }
 
-/// Return true if we can simplify two logical (either left or right) shifts
-/// that have constant shift amounts: OuterShift (InnerShift X, C1), C2.
-static bool canEvaluateShiftedShift(unsigned OuterShAmt, bool IsOuterShl,
-                                    Instruction *InnerShift,
-                                    InstCombinerImpl &IC, Instruction *CxtI) {
+/// Return a bitmask of all constant outer shift amounts that can be simplified
+/// by foldShiftedShift().
+static APInt getEvaluableShiftedShiftMask(bool IsOuterShl,
+                                          Instruction *InnerShift,
+                                          InstCombinerImpl &IC,
+                                          Instruction *CxtI) {
   assert(InnerShift->isLogicalShift() && "Unexpected instruction type");
 
+  const unsigned TypeWidth = InnerShift->getType()->getScalarSizeInBits();
+
   // We need constant scalar or constant splat shifts.
   const APInt *InnerShiftConst;
   if (!match(InnerShift->getOperand(1), m_APInt(InnerShiftConst)))
-    return false;
+    return APInt::getZero(TypeWidth);
 
-  // Two logical shifts in the same direction:
+  if (InnerShiftConst->uge(TypeWidth))
+    return APInt::getZero(TypeWidth);
+
+  const unsigned InnerShAmt = InnerShiftConst->getZExtValue();
+
+  // Two logical shifts in the same direction can always be simplified, so long
+  // as the total shift amount is legal.
   // shl (shl X, C1), C2 -->  shl X, C1 + C2
   // lshr (lshr X, C1), C2 --> lshr X, C1 + C2
   bool IsInnerShl = InnerShift->getOpcode() == Instruction::Shl;
   if (IsInnerShl == IsOuterShl)
-    return true;
+    return APInt::getLowBitsSet(TypeWidth, TypeWidth - InnerShAmt);
 
+  APInt ShMask = APInt::getZero(TypeWidth);
   // Equal shift amounts in opposite directions become bitwise 'and':
   // lshr (shl X, C), C --> and X, C'
   // shl (lshr X, C), C --> and X, C'
-  if (*InnerShiftConst == OuterShAmt)
-    return true;
+  ShMask.setBit(InnerShAmt);
 
-  // If the 2nd shift is bigger than the 1st, we can fold:
+  // If the inner shift is bigger than the outer, we can fold:
   // lshr (shl X, C1), C2 -->  and (shl X, C1 - C2), C3
   // shl (lshr X, C1), C2 --> and (lshr X, C1 - C2), C3
-  // but it isn't profitable unless we know the and'd out bits are already zero.
-  // Also, check that the inner shift is valid (less than the type width) or
-  // we'll crash trying to produce the bit mask for the 'and'.
-  unsigned TypeWidth = InnerShift->getType()->getScalarSizeInBits();
-  if (InnerShiftConst->ugt(OuterShAmt) && InnerShiftConst->ult(TypeWidth)) {
-    unsigned InnerShAmt = InnerShiftConst->getZExtValue();
-    unsigned MaskShift =
-        IsInnerShl ? TypeWidth - InnerShAmt : InnerShAmt - OuterShAmt;
-    APInt Mask = APInt::getLowBitsSet(TypeWidth, OuterShAmt) << MaskShift;
-    if (IC.MaskedValueIsZero(InnerShift->getOperand(0), Mask, CxtI))
-      return true;
-  }
-
-  return false;
+  // but it isn't profitable unless we know the masked out bits are already
+  // zero.
+  KnownBits Known = IC.computeKnownBits(InnerShift->getOperand(0), CxtI);
+  // Isolate the bits that are annihilated by the inner shift.
+  APInt InnerShMask = IsInnerShl ? Known.Zero.lshr(TypeWidth - InnerShAmt)
+                                 : Known.Zero.trunc(InnerShAmt);
+  // Isolate the upper (resp. lower) InnerShAmt bits of the base operand of the
+  // inner shl (resp. lshr).
+  // Then:
+  // - lshr (shl X, C1), C2 == (shl X, C1 - C2) if the bottom C2 of the isolated
+  //   bits are zero
+  // - shl (lshr X, C1), C2 == (lshr X, C1 - C2) if the top C2 of the isolated
+  //   bits are zero
+  const unsigned MaxOuterShAmt =
+      IsInnerShl ? Known.Zero.lshr(TypeWidth - InnerShAmt).countr_one()
+                 : Known.Zero.trunc(InnerShAmt).countl_one();
+  ShMask.setLowBits(MaxOuterShAmt);
+  return ShMask;
 }
 
-/// See if we can compute the specified value, but shifted logically to the left
-/// or right by some number of bits. This should return true if the expression
-/// can be computed for the same cost as the current expression tree. This is
-/// used to eliminate extraneous shifting from things like:
-///      %C = shl i128 %A, 64
-///      %D = shl i128 %B, 96
-///      %E = or i128 %C, %D
-///      %F = lshr i128 %E, 64
-/// where the client will ask if E can be computed shifted right by 64-bits. If
-/// this succeeds, getShiftedValue() will be called to produce the value.
-static bool canEvaluateShifted(Value *V, unsigned NumBits, bool IsLeftShift,
-                               InstCombinerImpl &IC, Instruction *CxtI) {
+/// Given a bitmask \p ShiftMask of desired shift amounts, determine the submask
+/// of bits corresponding to shift amounts X for which the given expression \p V
+/// can be computed for at worst the same cost as the current expression tree
+/// when shifted by X. For each set bit in the \p ShiftMask afterward,
+/// getShiftedValue() can produce the corresponding value.
+///
+/// \returns true if and only if at least one bit of the \p ShiftMask is set
+/// after refinement.
+static bool refineEvaluableShiftMask(Value *V, APInt &ShiftMask,
+                                     bool IsLeftShift, InstCombinerImpl &IC,
+                                     Instruction *CxtI) {
   // We can always evaluate immediate constants.
   if (match(V, m_ImmConstant()))
     return true;
 
   Instruction *I = dyn_cast<Instruction>(V);
-  if (!I) return false;
+  if (!I) {
+    ShiftMask.clearAllBits();
+    return false;
+  }
 
   // We can't mutate something that has multiple uses: doing so would
   // require duplicating the instruction in general, which isn't profitable.
-  if (!I->hasOneUse()) return false;
+  if (!I->hasOneUse()) {
+    ShiftMask.clearAllBits();
+    return false;
+  }
 
   switch (I->getOpcode()) {
-  default: return false;
+  default: {
+    ShiftMask.clearAllBits();
+    return false;
+  }
   case Instruction::And:
   case Instruction::Or:
   case Instruction::Xor:
-    // Bitwise operators can all arbitrarily be arbitrarily evaluated shifted.
-    return canEvaluateShifted(I->getOperand(0), NumBits, IsLeftShift, IC, I) &&
-           canEvaluateShifted(I->getOperand(1), NumBits, IsLeftShift, IC, I);
+    return refineEvaluableShiftMask(I->getOperand(0), ShiftMask, IsLeftShift,
+                                    IC, I) &&
+           refineEvaluableShiftMask(I->getOperand(1), ShiftMask, IsLeftShift,
+                                    IC, I);
 
   case Instruction::Shl:
-  case Instruction::LShr:
-    return canEvaluateShiftedShift(NumBits, IsLeftShift, I, IC, CxtI);
+  case Instruction::LShr: {
+    ShiftMask &= getEvaluableShiftedShiftMask(IsLeftShift, I, IC, CxtI);
+    return !ShiftMask.isZero();
+  }
 
   case Instruction::Select: {
     SelectInst *SI = cast<SelectInst>(I);
     Value *TrueVal = SI->getTrueValue();
     Value *FalseVal = SI->getFalseValue();
-    return canEvaluateShifted(TrueVal, NumBits, IsLeftShift, IC, SI) &&
-           canEvaluateShifted(FalseVal, NumBits, IsLeftShift, IC, SI);
+    return refineEvaluableShiftMask(TrueVal, ShiftMask, IsLeftShift, IC, SI) &&
+           refineEvaluableShiftMask(FalseVal, ShiftMask, IsLeftShift, IC, SI);
   }
   case Instruction::PHI: {
     // We can change a phi if we can change all operands.  Note that we never
@@ -623,19 +647,42 @@ static bool canEvaluateShifted(Value *V, unsigned NumBits, bool IsLeftShift,
     // instructions with a single use.
     PHINode *PN = cast<PHINode>(I);
     for (Value *IncValue : PN->incoming_values())
-      if (!canEvaluateShifted(IncValue, NumBits, IsLeftShift, IC, PN))
+      if (!refineEvaluableShiftMask(IncValue, ShiftMask, IsLeftShift, IC, PN))
         return false;
     return true;
   }
   case Instruction::Mul: {
     const APInt *MulConst;
     // We can fold (shr (mul X, -(1 << C)), C) -> (and (neg X), C`)
-    return !IsLeftShift && match(I->getOperand(1), m_APInt(MulConst)) &&
-           MulConst->isNegatedPowerOf2() && MulConst->countr_zero() == NumBits;
+    if (IsLeftShift || !match(I->getOperand(1), m_APInt(MulConst)) ||
+        !MulConst->isNegatedPowerOf2()) {
+      ShiftMask.clearAllBits();
+      return false;
+    }
+    ShiftMask &=
+        APInt::getOneBitSet(ShiftMask.getBitWidth(), MulConst->countr_zero());
+    return !ShiftMask.isZero();
   }
   }
 }
 
+/// See if we can compute the specified value, but shifted logically to the left
+/// or right by some number of bits. This should return true if the expression
+/// can be computed for the same cost as the current expression tree. This is
+/// used to eliminate extraneous shifting from things like:
+///      %C = shl i128 %A, 64
+///      %D = shl i128 %B, 96
+///      %E = or i128 %C, %D
+///      %F = lshr i128 %E, 64
+/// where the client will ask if E can be computed shifted right by 64-bits. If
+/// this succeeds, getShiftedValue() will be called to produce the value.
+static bool canEvaluateShifted(Value *V, unsigned ShAmt, bool IsLeftShift,
+                               InstCombinerImpl &IC, Instruction *CxtI) {
+  APInt ShiftMask =
+      APInt::getOneBitSet(V->getType()->getScalarSizeInBits(), ShAmt);
+  return refineEvaluableShiftMask(V, ShiftMask, IsLeftShift, IC, CxtI);
+}
+
 /// Fold OuterShift (InnerShift X, C1), C2.
 /// See canEvaluateShiftedShift() for the constraints on these instructions.
 static Value *foldShiftedShift(BinaryOperator *InnerShift, unsigned OuterShAmt,
@@ -985,37 +1032,32 @@ static Instruction *foldShrThroughZExtedShl(BinaryOperator &I, Value *Op,
                                             InstCombinerImpl &IC,
                                             const DataLayout &DL) {
   Type *DestTy = I.getType();
+  const unsigned InnerBitWidth = Op->getType()->getScalarSizeInBits();
 
-  auto *Inner = dyn_cast<Instruction>(Op);
-  if (!Inner)
+  // Determine if the operand is effectively right-shifted by counting the
+  // known leading zero bits.
+  KnownBits Known = IC.computeKnownBits(Op, nullptr);
+  const unsigned MaxInnerShrAmt = Known.countMinLeadingZeros();
+  if (MaxInnerShrAmt == 0)
     return nullptr;
+  APInt ShrMask =
+      APInt::getLowBitsSet(InnerBitWidth, std::min(MaxInnerShrAmt, ShlAmt) + 1);
 
-  // Dig through operations until the first shift.
-  while (!Inner->isShift())
-    if (!match(Inner, m_BinOp(m_OneUse(m_Instruction(Inner)), m_Constant())))
-      return nullptr;
-
-  // Fold only if the inner shift is a logical right-shift.
-  const APInt *InnerShrConst;
-  if (!match(Inner, m_LShr(m_Value(), m_APInt(InnerShrConst))))
+  // Undo the maximal inner right shift amount that simplifies the overall
+  // computation.
+  if (!refineEvaluableShiftMask(Op, ShrMask, /*IsLeftShift=*/true, IC, nullptr))
     return nullptr;
 
-  const uint64_t InnerShrAmt = InnerShrConst->getZExtValue();
-  if (InnerShrAmt >= ShlAmt) {
-    const uint64_t ReducedShrAmt = InnerShrAmt - ShlAmt;
-    if (!canEvaluateShifted(Op, ReducedShrAmt, /*IsLeftShift=*/false, IC,
-                            nullptr))
-      return nullptr;
-    Value *NewOp =
-        getShiftedValue(Op, ReducedShrAmt, /*isLeftShift=*/false, IC, DL);
-    return new ZExtInst(NewOp, DestTy);
-  }
-
-  if (!canEvaluateShifted(Op, InnerShrAmt, /*IsLeftShift=*/true, IC, nullptr))
+  const unsigned InnerShrAmt = ShrMask.getActiveBits() - 1;
+  if (InnerShrAmt == 0)
     return nullptr;
+  assert(InnerShrAmt <= ShlAmt);
 
   const uint64_t ReducedShlAmt = ShlAmt - InnerShrAmt;
   Value *NewOp = getShiftedValue(Op, InnerShrAmt, /*isLeftShift=*/true, IC, DL);
+  if (ReducedShlAmt == 0)
+    return new ZExtInst(NewOp, DestTy);
+
   Value *NewZExt = IC.Builder.CreateZExt(NewOp, DestTy);
   NewZExt->takeName(I.getOperand(0));
   auto *NewShl = BinaryOperator::CreateShl(
diff --git a/llvm/test/Transforms/InstCombine/shifts-around-zext.ll b/llvm/test/Transforms/InstCombine/shifts-around-zext.ll
index 818e7b0fc735c..82ed2985b3b16 100644
--- a/llvm/test/Transforms/InstCombine/shifts-around-zext.ll
+++ b/llvm/test/Transforms/InstCombine/shifts-around-zext.ll
@@ -1,6 +1,8 @@
 ; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5
 ; RUN: opt -S -passes=instcombine %s | FileCheck %s
 
+declare void @clobber.i32(i32)
+
 define i64 @simple(i32 %x) {
 ; CHECK-LABEL: define i64 @simple(
 ; CHECK-SAME: i32 [[X:%.*]]) {
@@ -15,6 +17,20 @@ define i64 @simple(i32 %x) {
   ret i64 %shl
 }
 
+define <2 x i64> @simple.vec(<2 x i32> %v) {
+; CHECK-LABEL: define <2 x i64> @simple.vec(
+; CHECK-SAME: <2 x i32> [[V:%.*]]) {
+; CHECK-NEXT:    [[LSHR:%.*]] = and <2 x i32> [[V]], splat (i32 -256)
+; CHECK-NEXT:    [[ZEXT:%.*]] = zext <2 x i32> [[LSHR]] to <2 x i64>
+; CHECK-NEXT:    [[SHL:%.*]] = shl nuw nsw <2 x i64> [[ZEXT]], splat (i64 24)
+; CHECK-NEXT:    ret <2 x i64> [[SHL]]
+;
+  %lshr = lshr <2 x i32> %v, splat(i32 8)
+  %zext = zext <2 x i32> %lshr to <2 x i64>
+  %shl = shl <2 x i64> %zext, splat(i64 32)
+  ret <2 x i64> %shl
+}
+
 ;; u0xff0 = 4080
 define i64 @masked(i32 %x) {
 ; CHECK-LABEL: define i64 @masked(
@@ -31,6 +47,83 @@ define i64 @masked(i32 %x) {
   ret i64 %shl
 }
 
+define i64 @masked.multi_use.0(i32 %x) {
+; CHECK-LABEL: define i64 @masked.multi_use.0(
+; CHECK-SAME: i32 [[X:%.*]]) {
+; CHECK-NEXT:    [[LSHR:%.*]] = lshr i32 [[X]], 4
+; CHECK-NEXT:    call void @clobber.i32(i32 [[LSHR]])
+; CHECK-NEXT:    [[MASK:%.*]] = and i32 [[LSHR]], 255
+; CHECK-NEXT:    [[ZEXT:%.*]] = zext nneg i32 [[MASK]] to i64
+; CHECK-NEXT:    [[SHL:%.*]] = shl nuw nsw i64 [[ZEXT]], 48
+; CHECK-NEXT:    ret i64 [[SHL]]
+;
+  %lshr = lshr i32 %x, 4
+  call void @clobber.i32(i32 %lshr)
+  %mask = and i32 %lshr, u0xff
+  %zext = zext i32 %mask to i64
+  %shl = shl i64 %zext, 48
+  ret i64 %shl
+}
+
+define i64 @masked.multi_use.1(i32 %x) {
+; CHECK-LABEL: define i64 @masked.multi_use.1(
+; CHECK-SAME: i32 [[X:%.*]]) {
+; CHECK-NEXT:    [[LSHR:%.*]] = lshr i32 [[X]], 4
+; CHECK-NEXT:    [[MASK:%.*]] = and i32 [[LSHR]], 255
+; CHECK-NEXT:    call void @clobber.i32(i32 [[MASK]])
+; CHECK-NEXT:    [[ZEXT:%.*]] = zext nneg i32 [[MASK]] to i64
+; CHECK-NEXT:    [[SHL:%.*]] = shl nuw nsw i64 [[ZEXT]], 48
+; CHECK-NEXT:    ret i64 [[SHL]]
+;
+  %lshr = lshr i32 %x, 4
+  %mask = and i32 %lshr, u0xff
+  call void @clobber.i32(i32 %mask)
+  %zext = zext i32 %mask to i64
+  %shl = shl i64 %zext, 48
+  ret i64 %shl
+}
+
+define <2 x i64> @masked.multi_use.2(i32 %x) {
+; CHECK-LABEL: define <2 x i64> @masked.multi_use.2(
+; CHECK-SAME: i32 [[X:%.*]]) {
+; CHECK-NEXT:    [[LSHR:%.*]] = lshr i32 [[X]], 4
+; CHECK-NEXT:    [[MASK:%.*]] = and i32 [[LSHR]], 255
+; CHECK-NEXT:    [[ZEXT:%.*]] = zext nneg i32 [[MASK]] to i64
+; CHECK-NEXT:    [[SHL:%.*]] = shl nuw nsw i64 [[ZEXT]], 48
+; CHECK-NEXT:    [[CLOBBER:%.*]] = xor i32 [[MASK]], 255
+; CHECK-NEXT:    [[CLOBBER_Z:%.*]] = zext nneg i32 [[CLOBBER]] to i64
+; CHECK-NEXT:    [[V_0:%.*]] = insertelement <2 x i64> poison, i64 [[SHL]], i64 0
+; CHECK-NEXT:    [[V_1:%.*]] = insertelement <2 x i64> [[V_0]], i64 [[CLOBBER_Z]], i64 1
+; CHECK-NEXT:    ret <2 x i64> [[V_1]]
+;
+  %lshr = lshr i32 %x, 4
+  %mask = and i32 %lshr, u0xff
+  %zext = zext i32 %mask to i64
+  %shl = shl i64 %zext, 48
+
+  %clobber = xor i32 %mask, u0xff
+  %clobber.z = zext i32 %clobber to i64
+  %v.0 = insertelement <2 x i64> poison, i64 %shl, i32 0
+  %v.1 = insertelement <2 x i64> %v.0, i64 %clobber.z, i32 1
+  ret <2 x i64> %v.1
+}
+
+;; u0xff0 = 4080
+define <2 x i64> @masked.vec(<2 x i32> %v) {
+; CHECK-LABEL: define <2 x i64> @masked.vec(
+; CHECK-SAME: <2 x i32> [[V:%.*]]) {
+; CHECK-NEXT:    [[MASK:%.*]] = and <2 x i32> [[V]], splat (i32 4080)
+; CHECK-NEXT:    [[ZEXT:%.*]] = zext nneg <2 x i32> [[MASK]] to <2 x i64>
+; CHECK-NEXT:    [[SHL:%.*]] = shl nuw nsw <2 x i64> [[ZEXT]], splat (i64 44)
+; CHECK-NEXT:    ret <2 x i64> [[SHL]]
+;
+  %lshr = lshr <2 x i32> %v, splat(i32 4)
+  %mask = and <2 x i32> %lshr, splat(i32 u0xff)
+  %zext = zext <2 x i32> %mask to <2 x i64>
+  %shl = shl <2 x i64> %zext, splat(i64 48)
+  ret <2 x i64> %shl
+}
+
 define i64 @combine(i32 %lower, i32 %upper) {
 ; CHECK-LABEL: define i64 @combine(
 ; CHECK-SAME: i32 [[LOWER:%.*]], i32 [[UPPER:%.*]]) {
@@ -67,17 +160,3 @@ define i64 @combine(i32 %lower, i32 %upper) {
 
   ret i64 %o.3
 }
-
-define <2 x i64> @simple.vec(<2 x i32> %v) {
-; CHECK-LABEL: define <2 x i64> @simple.vec(
-; CHECK-SAME: <2 x i32> [[V:%.*]]) {
-; CHECK-NEXT:    [[LSHR:%.*]] = and <2 x i32> [[V]], splat (i32 -256)
-; CHECK-NEXT:    [[ZEXT:%.*]] = zext <2 x i32> [[LSHR]] to <2 x i64>
-; CHECK-NEXT:    [[SHL:%.*]] = shl nuw nsw <2 x i64> [[ZEXT]], splat (i64 24)
-; CHECK-NEXT:    ret <2 x i64> [[SHL]]
-;
-  %lshr = lshr <2 x i32> %v, splat(i32 8)
-  %zext = zext <2 x i32> %lshr to <2 x i64>
-  %shl = shl <2 x i64> %zext, splat(i64 32)
-  ret <2 x i64> %shl
-}



More information about the llvm-commits mailing list