| Index: webrtc/modules/audio_processing/aec3/matched_filter.cc
|
| diff --git a/webrtc/modules/audio_processing/aec3/matched_filter.cc b/webrtc/modules/audio_processing/aec3/matched_filter.cc
|
| new file mode 100644
|
| index 0000000000000000000000000000000000000000..f187159911295aaf18771401d3f3cc71fed8c44b
|
| --- /dev/null
|
| +++ b/webrtc/modules/audio_processing/aec3/matched_filter.cc
|
| @@ -0,0 +1,172 @@
|
| +/*
|
| + * Copyright (c) 2017 The WebRTC project authors. All Rights Reserved.
|
| + *
|
| + * Use of this source code is governed by a BSD-style license
|
| + * that can be found in the LICENSE file in the root of the source
|
| + * tree. An additional intellectual property rights grant can be found
|
| + * in the file PATENTS. All contributing project authors may
|
| + * be found in the AUTHORS file in the root of the source tree.
|
| + */
|
| +#include "webrtc/modules/audio_processing/aec3/matched_filter.h"
|
| +
|
| +#include <algorithm>
|
| +#include <numeric>
|
| +
|
| +#include "webrtc/modules/audio_processing/include/audio_processing.h"
|
| +#include "webrtc/modules/audio_processing/logging/apm_data_dumper.h"
|
| +
|
| +namespace webrtc {
|
| +
|
| +MatchedFilter::IndexedBuffer::IndexedBuffer(size_t size) : data(size, 0.f) {
|
| + RTC_DCHECK_EQ(0, size % kSubBlockSize);
|
| +}
|
| +
|
| +MatchedFilter::IndexedBuffer::~IndexedBuffer() = default;
|
| +
|
| +MatchedFilter::MatchedFilter(ApmDataDumper* data_dumper,
|
| + size_t window_size_sub_blocks,
|
| + int num_matched_filters,
|
| + size_t alignment_shift_sub_blocks)
|
| + : data_dumper_(data_dumper),
|
| + filter_intra_lag_shift_(alignment_shift_sub_blocks * kSubBlockSize),
|
| + filters_(num_matched_filters,
|
| + std::vector<float>(window_size_sub_blocks * kSubBlockSize, 0.f)),
|
| + lag_estimates_(num_matched_filters),
|
| + x_buffer_(kSubBlockSize *
|
| + (alignment_shift_sub_blocks * num_matched_filters +
|
| + window_size_sub_blocks +
|
| + 1)) {
|
| + RTC_DCHECK(data_dumper);
|
| + RTC_DCHECK_EQ(0, x_buffer_.data.size() % kSubBlockSize);
|
| + RTC_DCHECK_LT(0, window_size_sub_blocks);
|
| +}
|
| +
|
| +MatchedFilter::~MatchedFilter() = default;
|
| +
|
| +void MatchedFilter::Update(const std::array<float, kSubBlockSize>& render,
|
| + const std::array<float, kSubBlockSize>& capture) {
|
| + const std::array<float, kSubBlockSize>& x = render;
|
| + const std::array<float, kSubBlockSize>& y = capture;
|
| +
|
| + const float x2_sum_threshold = filters_[0].size() * 150.f * 150.f;
|
| +
|
| + // Insert the new subblock into x_buffer.
|
| + x_buffer_.index = (x_buffer_.index - kSubBlockSize + x_buffer_.data.size()) %
|
| + x_buffer_.data.size();
|
| + RTC_DCHECK_LE(kSubBlockSize, x_buffer_.data.size() - x_buffer_.index);
|
| + std::copy(x.rbegin(), x.rend(), x_buffer_.data.begin() + x_buffer_.index);
|
| +
|
| + // Apply all matched filters.
|
| + size_t alignment_shift = 0;
|
| + for (size_t n = 0; n < filters_.size(); ++n) {
|
| + float error_sum = 0.f;
|
| + bool filters_updated = false;
|
| + size_t x_start_index =
|
| + (x_buffer_.index + alignment_shift + kSubBlockSize - 1) %
|
| + x_buffer_.data.size();
|
| +
|
| + // Process for all samples in the sub-block.
|
| + for (size_t i = 0; i < kSubBlockSize; ++i) {
|
| + // As x_buffer is a circular buffer, all of the processing is split into
|
| + // two loops around the wrapping of the buffer.
|
| + const size_t loop_size_1 =
|
| + std::min(filters_[n].size(), x_buffer_.data.size() - x_start_index);
|
| + const size_t loop_size_2 = filters_[n].size() - loop_size_1;
|
| + RTC_DCHECK_EQ(filters_[n].size(), loop_size_1 + loop_size_2);
|
| +
|
| + // x * x.
|
| + float x2_sum = std::inner_product(
|
| + x_buffer_.data.begin() + x_start_index,
|
| + x_buffer_.data.begin() + x_start_index + loop_size_1,
|
| + x_buffer_.data.begin() + x_start_index, 0.f);
|
| + // Apply the matched filter as filter * x.
|
| + float s = std::inner_product(filters_[n].begin(),
|
| + filters_[n].begin() + loop_size_1,
|
| + x_buffer_.data.begin() + x_start_index, 0.f);
|
| +
|
| + if (loop_size_2 > 0) {
|
| + // Update the cumulative sum of x * x.
|
| + x2_sum = std::inner_product(x_buffer_.data.begin(),
|
| + x_buffer_.data.begin() + loop_size_2,
|
| + x_buffer_.data.begin(), x2_sum);
|
| +
|
| + // Compute the matched filter output filter * x in a cumulative manner.
|
| + s = std::inner_product(x_buffer_.data.begin(),
|
| + x_buffer_.data.begin() + loop_size_2,
|
| + filters_[n].begin() + loop_size_1, s);
|
| + }
|
| +
|
| + // Compute the matched filter error.
|
| + const float e = std::min(32767.f, std::max(-32768.f, y[i] - s));
|
| + error_sum += e * e;
|
| +
|
| + // Update the matched filter estimate in an NLMS manner.
|
| + if (x2_sum > x2_sum_threshold) {
|
| + filters_updated = true;
|
| + RTC_DCHECK_LT(0.f, x2_sum);
|
| + const float alpha = 0.7f * e / x2_sum;
|
| +
|
| + // filter = filter + 0.7 * (y - filter * x) / x * x.
|
| + std::transform(filters_[n].begin(), filters_[n].begin() + loop_size_1,
|
| + x_buffer_.data.begin() + x_start_index,
|
| + filters_[n].begin(),
|
| + [&](float a, float b) { return a + alpha * b; });
|
| +
|
| + if (loop_size_2 > 0) {
|
| + // filter = filter + 0.7 * (y - filter * x) / x * x.
|
| + std::transform(x_buffer_.data.begin(),
|
| + x_buffer_.data.begin() + loop_size_2,
|
| + filters_[n].begin() + loop_size_1,
|
| + filters_[n].begin() + loop_size_1,
|
| + [&](float a, float b) { return b + alpha * a; });
|
| + }
|
| + }
|
| +
|
| + x_start_index =
|
| + x_start_index > 0 ? x_start_index - 1 : x_buffer_.data.size() - 1;
|
| + }
|
| +
|
| + // Compute anchor for the matched filter error.
|
| + const float error_sum_anchor =
|
| + std::inner_product(y.begin(), y.end(), y.begin(), 0.f);
|
| +
|
| + // Estimate the lag in the matched filter as the distance to the portion in
|
| + // the filter that contributes the most to the matched filter output. This
|
| + // is detected as the peak of the matched filter.
|
| + const size_t lag_estimate = std::distance(
|
| + filters_[n].begin(),
|
| + std::max_element(
|
| + filters_[n].begin(), filters_[n].end(),
|
| + [](float a, float b) -> bool { return a * a < b * b; }));
|
| +
|
| + // Update the lag estimates for the matched filter.
|
| + const float kMatchingFilterThreshold = 0.3f;
|
| + lag_estimates_[n] =
|
| + LagEstimate(error_sum_anchor - error_sum,
|
| + error_sum < kMatchingFilterThreshold * error_sum_anchor,
|
| + lag_estimate + alignment_shift, filters_updated);
|
| +
|
| + // TODO(peah): Remove once development of EchoCanceller3 is fully done.
|
| + RTC_DCHECK_EQ(4, filters_.size());
|
| + switch (n) {
|
| + case 0:
|
| + data_dumper_->DumpRaw("aec3_correlator_0_h", filters_[0]);
|
| + break;
|
| + case 1:
|
| + data_dumper_->DumpRaw("aec3_correlator_1_h", filters_[1]);
|
| + break;
|
| + case 2:
|
| + data_dumper_->DumpRaw("aec3_correlator_2_h", filters_[2]);
|
| + break;
|
| + case 3:
|
| + data_dumper_->DumpRaw("aec3_correlator_3_h", filters_[3]);
|
| + break;
|
| + default:
|
| + RTC_DCHECK(false);
|
| + }
|
| +
|
| + alignment_shift += filter_intra_lag_shift_;
|
| + }
|
| +}
|
| +
|
| +} // namespace webrtc
|
|
|