OLD | NEW |
1 /* | 1 /* |
2 * Copyright (c) 2017 The WebRTC project authors. All Rights Reserved. | 2 * Copyright (c) 2017 The WebRTC project authors. All Rights Reserved. |
3 * | 3 * |
4 * Use of this source code is governed by a BSD-style license | 4 * Use of this source code is governed by a BSD-style license |
5 * that can be found in the LICENSE file in the root of the source | 5 * that can be found in the LICENSE file in the root of the source |
6 * tree. An additional intellectual property rights grant can be found | 6 * tree. An additional intellectual property rights grant can be found |
7 * in the file PATENTS. All contributing project authors may | 7 * in the file PATENTS. All contributing project authors may |
8 * be found in the AUTHORS file in the root of the source tree. | 8 * be found in the AUTHORS file in the root of the source tree. |
9 */ | 9 */ |
10 | 10 |
11 #include "webrtc/modules/audio_processing/aec3/suppression_gain.h" | 11 #include "webrtc/modules/audio_processing/aec3/suppression_gain.h" |
12 | 12 |
13 #include "webrtc/typedefs.h" | 13 #include "webrtc/typedefs.h" |
14 #if defined(WEBRTC_ARCH_X86_FAMILY) | 14 #if defined(WEBRTC_ARCH_X86_FAMILY) |
15 #include <emmintrin.h> | 15 #include <emmintrin.h> |
16 #endif | 16 #endif |
17 #include <math.h> | 17 #include <math.h> |
18 #include <algorithm> | 18 #include <algorithm> |
19 #include <functional> | 19 #include <functional> |
20 #include <numeric> | 20 #include <numeric> |
21 | 21 |
22 #include "webrtc/base/checks.h" | 22 #include "webrtc/base/checks.h" |
| 23 #include "webrtc/modules/audio_processing/aec3/vector_math.h" |
23 | 24 |
24 namespace webrtc { | 25 namespace webrtc { |
25 namespace { | 26 namespace { |
26 | 27 |
27 void GainPostProcessing(std::array<float, kFftLengthBy2Plus1>* gain_squared) { | 28 void GainPostProcessing(std::array<float, kFftLengthBy2Plus1>* gain_squared) { |
28 // Limit the low frequency gains to avoid the impact of the high-pass filter | 29 // Limit the low frequency gains to avoid the impact of the high-pass filter |
29 // on the lower-frequency gain influencing the overall achieved gain. | 30 // on the lower-frequency gain influencing the overall achieved gain. |
30 (*gain_squared)[1] = std::min((*gain_squared)[1], (*gain_squared)[2]); | 31 (*gain_squared)[1] = std::min((*gain_squared)[1], (*gain_squared)[2]); |
31 (*gain_squared)[0] = (*gain_squared)[1]; | 32 (*gain_squared)[0] = (*gain_squared)[1]; |
32 | 33 |
33 // Limit the high frequency gains to avoid the impact of the anti-aliasing | 34 // Limit the high frequency gains to avoid the impact of the anti-aliasing |
34 // filter on the upper-frequency gains influencing the overall achieved | 35 // filter on the upper-frequency gains influencing the overall achieved |
35 // gain. TODO(peah): Update this when new anti-aliasing filters are | 36 // gain. TODO(peah): Update this when new anti-aliasing filters are |
36 // implemented. | 37 // implemented. |
37 constexpr size_t kAntiAliasingImpactLimit = (64 * 2000) / 8000; | 38 constexpr size_t kAntiAliasingImpactLimit = (64 * 2000) / 8000; |
38 std::for_each(gain_squared->begin() + kAntiAliasingImpactLimit, | 39 std::for_each(gain_squared->begin() + kAntiAliasingImpactLimit, |
39 gain_squared->end() - 1, | 40 gain_squared->end() - 1, |
40 [gain_squared, kAntiAliasingImpactLimit](float& a) { | 41 [gain_squared, kAntiAliasingImpactLimit](float& a) { |
41 a = std::min(a, (*gain_squared)[kAntiAliasingImpactLimit]); | 42 a = std::min(a, (*gain_squared)[kAntiAliasingImpactLimit]); |
42 }); | 43 }); |
43 (*gain_squared)[kFftLengthBy2] = (*gain_squared)[kFftLengthBy2Minus1]; | 44 (*gain_squared)[kFftLengthBy2] = (*gain_squared)[kFftLengthBy2Minus1]; |
44 } | 45 } |
45 | 46 |
46 constexpr int kNumIterations = 2; | 47 constexpr int kNumIterations = 2; |
47 constexpr float kEchoMaskingMargin = 1.f / 20.f; | 48 constexpr float kEchoMaskingMargin = 1.f / 20.f; |
48 constexpr float kBandMaskingFactor = 1.f / 10.f; | 49 constexpr float kBandMaskingFactor = 1.f / 10.f; |
49 constexpr float kTimeMaskingFactor = 1.f / 10.f; | 50 constexpr float kTimeMaskingFactor = 1.f / 10.f; |
50 | 51 |
51 } // namespace | |
52 | |
53 namespace aec3 { | |
54 | |
55 #if defined(WEBRTC_ARCH_X86_FAMILY) | |
56 | |
57 // Optimized SSE2 code for the gain computation. | |
58 // TODO(peah): Add further optimizations, in particular for the divisions. | 52 // TODO(peah): Add further optimizations, in particular for the divisions. |
59 void ComputeGains_SSE2( | 53 void ComputeGains( |
| 54 Aec3Optimization optimization, |
60 const std::array<float, kFftLengthBy2Plus1>& nearend_power, | 55 const std::array<float, kFftLengthBy2Plus1>& nearend_power, |
61 const std::array<float, kFftLengthBy2Plus1>& residual_echo_power, | 56 const std::array<float, kFftLengthBy2Plus1>& residual_echo_power, |
62 const std::array<float, kFftLengthBy2Plus1>& comfort_noise_power, | 57 const std::array<float, kFftLengthBy2Plus1>& comfort_noise_power, |
63 float strong_nearend_margin, | |
64 std::array<float, kFftLengthBy2Minus1>* previous_gain_squared, | |
65 std::array<float, kFftLengthBy2Minus1>* previous_masker, | |
66 std::array<float, kFftLengthBy2Plus1>* gain) { | |
67 std::array<float, kFftLengthBy2Minus1> masker; | |
68 std::array<float, kFftLengthBy2Minus1> same_band_masker; | |
69 std::array<float, kFftLengthBy2Minus1> one_by_residual_echo_power; | |
70 std::array<bool, kFftLengthBy2Minus1> strong_nearend; | |
71 std::array<float, kFftLengthBy2Plus1> neighboring_bands_masker; | |
72 std::array<float, kFftLengthBy2Plus1>* gain_squared = gain; | |
73 | |
74 // Precompute 1/residual_echo_power. | |
75 std::transform(residual_echo_power.begin() + 1, residual_echo_power.end() - 1, | |
76 one_by_residual_echo_power.begin(), | |
77 [](float a) { return a > 0.f ? 1.f / a : -1.f; }); | |
78 | |
79 // Precompute indicators for bands with strong nearend. | |
80 std::transform( | |
81 residual_echo_power.begin() + 1, residual_echo_power.end() - 1, | |
82 nearend_power.begin() + 1, strong_nearend.begin(), | |
83 [&](float a, float b) { return a <= strong_nearend_margin * b; }); | |
84 | |
85 // Precompute masker for the same band. | |
86 std::transform(comfort_noise_power.begin() + 1, comfort_noise_power.end() - 1, | |
87 previous_masker->begin(), same_band_masker.begin(), | |
88 [&](float a, float b) { return a + kTimeMaskingFactor * b; }); | |
89 | |
90 for (int k = 0; k < kNumIterations; ++k) { | |
91 if (k == 0) { | |
92 // Add masker from the same band. | |
93 std::copy(same_band_masker.begin(), same_band_masker.end(), | |
94 masker.begin()); | |
95 } else { | |
96 // Add masker for neighboring bands. | |
97 std::transform(nearend_power.begin(), nearend_power.end(), | |
98 gain_squared->begin(), neighboring_bands_masker.begin(), | |
99 std::multiplies<float>()); | |
100 std::transform(neighboring_bands_masker.begin(), | |
101 neighboring_bands_masker.end(), | |
102 comfort_noise_power.begin(), | |
103 neighboring_bands_masker.begin(), std::plus<float>()); | |
104 std::transform( | |
105 neighboring_bands_masker.begin(), neighboring_bands_masker.end() - 2, | |
106 neighboring_bands_masker.begin() + 2, masker.begin(), | |
107 [&](float a, float b) { return kBandMaskingFactor * (a + b); }); | |
108 | |
109 // Add masker from the same band. | |
110 std::transform(same_band_masker.begin(), same_band_masker.end(), | |
111 masker.begin(), masker.begin(), std::plus<float>()); | |
112 } | |
113 | |
114 // Compute new gain as: | |
115 // G2(t,f) = (comfort_noise_power(t,f) + G2(t-1)*nearend_power(t-1)) * | |
116 // kTimeMaskingFactor | |
117 // * kEchoMaskingMargin / residual_echo_power(t,f). | |
118 // or | |
119 // G2(t,f) = ((comfort_noise_power(t,f) + G2(t-1) * | |
120 // nearend_power(t-1)) * kTimeMaskingFactor + | |
121 // (comfort_noise_power(t, f-1) + comfort_noise_power(t, f+1) + | |
122 // (G2(t,f-1)*nearend_power(t, f-1) + | |
123 // G2(t,f+1)*nearend_power(t, f+1)) * | |
124 // kTimeMaskingFactor) * kBandMaskingFactor) | |
125 // * kEchoMaskingMargin / residual_echo_power(t,f). | |
126 std::transform( | |
127 masker.begin(), masker.end(), one_by_residual_echo_power.begin(), | |
128 gain_squared->begin() + 1, [&](float a, float b) { | |
129 return b >= 0 ? std::min(kEchoMaskingMargin * a * b, 1.f) : 1.f; | |
130 }); | |
131 | |
132 // Limit gain for bands with strong nearend. | |
133 std::transform(gain_squared->begin() + 1, gain_squared->end() - 1, | |
134 strong_nearend.begin(), gain_squared->begin() + 1, | |
135 [](float a, bool b) { return b ? 1.f : a; }); | |
136 | |
137 // Limit the allowed gain update over time. | |
138 std::transform(gain_squared->begin() + 1, gain_squared->end() - 1, | |
139 previous_gain_squared->begin(), gain_squared->begin() + 1, | |
140 [](float a, float b) { | |
141 return b < 0.001f ? std::min(a, 0.001f) | |
142 : std::min(a, b * 2.f); | |
143 }); | |
144 | |
145 // Process the gains to avoid artefacts caused by gain realization in the | |
146 // filterbank and impact of external pre-processing of the signal. | |
147 GainPostProcessing(gain_squared); | |
148 } | |
149 | |
150 std::copy(gain_squared->begin() + 1, gain_squared->end() - 1, | |
151 previous_gain_squared->begin()); | |
152 | |
153 std::transform(gain_squared->begin() + 1, gain_squared->end() - 1, | |
154 nearend_power.begin() + 1, previous_masker->begin(), | |
155 std::multiplies<float>()); | |
156 std::transform(previous_masker->begin(), previous_masker->end(), | |
157 comfort_noise_power.begin() + 1, previous_masker->begin(), | |
158 std::plus<float>()); | |
159 | |
160 for (size_t k = 0; k < kFftLengthBy2; k += 4) { | |
161 __m128 g = _mm_loadu_ps(&(*gain_squared)[k]); | |
162 g = _mm_sqrt_ps(g); | |
163 _mm_storeu_ps(&(*gain)[k], g); | |
164 } | |
165 | |
166 (*gain)[kFftLengthBy2] = sqrtf((*gain)[kFftLengthBy2]); | |
167 } | |
168 | |
169 #endif | |
170 | |
171 void ComputeGains( | |
172 const std::array<float, kFftLengthBy2Plus1>& nearend_power, | |
173 const std::array<float, kFftLengthBy2Plus1>& residual_echo_power, | |
174 const std::array<float, kFftLengthBy2Plus1>& comfort_noise_power, | |
175 float strong_nearend_margin, | 58 float strong_nearend_margin, |
176 std::array<float, kFftLengthBy2Minus1>* previous_gain_squared, | 59 std::array<float, kFftLengthBy2Minus1>* previous_gain_squared, |
177 std::array<float, kFftLengthBy2Minus1>* previous_masker, | 60 std::array<float, kFftLengthBy2Minus1>* previous_masker, |
178 std::array<float, kFftLengthBy2Plus1>* gain) { | 61 std::array<float, kFftLengthBy2Plus1>* gain) { |
179 std::array<float, kFftLengthBy2Minus1> masker; | 62 std::array<float, kFftLengthBy2Minus1> masker; |
180 std::array<float, kFftLengthBy2Minus1> same_band_masker; | 63 std::array<float, kFftLengthBy2Minus1> same_band_masker; |
181 std::array<float, kFftLengthBy2Minus1> one_by_residual_echo_power; | 64 std::array<float, kFftLengthBy2Minus1> one_by_residual_echo_power; |
182 std::array<bool, kFftLengthBy2Minus1> strong_nearend; | 65 std::array<bool, kFftLengthBy2Minus1> strong_nearend; |
183 std::array<float, kFftLengthBy2Plus1> neighboring_bands_masker; | 66 std::array<float, kFftLengthBy2Plus1> neighboring_bands_masker; |
184 std::array<float, kFftLengthBy2Plus1>* gain_squared = gain; | 67 std::array<float, kFftLengthBy2Plus1>* gain_squared = gain; |
| 68 aec3::VectorMath math(optimization); |
185 | 69 |
186 // Precompute 1/residual_echo_power. | 70 // Precompute 1/residual_echo_power. |
187 std::transform(residual_echo_power.begin() + 1, residual_echo_power.end() - 1, | 71 std::transform(residual_echo_power.begin() + 1, residual_echo_power.end() - 1, |
188 one_by_residual_echo_power.begin(), | 72 one_by_residual_echo_power.begin(), |
189 [](float a) { return a > 0.f ? 1.f / a : -1.f; }); | 73 [](float a) { return a > 0.f ? 1.f / a : -1.f; }); |
190 | 74 |
191 // Precompute indicators for bands with strong nearend. | 75 // Precompute indicators for bands with strong nearend. |
192 std::transform( | 76 std::transform( |
193 residual_echo_power.begin() + 1, residual_echo_power.end() - 1, | 77 residual_echo_power.begin() + 1, residual_echo_power.end() - 1, |
194 nearend_power.begin() + 1, strong_nearend.begin(), | 78 nearend_power.begin() + 1, strong_nearend.begin(), |
195 [&](float a, float b) { return a <= strong_nearend_margin * b; }); | 79 [&](float a, float b) { return a <= strong_nearend_margin * b; }); |
196 | 80 |
197 // Precompute masker for the same band. | 81 // Precompute masker for the same band. |
198 std::transform(comfort_noise_power.begin() + 1, comfort_noise_power.end() - 1, | 82 std::transform(comfort_noise_power.begin() + 1, comfort_noise_power.end() - 1, |
199 previous_masker->begin(), same_band_masker.begin(), | 83 previous_masker->begin(), same_band_masker.begin(), |
200 [&](float a, float b) { return a + kTimeMaskingFactor * b; }); | 84 [&](float a, float b) { return a + kTimeMaskingFactor * b; }); |
201 | 85 |
202 for (int k = 0; k < kNumIterations; ++k) { | 86 for (int k = 0; k < kNumIterations; ++k) { |
203 if (k == 0) { | 87 if (k == 0) { |
204 // Add masker from the same band. | 88 // Add masker from the same band. |
205 std::copy(same_band_masker.begin(), same_band_masker.end(), | 89 std::copy(same_band_masker.begin(), same_band_masker.end(), |
206 masker.begin()); | 90 masker.begin()); |
207 } else { | 91 } else { |
208 // Add masker for neightboring bands. | 92 // Add masker for neighboring bands. |
209 std::transform(nearend_power.begin(), nearend_power.end(), | 93 math.Multiply(nearend_power, *gain_squared, neighboring_bands_masker); |
210 gain_squared->begin(), neighboring_bands_masker.begin(), | 94 math.Accumulate(comfort_noise_power, neighboring_bands_masker); |
211 std::multiplies<float>()); | |
212 std::transform(neighboring_bands_masker.begin(), | |
213 neighboring_bands_masker.end(), | |
214 comfort_noise_power.begin(), | |
215 neighboring_bands_masker.begin(), std::plus<float>()); | |
216 std::transform( | 95 std::transform( |
217 neighboring_bands_masker.begin(), neighboring_bands_masker.end() - 2, | 96 neighboring_bands_masker.begin(), neighboring_bands_masker.end() - 2, |
218 neighboring_bands_masker.begin() + 2, masker.begin(), | 97 neighboring_bands_masker.begin() + 2, masker.begin(), |
219 [&](float a, float b) { return kBandMaskingFactor * (a + b); }); | 98 [&](float a, float b) { return kBandMaskingFactor * (a + b); }); |
220 | 99 |
221 // Add masker from the same band. | 100 // Add masker from the same band. |
222 std::transform(same_band_masker.begin(), same_band_masker.end(), | 101 math.Accumulate(same_band_masker, masker); |
223 masker.begin(), masker.begin(), std::plus<float>()); | |
224 } | 102 } |
225 | 103 |
226 // Compute new gain as: | 104 // Compute new gain as: |
227 // G2(t,f) = (comfort_noise_power(t,f) + G2(t-1)*nearend_power(t-1)) * | 105 // G2(t,f) = (comfort_noise_power(t,f) + G2(t-1)*nearend_power(t-1)) * |
228 // kTimeMaskingFactor | 106 // kTimeMaskingFactor |
229 // * kEchoMaskingMargin / residual_echo_power(t,f). | 107 // * kEchoMaskingMargin / residual_echo_power(t,f). |
230 // or | 108 // or |
231 // G2(t,f) = ((comfort_noise_power(t,f) + G2(t-1) * | 109 // G2(t,f) = ((comfort_noise_power(t,f) + G2(t-1) * |
232 // nearend_power(t-1)) * kTimeMaskingFactor + | 110 // nearend_power(t-1)) * kTimeMaskingFactor + |
233 // (comfort_noise_power(t, f-1) + comfort_noise_power(t, f+1) + | 111 // (comfort_noise_power(t, f-1) + comfort_noise_power(t, f+1) + |
(...skipping 21 matching lines...) Expand all Loading... |
255 }); | 133 }); |
256 | 134 |
257 // Process the gains to avoid artefacts caused by gain realization in the | 135 // Process the gains to avoid artefacts caused by gain realization in the |
258 // filterbank and impact of external pre-processing of the signal. | 136 // filterbank and impact of external pre-processing of the signal. |
259 GainPostProcessing(gain_squared); | 137 GainPostProcessing(gain_squared); |
260 } | 138 } |
261 | 139 |
262 std::copy(gain_squared->begin() + 1, gain_squared->end() - 1, | 140 std::copy(gain_squared->begin() + 1, gain_squared->end() - 1, |
263 previous_gain_squared->begin()); | 141 previous_gain_squared->begin()); |
264 | 142 |
265 std::transform(gain_squared->begin() + 1, gain_squared->end() - 1, | 143 math.Multiply( |
266 nearend_power.begin() + 1, previous_masker->begin(), | 144 rtc::ArrayView<const float>(&(*gain_squared)[1], previous_masker->size()), |
267 std::multiplies<float>()); | 145 rtc::ArrayView<const float>(&nearend_power[1], previous_masker->size()), |
268 std::transform(previous_masker->begin(), previous_masker->end(), | 146 *previous_masker); |
269 comfort_noise_power.begin() + 1, previous_masker->begin(), | 147 math.Accumulate(rtc::ArrayView<const float>(&comfort_noise_power[1], |
270 std::plus<float>()); | 148 previous_masker->size()), |
271 | 149 *previous_masker); |
272 std::transform(gain_squared->begin(), gain_squared->end(), gain->begin(), | 150 math.Sqrt(*gain); |
273 [](float a) { return sqrtf(a); }); | |
274 } | 151 } |
275 | 152 |
276 } // namespace aec3 | 153 } // namespace |
277 | 154 |
278 // Computes an upper bound on the gain to apply for high frequencies. | 155 // Computes an upper bound on the gain to apply for high frequencies. |
279 float HighFrequencyGainBound(bool saturated_echo, | 156 float HighFrequencyGainBound(bool saturated_echo, |
280 const std::vector<std::vector<float>>& render) { | 157 const std::vector<std::vector<float>>& render) { |
281 if (render.size() == 1) { | 158 if (render.size() == 1) { |
282 return 1.f; | 159 return 1.f; |
283 } | 160 } |
284 | 161 |
285 // Always attenuate the upper bands when there is saturated echo. | 162 // Always attenuate the upper bands when there is saturated echo. |
286 if (saturated_echo) { | 163 if (saturated_echo) { |
(...skipping 48 matching lines...) Expand 10 before | Expand all | Expand 10 after Loading... |
335 previous_gain_squared_.fill(0.f); | 212 previous_gain_squared_.fill(0.f); |
336 std::copy(comfort_noise_power.begin() + 1, comfort_noise_power.end() - 1, | 213 std::copy(comfort_noise_power.begin() + 1, comfort_noise_power.end() - 1, |
337 previous_masker_.begin()); | 214 previous_masker_.begin()); |
338 low_band_gain->fill(0.f); | 215 low_band_gain->fill(0.f); |
339 *high_bands_gain = 0.f; | 216 *high_bands_gain = 0.f; |
340 return; | 217 return; |
341 } | 218 } |
342 | 219 |
343 // Choose margin to use. | 220 // Choose margin to use. |
344 const float margin = saturated_echo ? 0.001f : 0.01f; | 221 const float margin = saturated_echo ? 0.001f : 0.01f; |
345 switch (optimization_) { | 222 ComputeGains(optimization_, nearend_power, residual_echo_power, |
346 #if defined(WEBRTC_ARCH_X86_FAMILY) | 223 comfort_noise_power, margin, &previous_gain_squared_, |
347 case Aec3Optimization::kSse2: | 224 &previous_masker_, low_band_gain); |
348 aec3::ComputeGains_SSE2( | |
349 nearend_power, residual_echo_power, comfort_noise_power, margin, | |
350 &previous_gain_squared_, &previous_masker_, low_band_gain); | |
351 break; | |
352 #endif | |
353 default: | |
354 aec3::ComputeGains(nearend_power, residual_echo_power, | |
355 comfort_noise_power, margin, &previous_gain_squared_, | |
356 &previous_masker_, low_band_gain); | |
357 } | |
358 | 225 |
359 if (num_capture_bands > 1) { | 226 if (num_capture_bands > 1) { |
360 // Compute the gain for upper frequencies. | 227 // Compute the gain for upper frequencies. |
361 const float min_high_band_gain = | 228 const float min_high_band_gain = |
362 HighFrequencyGainBound(saturated_echo, render); | 229 HighFrequencyGainBound(saturated_echo, render); |
363 *high_bands_gain = | 230 *high_bands_gain = |
364 *std::min_element(low_band_gain->begin() + 32, low_band_gain->end()); | 231 *std::min_element(low_band_gain->begin() + 32, low_band_gain->end()); |
365 | 232 |
366 *high_bands_gain = std::min(*high_bands_gain, min_high_band_gain); | 233 *high_bands_gain = std::min(*high_bands_gain, min_high_band_gain); |
367 | 234 |
368 } else { | 235 } else { |
369 *high_bands_gain = 1.f; | 236 *high_bands_gain = 1.f; |
370 } | 237 } |
371 } | 238 } |
372 | 239 |
373 } // namespace webrtc | 240 } // namespace webrtc |
OLD | NEW |