OLD | NEW |
1 /* | 1 /* |
2 * Copyright (c) 2017 The WebRTC project authors. All Rights Reserved. | 2 * Copyright (c) 2017 The WebRTC project authors. All Rights Reserved. |
3 * | 3 * |
4 * Use of this source code is governed by a BSD-style license | 4 * Use of this source code is governed by a BSD-style license |
5 * that can be found in the LICENSE file in the root of the source | 5 * that can be found in the LICENSE file in the root of the source |
6 * tree. An additional intellectual property rights grant can be found | 6 * tree. An additional intellectual property rights grant can be found |
7 * in the file PATENTS. All contributing project authors may | 7 * in the file PATENTS. All contributing project authors may |
8 * be found in the AUTHORS file in the root of the source tree. | 8 * be found in the AUTHORS file in the root of the source tree. |
9 */ | 9 */ |
10 | 10 |
11 #include "webrtc/modules/audio_processing/aec3/suppression_gain.h" | 11 #include "webrtc/modules/audio_processing/aec3/suppression_gain.h" |
12 | 12 |
13 #include "webrtc/typedefs.h" | 13 #include "webrtc/typedefs.h" |
14 #if defined(WEBRTC_ARCH_X86_FAMILY) | 14 #if defined(WEBRTC_ARCH_X86_FAMILY) |
15 #include <emmintrin.h> | 15 #include <emmintrin.h> |
16 #endif | 16 #endif |
17 #include <math.h> | 17 #include <math.h> |
18 #include <algorithm> | 18 #include <algorithm> |
19 #include <functional> | 19 #include <functional> |
| 20 #include <numeric> |
20 | 21 |
21 #include "webrtc/base/checks.h" | 22 #include "webrtc/base/checks.h" |
22 | 23 |
23 namespace webrtc { | 24 namespace webrtc { |
24 namespace { | 25 namespace { |
25 | 26 |
26 void GainPostProcessing(std::array<float, kFftLengthBy2Plus1>* gain_squared) { | 27 void GainPostProcessing(std::array<float, kFftLengthBy2Plus1>* gain_squared) { |
27 // Limit the low frequency gains to avoid the impact of the high-pass filter | 28 // Limit the low frequency gains to avoid the impact of the high-pass filter |
28 // on the lower-frequency gain influencing the overall achieved gain. | 29 // on the lower-frequency gain influencing the overall achieved gain. |
29 (*gain_squared)[1] = std::min((*gain_squared)[1], (*gain_squared)[2]); | 30 (*gain_squared)[1] = std::min((*gain_squared)[1], (*gain_squared)[2]); |
30 (*gain_squared)[0] = (*gain_squared)[1]; | 31 (*gain_squared)[0] = (*gain_squared)[1]; |
31 | 32 |
32 // Limit the high frequency gains to avoid the impact of the anti-aliasing | 33 // Limit the high frequency gains to avoid the impact of the anti-aliasing |
33 // filter on the upper-frequency gains influencing the overall achieved | 34 // filter on the upper-frequency gains influencing the overall achieved |
34 // gain. TODO(peah): Update this when new anti-aliasing filters are | 35 // gain. TODO(peah): Update this when new anti-aliasing filters are |
35 // implemented. | 36 // implemented. |
36 constexpr size_t kAntiAliasingImpactLimit = 64 * 0.7f; | 37 constexpr size_t kAntiAliasingImpactLimit = (64 * 2000) / 8000; |
37 std::for_each(gain_squared->begin() + kAntiAliasingImpactLimit, | 38 std::for_each(gain_squared->begin() + kAntiAliasingImpactLimit, |
38 gain_squared->end(), | 39 gain_squared->end() - 1, |
39 [gain_squared, kAntiAliasingImpactLimit](float& a) { | 40 [gain_squared, kAntiAliasingImpactLimit](float& a) { |
40 a = std::min(a, (*gain_squared)[kAntiAliasingImpactLimit]); | 41 a = std::min(a, (*gain_squared)[kAntiAliasingImpactLimit]); |
41 }); | 42 }); |
42 (*gain_squared)[kFftLengthBy2] = (*gain_squared)[kFftLengthBy2Minus1]; | 43 (*gain_squared)[kFftLengthBy2] = (*gain_squared)[kFftLengthBy2Minus1]; |
43 } | 44 } |
44 | 45 |
45 constexpr int kNumIterations = 2; | 46 constexpr int kNumIterations = 2; |
46 constexpr float kEchoMaskingMargin = 1.f / 10.f; | 47 constexpr float kEchoMaskingMargin = 1.f / 20.f; |
47 constexpr float kBandMaskingFactor = 1.f / 2.f; | 48 constexpr float kBandMaskingFactor = 1.f / 10.f; |
48 constexpr float kTimeMaskingFactor = 1.f / 10.f; | 49 constexpr float kTimeMaskingFactor = 1.f / 10.f; |
49 | 50 |
50 } // namespace | 51 } // namespace |
51 | 52 |
52 namespace aec3 { | 53 namespace aec3 { |
53 | 54 |
54 #if defined(WEBRTC_ARCH_X86_FAMILY) | 55 #if defined(WEBRTC_ARCH_X86_FAMILY) |
55 | 56 |
56 // Optimized SSE2 code for the gain computation. | 57 // Optimized SSE2 code for the gain computation. |
57 // TODO(peah): Add further optimizations, in particular for the divisions. | 58 // TODO(peah): Add further optimizations, in particular for the divisions. |
(...skipping 72 matching lines...) Expand 10 before | Expand all | Expand 10 after Loading... |
130 | 131 |
131 // Limit gain for bands with strong nearend. | 132 // Limit gain for bands with strong nearend. |
132 std::transform(gain_squared->begin() + 1, gain_squared->end() - 1, | 133 std::transform(gain_squared->begin() + 1, gain_squared->end() - 1, |
133 strong_nearend.begin(), gain_squared->begin() + 1, | 134 strong_nearend.begin(), gain_squared->begin() + 1, |
134 [](float a, bool b) { return b ? 1.f : a; }); | 135 [](float a, bool b) { return b ? 1.f : a; }); |
135 | 136 |
136 // Limit the allowed gain update over time. | 137 // Limit the allowed gain update over time. |
137 std::transform(gain_squared->begin() + 1, gain_squared->end() - 1, | 138 std::transform(gain_squared->begin() + 1, gain_squared->end() - 1, |
138 previous_gain_squared->begin(), gain_squared->begin() + 1, | 139 previous_gain_squared->begin(), gain_squared->begin() + 1, |
139 [](float a, float b) { | 140 [](float a, float b) { |
140 return b < 0.0001f ? std::min(a, 0.0001f) | 141 return b < 0.001f ? std::min(a, 0.001f) |
141 : std::min(a, b * 2.f); | 142 : std::min(a, b * 2.f); |
142 }); | 143 }); |
143 | 144 |
144 // Process the gains to avoid artefacts caused by gain realization in the | 145 // Process the gains to avoid artefacts caused by gain realization in the |
145 // filterbank and impact of external pre-processing of the signal. | 146 // filterbank and impact of external pre-processing of the signal. |
146 GainPostProcessing(gain_squared); | 147 GainPostProcessing(gain_squared); |
147 } | 148 } |
148 | 149 |
149 std::copy(gain_squared->begin() + 1, gain_squared->end() - 1, | 150 std::copy(gain_squared->begin() + 1, gain_squared->end() - 1, |
150 previous_gain_squared->begin()); | 151 previous_gain_squared->begin()); |
151 | 152 |
(...skipping 90 matching lines...) Expand 10 before | Expand all | Expand 10 after Loading... |
242 | 243 |
243 // Limit gain for bands with strong nearend. | 244 // Limit gain for bands with strong nearend. |
244 std::transform(gain_squared->begin() + 1, gain_squared->end() - 1, | 245 std::transform(gain_squared->begin() + 1, gain_squared->end() - 1, |
245 strong_nearend.begin(), gain_squared->begin() + 1, | 246 strong_nearend.begin(), gain_squared->begin() + 1, |
246 [](float a, bool b) { return b ? 1.f : a; }); | 247 [](float a, bool b) { return b ? 1.f : a; }); |
247 | 248 |
248 // Limit the allowed gain update over time. | 249 // Limit the allowed gain update over time. |
249 std::transform(gain_squared->begin() + 1, gain_squared->end() - 1, | 250 std::transform(gain_squared->begin() + 1, gain_squared->end() - 1, |
250 previous_gain_squared->begin(), gain_squared->begin() + 1, | 251 previous_gain_squared->begin(), gain_squared->begin() + 1, |
251 [](float a, float b) { | 252 [](float a, float b) { |
252 return b < 0.0001f ? std::min(a, 0.0001f) | 253 return b < 0.001f ? std::min(a, 0.001f) |
253 : std::min(a, b * 2.f); | 254 : std::min(a, b * 2.f); |
254 }); | 255 }); |
255 | 256 |
256 // Process the gains to avoid artefacts caused by gain realization in the | 257 // Process the gains to avoid artefacts caused by gain realization in the |
257 // filterbank and impact of external pre-processing of the signal. | 258 // filterbank and impact of external pre-processing of the signal. |
258 GainPostProcessing(gain_squared); | 259 GainPostProcessing(gain_squared); |
259 } | 260 } |
260 | 261 |
261 std::copy(gain_squared->begin() + 1, gain_squared->end() - 1, | 262 std::copy(gain_squared->begin() + 1, gain_squared->end() - 1, |
262 previous_gain_squared->begin()); | 263 previous_gain_squared->begin()); |
263 | 264 |
264 std::transform(gain_squared->begin() + 1, gain_squared->end() - 1, | 265 std::transform(gain_squared->begin() + 1, gain_squared->end() - 1, |
265 nearend_power.begin() + 1, previous_masker->begin(), | 266 nearend_power.begin() + 1, previous_masker->begin(), |
266 std::multiplies<float>()); | 267 std::multiplies<float>()); |
267 std::transform(previous_masker->begin(), previous_masker->end(), | 268 std::transform(previous_masker->begin(), previous_masker->end(), |
268 comfort_noise_power.begin() + 1, previous_masker->begin(), | 269 comfort_noise_power.begin() + 1, previous_masker->begin(), |
269 std::plus<float>()); | 270 std::plus<float>()); |
270 | 271 |
271 std::transform(gain_squared->begin(), gain_squared->end(), gain->begin(), | 272 std::transform(gain_squared->begin(), gain_squared->end(), gain->begin(), |
272 [](float a) { return sqrtf(a); }); | 273 [](float a) { return sqrtf(a); }); |
273 } | 274 } |
274 | 275 |
275 } // namespace aec3 | 276 } // namespace aec3 |
276 | 277 |
| 278 // Computes an upper bound on the gain to apply for high frequencies. |
| 279 float HighFrequencyGainBound(bool saturated_echo, |
| 280 const std::vector<std::vector<float>>& render) { |
| 281 if (render.size() == 1) { |
| 282 return 1.f; |
| 283 } |
| 284 |
| 285 // Always attenuate the upper bands when there is saturated echo. |
| 286 if (saturated_echo) { |
| 287 return 0.001f; |
| 288 } |
| 289 |
| 290 // Compute the upper and lower band energies. |
| 291 float low_band_energy = |
| 292 std::accumulate(render[0].begin(), render[0].end(), 0.f, |
| 293 [](float a, float b) -> float { return a + b * b; }); |
| 294 float high_band_energies = 0.f; |
| 295 for (size_t k = 1; k < render.size(); ++k) { |
| 296 high_band_energies = std::max( |
| 297 high_band_energies, |
| 298 std::accumulate(render[k].begin(), render[k].end(), 0.f, |
| 299 [](float a, float b) -> float { return a + b * b; })); |
| 300 } |
| 301 |
| 302 // If there is more power in the lower frequencies than the upper frequencies, |
| 303 // or if the power in upper frequencies is low, do not bound the gain in the |
| 304 // upper bands. |
| 305 if (high_band_energies < low_band_energy || |
| 306 high_band_energies < kSubBlockSize * 10.f * 10.f) { |
| 307 return 1.f; |
| 308 } |
| 309 |
| 310 // In all other cases, bound the gain for upper frequencies. |
| 311 RTC_DCHECK_LE(low_band_energy, high_band_energies); |
| 312 return 0.01f * sqrtf(low_band_energy / high_band_energies); |
| 313 } |
| 314 |
277 SuppressionGain::SuppressionGain(Aec3Optimization optimization) | 315 SuppressionGain::SuppressionGain(Aec3Optimization optimization) |
278 : optimization_(optimization) { | 316 : optimization_(optimization) { |
279 previous_gain_squared_.fill(1.f); | 317 previous_gain_squared_.fill(1.f); |
280 previous_masker_.fill(0.f); | 318 previous_masker_.fill(0.f); |
281 } | 319 } |
282 | 320 |
283 void SuppressionGain::GetGain( | 321 void SuppressionGain::GetGain( |
284 const std::array<float, kFftLengthBy2Plus1>& nearend_power, | 322 const std::array<float, kFftLengthBy2Plus1>& nearend_power, |
285 const std::array<float, kFftLengthBy2Plus1>& residual_echo_power, | 323 const std::array<float, kFftLengthBy2Plus1>& residual_echo_power, |
286 const std::array<float, kFftLengthBy2Plus1>& comfort_noise_power, | 324 const std::array<float, kFftLengthBy2Plus1>& comfort_noise_power, |
287 float strong_nearend_margin, | 325 bool saturated_echo, |
288 std::array<float, kFftLengthBy2Plus1>* gain) { | 326 const std::vector<std::vector<float>>& render, |
289 RTC_DCHECK(gain); | 327 size_t num_capture_bands, |
| 328 float* high_bands_gain, |
| 329 std::array<float, kFftLengthBy2Plus1>* low_band_gain) { |
| 330 RTC_DCHECK(high_bands_gain); |
| 331 RTC_DCHECK(low_band_gain); |
| 332 |
| 333 // Choose margin to use. |
| 334 const float margin = saturated_echo ? 0.001f : 0.01f; |
290 switch (optimization_) { | 335 switch (optimization_) { |
291 #if defined(WEBRTC_ARCH_X86_FAMILY) | 336 #if defined(WEBRTC_ARCH_X86_FAMILY) |
292 case Aec3Optimization::kSse2: | 337 case Aec3Optimization::kSse2: |
293 aec3::ComputeGains_SSE2(nearend_power, residual_echo_power, | 338 aec3::ComputeGains_SSE2( |
294 comfort_noise_power, strong_nearend_margin, | 339 nearend_power, residual_echo_power, comfort_noise_power, margin, |
295 &previous_gain_squared_, &previous_masker_, gain); | 340 &previous_gain_squared_, &previous_masker_, low_band_gain); |
296 break; | 341 break; |
297 #endif | 342 #endif |
298 default: | 343 default: |
299 aec3::ComputeGains(nearend_power, residual_echo_power, | 344 aec3::ComputeGains(nearend_power, residual_echo_power, |
300 comfort_noise_power, strong_nearend_margin, | 345 comfort_noise_power, margin, &previous_gain_squared_, |
301 &previous_gain_squared_, &previous_masker_, gain); | 346 &previous_masker_, low_band_gain); |
| 347 } |
| 348 |
| 349 if (num_capture_bands > 1) { |
| 350 // Compute the gain for upper frequencies. |
| 351 const float min_high_band_gain = |
| 352 HighFrequencyGainBound(saturated_echo, render); |
| 353 *high_bands_gain = |
| 354 *std::min_element(low_band_gain->begin() + 32, low_band_gain->end()); |
| 355 |
| 356 *high_bands_gain = std::min(*high_bands_gain, min_high_band_gain); |
| 357 |
| 358 } else { |
| 359 *high_bands_gain = 1.f; |
302 } | 360 } |
303 } | 361 } |
304 | 362 |
305 } // namespace webrtc | 363 } // namespace webrtc |
OLD | NEW |