| OLD | NEW |
| 1 /* | 1 /* |
| 2 * Copyright (c) 2017 The WebRTC project authors. All Rights Reserved. | 2 * Copyright (c) 2017 The WebRTC project authors. All Rights Reserved. |
| 3 * | 3 * |
| 4 * Use of this source code is governed by a BSD-style license | 4 * Use of this source code is governed by a BSD-style license |
| 5 * that can be found in the LICENSE file in the root of the source | 5 * that can be found in the LICENSE file in the root of the source |
| 6 * tree. An additional intellectual property rights grant can be found | 6 * tree. An additional intellectual property rights grant can be found |
| 7 * in the file PATENTS. All contributing project authors may | 7 * in the file PATENTS. All contributing project authors may |
| 8 * be found in the AUTHORS file in the root of the source tree. | 8 * be found in the AUTHORS file in the root of the source tree. |
| 9 */ | 9 */ |
| 10 | 10 |
| 11 #include "webrtc/modules/audio_processing/aec3/suppression_gain.h" | 11 #include "webrtc/modules/audio_processing/aec3/suppression_gain.h" |
| 12 | 12 |
| 13 #include "webrtc/typedefs.h" | 13 #include "webrtc/typedefs.h" |
| 14 #if defined(WEBRTC_ARCH_X86_FAMILY) | 14 #if defined(WEBRTC_ARCH_X86_FAMILY) |
| 15 #include <emmintrin.h> | 15 #include <emmintrin.h> |
| 16 #endif | 16 #endif |
| 17 #include <math.h> | 17 #include <math.h> |
| 18 #include <algorithm> | 18 #include <algorithm> |
| 19 #include <functional> | 19 #include <functional> |
| 20 #include <numeric> |
| 20 | 21 |
| 21 #include "webrtc/base/checks.h" | 22 #include "webrtc/base/checks.h" |
| 22 | 23 |
| 23 namespace webrtc { | 24 namespace webrtc { |
| 24 namespace { | 25 namespace { |
| 25 | 26 |
| 26 void GainPostProcessing(std::array<float, kFftLengthBy2Plus1>* gain_squared) { | 27 void GainPostProcessing(std::array<float, kFftLengthBy2Plus1>* gain_squared) { |
| 27 // Limit the low frequency gains to avoid the impact of the high-pass filter | 28 // Limit the low frequency gains to avoid the impact of the high-pass filter |
| 28 // on the lower-frequency gain influencing the overall achieved gain. | 29 // on the lower-frequency gain influencing the overall achieved gain. |
| 29 (*gain_squared)[1] = std::min((*gain_squared)[1], (*gain_squared)[2]); | 30 (*gain_squared)[1] = std::min((*gain_squared)[1], (*gain_squared)[2]); |
| 30 (*gain_squared)[0] = (*gain_squared)[1]; | 31 (*gain_squared)[0] = (*gain_squared)[1]; |
| 31 | 32 |
| 32 // Limit the high frequency gains to avoid the impact of the anti-aliasing | 33 // Limit the high frequency gains to avoid the impact of the anti-aliasing |
| 33 // filter on the upper-frequency gains influencing the overall achieved | 34 // filter on the upper-frequency gains influencing the overall achieved |
| 34 // gain. TODO(peah): Update this when new anti-aliasing filters are | 35 // gain. TODO(peah): Update this when new anti-aliasing filters are |
| 35 // implemented. | 36 // implemented. |
| 36 constexpr size_t kAntiAliasingImpactLimit = 64 * 0.7f; | 37 constexpr size_t kAntiAliasingImpactLimit = (64 * 2000) / 8000; |
| 37 std::for_each(gain_squared->begin() + kAntiAliasingImpactLimit, | 38 std::for_each(gain_squared->begin() + kAntiAliasingImpactLimit, |
| 38 gain_squared->end(), | 39 gain_squared->end() - 1, |
| 39 [gain_squared, kAntiAliasingImpactLimit](float& a) { | 40 [gain_squared, kAntiAliasingImpactLimit](float& a) { |
| 40 a = std::min(a, (*gain_squared)[kAntiAliasingImpactLimit]); | 41 a = std::min(a, (*gain_squared)[kAntiAliasingImpactLimit]); |
| 41 }); | 42 }); |
| 42 (*gain_squared)[kFftLengthBy2] = (*gain_squared)[kFftLengthBy2Minus1]; | 43 (*gain_squared)[kFftLengthBy2] = (*gain_squared)[kFftLengthBy2Minus1]; |
| 43 } | 44 } |
| 44 | 45 |
| 45 constexpr int kNumIterations = 2; | 46 constexpr int kNumIterations = 2; |
| 46 constexpr float kEchoMaskingMargin = 1.f / 10.f; | 47 constexpr float kEchoMaskingMargin = 1.f / 20.f; |
| 47 constexpr float kBandMaskingFactor = 1.f / 2.f; | 48 constexpr float kBandMaskingFactor = 1.f / 10.f; |
| 48 constexpr float kTimeMaskingFactor = 1.f / 10.f; | 49 constexpr float kTimeMaskingFactor = 1.f / 10.f; |
| 49 | 50 |
| 50 } // namespace | 51 } // namespace |
| 51 | 52 |
| 52 namespace aec3 { | 53 namespace aec3 { |
| 53 | 54 |
| 54 #if defined(WEBRTC_ARCH_X86_FAMILY) | 55 #if defined(WEBRTC_ARCH_X86_FAMILY) |
| 55 | 56 |
| 56 // Optimized SSE2 code for the gain computation. | 57 // Optimized SSE2 code for the gain computation. |
| 57 // TODO(peah): Add further optimizations, in particular for the divisions. | 58 // TODO(peah): Add further optimizations, in particular for the divisions. |
| (...skipping 72 matching lines...) Expand 10 before | Expand all | Expand 10 after Loading... |
| 130 | 131 |
| 131 // Limit gain for bands with strong nearend. | 132 // Limit gain for bands with strong nearend. |
| 132 std::transform(gain_squared->begin() + 1, gain_squared->end() - 1, | 133 std::transform(gain_squared->begin() + 1, gain_squared->end() - 1, |
| 133 strong_nearend.begin(), gain_squared->begin() + 1, | 134 strong_nearend.begin(), gain_squared->begin() + 1, |
| 134 [](float a, bool b) { return b ? 1.f : a; }); | 135 [](float a, bool b) { return b ? 1.f : a; }); |
| 135 | 136 |
| 136 // Limit the allowed gain update over time. | 137 // Limit the allowed gain update over time. |
| 137 std::transform(gain_squared->begin() + 1, gain_squared->end() - 1, | 138 std::transform(gain_squared->begin() + 1, gain_squared->end() - 1, |
| 138 previous_gain_squared->begin(), gain_squared->begin() + 1, | 139 previous_gain_squared->begin(), gain_squared->begin() + 1, |
| 139 [](float a, float b) { | 140 [](float a, float b) { |
| 140 return b < 0.0001f ? std::min(a, 0.0001f) | 141 return b < 0.001f ? std::min(a, 0.001f) |
| 141 : std::min(a, b * 2.f); | 142 : std::min(a, b * 2.f); |
| 142 }); | 143 }); |
| 143 | 144 |
| 144 // Process the gains to avoid artefacts caused by gain realization in the | 145 // Process the gains to avoid artefacts caused by gain realization in the |
| 145 // filterbank and impact of external pre-processing of the signal. | 146 // filterbank and impact of external pre-processing of the signal. |
| 146 GainPostProcessing(gain_squared); | 147 GainPostProcessing(gain_squared); |
| 147 } | 148 } |
| 148 | 149 |
| 149 std::copy(gain_squared->begin() + 1, gain_squared->end() - 1, | 150 std::copy(gain_squared->begin() + 1, gain_squared->end() - 1, |
| 150 previous_gain_squared->begin()); | 151 previous_gain_squared->begin()); |
| 151 | 152 |
| (...skipping 90 matching lines...) Expand 10 before | Expand all | Expand 10 after Loading... |
| 242 | 243 |
| 243 // Limit gain for bands with strong nearend. | 244 // Limit gain for bands with strong nearend. |
| 244 std::transform(gain_squared->begin() + 1, gain_squared->end() - 1, | 245 std::transform(gain_squared->begin() + 1, gain_squared->end() - 1, |
| 245 strong_nearend.begin(), gain_squared->begin() + 1, | 246 strong_nearend.begin(), gain_squared->begin() + 1, |
| 246 [](float a, bool b) { return b ? 1.f : a; }); | 247 [](float a, bool b) { return b ? 1.f : a; }); |
| 247 | 248 |
| 248 // Limit the allowed gain update over time. | 249 // Limit the allowed gain update over time. |
| 249 std::transform(gain_squared->begin() + 1, gain_squared->end() - 1, | 250 std::transform(gain_squared->begin() + 1, gain_squared->end() - 1, |
| 250 previous_gain_squared->begin(), gain_squared->begin() + 1, | 251 previous_gain_squared->begin(), gain_squared->begin() + 1, |
| 251 [](float a, float b) { | 252 [](float a, float b) { |
| 252 return b < 0.0001f ? std::min(a, 0.0001f) | 253 return b < 0.001f ? std::min(a, 0.001f) |
| 253 : std::min(a, b * 2.f); | 254 : std::min(a, b * 2.f); |
| 254 }); | 255 }); |
| 255 | 256 |
| 256 // Process the gains to avoid artefacts caused by gain realization in the | 257 // Process the gains to avoid artefacts caused by gain realization in the |
| 257 // filterbank and impact of external pre-processing of the signal. | 258 // filterbank and impact of external pre-processing of the signal. |
| 258 GainPostProcessing(gain_squared); | 259 GainPostProcessing(gain_squared); |
| 259 } | 260 } |
| 260 | 261 |
| 261 std::copy(gain_squared->begin() + 1, gain_squared->end() - 1, | 262 std::copy(gain_squared->begin() + 1, gain_squared->end() - 1, |
| 262 previous_gain_squared->begin()); | 263 previous_gain_squared->begin()); |
| 263 | 264 |
| 264 std::transform(gain_squared->begin() + 1, gain_squared->end() - 1, | 265 std::transform(gain_squared->begin() + 1, gain_squared->end() - 1, |
| 265 nearend_power.begin() + 1, previous_masker->begin(), | 266 nearend_power.begin() + 1, previous_masker->begin(), |
| 266 std::multiplies<float>()); | 267 std::multiplies<float>()); |
| 267 std::transform(previous_masker->begin(), previous_masker->end(), | 268 std::transform(previous_masker->begin(), previous_masker->end(), |
| 268 comfort_noise_power.begin() + 1, previous_masker->begin(), | 269 comfort_noise_power.begin() + 1, previous_masker->begin(), |
| 269 std::plus<float>()); | 270 std::plus<float>()); |
| 270 | 271 |
| 271 std::transform(gain_squared->begin(), gain_squared->end(), gain->begin(), | 272 std::transform(gain_squared->begin(), gain_squared->end(), gain->begin(), |
| 272 [](float a) { return sqrtf(a); }); | 273 [](float a) { return sqrtf(a); }); |
| 273 } | 274 } |
| 274 | 275 |
| 275 } // namespace aec3 | 276 } // namespace aec3 |
| 276 | 277 |
| 278 // Computes an upper bound on the gain to apply for high frequencies. |
| 279 float HighFrequencyGainBound(bool saturated_echo, |
| 280 const std::vector<std::vector<float>>& render) { |
| 281 if (render.size() == 1) { |
| 282 return 1.f; |
| 283 } |
| 284 |
| 285 // Always attenuate the upper bands when there is saturated echo. |
| 286 if (saturated_echo) { |
| 287 return 0.001f; |
| 288 } |
| 289 |
| 290 // Compute the upper and lower band energies. |
| 291 float low_band_energy = |
| 292 std::accumulate(render[0].begin(), render[0].end(), 0.f, |
| 293 [](float a, float b) -> float { return a + b * b; }); |
| 294 float high_band_energies = 0.f; |
| 295 for (size_t k = 1; k < render.size(); ++k) { |
| 296 high_band_energies = std::max( |
| 297 high_band_energies, |
| 298 std::accumulate(render[k].begin(), render[k].end(), 0.f, |
| 299 [](float a, float b) -> float { return a + b * b; })); |
| 300 } |
| 301 |
| 302 // If there is more power in the lower frequencies than the upper frequencies, |
| 303 // or if the power in upper frequencies is low, do not bound the gain in the |
| 304 // upper bands. |
| 305 if (high_band_energies < low_band_energy || |
| 306 high_band_energies < kSubBlockSize * 10.f * 10.f) { |
| 307 return 1.f; |
| 308 } |
| 309 |
| 310 // In all other cases, bound the gain for upper frequencies. |
| 311 RTC_DCHECK_LE(low_band_energy, high_band_energies); |
| 312 return 0.01f * sqrtf(low_band_energy / high_band_energies); |
| 313 } |
| 314 |
| 277 SuppressionGain::SuppressionGain(Aec3Optimization optimization) | 315 SuppressionGain::SuppressionGain(Aec3Optimization optimization) |
| 278 : optimization_(optimization) { | 316 : optimization_(optimization) { |
| 279 previous_gain_squared_.fill(1.f); | 317 previous_gain_squared_.fill(1.f); |
| 280 previous_masker_.fill(0.f); | 318 previous_masker_.fill(0.f); |
| 281 } | 319 } |
| 282 | 320 |
| 283 void SuppressionGain::GetGain( | 321 void SuppressionGain::GetGain( |
| 284 const std::array<float, kFftLengthBy2Plus1>& nearend_power, | 322 const std::array<float, kFftLengthBy2Plus1>& nearend_power, |
| 285 const std::array<float, kFftLengthBy2Plus1>& residual_echo_power, | 323 const std::array<float, kFftLengthBy2Plus1>& residual_echo_power, |
| 286 const std::array<float, kFftLengthBy2Plus1>& comfort_noise_power, | 324 const std::array<float, kFftLengthBy2Plus1>& comfort_noise_power, |
| 287 float strong_nearend_margin, | 325 bool saturated_echo, |
| 288 std::array<float, kFftLengthBy2Plus1>* gain) { | 326 const std::vector<std::vector<float>>& render, |
| 289 RTC_DCHECK(gain); | 327 size_t num_capture_bands, |
| 328 float* high_bands_gain, |
| 329 std::array<float, kFftLengthBy2Plus1>* low_band_gain) { |
| 330 RTC_DCHECK(high_bands_gain); |
| 331 RTC_DCHECK(low_band_gain); |
| 332 |
| 333 // Choose margin to use. |
| 334 const float margin = saturated_echo ? 0.001f : 0.01f; |
| 290 switch (optimization_) { | 335 switch (optimization_) { |
| 291 #if defined(WEBRTC_ARCH_X86_FAMILY) | 336 #if defined(WEBRTC_ARCH_X86_FAMILY) |
| 292 case Aec3Optimization::kSse2: | 337 case Aec3Optimization::kSse2: |
| 293 aec3::ComputeGains_SSE2(nearend_power, residual_echo_power, | 338 aec3::ComputeGains_SSE2( |
| 294 comfort_noise_power, strong_nearend_margin, | 339 nearend_power, residual_echo_power, comfort_noise_power, margin, |
| 295 &previous_gain_squared_, &previous_masker_, gain); | 340 &previous_gain_squared_, &previous_masker_, low_band_gain); |
| 296 break; | 341 break; |
| 297 #endif | 342 #endif |
| 298 default: | 343 default: |
| 299 aec3::ComputeGains(nearend_power, residual_echo_power, | 344 aec3::ComputeGains(nearend_power, residual_echo_power, |
| 300 comfort_noise_power, strong_nearend_margin, | 345 comfort_noise_power, margin, &previous_gain_squared_, |
| 301 &previous_gain_squared_, &previous_masker_, gain); | 346 &previous_masker_, low_band_gain); |
| 347 } |
| 348 |
| 349 if (num_capture_bands > 1) { |
| 350 // Compute the gain for upper frequencies. |
| 351 const float min_high_band_gain = |
| 352 HighFrequencyGainBound(saturated_echo, render); |
| 353 *high_bands_gain = |
| 354 *std::min_element(low_band_gain->begin() + 32, low_band_gain->end()); |
| 355 |
| 356 *high_bands_gain = std::min(*high_bands_gain, min_high_band_gain); |
| 357 |
| 358 } else { |
| 359 *high_bands_gain = 1.f; |
| 302 } | 360 } |
| 303 } | 361 } |
| 304 | 362 |
| 305 } // namespace webrtc | 363 } // namespace webrtc |
| OLD | NEW |