| OLD | NEW |
| 1 /* | 1 /* |
| 2 * Copyright (c) 2013 The WebRTC project authors. All Rights Reserved. | 2 * Copyright (c) 2013 The WebRTC project authors. All Rights Reserved. |
| 3 * | 3 * |
| 4 * Use of this source code is governed by a BSD-style license | 4 * Use of this source code is governed by a BSD-style license |
| 5 * that can be found in the LICENSE file in the root of the source | 5 * that can be found in the LICENSE file in the root of the source |
| 6 * tree. An additional intellectual property rights grant can be found | 6 * tree. An additional intellectual property rights grant can be found |
| 7 * in the file PATENTS. All contributing project authors may | 7 * in the file PATENTS. All contributing project authors may |
| 8 * be found in the AUTHORS file in the root of the source tree. | 8 * be found in the AUTHORS file in the root of the source tree. |
| 9 */ | 9 */ |
| 10 | 10 |
| 11 #include "webrtc/modules/audio_processing/transient/transient_suppressor.h" | 11 #include "webrtc/modules/audio_processing/transient/transient_suppressor.h" |
| 12 | 12 |
| 13 #include <stdlib.h> | 13 #include <stdlib.h> |
| 14 #include <stdio.h> | 14 #include <stdio.h> |
| 15 #include <string.h> |
| 15 | 16 |
| 16 #include <memory> | 17 #include <memory> |
| 17 #include <string> | 18 #include <string> |
| 18 | 19 |
| 19 #include "gflags/gflags.h" | |
| 20 #include "webrtc/common_audio/include/audio_util.h" | 20 #include "webrtc/common_audio/include/audio_util.h" |
| 21 #include "webrtc/modules/audio_processing/agc/agc.h" | 21 #include "webrtc/modules/audio_processing/agc/agc.h" |
| 22 #include "webrtc/modules/include/module_common_types.h" | 22 #include "webrtc/modules/include/module_common_types.h" |
| 23 #include "webrtc/rtc_base/flags.h" |
| 23 #include "webrtc/test/gtest.h" | 24 #include "webrtc/test/gtest.h" |
| 24 #include "webrtc/test/testsupport/fileutils.h" | 25 #include "webrtc/test/testsupport/fileutils.h" |
| 25 #include "webrtc/typedefs.h" | 26 #include "webrtc/typedefs.h" |
| 26 | 27 |
| 27 DEFINE_string(in_file_name, "", "PCM file that contains the signal."); | 28 DEFINE_string(in_file_name, "", "PCM file that contains the signal."); |
| 28 DEFINE_string(detection_file_name, | 29 DEFINE_string(detection_file_name, |
| 29 "", | 30 "", |
| 30 "PCM file that contains the detection signal."); | 31 "PCM file that contains the detection signal."); |
| 31 DEFINE_string(reference_file_name, | 32 DEFINE_string(reference_file_name, |
| 32 "", | 33 "", |
| 33 "PCM file that contains the reference signal."); | 34 "PCM file that contains the reference signal."); |
| 34 | 35 |
| 35 static bool ValidatePositiveInt(const char* flagname, int32_t value) { | 36 DEFINE_int(chunk_size_ms, |
| 36 if (value <= 0) { | 37 10, |
| 37 printf("%s must be a positive integer.\n", flagname); | 38 "Time between each chunk of samples in milliseconds."); |
| 38 return false; | |
| 39 } | |
| 40 return true; | |
| 41 } | |
| 42 DEFINE_int32(chunk_size_ms, | |
| 43 10, | |
| 44 "Time between each chunk of samples in milliseconds."); | |
| 45 static const bool chunk_size_ms_dummy = | |
| 46 google::RegisterFlagValidator(&FLAGS_chunk_size_ms, &ValidatePositiveInt); | |
| 47 | 39 |
| 48 DEFINE_int32(sample_rate_hz, | 40 DEFINE_int(sample_rate_hz, |
| 49 16000, | 41 16000, |
| 50 "Sampling frequency of the signal in Hertz."); | 42 "Sampling frequency of the signal in Hertz."); |
| 51 static const bool sample_rate_hz_dummy = | 43 DEFINE_int(detection_rate_hz, |
| 52 google::RegisterFlagValidator(&FLAGS_sample_rate_hz, &ValidatePositiveInt); | 44 0, |
| 53 DEFINE_int32(detection_rate_hz, | 45 "Sampling frequency of the detection signal in Hertz."); |
| 54 0, | |
| 55 "Sampling frequency of the detection signal in Hertz."); | |
| 56 | 46 |
| 57 DEFINE_int32(num_channels, 1, "Number of channels."); | 47 DEFINE_int(num_channels, 1, "Number of channels."); |
| 58 static const bool num_channels_dummy = | 48 |
| 59 google::RegisterFlagValidator(&FLAGS_num_channels, &ValidatePositiveInt); | 49 DEFINE_bool(help, false, "Print this message."); |
| 60 | 50 |
| 61 namespace webrtc { | 51 namespace webrtc { |
| 62 | 52 |
| 63 const char kUsage[] = | 53 const char kUsage[] = |
| 64 "\nDetects and suppresses transients from file.\n\n" | 54 "\nDetects and suppresses transients from file.\n\n" |
| 65 "This application loads the signal from the in_file_name with a specific\n" | 55 "This application loads the signal from the in_file_name with a specific\n" |
| 66 "num_channels and sample_rate_hz, the detection signal from the\n" | 56 "num_channels and sample_rate_hz, the detection signal from the\n" |
| 67 "detection_file_name with a specific detection_rate_hz, and the reference\n" | 57 "detection_file_name with a specific detection_rate_hz, and the reference\n" |
| 68 "signal from the reference_file_name with sample_rate_hz, divides them\n" | 58 "signal from the reference_file_name with sample_rate_hz, divides them\n" |
| 69 "into chunk_size_ms blocks, computes its voice value and depending on the\n" | 59 "into chunk_size_ms blocks, computes its voice value and depending on the\n" |
| (...skipping 69 matching lines...) Expand 10 before | Expand all | Expand 10 after Loading... |
| 139 } | 129 } |
| 140 | 130 |
| 141 // This application tests the transient suppression by providing a processed | 131 // This application tests the transient suppression by providing a processed |
| 142 // PCM file, which has to be listened to in order to evaluate the | 132 // PCM file, which has to be listened to in order to evaluate the |
| 143 // performance. | 133 // performance. |
| 144 // It gets an audio file, and its voice gain information, and the suppressor | 134 // It gets an audio file, and its voice gain information, and the suppressor |
| 145 // process it giving the output file "suppressed_keystrokes.pcm". | 135 // process it giving the output file "suppressed_keystrokes.pcm". |
| 146 void void_main() { | 136 void void_main() { |
| 147 // TODO(aluebs): Remove all FileWrappers. | 137 // TODO(aluebs): Remove all FileWrappers. |
| 148 // Prepare the input file. | 138 // Prepare the input file. |
| 149 FILE* in_file = fopen(FLAGS_in_file_name.c_str(), "rb"); | 139 FILE* in_file = fopen(FLAG_in_file_name, "rb"); |
| 150 ASSERT_TRUE(in_file != NULL); | 140 ASSERT_TRUE(in_file != NULL); |
| 151 | 141 |
| 152 // Prepare the detection file. | 142 // Prepare the detection file. |
| 153 FILE* detection_file = NULL; | 143 FILE* detection_file = NULL; |
| 154 if (!FLAGS_detection_file_name.empty()) { | 144 if (strlen(FLAG_detection_file_name) > 0) { |
| 155 detection_file = fopen(FLAGS_detection_file_name.c_str(), "rb"); | 145 detection_file = fopen(FLAG_detection_file_name, "rb"); |
| 156 } | 146 } |
| 157 | 147 |
| 158 // Prepare the reference file. | 148 // Prepare the reference file. |
| 159 FILE* reference_file = NULL; | 149 FILE* reference_file = NULL; |
| 160 if (!FLAGS_reference_file_name.empty()) { | 150 if (strlen(FLAG_reference_file_name) > 0) { |
| 161 reference_file = fopen(FLAGS_reference_file_name.c_str(), "rb"); | 151 reference_file = fopen(FLAG_reference_file_name, "rb"); |
| 162 } | 152 } |
| 163 | 153 |
| 164 // Prepare the output file. | 154 // Prepare the output file. |
| 165 std::string out_file_name = test::OutputPath() + "suppressed_keystrokes.pcm"; | 155 std::string out_file_name = test::OutputPath() + "suppressed_keystrokes.pcm"; |
| 166 FILE* out_file = fopen(out_file_name.c_str(), "wb"); | 156 FILE* out_file = fopen(out_file_name.c_str(), "wb"); |
| 167 ASSERT_TRUE(out_file != NULL); | 157 ASSERT_TRUE(out_file != NULL); |
| 168 | 158 |
| 169 int detection_rate_hz = FLAGS_detection_rate_hz; | 159 int detection_rate_hz = FLAG_detection_rate_hz; |
| 170 if (detection_rate_hz == 0) { | 160 if (detection_rate_hz == 0) { |
| 171 detection_rate_hz = FLAGS_sample_rate_hz; | 161 detection_rate_hz = FLAG_sample_rate_hz; |
| 172 } | 162 } |
| 173 | 163 |
| 174 Agc agc; | 164 Agc agc; |
| 175 | 165 |
| 176 TransientSuppressor suppressor; | 166 TransientSuppressor suppressor; |
| 177 suppressor.Initialize( | 167 suppressor.Initialize( |
| 178 FLAGS_sample_rate_hz, detection_rate_hz, FLAGS_num_channels); | 168 FLAG_sample_rate_hz, detection_rate_hz, FLAG_num_channels); |
| 179 | 169 |
| 180 const size_t audio_buffer_size = | 170 const size_t audio_buffer_size = |
| 181 FLAGS_chunk_size_ms * FLAGS_sample_rate_hz / 1000; | 171 FLAG_chunk_size_ms * FLAG_sample_rate_hz / 1000; |
| 182 const size_t detection_buffer_size = | 172 const size_t detection_buffer_size = |
| 183 FLAGS_chunk_size_ms * detection_rate_hz / 1000; | 173 FLAG_chunk_size_ms * detection_rate_hz / 1000; |
| 184 | 174 |
| 185 // int16 and float variants of the same data. | 175 // int16 and float variants of the same data. |
| 186 std::unique_ptr<int16_t[]> audio_buffer_i( | 176 std::unique_ptr<int16_t[]> audio_buffer_i( |
| 187 new int16_t[FLAGS_num_channels * audio_buffer_size]); | 177 new int16_t[FLAG_num_channels * audio_buffer_size]); |
| 188 std::unique_ptr<float[]> audio_buffer_f( | 178 std::unique_ptr<float[]> audio_buffer_f( |
| 189 new float[FLAGS_num_channels * audio_buffer_size]); | 179 new float[FLAG_num_channels * audio_buffer_size]); |
| 190 | 180 |
| 191 std::unique_ptr<float[]> detection_buffer, reference_buffer; | 181 std::unique_ptr<float[]> detection_buffer, reference_buffer; |
| 192 | 182 |
| 193 if (detection_file) | 183 if (detection_file) |
| 194 detection_buffer.reset(new float[detection_buffer_size]); | 184 detection_buffer.reset(new float[detection_buffer_size]); |
| 195 if (reference_file) | 185 if (reference_file) |
| 196 reference_buffer.reset(new float[audio_buffer_size]); | 186 reference_buffer.reset(new float[audio_buffer_size]); |
| 197 | 187 |
| 198 while (ReadBuffers(in_file, | 188 while (ReadBuffers(in_file, |
| 199 audio_buffer_size, | 189 audio_buffer_size, |
| 200 FLAGS_num_channels, | 190 FLAG_num_channels, |
| 201 audio_buffer_i.get(), | 191 audio_buffer_i.get(), |
| 202 detection_file, | 192 detection_file, |
| 203 detection_buffer_size, | 193 detection_buffer_size, |
| 204 detection_buffer.get(), | 194 detection_buffer.get(), |
| 205 reference_file, | 195 reference_file, |
| 206 reference_buffer.get())) { | 196 reference_buffer.get())) { |
| 207 ASSERT_EQ(0, | 197 ASSERT_EQ(0, |
| 208 agc.Process(audio_buffer_i.get(), | 198 agc.Process(audio_buffer_i.get(), |
| 209 static_cast<int>(audio_buffer_size), | 199 static_cast<int>(audio_buffer_size), |
| 210 FLAGS_sample_rate_hz)) | 200 FLAG_sample_rate_hz)) |
| 211 << "The AGC could not process the frame"; | 201 << "The AGC could not process the frame"; |
| 212 | 202 |
| 213 for (size_t i = 0; i < FLAGS_num_channels * audio_buffer_size; ++i) { | 203 for (size_t i = 0; i < FLAG_num_channels * audio_buffer_size; ++i) { |
| 214 audio_buffer_f[i] = audio_buffer_i[i]; | 204 audio_buffer_f[i] = audio_buffer_i[i]; |
| 215 } | 205 } |
| 216 | 206 |
| 217 ASSERT_EQ(0, | 207 ASSERT_EQ(0, |
| 218 suppressor.Suppress(audio_buffer_f.get(), | 208 suppressor.Suppress(audio_buffer_f.get(), |
| 219 audio_buffer_size, | 209 audio_buffer_size, |
| 220 FLAGS_num_channels, | 210 FLAG_num_channels, |
| 221 detection_buffer.get(), | 211 detection_buffer.get(), |
| 222 detection_buffer_size, | 212 detection_buffer_size, |
| 223 reference_buffer.get(), | 213 reference_buffer.get(), |
| 224 audio_buffer_size, | 214 audio_buffer_size, |
| 225 agc.voice_probability(), | 215 agc.voice_probability(), |
| 226 true)) | 216 true)) |
| 227 << "The transient suppressor could not suppress the frame"; | 217 << "The transient suppressor could not suppress the frame"; |
| 228 | 218 |
| 229 // Write result to out file. | 219 // Write result to out file. |
| 230 WritePCM( | 220 WritePCM( |
| 231 out_file, audio_buffer_size, FLAGS_num_channels, audio_buffer_f.get()); | 221 out_file, audio_buffer_size, FLAG_num_channels, audio_buffer_f.get()); |
| 232 } | 222 } |
| 233 | 223 |
| 234 fclose(in_file); | 224 fclose(in_file); |
| 235 if (detection_file) { | 225 if (detection_file) { |
| 236 fclose(detection_file); | 226 fclose(detection_file); |
| 237 } | 227 } |
| 238 if (reference_file) { | 228 if (reference_file) { |
| 239 fclose(reference_file); | 229 fclose(reference_file); |
| 240 } | 230 } |
| 241 fclose(out_file); | 231 fclose(out_file); |
| 242 } | 232 } |
| 243 | 233 |
| 244 } // namespace webrtc | 234 } // namespace webrtc |
| 245 | 235 |
| 246 int main(int argc, char* argv[]) { | 236 int main(int argc, char* argv[]) { |
| 247 google::SetUsageMessage(webrtc::kUsage); | 237 if (rtc::FlagList::SetFlagsFromCommandLine(&argc, argv, true) || |
| 248 google::ParseCommandLineFlags(&argc, &argv, true); | 238 FLAG_help || argc != 1) { |
| 239 printf("%s", webrtc::kUsage); |
| 240 if (FLAG_help) { |
| 241 rtc::FlagList::Print(nullptr, false); |
| 242 return 0; |
| 243 } |
| 244 return 1; |
| 245 } |
| 246 RTC_CHECK_GT(FLAG_chunk_size_ms, 0); |
| 247 RTC_CHECK_GT(FLAG_sample_rate_hz, 0); |
| 248 RTC_CHECK_GT(FLAG_num_channels, 0); |
| 249 |
| 249 webrtc::void_main(); | 250 webrtc::void_main(); |
| 250 return 0; | 251 return 0; |
| 251 } | 252 } |
| OLD | NEW |