Index: webrtc/modules/audio_processing/test/py_quality_assessment/quality_assessment/test_data_generation_unittest.py |
diff --git a/webrtc/modules/audio_processing/test/py_quality_assessment/quality_assessment/noise_generation_unittest.py b/webrtc/modules/audio_processing/test/py_quality_assessment/quality_assessment/test_data_generation_unittest.py |
similarity index 52% |
rename from webrtc/modules/audio_processing/test/py_quality_assessment/quality_assessment/noise_generation_unittest.py |
rename to webrtc/modules/audio_processing/test/py_quality_assessment/quality_assessment/test_data_generation_unittest.py |
index c5dfed2796a42a4d4b2f8c1c638c298ef40fb36f..5261dd25f4cae1ec3393379ed47ca242ab1af37e 100644 |
--- a/webrtc/modules/audio_processing/test/py_quality_assessment/quality_assessment/noise_generation_unittest.py |
+++ b/webrtc/modules/audio_processing/test/py_quality_assessment/quality_assessment/test_data_generation_unittest.py |
@@ -6,7 +6,7 @@ |
# in the file PATENTS. All contributing project authors may |
# be found in the AUTHORS file in the root of the source tree. |
-"""Unit tests for the noise_generation module. |
+"""Unit tests for the test_data_generation module. |
""" |
import os |
@@ -14,13 +14,13 @@ import shutil |
import tempfile |
import unittest |
-from . import noise_generation |
-from . import noise_generation_factory |
+from . import test_data_generation |
+from . import test_data_generation_factory |
from . import signal_processing |
-class TestNoiseGen(unittest.TestCase): |
- """Unit tests for the noise_generation module. |
+class TestTestDataGenerators(unittest.TestCase): |
+ """Unit tests for the test_data_generation module. |
""" |
def setUp(self): |
@@ -33,22 +33,25 @@ class TestNoiseGen(unittest.TestCase): |
shutil.rmtree(self._base_output_path) |
shutil.rmtree(self._input_noise_cache_path) |
- def testNoiseGenerators(self): |
+ def testTestDataGenerators(self): |
# Preliminary check. |
self.assertTrue(os.path.exists(self._base_output_path)) |
self.assertTrue(os.path.exists(self._input_noise_cache_path)) |
- # Check that there is at least one registered noise generator. |
- registered_classes = noise_generation.NoiseGenerator.REGISTERED_CLASSES |
+ # Check that there is at least one registered test data generator. |
+ registered_classes = ( |
+ test_data_generation.TestDataGenerator.REGISTERED_CLASSES) |
self.assertIsInstance(registered_classes, dict) |
self.assertGreater(len(registered_classes), 0) |
- # Instance noise generator factory. |
- noise_generator_factory = noise_generation_factory.NoiseGeneratorFactory( |
- aechen_ir_database_path='') |
- # TODO(alessiob): Replace with a mock of NoiseGeneratorFactory that takes |
- # no arguments in the ctor. For those generators that need parameters, it |
- # will return a mock generator (see the first comment in the next for loop). |
+ # Instance generators factory. |
+ generators_factory = ( |
+ test_data_generation_factory.TestDataGeneratorFactory( |
+ aechen_ir_database_path='')) |
+ # TODO(alessiob): Replace with a mock of TestDataGeneratorFactory that |
+ # takes no arguments in the ctor. For those generators that need parameters, |
+ # it will return a mock generator (see the first comment in the next for |
+ # loop). |
# Use a sample input file as clean input signal. |
input_signal_filepath = os.path.join( |
@@ -59,64 +62,62 @@ class TestNoiseGen(unittest.TestCase): |
input_signal = signal_processing.SignalProcessingUtils.LoadWav( |
input_signal_filepath) |
- # Try each registered noise generator. |
- for noise_generator_name in registered_classes: |
- # Exclude EchoNoiseGenerator. |
- # TODO(alessiob): Mock EchoNoiseGenerator, the mock should rely on |
- # hard-coded impulse responses. This requires a mock for |
- # NoiseGeneratorFactory. The latter knows whether returning the actual |
- # generator or a mock object (as in the case of EchoNoiseGenerator). |
- if noise_generator_name == 'echo': |
+ # Try each registered test data generator. |
+ for generator_name in registered_classes: |
+ # Exclude ReverberationTestDataGenerator. |
+ # TODO(alessiob): Mock ReverberationTestDataGenerator, the mock |
+ # should rely on hard-coded impulse responses. This requires a mock for |
+ # TestDataGeneratorFactory. The latter knows whether returning the |
+ # actual generator or a mock object (as in the case of |
+ # ReverberationTestDataGenerator). |
+ if generator_name == ( |
+ test_data_generation.ReverberationTestDataGenerator.NAME): |
continue |
- # Instance noise generator. |
- noise_generator = noise_generator_factory.GetInstance( |
- registered_classes[noise_generator_name]) |
+ # Instance test data generator. |
+ generator = generators_factory.GetInstance( |
+ registered_classes[generator_name]) |
# Generate the noisy input - reference pairs. |
- noise_generator.Generate( |
+ generator.Generate( |
input_signal_filepath=input_signal_filepath, |
input_noise_cache_path=self._input_noise_cache_path, |
base_output_path=self._base_output_path) |
# Perform checks. |
- self._CheckNoiseGeneratorPairsListSizes(noise_generator) |
- self._CheckNoiseGeneratorPairsSignalDurations( |
- noise_generator, input_signal) |
- self._CheckNoiseGeneratorPairsOutputPaths(noise_generator) |
+ self._CheckGeneratedPairsListSizes(generator) |
+ self._CheckGeneratedPairsSignalDurations(generator, input_signal) |
+ self._CheckGeneratedPairsOutputPaths(generator) |
- def _CheckNoiseGeneratorPairsListSizes(self, noise_generator): |
- # Noise configuration names. |
- noise_config_names = noise_generator.config_names |
- number_of_pairs = len(noise_config_names) |
- |
- # Check. |
+ def _CheckGeneratedPairsListSizes(self, generator): |
+ config_names = generator.config_names |
+ number_of_pairs = len(config_names) |
self.assertEqual(number_of_pairs, |
- len(noise_generator.noisy_signal_filepaths)) |
+ len(generator.noisy_signal_filepaths)) |
self.assertEqual(number_of_pairs, |
- len(noise_generator.apm_output_paths)) |
+ len(generator.apm_output_paths)) |
self.assertEqual(number_of_pairs, |
- len(noise_generator.reference_signal_filepaths)) |
+ len(generator.reference_signal_filepaths)) |
- def _CheckNoiseGeneratorPairsSignalDurations( |
- self, noise_generator, input_signal): |
- """Check duration of the signals generated by a noise generator. |
+ def _CheckGeneratedPairsSignalDurations( |
+ self, generator, input_signal): |
+ """Checks duration of the generated signals. |
Checks that the noisy input and the reference tracks are audio files |
with duration equal to or greater than that of the input signal. |
Args: |
- noise_generator: NoiseGenerator instance. |
+ generator: TestDataGenerator instance. |
input_signal: AudioSegment instance. |
""" |
input_signal_length = ( |
signal_processing.SignalProcessingUtils.CountSamples(input_signal)) |
# Iterate over the noisy signal - reference pairs. |
- for noise_config_name in noise_generator.config_names: |
+ for config_name in generator.config_names: |
# Load the noisy input file. |
- noisy_signal_filepath = noise_generator.noisy_signal_filepaths[ |
- noise_config_name] |
+ noisy_signal_filepath = generator.noisy_signal_filepaths[ |
+ config_name] |
noisy_signal = signal_processing.SignalProcessingUtils.LoadWav( |
noisy_signal_filepath) |
@@ -126,8 +127,8 @@ class TestNoiseGen(unittest.TestCase): |
self.assertGreaterEqual(noisy_signal_length, input_signal_length) |
# Load the reference file. |
- reference_signal_filepath = ( |
- noise_generator.reference_signal_filepaths[noise_config_name]) |
+ reference_signal_filepath = generator.reference_signal_filepaths[ |
+ config_name] |
reference_signal = signal_processing.SignalProcessingUtils.LoadWav( |
reference_signal_filepath) |
@@ -137,13 +138,13 @@ class TestNoiseGen(unittest.TestCase): |
reference_signal)) |
self.assertGreaterEqual(reference_signal_length, input_signal_length) |
- def _CheckNoiseGeneratorPairsOutputPaths(self, noise_generator): |
+ def _CheckGeneratedPairsOutputPaths(self, generator): |
"""Checks that the output path created by the generator exists. |
Args: |
- noise_generator: NoiseGenerator instance. |
+ generator: TestDataGenerator instance. |
""" |
# Iterate over the noisy signal - reference pairs. |
- for noise_config_name in noise_generator.config_names: |
- output_path = noise_generator.apm_output_paths[noise_config_name] |
+ for config_name in generator.config_names: |
+ output_path = generator.apm_output_paths[config_name] |
self.assertTrue(os.path.exists(output_path)) |