#!/usr/bin/env python # -*- coding: utf-8 -*- __author__ = 'maxim' import tensorflow as tf from tensorflow.python.training import moving_averages def leaky_relu(x, alpha=0.1): x = tf.nn.relu(x) m_x = tf.nn.relu(-x) x -= alpha * m_x return x def prelu(x): shape = x.get_shape() alpha = tf.Variable(initial_value=tf.zeros(shape=shape[1:]), name='alpha') x = tf.nn.relu(x) + tf.multiply(alpha, (x - tf.abs(x))) * 0.5 return x ACTIVATIONS = {'leaky_relu': leaky_relu, 'prelu': prelu} ACTIVATIONS.update({name: getattr(tf, name) for name in ['tanh']}) ACTIVATIONS.update({name: getattr(tf.nn, name) for name in ['relu', 'elu', 'sigmoid']}) COST_FUNCTIONS = { 'l1': lambda output, y: tf.reduce_mean(tf.abs(output - y)), 'l2': lambda output, y: tf.reduce_mean(tf.pow(output - y, 2.0)), } def dropout(incoming, is_training, keep_prob): if keep_prob is None: return incoming return tf.cond(is_training, lambda: tf.nn.dropout(incoming, keep_prob), lambda: incoming) def batch_normalization(incoming, is_training, beta=0.0, gamma=1.0, epsilon=1e-5, decay=0.9): shape = incoming.get_shape() dimensions_num = len(shape) axis = list(range(dimensions_num - 1)) with tf.variable_scope('batchnorm'): beta = tf.Variable(initial_value=tf.ones(shape=[shape[-1]]) * beta, name='beta') gamma = tf.Variable(initial_value=tf.ones(shape=[shape[-1]]) * gamma, name='gamma') moving_mean = tf.Variable(initial_value=tf.zeros(shape=shape[-1:]), trainable=False, name='moving_mean') moving_variance = tf.Variable(initial_value=tf.zeros(shape=shape[-1:]), trainable=False, name='moving_variance') def update_mean_var(): mean, variance = tf.nn.moments(incoming, axis) update_moving_mean = moving_averages.assign_moving_average(moving_mean, mean, decay) update_moving_variance = moving_averages.assign_moving_average(moving_variance, variance, decay) with tf.control_dependencies([update_moving_mean, update_moving_variance]): return tf.identity(mean), tf.identity(variance) mean, var = tf.cond(is_training, update_mean_var, lambda: (moving_mean, moving_variance)) inference = tf.nn.batch_normalization(incoming, mean, var, beta, gamma, epsilon) inference.set_shape(shape) return inference