import shutil import os.path import tempfile import unittest import torch from torch.testing._internal.common_utils import TestCase as PytorchTestCase import torchaudio from torchaudio._internal.module_utils import is_module_available from .backend_utils import set_audio_backend class TempDirMixin: """Mixin to provide easy access to temp dir""" temp_dir_ = None base_temp_dir = None temp_dir = None @classmethod def setUpClass(cls): super().setUpClass() # If TORCHAUDIO_TEST_TEMP_DIR is set, use it instead of temporary directory. # this is handy for debugging. key = 'TORCHAUDIO_TEST_TEMP_DIR' if key in os.environ: cls.base_temp_dir = os.environ[key] else: cls.temp_dir_ = tempfile.TemporaryDirectory() cls.base_temp_dir = cls.temp_dir_.name @classmethod def tearDownClass(cls): super().tearDownClass() if isinstance(cls.temp_dir_, tempfile.TemporaryDirectory): cls.temp_dir_.cleanup() def setUp(self): super().setUp() self.temp_dir = os.path.join(self.base_temp_dir, self.id()) def get_temp_path(self, *paths): path = os.path.join(self.temp_dir, *paths) os.makedirs(os.path.dirname(path), exist_ok=True) return path class TestBaseMixin: """Mixin to provide consistent way to define device/dtype/backend aware TestCase""" dtype = None device = None backend = None def setUp(self): super().setUp() set_audio_backend(self.backend) class TorchaudioTestCase(TestBaseMixin, PytorchTestCase): pass def skipIfNoExec(cmd): return unittest.skipIf(shutil.which(cmd) is None, f'`{cmd}` is not available') def skipIfNoModule(module, display_name=None): display_name = display_name or module return unittest.skipIf(not is_module_available(module), f'"{display_name}" is not available') skipIfNoSoxBackend = unittest.skipIf( 'sox' not in torchaudio.list_audio_backends(), 'Sox backend not available') skipIfNoCuda = unittest.skipIf(not torch.cuda.is_available(), reason='CUDA not available') skipIfNoExtension = skipIfNoModule('torchaudio._torchaudio', 'torchaudio C++ extension')