Chromium Code Reviews
chromiumcodereview-hr@appspot.gserviceaccount.com (chromiumcodereview-hr) | Please choose your nickname with Settings | Help | Chromium Project | Gerrit Changes | Sign out
(951)

Side by Side Diff: webrtc/modules/audio_coding/neteq/merge.cc

Issue 1908623002: Avoiding overflow in cross correlation in NetEq. (Closed) Base URL: https://chromium.googlesource.com/external/webrtc.git@master
Patch Set: turn off ubsan as it was Created 4 years, 7 months ago
Use n/p to move between diff chunks; N/P to move between comments. Draft comments are only viewable by you.
Jump to:
View unified diff | Download patch
« no previous file with comments | « webrtc/modules/audio_coding/neteq/merge.h ('k') | webrtc/modules/audio_coding/neteq/neteq.gypi » ('j') | no next file with comments »
Toggle Intra-line Diffs ('i') | Expand Comments ('e') | Collapse Comments ('c') | Show Comments Hide Comments ('s')
OLDNEW
1 /* 1 /*
2 * Copyright (c) 2012 The WebRTC project authors. All Rights Reserved. 2 * Copyright (c) 2012 The WebRTC project authors. All Rights Reserved.
3 * 3 *
4 * Use of this source code is governed by a BSD-style license 4 * Use of this source code is governed by a BSD-style license
5 * that can be found in the LICENSE file in the root of the source 5 * that can be found in the LICENSE file in the root of the source
6 * tree. An additional intellectual property rights grant can be found 6 * tree. An additional intellectual property rights grant can be found
7 * in the file PATENTS. All contributing project authors may 7 * in the file PATENTS. All contributing project authors may
8 * be found in the AUTHORS file in the root of the source tree. 8 * be found in the AUTHORS file in the root of the source tree.
9 */ 9 */
10 10
11 #include "webrtc/modules/audio_coding/neteq/merge.h" 11 #include "webrtc/modules/audio_coding/neteq/merge.h"
12 12
13 #include <assert.h> 13 #include <assert.h>
14 #include <string.h> // memmove, memcpy, memset, size_t 14 #include <string.h> // memmove, memcpy, memset, size_t
15 15
16 #include <algorithm> // min, max 16 #include <algorithm> // min, max
17 #include <memory> 17 #include <memory>
18 18
19 #include "webrtc/common_audio/signal_processing/include/signal_processing_librar y.h" 19 #include "webrtc/common_audio/signal_processing/include/signal_processing_librar y.h"
20 #include "webrtc/modules/audio_coding/neteq/audio_multi_vector.h" 20 #include "webrtc/modules/audio_coding/neteq/audio_multi_vector.h"
21 #include "webrtc/modules/audio_coding/neteq/cross_correlation.h"
21 #include "webrtc/modules/audio_coding/neteq/dsp_helper.h" 22 #include "webrtc/modules/audio_coding/neteq/dsp_helper.h"
22 #include "webrtc/modules/audio_coding/neteq/expand.h" 23 #include "webrtc/modules/audio_coding/neteq/expand.h"
23 #include "webrtc/modules/audio_coding/neteq/sync_buffer.h" 24 #include "webrtc/modules/audio_coding/neteq/sync_buffer.h"
24 25
25 namespace webrtc { 26 namespace webrtc {
26 27
27 Merge::Merge(int fs_hz, 28 Merge::Merge(int fs_hz,
28 size_t num_channels, 29 size_t num_channels,
29 Expand* expand, 30 Expand* expand,
30 SyncBuffer* sync_buffer) 31 SyncBuffer* sync_buffer)
(...skipping 25 matching lines...) Expand all
56 input_vector.PushBackInterleaved(input, input_length); 57 input_vector.PushBackInterleaved(input, input_length);
57 size_t input_length_per_channel = input_vector.Size(); 58 size_t input_length_per_channel = input_vector.Size();
58 assert(input_length_per_channel == input_length / num_channels_); 59 assert(input_length_per_channel == input_length / num_channels_);
59 60
60 size_t best_correlation_index = 0; 61 size_t best_correlation_index = 0;
61 size_t output_length = 0; 62 size_t output_length = 0;
62 63
63 for (size_t channel = 0; channel < num_channels_; ++channel) { 64 for (size_t channel = 0; channel < num_channels_; ++channel) {
64 int16_t* input_channel = &input_vector[channel][0]; 65 int16_t* input_channel = &input_vector[channel][0];
65 int16_t* expanded_channel = &expanded_[channel][0]; 66 int16_t* expanded_channel = &expanded_[channel][0];
66 int16_t expanded_max, input_max;
67 int16_t new_mute_factor = SignalScaling( 67 int16_t new_mute_factor = SignalScaling(
68 input_channel, input_length_per_channel, expanded_channel, 68 input_channel, input_length_per_channel, expanded_channel);
69 &expanded_max, &input_max);
70 69
71 // Adjust muting factor (product of "main" muting factor and expand muting 70 // Adjust muting factor (product of "main" muting factor and expand muting
72 // factor). 71 // factor).
73 int16_t* external_mute_factor = &external_mute_factor_array[channel]; 72 int16_t* external_mute_factor = &external_mute_factor_array[channel];
74 *external_mute_factor = 73 *external_mute_factor =
75 (*external_mute_factor * expand_->MuteFactor(channel)) >> 14; 74 (*external_mute_factor * expand_->MuteFactor(channel)) >> 14;
76 75
77 // Update |external_mute_factor| if it is lower than |new_mute_factor|. 76 // Update |external_mute_factor| if it is lower than |new_mute_factor|.
78 if (new_mute_factor > *external_mute_factor) { 77 if (new_mute_factor > *external_mute_factor) {
79 *external_mute_factor = std::min(new_mute_factor, 78 *external_mute_factor = std::min(new_mute_factor,
80 static_cast<int16_t>(16384)); 79 static_cast<int16_t>(16384));
81 } 80 }
82 81
83 if (channel == 0) { 82 if (channel == 0) {
84 // Downsample, correlate, and find strongest correlation period for the 83 // Downsample, correlate, and find strongest correlation period for the
85 // master (i.e., first) channel only. 84 // master (i.e., first) channel only.
86 // Downsample to 4kHz sample rate. 85 // Downsample to 4kHz sample rate.
87 Downsample(input_channel, input_length_per_channel, expanded_channel, 86 Downsample(input_channel, input_length_per_channel, expanded_channel,
88 expanded_length); 87 expanded_length);
89 88
90 // Calculate the lag of the strongest correlation period. 89 // Calculate the lag of the strongest correlation period.
91 best_correlation_index = CorrelateAndPeakSearch( 90 best_correlation_index = CorrelateAndPeakSearch(
92 expanded_max, input_max, old_length, 91 old_length, input_length_per_channel, expand_period);
93 input_length_per_channel, expand_period);
94 } 92 }
95 93
96 static const int kTempDataSize = 3600; 94 static const int kTempDataSize = 3600;
97 int16_t temp_data[kTempDataSize]; // TODO(hlundin) Remove this. 95 int16_t temp_data[kTempDataSize]; // TODO(hlundin) Remove this.
98 int16_t* decoded_output = temp_data + best_correlation_index; 96 int16_t* decoded_output = temp_data + best_correlation_index;
99 97
100 // Mute the new decoded data if needed (and unmute it linearly). 98 // Mute the new decoded data if needed (and unmute it linearly).
101 // This is the overlapping part of expanded_signal. 99 // This is the overlapping part of expanded_signal.
102 size_t interpolation_length = std::min( 100 size_t interpolation_length = std::min(
103 kMaxCorrelationLength * fs_mult_, 101 kMaxCorrelationLength * fs_mult_,
(...skipping 93 matching lines...) Expand 10 before | Expand all | Expand 10 after
197 expanded_.PushBack(expanded_temp); 195 expanded_.PushBack(expanded_temp);
198 } 196 }
199 // Trim the length to exactly |required_length|. 197 // Trim the length to exactly |required_length|.
200 expanded_.PopBack(expanded_.Size() - required_length); 198 expanded_.PopBack(expanded_.Size() - required_length);
201 } 199 }
202 assert(expanded_.Size() >= required_length); 200 assert(expanded_.Size() >= required_length);
203 return required_length; 201 return required_length;
204 } 202 }
205 203
206 int16_t Merge::SignalScaling(const int16_t* input, size_t input_length, 204 int16_t Merge::SignalScaling(const int16_t* input, size_t input_length,
207 const int16_t* expanded_signal, 205 const int16_t* expanded_signal) const {
208 int16_t* expanded_max, int16_t* input_max) const {
209 // Adjust muting factor if new vector is more or less of the BGN energy. 206 // Adjust muting factor if new vector is more or less of the BGN energy.
210 const size_t mod_input_length = 207 const size_t mod_input_length =
211 std::min(static_cast<size_t>(64 * fs_mult_), input_length); 208 std::min(static_cast<size_t>(64 * fs_mult_), input_length);
212 *expanded_max = WebRtcSpl_MaxAbsValueW16(expanded_signal, mod_input_length); 209 const int16_t expanded_max =
213 *input_max = WebRtcSpl_MaxAbsValueW16(input, mod_input_length); 210 WebRtcSpl_MaxAbsValueW16(expanded_signal, mod_input_length);
211 const int16_t input_max = WebRtcSpl_MaxAbsValueW16(input, mod_input_length);
214 212
215 // Calculate energy of expanded signal. 213 // Calculate energy of expanded signal.
216 // |log_fs_mult| is log2(fs_mult_), but is not exact for 48000 Hz. 214 // |log_fs_mult| is log2(fs_mult_), but is not exact for 48000 Hz.
217 int log_fs_mult = 30 - WebRtcSpl_NormW32(fs_mult_); 215 int log_fs_mult = 30 - WebRtcSpl_NormW32(fs_mult_);
218 int expanded_shift = 6 + log_fs_mult 216 int expanded_shift = 6 + log_fs_mult
219 - WebRtcSpl_NormW32(*expanded_max * *expanded_max); 217 - WebRtcSpl_NormW32(expanded_max * expanded_max);
220 expanded_shift = std::max(expanded_shift, 0); 218 expanded_shift = std::max(expanded_shift, 0);
221 int32_t energy_expanded = WebRtcSpl_DotProductWithScale(expanded_signal, 219 int32_t energy_expanded = WebRtcSpl_DotProductWithScale(expanded_signal,
222 expanded_signal, 220 expanded_signal,
223 mod_input_length, 221 mod_input_length,
224 expanded_shift); 222 expanded_shift);
225 223
226 // Calculate energy of input signal. 224 // Calculate energy of input signal.
227 int input_shift = 6 + log_fs_mult - 225 int input_shift = 6 + log_fs_mult - WebRtcSpl_NormW32(input_max * input_max);
228 WebRtcSpl_NormW32(*input_max * *input_max);
229 input_shift = std::max(input_shift, 0); 226 input_shift = std::max(input_shift, 0);
230 int32_t energy_input = WebRtcSpl_DotProductWithScale(input, input, 227 int32_t energy_input = WebRtcSpl_DotProductWithScale(input, input,
231 mod_input_length, 228 mod_input_length,
232 input_shift); 229 input_shift);
233 230
234 // Align to the same Q-domain. 231 // Align to the same Q-domain.
235 if (input_shift > expanded_shift) { 232 if (input_shift > expanded_shift) {
236 energy_expanded = energy_expanded >> (input_shift - expanded_shift); 233 energy_expanded = energy_expanded >> (input_shift - expanded_shift);
237 } else { 234 } else {
238 energy_input = energy_input >> (expanded_shift - input_shift); 235 energy_input = energy_input >> (expanded_shift - input_shift);
(...skipping 61 matching lines...) Expand 10 before | Expand all | Expand 10 after
300 sizeof(int16_t) * (kInputDownsampLength - downsamp_temp_len)); 297 sizeof(int16_t) * (kInputDownsampLength - downsamp_temp_len));
301 } else { 298 } else {
302 WebRtcSpl_DownsampleFast(&input[signal_offset], 299 WebRtcSpl_DownsampleFast(&input[signal_offset],
303 input_length - signal_offset, input_downsampled_, 300 input_length - signal_offset, input_downsampled_,
304 kInputDownsampLength, filter_coefficients, 301 kInputDownsampLength, filter_coefficients,
305 num_coefficients, decimation_factor, 302 num_coefficients, decimation_factor,
306 kCompensateDelay); 303 kCompensateDelay);
307 } 304 }
308 } 305 }
309 306
310 size_t Merge::CorrelateAndPeakSearch(int16_t expanded_max, int16_t input_max, 307 size_t Merge::CorrelateAndPeakSearch(size_t start_position, size_t input_length,
311 size_t start_position, size_t input_length,
312 size_t expand_period) const { 308 size_t expand_period) const {
313 // Calculate correlation without any normalization. 309 // Calculate correlation without any normalization.
314 const size_t max_corr_length = kMaxCorrelationLength; 310 const size_t max_corr_length = kMaxCorrelationLength;
315 size_t stop_position_downsamp = 311 size_t stop_position_downsamp =
316 std::min(max_corr_length, expand_->max_lag() / (fs_mult_ * 2) + 1); 312 std::min(max_corr_length, expand_->max_lag() / (fs_mult_ * 2) + 1);
317 int correlation_shift = 0;
318 if (expanded_max * input_max > 26843546) {
319 correlation_shift = 3;
320 }
321 313
322 int32_t correlation[kMaxCorrelationLength]; 314 int32_t correlation[kMaxCorrelationLength];
323 WebRtcSpl_CrossCorrelation(correlation, input_downsampled_, 315 CrossCorrelationWithAutoShift(input_downsampled_, expanded_downsampled_,
324 expanded_downsampled_, kInputDownsampLength, 316 kInputDownsampLength, stop_position_downsamp, 1,
325 stop_position_downsamp, correlation_shift, 1); 317 correlation);
326 318
327 // Normalize correlation to 14 bits and copy to a 16-bit array. 319 // Normalize correlation to 14 bits and copy to a 16-bit array.
328 const size_t pad_length = expand_->overlap_length() - 1; 320 const size_t pad_length = expand_->overlap_length() - 1;
329 const size_t correlation_buffer_size = 2 * pad_length + kMaxCorrelationLength; 321 const size_t correlation_buffer_size = 2 * pad_length + kMaxCorrelationLength;
330 std::unique_ptr<int16_t[]> correlation16( 322 std::unique_ptr<int16_t[]> correlation16(
331 new int16_t[correlation_buffer_size]); 323 new int16_t[correlation_buffer_size]);
332 memset(correlation16.get(), 0, correlation_buffer_size * sizeof(int16_t)); 324 memset(correlation16.get(), 0, correlation_buffer_size * sizeof(int16_t));
333 int16_t* correlation_ptr = &correlation16[pad_length]; 325 int16_t* correlation_ptr = &correlation16[pad_length];
334 int32_t max_correlation = WebRtcSpl_MaxAbsValueW32(correlation, 326 int32_t max_correlation = WebRtcSpl_MaxAbsValueW32(correlation,
335 stop_position_downsamp); 327 stop_position_downsamp);
(...skipping 38 matching lines...) Expand 10 before | Expand all | Expand 10 after
374 } 366 }
375 return best_correlation_index; 367 return best_correlation_index;
376 } 368 }
377 369
378 size_t Merge::RequiredFutureSamples() { 370 size_t Merge::RequiredFutureSamples() {
379 return fs_hz_ / 100 * num_channels_; // 10 ms. 371 return fs_hz_ / 100 * num_channels_; // 10 ms.
380 } 372 }
381 373
382 374
383 } // namespace webrtc 375 } // namespace webrtc
OLDNEW
« no previous file with comments | « webrtc/modules/audio_coding/neteq/merge.h ('k') | webrtc/modules/audio_coding/neteq/neteq.gypi » ('j') | no next file with comments »

Powered by Google App Engine
This is Rietveld 408576698