Index: webrtc/modules/audio_processing/test/py_quality_assessment/quality_assessment/test_data_generation_unittest.py |
diff --git a/webrtc/modules/audio_processing/test/py_quality_assessment/quality_assessment/test_data_generation_unittest.py b/webrtc/modules/audio_processing/test/py_quality_assessment/quality_assessment/test_data_generation_unittest.py |
index 5261dd25f4cae1ec3393379ed47ca242ab1af37e..909d7bad39c266c30473c192bf517d6f3e28cbf8 100644 |
--- a/webrtc/modules/audio_processing/test/py_quality_assessment/quality_assessment/test_data_generation_unittest.py |
+++ b/webrtc/modules/audio_processing/test/py_quality_assessment/quality_assessment/test_data_generation_unittest.py |
@@ -14,6 +14,9 @@ import shutil |
import tempfile |
import unittest |
+import numpy as np |
+import scipy.io |
+ |
from . import test_data_generation |
from . import test_data_generation_factory |
from . import signal_processing |
@@ -27,11 +30,27 @@ class TestTestDataGenerators(unittest.TestCase): |
"""Create temporary folders.""" |
self._base_output_path = tempfile.mkdtemp() |
self._input_noise_cache_path = tempfile.mkdtemp() |
+ self._fake_air_db_path = tempfile.mkdtemp() |
+ |
+ # Fake AIR DB impulse responses. |
+ # TODO(alessiob): ReverberationTestDataGenerator will change to allow custom |
+ # impulse responses. When changed, the coupling below between |
+ # impulse_response_mat_file_names and |
+ # ReverberationTestDataGenerator._IMPULSE_RESPONSES can be removed. |
+ impulse_response_mat_file_names = [ |
+ 'air_binaural_lecture_0_0_1.mat', |
+ 'air_binaural_booth_0_0_1.mat', |
+ ] |
+ for impulse_response_mat_file_name in impulse_response_mat_file_names: |
+ data = {'h_air': np.random.rand(1, 1000).astype('<f8')} |
+ scipy.io.savemat(os.path.join( |
+ self._fake_air_db_path, impulse_response_mat_file_name), data) |
def tearDown(self): |
"""Recursively delete temporary folders.""" |
shutil.rmtree(self._base_output_path) |
shutil.rmtree(self._input_noise_cache_path) |
+ shutil.rmtree(self._fake_air_db_path) |
def testTestDataGenerators(self): |
# Preliminary check. |
@@ -47,11 +66,7 @@ class TestTestDataGenerators(unittest.TestCase): |
# Instance generators factory. |
generators_factory = ( |
test_data_generation_factory.TestDataGeneratorFactory( |
- aechen_ir_database_path='')) |
- # TODO(alessiob): Replace with a mock of TestDataGeneratorFactory that |
- # takes no arguments in the ctor. For those generators that need parameters, |
- # it will return a mock generator (see the first comment in the next for |
- # loop). |
+ aechen_ir_database_path=self._fake_air_db_path)) |
# Use a sample input file as clean input signal. |
input_signal_filepath = os.path.join( |
@@ -64,16 +79,6 @@ class TestTestDataGenerators(unittest.TestCase): |
# Try each registered test data generator. |
for generator_name in registered_classes: |
- # Exclude ReverberationTestDataGenerator. |
- # TODO(alessiob): Mock ReverberationTestDataGenerator, the mock |
- # should rely on hard-coded impulse responses. This requires a mock for |
- # TestDataGeneratorFactory. The latter knows whether returning the |
- # actual generator or a mock object (as in the case of |
- # ReverberationTestDataGenerator). |
- if generator_name == ( |
- test_data_generation.ReverberationTestDataGenerator.NAME): |
- continue |
- |
# Instance test data generator. |
generator = generators_factory.GetInstance( |
registered_classes[generator_name]) |