import datetime import json import os import click AMI_MAP = { "us-west-1": "FILL IN YOUR AMI HERE", } def highlight(x): if not isinstance(x, str): x = json.dumps(x, sort_keys=True, indent=2) click.secho(x, fg='green') def upload_archive(exp_name, archive_excludes, s3_bucket): import hashlib, os.path as osp, subprocess, tempfile, uuid, sys # Archive this package thisfile_dir = osp.dirname(osp.abspath(__file__)) pkg_parent_dir = osp.abspath(osp.join(thisfile_dir, '..', '..')) pkg_subdir = osp.basename(osp.abspath(osp.join(thisfile_dir, '..'))) assert osp.abspath(__file__) == osp.join(pkg_parent_dir, pkg_subdir, 'scripts', 'launch.py'), 'You moved me!' # Run tar tmpdir = tempfile.TemporaryDirectory() local_archive_path = osp.join(tmpdir.name, '{}.tar.gz'.format(uuid.uuid4())) tar_cmd = ["tar", "-zcvf", local_archive_path, "-C", pkg_parent_dir] for pattern in archive_excludes: tar_cmd += ["--exclude", pattern] tar_cmd += ["-h", pkg_subdir] highlight(" ".join(tar_cmd)) if sys.platform == 'darwin': # Prevent Mac tar from adding ._* files env = os.environ.copy() env['COPYFILE_DISABLE'] = '1' subprocess.check_call(tar_cmd, env=env) else: subprocess.check_call(tar_cmd) # Construct remote path to place the archive on S3 with open(local_archive_path, 'rb') as f: archive_hash = hashlib.sha224(f.read()).hexdigest() remote_archive_path = '{}/{}_{}.tar.gz'.format(s3_bucket, exp_name, archive_hash) # Upload upload_cmd = ["aws", "s3", "cp", local_archive_path, remote_archive_path] highlight(" ".join(upload_cmd)) subprocess.check_call(upload_cmd) presign_cmd = ["aws", "s3", "presign", remote_archive_path, "--expires-in", str(60 * 60 * 24 * 30)] highlight(" ".join(presign_cmd)) remote_url = subprocess.check_output(presign_cmd).decode("utf-8").strip() return remote_url def make_disable_hyperthreading_script(): return """ # disable hyperthreading # https://forums.aws.amazon.com/message.jspa?messageID=189757 for cpunum in $( cat /sys/devices/system/cpu/cpu*/topology/thread_siblings_list | sed 's/-/,/g' | cut -s -d, -f2- | tr ',' '\n' | sort -un); do echo 0 > /sys/devices/system/cpu/cpu$cpunum/online done """ def make_download_and_run_script(code_url, cmd): return """su -l ubuntu <<'EOF' set -x cd ~ wget --quiet "{code_url}" -O code.tar.gz tar xvaf code.tar.gz rm code.tar.gz cd es-distributed {cmd} EOF """.format(code_url=code_url, cmd=cmd) def make_master_script(code_url, exp_str): cmd = """ cat > ~/experiment.json <<< '{exp_str}' python -m es_distributed.main master \ --master_socket_path /var/run/redis/redis.sock \ --log_dir ~ \ --exp_file ~/experiment.json """.format(exp_str=exp_str) return """#!/bin/bash { set -x %s # Disable redis snapshots echo 'save ""' >> /etc/redis/redis.conf # Make the unix domain socket available for the master client # (TCP is still enabled for workers/relays) echo "unixsocket /var/run/redis/redis.sock" >> /etc/redis/redis.conf echo "unixsocketperm 777" >> /etc/redis/redis.conf mkdir -p /var/run/redis chown ubuntu:ubuntu /var/run/redis systemctl restart redis %s } >> /home/ubuntu/user_data.log 2>&1 """ % (make_disable_hyperthreading_script(), make_download_and_run_script(code_url, cmd)) def make_worker_script(code_url, master_private_ip): cmd = ("MKL_NUM_THREADS=1 OPENBLAS_NUM_THREADS=1 OMP_NUM_THREADS=1 " "python -m es_distributed.main workers " "--master_host {} " "--relay_socket_path /var/run/redis/redis.sock").format(master_private_ip) return """#!/bin/bash { set -x %s # Disable redis snapshots echo 'save ""' >> /etc/redis/redis.conf # Make redis use a unix domain socket and disable TCP sockets sed -ie "s/port 6379/port 0/" /etc/redis/redis.conf echo "unixsocket /var/run/redis/redis.sock" >> /etc/redis/redis.conf echo "unixsocketperm 777" >> /etc/redis/redis.conf mkdir -p /var/run/redis chown ubuntu:ubuntu /var/run/redis systemctl restart redis %s } >> /home/ubuntu/user_data.log 2>&1 """ % (make_disable_hyperthreading_script(), make_download_and_run_script(code_url, cmd)) @click.command() @click.argument('exp_files', nargs=-1, type=click.Path(), required=True) @click.option('--key_name', default=lambda: os.environ["KEY_NAME"]) @click.option('--aws_access_key_id', default=os.environ.get("AWS_ACCESS_KEY", None)) @click.option('--aws_secret_access_key', default=os.environ.get("AWS_ACCESS_SECRET", None)) @click.option('--archive_excludes', default=(".git", "__pycache__", ".idea", "scratch")) @click.option('--s3_bucket') @click.option('--spot_price') @click.option('--region_name') @click.option('--zone') @click.option('--cluster_size', type=int, default=1) @click.option('--spot_master', is_flag=True, help='Use a spot instance as the master') @click.option('--master_instance_type') @click.option('--worker_instance_type') @click.option('--security_group') @click.option('--yes', is_flag=True, help='Skip confirmation prompt') def main(exp_files, key_name, aws_access_key_id, aws_secret_access_key, archive_excludes, s3_bucket, spot_price, region_name, zone, cluster_size, spot_master, master_instance_type, worker_instance_type, security_group, yes ): highlight('Launching:') highlight(locals()) import boto3 ec2 = boto3.resource( "ec2", region_name=region_name, aws_access_key_id=aws_access_key_id, aws_secret_access_key=aws_secret_access_key ) as_client = boto3.client( 'autoscaling', region_name=region_name, aws_access_key_id=aws_access_key_id, aws_secret_access_key=aws_secret_access_key ) for i_exp_file, exp_file in enumerate(exp_files): with open(exp_file, 'r') as f: exp = json.loads(f.read()) highlight('Experiment [{}/{}]:'.format(i_exp_file + 1, len(exp_files))) highlight(exp) if not yes: click.confirm('Continue?', abort=True) exp_prefix = exp['exp_prefix'] exp_str = json.dumps(exp) exp_name = '{}_{}'.format(exp_prefix, datetime.datetime.now().strftime('%Y%m%d-%H%M%S')) code_url = upload_archive(exp_name, archive_excludes, s3_bucket) highlight("code_url: " + code_url) image_id = AMI_MAP[region_name] highlight('Using AMI: {}'.format(image_id)) if spot_master: import base64 requests = ec2.meta.client.request_spot_instances( SpotPrice=spot_price, InstanceCount=1, LaunchSpecification=dict( ImageId=image_id, KeyName=key_name, InstanceType=master_instance_type, EbsOptimized=True, SecurityGroups=[security_group], Placement=dict( AvailabilityZone=zone, ), UserData=base64.b64encode(make_master_script(code_url, exp_str).encode()).decode() ) )['SpotInstanceRequests'] assert len(requests) == 1 request_id = requests[0]['SpotInstanceRequestId'] # Wait for fulfillment highlight('Waiting for spot request {} to be fulfilled'.format(request_id)) ec2.meta.client.get_waiter('spot_instance_request_fulfilled').wait(SpotInstanceRequestIds=[request_id]) req = ec2.meta.client.describe_spot_instance_requests(SpotInstanceRequestIds=[request_id]) master_instance_id = req['SpotInstanceRequests'][0]['InstanceId'] master_instance = ec2.Instance(master_instance_id) else: master_instance = ec2.create_instances( ImageId=image_id, KeyName=key_name, InstanceType=master_instance_type, EbsOptimized=True, SecurityGroups=[security_group], MinCount=1, MaxCount=1, Placement=dict( AvailabilityZone=zone, ), UserData=make_master_script(code_url, exp_str) )[0] master_instance.create_tags( Tags=[ dict(Key="Name", Value=exp_name + "-master"), dict(Key="es_dist_role", Value="master"), dict(Key="exp_prefix", Value=exp_prefix), dict(Key="exp_name", Value=exp_name), ] ) highlight("Master created. IP: %s" % master_instance.public_ip_address) config_resp = as_client.create_launch_configuration( ImageId=image_id, KeyName=key_name, InstanceType=worker_instance_type, EbsOptimized=True, SecurityGroups=[security_group], LaunchConfigurationName=exp_name, UserData=make_worker_script(code_url, master_instance.private_ip_address), SpotPrice=spot_price, ) assert config_resp["ResponseMetadata"]["HTTPStatusCode"] == 200 asg_resp = as_client.create_auto_scaling_group( AutoScalingGroupName=exp_name, LaunchConfigurationName=exp_name, MinSize=cluster_size, MaxSize=cluster_size, DesiredCapacity=cluster_size, AvailabilityZones=[zone], Tags=[ dict(Key="Name", Value=exp_name + "-worker"), dict(Key="es_dist_role", Value="worker"), dict(Key="exp_prefix", Value=exp_prefix), dict(Key="exp_name", Value=exp_name), ] # todo: also try placement group to see if there is increased networking performance ) assert asg_resp["ResponseMetadata"]["HTTPStatusCode"] == 200 highlight("Scaling group created") highlight("%s launched successfully." % exp_name) highlight("Manage at %s" % ( "https://%s.console.aws.amazon.com/ec2/v2/home?region=%s#Instances:sort=tag:Name" % ( region_name, region_name) )) if __name__ == '__main__': main()