Index: webrtc/modules/audio_processing/test/py_quality_assessment/quality_assessment/eval_scores_unittest.py |
diff --git a/webrtc/modules/audio_processing/test/py_quality_assessment/quality_assessment/eval_scores_unittest.py b/webrtc/modules/audio_processing/test/py_quality_assessment/quality_assessment/eval_scores_unittest.py |
index 9b29555c1ac1eaf35a21b2dab429c8450f552d31..dae6a7cd8c73661463dce0f96062868af3574448 100644 |
--- a/webrtc/modules/audio_processing/test/py_quality_assessment/quality_assessment/eval_scores_unittest.py |
+++ b/webrtc/modules/audio_processing/test/py_quality_assessment/quality_assessment/eval_scores_unittest.py |
@@ -9,17 +9,76 @@ |
"""Unit tests for the eval_scores module. |
""" |
+import os |
+import shutil |
+import tempfile |
import unittest |
+import pydub |
+ |
+from . import data_access |
from . import eval_scores |
+from . import eval_scores_factory |
+from . import signal_processing |
class TestEvalScores(unittest.TestCase): |
"""Unit tests for the eval_scores module. |
""" |
+ def setUp(self): |
+ """Create temporary output folder and two audio track files.""" |
+ self._output_path = tempfile.mkdtemp() |
+ |
+ # Create fake reference and tested (i.e., APM output) audio track files. |
+ silence = pydub.AudioSegment.silent(duration=1000, frame_rate=48000) |
+ fake_reference_signal = ( |
+ signal_processing.SignalProcessingUtils.GenerateWhiteNoise(silence)) |
+ fake_tested_signal = ( |
+ signal_processing.SignalProcessingUtils.GenerateWhiteNoise(silence)) |
+ |
+ # Save fake audio tracks. |
+ self._fake_reference_signal_filepath = os.path.join( |
+ self._output_path, 'fake_ref.wav') |
+ signal_processing.SignalProcessingUtils.SaveWav( |
+ self._fake_reference_signal_filepath, fake_reference_signal) |
+ self._fake_tested_signal_filepath = os.path.join( |
+ self._output_path, 'fake_test.wav') |
+ signal_processing.SignalProcessingUtils.SaveWav( |
+ self._fake_tested_signal_filepath, fake_tested_signal) |
+ |
+ def tearDown(self): |
+ """Recursively delete temporary folder.""" |
+ shutil.rmtree(self._output_path) |
+ |
def test_registered_classes(self): |
+ # Preliminary check. |
+ self.assertTrue(os.path.exists(self._output_path)) |
+ |
# Check that there is at least one registered evaluation score worker. |
- classes = eval_scores.EvaluationScore.REGISTERED_CLASSES |
- self.assertIsInstance(classes, dict) |
- self.assertGreater(len(classes), 0) |
+ registered_classes = eval_scores.EvaluationScore.REGISTERED_CLASSES |
+ self.assertIsInstance(registered_classes, dict) |
+ self.assertGreater(len(registered_classes), 0) |
+ |
+ # Instance evaluation score workers factory with fake dependencies. |
+ eval_score_workers_factory = ( |
+ eval_scores_factory.EvaluationScoreWorkerFactory( |
+ polqa_tool_bin_path=os.path.join( |
+ os.path.dirname(os.path.abspath(__file__)), 'fake_polqa'))) |
+ |
+ # Try each registered evaluation score worker. |
+ for eval_score_name in registered_classes: |
+ # Instance evaluation score worker. |
+ eval_score_worker = eval_score_workers_factory.GetInstance( |
+ registered_classes[eval_score_name]) |
+ |
+ # Set reference and test, then run. |
+ eval_score_worker.SetReferenceSignalFilepath( |
+ self._fake_reference_signal_filepath) |
+ eval_score_worker.SetTestedSignalFilepath( |
+ self._fake_tested_signal_filepath) |
+ eval_score_worker.Run(self._output_path) |
+ |
+ # Check output. |
+ score = data_access.ScoreFile.Load(eval_score_worker.output_filepath) |
+ self.assertTrue(isinstance(score, float)) |