import numpy as np, tfutil as mu, aolib.util as ut, aolib.sound as sound, aolib.img as ig, sep_dset, tensorflow as tf, aolib.imtable as imtable, shift_net, gc, soundrep import tensorflow.contrib.slim as slim ed = tf.expand_dims shape = mu.shape add_n = mu.maybe_add_n pj = ut.pjoin cast_complex = mu.cast_complex cast_float = mu.cast_float cast_int = mu.cast_int def on_cpu(f): return mu.on_cpu(f) #return f() class NetClf: def __init__(self, pr, model_path, sess = None, gpu = None, restore_only_shift = False, input_sr = None): self.pr = pr self.sess = sess self.gpu = gpu self.model_path = model_path self.restore_only_shift = restore_only_shift self.input_sr = input_sr def init(self, reset = True): if self.sess is None: print 'Running on:', self.gpu with tf.device(self.gpu): if reset: tf.reset_default_graph() tf.Graph().as_default() pr = self.pr self.sess = tf.Session() self.ims_ph = tf.placeholder( tf.uint8, [1, pr.sampled_frames, pr.crop_im_dim, pr.crop_im_dim, 3]) self.samples_ph = tf.placeholder(tf.float32, (1, pr.num_samples, 2)) crop_spec = lambda x : x[:, :pr.spec_len] samples_trunc = self.samples_ph[:, :pr.sample_len] spec_mix, phase_mix = sep_module(pr).stft(samples_trunc[:, :, 0], pr) spec_mix = crop_spec(spec_mix) phase_mix = crop_spec(phase_mix) self.specgram_op, phase = map(crop_spec, sep_module(pr).stft(samples_trunc[:, :, 0], pr)) self.auto_op = sep_module(pr).istft(self.specgram_op, phase, pr) self.net = sep_module(pr).make_net( self.ims_ph, samples_trunc, spec_mix, phase_mix, pr, reuse = False, train = False) self.spec_pred_fg = self.net.pred_spec_fg self.spec_pred_bg = self.net.pred_spec_bg self.samples_pred_fg = self.net.pred_wav_fg self.samples_pred_bg = self.net.pred_wav_bg print 'Restoring from:', self.model_path if self.restore_only_shift: print 'restoring only shift' import tensorflow.contrib.slim as slim var_list = slim.get_variables_to_restore() var_list = [x for x in var_list if x.name.startswith('im/') or x.name.startswith('sf/') or x.name.startswith('joint/')] self.sess.run(tf.global_variables_initializer()) tf.train.Saver(var_list).restore(self.sess, self.model_path) else: tf.train.Saver().restore(self.sess, self.model_path) tf.get_default_graph().finalize() def predict(self, ims, samples): print 'predict' print 'samples shape:', samples.shape spec_mix = self.sess.run(self.specgram_op, {self.samples_ph : samples}) spec_pred, spec_pred_bg, samples_pred_fg, samples_pred_bg = self.sess.run( [self.spec_pred, self.spec_pred_bg, self.samples_pred_fg, self.samples_pred_bg], {self.ims_ph : ims, self.samples_ph : samples}) print 'samples pred shape:', samples.shape return dict(samples_pred_fg = samples_pred_fg, samples_pred_bg = samples_pred_bg, samples_mix = samples, spec_pred = spec_pred, spec_pred_bg = spec_pred_bg, spec_mix = spec_mix) def predict_unmixed(self, ims, samples0, samples1): # undo mixing samples_mix = samples0 + samples1 spec_pred_fg, samples_pred_fg, spec_pred_bg, samples_pred_bg = self.sess.run( [self.spec_pred_fg, self.samples_pred_fg, self.spec_pred_bg, self.samples_pred_bg], {self.ims_ph : ims[None], self.samples_ph : samples_mix[None]}) spec0 = self.sess.run(self.specgram_op, {self.samples_ph : samples0[None]}) spec1 = self.sess.run(self.specgram_op, {self.samples_ph : samples1[None]}) spec_mix = self.sess.run(self.specgram_op, {self.samples_ph : samples_mix[None]}) auto0 = self.sess.run(self.auto_op, {self.samples_ph : samples0[None]}) auto1 = self.sess.run(self.auto_op, {self.samples_ph : samples1[None]}) auto_mix = self.sess.run(self.auto_op, {self.samples_ph : samples_mix[None]}) return dict(samples_pred_fg = samples_pred_fg[0], samples_pred_bg = samples_pred_bg[0], spec_pred_fg = spec_pred_fg[0], spec_pred_bg = spec_pred_bg[0], spec0 = spec0[0], spec1 = spec1[0], spec_mix = spec_mix[0], auto_mix = auto_mix[0], auto0 = auto0[0], auto1 = auto1[0]) def predict_cam(self, ims, samples): cam = self.sess.run([self.net.vid_net.cam], {self.ims_ph : ims, self.samples_ph : samples}) return cam def read_data(pr, gpus): batch = ut.make_mod(pr.batch_size, len(gpus)) ims, samples, ytids = on_cpu( lambda : sep_dset.make_db_reader( pr.train_list, pr, batch, ['im', 'samples', 'ytid'], num_db_files = pr.num_dbs)) inputs = {'ims' : ims, 'samples' : samples, 'ytids' : ytids} splits = [{} for x in xrange(len(gpus))] for k, v in inputs.items(): if v is None: for i in xrange(len(gpus)): splits[i][k] = None else: s = tf.split(v, len(gpus)) for i in xrange(len(gpus)): splits[i][k] = s[i] return splits def make_opt(opt_method, lr_val, pr): if opt_method == 'adam': opt = tf.train.AdamOptimizer(lr_val) elif opt_method == 'momentum': opt = tf.train.MomentumOptimizer(lr_val, pr.momentum_rate) else: raise RuntimeError() return opt def make_mono(samples, tile = False): samples = tf.reduce_mean(samples, 2) if tile: samples = tf.tile(ed(samples, 2), (1, 1, 2)) return samples def pool3d(x, dim, stride, padding = 'SAME'): if np.ndim(stride) == 0: stride = [stride, stride, stride] if np.ndim(dim) == 0: dim = [dim, dim, dim] x = tf.nn.max_pool3d( x, ksize = [1] + list(dim) + [1], strides = [1] + list(stride) + [1], padding = padding) print 'pool ->', shape(x) return x def has_prefix(x, prefix): if type(prefix) == type(''): prefix = [prefix] return any(x.startswith(y) for y in prefix) def slim_losses_with_prefix(prefix, show = True): losses = tf.losses.get_regularization_losses() losses = [x for x in losses if prefix is None or x.name.startswith(prefix)] if show: print 'Collecting losses for prefix %s:' % prefix for x in losses: print x.name print return mu.maybe_add_n(losses) def vars_with_prefix(prefix): vs = [x for x in tf.trainable_variables() if has_prefix(x.name, prefix)] missing_vs = [x for x in tf.trainable_variables() if not has_prefix(x.name, prefix)] print print 'Variables included (prefix = "%s"):' % prefix for x in vs: print x.name print print 'Variables NOT included (prefix = "%s"):' % prefix for x in missing_vs: print x.name print return vs def slim_ups_with_prefix(prefix, show = True): ups = tf.get_collection(tf.GraphKeys.UPDATE_OPS) ups = [x for x in ups if prefix is None or x.name.startswith(prefix)] if show: print 'Collecting batch norm updates for prefix %s:' % prefix for x in ups: print x.name print return ups def show_results(ims, samples_mix, samples_gt, spec_mix, spec_gt, spec_pred_fg, spec_pred0, samples_pred, samples_pred0, samples_gt_auto, samples_mix_auto, ytids, pr = None, table = [], min_before_show = 1, n = None): def make_vid(ims, samples): samples = np.clip(samples, -1, 1).astype('float64') snd = sound.Sound(samples, pr.samp_sr) return imtable.Video(ims, pr.fps, snd) def vis_spec(spec): return ut.jet(spec.T, pr.spec_min, pr.spec_max * 0.75) for i in range(spec_mix.shape[0])[:n]: row = ['mix:', make_vid(ims[i], samples_mix[i]), 'pred:', make_vid(ims[i], samples_pred[i])] # if pr.use_decoder: # row += ['pred (before):', make_vid(ims[i], samples_pred0[i])] row += ['gt:', make_vid(ims[i], samples_gt[i]), 'gt autoencoded:', make_vid(ims[i], samples_gt_auto[i]), 'mix autoencoded:', make_vid(ims[i], samples_mix_auto[i]), ut.link('https://youtube.com/watch?v=%s' % ytids[i], ytids[i])] table.append(row) row = ['mix:', vis_spec(spec_mix[i]), 'pred:', vis_spec(spec_pred_fg[i])] # if pr.use_decoder: # row += ['pred (before):', make_vid(ims[i], samples_pred0[i])] row += ['gt:', vis_spec(spec_gt[i]), 'gt autoencoded', vis_spec(spec_gt[i]), 'mix autoencoded', vis_spec(spec_mix[i]), ''] table.append(row) if len(table) >= min_before_show*2: ig.show(table) table[:] = [] return np.array([1], np.int64) # def mix_sounds(samples0, pr, quiet_thresh_db = 40., samples1 = None): # # todo: for PIT # if pr.normalize_rms: # samples0 = mu.normalize_rms(samples0) # if samples1 is not None: # samples1 = mu.normalize_rms(samples1) # if samples1 is None: # n = shape(samples0, 0)/2 # samples0 = samples0[:, :pr.sample_len] # # samples1 = tf.concat( # # [samples0[n:, :pr.sample_len], # # samples0[:n, :pr.sample_len]], axis = 0) # samples1 = samples0[n:] # samples0 = samples0[:n] # else: # samples0 = samples0[:, :pr.sample_len] # samples1 = samples1[:, :pr.sample_len] # if pr.augment_rms: # print 'Augmenting rms' # scale0 = tf.random_uniform((shape(samples0, 0), 1, 1), 0.9, 1.1) # scale1 = tf.random_uniform((shape(samples1, 0), 1, 1), 0.9, 1.1) # samples0 = scale0 * samples0 # samples1 = scale1 * samples1 # samples_mix = samples0 + samples1 # spec_mix, phase_mix = stft(make_mono(samples_mix), pr) # spec0, phase0 = stft(make_mono(samples0), pr) # spec1, phase1 = stft(make_mono(samples1), pr) # print 'Before truncating specgram:', shape(spec_mix) # spec_mix = spec_mix[:, :pr.spec_len] # print 'After truncating specgram:', shape(spec_mix) # phase_mix = phase_mix[:, :pr.spec_len] # spec0 = spec0[:, :pr.spec_len] # spec1 = spec1[:, :pr.spec_len] # phase0 = phase0[:, :pr.spec_len] # phase1 = phase1[:, :pr.spec_len] # return ut.Struct( # samples = samples_mix, # phase = phase_mix, # spec = spec_mix, # sample_parts = [samples0, samples1], # spec_parts = [spec0, spec1], # phase_parts = [phase0, phase1]) def mix_sounds(samples0, pr, quiet_thresh_db = 40., samples1 = None): if pr.normalize_rms: samples0 = mu.normalize_rms(samples0) if samples1 is not None: samples1 = mu.normalize_rms(samples1) if samples1 is None: n = shape(samples0, 0)/2 samples0 = samples0[:, :pr.sample_len] if pr.both_videos_in_batch: print 'Using both videos' samples1 = tf.concat( [samples0[n:, :pr.sample_len], samples0[:n, :pr.sample_len]], axis = 0) else: print 'Only using first videos' samples1 = samples0[n:] samples0 = samples0[:n] else: samples0 = samples0[:, :pr.sample_len] samples1 = samples1[:, :pr.sample_len] if pr.augment_rms: print 'Augmenting rms' # scale0 = tf.random_uniform((shape(samples0, 0), 1, 1), 0.9, 1.1) # scale1 = tf.random_uniform((shape(samples1, 0), 1, 1), 0.9, 1.1) db = 0.25 scale0 = 2.**tf.random_uniform((shape(samples0, 0), 1, 1), -db, db) scale1 = 2.**tf.random_uniform((shape(samples1, 0), 1, 1), -db, db) samples0 = scale0 * samples0 samples1 = scale1 * samples1 samples_mix = samples0 + samples1 spec_mix, phase_mix = stft(make_mono(samples_mix), pr) spec0, phase0 = stft(make_mono(samples0), pr) spec1, phase1 = stft(make_mono(samples1), pr) print 'Before truncating specgram:', shape(spec_mix) spec_mix = spec_mix[:, :pr.spec_len] print 'After truncating specgram:', shape(spec_mix) phase_mix = phase_mix[:, :pr.spec_len] spec0 = spec0[:, :pr.spec_len] spec1 = spec1[:, :pr.spec_len] phase0 = phase0[:, :pr.spec_len] phase1 = phase1[:, :pr.spec_len] return ut.Struct( samples = samples_mix, phase = phase_mix, spec = spec_mix, sample_parts = [samples0, samples1], spec_parts = [spec0, spec1], phase_parts = [phase0, phase1]) def make_discrim_spec(spec_in, spec_out, phase_in, phase_out, pr, reuse = True, train = True): with slim.arg_scope(unet_arg_scope(pr, reuse = reuse, train = train)): spec_in = normalize_spec(spec_in, pr) spec_out = normalize_spec(spec_out, pr) spec_in = ed(spec_in, 3) spec_out = ed(spec_out, 3) phase_in = ed(phase_in, 3) phase_out = ed(phase_out, 3) net = tf.concat([spec_in, phase_in, spec_out, phase_out], 3) net = conv2d_same(net, 32, 4, scope = 'discrim/spec/conv1', stride = 2) net = conv2d_same(net, 64, 4, scope = 'discrim/spec/conv2', stride = 2) net = conv2d(net, 128, 4, scope = 'discrim/spec/conv3', stride = 2) net = conv2d(net, 128, 4, scope = 'discrim/spec/conv4', stride = 2) # net = conv2d(net, 256, 4, scope = 'discrim/spec/conv5', stride = 2) # net = conv2d(net, 256, 4, scope = 'discrim/spec/conv6', stride = [1, 2]) logits = conv2d(net, 1, 1, scope = 'discrim/spec/logits', stride = 1, normalizer_fn = None, activation_fn = None) return ut.Struct(logits = logits) def sigmoid_loss(logits, label): loss = tf.nn.sigmoid_cross_entropy_with_logits( logits = logits, labels = tf.zeros_like(logits) + label) ok = tf.equal(cast_int(logits >= 0.), label) acc = tf.stop_gradient(tf.reduce_mean(cast_float(ok))) return tf.reduce_mean(loss), acc def normalize_spec(spec, pr): return norm_range(spec, pr.spec_min, pr.spec_max) def unnormalize_spec(spec, pr): return unnorm_range(spec, pr.spec_min, pr.spec_max) def normalize_phase(phase, pr): return norm_range(phase, -np.pi, np.pi) def unnormalize_phase(phase, pr): return unnorm_range(phase, -np.pi, np.pi) def add_pred_losses(gen_loss, net, snd, pr): if 'fg-bg' in pr.loss_types: gt = normalize_spec(snd.spec_parts[0], pr) pred = normalize_spec(net.pred_spec_fg, pr) if 'fg': diff = pred - gt loss = pr.l1_weight*tf.reduce_mean(tf.abs(diff)) gen_loss.add_loss(loss, 'diff-fg') gt = normalize_phase(snd.phase_parts[0], pr) pred = normalize_phase(net.pred_phase_fg, pr) diff = pred - gt loss = pr.phase_weight*tf.reduce_mean(tf.abs(diff)) gen_loss.add_loss(loss, 'phase-fg') if pr.predict_bg: gt = normalize_spec(snd.spec_parts[1], pr) pred = normalize_spec(net.pred_spec_bg, pr) diff = pred - gt loss = pr.l1_weight*tf.reduce_mean(tf.abs(diff)) gen_loss.add_loss(loss, 'diff-bg') gt = normalize_phase(snd.phase_parts[1], pr) pred = normalize_phase(net.pred_phase_bg, pr) diff = pred - gt loss = pr.phase_weight*tf.reduce_mean(tf.abs(diff)) gen_loss.add_loss(loss, 'phase-bg') if 'pit' in pr.loss_types: print 'Using permutation loss' ns = lambda x : normalize_spec(x, pr) np = lambda x : normalize_phase(x, pr) gts_ = [[ns(snd.spec_parts[0]), np(snd.phase_parts[0])], [ns(snd.spec_parts[1]), np(snd.phase_parts[1])]] preds = [[ns(net.pred_spec_fg), np(net.pred_phase_fg)], [ns(net.pred_spec_bg), np(net.pred_phase_bg)]] l1 = lambda x, y : tf.reduce_mean(tf.abs(x - y), [1, 2]) losses = [] for i in xrange(2): gt = [gts_[i%2], gts_[(i+1)%2]] print 'preds[0][0] shape =', shape(preds[0][0]) fg_spec = pr.l1_weight * l1(preds[0][0], gt[0][0]) fg_phase = pr.phase_weight * l1(preds[0][1], gt[0][1]) bg_spec = pr.l1_weight * l1(preds[1][0], gt[1][0]) bg_phase = pr.phase_weight * l1(preds[1][1], gt[1][1]) losses.append(fg_spec + fg_phase + bg_spec + bg_phase) losses = tf.concat([ed(x, 0) for x in losses], 0) print 'losses shape =', shape(losses) loss_val = tf.reduce_min(losses, 0) print 'losses shape after min =', shape(losses) loss_val = pr.pit_weight * tf.reduce_mean(loss_val) #loss_val = tf.Print(loss_val, [losses]) gen_loss.add_loss(loss_val, 'pit') # else: # raise RuntimeError() def make_loss(net, snd, pr, reuse = True, train = True): assert set(pr.loss_types).issubset({'pit', 'fg-bg'}) gen_loss = mu.Loss('gen') gen_loss.add_loss(slim_losses_with_prefix('gen'), 'gen:reg') add_pred_losses(gen_loss, net, snd, pr) n = shape(net.pred_spec_fg, 1) if pr.gan_weight > 0: discrim_fake_spec = make_discrim_spec(snd.spec[:, :n], net.pred_spec_fg, snd.phase[:, :n], net.pred_phase_fg, pr, reuse = reuse, train = train) discrim_real_spec = make_discrim_spec(snd.spec[:, :n], snd.spec_parts[0][:, :n], snd.phase[:, :n], snd.phase_parts[0][:, :n], pr, reuse = True, train = train) discrim_loss = mu.Loss('discrim') discrim_loss.add_loss(slim_losses_with_prefix('discrim'), 'discrim:reg') tasks = [] if pr.gan_weight > 0: tasks.append(('spec', discrim_fake_spec, discrim_real_spec)) for name, discrim_fake, discrim_real in tasks: loss, acc = sigmoid_loss(discrim_fake.logits, 1) loss = loss * pr.gan_weight gen_loss.add_loss_acc((loss, acc), 'gen:gan_%s' % name) loss1, acc1 = sigmoid_loss(discrim_real.logits, 1) loss0, acc0 = sigmoid_loss(discrim_fake.logits, 0) loss, acc = (0.5*(loss0 + loss1), 0.5*(acc0 + acc1)) acc = tf.stop_gradient(acc) #loss = loss * pr.gan_weight discrim_loss.add_loss_acc((loss, acc), 'discrim:%s' % name) return gen_loss, discrim_loss class Model: def __init__(self, pr, sess, gpus, is_training = True, profile = False): self.pr = pr self.sess = sess self.gpus = gpus self.default_gpu = gpus[0] self.is_training = is_training self.profile = profile def make_model(self): with tf.device(self.default_gpu): pr = self.pr if self.is_training: self.make_train_ops() else: self.make_test_ops(reuse=False) self.coord = tf.train.Coordinator() self.saver_fast = tf.train.Saver() self.saver_slow = tf.train.Saver(max_to_keep = 1000) self.init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer()) self.sess.run(self.init_op) tf.train.start_queue_runners(sess = self.sess, coord = self.coord) print 'Initializing' self.merged_summary = tf.summary.merge_all() print 'Tensorboard command:' summary_dir = ut.mkdir(pj(pr.summary_dir, ut.simple_timestamp())) print 'tensorboard --logdir=%s' % summary_dir self.sum_writer = tf.summary.FileWriter(summary_dir, self.sess.graph) if self.profile: #self.run_meta = tf.RunMetadata() self.profiler = tf.profiler.Profiler(self.sess.graph) def make_train_ops(self): pr = self.pr # steps self.step = tf.get_variable( 'global_step', [], trainable = False, initializer = tf.constant_initializer(0), dtype = tf.int64) #self.lr = tf.constant(pr.base_lr) # model scale = pr.gamma ** tf.floor(cast_float(self.step) / float(pr.step_size)) self.lr = pr.base_lr * scale opt = make_opt(pr.opt_method, self.lr, pr) self.inputs = read_data(pr, self.gpus) gpu_grads, gpu_losses = {}, {} for i, gpu in enumerate(self.gpus): with tf.device(gpu): reuse = (i > 0) ims = self.inputs[i]['ims'] all_samples = self.inputs[i]['samples'] ytids = self.inputs[i]['ytids'] assert not pr.do_shift snd = mix_sounds(all_samples, pr) net = make_net(ims, snd.samples, snd.spec, snd.phase, pr, reuse = reuse, train = self.is_training) gen_loss, discrim_loss = make_loss(net, snd, pr, reuse = reuse, train = self.is_training) if pr.gan_weight <= 0: grads = opt.compute_gradients(gen_loss.total_loss()) else: # doesn't work with baselines, such as I3D #raise RuntimeError() print 'WARNING: DO NOT USE GAN WITH I3D' var_list = vars_with_prefix('gen') + vars_with_prefix('im') + vars_with_prefix('sf') grads = opt.compute_gradients(gen_loss.total_loss(), var_list = var_list) ut.add_dict_list(gpu_grads, 'gen', grads) ut.add_dict_list(gpu_losses, 'gen', gen_loss) var_list = vars_with_prefix('discrim') if pr.gan_weight <= 0: grads = [] else: grads = opt.compute_gradients(discrim_loss.total_loss(), var_list = var_list) ut.add_dict_list(gpu_grads, 'discrim', grads) ut.add_dict_list(gpu_losses, 'discrim', discrim_loss) if i == 0: self.net = net self.show_train = self.make_show_op(net, ims, snd, ytids) self.gen_loss = gpu_losses['gen'][0] self.discrim_loss = gpu_losses['discrim'][0] self.train_ops = {} self.loss_names = {} self.loss_vals = {} ops = [] for name in ['gen', 'discrim']: if pr.gan_weight <= 0. and name == 'discrim': op = tf.no_op() else: (gs, vs) = zip(*mu.average_grads(gpu_grads[name])) if pr.grad_clip is not None: gs, _ = tf.clip_by_global_norm(gs, pr.grad_clip) #gs = [mu.print_every(gs[0], 100, ['%s grad norm:' % name, tf.global_norm(gs)])] + list(gs[1:]) gvs = zip(gs, vs) #bn_ups = slim_ups_with_prefix(name) #bn_ups = slim_ups_with_prefix(None) if name == 'gen': bn_ups = tf.get_collection(tf.GraphKeys.UPDATE_OPS) else: bn_ups = slim_ups_with_prefix('discrim') print 'Number of batch norm ups for', name, len(bn_ups) with tf.control_dependencies(bn_ups): op = opt.apply_gradients(gvs) #op = tf.group(opt.apply_gradients(gvs, global_step = (self.step if name == 'discrim' else None)), *bn_ups) #op = tf.group(opt.apply_gradients(gvs), *bn_ups) ops.append(op) self.train_ops[name] = op loss = (self.gen_loss if name == 'gen' else self.discrim_loss) self.loss_names[name] = loss.get_loss_names() self.loss_vals[name] = loss.get_losses() self.update_step = self.step.assign(self.step + 1) if pr.gan_weight > 0: self.train_op = tf.group(*(ops + [self.update_step])) else: print 'Only using generator, because gan_weight = %.2f' % pr.gan_weight self.train_op = tf.group(ops[0], self.update_step) def make_show_op(self, net, ims, snd, ytids): pr = self.pr samples_gt_auto = istft(snd.spec_parts[0], snd.phase, pr) samples_mix_auto = istft(snd.spec, snd.phase, pr) return tf.py_func( lambda *args : show_results(*args, pr = pr), [ims, snd.samples, snd.sample_parts[0], snd.spec, snd.spec_parts[0], net.pred_spec_fg, net.pred_spec_fg, net.pred_wav_fg, net.pred_wav_fg, samples_gt_auto, samples_mix_auto, ytids], tf.int64) def checkpoint_fast(self): check_path = pj(ut.mkdir(self.pr.train_dir), 'net.tf') out = self.saver_fast.save(self.sess, check_path, global_step = self.step) print 'Checkpoint:', out def checkpoint_slow(self): check_path = pj(ut.mkdir(pj(self.pr.train_dir, 'slow')), 'net.tf') out = self.saver_slow.save(self.sess, check_path, global_step = self.step) print 'Checkpoint:', out def restore(self, path = None, restore_opt = True, init_type = None): if path is None: path = tf.train.latest_checkpoint(self.pr.train_dir) print 'Restoring from:', path var_list = slim.get_variables_to_restore() opt_names = ['Adam', 'beta1_power', 'beta2_power', 'Momentum', 'cache'] if init_type == 'shift': # gamma is reinitialized opt_names += ['gen/', 'discrim/', 'global_step', 'gamma'] elif init_type == 'sep': #opt_names += ['global_step'] opt_names += ['global_step', 'discrim'] elif init_type is None: pass else: raise RuntimeError() if not restore_opt or init_type is not None: var_list = [x for x in var_list if not any(name in x.name for name in opt_names)] print 'Restoring:' for x in var_list: print x.name print tf.train.Saver(var_list).restore(self.sess, path) def get_step(self): return self.sess.run([self.step, self.lr]) def train(self): val_hist = {} pr = self.pr num_steps = 0 while True: step, lr = self.get_step() first = (num_steps == 0) if not first and step % pr.check_iters == 0: self.checkpoint_fast() if not first and step % pr.slow_check_iters == 0: self.checkpoint_slow() if step >= pr.train_iters: break if pr.show_iters is not None and (first or step % pr.show_iters == 0): self.sess.run(self.show_train) loss_ops = self.gen_loss.get_losses() + self.discrim_loss.get_losses() loss_names = self.gen_loss.get_loss_names() + self.discrim_loss.get_loss_names() start = ut.now_sec() if pr.summary_iters is not None and step % pr.summary_iters == 0: ret = self.sess.run([self.train_op, self.merged_summary] + loss_ops) self.sum_writer.add_summary(ret[1], step) loss_vals = ret[2:] elif self.profile and (pr.profile_iters is not None and not first and step % pr.profile_iters == 0): run_meta = tf.RunMetadata() loss_vals = self.sess.run( [self.train_op] + loss_ops, options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE), run_metadata = run_meta)[1:] opts = tf.profiler.ProfileOptionBuilder.time_and_memory() self.profiler.add_step(step, run_meta) self.profiler.profile_operations(options = opts) self.profiler.profile_graph(options = opts) self.profiler.advise(self.sess.graph) else: loss_vals = self.sess.run([self.train_op] + loss_ops)[1:] if step % 100 == 0: gc.collect() ts = moving_avg('time', ut.now_sec() - start, val_hist) out = [] for name, val in zip(loss_names, loss_vals): out.append('%s: %.3f' % (name, moving_avg(name, val, val_hist))) out = ' '.join(out) if step < 10 or step % pr.print_iters == 0: print 'Iteration %d, lr = %.0e, %s, time: %.3f' % (step, lr, out, ts) num_steps += 1 def find_best_iter(pr, gpu, num_iters = 10, sample_rate = 10, dset_name = 'val'): [gpu] = mu.set_gpus([gpu]) best_iter = (np.inf, '') model_paths = sorted( ut.glob(pj(pr.train_dir, 'slow', 'net*.index')), key = lambda x : int(x.split('-')[-1].split('.')[0]))[-5:] model_paths = list(reversed(model_paths)) assert len(model_paths), 'no model paths at %s' % pj(pr.train_dir, 'slow', 'net*.index') for model_path in model_paths: model_path = model_path.split('.index')[0] print model_path clf = NetClf(pr, model_path, gpu = gpu) clf.init() if dset_name == 'train': print 'train' tf_files = sorted(ut.glob(pj(pr.train_list, '*.tf'))) elif dset_name == 'val': tf_files = sorted(ut.glob(pj(pr.val_list, '*.tf'))) else: raise RuntimeError() import sep_eval losses = [] for ims, _, pair in sep_eval.pair_data(tf_files, pr): if abs(hash(pair['ytid_gt'])) % sample_rate == 0: res = clf.predict_unmixed(ims, pair['samples_gt'], pair['samples_bg']) # loss = np.mean(np.abs(res['spec_pred_fg'] - res['spec0'])) # loss += np.mean(np.abs(res['spec_pred_bg'] - res['spec1'])) loss = 0. if 'pit' in pr.loss_types: loss += pit_loss( [res['spec0']], [res['spec1']], [res['spec_pred_fg']], [res['spec_pred_bg']], pr) if 'fg-bg' in pr.loss_types: loss += np.mean(np.abs(res['spec_pred_fg'] - res['spec0'])) loss += np.mean(np.abs(res['spec_pred_bg'] - res['spec1'])) losses.append(loss) print 'running:', np.mean(losses) loss = np.mean(losses) print model_path, 'Loss:', loss best_iter = min(best_iter, (loss, model_path)) ut.write_lines(pj(pr.resdir, 'model_path.txt'), [best_iter[1]]) def pit_loss(gt0, gt1, pred0, pred1, pr): losses = [] weights = np.array([pr.l1_weight, pr.phase_weight]) for i in xrange(2): gt = [gt0, gt1] if i == 0 else [gt1, gt0] loss = 0. for j in xrange(1): p = np.array([pred0[j], pred1[j]]) g = np.array([gt[0][j], gt[1][j]]) w = weights[j] loss += w * np.mean(np.abs(p - g)) losses.append(loss) print 'losses =', losses return np.min(losses) # def find_best_iter(pr, gpu, num_iters = 10, sample_rate = 10, dset_name = 'val'): # [gpu] = mu.set_gpus([gpu]) # best_iter = (np.inf, '') # model_paths = sorted( # ut.glob(pj(pr.train_dir, 'slow', 'net*.index')), # key = lambda x : int(x.split('-')[-1].split('.')[0]))[-5:] # model_paths = reversed(model_paths) # for model_path in model_paths: # model_path = model_path.split('.index')[0] # print model_path # clf = sep_eval.NetClf(pr, model_path, gpu = gpu) # clf.init() # if dset_name == 'train': # print 'train' # tf_files = sorted(ut.glob(pj(pr.train_list, '*.tf'))) # elif dset_name == 'val': # tf_files = sorted(ut.glob(pj(pr.val_list, '*.tf'))) # else: # raise RuntimeError() # losses = [] # for ims, _, pair in sep_eval.pair_data(tf_files, pr): # if abs(hash(pair['ytid_gt'])) % sample_rate == 0: # res = clf.predict_unmixed(ims, pair['samples_gt'], pair['samples_bg']) # loss = np.mean(np.abs(res['spec_pred_fg'] - res['spec0'])) # loss += np.mean(np.abs(res['spec_pred_bg'] - res['spec1'])) # losses.append(loss) # print 'running:', np.mean(losses) # loss = np.mean(losses) # print model_path, 'Loss:', loss # best_iter = min(best_iter, (loss, model_path)) # ut.write_lines(pj(pr.resdir, 'model_path.txt'), [best_iter[1]]) # def find_best_iter(pr, gpu, num_iters = 10, sample_rate = 10): # [gpu] = mu.set_gpus([gpu]) # def f((model_path, gpu_num)): # model_path = model_path.split('.index')[0] # print model_path # clf = sep_eval.NetClf(pr, model_path, gpu = gpu) # clf.init() # tf_files = sorted(ut.glob(pj(pr.val_list, '*.tf'))) # losses = [] # for ims, _, pair in sep_eval.pair_data(tf_files, pr): # if abs(hash(pair['ytid_gt'])) % sample_rate == 0: # res = clf.predict_unmixed(ims, pair['samples_gt'], pair['samples_bg']) # loss = np.mean(np.abs(res['spec_pred_fg'] - res['spec0'])) # loss += np.mean(np.abs(res['spec_pred_bg'] - res['spec1'])) # losses.append(loss) # loss = np.mean(losses) # print model_path, 'Loss:', loss # return (loss, model_path) # model_files = sorted( # ut.glob(pj(pr.train_dir, 'slow', 'net*.index')), # key = lambda x : int(x.split('-')[-1].split('.')[0]))[-5:] # for model_file in ut.model_files # ut.write_lines(pj(pr.resdir, 'model_path.txt'), [best_iter[1]]) def moving_avg(name, x, vals, avg_win_size = 100, p = 0.99): vals[name] = p*vals.get(name, x) + (1 - p)*x return vals[name] def conv2d(*args, **kwargs): out = slim.conv2d(*args, **kwargs) print kwargs['scope'], shape(args[0]), '->', shape(out) return out def conv2d_same(*args, **kwargs): out = mu.conv2d_same(*args, **kwargs) print kwargs['scope'], '->', shape(out) return out def deconv2d(*args, **kwargs): out = slim.conv2d_transpose(*args, **kwargs) print kwargs['scope'], shape(args[0]), '->', shape(out) return out def unet_arg_scope(pr, weight_decay = 1e-5, reuse = False, renorm = True, train = True, scale = True, center = True): batch_norm_params = { 'decay': 0.9997, 'epsilon': 1e-5, 'updates_collections': slim.ops.GraphKeys.UPDATE_OPS, 'scale' : scale, 'center' : center, 'is_training' : train, 'renorm' : renorm, 'param_initializers' : {'gamma' : tf.random_normal_initializer(1., 0.02)}, } normalizer_fn = slim.batch_norm normalizer_params = batch_norm_params with slim.arg_scope([slim.batch_norm], **batch_norm_params): with slim.arg_scope( [slim.conv2d, slim.conv2d_transpose], weights_regularizer = slim.regularizers.l2_regularizer(weight_decay), weights_initializer = tf.random_normal_initializer(0, 0.02), activation_fn = tf.nn.relu, normalizer_fn = normalizer_fn, reuse = reuse, normalizer_params = normalizer_params) as sc: return sc def norm_range(x, min_val, max_val): return 2.*(x - min_val)/float(max_val - min_val) - 1. def unnorm_range(y, min_val, max_val): return 0.5*float(max_val - min_val) * (y + 1) + min_val def print_vals(name, x): return tf.Print(x, [name, tf.reduce_min(x), tf.reduce_max(x)]) def stft(samples, pr): spec_complex = tf.contrib.signal.stft( samples, frame_length = soundrep.stft_frame_length(pr), frame_step = soundrep.stft_frame_step(pr), pad_end = pr.pad_stft) mag = tf.abs(spec_complex) #phase = tf.angle(spec_complex) phase = mu.angle(spec_complex) if pr.log_spec: mag = soundrep.db_from_amp(mag) return mag, phase def make_complex(mag, phase): mag = cast_complex(mag) phase = cast_complex(phase) j = tf.constant(1j, dtype = tf.complex64) return mag * (tf.cos(phase) + j*tf.sin(phase)) def istft(mag, phase, pr): if pr.log_spec: mag = soundrep.amp_from_db(mag) samples = tf.contrib.signal.inverse_stft( make_complex(mag, phase), frame_length = soundrep.stft_frame_length(pr), frame_step = soundrep.stft_frame_step(pr), fft_length = soundrep.stft_num_fft(pr)) return samples # def griffin_lim(mag, phase, pr): # import soundrep # if pr.log_spec: # mag = soundrep.amp_from_db(mag) # samples = soundrep.griffin_lim( # make_complex(mag, phase), # frame_length = soundrep.stft_frame_length(pr), # frame_step = soundrep.stft_frame_step(pr), # num_fft = soundrep.stft_num_fft(pr), # num_iters = 5) # return samples def make_net(ims, sfs, spec, phase, pr, reuse = True, train = True, vid_net_full = None): if pr.mono: print 'Using mono!' sfs = make_mono(sfs, tile = True) if vid_net_full is None: if pr.net_style == 'static': n = shape(ims, 1) if 0: ims_tile = tf.tile(ims[:, n/2:n/2+1], (1, n, 1, 1, 1)) else: ims = tf.cast(ims, tf.float32) ims_tile = tf.tile(ims[:, n/2:n/2+1], (1, n, 1, 1, 1)) vid_net_full = shift_net.make_net(ims_tile, sfs, pr, None, reuse, train) elif pr.net_style == 'no-im': vid_net_full = None elif pr.net_style == 'full': vid_net_full = shift_net.make_net(ims, sfs, pr, None, reuse, train) elif pr.net_style == 'i3d': with tf.variable_scope('RGB', reuse = reuse): import sep_i3d i3d_net = sep_i3d.InceptionI3d(1) vid_net_full = ut.Struct(scales = i3d_net(ims, is_training = train)) with slim.arg_scope(unet_arg_scope(pr, reuse = reuse, train = train)): acts = [] def conv(*args, **kwargs): out = conv2d(*args, activation_fn = None, **kwargs) acts.append(out) out = mu.lrelu(out, 0.2) return out def deconv(*args, **kwargs): args = list(args) if kwargs.get('do_pop', True): skip_layer = acts.pop() else: skip_layer = acts[-1] if 'do_pop' in kwargs: del kwargs['do_pop'] x = args[0] if kwargs.get('concat', True): x = tf.concat([x, skip_layer], 3) if 'concat' in kwargs: del kwargs['concat'] args[0] = tf.nn.relu(x) return deconv2d(*args, activation_fn = None, **kwargs) def merge_level(net, n): if vid_net_full is None: return net vid_net = tf.reduce_mean(vid_net_full.scales[n], [2, 3], keep_dims = True) vid_net = vid_net[:, :, 0, :, :]; s = shape(vid_net) if shape(net, 1) != s[1]: vid_net = tf.image.resize_images(vid_net, [shape(net, 1), 1]) print 'Video net before merge:', s, 'After:', shape(vid_net) else: print 'No need to resize:', s, shape(net) vid_net = tf.tile(vid_net, (1, 1, shape(net, 2), 1)) net = tf.concat([net, vid_net], 3) acts[-1] = net return net num_freq = shape(spec, 2) net = tf.concat( [ed(normalize_spec(spec, pr), 3), ed(normalize_phase(phase, pr), 3)], 3) net = net[:, :, :pr.freq_len, :] net = conv(net, 64, 4, scope = 'gen/conv1', stride = [1, 2]) net = conv(net, 128, 4, scope = 'gen/conv2', stride = [1, 2]) net = conv(net, 256, 4, scope = 'gen/conv3', stride = 2) net = merge_level(net, 0) net = conv(net, 512, 4, scope = 'gen/conv4', stride = 2) net = merge_level(net, 1) net = conv(net, 512, 4, scope = 'gen/conv5', stride = 2) net = merge_level(net, 2) net = conv(net, 512, 4, scope = 'gen/conv6', stride = 2) net = conv(net, 512, 4, scope = 'gen/conv7', stride = 2) net = conv(net, 512, 4, scope = 'gen/conv8', stride = 2) net = conv(net, 512, 4, scope = 'gen/conv9', stride = 2) net = deconv(net, 512, 4, scope = 'gen/deconv1', stride = 2, concat = False) net = deconv(net, 512, 4, scope = 'gen/deconv2', stride = 2) net = deconv(net, 512, 4, scope = 'gen/deconv3', stride = 2) net = deconv(net, 512, 4, scope = 'gen/deconv4', stride = 2) net = deconv(net, 512, 4, scope = 'gen/deconv5', stride = 2) net = deconv(net, 256, 4, scope = 'gen/deconv6', stride = 2) net = deconv(net, 128, 4, scope = 'gen/deconv7', stride = 2) net = deconv(net, 64, 4, scope = 'gen/deconv8', stride = [1, 2]) out_fg = deconv(net, 2, 4, scope = 'gen/fg', stride = [1, 2], normalizer_fn = None, do_pop = False) out_bg = deconv(net, 2, 4, scope = 'gen/bg', stride = [1, 2], normalizer_fn = None, do_pop = False) def process(out): pred_spec = out[..., 0] pred_spec = tf.tanh(pred_spec) pred_spec = unnormalize_spec(pred_spec, pr) pred_phase = out[..., 1] pred_phase = tf.tanh(pred_phase) pred_phase = unnormalize_phase(pred_phase, pr) val = soundrep.db_from_amp(0.) if pr.log_spec else 0. pred_spec = tf.pad(pred_spec, [(0, 0), (0, 0), (0, num_freq - shape(pred_spec, 2))], constant_values = val) if pr.phase_type == 'pred': pred_phase = tf.concat([pred_phase, phase[..., -1:]], 2) elif pr.phase_type == 'orig': pred_phase = phase else: raise RuntimeError() # if ut.hastrue(pr, 'griffin_lim'): # print 'using griffin-lim' # pred_wav = griffin_lim(pred_spec, pred_phase, pr) # else: pred_wav = istft(pred_spec, pred_phase, pr) return pred_spec, pred_phase, pred_wav pred_spec_fg, pred_phase_fg, pred_wav_fg = process(out_fg) pred_spec_bg, pred_phase_bg, pred_wav_bg = process(out_bg) return ut.Struct(pred_spec_fg = pred_spec_fg, pred_wav_fg = pred_wav_fg, pred_phase_fg = pred_phase_fg, pred_spec_bg = pred_spec_bg, pred_phase_bg = pred_phase_bg, pred_wav_bg = pred_wav_bg, vid_net = vid_net_full, ) def truncate_min(x, y): n = min(shape(x, 1), shape(y, 1)) x = x[:, :n] y = y[:, :n] return x, y def train(pr, gpus, restore = False, restore_opt = True, profile = False): print pr gpus = mu.set_gpus(gpus) with tf.Graph().as_default(): config = tf.ConfigProto(allow_soft_placement = True) sess = tf.InteractiveSession(config = config) model = Model(pr, sess, gpus, profile = profile) model.make_model() if restore: model.restore(restore_opt = restore_opt) elif pr.init_path is not None: if pr.init_type in ['shift', 'sep']: model.restore(pr.init_path, restore_opt = False, init_type = pr.init_type) elif pr.init_type == 'i3d': opt_names = ['Adam', 'beta1_power', 'beta2_power', 'Momentum'] rgb_variable_map = {} for variable in tf.global_variables(): if any(x in variable.name for x in opt_names): print 'Skipping:', variable.name continue if variable.name.split('/')[0] == 'RGB': rgb_variable_map[variable.name.replace(':0', '')] = variable print 'Restoring:', variable.name rgb_saver = tf.train.Saver(var_list = rgb_variable_map, reshape=True) rgb_saver.restore(sess, pr.init_path) elif pr.init_type == 'scratch': pass else: raise RuntimeError() tf.get_default_graph().finalize() model.train()