[llvm] [CVP] Implement type narrowing for LShr (PR #119577)

via llvm-commits llvm-commits at lists.llvm.org
Fri Dec 13 11:42:11 PST 2024


https://github.com/adam-bzowski updated https://github.com/llvm/llvm-project/pull/119577

>From 44e162b13f21a98d9c138b444524b95f2389bfff Mon Sep 17 00:00:00 2001
From: "Bzowski, Adam" <adam.bzowski at intel.com>
Date: Wed, 11 Dec 2024 08:07:31 -0800
Subject: [PATCH 1/5] [CVP] Implement type narrowing for LShr

Implements type narrowing for LShr. The treatment is analogous to the type narrowing of UDiv. Since LShr is a relatively cheap instruction, the narrowing occurs only if the following conditions hold: i) all the users of the LShr instruction are already TruncInst; ii) the narrowing is carried out to the largest TruncInst following the LShr instruction. Additionally, the function optimizes the cases where the result of the LShr instruction is guaranteed to vanish or be equal to poison.
---
 .../Scalar/CorrelatedValuePropagation.cpp     | 133 ++++++++++
 .../lshr-plus-instcombine.ll                  | 114 +++++++++
 .../CorrelatedValuePropagation/lshr.ll        | 241 ++++++++++++++++++
 3 files changed, 488 insertions(+)
 create mode 100644 llvm/test/Transforms/CorrelatedValuePropagation/lshr-plus-instcombine.ll
 create mode 100644 llvm/test/Transforms/CorrelatedValuePropagation/lshr.ll

diff --git a/llvm/lib/Transforms/Scalar/CorrelatedValuePropagation.cpp b/llvm/lib/Transforms/Scalar/CorrelatedValuePropagation.cpp
index 8e74b8645fad9a..9b81adf77bdf52 100644
--- a/llvm/lib/Transforms/Scalar/CorrelatedValuePropagation.cpp
+++ b/llvm/lib/Transforms/Scalar/CorrelatedValuePropagation.cpp
@@ -40,6 +40,7 @@
 #include "llvm/Support/Casting.h"
 #include "llvm/Transforms/Utils/Local.h"
 #include <cassert>
+#include <limits>
 #include <optional>
 #include <utility>
 
@@ -60,6 +61,8 @@ STATISTIC(NumUDivURemsNarrowed,
           "Number of udivs/urems whose width was decreased");
 STATISTIC(NumAShrsConverted, "Number of ashr converted to lshr");
 STATISTIC(NumAShrsRemoved, "Number of ashr removed");
+STATISTIC(NumLShrsRemoved, "Number of lshr removed");
+STATISTIC(NumLShrsNarrowed, "Number of lshrs whose width was decreased");
 STATISTIC(NumSRems,     "Number of srem converted to urem");
 STATISTIC(NumSExt,      "Number of sext converted to zext");
 STATISTIC(NumSIToFP,    "Number of sitofp converted to uitofp");
@@ -93,6 +96,10 @@ STATISTIC(NumUDivURemsNarrowedExpanded,
           "Number of bound udiv's/urem's expanded");
 STATISTIC(NumNNeg, "Number of zext/uitofp non-negative deductions");
 
+static cl::opt<bool>
+    NarrowLShr("correlated-propagation-narrow-lshr", cl::init(true), cl::Hidden,
+               cl::desc("Enable narrowing of LShr instructions."));
+
 static Constant *getConstantAt(Value *V, Instruction *At, LazyValueInfo *LVI) {
   if (Constant *C = LVI->getConstant(V, At))
     return C;
@@ -1067,6 +1074,124 @@ static bool processSDivOrSRem(BinaryOperator *Instr, LazyValueInfo *LVI) {
   return narrowSDivOrSRem(Instr, LCR, RCR);
 }
 
+/**
+ * @brief Narrows type of the LShr instruction if the range of the possible
+ * values fits into a smaller type. Since LShr is a relatively cheap
+ * instruction, the narrowing should not happen too frequently. Performance
+ * testing and compatibility with other passes indicate that the narrowing is
+ * beneficial under the following circumstances:
+ *
+ * i) the narrowing occurs only if all the users of the LShr instruction are
+ * already TruncInst;
+ *
+ * ii) the narrowing is carried out to the largest TruncInst following the LShr
+ * instruction.
+ *
+ * Additionally, the function optimizes the cases where the result of the LShr
+ * instruction is guaranteed to vanish or be equal to poison.
+ */
+static bool narrowLShr(BinaryOperator *LShr, LazyValueInfo *LVI) {
+
+  IntegerType *RetTy = dyn_cast<IntegerType>(LShr->getType());
+  if (!RetTy)
+    return false;
+
+  ConstantRange ArgRange = LVI->getConstantRangeAtUse(LShr->getOperandUse(0),
+                                                      /*UndefAllowed*/ false);
+  ConstantRange ShiftRange = LVI->getConstantRangeAtUse(LShr->getOperandUse(1),
+                                                        /*UndefAllowed*/ false);
+
+  unsigned OrigWidth = RetTy->getScalarSizeInBits();
+  unsigned MaxActiveBitsInArg = ArgRange.getActiveBits();
+  uint64_t MinShiftValue64 = ShiftRange.getUnsignedMin().getZExtValue();
+  unsigned MinShiftValue =
+      MinShiftValue64 < std::numeric_limits<unsigned>::max()
+          ? static_cast<unsigned>(MinShiftValue64)
+          : std::numeric_limits<unsigned>::max();
+
+  // First we deal with the cases where the result is guaranteed to vanish or be
+  // equal to posion.
+
+  auto replaceWith = [&](Value *V) -> void {
+    LShr->replaceAllUsesWith(V);
+    LShr->eraseFromParent();
+    ++NumLShrsRemoved;
+  };
+
+  // If the shift is larger or equal to the bit width of the argument,
+  // the instruction returns a poison value.
+  if (MinShiftValue >= OrigWidth) {
+    replaceWith(PoisonValue::get(RetTy));
+    return true;
+  }
+
+  // If we are guaranteed to shift away all bits,
+  // we replace the shift by the null value.
+  // We should not apply the optimization if LShr is exact,
+  // as the result may be poison.
+  if (!LShr->isExact() && MinShiftValue >= MaxActiveBitsInArg) {
+    replaceWith(Constant::getNullValue(RetTy));
+    return true;
+  }
+
+  // That's how many bits we need.
+  unsigned MaxActiveBits =
+      std::max(MaxActiveBitsInArg, ShiftRange.getActiveBits());
+
+  // We could do better, but is it worth it?
+  // With the first argument being the n-bit integer, we may limit the value of
+  // the second argument to be less than n, as larger shifts would lead to a
+  // vanishing result or poison. Thus the number of bits in the second argument
+  // is limited by Log2(n). Unfortunately, this would require an introduction of
+  // a select instruction (or llvm.min) to make sure every argument larger than
+  // n is mapped to n and not just truncated. We do not implement it here.
+
+  // What is the smallest bit width that can accommodate the entire value ranges
+  // of both of the operands? Don't shrink below 8 bits wide.
+  unsigned NewWidth = std::max<unsigned>(PowerOf2Ceil(MaxActiveBits), 8);
+
+  // NewWidth might be greater than OrigWidth if OrigWidth is not a power of
+  // two.
+  if (NewWidth >= OrigWidth)
+    return false;
+
+  // This is the time to check if all the users are TruncInst
+  // and to figure out what the largest user is.
+  for (User *user : LShr->users()) {
+    if (TruncInst *TI = dyn_cast<TruncInst>(user)) {
+      NewWidth = std::max(NewWidth, TI->getDestTy()->getScalarSizeInBits());
+    } else {
+      return false;
+    }
+  }
+
+  // We are ready to truncate.
+  IRBuilder<> B(LShr);
+  Type *TruncTy = RetTy->getWithNewBitWidth(NewWidth);
+  Value *ArgTrunc = B.CreateTruncOrBitCast(LShr->getOperand(0), TruncTy,
+                                           LShr->getName() + ".arg.trunc");
+  Value *ShiftTrunc = B.CreateTruncOrBitCast(LShr->getOperand(1), TruncTy,
+                                             LShr->getName() + ".shift.trunc");
+  Value *LShrTrunc =
+      B.CreateBinOp(LShr->getOpcode(), ArgTrunc, ShiftTrunc, LShr->getName());
+  Value *Zext = B.CreateZExt(LShrTrunc, RetTy, LShr->getName() + ".zext");
+
+  // Should always cast, but better safe than sorry.
+  if (BinaryOperator *LShrTruncBO = dyn_cast<BinaryOperator>(LShrTrunc)) {
+    LShrTruncBO->setDebugLoc(LShr->getDebugLoc());
+    LShrTruncBO->setIsExact(LShr->isExact());
+  }
+  LShr->replaceAllUsesWith(Zext);
+  LShr->eraseFromParent();
+
+  ++NumLShrsNarrowed;
+  return true;
+}
+
+static bool processLShr(BinaryOperator *SDI, LazyValueInfo *LVI) {
+  return NarrowLShr ? narrowLShr(SDI, LVI) : false;
+}
+
 static bool processAShr(BinaryOperator *SDI, LazyValueInfo *LVI) {
   ConstantRange LRange =
       LVI->getConstantRangeAtUse(SDI->getOperandUse(0), /*UndefAllowed*/ false);
@@ -1093,6 +1218,11 @@ static bool processAShr(BinaryOperator *SDI, LazyValueInfo *LVI) {
   SDI->replaceAllUsesWith(BO);
   SDI->eraseFromParent();
 
+  // Check if the new LShr can be narrowed.
+  if (NarrowLShr) {
+    narrowLShr(BO, LVI);
+  }
+
   return true;
 }
 
@@ -1254,6 +1384,9 @@ static bool runImpl(Function &F, LazyValueInfo *LVI, DominatorTree *DT,
       case Instruction::AShr:
         BBChanged |= processAShr(cast<BinaryOperator>(&II), LVI);
         break;
+      case Instruction::LShr:
+        BBChanged |= processLShr(cast<BinaryOperator>(&II), LVI);
+        break;
       case Instruction::SExt:
         BBChanged |= processSExt(cast<SExtInst>(&II), LVI);
         break;
diff --git a/llvm/test/Transforms/CorrelatedValuePropagation/lshr-plus-instcombine.ll b/llvm/test/Transforms/CorrelatedValuePropagation/lshr-plus-instcombine.ll
new file mode 100644
index 00000000000000..3646ff9aee4783
--- /dev/null
+++ b/llvm/test/Transforms/CorrelatedValuePropagation/lshr-plus-instcombine.ll
@@ -0,0 +1,114 @@
+; RUN: opt < %s -passes="correlated-propagation,instcombine" -S | FileCheck %s
+
+; The tests below are the same as in lshr.ll
+; Here we test whether the CorrelatedValuePropagation pass 
+; composed with InstCombinePass produces the expected optimizations.
+
+; CHECK-LABEL: @trunc_test1
+; CHECK-NEXT: [[A1:%.*]] = lshr i32 [[A:%.*]], 16
+; CHECK-NEXT: [[CARG:%.*]] = trunc nuw i32 [[A1]] to i16
+; CHECK-NEXT: [[CSHIFT:%.*]] = trunc i32 [[B:%.*]] to i16
+; CHECK-NEXT: [[C1:%.*]] = lshr i16 [[CARG]], [[CSHIFT]]
+; CHECK-NEXT: ret i16 [[C1]]
+
+define i16 @trunc_test1(i32 %a, i32 %b) {
+  %a.eff.trunc = lshr i32 %a, 16
+  %b.eff.trunc = and i32 %b, 65535
+  %c = lshr i32 %a.eff.trunc, %b.eff.trunc
+  %c.trunc = trunc i32 %c to i16
+  ret i16 %c.trunc
+}
+
+; CHECK-LABEL: @trunc_test2
+; CHECK-NEXT: [[C1:%.*]] = lshr i16 [[A:%.*]], 2
+; CHECK-NEXT: ret i16 [[C1]]
+
+define i16 @trunc_test2(i16 %a) {
+  %a.ext = zext i16 %a to i32
+  %c = lshr i32 %a.ext, 2
+  %c.trunc = trunc i32 %c to i16
+  ret i16 %c.trunc
+}
+
+; CHECK-LABEL: @trunc_test3
+; CHECK-NEXT: [[B:%.*]] = lshr i16 [[A:%.*]], 2
+; CHECK-NEXT: [[C:%.*]] = add nuw nsw i16 [[B]], 123
+; CHECK-NEXT: ret i16 [[C]]
+
+define i16 @trunc_test3(i16 %a) {
+  %a.ext = zext i16 %a to i32
+  %b = lshr i32 %a.ext, 2
+  %c = add i32 %b, 123
+  %c.trunc = trunc i32 %c to i16
+  ret i16 %c.trunc
+}
+
+; CHECK-LABEL: @trunc_test4
+; CHECK-NEXT: [[A1:%.*]] = udiv i32 [[A:%.*]], 17000000
+; CHECK-NEXT: [[B:%.*]] = trunc nuw nsw i32 [[A1]] to i16
+; CHECK-NEXT: [[B1:%.*]] = lshr i16 [[B]], 2
+; CHECK-NEXT: ret i16 [[B1]]
+
+define i16 @trunc_test4(i32 %a) {
+  %a.eff.trunc = udiv i32 %a, 17000000  ; larger than 2^24
+  %b = lshr i32 %a.eff.trunc, 2 
+  %b.trunc.1 = trunc i32 %b to i16
+  %b.trunc.2 = trunc i32 %b to i8
+  ret i16 %b.trunc.1
+}
+
+; CHECK-LABEL: @trunc_test5
+; CHECK-NEXT: [[A1:%.*]] = udiv i32 [[A:%.*]], 17000000
+; CHECK-NEXT: [[B:%.*]] = lshr i32 [[A1]], 2
+; CHECK-NEXT: [[C:%.*]] = add nuw nsw i32 [[B]], 123
+; CHECK-NEXT: ret i32 [[C]]
+
+define i32 @trunc_test5(i32 %a) {
+  %a.eff.trunc = udiv i32 %a, 17000000  ; larger than 2^24
+  %b = lshr i32 %a.eff.trunc, 2 
+  %b.trunc.1 = trunc i32 %b to i16
+  %b.trunc.2 = trunc i32 %b to i8
+  %c = add i32 %b, 123
+  ret i32 %c
+}
+
+; CHECK-LABEL: @zero_test1
+; CHECK-NEXT: ret i32 poison
+  
+define i32 @zero_test1(i32 %a) {
+  %b = lshr i32 %a, 32
+  %c = add i32 %b, 123
+  ret i32 %c
+}
+
+; CHECK-LABEL: @zero_test2
+; CHECK-NEXT: ret i32 poison
+
+define i32 @zero_test2(i32 %a, i32 %b) {
+  %b.large = add nuw nsw i32 %b, 50
+  %c = lshr i32 %a, %b.large
+  %d = add i32 %c, 123
+  ret i32 %d
+}
+
+; CHECK-LABEL: @zero_test3
+; CHECK-NEXT: ret i32 123
+
+define i32 @zero_test3(i32 %a, i32 %b) {
+  %a.small = lshr i32 %a, 16
+  %b.large = add nuw nsw i32 %b, 20
+  %c = lshr i32 %a.small, %b.large
+  %d = add i32 %c, 123
+  ret i32 %d
+}
+
+; CHECK-LABEL: @zero_test4
+; CHECK-NEXT: ret i32 123
+
+define i32 @zero_test4(i32 %a, i32 %b) {
+  %a.small = lshr i32 %a, 16
+  %b.large = add nuw nsw i32 %b, 20
+  %c = lshr exact i32 %a.small, %b.large
+  %d = add i32 %c, 123
+  ret i32 %d
+}
diff --git a/llvm/test/Transforms/CorrelatedValuePropagation/lshr.ll b/llvm/test/Transforms/CorrelatedValuePropagation/lshr.ll
new file mode 100644
index 00000000000000..3f129a9bd2e411
--- /dev/null
+++ b/llvm/test/Transforms/CorrelatedValuePropagation/lshr.ll
@@ -0,0 +1,241 @@
+; RUN: opt < %s -passes=correlated-propagation -S | FileCheck %s
+
+; Tests: test_nop and tests 1 through 6 are taken from udiv.ll
+; with udiv replaced by lshr (plus some tweaks).
+; In those tests the lshr instruction has no users.
+
+; CHECK-LABEL: @test_nop
+define void @test_nop(i32 %n) {
+; CHECK: lshr i32 %n, 2
+  %shr = lshr i32 %n, 2
+  ret void
+}
+
+; CHECK-LABEL: @test1(
+define void @test1(i32 %n) {
+entry:
+  %cmp = icmp ule i32 %n, 65535
+  br i1 %cmp, label %bb, label %exit
+
+bb:
+; CHECK: lshr i16
+  %shr = lshr i32 %n, 2
+  br label %exit
+
+exit:
+  ret void
+}
+
+; CHECK-LABEL: @test2(
+define void @test2(i32 %n) {
+entry:
+  %cmp = icmp ule i32 %n, 65536
+  br i1 %cmp, label %bb, label %exit
+
+bb:
+; CHECK: lshr i32
+  %shr = lshr i32 %n, 2
+  br label %exit
+
+exit:
+  ret void
+}
+
+; CHECK-LABEL: @test3(
+define void @test3(i32 %m, i32 %n) {
+entry:
+  %cmp1 = icmp ult i32 %m, 65535
+  %cmp2 = icmp ult i32 %n, 65535
+  %cmp = and i1 %cmp1, %cmp2
+  br i1 %cmp, label %bb, label %exit
+
+bb:
+; CHECK: lshr i16
+  %shr = lshr i32 %m, %n
+  br label %exit
+
+exit:
+  ret void
+}
+
+; CHECK-LABEL: @test4(
+define void @test4(i32 %m, i32 %n) {
+entry:
+  %cmp1 = icmp ult i32 %m, 65535
+  %cmp2 = icmp ule i32 %n, 65536
+  %cmp = and i1 %cmp1, %cmp2
+  br i1 %cmp, label %bb, label %exit
+
+bb:
+; CHECK: lshr i32
+  %shr = lshr i32 %m, %n
+  br label %exit
+
+exit:
+  ret void
+}
+
+; CHECK-LABEL: @test5
+define void @test5(i32 %n) {
+  %trunc = and i32 %n, 65535
+  ; CHECK: lshr i16
+  %shr = lshr i32 %trunc, 2
+  ret void
+}
+
+; CHECK-LABEL: @test6
+define void @test6(i32 %n) {
+entry:
+  %cmp = icmp ule i32 %n, 255
+  br i1 %cmp, label %bb, label %exit
+
+bb:
+; CHECK: lshr i8
+  %shr = lshr i32 %n, 2
+  br label %exit
+
+exit:
+  ret void
+}
+
+; The tests below check whether the narrowing occures only if the appropriate
+; `trunc` instructions follow.
+;
+; Just as in udiv.ll, additional zext and trunc instructions appear. 
+; They are eventually recombined by InstCombinePass 
+; that follows in the pipeline.
+
+; CHECK-LABEL: @trunc_test1
+; CHECK-NEXT: [[A1:%.*]] = lshr i32 [[A:%.*]], 16
+; CHECK-NEXT: [[B1:%.*]] = and i32 [[B:%.*]], 65535
+; CHECK-NEXT: [[A2:%.*]] = trunc i32 [[A1]] to i16
+; CHECK-NEXT: [[B2:%.*]] = trunc i32 [[B1]] to i16
+; CHECK-NEXT: [[C1:%.*]] = lshr i16 [[A2]], [[B2]]
+; CHECK-NEXT: [[C2:%.*]] = zext i16 [[C1]] to i32
+; CHECK-NEXT: [[C3:%.*]] = trunc i32 [[C2]] to i16
+; CHECK-NEXT: ret i16 [[C3]]
+
+define i16 @trunc_test1(i32 %a, i32 %b) {
+  %a.eff.trunc = lshr i32 %a, 16
+  %b.eff.trunc = and i32 %b, 65535
+  %c = lshr i32 %a.eff.trunc, %b.eff.trunc
+  %c.trunc = trunc i32 %c to i16
+  ret i16 %c.trunc
+}
+
+; CHECK-LABEL: @trunc_test2
+; CHECK-NEXT: [[A1:%.*]] = zext i16 [[A:%.*]] to i32
+; CHECK-NEXT: [[A2:%.*]] = trunc i32 [[A1]] to i16
+; CHECK-NEXT: [[C1:%.*]] = lshr i16 [[A2]], 2
+; CHECK-NEXT: [[C2:%.*]] = zext i16 [[C1]] to i32
+; CHECK-NEXT: [[C3:%.*]] = trunc i32 [[C2]] to i16
+; CHECK-NEXT: ret i16 [[C3]]
+
+define i16 @trunc_test2(i16 %a) {
+  %a.ext = zext i16 %a to i32
+  %c = lshr i32 %a.ext, 2
+  %c.trunc = trunc i32 %c to i16
+  ret i16 %c.trunc
+}
+
+; CHECK-LABEL: @trunc_test3
+; CHECK-NEXT: [[A1:%.*]] = zext i16 [[A:%.*]] to i32
+; CHECK-NEXT: [[B:%.*]] = lshr i32 [[A1]], 2
+; CHECK-NEXT: [[C0:%.*]] = add nuw nsw i32 [[B]], 123
+; CHECK-NEXT: [[C1:%.*]] = trunc i32 [[C0]] to i16
+; CHECK-NEXT: ret i16 [[C1]]
+
+define i16 @trunc_test3(i16 %a) {
+  %a.ext = zext i16 %a to i32
+  %b = lshr i32 %a.ext, 2
+  %c = add i32 %b, 123
+  %c.trunc = trunc i32 %c to i16
+  ret i16 %c.trunc
+}
+
+; CHECK-LABEL: @trunc_test4
+; CHECK-NEXT: [[A1:%.*]] = udiv i32 [[A:%.*]], 17000000
+; CHECK-NEXT: [[B0:%.*]] = trunc i32 [[A1]] to i16
+; CHECK-NEXT: [[B1:%.*]] = lshr i16 [[B0]], 2
+; CHECK-NEXT: [[B2:%.*]] = zext i16 [[B1]] to i32
+; CHECK-NEXT: [[C1:%.*]] = trunc i32 [[B2]] to i16
+; CHECK-NEXT: [[C2:%.*]] = trunc i32 [[B2]] to i8
+; CHECK-NEXT: ret i16 [[C1]]
+
+define i16 @trunc_test4(i32 %a) {
+  %a.eff.trunc = udiv i32 %a, 17000000  ; larger than 2^24
+  %b = lshr i32 %a.eff.trunc, 2 
+  %b.trunc.1 = trunc i32 %b to i16
+  %b.trunc.2 = trunc i32 %b to i8
+  ret i16 %b.trunc.1
+}
+
+; CHECK-LABEL: @trunc_test5
+; CHECK-NEXT: [[A1:%.*]] = udiv i32 [[A:%.*]], 17000000
+; CHECK-NEXT: [[B:%.*]] = lshr i32 [[A1]], 2
+; CHECK-NEXT: [[B1:%.*]] = trunc i32 [[B]] to i16
+; CHECK-NEXT: [[B2:%.*]] = trunc i32 [[B]] to i8
+; CHECK-NEXT: [[C:%.*]] = add nuw nsw i32 [[B]], 123
+; CHECK-NEXT: ret i32 [[C]]
+
+define i32 @trunc_test5(i32 %a) {
+  %a.eff.trunc = udiv i32 %a, 17000000  ; larger than 2^24
+  %b = lshr i32 %a.eff.trunc, 2 
+  %b.trunc.1 = trunc i32 %b to i16
+  %b.trunc.2 = trunc i32 %b to i8
+  %c = add i32 %b, 123
+  ret i32 %c
+}
+
+; Test cases where lshr simplifies to zero or poison.
+
+; CHECK-LABEL: @zero_test1
+; CHECK-NEXT: [[C:%.*]] = add i32 poison, 123
+; CHECK-NEXT: ret i32 [[C]]
+  
+define i32 @zero_test1(i32 %a) {
+  %b = lshr i32 %a, 32
+  %c = add i32 %b, 123
+  ret i32 %c
+}
+
+; CHECK-LABEL: @zero_test2
+; CHECK-NEXT: [[B1:%.*]] = add nuw nsw i32 [[B:%.*]], 50
+; CHECK-NEXT: [[D:%.*]] = add i32 poison, 123
+; CHECK-NEXT: ret i32 [[D]]
+
+define i32 @zero_test2(i32 %a, i32 %b) {
+  %b.large = add nuw nsw i32 %b, 50
+  %c = lshr i32 %a, %b.large
+  %d = add i32 %c, 123
+  ret i32 %d
+}
+
+; CHECK-LABEL: @zero_test3
+; CHECK-NEXT: [[A1:%.*]] = lshr i32 [[A:%.*]], 16
+; CHECK-NEXT: [[B1:%.*]] = add nuw nsw i32 [[B:%.*]], 20
+; CHECK-NEXT: [[D:%.*]] = add nuw nsw i32 0, 123
+; CHECK-NEXT: ret i32 123
+
+define i32 @zero_test3(i32 %a, i32 %b) {
+  %a.small = lshr i32 %a, 16
+  %b.large = add nuw nsw i32 %b, 20
+  %c = lshr i32 %a.small, %b.large
+  %d = add i32 %c, 123
+  ret i32 %d
+}
+
+; CHECK-LABEL: @zero_test4
+; CHECK-NEXT: [[A1:%.*]] = lshr i32 [[A:%.*]], 16
+; CHECK-NEXT: [[B1:%.*]] = add nuw nsw i32 [[B:%.*]], 20
+; CHECK-NEXT: [[C:%.*]] = lshr exact i32 [[A1]], [[B1]]
+; CHECK-NEXT: [[D:%.*]] = add nuw nsw i32 [[C]], 123
+; CHECK-NEXT: ret i32 123
+
+define i32 @zero_test4(i32 %a, i32 %b) {
+  %a.small = lshr i32 %a, 16
+  %b.large = add nuw nsw i32 %b, 20
+  %c = lshr exact i32 %a.small, %b.large
+  %d = add i32 %c, 123
+  ret i32 %d
+}

>From 2e44ffdbbe82a9b22330db2d37e6ea6c29baab2d Mon Sep 17 00:00:00 2001
From: "Bzowski, Adam" <adam.bzowski at intel.com>
Date: Fri, 13 Dec 2024 05:30:07 -0800
Subject: [PATCH 2/5] [CVP] Implement type narrowing for LShr

Implements type narrowing for LShr. The treatment is analogous to the type narrowing of UDiv. Since LShr is a relatively cheap instruction, the narrowing occurs only if the following conditions hold: i) all the users of the LShr instruction are already TruncInst; ii) the narrowing is carried out to the largest TruncInst following the LShr instruction. Additionally, the function optimizes the cases where the result of the LShr instruction is guaranteed to vanish or be equal to poison.
---
 .../Scalar/CorrelatedValuePropagation.cpp     | 17 +++++++++
 .../lshr-plus-instcombine.ll                  |  5 ++-
 .../CorrelatedValuePropagation/lshr.ll        | 37 ++++++++++++++++---
 3 files changed, 51 insertions(+), 8 deletions(-)

diff --git a/llvm/lib/Transforms/Scalar/CorrelatedValuePropagation.cpp b/llvm/lib/Transforms/Scalar/CorrelatedValuePropagation.cpp
index 9b81adf77bdf52..88ec68857c82f5 100644
--- a/llvm/lib/Transforms/Scalar/CorrelatedValuePropagation.cpp
+++ b/llvm/lib/Transforms/Scalar/CorrelatedValuePropagation.cpp
@@ -1134,6 +1134,19 @@ static bool narrowLShr(BinaryOperator *LShr, LazyValueInfo *LVI) {
     return true;
   }
 
+  // Since LShr returns poison if the shift is larger of equal that the bit
+  // width of the argument, we must make sure that the maximal possible value
+  // for the shift is larger than the new width after narrowing. Otherwise some
+  // shifts that originally vanish would result in poison after the narrowing.
+  uint64_t MaxShiftValue64 = ShiftRange.getUnsignedMax().getZExtValue();
+  unsigned MaxShiftValue =
+      MaxShiftValue64 < std::numeric_limits<unsigned>::max()
+          ? static_cast<unsigned>(MaxShiftValue64)
+          : std::numeric_limits<unsigned>::max();
+
+  if (OrigWidth <= MaxShiftValue)
+    return false;
+
   // That's how many bits we need.
   unsigned MaxActiveBits =
       std::max(MaxActiveBitsInArg, ShiftRange.getActiveBits());
@@ -1165,6 +1178,10 @@ static bool narrowLShr(BinaryOperator *LShr, LazyValueInfo *LVI) {
     }
   }
 
+  // See comment above MaxShiftValue.
+  if (NewWidth <= MaxShiftValue)
+    return false;
+
   // We are ready to truncate.
   IRBuilder<> B(LShr);
   Type *TruncTy = RetTy->getWithNewBitWidth(NewWidth);
diff --git a/llvm/test/Transforms/CorrelatedValuePropagation/lshr-plus-instcombine.ll b/llvm/test/Transforms/CorrelatedValuePropagation/lshr-plus-instcombine.ll
index 3646ff9aee4783..0599d8b9c02fe0 100644
--- a/llvm/test/Transforms/CorrelatedValuePropagation/lshr-plus-instcombine.ll
+++ b/llvm/test/Transforms/CorrelatedValuePropagation/lshr-plus-instcombine.ll
@@ -7,13 +7,14 @@
 ; CHECK-LABEL: @trunc_test1
 ; CHECK-NEXT: [[A1:%.*]] = lshr i32 [[A:%.*]], 16
 ; CHECK-NEXT: [[CARG:%.*]] = trunc nuw i32 [[A1]] to i16
-; CHECK-NEXT: [[CSHIFT:%.*]] = trunc i32 [[B:%.*]] to i16
+; CHECK-NEXT: [[B1:%.*]] = trunc i32 [[B:%.*]] to i16
+; CHECK-NEXT: [[CSHIFT:%.*]] = and i16 [[B1]], 15
 ; CHECK-NEXT: [[C1:%.*]] = lshr i16 [[CARG]], [[CSHIFT]]
 ; CHECK-NEXT: ret i16 [[C1]]
 
 define i16 @trunc_test1(i32 %a, i32 %b) {
   %a.eff.trunc = lshr i32 %a, 16
-  %b.eff.trunc = and i32 %b, 65535
+  %b.eff.trunc = and i32 %b, 15
   %c = lshr i32 %a.eff.trunc, %b.eff.trunc
   %c.trunc = trunc i32 %c to i16
   ret i16 %c.trunc
diff --git a/llvm/test/Transforms/CorrelatedValuePropagation/lshr.ll b/llvm/test/Transforms/CorrelatedValuePropagation/lshr.ll
index 3f129a9bd2e411..7f1e866874f3f2 100644
--- a/llvm/test/Transforms/CorrelatedValuePropagation/lshr.ll
+++ b/llvm/test/Transforms/CorrelatedValuePropagation/lshr.ll
@@ -50,7 +50,7 @@ entry:
   br i1 %cmp, label %bb, label %exit
 
 bb:
-; CHECK: lshr i16
+; CHECK: lshr i32
   %shr = lshr i32 %m, %n
   br label %exit
 
@@ -58,11 +58,11 @@ exit:
   ret void
 }
 
-; CHECK-LABEL: @test4(
-define void @test4(i32 %m, i32 %n) {
+; CHECK-LABEL: @test3a(
+define void @test3a(i32 %m, i32 %n) {
 entry:
   %cmp1 = icmp ult i32 %m, 65535
-  %cmp2 = icmp ule i32 %n, 65536
+  %cmp2 = icmp ult i32 %n, 17
   %cmp = and i1 %cmp1, %cmp2
   br i1 %cmp, label %bb, label %exit
 
@@ -75,6 +75,23 @@ exit:
   ret void
 }
 
+; CHECK-LABEL: @test3b(
+define void @test3b(i32 %m, i32 %n) {
+entry:
+  %cmp1 = icmp ult i32 %m, 65535
+  %cmp2 = icmp ult i32 %n, 16
+  %cmp = and i1 %cmp1, %cmp2
+  br i1 %cmp, label %bb, label %exit
+
+bb:
+; CHECK: lshr i16
+  %shr = lshr i32 %m, %n
+  br label %exit
+
+exit:
+  ret void
+}
+
 ; CHECK-LABEL: @test5
 define void @test5(i32 %n) {
   %trunc = and i32 %n, 65535
@@ -83,6 +100,14 @@ define void @test5(i32 %n) {
   ret void
 }
 
+; CHECK-LABEL: @test5a
+define void @test5a(i32 %n) {
+  %trunc = and i32 %n, 65535
+  ; CHECK: lshr i16
+  %shr = lshr i32 %trunc, 15
+  ret void
+}
+
 ; CHECK-LABEL: @test6
 define void @test6(i32 %n) {
 entry:
@@ -107,7 +132,7 @@ exit:
 
 ; CHECK-LABEL: @trunc_test1
 ; CHECK-NEXT: [[A1:%.*]] = lshr i32 [[A:%.*]], 16
-; CHECK-NEXT: [[B1:%.*]] = and i32 [[B:%.*]], 65535
+; CHECK-NEXT: [[B1:%.*]] = and i32 [[B:%.*]], 15
 ; CHECK-NEXT: [[A2:%.*]] = trunc i32 [[A1]] to i16
 ; CHECK-NEXT: [[B2:%.*]] = trunc i32 [[B1]] to i16
 ; CHECK-NEXT: [[C1:%.*]] = lshr i16 [[A2]], [[B2]]
@@ -117,7 +142,7 @@ exit:
 
 define i16 @trunc_test1(i32 %a, i32 %b) {
   %a.eff.trunc = lshr i32 %a, 16
-  %b.eff.trunc = and i32 %b, 65535
+  %b.eff.trunc = and i32 %b, 15
   %c = lshr i32 %a.eff.trunc, %b.eff.trunc
   %c.trunc = trunc i32 %c to i16
   ret i16 %c.trunc

>From 0ce38dd4ad6e33b108b99150823c5866fe6694b0 Mon Sep 17 00:00:00 2001
From: "Bzowski, Adam" <adam.bzowski at intel.com>
Date: Fri, 13 Dec 2024 06:34:36 -0800
Subject: [PATCH 3/5] [CVP] Implement type narrowing for LShr

Implements type narrowing for LShr. The treatment is analogous to the type narrowing of UDiv. Since LShr is a relatively cheap instruction, the narrowing occurs only if the following conditions hold: i) all the users of the LShr instruction are already TruncInst; ii) the narrowing is carried out to the largest TruncInst following the LShr instruction. Additionally, the function optimizes the cases where the result of the LShr instruction is guaranteed to vanish or be equal to poison.

>From e2cfa08c7bf365ce36068391fbbf7e12fde69b37 Mon Sep 17 00:00:00 2001
From: "Bzowski, Adam" <adam.bzowski at intel.com>
Date: Fri, 13 Dec 2024 10:50:25 -0800
Subject: [PATCH 4/5] [CVP] Implement type narrowing for LShr

Implements type narrowing for LShr. The treatment is analogous to the type narrowing of UDiv. Since LShr is a relatively cheap instruction, the narrowing occurs only if the following conditions hold: i) all the users of the LShr instruction are already TruncInst; ii) the narrowing is carried out to the largest TruncInst following the LShr instruction. Additionally, the function optimizes the cases where the result of the LShr instruction is guaranteed to vanish or be equal to poison.
---
 .../Scalar/CorrelatedValuePropagation.cpp           | 13 ++++---------
 1 file changed, 4 insertions(+), 9 deletions(-)

diff --git a/llvm/lib/Transforms/Scalar/CorrelatedValuePropagation.cpp b/llvm/lib/Transforms/Scalar/CorrelatedValuePropagation.cpp
index 88ec68857c82f5..9539a2d6c4c398 100644
--- a/llvm/lib/Transforms/Scalar/CorrelatedValuePropagation.cpp
+++ b/llvm/lib/Transforms/Scalar/CorrelatedValuePropagation.cpp
@@ -1103,10 +1103,8 @@ static bool narrowLShr(BinaryOperator *LShr, LazyValueInfo *LVI) {
 
   unsigned OrigWidth = RetTy->getScalarSizeInBits();
   unsigned MaxActiveBitsInArg = ArgRange.getActiveBits();
-  uint64_t MinShiftValue64 = ShiftRange.getUnsignedMin().getZExtValue();
-  unsigned MinShiftValue =
-      MinShiftValue64 < std::numeric_limits<unsigned>::max()
-          ? static_cast<unsigned>(MinShiftValue64)
+  unsigned MinShiftValue = ShiftRange.getUnsignedMin().getActiveBits() <= 32
+          ? static_cast<unsigned>(ShiftRange.getUnsignedMin().getZExtValue())
           : std::numeric_limits<unsigned>::max();
 
   // First we deal with the cases where the result is guaranteed to vanish or be
@@ -1138,12 +1136,9 @@ static bool narrowLShr(BinaryOperator *LShr, LazyValueInfo *LVI) {
   // width of the argument, we must make sure that the maximal possible value
   // for the shift is larger than the new width after narrowing. Otherwise some
   // shifts that originally vanish would result in poison after the narrowing.
-  uint64_t MaxShiftValue64 = ShiftRange.getUnsignedMax().getZExtValue();
-  unsigned MaxShiftValue =
-      MaxShiftValue64 < std::numeric_limits<unsigned>::max()
-          ? static_cast<unsigned>(MaxShiftValue64)
+  unsigned MaxShiftValue = ShiftRange.getUnsignedMax().getActiveBits() <= 32
+          ? static_cast<unsigned>(ShiftRange.getUnsignedMax().getZExtValue())
           : std::numeric_limits<unsigned>::max();
-
   if (OrigWidth <= MaxShiftValue)
     return false;
 

>From 2d9af1c7fad4e3deb314e5a79241c2330cd790ae Mon Sep 17 00:00:00 2001
From: "Bzowski, Adam" <adam.bzowski at intel.com>
Date: Fri, 13 Dec 2024 11:41:56 -0800
Subject: [PATCH 5/5] [CVP] Implement type narrowing for LShr

Implements type narrowing for LShr. The treatment is analogous to the type narrowing of UDiv. Since LShr is a relatively cheap instruction, the narrowing occurs only if the following conditions hold: i) all the users of the LShr instruction are already TruncInst; ii) the narrowing is carried out to the largest TruncInst following the LShr instruction. Additionally, the function optimizes the cases where the result of the LShr instruction is guaranteed to vanish or be equal to poison.
---
 .../Scalar/CorrelatedValuePropagation.cpp         | 15 ++++++++++++---
 1 file changed, 12 insertions(+), 3 deletions(-)

diff --git a/llvm/lib/Transforms/Scalar/CorrelatedValuePropagation.cpp b/llvm/lib/Transforms/Scalar/CorrelatedValuePropagation.cpp
index 9539a2d6c4c398..b5322fddf7d3fd 100644
--- a/llvm/lib/Transforms/Scalar/CorrelatedValuePropagation.cpp
+++ b/llvm/lib/Transforms/Scalar/CorrelatedValuePropagation.cpp
@@ -758,6 +758,15 @@ static Domain getDomain(const ConstantRange &CR) {
   return Domain::Unknown;
 }
 
+/// Returns the optimal bit width for the narrowed type.
+/// If desiredBitWidth is greater that 32 returns max int and thus effectively
+/// turns of the narrowing. For desiredBitWidth <= 32 returns i8, i16, i32.
+static unsigned getNarrowedWidth(unsigned desiredBitWidth) {
+  return desiredBitWidth <= 32
+             ? std::max<unsigned>(PowerOf2Ceil(desiredBitWidth), 8)
+             : std::numeric_limits<unsigned>::max();
+}
+
 /// Try to shrink a sdiv/srem's width down to the smallest power of two that's
 /// sufficient to contain its operands.
 static bool narrowSDivOrSRem(BinaryOperator *Instr, const ConstantRange &LCR,
@@ -781,7 +790,7 @@ static bool narrowSDivOrSRem(BinaryOperator *Instr, const ConstantRange &LCR,
     ++MinSignedBits;
 
   // Don't shrink below 8 bits wide.
-  unsigned NewWidth = std::max<unsigned>(PowerOf2Ceil(MinSignedBits), 8);
+  unsigned NewWidth = getNarrowedWidth(MinSignedBits);
 
   // NewWidth might be greater than OrigWidth if OrigWidth is not a power of
   // two.
@@ -899,7 +908,7 @@ static bool narrowUDivOrURem(BinaryOperator *Instr, const ConstantRange &XCR,
   // of both of the operands?
   unsigned MaxActiveBits = std::max(XCR.getActiveBits(), YCR.getActiveBits());
   // Don't shrink below 8 bits wide.
-  unsigned NewWidth = std::max<unsigned>(PowerOf2Ceil(MaxActiveBits), 8);
+  unsigned NewWidth = getNarrowedWidth(MaxActiveBits);
 
   // NewWidth might be greater than OrigWidth if OrigWidth is not a power of
   // two.
@@ -1156,7 +1165,7 @@ static bool narrowLShr(BinaryOperator *LShr, LazyValueInfo *LVI) {
 
   // What is the smallest bit width that can accommodate the entire value ranges
   // of both of the operands? Don't shrink below 8 bits wide.
-  unsigned NewWidth = std::max<unsigned>(PowerOf2Ceil(MaxActiveBits), 8);
+  unsigned NewWidth = getNarrowedWidth(MaxActiveBits);
 
   // NewWidth might be greater than OrigWidth if OrigWidth is not a power of
   // two.



More information about the llvm-commits mailing list