[llvm] [ValueTracking] Handle range attributes (PR #85143)
via llvm-commits
llvm-commits at lists.llvm.org
Wed Mar 13 15:44:55 PDT 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-llvm-analysis
Author: Andreas Jonson (andjo403)
<details>
<summary>Changes</summary>
Handle the range attribute in ValueTracking.
---
Full diff: https://github.com/llvm/llvm-project/pull/85143.diff
2 Files Affected:
- (modified) llvm/lib/Analysis/ValueTracking.cpp (+43-3)
- (modified) llvm/unittests/Analysis/ValueTrackingTest.cpp (+137)
``````````diff
diff --git a/llvm/lib/Analysis/ValueTracking.cpp b/llvm/lib/Analysis/ValueTracking.cpp
index 8a4a2c4f92a0dc..2a2336801990a8 100644
--- a/llvm/lib/Analysis/ValueTracking.cpp
+++ b/llvm/lib/Analysis/ValueTracking.cpp
@@ -1467,14 +1467,21 @@ static void computeKnownBitsFromOperator(const Operator *I,
break;
}
case Instruction::Call:
- case Instruction::Invoke:
+ case Instruction::Invoke:{
// If range metadata is attached to this call, set known bits from that,
// and then intersect with known bits based on other properties of the
// function.
if (MDNode *MD =
Q.IIQ.getMetadata(cast<Instruction>(I), LLVMContext::MD_range))
computeKnownBitsFromRangeMetadata(*MD, Known);
- if (const Value *RV = cast<CallBase>(I)->getReturnedArgOperand()) {
+
+ const CallBase *CB = cast<CallBase>(I);
+
+ const Attribute RangeAttr = CB->getRetAttr(llvm::Attribute::Range);
+ if (RangeAttr.isValid())
+ Known = RangeAttr.getRange().toKnownBits();
+
+ if (const Value *RV = CB->getReturnedArgOperand()) {
if (RV->getType() == I->getType()) {
computeKnownBits(RV, Known2, Depth + 1, Q);
Known = Known.unionWith(Known2);
@@ -1646,6 +1653,7 @@ static void computeKnownBitsFromOperator(const Operator *I,
}
}
break;
+ }
case Instruction::ShuffleVector: {
auto *Shuf = dyn_cast<ShuffleVectorInst>(I);
// FIXME: Do we need to handle ConstantExpr involving shufflevectors?
@@ -1900,6 +1908,12 @@ void computeKnownBits(const Value *V, const APInt &DemandedElts,
// assumptions. Confirm that we've handled them all.
assert(!isa<ConstantData>(V) && "Unhandled constant data!");
+ if (const Argument *A = dyn_cast<Argument>(V)) {
+ Attribute Range = A->getAttribute(llvm::Attribute::Range);
+ if (Range.isValid())
+ Known = Range.getRange().toKnownBits();
+ }
+
// All recursive calls that increase depth must come after this.
if (Depth == MaxAnalysisRecursionDepth)
return;
@@ -2893,6 +2907,20 @@ bool isKnownNonZero(const Value *V, const APInt &DemandedElts, unsigned Depth,
}
}
+ Attribute RangeAttr;
+ if (const CallBase *CB = dyn_cast<CallBase>(V))
+ RangeAttr = CB->getRetAttr(llvm::Attribute::Range);
+
+ if (const Argument *A = dyn_cast<Argument>(V))
+ RangeAttr = A->getAttribute(llvm::Attribute::Range);
+
+ if (RangeAttr.isValid()) {
+ const ConstantRange Range = RangeAttr.getRange();
+ const APInt ZeroValue(Range.getBitWidth(), 0);
+ if (!Range.contains(ZeroValue))
+ return true;
+ }
+
if (!isa<Constant>(V) && isKnownNonZeroFromAssume(V, Q))
return true;
@@ -9114,11 +9142,23 @@ ConstantRange llvm::computeConstantRange(const Value *V, bool ForSigned,
// TODO: Return ConstantRange.
setLimitForFPToI(cast<Instruction>(V), Lower, Upper);
CR = ConstantRange::getNonEmpty(Lower, Upper);
+ } else if (const Argument *A = dyn_cast<Argument>(V)) {
+ const Attribute RangeAttr = A->getAttribute(llvm::Attribute::Range);
+ if (RangeAttr.isValid()) {
+ CR = RangeAttr.getRange();
+ }
}
- if (auto *I = dyn_cast<Instruction>(V))
+ if (auto *I = dyn_cast<Instruction>(V)){
if (auto *Range = IIQ.getMetadata(I, LLVMContext::MD_range))
CR = CR.intersectWith(getConstantRangeFromMetadata(*Range));
+
+ if (const CallBase *CB = dyn_cast<CallBase>(V)) {
+ const Attribute RangeAttr = CB->getRetAttr(llvm::Attribute::Range);
+ if (RangeAttr.isValid())
+ CR = CR.intersectWith(RangeAttr.getRange());
+ }
+ }
if (CtxI && AC) {
// Try to restrict the range based on information from assumptions.
diff --git a/llvm/unittests/Analysis/ValueTrackingTest.cpp b/llvm/unittests/Analysis/ValueTrackingTest.cpp
index 6c6897d83a256e..706b87f5355598 100644
--- a/llvm/unittests/Analysis/ValueTrackingTest.cpp
+++ b/llvm/unittests/Analysis/ValueTrackingTest.cpp
@@ -2085,6 +2085,76 @@ TEST_F(ValueTrackingTest, KnownNonZeroFromDomCond2) {
EXPECT_EQ(isKnownNonZero(A, DL, 0, &AC, CxtI2, &DT), false);
}
+TEST_F(ValueTrackingTest, KnownNonZeroFromRangeAttributeArgument) {
+ parseAssembly(R"(
+ define i8 @test(i8 range(i8 1, 5) %q) {
+ %A = bitcast i8 %q to i8
+ ret i8 %A
+ }
+ )");
+ const DataLayout &DL = M->getDataLayout();
+ EXPECT_TRUE(isKnownNonZero(A, DL, 0));
+}
+
+TEST_F(ValueTrackingTest, PossibleZeroFromRangeAttributeArgument) {
+ parseAssembly(R"(
+ define i8 @test(i8 range(i8 0, 5) %q) {
+ %A = bitcast i8 %q to i8
+ ret i8 %A
+ }
+ )");
+ const DataLayout &DL = M->getDataLayout();
+ EXPECT_FALSE(isKnownNonZero(A, DL, 0));
+}
+
+TEST_F(ValueTrackingTest, KnownNonZeroFromCallRangeAttribute) {
+ parseAssembly(R"(
+ declare i8 @f()
+ define i8 @test() {
+ %A = call range(i8 1, 5) i8 @f()
+ ret i8 %A
+ }
+ )");
+ const DataLayout &DL = M->getDataLayout();
+ EXPECT_TRUE(isKnownNonZero(A, DL, 0));
+}
+
+TEST_F(ValueTrackingTest, PossibleZeroFromCallRangeAttribute) {
+ parseAssembly(R"(
+ declare i8 @f()
+ define i8 @test() {
+ %A = call range(i8 0, 5) i8 @f()
+ ret i8 %A
+ }
+ )");
+ const DataLayout &DL = M->getDataLayout();
+ EXPECT_FALSE(isKnownNonZero(A, DL, 0));
+}
+
+TEST_F(ValueTrackingTest, KnownNonZeroFromRangeAttributeResult) {
+ parseAssembly(R"(
+ declare range(i8 1, 5) i8 @f()
+ define i8 @test() {
+ %A = call i8 @f()
+ ret i8 %A
+ }
+ )");
+ const DataLayout &DL = M->getDataLayout();
+ EXPECT_TRUE(isKnownNonZero(A, DL, 0));
+}
+
+TEST_F(ValueTrackingTest, PossibleZeroFromRangeAttributeResult) {
+ parseAssembly(R"(
+ declare range(i8 0, 5) i8 @f()
+ define i8 @test() {
+ %A = call i8 @f()
+ ret i8 %A
+ }
+ )");
+ const DataLayout &DL = M->getDataLayout();
+ EXPECT_FALSE(isKnownNonZero(A, DL, 0));
+}
+
TEST_F(ValueTrackingTest, IsImpliedConditionAnd) {
parseAssembly(R"(
define void @test(i32 %x, i32 %y) {
@@ -2601,6 +2671,38 @@ TEST_F(ComputeKnownBitsTest, ComputeKnownBitsAbsoluteSymbol) {
EXPECT_EQ(0u, Known_0_256_Align8.countMinTrailingOnes());
}
+TEST_F(ComputeKnownBitsTest, KnownBitsFromRangeAttributeArgument) {
+ parseAssembly(R"(
+ define i8 @test(i8 range(i8 80, 84) %q) {
+ %A = bitcast i8 %q to i8
+ ret i8 %A
+ }
+ )");
+ expectKnownBits(/*zero*/ 172u, /*one*/ 80u);
+}
+
+TEST_F(ComputeKnownBitsTest, KnownBitsFromCallRangeAttribute) {
+ parseAssembly(R"(
+ declare i8 @f()
+ define i8 @test() {
+ %A = call range(i8 80, 84) i8 @f()
+ ret i8 %A
+ }
+ )");
+ expectKnownBits(/*zero*/ 172u, /*one*/ 80u);
+}
+
+TEST_F(ComputeKnownBitsTest, KnownBitsFromRangeAttributeResult) {
+ parseAssembly(R"(
+ declare range(i8 80, 84) i8 @f()
+ define i8 @test() {
+ %A = call i8 @f()
+ ret i8 %A
+ }
+ )");
+ expectKnownBits(/*zero*/ 172u, /*one*/ 80u);
+}
+
TEST_F(ValueTrackingTest, HaveNoCommonBitsSet) {
{
// Check for an inverted mask: (X & ~M) op (Y & M).
@@ -3125,6 +3227,41 @@ TEST_F(ValueTrackingTest, ComputeConstantRange) {
}
}
+TEST_F(ValueTrackingTest, ComputeConstantRangeFromRangeAttributeArgument) {
+ parseAssembly(R"(
+ define i8 @test(i8 range(i8 32, 36) %q) {
+ %A = bitcast i8 %q to i8
+ ret i8 %A
+ }
+ )");
+ EXPECT_EQ(computeConstantRange(F->arg_begin(), false),
+ ConstantRange(APInt(8, 32), APInt(8, 36)));
+}
+
+TEST_F(ValueTrackingTest, ComputeConstantRangeFromCallRangeAttribute) {
+ parseAssembly(R"(
+ declare i8 @f()
+ define i8 @test() {
+ %A = call range(i8 32, 36) i8 @f()
+ ret i8 %A
+ }
+ )");
+ EXPECT_EQ(computeConstantRange(A, false),
+ ConstantRange(APInt(8, 32), APInt(8, 36)));
+}
+
+TEST_F(ValueTrackingTest, ComputeConstantRangeFromRangeAttributeResult) {
+ parseAssembly(R"(
+ declare range(i8 32, 36) i8 @f()
+ define i8 @test() {
+ %A = call i8 @f()
+ ret i8 %A
+ }
+ )");
+ EXPECT_EQ(computeConstantRange(A, false),
+ ConstantRange(APInt(8, 32), APInt(8, 36)));
+}
+
struct FindAllocaForValueTestParams {
const char *IR;
bool AnyOffsetResult;
``````````
</details>
https://github.com/llvm/llvm-project/pull/85143
More information about the llvm-commits
mailing list