// Copyright 2018 The Chromium Authors. All rights reserved. // Use of this source code is governed by a BSD-style license that can be // found in the LICENSE file. #include "media/learning/common/target_histogram.h" #include namespace media { namespace learning { TargetHistogram::TargetHistogram() = default; TargetHistogram::TargetHistogram(const TargetHistogram& rhs) = default; TargetHistogram::TargetHistogram(TargetHistogram&& rhs) = default; TargetHistogram::~TargetHistogram() = default; TargetHistogram& TargetHistogram::operator=(const TargetHistogram& rhs) = default; TargetHistogram& TargetHistogram::operator=(TargetHistogram&& rhs) = default; bool TargetHistogram::operator==(const TargetHistogram& rhs) const { return rhs.total_counts() == total_counts() && rhs.counts_ == counts_; } TargetHistogram& TargetHistogram::operator+=(const TargetHistogram& rhs) { for (auto& rhs_pair : rhs.counts()) counts_[rhs_pair.first] += rhs_pair.second; return *this; } TargetHistogram& TargetHistogram::operator+=(const TargetValue& rhs) { counts_[rhs]++; return *this; } TargetHistogram& TargetHistogram::operator+=(const LabelledExample& example) { counts_[example.target_value] += example.weight; return *this; } double TargetHistogram::operator[](const TargetValue& value) const { auto iter = counts_.find(value); if (iter == counts_.end()) return 0; return iter->second; } double& TargetHistogram::operator[](const TargetValue& value) { return counts_[value]; } bool TargetHistogram::FindSingularMax(TargetValue* value_out, double* counts_out) const { if (!counts_.size()) return false; double unused_counts; if (!counts_out) counts_out = &unused_counts; auto iter = counts_.begin(); *value_out = iter->first; *counts_out = iter->second; bool singular_max = true; for (iter++; iter != counts_.end(); iter++) { if (iter->second > *counts_out) { *value_out = iter->first; *counts_out = iter->second; singular_max = true; } else if (iter->second == *counts_out) { // If this turns out to be the max, then it's not singular. singular_max = false; } } return singular_max; } double TargetHistogram::Average() const { double total_value = 0.; double total_counts = 0; for (auto& iter : counts_) { total_value += iter.first.value() * iter.second; total_counts += iter.second; } if (!total_counts) return 0.; return total_value / total_counts; } void TargetHistogram::Normalize() { double total = total_counts(); for (auto& iter : counts_) iter.second /= total; } std::string TargetHistogram::ToString() const { std::ostringstream ss; ss << "["; for (auto& entry : counts_) ss << " " << entry.first << ":" << entry.second; ss << " ]"; return ss.str(); } std::ostream& operator<<(std::ostream& out, const media::learning::TargetHistogram& dist) { return out << dist.ToString(); } } // namespace learning } // namespace media