# Copyright 2018 Jörg Franke # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== import numpy as np import tensorflow as tf from adnc.model.memory_units.base_cell import BaseMemoryUnitCell from adnc.model.utils import oneplus, layer_norm, unit_simplex_initialization """ The vanilla DNC memory unit. """ class DNCMemoryUnitCell(BaseMemoryUnitCell): def __init__(self, input_size, memory_length, memory_width, read_heads, bypass_dropout=False, dnc_norm=False, seed=100, reuse=False, analyse=False, dtype=tf.float32, name='dnc_mu'): super().__init__(input_size, memory_length, memory_width, read_heads, bypass_dropout, dnc_norm, seed, reuse, analyse, dtype, name) @property def state_size(self): init_memory = tf.TensorShape([self.h_N, self.h_W]) init_usage_vector = tf.TensorShape([self.h_N]) init_write_weighting = tf.TensorShape([self.h_N]) init_precedence_weightings = tf.TensorShape([self.h_N]) init_link_mat = tf.TensorShape([self.h_N, self.h_N]) init_read_weighting = tf.TensorShape([self.h_RH, self.h_N]) return (init_memory, init_usage_vector, init_write_weighting, init_precedence_weightings, init_link_mat, init_read_weighting) def zero_state(self, batch_size, dtype=tf.float32): init_memory = tf.fill([batch_size, self.h_N, self.h_W], tf.cast(1 / (self.h_N * self.h_W), dtype=dtype)) init_usage_vector = tf.zeros([batch_size, self.h_N], dtype=dtype) init_write_weighting = unit_simplex_initialization(self.rng, batch_size, [self.h_N], dtype=dtype) init_precedence_weightings = tf.zeros([batch_size, self.h_N], dtype=dtype) init_link_mat = tf.zeros([batch_size, self.h_N, self.h_N], dtype=dtype) init_read_weighting = unit_simplex_initialization(self.rng, batch_size, [self.h_RH, self.h_N], dtype=dtype) zero_states = (init_memory, init_usage_vector, init_write_weighting, init_precedence_weightings, init_link_mat, init_read_weighting,) return zero_states def analyse_state(self, batch_size, dtype=tf.float32): alloc_gate = tf.zeros([batch_size, 1], dtype=dtype) free_gates = tf.zeros([batch_size, self.h_RH, 1], dtype=dtype) write_gate = tf.zeros([batch_size, 1], dtype=dtype) write_keys = tf.zeros([batch_size, 1, self.h_W], dtype=dtype) write_strengths = tf.zeros([batch_size, 1], dtype=dtype) write_vector = tf.zeros([batch_size, 1, self.h_W], dtype=dtype) erase_vector = tf.zeros([batch_size, 1, self.h_W], dtype=dtype) read_keys = tf.zeros([batch_size, self.h_RH, self.h_W], dtype=dtype) read_strengths = tf.zeros([batch_size, self.h_RH, 1], dtype=dtype) read_modes = tf.zeros([batch_size, self.h_RH, 3], dtype=dtype) analyse_states = alloc_gate, free_gates, write_gate, write_keys, write_strengths, write_vector, \ erase_vector, read_keys, read_strengths, read_modes return analyse_states def __call__(self, inputs, pre_states, scope=None): self.h_B = inputs.get_shape()[0].value link_matrix_inv_eye, memory_ones, batch_memory_range = self._create_constant_value_tensors(self.h_B, self.dtype) self.const_link_matrix_inv_eye = link_matrix_inv_eye self.const_memory_ones = memory_ones self.const_batch_memory_range = batch_memory_range pre_memory, pre_usage_vector, pre_write_weightings, pre_precedence_weighting, pre_link_matrix, pre_read_weightings = pre_states weighted_input = self._weight_input(inputs) control_signals = self._create_control_signals(weighted_input) alloc_gate, free_gates, write_gate, write_keys, write_strengths, write_vector, \ erase_vector, read_keys, read_strengths, read_modes = control_signals alloc_weightings, usage_vector = self._update_alloc_and_usage_vectors(pre_write_weightings, pre_read_weightings, pre_usage_vector, free_gates) write_content_weighting = self._calculate_content_weightings(pre_memory, write_keys, write_strengths) write_weighting = self._update_write_weighting(alloc_weightings, write_content_weighting, write_gate, alloc_gate) memory = self._update_memory(pre_memory, write_weighting, write_vector, erase_vector) link_matrix, precedence_weighting = self._update_link_matrix(pre_link_matrix, write_weighting, pre_precedence_weighting) forward_weightings, backward_weightings = self._make_read_forward_backward_weightings(link_matrix, pre_read_weightings) read_content_weightings = self._calculate_content_weightings(memory, read_keys, read_strengths) read_weightings = self._make_read_weightings(forward_weightings, backward_weightings, read_content_weightings, read_modes) read_vectors = self._read_memory(memory, read_weightings) read_vectors = tf.reshape(read_vectors, [self.h_B, self.h_W * self.h_RH]) if self.bypass_dropout: input_bypass = tf.nn.dropout(inputs, self.bypass_dropout) else: input_bypass = inputs output = tf.concat([read_vectors, input_bypass], axis=-1) if self.analyse: output = (output, control_signals) return output, (memory, usage_vector, write_weighting, precedence_weighting, link_matrix, read_weightings) def _create_constant_value_tensors(self, batch_size, dtype): link_matrix_inv_eye = 1 - tf.constant(np.identity(self.h_N), dtype=dtype, name="link_matrix_inv_eye") memory_ones = tf.ones([batch_size, self.h_N, self.h_W], dtype=dtype, name="memory_ones") batch_range = tf.range(0, batch_size, delta=1, dtype=tf.int32, name="batch_range") repeat_memory_length = tf.fill([self.h_N], tf.constant(self.h_N, dtype=tf.int32), name="repeat_memory_length") batch_memory_range = tf.matmul(tf.expand_dims(batch_range, -1), tf.expand_dims(repeat_memory_length, 0), name="batch_memory_range") return link_matrix_inv_eye, memory_ones, batch_memory_range def _weight_input(self, inputs): input_size = inputs.get_shape()[1].value total_signal_size = (3 + self.h_RH) * self.h_W + 5 * self.h_RH + 3 with tf.variable_scope('{}'.format(self.name), reuse=self.reuse): w_x = tf.get_variable("mu_w_x", (input_size, total_signal_size), initializer=tf.contrib.layers.xavier_initializer(seed=self.seed), collections=['memory_unit', tf.GraphKeys.GLOBAL_VARIABLES], dtype=self.dtype) b_x = tf.get_variable("mu_b_x", (total_signal_size,), initializer=tf.constant_initializer(0.), collections=['memory_unit', tf.GraphKeys.GLOBAL_VARIABLES], dtype=self.dtype) weighted_input = tf.matmul(inputs, w_x) + b_x if self.dnc_norm: weighted_input = layer_norm(weighted_input, name='layer_norm', dtype=self.dtype, collection='memory_unit') return weighted_input def _create_control_signals(self, weighted_input): write_keys = weighted_input[:, : self.h_W] # W write_strengths = weighted_input[:, self.h_W: self.h_W + 1] # 1 erase_vector = weighted_input[:, self.h_W + 1: 2 * self.h_W + 1] # W write_vector = weighted_input[:, 2 * self.h_W + 1: 3 * self.h_W + 1] # W alloc_gates = weighted_input[:, 3 * self.h_W + 1: 3 * self.h_W + 2] # 1 write_gates = weighted_input[:, 3 * self.h_W + 2: 3 * self.h_W + 3] # 1 read_keys = weighted_input[:, 3 * self.h_W + 3: (self.h_RH + 3) * self.h_W + 3] # R * W read_strengths = weighted_input[:, (self.h_RH + 3) * self.h_W + 3: (self.h_RH + 3) * self.h_W + 3 + 1 * self.h_RH] # R read_modes = weighted_input[:, (self.h_RH + 3) * self.h_W + 3 + 1 * self.h_RH: ( self.h_RH + 3) * self.h_W + 3 + 4 * self.h_RH] # 3R free_gates = weighted_input[:, (self.h_RH + 3) * self.h_W + 3 + 4 * self.h_RH: ( self.h_RH + 3) * self.h_W + 3 + 5 * self.h_RH] # R alloc_gates = tf.sigmoid(alloc_gates, 'alloc_gates') free_gates = tf.sigmoid(free_gates, 'free_gates') free_gates = tf.expand_dims(free_gates, 2) write_gates = tf.sigmoid(write_gates, 'write_gates') write_keys = tf.expand_dims(write_keys, axis=1) write_strengths = oneplus(write_strengths) # write_strengths = tf.expand_dims(write_strengths, axis=2) write_vector = tf.reshape(write_vector, [self.h_B, 1, self.h_W]) erase_vector = tf.sigmoid(erase_vector, 'erase_vector') erase_vector = tf.reshape(erase_vector, [self.h_B, 1, self.h_W]) read_keys = tf.reshape(read_keys, [self.h_B, self.h_RH, self.h_W]) read_strengths = oneplus(read_strengths) read_strengths = tf.expand_dims(read_strengths, axis=2) read_modes = tf.reshape(read_modes, [self.h_B, self.h_RH, 3]) # 3 read modes read_modes = tf.nn.softmax(read_modes, dim=2) return alloc_gates, free_gates, write_gates, write_keys, write_strengths, write_vector, \ erase_vector, read_keys, read_strengths, read_modes def _update_alloc_and_usage_vectors(self, pre_write_weightings, pre_read_weightings, pre_usage_vector, free_gates): retention_vector = tf.reduce_prod(1 - free_gates * pre_read_weightings, axis=1, keepdims=False, name='retention_prod') usage_vector = ( pre_usage_vector + pre_write_weightings - pre_usage_vector * pre_write_weightings) * retention_vector sorted_usage, free_list = tf.nn.top_k(-1 * usage_vector, self.h_N) sorted_usage = -1 * sorted_usage cumprod_sorted_usage = tf.cumprod(sorted_usage, axis=1, exclusive=True) corrected_free_list = free_list + self.const_batch_memory_range cumprod_sorted_usage_re = [tf.reshape(cumprod_sorted_usage, [-1, ]), ] corrected_free_list_re = [tf.reshape(corrected_free_list, [-1]), ] stitched_usage = tf.dynamic_stitch(corrected_free_list_re, cumprod_sorted_usage_re, name=None) stitched_usage = tf.reshape(stitched_usage, [self.h_B, self.h_N]) alloc_weighting = (1 - usage_vector) * stitched_usage return alloc_weighting, usage_vector @staticmethod def _update_write_weighting(alloc_weighting, write_content_weighting, write_gate, alloc_gate): write_weighting = write_gate * (alloc_gate * alloc_weighting + (1 - alloc_gate) * write_content_weighting) return write_weighting def _update_memory(self, pre_memory, write_weighting, write_vector, erase_vector): write_w = tf.expand_dims(write_weighting, 2) erase_matrix = tf.multiply(pre_memory, (self.const_memory_ones - tf.matmul(write_w, erase_vector))) write_matrix = tf.matmul(write_w, write_vector) return erase_matrix + write_matrix def _update_link_matrix(self, pre_link_matrix, write_weighting, pre_precedence_weighting): precedence_weighting = (1 - tf.reduce_sum(write_weighting, 1, keepdims=True)) * pre_precedence_weighting + write_weighting add_mat = tf.matmul(tf.expand_dims(write_weighting, axis=2), tf.expand_dims(pre_precedence_weighting, axis=1)) erase_mat = 1 - tf.expand_dims(write_weighting, 1) - tf.expand_dims(write_weighting, 2) updated_link_mat = erase_mat * pre_link_matrix + add_mat link_matrix = self.const_link_matrix_inv_eye * updated_link_mat return link_matrix, precedence_weighting @staticmethod def _make_read_forward_backward_weightings(link_matrix, pre_read_weightings): forward_weightings = tf.matmul(pre_read_weightings, link_matrix) backward_weightings = tf.matmul(pre_read_weightings, link_matrix, adjoint_b=True) return forward_weightings, backward_weightings @staticmethod def _make_read_weightings(forward_weightings, backward_weightings, read_content_weightings, read_modes): read_weighting = tf.expand_dims(read_modes[:, :, 0], 2) * backward_weightings + \ tf.expand_dims(read_modes[:, :, 1], 2) * read_content_weightings + \ tf.expand_dims(read_modes[:, :, 2], 2) * forward_weightings return read_weighting