summaryrefslogtreecommitdiff
path: root/webrtc/modules/audio_processing/aec3/transparent_mode.cc
blob: 1820e16808a9deb27d8e44b12495c8853e02425c (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
/*
 *  Copyright (c) 2020 The WebRTC project authors. All Rights Reserved.
 *
 *  Use of this source code is governed by a BSD-style license
 *  that can be found in the LICENSE file in the root of the source
 *  tree. An additional intellectual property rights grant can be found
 *  in the file PATENTS.  All contributing project authors may
 *  be found in the AUTHORS file in the root of the source tree.
 */

#include "modules/audio_processing/aec3/transparent_mode.h"

#include "rtc_base/checks.h"
#include "system_wrappers/include/field_trial.h"

namespace webrtc {
namespace {

constexpr size_t kBlocksSinceConvergencedFilterInit = 10000;
constexpr size_t kBlocksSinceConsistentEstimateInit = 10000;

bool DeactivateTransparentMode() {
  return field_trial::IsEnabled("WebRTC-Aec3TransparentModeKillSwitch");
}

bool DeactivateTransparentModeHmm() {
  return field_trial::IsEnabled("WebRTC-Aec3TransparentModeHmmKillSwitch");
}

}  // namespace

// Classifier that toggles transparent mode which reduces echo suppression when
// headsets are used.
class TransparentModeImpl : public TransparentMode {
 public:
  bool Active() const override { return transparency_activated_; }

  void Reset() override {
    // Determines if transparent mode is used.
    transparency_activated_ = false;

    // The estimated probability of being transparent mode.
    prob_transparent_state_ = 0.f;
  }

  void Update(int filter_delay_blocks,
              bool any_filter_consistent,
              bool any_filter_converged,
              bool all_filters_diverged,
              bool active_render,
              bool saturated_capture) override {
    // The classifier is implemented as a Hidden Markov Model (HMM) with two
    // hidden states: "normal" and "transparent". The estimated probabilities of
    // the two states are updated by observing filter convergence during active
    // render. The filters are less likely to be reported as converged when
    // there is no echo present in the microphone signal.

    // The constants have been obtained by observing active_render and
    // any_filter_converged under varying call scenarios. They have further been
    // hand tuned to prefer normal state during uncertain regions (to avoid echo
    // leaks).

    // The model is only updated during active render.
    if (!active_render)
      return;

    // Probability of switching from one state to the other.
    constexpr float kSwitch = 0.000001f;

    // Probability of observing converged filters in states "normal" and
    // "transparent" during active render.
    constexpr float kConvergedNormal = 0.03f;
    constexpr float kConvergedTransparent = 0.005f;

    // Probability of transitioning to transparent state from normal state and
    // transparent state respectively.
    constexpr float kA[2] = {kSwitch, 1.f - kSwitch};

    // Probability of the two observations (converged filter or not converged
    // filter) in normal state and transparent state respectively.
    constexpr float kB[2][2] = {
        {1.f - kConvergedNormal, kConvergedNormal},
        {1.f - kConvergedTransparent, kConvergedTransparent}};

    // Probability of the two states before the update.
    const float prob_transparent = prob_transparent_state_;
    const float prob_normal = 1.f - prob_transparent;

    // Probability of transitioning to transparent state.
    const float prob_transition_transparent =
        prob_normal * kA[0] + prob_transparent * kA[1];
    const float prob_transition_normal = 1.f - prob_transition_transparent;

    // Observed output.
    const int out = any_filter_converged;

    // Joint probabilites of the observed output and respective states.
    const float prob_joint_normal = prob_transition_normal * kB[0][out];
    const float prob_joint_transparent =
        prob_transition_transparent * kB[1][out];

    // Conditional probability of transparent state and the observed output.
    RTC_DCHECK_GT(prob_joint_normal + prob_joint_transparent, 0.f);
    prob_transparent_state_ =
        prob_joint_transparent / (prob_joint_normal + prob_joint_transparent);

    // Transparent mode is only activated when its state probability is high.
    // Dead zone between activation/deactivation thresholds to avoid switching
    // back and forth.
    if (prob_transparent_state_ > 0.95f) {
      transparency_activated_ = true;
    } else if (prob_transparent_state_ < 0.5f) {
      transparency_activated_ = false;
    }
  }

 private:
  bool transparency_activated_ = false;
  float prob_transparent_state_ = 0.f;
};

// Legacy classifier for toggling transparent mode.
class LegacyTransparentModeImpl : public TransparentMode {
 public:
  explicit LegacyTransparentModeImpl(const EchoCanceller3Config& config)
      : linear_and_stable_echo_path_(
            config.echo_removal_control.linear_and_stable_echo_path),
        active_blocks_since_sane_filter_(kBlocksSinceConsistentEstimateInit),
        non_converged_sequence_size_(kBlocksSinceConvergencedFilterInit) {}

  bool Active() const override { return transparency_activated_; }

  void Reset() override {
    non_converged_sequence_size_ = kBlocksSinceConvergencedFilterInit;
    diverged_sequence_size_ = 0;
    strong_not_saturated_render_blocks_ = 0;
    if (linear_and_stable_echo_path_) {
      recent_convergence_during_activity_ = false;
    }
  }

  void Update(int filter_delay_blocks,
              bool any_filter_consistent,
              bool any_filter_converged,
              bool all_filters_diverged,
              bool active_render,
              bool saturated_capture) override {
    ++capture_block_counter_;
    strong_not_saturated_render_blocks_ +=
        active_render && !saturated_capture ? 1 : 0;

    if (any_filter_consistent && filter_delay_blocks < 5) {
      sane_filter_observed_ = true;
      active_blocks_since_sane_filter_ = 0;
    } else if (active_render) {
      ++active_blocks_since_sane_filter_;
    }

    bool sane_filter_recently_seen;
    if (!sane_filter_observed_) {
      sane_filter_recently_seen =
          capture_block_counter_ <= 5 * kNumBlocksPerSecond;
    } else {
      sane_filter_recently_seen =
          active_blocks_since_sane_filter_ <= 30 * kNumBlocksPerSecond;
    }

    if (any_filter_converged) {
      recent_convergence_during_activity_ = true;
      active_non_converged_sequence_size_ = 0;
      non_converged_sequence_size_ = 0;
      ++num_converged_blocks_;
    } else {
      if (++non_converged_sequence_size_ > 20 * kNumBlocksPerSecond) {
        num_converged_blocks_ = 0;
      }

      if (active_render &&
          ++active_non_converged_sequence_size_ > 60 * kNumBlocksPerSecond) {
        recent_convergence_during_activity_ = false;
      }
    }

    if (!all_filters_diverged) {
      diverged_sequence_size_ = 0;
    } else if (++diverged_sequence_size_ >= 60) {
      // TODO(peah): Change these lines to ensure proper triggering of usable
      // filter.
      non_converged_sequence_size_ = kBlocksSinceConvergencedFilterInit;
    }

    if (active_non_converged_sequence_size_ > 60 * kNumBlocksPerSecond) {
      finite_erl_recently_detected_ = false;
    }
    if (num_converged_blocks_ > 50) {
      finite_erl_recently_detected_ = true;
    }

    if (finite_erl_recently_detected_) {
      transparency_activated_ = false;
    } else if (sane_filter_recently_seen &&
               recent_convergence_during_activity_) {
      transparency_activated_ = false;
    } else {
      const bool filter_should_have_converged =
          strong_not_saturated_render_blocks_ > 6 * kNumBlocksPerSecond;
      transparency_activated_ = filter_should_have_converged;
    }
  }

 private:
  const bool linear_and_stable_echo_path_;
  size_t capture_block_counter_ = 0;
  bool transparency_activated_ = false;
  size_t active_blocks_since_sane_filter_;
  bool sane_filter_observed_ = false;
  bool finite_erl_recently_detected_ = false;
  size_t non_converged_sequence_size_;
  size_t diverged_sequence_size_ = 0;
  size_t active_non_converged_sequence_size_ = 0;
  size_t num_converged_blocks_ = 0;
  bool recent_convergence_during_activity_ = false;
  size_t strong_not_saturated_render_blocks_ = 0;
};

std::unique_ptr<TransparentMode> TransparentMode::Create(
    const EchoCanceller3Config& config) {
  if (config.ep_strength.bounded_erl || DeactivateTransparentMode()) {
    return nullptr;
  }
  if (DeactivateTransparentModeHmm()) {
    return std::make_unique<LegacyTransparentModeImpl>(config);
  }
  return std::make_unique<TransparentModeImpl>();
}

}  // namespace webrtc