# Keras layer implementation of "Fix your classifier: the marginal value of training the last weight layer" # by Andres Torrubia, licensed under GPL 3: https://www.gnu.org/licenses/gpl-3.0.en.html # https://arxiv.org/abs/1801.04540 from keras import backend as K from keras.engine.topology import Layer from keras import activations from keras.initializers import Constant, RandomUniform import numpy as np from scipy.linalg import hadamard import math class HadamardClassifier(Layer): def __init__(self, output_dim, activation=None, use_bias=True, l2_normalize=True, output_raw_logits=False, **kwargs): self.output_dim = output_dim self.activation = activations.get(activation) self.use_bias = use_bias self.l2_normalize = l2_normalize self.output_raw_logits = output_raw_logits super(HadamardClassifier, self).__init__(**kwargs) def build(self, input_shape): hadamard_size = 2 ** int(math.ceil(math.log(max(input_shape[1], self.output_dim), 2))) self.hadamard = K.constant( value=hadamard(hadamard_size, dtype=np.int8)[:input_shape[1], :self.output_dim]) init_scale = 1. / math.sqrt(self.output_dim) self.scale = self.add_weight(name='scale', shape=(1,), initializer=Constant(init_scale), trainable=True) if self.use_bias: self.bias = self.add_weight(name='bias', shape=(self.output_dim,), initializer=RandomUniform(-init_scale, init_scale), trainable=True) super(HadamardClassifier, self).build(input_shape) def call(self, x, training=None): is_training = training not in {0, False} output = K.l2_normalize(x, axis=-1) if self.l2_normalize else x output = -self.scale * K.dot(output, self.hadamard) # pity .dot requires both tensors to be same type, the last one could be int8 if self.output_raw_logits: output_logits = -self.scale * K.dot(x, self.hadamard) # probably better to reuse output * l2norm if self.use_bias: output = K.bias_add(output, self.bias) if self.output_raw_logits: output_logits = K.bias_add(output_logits, self.bias) if self.activation is not None: output = self.activation(output) if self.output_raw_logits: return [output, output_logits] return output def compute_output_shape(self, input_shape): if self.output_raw_logits: return [(input_shape[0], self.output_dim)] * 2 return (input_shape[0], self.output_dim) def get_config(self): config = { 'output_dim': self.output_dim, 'activation': activations.serialize(self.activation), 'use_bias': self.use_bias, 'l2_normalize': self.l2_normalize, 'output_raw_logits' : self.output_raw_logits, } base_config = super(HadamardClassifier, self).get_config() return dict(list(base_config.items()) + list(config.items()))