summaryrefslogtreecommitdiff
path: root/clang-tools-extra/clangd/DecisionForest.cpp
blob: a05a9ee436f1d572b50ad2a595b3a3c77cd2c300 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
//===--- DecisionForest.cpp --------------------------------------*- C++-*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "Features.inc"

#if !CLANGD_DECISION_FOREST
#include "Quality.h"
#include <cstdlib>

namespace clang {
namespace clangd {
DecisionForestScores
evaluateDecisionForest(const SymbolQualitySignals &Quality,
                       const SymbolRelevanceSignals &Relevance, float Base) {
  llvm::errs() << "Clangd was compiled without decision forest support.\n";
  std::abort();
}

} // namespace clangd
} // namespace clang

#else // !CLANGD_DECISION_FOREST

#include "CompletionModel.h"
#include "Quality.h"
#include <cmath>

namespace clang {
namespace clangd {

DecisionForestScores
evaluateDecisionForest(const SymbolQualitySignals &Quality,
                       const SymbolRelevanceSignals &Relevance, float Base) {
  Example E;
  E.setIsDeprecated(Quality.Deprecated);
  E.setIsReservedName(Quality.ReservedName);
  E.setIsImplementationDetail(Quality.ImplementationDetail);
  E.setNumReferences(Quality.References);
  E.setSymbolCategory(Quality.Category);

  SymbolRelevanceSignals::DerivedSignals Derived =
      Relevance.calculateDerivedSignals();
  int NumMatch = 0;
  if (Relevance.ContextWords) {
    for (const auto &Word : Relevance.ContextWords->keys()) {
      if (Relevance.Name.contains_insensitive(Word)) {
        ++NumMatch;
      }
    }
  }
  E.setIsNameInContext(NumMatch > 0);
  E.setNumNameInContext(NumMatch);
  E.setFractionNameInContext(
      Relevance.ContextWords && !Relevance.ContextWords->empty()
          ? NumMatch * 1.0 / Relevance.ContextWords->size()
          : 0);
  E.setIsInBaseClass(Relevance.InBaseClass);
  E.setFileProximityDistanceCost(Derived.FileProximityDistance);
  E.setSemaFileProximityScore(Relevance.SemaFileProximityScore);
  E.setSymbolScopeDistanceCost(Derived.ScopeProximityDistance);
  E.setSemaSaysInScope(Relevance.SemaSaysInScope);
  E.setScope(Relevance.Scope);
  E.setContextKind(Relevance.Context);
  E.setIsInstanceMember(Relevance.IsInstanceMember);
  E.setHadContextType(Relevance.HadContextType);
  E.setHadSymbolType(Relevance.HadSymbolType);
  E.setTypeMatchesPreferred(Relevance.TypeMatchesPreferred);

  DecisionForestScores Scores;
  // Exponentiating DecisionForest prediction makes the score of each tree a
  // multiplciative boost (like NameMatch). This allows us to weigh the
  // prediction score and NameMatch appropriately.
  Scores.ExcludingName = pow(Base, Evaluate(E));
  // Following cases are not part of the generated training dataset:
  //  - Symbols with `NeedsFixIts`.
  //  - Forbidden symbols.
  //  - Keywords: Dataset contains only macros and decls.
  if (Relevance.NeedsFixIts)
    Scores.ExcludingName *= 0.5;
  if (Relevance.Forbidden)
    Scores.ExcludingName *= 0;
  if (Quality.Category == SymbolQualitySignals::Keyword)
    Scores.ExcludingName *= 4;

  // NameMatch should be a multiplier on total score to support rescoring.
  Scores.Total = Relevance.NameMatch * Scores.ExcludingName;
  return Scores;
}

} // namespace clangd
} // namespace clang

#endif // !CLANGD_DECISION_FOREST