[llvm] 231fa27 - [InstCombine] Generate better code for std::bit_ceil

Kazu Hirata via llvm-commits llvm-commits at lists.llvm.org
Thu Mar 23 19:27:00 PDT 2023


Author: Kazu Hirata
Date: 2023-03-23T19:26:43-07:00
New Revision: 231fa27435105e980b113754c112980ebeb8927d

URL: https://github.com/llvm/llvm-project/commit/231fa27435105e980b113754c112980ebeb8927d
DIFF: https://github.com/llvm/llvm-project/commit/231fa27435105e980b113754c112980ebeb8927d.diff

LOG: [InstCombine] Generate better code for std::bit_ceil

Without this patch, std::bit_ceil<uint32_t> is compiled as:

  %dec = add i32 %x, -1
  %lz = tail call i32 @llvm.ctlz.i32(i32 %dec, i1 false)
  %sub = sub i32 32, %lz
  %res = shl i32 1, %sub
  %ugt = icmp ugt i32 %x, 1
  %sel = select i1 %ugt, i32 %res, i32 1

With this patch, we generate:

  %dec = add i32 %x, -1
  %ctlz = tail call i32 @llvm.ctlz.i32(i32 %dec, i1 false)
  %sub = sub nsw i32 0, %ctlz
  %and = and i32 %1, 31
  %sel = shl nuw i32 1, %and
  ret i32 %sel

https://alive2.llvm.org/ce/z/pwezvF

This patch recognizes the specific pattern from std::bit_ceil in
libc++ and libstdc++ and drops the conditional move.  In addition to
the LLVM IR generated for std::bit_ceil(X), this patch recognizes
variants like:

  std::bit_ceil(X - 1)
  std::bit_ceil(X + 1)
  std::bit_ceil(X + 2)
  std::bit_ceil(-X)
  std::bit_ceil(~X)

This patch fixes:

https://github.com/llvm/llvm-project/issues/60802

Differential Revision: https://reviews.llvm.org/D145299

Added: 
    

Modified: 
    llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
    llvm/test/Transforms/InstCombine/bit_ceil.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
index 1f2441bc9fcf9..3d1dbdd6270d5 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
@@ -3163,6 +3163,134 @@ Instruction *InstCombinerImpl::foldSelectOfBools(SelectInst &SI) {
   return nullptr;
 }
 
+// Return true if we can safely remove the select instruction for std::bit_ceil
+// pattern.
+static bool isSafeToRemoveBitCeilSelect(ICmpInst::Predicate Pred, Value *Cond0,
+                                        const APInt *Cond1, Value *CtlzOp,
+                                        unsigned BitWidth) {
+  // The challenge in recognizing std::bit_ceil(X) is that the operand is used
+  // for the CTLZ proper and select condition, each possibly with some
+  // operation like add and sub.
+  //
+  // Our aim is to make sure that -ctlz & (BitWidth - 1) == 0 even when the
+  // select instruction would select 1, which allows us to get rid of the select
+  // instruction.
+  //
+  // To see if we can do so, we do some symbolic execution with ConstantRange.
+  // Specifically, we compute the range of values that Cond0 could take when
+  // Cond == false.  Then we successively transform the range until we obtain
+  // the range of values that CtlzOp could take.
+  //
+  // Conceptually, we follow the def-use chain backward from Cond0 while
+  // transforming the range for Cond0 until we meet the common ancestor of Cond0
+  // and CtlzOp.  Then we follow the def-use chain forward until we obtain the
+  // range for CtlzOp.  That said, we only follow at most one ancestor from
+  // Cond0.  Likewise, we only follow at most one ancestor from CtrlOp.
+
+  ConstantRange CR = ConstantRange::makeExactICmpRegion(
+      CmpInst::getInversePredicate(Pred), *Cond1);
+
+  // Match the operation that's used to compute CtlzOp from CommonAncestor.  If
+  // CtlzOp == CommonAncestor, return true as no operation is needed.  If a
+  // match is found, execute the operation on CR, update CR, and return true.
+  // Otherwise, return false.
+  auto MatchForward = [&](Value *CommonAncestor) {
+    const APInt *C = nullptr;
+    if (CtlzOp == CommonAncestor)
+      return true;
+    if (match(CtlzOp, m_Add(m_Specific(CommonAncestor), m_APInt(C)))) {
+      CR = CR.add(*C);
+      return true;
+    }
+    if (match(CtlzOp, m_Sub(m_APInt(C), m_Specific(CommonAncestor)))) {
+      CR = ConstantRange(*C).sub(CR);
+      return true;
+    }
+    if (match(CtlzOp, m_Not(m_Specific(CommonAncestor)))) {
+      CR = CR.binaryNot();
+      return true;
+    }
+    return false;
+  };
+
+  const APInt *C = nullptr;
+  Value *CommonAncestor;
+  if (MatchForward(Cond0)) {
+    // Cond0 is either CtlzOp or CtlzOp's parent.  CR has been updated.
+  } else if (match(Cond0, m_Add(m_Value(CommonAncestor), m_APInt(C)))) {
+    CR = CR.sub(*C);
+    if (!MatchForward(CommonAncestor))
+      return false;
+    // Cond0's parent is either CtlzOp or CtlzOp's parent.  CR has been updated.
+  } else {
+    return false;
+  }
+
+  // Return true if all the values in the range are either 0 or negative (if
+  // treated as signed).  We do so by evaluating:
+  //
+  //   CR - 1 u>= (1 << BitWidth) - 1.
+  APInt IntMax = APInt::getSignMask(BitWidth) - 1;
+  CR = CR.sub(APInt(BitWidth, 1));
+  return CR.icmp(ICmpInst::ICMP_UGE, IntMax);
+}
+
+// Transform the std::bit_ceil(X) pattern like:
+//
+//   %dec = add i32 %x, -1
+//   %ctlz = tail call i32 @llvm.ctlz.i32(i32 %dec, i1 false)
+//   %sub = sub i32 32, %ctlz
+//   %shl = shl i32 1, %sub
+//   %ugt = icmp ugt i32 %x, 1
+//   %sel = select i1 %ugt, i32 %shl, i32 1
+//
+// into:
+//
+//   %dec = add i32 %x, -1
+//   %ctlz = tail call i32 @llvm.ctlz.i32(i32 %dec, i1 false)
+//   %neg = sub i32 0, %ctlz
+//   %masked = and i32 %ctlz, 31
+//   %shl = shl i32 1, %sub
+//
+// Note that the select is optimized away while the shift count is masked with
+// 31.  We handle some variations of the input operand like std::bit_ceil(X +
+// 1).
+static Instruction *foldBitCeil(SelectInst &SI, IRBuilderBase &Builder) {
+  Type *SelType = SI.getType();
+  unsigned BitWidth = SelType->getScalarSizeInBits();
+
+  Value *FalseVal = SI.getFalseValue();
+  Value *TrueVal = SI.getTrueValue();
+  ICmpInst::Predicate Pred;
+  const APInt *Cond1;
+  Value *Cond0, *Ctlz, *CtlzOp;
+  if (!match(SI.getCondition(), m_ICmp(Pred, m_Value(Cond0), m_APInt(Cond1))))
+    return nullptr;
+
+  if (match(TrueVal, m_One())) {
+    std::swap(FalseVal, TrueVal);
+    Pred = CmpInst::getInversePredicate(Pred);
+  }
+
+  if (!match(FalseVal, m_One()) ||
+      !match(TrueVal,
+             m_OneUse(m_Shl(m_One(), m_OneUse(m_Sub(m_SpecificInt(BitWidth),
+                                                    m_Value(Ctlz)))))) ||
+      !match(Ctlz, m_Intrinsic<Intrinsic::ctlz>(m_Value(CtlzOp), m_Zero())) ||
+      !isSafeToRemoveBitCeilSelect(Pred, Cond0, Cond1, CtlzOp, BitWidth))
+    return nullptr;
+
+  // Build 1 << (-CTLZ & (BitWidth-1)).  The negation likely corresponds to a
+  // single hardware instruction as opposed to BitWidth - CTLZ, where BitWidth
+  // is an integer constant.  Masking with BitWidth-1 comes free on some
+  // hardware as part of the shift instruction.
+  Value *Neg = Builder.CreateNeg(Ctlz);
+  Value *Masked =
+      Builder.CreateAnd(Neg, ConstantInt::get(SelType, BitWidth - 1));
+  return BinaryOperator::Create(Instruction::Shl, ConstantInt::get(SelType, 1),
+                                Masked);
+}
+
 Instruction *InstCombinerImpl::visitSelectInst(SelectInst &SI) {
   Value *CondVal = SI.getCondition();
   Value *TrueVal = SI.getTrueValue();
@@ -3590,5 +3718,8 @@ Instruction *InstCombinerImpl::visitSelectInst(SelectInst &SI) {
   if (sinkNotIntoOtherHandOfLogicalOp(SI))
     return &SI;
 
+  if (Instruction *I = foldBitCeil(SI, Builder))
+    return I;
+
   return nullptr;
 }

diff  --git a/llvm/test/Transforms/InstCombine/bit_ceil.ll b/llvm/test/Transforms/InstCombine/bit_ceil.ll
index 98f4cdb6fb834..6f714153a598a 100644
--- a/llvm/test/Transforms/InstCombine/bit_ceil.ll
+++ b/llvm/test/Transforms/InstCombine/bit_ceil.ll
@@ -6,10 +6,9 @@ define i32 @bit_ceil_32(i32 %x) {
 ; CHECK-LABEL: @bit_ceil_32(
 ; CHECK-NEXT:    [[DEC:%.*]] = add i32 [[X:%.*]], -1
 ; CHECK-NEXT:    [[CTLZ:%.*]] = tail call i32 @llvm.ctlz.i32(i32 [[DEC]], i1 false), !range [[RNG0:![0-9]+]]
-; CHECK-NEXT:    [[SUB:%.*]] = sub nuw nsw i32 32, [[CTLZ]]
-; CHECK-NEXT:    [[SHL:%.*]] = shl nuw i32 1, [[SUB]]
-; CHECK-NEXT:    [[UGT:%.*]] = icmp ugt i32 [[X]], 1
-; CHECK-NEXT:    [[SEL:%.*]] = select i1 [[UGT]], i32 [[SHL]], i32 1
+; CHECK-NEXT:    [[TMP1:%.*]] = sub nsw i32 0, [[CTLZ]]
+; CHECK-NEXT:    [[TMP2:%.*]] = and i32 [[TMP1]], 31
+; CHECK-NEXT:    [[SEL:%.*]] = shl nuw i32 1, [[TMP2]]
 ; CHECK-NEXT:    ret i32 [[SEL]]
 ;
   %dec = add i32 %x, -1
@@ -26,10 +25,9 @@ define i64 @bit_ceil_64(i64 %x) {
 ; CHECK-LABEL: @bit_ceil_64(
 ; CHECK-NEXT:    [[DEC:%.*]] = add i64 [[X:%.*]], -1
 ; CHECK-NEXT:    [[CTLZ:%.*]] = tail call i64 @llvm.ctlz.i64(i64 [[DEC]], i1 false), !range [[RNG1:![0-9]+]]
-; CHECK-NEXT:    [[SUB:%.*]] = sub nuw nsw i64 64, [[CTLZ]]
-; CHECK-NEXT:    [[SHL:%.*]] = shl nuw i64 1, [[SUB]]
-; CHECK-NEXT:    [[UGT:%.*]] = icmp ugt i64 [[X]], 1
-; CHECK-NEXT:    [[SEL:%.*]] = select i1 [[UGT]], i64 [[SHL]], i64 1
+; CHECK-NEXT:    [[TMP1:%.*]] = sub nsw i64 0, [[CTLZ]]
+; CHECK-NEXT:    [[TMP2:%.*]] = and i64 [[TMP1]], 63
+; CHECK-NEXT:    [[SEL:%.*]] = shl nuw i64 1, [[TMP2]]
 ; CHECK-NEXT:    ret i64 [[SEL]]
 ;
   %dec = add i64 %x, -1
@@ -47,11 +45,9 @@ define i32 @bit_ceil_32_minus_1(i32 %x) {
 ; CHECK-NEXT:  entry:
 ; CHECK-NEXT:    [[SUB:%.*]] = add i32 [[X:%.*]], -2
 ; CHECK-NEXT:    [[CTLZ:%.*]] = tail call i32 @llvm.ctlz.i32(i32 [[SUB]], i1 false), !range [[RNG0]]
-; CHECK-NEXT:    [[SUB2:%.*]] = sub nuw nsw i32 32, [[CTLZ]]
-; CHECK-NEXT:    [[SHL:%.*]] = shl nuw i32 1, [[SUB2]]
-; CHECK-NEXT:    [[ADD:%.*]] = add i32 [[X]], -3
-; CHECK-NEXT:    [[ULT:%.*]] = icmp ult i32 [[ADD]], -2
-; CHECK-NEXT:    [[SEL:%.*]] = select i1 [[ULT]], i32 [[SHL]], i32 1
+; CHECK-NEXT:    [[TMP0:%.*]] = sub nsw i32 0, [[CTLZ]]
+; CHECK-NEXT:    [[TMP1:%.*]] = and i32 [[TMP0]], 31
+; CHECK-NEXT:    [[SEL:%.*]] = shl nuw i32 1, [[TMP1]]
 ; CHECK-NEXT:    ret i32 [[SEL]]
 ;
 entry:
@@ -69,11 +65,9 @@ entry:
 define i32 @bit_ceil_32_plus_1(i32 %x) {
 ; CHECK-LABEL: @bit_ceil_32_plus_1(
 ; CHECK-NEXT:    [[CTLZ:%.*]] = tail call i32 @llvm.ctlz.i32(i32 [[X:%.*]], i1 false), !range [[RNG0]]
-; CHECK-NEXT:    [[SUB:%.*]] = sub nuw nsw i32 32, [[CTLZ]]
-; CHECK-NEXT:    [[SHL:%.*]] = shl nuw i32 1, [[SUB]]
-; CHECK-NEXT:    [[DEC:%.*]] = add i32 [[X]], -1
-; CHECK-NEXT:    [[ULT:%.*]] = icmp ult i32 [[DEC]], -2
-; CHECK-NEXT:    [[SEL:%.*]] = select i1 [[ULT]], i32 [[SHL]], i32 1
+; CHECK-NEXT:    [[TMP1:%.*]] = sub nsw i32 0, [[CTLZ]]
+; CHECK-NEXT:    [[TMP2:%.*]] = and i32 [[TMP1]], 31
+; CHECK-NEXT:    [[SEL:%.*]] = shl nuw i32 1, [[TMP2]]
 ; CHECK-NEXT:    ret i32 [[SEL]]
 ;
   %ctlz = tail call i32 @llvm.ctlz.i32(i32 %x, i1 false)
@@ -91,10 +85,9 @@ define i32 @bit_ceil_plus_2(i32 %x) {
 ; CHECK-NEXT:  entry:
 ; CHECK-NEXT:    [[SUB:%.*]] = add i32 [[X:%.*]], 1
 ; CHECK-NEXT:    [[CTLZ:%.*]] = tail call i32 @llvm.ctlz.i32(i32 [[SUB]], i1 false), !range [[RNG0]]
-; CHECK-NEXT:    [[SUB2:%.*]] = sub nuw nsw i32 32, [[CTLZ]]
-; CHECK-NEXT:    [[SHL:%.*]] = shl nuw i32 1, [[SUB2]]
-; CHECK-NEXT:    [[ULT:%.*]] = icmp ult i32 [[X]], -2
-; CHECK-NEXT:    [[SEL:%.*]] = select i1 [[ULT]], i32 [[SHL]], i32 1
+; CHECK-NEXT:    [[TMP0:%.*]] = sub nsw i32 0, [[CTLZ]]
+; CHECK-NEXT:    [[TMP1:%.*]] = and i32 [[TMP0]], 31
+; CHECK-NEXT:    [[SEL:%.*]] = shl nuw i32 1, [[TMP1]]
 ; CHECK-NEXT:    ret i32 [[SEL]]
 ;
 entry:
@@ -113,11 +106,9 @@ define i32 @bit_ceil_32_neg(i32 %x) {
 ; CHECK-NEXT:  entry:
 ; CHECK-NEXT:    [[SUB:%.*]] = xor i32 [[X:%.*]], -1
 ; CHECK-NEXT:    [[CTLZ:%.*]] = tail call i32 @llvm.ctlz.i32(i32 [[SUB]], i1 false), !range [[RNG0]]
-; CHECK-NEXT:    [[SUB2:%.*]] = sub nuw nsw i32 32, [[CTLZ]]
-; CHECK-NEXT:    [[SHL:%.*]] = shl nuw i32 1, [[SUB2]]
-; CHECK-NEXT:    [[NOTSUB:%.*]] = add i32 [[X]], -1
-; CHECK-NEXT:    [[ULT:%.*]] = icmp ult i32 [[NOTSUB]], -2
-; CHECK-NEXT:    [[SEL:%.*]] = select i1 [[ULT]], i32 [[SHL]], i32 1
+; CHECK-NEXT:    [[TMP0:%.*]] = sub nsw i32 0, [[CTLZ]]
+; CHECK-NEXT:    [[TMP1:%.*]] = and i32 [[TMP0]], 31
+; CHECK-NEXT:    [[SEL:%.*]] = shl nuw i32 1, [[TMP1]]
 ; CHECK-NEXT:    ret i32 [[SEL]]
 ;
 entry:
@@ -137,10 +128,9 @@ define i32 @bit_ceil_not(i32 %x) {
 ; CHECK-NEXT:  entry:
 ; CHECK-NEXT:    [[SUB:%.*]] = sub i32 -2, [[X:%.*]]
 ; CHECK-NEXT:    [[CTLZ:%.*]] = tail call i32 @llvm.ctlz.i32(i32 [[SUB]], i1 false), !range [[RNG0]]
-; CHECK-NEXT:    [[SUB2:%.*]] = sub nuw nsw i32 32, [[CTLZ]]
-; CHECK-NEXT:    [[SHL:%.*]] = shl nuw i32 1, [[SUB2]]
-; CHECK-NEXT:    [[ULT:%.*]] = icmp ult i32 [[X]], -2
-; CHECK-NEXT:    [[SEL:%.*]] = select i1 [[ULT]], i32 [[SHL]], i32 1
+; CHECK-NEXT:    [[TMP0:%.*]] = sub nsw i32 0, [[CTLZ]]
+; CHECK-NEXT:    [[TMP1:%.*]] = and i32 [[TMP0]], 31
+; CHECK-NEXT:    [[SEL:%.*]] = shl nuw i32 1, [[TMP1]]
 ; CHECK-NEXT:    ret i32 [[SEL]]
 ;
 entry:
@@ -158,18 +148,17 @@ define i32 @bit_ceil_commuted_operands(i32 %x) {
 ; CHECK-LABEL: @bit_ceil_commuted_operands(
 ; CHECK-NEXT:    [[DEC:%.*]] = add i32 [[X:%.*]], -1
 ; CHECK-NEXT:    [[CTLZ:%.*]] = tail call i32 @llvm.ctlz.i32(i32 [[DEC]], i1 false), !range [[RNG0]]
-; CHECK-NEXT:    [[SUB:%.*]] = sub nuw nsw i32 32, [[CTLZ]]
-; CHECK-NEXT:    [[SHL:%.*]] = shl nuw i32 1, [[SUB]]
-; CHECK-NEXT:    [[UGT_INV:%.*]] = icmp ugt i32 [[X]], 1
-; CHECK-NEXT:    [[SEL:%.*]] = select i1 [[UGT_INV]], i32 [[SHL]], i32 1
+; CHECK-NEXT:    [[TMP1:%.*]] = sub nsw i32 0, [[CTLZ]]
+; CHECK-NEXT:    [[TMP2:%.*]] = and i32 [[TMP1]], 31
+; CHECK-NEXT:    [[SEL:%.*]] = shl nuw i32 1, [[TMP2]]
 ; CHECK-NEXT:    ret i32 [[SEL]]
 ;
   %dec = add i32 %x, -1
   %ctlz = tail call i32 @llvm.ctlz.i32(i32 %dec, i1 false)
   %sub = sub i32 32, %ctlz
   %shl = shl i32 1, %sub
-  %ugt = icmp ule i32 %x, 1
-  %sel = select i1 %ugt, i32 1, i32 %shl
+  %eq = icmp eq i32 %dec, 0
+  %sel = select i1 %eq, i32 1, i32 %shl
   ret i32 %sel
 }
 
@@ -282,10 +271,9 @@ define <4 x i32> @bit_ceil_v4i32(<4 x i32> %x) {
 ; CHECK-LABEL: @bit_ceil_v4i32(
 ; CHECK-NEXT:    [[DEC:%.*]] = add <4 x i32> [[X:%.*]], <i32 -1, i32 -1, i32 -1, i32 -1>
 ; CHECK-NEXT:    [[CTLZ:%.*]] = tail call <4 x i32> @llvm.ctlz.v4i32(<4 x i32> [[DEC]], i1 false), !range [[RNG0]]
-; CHECK-NEXT:    [[SUB:%.*]] = sub nuw nsw <4 x i32> <i32 32, i32 32, i32 32, i32 32>, [[CTLZ]]
-; CHECK-NEXT:    [[SHL:%.*]] = shl nuw <4 x i32> <i32 1, i32 1, i32 1, i32 1>, [[SUB]]
-; CHECK-NEXT:    [[UGT:%.*]] = icmp ugt <4 x i32> [[X]], <i32 1, i32 1, i32 1, i32 1>
-; CHECK-NEXT:    [[SEL:%.*]] = select <4 x i1> [[UGT]], <4 x i32> [[SHL]], <4 x i32> <i32 1, i32 1, i32 1, i32 1>
+; CHECK-NEXT:    [[TMP1:%.*]] = sub nsw <4 x i32> zeroinitializer, [[CTLZ]]
+; CHECK-NEXT:    [[TMP2:%.*]] = and <4 x i32> [[TMP1]], <i32 31, i32 31, i32 31, i32 31>
+; CHECK-NEXT:    [[SEL:%.*]] = shl nuw <4 x i32> <i32 1, i32 1, i32 1, i32 1>, [[TMP2]]
 ; CHECK-NEXT:    ret <4 x i32> [[SEL]]
 ;
   %dec = add <4 x i32> %x, <i32 -1, i32 -1, i32 -1, i32 -1>


        


More information about the llvm-commits mailing list