import tensorflow as tf
from tensorflow.contrib import slim
import logging
from consensus_gan.ops import lrelu

logger = logging.getLogger(__name__)


def generator(z, f_dim, output_size, c_dim, is_training=True):
    # Network
    net = slim.fully_connected(z, output_size//16 * output_size//16 * f_dim, activation_fn=None)
    net = tf.reshape(net, [-1, output_size//16, output_size//16, f_dim])
    net = lrelu(net)

    conv2d_trp_argscope = slim.arg_scope(
        [slim.conv2d_transpose], kernel_size=[5, 5], stride=[2, 2], activation_fn=lrelu
    )

    with conv2d_trp_argscope:
        net = slim.conv2d_transpose(net, f_dim)
        net = slim.conv2d_transpose(net, f_dim)
        net = slim.conv2d_transpose(net, f_dim)
        net = slim.conv2d_transpose(net, c_dim, activation_fn=None)

    out = tf.nn.tanh(net)

    return out


def discriminator(x, f_dim, output_size, c_dim, is_training=True):
    # Network
    net = x

    conv2d_argscope = slim.arg_scope(
        [slim.conv2d], kernel_size=[5, 5], stride=[2, 2], activation_fn=lrelu
    )
    with conv2d_argscope:
        net = slim.conv2d(net, f_dim)
        net = slim.conv2d(net, f_dim)
        net = slim.conv2d(net, f_dim)
        net = slim.conv2d(net, f_dim)

    net = tf.reshape(net, [-1, output_size//16 * output_size//16 * f_dim])
    logits = slim.fully_connected(net, 1, activation_fn=None)
    logits = tf.squeeze(logits, -1)

    return logits