DL-Seq2Seq

This repository contains implementation of research papers on sequence-to-sequence learning. Currently the following implementations are supported:

Sketch Generation

Sketch-RNN model is described in the paper A Neural Representation of Sketch Drawings. It is a variational autoencoder which generates pen-strokes of various shapes. The idea is to use a Sequence-to-Sequence Variational Autoencoder (VAE) which can learn the latent distribution of the drawings.

The output of encoder is used to compute the latent parameters (mu, sigma and z) which is fed to the decoder. The output of the decoder is passed to a mixture density network (MDN) which fits k-gaussians to learn the distribution of pen strokes.

Unconditional Generation

For uncondtional generation, the decoder is trained independently while keeping the encoder parameters as non-trainable. The MDN parameters are computed by passing the output of the decoder to the MDN layer.

Train the models

If you want to train the model from scratch, then use the following command. You can set the hyperparamters in the main.py script. Set the flag cond_gen as False. The models can be trained on either cat or kanji character dataset which can be set using the data_type flag. Once trained the models would be saved in saved_model folder.

$ python main.py

Let's make some inference

For inference I have provided trained models in the saved_model folder. If you have trained your own model, then it would overwrite the pre-trained model and can be found inside saved_model folder.

>>> from data_load import get_data
>>> from model import encoder_skrnn, decoder_skrnn, skrnn_loss, skrnn_sample
>>> from eval_skrnn import draw_image
>>> data_type = 'cat' # can be kanji character or cat

>>> encoder, decoder, hid_dim, latent_dim, t_step, cond_gen, mode, device = load_pretrained_uncond(data_type)
>>> strokes, mix_params = skrnn_sample(encoder, decoder, hid_dim, latent_dim, time_step=t_step, random_state= 98,
                                               cond_gen=cond_gen, device=device, bi_mode= mode)
>>> draw_image(strokes)

Conditional Generation

Here, any sketch that is to be generated is conditioned on some input strokes. The given input is passed through the encoder whose output (hidden state) is used to compute the values of latent parameters (mu, sigma and z). The input of the decoder is the given input and the latent (z) vector concatenated together. Finally, the output of the decoder is fed to the MDN. Once trained the sketch-rnn model can be used to sample new data points.

Train the models

The conditional model can be trained from scratch using the main.py script. Set the hyperparameter cond_gen as True.

$ python main.py

Let's make some inference

For inference I have provided trained models in the saved_model folder. If you have trained your own model, then it would overwrite the pre-trained model and can be found inside saved_model folder.

>>> from data_load import get_data
>>> from model import encoder_skrnn, decoder_skrnn, skrnn_loss, skrnn_sample
>>> from eval_skrnn import draw_image
>>> import torch
>>> data_type = 'cat' # can be kanji character or cat

>>> data_enc, _ , _ = get_data(data_type=data_type) 
>>> encoder, decoder, hid_dim, latent_dim, t_step, cond_gen, mode, device = load_pretrained_cond(data_type)
>>> enc_rnd = torch.tensor(data_enc[np.random.randint(0,len(data_enc))].unsqueeze(0),\
                                                                          dtype=torch.float, device =device)

>>> strokes, mix_params = skrnn_sample(encoder, decoder, hid_dim, latent_dim, time_step=t_step, inp_enc=enc_rnd, 
                                               cond_gen=cond_gen, device=device, bi_mode=mode)
>>> draw_image(strokes)

Neural Machine Translation

For this task, I have followed attentional encoder-decoder model as described in Luong's paper. I have specifically focused on content-based attention strategy.

Train the models

If you want to train the model from scratch, then use the following command. You can set the hyperparamters in the main.py script. The trained model would be saved in saved_model folder.

$ python main.py

Let's make some inference

For inference I have provided trained models in the saved_model folder. For handwriting synthesis, the pretrained models are included in the github repository, but for machine translation please download the files from download pre-trained models for machine tranlation. Keep the downloaded saved_model folder inside the neural machine translation folder. You can also train your own model and parameters will be saved.

>>> from eval_nmt import load_pre_trained, evalText, viz_attn

>>> encoder_e2f, decoder_e2f = load_pre_trained('eng-fra') # 'eng-fra' or 'fra-eng'
>>> encoder_f2e, decoder_f2e = load_pre_trained('fra-eng') # 'eng-fra' or 'fra-eng'

>>> eng_text = "i m not giving you any money ."
>>> fra_text = "je crains de vous avoir offense ."

>>> inp1, out1, attn1 = evalText(eng_text, encoder_e2f, decoder_e2f)
English Text - "i m not giving you any money ."
French o/p   - "je ne te donnerai pas argent . <EOS>"

>>> inp2, out2, attn2 = evalText(fra_text, encoder_f2e, encoder_f2e, inp_lang='French', out_lang='English')
French Text  - "je crains de vous avoir offense ."
English o/p  - "i m afraid i ve offended you . <EOS>"

>>> viz_attn(inp1 ,out1 ,attn1)
>>> viz_attn(inp2 ,out2 ,attn2)

Handwriting Synthesis

The handwriting genration problem comes under the category of inverse problems, where we have multiple outputs at a given time-step. The idea, as given in Alex grave's paper, is to use Mixture Density Network (a gaussian distribution model) over the top of recurrent models. The handwriting generation problem is divided into two categories - unconditional and conditional generation. In case of unconditional generation the recurrent model is used to draw samples while in case of conditional generation handwriting is synthesied given some text.

Unconditional Generation

The unconditional model uses skip-connection as shown by arrows from input to outer recurrent model. The MDN parameters are computed by passing the output of recurrent model to MDN layer.

Train the models

If you want to train the model from scratch, then use the following command. You can set the hyperparamters in the main_uncond.py script. The trained model would be saved in saved_model folder.

$ python main_uncond.py

Let's make some inference

For inference I have provided trained models in the saved_model folder. If you have trained your own model, then it would overwrite the pre-trained model and can be found inside saved_model folder.

>>> from eval_hand import load_pretrained_uncond, gauss_params_plot, plot_stroke
>>> from model import model_uncond, mdn_loss, sample_uncond, scheduled_sample

>>> lr_model, h_size = load_pretrained_uncond()
>>> strokes, mix_params = sample_uncond(lr_model, h_size)
>>> plot_stroke(strokes)
>>> gauss_params_plot(mix_params)
>>> strokes, mix_params = sample_uncond(lr_model, h_size)
>>> plot_stroke(strokes)
>>> gauss_params_plot(mix_params)

Conditional Generation

In case of handwriting synthesis, a location based attention mechanism is used where a attention window (wt) is convolved with the character encodings. The attention parameters kt control the location of the window, the βt parameters control the width of the window and the αt parameters control the importance of the window within the mixture.

Train the models

If you want to train the model from scratch, then use the following command. You can set the hyperparamters in the main_congen.py script. The trained model would be saved in saved_model folder.

$ python main_congen.py

Let's make some inference

For inference I have provided trained models in the saved_model folder. If you have trained your own model, then it would overwrite the pre-trained model and can be found inside saved_model folder.

>>> from eval_hand import load_pretrained_congen, gauss_params_plot, plot_stroke
>>> from model import model_congen, mdn_loss, sample_congen

>>> lr_model, char_to_vec, h_size = load_pretrained_congen()
>>> strokes, mix_params, phi, win = sample_congen(lr_model, 'kiki do you love me ?', char_to_vec, h_size)
>>> plot_stroke(strokes)
>>> gauss_params_plot(mix_params)
>>> phi_window_plots(phi, win) 
>>> strokes, mix_params, phi, win = sample_congen(lr_model, 'a thing of beauty is joy forever', char_to_vec, h_size)
>>> plot_stroke(strokes)
>>> gauss_params_plot(mix_params)
>>> phi_window_plots(phi, win)