| Index: webrtc/modules/audio_processing/aec3/matched_filter.cc
|
| diff --git a/webrtc/modules/audio_processing/aec3/matched_filter.cc b/webrtc/modules/audio_processing/aec3/matched_filter.cc
|
| index f187159911295aaf18771401d3f3cc71fed8c44b..64596b53c09a5fba195a87ed34aff98d64e40bd7 100644
|
| --- a/webrtc/modules/audio_processing/aec3/matched_filter.cc
|
| +++ b/webrtc/modules/audio_processing/aec3/matched_filter.cc
|
| @@ -9,6 +9,10 @@
|
| */
|
| #include "webrtc/modules/audio_processing/aec3/matched_filter.h"
|
|
|
| +#include "webrtc/typedefs.h"
|
| +#if defined(WEBRTC_ARCH_X86_FAMILY)
|
| +#include <emmintrin.h>
|
| +#endif
|
| #include <algorithm>
|
| #include <numeric>
|
|
|
| @@ -16,6 +20,131 @@
|
| #include "webrtc/modules/audio_processing/logging/apm_data_dumper.h"
|
|
|
| namespace webrtc {
|
| +namespace aec3 {
|
| +
|
| +#if defined(WEBRTC_ARCH_X86_FAMILY)
|
| +
|
| +void MatchedFilterCore_SSE2(size_t x_start_index,
|
| + float x2_sum_threshold,
|
| + rtc::ArrayView<const float> x,
|
| + rtc::ArrayView<const float> y,
|
| + rtc::ArrayView<float> h,
|
| + bool* filters_updated,
|
| + float* error_sum) {
|
| + // Process for all samples in the sub-block.
|
| + for (size_t i = 0; i < kSubBlockSize; ++i) {
|
| + // Apply the matched filter as filter * x. and compute x * x.
|
| + float x2_sum = 0.f;
|
| + float s = 0;
|
| + size_t x_index = x_start_index;
|
| + RTC_DCHECK_EQ(0, h.size() % 4);
|
| +
|
| + __m128 s_128 = _mm_set1_ps(0);
|
| + __m128 x2_sum_128 = _mm_set1_ps(0);
|
| +
|
| + size_t k = 0;
|
| + if (h.size() > (x.size() - x_index)) {
|
| + const size_t limit = x.size() - x_index;
|
| + for (; (k + 3) < limit; k += 4, x_index += 4) {
|
| + const __m128 x_k = _mm_loadu_ps(&x[x_index]);
|
| + const __m128 h_k = _mm_loadu_ps(&h[k]);
|
| + const __m128 xx = _mm_mul_ps(x_k, x_k);
|
| + x2_sum_128 = _mm_add_ps(x2_sum_128, xx);
|
| + const __m128 hx = _mm_mul_ps(h_k, x_k);
|
| + s_128 = _mm_add_ps(s_128, hx);
|
| + }
|
| +
|
| + for (; k < limit; ++k, ++x_index) {
|
| + x2_sum += x[x_index] * x[x_index];
|
| + s += h[k] * x[x_index];
|
| + }
|
| + x_index = 0;
|
| + }
|
| +
|
| + for (; k + 3 < h.size(); k += 4, x_index += 4) {
|
| + const __m128 x_k = _mm_loadu_ps(&x[x_index]);
|
| + const __m128 h_k = _mm_loadu_ps(&h[k]);
|
| + const __m128 xx = _mm_mul_ps(x_k, x_k);
|
| + x2_sum_128 = _mm_add_ps(x2_sum_128, xx);
|
| + const __m128 hx = _mm_mul_ps(h_k, x_k);
|
| + s_128 = _mm_add_ps(s_128, hx);
|
| + }
|
| +
|
| + for (; k < h.size(); ++k, ++x_index) {
|
| + x2_sum += x[x_index] * x[x_index];
|
| + s += h[k] * x[x_index];
|
| + }
|
| +
|
| + float* v = reinterpret_cast<float*>(&x2_sum_128);
|
| + x2_sum += v[0] + v[1] + v[2] + v[3];
|
| + v = reinterpret_cast<float*>(&s_128);
|
| + s += v[0] + v[1] + v[2] + v[3];
|
| +
|
| + // Compute the matched filter error.
|
| + const float e = std::min(32767.f, std::max(-32768.f, y[i] - s));
|
| + (*error_sum) += e * e;
|
| +
|
| + // Update the matched filter estimate in an NLMS manner.
|
| + if (x2_sum > x2_sum_threshold) {
|
| + RTC_DCHECK_LT(0.f, x2_sum);
|
| + const float alpha = 0.7f * e / x2_sum;
|
| +
|
| + // filter = filter + 0.7 * (y - filter * x) / x * x.
|
| + size_t x_index = x_start_index;
|
| + for (size_t k = 0; k < h.size(); ++k) {
|
| + h[k] += alpha * x[x_index];
|
| + x_index = x_index < (x.size() - 1) ? x_index + 1 : 0;
|
| + }
|
| + *filters_updated = true;
|
| + }
|
| +
|
| + x_start_index = x_start_index > 0 ? x_start_index - 1 : x.size() - 1;
|
| + }
|
| +}
|
| +#endif
|
| +
|
| +void MatchedFilterCore(size_t x_start_index,
|
| + float x2_sum_threshold,
|
| + rtc::ArrayView<const float> x,
|
| + rtc::ArrayView<const float> y,
|
| + rtc::ArrayView<float> h,
|
| + bool* filters_updated,
|
| + float* error_sum) {
|
| + // Process for all samples in the sub-block.
|
| + for (size_t i = 0; i < kSubBlockSize; ++i) {
|
| + // Apply the matched filter as filter * x. and compute x * x.
|
| + float x2_sum = 0.f;
|
| + float s = 0;
|
| + size_t x_index = x_start_index;
|
| + for (size_t k = 0; k < h.size(); ++k) {
|
| + x2_sum += x[x_index] * x[x_index];
|
| + s += h[k] * x[x_index];
|
| + x_index = x_index < (x.size() - 1) ? x_index + 1 : 0;
|
| + }
|
| +
|
| + // Compute the matched filter error.
|
| + const float e = std::min(32767.f, std::max(-32768.f, y[i] - s));
|
| + (*error_sum) += e * e;
|
| +
|
| + // Update the matched filter estimate in an NLMS manner.
|
| + if (x2_sum > x2_sum_threshold) {
|
| + RTC_DCHECK_LT(0.f, x2_sum);
|
| + const float alpha = 0.7f * e / x2_sum;
|
| +
|
| + // filter = filter + 0.7 * (y - filter * x) / x * x.
|
| + size_t x_index = x_start_index;
|
| + for (size_t k = 0; k < h.size(); ++k) {
|
| + h[k] += alpha * x[x_index];
|
| + x_index = x_index < (x.size() - 1) ? x_index + 1 : 0;
|
| + }
|
| + *filters_updated = true;
|
| + }
|
| +
|
| + x_start_index = x_start_index > 0 ? x_start_index - 1 : x.size() - 1;
|
| + }
|
| +}
|
| +
|
| +} // namespace aec3
|
|
|
| MatchedFilter::IndexedBuffer::IndexedBuffer(size_t size) : data(size, 0.f) {
|
| RTC_DCHECK_EQ(0, size % kSubBlockSize);
|
| @@ -24,10 +153,12 @@ MatchedFilter::IndexedBuffer::IndexedBuffer(size_t size) : data(size, 0.f) {
|
| MatchedFilter::IndexedBuffer::~IndexedBuffer() = default;
|
|
|
| MatchedFilter::MatchedFilter(ApmDataDumper* data_dumper,
|
| + Aec3Optimization optimization,
|
| size_t window_size_sub_blocks,
|
| int num_matched_filters,
|
| size_t alignment_shift_sub_blocks)
|
| : data_dumper_(data_dumper),
|
| + optimization_(optimization),
|
| filter_intra_lag_shift_(alignment_shift_sub_blocks * kSubBlockSize),
|
| filters_(num_matched_filters,
|
| std::vector<float>(window_size_sub_blocks * kSubBlockSize, 0.f)),
|
| @@ -65,65 +196,17 @@ void MatchedFilter::Update(const std::array<float, kSubBlockSize>& render,
|
| (x_buffer_.index + alignment_shift + kSubBlockSize - 1) %
|
| x_buffer_.data.size();
|
|
|
| - // Process for all samples in the sub-block.
|
| - for (size_t i = 0; i < kSubBlockSize; ++i) {
|
| - // As x_buffer is a circular buffer, all of the processing is split into
|
| - // two loops around the wrapping of the buffer.
|
| - const size_t loop_size_1 =
|
| - std::min(filters_[n].size(), x_buffer_.data.size() - x_start_index);
|
| - const size_t loop_size_2 = filters_[n].size() - loop_size_1;
|
| - RTC_DCHECK_EQ(filters_[n].size(), loop_size_1 + loop_size_2);
|
| -
|
| - // x * x.
|
| - float x2_sum = std::inner_product(
|
| - x_buffer_.data.begin() + x_start_index,
|
| - x_buffer_.data.begin() + x_start_index + loop_size_1,
|
| - x_buffer_.data.begin() + x_start_index, 0.f);
|
| - // Apply the matched filter as filter * x.
|
| - float s = std::inner_product(filters_[n].begin(),
|
| - filters_[n].begin() + loop_size_1,
|
| - x_buffer_.data.begin() + x_start_index, 0.f);
|
| -
|
| - if (loop_size_2 > 0) {
|
| - // Update the cumulative sum of x * x.
|
| - x2_sum = std::inner_product(x_buffer_.data.begin(),
|
| - x_buffer_.data.begin() + loop_size_2,
|
| - x_buffer_.data.begin(), x2_sum);
|
| -
|
| - // Compute the matched filter output filter * x in a cumulative manner.
|
| - s = std::inner_product(x_buffer_.data.begin(),
|
| - x_buffer_.data.begin() + loop_size_2,
|
| - filters_[n].begin() + loop_size_1, s);
|
| - }
|
| -
|
| - // Compute the matched filter error.
|
| - const float e = std::min(32767.f, std::max(-32768.f, y[i] - s));
|
| - error_sum += e * e;
|
| -
|
| - // Update the matched filter estimate in an NLMS manner.
|
| - if (x2_sum > x2_sum_threshold) {
|
| - filters_updated = true;
|
| - RTC_DCHECK_LT(0.f, x2_sum);
|
| - const float alpha = 0.7f * e / x2_sum;
|
| -
|
| - // filter = filter + 0.7 * (y - filter * x) / x * x.
|
| - std::transform(filters_[n].begin(), filters_[n].begin() + loop_size_1,
|
| - x_buffer_.data.begin() + x_start_index,
|
| - filters_[n].begin(),
|
| - [&](float a, float b) { return a + alpha * b; });
|
| -
|
| - if (loop_size_2 > 0) {
|
| - // filter = filter + 0.7 * (y - filter * x) / x * x.
|
| - std::transform(x_buffer_.data.begin(),
|
| - x_buffer_.data.begin() + loop_size_2,
|
| - filters_[n].begin() + loop_size_1,
|
| - filters_[n].begin() + loop_size_1,
|
| - [&](float a, float b) { return b + alpha * a; });
|
| - }
|
| - }
|
| -
|
| - x_start_index =
|
| - x_start_index > 0 ? x_start_index - 1 : x_buffer_.data.size() - 1;
|
| + switch (optimization_) {
|
| +#if defined(WEBRTC_ARCH_X86_FAMILY)
|
| + case Aec3Optimization::kSse2:
|
| + aec3::MatchedFilterCore_SSE2(x_start_index, x2_sum_threshold,
|
| + x_buffer_.data, y, filters_[n],
|
| + &filters_updated, &error_sum);
|
| + break;
|
| +#endif
|
| + default:
|
| + aec3::MatchedFilterCore(x_start_index, x2_sum_threshold, x_buffer_.data,
|
| + y, filters_[n], &filters_updated, &error_sum);
|
| }
|
|
|
| // Compute anchor for the matched filter error.
|
| @@ -140,11 +223,12 @@ void MatchedFilter::Update(const std::array<float, kSubBlockSize>& render,
|
| [](float a, float b) -> bool { return a * a < b * b; }));
|
|
|
| // Update the lag estimates for the matched filter.
|
| - const float kMatchingFilterThreshold = 0.3f;
|
| - lag_estimates_[n] =
|
| - LagEstimate(error_sum_anchor - error_sum,
|
| - error_sum < kMatchingFilterThreshold * error_sum_anchor,
|
| - lag_estimate + alignment_shift, filters_updated);
|
| + const float kMatchingFilterThreshold = 0.1f;
|
| + lag_estimates_[n] = LagEstimate(
|
| + error_sum_anchor - error_sum,
|
| + (lag_estimate > 2 && lag_estimate < (filters_[n].size() - 10) &&
|
| + error_sum < kMatchingFilterThreshold * error_sum_anchor),
|
| + lag_estimate + alignment_shift, filters_updated);
|
|
|
| // TODO(peah): Remove once development of EchoCanceller3 is fully done.
|
| RTC_DCHECK_EQ(4, filters_.size());
|
|
|