Index: webrtc/modules/audio_processing/aec3/matched_filter.cc |
diff --git a/webrtc/modules/audio_processing/aec3/matched_filter.cc b/webrtc/modules/audio_processing/aec3/matched_filter.cc |
index f187159911295aaf18771401d3f3cc71fed8c44b..64596b53c09a5fba195a87ed34aff98d64e40bd7 100644 |
--- a/webrtc/modules/audio_processing/aec3/matched_filter.cc |
+++ b/webrtc/modules/audio_processing/aec3/matched_filter.cc |
@@ -9,6 +9,10 @@ |
*/ |
#include "webrtc/modules/audio_processing/aec3/matched_filter.h" |
+#include "webrtc/typedefs.h" |
+#if defined(WEBRTC_ARCH_X86_FAMILY) |
+#include <emmintrin.h> |
+#endif |
#include <algorithm> |
#include <numeric> |
@@ -16,6 +20,131 @@ |
#include "webrtc/modules/audio_processing/logging/apm_data_dumper.h" |
namespace webrtc { |
+namespace aec3 { |
+ |
+#if defined(WEBRTC_ARCH_X86_FAMILY) |
+ |
+void MatchedFilterCore_SSE2(size_t x_start_index, |
+ float x2_sum_threshold, |
+ rtc::ArrayView<const float> x, |
+ rtc::ArrayView<const float> y, |
+ rtc::ArrayView<float> h, |
+ bool* filters_updated, |
+ float* error_sum) { |
+ // Process for all samples in the sub-block. |
+ for (size_t i = 0; i < kSubBlockSize; ++i) { |
+ // Apply the matched filter as filter * x. and compute x * x. |
+ float x2_sum = 0.f; |
+ float s = 0; |
+ size_t x_index = x_start_index; |
+ RTC_DCHECK_EQ(0, h.size() % 4); |
+ |
+ __m128 s_128 = _mm_set1_ps(0); |
+ __m128 x2_sum_128 = _mm_set1_ps(0); |
+ |
+ size_t k = 0; |
+ if (h.size() > (x.size() - x_index)) { |
+ const size_t limit = x.size() - x_index; |
+ for (; (k + 3) < limit; k += 4, x_index += 4) { |
+ const __m128 x_k = _mm_loadu_ps(&x[x_index]); |
+ const __m128 h_k = _mm_loadu_ps(&h[k]); |
+ const __m128 xx = _mm_mul_ps(x_k, x_k); |
+ x2_sum_128 = _mm_add_ps(x2_sum_128, xx); |
+ const __m128 hx = _mm_mul_ps(h_k, x_k); |
+ s_128 = _mm_add_ps(s_128, hx); |
+ } |
+ |
+ for (; k < limit; ++k, ++x_index) { |
+ x2_sum += x[x_index] * x[x_index]; |
+ s += h[k] * x[x_index]; |
+ } |
+ x_index = 0; |
+ } |
+ |
+ for (; k + 3 < h.size(); k += 4, x_index += 4) { |
+ const __m128 x_k = _mm_loadu_ps(&x[x_index]); |
+ const __m128 h_k = _mm_loadu_ps(&h[k]); |
+ const __m128 xx = _mm_mul_ps(x_k, x_k); |
+ x2_sum_128 = _mm_add_ps(x2_sum_128, xx); |
+ const __m128 hx = _mm_mul_ps(h_k, x_k); |
+ s_128 = _mm_add_ps(s_128, hx); |
+ } |
+ |
+ for (; k < h.size(); ++k, ++x_index) { |
+ x2_sum += x[x_index] * x[x_index]; |
+ s += h[k] * x[x_index]; |
+ } |
+ |
+ float* v = reinterpret_cast<float*>(&x2_sum_128); |
+ x2_sum += v[0] + v[1] + v[2] + v[3]; |
+ v = reinterpret_cast<float*>(&s_128); |
+ s += v[0] + v[1] + v[2] + v[3]; |
+ |
+ // Compute the matched filter error. |
+ const float e = std::min(32767.f, std::max(-32768.f, y[i] - s)); |
+ (*error_sum) += e * e; |
+ |
+ // Update the matched filter estimate in an NLMS manner. |
+ if (x2_sum > x2_sum_threshold) { |
+ RTC_DCHECK_LT(0.f, x2_sum); |
+ const float alpha = 0.7f * e / x2_sum; |
+ |
+ // filter = filter + 0.7 * (y - filter * x) / x * x. |
+ size_t x_index = x_start_index; |
+ for (size_t k = 0; k < h.size(); ++k) { |
+ h[k] += alpha * x[x_index]; |
+ x_index = x_index < (x.size() - 1) ? x_index + 1 : 0; |
+ } |
+ *filters_updated = true; |
+ } |
+ |
+ x_start_index = x_start_index > 0 ? x_start_index - 1 : x.size() - 1; |
+ } |
+} |
+#endif |
+ |
+void MatchedFilterCore(size_t x_start_index, |
+ float x2_sum_threshold, |
+ rtc::ArrayView<const float> x, |
+ rtc::ArrayView<const float> y, |
+ rtc::ArrayView<float> h, |
+ bool* filters_updated, |
+ float* error_sum) { |
+ // Process for all samples in the sub-block. |
+ for (size_t i = 0; i < kSubBlockSize; ++i) { |
+ // Apply the matched filter as filter * x. and compute x * x. |
+ float x2_sum = 0.f; |
+ float s = 0; |
+ size_t x_index = x_start_index; |
+ for (size_t k = 0; k < h.size(); ++k) { |
+ x2_sum += x[x_index] * x[x_index]; |
+ s += h[k] * x[x_index]; |
+ x_index = x_index < (x.size() - 1) ? x_index + 1 : 0; |
+ } |
+ |
+ // Compute the matched filter error. |
+ const float e = std::min(32767.f, std::max(-32768.f, y[i] - s)); |
+ (*error_sum) += e * e; |
+ |
+ // Update the matched filter estimate in an NLMS manner. |
+ if (x2_sum > x2_sum_threshold) { |
+ RTC_DCHECK_LT(0.f, x2_sum); |
+ const float alpha = 0.7f * e / x2_sum; |
+ |
+ // filter = filter + 0.7 * (y - filter * x) / x * x. |
+ size_t x_index = x_start_index; |
+ for (size_t k = 0; k < h.size(); ++k) { |
+ h[k] += alpha * x[x_index]; |
+ x_index = x_index < (x.size() - 1) ? x_index + 1 : 0; |
+ } |
+ *filters_updated = true; |
+ } |
+ |
+ x_start_index = x_start_index > 0 ? x_start_index - 1 : x.size() - 1; |
+ } |
+} |
+ |
+} // namespace aec3 |
MatchedFilter::IndexedBuffer::IndexedBuffer(size_t size) : data(size, 0.f) { |
RTC_DCHECK_EQ(0, size % kSubBlockSize); |
@@ -24,10 +153,12 @@ MatchedFilter::IndexedBuffer::IndexedBuffer(size_t size) : data(size, 0.f) { |
MatchedFilter::IndexedBuffer::~IndexedBuffer() = default; |
MatchedFilter::MatchedFilter(ApmDataDumper* data_dumper, |
+ Aec3Optimization optimization, |
size_t window_size_sub_blocks, |
int num_matched_filters, |
size_t alignment_shift_sub_blocks) |
: data_dumper_(data_dumper), |
+ optimization_(optimization), |
filter_intra_lag_shift_(alignment_shift_sub_blocks * kSubBlockSize), |
filters_(num_matched_filters, |
std::vector<float>(window_size_sub_blocks * kSubBlockSize, 0.f)), |
@@ -65,65 +196,17 @@ void MatchedFilter::Update(const std::array<float, kSubBlockSize>& render, |
(x_buffer_.index + alignment_shift + kSubBlockSize - 1) % |
x_buffer_.data.size(); |
- // Process for all samples in the sub-block. |
- for (size_t i = 0; i < kSubBlockSize; ++i) { |
- // As x_buffer is a circular buffer, all of the processing is split into |
- // two loops around the wrapping of the buffer. |
- const size_t loop_size_1 = |
- std::min(filters_[n].size(), x_buffer_.data.size() - x_start_index); |
- const size_t loop_size_2 = filters_[n].size() - loop_size_1; |
- RTC_DCHECK_EQ(filters_[n].size(), loop_size_1 + loop_size_2); |
- |
- // x * x. |
- float x2_sum = std::inner_product( |
- x_buffer_.data.begin() + x_start_index, |
- x_buffer_.data.begin() + x_start_index + loop_size_1, |
- x_buffer_.data.begin() + x_start_index, 0.f); |
- // Apply the matched filter as filter * x. |
- float s = std::inner_product(filters_[n].begin(), |
- filters_[n].begin() + loop_size_1, |
- x_buffer_.data.begin() + x_start_index, 0.f); |
- |
- if (loop_size_2 > 0) { |
- // Update the cumulative sum of x * x. |
- x2_sum = std::inner_product(x_buffer_.data.begin(), |
- x_buffer_.data.begin() + loop_size_2, |
- x_buffer_.data.begin(), x2_sum); |
- |
- // Compute the matched filter output filter * x in a cumulative manner. |
- s = std::inner_product(x_buffer_.data.begin(), |
- x_buffer_.data.begin() + loop_size_2, |
- filters_[n].begin() + loop_size_1, s); |
- } |
- |
- // Compute the matched filter error. |
- const float e = std::min(32767.f, std::max(-32768.f, y[i] - s)); |
- error_sum += e * e; |
- |
- // Update the matched filter estimate in an NLMS manner. |
- if (x2_sum > x2_sum_threshold) { |
- filters_updated = true; |
- RTC_DCHECK_LT(0.f, x2_sum); |
- const float alpha = 0.7f * e / x2_sum; |
- |
- // filter = filter + 0.7 * (y - filter * x) / x * x. |
- std::transform(filters_[n].begin(), filters_[n].begin() + loop_size_1, |
- x_buffer_.data.begin() + x_start_index, |
- filters_[n].begin(), |
- [&](float a, float b) { return a + alpha * b; }); |
- |
- if (loop_size_2 > 0) { |
- // filter = filter + 0.7 * (y - filter * x) / x * x. |
- std::transform(x_buffer_.data.begin(), |
- x_buffer_.data.begin() + loop_size_2, |
- filters_[n].begin() + loop_size_1, |
- filters_[n].begin() + loop_size_1, |
- [&](float a, float b) { return b + alpha * a; }); |
- } |
- } |
- |
- x_start_index = |
- x_start_index > 0 ? x_start_index - 1 : x_buffer_.data.size() - 1; |
+ switch (optimization_) { |
+#if defined(WEBRTC_ARCH_X86_FAMILY) |
+ case Aec3Optimization::kSse2: |
+ aec3::MatchedFilterCore_SSE2(x_start_index, x2_sum_threshold, |
+ x_buffer_.data, y, filters_[n], |
+ &filters_updated, &error_sum); |
+ break; |
+#endif |
+ default: |
+ aec3::MatchedFilterCore(x_start_index, x2_sum_threshold, x_buffer_.data, |
+ y, filters_[n], &filters_updated, &error_sum); |
} |
// Compute anchor for the matched filter error. |
@@ -140,11 +223,12 @@ void MatchedFilter::Update(const std::array<float, kSubBlockSize>& render, |
[](float a, float b) -> bool { return a * a < b * b; })); |
// Update the lag estimates for the matched filter. |
- const float kMatchingFilterThreshold = 0.3f; |
- lag_estimates_[n] = |
- LagEstimate(error_sum_anchor - error_sum, |
- error_sum < kMatchingFilterThreshold * error_sum_anchor, |
- lag_estimate + alignment_shift, filters_updated); |
+ const float kMatchingFilterThreshold = 0.1f; |
+ lag_estimates_[n] = LagEstimate( |
+ error_sum_anchor - error_sum, |
+ (lag_estimate > 2 && lag_estimate < (filters_[n].size() - 10) && |
+ error_sum < kMatchingFilterThreshold * error_sum_anchor), |
+ lag_estimate + alignment_shift, filters_updated); |
// TODO(peah): Remove once development of EchoCanceller3 is fully done. |
RTC_DCHECK_EQ(4, filters_.size()); |