Python tempfile.gettempdir() Examples

The following are 30 code examples for showing how to use tempfile.gettempdir(). These examples are extracted from open source projects. You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example.

You may want to check out the right sidebar which shows the related API usage.

You may also want to check out all available functions/classes of the module tempfile , or try the search function .

Example 1
Project: wechat-alfred-workflow   Author: TKkk-iOSer   File: update.py    License: MIT License 6 votes vote down vote up
def download_workflow(url):
    """Download workflow at ``url`` to a local temporary file.

    :param url: URL to .alfredworkflow file in GitHub repo
    :returns: path to downloaded file

    """
    filename = url.split('/')[-1]

    if (not filename.endswith('.alfredworkflow') and
            not filename.endswith('.alfred3workflow')):
        raise ValueError('attachment not a workflow: {0}'.format(filename))

    local_path = os.path.join(tempfile.gettempdir(), filename)

    wf().logger.debug(
        'downloading updated workflow from `%s` to `%s` ...', url, local_path)

    response = web.get(url)

    with open(local_path, 'wb') as output:
        output.write(response.content)

    return local_path 
Example 2
Project: neural-fingerprinting   Author: StephanZheng   File: utils_mnist.py    License: BSD 3-Clause "New" or "Revised" License 6 votes vote down vote up
def maybe_download_mnist_file(file_name, datadir=None, force=False):
    try:
        from urllib.parse import urljoin
        from urllib.request import urlretrieve
    except ImportError:
        from urlparse import urljoin
        from urllib import urlretrieve

    if not datadir:
        datadir = tempfile.gettempdir()
    dest_file = os.path.join(datadir, file_name)

    if force or not os.path.isfile(file_name):
        url = urljoin('http://yann.lecun.com/exdb/mnist/', file_name)
        urlretrieve(url, dest_file)
    return dest_file 
Example 3
Project: CAMISIM   Author: CAMI-challenge   File: defaultvalues.py    License: Apache License 2.0 6 votes vote down vote up
def __init__(self, label="DefaultValues", logfile=None, verbose=False, debug=False):
        super(DefaultValues, self).__init__(label=label, logfile=logfile, verbose=verbose, debug=debug)
        self._validator = Validator(logfile=logfile, verbose=verbose, debug=debug)
        pipeline_dir = os.path.dirname(self._validator.get_full_path(os.path.dirname(scripts.__file__)))

        self._DEFAULT_seed = random.randint(0, 2147483640)
        self._DEFAULT_tmp_dir = tempfile.gettempdir()
        self._DEFAULT_directory_pipeline = pipeline_dir

        original_wd = os.getcwd()
        os.chdir(pipeline_dir)
        file_path_config = os.path.join(pipeline_dir, "default_config.ini")
        if self._validator.validate_file(file_path_config, silent=True):
            self._from_config(file_path_config)
        else:
            self._from_hardcoded(pipeline_dir)
        os.chdir(original_wd) 
Example 4
Project: dynamic-training-with-apache-mxnet-on-aws   Author: awslabs   File: build.py    License: Apache License 2.0 6 votes vote down vote up
def default_ccache_dir() -> str:
    """:return: ccache directory for the current platform"""
    # Share ccache across containers
    if 'CCACHE_DIR' in os.environ:
        ccache_dir = os.path.realpath(os.environ['CCACHE_DIR'])
        try:
            os.makedirs(ccache_dir, exist_ok=True)
            return ccache_dir
        except PermissionError:
            logging.info('Unable to make dirs at %s, falling back to local temp dir', ccache_dir)
    # In osx tmpdir is not mountable by default
    import platform
    if platform.system() == 'Darwin':
        ccache_dir = "/tmp/_mxnet_ccache"
        os.makedirs(ccache_dir, exist_ok=True)
        return ccache_dir
    return os.path.join(tempfile.gettempdir(), "ci_ccache") 
Example 5
Project: CyberTK-Self   Author: CyberTKR   File: LineApi.py    License: GNU General Public License v2.0 6 votes vote down vote up
def sendImageWithUrl(self, to_, url):
        """Send a image with given image url

        :param url: image url to send
        """
        path = '%s/pythonLine-%1.data' % (tempfile.gettempdir(), randint(0, 9))


        r = requests.get(url, stream=True)
        if r.status_code == 200:
            with open(path, 'w') as f:
                shutil.copyfileobj(r.raw, f)
        else:
            raise Exception('Download image failure.')

        try:
            self.sendImage(to_, path)
        except Exception as e:
            raise e 
Example 6
Project: wuy   Author: manatlan   File: freeze.py    License: GNU General Public License v2.0 6 votes vote down vote up
def build(path,inConsole=False,addWeb=False):
    params=[path,"--noupx","--onefile"]

    if not inConsole:
        params.append( "--noconsole")
    
    web=os.path.join( os.path.dirname(path), "web" )
    if addWeb and os.path.isdir(web):
        sep = (os.name == 'nt') and ";" or ":"
        params.append("--add-data=%s%sweb" % (web,sep))

    temp=os.path.join(tempfile.gettempdir(),".build")
    params.append( "--workpath" )
    params.append( temp )

    params.append( "--distpath" )
    params.append( os.path.dirname(path) )

    print( "PYINSTALLER:",params )
    pyi.run( params ) 
Example 7
Project: fine-lm   Author: akzaidi   File: cloud_mlengine.py    License: MIT License 6 votes vote down vote up
def _tar_and_copy(src_dir, target_dir):
  """Tar and gzip src_dir and copy to GCS target_dir."""
  src_dir = src_dir.rstrip("/")
  target_dir = target_dir.rstrip("/")
  tmp_dir = tempfile.gettempdir().rstrip("/")
  src_base = os.path.basename(src_dir)
  cloud.shell_run(
      "tar -zcf {tmp_dir}/{src_base}.tar.gz -C {src_dir} .",
      src_dir=src_dir,
      src_base=src_base,
      tmp_dir=tmp_dir)
  final_destination = "%s/%s.tar.gz" % (target_dir, src_base)
  cloud.shell_run(
      ("gsutil cp {tmp_dir}/{src_base}.tar.gz "
       "{final_destination}"),
      tmp_dir=tmp_dir,
      src_base=src_base,
      final_destination=final_destination)
  return final_destination 
Example 8
Project: fine-lm   Author: akzaidi   File: cloud_mlengine.py    License: MIT License 6 votes vote down vote up
def tar_and_copy_usr_dir(usr_dir, train_dir):
  """Package, tar, and copy usr_dir to GCS train_dir."""
  tf.logging.info("Tarring and pushing t2t_usr_dir.")
  usr_dir = os.path.abspath(os.path.expanduser(usr_dir))
  # Copy usr dir to a temp location
  top_dir = os.path.join(tempfile.gettempdir(), "t2t_usr_container")
  tmp_usr_dir = os.path.join(top_dir, usr_dir_lib.INTERNAL_USR_DIR_PACKAGE)
  shutil.rmtree(top_dir, ignore_errors=True)
  shutil.copytree(usr_dir, tmp_usr_dir)
  # Insert setup.py if one does not exist
  top_setup_fname = os.path.join(top_dir, "setup.py")
  setup_file_str = get_setup_file(
      name="DummyUsrDirPackage",
      packages=get_requirements(usr_dir)
  )
  with tf.gfile.Open(top_setup_fname, "w") as f:
    f.write(setup_file_str)
  usr_tar = _tar_and_copy(top_dir, train_dir)
  return usr_tar 
Example 9
Project: lirpg   Author: Hwhitetooth   File: logger.py    License: MIT License 6 votes vote down vote up
def configure(dir=None, format_strs=None):
    if dir is None:
        dir = os.getenv('OPENAI_LOGDIR')
    if dir is None:
        dir = osp.join(tempfile.gettempdir(),
            datetime.datetime.now().strftime("openai-%Y-%m-%d-%H-%M-%S-%f"))
    assert isinstance(dir, str)
    os.makedirs(dir, exist_ok=True)

    if format_strs is None:
        strs = os.getenv('OPENAI_LOG_FORMAT')
        format_strs = strs.split(',') if strs else LOG_OUTPUT_FORMATS
    output_formats = [make_output_format(f, dir) for f in format_strs]

    Logger.CURRENT = Logger(dir=dir, output_formats=output_formats)
    log('Logging to %s'%dir) 
Example 10
Project: calmjs   Author: calmjs   File: test_toolchain.py    License: GNU General Public License v2.0 6 votes vote down vote up
def test_toolchain_standard_not_implemented(self):
        spec = Spec()

        with self.assertRaises(NotImplementedError):
            self.toolchain(spec)

        with self.assertRaises(NotImplementedError):
            self.toolchain.assemble(spec)

        with self.assertRaises(NotImplementedError):
            self.toolchain.link(spec)

        # Check that the build_dir is set on the spec based on tempfile
        self.assertTrue(spec['build_dir'].startswith(
            realpath(tempfile.gettempdir())))
        # Also that it got deleted properly.
        self.assertFalse(exists(spec['build_dir'])) 
Example 11
Project: HardRLWithYoutube   Author: MaxSobolMark   File: logger.py    License: MIT License 6 votes vote down vote up
def configure(dir=None, format_strs=None):
    if dir is None:
        dir = os.getenv('OPENAI_LOGDIR')
    if dir is None:
        dir = osp.join(tempfile.gettempdir(),
            datetime.datetime.now().strftime("openai-%Y-%m-%d-%H-%M-%S-%f"))
    assert isinstance(dir, str)
    os.makedirs(dir, exist_ok=True)

    log_suffix = ''
    from mpi4py import MPI
    rank = MPI.COMM_WORLD.Get_rank()
    if rank > 0:
        log_suffix = "-rank%03i" % rank

    if format_strs is None:
        if rank == 0:
            format_strs = os.getenv('OPENAI_LOG_FORMAT', 'stdout,log,csv').split(',')
        else:
            format_strs = os.getenv('OPENAI_LOG_FORMAT_MPI', 'log').split(',')
    format_strs = filter(None, format_strs)
    output_formats = [make_output_format(f, dir, log_suffix) for f in format_strs]

    Logger.CURRENT = Logger(dir=dir, output_formats=output_formats)
    log('Logging to %s'%dir) 
Example 12
Project: HardRLWithYoutube   Author: MaxSobolMark   File: mnist_env.py    License: MIT License 6 votes vote down vote up
def __init__(
            self,
            seed=0,
            episode_len=None,
            no_images=None
    ):
        from tensorflow.examples.tutorials.mnist import input_data
        # we could use temporary directory for this with a context manager and 
        # TemporaryDirecotry, but then each test that uses mnist would re-download the data
        # this way the data is not cleaned up, but we only download it once per machine
        mnist_path = osp.join(tempfile.gettempdir(), 'MNIST_data')
        with filelock.FileLock(mnist_path + '.lock'):
           self.mnist = input_data.read_data_sets(mnist_path)

        self.np_random = np.random.RandomState()
        self.np_random.seed(seed)

        self.observation_space = Box(low=0.0, high=1.0, shape=(28,28,1))
        self.action_space = Discrete(10)
        self.episode_len = episode_len
        self.time = 0
        self.no_images = no_images

        self.train_mode()
        self.reset() 
Example 13
Project: iAI   Author: aimuch   File: sample.py    License: MIT License 6 votes vote down vote up
def main():
    # Get data files for the model.
    data_paths, [deploy_file, model_file, mean_proto] = common.find_sample_data(description="Runs an MNIST network using a Caffe model file", subfolder="mnist", find_files=["mnist.prototxt", "mnist.caffemodel", "mnist_mean.binaryproto"])

    # Cache the engine in a temporary directory.
    engine_path = os.path.join(tempfile.gettempdir(), "mnist.engine")
    with get_engine(deploy_file, model_file, engine_path) as engine, engine.create_execution_context() as context:
        # Build an engine, allocate buffers and create a stream.
        # For more information on buffer allocation, refer to the introductory samples.
        inputs, outputs, bindings, stream = common.allocate_buffers(engine)
        mean = retrieve_mean(mean_proto)
        # For more information on performing inference, refer to the introductory samples.
        inputs[0].host, case_num = load_normalized_test_case(data_paths, mean)
        # The common.do_inference function will return a list of outputs - we only have one in this case.
        [output] = common.do_inference(context, bindings=bindings, inputs=inputs, outputs=outputs, stream=stream)
        pred = np.argmax(output)
        print("Test Case: " + str(case_num))
        print("Prediction: " + str(pred))

    # After the engine is destroyed, we destroy the plugin. This function is exposed through the binding code in plugin/pyFullyConnected.cpp.
    fc_factory.destroy_plugin() 
Example 14
Project: keras-bert   Author: CyberZHG   File: test_bert.py    License: MIT License 6 votes vote down vote up
def test_task_embed(self):
        inputs, outputs = get_model(
            token_num=20,
            embed_dim=12,
            head_num=3,
            transformer_num=2,
            use_task_embed=True,
            task_num=10,
            training=False,
            dropout_rate=0.0,
        )
        model = keras.models.Model(inputs, outputs)
        model_path = os.path.join(tempfile.gettempdir(), 'keras_bert_%f.h5' % np.random.random())
        model.save(model_path)
        model = keras.models.load_model(
            model_path,
            custom_objects=get_custom_objects(),
        )
        model.summary(line_length=200) 
Example 15
Project: gist-alfred   Author: danielecook   File: update.py    License: MIT License 6 votes vote down vote up
def retrieve_download(dl):
    """Saves a download to a temporary file and returns path.

    .. versionadded: 1.37

    Args:
        url (unicode): URL to .alfredworkflow file in GitHub repo

    Returns:
        unicode: path to downloaded file

    """
    if not match_workflow(dl.filename):
        raise ValueError('attachment not a workflow: ' + dl.filename)

    path = os.path.join(tempfile.gettempdir(), dl.filename)
    wf().logger.debug('downloading update from '
                      '%r to %r ...', dl.url, path)

    r = web.get(dl.url)
    r.raise_for_status()

    r.save_to_path(path)

    return path 
Example 16
Project: misp42splunk   Author: remg427   File: api_documenter.py    License: GNU Lesser General Public License v3.0 6 votes vote down vote up
def api_get_spec(context, method_list):
    '''Generates and Returns the spec file data
    :param context: Dictionary with app, session, version and api fields
    :type: ```dict```
    :param method_list: List of API methods to call
    :type: ```list```
    :return: generated spec file
    :rtype: ```basestring```
    '''
    _generate_documentation(context, method_list)
    with open(tempfile.gettempdir() + op.sep + 'spec.yaml') as stream:
        try:
            spec_file = yaml.safe_load(stream)
        except yaml.YAMLError as ex:
            raise Exception("Please try again. Exception: {}".format(ex))
        return json.dumps(spec_file) 
Example 17
Project: misp42splunk   Author: remg427   File: api_documenter.py    License: GNU Lesser General Public License v3.0 6 votes vote down vote up
def write_temp(self):
        '''
        Stores changes to the spec in a temp file.
        '''
        spec = {
            "swagger": self.api.__getattribute__('swagger'),
            "info": self.api.__getattribute__('info'),
            "host": self.api.__getattribute__('host'),
            "schemes": self.api.__getattribute__('schemes'),
            "consumes": self.api.__getattribute__('consumes'),
            "produces": self.api.__getattribute__('produces'),
            "paths": self.api.__getattribute__('paths'),
            "definitions": self.api.__getattribute__('definitions')
        }

        stream = file((tempfile.gettempdir() + op.sep + 'temp.yaml'), 'w')
        for x in self.order:
            yaml.dump({x: spec[x]}, stream, default_flow_style=False) 
Example 18
Project: misp42splunk   Author: remg427   File: api_documenter.py    License: GNU Lesser General Public License v3.0 6 votes vote down vote up
def update_spec(self):
        '''
        Updates the specification from the temp file.
        '''
        try:
            os.rename(
                tempfile.gettempdir() +
                op.sep +
                'temp.yaml',
                tempfile.gettempdir() +
                op.sep +
                'spec.yaml')
        except Exception as e:
            raise Exception(
                "Spec file not found, please try again."
                " Exception: {}".format(e)) 
Example 19
Project: misp42splunk   Author: remg427   File: api_documenter.py    License: GNU Lesser General Public License v3.0 6 votes vote down vote up
def api_get_spec(context, method_list):
    '''Generates and Returns the spec file data
    :param context: Dictionary with app, session, version and api fields
    :type: ```dict```
    :param method_list: List of API methods to call
    :type: ```list```
    :return: generated spec file
    :rtype: ```basestring```
    '''
    _generate_documentation(context, method_list)
    with open(tempfile.gettempdir() + op.sep + 'spec.yaml') as stream:
        try:
            spec_file = yaml.safe_load(stream)
        except yaml.YAMLError as ex:
            raise Exception("Please try again. Exception: {}".format(ex))
        return json.dumps(spec_file) 
Example 20
Project: misp42splunk   Author: remg427   File: api_documenter.py    License: GNU Lesser General Public License v3.0 6 votes vote down vote up
def write_temp(self):
        '''
        Stores changes to the spec in a temp file.
        '''
        spec = {
            "swagger": self.api.__getattribute__('swagger'),
            "info": self.api.__getattribute__('info'),
            "host": self.api.__getattribute__('host'),
            "schemes": self.api.__getattribute__('schemes'),
            "consumes": self.api.__getattribute__('consumes'),
            "produces": self.api.__getattribute__('produces'),
            "paths": self.api.__getattribute__('paths'),
            "definitions": self.api.__getattribute__('definitions')
        }

        stream = file((tempfile.gettempdir() + op.sep + 'temp.yaml'), 'w')
        for x in self.order:
            yaml.dump({x: spec[x]}, stream, default_flow_style=False) 
Example 21
Project: misp42splunk   Author: remg427   File: api_documenter.py    License: GNU Lesser General Public License v3.0 6 votes vote down vote up
def update_spec(self):
        '''
        Updates the specification from the temp file.
        '''
        try:
            os.rename(
                tempfile.gettempdir() +
                op.sep +
                'temp.yaml',
                tempfile.gettempdir() +
                op.sep +
                'spec.yaml')
        except Exception as e:
            raise Exception(
                "Spec file not found, please try again."
                " Exception: {}".format(e)) 
Example 22
Project: pySmartDL   Author: iTaybb   File: test_pySmartDL.py    License: The Unlicense 6 votes vote down vote up
def setUp(self):
        warnings.filterwarnings("ignore", category=ResourceWarning, message="unclosed.*<ssl.SSLSocket.*>")
        self.dl_dir = os.path.join(tempfile.gettempdir(), "".join([random.choice(string.ascii_letters+string.digits) for i in range(8)]), '')
        while os.path.exists(self.dl_dir):
            self.dl_dir = os.path.join(tempfile.gettempdir(), "".join([random.choice(string.ascii_letters+string.digits) for i in range(8)]), '')
            
        self.res_7za920_mirrors = [
            "https://github.com/iTaybb/pySmartDL/raw/master/test/7za920.zip",
            "https://sourceforge.mirrorservice.org/s/se/sevenzip/7-Zip/9.20/7za920.zip",
            "http://www.bevc.net/dl/7za920.zip",
            "http://ftp.psu.ru/tools/7-zip/stable/7za920.zip",
            "http://www.mirrorservice.org/sites/downloads.sourceforge.net/s/se/sevenzip/7-Zip/9.20/7za920.zip"
        ]
        self.res_7za920_hash = '2a3afe19c180f8373fa02ff00254d5394fec0349f5804e0ad2f6067854ff28ac'
        self.res_testfile_1gb = 'http://www.ovh.net/files/1Gio.dat'
        self.res_testfile_100mb = 'http://www.ovh.net/files/100Mio.dat'
        self.enable_logging = "-vvv" in sys.argv 
Example 23
Project: snowflake-connector-python   Author: snowflakedb   File: ocsp_snowflake.py    License: Apache License 2.0 6 votes vote down vote up
def reset_cache_dir():
        # Cache directory
        OCSPCache.CACHE_DIR = os.getenv('SF_OCSP_RESPONSE_CACHE_DIR')
        if OCSPCache.CACHE_DIR is None:
            cache_root_dir = expanduser("~") or tempfile.gettempdir()
            if platform.system() == 'Windows':
                OCSPCache.CACHE_DIR = path.join(cache_root_dir, 'AppData', 'Local', 'Snowflake',
                                                'Caches')
            elif platform.system() == 'Darwin':
                OCSPCache.CACHE_DIR = path.join(cache_root_dir, 'Library', 'Caches', 'Snowflake')
            else:
                OCSPCache.CACHE_DIR = path.join(cache_root_dir, '.cache', 'snowflake')
        logger.debug("cache directory: %s", OCSPCache.CACHE_DIR)

        if not path.exists(OCSPCache.CACHE_DIR):
            try:
                os.makedirs(OCSPCache.CACHE_DIR, mode=0o700)
            except Exception as ex:
                logger.debug('cannot create a cache directory: [%s], err=[%s]',
                             OCSPCache.CACHE_DIR, ex)
                OCSPCache.CACHE_DIR = None 
Example 24
Project: mlbv   Author: kmac   File: util.py    License: GNU General Public License v3.0 5 votes vote down vote up
def get_tempdir():
    """Create a directory for ourselves in the system tempdir."""
    tempdir = config.CONFIG.parser.get('tempdir', None)
    if tempdir:
        if '<timestamp>' in tempdir:
            tempdir = tempdir.replace('<timestamp>', time.strftime('%Y-%m-%d-%H%M'))
    else:
        script_name = os.path.splitext(os.path.basename(sys.argv[0]))[0]
        tempdir = os.path.join(tempfile.gettempdir(), script_name)
    if not os.path.exists(tempdir):
        os.makedirs(tempdir)
    return tempdir 
Example 25
Project: multibootusb   Author: mbusb   File: osdriver.py    License: GNU General Public License v2.0 5 votes vote down vote up
def multibootusb_host_dir(self):
        return os.path.join(tempfile.gettempdir(), "multibootusb") 
Example 26
Project: CAMISIM   Author: CAMI-challenge   File: fastaanonymizer.py    License: Apache License 2.0 5 votes vote down vote up
def __init__(self, logfile=None, verbose=True, debug=False, seed=None, tmp_dir=None):
		"""
			Anonymize fasta sequences

			@attention: 'shuf' is used which loads everything into memory!

			@param logfile: file handler or file path to a log file
			@type logfile: file | io.FileIO | StringIO.StringIO | str | unicode
			@param verbose: Not verbose means that only warnings and errors will be past to stream
			@type verbose: bool
			@param debug: more output and files are kept, manual clean up required
			@type debug: bool
			@param seed: The seed written to the random_source file used by the 'shuf' command
			@type seed: long | int | float | str | unicode
			@param tmp_dir: directory for temporary files, like the random_source file for 'shuf'
			@type tmp_dir: str | unicode

			@return: None
			@rtype: None
		"""
		assert isinstance(verbose, bool)
		assert isinstance(debug, bool)
		assert seed is None or isinstance(seed, (long, int, float, basestring))
		assert tmp_dir is None or isinstance(tmp_dir, basestring)
		if tmp_dir is not None:
			assert self.validate_dir(tmp_dir)
		else:
			tmp_dir = tempfile.gettempdir()
		self._tmp_dir = tmp_dir
		super(FastaAnonymizer, self).__init__(logfile, verbose, debug, label="FastaAnonymizer")

		if seed is not None:
			random.seed(seed)

		script_dir = os.path.dirname(self.get_full_path(__file__))
		self._anonymizer = os.path.join(script_dir, "anonymizer.py")
		self._fastastreamer = os.path.join(script_dir, "fastastreamer.py")
		assert self.validate_file(self._anonymizer)
		assert self.validate_file(self._fastastreamer) 
Example 27
Project: CyberTK-Self   Author: CyberTKR   File: LineApi.py    License: GNU General Public License v2.0 5 votes vote down vote up
def sendVideoWithURL(self, to_, url):
        path = '%s/pythonLine-%i.data' % (tempfile.gettempdir(), randint(0, 9))

        r = requests.get(url, stream=True)
        if r.status_code == 200:
            with open(path, 'w') as f:
                shutil.copyfileobj(r.raw, f)
        else:
            raise Exception('Download video failure.')

        try:
            self.sendVideo(to_, path)
        except Exception as e:
            raise (e) 
Example 28
Project: CyberTK-Self   Author: CyberTKR   File: LineApi.py    License: GNU General Public License v2.0 5 votes vote down vote up
def sendAudioWithUrl(self, to_, url):
        path = '%s/pythonLine-%1.data' % (tempfile.gettempdir(), randint(0, 9))

        r = requests.get(url, stream=True)
        if r.status_code == 200:
            with open(path, 'w') as f:
                shutil.copyfileobj(r.raw, f)
        else:
            raise Exception('Download audio failure.')

        try:
            self.sendAudio(to_, path)
        except Exception as e:
            raise (e) 
Example 29
Project: keras-gpt-2   Author: CyberZHG   File: test_model.py    License: MIT License 5 votes vote down vote up
def test_save_load(self):
        model = get_model(
            n_vocab=50257,
            n_ctx=1024,
            n_embd=768,
            n_head=12,
            n_layer=12,
        )
        model_path = os.path.join(tempfile.gettempdir(), 'test_gpt_2_%f.h5' % np.random.random())
        model.save(model_path)
        model = keras.models.load_model(model_path, custom_objects=get_custom_objects())
        model.summary() 
Example 30
Project: keras-gpt-2   Author: CyberZHG   File: test_model.py    License: MIT License 5 votes vote down vote up
def test_fixed_input_shape(self):
        model = get_model(
            n_vocab=50257,
            n_ctx=1024,
            n_embd=768,
            n_head=12,
            n_layer=12,
            fixed_input_shape=True,
        )
        model_path = os.path.join(tempfile.gettempdir(), 'test_gpt_2_%f.h5' % np.random.random())
        model.save(model_path)
        model = keras.models.load_model(model_path, custom_objects=get_custom_objects())
        model.summary()