import tensorflow as tf from os.path import join import getopt import sys import constants as c from LSTMModel import LSTMModel from data_reader import DataReader class LyricGenRunner: def __init__(self, model_load_path, artist_name, test, prime_text): """ Initializes the Lyric Generation Runner. @param model_load_path: The path from which to load a previously-saved model. Default = None. @param artist_name: The name of the artist on which to train. (Used to grab data). Default = 'kanye_west' @param test: Whether to test or train the model. Testing generates a sequence from the provided model and artist. Default = False. @param prime_text: The text with which to start the test sequence. """ self.sess = tf.Session() self.artist_name = artist_name print 'Process data...' self.data_reader = DataReader(self.artist_name) self.vocab = self.data_reader.get_vocab() print 'Init model...' self.model = LSTMModel(self.sess, self.vocab, c.BATCH_SIZE, c.SEQ_LEN, c.CELL_SIZE, c.NUM_LAYERS, test=test) print 'Init variables...' self.saver = tf.train.Saver(max_to_keep=None) self.sess.run(tf.global_variables_initializer()) # if load path specified, load a saved model if model_load_path is not None: self.saver.restore(self.sess, model_load_path) print 'Model restored from ' + model_load_path if test: self.test(prime_text) else: self.train() def train(self): """ Runs a training loop on the model. """ while True: inputs, targets = self.data_reader.get_train_batch(c.BATCH_SIZE, c.SEQ_LEN) print 'Training model...' feed_dict = {self.model.inputs: inputs, self.model.targets: targets} global_step, loss, _ = self.sess.run([self.model.global_step, self.model.loss, self.model.train_op], feed_dict=feed_dict) print 'Step: %d | loss: %f' % (global_step, loss) if global_step % c.MODEL_SAVE_FREQ == 0: print 'Saving model...' self.saver.save(self.sess, join(c.MODEL_SAVE_DIR, self.artist_name + '.ckpt'), global_step=global_step) def test(self, prime_text): """ Generates a text sequence. """ # generate and save sample sequence sample = self.model.generate(prime=prime_text) print sample def main(): load_path = None artist_name = 'kanye_west' test = False prime_text = None try: opts, _ = getopt.getopt(sys.argv[1:], 'l:m:a:p:s:t', ['load_path=', 'model_name=', 'artist_name=', 'prime=', 'seq_len', 'test', 'save_freq=']) except getopt.GetoptError: sys.exit(2) for opt, arg in opts: if opt in ('-l', '--load_path'): load_path = arg if opt in ('-m', '--model_name'): c.set_save_name(arg) if opt in ('-a', '--artist_name'): artist_name = arg if opt in ('-p', '--prime'): prime_text = arg if opt in ('-s', '--seq_len'): c.SEQ_LEN = arg if opt in ('-t', '--test'): test = True if opt == '--save_freq': c.MODEL_SAVE_FREQ = int(arg) LyricGenRunner(load_path, artist_name, test, prime_text) if __name__ == '__main__': main()