# Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import numpy as np from keras import backend as K from keras import regularizers from keras.engine import Layer from keras.initializers import RandomNormal class KerasMatrixFactorizer(Layer): def __init__( self, rank, input_dim_i, input_dim_j, embeddings_regularizer=None, use_bias=True, **kwargs ): self.rank = rank self.input_dim_i = input_dim_i self.input_dim_j = input_dim_j self.embeddings_regularizer = regularizers.get(embeddings_regularizer) self.use_bias = use_bias super(KerasMatrixFactorizer, self).__init__(**kwargs) def build(self, input_shape): # Create a trainable weight variable for this layer. self.i_embedding = self.add_weight( shape=(self.input_dim_i, self.rank), initializer=RandomNormal(mean=0.0, stddev=1 / np.sqrt(self.rank)), name='i_embedding', regularizer=self.embeddings_regularizer ) self.j_embedding = self.add_weight( shape=(self.input_dim_j, self.rank), initializer=RandomNormal(mean=0.0, stddev=1 / np.sqrt(self.rank)), name='j_embedding', regularizer=self.embeddings_regularizer ) if self.use_bias: self.i_bias = self.add_weight( shape=(self.input_dim_i, 1), initializer='zeros', name='i_bias' ) self.j_bias = self.add_weight( shape=(self.input_dim_j, 1), initializer='zeros', name='j_bias' ) self.constant = self.add_weight( shape=(1, 1), initializer='zeros', name='constant', ) self.built = True super(KerasMatrixFactorizer, self).build(input_shape) def call(self, inputs): if K.dtype(inputs) != 'int32': inputs = K.cast(inputs, 'int32') # get the embeddings i = inputs[:, 0] # by convention j = inputs[:, 1] i_embedding = K.gather(self.i_embedding, i) j_embedding = K.gather(self.j_embedding, j) # <i_embed, j_embed> + i_bias + j_bias + constant out = K.batch_dot(i_embedding, j_embedding, axes=[1, 1]) if self.use_bias: i_bias = K.gather(self.i_bias, i) j_bias = K.gather(self.j_bias, j) out += (i_bias + j_bias + self.constant) return out def compute_output_shape(self, input_shape): return (input_shape[0], 1)