OLD | NEW |
1 /* | 1 /* |
2 * Copyright (c) 2016 The WebRTC project authors. All Rights Reserved. | 2 * Copyright (c) 2016 The WebRTC project authors. All Rights Reserved. |
3 * | 3 * |
4 * Use of this source code is governed by a BSD-style license | 4 * Use of this source code is governed by a BSD-style license |
5 * that can be found in the LICENSE file in the root of the source | 5 * that can be found in the LICENSE file in the root of the source |
6 * tree. An additional intellectual property rights grant can be found | 6 * tree. An additional intellectual property rights grant can be found |
7 * in the file PATENTS. All contributing project authors may | 7 * in the file PATENTS. All contributing project authors may |
8 * be found in the AUTHORS file in the root of the source tree. | 8 * be found in the AUTHORS file in the root of the source tree. |
9 */ | 9 */ |
10 | 10 |
11 #include "webrtc/modules/audio_processing/level_controller/signal_classifier.h" | 11 #include "webrtc/modules/audio_processing/level_controller/signal_classifier.h" |
12 | 12 |
13 #include <algorithm> | 13 #include <algorithm> |
14 #include <numeric> | 14 #include <numeric> |
15 #include <vector> | 15 #include <vector> |
16 | 16 |
17 #include "webrtc/base/array_view.h" | 17 #include "webrtc/base/array_view.h" |
18 #include "webrtc/base/constructormagic.h" | 18 #include "webrtc/base/constructormagic.h" |
19 #include "webrtc/modules/audio_processing/aec/aec_rdft.h" | |
20 #include "webrtc/modules/audio_processing/audio_buffer.h" | 19 #include "webrtc/modules/audio_processing/audio_buffer.h" |
21 #include "webrtc/modules/audio_processing/level_controller/down_sampler.h" | 20 #include "webrtc/modules/audio_processing/level_controller/down_sampler.h" |
22 #include "webrtc/modules/audio_processing/level_controller/noise_spectrum_estima
tor.h" | 21 #include "webrtc/modules/audio_processing/level_controller/noise_spectrum_estima
tor.h" |
23 #include "webrtc/modules/audio_processing/logging/apm_data_dumper.h" | 22 #include "webrtc/modules/audio_processing/logging/apm_data_dumper.h" |
24 | 23 |
25 namespace webrtc { | 24 namespace webrtc { |
26 namespace { | 25 namespace { |
27 | 26 |
28 void RemoveDcLevel(rtc::ArrayView<float> x) { | 27 void RemoveDcLevel(rtc::ArrayView<float> x) { |
29 RTC_DCHECK_LT(0u, x.size()); | 28 RTC_DCHECK_LT(0u, x.size()); |
30 float mean = std::accumulate(x.data(), x.data() + x.size(), 0.f); | 29 float mean = std::accumulate(x.data(), x.data() + x.size(), 0.f); |
31 mean /= x.size(); | 30 mean /= x.size(); |
32 | 31 |
33 for (float& v : x) { | 32 for (float& v : x) { |
34 v -= mean; | 33 v -= mean; |
35 } | 34 } |
36 } | 35 } |
37 | 36 |
38 void PowerSpectrum(rtc::ArrayView<const float> x, | 37 void PowerSpectrum(const OouraFft* ooura_fft, |
| 38 rtc::ArrayView<const float> x, |
39 rtc::ArrayView<float> spectrum) { | 39 rtc::ArrayView<float> spectrum) { |
40 RTC_DCHECK_EQ(65u, spectrum.size()); | 40 RTC_DCHECK_EQ(65u, spectrum.size()); |
41 RTC_DCHECK_EQ(128u, x.size()); | 41 RTC_DCHECK_EQ(128u, x.size()); |
42 float X[128]; | 42 float X[128]; |
43 std::copy(x.data(), x.data() + x.size(), X); | 43 std::copy(x.data(), x.data() + x.size(), X); |
44 aec_rdft_forward_128(X); | 44 ooura_fft->Fft(X); |
45 | 45 |
46 float* X_p = X; | 46 float* X_p = X; |
47 RTC_DCHECK_EQ(X_p, &X[0]); | 47 RTC_DCHECK_EQ(X_p, &X[0]); |
48 spectrum[0] = (*X_p) * (*X_p); | 48 spectrum[0] = (*X_p) * (*X_p); |
49 ++X_p; | 49 ++X_p; |
50 RTC_DCHECK_EQ(X_p, &X[1]); | 50 RTC_DCHECK_EQ(X_p, &X[1]); |
51 spectrum[64] = (*X_p) * (*X_p); | 51 spectrum[64] = (*X_p) * (*X_p); |
52 for (int k = 1; k < 64; ++k) { | 52 for (int k = 1; k < 64; ++k) { |
53 ++X_p; | 53 ++X_p; |
54 RTC_DCHECK_EQ(X_p, &X[2 * k]); | 54 RTC_DCHECK_EQ(X_p, &X[2 * k]); |
(...skipping 56 matching lines...) Expand 10 before | Expand all | Expand 10 after Loading... |
111 | 111 |
112 SignalClassifier::SignalClassifier(ApmDataDumper* data_dumper) | 112 SignalClassifier::SignalClassifier(ApmDataDumper* data_dumper) |
113 : data_dumper_(data_dumper), | 113 : data_dumper_(data_dumper), |
114 down_sampler_(data_dumper_), | 114 down_sampler_(data_dumper_), |
115 noise_spectrum_estimator_(data_dumper_) { | 115 noise_spectrum_estimator_(data_dumper_) { |
116 Initialize(AudioProcessing::kSampleRate48kHz); | 116 Initialize(AudioProcessing::kSampleRate48kHz); |
117 } | 117 } |
118 SignalClassifier::~SignalClassifier() {} | 118 SignalClassifier::~SignalClassifier() {} |
119 | 119 |
120 void SignalClassifier::Initialize(int sample_rate_hz) { | 120 void SignalClassifier::Initialize(int sample_rate_hz) { |
121 aec_rdft_init(); | |
122 down_sampler_.Initialize(sample_rate_hz); | 121 down_sampler_.Initialize(sample_rate_hz); |
123 noise_spectrum_estimator_.Initialize(); | 122 noise_spectrum_estimator_.Initialize(); |
124 frame_extender_.reset(new FrameExtender(80, 128)); | 123 frame_extender_.reset(new FrameExtender(80, 128)); |
125 sample_rate_hz_ = sample_rate_hz; | 124 sample_rate_hz_ = sample_rate_hz; |
126 initialization_frames_left_ = 2; | 125 initialization_frames_left_ = 2; |
127 consistent_classification_counter_ = 3; | 126 consistent_classification_counter_ = 3; |
128 last_signal_type_ = SignalClassifier::SignalType::kNonStationary; | 127 last_signal_type_ = SignalClassifier::SignalType::kNonStationary; |
129 } | 128 } |
130 | 129 |
131 void SignalClassifier::Analyze(const AudioBuffer& audio, | 130 void SignalClassifier::Analyze(const AudioBuffer& audio, |
132 SignalType* signal_type) { | 131 SignalType* signal_type) { |
133 RTC_DCHECK_EQ(audio.num_frames(), static_cast<size_t>(sample_rate_hz_ / 100)); | 132 RTC_DCHECK_EQ(audio.num_frames(), static_cast<size_t>(sample_rate_hz_ / 100)); |
134 | 133 |
135 // Compute the signal power spectrum. | 134 // Compute the signal power spectrum. |
136 float downsampled_frame[80]; | 135 float downsampled_frame[80]; |
137 down_sampler_.DownSample(rtc::ArrayView<const float>( | 136 down_sampler_.DownSample(rtc::ArrayView<const float>( |
138 audio.channels_const_f()[0], audio.num_frames()), | 137 audio.channels_const_f()[0], audio.num_frames()), |
139 downsampled_frame); | 138 downsampled_frame); |
140 float extended_frame[128]; | 139 float extended_frame[128]; |
141 frame_extender_->ExtendFrame(downsampled_frame, extended_frame); | 140 frame_extender_->ExtendFrame(downsampled_frame, extended_frame); |
142 RemoveDcLevel(extended_frame); | 141 RemoveDcLevel(extended_frame); |
143 float signal_spectrum[65]; | 142 float signal_spectrum[65]; |
144 PowerSpectrum(extended_frame, signal_spectrum); | 143 PowerSpectrum(&ooura_fft_, extended_frame, signal_spectrum); |
145 | 144 |
146 // Classify the signal based on the estimate of the noise spectrum and the | 145 // Classify the signal based on the estimate of the noise spectrum and the |
147 // signal spectrum estimate. | 146 // signal spectrum estimate. |
148 *signal_type = ClassifySignal(signal_spectrum, | 147 *signal_type = ClassifySignal(signal_spectrum, |
149 noise_spectrum_estimator_.GetNoiseSpectrum(), | 148 noise_spectrum_estimator_.GetNoiseSpectrum(), |
150 data_dumper_); | 149 data_dumper_); |
151 | 150 |
152 // Update the noise spectrum based on the signal spectrum. | 151 // Update the noise spectrum based on the signal spectrum. |
153 noise_spectrum_estimator_.Update(signal_spectrum, | 152 noise_spectrum_estimator_.Update(signal_spectrum, |
154 initialization_frames_left_ > 0); | 153 initialization_frames_left_ > 0); |
155 | 154 |
156 // Update the number of frames until a reliable signal spectrum is achieved. | 155 // Update the number of frames until a reliable signal spectrum is achieved. |
157 initialization_frames_left_ = std::max(0, initialization_frames_left_ - 1); | 156 initialization_frames_left_ = std::max(0, initialization_frames_left_ - 1); |
158 | 157 |
159 if (last_signal_type_ == *signal_type) { | 158 if (last_signal_type_ == *signal_type) { |
160 consistent_classification_counter_ = | 159 consistent_classification_counter_ = |
161 std::max(0, consistent_classification_counter_ - 1); | 160 std::max(0, consistent_classification_counter_ - 1); |
162 } else { | 161 } else { |
163 last_signal_type_ = *signal_type; | 162 last_signal_type_ = *signal_type; |
164 consistent_classification_counter_ = 3; | 163 consistent_classification_counter_ = 3; |
165 } | 164 } |
166 | 165 |
167 if (consistent_classification_counter_ > 0) { | 166 if (consistent_classification_counter_ > 0) { |
168 *signal_type = SignalClassifier::SignalType::kNonStationary; | 167 *signal_type = SignalClassifier::SignalType::kNonStationary; |
169 } | 168 } |
170 } | 169 } |
171 | 170 |
172 } // namespace webrtc | 171 } // namespace webrtc |
OLD | NEW |