OLD | NEW |
1 /* | 1 /* |
2 * Copyright (c) 2014 The WebRTC project authors. All Rights Reserved. | 2 * Copyright (c) 2014 The WebRTC project authors. All Rights Reserved. |
3 * | 3 * |
4 * Use of this source code is governed by a BSD-style license | 4 * Use of this source code is governed by a BSD-style license |
5 * that can be found in the LICENSE file in the root of the source | 5 * that can be found in the LICENSE file in the root of the source |
6 * tree. An additional intellectual property rights grant can be found | 6 * tree. An additional intellectual property rights grant can be found |
7 * in the file PATENTS. All contributing project authors may | 7 * in the file PATENTS. All contributing project authors may |
8 * be found in the AUTHORS file in the root of the source tree. | 8 * be found in the AUTHORS file in the root of the source tree. |
9 */ | 9 */ |
10 | 10 |
(...skipping 62 matching lines...) Expand 10 before | Expand all | Expand 10 after Loading... |
73 const float kMaskTargetThreshold = 0.3f; | 73 const float kMaskTargetThreshold = 0.3f; |
74 // Time in seconds after which the data is considered interference if the mask | 74 // Time in seconds after which the data is considered interference if the mask |
75 // does not pass |kMaskTargetThreshold|. | 75 // does not pass |kMaskTargetThreshold|. |
76 const float kHoldTargetSeconds = 0.25f; | 76 const float kHoldTargetSeconds = 0.25f; |
77 | 77 |
78 // Does conjugate(|norm_mat|) * |mat| * transpose(|norm_mat|). No extra space is | 78 // Does conjugate(|norm_mat|) * |mat| * transpose(|norm_mat|). No extra space is |
79 // used; to accomplish this, we compute both multiplications in the same loop. | 79 // used; to accomplish this, we compute both multiplications in the same loop. |
80 // The returned norm is clamped to be non-negative. | 80 // The returned norm is clamped to be non-negative. |
81 float Norm(const ComplexMatrix<float>& mat, | 81 float Norm(const ComplexMatrix<float>& mat, |
82 const ComplexMatrix<float>& norm_mat) { | 82 const ComplexMatrix<float>& norm_mat) { |
83 CHECK_EQ(norm_mat.num_rows(), 1); | 83 RTC_CHECK_EQ(norm_mat.num_rows(), 1); |
84 CHECK_EQ(norm_mat.num_columns(), mat.num_rows()); | 84 RTC_CHECK_EQ(norm_mat.num_columns(), mat.num_rows()); |
85 CHECK_EQ(norm_mat.num_columns(), mat.num_columns()); | 85 RTC_CHECK_EQ(norm_mat.num_columns(), mat.num_columns()); |
86 | 86 |
87 complex<float> first_product = complex<float>(0.f, 0.f); | 87 complex<float> first_product = complex<float>(0.f, 0.f); |
88 complex<float> second_product = complex<float>(0.f, 0.f); | 88 complex<float> second_product = complex<float>(0.f, 0.f); |
89 | 89 |
90 const complex<float>* const* mat_els = mat.elements(); | 90 const complex<float>* const* mat_els = mat.elements(); |
91 const complex<float>* const* norm_mat_els = norm_mat.elements(); | 91 const complex<float>* const* norm_mat_els = norm_mat.elements(); |
92 | 92 |
93 for (int i = 0; i < norm_mat.num_columns(); ++i) { | 93 for (int i = 0; i < norm_mat.num_columns(); ++i) { |
94 for (int j = 0; j < norm_mat.num_columns(); ++j) { | 94 for (int j = 0; j < norm_mat.num_columns(); ++j) { |
95 first_product += conj(norm_mat_els[0][j]) * mat_els[j][i]; | 95 first_product += conj(norm_mat_els[0][j]) * mat_els[j][i]; |
96 } | 96 } |
97 second_product += first_product * norm_mat_els[0][i]; | 97 second_product += first_product * norm_mat_els[0][i]; |
98 first_product = 0.f; | 98 first_product = 0.f; |
99 } | 99 } |
100 return std::max(second_product.real(), 0.f); | 100 return std::max(second_product.real(), 0.f); |
101 } | 101 } |
102 | 102 |
103 // Does conjugate(|lhs|) * |rhs| for row vectors |lhs| and |rhs|. | 103 // Does conjugate(|lhs|) * |rhs| for row vectors |lhs| and |rhs|. |
104 complex<float> ConjugateDotProduct(const ComplexMatrix<float>& lhs, | 104 complex<float> ConjugateDotProduct(const ComplexMatrix<float>& lhs, |
105 const ComplexMatrix<float>& rhs) { | 105 const ComplexMatrix<float>& rhs) { |
106 CHECK_EQ(lhs.num_rows(), 1); | 106 RTC_CHECK_EQ(lhs.num_rows(), 1); |
107 CHECK_EQ(rhs.num_rows(), 1); | 107 RTC_CHECK_EQ(rhs.num_rows(), 1); |
108 CHECK_EQ(lhs.num_columns(), rhs.num_columns()); | 108 RTC_CHECK_EQ(lhs.num_columns(), rhs.num_columns()); |
109 | 109 |
110 const complex<float>* const* lhs_elements = lhs.elements(); | 110 const complex<float>* const* lhs_elements = lhs.elements(); |
111 const complex<float>* const* rhs_elements = rhs.elements(); | 111 const complex<float>* const* rhs_elements = rhs.elements(); |
112 | 112 |
113 complex<float> result = complex<float>(0.f, 0.f); | 113 complex<float> result = complex<float>(0.f, 0.f); |
114 for (int i = 0; i < lhs.num_columns(); ++i) { | 114 for (int i = 0; i < lhs.num_columns(); ++i) { |
115 result += conj(lhs_elements[0][i]) * rhs_elements[0][i]; | 115 result += conj(lhs_elements[0][i]) * rhs_elements[0][i]; |
116 } | 116 } |
117 | 117 |
118 return result; | 118 return result; |
(...skipping 25 matching lines...) Expand all Loading... |
144 float abs_value = std::abs(mat_els[i][j]); | 144 float abs_value = std::abs(mat_els[i][j]); |
145 sum_squares += abs_value * abs_value; | 145 sum_squares += abs_value * abs_value; |
146 } | 146 } |
147 } | 147 } |
148 return sum_squares; | 148 return sum_squares; |
149 } | 149 } |
150 | 150 |
151 // Does |out| = |in|.' * conj(|in|) for row vector |in|. | 151 // Does |out| = |in|.' * conj(|in|) for row vector |in|. |
152 void TransposedConjugatedProduct(const ComplexMatrix<float>& in, | 152 void TransposedConjugatedProduct(const ComplexMatrix<float>& in, |
153 ComplexMatrix<float>* out) { | 153 ComplexMatrix<float>* out) { |
154 CHECK_EQ(in.num_rows(), 1); | 154 RTC_CHECK_EQ(in.num_rows(), 1); |
155 CHECK_EQ(out->num_rows(), in.num_columns()); | 155 RTC_CHECK_EQ(out->num_rows(), in.num_columns()); |
156 CHECK_EQ(out->num_columns(), in.num_columns()); | 156 RTC_CHECK_EQ(out->num_columns(), in.num_columns()); |
157 const complex<float>* in_elements = in.elements()[0]; | 157 const complex<float>* in_elements = in.elements()[0]; |
158 complex<float>* const* out_elements = out->elements(); | 158 complex<float>* const* out_elements = out->elements(); |
159 for (int i = 0; i < out->num_rows(); ++i) { | 159 for (int i = 0; i < out->num_rows(); ++i) { |
160 for (int j = 0; j < out->num_columns(); ++j) { | 160 for (int j = 0; j < out->num_columns(); ++j) { |
161 out_elements[i][j] = in_elements[i] * conj(in_elements[j]); | 161 out_elements[i][j] = in_elements[i] * conj(in_elements[j]); |
162 } | 162 } |
163 } | 163 } |
164 } | 164 } |
165 | 165 |
166 std::vector<Point> GetCenteredArray(std::vector<Point> array_geometry) { | 166 std::vector<Point> GetCenteredArray(std::vector<Point> array_geometry) { |
(...skipping 33 matching lines...) Expand 10 before | Expand all | Expand 10 after Loading... |
200 // These bin indexes determine the regions over which a mean is taken. This | 200 // These bin indexes determine the regions over which a mean is taken. This |
201 // is applied as a constant value over the adjacent end "frequency correction" | 201 // is applied as a constant value over the adjacent end "frequency correction" |
202 // regions. | 202 // regions. |
203 // | 203 // |
204 // low_mean_start_bin_ high_mean_start_bin_ | 204 // low_mean_start_bin_ high_mean_start_bin_ |
205 // v v constant | 205 // v v constant |
206 // |----------------|--------|----------------|-------|----------------| | 206 // |----------------|--------|----------------|-------|----------------| |
207 // constant ^ ^ | 207 // constant ^ ^ |
208 // low_mean_end_bin_ high_mean_end_bin_ | 208 // low_mean_end_bin_ high_mean_end_bin_ |
209 // | 209 // |
210 DCHECK_GT(low_mean_start_bin_, 0U); | 210 RTC_DCHECK_GT(low_mean_start_bin_, 0U); |
211 DCHECK_LT(low_mean_start_bin_, low_mean_end_bin_); | 211 RTC_DCHECK_LT(low_mean_start_bin_, low_mean_end_bin_); |
212 DCHECK_LT(low_mean_end_bin_, high_mean_end_bin_); | 212 RTC_DCHECK_LT(low_mean_end_bin_, high_mean_end_bin_); |
213 DCHECK_LT(high_mean_start_bin_, high_mean_end_bin_); | 213 RTC_DCHECK_LT(high_mean_start_bin_, high_mean_end_bin_); |
214 DCHECK_LT(high_mean_end_bin_, kNumFreqBins - 1); | 214 RTC_DCHECK_LT(high_mean_end_bin_, kNumFreqBins - 1); |
215 | 215 |
216 high_pass_postfilter_mask_ = 1.f; | 216 high_pass_postfilter_mask_ = 1.f; |
217 is_target_present_ = false; | 217 is_target_present_ = false; |
218 hold_target_blocks_ = kHoldTargetSeconds * 2 * sample_rate_hz / kFftSize; | 218 hold_target_blocks_ = kHoldTargetSeconds * 2 * sample_rate_hz / kFftSize; |
219 interference_blocks_count_ = hold_target_blocks_; | 219 interference_blocks_count_ = hold_target_blocks_; |
220 | 220 |
221 | 221 |
222 lapped_transform_.reset(new LappedTransform(num_input_channels_, | 222 lapped_transform_.reset(new LappedTransform(num_input_channels_, |
223 1, | 223 1, |
224 chunk_length_, | 224 chunk_length_, |
(...skipping 80 matching lines...) Expand 10 before | Expand all | Expand 10 after Loading... |
305 // Average matrices. | 305 // Average matrices. |
306 uniform_cov_mat.Scale(1 - kBalance); | 306 uniform_cov_mat.Scale(1 - kBalance); |
307 angled_cov_mat.Scale(kBalance); | 307 angled_cov_mat.Scale(kBalance); |
308 interf_cov_mats_[i].Add(uniform_cov_mat, angled_cov_mat); | 308 interf_cov_mats_[i].Add(uniform_cov_mat, angled_cov_mat); |
309 reflected_interf_cov_mats_[i].PointwiseConjugate(interf_cov_mats_[i]); | 309 reflected_interf_cov_mats_[i].PointwiseConjugate(interf_cov_mats_[i]); |
310 } | 310 } |
311 } | 311 } |
312 | 312 |
313 void NonlinearBeamformer::ProcessChunk(const ChannelBuffer<float>& input, | 313 void NonlinearBeamformer::ProcessChunk(const ChannelBuffer<float>& input, |
314 ChannelBuffer<float>* output) { | 314 ChannelBuffer<float>* output) { |
315 DCHECK_EQ(input.num_channels(), num_input_channels_); | 315 RTC_DCHECK_EQ(input.num_channels(), num_input_channels_); |
316 DCHECK_EQ(input.num_frames_per_band(), chunk_length_); | 316 RTC_DCHECK_EQ(input.num_frames_per_band(), chunk_length_); |
317 | 317 |
318 float old_high_pass_mask = high_pass_postfilter_mask_; | 318 float old_high_pass_mask = high_pass_postfilter_mask_; |
319 lapped_transform_->ProcessChunk(input.channels(0), output->channels(0)); | 319 lapped_transform_->ProcessChunk(input.channels(0), output->channels(0)); |
320 // Ramp up/down for smoothing. 1 mask per 10ms results in audible | 320 // Ramp up/down for smoothing. 1 mask per 10ms results in audible |
321 // discontinuities. | 321 // discontinuities. |
322 const float ramp_increment = | 322 const float ramp_increment = |
323 (high_pass_postfilter_mask_ - old_high_pass_mask) / | 323 (high_pass_postfilter_mask_ - old_high_pass_mask) / |
324 input.num_frames_per_band(); | 324 input.num_frames_per_band(); |
325 // Apply delay and sum and post-filter in the time domain. WARNING: only works | 325 // Apply delay and sum and post-filter in the time domain. WARNING: only works |
326 // because delay-and-sum is not frequency dependent. | 326 // because delay-and-sum is not frequency dependent. |
(...skipping 18 matching lines...) Expand all Loading... |
345 // you are out of the beam. | 345 // you are out of the beam. |
346 return fabs(spherical_point.azimuth() - kTargetAngleRadians) < | 346 return fabs(spherical_point.azimuth() - kTargetAngleRadians) < |
347 kHalfBeamWidthRadians; | 347 kHalfBeamWidthRadians; |
348 } | 348 } |
349 | 349 |
350 void NonlinearBeamformer::ProcessAudioBlock(const complex_f* const* input, | 350 void NonlinearBeamformer::ProcessAudioBlock(const complex_f* const* input, |
351 int num_input_channels, | 351 int num_input_channels, |
352 size_t num_freq_bins, | 352 size_t num_freq_bins, |
353 int num_output_channels, | 353 int num_output_channels, |
354 complex_f* const* output) { | 354 complex_f* const* output) { |
355 CHECK_EQ(num_freq_bins, kNumFreqBins); | 355 RTC_CHECK_EQ(num_freq_bins, kNumFreqBins); |
356 CHECK_EQ(num_input_channels, num_input_channels_); | 356 RTC_CHECK_EQ(num_input_channels, num_input_channels_); |
357 CHECK_EQ(num_output_channels, 1); | 357 RTC_CHECK_EQ(num_output_channels, 1); |
358 | 358 |
359 // Calculating the post-filter masks. Note that we need two for each | 359 // Calculating the post-filter masks. Note that we need two for each |
360 // frequency bin to account for the positive and negative interferer | 360 // frequency bin to account for the positive and negative interferer |
361 // angle. | 361 // angle. |
362 for (size_t i = low_mean_start_bin_; i <= high_mean_end_bin_; ++i) { | 362 for (size_t i = low_mean_start_bin_; i <= high_mean_end_bin_; ++i) { |
363 eig_m_.CopyFromColumn(input, i, num_input_channels_); | 363 eig_m_.CopyFromColumn(input, i, num_input_channels_); |
364 float eig_m_norm_factor = std::sqrt(SumSquares(eig_m_)); | 364 float eig_m_norm_factor = std::sqrt(SumSquares(eig_m_)); |
365 if (eig_m_norm_factor != 0.f) { | 365 if (eig_m_norm_factor != 0.f) { |
366 eig_m_.Scale(1.f / eig_m_norm_factor); | 366 eig_m_.Scale(1.f / eig_m_norm_factor); |
367 } | 367 } |
(...skipping 118 matching lines...) Expand 10 before | Expand all | Expand 10 after Loading... |
486 // high_pass_postfilter_mask_ to use for the high frequency time-domain bands. | 486 // high_pass_postfilter_mask_ to use for the high frequency time-domain bands. |
487 void NonlinearBeamformer::ApplyHighFrequencyCorrection() { | 487 void NonlinearBeamformer::ApplyHighFrequencyCorrection() { |
488 high_pass_postfilter_mask_ = | 488 high_pass_postfilter_mask_ = |
489 MaskRangeMean(high_mean_start_bin_, high_mean_end_bin_ + 1); | 489 MaskRangeMean(high_mean_start_bin_, high_mean_end_bin_ + 1); |
490 std::fill(time_smooth_mask_ + high_mean_end_bin_ + 1, | 490 std::fill(time_smooth_mask_ + high_mean_end_bin_ + 1, |
491 time_smooth_mask_ + kNumFreqBins, high_pass_postfilter_mask_); | 491 time_smooth_mask_ + kNumFreqBins, high_pass_postfilter_mask_); |
492 } | 492 } |
493 | 493 |
494 // Compute mean over the given range of time_smooth_mask_, [first, last). | 494 // Compute mean over the given range of time_smooth_mask_, [first, last). |
495 float NonlinearBeamformer::MaskRangeMean(size_t first, size_t last) { | 495 float NonlinearBeamformer::MaskRangeMean(size_t first, size_t last) { |
496 DCHECK_GT(last, first); | 496 RTC_DCHECK_GT(last, first); |
497 const float sum = std::accumulate(time_smooth_mask_ + first, | 497 const float sum = std::accumulate(time_smooth_mask_ + first, |
498 time_smooth_mask_ + last, 0.f); | 498 time_smooth_mask_ + last, 0.f); |
499 return sum / (last - first); | 499 return sum / (last - first); |
500 } | 500 } |
501 | 501 |
502 void NonlinearBeamformer::EstimateTargetPresence() { | 502 void NonlinearBeamformer::EstimateTargetPresence() { |
503 const size_t quantile = static_cast<size_t>( | 503 const size_t quantile = static_cast<size_t>( |
504 (high_mean_end_bin_ - low_mean_start_bin_) * kMaskQuantile + | 504 (high_mean_end_bin_ - low_mean_start_bin_) * kMaskQuantile + |
505 low_mean_start_bin_); | 505 low_mean_start_bin_); |
506 std::nth_element(new_mask_ + low_mean_start_bin_, new_mask_ + quantile, | 506 std::nth_element(new_mask_ + low_mean_start_bin_, new_mask_ + quantile, |
507 new_mask_ + high_mean_end_bin_ + 1); | 507 new_mask_ + high_mean_end_bin_ + 1); |
508 if (new_mask_[quantile] > kMaskTargetThreshold) { | 508 if (new_mask_[quantile] > kMaskTargetThreshold) { |
509 is_target_present_ = true; | 509 is_target_present_ = true; |
510 interference_blocks_count_ = 0; | 510 interference_blocks_count_ = 0; |
511 } else { | 511 } else { |
512 is_target_present_ = interference_blocks_count_++ < hold_target_blocks_; | 512 is_target_present_ = interference_blocks_count_++ < hold_target_blocks_; |
513 } | 513 } |
514 } | 514 } |
515 | 515 |
516 } // namespace webrtc | 516 } // namespace webrtc |
OLD | NEW |