[clang-tools-extra] a8b55b6 - [clangd] Use Decision Forest to score code completions.

Utkarsh Saxena via cfe-commits cfe-commits at lists.llvm.org
Mon Sep 28 09:59:39 PDT 2020


Author: Utkarsh Saxena
Date: 2020-09-28T18:59:29+02:00
New Revision: a8b55b6939a5962d5b2bf1a57980562d6f3045e5

URL: https://github.com/llvm/llvm-project/commit/a8b55b6939a5962d5b2bf1a57980562d6f3045e5
DIFF: https://github.com/llvm/llvm-project/commit/a8b55b6939a5962d5b2bf1a57980562d6f3045e5.diff

LOG: [clangd] Use Decision Forest to score code completions.

By default clangd will score a code completion item using heuristics model.

Scoring can be done by Decision Forest model by passing `--ranking_model=decision_forest` to
clangd.

Features omitted from the model:
- `NameMatch` is excluded because the final score must be multiplicative in `NameMatch` to allow rescoring by the editor.
- `NeedsFixIts` is excluded because the generating dataset that needs 'fixits' is non-trivial.

There are multiple ways (heuristics) to combine the above two features with the prediction of the DF:
- `NeedsFixIts` is used as is with a penalty of `0.5`.

Various alternatives of combining NameMatch `N` and Decision forest Prediction `P`
- N * scale(P, 0, 1): Linearly scale the output of model to range [0, 1]
- N * a^P:
  - More natural: Prediction of each Decision Tree can be considered as a multiplicative boost (like NameMatch)
  - Ordering is independent of the absolute value of P. Order of two items is proportional to `a^{difference in model prediction score}`. Higher `a` gives higher weightage to model output as compared to NameMatch score.

Baseline MRR = 0.619
MRR for various combinations:
N * P = 0.6346, advantage%=2.5768
N * 1.1^P = 0.6600, advantage%=6.6853
N * **1.2**^P = 0.6669, advantage%=**7.8005**
N * **1.3**^P = 0.6668, advantage%=**7.7795**
N * **1.4**^P = 0.6659, advantage%=**7.6270**
N * 1.5^P = 0.6646, advantage%=7.4200
N * 1.6^P = 0.6636, advantage%=7.2671
N * 1.7^P = 0.6629, advantage%=7.1450
N * 2^P = 0.6612, advantage%=6.8673
N * 2.5^P = 0.6598, advantage%=6.6491
N * 3^P = 0.6590, advantage%=6.5242
N * scaled[0, 1] = 0.6465, advantage%=4.5054

Differential Revision: https://reviews.llvm.org/D88281

Added: 
    

Modified: 
    clang-tools-extra/clangd/CodeComplete.cpp
    clang-tools-extra/clangd/CodeComplete.h
    clang-tools-extra/clangd/Quality.cpp
    clang-tools-extra/clangd/Quality.h
    clang-tools-extra/clangd/tool/ClangdMain.cpp
    clang-tools-extra/clangd/unittests/CodeCompleteTests.cpp

Removed: 
    


################################################################################
diff  --git a/clang-tools-extra/clangd/CodeComplete.cpp b/clang-tools-extra/clangd/CodeComplete.cpp
index 4d5b2975c9ae..90e793f11564 100644
--- a/clang-tools-extra/clangd/CodeComplete.cpp
+++ b/clang-tools-extra/clangd/CodeComplete.cpp
@@ -1625,6 +1625,43 @@ class CodeCompleteFlow {
     return Filter->match(C.Name);
   }
 
+  CodeCompletion::Scores
+  evaluateCompletion(const SymbolQualitySignals &Quality,
+                     const SymbolRelevanceSignals &Relevance) {
+    using RM = CodeCompleteOptions::CodeCompletionRankingModel;
+    CodeCompletion::Scores Scores;
+    switch (Opts.RankingModel) {
+    case RM::Heuristics:
+      Scores.Quality = Quality.evaluate();
+      Scores.Relevance = Relevance.evaluate();
+      Scores.Total =
+          evaluateSymbolAndRelevance(Scores.Quality, Scores.Relevance);
+      // NameMatch is in fact a multiplier on total score, so rescoring is
+      // sound.
+      Scores.ExcludingName = Relevance.NameMatch
+                                 ? Scores.Total / Relevance.NameMatch
+                                 : Scores.Quality;
+      return Scores;
+
+    case RM::DecisionForest:
+      Scores.Quality = 0;
+      Scores.Relevance = 0;
+      // Exponentiating DecisionForest prediction makes the score of each tree a
+      // multiplciative boost (like NameMatch). This allows us to weigh the
+      // prediciton score and NameMatch appropriately.
+      Scores.ExcludingName = pow(Opts.DecisionForestBase,
+                                 evaluateDecisionForest(Quality, Relevance));
+      // NeedsFixIts is not part of the DecisionForest as generating training
+      // data that needs fixits is not-feasible.
+      if (Relevance.NeedsFixIts)
+        Scores.ExcludingName *= 0.5;
+      // NameMatch should be a multiplier on total score to support rescoring.
+      Scores.Total = Relevance.NameMatch * Scores.ExcludingName;
+      return Scores;
+    }
+    llvm_unreachable("Unhandled CodeCompletion ranking model.");
+  }
+
   // Scores a candidate and adds it to the TopN structure.
   void addCandidate(TopN<ScoredBundle, ScoredBundleGreater> &Candidates,
                     CompletionCandidate::Bundle Bundle) {
@@ -1632,6 +1669,7 @@ class CodeCompleteFlow {
     SymbolRelevanceSignals Relevance;
     Relevance.Context = CCContextKind;
     Relevance.Name = Bundle.front().Name;
+    Relevance.FilterLength = HeuristicPrefix.Name.size();
     Relevance.Query = SymbolRelevanceSignals::CodeComplete;
     Relevance.FileProximityMatch = FileProximity.getPointer();
     if (ScopeProximity)
@@ -1680,15 +1718,7 @@ class CodeCompleteFlow {
       }
     }
 
-    CodeCompletion::Scores Scores;
-    Scores.Quality = Quality.evaluate();
-    Scores.Relevance = Relevance.evaluate();
-    Scores.Total = evaluateSymbolAndRelevance(Scores.Quality, Scores.Relevance);
-    // NameMatch is in fact a multiplier on total score, so rescoring is sound.
-    Scores.ExcludingName = Relevance.NameMatch
-                               ? Scores.Total / Relevance.NameMatch
-                               : Scores.Quality;
-
+    CodeCompletion::Scores Scores = evaluateCompletion(Quality, Relevance);
     if (Opts.RecordCCResult)
       Opts.RecordCCResult(toCodeCompletion(Bundle), Quality, Relevance,
                           Scores.Total);

diff  --git a/clang-tools-extra/clangd/CodeComplete.h b/clang-tools-extra/clangd/CodeComplete.h
index beffabd19f3b..82a2656f172e 100644
--- a/clang-tools-extra/clangd/CodeComplete.h
+++ b/clang-tools-extra/clangd/CodeComplete.h
@@ -147,6 +147,22 @@ struct CodeCompleteOptions {
   std::function<void(const CodeCompletion &, const SymbolQualitySignals &,
                      const SymbolRelevanceSignals &, float Score)>
       RecordCCResult;
+
+  /// Model to use for ranking code completion candidates.
+  enum CodeCompletionRankingModel {
+    Heuristics,
+    DecisionForest,
+  } RankingModel = Heuristics;
+
+  /// Weight for combining NameMatch and Prediction of DecisionForest.
+  /// CompletionScore is NameMatch * pow(Base, Prediction).
+  /// The optimal value of Base largely depends on the semantics of the model
+  /// and prediction score (e.g. algorithm used during training, number of
+  /// trees, etc.). Usually if the range of Prediciton is [-20, 20] then a Base
+  /// in [1.2, 1.7] works fine.
+  /// Semantics: E.g. the completion score reduces by 50% if the Prediciton
+  /// score is reduced by 2.6 points for Base = 1.3.
+  float DecisionForestBase = 1.3f;
 };
 
 // Semi-structured representation of a code-complete suggestion for our C++ API.

diff  --git a/clang-tools-extra/clangd/Quality.cpp b/clang-tools-extra/clangd/Quality.cpp
index bf0c0957084c..37f1cf62821a 100644
--- a/clang-tools-extra/clangd/Quality.cpp
+++ b/clang-tools-extra/clangd/Quality.cpp
@@ -8,6 +8,7 @@
 
 #include "Quality.h"
 #include "AST.h"
+#include "CompletionModel.h"
 #include "FileDistance.h"
 #include "SourceCode.h"
 #include "URI.h"
@@ -486,6 +487,34 @@ float evaluateSymbolAndRelevance(float SymbolQuality, float SymbolRelevance) {
   return SymbolQuality * SymbolRelevance;
 }
 
+float evaluateDecisionForest(const SymbolQualitySignals &Quality,
+                             const SymbolRelevanceSignals &Relevance) {
+  Example E;
+  E.setIsDeprecated(Quality.Deprecated);
+  E.setIsReservedName(Quality.ReservedName);
+  E.setIsImplementationDetail(Quality.ImplementationDetail);
+  E.setNumReferences(Quality.References);
+  E.setSymbolCategory(Quality.Category);
+
+  SymbolRelevanceSignals::DerivedSignals Derived =
+      Relevance.calculateDerivedSignals();
+  E.setIsNameInContext(Derived.NameMatchesContext);
+  E.setIsForbidden(Relevance.Forbidden);
+  E.setIsInBaseClass(Relevance.InBaseClass);
+  E.setFileProximityDistance(Derived.FileProximityDistance);
+  E.setSemaFileProximityScore(Relevance.SemaFileProximityScore);
+  E.setSymbolScopeDistance(Derived.ScopeProximityDistance);
+  E.setSemaSaysInScope(Relevance.SemaSaysInScope);
+  E.setScope(Relevance.Scope);
+  E.setContextKind(Relevance.Context);
+  E.setIsInstanceMember(Relevance.IsInstanceMember);
+  E.setHadContextType(Relevance.HadContextType);
+  E.setHadSymbolType(Relevance.HadSymbolType);
+  E.setTypeMatchesPreferred(Relevance.TypeMatchesPreferred);
+  E.setFilterLength(Relevance.FilterLength);
+  return Evaluate(E);
+}
+
 // Produces an integer that sorts in the same order as F.
 // That is: a < b <==> encodeFloat(a) < encodeFloat(b).
 static uint32_t encodeFloat(float F) {

diff  --git a/clang-tools-extra/clangd/Quality.h b/clang-tools-extra/clangd/Quality.h
index 04c6ce211ca9..694653e1a714 100644
--- a/clang-tools-extra/clangd/Quality.h
+++ b/clang-tools-extra/clangd/Quality.h
@@ -77,6 +77,7 @@ struct SymbolQualitySignals {
   void merge(const CodeCompletionResult &SemaCCResult);
   void merge(const Symbol &IndexResult);
 
+  // FIXME(usx): Rename to evaluateHeuristics().
   // Condense these signals down to a single number, higher is better.
   float evaluate() const;
 };
@@ -136,6 +137,10 @@ struct SymbolRelevanceSignals {
   // Whether the item matches the type expected in the completion context.
   bool TypeMatchesPreferred = false;
 
+  /// Length of the unqualified partial name of Symbol typed in
+  /// CompletionPrefix.
+  unsigned FilterLength = 0;
+
   /// Set of derived signals computed by calculateDerivedSignals(). Must not be
   /// set explicitly.
   struct DerivedSignals {
@@ -161,6 +166,8 @@ llvm::raw_ostream &operator<<(llvm::raw_ostream &,
 /// Combine symbol quality and relevance into a single score.
 float evaluateSymbolAndRelevance(float SymbolQuality, float SymbolRelevance);
 
+float evaluateDecisionForest(const SymbolQualitySignals &Quality,
+                             const SymbolRelevanceSignals &Relevance);
 /// TopN<T> is a lossy container that preserves only the "best" N elements.
 template <typename T, typename Compare = std::greater<T>> class TopN {
 public:

diff  --git a/clang-tools-extra/clangd/tool/ClangdMain.cpp b/clang-tools-extra/clangd/tool/ClangdMain.cpp
index 9660f1bd76f7..8e5d6cb97a32 100644
--- a/clang-tools-extra/clangd/tool/ClangdMain.cpp
+++ b/clang-tools-extra/clangd/tool/ClangdMain.cpp
@@ -167,6 +167,26 @@ opt<CodeCompleteOptions::CodeCompletionParse> CodeCompletionParse{
     Hidden,
 };
 
+opt<CodeCompleteOptions::CodeCompletionRankingModel> RankingModel{
+    "ranking-model",
+    cat(Features),
+    desc("Model to use to rank code-completion items"),
+    values(clEnumValN(CodeCompleteOptions::Heuristics, "heuristics",
+                      "Use hueristics to rank code completion items"),
+           clEnumValN(CodeCompleteOptions::DecisionForest, "decision_forest",
+                      "Use Decision Forest model to rank completion items")),
+    init(CodeCompleteOptions().RankingModel),
+    Hidden,
+};
+
+opt<bool> DecisionForestBase{
+    "decision-forest-base",
+    cat(Features),
+    desc("Base for exponentiating the prediction from DecisionForest."),
+    init(CodeCompleteOptions().DecisionForestBase),
+    Hidden,
+};
+
 // FIXME: also support "plain" style where signatures are always omitted.
 enum CompletionStyleFlag { Detailed, Bundled };
 opt<CompletionStyleFlag> CompletionStyle{
@@ -739,6 +759,8 @@ clangd accepts flags on the commandline, and in the CLANGD_FLAGS environment var
   CCOpts.EnableFunctionArgSnippets = EnableFunctionArgSnippets;
   CCOpts.AllScopes = AllScopesCompletion;
   CCOpts.RunParser = CodeCompletionParse;
+  CCOpts.RankingModel = RankingModel;
+  CCOpts.DecisionForestBase = DecisionForestBase;
 
   RealThreadsafeFS TFS;
   std::vector<std::unique_ptr<config::Provider>> ProviderStack;

diff  --git a/clang-tools-extra/clangd/unittests/CodeCompleteTests.cpp b/clang-tools-extra/clangd/unittests/CodeCompleteTests.cpp
index 460976d64f9f..de73bc66a178 100644
--- a/clang-tools-extra/clangd/unittests/CodeCompleteTests.cpp
+++ b/clang-tools-extra/clangd/unittests/CodeCompleteTests.cpp
@@ -10,7 +10,6 @@
 #include "ClangdServer.h"
 #include "CodeComplete.h"
 #include "Compiler.h"
-#include "CompletionModel.h"
 #include "Matchers.h"
 #include "Protocol.h"
 #include "Quality.h"
@@ -163,14 +162,38 @@ Symbol withReferences(int N, Symbol S) {
   return S;
 }
 
-TEST(DecisionForestRuntime, SanityTest) {
-  using Example = clangd::Example;
-  using clangd::Evaluate;
-  Example E1;
-  E1.setContextKind(ContextKind::CCC_ArrowMemberAccess);
-  Example E2;
-  E2.setContextKind(ContextKind::CCC_SymbolOrNewName);
-  EXPECT_GT(Evaluate(E1), Evaluate(E2));
+TEST(DecisionForestRankingModel, NameMatchSanityTest) {
+  clangd::CodeCompleteOptions Opts;
+  Opts.RankingModel = CodeCompleteOptions::DecisionForest;
+  auto Results = completions(
+      R"cpp(
+struct MemberAccess {
+  int ABG();
+  int AlphaBetaGamma();
+};
+int func() { MemberAccess().ABG^ }
+)cpp",
+      /*IndexSymbols=*/{}, Opts);
+  EXPECT_THAT(Results.Completions,
+              ElementsAre(Named("ABG"), Named("AlphaBetaGamma")));
+}
+
+TEST(DecisionForestRankingModel, ReferencesAffectRanking) {
+  clangd::CodeCompleteOptions Opts;
+  Opts.RankingModel = CodeCompleteOptions::DecisionForest;
+  constexpr int NumReferences = 100000;
+  EXPECT_THAT(
+      completions("int main() { clang^ }",
+                  {ns("clangA"), withReferences(NumReferences, func("clangD"))},
+                  Opts)
+          .Completions,
+      ElementsAre(Named("clangD"), Named("clangA")));
+  EXPECT_THAT(
+      completions("int main() { clang^ }",
+                  {withReferences(NumReferences, ns("clangA")), func("clangD")},
+                  Opts)
+          .Completions,
+      ElementsAre(Named("clangA"), Named("clangD")));
 }
 
 TEST(CompletionTest, Limit) {


        


More information about the cfe-commits mailing list