#!/usr/bin/env python """Sample script of recurrent neural network language model. This code is ported from the following implementation written in Torch. https://github.com/tomsercu/lstm """ from __future__ import division from __future__ import print_function import argparse import json import warnings import numpy as np import chainer from chainer import cuda import chainer.functions as F import chainer.links as L from chainer import training from chainer.training import extensions from chainer import reporter embed_init = chainer.initializers.Uniform(.25) def embed_seq_batch(embed, seq_batch, dropout=0., context=None): x_len = [len(seq) for seq in seq_batch] x_section = np.cumsum(x_len[:-1]) ex = embed(F.concat(seq_batch, axis=0)) ex = F.dropout(ex, dropout) if context is not None: ids = [embed.xp.full((l, ), i).astype('i') for i, l in enumerate(x_len)] ids = embed.xp.concatenate(ids, axis=0) cx = F.embed_id(ids, context) ex = F.concat([ex, cx], axis=1) exs = F.split_axis(ex, x_section, 0) return exs class NormalOutputLayer(L.Linear): def __init__(self, *args, **kwargs): super(NormalOutputLayer, self).__init__(*args, **kwargs) def output_and_loss(self, h, t, reduce='mean'): logit = self(h) return F.softmax_cross_entropy( logit, t, normalize=False, reduce=reduce) def output(self, h, t=None): return self(h) class MLP(chainer.Chain): def __init__(self, n_hidden, in_units, hidden_units, out_units, dropout=0.): super(MLP, self).__init__() with self.init_scope(): self.l1 = L.Linear(in_units, hidden_units) self.lo = L.Linear(hidden_units, out_units) for i in range(2, n_hidden + 2): setattr(self, 'l{}'.format(i), L.Linear(hidden_units, hidden_units)) self.n_hidden = n_hidden self.dropout = dropout def __call__(self, x, label=None): x = self.l1(x) for i in range(2, self.n_hidden + 2): x = F.relu(x) x = F.dropout(x, self.dropout) x = getattr(self, 'l{}'.format(i))(x) x = F.relu(x) x = F.dropout(x, self.dropout) x = self.lo(x) x = F.relu(x) if hasattr(self, 'l1_label') and label is not None: x += self.l1_label(label) return x class BiLanguageModel(chainer.Chain): def __init__(self, n_vocab, n_units, n_layers=2, dropout=0.5): super(BiLanguageModel, self).__init__() with self.init_scope(): self.embed = L.EmbedID(n_vocab, n_units) RNN = L.NStepLSTM self.encoder_fw = RNN(n_layers, n_units, n_units, dropout) self.encoder_bw = RNN(n_layers, n_units, n_units, dropout) self.output = NormalOutputLayer(n_units, n_vocab) self.mlp = MLP(1, n_units * 2, n_units, n_units, dropout) self.dropout = dropout self.n_units = n_units self.n_layers = n_layers def add_label_condition_nets(self, n_labels, label_units): with self.init_scope(): self.mlp.add_link( 'l1_label', L.Linear(None, self.mlp.l1.b.size, nobias=True, initialW=chainer.initializers.Uniform(0.4))) self.n_labels = n_labels def encode(self, seq_batch, labels=None): seq_batch_wo_2bos = [seq[2::] for seq in seq_batch] revseq_batch_wo_2bos = [seq[::-1] for seq in seq_batch_wo_2bos] seq_batch_wo_2eos = [seq[:-2] for seq in seq_batch] bwe_seq_batch = self.embed_seq_batch(revseq_batch_wo_2bos) fwe_seq_batch = self.embed_seq_batch(seq_batch_wo_2eos) bwt_out_batch = self.encode_seq_batch( bwe_seq_batch, self.encoder_bw)[-1] fwt_out_batch = self.encode_seq_batch( fwe_seq_batch, self.encoder_fw)[-1] revbwt_concat = F.concat( [b[::-1] for b in bwt_out_batch], axis=0) fwt_concat = F.concat(fwt_out_batch, axis=0) t_out_concat = F.concat([fwt_concat, revbwt_concat], axis=1) t_out_concat = F.dropout(t_out_concat, self.dropout) if hasattr(self.mlp, 'l1_label') and labels is not None: labels = [[labels[i]] * f.shape[0] for i, f in enumerate(fwt_out_batch)] labels = self.xp.concatenate(sum(labels, []), axis=0) label_concat = self.xp.zeros( (t_out_concat.shape[0], self.n_labels)).astype('f') label_concat[self.xp.arange(len(labels)), labels] = 1. t_out_concat = self.mlp(t_out_concat, label_concat) else: t_out_concat = self.mlp(t_out_concat) return t_out_concat def embed_seq_batch(self, x_seq_batch, context=None): e_seq_batch = embed_seq_batch( self.embed, x_seq_batch, dropout=self.dropout, context=context) return e_seq_batch def encode_seq_batch(self, e_seq_batch, encoder): hs, cs, y_seq_batch = encoder(None, None, e_seq_batch) return hs, cs, y_seq_batch def calculate_loss(self, input_chain, **args): seq_batch = sum(input_chain, []) t_out_concat = self.encode(seq_batch) seq_batch_mid = [seq[1:-1] for seq in seq_batch] seq_mid_concat = F.concat(seq_batch_mid, axis=0) n_tok = sum(len(s) for s in seq_batch_mid) loss = self.output_and_loss_from_concat( t_out_concat, seq_mid_concat, normalize=n_tok) reporter.report({'perp': self.xp.exp(loss.data)}, self) return loss def output_and_loss_from_concat(self, y, t, normalize=None): y = F.dropout(y, ratio=self.dropout) loss = self.output.output_and_loss(y, t) if normalize is not None: loss *= 1. * t.shape[0] / normalize else: loss *= t.shape[0] return loss def calculate_loss_with_labels(self, seq_batch_with_labels): seq_batch, labels = seq_batch_with_labels t_out_concat = self.encode(seq_batch, labels=labels) seq_batch_mid = [seq[1:-1] for seq in seq_batch] seq_mid_concat = F.concat(seq_batch_mid, axis=0) n_tok = sum(len(s) for s in seq_batch_mid) loss = self.output_and_loss_from_concat( t_out_concat, seq_mid_concat, normalize=n_tok) reporter.report({'perp': self.xp.exp(loss.data)}, self) return loss def predict(self, xs, labels=None): with chainer.using_config('train', False), chainer.no_backprop_mode(): t_out_concat = self.encode(xs, labels=labels, add_original=0.) prob_concat = F.softmax(self.output.output(t_out_concat)).data x_len = [len(x) for x in xs] x_section = np.cumsum(x_len[:-1]) ps = np.split(cuda.to_cpu(prob_concat), x_section, 0) return ps def predict_embed(self, xs, embedW, labels=None, dropout=0., mode='sampling', temp=1., word_lower_bound=0., gold_lower_bound=0., gumbel=True, residual=0., wordwise=True, add_original=0., augment_ratio=0.25): x_len = [len(x) for x in xs] with chainer.using_config('train', False), chainer.no_backprop_mode(): t_out_concat = self.encode(xs, labels=labels) prob_concat = self.output.output(t_out_concat).data prob_concat /= temp prob_concat += self.xp.random.gumbel( size=prob_concat.shape).astype('f') prob_concat = F.softmax(prob_concat).data out_concat = F.embed_id( self.xp.argmax(prob_concat, axis=1).astype(np.int32), embedW) # insert eos eos = embedW[0][None] new_out = [] count = 0 for i, x in enumerate(xs): new_out.append(eos) new_out.append(out_concat[count:count + len(x) - 2]) new_out.append(eos) count += len(x) - 2 out_concat = F.concat(new_out, axis=0) def embed_func(x): return F.embed_id(x, embedW, ignore_label=-1) raw_concat = F.concat( sequence_embed(embed_func, xs, self.dropout), axis=0) b, u = raw_concat.shape mask = self.xp.broadcast_to( (self.xp.random.rand(b, 1) < augment_ratio), raw_concat.shape) out_concat = F.where(mask, out_concat, raw_concat) x_len = [len(x) for x in xs] x_section = np.cumsum(x_len[:-1]) out_concat = F.dropout(out_concat, dropout) exs = F.split_axis(out_concat, x_section, 0) return exs def sequence_embed(embed, xs, dropout=0.): """Efficient embedding function for variable-length sequences This output is equally to "return [F.dropout(embed(x), ratio=dropout) for x in xs]". However, calling the functions is one-shot and faster. Args: embed (callable): A :func:`~chainer.functions.embed_id` function or :class:`~chainer.links.EmbedID` link. xs (list of :class:`~chainer.Variable` or :class:`numpy.ndarray` or \ :class:`cupy.ndarray`): i-th element in the list is an input variable, which is a :math:`(L_i, )`-shaped int array. dropout (float): Dropout ratio. Returns: list of ~chainer.Variable: Output variables. i-th element in the list is an output variable, which is a :math:`(L_i, N)`-shaped float array. :math:`(N)` is the number of dimensions of word embedding. """ x_len = [len(x) for x in xs] x_section = np.cumsum(x_len[:-1]) ex = embed(F.concat(xs, axis=0)) ex = F.dropout(ex, ratio=dropout) exs = F.split_axis(ex, x_section, 0) return exs def block_embed(embed, x, dropout=0.): """Embedding function followed by convolution Args: embed (callable): A :func:`~chainer.functions.embed_id` function or :class:`~chainer.links.EmbedID` link. x (:class:`~chainer.Variable` or :class:`numpy.ndarray` or \ :class:`cupy.ndarray`): Input variable, which is a :math:`(B, L)`-shaped int array. Its first dimension :math:`(B)` is assumed to be the *minibatch dimension*. The second dimension :math:`(L)` is the length of padded sentences. dropout (float): Dropout ratio. Returns: ~chainer.Variable: Output variable. A float array with shape of :math:`(B, N, L, 1)`. :math:`(N)` is the number of dimensions of word embedding. """ e = embed(x) e = F.dropout(e, ratio=dropout) e = F.transpose(e, (0, 2, 1)) e = e[:, :, :, None] return e class PredictiveEmbed(chainer.Chain): def __init__(self, n_vocab, n_units, bilm, dropout=0., initialW=embed_init): super(PredictiveEmbed, self).__init__() with self.init_scope(): self.embed = L.EmbedID(n_vocab, n_units, ignore_label=-1, initialW=initialW) self.bilm = bilm self.n_vocab = n_vocab self.n_units = n_units self.dropout = dropout def __call__(self, x): return self.embed(x) def setup(self, mode='weighted_sum', temp=1., word_lower_bound=0., gold_lower_bound=0., gumbel=True, residual=0., wordwise=True, add_original=1., augment_ratio=0.5, ignore_unk=-1): self.config = { 'dropout': self.dropout, 'mode': mode, 'temp': temp, 'word_lower_bound': 0., 'gold_lower_bound': 0., 'gumbel': gumbel, 'residual': residual, 'wordwise': wordwise, 'add_original': add_original, 'augment_ratio': augment_ratio } if ignore_unk >= 0: self.bilm.output.b.data[ignore_unk] = -1e5 def embed_xs(self, xs, batch='concat'): if batch == 'concat': x_block = chainer.dataset.convert.concat_examples(xs, padding=-1) ex_block = block_embed(self.embed, x_block, self.dropout) return ex_block elif batch == 'list': exs = sequence_embed(self.embed, xs, self.dropout) return exs else: raise NotImplementedError def embed_xs_with_prediction(self, xs, labels=None, batch='concat'): predicted_exs = self.bilm.predict_embed( xs, self.embed.W, labels=labels, dropout=self.config['dropout'], mode=self.config['mode'], temp=self.config['temp'], word_lower_bound=self.config['word_lower_bound'], gold_lower_bound=self.config['gold_lower_bound'], gumbel=self.config['gumbel'], residual=self.config['residual'], wordwise=self.config['wordwise'], add_original=self.config['add_original'], augment_ratio=self.config['augment_ratio']) if batch == 'concat': predicted_ex_block = F.pad_sequence(predicted_exs, padding=0.) predicted_ex_block = F.transpose( predicted_ex_block, (0, 2, 1))[:, :, :, None] return predicted_ex_block elif batch == 'list': return predicted_exs else: raise NotImplementedError