[llvm] 8d04181 - [ValueTracking] Use assumptions in computeConstantRange.

Florian Hahn via llvm-commits llvm-commits at lists.llvm.org
Sat May 23 12:19:23 PDT 2020


Author: Florian Hahn
Date: 2020-05-23T20:07:52+01:00
New Revision: 8d041811983b0b8b9df5fa5f76d8b9d5178d03ea

URL: https://github.com/llvm/llvm-project/commit/8d041811983b0b8b9df5fa5f76d8b9d5178d03ea
DIFF: https://github.com/llvm/llvm-project/commit/8d041811983b0b8b9df5fa5f76d8b9d5178d03ea.diff

LOG: [ValueTracking] Use assumptions in computeConstantRange.

This patch updates computeConstantRange to optionally take an assumption
cache as argument and use the available assumptions to limit the range
of the result.

Currently this is limited to assumptions that are comparisons.

Reviewers: reames, nikic, spatel, jdoerfert, lebedev.ri

Reviewed By: nikic

Differential Revision: https://reviews.llvm.org/D76193

Added: 
    

Modified: 
    llvm/include/llvm/Analysis/ValueTracking.h
    llvm/lib/Analysis/ValueTracking.cpp
    llvm/unittests/Analysis/ValueTrackingTest.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/Analysis/ValueTracking.h b/llvm/include/llvm/Analysis/ValueTracking.h
index caec9326755e..9510739ef5ab 100644
--- a/llvm/include/llvm/Analysis/ValueTracking.h
+++ b/llvm/include/llvm/Analysis/ValueTracking.h
@@ -531,7 +531,10 @@ class Value;
 
   /// Determine the possible constant range of an integer or vector of integer
   /// value. This is intended as a cheap, non-recursive check.
-  ConstantRange computeConstantRange(const Value *V, bool UseInstrInfo = true);
+  ConstantRange computeConstantRange(const Value *V, bool UseInstrInfo = true,
+                                     AssumptionCache *AC = nullptr,
+                                     const Instruction *CtxI = nullptr,
+                                     unsigned Depth = 0);
 
   /// Return true if this function can prove that the instruction I will
   /// always transfer execution to one of its successors (including the next

diff  --git a/llvm/lib/Analysis/ValueTracking.cpp b/llvm/lib/Analysis/ValueTracking.cpp
index a5fb6fba4d5a..1b73a2062095 100644
--- a/llvm/lib/Analysis/ValueTracking.cpp
+++ b/llvm/lib/Analysis/ValueTracking.cpp
@@ -6367,9 +6367,15 @@ static void setLimitsForSelectPattern(const SelectInst &SI, APInt &Lower,
   }
 }
 
-ConstantRange llvm::computeConstantRange(const Value *V, bool UseInstrInfo) {
+ConstantRange llvm::computeConstantRange(const Value *V, bool UseInstrInfo,
+                                         AssumptionCache *AC,
+                                         const Instruction *CtxI,
+                                         unsigned Depth) {
   assert(V->getType()->isIntOrIntVectorTy() && "Expected integer instruction");
 
+  if (Depth == MaxDepth)
+    return ConstantRange::getFull(V->getType()->getScalarSizeInBits());
+
   const APInt *C;
   if (match(V, m_APInt(C)))
     return ConstantRange(*C);
@@ -6391,6 +6397,31 @@ ConstantRange llvm::computeConstantRange(const Value *V, bool UseInstrInfo) {
     if (auto *Range = IIQ.getMetadata(I, LLVMContext::MD_range))
       CR = CR.intersectWith(getConstantRangeFromMetadata(*Range));
 
+  if (CtxI && AC) {
+    // Try to restrict the range based on information from assumptions.
+    for (auto &AssumeVH : AC->assumptionsFor(V)) {
+      if (!AssumeVH)
+        continue;
+      CallInst *I = cast<CallInst>(AssumeVH);
+      assert(I->getParent()->getParent() == CtxI->getParent()->getParent() &&
+             "Got assumption for the wrong function!");
+      assert(I->getCalledFunction()->getIntrinsicID() == Intrinsic::assume &&
+             "must be an assume intrinsic");
+
+      if (!isValidAssumeForContext(I, CtxI, nullptr))
+        continue;
+      Value *Arg = I->getArgOperand(0);
+      ICmpInst *Cmp = dyn_cast<ICmpInst>(Arg);
+      // Currently we just use information from comparisons.
+      if (!Cmp || Cmp->getOperand(0) != V)
+        continue;
+      ConstantRange RHS = computeConstantRange(Cmp->getOperand(1), UseInstrInfo,
+                                               AC, I, Depth + 1);
+      CR = CR.intersectWith(
+          ConstantRange::makeSatisfyingICmpRegion(Cmp->getPredicate(), RHS));
+    }
+  }
+
   return CR;
 }
 

diff  --git a/llvm/unittests/Analysis/ValueTrackingTest.cpp b/llvm/unittests/Analysis/ValueTrackingTest.cpp
index a5ebb7ff5ce1..7dcb6204ba40 100644
--- a/llvm/unittests/Analysis/ValueTrackingTest.cpp
+++ b/llvm/unittests/Analysis/ValueTrackingTest.cpp
@@ -9,6 +9,7 @@
 #include "llvm/Analysis/ValueTracking.h"
 #include "llvm/Analysis/AssumptionCache.h"
 #include "llvm/AsmParser/Parser.h"
+#include "llvm/IR/ConstantRange.h"
 #include "llvm/IR/Function.h"
 #include "llvm/IR/InstIterator.h"
 #include "llvm/IR/Instructions.h"
@@ -23,6 +24,14 @@ using namespace llvm;
 
 namespace {
 
+static Instruction &findInstructionByName(Function *F, StringRef Name) {
+  for (Instruction &I : instructions(F))
+    if (I.getName() == Name)
+      return I;
+
+  llvm_unreachable("Expected value not found");
+}
+
 class ValueTrackingTest : public testing::Test {
 protected:
   std::unique_ptr<Module> parseModule(StringRef Assembly) {
@@ -46,13 +55,7 @@ class ValueTrackingTest : public testing::Test {
     if (!F)
       return;
 
-    A = nullptr;
-    for (inst_iterator I = inst_begin(F), E = inst_end(F); I != E; ++I) {
-      if (I->hasName()) {
-        if (I->getName() == "A")
-          A = &*I;
-      }
-    }
+    A = &findInstructionByName(F, "A");
     ASSERT_TRUE(A) << "@test must have an instruction %A";
   }
 
@@ -1246,3 +1249,168 @@ TEST_P(IsBytewiseValueTest, IsBytewiseValue) {
     S << *Actual;
   EXPECT_EQ(GetParam().first, S.str());
 }
+
+TEST_F(ValueTrackingTest, ComputeConstantRange) {
+  {
+    // Assumptions:
+    //  * stride >= 5
+    //  * stride < 10
+    //
+    // stride = [5, 10)
+    auto M = parseModule(R"(
+  declare void @llvm.assume(i1)
+
+  define i32 @test(i32 %stride) {
+    %gt = icmp uge i32 %stride, 5
+    call void @llvm.assume(i1 %gt)
+    %lt = icmp ult i32 %stride, 10
+    call void @llvm.assume(i1 %lt)
+    %stride.plus.one = add nsw nuw i32 %stride, 1
+    ret i32 %stride.plus.one
+  })");
+    Function *F = M->getFunction("test");
+
+    AssumptionCache AC(*F);
+    Value *Stride = &*F->arg_begin();
+    ConstantRange CR1 = computeConstantRange(Stride, true, &AC, nullptr);
+    EXPECT_TRUE(CR1.isFullSet());
+
+    Instruction *I = &findInstructionByName(F, "stride.plus.one");
+    ConstantRange CR2 = computeConstantRange(Stride, true, &AC, I);
+    EXPECT_EQ(5, CR2.getLower());
+    EXPECT_EQ(10, CR2.getUpper());
+  }
+
+  {
+    // Assumptions:
+    //  * stride >= 5
+    //  * stride < 200
+    //  * stride == 99
+    //
+    // stride = [99, 100)
+    auto M = parseModule(R"(
+  declare void @llvm.assume(i1)
+
+  define i32 @test(i32 %stride) {
+    %gt = icmp uge i32 %stride, 5
+    call void @llvm.assume(i1 %gt)
+    %lt = icmp ult i32 %stride, 200
+    call void @llvm.assume(i1 %lt)
+    %eq = icmp eq i32 %stride, 99
+    call void @llvm.assume(i1 %eq)
+    %stride.plus.one = add nsw nuw i32 %stride, 1
+    ret i32 %stride.plus.one
+  })");
+    Function *F = M->getFunction("test");
+
+    AssumptionCache AC(*F);
+    Value *Stride = &*F->arg_begin();
+    Instruction *I = &findInstructionByName(F, "stride.plus.one");
+    ConstantRange CR = computeConstantRange(Stride, true, &AC, I);
+    EXPECT_EQ(99, *CR.getSingleElement());
+  }
+
+  {
+    // Assumptions:
+    //  * stride >= 5
+    //  * stride >= 50
+    //  * stride < 100
+    //  * stride < 200
+    //
+    // stride = [50, 100)
+    auto M = parseModule(R"(
+  declare void @llvm.assume(i1)
+
+  define i32 @test(i32 %stride, i1 %cond) {
+    %gt = icmp uge i32 %stride, 5
+    call void @llvm.assume(i1 %gt)
+    %gt.2 = icmp uge i32 %stride, 50
+    call void @llvm.assume(i1 %gt.2)
+    br i1 %cond, label %bb1, label %bb2
+
+  bb1:
+    %lt = icmp ult i32 %stride, 200
+    call void @llvm.assume(i1 %lt)
+    %lt.2 = icmp ult i32 %stride, 100
+    call void @llvm.assume(i1 %lt.2)
+    %stride.plus.one = add nsw nuw i32 %stride, 1
+    ret i32 %stride.plus.one
+
+  bb2:
+    ret i32 0
+  })");
+    Function *F = M->getFunction("test");
+
+    AssumptionCache AC(*F);
+    Value *Stride = &*F->arg_begin();
+    Instruction *GT2 = &findInstructionByName(F, "gt.2");
+    ConstantRange CR = computeConstantRange(Stride, true, &AC, GT2);
+    EXPECT_EQ(5, CR.getLower());
+    EXPECT_EQ(0, CR.getUpper());
+
+    Instruction *I = &findInstructionByName(F, "stride.plus.one");
+    ConstantRange CR2 = computeConstantRange(Stride, true, &AC, I);
+    EXPECT_EQ(50, CR2.getLower());
+    EXPECT_EQ(100, CR2.getUpper());
+  }
+
+  {
+    // Assumptions:
+    //  * stride > 5
+    //  * stride < 5
+    //
+    // stride = empty range, as the assumptions contradict each other.
+    auto M = parseModule(R"(
+  declare void @llvm.assume(i1)
+
+  define i32 @test(i32 %stride, i1 %cond) {
+    %gt = icmp ugt i32 %stride, 5
+    call void @llvm.assume(i1 %gt)
+    %lt = icmp ult i32 %stride, 5
+    call void @llvm.assume(i1 %lt)
+    %stride.plus.one = add nsw nuw i32 %stride, 1
+    ret i32 %stride.plus.one
+  })");
+    Function *F = M->getFunction("test");
+
+    AssumptionCache AC(*F);
+    Value *Stride = &*F->arg_begin();
+
+    Instruction *I = &findInstructionByName(F, "stride.plus.one");
+    ConstantRange CR = computeConstantRange(Stride, true, &AC, I);
+    EXPECT_TRUE(CR.isEmptySet());
+  }
+
+  {
+    // Assumptions:
+    //  * x.1 >= 5
+    //  * x.2 < x.1
+    //
+    // stride = [0, 5)
+    auto M = parseModule(R"(
+  declare void @llvm.assume(i1)
+
+  define i32 @test(i32 %x.1, i32 %x.2) {
+    %gt = icmp uge i32 %x.1, 5
+    call void @llvm.assume(i1 %gt)
+    %lt = icmp ult i32 %x.2, %x.1
+    call void @llvm.assume(i1 %lt)
+    %stride.plus.one = add nsw nuw i32 %x.1, 1
+    ret i32 %stride.plus.one
+  })");
+    Function *F = M->getFunction("test");
+
+    AssumptionCache AC(*F);
+    Value *X2 = &*std::next(F->arg_begin());
+
+    Instruction *I = &findInstructionByName(F, "stride.plus.one");
+    ConstantRange CR1 = computeConstantRange(X2, true, &AC, I);
+    EXPECT_EQ(0, CR1.getLower());
+    EXPECT_EQ(5, CR1.getUpper());
+
+    // Check the depth cutoff results in a conservative result (full set) by
+    // passing Depth == MaxDepth == 6.
+    ConstantRange CR2 = computeConstantRange(X2, true, &AC, I, 6);
+    EXPECT_TRUE(CR2.isFullSet());
+  }
+}


        


More information about the llvm-commits mailing list