summaryrefslogtreecommitdiff
path: root/webrtc/modules/audio_processing/ns/prior_signal_model_estimator.cc
blob: c814658e57a56668514e2bc0bb84132fe297b1c0 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
/*
 *  Copyright (c) 2019 The WebRTC project authors. All Rights Reserved.
 *
 *  Use of this source code is governed by a BSD-style license
 *  that can be found in the LICENSE file in the root of the source
 *  tree. An additional intellectual property rights grant can be found
 *  in the file PATENTS.  All contributing project authors may
 *  be found in the AUTHORS file in the root of the source tree.
 */

#include "modules/audio_processing/ns/prior_signal_model_estimator.h"

#include <math.h>
#include <algorithm>

#include "modules/audio_processing/ns/fast_math.h"
#include "rtc_base/checks.h"

namespace webrtc {

namespace {

// Identifies the first of the two largest peaks in the histogram.
void FindFirstOfTwoLargestPeaks(
    float bin_size,
    rtc::ArrayView<const int, kHistogramSize> spectral_flatness,
    float* peak_position,
    int* peak_weight) {
  RTC_DCHECK(peak_position);
  RTC_DCHECK(peak_weight);

  int peak_value = 0;
  int secondary_peak_value = 0;
  *peak_position = 0.f;
  float secondary_peak_position = 0.f;
  *peak_weight = 0;
  int secondary_peak_weight = 0;

  // Identify the two largest peaks.
  for (int i = 0; i < kHistogramSize; ++i) {
    const float bin_mid = (i + 0.5f) * bin_size;
    if (spectral_flatness[i] > peak_value) {
      // Found new "first" peak candidate.
      secondary_peak_value = peak_value;
      secondary_peak_weight = *peak_weight;
      secondary_peak_position = *peak_position;

      peak_value = spectral_flatness[i];
      *peak_weight = spectral_flatness[i];
      *peak_position = bin_mid;
    } else if (spectral_flatness[i] > secondary_peak_value) {
      // Found new "second" peak candidate.
      secondary_peak_value = spectral_flatness[i];
      secondary_peak_weight = spectral_flatness[i];
      secondary_peak_position = bin_mid;
    }
  }

  // Merge the peaks if they are close.
  if ((fabs(secondary_peak_position - *peak_position) < 2 * bin_size) &&
      (secondary_peak_weight > 0.5f * (*peak_weight))) {
    *peak_weight += secondary_peak_weight;
    *peak_position = 0.5f * (*peak_position + secondary_peak_position);
  }
}

void UpdateLrt(rtc::ArrayView<const int, kHistogramSize> lrt_histogram,
               float* prior_model_lrt,
               bool* low_lrt_fluctuations) {
  RTC_DCHECK(prior_model_lrt);
  RTC_DCHECK(low_lrt_fluctuations);

  float average = 0.f;
  float average_compl = 0.f;
  float average_squared = 0.f;
  int count = 0;

  for (int i = 0; i < 10; ++i) {
    float bin_mid = (i + 0.5f) * kBinSizeLrt;
    average += lrt_histogram[i] * bin_mid;
    count += lrt_histogram[i];
  }
  if (count > 0) {
    average = average / count;
  }

  for (int i = 0; i < kHistogramSize; ++i) {
    float bin_mid = (i + 0.5f) * kBinSizeLrt;
    average_squared += lrt_histogram[i] * bin_mid * bin_mid;
    average_compl += lrt_histogram[i] * bin_mid;
  }
  constexpr float kOneFeatureUpdateWindowSize = 1.f / kFeatureUpdateWindowSize;
  average_squared = average_squared * kOneFeatureUpdateWindowSize;
  average_compl = average_compl * kOneFeatureUpdateWindowSize;

  // Fluctuation limit of LRT feature.
  *low_lrt_fluctuations = average_squared - average * average_compl < 0.05f;

  // Get threshold for LRT feature.
  constexpr float kMaxLrt = 1.f;
  constexpr float kMinLrt = .2f;
  if (*low_lrt_fluctuations) {
    // Very low fluctuation, so likely noise.
    *prior_model_lrt = kMaxLrt;
  } else {
    *prior_model_lrt = std::min(kMaxLrt, std::max(kMinLrt, 1.2f * average));
  }
}

}  // namespace

PriorSignalModelEstimator::PriorSignalModelEstimator(float lrt_initial_value)
    : prior_model_(lrt_initial_value) {}

// Extract thresholds for feature parameters and computes the threshold/weights.
void PriorSignalModelEstimator::Update(const Histograms& histograms) {
  bool low_lrt_fluctuations;
  UpdateLrt(histograms.get_lrt(), &prior_model_.lrt, &low_lrt_fluctuations);

  // For spectral flatness and spectral difference: compute the main peaks of
  // the histograms.
  float spectral_flatness_peak_position;
  int spectral_flatness_peak_weight;
  FindFirstOfTwoLargestPeaks(
      kBinSizeSpecFlat, histograms.get_spectral_flatness(),
      &spectral_flatness_peak_position, &spectral_flatness_peak_weight);

  float spectral_diff_peak_position = 0.f;
  int spectral_diff_peak_weight = 0;
  FindFirstOfTwoLargestPeaks(kBinSizeSpecDiff, histograms.get_spectral_diff(),
                             &spectral_diff_peak_position,
                             &spectral_diff_peak_weight);

  // Reject if weight of peaks is not large enough, or peak value too small.
  // Peak limit for spectral flatness (varies between 0 and 1).
  const int use_spec_flat = spectral_flatness_peak_weight < 0.3f * 500 ||
                                    spectral_flatness_peak_position < 0.6f
                                ? 0
                                : 1;

  // Reject if weight of peaks is not large enough or if fluctuation of the LRT
  // feature are very low, indicating a noise state.
  const int use_spec_diff =
      spectral_diff_peak_weight < 0.3f * 500 || low_lrt_fluctuations ? 0 : 1;

  // Update the model.
  prior_model_.template_diff_threshold = 1.2f * spectral_diff_peak_position;
  prior_model_.template_diff_threshold =
      std::min(1.f, std::max(0.16f, prior_model_.template_diff_threshold));

  float one_by_feature_sum = 1.f / (1.f + use_spec_flat + use_spec_diff);
  prior_model_.lrt_weighting = one_by_feature_sum;

  if (use_spec_flat == 1) {
    prior_model_.flatness_threshold = 0.9f * spectral_flatness_peak_position;
    prior_model_.flatness_threshold =
        std::min(.95f, std::max(0.1f, prior_model_.flatness_threshold));
    prior_model_.flatness_weighting = one_by_feature_sum;
  } else {
    prior_model_.flatness_weighting = 0.f;
  }

  if (use_spec_diff == 1) {
    prior_model_.difference_weighting = one_by_feature_sum;
  } else {
    prior_model_.difference_weighting = 0.f;
  }
}

}  // namespace webrtc