"""Asynchronous Distributed Adaptive Gradients (ADAG) Formerly known as ADAG. Performs asynchronous updates with update window. Author: Tommy Mulc """ from __future__ import print_function import tensorflow as tf import argparse import time import os FLAGS = None log_dir = '/logdir' def main(): # Configure config=tf.ConfigProto(log_device_placement=False) #Server Setup cluster_spec = { 'ps':['localhost:2222'], 'worker':['localhost:2223','localhost:2224'] } #allows this node know about all other nodes n_pss = len(cluster_spec['ps']) #the number of parameter servers n_workers = len(cluster_spec['worker']) #the number of worker nodes cluster = tf.train.ClusterSpec(cluster_spec) if FLAGS.job_name == 'ps': #checks if parameter server server = tf.train.Server(cluster, job_name="ps", task_index=FLAGS.task_index, config=config) server.join() else: #it must be a worker server is_chief = (FLAGS.task_index == 0) #checks if this is the chief node server = tf.train.Server(cluster, job_name="worker", task_index=FLAGS.task_index, config=config) # Graph # We must not use train.replicate_device_setter for normal operations # Local operations with tf.device("/job:worker/replica:0/task:%d" % FLAGS.task_index): a = tf.Variable(tf.constant(0.,shape=[2]),dtype=tf.float32, collections=[tf.GraphKeys.LOCAL_VARIABLES]) b = tf.Variable(tf.constant(0.,shape=[2]),dtype=tf.float32, collections=[tf.GraphKeys.LOCAL_VARIABLES]) c=a+b target = tf.constant(100.,shape=[2],dtype=tf.float32) loss = tf.reduce_mean(tf.square(c-target)) local_step = tf.Variable(0,dtype=tf.int32,trainable=False, name='local_step',collections=['local_non_trainable']) lr = .0001 # loptimizer = tf.train.GradientDescentOptimizer(lr) #local optimizer loptimizer = tf.train.AdamOptimizer(lr) #local optimizer # ADAG (simplest case since all batches are the same) update_window = 3 # T: update/communication window grad_list = [] # the array to store the gradients through the communication window for t in range(update_window): if t != 0: with tf.control_dependencies([opt_local]): #compute gradients only if the local opt was run grads, varss = zip(*loptimizer.compute_gradients(loss, var_list=tf.local_variables())) else: grads, varss = zip(*loptimizer.compute_gradients(loss, var_list=tf.local_variables())) grad_list.append(grads) #add gradients to the list opt_local = loptimizer.apply_gradients(zip(grads,varss), global_step=local_step) #update local parameters grads = tf.reduce_mean(grad_list,axis=0) grads = tuple([grads[i]for i in range(len(varss))]) # add these variables created by local optimizer to local collection lopt_vars = add_global_variables_to_local_collection() # delete the variables from the global collection clear_global_collection() with tf.device(tf.train.replica_device_setter(ps_tasks=n_pss, worker_device="/job:%s/task:%d" % (FLAGS.job_name,FLAGS.task_index))): global_step = tf.Variable(0,dtype=tf.int32,trainable=False,name='global_step') # optimizer for central variables optimizer = tf.train.AdamOptimizer(lr) # optimizer = tf.train.GradientDescentOptimizer(lr) #create global variables and/or references local_to_global, global_to_local = create_global_variables(lopt_vars) opt = optimizer.apply_gradients( zip(grads,[ local_to_global[v] for v in varss]) ,global_step=global_step) #apply the gradients to variables on ps # Pull param from global server with tf.control_dependencies([opt]): assign_locals = assign_global_to_local(global_to_local) # Init ops init_local = tf.variables_initializer(tf.local_variables() \ +tf.get_collection('local_non_trainable'))#for local variables init = tf.global_variables_initializer() # for global variables # Grab global state before training so all workers have same initialization grab_global_init = assign_global_to_local(global_to_local) # Assigns local values to global ones for chief to execute assign_global = assign_local_to_global(local_to_global) # Session stop_hook = tf.train.StopAtStepHook(last_step=40) hooks = [stop_hook] scaff = tf.train.Scaffold(init_op=init,local_init_op=init_local) #Monitored Training Session sess = tf.train.MonitoredTrainingSession(master=server.target, is_chief=is_chief, config=config, scaffold=scaff, hooks=hooks, save_checkpoint_secs=1, checkpoint_dir='logdir') if is_chief: sess.run(assign_global) #Assigns chief's initial values to ps time.sleep(10) #grace period to wait on other workers before starting training # Train until hook stops session print('Starting training on worker %d'%FLAGS.task_index) sess.run(grab_global_init) while not sess.should_stop(): _,_,r,gs,ls = sess.run([opt,assign_locals,c,global_step,local_step]) print(r,"global step: "+str(gs),"worker: "+str(FLAGS.task_index),"local step: "+str(ls)) time.sleep(1) print('Done',FLAGS.task_index) time.sleep(10) #grace period to wait before closing session sess.close() print('Session from worker %d closed cleanly'%FLAGS.task_index) def assign_global_to_local(global_to_local): """ global_to_local : dictionary with corresponding local variable for global key Assigns global variable value to local variables """ r = [] for v in global_to_local.keys(): r.append(tf.assign(global_to_local[v],v)) with tf.control_dependencies(r): a = tf.no_op() return a def assign_local_to_global(local_to_global): """Assigns global variable value to local variables. local_to_global : dictionary with corresponding global variable for local key """ r= [] for v in local_to_global.keys(): r.append(tf.assign(local_to_global[v],v)) with tf.control_dependencies(r): a = tf.no_op() return a def get_global_variable_by_name(name): """Returns the global variable of given name. name : the name of the global variable """ return [v for v in tf.global_variables() if v.name == name][0] def create_global_variables(local_optimizer_vars = []): """Creates global variables for local variables on the graph. Skips variables local variables that are created for local optimization. Returns dictionarys for local-to-global and global-to-local variable mappings. """ local_to_global = {} global_to_local = {} with tf.device('/job:ps/task:0'): for v in tf.local_variables(): if v not in local_optimizer_vars: v_g = tf.get_variable('g/'+v.op.name, shape = v.shape, dtype = v.dtype, trainable=True, collections=[tf.GraphKeys.GLOBAL_VARIABLES, tf.GraphKeys.TRAINABLE_VARIABLES]) local_to_global[v] = v_g global_to_local[v_g] = v return local_to_global,global_to_local def add_global_variables_to_local_collection(): """Adds all variables from the global collection to the local collection. Returns the list of variables added. """ r =[] for var in tf.get_default_graph()._collections[tf.GraphKeys.GLOBAL_VARIABLES]: tf.add_to_collection(tf.GraphKeys.LOCAL_VARIABLES,var) r.append(var) return r def clear_global_collection(): """Removes all variables from global collection.""" g = tf.get_default_graph() for _ in range(len(g._collections[tf.GraphKeys.GLOBAL_VARIABLES])): del g._collections[tf.GraphKeys.GLOBAL_VARIABLES][0] if __name__ == '__main__': parser = argparse.ArgumentParser() # Flags for defining the tf.train.ClusterSpec parser.add_argument( "--job_name", type=str, default="", help="One of 'ps', 'worker'" ) # Flags for defining the tf.train.Server parser.add_argument( "--task_index", type=int, default=0, help="Index of task within the job" ) FLAGS, unparsed = parser.parse_known_args() print(FLAGS.task_index) main()