#!/usr/bin/env python # coding: utf8 """ A ModelProvider backed by Github Release feature. :Example: >>> from spleeter.model.provider import github >>> provider = github.GithubModelProvider( 'github.com', 'Deezer/spleeter', 'latest') >>> provider.download('2stems', '/path/to/local/storage') """ import hashlib import tarfile import os from tempfile import NamedTemporaryFile import requests from . import ModelProvider from ...utils.logging import get_logger __email__ = 'research@deezer.com' __author__ = 'Deezer Research' __license__ = 'MIT License' def compute_file_checksum(path): """ Computes given path file sha256. :param path: Path of the file to compute checksum for. :returns: File checksum. """ sha256 = hashlib.sha256() with open(path, 'rb') as stream: for chunk in iter(lambda: stream.read(4096), b''): sha256.update(chunk) return sha256.hexdigest() class GithubModelProvider(ModelProvider): """ A ModelProvider implementation backed on Github for remote storage. """ LATEST_RELEASE = 'v1.4.0' RELEASE_PATH = 'releases/download' CHECKSUM_INDEX = 'checksum.json' def __init__(self, host, repository, release): """ Default constructor. :param host: Host to the Github instance to reach. :param repository: Repository path within target Github. :param release: Release name to get models from. """ self._host = host self._repository = repository self._release = release def checksum(self, name): """ Downloads and returns reference checksum for the given model name. :param name: Name of the model to get checksum for. :returns: Checksum of the required model. :raise ValueError: If the given model name is not indexed. """ url = '{}/{}/{}/{}/{}'.format( self._host, self._repository, self.RELEASE_PATH, self._release, self.CHECKSUM_INDEX) response = requests.get(url) response.raise_for_status() index = response.json() if name not in index: raise ValueError('No checksum for model {}'.format(name)) return index[name] def download(self, name, path): """ Download model denoted by the given name to disk. :param name: Name of the model to download. :param path: Path of the directory to save model into. """ url = '{}/{}/{}/{}/{}.tar.gz'.format( self._host, self._repository, self.RELEASE_PATH, self._release, name) get_logger().info('Downloading model archive %s', url) with requests.get(url, stream=True) as response: response.raise_for_status() archive = NamedTemporaryFile(delete=False) try: with archive as stream: # Note: check for chunk size parameters ? for chunk in response.iter_content(chunk_size=8192): if chunk: stream.write(chunk) get_logger().info('Validating archive checksum') if compute_file_checksum(archive.name) != self.checksum(name): raise IOError('Downloaded file is corrupted, please retry') get_logger().info('Extracting downloaded %s archive', name) with tarfile.open(name=archive.name) as tar: tar.extractall(path=path) finally: os.unlink(archive.name) get_logger().info('%s model file(s) extracted', name)