Chromium Code Reviews| Index: webrtc/modules/audio_processing/aec3/matched_filter.cc |
| diff --git a/webrtc/modules/audio_processing/aec3/matched_filter.cc b/webrtc/modules/audio_processing/aec3/matched_filter.cc |
| index f187159911295aaf18771401d3f3cc71fed8c44b..20ceb916a97c85f350fc80f104498548be0c8532 100644 |
| --- a/webrtc/modules/audio_processing/aec3/matched_filter.cc |
| +++ b/webrtc/modules/audio_processing/aec3/matched_filter.cc |
| @@ -9,6 +9,10 @@ |
| */ |
| #include "webrtc/modules/audio_processing/aec3/matched_filter.h" |
| +#include "webrtc/typedefs.h" |
| +#if defined(WEBRTC_ARCH_X86_FAMILY) |
| +#include <emmintrin.h> |
| +#endif |
| #include <algorithm> |
| #include <numeric> |
| @@ -16,6 +20,131 @@ |
| #include "webrtc/modules/audio_processing/logging/apm_data_dumper.h" |
| namespace webrtc { |
| +namespace aec3 { |
| + |
| +#if defined(WEBRTC_ARCH_X86_FAMILY) |
| + |
| +void MatchedFilterCore_SSE2(size_t x_start_index, |
| + float x2_sum_threshold, |
| + rtc::ArrayView<const float> x, |
| + rtc::ArrayView<const float> y, |
| + rtc::ArrayView<float> h, |
| + bool* filters_updated, |
| + float* error_sum) { |
| + // Process for all samples in the sub-block. |
| + for (size_t i = 0; i < kSubBlockSize; ++i) { |
| + // Apply the matched filter as filter * x. and compute x * x. |
| + float x2_sum = 0.f; |
| + float s = 0; |
| + size_t x_index = x_start_index; |
| + RTC_DCHECK_EQ(h.size(), (h.size() / 4) * 4); |
|
ivoc
2017/02/22 17:08:13
I would prefer a check for h.size() % 4 == 0.
peah-webrtc
2017/02/22 23:51:39
Done.
|
| + |
| + __m128 s_128 = _mm_set1_ps(0); |
| + __m128 x2_sum_128 = _mm_set1_ps(0); |
| + |
| + size_t k = 0; |
| + if (h.size() > (x.size() - x_index)) { |
| + const size_t limit = x.size() - x_index; |
| + for (; (k + 3) < limit; k += 4, x_index += 4) { |
| + const __m128 x_k = _mm_loadu_ps(&x[x_index]); |
| + const __m128 h_k = _mm_loadu_ps(&h[k]); |
| + const __m128 xx = _mm_mul_ps(x_k, x_k); |
| + x2_sum_128 = _mm_add_ps(x2_sum_128, xx); |
| + const __m128 hx = _mm_mul_ps(h_k, x_k); |
| + s_128 = _mm_add_ps(s_128, hx); |
| + } |
| + |
| + for (; k < limit; ++k, ++x_index) { |
| + x2_sum += x[x_index] * x[x_index]; |
| + s += h[k] * x[x_index]; |
| + } |
| + x_index = 0; |
| + } |
| + |
| + for (; k + 3 < h.size(); k += 4, x_index += 4) { |
| + const __m128 x_k = _mm_loadu_ps(&x[x_index]); |
| + const __m128 h_k = _mm_loadu_ps(&h[k]); |
| + const __m128 xx = _mm_mul_ps(x_k, x_k); |
| + x2_sum_128 = _mm_add_ps(x2_sum_128, xx); |
| + const __m128 hx = _mm_mul_ps(h_k, x_k); |
| + s_128 = _mm_add_ps(s_128, hx); |
| + } |
| + |
| + for (; k < h.size(); ++k, ++x_index) { |
| + x2_sum += x[x_index] * x[x_index]; |
| + s += h[k] * x[x_index]; |
| + } |
| + |
| + float* v = reinterpret_cast<float*>(&x2_sum_128); |
| + x2_sum += v[0] + v[1] + v[2] + v[3]; |
| + v = reinterpret_cast<float*>(&s_128); |
| + s += v[0] + v[1] + v[2] + v[3]; |
| + |
| + // Compute the matched filter error. |
| + const float e = std::min(32767.f, std::max(-32768.f, y[i] - s)); |
| + (*error_sum) += e * e; |
| + |
| + // Update the matched filter estimate in an NLMS manner. |
| + if (x2_sum > x2_sum_threshold) { |
| + RTC_DCHECK_LT(0.f, x2_sum); |
| + const float alpha = 0.7f * e / x2_sum; |
| + |
| + // filter = filter + 0.7 * (y - filter * x) / x * x. |
| + size_t x_index = x_start_index; |
| + for (size_t k = 0; k < h.size(); ++k) { |
| + h[k] += alpha * x[x_index]; |
| + x_index = x_index < (x.size() - 1) ? x_index + 1 : 0; |
| + } |
| + *filters_updated = true; |
| + } |
| + |
| + x_start_index = x_start_index > 0 ? x_start_index - 1 : x.size() - 1; |
| + } |
| +} |
| +#endif |
| + |
| +void MatchedFilterCore(size_t x_start_index, |
| + float x2_sum_threshold, |
| + rtc::ArrayView<const float> x, |
| + rtc::ArrayView<const float> y, |
| + rtc::ArrayView<float> h, |
| + bool* filters_updated, |
| + float* error_sum) { |
| + // Process for all samples in the sub-block. |
| + for (size_t i = 0; i < kSubBlockSize; ++i) { |
| + // Apply the matched filter as filter * x. and compute x * x. |
| + float x2_sum = 0.f; |
| + float s = 0; |
| + size_t x_index = x_start_index; |
| + for (size_t k = 0; k < h.size(); ++k) { |
| + x2_sum += x[x_index] * x[x_index]; |
| + s += h[k] * x[x_index]; |
| + x_index = x_index < (x.size() - 1) ? x_index + 1 : 0; |
| + } |
| + |
| + // Compute the matched filter error. |
| + const float e = std::min(32767.f, std::max(-32768.f, y[i] - s)); |
| + (*error_sum) += e * e; |
| + |
| + // Update the matched filter estimate in an NLMS manner. |
| + if (x2_sum > x2_sum_threshold) { |
| + RTC_DCHECK_LT(0.f, x2_sum); |
| + const float alpha = 0.7f * e / x2_sum; |
| + |
| + // filter = filter + 0.7 * (y - filter * x) / x * x. |
| + size_t x_index = x_start_index; |
| + for (size_t k = 0; k < h.size(); ++k) { |
| + h[k] += alpha * x[x_index]; |
| + x_index = x_index < (x.size() - 1) ? x_index + 1 : 0; |
| + } |
| + *filters_updated = true; |
| + } |
| + |
| + x_start_index = x_start_index > 0 ? x_start_index - 1 : x.size() - 1; |
| + } |
| +} |
| + |
| +} // namespace aec3 |
| MatchedFilter::IndexedBuffer::IndexedBuffer(size_t size) : data(size, 0.f) { |
| RTC_DCHECK_EQ(0, size % kSubBlockSize); |
| @@ -24,10 +153,12 @@ MatchedFilter::IndexedBuffer::IndexedBuffer(size_t size) : data(size, 0.f) { |
| MatchedFilter::IndexedBuffer::~IndexedBuffer() = default; |
| MatchedFilter::MatchedFilter(ApmDataDumper* data_dumper, |
| + Aec3Optimization optimization, |
| size_t window_size_sub_blocks, |
| int num_matched_filters, |
| size_t alignment_shift_sub_blocks) |
| : data_dumper_(data_dumper), |
| + optimization_(optimization), |
| filter_intra_lag_shift_(alignment_shift_sub_blocks * kSubBlockSize), |
| filters_(num_matched_filters, |
| std::vector<float>(window_size_sub_blocks * kSubBlockSize, 0.f)), |
| @@ -65,65 +196,17 @@ void MatchedFilter::Update(const std::array<float, kSubBlockSize>& render, |
| (x_buffer_.index + alignment_shift + kSubBlockSize - 1) % |
| x_buffer_.data.size(); |
| - // Process for all samples in the sub-block. |
| - for (size_t i = 0; i < kSubBlockSize; ++i) { |
| - // As x_buffer is a circular buffer, all of the processing is split into |
| - // two loops around the wrapping of the buffer. |
| - const size_t loop_size_1 = |
| - std::min(filters_[n].size(), x_buffer_.data.size() - x_start_index); |
| - const size_t loop_size_2 = filters_[n].size() - loop_size_1; |
| - RTC_DCHECK_EQ(filters_[n].size(), loop_size_1 + loop_size_2); |
| - |
| - // x * x. |
| - float x2_sum = std::inner_product( |
| - x_buffer_.data.begin() + x_start_index, |
| - x_buffer_.data.begin() + x_start_index + loop_size_1, |
| - x_buffer_.data.begin() + x_start_index, 0.f); |
| - // Apply the matched filter as filter * x. |
| - float s = std::inner_product(filters_[n].begin(), |
| - filters_[n].begin() + loop_size_1, |
| - x_buffer_.data.begin() + x_start_index, 0.f); |
| - |
| - if (loop_size_2 > 0) { |
| - // Update the cumulative sum of x * x. |
| - x2_sum = std::inner_product(x_buffer_.data.begin(), |
| - x_buffer_.data.begin() + loop_size_2, |
| - x_buffer_.data.begin(), x2_sum); |
| - |
| - // Compute the matched filter output filter * x in a cumulative manner. |
| - s = std::inner_product(x_buffer_.data.begin(), |
| - x_buffer_.data.begin() + loop_size_2, |
| - filters_[n].begin() + loop_size_1, s); |
| - } |
| - |
| - // Compute the matched filter error. |
| - const float e = std::min(32767.f, std::max(-32768.f, y[i] - s)); |
| - error_sum += e * e; |
| - |
| - // Update the matched filter estimate in an NLMS manner. |
| - if (x2_sum > x2_sum_threshold) { |
| - filters_updated = true; |
| - RTC_DCHECK_LT(0.f, x2_sum); |
| - const float alpha = 0.7f * e / x2_sum; |
| - |
| - // filter = filter + 0.7 * (y - filter * x) / x * x. |
| - std::transform(filters_[n].begin(), filters_[n].begin() + loop_size_1, |
| - x_buffer_.data.begin() + x_start_index, |
| - filters_[n].begin(), |
| - [&](float a, float b) { return a + alpha * b; }); |
| - |
| - if (loop_size_2 > 0) { |
| - // filter = filter + 0.7 * (y - filter * x) / x * x. |
| - std::transform(x_buffer_.data.begin(), |
| - x_buffer_.data.begin() + loop_size_2, |
| - filters_[n].begin() + loop_size_1, |
| - filters_[n].begin() + loop_size_1, |
| - [&](float a, float b) { return b + alpha * a; }); |
| - } |
| - } |
| - |
| - x_start_index = |
| - x_start_index > 0 ? x_start_index - 1 : x_buffer_.data.size() - 1; |
| + switch (optimization_) { |
| +#if defined(WEBRTC_ARCH_X86_FAMILY) |
| + case Aec3Optimization::kSse2: |
| + aec3::MatchedFilterCore_SSE2(x_start_index, x2_sum_threshold, |
| + x_buffer_.data, y, filters_[n], |
| + &filters_updated, &error_sum); |
| + break; |
| +#endif |
| + default: |
| + aec3::MatchedFilterCore(x_start_index, x2_sum_threshold, x_buffer_.data, |
| + y, filters_[n], &filters_updated, &error_sum); |
| } |
| // Compute anchor for the matched filter error. |
| @@ -140,11 +223,12 @@ void MatchedFilter::Update(const std::array<float, kSubBlockSize>& render, |
| [](float a, float b) -> bool { return a * a < b * b; })); |
| // Update the lag estimates for the matched filter. |
| - const float kMatchingFilterThreshold = 0.3f; |
| - lag_estimates_[n] = |
| - LagEstimate(error_sum_anchor - error_sum, |
| - error_sum < kMatchingFilterThreshold * error_sum_anchor, |
| - lag_estimate + alignment_shift, filters_updated); |
| + const float kMatchingFilterThreshold = 0.1f; |
| + lag_estimates_[n] = LagEstimate( |
| + error_sum_anchor - error_sum, |
| + (lag_estimate > 2 && lag_estimate < (filters_[n].size() - 10) && |
| + error_sum < kMatchingFilterThreshold * error_sum_anchor), |
| + lag_estimate + alignment_shift, filters_updated); |
| // TODO(peah): Remove once development of EchoCanceller3 is fully done. |
| RTC_DCHECK_EQ(4, filters_.size()); |