[llvm] [ValueTracking] Handle range attributes (PR #85143)

Andreas Jonson via llvm-commits llvm-commits at lists.llvm.org
Wed Mar 13 15:43:51 PDT 2024


https://github.com/andjo403 created https://github.com/llvm/llvm-project/pull/85143

Handle the range attribute in ValueTracking.

>From a15ab683aee72b66b751ac38be7acae4593d9aa5 Mon Sep 17 00:00:00 2001
From: Andreas Jonson <andjo403 at hotmail.com>
Date: Wed, 13 Mar 2024 19:14:16 +0100
Subject: [PATCH] [ValueTracking] Handle range attributes

---
 llvm/lib/Analysis/ValueTracking.cpp           |  46 +++++-
 llvm/unittests/Analysis/ValueTrackingTest.cpp | 137 ++++++++++++++++++
 2 files changed, 180 insertions(+), 3 deletions(-)

diff --git a/llvm/lib/Analysis/ValueTracking.cpp b/llvm/lib/Analysis/ValueTracking.cpp
index 8a4a2c4f92a0dc..2a2336801990a8 100644
--- a/llvm/lib/Analysis/ValueTracking.cpp
+++ b/llvm/lib/Analysis/ValueTracking.cpp
@@ -1467,14 +1467,21 @@ static void computeKnownBitsFromOperator(const Operator *I,
     break;
   }
   case Instruction::Call:
-  case Instruction::Invoke:
+  case Instruction::Invoke:{
     // If range metadata is attached to this call, set known bits from that,
     // and then intersect with known bits based on other properties of the
     // function.
     if (MDNode *MD =
             Q.IIQ.getMetadata(cast<Instruction>(I), LLVMContext::MD_range))
       computeKnownBitsFromRangeMetadata(*MD, Known);
-    if (const Value *RV = cast<CallBase>(I)->getReturnedArgOperand()) {
+      
+    const CallBase *CB = cast<CallBase>(I);
+    
+    const Attribute RangeAttr = CB->getRetAttr(llvm::Attribute::Range);
+    if (RangeAttr.isValid())
+      Known = RangeAttr.getRange().toKnownBits();
+    
+    if (const Value *RV = CB->getReturnedArgOperand()) {
       if (RV->getType() == I->getType()) {
         computeKnownBits(RV, Known2, Depth + 1, Q);
         Known = Known.unionWith(Known2);
@@ -1646,6 +1653,7 @@ static void computeKnownBitsFromOperator(const Operator *I,
       }
     }
     break;
+  }
   case Instruction::ShuffleVector: {
     auto *Shuf = dyn_cast<ShuffleVectorInst>(I);
     // FIXME: Do we need to handle ConstantExpr involving shufflevectors?
@@ -1900,6 +1908,12 @@ void computeKnownBits(const Value *V, const APInt &DemandedElts,
   // assumptions.  Confirm that we've handled them all.
   assert(!isa<ConstantData>(V) && "Unhandled constant data!");
 
+  if (const Argument *A = dyn_cast<Argument>(V)) {
+    Attribute Range = A->getAttribute(llvm::Attribute::Range);
+    if (Range.isValid())
+      Known = Range.getRange().toKnownBits();
+  }
+
   // All recursive calls that increase depth must come after this.
   if (Depth == MaxAnalysisRecursionDepth)
     return;
@@ -2893,6 +2907,20 @@ bool isKnownNonZero(const Value *V, const APInt &DemandedElts, unsigned Depth,
     }
   }
 
+  Attribute RangeAttr;
+  if (const CallBase *CB = dyn_cast<CallBase>(V))
+    RangeAttr = CB->getRetAttr(llvm::Attribute::Range);
+
+  if (const Argument *A = dyn_cast<Argument>(V))
+    RangeAttr = A->getAttribute(llvm::Attribute::Range);
+
+  if (RangeAttr.isValid()) {
+    const ConstantRange Range = RangeAttr.getRange();
+    const APInt ZeroValue(Range.getBitWidth(), 0);
+    if (!Range.contains(ZeroValue))
+      return true;
+  }
+
   if (!isa<Constant>(V) && isKnownNonZeroFromAssume(V, Q))
     return true;
 
@@ -9114,11 +9142,23 @@ ConstantRange llvm::computeConstantRange(const Value *V, bool ForSigned,
     // TODO: Return ConstantRange.
     setLimitForFPToI(cast<Instruction>(V), Lower, Upper);
     CR = ConstantRange::getNonEmpty(Lower, Upper);
+  } else if (const Argument *A = dyn_cast<Argument>(V)) {
+    const Attribute RangeAttr = A->getAttribute(llvm::Attribute::Range);
+    if (RangeAttr.isValid()) {
+      CR = RangeAttr.getRange();
+    }
   }
 
-  if (auto *I = dyn_cast<Instruction>(V))
+  if (auto *I = dyn_cast<Instruction>(V)){ 
     if (auto *Range = IIQ.getMetadata(I, LLVMContext::MD_range))
       CR = CR.intersectWith(getConstantRangeFromMetadata(*Range));
+    
+    if (const CallBase *CB = dyn_cast<CallBase>(V)) {
+      const Attribute RangeAttr = CB->getRetAttr(llvm::Attribute::Range);
+      if (RangeAttr.isValid()) 
+        CR = CR.intersectWith(RangeAttr.getRange());
+    }
+  }
 
   if (CtxI && AC) {
     // Try to restrict the range based on information from assumptions.
diff --git a/llvm/unittests/Analysis/ValueTrackingTest.cpp b/llvm/unittests/Analysis/ValueTrackingTest.cpp
index 6c6897d83a256e..706b87f5355598 100644
--- a/llvm/unittests/Analysis/ValueTrackingTest.cpp
+++ b/llvm/unittests/Analysis/ValueTrackingTest.cpp
@@ -2085,6 +2085,76 @@ TEST_F(ValueTrackingTest, KnownNonZeroFromDomCond2) {
   EXPECT_EQ(isKnownNonZero(A, DL, 0, &AC, CxtI2, &DT), false);
 }
 
+TEST_F(ValueTrackingTest, KnownNonZeroFromRangeAttributeArgument) {
+  parseAssembly(R"(
+    define i8 @test(i8 range(i8 1, 5) %q) {
+      %A = bitcast i8 %q to i8
+      ret i8 %A
+    }
+  )");
+  const DataLayout &DL = M->getDataLayout();
+  EXPECT_TRUE(isKnownNonZero(A, DL, 0));
+}
+
+TEST_F(ValueTrackingTest, PossibleZeroFromRangeAttributeArgument) {
+  parseAssembly(R"(
+    define i8 @test(i8 range(i8 0, 5) %q) {
+      %A = bitcast i8 %q to i8
+      ret i8 %A
+    }
+  )");
+  const DataLayout &DL = M->getDataLayout();
+  EXPECT_FALSE(isKnownNonZero(A, DL, 0));
+}
+
+TEST_F(ValueTrackingTest, KnownNonZeroFromCallRangeAttribute) {
+  parseAssembly(R"(
+    declare i8 @f()
+    define i8 @test() {
+      %A = call range(i8 1, 5) i8 @f()
+      ret i8 %A
+    }
+  )");
+  const DataLayout &DL = M->getDataLayout();
+  EXPECT_TRUE(isKnownNonZero(A, DL, 0));
+}
+
+TEST_F(ValueTrackingTest, PossibleZeroFromCallRangeAttribute) {
+  parseAssembly(R"(
+    declare i8 @f()
+    define i8 @test() {
+      %A = call range(i8 0, 5) i8 @f()
+      ret i8 %A
+    }
+  )");
+  const DataLayout &DL = M->getDataLayout();
+  EXPECT_FALSE(isKnownNonZero(A, DL, 0));
+}
+
+TEST_F(ValueTrackingTest, KnownNonZeroFromRangeAttributeResult) {
+  parseAssembly(R"(
+    declare range(i8 1, 5) i8 @f()
+    define i8 @test() {
+      %A = call i8 @f()
+      ret i8 %A
+    }
+  )");
+  const DataLayout &DL = M->getDataLayout();
+  EXPECT_TRUE(isKnownNonZero(A, DL, 0));
+}
+
+TEST_F(ValueTrackingTest, PossibleZeroFromRangeAttributeResult) {
+  parseAssembly(R"(
+    declare range(i8 0, 5) i8 @f()
+    define i8 @test() {
+      %A = call i8 @f()
+      ret i8 %A
+    }
+  )");
+  const DataLayout &DL = M->getDataLayout();
+  EXPECT_FALSE(isKnownNonZero(A, DL, 0));
+}
+
 TEST_F(ValueTrackingTest, IsImpliedConditionAnd) {
   parseAssembly(R"(
     define void @test(i32 %x, i32 %y) {
@@ -2601,6 +2671,38 @@ TEST_F(ComputeKnownBitsTest, ComputeKnownBitsAbsoluteSymbol) {
   EXPECT_EQ(0u, Known_0_256_Align8.countMinTrailingOnes());
 }
 
+TEST_F(ComputeKnownBitsTest, KnownBitsFromRangeAttributeArgument) {
+  parseAssembly(R"(
+    define i8 @test(i8 range(i8 80, 84) %q) {
+      %A = bitcast i8 %q to i8
+      ret i8 %A
+    }
+  )");
+  expectKnownBits(/*zero*/ 172u, /*one*/ 80u);
+}
+
+TEST_F(ComputeKnownBitsTest, KnownBitsFromCallRangeAttribute) {
+  parseAssembly(R"(
+    declare i8 @f()
+    define i8 @test() {
+      %A = call range(i8 80, 84) i8 @f()
+      ret i8 %A
+    }
+  )");
+  expectKnownBits(/*zero*/ 172u, /*one*/ 80u);
+}
+
+TEST_F(ComputeKnownBitsTest, KnownBitsFromRangeAttributeResult) {
+  parseAssembly(R"(
+    declare range(i8 80, 84) i8 @f()
+    define i8 @test() {
+      %A = call i8 @f()
+      ret i8 %A
+    }
+  )");
+  expectKnownBits(/*zero*/ 172u, /*one*/ 80u);
+}
+
 TEST_F(ValueTrackingTest, HaveNoCommonBitsSet) {
   {
     // Check for an inverted mask: (X & ~M) op (Y & M).
@@ -3125,6 +3227,41 @@ TEST_F(ValueTrackingTest, ComputeConstantRange) {
   }
 }
 
+TEST_F(ValueTrackingTest, ComputeConstantRangeFromRangeAttributeArgument) {
+  parseAssembly(R"(
+    define i8 @test(i8 range(i8 32, 36) %q) {
+      %A = bitcast i8 %q to i8
+      ret i8 %A
+    }
+  )");
+  EXPECT_EQ(computeConstantRange(F->arg_begin(), false),
+            ConstantRange(APInt(8, 32), APInt(8, 36)));
+}
+
+TEST_F(ValueTrackingTest, ComputeConstantRangeFromCallRangeAttribute) {
+  parseAssembly(R"(
+    declare i8 @f()
+    define i8 @test() {
+      %A = call range(i8 32, 36) i8 @f()
+      ret i8 %A
+    }
+  )");
+  EXPECT_EQ(computeConstantRange(A, false),
+            ConstantRange(APInt(8, 32), APInt(8, 36)));
+}
+
+TEST_F(ValueTrackingTest, ComputeConstantRangeFromRangeAttributeResult) {
+  parseAssembly(R"(
+    declare range(i8 32, 36) i8 @f()
+    define i8 @test() {
+      %A = call i8 @f()
+      ret i8 %A
+    }
+  )");
+  EXPECT_EQ(computeConstantRange(A, false),
+            ConstantRange(APInt(8, 32), APInt(8, 36)));
+}
+
 struct FindAllocaForValueTestParams {
   const char *IR;
   bool AnyOffsetResult;



More information about the llvm-commits mailing list