OLD | NEW |
---|---|
1 /* | 1 /* |
2 * Copyright (c) 2017 The WebRTC project authors. All Rights Reserved. | 2 * Copyright (c) 2017 The WebRTC project authors. All Rights Reserved. |
3 * | 3 * |
4 * Use of this source code is governed by a BSD-style license | 4 * Use of this source code is governed by a BSD-style license |
5 * that can be found in the LICENSE file in the root of the source | 5 * that can be found in the LICENSE file in the root of the source |
6 * tree. An additional intellectual property rights grant can be found | 6 * tree. An additional intellectual property rights grant can be found |
7 * in the file PATENTS. All contributing project authors may | 7 * in the file PATENTS. All contributing project authors may |
8 * be found in the AUTHORS file in the root of the source tree. | 8 * be found in the AUTHORS file in the root of the source tree. |
9 */ | 9 */ |
10 | 10 |
11 #include "webrtc/modules/audio_processing/aec3/suppression_gain.h" | 11 #include "webrtc/modules/audio_processing/aec3/suppression_gain.h" |
12 | 12 |
13 #include "webrtc/typedefs.h" | 13 #include "webrtc/typedefs.h" |
14 #if defined(WEBRTC_ARCH_X86_FAMILY) | 14 #if defined(WEBRTC_ARCH_X86_FAMILY) |
15 #include <emmintrin.h> | 15 #include <emmintrin.h> |
16 #endif | 16 #endif |
17 #include <math.h> | 17 #include <math.h> |
18 #include <algorithm> | 18 #include <algorithm> |
19 #include <functional> | 19 #include <functional> |
20 #include <numeric> | |
20 | 21 |
21 namespace webrtc { | 22 namespace webrtc { |
22 namespace { | 23 namespace { |
23 | 24 |
24 void GainPostProcessing(std::array<float, kFftLengthBy2Plus1>* gain_squared) { | 25 void GainPostProcessing(std::array<float, kFftLengthBy2Plus1>* gain_squared) { |
25 // Limit the low frequency gains to avoid the impact of the high-pass filter | 26 // Limit the low frequency gains to avoid the impact of the high-pass filter |
26 // on the lower-frequency gain influencing the overall achieved gain. | 27 // on the lower-frequency gain influencing the overall achieved gain. |
27 (*gain_squared)[1] = std::min((*gain_squared)[1], (*gain_squared)[2]); | 28 (*gain_squared)[1] = std::min((*gain_squared)[1], (*gain_squared)[2]); |
28 (*gain_squared)[0] = (*gain_squared)[1]; | 29 (*gain_squared)[0] = (*gain_squared)[1]; |
29 | 30 |
30 // Limit the high frequency gains to avoid the impact of the anti-aliasing | 31 // Limit the high frequency gains to avoid the impact of the anti-aliasing |
31 // filter on the upper-frequency gains influencing the overall achieved | 32 // filter on the upper-frequency gains influencing the overall achieved |
32 // gain. TODO(peah): Update this when new anti-aliasing filters are | 33 // gain. TODO(peah): Update this when new anti-aliasing filters are |
33 // implemented. | 34 // implemented. |
34 constexpr size_t kAntiAliasingImpactLimit = 64 * 0.7f; | 35 constexpr size_t kAntiAliasingImpactLimit = (64 * 2000) / 8000; |
36 float minG = (*gain_squared)[2]; | |
37 for (size_t k = 2; k < kAntiAliasingImpactLimit; ++k) { | |
38 minG = std::min(minG, (*gain_squared)[k]); | |
39 } | |
aleloi
2017/04/06 14:34:22
This can be simplified to something like this:
mi
peah-webrtc
2017/04/06 21:11:14
True. Furthermore, it turned out that minG was not
| |
40 | |
35 std::for_each(gain_squared->begin() + kAntiAliasingImpactLimit, | 41 std::for_each(gain_squared->begin() + kAntiAliasingImpactLimit, |
36 gain_squared->end(), | 42 gain_squared->end(), |
37 [gain_squared, kAntiAliasingImpactLimit](float& a) { | 43 [gain_squared, kAntiAliasingImpactLimit](float& a) { |
38 a = std::min(a, (*gain_squared)[kAntiAliasingImpactLimit]); | 44 a = std::min(a, (*gain_squared)[kAntiAliasingImpactLimit]); |
39 }); | 45 }); |
40 (*gain_squared)[kFftLengthBy2] = (*gain_squared)[kFftLengthBy2Minus1]; | 46 (*gain_squared)[kFftLengthBy2] = (*gain_squared)[kFftLengthBy2Minus1]; |
41 } | 47 } |
42 | 48 |
43 constexpr int kNumIterations = 2; | 49 constexpr int kNumIterations = 2; |
44 constexpr float kEchoMaskingMargin = 1.f / 10.f; | 50 constexpr float kEchoMaskingMargin = 1.f / 20.f; |
45 constexpr float kBandMaskingFactor = 1.f / 2.f; | 51 constexpr float kBandMaskingFactor = 1.f / 10.f; |
46 constexpr float kTimeMaskingFactor = 1.f / 10.f; | 52 constexpr float kTimeMaskingFactor = 1.f / 10.f; |
47 | 53 |
48 } // namespace | 54 } // namespace |
49 | 55 |
50 namespace aec3 { | 56 namespace aec3 { |
51 | 57 |
52 #if defined(WEBRTC_ARCH_X86_FAMILY) | 58 #if defined(WEBRTC_ARCH_X86_FAMILY) |
53 | 59 |
54 // Optimized SSE2 code for the gain computation. | 60 // Optimized SSE2 code for the gain computation. |
55 // TODO(peah): Add further optimizations, in particular for the divisions. | 61 // TODO(peah): Add further optimizations, in particular for the divisions. |
(...skipping 72 matching lines...) Expand 10 before | Expand all | Expand 10 after Loading... | |
128 | 134 |
129 // Limit gain for bands with strong nearend. | 135 // Limit gain for bands with strong nearend. |
130 std::transform(gain_squared->begin() + 1, gain_squared->end() - 1, | 136 std::transform(gain_squared->begin() + 1, gain_squared->end() - 1, |
131 strong_nearend.begin(), gain_squared->begin() + 1, | 137 strong_nearend.begin(), gain_squared->begin() + 1, |
132 [](float a, bool b) { return b ? 1.f : a; }); | 138 [](float a, bool b) { return b ? 1.f : a; }); |
133 | 139 |
134 // Limit the allowed gain update over time. | 140 // Limit the allowed gain update over time. |
135 std::transform(gain_squared->begin() + 1, gain_squared->end() - 1, | 141 std::transform(gain_squared->begin() + 1, gain_squared->end() - 1, |
136 previous_gain_squared->begin(), gain_squared->begin() + 1, | 142 previous_gain_squared->begin(), gain_squared->begin() + 1, |
137 [](float a, float b) { | 143 [](float a, float b) { |
138 return b < 0.0001f ? std::min(a, 0.0001f) | 144 return b < 0.001f ? std::min(a, 0.001f) |
139 : std::min(a, b * 2.f); | 145 : std::min(a, b * 2.f); |
140 }); | 146 }); |
141 | 147 |
142 // Process the gains to avoid artefacts caused by gain realization in the | 148 // Process the gains to avoid artefacts caused by gain realization in the |
143 // filterbank and impact of external pre-processing of the signal. | 149 // filterbank and impact of external pre-processing of the signal. |
144 GainPostProcessing(gain_squared); | 150 GainPostProcessing(gain_squared); |
145 } | 151 } |
146 | 152 |
147 std::copy(gain_squared->begin() + 1, gain_squared->end() - 1, | 153 std::copy(gain_squared->begin() + 1, gain_squared->end() - 1, |
148 previous_gain_squared->begin()); | 154 previous_gain_squared->begin()); |
149 | 155 |
(...skipping 115 matching lines...) Expand 10 before | Expand all | Expand 10 after Loading... | |
265 std::transform(previous_masker->begin(), previous_masker->end(), | 271 std::transform(previous_masker->begin(), previous_masker->end(), |
266 comfort_noise_power.begin() + 1, previous_masker->begin(), | 272 comfort_noise_power.begin() + 1, previous_masker->begin(), |
267 std::plus<float>()); | 273 std::plus<float>()); |
268 | 274 |
269 std::transform(gain_squared->begin(), gain_squared->end(), gain->begin(), | 275 std::transform(gain_squared->begin(), gain_squared->end(), gain->begin(), |
270 [](float a) { return sqrtf(a); }); | 276 [](float a) { return sqrtf(a); }); |
271 } | 277 } |
272 | 278 |
273 } // namespace aec3 | 279 } // namespace aec3 |
274 | 280 |
281 // Computes an upper bound on the gain to apply for high frequencies. | |
282 float HighFrequencyGainBound(bool saturated_echo, | |
283 const std::vector<std::vector<float>>& render) { | |
284 if (render.size() == 1) { | |
285 return 1.f; | |
286 } | |
287 | |
288 // Always attenuate the upper bands when there is saturated echo. | |
289 if (saturated_echo) { | |
290 return 0.001; | |
291 } | |
292 | |
293 // Compute the upper and lower band energies. | |
294 float low_band_energy = | |
295 std::accumulate(render[0].begin(), render[0].end(), 0.f, | |
296 [](float a, float b) -> float { return a + b * b; }); | |
297 float high_band_energies = 0.f; | |
298 for (size_t k = 1; k < render.size(); ++k) { | |
299 high_band_energies = std::max( | |
300 high_band_energies, | |
301 std::accumulate(render[k].begin(), render[k].end(), 0.f, | |
302 [](float a, float b) -> float { return a + b * b; })); | |
303 } | |
304 | |
305 // If there is more power in the lower frequencies than the upper frequencies, | |
306 // or if the power in upper frequencies is low, do not bound the gain in the | |
307 // upper bands. | |
308 if (high_band_energies < low_band_energy || | |
309 high_band_energies < kSubBlockSize * 10.f * 10.f) { | |
310 return 1.f; | |
311 } | |
312 | |
313 // In all other cases, bound the gain for upper frequencies. | |
314 RTC_DCHECK_LE(low_band_energy, high_band_energies); | |
315 return 0.01f * sqrtf(low_band_energy / high_band_energies); | |
316 } | |
317 | |
275 SuppressionGain::SuppressionGain(Aec3Optimization optimization) | 318 SuppressionGain::SuppressionGain(Aec3Optimization optimization) |
276 : optimization_(optimization) { | 319 : optimization_(optimization) { |
277 previous_gain_squared_.fill(1.f); | 320 previous_gain_squared_.fill(1.f); |
278 previous_masker_.fill(0.f); | 321 previous_masker_.fill(0.f); |
279 } | 322 } |
280 | 323 |
281 void SuppressionGain::GetGain( | 324 void SuppressionGain::GetGain( |
282 const std::array<float, kFftLengthBy2Plus1>& nearend_power, | 325 const std::array<float, kFftLengthBy2Plus1>& nearend_power, |
283 const std::array<float, kFftLengthBy2Plus1>& residual_echo_power, | 326 const std::array<float, kFftLengthBy2Plus1>& residual_echo_power, |
284 const std::array<float, kFftLengthBy2Plus1>& comfort_noise_power, | 327 const std::array<float, kFftLengthBy2Plus1>& comfort_noise_power, |
285 float strong_nearend_margin, | 328 bool saturated_echo, |
286 std::array<float, kFftLengthBy2Plus1>* gain) { | 329 const std::vector<std::vector<float>>& render, |
287 RTC_DCHECK(gain); | 330 size_t num_capture_bands, |
331 float* high_bands_gain, | |
332 std::array<float, kFftLengthBy2Plus1>* low_band_gain) { | |
333 RTC_DCHECK(high_bands_gain); | |
334 RTC_DCHECK(low_band_gain); | |
335 | |
336 // Choose margin to use. | |
337 const float margin = saturated_echo ? 0.001f : 0.01f; | |
288 switch (optimization_) { | 338 switch (optimization_) { |
289 #if defined(WEBRTC_ARCH_X86_FAMILY) | 339 #if defined(WEBRTC_ARCH_X86_FAMILY) |
290 case Aec3Optimization::kSse2: | 340 case Aec3Optimization::kSse2: |
291 aec3::ComputeGains_SSE2(nearend_power, residual_echo_power, | 341 aec3::ComputeGains_SSE2( |
292 comfort_noise_power, strong_nearend_margin, | 342 nearend_power, residual_echo_power, comfort_noise_power, margin, |
293 &previous_gain_squared_, &previous_masker_, gain); | 343 &previous_gain_squared_, &previous_masker_, low_band_gain); |
294 break; | 344 break; |
295 #endif | 345 #endif |
296 default: | 346 default: |
297 aec3::ComputeGains(nearend_power, residual_echo_power, | 347 aec3::ComputeGains(nearend_power, residual_echo_power, |
298 comfort_noise_power, strong_nearend_margin, | 348 comfort_noise_power, margin, &previous_gain_squared_, |
299 &previous_gain_squared_, &previous_masker_, gain); | 349 &previous_masker_, low_band_gain); |
350 } | |
351 | |
352 if (num_capture_bands > 1) { | |
353 // Compute the gain for upper frequencies. | |
354 const float min_high_band_gain = | |
355 HighFrequencyGainBound(saturated_echo, render); | |
356 *high_bands_gain = | |
357 *std::min_element(low_band_gain->begin() + 32, low_band_gain->end()); | |
358 | |
359 *high_bands_gain = std::min(*high_bands_gain, min_high_band_gain); | |
360 | |
361 } else { | |
362 *high_bands_gain = 1.f; | |
300 } | 363 } |
301 } | 364 } |
302 | 365 |
303 } // namespace webrtc | 366 } // namespace webrtc |
OLD | NEW |