[llvm] d2502eb - [KnownBits] Add support for nuw/nsw on shifts
Nikita Popov via llvm-commits
llvm-commits at lists.llvm.org
Thu May 25 01:17:19 PDT 2023
Author: Nikita Popov
Date: 2023-05-25T10:17:10+02:00
New Revision: d2502eb091fabc36463e491b066bb002b47ba521
URL: https://github.com/llvm/llvm-project/commit/d2502eb091fabc36463e491b066bb002b47ba521
DIFF: https://github.com/llvm/llvm-project/commit/d2502eb091fabc36463e491b066bb002b47ba521.diff
LOG: [KnownBits] Add support for nuw/nsw on shifts
Implement precise nuw/nsw support in the KnownBits implementation,
replacing the rather crude handling in ValueTracking.
Differential Revision: https://reviews.llvm.org/D151208
Added:
Modified:
llvm/include/llvm/Support/KnownBits.h
llvm/lib/Analysis/ValueTracking.cpp
llvm/lib/Support/KnownBits.cpp
llvm/test/Transforms/ConstraintElimination/geps-unsigned-predicates.ll
llvm/test/Transforms/LoopVectorize/AArch64/sve-interleaved-accesses.ll
llvm/unittests/Support/KnownBitsTest.cpp
Removed:
################################################################################
diff --git a/llvm/include/llvm/Support/KnownBits.h b/llvm/include/llvm/Support/KnownBits.h
index 6ba4dd4f82540..9229a4d61d4b4 100644
--- a/llvm/include/llvm/Support/KnownBits.h
+++ b/llvm/include/llvm/Support/KnownBits.h
@@ -382,7 +382,8 @@ struct KnownBits {
/// Compute known bits for shl(LHS, RHS).
/// NOTE: RHS (shift amount) bitwidth doesn't need to be the same as LHS.
- static KnownBits shl(const KnownBits &LHS, const KnownBits &RHS);
+ static KnownBits shl(const KnownBits &LHS, const KnownBits &RHS,
+ bool NUW = false, bool NSW = false);
/// Compute known bits for lshr(LHS, RHS).
/// NOTE: RHS (shift amount) bitwidth doesn't need to be the same as LHS.
diff --git a/llvm/lib/Analysis/ValueTracking.cpp b/llvm/lib/Analysis/ValueTracking.cpp
index 90dbfcb85fb2c..7ec34cdca0be5 100644
--- a/llvm/lib/Analysis/ValueTracking.cpp
+++ b/llvm/lib/Analysis/ValueTracking.cpp
@@ -1353,20 +1353,10 @@ static void computeKnownBitsFromOperator(const Operator *I,
break;
}
case Instruction::Shl: {
+ bool NUW = Q.IIQ.hasNoUnsignedWrap(cast<OverflowingBinaryOperator>(I));
bool NSW = Q.IIQ.hasNoSignedWrap(cast<OverflowingBinaryOperator>(I));
- auto KF = [NSW](const KnownBits &KnownVal, const KnownBits &KnownAmt) {
- KnownBits Result = KnownBits::shl(KnownVal, KnownAmt);
- // If this shift has "nsw" keyword, then the result is either a poison
- // value or has the same sign bit as the first operand.
- if (NSW) {
- if (KnownVal.Zero.isSignBitSet())
- Result.Zero.setSignBit();
- if (KnownVal.One.isSignBitSet())
- Result.One.setSignBit();
- if (Result.hasConflict())
- Result.setAllZero();
- }
- return Result;
+ auto KF = [NUW, NSW](const KnownBits &KnownVal, const KnownBits &KnownAmt) {
+ return KnownBits::shl(KnownVal, KnownAmt, NUW, NSW);
};
computeKnownBitsFromShiftOperator(I, DemandedElts, Known, Known2, Depth, Q,
KF);
diff --git a/llvm/lib/Support/KnownBits.cpp b/llvm/lib/Support/KnownBits.cpp
index d52f08739be79..8bb236baf4ae5 100644
--- a/llvm/lib/Support/KnownBits.cpp
+++ b/llvm/lib/Support/KnownBits.cpp
@@ -164,21 +164,51 @@ KnownBits KnownBits::smin(const KnownBits &LHS, const KnownBits &RHS) {
return Flip(umax(Flip(LHS), Flip(RHS)));
}
-KnownBits KnownBits::shl(const KnownBits &LHS, const KnownBits &RHS) {
+KnownBits KnownBits::shl(const KnownBits &LHS, const KnownBits &RHS, bool NUW,
+ bool NSW) {
unsigned BitWidth = LHS.getBitWidth();
- KnownBits Known(BitWidth);
+ auto ShiftByConst = [&](const KnownBits &LHS,
+ uint64_t ShiftAmt) -> std::optional<KnownBits> {
+ KnownBits Known;
+ Known.Zero = LHS.Zero << ShiftAmt;
+ Known.Zero.setLowBits(ShiftAmt);
+ Known.One = LHS.One << ShiftAmt;
+ if ((!NUW && !NSW) || ShiftAmt == 0)
+ return Known;
+
+ KnownBits ShiftedOutBits = LHS.extractBits(ShiftAmt, BitWidth - ShiftAmt);
+ if (NUW && !ShiftedOutBits.One.isZero())
+ // One bit has been shifted out.
+ return std::nullopt;
+ if (NSW) {
+ if (!ShiftedOutBits.Zero.isZero() && !ShiftedOutBits.One.isZero())
+ // Both zeros and ones have been shifted out.
+ return std::nullopt;
+ if (NUW || !ShiftedOutBits.Zero.isZero()) {
+ if (Known.isNegative())
+ // Zero bit has been shifted out, but result sign is negative.
+ return std::nullopt;
+ Known.makeNonNegative();
+ } else if (!ShiftedOutBits.One.isZero()) {
+ if (Known.isNonNegative())
+ // One bit has been shifted out, but result sign is negative.
+ return std::nullopt;
+ Known.makeNegative();
+ }
+ }
+ return Known;
+ };
// If the shift amount is a valid constant then transform LHS directly.
if (RHS.isConstant() && RHS.getConstant().ult(BitWidth)) {
- unsigned Shift = RHS.getConstant().getZExtValue();
- Known = LHS;
- Known.Zero <<= Shift;
- Known.One <<= Shift;
- // Low bits are known zero.
- Known.Zero.setLowBits(Shift);
+ if (auto Res = ShiftByConst(LHS, RHS.getConstant().getZExtValue()))
+ return *Res;
+ KnownBits Known(BitWidth);
+ Known.setAllZero();
return Known;
}
+ KnownBits Known(BitWidth);
APInt MinShiftAmount = RHS.getMinValue();
if (MinShiftAmount.uge(BitWidth)) {
// Always poison. Return zero because we don't like returning conflict.
@@ -193,6 +223,8 @@ KnownBits KnownBits::shl(const KnownBits &LHS, const KnownBits &RHS) {
MinTrailingZeros += MinShiftAmount.getZExtValue();
MinTrailingZeros = std::min(MinTrailingZeros, BitWidth);
Known.Zero.setLowBits(MinTrailingZeros);
+ if (NUW && NSW && !MinShiftAmount.isZero())
+ Known.makeNonNegative();
return Known;
}
@@ -210,15 +242,20 @@ KnownBits KnownBits::shl(const KnownBits &LHS, const KnownBits &RHS) {
if ((ShiftAmtZeroMask & ShiftAmt) != ShiftAmt ||
(ShiftAmtOneMask | ShiftAmt) != ShiftAmt)
continue;
- KnownBits SpecificShift;
- SpecificShift.Zero = LHS.Zero << ShiftAmt;
- SpecificShift.Zero.setLowBits(ShiftAmt);
- SpecificShift.One = LHS.One << ShiftAmt;
- Known = Known.intersectWith(SpecificShift);
+ auto Res = ShiftByConst(LHS, ShiftAmt);
+ if (!Res)
+ // All larger shift amounts will overflow as well.
+ break;
+ Known = Known.intersectWith(*Res);
if (Known.isUnknown())
break;
}
+ // All shift amounts may result in poison.
+ if (Known.hasConflict()) {
+ assert((NUW || NSW) && "Can only happen with nowrap flags");
+ Known.setAllZero();
+ }
return Known;
}
diff --git a/llvm/test/Transforms/ConstraintElimination/geps-unsigned-predicates.ll b/llvm/test/Transforms/ConstraintElimination/geps-unsigned-predicates.ll
index 524a13fadf661..51eb96c38c1c8 100644
--- a/llvm/test/Transforms/ConstraintElimination/geps-unsigned-predicates.ll
+++ b/llvm/test/Transforms/ConstraintElimination/geps-unsigned-predicates.ll
@@ -513,7 +513,7 @@ define void @test.ult.gep.shl(ptr readonly %src, ptr readnone %max, i8 %idx) {
; CHECK-NEXT: [[IDX_SHL_1:%.*]] = shl nuw nsw i8 [[IDX]], 1
; CHECK-NEXT: [[ADD_PTR_SHL_1:%.*]] = getelementptr inbounds i32, ptr [[SRC]], i8 [[IDX_SHL_1]]
; CHECK-NEXT: [[C_MAX_0:%.*]] = icmp ult ptr [[ADD_PTR_SHL_1]], [[MAX]]
-; CHECK-NEXT: call void @use(i1 [[C_MAX_0]])
+; CHECK-NEXT: call void @use(i1 true)
; CHECK-NEXT: [[IDX_SHL_2:%.*]] = shl nuw i8 [[IDX]], 2
; CHECK-NEXT: [[ADD_PTR_SHL_2:%.*]] = getelementptr inbounds i32, ptr [[SRC]], i8 [[IDX_SHL_2]]
; CHECK-NEXT: [[C_MAX_1:%.*]] = icmp ult ptr [[ADD_PTR_SHL_2]], [[MAX]]
diff --git a/llvm/test/Transforms/LoopVectorize/AArch64/sve-interleaved-accesses.ll b/llvm/test/Transforms/LoopVectorize/AArch64/sve-interleaved-accesses.ll
index 24f6a5a0b1eb3..7c91170bd051b 100644
--- a/llvm/test/Transforms/LoopVectorize/AArch64/sve-interleaved-accesses.ll
+++ b/llvm/test/Transforms/LoopVectorize/AArch64/sve-interleaved-accesses.ll
@@ -82,7 +82,7 @@ define void @test_array_load2_store2(i32 %C, i32 %D) #1 {
; CHECK-NEXT: [[ARRAYIDX3:%.*]] = getelementptr inbounds [1024 x i32], ptr @CD, i64 0, i64 [[OR]]
; CHECK-NEXT: store i32 [[MUL]], ptr [[ARRAYIDX3]], align 4
; CHECK-NEXT: [[INDVARS_IV_NEXT]] = add nuw nsw i64 [[INDVARS_IV]], 2
-; CHECK-NEXT: [[CMP:%.*]] = icmp slt i64 [[INDVARS_IV]], 1022
+; CHECK-NEXT: [[CMP:%.*]] = icmp ult i64 [[INDVARS_IV]], 1022
; CHECK-NEXT: br i1 [[CMP]], label [[FOR_BODY]], label [[FOR_END]], !llvm.loop [[LOOP3:![0-9]+]]
; CHECK: for.end:
; CHECK-NEXT: ret void
diff --git a/llvm/unittests/Support/KnownBitsTest.cpp b/llvm/unittests/Support/KnownBitsTest.cpp
index 1607ddfdec502..427b323912abb 100644
--- a/llvm/unittests/Support/KnownBitsTest.cpp
+++ b/llvm/unittests/Support/KnownBitsTest.cpp
@@ -343,6 +343,41 @@ TEST(KnownBitsTest, BinaryExhaustive) {
return std::nullopt;
return N1.shl(N2);
});
+ testBinaryOpExhaustive(
+ [](const KnownBits &Known1, const KnownBits &Known2) {
+ return KnownBits::shl(Known1, Known2, /* NUW */ true);
+ },
+ [](const APInt &N1, const APInt &N2) -> std::optional<APInt> {
+ bool Overflow;
+ APInt Res = N1.ushl_ov(N2, Overflow);
+ if (Overflow)
+ return std::nullopt;
+ return Res;
+ });
+ testBinaryOpExhaustive(
+ [](const KnownBits &Known1, const KnownBits &Known2) {
+ return KnownBits::shl(Known1, Known2, /* NUW */ false, /* NSW */ true);
+ },
+ [](const APInt &N1, const APInt &N2) -> std::optional<APInt> {
+ bool Overflow;
+ APInt Res = N1.sshl_ov(N2, Overflow);
+ if (Overflow)
+ return std::nullopt;
+ return Res;
+ });
+ testBinaryOpExhaustive(
+ [](const KnownBits &Known1, const KnownBits &Known2) {
+ return KnownBits::shl(Known1, Known2, /* NUW */ true, /* NSW */ true);
+ },
+ [](const APInt &N1, const APInt &N2) -> std::optional<APInt> {
+ bool OverflowUnsigned, OverflowSigned;
+ APInt Res = N1.ushl_ov(N2, OverflowUnsigned);
+ (void)N1.sshl_ov(N2, OverflowSigned);
+ if (OverflowUnsigned || OverflowSigned)
+ return std::nullopt;
+ return Res;
+ });
+
testBinaryOpExhaustive(
[](const KnownBits &Known1, const KnownBits &Known2) {
return KnownBits::lshr(Known1, Known2);
More information about the llvm-commits
mailing list