''' Loss functions for CycleGan ''' import keras.backend as k from keras.losses import mean_squared_error _disc_train_thresh = 0.0 def discriminator_loss(y_true, y_pred): loss = mean_squared_error(y_true, y_pred) is_large = k.greater(loss, k.constant(_disc_train_thresh)) # threshold is_large = k.cast(is_large, k.floatx()) return loss * is_large # binary threshold the loss to prevent overtraining the discriminator def cycle_loss(y_true, y_pred): if k.image_data_format() is 'channels_first': x_w = 2 x_h = 3 else: x_w = 1 x_h = 2 loss = k.abs(y_true - y_pred) loss = k.sum(loss, axis=x_h) loss = k.sum(loss, axis=x_w) return loss