OLD | NEW |
1 /* | 1 /* |
2 * Copyright (c) 2017 The WebRTC project authors. All Rights Reserved. | 2 * Copyright (c) 2017 The WebRTC project authors. All Rights Reserved. |
3 * | 3 * |
4 * Use of this source code is governed by a BSD-style license | 4 * Use of this source code is governed by a BSD-style license |
5 * that can be found in the LICENSE file in the root of the source | 5 * that can be found in the LICENSE file in the root of the source |
6 * tree. An additional intellectual property rights grant can be found | 6 * tree. An additional intellectual property rights grant can be found |
7 * in the file PATENTS. All contributing project authors may | 7 * in the file PATENTS. All contributing project authors may |
8 * be found in the AUTHORS file in the root of the source tree. | 8 * be found in the AUTHORS file in the root of the source tree. |
9 */ | 9 */ |
10 #include "webrtc/modules/audio_processing/aec3/matched_filter.h" | 10 #include "webrtc/modules/audio_processing/aec3/matched_filter.h" |
11 | 11 |
| 12 #include "webrtc/typedefs.h" |
| 13 #if defined(WEBRTC_ARCH_X86_FAMILY) |
| 14 #include <emmintrin.h> |
| 15 #endif |
12 #include <algorithm> | 16 #include <algorithm> |
13 #include <numeric> | 17 #include <numeric> |
14 | 18 |
15 #include "webrtc/modules/audio_processing/include/audio_processing.h" | 19 #include "webrtc/modules/audio_processing/include/audio_processing.h" |
16 #include "webrtc/modules/audio_processing/logging/apm_data_dumper.h" | 20 #include "webrtc/modules/audio_processing/logging/apm_data_dumper.h" |
17 | 21 |
18 namespace webrtc { | 22 namespace webrtc { |
| 23 namespace aec3 { |
| 24 |
| 25 #if defined(WEBRTC_ARCH_X86_FAMILY) |
| 26 |
| 27 void MatchedFilterCore_SSE2(size_t x_start_index, |
| 28 float x2_sum_threshold, |
| 29 rtc::ArrayView<const float> x, |
| 30 rtc::ArrayView<const float> y, |
| 31 rtc::ArrayView<float> h, |
| 32 bool* filters_updated, |
| 33 float* error_sum) { |
| 34 // Process for all samples in the sub-block. |
| 35 for (size_t i = 0; i < kSubBlockSize; ++i) { |
| 36 // Apply the matched filter as filter * x. and compute x * x. |
| 37 float x2_sum = 0.f; |
| 38 float s = 0; |
| 39 size_t x_index = x_start_index; |
| 40 RTC_DCHECK_EQ(0, h.size() % 4); |
| 41 |
| 42 __m128 s_128 = _mm_set1_ps(0); |
| 43 __m128 x2_sum_128 = _mm_set1_ps(0); |
| 44 |
| 45 size_t k = 0; |
| 46 if (h.size() > (x.size() - x_index)) { |
| 47 const size_t limit = x.size() - x_index; |
| 48 for (; (k + 3) < limit; k += 4, x_index += 4) { |
| 49 const __m128 x_k = _mm_loadu_ps(&x[x_index]); |
| 50 const __m128 h_k = _mm_loadu_ps(&h[k]); |
| 51 const __m128 xx = _mm_mul_ps(x_k, x_k); |
| 52 x2_sum_128 = _mm_add_ps(x2_sum_128, xx); |
| 53 const __m128 hx = _mm_mul_ps(h_k, x_k); |
| 54 s_128 = _mm_add_ps(s_128, hx); |
| 55 } |
| 56 |
| 57 for (; k < limit; ++k, ++x_index) { |
| 58 x2_sum += x[x_index] * x[x_index]; |
| 59 s += h[k] * x[x_index]; |
| 60 } |
| 61 x_index = 0; |
| 62 } |
| 63 |
| 64 for (; k + 3 < h.size(); k += 4, x_index += 4) { |
| 65 const __m128 x_k = _mm_loadu_ps(&x[x_index]); |
| 66 const __m128 h_k = _mm_loadu_ps(&h[k]); |
| 67 const __m128 xx = _mm_mul_ps(x_k, x_k); |
| 68 x2_sum_128 = _mm_add_ps(x2_sum_128, xx); |
| 69 const __m128 hx = _mm_mul_ps(h_k, x_k); |
| 70 s_128 = _mm_add_ps(s_128, hx); |
| 71 } |
| 72 |
| 73 for (; k < h.size(); ++k, ++x_index) { |
| 74 x2_sum += x[x_index] * x[x_index]; |
| 75 s += h[k] * x[x_index]; |
| 76 } |
| 77 |
| 78 float* v = reinterpret_cast<float*>(&x2_sum_128); |
| 79 x2_sum += v[0] + v[1] + v[2] + v[3]; |
| 80 v = reinterpret_cast<float*>(&s_128); |
| 81 s += v[0] + v[1] + v[2] + v[3]; |
| 82 |
| 83 // Compute the matched filter error. |
| 84 const float e = std::min(32767.f, std::max(-32768.f, y[i] - s)); |
| 85 (*error_sum) += e * e; |
| 86 |
| 87 // Update the matched filter estimate in an NLMS manner. |
| 88 if (x2_sum > x2_sum_threshold) { |
| 89 RTC_DCHECK_LT(0.f, x2_sum); |
| 90 const float alpha = 0.7f * e / x2_sum; |
| 91 |
| 92 // filter = filter + 0.7 * (y - filter * x) / x * x. |
| 93 size_t x_index = x_start_index; |
| 94 for (size_t k = 0; k < h.size(); ++k) { |
| 95 h[k] += alpha * x[x_index]; |
| 96 x_index = x_index < (x.size() - 1) ? x_index + 1 : 0; |
| 97 } |
| 98 *filters_updated = true; |
| 99 } |
| 100 |
| 101 x_start_index = x_start_index > 0 ? x_start_index - 1 : x.size() - 1; |
| 102 } |
| 103 } |
| 104 #endif |
| 105 |
| 106 void MatchedFilterCore(size_t x_start_index, |
| 107 float x2_sum_threshold, |
| 108 rtc::ArrayView<const float> x, |
| 109 rtc::ArrayView<const float> y, |
| 110 rtc::ArrayView<float> h, |
| 111 bool* filters_updated, |
| 112 float* error_sum) { |
| 113 // Process for all samples in the sub-block. |
| 114 for (size_t i = 0; i < kSubBlockSize; ++i) { |
| 115 // Apply the matched filter as filter * x. and compute x * x. |
| 116 float x2_sum = 0.f; |
| 117 float s = 0; |
| 118 size_t x_index = x_start_index; |
| 119 for (size_t k = 0; k < h.size(); ++k) { |
| 120 x2_sum += x[x_index] * x[x_index]; |
| 121 s += h[k] * x[x_index]; |
| 122 x_index = x_index < (x.size() - 1) ? x_index + 1 : 0; |
| 123 } |
| 124 |
| 125 // Compute the matched filter error. |
| 126 const float e = std::min(32767.f, std::max(-32768.f, y[i] - s)); |
| 127 (*error_sum) += e * e; |
| 128 |
| 129 // Update the matched filter estimate in an NLMS manner. |
| 130 if (x2_sum > x2_sum_threshold) { |
| 131 RTC_DCHECK_LT(0.f, x2_sum); |
| 132 const float alpha = 0.7f * e / x2_sum; |
| 133 |
| 134 // filter = filter + 0.7 * (y - filter * x) / x * x. |
| 135 size_t x_index = x_start_index; |
| 136 for (size_t k = 0; k < h.size(); ++k) { |
| 137 h[k] += alpha * x[x_index]; |
| 138 x_index = x_index < (x.size() - 1) ? x_index + 1 : 0; |
| 139 } |
| 140 *filters_updated = true; |
| 141 } |
| 142 |
| 143 x_start_index = x_start_index > 0 ? x_start_index - 1 : x.size() - 1; |
| 144 } |
| 145 } |
| 146 |
| 147 } // namespace aec3 |
19 | 148 |
20 MatchedFilter::IndexedBuffer::IndexedBuffer(size_t size) : data(size, 0.f) { | 149 MatchedFilter::IndexedBuffer::IndexedBuffer(size_t size) : data(size, 0.f) { |
21 RTC_DCHECK_EQ(0, size % kSubBlockSize); | 150 RTC_DCHECK_EQ(0, size % kSubBlockSize); |
22 } | 151 } |
23 | 152 |
24 MatchedFilter::IndexedBuffer::~IndexedBuffer() = default; | 153 MatchedFilter::IndexedBuffer::~IndexedBuffer() = default; |
25 | 154 |
26 MatchedFilter::MatchedFilter(ApmDataDumper* data_dumper, | 155 MatchedFilter::MatchedFilter(ApmDataDumper* data_dumper, |
| 156 Aec3Optimization optimization, |
27 size_t window_size_sub_blocks, | 157 size_t window_size_sub_blocks, |
28 int num_matched_filters, | 158 int num_matched_filters, |
29 size_t alignment_shift_sub_blocks) | 159 size_t alignment_shift_sub_blocks) |
30 : data_dumper_(data_dumper), | 160 : data_dumper_(data_dumper), |
| 161 optimization_(optimization), |
31 filter_intra_lag_shift_(alignment_shift_sub_blocks * kSubBlockSize), | 162 filter_intra_lag_shift_(alignment_shift_sub_blocks * kSubBlockSize), |
32 filters_(num_matched_filters, | 163 filters_(num_matched_filters, |
33 std::vector<float>(window_size_sub_blocks * kSubBlockSize, 0.f)), | 164 std::vector<float>(window_size_sub_blocks * kSubBlockSize, 0.f)), |
34 lag_estimates_(num_matched_filters), | 165 lag_estimates_(num_matched_filters), |
35 x_buffer_(kSubBlockSize * | 166 x_buffer_(kSubBlockSize * |
36 (alignment_shift_sub_blocks * num_matched_filters + | 167 (alignment_shift_sub_blocks * num_matched_filters + |
37 window_size_sub_blocks + | 168 window_size_sub_blocks + |
38 1)) { | 169 1)) { |
39 RTC_DCHECK(data_dumper); | 170 RTC_DCHECK(data_dumper); |
40 RTC_DCHECK_EQ(0, x_buffer_.data.size() % kSubBlockSize); | 171 RTC_DCHECK_EQ(0, x_buffer_.data.size() % kSubBlockSize); |
(...skipping 17 matching lines...) Expand all Loading... |
58 | 189 |
59 // Apply all matched filters. | 190 // Apply all matched filters. |
60 size_t alignment_shift = 0; | 191 size_t alignment_shift = 0; |
61 for (size_t n = 0; n < filters_.size(); ++n) { | 192 for (size_t n = 0; n < filters_.size(); ++n) { |
62 float error_sum = 0.f; | 193 float error_sum = 0.f; |
63 bool filters_updated = false; | 194 bool filters_updated = false; |
64 size_t x_start_index = | 195 size_t x_start_index = |
65 (x_buffer_.index + alignment_shift + kSubBlockSize - 1) % | 196 (x_buffer_.index + alignment_shift + kSubBlockSize - 1) % |
66 x_buffer_.data.size(); | 197 x_buffer_.data.size(); |
67 | 198 |
68 // Process for all samples in the sub-block. | 199 switch (optimization_) { |
69 for (size_t i = 0; i < kSubBlockSize; ++i) { | 200 #if defined(WEBRTC_ARCH_X86_FAMILY) |
70 // As x_buffer is a circular buffer, all of the processing is split into | 201 case Aec3Optimization::kSse2: |
71 // two loops around the wrapping of the buffer. | 202 aec3::MatchedFilterCore_SSE2(x_start_index, x2_sum_threshold, |
72 const size_t loop_size_1 = | 203 x_buffer_.data, y, filters_[n], |
73 std::min(filters_[n].size(), x_buffer_.data.size() - x_start_index); | 204 &filters_updated, &error_sum); |
74 const size_t loop_size_2 = filters_[n].size() - loop_size_1; | 205 break; |
75 RTC_DCHECK_EQ(filters_[n].size(), loop_size_1 + loop_size_2); | 206 #endif |
76 | 207 default: |
77 // x * x. | 208 aec3::MatchedFilterCore(x_start_index, x2_sum_threshold, x_buffer_.data, |
78 float x2_sum = std::inner_product( | 209 y, filters_[n], &filters_updated, &error_sum); |
79 x_buffer_.data.begin() + x_start_index, | |
80 x_buffer_.data.begin() + x_start_index + loop_size_1, | |
81 x_buffer_.data.begin() + x_start_index, 0.f); | |
82 // Apply the matched filter as filter * x. | |
83 float s = std::inner_product(filters_[n].begin(), | |
84 filters_[n].begin() + loop_size_1, | |
85 x_buffer_.data.begin() + x_start_index, 0.f); | |
86 | |
87 if (loop_size_2 > 0) { | |
88 // Update the cumulative sum of x * x. | |
89 x2_sum = std::inner_product(x_buffer_.data.begin(), | |
90 x_buffer_.data.begin() + loop_size_2, | |
91 x_buffer_.data.begin(), x2_sum); | |
92 | |
93 // Compute the matched filter output filter * x in a cumulative manner. | |
94 s = std::inner_product(x_buffer_.data.begin(), | |
95 x_buffer_.data.begin() + loop_size_2, | |
96 filters_[n].begin() + loop_size_1, s); | |
97 } | |
98 | |
99 // Compute the matched filter error. | |
100 const float e = std::min(32767.f, std::max(-32768.f, y[i] - s)); | |
101 error_sum += e * e; | |
102 | |
103 // Update the matched filter estimate in an NLMS manner. | |
104 if (x2_sum > x2_sum_threshold) { | |
105 filters_updated = true; | |
106 RTC_DCHECK_LT(0.f, x2_sum); | |
107 const float alpha = 0.7f * e / x2_sum; | |
108 | |
109 // filter = filter + 0.7 * (y - filter * x) / x * x. | |
110 std::transform(filters_[n].begin(), filters_[n].begin() + loop_size_1, | |
111 x_buffer_.data.begin() + x_start_index, | |
112 filters_[n].begin(), | |
113 [&](float a, float b) { return a + alpha * b; }); | |
114 | |
115 if (loop_size_2 > 0) { | |
116 // filter = filter + 0.7 * (y - filter * x) / x * x. | |
117 std::transform(x_buffer_.data.begin(), | |
118 x_buffer_.data.begin() + loop_size_2, | |
119 filters_[n].begin() + loop_size_1, | |
120 filters_[n].begin() + loop_size_1, | |
121 [&](float a, float b) { return b + alpha * a; }); | |
122 } | |
123 } | |
124 | |
125 x_start_index = | |
126 x_start_index > 0 ? x_start_index - 1 : x_buffer_.data.size() - 1; | |
127 } | 210 } |
128 | 211 |
129 // Compute anchor for the matched filter error. | 212 // Compute anchor for the matched filter error. |
130 const float error_sum_anchor = | 213 const float error_sum_anchor = |
131 std::inner_product(y.begin(), y.end(), y.begin(), 0.f); | 214 std::inner_product(y.begin(), y.end(), y.begin(), 0.f); |
132 | 215 |
133 // Estimate the lag in the matched filter as the distance to the portion in | 216 // Estimate the lag in the matched filter as the distance to the portion in |
134 // the filter that contributes the most to the matched filter output. This | 217 // the filter that contributes the most to the matched filter output. This |
135 // is detected as the peak of the matched filter. | 218 // is detected as the peak of the matched filter. |
136 const size_t lag_estimate = std::distance( | 219 const size_t lag_estimate = std::distance( |
137 filters_[n].begin(), | 220 filters_[n].begin(), |
138 std::max_element( | 221 std::max_element( |
139 filters_[n].begin(), filters_[n].end(), | 222 filters_[n].begin(), filters_[n].end(), |
140 [](float a, float b) -> bool { return a * a < b * b; })); | 223 [](float a, float b) -> bool { return a * a < b * b; })); |
141 | 224 |
142 // Update the lag estimates for the matched filter. | 225 // Update the lag estimates for the matched filter. |
143 const float kMatchingFilterThreshold = 0.3f; | 226 const float kMatchingFilterThreshold = 0.1f; |
144 lag_estimates_[n] = | 227 lag_estimates_[n] = LagEstimate( |
145 LagEstimate(error_sum_anchor - error_sum, | 228 error_sum_anchor - error_sum, |
146 error_sum < kMatchingFilterThreshold * error_sum_anchor, | 229 (lag_estimate > 2 && lag_estimate < (filters_[n].size() - 10) && |
147 lag_estimate + alignment_shift, filters_updated); | 230 error_sum < kMatchingFilterThreshold * error_sum_anchor), |
| 231 lag_estimate + alignment_shift, filters_updated); |
148 | 232 |
149 // TODO(peah): Remove once development of EchoCanceller3 is fully done. | 233 // TODO(peah): Remove once development of EchoCanceller3 is fully done. |
150 RTC_DCHECK_EQ(4, filters_.size()); | 234 RTC_DCHECK_EQ(4, filters_.size()); |
151 switch (n) { | 235 switch (n) { |
152 case 0: | 236 case 0: |
153 data_dumper_->DumpRaw("aec3_correlator_0_h", filters_[0]); | 237 data_dumper_->DumpRaw("aec3_correlator_0_h", filters_[0]); |
154 break; | 238 break; |
155 case 1: | 239 case 1: |
156 data_dumper_->DumpRaw("aec3_correlator_1_h", filters_[1]); | 240 data_dumper_->DumpRaw("aec3_correlator_1_h", filters_[1]); |
157 break; | 241 break; |
158 case 2: | 242 case 2: |
159 data_dumper_->DumpRaw("aec3_correlator_2_h", filters_[2]); | 243 data_dumper_->DumpRaw("aec3_correlator_2_h", filters_[2]); |
160 break; | 244 break; |
161 case 3: | 245 case 3: |
162 data_dumper_->DumpRaw("aec3_correlator_3_h", filters_[3]); | 246 data_dumper_->DumpRaw("aec3_correlator_3_h", filters_[3]); |
163 break; | 247 break; |
164 default: | 248 default: |
165 RTC_DCHECK(false); | 249 RTC_DCHECK(false); |
166 } | 250 } |
167 | 251 |
168 alignment_shift += filter_intra_lag_shift_; | 252 alignment_shift += filter_intra_lag_shift_; |
169 } | 253 } |
170 } | 254 } |
171 | 255 |
172 } // namespace webrtc | 256 } // namespace webrtc |
OLD | NEW |