OLD | NEW |
1 /* | 1 /* |
2 * Copyright (c) 2013 The WebRTC project authors. All Rights Reserved. | 2 * Copyright (c) 2013 The WebRTC project authors. All Rights Reserved. |
3 * | 3 * |
4 * Use of this source code is governed by a BSD-style license | 4 * Use of this source code is governed by a BSD-style license |
5 * that can be found in the LICENSE file in the root of the source | 5 * that can be found in the LICENSE file in the root of the source |
6 * tree. An additional intellectual property rights grant can be found | 6 * tree. An additional intellectual property rights grant can be found |
7 * in the file PATENTS. All contributing project authors may | 7 * in the file PATENTS. All contributing project authors may |
8 * be found in the AUTHORS file in the root of the source tree. | 8 * be found in the AUTHORS file in the root of the source tree. |
9 */ | 9 */ |
10 | 10 |
11 #include "webrtc/modules/audio_processing/transient/transient_suppressor.h" | 11 #include "webrtc/modules/audio_processing/transient/transient_suppressor.h" |
12 | 12 |
13 #include <stdlib.h> | 13 #include <stdlib.h> |
14 #include <stdio.h> | 14 #include <stdio.h> |
| 15 #include <string.h> |
15 | 16 |
16 #include <memory> | 17 #include <memory> |
17 #include <string> | 18 #include <string> |
18 | 19 |
19 #include "gflags/gflags.h" | |
20 #include "webrtc/common_audio/include/audio_util.h" | 20 #include "webrtc/common_audio/include/audio_util.h" |
21 #include "webrtc/modules/audio_processing/agc/agc.h" | 21 #include "webrtc/modules/audio_processing/agc/agc.h" |
22 #include "webrtc/modules/include/module_common_types.h" | 22 #include "webrtc/modules/include/module_common_types.h" |
| 23 #include "webrtc/rtc_base/flags.h" |
23 #include "webrtc/test/gtest.h" | 24 #include "webrtc/test/gtest.h" |
24 #include "webrtc/test/testsupport/fileutils.h" | 25 #include "webrtc/test/testsupport/fileutils.h" |
25 #include "webrtc/typedefs.h" | 26 #include "webrtc/typedefs.h" |
26 | 27 |
27 DEFINE_string(in_file_name, "", "PCM file that contains the signal."); | 28 DEFINE_string(in_file_name, "", "PCM file that contains the signal."); |
28 DEFINE_string(detection_file_name, | 29 DEFINE_string(detection_file_name, |
29 "", | 30 "", |
30 "PCM file that contains the detection signal."); | 31 "PCM file that contains the detection signal."); |
31 DEFINE_string(reference_file_name, | 32 DEFINE_string(reference_file_name, |
32 "", | 33 "", |
33 "PCM file that contains the reference signal."); | 34 "PCM file that contains the reference signal."); |
34 | 35 |
35 static bool ValidatePositiveInt(const char* flagname, int32_t value) { | 36 DEFINE_int(chunk_size_ms, |
36 if (value <= 0) { | 37 10, |
37 printf("%s must be a positive integer.\n", flagname); | 38 "Time between each chunk of samples in milliseconds."); |
38 return false; | |
39 } | |
40 return true; | |
41 } | |
42 DEFINE_int32(chunk_size_ms, | |
43 10, | |
44 "Time between each chunk of samples in milliseconds."); | |
45 static const bool chunk_size_ms_dummy = | |
46 google::RegisterFlagValidator(&FLAGS_chunk_size_ms, &ValidatePositiveInt); | |
47 | 39 |
48 DEFINE_int32(sample_rate_hz, | 40 DEFINE_int(sample_rate_hz, |
49 16000, | 41 16000, |
50 "Sampling frequency of the signal in Hertz."); | 42 "Sampling frequency of the signal in Hertz."); |
51 static const bool sample_rate_hz_dummy = | 43 DEFINE_int(detection_rate_hz, |
52 google::RegisterFlagValidator(&FLAGS_sample_rate_hz, &ValidatePositiveInt); | 44 0, |
53 DEFINE_int32(detection_rate_hz, | 45 "Sampling frequency of the detection signal in Hertz."); |
54 0, | |
55 "Sampling frequency of the detection signal in Hertz."); | |
56 | 46 |
57 DEFINE_int32(num_channels, 1, "Number of channels."); | 47 DEFINE_int(num_channels, 1, "Number of channels."); |
58 static const bool num_channels_dummy = | 48 |
59 google::RegisterFlagValidator(&FLAGS_num_channels, &ValidatePositiveInt); | 49 DEFINE_bool(help, false, "Print this message."); |
60 | 50 |
61 namespace webrtc { | 51 namespace webrtc { |
62 | 52 |
63 const char kUsage[] = | 53 const char kUsage[] = |
64 "\nDetects and suppresses transients from file.\n\n" | 54 "\nDetects and suppresses transients from file.\n\n" |
65 "This application loads the signal from the in_file_name with a specific\n" | 55 "This application loads the signal from the in_file_name with a specific\n" |
66 "num_channels and sample_rate_hz, the detection signal from the\n" | 56 "num_channels and sample_rate_hz, the detection signal from the\n" |
67 "detection_file_name with a specific detection_rate_hz, and the reference\n" | 57 "detection_file_name with a specific detection_rate_hz, and the reference\n" |
68 "signal from the reference_file_name with sample_rate_hz, divides them\n" | 58 "signal from the reference_file_name with sample_rate_hz, divides them\n" |
69 "into chunk_size_ms blocks, computes its voice value and depending on the\n" | 59 "into chunk_size_ms blocks, computes its voice value and depending on the\n" |
(...skipping 69 matching lines...) Expand 10 before | Expand all | Expand 10 after Loading... |
139 } | 129 } |
140 | 130 |
141 // This application tests the transient suppression by providing a processed | 131 // This application tests the transient suppression by providing a processed |
142 // PCM file, which has to be listened to in order to evaluate the | 132 // PCM file, which has to be listened to in order to evaluate the |
143 // performance. | 133 // performance. |
144 // It gets an audio file, and its voice gain information, and the suppressor | 134 // It gets an audio file, and its voice gain information, and the suppressor |
145 // process it giving the output file "suppressed_keystrokes.pcm". | 135 // process it giving the output file "suppressed_keystrokes.pcm". |
146 void void_main() { | 136 void void_main() { |
147 // TODO(aluebs): Remove all FileWrappers. | 137 // TODO(aluebs): Remove all FileWrappers. |
148 // Prepare the input file. | 138 // Prepare the input file. |
149 FILE* in_file = fopen(FLAGS_in_file_name.c_str(), "rb"); | 139 FILE* in_file = fopen(FLAG_in_file_name, "rb"); |
150 ASSERT_TRUE(in_file != NULL); | 140 ASSERT_TRUE(in_file != NULL); |
151 | 141 |
152 // Prepare the detection file. | 142 // Prepare the detection file. |
153 FILE* detection_file = NULL; | 143 FILE* detection_file = NULL; |
154 if (!FLAGS_detection_file_name.empty()) { | 144 if (strlen(FLAG_detection_file_name) > 0) { |
155 detection_file = fopen(FLAGS_detection_file_name.c_str(), "rb"); | 145 detection_file = fopen(FLAG_detection_file_name, "rb"); |
156 } | 146 } |
157 | 147 |
158 // Prepare the reference file. | 148 // Prepare the reference file. |
159 FILE* reference_file = NULL; | 149 FILE* reference_file = NULL; |
160 if (!FLAGS_reference_file_name.empty()) { | 150 if (strlen(FLAG_reference_file_name) > 0) { |
161 reference_file = fopen(FLAGS_reference_file_name.c_str(), "rb"); | 151 reference_file = fopen(FLAG_reference_file_name, "rb"); |
162 } | 152 } |
163 | 153 |
164 // Prepare the output file. | 154 // Prepare the output file. |
165 std::string out_file_name = test::OutputPath() + "suppressed_keystrokes.pcm"; | 155 std::string out_file_name = test::OutputPath() + "suppressed_keystrokes.pcm"; |
166 FILE* out_file = fopen(out_file_name.c_str(), "wb"); | 156 FILE* out_file = fopen(out_file_name.c_str(), "wb"); |
167 ASSERT_TRUE(out_file != NULL); | 157 ASSERT_TRUE(out_file != NULL); |
168 | 158 |
169 int detection_rate_hz = FLAGS_detection_rate_hz; | 159 int detection_rate_hz = FLAG_detection_rate_hz; |
170 if (detection_rate_hz == 0) { | 160 if (detection_rate_hz == 0) { |
171 detection_rate_hz = FLAGS_sample_rate_hz; | 161 detection_rate_hz = FLAG_sample_rate_hz; |
172 } | 162 } |
173 | 163 |
174 Agc agc; | 164 Agc agc; |
175 | 165 |
176 TransientSuppressor suppressor; | 166 TransientSuppressor suppressor; |
177 suppressor.Initialize( | 167 suppressor.Initialize( |
178 FLAGS_sample_rate_hz, detection_rate_hz, FLAGS_num_channels); | 168 FLAG_sample_rate_hz, detection_rate_hz, FLAG_num_channels); |
179 | 169 |
180 const size_t audio_buffer_size = | 170 const size_t audio_buffer_size = |
181 FLAGS_chunk_size_ms * FLAGS_sample_rate_hz / 1000; | 171 FLAG_chunk_size_ms * FLAG_sample_rate_hz / 1000; |
182 const size_t detection_buffer_size = | 172 const size_t detection_buffer_size = |
183 FLAGS_chunk_size_ms * detection_rate_hz / 1000; | 173 FLAG_chunk_size_ms * detection_rate_hz / 1000; |
184 | 174 |
185 // int16 and float variants of the same data. | 175 // int16 and float variants of the same data. |
186 std::unique_ptr<int16_t[]> audio_buffer_i( | 176 std::unique_ptr<int16_t[]> audio_buffer_i( |
187 new int16_t[FLAGS_num_channels * audio_buffer_size]); | 177 new int16_t[FLAG_num_channels * audio_buffer_size]); |
188 std::unique_ptr<float[]> audio_buffer_f( | 178 std::unique_ptr<float[]> audio_buffer_f( |
189 new float[FLAGS_num_channels * audio_buffer_size]); | 179 new float[FLAG_num_channels * audio_buffer_size]); |
190 | 180 |
191 std::unique_ptr<float[]> detection_buffer, reference_buffer; | 181 std::unique_ptr<float[]> detection_buffer, reference_buffer; |
192 | 182 |
193 if (detection_file) | 183 if (detection_file) |
194 detection_buffer.reset(new float[detection_buffer_size]); | 184 detection_buffer.reset(new float[detection_buffer_size]); |
195 if (reference_file) | 185 if (reference_file) |
196 reference_buffer.reset(new float[audio_buffer_size]); | 186 reference_buffer.reset(new float[audio_buffer_size]); |
197 | 187 |
198 while (ReadBuffers(in_file, | 188 while (ReadBuffers(in_file, |
199 audio_buffer_size, | 189 audio_buffer_size, |
200 FLAGS_num_channels, | 190 FLAG_num_channels, |
201 audio_buffer_i.get(), | 191 audio_buffer_i.get(), |
202 detection_file, | 192 detection_file, |
203 detection_buffer_size, | 193 detection_buffer_size, |
204 detection_buffer.get(), | 194 detection_buffer.get(), |
205 reference_file, | 195 reference_file, |
206 reference_buffer.get())) { | 196 reference_buffer.get())) { |
207 ASSERT_EQ(0, | 197 ASSERT_EQ(0, |
208 agc.Process(audio_buffer_i.get(), | 198 agc.Process(audio_buffer_i.get(), |
209 static_cast<int>(audio_buffer_size), | 199 static_cast<int>(audio_buffer_size), |
210 FLAGS_sample_rate_hz)) | 200 FLAG_sample_rate_hz)) |
211 << "The AGC could not process the frame"; | 201 << "The AGC could not process the frame"; |
212 | 202 |
213 for (size_t i = 0; i < FLAGS_num_channels * audio_buffer_size; ++i) { | 203 for (size_t i = 0; i < FLAG_num_channels * audio_buffer_size; ++i) { |
214 audio_buffer_f[i] = audio_buffer_i[i]; | 204 audio_buffer_f[i] = audio_buffer_i[i]; |
215 } | 205 } |
216 | 206 |
217 ASSERT_EQ(0, | 207 ASSERT_EQ(0, |
218 suppressor.Suppress(audio_buffer_f.get(), | 208 suppressor.Suppress(audio_buffer_f.get(), |
219 audio_buffer_size, | 209 audio_buffer_size, |
220 FLAGS_num_channels, | 210 FLAG_num_channels, |
221 detection_buffer.get(), | 211 detection_buffer.get(), |
222 detection_buffer_size, | 212 detection_buffer_size, |
223 reference_buffer.get(), | 213 reference_buffer.get(), |
224 audio_buffer_size, | 214 audio_buffer_size, |
225 agc.voice_probability(), | 215 agc.voice_probability(), |
226 true)) | 216 true)) |
227 << "The transient suppressor could not suppress the frame"; | 217 << "The transient suppressor could not suppress the frame"; |
228 | 218 |
229 // Write result to out file. | 219 // Write result to out file. |
230 WritePCM( | 220 WritePCM( |
231 out_file, audio_buffer_size, FLAGS_num_channels, audio_buffer_f.get()); | 221 out_file, audio_buffer_size, FLAG_num_channels, audio_buffer_f.get()); |
232 } | 222 } |
233 | 223 |
234 fclose(in_file); | 224 fclose(in_file); |
235 if (detection_file) { | 225 if (detection_file) { |
236 fclose(detection_file); | 226 fclose(detection_file); |
237 } | 227 } |
238 if (reference_file) { | 228 if (reference_file) { |
239 fclose(reference_file); | 229 fclose(reference_file); |
240 } | 230 } |
241 fclose(out_file); | 231 fclose(out_file); |
242 } | 232 } |
243 | 233 |
244 } // namespace webrtc | 234 } // namespace webrtc |
245 | 235 |
246 int main(int argc, char* argv[]) { | 236 int main(int argc, char* argv[]) { |
247 google::SetUsageMessage(webrtc::kUsage); | 237 if (rtc::FlagList::SetFlagsFromCommandLine(&argc, argv, true) || |
248 google::ParseCommandLineFlags(&argc, &argv, true); | 238 FLAG_help || argc != 1) { |
| 239 printf("%s", webrtc::kUsage); |
| 240 if (FLAG_help) { |
| 241 rtc::FlagList::Print(nullptr, false); |
| 242 return 0; |
| 243 } |
| 244 return 1; |
| 245 } |
| 246 RTC_CHECK_GT(FLAG_chunk_size_ms, 0); |
| 247 RTC_CHECK_GT(FLAG_sample_rate_hz, 0); |
| 248 RTC_CHECK_GT(FLAG_num_channels, 0); |
| 249 |
249 webrtc::void_main(); | 250 webrtc::void_main(); |
250 return 0; | 251 return 0; |
251 } | 252 } |
OLD | NEW |