Chromium Code Reviews
chromiumcodereview-hr@appspot.gserviceaccount.com (chromiumcodereview-hr) | Please choose your nickname with Settings | Help | Chromium Project | Gerrit Changes | Sign out
(383)

Side by Side Diff: webrtc/modules/audio_processing/intelligibility/intelligibility_enhancer.cc

Issue 1207353002: Add new variance update option and unittests for intelligibility (Closed) Base URL: https://chromium.googlesource.com/external/webrtc.git@master
Patch Set: Merge Created 5 years, 5 months ago
Use n/p to move between diff chunks; N/P to move between comments. Draft comments are only viewable by you.
Jump to:
View unified diff | Download patch
OLDNEW
1 /* 1 /*
2 * Copyright (c) 2014 The WebRTC project authors. All Rights Reserved. 2 * Copyright (c) 2014 The WebRTC project authors. All Rights Reserved.
3 * 3 *
4 * Use of this source code is governed by a BSD-style license 4 * Use of this source code is governed by a BSD-style license
5 * that can be found in the LICENSE file in the root of the source 5 * that can be found in the LICENSE file in the root of the source
6 * tree. An additional intellectual property rights grant can be found 6 * tree. An additional intellectual property rights grant can be found
7 * in the file PATENTS. All contributing project authors may 7 * in the file PATENTS. All contributing project authors may
8 * be found in the AUTHORS file in the root of the source tree. 8 * be found in the AUTHORS file in the root of the source tree.
9 */ 9 */
10 10
11 // 11 //
12 // Implements core class for intelligibility enhancer. 12 // Implements core class for intelligibility enhancer.
13 // 13 //
14 // Details of the model and algorithm can be found in the original paper: 14 // Details of the model and algorithm can be found in the original paper:
15 // http://ieeexplore.ieee.org/stamp/stamp.jsp?tp=&arnumber=6882788 15 // http://ieeexplore.ieee.org/stamp/stamp.jsp?tp=&arnumber=6882788
16 // 16 //
17 17
18 #include "webrtc/modules/audio_processing/intelligibility/intelligibility_enhanc er.h" 18 #include "webrtc/modules/audio_processing/intelligibility/intelligibility_enhanc er.h"
19 19
20 #include <cmath> 20 #include <math.h>
21 #include <cstdlib> 21 #include <stdlib.h>
22 22
23 #include <algorithm> 23 #include <algorithm>
24 #include <numeric> 24 #include <numeric>
25 25
26 #include "webrtc/base/checks.h" 26 #include "webrtc/base/checks.h"
27 #include "webrtc/common_audio/vad/include/webrtc_vad.h" 27 #include "webrtc/common_audio/vad/include/webrtc_vad.h"
28 #include "webrtc/common_audio/window_generator.h" 28 #include "webrtc/common_audio/window_generator.h"
29 29
30 namespace webrtc {
31
32 namespace {
33
34 const int kErbResolution = 2;
35 const int kWindowSizeMs = 2;
36 const int kChunkSizeMs = 10; // Size provided by APM.
37 const float kClipFreq = 200.0f;
38 const float kConfigRho = 0.02f; // Default production and interpretation SNR.
39 const float kKbdAlpha = 1.5f;
40 const float kLambdaBot = -1.0f; // Extreme values in bisection
41 const float kLambdaTop = -10e-18f; // search for lamda.
42
43 } // namespace
44
30 using std::complex; 45 using std::complex;
31 using std::max; 46 using std::max;
32 using std::min; 47 using std::min;
33
34 namespace webrtc {
35
36 const int IntelligibilityEnhancer::kErbResolution = 2;
37 const int IntelligibilityEnhancer::kWindowSizeMs = 2;
38 const int IntelligibilityEnhancer::kChunkSizeMs = 10; // Size provided by APM.
39 const int IntelligibilityEnhancer::kAnalyzeRate = 800;
40 const int IntelligibilityEnhancer::kVarianceRate = 2;
41 const float IntelligibilityEnhancer::kClipFreq = 200.0f;
42 const float IntelligibilityEnhancer::kConfigRho = 0.02f;
43 const float IntelligibilityEnhancer::kKbdAlpha = 1.5f;
44
45 // To disable gain update smoothing, set gain limit to be VERY high.
46 // TODO(ekmeyerson): Add option to disable gain smoothing altogether
47 // to avoid the extra computation.
48 const float IntelligibilityEnhancer::kGainChangeLimit = 0.0125f;
49
50 using VarianceType = intelligibility::VarianceArray::StepType; 48 using VarianceType = intelligibility::VarianceArray::StepType;
51 49
52 IntelligibilityEnhancer::TransformCallback::TransformCallback( 50 IntelligibilityEnhancer::TransformCallback::TransformCallback(
53 IntelligibilityEnhancer* parent, 51 IntelligibilityEnhancer* parent,
54 IntelligibilityEnhancer::AudioSource source) 52 IntelligibilityEnhancer::AudioSource source)
55 : parent_(parent), source_(source) { 53 : parent_(parent), source_(source) {
56 } 54 }
57 55
58 void IntelligibilityEnhancer::TransformCallback::ProcessAudioBlock( 56 void IntelligibilityEnhancer::TransformCallback::ProcessAudioBlock(
59 const complex<float>* const* in_block, 57 const complex<float>* const* in_block,
(...skipping 26 matching lines...) Expand all
86 channels_(channels), 84 channels_(channels),
87 analysis_rate_(analysis_rate), 85 analysis_rate_(analysis_rate),
88 variance_rate_(variance_rate), 86 variance_rate_(variance_rate),
89 clear_variance_(freqs_, 87 clear_variance_(freqs_,
90 static_cast<VarianceType>(cv_type), 88 static_cast<VarianceType>(cv_type),
91 cv_win, 89 cv_win,
92 cv_alpha), 90 cv_alpha),
93 noise_variance_(freqs_, VarianceType::kStepInfinite, 475, 0.01f), 91 noise_variance_(freqs_, VarianceType::kStepInfinite, 475, 0.01f),
94 filtered_clear_var_(new float[bank_size_]), 92 filtered_clear_var_(new float[bank_size_]),
95 filtered_noise_var_(new float[bank_size_]), 93 filtered_noise_var_(new float[bank_size_]),
96 filter_bank_(nullptr), 94 filter_bank_(bank_size_),
97 center_freqs_(new float[bank_size_]), 95 center_freqs_(new float[bank_size_]),
98 rho_(new float[bank_size_]), 96 rho_(new float[bank_size_]),
99 gains_eq_(new float[bank_size_]), 97 gains_eq_(new float[bank_size_]),
100 gain_applier_(freqs_, gain_limit), 98 gain_applier_(freqs_, gain_limit),
101 temp_out_buffer_(nullptr), 99 temp_out_buffer_(nullptr),
102 input_audio_(new float* [channels]), 100 input_audio_(new float* [channels]),
103 kbd_window_(new float[window_size_]), 101 kbd_window_(new float[window_size_]),
104 render_callback_(this, AudioSource::kRenderStream), 102 render_callback_(this, AudioSource::kRenderStream),
105 capture_callback_(this, AudioSource::kCaptureStream), 103 capture_callback_(this, AudioSource::kCaptureStream),
106 block_count_(0), 104 block_count_(0),
(...skipping 35 matching lines...) Expand 10 before | Expand all | Expand 10 after
142 channels_, channels_, chunk_length_, kbd_window_.get(), window_size_, 140 channels_, channels_, chunk_length_, kbd_window_.get(), window_size_,
143 window_size_ / 2, &render_callback_)); 141 window_size_ / 2, &render_callback_));
144 capture_mangler_.reset(new LappedTransform( 142 capture_mangler_.reset(new LappedTransform(
145 channels_, channels_, chunk_length_, kbd_window_.get(), window_size_, 143 channels_, channels_, chunk_length_, kbd_window_.get(), window_size_,
146 window_size_ / 2, &capture_callback_)); 144 window_size_ / 2, &capture_callback_));
147 } 145 }
148 146
149 IntelligibilityEnhancer::~IntelligibilityEnhancer() { 147 IntelligibilityEnhancer::~IntelligibilityEnhancer() {
150 WebRtcVad_Free(vad_low_); 148 WebRtcVad_Free(vad_low_);
151 WebRtcVad_Free(vad_high_); 149 WebRtcVad_Free(vad_high_);
152 free(filter_bank_); 150 free(temp_out_buffer_);
153 } 151 }
154 152
155 void IntelligibilityEnhancer::ProcessRenderAudio(float* const* audio) { 153 void IntelligibilityEnhancer::ProcessRenderAudio(float* const* audio) {
156 for (int i = 0; i < chunk_length_; ++i) { 154 for (int i = 0; i < chunk_length_; ++i) {
157 vad_tmp_buffer_[i] = (int16_t)audio[0][i]; 155 vad_tmp_buffer_[i] = (int16_t)audio[0][i];
158 } 156 }
159 has_voice_low_ = WebRtcVad_Process(vad_low_, sample_rate_hz_, 157 has_voice_low_ = WebRtcVad_Process(vad_low_, sample_rate_hz_,
160 vad_tmp_buffer_.get(), chunk_length_) == 1; 158 vad_tmp_buffer_.get(), chunk_length_) == 1;
161 159
162 // Process and enhance chunk of |audio| 160 // Process and enhance chunk of |audio|
(...skipping 33 matching lines...) Expand 10 before | Expand all | Expand 10 after
196 ProcessClearBlock(in_block, out_block); 194 ProcessClearBlock(in_block, out_block);
197 break; 195 break;
198 case kCaptureStream: 196 case kCaptureStream:
199 ProcessNoiseBlock(in_block, out_block); 197 ProcessNoiseBlock(in_block, out_block);
200 break; 198 break;
201 } 199 }
202 } 200 }
203 201
204 void IntelligibilityEnhancer::ProcessClearBlock(const complex<float>* in_block, 202 void IntelligibilityEnhancer::ProcessClearBlock(const complex<float>* in_block,
205 complex<float>* out_block) { 203 complex<float>* out_block) {
206 float power_target;
207
208 if (block_count_ < 2) { 204 if (block_count_ < 2) {
209 memset(out_block, 0, freqs_ * sizeof(*out_block)); 205 memset(out_block, 0, freqs_ * sizeof(*out_block));
210 ++block_count_; 206 ++block_count_;
211 return; 207 return;
212 } 208 }
213 209
214 // For now, always assumes enhancement is necessary. 210 // For now, always assumes enhancement is necessary.
215 // TODO(ekmeyerson): Change to only enhance if necessary, 211 // TODO(ekmeyerson): Change to only enhance if necessary,
216 // based on experiments with different cutoffs. 212 // based on experiments with different cutoffs.
217 if (has_voice_low_ || true) { 213 if (has_voice_low_ || true) {
218 clear_variance_.Step(in_block, false); 214 clear_variance_.Step(in_block, false);
219 power_target = std::accumulate(clear_variance_.variance(), 215 const float power_target = std::accumulate(
220 clear_variance_.variance() + freqs_, 0.0f); 216 clear_variance_.variance(), clear_variance_.variance() + freqs_, 0.0f);
221 217
222 if (block_count_ % analysis_rate_ == analysis_rate_ - 1) { 218 if (block_count_ % analysis_rate_ == analysis_rate_ - 1) {
223 AnalyzeClearBlock(power_target); 219 AnalyzeClearBlock(power_target);
224 ++analysis_step_; 220 ++analysis_step_;
225 if (analysis_step_ == variance_rate_) { 221 if (analysis_step_ == variance_rate_) {
226 analysis_step_ = 0; 222 analysis_step_ = 0;
227 clear_variance_.Clear(); 223 clear_variance_.Clear();
228 noise_variance_.Clear(); 224 noise_variance_.Clear();
229 } 225 }
230 } 226 }
231 ++block_count_; 227 ++block_count_;
232 } 228 }
233 229
234 /* efidata(n,:) = sqrt(b(n)) * fidata(n,:) */ 230 /* efidata(n,:) = sqrt(b(n)) * fidata(n,:) */
235 gain_applier_.Apply(in_block, out_block); 231 gain_applier_.Apply(in_block, out_block);
236 } 232 }
237 233
238 void IntelligibilityEnhancer::AnalyzeClearBlock(float power_target) { 234 void IntelligibilityEnhancer::AnalyzeClearBlock(float power_target) {
239 FilterVariance(clear_variance_.variance(), filtered_clear_var_.get()); 235 FilterVariance(clear_variance_.variance(), filtered_clear_var_.get());
240 FilterVariance(noise_variance_.variance(), filtered_noise_var_.get()); 236 FilterVariance(noise_variance_.variance(), filtered_noise_var_.get());
241 237
242 // Bisection search for optimal |lambda| 238 SolveForGainsGivenLambda(kLambdaTop, start_freq_, gains_eq_.get());
239 const float power_top =
240 DotProduct(gains_eq_.get(), filtered_clear_var_.get(), bank_size_);
241 SolveForGainsGivenLambda(kLambdaBot, start_freq_, gains_eq_.get());
242 const float power_bot =
243 DotProduct(gains_eq_.get(), filtered_clear_var_.get(), bank_size_);
244 if (power_target >= power_bot && power_target <= power_top) {
245 SolveForLambda(power_target, power_bot, power_top);
246 UpdateErbGains();
247 } // Else experiencing variance underflow, so do nothing.
248 }
243 249
244 float lambda_bot = -1.0f, lambda_top = -10e-18f, lambda; 250 void IntelligibilityEnhancer::SolveForLambda(float power_target,
245 float power_bot, power_top, power; 251 float power_bot,
246 SolveForGainsGivenLambda(lambda_top, start_freq_, gains_eq_.get()); 252 float power_top) {
247 power_top =
248 DotProduct(gains_eq_.get(), filtered_clear_var_.get(), bank_size_);
249 SolveForGainsGivenLambda(lambda_bot, start_freq_, gains_eq_.get());
250 power_bot =
251 DotProduct(gains_eq_.get(), filtered_clear_var_.get(), bank_size_);
252 DCHECK(power_target >= power_bot && power_target <= power_top);
253
254 float power_ratio = 2.0f; // Ratio of achieved power to target power.
255 const float kConvergeThresh = 0.001f; // TODO(ekmeyerson): Find best values 253 const float kConvergeThresh = 0.001f; // TODO(ekmeyerson): Find best values
256 const int kMaxIters = 100; // for these, based on experiments. 254 const int kMaxIters = 100; // for these, based on experiments.
255
256 const float reciprocal_power_target = 1.f / power_target;
257 float lambda_bot = kLambdaBot;
258 float lambda_top = kLambdaTop;
259 float power_ratio = 2.0f; // Ratio of achieved power to target power.
257 int iters = 0; 260 int iters = 0;
258 while (fabs(power_ratio - 1.0f) > kConvergeThresh && iters <= kMaxIters) { 261 while (std::fabs(power_ratio - 1.0f) > kConvergeThresh &&
259 lambda = lambda_bot + (lambda_top - lambda_bot) / 2.0f; 262 iters <= kMaxIters) {
263 const float lambda = lambda_bot + (lambda_top - lambda_bot) / 2.0f;
260 SolveForGainsGivenLambda(lambda, start_freq_, gains_eq_.get()); 264 SolveForGainsGivenLambda(lambda, start_freq_, gains_eq_.get());
261 power = DotProduct(gains_eq_.get(), filtered_clear_var_.get(), bank_size_); 265 const float power =
266 DotProduct(gains_eq_.get(), filtered_clear_var_.get(), bank_size_);
262 if (power < power_target) { 267 if (power < power_target) {
263 lambda_bot = lambda; 268 lambda_bot = lambda;
264 } else { 269 } else {
265 lambda_top = lambda; 270 lambda_top = lambda;
266 } 271 }
267 power_ratio = fabs(power / power_target); 272 power_ratio = std::fabs(power * reciprocal_power_target);
268 ++iters; 273 ++iters;
269 } 274 }
275 }
270 276
277 void IntelligibilityEnhancer::UpdateErbGains() {
271 // (ERB gain) = filterbank' * (freq gain) 278 // (ERB gain) = filterbank' * (freq gain)
272 float* gains = gain_applier_.target(); 279 float* gains = gain_applier_.target();
273 for (int i = 0; i < freqs_; ++i) { 280 for (int i = 0; i < freqs_; ++i) {
274 gains[i] = 0.0f; 281 gains[i] = 0.0f;
275 for (int j = 0; j < bank_size_; ++j) { 282 for (int j = 0; j < bank_size_; ++j) {
276 gains[i] = fmaf(filter_bank_[j][i], gains_eq_[j], gains[i]); 283 gains[i] = fmaf(filter_bank_[j][i], gains_eq_[j], gains[i]);
277 } 284 }
278 } 285 }
279 } 286 }
280 287
(...skipping 15 matching lines...) Expand all
296 for (int i = 0; i < bank_size_; ++i) { 303 for (int i = 0; i < bank_size_; ++i) {
297 float abs_temp = fabsf((i + 1.0f) / static_cast<float>(erb_resolution_)); 304 float abs_temp = fabsf((i + 1.0f) / static_cast<float>(erb_resolution_));
298 center_freqs_[i] = 676170.4f / (47.06538f - expf(0.08950404f * abs_temp)); 305 center_freqs_[i] = 676170.4f / (47.06538f - expf(0.08950404f * abs_temp));
299 center_freqs_[i] -= 14678.49f; 306 center_freqs_[i] -= 14678.49f;
300 } 307 }
301 float last_center_freq = center_freqs_[bank_size_ - 1]; 308 float last_center_freq = center_freqs_[bank_size_ - 1];
302 for (int i = 0; i < bank_size_; ++i) { 309 for (int i = 0; i < bank_size_; ++i) {
303 center_freqs_[i] *= 0.5f * sample_rate_hz_ / last_center_freq; 310 center_freqs_[i] *= 0.5f * sample_rate_hz_ / last_center_freq;
304 } 311 }
305 312
306 filter_bank_ = static_cast<float**>(
307 malloc(sizeof(*filter_bank_) * bank_size_ +
308 sizeof(**filter_bank_) * freqs_ * bank_size_));
309 for (int i = 0; i < bank_size_; ++i) { 313 for (int i = 0; i < bank_size_; ++i) {
310 filter_bank_[i] = 314 filter_bank_[i].resize(freqs_);
311 reinterpret_cast<float*>(filter_bank_ + bank_size_) + freqs_ * i;
312 } 315 }
313 316
314 for (int i = 1; i <= bank_size_; ++i) { 317 for (int i = 1; i <= bank_size_; ++i) {
315 int lll, ll, rr, rrr; 318 int lll, ll, rr, rrr;
316 lll = round(center_freqs_[max(1, i - lf) - 1] * freqs_ / 319 lll = round(center_freqs_[max(1, i - lf) - 1] * freqs_ /
317 (0.5f * sample_rate_hz_)); 320 (0.5f * sample_rate_hz_));
318 ll = 321 ll =
319 round(center_freqs_[max(1, i) - 1] * freqs_ / (0.5f * sample_rate_hz_)); 322 round(center_freqs_[max(1, i) - 1] * freqs_ / (0.5f * sample_rate_hz_));
320 lll = min(freqs_, max(lll, 1)) - 1; 323 lll = min(freqs_, max(lll, 1)) - 1;
321 ll = min(freqs_, max(ll, 1)) - 1; 324 ll = min(freqs_, max(ll, 1)) - 1;
(...skipping 59 matching lines...) Expand 10 before | Expand all | Expand 10 after
381 (-beta0 - sqrtf(beta0 * beta0 - 4 * alpha0 * gamma0)) / (2 * alpha0); 384 (-beta0 - sqrtf(beta0 * beta0 - 4 * alpha0 * gamma0)) / (2 * alpha0);
382 } else { 385 } else {
383 sols[n] = -gamma0 / beta0; 386 sols[n] = -gamma0 / beta0;
384 } 387 }
385 sols[n] = fmax(0, sols[n]); 388 sols[n] = fmax(0, sols[n]);
386 } 389 }
387 } 390 }
388 391
389 void IntelligibilityEnhancer::FilterVariance(const float* var, float* result) { 392 void IntelligibilityEnhancer::FilterVariance(const float* var, float* result) {
390 for (int i = 0; i < bank_size_; ++i) { 393 for (int i = 0; i < bank_size_; ++i) {
391 result[i] = DotProduct(filter_bank_[i], var, freqs_); 394 result[i] = DotProduct(filter_bank_[i].data(), var, freqs_);
392 } 395 }
393 } 396 }
394 397
395 float IntelligibilityEnhancer::DotProduct(const float* a, 398 float IntelligibilityEnhancer::DotProduct(const float* a,
396 const float* b, 399 const float* b,
397 int length) { 400 int length) {
398 float ret = 0.0f; 401 float ret = 0.0f;
399 402
400 for (int i = 0; i < length; ++i) { 403 for (int i = 0; i < length; ++i) {
401 ret = fmaf(a[i], b[i], ret); 404 ret = fmaf(a[i], b[i], ret);
402 } 405 }
403 return ret; 406 return ret;
404 } 407 }
405 408
406 } // namespace webrtc 409 } // namespace webrtc
OLDNEW

Powered by Google App Engine
This is Rietveld 408576698