Python os.environ() Examples

The following are 30 code examples of os.environ(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module os , or try the search function .
Example #1
Source File: runtime.py    From godot-mono-builds with MIT License 9 votes vote down vote up
def setup_runtime_cross_template(env: dict, opts: RuntimeOpts, product: str, target: str, host_triple: str,
                target_triple: str, device_target: str, llvm: str, offsets_dumper_abi: str):
    CONFIGURE_FLAGS = [
        '--target=%s' % target_triple,
        '--with-cross-offsets=%s.h' % target_triple,
        '--with-llvm=%s/llvm-%s' % (opts.install_dir, llvm)
    ]

    env['_cross-runtime_%s-%s_CONFIGURE_FLAGS' % (product, target)] = CONFIGURE_FLAGS

    new_offsets_tool_path = '%s/mono/tools/offsets-tool/offsets-tool.py' % opts.mono_source_root
    old_offsets_tool_path = '%s/tools/offsets-tool-py/offsets-tool.py' % opts.mono_source_root

    old_offsets_tool = not os.path.isfile(new_offsets_tool_path)

    offsets_tool_env = None

    if old_offsets_tool:
        # Setup old offsets-tool-py if present (new location doesn't require setup)
        run_command('make', ['-C', '%s/tools/offsets-tool-py' % opts.mono_source_root, 'setup'], name='make offsets-tool-py')

        # Run offsets-tool in its virtual env
        virtualenv_vars = source('%s/tools/offsets-tool-py/offtool/bin/activate' % opts.mono_source_root)

        offsets_tool_env = os.environ.copy()
        offsets_tool_env.update(virtualenv_vars)

    build_dir = '%s/%s-%s-%s' % (opts.configure_dir, product, target, opts.configuration)
    mkdir_p(build_dir)

    run_command('python3', [
            old_offsets_tool_path if old_offsets_tool else new_offsets_tool_path,
            '--targetdir=%s/%s-%s-%s' % (opts.configure_dir, product, device_target, opts.configuration),
            '--abi=%s' % offsets_dumper_abi,
            '--monodir=%s' % opts.mono_source_root,
            '--outfile=%s/%s.h' % (build_dir, target_triple)
        ] + env['_%s-%s_OFFSETS_DUMPER_ARGS' % (product, target)],
        env=offsets_tool_env, name='offsets-tool')

    # Runtime template
    setup_runtime_template(env, opts, product, target, host_triple) 
Example #2
Source File: lambda_handler.py    From aws-auto-remediate with GNU General Public License v3.0 8 votes vote down vote up
def get_settings(self):
        """Return the DynamoDB aws-auto-remediate-settings table in a Python dict format
        
        Returns:
            dict -- aws-auto-remediate-settings table
        """
        settings = {}
        try:
            for record in boto3.client("dynamodb").scan(
                TableName=os.environ["SETTINGSTABLE"]
            )["Items"]:
                record_json = dynamodb_json.loads(record, True)
                settings[record_json["key"]] = record_json["value"]
        except:
            self.logging.error(
                f"Could not read DynamoDB table '{os.environ['SETTINGSTABLE']}'."
            )
            self.logging.error(sys.exc_info()[1])

        return settings 
Example #3
Source File: secrets.py    From aegea with Apache License 2.0 6 votes vote down vote up
def put(args):
    if args.generate_ssh_key:
        ssh_key = new_ssh_key()
        buf = StringIO()
        ssh_key.write_private_key(buf)
        secret_value = buf.getvalue()
    elif args.secret_name in os.environ:
        secret_value = os.environ[args.secret_name]
    else:
        secret_value = sys.stdin.read()
    try:
        res = clients.secretsmanager.create_secret(Name=args.secret_name, SecretString=secret_value)
    except clients.secretsmanager.exceptions.ResourceExistsException:
        res = clients.secretsmanager.put_secret_value(SecretId=args.secret_name, SecretString=secret_value)
    if parse_principal(args):
        ensure_policy(parse_principal(args), res["ARN"])
    if args.generate_ssh_key:
        return dict(ssh_public_key=hostkey_line(hostnames=[], key=ssh_key).strip(),
                    ssh_key_fingerprint=key_fingerprint(ssh_key)) 
Example #4
Source File: views.py    From MPContribs with MIT License 6 votes vote down vote up
def index(request):
    ctx = get_context(request)
    cname = os.environ["PORTAL_CNAME"]
    template_dir = get_app_template_dirs("templates/notebooks")[0]
    htmls = os.path.join(template_dir, cname, "*.html")
    ctx["notebooks"] = [
        p.split("/" + cname + "/")[-1].replace(".html", "") for p in glob(htmls)
    ]
    ctx["PORTAL_CNAME"] = cname
    ctx["landing_pages"] = []
    mask = ["project", "title", "authors", "is_public", "description", "urls"]
    client = Client(headers=get_consumer(request))  # sets/returns global variable
    entries = client.projects.get_entries(_fields=mask).result()["data"]
    for entry in entries:
        authors = entry["authors"].strip().split(",", 1)
        if len(authors) > 1:
            authors[1] = authors[1].strip()
        entry["authors"] = authors
        entry["description"] = entry["description"].split(".", 1)[0] + "."
        ctx["landing_pages"].append(
            entry
        )  # visibility governed by is_public flag and X-Consumer-Groups header
    return render(request, "home.html", ctx.flatten()) 
Example #5
Source File: lambda_handler.py    From aws-auto-remediate with GNU General Public License v3.0 6 votes vote down vote up
def get_settings(self):
        """Return the DynamoDB aws-auto-remediate-settings table in a Python dict format
        
        Returns:
            dict -- aws-auto-remediate-settings table
        """
        settings = {}
        try:
            for record in self.client_dynamodb.scan(
                TableName=os.environ["SETTINGSTABLE"]
            )["Items"]:
                record_json = dynamodb_json.loads(record, True)

                if "key" in record_json and "value" in record_json:
                    settings[record_json.get("key")] = record_json.get("value")
        except:
            self.logging.error(
                f"Could not read DynamoDB table '{os.environ['SETTINGSTABLE']}'."
            )
            self.logging.error(sys.exc_info()[1])

        return settings 
Example #6
Source File: main.py    From friendly-telegram with GNU Affero General Public License v3.0 6 votes vote down vote up
def get_phones(arguments):
    """Get phones from the --token, --phone, and environment"""
    phones = set(arguments.phone if arguments.phone else [])
    phones.update(map(lambda f: f[18:-8],
                      filter(lambda f: f.startswith("friendly-telegram-") and f.endswith(".session"),
                             os.listdir(os.path.dirname(utils.get_base_dir())))))

    authtoken = os.environ.get("authorization_strings", False)  # for heroku
    if authtoken and not arguments.setup:
        try:
            authtoken = json.loads(authtoken)
        except json.decoder.JSONDecodeError:
            logging.warning("authtoken invalid")
            authtoken = False

    if arguments.setup or (arguments.tokens and not authtoken):
        authtoken = {}
    if arguments.tokens:
        for token in arguments.tokens:
            phone = sorted(phones).pop(0)
            phones.remove(phone)  # Handled seperately by authtoken logic
            authtoken.update(**{phone: token})
    return phones, authtoken 
Example #7
Source File: updater.py    From friendly-telegram with GNU Affero General Public License v3.0 6 votes vote down vote up
def update_complete(self, client):
        logger.debug("Self update successful! Edit message")
        heroku_key = os.environ.get("heroku_api_token")
        herokufail = ("DYNO" in os.environ) and (heroku_key is None)
        if herokufail:
            logger.warning("heroku token not set")
            msg = self.strings["heroku_warning"]
        else:
            logger.debug("Self update successful! Edit message")
            msg = self.strings["success"] if random.randint(0, 10) != 0 else self.strings["success_meme"]
        if self.config["AUDIO"]:
            await client.send_file(self._db.get(__name__, "selfupdatechat"), STARTUP, caption=msg, voice_note=True)
            await client.delete_messages(self._db.get(__name__, "selfupdatechat"),
                                         [self._db.get(__name__, "selfupdatemsg")])
        else:
            await client.edit_message(self._db.get(__name__, "selfupdatechat"),
                                      self._db.get(__name__, "selfupdatemsg"), msg) 
Example #8
Source File: test_setup.py    From aws-auto-remediate with GNU General Public License v3.0 6 votes vote down vote up
def test_invalid_table_schema(self, setup):
        """Tests retrieval of settings from DynamoDB with the wrong schema
        
        Arguments:
            setup {class} -- Instance of Setup class
        """
        os.environ["SETTINGSTABLE"] = "settings-table"

        setup.client_dynamodb.create_table(
            TableName="settings-table",
            KeySchema=[{"AttributeName": "id", "KeyType": "HASH"}],
            AttributeDefinitions=[{"AttributeName": "id", "AttributeType": "S"}],
            ProvisionedThroughput={"ReadCapacityUnits": 1, "WriteCapacityUnits": 1},
        )

        setup.client_dynamodb.put_item(
            TableName="settings-table", Item={"id": {"S": "123"}}
        )

        # test get_settings function
        assert setup.get_settings() == {} 
Example #9
Source File: lambda_handler.py    From aws-auto-remediate with GNU General Public License v3.0 6 votes vote down vote up
def send_to_missing_remediation_topic(self, config_rule_name, config_payload):
        """Publishes a message onto the missing remediation SNS Topic. The topic should be subscribed to
        by administrators to be aware when their security remediations are not fully covered.
        
        Arguments:
            config_rule_name {string} -- AWS Config Rule name
            config_payload {dictionary} -- AWS Config Rule payload
        """
        client = boto3.client("sns")
        topic_arn = os.environ["MISSINGREMEDIATIONTOPIC"]

        try:
            client.publish(
                TopicArn=topic_arn,
                Message=json.dumps(config_payload),
                Subject=f"No remediation available for Config Rule '{config_rule_name}'",
            )
        except:
            self.logging.error(f"Could not publish to SNS Topic 'topic_arn'.") 
Example #10
Source File: test_grid.py    From indras_net with GNU General Public License v3.0 6 votes vote down vote up
def __init__(self, methodName, prop_file="models/grid_for_test.props"):
        super().__init__(methodName=methodName)

        self.pa = props.read_props(MODEL_NM, prop_file)

        # Now we create a forest environment for our agents to act within:
        if self.pa["user_type"] == props.WEB:
            self.pa["base_dir"] = os.environ["base_dir"]

        # Now we create a minimal environment for our agents to act within:
        self.env = ge.GridEnv("Test grid env",
                         self.pa["grid_width"],
                         self.pa["grid_height"],
                         torus=False,
                         model_nm=MODEL_NM,
                         preact=True,
                         postact=True,
                         props=self.pa)

        for i in range(self.pa["num_agents"]):
            self.env.add_agent(gm.TestGridAgent(name="agent" + str(i),
                                           goal="taking up a grid space!"))

        self.env.add_agent(gm.TestGridAgent(name="agent for tracking",
                                       goal="taking up a grid space!")) 
Example #11
Source File: os_utils.py    From godot-mono-builds with MIT License 6 votes vote down vote up
def find_executable(name) -> str:
    is_windows = os.name == 'nt'
    windows_exts = os.environ['PATHEXT'].split(ENV_PATH_SEP) if is_windows else None
    path_dirs = os.environ['PATH'].split(ENV_PATH_SEP)

    search_dirs = path_dirs + [os.getcwd()] # cwd is last in the list

    for dir in search_dirs:
        path = os.path.join(dir, name)

        if is_windows:
            for extension in windows_exts:
                path_with_ext = path + extension

                if os.path.isfile(path_with_ext) and os.access(path_with_ext, os.X_OK):
                    return path_with_ext
        else:
            if os.path.isfile(path) and os.access(path, os.X_OK):
                return path

    return '' 
Example #12
Source File: core.py    From BASS with GNU General Public License v2.0 6 votes vote down vote up
def get_VT_name(hashes):
    try:
        vt = PrivateApi(api_key = os.environ["VIRUSTOTAL_API_KEY"])
        generator = ComputeVtUniqueName()
        names = [generator.build_unique_name(vt.get_file_report(hash_) or "") for hash_ in hashes]
        if len(names) >= 2 and all(names[0] == name for name in names[1:]):
            name = names[0]
            if name["pup"]:
                log.error("PUA signatures are not implemented yet. Excpected name was: %s", str(name))
                pass
            else:
                return "{}.{}.{}".format(name["platform"], name["category"], name["unique_name"])
    except KeyError:
        log.warn("No VIRUSTOTAL_API_KEY specified. Falling back to generic name.")
    except Exception:
        log.exception("White trying to compute VT name. Falling back to generic name.")
    return GENERIC_CLAMAV_MALWARE_NAME 
Example #13
Source File: configparserwrapper.py    From CAMISIM with Apache License 2.0 6 votes vote down vote up
def _get_full_path(value):
		"""
			convert string to absolute normpath.

			@param value: some string to be converted
			@type value: basestring

			@return: absolute normpath
			@rtype: basestring
		"""
		assert isinstance(value, basestring)
		parent_directory, filename = os.path.split(value)
		if not parent_directory and not os.path.isfile(value):
			for path in os.environ["PATH"].split(os.pathsep):
				path = path.strip('"')
				exe_file = os.path.join(path, filename)
				if os.path.isfile(exe_file):
					value = exe_file
					break
		value = os.path.expanduser(value)
		value = os.path.normpath(value)
		value = os.path.abspath(value)
		return value 
Example #14
Source File: env.py    From drydock with Apache License 2.0 6 votes vote down vote up
def run_migrations_online():
    """Run migrations in 'online' mode.

    In this scenario we need to create an Engine
    and associate a connection with the context.

    """
    db_url = os.environ['DRYDOCK_DB_URL']

    connectable = engine_from_config(
        config.get_section(config.config_ini_section),
        prefix='sqlalchemy.',
        poolclass=pool.NullPool,
        url=db_url)

    with connectable.connect() as connection:
        context.configure(
            connection=connection, target_metadata=target_metadata)

        with context.begin_transaction():
            context.run_migrations() 
Example #15
Source File: _device.py    From multibootusb with GNU General Public License v2.0 6 votes vote down vote up
def from_environment(cls, context):
        """
        Create a new device from the process environment (as in
        :data:`os.environ`).

        This only works reliable, if the current process is called from an
        udev rule, and is usually used for tools executed from ``IMPORT=``
        rules.  Use this method to create device objects in Python scripts
        called from udev rules.

        ``context`` is the library :class:`Context`.

        Return a :class:`Device` object constructed from the environment.
        Raise :exc:`DeviceNotFoundInEnvironmentError`, if no device could be
        created from the environment.

        .. udevversion:: 152

        .. versionadded:: 0.18
        """
        device = context._libudev.udev_device_new_from_environment(context)
        if not device:
            raise DeviceNotFoundInEnvironmentError()
        return Device(context, device) 
Example #16
Source File: lambda_handler.py    From aws-auto-remediate with GNU General Public License v3.0 5 votes vote down vote up
def send_to_dead_letter_queue(self, config_payload, try_count):
        """Sends the AWS Config payload to an SQS Queue (DLQ) if after incrementing 
        the "try_count" variable it is below the user defined "RETRYCOUNT" setting.
        
        Arguments:
            config_payload {dictionary} -- AWS Config payload
            try_count {string} -- Number of previos remediation attemps for this AWS Config payload
        """
        client = boto3.client("sqs")

        try_count = int(try_count) + 1
        if try_count < int(os.environ.get("RETRYCOUNT", 3)):
            try:
                client.send_message(
                    QueueUrl=os.environ["DEADLETTERQUEUE"],
                    MessageBody=json.dumps(config_payload),
                    MessageAttributes={
                        "try_count": {
                            "StringValue": str(try_count),
                            "DataType": "Number",
                        }
                    },
                )

                self.logging.debug(
                    f"Remediation failed. Payload has been sent to SQS DLQ '{os.environ['DEADLETTERQUEUE']}'."
                )
            except:
                self.logging.error(
                    f"Could not send payload to SQS DLQ '{os.environ['DEADLETTERQUEUE']}'."
                )
                self.logging.error(sys.exc_info()[1])
        else:
            self.logging.warning(
                f"Could not remediate Config change within an "
                f"acceptable number of retries for payload '{config_payload}'."
            ) 
Example #17
Source File: train.py    From Sound-Recognition-Tutorial with Apache License 2.0 5 votes vote down vote up
def use_gpu():
    """Configuration for GPU"""
    from keras.backend.tensorflow_backend import set_session
    os.environ['CUDA_VISIBLE_DEVICES'] = str(0)   # 使用第一台GPU
    config = tf.ConfigProto()
    config.gpu_options.per_process_gpu_memory_fraction = 0.5  # GPU使用率为50%
    config.gpu_options.allow_growth = True    # 允许容量增长
    set_session(tf.InteractiveSession(config=config)) 
Example #18
Source File: loader.py    From friendly-telegram with GNU Affero General Public License v3.0 5 votes vote down vote up
def send_config_one(self, mod, db, babel=None, skip_hook=False):  # pylint: disable=R0201
        """Send config to single instance"""
        if hasattr(mod, "config"):
            modcfg = db.get(mod.__module__, "__config__", {})
            logging.debug(modcfg)
            for conf in mod.config.keys():
                logging.debug(conf)
                if conf in modcfg.keys():
                    mod.config[conf] = modcfg[conf]
                else:
                    try:
                        mod.config[conf] = os.environ[mod.__module__ + "." + conf]
                        logging.debug("Loaded config key %s from environment", conf)
                    except KeyError:
                        logging.debug("No config value for %s", conf)
                        mod.config[conf] = mod.config.getdef(conf)
            logging.debug(mod.config)
        if hasattr(mod, "strings") and babel is not None:
            mod.strings = mod.strings.copy()  # For users with many accounts with diff. translations
            for key, value in mod.strings.items():
                new = babel.getkey(mod.__module__ + "." + key)
                if new is not False:
                    mod.strings[key] = new
        if skip_hook:
            return
        mod.babel = babel
        try:
            mod.config_complete()
        except Exception:
            logging.exception("Failed to send mod config complete signal") 
Example #19
Source File: heroku.py    From friendly-telegram with GNU Affero General Public License v3.0 5 votes vote down vote up
def __init__(self, **kwargs):
        super().__init__(**kwargs)
        if "heroku_api_token" in os.environ:
            # This is called before asyncio is even set up. We can only use sync methods which is fine.
            api_token = collections.namedtuple("api_token", ["ID", "HASH"])(os.environ["api_id"],
                                                                            os.environ["api_hash"])
            app, config = heroku.get_app([c[1] for c in self.client_data],
                                         os.environ["heroku_api_token"], api_token, False, True)
            if os.environ["DYNO"].startswith("web."):
                app.scale_formation_process("worker-DO-NOT-TURN-ON-OR-THINGS-WILL-BREAK", 0)
            atexit.register(functools.partial(exit_handler, app)) 
Example #20
Source File: updater.py    From friendly-telegram with GNU Affero General Public License v3.0 5 votes vote down vote up
def updatecmd(self, message):
        """Downloads userbot updates"""
        # We don't really care about asyncio at this point, as we are shutting down
        msgs = await utils.answer(message, self.strings["downloading"])
        req_update = await self.download_common()
        if self.config["AUDIO"]:
            message = await message.client.send_file(message.chat_id, SHUTDOWN,
                                                     caption=self.strings["installing"], voice_note=True)
            await asyncio.gather(*[msg.delete() for msg in msgs])
        else:
            message = (await utils.answer(message, self.strings["installing"]))[0]
        heroku_key = os.environ.get("heroku_api_token")
        if heroku_key:
            from .. import heroku
            await self.prerestart_common(message)
            heroku.publish(self.allclients, heroku_key)
            # If we pushed, this won't return. If the push failed, we will get thrown at.
            # So this only happens when remote is already up to date (remote is heroku, where we are running)
            self._db.set(__name__, "selfupdatechat", None)
            self._db.set(__name__, "selfupdatemsg", None)
            if self.config["AUDIO"]:
                await message.client.send_file(message.chat_id, STARTUP, voice_note=True,
                                               caption=self.strings["already_updated"])
                await message.delete()
            else:
                await utils.answer(message, self.strings["already_updated"])
        else:
            if req_update:
                self.req_common()
            await self.restart_common(message) 
Example #21
Source File: main.py    From friendly-telegram with GNU Affero General Public License v3.0 5 votes vote down vote up
def get_api_token():
    """Get API Token from disk or environment"""
    while True:
        try:
            from . import api_token
        except ImportError:
            try:
                api_token = collections.namedtuple("api_token", ("ID", "HASH"))(os.environ["api_id"],
                                                                                os.environ["api_hash"])
            except KeyError:
                return None
            else:
                return api_token
        else:
            return api_token 
Example #22
Source File: views.py    From MPContribs with MIT License 5 votes vote down vote up
def get_context(request):
    ctx = RequestContext(request)
    ctx["API_CNAME"] = os.environ["API_CNAME"]
    ctx["API_PORT"] = os.environ["API_PORT"]
    ctx["TRADEMARK"] = os.environ.get("TRADEMARK", "")
    return ctx 
Example #23
Source File: runtests.py    From django-json-widget with MIT License 5 votes vote down vote up
def run_tests(*test_args):
    if not test_args:
        test_args = ['tests']

    os.environ['DJANGO_SETTINGS_MODULE'] = 'tests.settings'
    django.setup()
    TestRunner = get_runner(settings)
    test_runner = TestRunner()
    failures = test_runner.run_tests(test_args)
    sys.exit(bool(failures)) 
Example #24
Source File: easy_deploy.py    From reinvent2014-scalable-site-management with Apache License 2.0 5 votes vote down vote up
def cli(ctx, profile, opsworks_region, elb_region):
    if profile is not None:
        os.environ['BOTO_DEFAULT_PROFILE'] = profile
    ctx.obj['OPSWORKS_REGION'] = opsworks_region
    ctx.obj['ELB_REGION'] = elb_region 
Example #25
Source File: util.py    From text-rank with MIT License 5 votes vote down vote up
def debug(*args):
    global __DEBUG
    if __DEBUG is None:
        try:
            if os.environ['DEBUG'] == '1':
                __DEBUG = True
            else:
                __DEBUG = False
        except:
            __DEBUG = False
    if __DEBUG:
        print( ' '.join([str(arg) for arg in args]) ) 
Example #26
Source File: train.py    From mmdetection with Apache License 2.0 5 votes vote down vote up
def parse_args():
    parser = argparse.ArgumentParser(description='Train a detector')
    parser.add_argument('config', help='train config file path')
    parser.add_argument('--work-dir', help='the dir to save logs and models')
    parser.add_argument(
        '--resume-from', help='the checkpoint file to resume from')
    parser.add_argument(
        '--no-validate',
        action='store_true',
        help='whether not to evaluate the checkpoint during training')
    group_gpus = parser.add_mutually_exclusive_group()
    group_gpus.add_argument(
        '--gpus',
        type=int,
        help='number of gpus to use '
        '(only applicable to non-distributed training)')
    group_gpus.add_argument(
        '--gpu-ids',
        type=int,
        nargs='+',
        help='ids of gpus to use '
        '(only applicable to non-distributed training)')
    parser.add_argument('--seed', type=int, default=None, help='random seed')
    parser.add_argument(
        '--deterministic',
        action='store_true',
        help='whether to set deterministic options for CUDNN backend.')
    parser.add_argument(
        '--options', nargs='+', action=DictAction, help='arguments in dict')
    parser.add_argument(
        '--launcher',
        choices=['none', 'pytorch', 'slurm', 'mpi'],
        default='none',
        help='job launcher')
    parser.add_argument('--local_rank', type=int, default=0)
    args = parser.parse_args()
    if 'LOCAL_RANK' not in os.environ:
        os.environ['LOCAL_RANK'] = str(args.local_rank)

    return args 
Example #27
Source File: __init__.py    From django-template with MIT License 5 votes vote down vote up
def get_env_variable(var_name, default=None):
    """
    Get the environment variable or return exception
    """
    try:
        return os.environ[var_name]
    except KeyError:
        if default is not None:
            return default
        else:
            error_msg = "Set the %s environment variable" % var_name
            raise ImproperlyConfigured(error_msg) 
Example #28
Source File: views.py    From MPContribs with MIT License 5 votes vote down vote up
def notebooks(request, nb):
    return render(
        request, os.path.join("notebooks", os.environ["PORTAL_CNAME"], nb + ".html")
    ) 
Example #29
Source File: views.py    From MPContribs with MIT License 5 votes vote down vote up
def download(request, project):
    cname = os.environ["PORTAL_CNAME"]
    s3obj = f"{S3_DOWNLOAD_URL}{cname}/{project}.json.gz"
    return redirect(s3obj)
    # TODO check if exists, generate if not, progressbar...
    # return HttpResponse(status=404) 
Example #30
Source File: build.py    From DDPAE-video-prediction with MIT License 5 votes vote down vote up
def build(is_train, tb_dir=None):
  '''
  Parse arguments, setup logger and tensorboardX directory.
  '''
  opt, log = args.TrainArgs().parse() if is_train else args.TestArgs().parse()

  os.environ['CUDA_VISIBLE_DEVICES'] = opt.gpus
  os.makedirs(opt.ckpt_path, exist_ok=True)

  # Set seed
  torch.manual_seed(666)
  torch.cuda.manual_seed_all(666)
  np.random.seed(666)
  random.seed(666)

  logger = Logger(opt.ckpt_path, opt.split)

  if tb_dir is not None:
    tb_path = os.path.join(opt.ckpt_path, tb_dir)
    vis = Visualizer(tb_path)
  else:
    vis = None

  logger.print(log)

  return opt, logger, vis