[llvm] 231fa27 - [InstCombine] Generate better code for std::bit_ceil
Kazu Hirata via llvm-commits
llvm-commits at lists.llvm.org
Thu Mar 23 19:27:00 PDT 2023
Author: Kazu Hirata
Date: 2023-03-23T19:26:43-07:00
New Revision: 231fa27435105e980b113754c112980ebeb8927d
URL: https://github.com/llvm/llvm-project/commit/231fa27435105e980b113754c112980ebeb8927d
DIFF: https://github.com/llvm/llvm-project/commit/231fa27435105e980b113754c112980ebeb8927d.diff
LOG: [InstCombine] Generate better code for std::bit_ceil
Without this patch, std::bit_ceil<uint32_t> is compiled as:
%dec = add i32 %x, -1
%lz = tail call i32 @llvm.ctlz.i32(i32 %dec, i1 false)
%sub = sub i32 32, %lz
%res = shl i32 1, %sub
%ugt = icmp ugt i32 %x, 1
%sel = select i1 %ugt, i32 %res, i32 1
With this patch, we generate:
%dec = add i32 %x, -1
%ctlz = tail call i32 @llvm.ctlz.i32(i32 %dec, i1 false)
%sub = sub nsw i32 0, %ctlz
%and = and i32 %1, 31
%sel = shl nuw i32 1, %and
ret i32 %sel
https://alive2.llvm.org/ce/z/pwezvF
This patch recognizes the specific pattern from std::bit_ceil in
libc++ and libstdc++ and drops the conditional move. In addition to
the LLVM IR generated for std::bit_ceil(X), this patch recognizes
variants like:
std::bit_ceil(X - 1)
std::bit_ceil(X + 1)
std::bit_ceil(X + 2)
std::bit_ceil(-X)
std::bit_ceil(~X)
This patch fixes:
https://github.com/llvm/llvm-project/issues/60802
Differential Revision: https://reviews.llvm.org/D145299
Added:
Modified:
llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
llvm/test/Transforms/InstCombine/bit_ceil.ll
Removed:
################################################################################
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
index 1f2441bc9fcf9..3d1dbdd6270d5 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
@@ -3163,6 +3163,134 @@ Instruction *InstCombinerImpl::foldSelectOfBools(SelectInst &SI) {
return nullptr;
}
+// Return true if we can safely remove the select instruction for std::bit_ceil
+// pattern.
+static bool isSafeToRemoveBitCeilSelect(ICmpInst::Predicate Pred, Value *Cond0,
+ const APInt *Cond1, Value *CtlzOp,
+ unsigned BitWidth) {
+ // The challenge in recognizing std::bit_ceil(X) is that the operand is used
+ // for the CTLZ proper and select condition, each possibly with some
+ // operation like add and sub.
+ //
+ // Our aim is to make sure that -ctlz & (BitWidth - 1) == 0 even when the
+ // select instruction would select 1, which allows us to get rid of the select
+ // instruction.
+ //
+ // To see if we can do so, we do some symbolic execution with ConstantRange.
+ // Specifically, we compute the range of values that Cond0 could take when
+ // Cond == false. Then we successively transform the range until we obtain
+ // the range of values that CtlzOp could take.
+ //
+ // Conceptually, we follow the def-use chain backward from Cond0 while
+ // transforming the range for Cond0 until we meet the common ancestor of Cond0
+ // and CtlzOp. Then we follow the def-use chain forward until we obtain the
+ // range for CtlzOp. That said, we only follow at most one ancestor from
+ // Cond0. Likewise, we only follow at most one ancestor from CtrlOp.
+
+ ConstantRange CR = ConstantRange::makeExactICmpRegion(
+ CmpInst::getInversePredicate(Pred), *Cond1);
+
+ // Match the operation that's used to compute CtlzOp from CommonAncestor. If
+ // CtlzOp == CommonAncestor, return true as no operation is needed. If a
+ // match is found, execute the operation on CR, update CR, and return true.
+ // Otherwise, return false.
+ auto MatchForward = [&](Value *CommonAncestor) {
+ const APInt *C = nullptr;
+ if (CtlzOp == CommonAncestor)
+ return true;
+ if (match(CtlzOp, m_Add(m_Specific(CommonAncestor), m_APInt(C)))) {
+ CR = CR.add(*C);
+ return true;
+ }
+ if (match(CtlzOp, m_Sub(m_APInt(C), m_Specific(CommonAncestor)))) {
+ CR = ConstantRange(*C).sub(CR);
+ return true;
+ }
+ if (match(CtlzOp, m_Not(m_Specific(CommonAncestor)))) {
+ CR = CR.binaryNot();
+ return true;
+ }
+ return false;
+ };
+
+ const APInt *C = nullptr;
+ Value *CommonAncestor;
+ if (MatchForward(Cond0)) {
+ // Cond0 is either CtlzOp or CtlzOp's parent. CR has been updated.
+ } else if (match(Cond0, m_Add(m_Value(CommonAncestor), m_APInt(C)))) {
+ CR = CR.sub(*C);
+ if (!MatchForward(CommonAncestor))
+ return false;
+ // Cond0's parent is either CtlzOp or CtlzOp's parent. CR has been updated.
+ } else {
+ return false;
+ }
+
+ // Return true if all the values in the range are either 0 or negative (if
+ // treated as signed). We do so by evaluating:
+ //
+ // CR - 1 u>= (1 << BitWidth) - 1.
+ APInt IntMax = APInt::getSignMask(BitWidth) - 1;
+ CR = CR.sub(APInt(BitWidth, 1));
+ return CR.icmp(ICmpInst::ICMP_UGE, IntMax);
+}
+
+// Transform the std::bit_ceil(X) pattern like:
+//
+// %dec = add i32 %x, -1
+// %ctlz = tail call i32 @llvm.ctlz.i32(i32 %dec, i1 false)
+// %sub = sub i32 32, %ctlz
+// %shl = shl i32 1, %sub
+// %ugt = icmp ugt i32 %x, 1
+// %sel = select i1 %ugt, i32 %shl, i32 1
+//
+// into:
+//
+// %dec = add i32 %x, -1
+// %ctlz = tail call i32 @llvm.ctlz.i32(i32 %dec, i1 false)
+// %neg = sub i32 0, %ctlz
+// %masked = and i32 %ctlz, 31
+// %shl = shl i32 1, %sub
+//
+// Note that the select is optimized away while the shift count is masked with
+// 31. We handle some variations of the input operand like std::bit_ceil(X +
+// 1).
+static Instruction *foldBitCeil(SelectInst &SI, IRBuilderBase &Builder) {
+ Type *SelType = SI.getType();
+ unsigned BitWidth = SelType->getScalarSizeInBits();
+
+ Value *FalseVal = SI.getFalseValue();
+ Value *TrueVal = SI.getTrueValue();
+ ICmpInst::Predicate Pred;
+ const APInt *Cond1;
+ Value *Cond0, *Ctlz, *CtlzOp;
+ if (!match(SI.getCondition(), m_ICmp(Pred, m_Value(Cond0), m_APInt(Cond1))))
+ return nullptr;
+
+ if (match(TrueVal, m_One())) {
+ std::swap(FalseVal, TrueVal);
+ Pred = CmpInst::getInversePredicate(Pred);
+ }
+
+ if (!match(FalseVal, m_One()) ||
+ !match(TrueVal,
+ m_OneUse(m_Shl(m_One(), m_OneUse(m_Sub(m_SpecificInt(BitWidth),
+ m_Value(Ctlz)))))) ||
+ !match(Ctlz, m_Intrinsic<Intrinsic::ctlz>(m_Value(CtlzOp), m_Zero())) ||
+ !isSafeToRemoveBitCeilSelect(Pred, Cond0, Cond1, CtlzOp, BitWidth))
+ return nullptr;
+
+ // Build 1 << (-CTLZ & (BitWidth-1)). The negation likely corresponds to a
+ // single hardware instruction as opposed to BitWidth - CTLZ, where BitWidth
+ // is an integer constant. Masking with BitWidth-1 comes free on some
+ // hardware as part of the shift instruction.
+ Value *Neg = Builder.CreateNeg(Ctlz);
+ Value *Masked =
+ Builder.CreateAnd(Neg, ConstantInt::get(SelType, BitWidth - 1));
+ return BinaryOperator::Create(Instruction::Shl, ConstantInt::get(SelType, 1),
+ Masked);
+}
+
Instruction *InstCombinerImpl::visitSelectInst(SelectInst &SI) {
Value *CondVal = SI.getCondition();
Value *TrueVal = SI.getTrueValue();
@@ -3590,5 +3718,8 @@ Instruction *InstCombinerImpl::visitSelectInst(SelectInst &SI) {
if (sinkNotIntoOtherHandOfLogicalOp(SI))
return &SI;
+ if (Instruction *I = foldBitCeil(SI, Builder))
+ return I;
+
return nullptr;
}
diff --git a/llvm/test/Transforms/InstCombine/bit_ceil.ll b/llvm/test/Transforms/InstCombine/bit_ceil.ll
index 98f4cdb6fb834..6f714153a598a 100644
--- a/llvm/test/Transforms/InstCombine/bit_ceil.ll
+++ b/llvm/test/Transforms/InstCombine/bit_ceil.ll
@@ -6,10 +6,9 @@ define i32 @bit_ceil_32(i32 %x) {
; CHECK-LABEL: @bit_ceil_32(
; CHECK-NEXT: [[DEC:%.*]] = add i32 [[X:%.*]], -1
; CHECK-NEXT: [[CTLZ:%.*]] = tail call i32 @llvm.ctlz.i32(i32 [[DEC]], i1 false), !range [[RNG0:![0-9]+]]
-; CHECK-NEXT: [[SUB:%.*]] = sub nuw nsw i32 32, [[CTLZ]]
-; CHECK-NEXT: [[SHL:%.*]] = shl nuw i32 1, [[SUB]]
-; CHECK-NEXT: [[UGT:%.*]] = icmp ugt i32 [[X]], 1
-; CHECK-NEXT: [[SEL:%.*]] = select i1 [[UGT]], i32 [[SHL]], i32 1
+; CHECK-NEXT: [[TMP1:%.*]] = sub nsw i32 0, [[CTLZ]]
+; CHECK-NEXT: [[TMP2:%.*]] = and i32 [[TMP1]], 31
+; CHECK-NEXT: [[SEL:%.*]] = shl nuw i32 1, [[TMP2]]
; CHECK-NEXT: ret i32 [[SEL]]
;
%dec = add i32 %x, -1
@@ -26,10 +25,9 @@ define i64 @bit_ceil_64(i64 %x) {
; CHECK-LABEL: @bit_ceil_64(
; CHECK-NEXT: [[DEC:%.*]] = add i64 [[X:%.*]], -1
; CHECK-NEXT: [[CTLZ:%.*]] = tail call i64 @llvm.ctlz.i64(i64 [[DEC]], i1 false), !range [[RNG1:![0-9]+]]
-; CHECK-NEXT: [[SUB:%.*]] = sub nuw nsw i64 64, [[CTLZ]]
-; CHECK-NEXT: [[SHL:%.*]] = shl nuw i64 1, [[SUB]]
-; CHECK-NEXT: [[UGT:%.*]] = icmp ugt i64 [[X]], 1
-; CHECK-NEXT: [[SEL:%.*]] = select i1 [[UGT]], i64 [[SHL]], i64 1
+; CHECK-NEXT: [[TMP1:%.*]] = sub nsw i64 0, [[CTLZ]]
+; CHECK-NEXT: [[TMP2:%.*]] = and i64 [[TMP1]], 63
+; CHECK-NEXT: [[SEL:%.*]] = shl nuw i64 1, [[TMP2]]
; CHECK-NEXT: ret i64 [[SEL]]
;
%dec = add i64 %x, -1
@@ -47,11 +45,9 @@ define i32 @bit_ceil_32_minus_1(i32 %x) {
; CHECK-NEXT: entry:
; CHECK-NEXT: [[SUB:%.*]] = add i32 [[X:%.*]], -2
; CHECK-NEXT: [[CTLZ:%.*]] = tail call i32 @llvm.ctlz.i32(i32 [[SUB]], i1 false), !range [[RNG0]]
-; CHECK-NEXT: [[SUB2:%.*]] = sub nuw nsw i32 32, [[CTLZ]]
-; CHECK-NEXT: [[SHL:%.*]] = shl nuw i32 1, [[SUB2]]
-; CHECK-NEXT: [[ADD:%.*]] = add i32 [[X]], -3
-; CHECK-NEXT: [[ULT:%.*]] = icmp ult i32 [[ADD]], -2
-; CHECK-NEXT: [[SEL:%.*]] = select i1 [[ULT]], i32 [[SHL]], i32 1
+; CHECK-NEXT: [[TMP0:%.*]] = sub nsw i32 0, [[CTLZ]]
+; CHECK-NEXT: [[TMP1:%.*]] = and i32 [[TMP0]], 31
+; CHECK-NEXT: [[SEL:%.*]] = shl nuw i32 1, [[TMP1]]
; CHECK-NEXT: ret i32 [[SEL]]
;
entry:
@@ -69,11 +65,9 @@ entry:
define i32 @bit_ceil_32_plus_1(i32 %x) {
; CHECK-LABEL: @bit_ceil_32_plus_1(
; CHECK-NEXT: [[CTLZ:%.*]] = tail call i32 @llvm.ctlz.i32(i32 [[X:%.*]], i1 false), !range [[RNG0]]
-; CHECK-NEXT: [[SUB:%.*]] = sub nuw nsw i32 32, [[CTLZ]]
-; CHECK-NEXT: [[SHL:%.*]] = shl nuw i32 1, [[SUB]]
-; CHECK-NEXT: [[DEC:%.*]] = add i32 [[X]], -1
-; CHECK-NEXT: [[ULT:%.*]] = icmp ult i32 [[DEC]], -2
-; CHECK-NEXT: [[SEL:%.*]] = select i1 [[ULT]], i32 [[SHL]], i32 1
+; CHECK-NEXT: [[TMP1:%.*]] = sub nsw i32 0, [[CTLZ]]
+; CHECK-NEXT: [[TMP2:%.*]] = and i32 [[TMP1]], 31
+; CHECK-NEXT: [[SEL:%.*]] = shl nuw i32 1, [[TMP2]]
; CHECK-NEXT: ret i32 [[SEL]]
;
%ctlz = tail call i32 @llvm.ctlz.i32(i32 %x, i1 false)
@@ -91,10 +85,9 @@ define i32 @bit_ceil_plus_2(i32 %x) {
; CHECK-NEXT: entry:
; CHECK-NEXT: [[SUB:%.*]] = add i32 [[X:%.*]], 1
; CHECK-NEXT: [[CTLZ:%.*]] = tail call i32 @llvm.ctlz.i32(i32 [[SUB]], i1 false), !range [[RNG0]]
-; CHECK-NEXT: [[SUB2:%.*]] = sub nuw nsw i32 32, [[CTLZ]]
-; CHECK-NEXT: [[SHL:%.*]] = shl nuw i32 1, [[SUB2]]
-; CHECK-NEXT: [[ULT:%.*]] = icmp ult i32 [[X]], -2
-; CHECK-NEXT: [[SEL:%.*]] = select i1 [[ULT]], i32 [[SHL]], i32 1
+; CHECK-NEXT: [[TMP0:%.*]] = sub nsw i32 0, [[CTLZ]]
+; CHECK-NEXT: [[TMP1:%.*]] = and i32 [[TMP0]], 31
+; CHECK-NEXT: [[SEL:%.*]] = shl nuw i32 1, [[TMP1]]
; CHECK-NEXT: ret i32 [[SEL]]
;
entry:
@@ -113,11 +106,9 @@ define i32 @bit_ceil_32_neg(i32 %x) {
; CHECK-NEXT: entry:
; CHECK-NEXT: [[SUB:%.*]] = xor i32 [[X:%.*]], -1
; CHECK-NEXT: [[CTLZ:%.*]] = tail call i32 @llvm.ctlz.i32(i32 [[SUB]], i1 false), !range [[RNG0]]
-; CHECK-NEXT: [[SUB2:%.*]] = sub nuw nsw i32 32, [[CTLZ]]
-; CHECK-NEXT: [[SHL:%.*]] = shl nuw i32 1, [[SUB2]]
-; CHECK-NEXT: [[NOTSUB:%.*]] = add i32 [[X]], -1
-; CHECK-NEXT: [[ULT:%.*]] = icmp ult i32 [[NOTSUB]], -2
-; CHECK-NEXT: [[SEL:%.*]] = select i1 [[ULT]], i32 [[SHL]], i32 1
+; CHECK-NEXT: [[TMP0:%.*]] = sub nsw i32 0, [[CTLZ]]
+; CHECK-NEXT: [[TMP1:%.*]] = and i32 [[TMP0]], 31
+; CHECK-NEXT: [[SEL:%.*]] = shl nuw i32 1, [[TMP1]]
; CHECK-NEXT: ret i32 [[SEL]]
;
entry:
@@ -137,10 +128,9 @@ define i32 @bit_ceil_not(i32 %x) {
; CHECK-NEXT: entry:
; CHECK-NEXT: [[SUB:%.*]] = sub i32 -2, [[X:%.*]]
; CHECK-NEXT: [[CTLZ:%.*]] = tail call i32 @llvm.ctlz.i32(i32 [[SUB]], i1 false), !range [[RNG0]]
-; CHECK-NEXT: [[SUB2:%.*]] = sub nuw nsw i32 32, [[CTLZ]]
-; CHECK-NEXT: [[SHL:%.*]] = shl nuw i32 1, [[SUB2]]
-; CHECK-NEXT: [[ULT:%.*]] = icmp ult i32 [[X]], -2
-; CHECK-NEXT: [[SEL:%.*]] = select i1 [[ULT]], i32 [[SHL]], i32 1
+; CHECK-NEXT: [[TMP0:%.*]] = sub nsw i32 0, [[CTLZ]]
+; CHECK-NEXT: [[TMP1:%.*]] = and i32 [[TMP0]], 31
+; CHECK-NEXT: [[SEL:%.*]] = shl nuw i32 1, [[TMP1]]
; CHECK-NEXT: ret i32 [[SEL]]
;
entry:
@@ -158,18 +148,17 @@ define i32 @bit_ceil_commuted_operands(i32 %x) {
; CHECK-LABEL: @bit_ceil_commuted_operands(
; CHECK-NEXT: [[DEC:%.*]] = add i32 [[X:%.*]], -1
; CHECK-NEXT: [[CTLZ:%.*]] = tail call i32 @llvm.ctlz.i32(i32 [[DEC]], i1 false), !range [[RNG0]]
-; CHECK-NEXT: [[SUB:%.*]] = sub nuw nsw i32 32, [[CTLZ]]
-; CHECK-NEXT: [[SHL:%.*]] = shl nuw i32 1, [[SUB]]
-; CHECK-NEXT: [[UGT_INV:%.*]] = icmp ugt i32 [[X]], 1
-; CHECK-NEXT: [[SEL:%.*]] = select i1 [[UGT_INV]], i32 [[SHL]], i32 1
+; CHECK-NEXT: [[TMP1:%.*]] = sub nsw i32 0, [[CTLZ]]
+; CHECK-NEXT: [[TMP2:%.*]] = and i32 [[TMP1]], 31
+; CHECK-NEXT: [[SEL:%.*]] = shl nuw i32 1, [[TMP2]]
; CHECK-NEXT: ret i32 [[SEL]]
;
%dec = add i32 %x, -1
%ctlz = tail call i32 @llvm.ctlz.i32(i32 %dec, i1 false)
%sub = sub i32 32, %ctlz
%shl = shl i32 1, %sub
- %ugt = icmp ule i32 %x, 1
- %sel = select i1 %ugt, i32 1, i32 %shl
+ %eq = icmp eq i32 %dec, 0
+ %sel = select i1 %eq, i32 1, i32 %shl
ret i32 %sel
}
@@ -282,10 +271,9 @@ define <4 x i32> @bit_ceil_v4i32(<4 x i32> %x) {
; CHECK-LABEL: @bit_ceil_v4i32(
; CHECK-NEXT: [[DEC:%.*]] = add <4 x i32> [[X:%.*]], <i32 -1, i32 -1, i32 -1, i32 -1>
; CHECK-NEXT: [[CTLZ:%.*]] = tail call <4 x i32> @llvm.ctlz.v4i32(<4 x i32> [[DEC]], i1 false), !range [[RNG0]]
-; CHECK-NEXT: [[SUB:%.*]] = sub nuw nsw <4 x i32> <i32 32, i32 32, i32 32, i32 32>, [[CTLZ]]
-; CHECK-NEXT: [[SHL:%.*]] = shl nuw <4 x i32> <i32 1, i32 1, i32 1, i32 1>, [[SUB]]
-; CHECK-NEXT: [[UGT:%.*]] = icmp ugt <4 x i32> [[X]], <i32 1, i32 1, i32 1, i32 1>
-; CHECK-NEXT: [[SEL:%.*]] = select <4 x i1> [[UGT]], <4 x i32> [[SHL]], <4 x i32> <i32 1, i32 1, i32 1, i32 1>
+; CHECK-NEXT: [[TMP1:%.*]] = sub nsw <4 x i32> zeroinitializer, [[CTLZ]]
+; CHECK-NEXT: [[TMP2:%.*]] = and <4 x i32> [[TMP1]], <i32 31, i32 31, i32 31, i32 31>
+; CHECK-NEXT: [[SEL:%.*]] = shl nuw <4 x i32> <i32 1, i32 1, i32 1, i32 1>, [[TMP2]]
; CHECK-NEXT: ret <4 x i32> [[SEL]]
;
%dec = add <4 x i32> %x, <i32 -1, i32 -1, i32 -1, i32 -1>
More information about the llvm-commits
mailing list