Chromium Code Reviews
chromiumcodereview-hr@appspot.gserviceaccount.com (chromiumcodereview-hr) | Please choose your nickname with Settings | Help | Chromium Project | Gerrit Changes | Sign out
(97)

Side by Side Diff: webrtc/modules/audio_coding/neteq/expand.cc

Issue 1908623002: Avoiding overflow in cross correlation in NetEq. (Closed) Base URL: https://chromium.googlesource.com/external/webrtc.git@master
Patch Set: Created 4 years, 8 months ago
Use n/p to move between diff chunks; N/P to move between comments. Draft comments are only viewable by you.
Jump to:
View unified diff | Download patch
OLDNEW
1 /* 1 /*
2 * Copyright (c) 2012 The WebRTC project authors. All Rights Reserved. 2 * Copyright (c) 2012 The WebRTC project authors. All Rights Reserved.
3 * 3 *
4 * Use of this source code is governed by a BSD-style license 4 * Use of this source code is governed by a BSD-style license
5 * that can be found in the LICENSE file in the root of the source 5 * that can be found in the LICENSE file in the root of the source
6 * tree. An additional intellectual property rights grant can be found 6 * tree. An additional intellectual property rights grant can be found
7 * in the file PATENTS. All contributing project authors may 7 * in the file PATENTS. All contributing project authors may
8 * be found in the AUTHORS file in the root of the source tree. 8 * be found in the AUTHORS file in the root of the source tree.
9 */ 9 */
10 10
11 #include "webrtc/modules/audio_coding/neteq/expand.h" 11 #include "webrtc/modules/audio_coding/neteq/expand.h"
12 12
13 #include <assert.h> 13 #include <assert.h>
14 #include <string.h> // memset 14 #include <string.h> // memset
15 15
16 #include <algorithm> // min, max 16 #include <algorithm> // min, max
17 #include <limits> // numeric_limits<T> 17 #include <limits> // numeric_limits<T>
18 18
19 #include "webrtc/base/safe_conversions.h" 19 #include "webrtc/base/safe_conversions.h"
20 #include "webrtc/common_audio/signal_processing/include/signal_processing_librar y.h" 20 #include "webrtc/common_audio/signal_processing/include/signal_processing_librar y.h"
21 #include "webrtc/modules/audio_coding/neteq/background_noise.h" 21 #include "webrtc/modules/audio_coding/neteq/background_noise.h"
22 #include "webrtc/modules/audio_coding/neteq/dsp_helper.h" 22 #include "webrtc/modules/audio_coding/neteq/dsp_helper.h"
23 #include "webrtc/modules/audio_coding/neteq/random_vector.h" 23 #include "webrtc/modules/audio_coding/neteq/random_vector.h"
24 #include "webrtc/modules/audio_coding/neteq/statistics_calculator.h" 24 #include "webrtc/modules/audio_coding/neteq/statistics_calculator.h"
25 #include "webrtc/modules/audio_coding/neteq/sync_buffer.h" 25 #include "webrtc/modules/audio_coding/neteq/sync_buffer.h"
26 26
27 namespace webrtc { 27 namespace webrtc {
28 28
29 namespace {
30
31 // This function decides the overflow-protecting scaling and call
hlundin-webrtc 2016/04/21 08:52:01 call -> calls
hlundin-webrtc 2016/04/22 06:48:56 Still, call -> calls
minyue-webrtc 2016/04/22 07:39:49 Yes, I realized it when I uploaded the patch.
32 // WebRtcSpl_CrossCorrelation.s
hlundin-webrtc 2016/04/21 08:52:01 Remove the 's' at the end of the line.
minyue-webrtc 2016/04/21 11:06:50 oh, sorry, typo
33 void CrossCorrelation(int32_t* cross_correlation,
34 const int16_t* sequence_1,
35 const int16_t* sequence_2,
36 size_t sequence_1_length,
37 size_t cross_correlation_length,
38 int* right_shifts,
39 int cross_correlation_step) {
hlundin-webrtc 2016/04/21 08:52:01 You are always calling this function with cross_co
minyue-webrtc 2016/04/21 11:06:49 I am about to discuss with you whether we'd put th
hlundin-webrtc 2016/04/22 06:48:56 OK. Keep it like this.
minyue-webrtc 2016/04/22 07:39:49 Do you think it is good that we try to replace Web
hlundin-webrtc 2016/04/22 07:48:20 Replace all in NetEq now. Leave iLBC as it is.
40 // Find the maximum absolute value of sequence_1 and 2.
41 const int16_t max_1 = WebRtcSpl_MaxAbsValueW16(sequence_1, sequence_1_length);
42 const int sequence_2_shift =
43 cross_correlation_step * (cross_correlation_length - 1);
44 const int16_t* sequence_2_start =
45 sequence_2_shift >= 0 ? sequence_2 : sequence_2 + sequence_2_shift;
46 const size_t sequence_2_length = sequence_1_length + abs(sequence_2_shift);
47 const int16_t max_2 =
48 WebRtcSpl_MaxAbsValueW16(sequence_2_start, sequence_2_length);
49
50 // In order to avoid overflow when computing the sum we should scale the
51 // samples so that (in_vector_length * max_1 * max_2) will not overflow.
52 int scaling;
53 const int32_t factor = WEBRTC_SPL_MUL(max_1, max_2) /
hlundin-webrtc 2016/04/21 08:52:01 I think you can omit WEBRTC_SPL_MUL and simply do
minyue-webrtc 2016/04/21 11:06:50 Yes, I should.
54 (WEBRTC_SPL_WORD32_MAX / (int32_t)sequence_1_length);
hlundin-webrtc 2016/04/21 08:52:00 Use c++-style cast.
hlundin-webrtc 2016/04/21 08:52:01 Use std::numeric_limits<int32_t>::max() instead of
minyue-webrtc 2016/04/21 11:06:50 Done.
minyue-webrtc 2016/04/21 11:06:50 Done.
55 scaling = factor == 0 ? 0 : 31 - WebRtcSpl_NormW32(factor);
56
57 assert((double)max_1 * max_2 * sequence_1_length / (1 << scaling) <=
minyue-webrtc 2016/04/20 17:58:59 test code, will remove
hlundin-webrtc 2016/04/21 08:52:00 Acknowledged.
58 WEBRTC_SPL_WORD32_MAX);
59 assert(scaling == 0 ||
60 (double)max_1 * max_2 * sequence_1_length /(1 << scaling) * 2 >
61 WEBRTC_SPL_WORD32_MAX);
62
63 WebRtcSpl_CrossCorrelation(cross_correlation, sequence_1, sequence_2,
64 sequence_1_length, cross_correlation_length,
65 scaling, cross_correlation_step);
66 if (right_shifts)
hlundin-webrtc 2016/04/21 08:52:00 You are not calling this function without a valid
minyue-webrtc 2016/04/21 11:06:50 ok, I was thinking of reducing some uses of the re
hlundin-webrtc 2016/04/22 06:48:56 I see. But then I suggest you let the scaling fact
minyue-webrtc 2016/04/22 07:39:49 Good idea.
67 *right_shifts = scaling;
68 }
69
70 } // namespace
71
29 Expand::Expand(BackgroundNoise* background_noise, 72 Expand::Expand(BackgroundNoise* background_noise,
30 SyncBuffer* sync_buffer, 73 SyncBuffer* sync_buffer,
31 RandomVector* random_vector, 74 RandomVector* random_vector,
32 StatisticsCalculator* statistics, 75 StatisticsCalculator* statistics,
33 int fs, 76 int fs,
34 size_t num_channels) 77 size_t num_channels)
35 : random_vector_(random_vector), 78 : random_vector_(random_vector),
36 sync_buffer_(sync_buffer), 79 sync_buffer_(sync_buffer),
37 first_expand_(true), 80 first_expand_(true),
38 fs_hz_(fs), 81 fs_hz_(fs),
(...skipping 404 matching lines...) Expand 10 before | Expand all | Expand 10 after
443 std::max(std::min(distortion_lag + 10, fs_mult_120), 486 std::max(std::min(distortion_lag + 10, fs_mult_120),
444 static_cast<size_t>(60 * fs_mult)); 487 static_cast<size_t>(60 * fs_mult));
445 488
446 size_t start_index = std::min(distortion_lag, correlation_lag); 489 size_t start_index = std::min(distortion_lag, correlation_lag);
447 size_t correlation_lags = static_cast<size_t>( 490 size_t correlation_lags = static_cast<size_t>(
448 WEBRTC_SPL_ABS_W16((distortion_lag-correlation_lag)) + 1); 491 WEBRTC_SPL_ABS_W16((distortion_lag-correlation_lag)) + 1);
449 assert(correlation_lags <= static_cast<size_t>(99 * fs_mult + 1)); 492 assert(correlation_lags <= static_cast<size_t>(99 * fs_mult + 1));
450 493
451 for (size_t channel_ix = 0; channel_ix < num_channels_; ++channel_ix) { 494 for (size_t channel_ix = 0; channel_ix < num_channels_; ++channel_ix) {
452 ChannelParameters& parameters = channel_parameters_[channel_ix]; 495 ChannelParameters& parameters = channel_parameters_[channel_ix];
453 // Calculate suitable scaling.
454 int16_t signal_max = WebRtcSpl_MaxAbsValueW16(
455 &audio_history[signal_length - correlation_length - start_index
456 - correlation_lags],
457 correlation_length + start_index + correlation_lags - 1);
458 correlation_scale = (31 - WebRtcSpl_NormW32(signal_max * signal_max)) +
459 (31 - WebRtcSpl_NormW32(static_cast<int32_t>(correlation_length))) - 31;
460 correlation_scale = std::max(0, correlation_scale);
461 496
462 // Calculate the correlation, store in |correlation_vector2|. 497 // Calculate the correlation, store in |correlation_vector2|.
463 WebRtcSpl_CrossCorrelation( 498 CrossCorrelation(
464 correlation_vector2, 499 correlation_vector2,
465 &(audio_history[signal_length - correlation_length]), 500 &(audio_history[signal_length - correlation_length]),
466 &(audio_history[signal_length - correlation_length - start_index]), 501 &(audio_history[signal_length - correlation_length - start_index]),
467 correlation_length, correlation_lags, correlation_scale, -1); 502 correlation_length, correlation_lags, &correlation_scale, -1);
468 503
469 // Find maximizing index. 504 // Find maximizing index.
470 best_index = WebRtcSpl_MaxIndexW32(correlation_vector2, correlation_lags); 505 best_index = WebRtcSpl_MaxIndexW32(correlation_vector2, correlation_lags);
471 int32_t max_correlation = correlation_vector2[best_index]; 506 int32_t max_correlation = correlation_vector2[best_index];
472 // Compensate index with start offset. 507 // Compensate index with start offset.
473 best_index = best_index + start_index; 508 best_index = best_index + start_index;
474 509
475 // Calculate energies. 510 // Calculate energies.
476 int32_t energy1 = WebRtcSpl_DotProductWithScale( 511 int32_t energy1 = WebRtcSpl_DotProductWithScale(
477 &(audio_history[signal_length - correlation_length]), 512 &(audio_history[signal_length - correlation_length]),
(...skipping 97 matching lines...) Expand 10 before | Expand all | Expand 10 after
575 expand_lags_[1] = (distortion_lag + correlation_lag) / 2; 610 expand_lags_[1] = (distortion_lag + correlation_lag) / 2;
576 // Third lag is the average again, but rounding towards |correlation_lag|. 611 // Third lag is the average again, but rounding towards |correlation_lag|.
577 if (distortion_lag > correlation_lag) { 612 if (distortion_lag > correlation_lag) {
578 expand_lags_[2] = (distortion_lag + correlation_lag - 1) / 2; 613 expand_lags_[2] = (distortion_lag + correlation_lag - 1) / 2;
579 } else { 614 } else {
580 expand_lags_[2] = (distortion_lag + correlation_lag + 1) / 2; 615 expand_lags_[2] = (distortion_lag + correlation_lag + 1) / 2;
581 } 616 }
582 } 617 }
583 618
584 // Calculate the LPC and the gain of the filters. 619 // Calculate the LPC and the gain of the filters.
585 // Calculate scale value needed for auto-correlation.
586 correlation_scale = WebRtcSpl_MaxAbsValueW16(
587 &(audio_history[signal_length - fs_mult_lpc_analysis_len]),
588 fs_mult_lpc_analysis_len);
589
590 correlation_scale = std::min(16 - WebRtcSpl_NormW32(correlation_scale), 0);
591 correlation_scale = std::max(correlation_scale * 2 + 7, 0);
592 620
593 // Calculate kUnvoicedLpcOrder + 1 lags of the auto-correlation function. 621 // Calculate kUnvoicedLpcOrder + 1 lags of the auto-correlation function.
594 size_t temp_index = signal_length - fs_mult_lpc_analysis_len - 622 size_t temp_index = signal_length - fs_mult_lpc_analysis_len -
595 kUnvoicedLpcOrder; 623 kUnvoicedLpcOrder;
596 // Copy signal to temporary vector to be able to pad with leading zeros. 624 // Copy signal to temporary vector to be able to pad with leading zeros.
597 int16_t* temp_signal = new int16_t[fs_mult_lpc_analysis_len 625 int16_t* temp_signal = new int16_t[fs_mult_lpc_analysis_len
598 + kUnvoicedLpcOrder]; 626 + kUnvoicedLpcOrder];
599 memset(temp_signal, 0, 627 memset(temp_signal, 0,
600 sizeof(int16_t) * (fs_mult_lpc_analysis_len + kUnvoicedLpcOrder)); 628 sizeof(int16_t) * (fs_mult_lpc_analysis_len + kUnvoicedLpcOrder));
601 memcpy(&temp_signal[kUnvoicedLpcOrder], 629 memcpy(&temp_signal[kUnvoicedLpcOrder],
602 &audio_history[temp_index + kUnvoicedLpcOrder], 630 &audio_history[temp_index + kUnvoicedLpcOrder],
603 sizeof(int16_t) * fs_mult_lpc_analysis_len); 631 sizeof(int16_t) * fs_mult_lpc_analysis_len);
604 WebRtcSpl_CrossCorrelation(auto_correlation, 632 CrossCorrelation(auto_correlation,
605 &temp_signal[kUnvoicedLpcOrder], 633 &temp_signal[kUnvoicedLpcOrder],
606 &temp_signal[kUnvoicedLpcOrder], 634 &temp_signal[kUnvoicedLpcOrder],
607 fs_mult_lpc_analysis_len, kUnvoicedLpcOrder + 1, 635 fs_mult_lpc_analysis_len, kUnvoicedLpcOrder + 1,
608 correlation_scale, -1); 636 &correlation_scale, -1);
609 delete [] temp_signal; 637 delete [] temp_signal;
610 638
611 // Verify that variance is positive. 639 // Verify that variance is positive.
612 if (auto_correlation[0] > 0) { 640 if (auto_correlation[0] > 0) {
613 // Estimate AR filter parameters using Levinson-Durbin algorithm; 641 // Estimate AR filter parameters using Levinson-Durbin algorithm;
614 // kUnvoicedLpcOrder + 1 filter coefficients. 642 // kUnvoicedLpcOrder + 1 filter coefficients.
615 int16_t stability = WebRtcSpl_LevinsonDurbin(auto_correlation, 643 int16_t stability = WebRtcSpl_LevinsonDurbin(auto_correlation,
616 parameters.ar_filter, 644 parameters.ar_filter,
617 reflection_coeff, 645 reflection_coeff,
618 kUnvoicedLpcOrder); 646 kUnvoicedLpcOrder);
(...skipping 188 matching lines...) Expand 10 before | Expand all | Expand 10 after
807 downsampling_factor, kFilterDelay); 835 downsampling_factor, kFilterDelay);
808 836
809 // Normalize |downsampled_input| to using all 16 bits. 837 // Normalize |downsampled_input| to using all 16 bits.
810 int16_t max_value = WebRtcSpl_MaxAbsValueW16(downsampled_input, 838 int16_t max_value = WebRtcSpl_MaxAbsValueW16(downsampled_input,
811 kDownsampledLength); 839 kDownsampledLength);
812 int16_t norm_shift = 16 - WebRtcSpl_NormW32(max_value); 840 int16_t norm_shift = 16 - WebRtcSpl_NormW32(max_value);
813 WebRtcSpl_VectorBitShiftW16(downsampled_input, kDownsampledLength, 841 WebRtcSpl_VectorBitShiftW16(downsampled_input, kDownsampledLength,
814 downsampled_input, norm_shift); 842 downsampled_input, norm_shift);
815 843
816 int32_t correlation[kNumCorrelationLags]; 844 int32_t correlation[kNumCorrelationLags];
817 static const int kCorrelationShift = 6; 845 int correlation_shift;
818 WebRtcSpl_CrossCorrelation( 846 CrossCorrelation(
819 correlation, 847 correlation,
820 &downsampled_input[kDownsampledLength - kCorrelationLength], 848 &downsampled_input[kDownsampledLength - kCorrelationLength],
821 &downsampled_input[kDownsampledLength - kCorrelationLength 849 &downsampled_input[kDownsampledLength - kCorrelationLength
822 - kCorrelationStartLag], 850 - kCorrelationStartLag],
823 kCorrelationLength, kNumCorrelationLags, kCorrelationShift, -1); 851 kCorrelationLength, kNumCorrelationLags, &correlation_shift, -1);
824 852
825 // Normalize and move data from 32-bit to 16-bit vector. 853 // Normalize and move data from 32-bit to 16-bit vector.
826 int32_t max_correlation = WebRtcSpl_MaxAbsValueW32(correlation, 854 int32_t max_correlation = WebRtcSpl_MaxAbsValueW32(correlation,
827 kNumCorrelationLags); 855 kNumCorrelationLags);
828 int16_t norm_shift2 = static_cast<int16_t>( 856 int16_t norm_shift2 = static_cast<int16_t>(
829 std::max(18 - WebRtcSpl_NormW32(max_correlation), 0)); 857 std::max(18 - WebRtcSpl_NormW32(max_correlation), 0));
830 WebRtcSpl_VectorBitShiftW32ToW16(output, kNumCorrelationLags, correlation, 858 WebRtcSpl_VectorBitShiftW32ToW16(output, kNumCorrelationLags, correlation,
831 norm_shift2); 859 norm_shift2);
832 // Total scale factor (right shifts) of correlation value. 860 // Total scale factor (right shifts) of correlation value.
833 *output_scale = 2 * norm_shift + kCorrelationShift + norm_shift2; 861 *output_scale = 2 * norm_shift + correlation_shift + norm_shift2;
834 } 862 }
835 863
836 void Expand::UpdateLagIndex() { 864 void Expand::UpdateLagIndex() {
837 current_lag_index_ = current_lag_index_ + lag_index_direction_; 865 current_lag_index_ = current_lag_index_ + lag_index_direction_;
838 // Change direction if needed. 866 // Change direction if needed.
839 if (current_lag_index_ <= 0) { 867 if (current_lag_index_ <= 0) {
840 lag_index_direction_ = 1; 868 lag_index_direction_ = 1;
841 } 869 }
842 if (current_lag_index_ >= kNumLags - 1) { 870 if (current_lag_index_ >= kNumLags - 1) {
843 lag_index_direction_ = -1; 871 lag_index_direction_ = -1;
(...skipping 109 matching lines...) Expand 10 before | Expand all | Expand 10 after
953 const size_t kMaxRandSamples = RandomVector::kRandomTableSize; 981 const size_t kMaxRandSamples = RandomVector::kRandomTableSize;
954 while (samples_generated < length) { 982 while (samples_generated < length) {
955 size_t rand_length = std::min(length - samples_generated, kMaxRandSamples); 983 size_t rand_length = std::min(length - samples_generated, kMaxRandSamples);
956 random_vector_->IncreaseSeedIncrement(seed_increment); 984 random_vector_->IncreaseSeedIncrement(seed_increment);
957 random_vector_->Generate(rand_length, &random_vector[samples_generated]); 985 random_vector_->Generate(rand_length, &random_vector[samples_generated]);
958 samples_generated += rand_length; 986 samples_generated += rand_length;
959 } 987 }
960 } 988 }
961 989
962 } // namespace webrtc 990 } // namespace webrtc
OLDNEW

Powered by Google App Engine
This is Rietveld 408576698