// Copyright 2018 The Chromium Authors. All rights reserved. // Use of this source code is governed by a BSD-style license that can be // found in the LICENSE file. #ifndef MEDIA_LEARNING_COMMON_LABELLED_EXAMPLE_H_ #define MEDIA_LEARNING_COMMON_LABELLED_EXAMPLE_H_ #include #include #include #include "base/component_export.h" #include "base/macros.h" #include "base/memory/ref_counted.h" #include "media/learning/common/value.h" namespace media { namespace learning { // Vector of features, for training or prediction. // To interpret the features, one probably needs to check a LearningTask. It // provides a description for each index. For example, [0]=="height", // [1]=="url", etc. using FeatureVector = std::vector; using WeightType = size_t; // One training example == group of feature values, plus the desired target. struct COMPONENT_EXPORT(LEARNING_COMMON) LabelledExample { LabelledExample(); LabelledExample(FeatureVector feature_vector, TargetValue target); LabelledExample(std::initializer_list init_list, TargetValue target); LabelledExample(const LabelledExample& rhs); LabelledExample(LabelledExample&& rhs); ~LabelledExample(); // Comparisons ignore weight, because it's convenient. bool operator==(const LabelledExample& rhs) const; bool operator!=(const LabelledExample& rhs) const; bool operator<(const LabelledExample& rhs) const; LabelledExample& operator=(const LabelledExample& rhs); LabelledExample& operator=(LabelledExample&& rhs) noexcept; // Observed feature values. // Note that to interpret these values, you probably need to have the // LearningTask that they're supposed to be used with. FeatureVector features; // Observed output value, when given |features| as input. TargetValue target_value; WeightType weight = 1u; // Copy / assignment is allowed. }; // TODO(liberato): This should probably move to impl/ . class COMPONENT_EXPORT(LEARNING_COMMON) TrainingData { public: using ExampleVector = std::vector; using const_iterator = ExampleVector::const_iterator; TrainingData(); TrainingData(const TrainingData& rhs); TrainingData(TrainingData&& rhs); TrainingData& operator=(const TrainingData& rhs); TrainingData& operator=(TrainingData&& rhs); ~TrainingData(); // Add |example| with weight |weight|. void push_back(const LabelledExample& example) { DCHECK_GT(example.weight, 0u); examples_.push_back(example); total_weight_ += example.weight; } bool empty() const { return !total_weight_; } size_t size() const { return examples_.size(); } // Returns the number of instances, taking into account their weight. For // example, if one adds an example with weight 2, then this will return two // more than it did before. WeightType total_weight() const { return total_weight_; } const_iterator begin() const { return examples_.begin(); } const_iterator end() const { return examples_.end(); } bool is_unweighted() const { return examples_.size() == total_weight_; } // Provide the |i|-th example, over [0, size()). const LabelledExample& operator[](size_t i) const { return examples_[i]; } LabelledExample& operator[](size_t i) { return examples_[i]; } // Return a copy of this data with duplicate entries merged. Example weights // will be summed. TrainingData DeDuplicate() const; private: ExampleVector examples_; WeightType total_weight_ = 0u; // Copy / assignment is allowed. }; COMPONENT_EXPORT(LEARNING_COMMON) std::ostream& operator<<(std::ostream& out, const LabelledExample& example); COMPONENT_EXPORT(LEARNING_COMMON) std::ostream& operator<<(std::ostream& out, const FeatureVector& features); } // namespace learning } // namespace media #endif // MEDIA_LEARNING_COMMON_LABELLED_EXAMPLE_H_