# Adversarial learning for event-based music generation with SeqGAN # Reference: # "SeqGAN: Sequence Generative Adversarial Nets with Policy Gradient." # (Yu, Lantao, et al.). # ... Honestly, it's too hard to train ;( import torch import torch.nn as nn import torch.optim as optim import torch.nn.functional as F from torch.distributions import Categorical import numpy as np import os, sys, time, argparse from progress.bar import Bar import config, utils from config import device from data import Dataset from model import PerformanceRNN from sequence import EventSeq, ControlSeq # pylint: disable=E1101 #======================================================================== # Discriminator #======================================================================== discriminator_config = { 'event_dim': EventSeq.dim(), 'hidden_dim': 512, 'gru_layers': 3, 'gru_dropout': 0.3 } class EventSequenceEncoder(nn.Module): def __init__(self, event_dim=EventSeq.dim(), hidden_dim=512, gru_layers=3, gru_dropout=0.3): super().__init__() self.event_embedding = nn.Embedding(event_dim, hidden_dim) self.gru = nn.GRU(hidden_dim, hidden_dim, num_layers=gru_layers, dropout=gru_dropout) self.attn = nn.Parameter(torch.randn(hidden_dim), requires_grad=True) self.output_fc = nn.Linear(hidden_dim, 1) self.output_fc_activation = nn.Sigmoid() def forward(self, events, hidden=None, output_logits=False): # events: [steps, batch_size] events = self.event_embedding(events) outputs, _ = self.gru(events, hidden) # [t, b, h] weights = (outputs * self.attn).sum(-1, keepdim=True) output = (outputs * weights).mean(0) # [b, h] output = self.output_fc(output).squeeze(-1) # [b] if output_logits: return output output = self.output_fc_activation(output) return output #======================================================================== # Pretrain Discriminator #======================================================================== def pretrain_discriminator(model_sess_path, # load discriminator_sess_path, # load + save batch_data_generator, # Dataset(...).batches(...) discriminator_config_overwrite={}, gradient_clipping=False, control_ratio=1.0, num_iter=-1, save_interval=60.0, discriminator_lr=0.001, enable_logging=False, auto_sample_factor=False, sample_factor=1.0): print('-' * 70) print('model_sess_path:', model_sess_path) print('discriminator_sess_path:', discriminator_sess_path) print('discriminator_config_overwrite:', discriminator_config_overwrite) print('sample_factor:', sample_factor) print('auto_sample_factor:', auto_sample_factor) print('discriminator_lr:', discriminator_lr) print('gradient_clipping:', gradient_clipping) print('control_ratio:', control_ratio) print('num_iter:', num_iter) print('save_interval:', save_interval) print('enable_logging:', enable_logging) print('-' * 70) # Load generator model_sess = torch.load(model_sess_path) model_config = model_sess['model_config'] model = PerformanceRNN(**model_config).to(device) model.load_state_dict(model_sess['model_state']) print(f'Generator from "{model_sess_path}"') print(model) print('-' * 70) # Load discriminator and optimizer global discriminator_config try: discriminator_sess = torch.load(discriminator_sess_path) discriminator_config = discriminator_sess['discriminator_config'] discriminator_state = discriminator_sess['discriminator_state'] discriminator_optimizer_state = discriminator_sess['discriminator_optimizer_state'] print(f'Discriminator from "{discriminator_sess_path}"') discriminator_loaded = True except: print(f'New discriminator session at "{discriminator_sess_path}"') discriminator_config.update(discriminator_config_overwrite) discriminator_loaded = False discriminator = EventSequenceEncoder(**discriminator_config).to(device) optimizer = optim.Adam(discriminator.parameters(), lr=discriminator_lr) if discriminator_loaded: discriminator.load_state_dict(discriminator_state) optimizer.load_state_dict(discriminator_optimizer_state) print(discriminator) print(optimizer) print('-' * 70) def save_discriminator(): print(f'Saving to "{discriminator_sess_path}"') torch.save({ 'discriminator_config': discriminator_config, 'discriminator_state': discriminator.state_dict(), 'discriminator_optimizer_state': optimizer.state_dict() }, discriminator_sess_path) print('Done saving') # Disable gradient for generator for parameter in model.parameters(): parameter.requires_grad_(False) model.eval() discriminator.train() loss_func = nn.BCEWithLogitsLoss() last_save_time = time.time() if enable_logging: from tensorboardX import SummaryWriter writer = SummaryWriter() try: for i, (events, controls) in enumerate(batch_data_generator): if i == num_iter: break steps, batch_size = events.shape # Prepare inputs events = torch.LongTensor(events).to(device) if np.random.random() <= control_ratio: controls = torch.FloatTensor(controls).to(device) else: controls = None init = torch.randn(batch_size, model.init_dim).to(device) # Predict for real event sequence real_events = events real_logit = discriminator(real_events, output_logits=True) real_target = torch.ones_like(real_logit).to(device) if auto_sample_factor: sample_factor = np.random.choice([ 0.1, 0.4, 0.6, 0.7, 0.8, 0.9, 1.0, 1.1, 1.2, 1.4, 1.6, 2.0, 4.0, 10.0]) # Predict for fake event sequence from the generator fake_events = model.generate(init, steps, None, controls, greedy=0, output_type='index', temperature=sample_factor) fake_logit = discriminator(fake_events, output_logits=True) fake_target = torch.zeros_like(fake_logit).to(device) # Compute loss loss = (loss_func(real_logit, real_target) + loss_func(fake_logit, fake_target)) / 2 # Backprop discriminator.zero_grad() loss.backward() # Gradient clipping norm = utils.compute_gradient_norm(discriminator.parameters()) if gradient_clipping: nn.utils.clip_grad_norm_(discriminator.parameters(), gradient_clipping) optimizer.step() # Logging loss = loss.item() norm = norm.item() print(f'{i} loss: {loss}, norm: {norm}, sf: {sample_factor}') if enable_logging: writer.add_scalar(f'pretrain/D/loss/all', loss, i) writer.add_scalar(f'pretrain/D/loss/{sample_factor}', loss, i) writer.add_scalar(f'pretrain/D/norm/{sample_factor}', norm, i) if last_save_time + save_interval < time.time(): last_save_time = time.time() save_discriminator() except KeyboardInterrupt: save_discriminator() #======================================================================== # Adversarial Learning #======================================================================== def train_adversarial(sess_path, batch_data_generator, model_load_path, model_optimizer_class, model_gradient_clipping, discriminator_gradient_clipping, model_learning_rate, reset_model_optimizer, discriminator_load_path, discriminator_optimizer_class, discriminator_learning_rate, reset_discriminator_optimizer, g_max_q_mean, g_min_q_mean, d_min_loss, g_max_steps, d_max_steps, mc_sample_size, mc_sample_factor, first_to_train, save_interval, control_ratio, enable_logging): if enable_logging: from tensorboardX import SummaryWriter writer = SummaryWriter() if os.path.isfile(sess_path): adv_state = torch.load(sess_path) model_config = adv_state['model_config'] model_state = adv_state['model_state'] model_optimizer_state = adv_state['model_optimizer_state'] discriminator_config = adv_state['discriminator_config'] discriminator_state = adv_state['discriminator_state'] discriminator_optimizer_state = adv_state['discriminator_optimizer_state'] print('-' * 70) print('Session is loaded from', sess_path) loaded_from_session = True else: model_sess = torch.load(model_load_path) model_config = model_sess['model_config'] model_state = model_sess['model_state'] discriminator_sess = torch.load(discriminator_load_path) discriminator_config = discriminator_sess['discriminator_config'] discriminator_state = discriminator_sess['discriminator_state'] loaded_from_session = False model = PerformanceRNN(**model_config) model.load_state_dict(model_state) model.to(device).train() model_optimizer = model_optimizer_class(model.parameters(), lr=model_learning_rate) discriminator = EventSequenceEncoder(**discriminator_config) discriminator.load_state_dict(discriminator_state) discriminator.to(device).train() discriminator_optimizer = discriminator_optimizer_class(discriminator.parameters(), lr=discriminator_learning_rate) if loaded_from_session: if not reset_model_optimizer: model_optimizer.load_state_dict(model_optimizer_state) if not reset_discriminator_optimizer: discriminator_optimizer.load_state_dict(discriminator_optimizer_state) g_loss_func = nn.CrossEntropyLoss() d_loss_func = nn.BCEWithLogitsLoss(reduce=False) print('-' * 70) print('Options') print('sess_path:', sess_path) print('save_interval:', save_interval) print('batch_data_generator:', batch_data_generator) print('control_ratio:', control_ratio) print('g_max_q_mean:', g_max_q_mean) print('g_min_q_mean:', g_min_q_mean) print('d_min_loss:', d_min_loss) print('mc_sample_size:', mc_sample_size) print('mc_sample_factor:', mc_sample_factor) print('enable_logging:', enable_logging) print('model_load_path:', model_load_path) print('model_loss:', g_loss_func) print('model_optimizer_class:', model_optimizer_class) print('model_gradient_clipping:', model_gradient_clipping) print('model_learning_rate:', model_learning_rate) print('reset_model_optimizer:', reset_model_optimizer) print('discriminator_load_path:', discriminator_load_path) print('discriminator_loss:', d_loss_func) print('discriminator_optimizer_class:', discriminator_optimizer_class) print('discriminator_gradient_clipping:', discriminator_gradient_clipping) print('discriminator_learning_rate:', discriminator_learning_rate) print('reset_discriminator_optimizer:', reset_discriminator_optimizer) print('first_to_train:', first_to_train) print('-' * 70) print(f'Generator from "{sess_path if loaded_from_session else model_load_path}"') print(model) print(model_optimizer) print('-' * 70) print(f'Discriminator from "{sess_path if loaded_from_session else discriminator_load_path}"') print(discriminator) print(discriminator_optimizer) print('-' * 70) def save(): print(f'Saving to "{sess_path}"') torch.save({ 'model_config': model_config, 'model_state': model.state_dict(), 'model_optimizer_state': model_optimizer.state_dict(), 'discriminator_config': discriminator_config, 'discriminator_state': discriminator.state_dict(), 'discriminator_optimizer_state': discriminator_optimizer.state_dict() }, sess_path) print('Done saving') def mc_rollout(generated, hidden, total_steps, controls=None): # generated: [t, batch_size] # hidden: [n_layers, batch_size, hidden_dim] # controls: [total_steps - t, batch_size, control_dim] generated = torch.cat(generated, 0) generated_steps, batch_size = generated.shape # t, b steps = total_steps - generated_steps # s generated = generated.unsqueeze(1) # [t, 1, b] generated = generated.repeat(1, mc_sample_size, 1) # [t, mcs, b] generated = generated.view(generated_steps, -1) # [t, mcs * b] hidden = hidden.unsqueeze(1).repeat(1, mc_sample_size, 1, 1) hidden = hidden.view(model.gru_layers, -1, model.hidden_dim) if controls is not None: assert controls.shape == (steps, batch_size, model.control_dim) controls = controls.unsqueeze(1) # [s, 1, b, c] controls = controls.repeat(1, mc_sample_size, 1, 1) # [s, mcs, b, c] controls = controls.view(steps, -1, model.control_dim) # [s, mcs * b, c] event = generated[-1].unsqueeze(0) # [1, mcs * b] control = None # default when controls is None outputs = [] for i in range(steps): if controls is not None: control = controls[i].unsqueeze(0) # [1, mcs * b, c] output, hidden = model.forward(event, control=control, hidden=hidden) probs = model.output_fc_activation(output / mc_sample_factor) event = Categorical(probs).sample() # [1, mcs * b] outputs.append(event) sequences = torch.cat([generated, *outputs], 0) assert sequences.shape == (total_steps, mc_sample_size * batch_size) return sequences def train_generator(batch_size, init, events, controls): # Generator step hidden = model.init_to_hidden(init) event = model.get_primary_event(batch_size) outputs = [] generated = [] q_values = [] for step in Bar('MC Rollout').iter(range(steps)): control = controls[step].unsqueeze(0) if use_control else None output, hidden = model.forward(event, control=control, hidden=hidden) outputs.append(output) probs = model.output_fc_activation(output / mc_sample_factor) generated.append(Categorical(probs).sample()) with torch.no_grad(): if step < steps - 1: sequences = mc_rollout(generated, hidden, steps, controls[step+1:]) mc_score = discriminator(sequences) # [mcs * b] mc_score = mc_score.view(mc_sample_size, batch_size) # [mcs, b] q_value = mc_score.mean(0, keepdim=True) # [1, batch_size] else: q_value = discriminator(torch.cat(generated, 0)) q_value = q_value.unsqueeze(0) # [1, batch_size] q_values.append(q_value) # Compute loss q_values = torch.cat(q_values, 0) # [steps, batch_size] q_mean = q_values.mean().detach() q_values = q_values - q_mean generated = torch.cat(generated, 0) # [steps, batch_size] outputs = torch.cat(outputs, 0) # [steps, batch_size, event_dim] loss = F.cross_entropy(outputs.view(-1, model.event_dim), generated.view(-1), reduce=False) loss = (loss * q_values.view(-1)).mean() # Backprop model.zero_grad() loss.backward() # Gradient clipping norm = utils.compute_gradient_norm(model.parameters()) if model_gradient_clipping: nn.utils.clip_grad_norm_(model.parameters(), model_gradient_clipping) model_optimizer.step() q_mean = q_mean.item() norm = norm.item() return q_mean, norm def train_discriminator(batch_size, init, events, controls): # Discriminator step with torch.no_grad(): generated = model.generate(init, steps, None, controls, greedy=0, temperature=mc_sample_factor) fake_logit = discriminator(generated, output_logits=True) real_logit = discriminator(events, output_logits=True) fake_target = torch.zeros_like(fake_logit) real_target = torch.ones_like(real_logit) # Compute loss fake_loss = F.binary_cross_entropy_with_logits(fake_logit, fake_target) real_loss = F.binary_cross_entropy_with_logits(real_logit, real_target) loss = (real_loss + fake_loss) / 2 # Backprop discriminator.zero_grad() loss.backward() # Gradient clipping norm = utils.compute_gradient_norm(discriminator.parameters()) if discriminator_gradient_clipping: nn.utils.clip_grad_norm_(discriminator.parameters(), discriminator_gradient_clipping) discriminator_optimizer.step() real_loss = real_loss.item() fake_loss = fake_loss.item() loss = loss.item() norm = norm.item() return loss, real_loss, fake_loss, norm try: last_save_time = time.time() step_for = first_to_train g_steps = 0 d_steps = 0 for i, (events, controls) in enumerate(batch_data_generator): steps, batch_size = events.shape init = torch.randn(batch_size, model.init_dim).to(device) events = torch.LongTensor(events).to(device) use_control = np.random.random() <= control_ratio controls = torch.FloatTensor(controls).to(device) if use_control else None if step_for == 'G': q_mean, norm = train_generator(batch_size, init, events, controls) g_steps += 1 print(f'{i} (G-step) Q_mean: {q_mean}, norm: {norm}') if enable_logging: writer.add_scalar('adversarial/G/Q_mean', q_mean, i) writer.add_scalar('adversarial/G/norm', norm, i) if q_mean < g_min_q_mean: print(f'Q is too small: {q_mean}, exiting') raise KeyboardInterrupt if q_mean > g_max_q_mean or (g_max_steps and g_steps >= g_max_steps): step_for = 'D' d_steps = 0 if step_for == 'D': loss, real_loss, fake_loss, norm = train_discriminator(batch_size, init, events, controls) d_steps += 1 print(f'{i} (D-step) loss: {loss} (real: {real_loss}, fake: {fake_loss}), norm: {norm}') if enable_logging: writer.add_scalar('adversarial/D/loss', loss, i) writer.add_scalar('adversarial/D/norm', norm, i) if fake_loss <= real_loss < d_min_loss or (d_max_steps and d_steps >= d_max_steps): step_for = 'G' g_steps = 0 if last_save_time + save_interval < time.time(): last_save_time = time.time() save() except KeyboardInterrupt: save() #======================================================================== # Script Arguments #======================================================================== def batch_generator(args): print('-' * 70) dataset = Dataset(args.dataset_path, verbose=True) print(dataset) return dataset.batches(args.batch_size, args.window_size, args.stride_size) def pretrain(args): pretrain_discriminator(model_sess_path=args.generator_session_path, discriminator_sess_path=args.discriminator_session_path, discriminator_config_overwrite=utils.params2dict(args.discriminator_parameters), batch_data_generator=args.batch_generator(args), gradient_clipping=args.gradient_clipping, sample_factor=args.sample_factor, auto_sample_factor=args.auto_sample_factor, control_ratio=args.control_ratio, num_iter=args.stop_iteration, save_interval=args.save_interval, discriminator_lr=args.discriminator_learning_rate, enable_logging=args.enable_logging) def adversarial(args): train_adversarial(sess_path=args.session_path, batch_data_generator=args.batch_generator(args), model_load_path=args.generator_load_path, discriminator_load_path=args.discriminator_load_path, model_optimizer_class=getattr(optim, args.generator_optimizer), discriminator_optimizer_class=getattr(optim, args.discriminator_optimizer), model_gradient_clipping=args.generator_gradient_clipping, discriminator_gradient_clipping=args.discriminator_gradient_clipping, model_learning_rate=args.generator_learning_rate, discriminator_learning_rate=args.discriminator_learning_rate, reset_model_optimizer=args.reset_generator_optimizer, reset_discriminator_optimizer=args.reset_discriminator_optimizer, g_max_q_mean=args.g_max_q_mean, g_min_q_mean=args.g_min_q_mean, d_min_loss=args.d_min_loss, g_max_steps=args.g_max_steps, d_max_steps=args.d_max_steps, mc_sample_size=args.monte_carlo_sample_size, mc_sample_factor=args.monte_carlo_sample_factor, first_to_train=args.first_to_train, control_ratio=args.control_ratio, save_interval=args.save_interval, enable_logging=args.enable_logging) def get_args(): parser = argparse.ArgumentParser() subparsers = parser.add_subparsers() parser.add_argument('-d', '--dataset-path', type=str, required=True) parser.add_argument('-b', '--batch-size', type=int, default=64) parser.add_argument('-w', '--window-size', type=int, default=200) parser.add_argument('-s', '--stride-size', type=int, default=10) parser.set_defaults(batch_generator=batch_generator) pre_parser = subparsers.add_parser('pretrain', aliases=['p', 'pre']) pre_parser.add_argument('-G', '--generator-session-path', type=str, default=True) pre_parser.add_argument('-D', '--discriminator-session-path', type=str, required=True) pre_parser.add_argument('-p', '--discriminator-parameters', type=str, default='') pre_parser.add_argument('-l', '--discriminator-learning-rate', type=float, default=0.001) pre_parser.add_argument('-g', '--gradient-clipping', type=float, default=False) pre_parser.add_argument('-f', '--sample-factor', type=float, default=1.0) pre_parser.add_argument('-af', '--auto-sample-factor', action='store_true', default=False) pre_parser.add_argument('-c', '--control-ratio', type=float, default=1.0) pre_parser.add_argument('-n', '--stop-iteration', type=int, default=-1) pre_parser.add_argument('-i', '--save-interval', type=float, default=60.0) pre_parser.add_argument('-L', '--enable-logging', action='store_true', default=False) pre_parser.set_defaults(main=pretrain) adv_parser = subparsers.add_parser('adversarial', aliases=['a', 'adv']) adv_parser.add_argument('-S', '--session-path', type=str, required=True) adv_parser.add_argument('-Gp', '--generator-load-path', type=str) adv_parser.add_argument('-Dp', '--discriminator-load-path', type=str) adv_parser.add_argument('-Go', '--generator-optimizer', type=str, default='Adam') adv_parser.add_argument('-Do', '--discriminator-optimizer', type=str, default='RMSprop') adv_parser.add_argument('-Gg', '--generator-gradient-clipping', type=float, default=False) adv_parser.add_argument('-Dg', '--discriminator-gradient-clipping', type=float, default=False) adv_parser.add_argument('-Gl', '--generator-learning-rate', type=float, default=0.001) adv_parser.add_argument('-Dl', '--discriminator-learning-rate', type=float, default=0.001) adv_parser.add_argument('-Gr', '--reset-generator-optimizer', action='store_true', default=False) adv_parser.add_argument('-Dr', '--reset-discriminator-optimizer', action='store_true', default=False) adv_parser.add_argument('-Gq', '--g-max-q-mean', type=float, default=0.5) adv_parser.add_argument('-Gm', '--g-min-q-mean', type=float, default=0.0) adv_parser.add_argument('-Dm', '--d-min-loss', type=float, default=0.5) adv_parser.add_argument('-Gs', '--g-max-steps', type=int, default=0) adv_parser.add_argument('-Ds', '--d-max-steps', type=int, default=0) adv_parser.add_argument('-f', '--first-to-train', type=str, default='G', choices=['G', 'D']) adv_parser.add_argument('-ms', '--monte-carlo-sample-size', type=int, default=8) adv_parser.add_argument('-mf', '--monte-carlo-sample-factor', type=float, default=1.0) adv_parser.add_argument('-c', '--control-ratio', type=float, default=1.0) adv_parser.add_argument('-i', '--save-interval', type=float, default=60.0) adv_parser.add_argument('-L', '--enable-logging', action='store_true', default=False) adv_parser.set_defaults(main=adversarial) return parser.parse_args() if __name__ == '__main__': args = get_args() args.main(args)