Chromium Code Reviews
chromiumcodereview-hr@appspot.gserviceaccount.com (chromiumcodereview-hr) | Please choose your nickname with Settings | Help | Chromium Project | Gerrit Changes | Sign out
(218)

Side by Side Diff: webrtc/modules/audio_processing/aec3/matched_filter.cc

Issue 2678423005: Finalization of the first version of EchoCanceller 3 (Closed)
Patch Set: Fixed compilation error Created 3 years, 10 months ago
Use n/p to move between diff chunks; N/P to move between comments. Draft comments are only viewable by you.
Jump to:
View unified diff | Download patch
OLDNEW
1 /* 1 /*
2 * Copyright (c) 2017 The WebRTC project authors. All Rights Reserved. 2 * Copyright (c) 2017 The WebRTC project authors. All Rights Reserved.
3 * 3 *
4 * Use of this source code is governed by a BSD-style license 4 * Use of this source code is governed by a BSD-style license
5 * that can be found in the LICENSE file in the root of the source 5 * that can be found in the LICENSE file in the root of the source
6 * tree. An additional intellectual property rights grant can be found 6 * tree. An additional intellectual property rights grant can be found
7 * in the file PATENTS. All contributing project authors may 7 * in the file PATENTS. All contributing project authors may
8 * be found in the AUTHORS file in the root of the source tree. 8 * be found in the AUTHORS file in the root of the source tree.
9 */ 9 */
10 #include "webrtc/modules/audio_processing/aec3/matched_filter.h" 10 #include "webrtc/modules/audio_processing/aec3/matched_filter.h"
11 11
12 #include "webrtc/typedefs.h"
13 #if defined(WEBRTC_ARCH_X86_FAMILY)
14 #include <emmintrin.h>
15 #endif
12 #include <algorithm> 16 #include <algorithm>
13 #include <numeric> 17 #include <numeric>
14 18
15 #include "webrtc/modules/audio_processing/include/audio_processing.h" 19 #include "webrtc/modules/audio_processing/include/audio_processing.h"
16 #include "webrtc/modules/audio_processing/logging/apm_data_dumper.h" 20 #include "webrtc/modules/audio_processing/logging/apm_data_dumper.h"
17 21
18 namespace webrtc { 22 namespace webrtc {
23 namespace aec3 {
24
25 #if defined(WEBRTC_ARCH_X86_FAMILY)
26
27 void MatchedFilterCore_SSE2(size_t x_start_index,
28 float x2_sum_threshold,
29 rtc::ArrayView<const float> x,
30 rtc::ArrayView<const float> y,
31 rtc::ArrayView<float> h,
32 bool* filters_updated,
33 float* error_sum) {
34 // Process for all samples in the sub-block.
35 for (size_t i = 0; i < kSubBlockSize; ++i) {
36 // Apply the matched filter as filter * x. and compute x * x.
37 float x2_sum = 0.f;
38 float s = 0;
39 size_t x_index = x_start_index;
40 RTC_DCHECK_EQ(0, h.size() % 4);
41
42 __m128 s_128 = _mm_set1_ps(0);
43 __m128 x2_sum_128 = _mm_set1_ps(0);
44
45 size_t k = 0;
46 if (h.size() > (x.size() - x_index)) {
47 const size_t limit = x.size() - x_index;
48 for (; (k + 3) < limit; k += 4, x_index += 4) {
49 const __m128 x_k = _mm_loadu_ps(&x[x_index]);
50 const __m128 h_k = _mm_loadu_ps(&h[k]);
51 const __m128 xx = _mm_mul_ps(x_k, x_k);
52 x2_sum_128 = _mm_add_ps(x2_sum_128, xx);
53 const __m128 hx = _mm_mul_ps(h_k, x_k);
54 s_128 = _mm_add_ps(s_128, hx);
55 }
56
57 for (; k < limit; ++k, ++x_index) {
58 x2_sum += x[x_index] * x[x_index];
59 s += h[k] * x[x_index];
60 }
61 x_index = 0;
62 }
63
64 for (; k + 3 < h.size(); k += 4, x_index += 4) {
65 const __m128 x_k = _mm_loadu_ps(&x[x_index]);
66 const __m128 h_k = _mm_loadu_ps(&h[k]);
67 const __m128 xx = _mm_mul_ps(x_k, x_k);
68 x2_sum_128 = _mm_add_ps(x2_sum_128, xx);
69 const __m128 hx = _mm_mul_ps(h_k, x_k);
70 s_128 = _mm_add_ps(s_128, hx);
71 }
72
73 for (; k < h.size(); ++k, ++x_index) {
74 x2_sum += x[x_index] * x[x_index];
75 s += h[k] * x[x_index];
76 }
77
78 float* v = reinterpret_cast<float*>(&x2_sum_128);
79 x2_sum += v[0] + v[1] + v[2] + v[3];
80 v = reinterpret_cast<float*>(&s_128);
81 s += v[0] + v[1] + v[2] + v[3];
82
83 // Compute the matched filter error.
84 const float e = std::min(32767.f, std::max(-32768.f, y[i] - s));
85 (*error_sum) += e * e;
86
87 // Update the matched filter estimate in an NLMS manner.
88 if (x2_sum > x2_sum_threshold) {
89 RTC_DCHECK_LT(0.f, x2_sum);
90 const float alpha = 0.7f * e / x2_sum;
91
92 // filter = filter + 0.7 * (y - filter * x) / x * x.
93 size_t x_index = x_start_index;
94 for (size_t k = 0; k < h.size(); ++k) {
95 h[k] += alpha * x[x_index];
96 x_index = x_index < (x.size() - 1) ? x_index + 1 : 0;
97 }
98 *filters_updated = true;
99 }
100
101 x_start_index = x_start_index > 0 ? x_start_index - 1 : x.size() - 1;
102 }
103 }
104 #endif
105
106 void MatchedFilterCore(size_t x_start_index,
107 float x2_sum_threshold,
108 rtc::ArrayView<const float> x,
109 rtc::ArrayView<const float> y,
110 rtc::ArrayView<float> h,
111 bool* filters_updated,
112 float* error_sum) {
113 // Process for all samples in the sub-block.
114 for (size_t i = 0; i < kSubBlockSize; ++i) {
115 // Apply the matched filter as filter * x. and compute x * x.
116 float x2_sum = 0.f;
117 float s = 0;
118 size_t x_index = x_start_index;
119 for (size_t k = 0; k < h.size(); ++k) {
120 x2_sum += x[x_index] * x[x_index];
121 s += h[k] * x[x_index];
122 x_index = x_index < (x.size() - 1) ? x_index + 1 : 0;
123 }
124
125 // Compute the matched filter error.
126 const float e = std::min(32767.f, std::max(-32768.f, y[i] - s));
127 (*error_sum) += e * e;
128
129 // Update the matched filter estimate in an NLMS manner.
130 if (x2_sum > x2_sum_threshold) {
131 RTC_DCHECK_LT(0.f, x2_sum);
132 const float alpha = 0.7f * e / x2_sum;
133
134 // filter = filter + 0.7 * (y - filter * x) / x * x.
135 size_t x_index = x_start_index;
136 for (size_t k = 0; k < h.size(); ++k) {
137 h[k] += alpha * x[x_index];
138 x_index = x_index < (x.size() - 1) ? x_index + 1 : 0;
139 }
140 *filters_updated = true;
141 }
142
143 x_start_index = x_start_index > 0 ? x_start_index - 1 : x.size() - 1;
144 }
145 }
146
147 } // namespace aec3
19 148
20 MatchedFilter::IndexedBuffer::IndexedBuffer(size_t size) : data(size, 0.f) { 149 MatchedFilter::IndexedBuffer::IndexedBuffer(size_t size) : data(size, 0.f) {
21 RTC_DCHECK_EQ(0, size % kSubBlockSize); 150 RTC_DCHECK_EQ(0, size % kSubBlockSize);
22 } 151 }
23 152
24 MatchedFilter::IndexedBuffer::~IndexedBuffer() = default; 153 MatchedFilter::IndexedBuffer::~IndexedBuffer() = default;
25 154
26 MatchedFilter::MatchedFilter(ApmDataDumper* data_dumper, 155 MatchedFilter::MatchedFilter(ApmDataDumper* data_dumper,
156 Aec3Optimization optimization,
27 size_t window_size_sub_blocks, 157 size_t window_size_sub_blocks,
28 int num_matched_filters, 158 int num_matched_filters,
29 size_t alignment_shift_sub_blocks) 159 size_t alignment_shift_sub_blocks)
30 : data_dumper_(data_dumper), 160 : data_dumper_(data_dumper),
161 optimization_(optimization),
31 filter_intra_lag_shift_(alignment_shift_sub_blocks * kSubBlockSize), 162 filter_intra_lag_shift_(alignment_shift_sub_blocks * kSubBlockSize),
32 filters_(num_matched_filters, 163 filters_(num_matched_filters,
33 std::vector<float>(window_size_sub_blocks * kSubBlockSize, 0.f)), 164 std::vector<float>(window_size_sub_blocks * kSubBlockSize, 0.f)),
34 lag_estimates_(num_matched_filters), 165 lag_estimates_(num_matched_filters),
35 x_buffer_(kSubBlockSize * 166 x_buffer_(kSubBlockSize *
36 (alignment_shift_sub_blocks * num_matched_filters + 167 (alignment_shift_sub_blocks * num_matched_filters +
37 window_size_sub_blocks + 168 window_size_sub_blocks +
38 1)) { 169 1)) {
39 RTC_DCHECK(data_dumper); 170 RTC_DCHECK(data_dumper);
40 RTC_DCHECK_EQ(0, x_buffer_.data.size() % kSubBlockSize); 171 RTC_DCHECK_EQ(0, x_buffer_.data.size() % kSubBlockSize);
(...skipping 17 matching lines...) Expand all
58 189
59 // Apply all matched filters. 190 // Apply all matched filters.
60 size_t alignment_shift = 0; 191 size_t alignment_shift = 0;
61 for (size_t n = 0; n < filters_.size(); ++n) { 192 for (size_t n = 0; n < filters_.size(); ++n) {
62 float error_sum = 0.f; 193 float error_sum = 0.f;
63 bool filters_updated = false; 194 bool filters_updated = false;
64 size_t x_start_index = 195 size_t x_start_index =
65 (x_buffer_.index + alignment_shift + kSubBlockSize - 1) % 196 (x_buffer_.index + alignment_shift + kSubBlockSize - 1) %
66 x_buffer_.data.size(); 197 x_buffer_.data.size();
67 198
68 // Process for all samples in the sub-block. 199 switch (optimization_) {
69 for (size_t i = 0; i < kSubBlockSize; ++i) { 200 #if defined(WEBRTC_ARCH_X86_FAMILY)
70 // As x_buffer is a circular buffer, all of the processing is split into 201 case Aec3Optimization::kSse2:
71 // two loops around the wrapping of the buffer. 202 aec3::MatchedFilterCore_SSE2(x_start_index, x2_sum_threshold,
72 const size_t loop_size_1 = 203 x_buffer_.data, y, filters_[n],
73 std::min(filters_[n].size(), x_buffer_.data.size() - x_start_index); 204 &filters_updated, &error_sum);
74 const size_t loop_size_2 = filters_[n].size() - loop_size_1; 205 break;
75 RTC_DCHECK_EQ(filters_[n].size(), loop_size_1 + loop_size_2); 206 #endif
76 207 default:
77 // x * x. 208 aec3::MatchedFilterCore(x_start_index, x2_sum_threshold, x_buffer_.data,
78 float x2_sum = std::inner_product( 209 y, filters_[n], &filters_updated, &error_sum);
79 x_buffer_.data.begin() + x_start_index,
80 x_buffer_.data.begin() + x_start_index + loop_size_1,
81 x_buffer_.data.begin() + x_start_index, 0.f);
82 // Apply the matched filter as filter * x.
83 float s = std::inner_product(filters_[n].begin(),
84 filters_[n].begin() + loop_size_1,
85 x_buffer_.data.begin() + x_start_index, 0.f);
86
87 if (loop_size_2 > 0) {
88 // Update the cumulative sum of x * x.
89 x2_sum = std::inner_product(x_buffer_.data.begin(),
90 x_buffer_.data.begin() + loop_size_2,
91 x_buffer_.data.begin(), x2_sum);
92
93 // Compute the matched filter output filter * x in a cumulative manner.
94 s = std::inner_product(x_buffer_.data.begin(),
95 x_buffer_.data.begin() + loop_size_2,
96 filters_[n].begin() + loop_size_1, s);
97 }
98
99 // Compute the matched filter error.
100 const float e = std::min(32767.f, std::max(-32768.f, y[i] - s));
101 error_sum += e * e;
102
103 // Update the matched filter estimate in an NLMS manner.
104 if (x2_sum > x2_sum_threshold) {
105 filters_updated = true;
106 RTC_DCHECK_LT(0.f, x2_sum);
107 const float alpha = 0.7f * e / x2_sum;
108
109 // filter = filter + 0.7 * (y - filter * x) / x * x.
110 std::transform(filters_[n].begin(), filters_[n].begin() + loop_size_1,
111 x_buffer_.data.begin() + x_start_index,
112 filters_[n].begin(),
113 [&](float a, float b) { return a + alpha * b; });
114
115 if (loop_size_2 > 0) {
116 // filter = filter + 0.7 * (y - filter * x) / x * x.
117 std::transform(x_buffer_.data.begin(),
118 x_buffer_.data.begin() + loop_size_2,
119 filters_[n].begin() + loop_size_1,
120 filters_[n].begin() + loop_size_1,
121 [&](float a, float b) { return b + alpha * a; });
122 }
123 }
124
125 x_start_index =
126 x_start_index > 0 ? x_start_index - 1 : x_buffer_.data.size() - 1;
127 } 210 }
128 211
129 // Compute anchor for the matched filter error. 212 // Compute anchor for the matched filter error.
130 const float error_sum_anchor = 213 const float error_sum_anchor =
131 std::inner_product(y.begin(), y.end(), y.begin(), 0.f); 214 std::inner_product(y.begin(), y.end(), y.begin(), 0.f);
132 215
133 // Estimate the lag in the matched filter as the distance to the portion in 216 // Estimate the lag in the matched filter as the distance to the portion in
134 // the filter that contributes the most to the matched filter output. This 217 // the filter that contributes the most to the matched filter output. This
135 // is detected as the peak of the matched filter. 218 // is detected as the peak of the matched filter.
136 const size_t lag_estimate = std::distance( 219 const size_t lag_estimate = std::distance(
137 filters_[n].begin(), 220 filters_[n].begin(),
138 std::max_element( 221 std::max_element(
139 filters_[n].begin(), filters_[n].end(), 222 filters_[n].begin(), filters_[n].end(),
140 [](float a, float b) -> bool { return a * a < b * b; })); 223 [](float a, float b) -> bool { return a * a < b * b; }));
141 224
142 // Update the lag estimates for the matched filter. 225 // Update the lag estimates for the matched filter.
143 const float kMatchingFilterThreshold = 0.3f; 226 const float kMatchingFilterThreshold = 0.1f;
144 lag_estimates_[n] = 227 lag_estimates_[n] = LagEstimate(
145 LagEstimate(error_sum_anchor - error_sum, 228 error_sum_anchor - error_sum,
146 error_sum < kMatchingFilterThreshold * error_sum_anchor, 229 (lag_estimate > 2 && lag_estimate < (filters_[n].size() - 10) &&
147 lag_estimate + alignment_shift, filters_updated); 230 error_sum < kMatchingFilterThreshold * error_sum_anchor),
231 lag_estimate + alignment_shift, filters_updated);
148 232
149 // TODO(peah): Remove once development of EchoCanceller3 is fully done. 233 // TODO(peah): Remove once development of EchoCanceller3 is fully done.
150 RTC_DCHECK_EQ(4, filters_.size()); 234 RTC_DCHECK_EQ(4, filters_.size());
151 switch (n) { 235 switch (n) {
152 case 0: 236 case 0:
153 data_dumper_->DumpRaw("aec3_correlator_0_h", filters_[0]); 237 data_dumper_->DumpRaw("aec3_correlator_0_h", filters_[0]);
154 break; 238 break;
155 case 1: 239 case 1:
156 data_dumper_->DumpRaw("aec3_correlator_1_h", filters_[1]); 240 data_dumper_->DumpRaw("aec3_correlator_1_h", filters_[1]);
157 break; 241 break;
158 case 2: 242 case 2:
159 data_dumper_->DumpRaw("aec3_correlator_2_h", filters_[2]); 243 data_dumper_->DumpRaw("aec3_correlator_2_h", filters_[2]);
160 break; 244 break;
161 case 3: 245 case 3:
162 data_dumper_->DumpRaw("aec3_correlator_3_h", filters_[3]); 246 data_dumper_->DumpRaw("aec3_correlator_3_h", filters_[3]);
163 break; 247 break;
164 default: 248 default:
165 RTC_DCHECK(false); 249 RTC_DCHECK(false);
166 } 250 }
167 251
168 alignment_shift += filter_intra_lag_shift_; 252 alignment_shift += filter_intra_lag_shift_;
169 } 253 }
170 } 254 }
171 255
172 } // namespace webrtc 256 } // namespace webrtc
OLDNEW

Powered by Google App Engine
This is Rietveld 408576698