import tensorflow as tf import numpy as np import re from baselines.acktr.kfac_utils import * from functools import reduce KFAC_OPS = ['MatMul', 'Conv2D', 'BiasAdd'] KFAC_DEBUG = False class KfacOptimizer(): def __init__(self, learning_rate=0.01, momentum=0.9, clip_kl=0.01, kfac_update=2, stats_accum_iter=60, full_stats_init=False, cold_iter=100, cold_lr=None, async=False, async_stats=False, epsilon=1e-2, stats_decay=0.95, blockdiag_bias=False, channel_fac=False, factored_damping=False, approxT2=False, use_float64=False, weight_decay_dict={},max_grad_norm=0.5): self.max_grad_norm = max_grad_norm self._lr = learning_rate self._momentum = momentum self._clip_kl = clip_kl self._channel_fac = channel_fac self._kfac_update = kfac_update self._async = async self._async_stats = async_stats self._epsilon = epsilon self._stats_decay = stats_decay self._blockdiag_bias = blockdiag_bias self._approxT2 = approxT2 self._use_float64 = use_float64 self._factored_damping = factored_damping self._cold_iter = cold_iter if cold_lr == None: # good heuristics self._cold_lr = self._lr# * 3. else: self._cold_lr = cold_lr self._stats_accum_iter = stats_accum_iter self._weight_decay_dict = weight_decay_dict self._diag_init_coeff = 0. self._full_stats_init = full_stats_init if not self._full_stats_init: self._stats_accum_iter = self._cold_iter self.sgd_step = tf.Variable(0, name='KFAC/sgd_step', trainable=False) self.global_step = tf.Variable( 0, name='KFAC/global_step', trainable=False) self.cold_step = tf.Variable(0, name='KFAC/cold_step', trainable=False) self.factor_step = tf.Variable( 0, name='KFAC/factor_step', trainable=False) self.stats_step = tf.Variable( 0, name='KFAC/stats_step', trainable=False) self.vFv = tf.Variable(0., name='KFAC/vFv', trainable=False) self.factors = {} self.param_vars = [] self.stats = {} self.stats_eigen = {} def getFactors(self, g, varlist): graph = tf.get_default_graph() factorTensors = {} fpropTensors = [] bpropTensors = [] opTypes = [] fops = [] def searchFactors(gradient, graph): # hard coded search stratergy bpropOp = gradient.op bpropOp_name = bpropOp.name bTensors = [] fTensors = [] # combining additive gradient, assume they are the same op type and # indepedent if 'AddN' in bpropOp_name: factors = [] for g in gradient.op.inputs: factors.append(searchFactors(g, graph)) op_names = [item['opName'] for item in factors] # TO-DO: need to check all the attribute of the ops as well print (gradient.name) print (op_names) print (len(np.unique(op_names))) assert len(np.unique(op_names)) == 1, gradient.name + \ ' is shared among different computation OPs' bTensors = reduce(lambda x, y: x + y, [item['bpropFactors'] for item in factors]) if len(factors[0]['fpropFactors']) > 0: fTensors = reduce( lambda x, y: x + y, [item['fpropFactors'] for item in factors]) fpropOp_name = op_names[0] fpropOp = factors[0]['op'] else: fpropOp_name = re.search( 'gradientsSampled(_[0-9]+|)/(.+?)_grad', bpropOp_name).group(2) fpropOp = graph.get_operation_by_name(fpropOp_name) if fpropOp.op_def.name in KFAC_OPS: # Known OPs ### bTensor = [ i for i in bpropOp.inputs if 'gradientsSampled' in i.name][-1] bTensorShape = fpropOp.outputs[0].get_shape() if bTensor.get_shape()[0].value == None: bTensor.set_shape(bTensorShape) bTensors.append(bTensor) ### if fpropOp.op_def.name == 'BiasAdd': fTensors = [] else: fTensors.append( [i for i in fpropOp.inputs if param.op.name not in i.name][0]) fpropOp_name = fpropOp.op_def.name else: # unknown OPs, block approximation used bInputsList = [i for i in bpropOp.inputs[ 0].op.inputs if 'gradientsSampled' in i.name if 'Shape' not in i.name] if len(bInputsList) > 0: bTensor = bInputsList[0] bTensorShape = fpropOp.outputs[0].get_shape() if len(bTensor.get_shape()) > 0 and bTensor.get_shape()[0].value == None: bTensor.set_shape(bTensorShape) bTensors.append(bTensor) fpropOp_name = opTypes.append('UNK-' + fpropOp.op_def.name) return {'opName': fpropOp_name, 'op': fpropOp, 'fpropFactors': fTensors, 'bpropFactors': bTensors} for t, param in zip(g, varlist): if KFAC_DEBUG: print(('get factor for '+param.name)) factors = searchFactors(t, graph) factorTensors[param] = factors ######## # check associated weights and bias for homogeneous coordinate representation # and check redundent factors # TO-DO: there may be a bug to detect associate bias and weights for # forking layer, e.g. in inception models. for param in varlist: factorTensors[param]['assnWeights'] = None factorTensors[param]['assnBias'] = None for param in varlist: if factorTensors[param]['opName'] == 'BiasAdd': factorTensors[param]['assnWeights'] = None for item in varlist: if len(factorTensors[item]['bpropFactors']) > 0: if (set(factorTensors[item]['bpropFactors']) == set(factorTensors[param]['bpropFactors'])) and (len(factorTensors[item]['fpropFactors']) > 0): factorTensors[param]['assnWeights'] = item factorTensors[item]['assnBias'] = param factorTensors[param]['bpropFactors'] = factorTensors[ item]['bpropFactors'] ######## ######## # concatenate the additive gradients along the batch dimension, i.e. # assuming independence structure for key in ['fpropFactors', 'bpropFactors']: for i, param in enumerate(varlist): if len(factorTensors[param][key]) > 0: if (key + '_concat') not in factorTensors[param]: name_scope = factorTensors[param][key][0].name.split(':')[ 0] with tf.name_scope(name_scope): factorTensors[param][ key + '_concat'] = tf.concat(factorTensors[param][key], 0) else: factorTensors[param][key + '_concat'] = None for j, param2 in enumerate(varlist[(i + 1):]): if (len(factorTensors[param][key]) > 0) and (set(factorTensors[param2][key]) == set(factorTensors[param][key])): factorTensors[param2][key] = factorTensors[param][key] factorTensors[param2][ key + '_concat'] = factorTensors[param][key + '_concat'] ######## if KFAC_DEBUG: for items in zip(varlist, fpropTensors, bpropTensors, opTypes): print((items[0].name, factorTensors[item])) self.factors = factorTensors return factorTensors def getStats(self, factors, varlist): if len(self.stats) == 0: # initialize stats variables on CPU because eigen decomp is # computed on CPU with tf.device('/cpu'): tmpStatsCache = {} # search for tensor factors and # use block diag approx for the bias units for var in varlist: fpropFactor = factors[var]['fpropFactors_concat'] bpropFactor = factors[var]['bpropFactors_concat'] opType = factors[var]['opName'] if opType == 'Conv2D': Kh = var.get_shape()[0] Kw = var.get_shape()[1] C = fpropFactor.get_shape()[-1] Oh = bpropFactor.get_shape()[1] Ow = bpropFactor.get_shape()[2] if Oh == 1 and Ow == 1 and self._channel_fac: # factorization along the channels do not support # homogeneous coordinate var_assnBias = factors[var]['assnBias'] if var_assnBias: factors[var]['assnBias'] = None factors[var_assnBias]['assnWeights'] = None ## for var in varlist: fpropFactor = factors[var]['fpropFactors_concat'] bpropFactor = factors[var]['bpropFactors_concat'] opType = factors[var]['opName'] self.stats[var] = {'opName': opType, 'fprop_concat_stats': [], 'bprop_concat_stats': [], 'assnWeights': factors[var]['assnWeights'], 'assnBias': factors[var]['assnBias'], } if fpropFactor is not None: if fpropFactor not in tmpStatsCache: if opType == 'Conv2D': Kh = var.get_shape()[0] Kw = var.get_shape()[1] C = fpropFactor.get_shape()[-1] Oh = bpropFactor.get_shape()[1] Ow = bpropFactor.get_shape()[2] if Oh == 1 and Ow == 1 and self._channel_fac: # factorization along the channels # assume independence between input channels and spatial # 2K-1 x 2K-1 covariance matrix and C x C covariance matrix # factorization along the channels do not # support homogeneous coordinate, assnBias # is always None fpropFactor2_size = Kh * Kw slot_fpropFactor_stats2 = tf.Variable(tf.diag(tf.ones( [fpropFactor2_size])) * self._diag_init_coeff, name='KFAC_STATS/' + fpropFactor.op.name, trainable=False) self.stats[var]['fprop_concat_stats'].append( slot_fpropFactor_stats2) fpropFactor_size = C else: # 2K-1 x 2K-1 x C x C covariance matrix # assume BHWC fpropFactor_size = Kh * Kw * C else: # D x D covariance matrix fpropFactor_size = fpropFactor.get_shape()[-1] # use homogeneous coordinate if not self._blockdiag_bias and self.stats[var]['assnBias']: fpropFactor_size += 1 slot_fpropFactor_stats = tf.Variable(tf.diag(tf.ones( [fpropFactor_size])) * self._diag_init_coeff, name='KFAC_STATS/' + fpropFactor.op.name, trainable=False) self.stats[var]['fprop_concat_stats'].append( slot_fpropFactor_stats) if opType != 'Conv2D': tmpStatsCache[fpropFactor] = self.stats[ var]['fprop_concat_stats'] else: self.stats[var][ 'fprop_concat_stats'] = tmpStatsCache[fpropFactor] if bpropFactor is not None: # no need to collect backward stats for bias vectors if # using homogeneous coordinates if not((not self._blockdiag_bias) and self.stats[var]['assnWeights']): if bpropFactor not in tmpStatsCache: slot_bpropFactor_stats = tf.Variable(tf.diag(tf.ones([bpropFactor.get_shape( )[-1]])) * self._diag_init_coeff, name='KFAC_STATS/' + bpropFactor.op.name, trainable=False) self.stats[var]['bprop_concat_stats'].append( slot_bpropFactor_stats) tmpStatsCache[bpropFactor] = self.stats[ var]['bprop_concat_stats'] else: self.stats[var][ 'bprop_concat_stats'] = tmpStatsCache[bpropFactor] return self.stats def compute_and_apply_stats(self, loss_sampled, var_list=None): varlist = var_list if varlist is None: varlist = tf.trainable_variables() stats = self.compute_stats(loss_sampled, var_list=varlist) return self.apply_stats(stats) def compute_stats(self, loss_sampled, var_list=None): varlist = var_list if varlist is None: varlist = tf.trainable_variables() gs = tf.gradients(loss_sampled, varlist, name='gradientsSampled') self.gs = gs factors = self.getFactors(gs, varlist) stats = self.getStats(factors, varlist) updateOps = [] statsUpdates = {} statsUpdates_cache = {} for var in varlist: opType = factors[var]['opName'] fops = factors[var]['op'] fpropFactor = factors[var]['fpropFactors_concat'] fpropStats_vars = stats[var]['fprop_concat_stats'] bpropFactor = factors[var]['bpropFactors_concat'] bpropStats_vars = stats[var]['bprop_concat_stats'] SVD_factors = {} for stats_var in fpropStats_vars: stats_var_dim = int(stats_var.get_shape()[0]) if stats_var not in statsUpdates_cache: old_fpropFactor = fpropFactor B = (tf.shape(fpropFactor)[0]) # batch size if opType == 'Conv2D': strides = fops.get_attr("strides") padding = fops.get_attr("padding") convkernel_size = var.get_shape()[0:3] KH = int(convkernel_size[0]) KW = int(convkernel_size[1]) C = int(convkernel_size[2]) flatten_size = int(KH * KW * C) Oh = int(bpropFactor.get_shape()[1]) Ow = int(bpropFactor.get_shape()[2]) if Oh == 1 and Ow == 1 and self._channel_fac: # factorization along the channels # assume independence among input channels # factor = B x 1 x 1 x (KH xKW x C) # patches = B x Oh x Ow x (KH xKW x C) if len(SVD_factors) == 0: if KFAC_DEBUG: print(('approx %s act factor with rank-1 SVD factors' % (var.name))) # find closest rank-1 approx to the feature map S, U, V = tf.batch_svd(tf.reshape( fpropFactor, [-1, KH * KW, C])) # get rank-1 approx slides sqrtS1 = tf.expand_dims(tf.sqrt(S[:, 0, 0]), 1) patches_k = U[:, :, 0] * sqrtS1 # B x KH*KW full_factor_shape = fpropFactor.get_shape() patches_k.set_shape( [full_factor_shape[0], KH * KW]) patches_c = V[:, :, 0] * sqrtS1 # B x C patches_c.set_shape([full_factor_shape[0], C]) SVD_factors[C] = patches_c SVD_factors[KH * KW] = patches_k fpropFactor = SVD_factors[stats_var_dim] else: # poor mem usage implementation patches = tf.extract_image_patches(fpropFactor, ksizes=[1, convkernel_size[ 0], convkernel_size[1], 1], strides=strides, rates=[1, 1, 1, 1], padding=padding) if self._approxT2: if KFAC_DEBUG: print(('approxT2 act fisher for %s' % (var.name))) # T^2 terms * 1/T^2, size: B x C fpropFactor = tf.reduce_mean(patches, [1, 2]) else: # size: (B x Oh x Ow) x C fpropFactor = tf.reshape( patches, [-1, flatten_size]) / Oh / Ow fpropFactor_size = int(fpropFactor.get_shape()[-1]) if stats_var_dim == (fpropFactor_size + 1) and not self._blockdiag_bias: if opType == 'Conv2D' and not self._approxT2: # correct padding for numerical stability (we # divided out OhxOw from activations for T1 approx) fpropFactor = tf.concat([fpropFactor, tf.ones( [tf.shape(fpropFactor)[0], 1]) / Oh / Ow], 1) else: # use homogeneous coordinates fpropFactor = tf.concat( [fpropFactor, tf.ones([tf.shape(fpropFactor)[0], 1])], 1) # average over the number of data points in a batch # divided by B cov = tf.matmul(fpropFactor, fpropFactor, transpose_a=True) / tf.cast(B, tf.float32) updateOps.append(cov) statsUpdates[stats_var] = cov if opType != 'Conv2D': # HACK: for convolution we recompute fprop stats for # every layer including forking layers statsUpdates_cache[stats_var] = cov for stats_var in bpropStats_vars: stats_var_dim = int(stats_var.get_shape()[0]) if stats_var not in statsUpdates_cache: old_bpropFactor = bpropFactor bpropFactor_shape = bpropFactor.get_shape() B = tf.shape(bpropFactor)[0] # batch size C = int(bpropFactor_shape[-1]) # num channels if opType == 'Conv2D' or len(bpropFactor_shape) == 4: if fpropFactor is not None: if self._approxT2: if KFAC_DEBUG: print(('approxT2 grad fisher for %s' % (var.name))) bpropFactor = tf.reduce_sum( bpropFactor, [1, 2]) # T^2 terms * 1/T^2 else: bpropFactor = tf.reshape( bpropFactor, [-1, C]) * Oh * Ow # T * 1/T terms else: # just doing block diag approx. spatial independent # structure does not apply here. summing over # spatial locations if KFAC_DEBUG: print(('block diag approx fisher for %s' % (var.name))) bpropFactor = tf.reduce_sum(bpropFactor, [1, 2]) # assume sampled loss is averaged. TO-DO:figure out better # way to handle this bpropFactor *= tf.to_float(B) ## cov_b = tf.matmul( bpropFactor, bpropFactor, transpose_a=True) / tf.to_float(tf.shape(bpropFactor)[0]) updateOps.append(cov_b) statsUpdates[stats_var] = cov_b statsUpdates_cache[stats_var] = cov_b if KFAC_DEBUG: aKey = list(statsUpdates.keys())[0] statsUpdates[aKey] = tf.Print(statsUpdates[aKey], [tf.convert_to_tensor('step:'), self.global_step, tf.convert_to_tensor( 'computing stats'), ]) self.statsUpdates = statsUpdates return statsUpdates def apply_stats(self, statsUpdates): """ compute stats and update/apply the new stats to the running average """ def updateAccumStats(): if self._full_stats_init: return tf.cond(tf.greater(self.sgd_step, self._cold_iter), lambda: tf.group(*self._apply_stats(statsUpdates, accumulate=True, accumulateCoeff=1. / self._stats_accum_iter)), tf.no_op) else: return tf.group(*self._apply_stats(statsUpdates, accumulate=True, accumulateCoeff=1. / self._stats_accum_iter)) def updateRunningAvgStats(statsUpdates, fac_iter=1): # return tf.cond(tf.greater_equal(self.factor_step, # tf.convert_to_tensor(fac_iter)), lambda: # tf.group(*self._apply_stats(stats_list, varlist)), tf.no_op) return tf.group(*self._apply_stats(statsUpdates)) if self._async_stats: # asynchronous stats update update_stats = self._apply_stats(statsUpdates) queue = tf.FIFOQueue(1, [item.dtype for item in update_stats], shapes=[ item.get_shape() for item in update_stats]) enqueue_op = queue.enqueue(update_stats) def dequeue_stats_op(): return queue.dequeue() self.qr_stats = tf.train.QueueRunner(queue, [enqueue_op]) update_stats_op = tf.cond(tf.equal(queue.size(), tf.convert_to_tensor( 0)), tf.no_op, lambda: tf.group(*[dequeue_stats_op(), ])) else: # synchronous stats update update_stats_op = tf.cond(tf.greater_equal( self.stats_step, self._stats_accum_iter), lambda: updateRunningAvgStats(statsUpdates), updateAccumStats) self._update_stats_op = update_stats_op return update_stats_op def _apply_stats(self, statsUpdates, accumulate=False, accumulateCoeff=0.): updateOps = [] # obtain the stats var list for stats_var in statsUpdates: stats_new = statsUpdates[stats_var] if accumulate: # simple superbatch averaging update_op = tf.assign_add( stats_var, accumulateCoeff * stats_new, use_locking=True) else: # exponential running averaging update_op = tf.assign( stats_var, stats_var * self._stats_decay, use_locking=True) update_op = tf.assign_add( update_op, (1. - self._stats_decay) * stats_new, use_locking=True) updateOps.append(update_op) with tf.control_dependencies(updateOps): stats_step_op = tf.assign_add(self.stats_step, 1) if KFAC_DEBUG: stats_step_op = (tf.Print(stats_step_op, [tf.convert_to_tensor('step:'), self.global_step, tf.convert_to_tensor('fac step:'), self.factor_step, tf.convert_to_tensor('sgd step:'), self.sgd_step, tf.convert_to_tensor('Accum:'), tf.convert_to_tensor(accumulate), tf.convert_to_tensor('Accum coeff:'), tf.convert_to_tensor(accumulateCoeff), tf.convert_to_tensor('stat step:'), self.stats_step, updateOps[0], updateOps[1]])) return [stats_step_op, ] def getStatsEigen(self, stats=None): if len(self.stats_eigen) == 0: stats_eigen = {} if stats is None: stats = self.stats tmpEigenCache = {} with tf.device('/cpu:0'): for var in stats: for key in ['fprop_concat_stats', 'bprop_concat_stats']: for stats_var in stats[var][key]: if stats_var not in tmpEigenCache: stats_dim = stats_var.get_shape()[1].value e = tf.Variable(tf.ones( [stats_dim]), name='KFAC_FAC/' + stats_var.name.split(':')[0] + '/e', trainable=False) Q = tf.Variable(tf.diag(tf.ones( [stats_dim])), name='KFAC_FAC/' + stats_var.name.split(':')[0] + '/Q', trainable=False) stats_eigen[stats_var] = {'e': e, 'Q': Q} tmpEigenCache[ stats_var] = stats_eigen[stats_var] else: stats_eigen[stats_var] = tmpEigenCache[ stats_var] self.stats_eigen = stats_eigen return self.stats_eigen def computeStatsEigen(self): """ compute the eigen decomp using copied var stats to avoid concurrent read/write from other queue """ # TO-DO: figure out why this op has delays (possibly moving # eigenvectors around?) with tf.device('/cpu:0'): def removeNone(tensor_list): local_list = [] for item in tensor_list: if item is not None: local_list.append(item) return local_list def copyStats(var_list): print("copying stats to buffer tensors before eigen decomp") redundant_stats = {} copied_list = [] for item in var_list: if item is not None: if item not in redundant_stats: if self._use_float64: redundant_stats[item] = tf.cast( tf.identity(item), tf.float64) else: redundant_stats[item] = tf.identity(item) copied_list.append(redundant_stats[item]) else: copied_list.append(None) return copied_list #stats = [copyStats(self.fStats), copyStats(self.bStats)] #stats = [self.fStats, self.bStats] stats_eigen = self.stats_eigen computedEigen = {} eigen_reverse_lookup = {} updateOps = [] # sync copied stats # with tf.control_dependencies(removeNone(stats[0]) + # removeNone(stats[1])): with tf.control_dependencies([]): for stats_var in stats_eigen: if stats_var not in computedEigen: eigens = tf.self_adjoint_eig(stats_var) e = eigens[0] Q = eigens[1] if self._use_float64: e = tf.cast(e, tf.float32) Q = tf.cast(Q, tf.float32) updateOps.append(e) updateOps.append(Q) computedEigen[stats_var] = {'e': e, 'Q': Q} eigen_reverse_lookup[e] = stats_eigen[stats_var]['e'] eigen_reverse_lookup[Q] = stats_eigen[stats_var]['Q'] self.eigen_reverse_lookup = eigen_reverse_lookup self.eigen_update_list = updateOps if KFAC_DEBUG: self.eigen_update_list = [item for item in updateOps] with tf.control_dependencies(updateOps): updateOps.append(tf.Print(tf.constant( 0.), [tf.convert_to_tensor('computed factor eigen')])) return updateOps def applyStatsEigen(self, eigen_list): updateOps = [] print(('updating %d eigenvalue/vectors' % len(eigen_list))) for i, (tensor, mark) in enumerate(zip(eigen_list, self.eigen_update_list)): stats_eigen_var = self.eigen_reverse_lookup[mark] updateOps.append( tf.assign(stats_eigen_var, tensor, use_locking=True)) with tf.control_dependencies(updateOps): factor_step_op = tf.assign_add(self.factor_step, 1) updateOps.append(factor_step_op) if KFAC_DEBUG: updateOps.append(tf.Print(tf.constant( 0.), [tf.convert_to_tensor('updated kfac factors')])) return updateOps def getKfacPrecondUpdates(self, gradlist, varlist): updatelist = [] vg = 0. assert len(self.stats) > 0 assert len(self.stats_eigen) > 0 assert len(self.factors) > 0 counter = 0 grad_dict = {var: grad for grad, var in zip(gradlist, varlist)} for grad, var in zip(gradlist, varlist): GRAD_RESHAPE = False GRAD_TRANSPOSE = False fpropFactoredFishers = self.stats[var]['fprop_concat_stats'] bpropFactoredFishers = self.stats[var]['bprop_concat_stats'] if (len(fpropFactoredFishers) + len(bpropFactoredFishers)) > 0: counter += 1 GRAD_SHAPE = grad.get_shape() if len(grad.get_shape()) > 2: # reshape conv kernel parameters KW = int(grad.get_shape()[0]) KH = int(grad.get_shape()[1]) C = int(grad.get_shape()[2]) D = int(grad.get_shape()[3]) if len(fpropFactoredFishers) > 1 and self._channel_fac: # reshape conv kernel parameters into tensor grad = tf.reshape(grad, [KW * KH, C, D]) else: # reshape conv kernel parameters into 2D grad grad = tf.reshape(grad, [-1, D]) GRAD_RESHAPE = True elif len(grad.get_shape()) == 1: # reshape bias or 1D parameters D = int(grad.get_shape()[0]) grad = tf.expand_dims(grad, 0) GRAD_RESHAPE = True else: # 2D parameters C = int(grad.get_shape()[0]) D = int(grad.get_shape()[1]) if (self.stats[var]['assnBias'] is not None) and not self._blockdiag_bias: # use homogeneous coordinates only works for 2D grad. # TO-DO: figure out how to factorize bias grad # stack bias grad var_assnBias = self.stats[var]['assnBias'] grad = tf.concat( [grad, tf.expand_dims(grad_dict[var_assnBias], 0)], 0) # project gradient to eigen space and reshape the eigenvalues # for broadcasting eigVals = [] for idx, stats in enumerate(self.stats[var]['fprop_concat_stats']): Q = self.stats_eigen[stats]['Q'] e = detectMinVal(self.stats_eigen[stats][ 'e'], var, name='act', debug=KFAC_DEBUG) Q, e = factorReshape(Q, e, grad, facIndx=idx, ftype='act') eigVals.append(e) grad = gmatmul(Q, grad, transpose_a=True, reduce_dim=idx) for idx, stats in enumerate(self.stats[var]['bprop_concat_stats']): Q = self.stats_eigen[stats]['Q'] e = detectMinVal(self.stats_eigen[stats][ 'e'], var, name='grad', debug=KFAC_DEBUG) Q, e = factorReshape(Q, e, grad, facIndx=idx, ftype='grad') eigVals.append(e) grad = gmatmul(grad, Q, transpose_b=False, reduce_dim=idx) ## ##### # whiten using eigenvalues weightDecayCoeff = 0. if var in self._weight_decay_dict: weightDecayCoeff = self._weight_decay_dict[var] if KFAC_DEBUG: print(('weight decay coeff for %s is %f' % (var.name, weightDecayCoeff))) if self._factored_damping: if KFAC_DEBUG: print(('use factored damping for %s' % (var.name))) coeffs = 1. num_factors = len(eigVals) # compute the ratio of two trace norm of the left and right # KFac matrices, and their generalization if len(eigVals) == 1: damping = self._epsilon + weightDecayCoeff else: damping = tf.pow( self._epsilon + weightDecayCoeff, 1. / num_factors) eigVals_tnorm_avg = [tf.reduce_mean( tf.abs(e)) for e in eigVals] for e, e_tnorm in zip(eigVals, eigVals_tnorm_avg): eig_tnorm_negList = [ item for item in eigVals_tnorm_avg if item != e_tnorm] if len(eigVals) == 1: adjustment = 1. elif len(eigVals) == 2: adjustment = tf.sqrt( e_tnorm / eig_tnorm_negList[0]) else: eig_tnorm_negList_prod = reduce( lambda x, y: x * y, eig_tnorm_negList) adjustment = tf.pow( tf.pow(e_tnorm, num_factors - 1.) / eig_tnorm_negList_prod, 1. / num_factors) coeffs *= (e + adjustment * damping) else: coeffs = 1. damping = (self._epsilon + weightDecayCoeff) for e in eigVals: coeffs *= e coeffs += damping #grad = tf.Print(grad, [tf.convert_to_tensor('1'), tf.convert_to_tensor(var.name), grad.get_shape()]) grad /= coeffs #grad = tf.Print(grad, [tf.convert_to_tensor('2'), tf.convert_to_tensor(var.name), grad.get_shape()]) ##### # project gradient back to euclidean space for idx, stats in enumerate(self.stats[var]['fprop_concat_stats']): Q = self.stats_eigen[stats]['Q'] grad = gmatmul(Q, grad, transpose_a=False, reduce_dim=idx) for idx, stats in enumerate(self.stats[var]['bprop_concat_stats']): Q = self.stats_eigen[stats]['Q'] grad = gmatmul(grad, Q, transpose_b=True, reduce_dim=idx) ## #grad = tf.Print(grad, [tf.convert_to_tensor('3'), tf.convert_to_tensor(var.name), grad.get_shape()]) if (self.stats[var]['assnBias'] is not None) and not self._blockdiag_bias: # use homogeneous coordinates only works for 2D grad. # TO-DO: figure out how to factorize bias grad # un-stack bias grad var_assnBias = self.stats[var]['assnBias'] C_plus_one = int(grad.get_shape()[0]) grad_assnBias = tf.reshape(tf.slice(grad, begin=[ C_plus_one - 1, 0], size=[1, -1]), var_assnBias.get_shape()) grad_assnWeights = tf.slice(grad, begin=[0, 0], size=[C_plus_one - 1, -1]) grad_dict[var_assnBias] = grad_assnBias grad = grad_assnWeights #grad = tf.Print(grad, [tf.convert_to_tensor('4'), tf.convert_to_tensor(var.name), grad.get_shape()]) if GRAD_RESHAPE: grad = tf.reshape(grad, GRAD_SHAPE) grad_dict[var] = grad print(('projecting %d gradient matrices' % counter)) for g, var in zip(gradlist, varlist): grad = grad_dict[var] ### clipping ### if KFAC_DEBUG: print(('apply clipping to %s' % (var.name))) tf.Print(grad, [tf.sqrt(tf.reduce_sum(tf.pow(grad, 2)))], "Euclidean norm of new grad") local_vg = tf.reduce_sum(grad * g * (self._lr * self._lr)) vg += local_vg # recale everything if KFAC_DEBUG: print('apply vFv clipping') scaling = tf.minimum(1., tf.sqrt(self._clip_kl / vg)) if KFAC_DEBUG: scaling = tf.Print(scaling, [tf.convert_to_tensor( 'clip: '), scaling, tf.convert_to_tensor(' vFv: '), vg]) with tf.control_dependencies([tf.assign(self.vFv, vg)]): updatelist = [grad_dict[var] for var in varlist] for i, item in enumerate(updatelist): updatelist[i] = scaling * item return updatelist def compute_gradients(self, loss, var_list=None): varlist = var_list if varlist is None: varlist = tf.trainable_variables() g = tf.gradients(loss, varlist) return [(a, b) for a, b in zip(g, varlist)] def apply_gradients_kfac(self, grads): g, varlist = list(zip(*grads)) if len(self.stats_eigen) == 0: self.getStatsEigen() qr = None # launch eigen-decomp on a queue thread if self._async: print('Use async eigen decomp') # get a list of factor loading tensors factorOps_dummy = self.computeStatsEigen() # define a queue for the list of factor loading tensors queue = tf.FIFOQueue(1, [item.dtype for item in factorOps_dummy], shapes=[ item.get_shape() for item in factorOps_dummy]) enqueue_op = tf.cond(tf.logical_and(tf.equal(tf.mod(self.stats_step, self._kfac_update), tf.convert_to_tensor( 0)), tf.greater_equal(self.stats_step, self._stats_accum_iter)), lambda: queue.enqueue(self.computeStatsEigen()), tf.no_op) def dequeue_op(): return queue.dequeue() qr = tf.train.QueueRunner(queue, [enqueue_op]) updateOps = [] global_step_op = tf.assign_add(self.global_step, 1) updateOps.append(global_step_op) with tf.control_dependencies([global_step_op]): # compute updates assert self._update_stats_op != None updateOps.append(self._update_stats_op) dependency_list = [] if not self._async: dependency_list.append(self._update_stats_op) with tf.control_dependencies(dependency_list): def no_op_wrapper(): return tf.group(*[tf.assign_add(self.cold_step, 1)]) if not self._async: # synchronous eigen-decomp updates updateFactorOps = tf.cond(tf.logical_and(tf.equal(tf.mod(self.stats_step, self._kfac_update), tf.convert_to_tensor(0)), tf.greater_equal(self.stats_step, self._stats_accum_iter)), lambda: tf.group(*self.applyStatsEigen(self.computeStatsEigen())), no_op_wrapper) else: # asynchronous eigen-decomp updates using queue updateFactorOps = tf.cond(tf.greater_equal(self.stats_step, self._stats_accum_iter), lambda: tf.cond(tf.equal(queue.size(), tf.convert_to_tensor(0)), tf.no_op, lambda: tf.group( *self.applyStatsEigen(dequeue_op())), ), no_op_wrapper) updateOps.append(updateFactorOps) with tf.control_dependencies([updateFactorOps]): def gradOp(): return list(g) def getKfacGradOp(): return self.getKfacPrecondUpdates(g, varlist) u = tf.cond(tf.greater(self.factor_step, tf.convert_to_tensor(0)), getKfacGradOp, gradOp) optim = tf.train.MomentumOptimizer( self._lr * (1. - self._momentum), self._momentum) #optim = tf.train.AdamOptimizer(self._lr, epsilon=0.01) def optimOp(): def updateOptimOp(): if self._full_stats_init: return tf.cond(tf.greater(self.factor_step, tf.convert_to_tensor(0)), lambda: optim.apply_gradients(list(zip(u, varlist))), tf.no_op) else: return optim.apply_gradients(list(zip(u, varlist))) if self._full_stats_init: return tf.cond(tf.greater_equal(self.stats_step, self._stats_accum_iter), updateOptimOp, tf.no_op) else: return tf.cond(tf.greater_equal(self.sgd_step, self._cold_iter), updateOptimOp, tf.no_op) updateOps.append(optimOp()) return tf.group(*updateOps), qr def apply_gradients(self, grads): coldOptim = tf.train.MomentumOptimizer( self._cold_lr, self._momentum) def coldSGDstart(): sgd_grads, sgd_var = zip(*grads) if self.max_grad_norm != None: sgd_grads, sgd_grad_norm = tf.clip_by_global_norm(sgd_grads,self.max_grad_norm) sgd_grads = list(zip(sgd_grads,sgd_var)) sgd_step_op = tf.assign_add(self.sgd_step, 1) coldOptim_op = coldOptim.apply_gradients(sgd_grads) if KFAC_DEBUG: with tf.control_dependencies([sgd_step_op, coldOptim_op]): sgd_step_op = tf.Print( sgd_step_op, [self.sgd_step, tf.convert_to_tensor('doing cold sgd step')]) return tf.group(*[sgd_step_op, coldOptim_op]) kfacOptim_op, qr = self.apply_gradients_kfac(grads) def warmKFACstart(): return kfacOptim_op return tf.cond(tf.greater(self.sgd_step, self._cold_iter), warmKFACstart, coldSGDstart), qr def minimize(self, loss, loss_sampled, var_list=None): grads = self.compute_gradients(loss, var_list=var_list) update_stats_op = self.compute_and_apply_stats( loss_sampled, var_list=var_list) return self.apply_gradients(grads)