OLD | NEW |
1 # Copyright (c) 2017 The WebRTC project authors. All Rights Reserved. | 1 # Copyright (c) 2017 The WebRTC project authors. All Rights Reserved. |
2 # | 2 # |
3 # Use of this source code is governed by a BSD-style license | 3 # Use of this source code is governed by a BSD-style license |
4 # that can be found in the LICENSE file in the root of the source | 4 # that can be found in the LICENSE file in the root of the source |
5 # tree. An additional intellectual property rights grant can be found | 5 # tree. An additional intellectual property rights grant can be found |
6 # in the file PATENTS. All contributing project authors may | 6 # in the file PATENTS. All contributing project authors may |
7 # be found in the AUTHORS file in the root of the source tree. | 7 # be found in the AUTHORS file in the root of the source tree. |
8 | 8 |
9 """Unit tests for the eval_scores module. | 9 """Unit tests for the eval_scores module. |
10 """ | 10 """ |
(...skipping 34 matching lines...) Expand 10 before | Expand all | Expand 10 after Loading... |
45 self._fake_tested_signal_filepath = os.path.join( | 45 self._fake_tested_signal_filepath = os.path.join( |
46 self._output_path, 'fake_test.wav') | 46 self._output_path, 'fake_test.wav') |
47 signal_processing.SignalProcessingUtils.SaveWav( | 47 signal_processing.SignalProcessingUtils.SaveWav( |
48 self._fake_tested_signal_filepath, fake_tested_signal) | 48 self._fake_tested_signal_filepath, fake_tested_signal) |
49 | 49 |
50 def tearDown(self): | 50 def tearDown(self): |
51 """Recursively delete temporary folder.""" | 51 """Recursively delete temporary folder.""" |
52 shutil.rmtree(self._output_path) | 52 shutil.rmtree(self._output_path) |
53 | 53 |
54 def testRegisteredClasses(self): | 54 def testRegisteredClasses(self): |
| 55 # Evaluation score names to exclude (tested separately). |
| 56 exceptions = ['thd'] |
| 57 |
55 # Preliminary check. | 58 # Preliminary check. |
56 self.assertTrue(os.path.exists(self._output_path)) | 59 self.assertTrue(os.path.exists(self._output_path)) |
57 | 60 |
58 # Check that there is at least one registered evaluation score worker. | 61 # Check that there is at least one registered evaluation score worker. |
59 registered_classes = eval_scores.EvaluationScore.REGISTERED_CLASSES | 62 registered_classes = eval_scores.EvaluationScore.REGISTERED_CLASSES |
60 self.assertIsInstance(registered_classes, dict) | 63 self.assertIsInstance(registered_classes, dict) |
61 self.assertGreater(len(registered_classes), 0) | 64 self.assertGreater(len(registered_classes), 0) |
62 | 65 |
63 # Instance evaluation score workers factory with fake dependencies. | 66 # Instance evaluation score workers factory with fake dependencies. |
64 eval_score_workers_factory = ( | 67 eval_score_workers_factory = ( |
65 eval_scores_factory.EvaluationScoreWorkerFactory( | 68 eval_scores_factory.EvaluationScoreWorkerFactory( |
66 score_filename_prefix='scores-', | 69 score_filename_prefix='scores-', |
67 polqa_tool_bin_path=os.path.join( | 70 polqa_tool_bin_path=os.path.join( |
68 os.path.dirname(os.path.abspath(__file__)), 'fake_polqa'))) | 71 os.path.dirname(os.path.abspath(__file__)), 'fake_polqa'))) |
69 | 72 |
70 # Try each registered evaluation score worker. | 73 # Try each registered evaluation score worker. |
71 for eval_score_name in registered_classes: | 74 for eval_score_name in registered_classes: |
| 75 if eval_score_name in exceptions: |
| 76 continue |
| 77 |
72 # Instance evaluation score worker. | 78 # Instance evaluation score worker. |
73 eval_score_worker = eval_score_workers_factory.GetInstance( | 79 eval_score_worker = eval_score_workers_factory.GetInstance( |
74 registered_classes[eval_score_name]) | 80 registered_classes[eval_score_name]) |
75 | 81 |
76 # Set reference and test, then run. | 82 # Set fake input metadata and reference and test file paths, then run. |
77 eval_score_worker.SetReferenceSignalFilepath( | 83 eval_score_worker.SetReferenceSignalFilepath( |
78 self._fake_reference_signal_filepath) | 84 self._fake_reference_signal_filepath) |
79 eval_score_worker.SetTestedSignalFilepath( | 85 eval_score_worker.SetTestedSignalFilepath( |
80 self._fake_tested_signal_filepath) | 86 self._fake_tested_signal_filepath) |
81 eval_score_worker.Run(self._output_path) | 87 eval_score_worker.Run(self._output_path) |
82 | 88 |
83 # Check output. | 89 # Check output. |
84 score = data_access.ScoreFile.Load(eval_score_worker.output_filepath) | 90 score = data_access.ScoreFile.Load(eval_score_worker.output_filepath) |
85 self.assertTrue(isinstance(score, float)) | 91 self.assertTrue(isinstance(score, float)) |
| 92 |
| 93 def testTotalHarmonicDistorsionScore(self): |
| 94 # Init. |
| 95 pure_tone_freq = 5000.0 |
| 96 eval_score_worker = eval_scores.TotalHarmonicDistorsionScore('scores-') |
| 97 eval_score_worker.SetInputSignalMetadata({ |
| 98 'signal': 'pure_tone', |
| 99 'frequency': pure_tone_freq, |
| 100 'test_data_gen_name': 'identity', |
| 101 'test_data_gen_config': 'default', |
| 102 }) |
| 103 template = pydub.AudioSegment.silent(duration=1000, frame_rate=48000) |
| 104 |
| 105 # Create 3 test signals: pure tone, pure tone + white noise, white noise |
| 106 # only. |
| 107 pure_tone = signal_processing.SignalProcessingUtils.GeneratePureTone( |
| 108 template, pure_tone_freq) |
| 109 white_noise = signal_processing.SignalProcessingUtils.GenerateWhiteNoise( |
| 110 template) |
| 111 noisy_tone = signal_processing.SignalProcessingUtils.MixSignals( |
| 112 pure_tone, white_noise) |
| 113 |
| 114 # Compute scores for increasingly distorted pure tone signals. |
| 115 scores = [None, None, None] |
| 116 for index, tested_signal in enumerate([pure_tone, noisy_tone, white_noise]): |
| 117 # Save signal. |
| 118 tmp_filepath = os.path.join(self._output_path, 'tmp_thd.wav') |
| 119 signal_processing.SignalProcessingUtils.SaveWav( |
| 120 tmp_filepath, tested_signal) |
| 121 |
| 122 # Compute score. |
| 123 eval_score_worker.SetTestedSignalFilepath(tmp_filepath) |
| 124 eval_score_worker.Run(self._output_path) |
| 125 scores[index] = eval_score_worker.score |
| 126 |
| 127 # Remove output file to avoid caching. |
| 128 os.remove(eval_score_worker.output_filepath) |
| 129 |
| 130 # Validate scores (lowest score with a pure tone). |
| 131 self.assertTrue(all([scores[i + 1] > scores[i] for i in range(2)])) |
OLD | NEW |