""" Embedding layers useful for recommender models. """ import numpy as np from sklearn.utils import murmurhash3_32 import torch import torch.nn as nn SEEDS = [ 179424941, 179425457, 179425907, 179426369, 179424977, 179425517, 179425943, 179426407, 179424989, 179425529, 179425993, 179426447, 179425003, 179425537, 179426003, 179426453, 179425019, 179425559, 179426029, 179426491, 179425027, 179425579, 179426081, 179426549 ] class ScaledEmbedding(nn.Embedding): """ Embedding layer that initialises its values to using a normal variable scaled by the inverse of the embedding dimension. """ def reset_parameters(self): """ Initialize parameters. """ self.weight.data.normal_(0, 1.0 / self.embedding_dim) if self.padding_idx is not None: self.weight.data[self.padding_idx].fill_(0) class ZeroEmbedding(nn.Embedding): """ Embedding layer that initialises its values to using a normal variable scaled by the inverse of the embedding dimension. Used for biases. """ def reset_parameters(self): """ Initialize parameters. """ self.weight.data.zero_() if self.padding_idx is not None: self.weight.data[self.padding_idx].fill_(0) class ScaledEmbeddingBag(nn.EmbeddingBag): """ EmbeddingBag layer that initialises its values to using a normal variable scaled by the inverse of the embedding dimension. """ def reset_parameters(self): """ Initialize parameters. """ self.weight.data.normal_(0, 1.0 / self.embedding_dim) class BloomEmbedding(nn.Module): """ An embedding layer that compresses the number of embedding parameters required by using bloom filter-like hashing. Parameters ---------- num_embeddings: int Number of entities to be represented. embedding_dim: int Latent dimension of the embedding. compression_ratio: float, optional The underlying number of rows in the embedding layer after compression. Numbers below 1.0 will use more and more compression, reducing the number of parameters in the layer. num_hash_functions: int, optional Number of hash functions used to compute the bloom filter indices. bag: bool, optional Whether to use the ``EmbeddingBag`` layer for the underlying embedding. This should be faster in principle, but currently seems to perform very poorly. Notes ----- Large embedding layers are a performance problem for fitting models: even though the gradients are sparse (only a handful of user and item vectors need parameter updates in every minibatch), PyTorch updates the entire embedding layer at every backward pass. Computation time is then wasted on applying zero gradient steps to whole embedding matrix. To alleviate this problem, we can use a smaller underlying embedding layer, and probabilistically hash users and items into that smaller space. With good hash functions, collisions should be rare, and we should observe fitting speedups without a decrease in accuracy. The idea follows the RecSys 2017 "Getting recommenders fit"[1]_ paper. The authors use a bloom-filter-like approach to hashing. Their approach uses one-hot encoded inputs followed by fully connected layers as well as softmax layers for the output, and their hashing reduces the size of the fully connected layers rather than embedding layers as implemented here; mathematically, however, the two formulations are identical. The hash function used is murmurhash3, hashing the indices with a different seed for every hash function, modulo the size of the compressed embedding layer. The hash mapping is computed once at the start of training, and indexed into for every minibatch. References ---------- .. [1] Serra, Joan, and Alexandros Karatzoglou. "Getting deep recommenders fit: Bloom embeddings for sparse binary input/output networks." arXiv preprint arXiv:1706.03993 (2017). """ def __init__(self, num_embeddings, embedding_dim, compression_ratio=0.2, num_hash_functions=4, bag=False, padding_idx=0): super(BloomEmbedding, self).__init__() self.num_embeddings = num_embeddings self.embedding_dim = embedding_dim self.compression_ratio = compression_ratio self.compressed_num_embeddings = int(compression_ratio * num_embeddings) self.num_hash_functions = num_hash_functions self.padding_idx = padding_idx self._bag = bag if num_hash_functions > len(SEEDS): raise ValueError('Can use at most {} hash functions ({} requested)' .format(len(SEEDS), num_hash_functions)) self._masks = SEEDS[:self.num_hash_functions] if self._bag: self.embeddings = ScaledEmbeddingBag(self.compressed_num_embeddings, self.embedding_dim, mode='sum') else: self.embeddings = ScaledEmbedding(self.compressed_num_embeddings, self.embedding_dim, padding_idx=self.padding_idx) # Hash cache. We pre-hash all the indices, and then just # map the indices to their pre-hashed values as we go # through the minibatches. self._hashes = None self._offsets = None def __repr__(self): return ('<BloomEmbedding (compression_ratio: {}): {}>' .format(self.compression_ratio, repr(self.embeddings))) def _get_hashed_indices(self, original_indices): def _hash(x, seed): # TODO: integrate with padding index result = murmurhash3_32(x, seed=seed) result[self.padding_idx] = 0 return result % self.compressed_num_embeddings if self._hashes is None: indices = np.arange(self.num_embeddings, dtype=np.int32) hashes = np.stack([_hash(indices, seed) for seed in self._masks], axis=1).astype(np.int64) assert hashes[self.padding_idx].sum() == 0 self._hashes = torch.from_numpy(hashes) if original_indices.is_cuda: self._hashes = self._hashes.cuda() hashed_indices = torch.index_select(self._hashes, 0, original_indices.squeeze()) return hashed_indices def forward(self, indices): """ Retrieve embeddings corresponding to indices. See documentation on PyTorch ``nn.Embedding`` for details. """ if indices.dim() == 2: batch_size, seq_size = indices.size() else: batch_size, seq_size = indices.size(0), 1 if not indices.is_contiguous(): indices = indices.contiguous() indices = indices.data.view(batch_size * seq_size, 1) if self._bag: if (self._offsets is None or self._offsets.size(0) != (batch_size * seq_size)): self._offsets = torch.arange(0, indices.numel(), indices.size(1)).long() if indices.is_cuda: self._offsets = self._offsets.cuda() hashed_indices = self._get_hashed_indices(indices) embedding = self.embeddings(hashed_indices.view(-1), self._offsets) embedding = embedding.view(batch_size, seq_size, -1) else: hashed_indices = self._get_hashed_indices(indices) embedding = self.embeddings(hashed_indices) embedding = embedding.sum(1) embedding = embedding.view(batch_size, seq_size, -1) return embedding