[llvm] bdcbdb4 - [Attributor] Deduction based on path exploration

Hideto Ueno via llvm-commits llvm-commits at lists.llvm.org
Sun Mar 8 22:42:11 PDT 2020


Author: Hideto Ueno
Date: 2020-03-09T14:29:26+09:00
New Revision: bdcbdb484829c518511eece0809cc8ce1baa73c8

URL: https://github.com/llvm/llvm-project/commit/bdcbdb484829c518511eece0809cc8ce1baa73c8
DIFF: https://github.com/llvm/llvm-project/commit/bdcbdb484829c518511eece0809cc8ce1baa73c8.diff

LOG: [Attributor] Deduction based on path exploration

This patch introduces the propagation of known information based on path exploration.
For example,
```
int u(int c, int *p){
  if(c) {
     return *p;
  } else {
     return *p + 1;
  }
}
```
An argument `p` is dereferenced whatever c's value is.

For an instruction `CtxI`, we accumulate branch instructions in the must-be-executed-context of `CtxI` and then, we take the conjunction of the successors' known state.

Reviewed By: jdoerfert

Differential Revision: https://reviews.llvm.org/D65593

Added: 
    

Modified: 
    llvm/include/llvm/Analysis/MustExecute.h
    llvm/include/llvm/Transforms/IPO/Attributor.h
    llvm/lib/Transforms/IPO/Attributor.cpp
    llvm/test/Transforms/Attributor/IPConstantProp/openmp_parallel_for.ll
    llvm/test/Transforms/Attributor/dereferenceable-1.ll
    llvm/test/Transforms/Attributor/nonnull.ll

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/Analysis/MustExecute.h b/llvm/include/llvm/Analysis/MustExecute.h
index 56db454470b7..587ebcbfb8f8 100644
--- a/llvm/include/llvm/Analysis/MustExecute.h
+++ b/llvm/include/llvm/Analysis/MustExecute.h
@@ -461,6 +461,18 @@ struct MustBeExecutedContextExplorer {
   }
   ///}
 
+  /// Check \p Pred on all instructions in the context.
+  ///
+  /// This method will evaluate \p Pred and return
+  /// true if \p Pred holds in every instruction.
+  bool checkForAllContext(const Instruction *PP,
+                          const function_ref<bool(const Instruction *)> &Pred) {
+    for (auto EIt = begin(PP), EEnd = end(PP); EIt != EEnd; EIt++)
+      if (!Pred(*EIt))
+        return false;
+    return true;
+  }
+
   /// Helper to look for \p I in the context of \p PP.
   ///
   /// The context is expanded until \p I was found or no more expansion is

diff  --git a/llvm/include/llvm/Transforms/IPO/Attributor.h b/llvm/include/llvm/Transforms/IPO/Attributor.h
index 5d6c231a27bb..27a879df6674 100644
--- a/llvm/include/llvm/Transforms/IPO/Attributor.h
+++ b/llvm/include/llvm/Transforms/IPO/Attributor.h
@@ -1339,6 +1339,13 @@ struct IntegerStateBase : public AbstractState {
     handleNewAssumedValue(R.getAssumed());
   }
 
+  /// "Clamp" this state with \p R. The result is subtype dependent but it is
+  /// intended that information known in either state will be known in
+  /// this one afterwards.
+  void operator+=(const IntegerStateBase<base_t, BestState, WorstState> &R) {
+    handleNewKnownValue(R.getKnown());
+  }
+
   void operator|=(const IntegerStateBase<base_t, BestState, WorstState> &R) {
     joinOR(R.getAssumed(), R.getKnown());
   }
@@ -2294,6 +2301,13 @@ struct DerefState : AbstractState {
     return *this;
   }
 
+  /// See IntegerStateBase::operator+=
+  DerefState operator+=(const DerefState &R) {
+    DerefBytesState += R.DerefBytesState;
+    GlobalState += R.GlobalState;
+    return *this;
+  }
+
   /// See IntegerStateBase::operator&=
   DerefState operator&=(const DerefState &R) {
     DerefBytesState &= R.DerefBytesState;

diff  --git a/llvm/lib/Transforms/IPO/Attributor.cpp b/llvm/lib/Transforms/IPO/Attributor.cpp
index 3aef6e9fe5b4..d0bd1cc6692c 100644
--- a/llvm/lib/Transforms/IPO/Attributor.cpp
+++ b/llvm/lib/Transforms/IPO/Attributor.cpp
@@ -980,6 +980,23 @@ struct AAFromMustBeExecutedContext : public Base {
       Uses.insert(&U);
   }
 
+  /// Helper function to accumulate uses.
+  void followUsesInContext(Attributor &A,
+                           MustBeExecutedContextExplorer &Explorer,
+                           const Instruction *CtxI,
+                           SetVector<const Use *> &Uses, StateType &State) {
+    auto EIt = Explorer.begin(CtxI), EEnd = Explorer.end(CtxI);
+    for (unsigned u = 0; u < Uses.size(); ++u) {
+      const Use *U = Uses[u];
+      if (const Instruction *UserI = dyn_cast<Instruction>(U->getUser())) {
+        bool Found = Explorer.findInContextOf(UserI, EIt, EEnd);
+        if (Found && Base::followUse(A, U, UserI, State))
+          for (const Use &Us : UserI->uses())
+            Uses.insert(&Us);
+      }
+    }
+  }
+
   /// See AbstractAttribute::updateImpl(...).
   ChangeStatus updateImpl(Attributor &A) override {
     auto BeforeState = this->getState();
@@ -991,15 +1008,74 @@ struct AAFromMustBeExecutedContext : public Base {
     MustBeExecutedContextExplorer &Explorer =
         A.getInfoCache().getMustBeExecutedContextExplorer();
 
-    auto EIt = Explorer.begin(CtxI), EEnd = Explorer.end(CtxI);
-    for (unsigned u = 0; u < Uses.size(); ++u) {
-      const Use *U = Uses[u];
-      if (const Instruction *UserI = dyn_cast<Instruction>(U->getUser())) {
-        bool Found = Explorer.findInContextOf(UserI, EIt, EEnd);
-        if (Found && Base::followUse(A, U, UserI))
-          for (const Use &Us : UserI->uses())
-            Uses.insert(&Us);
+    followUsesInContext(A, Explorer, CtxI, Uses, S);
+
+    if (this->isAtFixpoint())
+      return ChangeStatus::CHANGED;
+
+    SmallVector<const BranchInst *, 4> BrInsts;
+    auto Pred = [&](const Instruction *I) {
+      if (const BranchInst *Br = dyn_cast<BranchInst>(I))
+        if (Br->isConditional())
+          BrInsts.push_back(Br);
+      return true;
+    };
+
+    // Here, accumulate conditional branch instructions in the context. We
+    // explore the child paths and collect the known states. The disjunction of
+    // those states can be merged to its own state. Let ParentState_i be a state
+    // to indicate the known information for an i-th branch instruction in the
+    // context. ChildStates are created for its successors respectively.
+    //
+    // ParentS_1 = ChildS_{1, 1} /\ ChildS_{1, 2} /\ ... /\ ChildS_{1, n_1}
+    // ParentS_2 = ChildS_{2, 1} /\ ChildS_{2, 2} /\ ... /\ ChildS_{2, n_2}
+    //      ...
+    // ParentS_m = ChildS_{m, 1} /\ ChildS_{m, 2} /\ ... /\ ChildS_{m, n_m}
+    //
+    // Known State |= ParentS_1 \/ ParentS_2 \/... \/ ParentS_m
+    //
+    // FIXME: Currently, recursive branches are not handled. For example, we
+    // can't deduce that ptr must be dereferenced in below function.
+    //
+    // void f(int a, int c, int *ptr) {
+    //    if(a)
+    //      if (b) {
+    //        *ptr = 0;
+    //      } else {
+    //        *ptr = 1;
+    //      }
+    //    else {
+    //      if (b) {
+    //        *ptr = 0;
+    //      } else {
+    //        *ptr = 1;
+    //      }
+    //    }
+    // }
+
+    Explorer.checkForAllContext(CtxI, Pred);
+    for (const BranchInst *Br : BrInsts) {
+      StateType ParentState;
+
+      // The known state of the parent state is a conjunction of children's
+      // known states so it is initialized with a best state.
+      ParentState.indicateOptimisticFixpoint();
+
+      for (const BasicBlock *BB : Br->successors()) {
+        StateType ChildState;
+
+        size_t BeforeSize = Uses.size();
+        followUsesInContext(A, Explorer, &BB->front(), Uses, ChildState);
+
+        // Erase uses which only appear in the child.
+        for (auto It = Uses.begin() + BeforeSize; It != Uses.end();)
+          It = Uses.erase(It);
+
+        ParentState &= ChildState;
       }
+
+      // Use only known state.
+      S += ParentState;
     }
 
     return BeforeState == S ? ChangeStatus::UNCHANGED : ChangeStatus::CHANGED;
@@ -1900,7 +1976,7 @@ struct AANoFreeCallSiteReturned final : AANoFreeFloating {
 
 /// ------------------------ NonNull Argument Attribute ------------------------
 static int64_t getKnownNonNullAndDerefBytesForUse(
-    Attributor &A, AbstractAttribute &QueryingAA, Value &AssociatedValue,
+    Attributor &A, const AbstractAttribute &QueryingAA, Value &AssociatedValue,
     const Use *U, const Instruction *I, bool &IsNonNull, bool &TrackUse) {
   TrackUse = false;
 
@@ -1991,12 +2067,13 @@ struct AANonNullImpl : AANonNull {
   }
 
   /// See AAFromMustBeExecutedContext
-  bool followUse(Attributor &A, const Use *U, const Instruction *I) {
+  bool followUse(Attributor &A, const Use *U, const Instruction *I,
+                 AANonNull::StateType &State) {
     bool IsNonNull = false;
     bool TrackUse = false;
     getKnownNonNullAndDerefBytesForUse(A, *this, getAssociatedValue(), U, I,
                                        IsNonNull, TrackUse);
-    setKnown(IsNonNull);
+    State.setKnown(IsNonNull);
     return TrackUse;
   }
 
@@ -3549,8 +3626,8 @@ struct AADereferenceableImpl : AADereferenceable {
   /// }
 
   /// Helper function for collecting accessed bytes in must-be-executed-context
-  void addAccessedBytesForUse(Attributor &A, const Use *U,
-                              const Instruction *I) {
+  void addAccessedBytesForUse(Attributor &A, const Use *U, const Instruction *I,
+                              DerefState &State) {
     const Value *UseV = U->get();
     if (!UseV->getType()->isPointerTy())
       return;
@@ -3563,21 +3640,22 @@ struct AADereferenceableImpl : AADereferenceable {
       if (Base == &getAssociatedValue() &&
           getPointerOperand(I, /* AllowVolatile */ false) == UseV) {
         uint64_t Size = DL.getTypeStoreSize(PtrTy->getPointerElementType());
-        addAccessedBytes(Offset, Size);
+        State.addAccessedBytes(Offset, Size);
       }
     }
     return;
   }
 
   /// See AAFromMustBeExecutedContext
-  bool followUse(Attributor &A, const Use *U, const Instruction *I) {
+  bool followUse(Attributor &A, const Use *U, const Instruction *I,
+                 AADereferenceable::StateType &State) {
     bool IsNonNull = false;
     bool TrackUse = false;
     int64_t DerefBytes = getKnownNonNullAndDerefBytesForUse(
         A, *this, getAssociatedValue(), U, I, IsNonNull, TrackUse);
 
-    addAccessedBytesForUse(A, U, I);
-    takeKnownDerefBytesMaximum(DerefBytes);
+    addAccessedBytesForUse(A, U, I, State);
+    State.takeKnownDerefBytesMaximum(DerefBytes);
     return TrackUse;
   }
 
@@ -3871,12 +3949,13 @@ struct AAAlignImpl : AAAlign {
           Attribute::getWithAlignment(Ctx, Align(getAssumedAlign())));
   }
   /// See AAFromMustBeExecutedContext
-  bool followUse(Attributor &A, const Use *U, const Instruction *I) {
+  bool followUse(Attributor &A, const Use *U, const Instruction *I,
+                 AAAlign::StateType &State) {
     bool TrackUse = false;
 
     unsigned int KnownAlign =
         getKnownAlignForUse(A, *this, getAssociatedValue(), U, I, TrackUse);
-    takeKnownMaximum(KnownAlign);
+    State.takeKnownMaximum(KnownAlign);
 
     return TrackUse;
   }

diff  --git a/llvm/test/Transforms/Attributor/IPConstantProp/openmp_parallel_for.ll b/llvm/test/Transforms/Attributor/IPConstantProp/openmp_parallel_for.ll
index 7c51808deb9c..c919bdbec49b 100644
--- a/llvm/test/Transforms/Attributor/IPConstantProp/openmp_parallel_for.ll
+++ b/llvm/test/Transforms/Attributor/IPConstantProp/openmp_parallel_for.ll
@@ -1,5 +1,5 @@
 ; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --function-signature --scrub-attributes
-; RUN: opt -S -passes=attributor -aa-pipeline='basic-aa' -attributor-disable=false -attributor-max-iterations-verify -attributor-max-iterations=1 < %s | FileCheck %s
+; RUN: opt -S -passes=attributor -aa-pipeline='basic-aa' -attributor-disable=false -attributor-max-iterations-verify -attributor-max-iterations=2 < %s | FileCheck %s
 ;
 ;    void bar(int, float, double);
 ;

diff  --git a/llvm/test/Transforms/Attributor/dereferenceable-1.ll b/llvm/test/Transforms/Attributor/dereferenceable-1.ll
index dcec9e50a820..87fb7867eabd 100644
--- a/llvm/test/Transforms/Attributor/dereferenceable-1.ll
+++ b/llvm/test/Transforms/Attributor/dereferenceable-1.ll
@@ -308,5 +308,152 @@ entry:
   ret void
 }
 
+declare void @use0() willreturn nounwind
+declare void @use1(i8*) willreturn nounwind
+declare void @use2(i8*, i8*) willreturn nounwind
+declare void @use3(i8*, i8*, i8*) willreturn nounwind
+; simple path test
+; if(..)
+;   fun2(dereferenceable(8) %a, dereferenceable(8) %b)
+; else
+;   fun2(dereferenceable(4) %a, %b)
+; We can say that %a is dereferenceable(4) but %b is not.
+define void @simple-path(i8* %a, i8 * %b, i8 %c) {
+; ATTRIBUTOR: define void @simple-path(i8* nonnull dereferenceable(4) %a, i8* %b, i8 %c)
+  %cmp = icmp eq i8 %c, 0
+  br i1 %cmp, label %if.then, label %if.else
+if.then:
+  tail call void @use2(i8* dereferenceable(8) %a, i8* dereferenceable(8) %b)
+  ret void
+if.else:
+  tail call void @use2(i8* dereferenceable(4) %a, i8* %b)
+  ret void
+}
+; More complex test
+; {
+; fun1(dereferenceable(4) %a)
+; if(..)
+;    ... (willreturn & nounwind)
+;    fun1(dereferenceable(12) %a)
+; else
+;    ... (willreturn & nounwind)
+;    fun1(dereferenceable(16) %a)
+; fun1(dereferenceable(8) %a)
+; }
+; %a is dereferenceable(12)
+
+define void @complex-path(i8* %a, i8* %b, i8 %c) {
+; ATTRIBUTOR: define void @complex-path(i8* nonnull dereferenceable(12) %a, i8* nocapture nofree readnone %b, i8 %c)
+  %cmp = icmp eq i8 %c, 0
+  tail call void @use1(i8* dereferenceable(4) %a)
+  br i1 %cmp, label %cont.then, label %cont.else
+cont.then:
+  tail call void @use1(i8* dereferenceable(12) %a)
+  br label %cont2
+cont.else:
+  tail call void @use1(i8* dereferenceable(16) %a)
+  br label %cont2
+cont2:
+  tail call void @use1(i8* dereferenceable(8) %a)
+  ret void
+}
+
+;  void rec-branch-1(int a, int b, int c, int *ptr) {
+;    if (a) {
+;      if (b)
+;        *ptr = 1;
+;      else
+;        *ptr = 2;
+;    } else {
+;      if (c)
+;        *ptr = 3;
+;      else
+;        *ptr = 4;
+;    }
+;  }
+;
+; FIXME: %ptr should be dereferenceable(4)
+; ATTRIBUTOR: define dso_local void @rec-branch-1(i32 %a, i32 %b, i32 %c, i32* nocapture nofree writeonly %ptr)
+define dso_local void @rec-branch-1(i32 %a, i32 %b, i32 %c, i32* %ptr) {
+entry:
+  %tobool = icmp eq i32 %a, 0
+  br i1 %tobool, label %if.else3, label %if.then
+
+if.then:                                          ; preds = %entry
+  %tobool1 = icmp eq i32 %b, 0
+  br i1 %tobool1, label %if.else, label %if.then2
+
+if.then2:                                         ; preds = %if.then
+  store i32 1, i32* %ptr, align 4
+  br label %if.end8
+
+if.else:                                          ; preds = %if.then
+  store i32 2, i32* %ptr, align 4
+  br label %if.end8
+
+if.else3:                                         ; preds = %entry
+  %tobool4 = icmp eq i32 %c, 0
+  br i1 %tobool4, label %if.else6, label %if.then5
+
+if.then5:                                         ; preds = %if.else3
+  store i32 3, i32* %ptr, align 4
+  br label %if.end8
+
+if.else6:                                         ; preds = %if.else3
+  store i32 4, i32* %ptr, align 4
+  br label %if.end8
+
+if.end8:                                          ; preds = %if.then5, %if.else6, %if.then2, %if.else
+  ret void
+}
+
+;  void rec-branch-2(int a, int b, int c, int *ptr) {
+;    if (a) {
+;      if (b)
+;        *ptr = 1;
+;      else
+;        *ptr = 2;
+;    } else {
+;      if (c)
+;        *ptr = 3;
+;      else
+;        rec-branch-2(1, 1, 1, ptr);
+;    }
+;  }
+; FIXME: %ptr should be dereferenceable(4)
+; ATTRIBUTOR: define dso_local void @rec-branch-2(i32 %a, i32 %b, i32 %c, i32* nocapture nofree writeonly %ptr)
+define dso_local void @rec-branch-2(i32 %a, i32 %b, i32 %c, i32* %ptr) {
+entry:
+  %tobool = icmp eq i32 %a, 0
+  br i1 %tobool, label %if.else3, label %if.then
+
+if.then:                                          ; preds = %entry
+  %tobool1 = icmp eq i32 %b, 0
+  br i1 %tobool1, label %if.else, label %if.then2
+
+if.then2:                                         ; preds = %if.then
+  store i32 1, i32* %ptr, align 4
+  br label %if.end8
+
+if.else:                                          ; preds = %if.then
+  store i32 2, i32* %ptr, align 4
+  br label %if.end8
+
+if.else3:                                         ; preds = %entry
+  %tobool4 = icmp eq i32 %c, 0
+  br i1 %tobool4, label %if.else6, label %if.then5
+
+if.then5:                                         ; preds = %if.else3
+  store i32 3, i32* %ptr, align 4
+  br label %if.end8
+
+if.else6:                                         ; preds = %if.else3
+  tail call void @rec-branch-2(i32 1, i32 1, i32 1, i32* %ptr)
+  br label %if.end8
+
+if.end8:                                          ; preds = %if.then5, %if.else6, %if.then2, %if.else
+  ret void
+}
+
 !0 = !{i64 10, i64 100}
 

diff  --git a/llvm/test/Transforms/Attributor/nonnull.ll b/llvm/test/Transforms/Attributor/nonnull.ll
index d3ccccf35c09..d4b0f9af393a 100644
--- a/llvm/test/Transforms/Attributor/nonnull.ll
+++ b/llvm/test/Transforms/Attributor/nonnull.ll
@@ -257,8 +257,7 @@ declare void @fun3(i8*, i8*, i8*) #1
 ;   fun2(nonnull %a, %b)
 ; We can say that %a is nonnull but %b is not.
 define void @f16(i8* %a, i8 * %b, i8 %c) {
-; FIXME: missing nonnull on %a
-; ATTRIBUTOR: define void @f16(i8* %a, i8* %b, i8 %c)
+; ATTRIBUTOR: define void @f16(i8* nonnull %a, i8* %b, i8 %c)
   %cmp = icmp eq i8 %c, 0
   br i1 %cmp, label %if.then, label %if.else
 if.then:
@@ -327,8 +326,7 @@ cont2:
 ; TEST 19: Loop
 
 define void @f19(i8* %a, i8* %b, i8 %c) {
-; FIXME: missing nonnull on %b
-; ATTRIBUTOR: define void @f19(i8* %a, i8* %b, i8 %c)
+; ATTRIBUTOR: define void @f19(i8* %a, i8* nonnull %b, i8 %c)
   br label %loop.header
 loop.header:
   %cmp2 = icmp eq i8 %c, 0
@@ -658,7 +656,7 @@ hd2:
 define i32 @nonnull_exec_ctx_2(i32* %a, i32 %b) willreturn nounwind {
 ;
 ; ATTRIBUTOR-LABEL: define {{[^@]+}}@nonnull_exec_ctx_2
-; ATTRIBUTOR-SAME: (i32* [[A:%.*]], i32 [[B:%.*]])
+; ATTRIBUTOR-SAME: (i32* nonnull [[A:%.*]], i32 [[B:%.*]])
 ; ATTRIBUTOR-NEXT:  en:
 ; ATTRIBUTOR-NEXT:    [[TMP3:%.*]] = icmp eq i32 [[B]], 0
 ; ATTRIBUTOR-NEXT:    br i1 [[TMP3]], label [[EX:%.*]], label [[HD:%.*]]
@@ -691,7 +689,7 @@ hd:
 define i32 @nonnull_exec_ctx_2b(i32* %a, i32 %b) willreturn nounwind {
 ;
 ; ATTRIBUTOR-LABEL: define {{[^@]+}}@nonnull_exec_ctx_2b
-; ATTRIBUTOR-SAME: (i32* [[A:%.*]], i32 [[B:%.*]])
+; ATTRIBUTOR-SAME: (i32* nonnull [[A:%.*]], i32 [[B:%.*]])
 ; ATTRIBUTOR-NEXT:  en:
 ; ATTRIBUTOR-NEXT:    [[TMP3:%.*]] = icmp eq i32 [[B]], 0
 ; ATTRIBUTOR-NEXT:    br i1 [[TMP3]], label [[EX:%.*]], label [[HD:%.*]]


        


More information about the llvm-commits mailing list