summaryrefslogtreecommitdiff
path: root/chromium/media/learning/common/labelled_example.h
blob: a728dc7049dceee9d00083975ccc66c5bfcead72 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
// Copyright 2018 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#ifndef MEDIA_LEARNING_COMMON_LABELLED_EXAMPLE_H_
#define MEDIA_LEARNING_COMMON_LABELLED_EXAMPLE_H_

#include <initializer_list>
#include <ostream>
#include <vector>

#include "base/component_export.h"
#include "base/macros.h"
#include "base/memory/ref_counted.h"
#include "media/learning/common/value.h"

namespace media {
namespace learning {

// Vector of features, for training or prediction.
// To interpret the features, one probably needs to check a LearningTask.  It
// provides a description for each index.  For example, [0]=="height",
// [1]=="url", etc.
using FeatureVector = std::vector<FeatureValue>;

using WeightType = size_t;

// One training example == group of feature values, plus the desired target.
struct COMPONENT_EXPORT(LEARNING_COMMON) LabelledExample {
  LabelledExample();
  LabelledExample(FeatureVector feature_vector, TargetValue target);
  LabelledExample(std::initializer_list<FeatureValue> init_list,
                  TargetValue target);
  LabelledExample(const LabelledExample& rhs);
  LabelledExample(LabelledExample&& rhs) noexcept;
  ~LabelledExample();

  // Comparisons ignore weight, because it's convenient.
  bool operator==(const LabelledExample& rhs) const;
  bool operator!=(const LabelledExample& rhs) const;
  bool operator<(const LabelledExample& rhs) const;

  LabelledExample& operator=(const LabelledExample& rhs);
  LabelledExample& operator=(LabelledExample&& rhs) noexcept;

  // Observed feature values.
  // Note that to interpret these values, you probably need to have the
  // LearningTask that they're supposed to be used with.
  FeatureVector features;

  // Observed output value, when given |features| as input.
  TargetValue target_value;

  WeightType weight = 1u;

  // Copy / assignment is allowed.
};

// TODO(liberato): This should probably move to impl/ .
class COMPONENT_EXPORT(LEARNING_COMMON) TrainingData {
 public:
  using ExampleVector = std::vector<LabelledExample>;
  using const_iterator = ExampleVector::const_iterator;

  TrainingData();
  TrainingData(const TrainingData& rhs);
  TrainingData(TrainingData&& rhs);

  TrainingData& operator=(const TrainingData& rhs);
  TrainingData& operator=(TrainingData&& rhs);

  ~TrainingData();

  // Add |example| with weight |weight|.
  void push_back(const LabelledExample& example) {
    DCHECK_GT(example.weight, 0u);
    examples_.push_back(example);
    total_weight_ += example.weight;
  }

  bool empty() const { return !total_weight_; }

  size_t size() const { return examples_.size(); }

  // Returns the number of instances, taking into account their weight.  For
  // example, if one adds an example with weight 2, then this will return two
  // more than it did before.
  WeightType total_weight() const { return total_weight_; }

  const_iterator begin() const { return examples_.begin(); }
  const_iterator end() const { return examples_.end(); }

  bool is_unweighted() const { return examples_.size() == total_weight_; }

  // Provide the |i|-th example, over [0, size()).
  const LabelledExample& operator[](size_t i) const { return examples_[i]; }
  LabelledExample& operator[](size_t i) { return examples_[i]; }

  // Return a copy of this data with duplicate entries merged.  Example weights
  // will be summed.
  TrainingData DeDuplicate() const;

 private:
  ExampleVector examples_;

  WeightType total_weight_ = 0u;

  // Copy / assignment is allowed.
};

COMPONENT_EXPORT(LEARNING_COMMON)
std::ostream& operator<<(std::ostream& out, const LabelledExample& example);

COMPONENT_EXPORT(LEARNING_COMMON)
std::ostream& operator<<(std::ostream& out, const FeatureVector& features);

}  // namespace learning
}  // namespace media

#endif  // MEDIA_LEARNING_COMMON_LABELLED_EXAMPLE_H_