from __future__ import print_function import time import tensorflow as tf from six.moves import xrange def model(x, nlogits, train=False): b = tf.shape(x)[0] if args.infer_legacy: window = tf.contrib.signal.hann_window(1024) windows = [] for i in xrange(0, 16384, 128): x_window = x[:, i:i+1024] x_padded = tf.pad(x_window, [[0, 0], [0, max(0, i + 1024 - 16384)]]) x_windowed = x_padded * window windows.append(x_windowed) windows = tf.stack(windows, axis=1) X = tf.spectral.rfft(windows) else: X = tf.contrib.signal.stft(x, 1024, 128, pad_end=True) X_mag = tf.abs(X) W = tf.contrib.signal.linear_to_mel_weight_matrix( num_mel_bins=128, num_spectrogram_bins=513, sample_rate=16000, lower_edge_hertz=40., upper_edge_hertz=7800., dtype=tf.float32) X_mag = tf.reshape(X_mag, [-1, 513]) X_mel = tf.matmul(X_mag, W) X_mel = tf.reshape(X_mel, [b, 128, 128]) X_lmel = tf.log(X_mel + 1e-6) x = tf.stop_gradient(X_lmel) dropout = 0.5 if train else 0. x = tf.layers.batch_normalization(x, training=train) x = tf.expand_dims(x, axis=3) x = tf.layers.conv2d(x, 128, (5, 5), padding='same', activation=tf.nn.relu) x = tf.layers.max_pooling2d(x, (2, 2), (2, 2)) x = tf.layers.batch_normalization(x, training=train) x = tf.layers.conv2d(x, 128, (5, 5), padding='same', activation=tf.nn.relu) x = tf.layers.max_pooling2d(x, (2, 2), (2, 2)) x = tf.layers.batch_normalization(x, training=train) x = tf.layers.conv2d(x, 128, (5, 5), padding='same', activation=tf.nn.relu) x = tf.layers.max_pooling2d(x, (2, 2), (2, 2)) x = tf.layers.batch_normalization(x, training=train) x = tf.layers.conv2d(x, 128, (5, 5), padding='same', activation=tf.nn.relu) x = tf.layers.max_pooling2d(x, (2, 2), (2, 2)) x = tf.reshape(x, [b, 8 * 8 * 128]) x = tf.layers.batch_normalization(x, training=train) x = tf.layers.dense(x, nlogits) if train: update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) assert len(update_ops) == 10 with tf.control_dependencies(update_ops): x = tf.identity(x) return x def record_to_xy(example_proto, labels): features = { 'samples': tf.FixedLenSequenceFeature([1], tf.float32, allow_missing=True), 'label': tf.FixedLenSequenceFeature([], tf.string, allow_missing=True) } example = tf.parse_single_example(example_proto, features) wav = example['samples'][:, 0] wav = wav[:16384] wav = tf.pad(wav, [[0, 16384 - tf.shape(wav)[0]]]) wav.set_shape([16384]) label_chars = example['label'] # Truncate labels for TIMIT label_lens = [len(l) for l in labels] if len(set(label_lens)) == 1: label_chars = label_chars[:label_lens[0]] label = tf.reduce_join(label_chars, 0) label_id = tf.constant(0, dtype=tf.int32) nmatches = tf.constant(0) for i, label_candidate in enumerate(labels): match = tf.cast(tf.equal(label, label_candidate), tf.int32) label_id += i * match nmatches += match with tf.control_dependencies([tf.assert_equal(nmatches, 1)]): return wav, label_id def eval(fps, args): import numpy as np eval_dir = os.path.join(args.train_dir, 'eval_{}'.format(args.eval_split)) if not os.path.isdir(eval_dir): os.makedirs(eval_dir) with tf.name_scope('eval_loader'): dataset = tf.data.TFRecordDataset(fps) dataset = dataset.map(lambda x: record_to_xy(x, args.data_labels)) dataset = dataset.apply(tf.contrib.data.batch_and_drop_remainder(args.eval_batch_size)) iterator = dataset.make_one_shot_iterator() x, y = iterator.get_next() with tf.variable_scope('classifier'): logits = model(x, len(args.data_labels)) xent = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits, labels=y) xent_mean = tf.reduce_mean(xent) preds = tf.argmax(logits, axis=1, output_type=tf.int32) acc = tf.reduce_mean(tf.cast(tf.equal(preds, y), tf.float32)) step = tf.train.get_or_create_global_step() summaries = [ tf.summary.scalar('acc', acc), tf.summary.scalar('xent', xent_mean) ] summaries = tf.summary.merge(summaries) summary_writer = tf.summary.FileWriter(eval_dir) saver = tf.train.Saver(max_to_keep=1) def eval_ckpt_fp(sess, ckpt_fp): saver.restore(sess, ckpt_fp) _xents = [] _accs = [] while True: try: _xent, _acc = sess.run([xent_mean, acc]) except: break _xents.append(_xent) _accs.append(_acc) _step = sess.run(step) _summaries = sess.run(summaries, {acc: np.mean(_accs), xent_mean: np.mean(_xents)}) summary_writer.add_summary(_summaries, _step) return _step, np.mean(_accs) if args.eval_ckpt_fp is not None: # Eval one with tf.Session() as sess: eval_ckpt_fp(sess, args.eval_ckpt_fp) else: # Loop, waiting for checkpoints ckpt_fp = None _best_acc = 0. while True: latest_ckpt_fp = tf.train.latest_checkpoint(args.train_dir) if latest_ckpt_fp != ckpt_fp: print('Preview: {}'.format(latest_ckpt_fp)) with tf.Session() as sess: _step, _acc = eval_ckpt_fp(sess, latest_ckpt_fp) if _acc > _best_acc: saver.save(sess, os.path.join(eval_dir, 'best_acc'), _step) _best_acc = _acc print('Done') ckpt_fp = latest_ckpt_fp time.sleep(1) def infer(args): import cPickle as pickle infer_dir = os.path.join(args.train_dir, 'infer') if not os.path.isdir(infer_dir): os.makedirs(infer_dir) # Placeholders for sampling stage x = tf.placeholder(tf.float32, [None, 16384], name='x') labels = tf.constant(args.data_labels, name='labels') with tf.variable_scope('classifier'): logits = model(x, len(args.data_labels)) scores = tf.nn.softmax(logits, name='scores') pred = tf.argmax(logits, axis=1, output_type=tf.int32, name='pred') pred_label = tf.gather(labels, pred, name='pred_label') # Create saver all_vars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES) global_step = tf.train.get_or_create_global_step() saver = tf.train.Saver(all_vars + [global_step]) # Export graph tf.train.write_graph(tf.get_default_graph(), infer_dir, 'infer.pbtxt') # Export MetaGraph infer_metagraph_fp = os.path.join(infer_dir, 'infer.meta') tf.train.export_meta_graph( filename=infer_metagraph_fp, clear_devices=True, saver_def=saver.as_saver_def()) # Reset graph (in case training afterwards) tf.reset_default_graph() def train(fps, args): with tf.name_scope('loader'): dataset = tf.data.TFRecordDataset(fps) dataset = dataset.map(lambda x: record_to_xy(x, args.data_labels)) dataset = dataset.shuffle(buffer_size=8192) dataset = dataset.apply(tf.contrib.data.batch_and_drop_remainder(args.train_batch_size)) dataset = dataset.repeat() iterator = dataset.make_one_shot_iterator() x, y = iterator.get_next() with tf.variable_scope('classifier'): logits = model(x, len(args.data_labels), train=True) for v in tf.global_variables(): print(v.get_shape(), v.name) xent = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits, labels=y) xent_mean = tf.reduce_mean(xent) opt = tf.train.GradientDescentOptimizer(1e-4) train_op = opt.minimize(xent_mean, global_step=tf.train.get_or_create_global_step()) preds = tf.argmax(logits, axis=1, output_type=tf.int32) acc = tf.reduce_mean(tf.cast(tf.equal(preds, y), tf.float32)) tf.summary.audio('x', tf.expand_dims(x, axis=2), 16000) tf.summary.scalar('xent', xent_mean) tf.summary.scalar('acc', acc) tf.summary.histogram('xent', xent_mean) with tf.train.MonitoredTrainingSession( checkpoint_dir=args.train_dir, save_checkpoint_secs=args.train_save_secs, save_summaries_secs=args.train_summary_secs) as sess: while True: _, _acc = sess.run([train_op, acc]) if __name__ == '__main__': import argparse import glob import os import sys parser = argparse.ArgumentParser() parser.add_argument('mode', type=str, choices=['train', 'eval']) parser.add_argument('train_dir', type=str, help='Training directory') data_args = parser.add_argument_group('Data') data_args.add_argument('--data_dir', type=str, help='Data directory') data_args.add_argument('--data_labels', type=str, help='Comma-separated list of labels') train_args = parser.add_argument_group('Train') train_args.add_argument('--train_batch_size', type=int, help='Batch size') train_args.add_argument('--train_save_secs', type=int, help='How often to save model') train_args.add_argument('--train_summary_secs', type=int, help='How often to report summaries') eval_args = parser.add_argument_group('Eval') eval_args.add_argument('--eval_batch_size', type=int, help='Batch size') eval_args.add_argument('--eval_split', type=str, help='Eval split') eval_args.add_argument('--eval_ckpt_fp', type=str, help='If set, evaluate this checkpoint once') infer_args = parser.add_argument_group('Infer') infer_args.add_argument('--infer_legacy', action='store_true', dest='infer_legacy', help='If set, create graph compatible with tf1.1') parser.set_defaults( data_dir=None, data_labels=None, train_batch_size=64, train_save_secs=300, train_summary_secs=120, eval_batch_size=64, eval_split='valid', eval_ckpt_fp=None, infer_legacy=False) args = parser.parse_args() labels = [l.strip() for l in args.data_labels.split(',')] setattr(args, 'data_labels', labels) # Make train dir if not os.path.isdir(args.train_dir): os.makedirs(args.train_dir) # Save args with open(os.path.join(args.train_dir, 'args.txt'), 'w') as f: f.write('\n'.join([str(k) + ',' + str(v) for k, v in sorted(vars(args).items(), key=lambda x: x[0])])) # Assign appropriate split for mode if args.mode == 'train': split = 'train' elif args.mode == 'eval': split = args.eval_split else: raise NotImplementedError() # Find group fps and make splits fps = glob.glob(os.path.join(args.data_dir, split) + '*.tfrecord') if args.mode == 'train': infer(args) train(fps, args) if args.mode == 'eval': eval(fps, args) else: raise NotImplementedError()