import tensorflow as tf class CopyModel(tf.keras.Model): def __init__(self, batch_size, vector_dim, model_type, cell_params): super().__init__() self.batch_size = batch_size self.vector_dim = vector_dim self.eof = tf.one_hot([self.vector_dim] * batch_size, depth=self.vector_dim+1) self.zero = tf.zeros([batch_size, vector_dim + 1], dtype=tf.float32) self.model_type = model_type self.cell_params = cell_params if self.model_type == 'LSTM': self.cell = tf.keras.layers.StackedRNNCells( [tf.keras.layers.LSTMCell(units=self.cell_params['rnn_size']) for _ in range(self.cell_params['rnn_num_layers'])]) elif self.model_type == 'NTM': from ntm.ntm_cell_v2 import NTMCell self.cell = NTMCell(rnn_size=self.cell_params['rnn_size'], memory_size=self.cell_params['memory_size'], memory_vector_dim=self.cell_params['memory_vector_dim'], read_head_num=self.cell_params['read_head_num'], write_head_num=self.cell_params['write_head_num'], addressing_mode='content_and_location', output_dim=self.vector_dim) else: raise ValueError('Model type not supported') @tf.function def call(self, inputs): x, seq_length = inputs x_list = tf.TensorArray(dtype=tf.float32, size=seq_length) x_list = x_list.unstack(tf.transpose(x, perm=[1, 0, 2])) state = self.cell.get_initial_state(batch_size=self.batch_size, dtype=tf.float32) for t in range(seq_length): output, state = self.cell(tf.concat([x_list.read(t), tf.zeros([self.batch_size, 1])], axis=1), state) output, state = self.cell(self.eof, state) output_list = tf.TensorArray(dtype=tf.float32, size=seq_length) for t in range(seq_length): output, state = self.cell(self.zero, state) output_list = output_list.write(t, output[:, 0:self.vector_dim]) y_pred = tf.sigmoid(tf.transpose(output_list.stack(), perm=[1, 0, 2])) return y_pred